├── .DS_Store
├── .idea
├── .gitignore
├── TextKG.iml
├── deployment.xml
├── inspectionProfiles
│ ├── Project_Default.xml
│ └── profiles_settings.xml
├── misc.xml
└── modules.xml
├── README.md
├── densevid_eval
├── .DS_Store
├── LICENCE
├── README.md
├── __init__.py
├── anet_data
│ ├── anet_entities_test_1.json
│ ├── anet_entities_test_1_para.json
│ ├── anet_entities_test_2.json
│ ├── anet_entities_test_2_para.json
│ ├── anet_entities_val_1.json
│ ├── anet_entities_val_1_para.json
│ ├── anet_entities_val_2.json
│ ├── anet_entities_val_2_para.json
│ ├── readme.txt
│ └── train.json
├── coco-caption
│ ├── .gitignore
│ ├── README.md
│ ├── __init__.py
│ ├── annotations
│ │ └── captions_val2014.json
│ ├── cocoEvalCapDemo.ipynb
│ ├── license.txt
│ ├── pycocoevalcap
│ │ ├── __init__.py
│ │ ├── bleu
│ │ │ ├── LICENSE
│ │ │ ├── __init__.py
│ │ │ ├── bleu.py
│ │ │ └── bleu_scorer.py
│ │ ├── cider
│ │ │ ├── __init__.py
│ │ │ ├── cider.py
│ │ │ └── cider_scorer.py
│ │ ├── eval.py
│ │ ├── meteor
│ │ │ ├── __init__.py
│ │ │ └── meteor.py
│ │ ├── rouge
│ │ │ ├── __init__.py
│ │ │ └── rouge.py
│ │ └── tokenizer
│ │ │ ├── __init__.py
│ │ │ └── ptbtokenizer.py
│ └── pycocotools
│ │ ├── __init__.py
│ │ └── coco.py
├── evaluate.py
├── evaluateCaptionsDiversity.py
├── evaluateRepetition.py
├── get_caption_stat.py
├── merge_dicts_by_prefix.py
├── para-evaluate.py
└── yc2_data
│ ├── train_list.txt
│ ├── val_list.txt
│ ├── val_yc2.json
│ ├── yc2_annotations_trainval.json
│ ├── yc2_duration_frame.csv
│ ├── yc2_train_anet_format.json
│ ├── yc2_val_anet_format.json
│ └── yc2_val_anet_format_para.json
├── figures
└── network.png
└── src
├── __init__.py
├── build_vocab.py
├── caption_eval.py
├── rtransformer
├── __init__.py
├── beam_search.py
├── decode_strategy.py
├── masked_transformer.py
├── model.py
├── optimization.py
└── recursive_caption_dataset.py
├── train.py
├── translate.py
├── translator.py
├── utils
├── __init__.py
├── checkpoint.py
├── json.py
├── logging.py
├── register.py
├── train_utils.py
└── writer.py
└── utils_func.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/.DS_Store
--------------------------------------------------------------------------------
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # 默认忽略的文件
2 | /shelf/
3 | /workspace.xml
4 | # Datasource local storage ignored files
5 | /dataSources/
6 | /dataSources.local.xml
7 | # 基于编辑器的 HTTP 客户端请求
8 | /httpRequests/
9 |
--------------------------------------------------------------------------------
/.idea/TextKG.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/deployment.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
96 |
97 |
98 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # TextKG
2 |
3 | Text with Knowledge Graph Augmented Transformer for Video Captioning
4 |
5 | [[Paper]](https://arxiv.org/abs/2303.12423)
6 |
7 | Official code for **Text with Knowledge Graph Augmented Transformer for Video Captioning**.
8 |
9 | *Xin Gu, Guang Chen, Yufei Wang, Libo Zhang, Tiejian Luo, Longyin Wen*
10 |
11 | Accepted by CVPR2023
12 |
13 | ## Introduction
14 | Existing video captioning methods generally have long tail problems. We present TextKG, a knowledge graph (KG) augmented transformer for video captioning, which integrates external knowledge and exploits multi-modality information in videos to address the challenge of long-tail words.
15 |
16 | ## Approach
17 | ### Knowledge Graphs Construction
18 | - General knowledge graph (G-KG) is designed to include the most keyinformation in general scenarios in which we are interested, suchas cooking and activity. It is built from the public available giant knowledge graph ConceptNet by extracting keywords in ConceptNet with the connected edges and neighboring nodes.
19 | - Specific knowledge graph (S-KG) is built to cover key information in specific scenarios. We extract speech transcripts from videos using an automatic speech recognition (ASR) model. We gather phrases such as “adjective and noun”, "noun and noun”, and "adverb and verb" to construct the S-KG.
20 |
21 | ### Two-Stream Transformer
22 | Our approach comprises an external stream that can utilize external knowledge information and an internal stream that can leverage the multimodal information from the video.
23 |
24 | ## Architecture
25 |
26 |
27 |
Figure 1.TextKG Network
28 |
29 |
30 | ## Usage
31 |
32 | Our proposed TextKG is implemented with PyTorch.
33 |
34 | #### Environment
35 |
36 | - Python = 3.7
37 | - PyTorch = 1.4
38 | - pycocoevalcap
39 |
40 | #### 1.Installation
41 |
42 | - Clone this repo:
43 |
44 | ```
45 | git clone https://github.com/GX77/TextKG.git
46 | cd TextKG
47 | ```
48 |
49 | #### 2.Download datasets
50 |
51 | - [YouCooKII](https://drive.google.com/file/d/1mj76DwNexFCYovUt8BREeHccQn_z_By9/view?usp=sharing)
52 |
53 | ## Training & Testing
54 |
55 | #### YouCooKII
56 |
57 | ```bash
58 | # Training
59 | python3 train.py --res_root_dir YOUR_DIR --dset_name yc2
60 |
61 | # Test
62 | python3 translate.py --res_dir YOUR_DIR
63 | ```
64 | We will add other datasets later.
65 |
66 |
67 | ## Citation
68 |
69 | If our research and this repository are helpful to your work, please cite with:
70 |
71 | ```
72 | @InProceedings{Gu_2023_CVPR,
73 | author = {Gu, Xin and Chen, Guang and Wang, Yufei and Zhang, Libo and Luo, Tiejian and Wen, Longyin},
74 | title = {Text With Knowledge Graph Augmented Transformer for Video Captioning},
75 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
76 | month = {June},
77 | year = {2023},
78 | pages = {18941-18951}
79 | }
80 | ```
81 |
82 |
83 |
84 | ## Acknowledge
85 |
86 | Code of the decoding part is based on [MART](https://github.com/jayleicn/recurrent-transformer).
87 |
--------------------------------------------------------------------------------
/densevid_eval/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/densevid_eval/.DS_Store
--------------------------------------------------------------------------------
/densevid_eval/LICENCE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Ranjay Krishna
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/densevid_eval/README.md:
--------------------------------------------------------------------------------
1 | # Dense Captioning Events in Video - Evaluation Code
2 |
3 | This is a modified copy from https://github.com/jamespark3922/densevid_eval,
4 | which is again a modified copy of [densevid_eval](https://github.com/ranjaykrishna/densevid_eval).
5 | Instead of using sentence metrics, we evaluate captions at the paragraph level,
6 | as described in [Move Forward and Tell (ECCV18)](https://arxiv.org/abs/1807.10018)
7 |
8 | ## Usage
9 | ```
10 | python para-evaluate.py -s YOUR_SUBMISSION_FILE.JSON --verbose
11 | ```
12 |
13 | ## Paper
14 | Visit [the project page](http://cs.stanford.edu/people/ranjaykrishna/densevid) for details on activitynet captions.
15 |
16 | ## Citation
17 | ```
18 | @inproceedings{krishna2017dense,
19 | title={Dense-Captioning Events in Videos},
20 | author={Krishna, Ranjay and Hata, Kenji and Ren, Frederic and Fei-Fei, Li and Niebles, Juan Carlos},
21 | booktitle={ArXiv},
22 | year={2017}
23 | }
24 | ```
25 |
26 | ## License
27 |
28 | MIT License copyright Ranjay Krishna
29 |
30 |
--------------------------------------------------------------------------------
/densevid_eval/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/densevid_eval/__init__.py
--------------------------------------------------------------------------------
/densevid_eval/anet_data/readme.txt:
--------------------------------------------------------------------------------
1 | ANet-Entities val/test splits (re-split from ANet-caption val_1 and val_2 splits):
2 | https://dl.fbaipublicfiles.com/ActivityNet-Entities/ActivityNet-Entities/anet_entities_captions.tar.gz
3 |
4 | ANet-caption original splits:
5 | http://cs.stanford.edu/people/ranjaykrishna/densevid/captions.zip
6 |
7 | Experiment settings:
8 | Training: use GT segments/sentences in `train.json`,
9 | Validation: use GT segments in `anet_entities_val_1.json`, evaluate against references `anet_entities_val_1_para.json` and `anet_entities_val_2_para.json`
10 | Test: use GT segments in `anet_entities_test_1.json`, evaluate against references `anet_entities_test_1_para.json` and `anet_entities_test_2_para.json`
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/README.md:
--------------------------------------------------------------------------------
1 | Microsoft COCO Caption Evaluation
2 | ===================
3 |
4 | Evaluation codes for MS COCO caption generation.
5 |
6 | ## Requirements ##
7 | - java 1.8.0
8 | - python 2.7
9 |
10 | ## Files ##
11 | ./
12 | - cocoEvalCapDemo.py (demo script)
13 |
14 | ./annotation
15 | - captions_val2014.json (MS COCO 2014 caption validation set)
16 | - Visit MS COCO [download](http://mscoco.org/dataset/#download) page for more details.
17 |
18 | ./results
19 | - captions_val2014_fakecap_results.json (an example of fake results for running demo)
20 | - Visit MS COCO [format](http://mscoco.org/dataset/#format) page for more details.
21 |
22 | ./pycocoevalcap: The folder where all evaluation codes are stored.
23 | - evals.py: The file includes COCOEavlCap class that can be used to evaluate results on COCO.
24 | - tokenizer: Python wrapper of Stanford CoreNLP PTBTokenizer
25 | - bleu: Bleu evalutation codes
26 | - meteor: Meteor evaluation codes
27 | - rouge: Rouge-L evaluation codes
28 | - cider: CIDEr evaluation codes
29 |
30 | ## References ##
31 |
32 | - [Microsoft COCO Captions: Data Collection and Evaluation Server](http://arxiv.org/abs/1504.00325)
33 | - PTBTokenizer: We use the [Stanford Tokenizer](http://nlp.stanford.edu/software/tokenizer.shtml) which is included in [Stanford CoreNLP 3.4.1](http://nlp.stanford.edu/software/corenlp.shtml).
34 | - BLEU: [BLEU: a Method for Automatic Evaluation of Machine Translation](http://www.aclweb.org/anthology/P02-1040.pdf)
35 | - Meteor: [Project page](http://www.cs.cmu.edu/~alavie/METEOR/) with related publications. We use the latest version (1.5) of the [Code](https://github.com/mjdenkowski/meteor). Changes have been made to the source code to properly aggreate the statistics for the entire corpus.
36 | - Rouge-L: [ROUGE: A Package for Automatic Evaluation of Summaries](http://anthology.aclweb.org/W/W04/W04-1013.pdf)
37 | - CIDEr: [CIDEr: Consensus-based Image Description Evaluation] (http://arxiv.org/pdf/1411.5726.pdf)
38 |
39 | ## Developers ##
40 | - Xinlei Chen (CMU)
41 | - Hao Fang (University of Washington)
42 | - Tsung-Yi Lin (Cornell)
43 | - Ramakrishna Vedantam (Virgina Tech)
44 |
45 | ## Acknowledgement ##
46 | - David Chiang (University of Norte Dame)
47 | - Michael Denkowski (CMU)
48 | - Alexander Rush (Harvard University)
49 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/densevid_eval/coco-caption/__init__.py
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/license.txt:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015, Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 | All rights reserved.
3 |
4 | Redistribution and use in source and binary forms, with or without
5 | modification, are permitted provided that the following conditions are met:
6 |
7 | 1. Redistributions of source code must retain the above copyright notice, this
8 | list of conditions and the following disclaimer.
9 | 2. Redistributions in binary form must reproduce the above copyright notice,
10 | this list of conditions and the following disclaimer in the documentation
11 | and/or other materials provided with the distribution.
12 |
13 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
14 | ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
15 | WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
16 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
17 | ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
18 | (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
19 | LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
20 | ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
21 | (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
22 | SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23 |
24 | The views and conclusions contained in the software and documentation are those
25 | of the authors and should not be interpreted as representing official policies,
26 | either expressed or implied, of the FreeBSD Project.
27 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/bleu/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2 |
3 | Permission is hereby granted, free of charge, to any person obtaining a copy
4 | of this software and associated documentation files (the "Software"), to deal
5 | in the Software without restriction, including without limitation the rights
6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7 | copies of the Software, and to permit persons to whom the Software is
8 | furnished to do so, subject to the following conditions:
9 |
10 | The above copyright notice and this permission notice shall be included in
11 | all copies or substantial portions of the Software.
12 |
13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19 | THE SOFTWARE.
20 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/bleu/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/bleu/bleu.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : bleu.py
4 | #
5 | # Description : Wrapper for BLEU scorer.
6 | #
7 | # Creation Date : 06-01-2015
8 | # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 | import sys
11 | sys.path.append('/mnt/bd/gxvolume/Try_asr/densevid_eval/coco-caption/pycocoevalcap/bleu/')
12 | from bleu_scorer import BleuScorer
13 |
14 |
15 | class Bleu:
16 | def __init__(self, n=4):
17 | # default compute Blue score up to 4
18 | self._n = n
19 | self._hypo_for_image = {}
20 | self.ref_for_image = {}
21 |
22 | def compute_score(self, gts, res):
23 |
24 | assert(gts.keys() == res.keys())
25 | imgIds = gts.keys()
26 |
27 | bleu_scorer = BleuScorer(n=self._n)
28 | for id in imgIds:
29 | hypo = res[id]
30 | ref = gts[id]
31 |
32 | # Sanity check.
33 | assert(type(hypo) is list)
34 | assert(len(hypo) == 1)
35 | assert(type(ref) is list)
36 | assert(len(ref) >= 1)
37 |
38 | bleu_scorer += (hypo[0], ref)
39 |
40 | #score, scores = bleu_scorer.compute_score(option='shortest')
41 | score, scores = bleu_scorer.compute_score(option='closest', verbose=1)
42 | #score, scores = bleu_scorer.compute_score(option='average', verbose=1)
43 |
44 | # return (bleu, bleu_info)
45 | return score, scores
46 |
47 | def method(self):
48 | return "Bleu"
49 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/bleu/bleu_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # bleu_scorer.py
4 | # David Chiang
5 |
6 | # Copyright (c) 2004-2006 University of Maryland. All rights
7 | # reserved. Do not redistribute without permission from the
8 | # author. Not for commercial use.
9 |
10 | # Modified by:
11 | # Hao Fang
12 | # Tsung-Yi Lin
13 |
14 | '''Provides:
15 | cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16 | cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17 | '''
18 |
19 | import copy
20 | import sys, math, re
21 | from collections import defaultdict
22 |
23 | def precook(s, n=4, out=False):
24 | """Takes a string as input and returns an object that can be given to
25 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
26 | can take string arguments as well."""
27 | words = s.split()
28 | counts = defaultdict(int)
29 | for k in range(1,n+1):
30 | for i in range(len(words)-k+1):
31 | ngram = tuple(words[i:i+k])
32 | counts[ngram] += 1
33 | return (len(words), counts)
34 |
35 | def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36 | '''Takes a list of reference sentences for a single segment
37 | and returns an object that encapsulates everything that BLEU
38 | needs to know about them.'''
39 |
40 | reflen = []
41 | maxcounts = {}
42 | for ref in refs:
43 | rl, counts = precook(ref, n)
44 | reflen.append(rl)
45 | for (ngram,count) in counts.items():
46 | maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47 |
48 | # Calculate effective reference sentence length.
49 | if eff == "shortest":
50 | reflen = min(reflen)
51 | elif eff == "average":
52 | reflen = float(sum(reflen))/len(reflen)
53 |
54 | ## lhuang: N.B.: leave reflen computaiton to the very end!!
55 |
56 | ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57 |
58 | return (reflen, maxcounts)
59 |
60 | def cook_test(test, xxx_todo_changeme, eff=None, n=4):
61 | '''Takes a test sentence and returns an object that
62 | encapsulates everything that BLEU needs to know about it.'''
63 | (reflen, refmaxcounts) = xxx_todo_changeme
64 | testlen, counts = precook(test, n, True)
65 |
66 | result = {}
67 |
68 | # Calculate effective reference sentence length.
69 |
70 | if eff == "closest":
71 | result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
72 | else: ## i.e., "average" or "shortest" or None
73 | result["reflen"] = reflen
74 |
75 | result["testlen"] = testlen
76 |
77 | result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
78 |
79 | result['correct'] = [0]*n
80 | for (ngram, count) in counts.items():
81 | result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
82 |
83 | return result
84 |
85 | class BleuScorer(object):
86 | """Bleu scorer.
87 | """
88 |
89 | __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
90 | # special_reflen is used in oracle (proportional effective ref len for a node).
91 |
92 | def copy(self):
93 | ''' copy the refs.'''
94 | new = BleuScorer(n=self.n)
95 | new.ctest = copy.copy(self.ctest)
96 | new.crefs = copy.copy(self.crefs)
97 | new._score = None
98 | return new
99 |
100 | def __init__(self, test=None, refs=None, n=4, special_reflen=None):
101 | ''' singular instance '''
102 |
103 | self.n = n
104 | self.crefs = []
105 | self.ctest = []
106 | self.cook_append(test, refs)
107 | self.special_reflen = special_reflen
108 |
109 | def cook_append(self, test, refs):
110 | '''called by constructor and __iadd__ to avoid creating new instances.'''
111 |
112 | if refs is not None:
113 | self.crefs.append(cook_refs(refs))
114 | if test is not None:
115 | cooked_test = cook_test(test, self.crefs[-1])
116 | self.ctest.append(cooked_test) ## N.B.: -1
117 | else:
118 | self.ctest.append(None) # lens of crefs and ctest have to match
119 |
120 | self._score = None ## need to recompute
121 |
122 | def ratio(self, option=None):
123 | self.compute_score(option=option)
124 | return self._ratio
125 |
126 | def score_ratio(self, option=None):
127 | '''return (bleu, len_ratio) pair'''
128 | return (self.fscore(option=option), self.ratio(option=option))
129 |
130 | def score_ratio_str(self, option=None):
131 | return "%.4f (%.2f)" % self.score_ratio(option)
132 |
133 | def reflen(self, option=None):
134 | self.compute_score(option=option)
135 | return self._reflen
136 |
137 | def testlen(self, option=None):
138 | self.compute_score(option=option)
139 | return self._testlen
140 |
141 | def retest(self, new_test):
142 | if type(new_test) is str:
143 | new_test = [new_test]
144 | assert len(new_test) == len(self.crefs), new_test
145 | self.ctest = []
146 | for t, rs in zip(new_test, self.crefs):
147 | self.ctest.append(cook_test(t, rs))
148 | self._score = None
149 |
150 | return self
151 |
152 | def rescore(self, new_test):
153 | ''' replace test(s) with new test(s), and returns the new score.'''
154 |
155 | return self.retest(new_test).compute_score()
156 |
157 | def size(self):
158 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
159 | return len(self.crefs)
160 |
161 | def __iadd__(self, other):
162 | '''add an instance (e.g., from another sentence).'''
163 |
164 | if type(other) is tuple:
165 | ## avoid creating new BleuScorer instances
166 | self.cook_append(other[0], other[1])
167 | else:
168 | assert self.compatible(other), "incompatible BLEUs."
169 | self.ctest.extend(other.ctest)
170 | self.crefs.extend(other.crefs)
171 | self._score = None ## need to recompute
172 |
173 | return self
174 |
175 | def compatible(self, other):
176 | return isinstance(other, BleuScorer) and self.n == other.n
177 |
178 | def single_reflen(self, option="average"):
179 | return self._single_reflen(self.crefs[0][0], option)
180 |
181 | def _single_reflen(self, reflens, option=None, testlen=None):
182 |
183 | if option == "shortest":
184 | reflen = min(reflens)
185 | elif option == "average":
186 | reflen = float(sum(reflens))/len(reflens)
187 | elif option == "closest":
188 | reflen = min((abs(l-testlen), l) for l in reflens)[1]
189 | else:
190 | assert False, "unsupported reflen option %s" % option
191 |
192 | return reflen
193 |
194 | def recompute_score(self, option=None, verbose=0):
195 | self._score = None
196 | return self.compute_score(option, verbose)
197 |
198 | def compute_score(self, option=None, verbose=0):
199 | n = self.n
200 | small = 1e-9
201 | tiny = 1e-15 ## so that if guess is 0 still return 0
202 | bleu_list = [[] for _ in range(n)]
203 |
204 | if self._score is not None:
205 | return self._score
206 |
207 | if option is None:
208 | option = "average" if len(self.crefs) == 1 else "closest"
209 |
210 | self._testlen = 0
211 | self._reflen = 0
212 | totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
213 |
214 | # for each sentence
215 | for comps in self.ctest:
216 | testlen = comps['testlen']
217 | self._testlen += testlen
218 |
219 | if self.special_reflen is None: ## need computation
220 | reflen = self._single_reflen(comps['reflen'], option, testlen)
221 | else:
222 | reflen = self.special_reflen
223 |
224 | self._reflen += reflen
225 |
226 | for key in ['guess','correct']:
227 | for k in range(n):
228 | totalcomps[key][k] += comps[key][k]
229 |
230 | # append per image bleu score
231 | bleu = 1.
232 | for k in range(n):
233 | bleu *= (float(comps['correct'][k]) + tiny) \
234 | /(float(comps['guess'][k]) + small)
235 | bleu_list[k].append(bleu ** (1./(k+1)))
236 | ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
237 | if ratio < 1:
238 | for k in range(n):
239 | bleu_list[k][-1] *= math.exp(1 - 1/ratio)
240 |
241 | if verbose > 1:
242 | print(comps, reflen)
243 |
244 | totalcomps['reflen'] = self._reflen
245 | totalcomps['testlen'] = self._testlen
246 |
247 | bleus = []
248 | bleu = 1.
249 | for k in range(n):
250 | bleu *= float(totalcomps['correct'][k] + tiny) \
251 | / (totalcomps['guess'][k] + small)
252 | bleus.append(bleu ** (1./(k+1)))
253 | ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
254 | if ratio < 1:
255 | for k in range(n):
256 | bleus[k] *= math.exp(1 - 1/ratio)
257 |
258 | if verbose > 0:
259 | print(totalcomps)
260 | print("ratio:", ratio)
261 |
262 | self._score = bleus
263 | return self._score, bleu_list
264 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/cider/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/cider/cider.py:
--------------------------------------------------------------------------------
1 | # Filename: cider.py
2 | #
3 | # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4 | # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5 | #
6 | # Creation Date: Sun Feb 8 14:16:54 2015
7 | #
8 | # Authors: Ramakrishna Vedantam and Tsung-Yi Lin
9 | import sys
10 | sys.path.append('/mnt/bd/gxvolume/Try_asr/densevid_eval/coco-caption/pycocoevalcap/cider/')
11 | from cider_scorer import CiderScorer
12 | import pdb
13 |
14 | class Cider:
15 | """
16 | Main Class to compute the CIDEr metric
17 |
18 | """
19 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
20 | # set cider to sum over 1 to 4-grams
21 | self._n = n
22 | # set the standard deviation parameter for gaussian penalty
23 | self._sigma = sigma
24 |
25 | def compute_score(self, gts, res):
26 | """
27 | Main function to compute CIDEr score
28 | :param hypo_for_image (dict) : dictionary with key and value
29 | ref_for_image (dict) : dictionary with key and value
30 | :return: cider (float) : computed CIDEr score for the corpus
31 | """
32 |
33 | assert(gts.keys() == res.keys())
34 | imgIds = gts.keys()
35 |
36 | cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
37 |
38 | for id in imgIds:
39 | hypo = res[id]
40 | ref = gts[id]
41 |
42 | # Sanity check.
43 | assert(type(hypo) is list)
44 | assert(len(hypo) == 1)
45 | assert(type(ref) is list)
46 | assert(len(ref) > 0)
47 |
48 | cider_scorer += (hypo[0], ref)
49 |
50 | (score, scores) = cider_scorer.compute_score()
51 |
52 | return(score, scores)
53 |
54 | def method(self):
55 | return "CIDEr"
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/cider/cider_scorer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # Tsung-Yi Lin
3 | # Ramakrishna Vedantam
4 |
5 | import copy
6 | from collections import defaultdict
7 | import numpy as np
8 | import pdb
9 | import math
10 |
11 | def precook(s, n=4, out=False):
12 | """
13 | Takes a string as input and returns an object that can be given to
14 | either cook_refs or cook_test. This is optional: cook_refs and cook_test
15 | can take string arguments as well.
16 | :param s: string : sentence to be converted into ngrams
17 | :param n: int : number of ngrams for which representation is calculated
18 | :return: term frequency vector for occuring ngrams
19 | """
20 | words = s.split()
21 | counts = defaultdict(int)
22 | for k in range(1,n+1):
23 | for i in range(len(words)-k+1):
24 | ngram = tuple(words[i:i+k])
25 | counts[ngram] += 1
26 | return counts
27 |
28 | def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29 | '''Takes a list of reference sentences for a single segment
30 | and returns an object that encapsulates everything that BLEU
31 | needs to know about them.
32 | :param refs: list of string : reference sentences for some image
33 | :param n: int : number of ngrams for which (ngram) representation is calculated
34 | :return: result (list of dict)
35 | '''
36 | return [precook(ref, n) for ref in refs]
37 |
38 | def cook_test(test, n=4):
39 | '''Takes a test sentence and returns an object that
40 | encapsulates everything that BLEU needs to know about it.
41 | :param test: list of string : hypothesis sentence for some image
42 | :param n: int : number of ngrams for which (ngram) representation is calculated
43 | :return: result (dict)
44 | '''
45 | return precook(test, n, True)
46 |
47 | class CiderScorer(object):
48 | """CIDEr scorer.
49 | """
50 |
51 | def copy(self):
52 | ''' copy the refs.'''
53 | new = CiderScorer(n=self.n)
54 | new.ctest = copy.copy(self.ctest)
55 | new.crefs = copy.copy(self.crefs)
56 | return new
57 |
58 | def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59 | ''' singular instance '''
60 | self.n = n
61 | self.sigma = sigma
62 | self.crefs = []
63 | self.ctest = []
64 | self.document_frequency = defaultdict(float)
65 | self.cook_append(test, refs)
66 | self.ref_len = None
67 |
68 | def cook_append(self, test, refs):
69 | '''called by constructor and __iadd__ to avoid creating new instances.'''
70 |
71 | if refs is not None:
72 | self.crefs.append(cook_refs(refs))
73 | if test is not None:
74 | self.ctest.append(cook_test(test)) ## N.B.: -1
75 | else:
76 | self.ctest.append(None) # lens of crefs and ctest have to match
77 |
78 | def size(self):
79 | assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80 | return len(self.crefs)
81 |
82 | def __iadd__(self, other):
83 | '''add an instance (e.g., from another sentence).'''
84 |
85 | if type(other) is tuple:
86 | ## avoid creating new CiderScorer instances
87 | self.cook_append(other[0], other[1])
88 | else:
89 | self.ctest.extend(other.ctest)
90 | self.crefs.extend(other.crefs)
91 |
92 | return self
93 | def compute_doc_freq(self):
94 | '''
95 | Compute term frequency for reference data.
96 | This will be used to compute idf (inverse document frequency later)
97 | The term frequency is stored in the object
98 | :return: None
99 | '''
100 | for refs in self.crefs:
101 | # refs, k ref captions of one image
102 | for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
103 | self.document_frequency[ngram] += 1
104 | # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105 |
106 | def compute_cider(self):
107 | def counts2vec(cnts):
108 | """
109 | Function maps counts of ngram to vector of tfidf weights.
110 | The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111 | The n-th entry of array denotes length of n-grams.
112 | :param cnts:
113 | :return: vec (array of dict), norm (array of float), length (int)
114 | """
115 | vec = [defaultdict(float) for _ in range(self.n)]
116 | length = 0
117 | norm = [0.0 for _ in range(self.n)]
118 | for (ngram,term_freq) in cnts.items():
119 | # give word count 1 if it doesn't appear in reference corpus
120 | df = np.log(max(1.0, self.document_frequency[ngram]))
121 | # ngram index
122 | n = len(ngram)-1
123 | # tf (term_freq) * idf (precomputed idf) for n-grams
124 | vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125 | # compute norm for the vector. the norm will be used for computing similarity
126 | norm[n] += pow(vec[n][ngram], 2)
127 |
128 | if n == 1:
129 | length += term_freq
130 | norm = [np.sqrt(n) for n in norm]
131 | return vec, norm, length
132 |
133 | def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134 | '''
135 | Compute the cosine similarity of two vectors.
136 | :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137 | :param vec_ref: array of dictionary for vector corresponding to reference
138 | :param norm_hyp: array of float for vector corresponding to hypothesis
139 | :param norm_ref: array of float for vector corresponding to reference
140 | :param length_hyp: int containing length of hypothesis
141 | :param length_ref: int containing length of reference
142 | :return: array of score for each n-grams cosine similarity
143 | '''
144 | delta = float(length_hyp - length_ref)
145 | # measure consine similarity
146 | val = np.array([0.0 for _ in range(self.n)])
147 | for n in range(self.n):
148 | # ngram
149 | for (ngram,count) in vec_hyp[n].items():
150 | # vrama91 : added clipping
151 | val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152 |
153 | if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154 | val[n] /= (norm_hyp[n]*norm_ref[n])
155 |
156 | assert(not math.isnan(val[n]))
157 | # vrama91: added a length based gaussian penalty
158 | val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159 | return val
160 |
161 | # compute log reference length
162 | self.ref_len = np.log(float(len(self.crefs)))
163 |
164 | scores = []
165 | for test, refs in zip(self.ctest, self.crefs):
166 | # compute vector for test captions
167 | vec, norm, length = counts2vec(test)
168 | # compute vector for ref captions
169 | score = np.array([0.0 for _ in range(self.n)])
170 | for ref in refs:
171 | vec_ref, norm_ref, length_ref = counts2vec(ref)
172 | score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
173 | # change by vrama91 - mean of ngram scores, instead of sum
174 | score_avg = np.mean(score)
175 | # divide by number of references
176 | score_avg /= len(refs)
177 | # multiply score by 10
178 | score_avg *= 10.0
179 | # append score of an image to the score list
180 | scores.append(score_avg)
181 | return scores
182 |
183 | def compute_score(self, option=None, verbose=0):
184 | # compute idf
185 | self.compute_doc_freq()
186 | # assert to check document frequency
187 | assert(len(self.ctest) >= max(self.document_frequency.values()))
188 | # compute cider score
189 | score = self.compute_cider()
190 | # debug
191 | # print score
192 | return np.mean(np.array(score)), np.array(score)
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/eval.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 | from tokenizer.ptbtokenizer import PTBTokenizer
3 | from bleu.bleu import Bleu
4 | from meteor.meteor import Meteor
5 | from rouge.rouge import Rouge
6 | from cider.cider import Cider
7 |
8 | class COCOEvalCap:
9 | def __init__(self, coco, cocoRes):
10 | self.evalImgs = []
11 | self.eval = {}
12 | self.imgToEval = {}
13 | self.coco = coco
14 | self.cocoRes = cocoRes
15 | self.params = {'image_id': coco.getImgIds()}
16 |
17 | def evaluate(self):
18 | imgIds = self.params['image_id']
19 | # imgIds = self.coco.getImgIds()
20 | gts = {}
21 | res = {}
22 | for imgId in imgIds:
23 | gts[imgId] = self.coco.imgToAnns[imgId]
24 | res[imgId] = self.cocoRes.imgToAnns[imgId]
25 |
26 | # =================================================
27 | # Set up scorers
28 | # =================================================
29 | print('tokenization...')
30 | tokenizer = PTBTokenizer()
31 | gts = tokenizer.tokenize(gts)
32 | res = tokenizer.tokenize(res)
33 |
34 | # =================================================
35 | # Set up scorers
36 | # =================================================
37 | print('setting up scorers...')
38 | scorers = [
39 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
40 | (Meteor(),"METEOR"),
41 | (Rouge(), "ROUGE_L"),
42 | (Cider(), "CIDEr")
43 | ]
44 |
45 | # =================================================
46 | # Compute scores
47 | # =================================================
48 | for scorer, method in scorers:
49 | print('computing %s score...'%(scorer.method()))
50 | score, scores = scorer.compute_score(gts, res)
51 | if type(method) == list:
52 | for sc, scs, m in zip(score, scores, method):
53 | self.setEval(sc, m)
54 | self.setImgToEvalImgs(scs, gts.keys(), m)
55 | print('%s: %0.3f'%(m, sc))
56 | else:
57 | self.setEval(score, method)
58 | self.setImgToEvalImgs(scores, gts.keys(), method)
59 | print("%s: %0.3f"%(method, score))
60 | self.setEvalImgs()
61 |
62 | def setEval(self, score, method):
63 | self.eval[method] = score
64 |
65 | def setImgToEvalImgs(self, scores, imgIds, method):
66 | for imgId, score in zip(imgIds, scores):
67 | if not imgId in self.imgToEval:
68 | self.imgToEval[imgId] = {}
69 | self.imgToEval[imgId]["image_id"] = imgId
70 | self.imgToEval[imgId][method] = score
71 |
72 | def setEvalImgs(self):
73 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/meteor/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/meteor/meteor.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | # Python wrapper for METEOR implementation, by Xinlei Chen
4 | # Acknowledge Michael Denkowski for the generous discussion and help
5 |
6 | import os
7 | import sys
8 | import subprocess
9 | import threading
10 | from signal import signal, SIGPIPE, SIG_DFL
11 |
12 | # 让 python 忽略 SIGPIPE 信号,并且不抛出异常
13 | signal(SIGPIPE,SIG_DFL)
14 |
15 | # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
16 | METEOR_JAR = 'meteor-1.5.jar'
17 | # print METEOR_JAR
18 |
19 | class Meteor:
20 |
21 | def __init__(self):
22 | self.meteor_cmd = ['java', '-jar', '-Xmx2G', METEOR_JAR, \
23 | '-', '-', '-stdio', '-l', 'en', '-norm']
24 | self.meteor_p = subprocess.Popen(self.meteor_cmd, \
25 | cwd=os.path.dirname(os.path.abspath(__file__)), \
26 | stdin=subprocess.PIPE, \
27 | stdout=subprocess.PIPE, \
28 | stderr=subprocess.PIPE)
29 | # Used to guarantee thread safety
30 | self.lock = threading.Lock()
31 |
32 | def compute_score(self, gts, res):
33 | assert(gts.keys() == res.keys())
34 | imgIds = gts.keys()
35 | scores = []
36 |
37 | eval_line = 'EVAL'
38 | self.lock.acquire()
39 | for i in imgIds:
40 | assert(len(res[i]) == 1)
41 | stat = self._stat(res[i][0], gts[i])
42 | eval_line += ' ||| {}'.format(stat)
43 |
44 | self.meteor_p.stdin.write('{}\n'.format(eval_line))
45 | for i in range(0,len(imgIds)):
46 | scores.append(float(self.meteor_p.stdout.readline().strip()))
47 | score = float(self.meteor_p.stdout.readline().strip())
48 | self.lock.release()
49 |
50 | return score, scores
51 |
52 | def method(self):
53 | return "METEOR"
54 |
55 | def _stat(self, hypothesis_str, reference_list):
56 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
57 | hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
58 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
59 | w = bytes(score_line,encoding='utf-8')
60 | self.meteor_p.stdin.write(w)
61 | return self.meteor_p.stdout.readline().strip()
62 |
63 |
64 | def _score(self, hypothesis_str, reference_list):
65 | self.lock.acquire()
66 | # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
67 | hypothesis_str = hypothesis_str.replace('|||','').replace(' ',' ')
68 | score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
69 | self.meteor_p.stdin.write('{}\n'.format(score_line))
70 | stats = self.meteor_p.stdout.readline().strip()
71 | eval_line = 'EVAL ||| {}'.format(stats)
72 | # EVAL ||| stats
73 | self.meteor_p.stdin.write('{}\n'.format(eval_line))
74 | score = float(self.meteor_p.stdout.readline().strip())
75 | # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
76 | # thanks for Andrej for pointing this out
77 | score = float(self.meteor_p.stdout.readline().strip())
78 | self.lock.release()
79 | return score
80 |
81 | def __del__(self):
82 | self.lock.acquire()
83 | self.meteor_p.stdin.close()
84 | self.meteor_p.kill()
85 | self.meteor_p.wait()
86 | self.lock.release()
87 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/rouge/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'vrama91'
2 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/rouge/rouge.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : rouge.py
4 | #
5 | # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6 | #
7 | # Creation Date : 2015-01-07 06:03
8 | # Author : Ramakrishna Vedantam
9 |
10 | import numpy as np
11 | import pdb
12 |
13 | def my_lcs(string, sub):
14 | """
15 | Calculates longest common subsequence for a pair of tokenized strings
16 | :param string : list of str : tokens from a string split using whitespace
17 | :param sub : list of str : shorter string, also split using whitespace
18 | :returns: length (list of int): length of the longest common subsequence between the two strings
19 |
20 | Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21 | """
22 | if(len(string)< len(sub)):
23 | sub, string = string, sub
24 |
25 | lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
26 |
27 | for j in range(1,len(sub)+1):
28 | for i in range(1,len(string)+1):
29 | if(string[i-1] == sub[j-1]):
30 | lengths[i][j] = lengths[i-1][j-1] + 1
31 | else:
32 | lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
33 |
34 | return lengths[len(string)][len(sub)]
35 |
36 | class Rouge():
37 | '''
38 | Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
39 |
40 | '''
41 | def __init__(self):
42 | # vrama91: updated the value below based on discussion with Hovey
43 | self.beta = 1.2
44 |
45 | def calc_score(self, candidate, refs):
46 | """
47 | Compute ROUGE-L score given one candidate and references for an image
48 | :param candidate: str : candidate sentence to be evaluated
49 | :param refs: list of str : COCO reference sentences for the particular image to be evaluated
50 | :returns score: int (ROUGE-L score for the candidate evaluated against references)
51 | """
52 | assert(len(candidate)==1)
53 | assert(len(refs)>0)
54 | prec = []
55 | rec = []
56 |
57 | # split into tokens
58 | token_c = candidate[0].split(" ")
59 |
60 | for reference in refs:
61 | # split into tokens
62 | token_r = reference.split(" ")
63 | # compute the longest common subsequence
64 | lcs = my_lcs(token_r, token_c)
65 | prec.append(lcs/float(len(token_c)))
66 | rec.append(lcs/float(len(token_r)))
67 |
68 | prec_max = max(prec)
69 | rec_max = max(rec)
70 |
71 | if(prec_max!=0 and rec_max !=0):
72 | score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
73 | else:
74 | score = 0.0
75 | return score
76 |
77 | def compute_score(self, gts, res):
78 | """
79 | Computes Rouge-L score given a set of reference and candidate sentences for the dataset
80 | Invoked by evaluate_captions.py
81 | :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
82 | :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
83 | :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
84 | """
85 | assert(gts.keys() == res.keys())
86 | imgIds = gts.keys()
87 |
88 | score = []
89 | for id in imgIds:
90 | hypo = res[id]
91 | ref = gts[id]
92 |
93 | score.append(self.calc_score(hypo, ref))
94 |
95 | # Sanity check.
96 | assert(type(hypo) is list)
97 | assert(len(hypo) == 1)
98 | assert(type(ref) is list)
99 | assert(len(ref) > 0)
100 |
101 | average_score = np.mean(np.array(score))
102 | return average_score, np.array(score)
103 |
104 | def method(self):
105 | return "Rouge"
106 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/tokenizer/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'hfang'
2 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocoevalcap/tokenizer/ptbtokenizer.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | #
3 | # File Name : ptbtokenizer.py
4 | #
5 | # Description : Do the PTB Tokenization and remove punctuations.
6 | #
7 | # Creation Date : 29-12-2014
8 | # Last Modified : Thu Mar 19 09:53:35 2015
9 | # Authors : Hao Fang and Tsung-Yi Lin
10 |
11 | import os
12 | import sys
13 | import subprocess
14 | import tempfile
15 | import itertools
16 |
17 | # path to the stanford corenlp jar
18 | STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
19 |
20 | # punctuations to be removed from the sentences
21 | PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
22 | ".", "?", "!", ",", ":", "-", "--", "...", ";"]
23 |
24 | class PTBTokenizer:
25 | """Python wrapper of Stanford PTBTokenizer"""
26 |
27 | def tokenize(self, captions_for_image):
28 | cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
29 | 'edu.stanford.nlp.process.PTBTokenizer', \
30 | '-preserveLines', '-lowerCase']
31 |
32 | # ======================================================
33 | # prepare data for PTB Tokenizer
34 | # ======================================================
35 | final_tokenized_captions_for_image = {}
36 | image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
37 | sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
38 |
39 | # ======================================================
40 | # save sentences to temporary file
41 | # ======================================================
42 | path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
43 | tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname)
44 | tmp_file.write(sentences)
45 | tmp_file.close()
46 |
47 | # ======================================================
48 | # tokenize sentence
49 | # ======================================================
50 | cmd.append(os.path.basename(tmp_file.name))
51 | p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
52 | stdout=subprocess.PIPE)
53 | token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
54 | lines = token_lines.split('\n')
55 | # remove temp file
56 | os.remove(tmp_file.name)
57 |
58 | # ======================================================
59 | # create dictionary for tokenized captions
60 | # ======================================================
61 | for k, line in zip(image_id, lines):
62 | if not k in final_tokenized_captions_for_image:
63 | final_tokenized_captions_for_image[k] = []
64 | tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
65 | if w not in PUNCTUATIONS])
66 | final_tokenized_captions_for_image[k].append(tokenized_caption)
67 |
68 | return final_tokenized_captions_for_image
69 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocotools/__init__.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 |
--------------------------------------------------------------------------------
/densevid_eval/coco-caption/pycocotools/coco.py:
--------------------------------------------------------------------------------
1 | __author__ = 'tylin'
2 | __version__ = '1.0.1'
3 | # Interface for accessing the Microsoft COCO dataset.
4 |
5 | # Microsoft COCO is a large image dataset designed for object detection,
6 | # segmentation, and caption generation. pycocotools is a Python API that
7 | # assists in loading, parsing and visualizing the annotations in COCO.
8 | # Please visit http://mscoco.org/ for more information on COCO, including
9 | # for the data, paper, and tutorials. The exact format of the annotations
10 | # is also described on the COCO website. For example usage of the pycocotools
11 | # please see pycocotools_demo.ipynb. In addition to this API, please download both
12 | # the COCO images and annotations in order to run the demo.
13 |
14 | # An alternative to using the API is to load the annotations directly
15 | # into Python dictionary
16 | # Using the API provides additional utility functions. Note that this API
17 | # supports both *instance* and *caption* annotations. In the case of
18 | # captions not all functions are defined (e.g. categories are undefined).
19 |
20 | # The following API functions are defined:
21 | # COCO - COCO api class that loads COCO annotation file and prepare data structures.
22 | # decodeMask - Decode binary mask M encoded via run-length encoding.
23 | # encodeMask - Encode binary mask M using run-length encoding.
24 | # getAnnIds - Get ann ids that satisfy given filter conditions.
25 | # getCatIds - Get cat ids that satisfy given filter conditions.
26 | # getImgIds - Get img ids that satisfy given filter conditions.
27 | # loadAnns - Load anns with the specified ids.
28 | # loadCats - Load cats with the specified ids.
29 | # loadImgs - Load imgs with the specified ids.
30 | # segToMask - Convert polygon segmentation to binary mask.
31 | # showAnns - Display the specified annotations.
32 | # loadRes - Load result file and create result api object.
33 | # Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
34 | # Help on each functions can be accessed by: "help COCO>function".
35 |
36 | # See also COCO>decodeMask,
37 | # COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
38 | # COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
39 | # COCO>loadImgs, COCO>segToMask, COCO>showAnns
40 |
41 | # Microsoft COCO Toolbox. Version 1.0
42 | # Data, paper, and tutorials available at: http://mscoco.org/
43 | # Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
44 | # Licensed under the Simplified BSD License [see bsd.txt]
45 |
46 | import json
47 | import datetime
48 | import matplotlib.pyplot as plt
49 | from matplotlib.collections import PatchCollection
50 | from matplotlib.patches import Polygon
51 | import numpy as np
52 | from skimage.draw import polygon
53 | import copy
54 |
55 | class COCO:
56 | def __init__(self, annotation_file=None):
57 | """
58 | Constructor of Microsoft COCO helper class for reading and visualizing annotations.
59 | :param annotation_file (str): location of annotation file
60 | :param image_folder (str): location to the folder that hosts images.
61 | :return:
62 | """
63 | # load dataset
64 | self.dataset = {}
65 | self.anns = []
66 | self.imgToAnns = {}
67 | self.catToImgs = {}
68 | self.imgs = []
69 | self.cats = []
70 | if not annotation_file == None:
71 | print('loading annotations into memory...')
72 | time_t = datetime.datetime.utcnow()
73 | dataset = json.load(open(annotation_file, 'r'))
74 | print(datetime.datetime.utcnow() - time_t)
75 | self.dataset = dataset
76 | self.createIndex()
77 |
78 | def createIndex(self):
79 | # create index
80 | print('creating index...')
81 | imgToAnns = {ann['image_id']: [] for ann in self.dataset['annotations']}
82 | anns = {ann['id']: [] for ann in self.dataset['annotations']}
83 | for ann in self.dataset['annotations']:
84 | imgToAnns[ann['image_id']] += [ann]
85 | anns[ann['id']] = ann
86 |
87 | imgs = {im['id']: {} for im in self.dataset['images']}
88 | for img in self.dataset['images']:
89 | imgs[img['id']] = img
90 |
91 | cats = []
92 | catToImgs = []
93 | if self.dataset['type'] == 'instances':
94 | cats = {cat['id']: [] for cat in self.dataset['categories']}
95 | for cat in self.dataset['categories']:
96 | cats[cat['id']] = cat
97 | catToImgs = {cat['id']: [] for cat in self.dataset['categories']}
98 | for ann in self.dataset['annotations']:
99 | catToImgs[ann['category_id']] += [ann['image_id']]
100 |
101 | print('index created!')
102 |
103 | # create class members
104 | self.anns = anns
105 | self.imgToAnns = imgToAnns
106 | self.catToImgs = catToImgs
107 | self.imgs = imgs
108 | self.cats = cats
109 |
110 | def info(self):
111 | """
112 | Print information about the annotation file.
113 | :return:
114 | """
115 | for key, value in self.datset['info'].items():
116 | print('%s: %s'%(key, value))
117 |
118 | def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
119 | """
120 | Get ann ids that satisfy given filter conditions. default skips that filter
121 | :param imgIds (int array) : get anns for given imgs
122 | catIds (int array) : get anns for given cats
123 | areaRng (float array) : get anns for given area range (e.g. [0 inf])
124 | iscrowd (boolean) : get anns for given crowd label (False or True)
125 | :return: ids (int array) : integer array of ann ids
126 | """
127 | imgIds = imgIds if type(imgIds) == list else [imgIds]
128 | catIds = catIds if type(catIds) == list else [catIds]
129 |
130 | if len(imgIds) == len(catIds) == len(areaRng) == 0:
131 | anns = self.dataset['annotations']
132 | else:
133 | if not len(imgIds) == 0:
134 | anns = sum([self.imgToAnns[imgId] for imgId in imgIds if imgId in self.imgToAnns],[])
135 | else:
136 | anns = self.dataset['annotations']
137 | anns = anns if len(catIds) == 0 else [ann for ann in anns if ann['category_id'] in catIds]
138 | anns = anns if len(areaRng) == 0 else [ann for ann in anns if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]]
139 | if self.dataset['type'] == 'instances':
140 | if not iscrowd == None:
141 | ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
142 | else:
143 | ids = [ann['id'] for ann in anns]
144 | else:
145 | ids = [ann['id'] for ann in anns]
146 | return ids
147 |
148 | def getCatIds(self, catNms=[], supNms=[], catIds=[]):
149 | """
150 | filtering parameters. default skips that filter.
151 | :param catNms (str array) : get cats for given cat names
152 | :param supNms (str array) : get cats for given supercategory names
153 | :param catIds (int array) : get cats for given cat ids
154 | :return: ids (int array) : integer array of cat ids
155 | """
156 | catNms = catNms if type(catNms) == list else [catNms]
157 | supNms = supNms if type(supNms) == list else [supNms]
158 | catIds = catIds if type(catIds) == list else [catIds]
159 |
160 | if len(catNms) == len(supNms) == len(catIds) == 0:
161 | cats = self.dataset['categories']
162 | else:
163 | cats = self.dataset['categories']
164 | cats = cats if len(catNms) == 0 else [cat for cat in cats if cat['name'] in catNms]
165 | cats = cats if len(supNms) == 0 else [cat for cat in cats if cat['supercategory'] in supNms]
166 | cats = cats if len(catIds) == 0 else [cat for cat in cats if cat['id'] in catIds]
167 | ids = [cat['id'] for cat in cats]
168 | return ids
169 |
170 | def getImgIds(self, imgIds=[], catIds=[]):
171 | '''
172 | Get img ids that satisfy given filter conditions.
173 | :param imgIds (int array) : get imgs for given ids
174 | :param catIds (int array) : get imgs with all given cats
175 | :return: ids (int array) : integer array of img ids
176 | '''
177 | imgIds = imgIds if type(imgIds) == list else [imgIds]
178 | catIds = catIds if type(catIds) == list else [catIds]
179 |
180 | if len(imgIds) == len(catIds) == 0:
181 | ids = self.imgs.keys()
182 | else:
183 | ids = set(imgIds)
184 | for catId in catIds:
185 | if len(ids) == 0:
186 | ids = set(self.catToImgs[catId])
187 | else:
188 | ids &= set(self.catToImgs[catId])
189 | return list(ids)
190 |
191 | def loadAnns(self, ids=[]):
192 | """
193 | Load anns with the specified ids.
194 | :param ids (int array) : integer ids specifying anns
195 | :return: anns (object array) : loaded ann objects
196 | """
197 | if type(ids) == list:
198 | return [self.anns[id] for id in ids]
199 | elif type(ids) == int:
200 | return [self.anns[ids]]
201 |
202 | def loadCats(self, ids=[]):
203 | """
204 | Load cats with the specified ids.
205 | :param ids (int array) : integer ids specifying cats
206 | :return: cats (object array) : loaded cat objects
207 | """
208 | if type(ids) == list:
209 | return [self.cats[id] for id in ids]
210 | elif type(ids) == int:
211 | return [self.cats[ids]]
212 |
213 | def loadImgs(self, ids=[]):
214 | """
215 | Load anns with the specified ids.
216 | :param ids (int array) : integer ids specifying img
217 | :return: imgs (object array) : loaded img objects
218 | """
219 | if type(ids) == list:
220 | return [self.imgs[id] for id in ids]
221 | elif type(ids) == int:
222 | return [self.imgs[ids]]
223 |
224 | def showAnns(self, anns):
225 | """
226 | Display the specified annotations.
227 | :param anns (array of object): annotations to display
228 | :return: None
229 | """
230 | if len(anns) == 0:
231 | return 0
232 | if self.dataset['type'] == 'instances':
233 | ax = plt.gca()
234 | polygons = []
235 | color = []
236 | for ann in anns:
237 | c = np.random.random((1, 3)).tolist()[0]
238 | if type(ann['segmentation']) == list:
239 | # polygon
240 | for seg in ann['segmentation']:
241 | poly = np.array(seg).reshape((len(seg)/2, 2))
242 | polygons.append(Polygon(poly, True,alpha=0.4))
243 | color.append(c)
244 | else:
245 | # mask
246 | mask = COCO.decodeMask(ann['segmentation'])
247 | img = np.ones( (mask.shape[0], mask.shape[1], 3) )
248 | if ann['iscrowd'] == 1:
249 | color_mask = np.array([2.0,166.0,101.0])/255
250 | if ann['iscrowd'] == 0:
251 | color_mask = np.random.random((1, 3)).tolist()[0]
252 | for i in range(3):
253 | img[:,:,i] = color_mask[i]
254 | ax.imshow(np.dstack( (img, mask*0.5) ))
255 | p = PatchCollection(polygons, facecolors=color, edgecolors=(0,0,0,1), linewidths=3, alpha=0.4)
256 | ax.add_collection(p)
257 | if self.dataset['type'] == 'captions':
258 | for ann in anns:
259 | print(ann['caption'])
260 |
261 | def loadRes(self, resFile):
262 | """
263 | Load result file and return a result api object.
264 | :param resFile (str) : file name of result file
265 | :return: res (obj) : result api object
266 | """
267 | res = COCO()
268 | res.dataset['images'] = [img for img in self.dataset['images']]
269 | res.dataset['info'] = copy.deepcopy(self.dataset['info'])
270 | res.dataset['type'] = copy.deepcopy(self.dataset['type'])
271 | res.dataset['licenses'] = copy.deepcopy(self.dataset['licenses'])
272 |
273 | print('Loading and preparing results... ')
274 | time_t = datetime.datetime.utcnow()
275 | anns = json.load(open(resFile))
276 | assert type(anns) == list, 'results in not an array of objects'
277 | annsImgIds = [ann['image_id'] for ann in anns]
278 | assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
279 | 'Results do not correspond to current coco set'
280 | if 'caption' in anns[0]:
281 | imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns])
282 | res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds]
283 | for id, ann in enumerate(anns):
284 | ann['id'] = id
285 | elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
286 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
287 | for id, ann in enumerate(anns):
288 | bb = ann['bbox']
289 | x1, x2, y1, y2 = [bb[0], bb[0]+bb[2], bb[1], bb[1]+bb[3]]
290 | ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
291 | ann['area'] = bb[2]*bb[3]
292 | ann['id'] = id
293 | ann['iscrowd'] = 0
294 | elif 'segmentation' in anns[0]:
295 | res.dataset['categories'] = copy.deepcopy(self.dataset['categories'])
296 | for id, ann in enumerate(anns):
297 | ann['area']=sum(ann['segmentation']['counts'][2:-1:2])
298 | ann['bbox'] = []
299 | ann['id'] = id
300 | ann['iscrowd'] = 0
301 | print('DONE (t=%0.2fs)'%((datetime.datetime.utcnow() - time_t).total_seconds()))
302 |
303 | res.dataset['annotations'] = anns
304 | res.createIndex()
305 | return res
306 |
307 |
308 | @staticmethod
309 | def decodeMask(R):
310 | """
311 | Decode binary mask M encoded via run-length encoding.
312 | :param R (object RLE) : run-length encoding of binary mask
313 | :return: M (bool 2D array) : decoded binary mask
314 | """
315 | N = len(R['counts'])
316 | M = np.zeros( (R['size'][0]*R['size'][1], ))
317 | n = 0
318 | val = 1
319 | for pos in range(N):
320 | val = not val
321 | for c in range(R['counts'][pos]):
322 | R['counts'][pos]
323 | M[n] = val
324 | n += 1
325 | return M.reshape((R['size']), order='F')
326 |
327 | @staticmethod
328 | def encodeMask(M):
329 | """
330 | Encode binary mask M using run-length encoding.
331 | :param M (bool 2D array) : binary mask to encode
332 | :return: R (object RLE) : run-length encoding of binary mask
333 | """
334 | [h, w] = M.shape
335 | M = M.flatten(order='F')
336 | N = len(M)
337 | counts_list = []
338 | pos = 0
339 | # counts
340 | counts_list.append(1)
341 | diffs = np.logical_xor(M[0:N-1], M[1:N])
342 | for diff in diffs:
343 | if diff:
344 | pos +=1
345 | counts_list.append(1)
346 | else:
347 | counts_list[pos] += 1
348 | # if array starts from 1. start with 0 counts for 0
349 | if M[0] == 1:
350 | counts_list = [0] + counts_list
351 | return {'size': [h, w],
352 | 'counts': counts_list ,
353 | }
354 |
355 | @staticmethod
356 | def segToMask( S, h, w ):
357 | """
358 | Convert polygon segmentation to binary mask.
359 | :param S (float array) : polygon segmentation mask
360 | :param h (int) : target mask height
361 | :param w (int) : target mask width
362 | :return: M (bool 2D array) : binary mask
363 | """
364 | M = np.zeros((h,w), dtype=np.bool)
365 | for s in S:
366 | N = len(s)
367 | rr, cc = polygon(np.array(s[1:N:2]), np.array(s[0:N:2])) # (y, x)
368 | M[rr, cc] = 1
369 | return M
--------------------------------------------------------------------------------
/densevid_eval/evaluate.py:
--------------------------------------------------------------------------------
1 | # --------------------------------------------------------
2 | # Dense-Captioning Events in Videos Eval
3 | # Copyright (c) 2017 Ranjay Krishna
4 | # Licensed under The MIT License [see LICENSE for details]
5 | # Written by Ranjay Krishna
6 | # --------------------------------------------------------
7 |
8 | import argparse
9 | import string
10 | import json
11 | import sys
12 | sys.path.insert(0, './coco-caption') # Hack to allow the import of pycocoeval
13 |
14 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
15 | from pycocoevalcap.bleu.bleu import Bleu
16 | from pycocoevalcap.meteor.meteor import Meteor
17 | from pycocoevalcap.rouge.rouge import Rouge
18 | from pycocoevalcap.cider.cider import Cider
19 | from pycocoevalcap.spice.spice import Spice
20 | from sets import Set
21 | import numpy as np
22 |
23 | def remove_nonascii(text):
24 | return ''.join([i if ord(i) < 128 else ' ' for i in text])
25 |
26 | class ANETcaptions(object):
27 | PREDICTION_FIELDS = ['results', 'version', 'external_data']
28 |
29 | def __init__(self, ground_truth_filenames=None, prediction_filename=None,
30 | tious=None, max_proposals=1000,
31 | prediction_fields=PREDICTION_FIELDS, verbose=False):
32 | # Check that the gt and submission files exist and load them
33 | if len(tious) == 0:
34 | raise IOError('Please input a valid tIoU.')
35 | if not ground_truth_filenames:
36 | raise IOError('Please input a valid ground truth file.')
37 | if not prediction_filename:
38 | raise IOError('Please input a valid prediction file.')
39 |
40 | self.verbose = verbose
41 | self.tious = tious
42 | self.max_proposals = max_proposals
43 | self.pred_fields = prediction_fields
44 | self.ground_truths = self.import_ground_truths(ground_truth_filenames)
45 | self.prediction = self.import_prediction(prediction_filename)
46 | self.tokenizer = PTBTokenizer()
47 |
48 | # Set up scorers, if not verbose, we only use the one we're
49 | # testing on: METEOR
50 | if self.verbose:
51 | self.scorers = [
52 | (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
53 | (Meteor(),"METEOR"),
54 | (Rouge(), "ROUGE_L"),
55 | (Cider(), "CIDEr"),
56 | (Spice(), "SPICE")
57 | ]
58 | else:
59 | self.scorers = [(Meteor(), "METEOR")]
60 |
61 | def import_prediction(self, prediction_filename):
62 | if self.verbose:
63 | print("| Loading submission...")
64 | submission = json.load(open(prediction_filename))
65 | if not all([field in submission.keys() for field in self.pred_fields]):
66 | raise IOError('Please input a valid ground truth file.')
67 | # Ensure that every video is limited to the correct maximum number of proposals.
68 | results = {}
69 | len_captions = 0
70 | for vid_id in submission['results']:
71 | results[vid_id] = submission['results'][vid_id][:self.max_proposals]
72 | len_captions+= len(submission['results'][vid_id][:self.max_proposals])
73 | print('len of results:', len(results))
74 | print('len of captions:', len_captions)
75 | return results
76 |
77 | def import_ground_truths(self, filenames):
78 | gts = []
79 | self.n_ref_vids = Set()
80 | for filename in filenames:
81 | gt = json.load(open(filename))
82 | self.n_ref_vids.update(gt.keys())
83 | gts.append(gt)
84 | if self.verbose:
85 | print("| Loading GT. #files: %d, #videos: %d" % (len(filenames), len(self.n_ref_vids)))
86 | return gts
87 |
88 | def iou(self, interval_1, interval_2):
89 | start_i, end_i = interval_1[0], interval_1[1]
90 | start, end = interval_2[0], interval_2[1]
91 | intersection = max(0, min(end, end_i) - max(start, start_i))
92 | union = min(max(end, end_i) - min(start, start_i), end-start + end_i-start_i)
93 | iou = float(intersection) / (union + 1e-8)
94 | return iou
95 |
96 | def check_gt_exists(self, vid_id):
97 | for gt in self.ground_truths:
98 | if vid_id in gt:
99 | return True
100 | return False
101 |
102 | def get_gt_vid_ids(self):
103 | vid_ids = set([])
104 | for gt in self.ground_truths:
105 | vid_ids |= set(gt.keys())
106 | return list(vid_ids)
107 |
108 | def evaluate(self):
109 | aggregator = {}
110 | self.scores = {}
111 | for tiou in self.tious:
112 | scores = self.evaluate_tiou(tiou)
113 | for metric, score in scores.items():
114 | if metric not in self.scores:
115 | self.scores[metric] = []
116 | self.scores[metric].append(score)
117 | if self.verbose:
118 | self.scores['Recall'] = []
119 | self.scores['Precision'] = []
120 | for tiou in self.tious:
121 | precision, recall = self.evaluate_detection(tiou)
122 | self.scores['Recall'].append(recall)
123 | self.scores['Precision'].append(precision)
124 |
125 | def evaluate_detection(self, tiou):
126 | gt_vid_ids = self.get_gt_vid_ids()
127 | # Recall is the percentage of ground truth that is covered by the predictions
128 | # Precision is the percentage of predictions that are valid
129 | recall = [0] * len(gt_vid_ids)
130 | precision = [0] * len(gt_vid_ids)
131 | for vid_i, vid_id in enumerate(gt_vid_ids):
132 | best_recall = 0
133 | best_precision = 0
134 | for gt in self.ground_truths:
135 | if vid_id not in gt:
136 | continue
137 | refs = gt[vid_id]
138 | ref_set_covered = set([])
139 | pred_set_covered = set([])
140 | num_gt = 0
141 | num_pred = 0
142 | if vid_id in self.prediction:
143 | for pred_i, pred in enumerate(self.prediction[vid_id]):
144 | pred_timestamp = pred['timestamp']
145 | for ref_i, ref_timestamp in enumerate(refs['timestamps']):
146 | if self.iou(pred_timestamp, ref_timestamp) > tiou:
147 | ref_set_covered.add(ref_i)
148 | pred_set_covered.add(pred_i)
149 |
150 | new_precision = float(len(pred_set_covered)) / (pred_i + 1)
151 | best_precision = max(best_precision, new_precision)
152 | new_recall = float(len(ref_set_covered)) / len(refs['timestamps'])
153 | best_recall = max(best_recall, new_recall)
154 | recall[vid_i] = best_recall
155 | precision[vid_i] = best_precision
156 | return sum(precision) / len(precision), sum(recall) / len(recall)
157 |
158 | def evaluate_tiou(self, tiou):
159 | # This method averages the tIoU precision from METEOR, Bleu, etc. across videos
160 | res = {}
161 | gts = {}
162 | gt_vid_ids = self.get_gt_vid_ids()
163 |
164 | unique_index = 0
165 |
166 | # video id to unique caption ids mapping
167 | vid2capid = {}
168 |
169 | cur_res = {}
170 | cur_gts = {}
171 |
172 | for vid_id in gt_vid_ids:
173 |
174 | vid2capid[vid_id] = []
175 |
176 | # If the video does not have a prediction, then Vwe give it no matches
177 | # We set it to empty, and use this as a sanity check later on
178 | if vid_id not in self.prediction:
179 | pass
180 |
181 | # If we do have a prediction, then we find the scores based on all the
182 | # valid tIoU overlaps
183 | else:
184 | # For each prediction, we look at the tIoU with ground truth
185 | for i,pred in enumerate(self.prediction[vid_id]):
186 | has_added = False
187 | for gt in self.ground_truths:
188 | if vid_id not in gt:
189 | print('skipped')
190 | continue
191 | gt_captions = gt[vid_id]
192 | for caption_idx, caption_timestamp in enumerate(gt_captions['timestamps']):
193 | if True or self.iou(pred['timestamp'], caption_timestamp) >= tiou:
194 | gt_caption = gt_captions['sentences'][i] # for now we use gt proposal
195 | cur_res[unique_index] = [{'caption': remove_nonascii(pred['sentence'])}]
196 | cur_gts[unique_index] = [{'caption': remove_nonascii(gt_caption)}] # for now we use gt proposal
197 | #cur_gts[unique_index] = [{'caption': remove_nonascii(gt_captions['sentences'][caption_idx])}]
198 | vid2capid[vid_id].append(unique_index)
199 | unique_index += 1
200 | has_added = True
201 | break # for now we use gt proposal
202 |
203 | # If the predicted caption does not overlap with any ground truth,
204 | # we should compare it with garbage
205 | if not has_added:
206 | cur_res[unique_index] = [{'caption': remove_nonascii(pred['sentence'])}]
207 | cur_gts[unique_index] = [{'caption': 'abc123!@#'}]
208 | vid2capid[vid_id].append(unique_index)
209 | unique_index += 1
210 |
211 | # Each scorer will compute across all videos and take average score
212 | output = {}
213 |
214 | # call tokenizer here for all predictions and gts
215 | tokenize_res = self.tokenizer.tokenize(cur_res)
216 | tokenize_gts = self.tokenizer.tokenize(cur_gts)
217 |
218 | # reshape back
219 | for vid in vid2capid.keys():
220 | res[vid] = {index:tokenize_res[index] for index in vid2capid[vid]}
221 | gts[vid] = {index:tokenize_gts[index] for index in vid2capid[vid]}
222 |
223 | for scorer, method in self.scorers:
224 | if self.verbose:
225 | print('computing %s score...'%(scorer.method()))
226 |
227 | # For each video, take all the valid pairs (based from tIoU) and compute the score
228 | all_scores = {}
229 |
230 | if method == "SPICE": # don't want to compute spice for 10000 times
231 | print("getting spice score...")
232 | score, scores = scorer.compute_score(tokenize_gts, tokenize_res)
233 | all_scores[0] = score
234 | else:
235 | for i,vid_id in enumerate(gt_vid_ids):
236 | if len(res[vid_id]) == 0 or len(gts[vid_id]) == 0:
237 | if type(method) == list:
238 | score = [0] * len(method)
239 | else:
240 | score = 0
241 | else:
242 | score, scores = scorer.compute_score(gts[vid_id], res[vid_id])
243 | all_scores[vid_id] = score
244 |
245 | #print all_scores.values()
246 | if type(method) == list:
247 | scores = np.mean(all_scores.values(), axis=0)
248 | for m in range(len(method)):
249 | output[method[m]] = scores[m]
250 | if self.verbose:
251 | print("Calculated tIoU: %1.1f, %s: %0.3f" % (tiou, method[m], output[method[m]]))
252 | else:
253 | output[method] = np.mean(all_scores.values())
254 | if self.verbose:
255 | print("Calculated tIoU: %1.1f, %s: %0.3f" % (tiou, method, output[method]))
256 | return output
257 |
258 | def main(args):
259 | # Call coco eval
260 | evaluator = ANETcaptions(ground_truth_filenames=args.references,
261 | prediction_filename=args.submission,
262 | tious=args.tious,
263 | max_proposals=args.max_proposals_per_video,
264 | verbose=args.verbose)
265 | evaluator.evaluate()
266 |
267 | # Output the results
268 | if args.verbose:
269 | for i, tiou in enumerate(args.tious):
270 | print('-' * 80)
271 | print("tIoU: " , tiou)
272 | print('-' * 80)
273 | for metric in evaluator.scores:
274 | score = evaluator.scores[metric][i]
275 | print('| %s: %2.4f'%(metric, 100*score))
276 |
277 | # Print the averages
278 | print('-' * 80)
279 | print("Average across all tIoUs")
280 | print('-' * 80)
281 | output = {}
282 | for metric in evaluator.scores:
283 | score = evaluator.scores[metric]
284 | print('| %s: %2.4f'%(metric, 100 * sum(score) / float(len(score))))
285 | output[metric] = 100 * sum(score) / float(len(score))
286 | json.dump(output,open(args.output,'w'))
287 | print(output)
288 | if __name__=='__main__':
289 | parser = argparse.ArgumentParser(description='Evaluate the results stored in a submissions file.')
290 | parser.add_argument('-s', '--submission', type=str, default='sample_submission.json',
291 | help='sample submission file for ActivityNet Captions Challenge.')
292 | parser.add_argument('-r', '--references', type=str, nargs='+', default=['data/val_1.json'],
293 | help='reference files with ground truth captions to compare results against. delimited (,) str')
294 | parser.add_argument('-o', '--output', type=str, default='result.json',
295 | help='output file with final language metrics.')
296 | parser.add_argument('--tious', type=float, nargs='+', default=[0.3],
297 | help='Choose the tIoUs to average over.')
298 | parser.add_argument('-ppv', '--max-proposals-per-video', type=int, default=1000,
299 | help='maximum propoasls per video.')
300 | parser.add_argument('-v', '--verbose', action='store_true',
301 | help='Print intermediate steps.')
302 | args = parser.parse_args()
303 |
304 | main(args)
305 |
--------------------------------------------------------------------------------
/densevid_eval/evaluateCaptionsDiversity.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import sys
4 |
5 | def getNgrams(words_pred, unigrams, bigrams, trigrams, fourgrams):
6 | # N=1
7 | for w in words_pred:
8 | if w not in unigrams:
9 | unigrams[w] = 0
10 | unigrams[w] += 1
11 | # N=2
12 | for i, w in enumerate(words_pred):
13 | if i', ' ')
76 | query = query.replace('`', ' ')
77 | query = query.replace('#', ' ')
78 | query = query.replace(u'\u2019', "'")
79 | while query[-1] == ' ':
80 | query = query[0:-1]
81 | while query[0] == ' ':
82 | query = query[1:]
83 | while ' ' in query:
84 | query = query.replace(' ', ' ')
85 | # print(query)
86 | if query not in trainingQueries:
87 | trainingQueries[query] = 0
88 | trainingQueries[query] += 1
89 |
90 | vocab = {}
91 | novel_sentence = []
92 | uniq_sentence = {}
93 | count_sent = 0
94 | sent_length = []
95 |
96 | for vid in data_gt['results']:
97 | for i, _ in enumerate(data_gt['results'][vid]):
98 |
99 | try:
100 | pred_sentence = data_predicted['results'][vid][i]['sentence'].lower()
101 | except:
102 | continue
103 |
104 | if pred_sentence[-1] == '.':
105 | pred_sentence = pred_sentence[0:-1]
106 | while pred_sentence[-1] == ' ':
107 | pred_sentence = pred_sentence[0:-1]
108 | pred_sentence = pred_sentence.replace(',', ' ')
109 | while ' ' in pred_sentence:
110 | pred_sentence = pred_sentence.replace(' ', ' ')
111 |
112 | if pred_sentence in trainingQueries:
113 | novel_sentence.append(0)
114 | else:
115 | novel_sentence.append(1)
116 |
117 | if pred_sentence not in uniq_sentence:
118 | uniq_sentence[pred_sentence] = 0
119 | uniq_sentence[pred_sentence] += 1
120 |
121 | words_pred = pred_sentence.split(' ')
122 | for w in words_pred:
123 | if w not in vocab:
124 | vocab[w] = 0
125 | vocab[w] += 1
126 |
127 | sent_length.append(len(words_pred))
128 | count_sent += 1
129 |
130 | print ('Vocab: %d\t Novel Sent: %.2f\t Uniq Sent: %.2f\t Sent length: %.2f' %
131 | (len(vocab), np.mean(novel_sentence), len(uniq_sentence)/float(count_sent), np.mean(sent_length)))
132 |
133 | def activity_stats(data_predicted, data_gt, data_annos):
134 | print('#### Per activity ####')
135 |
136 | # Per activity
137 | data_annos = data_annos['database']
138 | activities = {}
139 | vid2act = {}
140 | for vid_id in data_annos:
141 | vid2act[vid_id] = []
142 | annos = data_annos[vid_id]['annotations']
143 | for anno in annos:
144 | if anno['label'] not in vid2act[vid_id]:
145 | vid2act[vid_id].append(anno['label'])
146 | if anno['label'] not in activities:
147 | activities[anno['label']] = True
148 | div1 = {}
149 | div2 = {}
150 | div3 = {}
151 | div4 = {}
152 | re = {}
153 | sentences = {}
154 |
155 | for act in activities:
156 | div1[act] = -1
157 | div2[act] = -1
158 | div3[act] = -1
159 | div4[act] = -1
160 | re[act] = -1
161 | sentences[act] = []
162 |
163 | for vid in data_gt['results']:
164 |
165 | act = vid2act[vid[2:]][0]
166 |
167 | if vid not in data_predicted['results']:
168 | continue
169 |
170 | for i, _ in enumerate(data_gt['results'][vid]):
171 |
172 | try:
173 | pred_sentence = data_predicted['results'][vid][i]['sentence']
174 | except:
175 | continue
176 |
177 | if pred_sentence[-1] == '.':
178 | pred_sentence = pred_sentence[0:-1]
179 | while pred_sentence[-1] == ' ':
180 | pred_sentence = pred_sentence[0:-1]
181 | pred_sentence = pred_sentence.replace(',', ' ')
182 | while ' ' in pred_sentence:
183 | pred_sentence = pred_sentence.replace(' ', ' ')
184 |
185 | sentences[act].append(pred_sentence)
186 |
187 | for act in activities:
188 | unigrams = {}
189 | bigrams = {}
190 | trigrams = {}
191 | fourgrams = {}
192 |
193 | for pred_sentence in sentences[act]:
194 | words_pred = pred_sentence.split(' ')
195 | unigrams, bigrams, trigrams, fourgrams = getNgrams(words_pred, unigrams, bigrams, trigrams, fourgrams)
196 |
197 | sum_unigrams = sum([unigrams[un] for un in unigrams])
198 | vid_div1 = float(len(unigrams)) / float(sum_unigrams)
199 | vid_div2 = float(len(bigrams)) / float(sum_unigrams)
200 | vid_div3 = float(len(trigrams)) / float(sum_unigrams)
201 | vid_div4 = float(len(fourgrams)) / float(sum_unigrams)
202 |
203 | vid_re = float(sum([max(fourgrams[f]-1,0) for f in fourgrams])) / float(sum([fourgrams[f] for f in fourgrams]))
204 |
205 | div1[act] = vid_div1
206 | div2[act] = vid_div2
207 | div3[act] = vid_div3
208 | div4[act] = vid_div4
209 | re[act] = vid_re
210 |
211 | mean_div1 = np.mean([div1[act] for act in activities])
212 | mean_div2 = np.mean([div2[act] for act in activities])
213 | mean_div3 = np.mean([div3[act] for act in activities])
214 | mean_div4 = np.mean([div4[act] for act in activities])
215 | mean_re = np.mean([re[act] for act in activities])
216 |
217 | print ('Div-1: %.4f\t Div-2: %.4f\t RE: %.4f' % (mean_div1, mean_div2, mean_re))
218 |
219 | def video_stats(data_predicted, data_gt):
220 | # print('#### Per video ####')
221 |
222 | # Per video
223 |
224 | div1 = []
225 | div2 = []
226 | div3 = []
227 | div4 = []
228 | re1 = []
229 | re2 = []
230 | re3 = []
231 | re4 = []
232 |
233 | for vid in data_gt['results']:
234 |
235 | unigrams = {}
236 | bigrams = {}
237 | trigrams = {}
238 | fourgrams = {}
239 |
240 | if vid not in data_predicted['results']:
241 | continue
242 |
243 | for i, _ in enumerate(data_gt['results'][vid]):
244 |
245 | try:
246 | pred_sentence = data_predicted['results'][vid][i]['sentence']
247 | except:
248 | continue
249 |
250 | if pred_sentence[-1] == '.':
251 | pred_sentence = pred_sentence[0:-1]
252 | while pred_sentence[-1] == ' ':
253 | pred_sentence = pred_sentence[0:-1]
254 | pred_sentence = pred_sentence.replace(',', ' ')
255 | while ' ' in pred_sentence:
256 | pred_sentence = pred_sentence.replace(' ', ' ')
257 |
258 | words_pred = pred_sentence.split(' ')
259 | unigrams, bigrams, trigrams, fourgrams = getNgrams(words_pred, unigrams, bigrams, trigrams, fourgrams)
260 |
261 | sum_unigrams = sum([unigrams[un] for un in unigrams])
262 | vid_div1 = float(len(unigrams)) / float(sum_unigrams)
263 | vid_div2 = float(len(bigrams)) / float(sum_unigrams)
264 | vid_div3 = float(len(trigrams)) / float(sum_unigrams)
265 | vid_div4 = float(len(fourgrams)) / float(sum_unigrams)
266 |
267 | vid_re1 = float(sum([max(unigrams[f] - 1, 0) for f in unigrams])) / float(sum([unigrams[f] for f in unigrams]))
268 | vid_re2 = float(sum([max(bigrams[f] - 1, 0) for f in bigrams])) / float(sum([bigrams[f] for f in bigrams]))
269 | vid_re3 = float(sum([max(trigrams[f] - 1, 0) for f in trigrams])) / float(sum([trigrams[f] for f in trigrams]))
270 | vid_re4 = float(sum([max(fourgrams[f]-1,0) for f in fourgrams])) / float(sum([fourgrams[f] for f in fourgrams]))
271 |
272 | div1.append(vid_div1)
273 | div2.append(vid_div2)
274 | div3.append(vid_div3)
275 | div4.append(vid_div4)
276 | re1.append(vid_re1)
277 | re2.append(vid_re2)
278 | re3.append(vid_re3)
279 | re4.append(vid_re4)
280 |
281 | #print ('tDiv-1: %.4f\t Div-2: %.4f\t RE-4: %.4f' % (np.mean(div1), np.mean(div2),np.mean(re4)))
282 |
283 | if __name__=='__main__':
284 | submission = sys.argv[1]
285 | evaluateDiversity(submission)
286 |
--------------------------------------------------------------------------------
/densevid_eval/evaluateRepetition.py:
--------------------------------------------------------------------------------
1 | import json
2 | import numpy as np
3 | import sys
4 |
5 |
6 | def save_json(data, filepath):
7 | with open(filepath, "w") as f:
8 | json.dump(data, f)
9 |
10 |
11 | def save_json_pretty(data, filepath):
12 | with open(filepath, "w") as f:
13 | f.write(json.dumps(data, indent=4, sort_keys=True))
14 |
15 |
16 | def get_ngrams(words_pred, unigrams, bigrams, trigrams, fourgrams):
17 | # N=1
18 | for w in words_pred:
19 | if w not in unigrams:
20 | unigrams[w] = 0
21 | unigrams[w] += 1
22 | # N=2
23 | for i, w in enumerate(words_pred):
24 | if i 0 else [''] for k in gt_vid_ids}
124 | para_res = {vid2idx[k]: [' '.join(parse_para(self.prediction[k]))] \
125 | if k in self.prediction and len(self.prediction[k]) > 0 else [''] for k in gt_vid_ids}
126 |
127 | # Each scorer will compute across all videos and take average score
128 | output = {}
129 | num = len(res)
130 | hard_samples = {}
131 | easy_samples = {}
132 | for scorer, method in self.scorers:
133 | if self.verbose:
134 | print('computing %s score...'%(scorer.method()))
135 |
136 | if method != 'Self_Bleu':
137 | score, scores = scorer.compute_score(gts, res)
138 | else:
139 | score, scores = scorer.compute_score(gts, para_res)
140 | scores = np.asarray(scores)
141 |
142 | if type(method) == list:
143 | for m in range(len(method)):
144 | output[method[m]] = score[m]
145 | if self.verbose:
146 | print("%s: %0.3f" % (method[m], output[method[m]]))
147 | for m, i in enumerate(scores.argmin(1)):
148 | if i not in hard_samples:
149 | hard_samples[i] = []
150 | hard_samples[i].append(method[m])
151 | for m, i in enumerate(scores.argmax(1)):
152 | if i not in easy_samples:
153 | easy_samples[i] = []
154 | easy_samples[i].append(method[m])
155 | else:
156 | output[method] = score
157 | if self.verbose:
158 | print("%s: %0.3f" % (method, output[method]))
159 | i = scores.argmin()
160 | if i not in hard_samples:
161 | hard_samples[i] = []
162 | hard_samples[i].append(method)
163 | i = scores.argmax()
164 | if i not in easy_samples:
165 | easy_samples[i] = []
166 | easy_samples[i].append(method)
167 | # print('# scored video =', num)
168 |
169 | self.hard_samples = {gt_vid_ids[i]: v for i, v in hard_samples.items()}
170 | self.easy_samples = {gt_vid_ids[i]: v for i, v in easy_samples.items()}
171 | return output
172 |
173 | def main(args):
174 | # Call coco eval
175 | evaluator = ANETcaptions(ground_truth_filenames=args.references,
176 | prediction_filename=args.submission,
177 | verbose=args.verbose,
178 | all_scorer=args.all_scorer)
179 | evaluator.evaluate()
180 | output = {}
181 | # Output the results
182 | for metric, score in evaluator.scores.items():
183 | # print('| %s: %2.4f'%(metric, 100*score))
184 | output[metric] = score
185 | json.dump(output, open(args.output, 'w'))
186 | #print(output)
187 |
188 | import time
189 | if __name__=='__main__':
190 | parser = argparse.ArgumentParser(description='Evaluate the results stored in a submissions file.')
191 | parser.add_argument('-s', '--submission', type=str, default='sample_submission.json',
192 | help='sample submission file for ActivityNet Captions Challenge.')
193 | parser.add_argument('-r', '--references', type=str, nargs='+', required=True,
194 | help='reference files with ground truth captions to compare results against. delimited (,) str')
195 | parser.add_argument('-o', '--output', type=str, default=None, help='output file with final language metrics.')
196 | parser.add_argument('-v', '--verbose', action='store_true',
197 | help='Print intermediate steps.')
198 | parser.add_argument('--time', '--t', action = 'store_true',
199 | help = 'Count running time.')
200 | parser.add_argument('--all_scorer', '--a', action = 'store_true',
201 | help = 'Use all scorer.')
202 | args = parser.parse_args()
203 |
204 | if args.output is None:
205 | r_path = args.submission
206 | r_path_splits = r_path.split(".")
207 | r_path_splits = r_path_splits[:-1] + ["_metric", r_path_splits[-1]]
208 | args.output = ".".join(r_path_splits)
209 |
210 | if args.time:
211 | start_time = time.time()
212 | main(args)
213 | if args.time:
214 | print('time = %.2f' % (time.time() - start_time))
215 |
--------------------------------------------------------------------------------
/densevid_eval/yc2_data/val_list.txt:
--------------------------------------------------------------------------------
1 | 405/sdB8qBlLS2E
2 | 405/fn9anlEL4FI
3 | 405/-dh_uGahzYo
4 | 405/BktdaTg6_E4
5 | 325/U_yVc8Dl048
6 | 325/RnSl1LVrItI
7 | 325/vVZsj1t9R70
8 | 325/ysRLGUndzgg
9 | 306/HF49t8uVJOE
10 | 306/GXnzgRC3sd4
11 | 306/pxQd53yvSaA
12 | 306/uAzzevo-FME
13 | 306/vU2lND4YQjM
14 | 306/Mzn6Q4gUDBo
15 | 303/JylDlRtH9Tc
16 | 303/9pJToG30LdM
17 | 303/2IkN3hTEZ2Y
18 | 303/ulrh6C5V_VI
19 | 303/2zFAZy0zSbw
20 | 303/6uHoTJSLoL8
21 | 303/EkuM7L31bMQ
22 | 302/PQ97HXmsFR0
23 | 302/vDDeMg2dhEM
24 | 302/WlHWRPyA7_g
25 | 302/c4WaDsqP38k
26 | 302/9guuyTr8EUg
27 | 302/LYj5-CdRIz0
28 | 321/v_dkYNq8G9Y
29 | 321/vXlmXrKC0FE
30 | 321/UHhuaRTF1UY
31 | 221/zzT6RoI4JPU
32 | 221/cCWDR-jUv9U
33 | 221/GmkRlWA2kGI
34 | 221/G-AUY-jWzck
35 | 221/PV93b0xisN8
36 | 221/2heP32bqOV0
37 | 210/LjfTvZ-cmzs
38 | 210/YMYNv3cZ9SE
39 | 210/_GTwKEPmB-U
40 | 210/nHZsE7T7hwI
41 | 210/gTqhgReBDw0
42 | 210/N35UyfIwhVI
43 | 225/NAMZY2LbeFY
44 | 225/KYoelaJY5LA
45 | 225/2HsWZdKKBGg
46 | 225/-Ju39A-G0Dk
47 | 225/pTjoGIvSfE8
48 | 225/lwdypoLpMW4
49 | 107/peld2w63tpM
50 | 107/VLS3ZJt9GMg
51 | 107/zljhtdoqpv0
52 | 107/3jDAyeKeYFA
53 | 107/-k7trpuj3X8
54 | 201/7D4uMKxLDT0
55 | 201/zBexcthy_tA
56 | 201/-ErPSunMfcs
57 | 201/30Q8k57Kbz4
58 | 201/2SxbO4VAgN8
59 | 201/mZwK0TBI1iY
60 | 119/bAC0cZIQVOk
61 | 119/0ShsPjf9shQ
62 | 119/4bEtf7u4YtE
63 | 119/9BNRMHGepS4
64 | 119/wqpqx-Qm7lk
65 | 119/C73qiF138VU
66 | 419/yizxI2Gf_ww
67 | 419/bY4_F8J8HOM
68 | 419/JK0DTF9Edtk
69 | 419/5W3jHo5d7hM
70 | 419/4K9h7ojJYkc
71 | 120/iqcnbNqVc7U
72 | 120/ucaCmhNo78k
73 | 120/QYl_wwBKt18
74 | 120/CWxjNRIKjA0
75 | 120/OrXZqt42OVs
76 | 418/D4mU_NtbneA
77 | 418/LeCwqp8Bic8
78 | 418/DpuofwnCI8A
79 | 418/DrXVuj1Qowo
80 | 204/cF45-iVw--w
81 | 204/EJm2J0WqRcY
82 | 204/nfVXBQwOCMc
83 | 204/OF-Zh5FrxGc
84 | 204/LpBsoQ6TAL0
85 | 319/RUxugNYxFqg
86 | 319/tkuST4Ku37s
87 | 319/FcjEswcaJW4
88 | 319/X4GOx3EW3Rw
89 | 309/Re46osq_NkI
90 | 309/H6acK-N2wMs
91 | 309/abfhnSaZFlA
92 | 309/-AwyG1JcMp8
93 | 309/tYg3lQ5aZv8
94 | 212/9GX8f5EwwE4
95 | 212/XUyqiWN8WFI
96 | 212/HBUz55JRRm8
97 | 212/0uaKitJaqmI
98 | 314/TO_W2RYL2mA
99 | 314/0hb6NShH9hY
100 | 314/p6LSW9kuRCE
101 | 314/2rJ3KKx0oRk
102 | 314/qRSZEN6g8jY
103 | 314/RqgN6iWMkb0
104 | 317/b_uKIQ4dn3A
105 | 317/R-EnNr_oH8A
106 | 317/noS_n5k3oxM
107 | 317/H5NPxWpfYNU
108 | 317/AcWeYhS3cDs
109 | 209/btikV_DUoCM
110 | 209/xPiv3hP5888
111 | 209/5nh2CP22dgY
112 | 228/Odv6ltYAMw4
113 | 228/jT75QMjRkD0
114 | 228/YP4B9gLNOIM
115 | 228/R5IAGR2SeaE
116 | 228/WQlMXudBGT4
117 | 228/IDiovuOcKW8
118 | 122/LfSYF1N5i_Q
119 | 122/viwpmylgps0
120 | 122/lKm5Ji1Fr4U
121 | 122/GgM8IIglBLw
122 | 122/9F5FvWheSrg
123 | 307/YRZ8zZElALQ
124 | 307/eQZEf3NCCo4
125 | 307/Vq5gxXh9zLM
126 | 307/FliMoBfG72Y
127 | 216/Z5bpo2sBsl8
128 | 216/1iv2xhPN3vk
129 | 216/xx698BRyqG4
130 | 216/EsQbw20TQPA
131 | 216/QUt050AXQMw
132 | 216/JPbFE731Y0c
133 | 215/EpNUSTO2BI4
134 | 215/95WMX64RIBc
135 | 215/woTrhsB_bcA
136 | 215/xwQBrf2CAvc
137 | 215/-ju7_ZORsZw
138 | 215/Hh-uza7bwgE
139 | 313/uHv9xRooPMc
140 | 313/mi8NwUqf7nM
141 | 313/CotdlwupDSI
142 | 313/524UzHtbAcY
143 | 108/LWuuCndtJr0
144 | 108/NjAtxfaLwCk
145 | 108/tPLVNKgs8Lk
146 | 108/7ebZWviUfUA
147 | 108/PYjrGqPHGhY
148 | 108/2iWUUcW08ac
149 | 124/f2uDKzq8WM0
150 | 124/-goI2-eJO1w
151 | 124/NTyhMGmuWik
152 | 124/soLZjUyn0CI
153 | 124/1vJp-jaIaeE
154 | 124/Jtusyjv7GiY
155 | 124/RKhfv-spUaI
156 | 213/YX6v3tY7OPg
157 | 213/4Zl5NvXPi-0
158 | 213/dMhoqii0Cq0
159 | 213/4Y8vVGsv4JE
160 | 213/FzhJGCaaYVs
161 | 213/c3JFGGhkArA
162 | 213/i0qYuhtSQHI
163 | 423/AMBH5L6x3dQ
164 | 423/p-NnIyGFZVw
165 | 423/QISvGTL2VDc
166 | 423/SjA7PFoZcNQ
167 | 110/VPFmudvabUg
168 | 110/PTpRTJKAEoI
169 | 110/9GIPE0aeVNI
170 | 110/OWtnI3m-p8g
171 | 110/I1JgU6TK-yc
172 | 110/FSWZXBbEyFw
173 | 403/SOMsxGGSTUk
174 | 403/ACyY0jTrm5c
175 | 403/vp_dOhmfGcs
176 | 403/VmaEuPzlPII
177 | 403/m88rF0rwHo8
178 | 403/gYWqhml_YJQ
179 | 117/hLTNXDKU_Pk
180 | 117/oAE7nqQeMBQ
181 | 117/8XcSP7kKOIo
182 | 117/MCtF5tRCRUk
183 | 117/wk0nfwGyPBI
184 | 425/OUhxy5BANfk
185 | 425/G-spzGkKIHM
186 | 425/57e54HEcrUE
187 | 425/6nVIgasiUtw
188 | 425/W6DgS0s0qcI
189 | 425/yreC9D4yYiM
190 | 106/xhXcJ6bhX2w
191 | 106/EedEYHqLfP8
192 | 106/e8S1vFC8zYk
193 | 106/RY10IUcz3bk
194 | 106/UmJk0WSl9Uc
195 | 106/88YovCsnMxs
196 | 219/R3Jc1fXwSnU
197 | 219/o8HaMr9E8J8
198 | 219/OpURFOTdycE
199 | 219/i9CMFh31Bs0
200 | 219/5Pa79r5Q-ZI
201 | 219/tQ6-_e59Zrk
202 | 218/1Ihxcua2HBc
203 | 218/ljyO7IaGWLY
204 | 218/TMpt-41UTOk
205 | 127/OEfzgobszUA
206 | 127/sj4BJSnjubc
207 | 127/Xz3-xRyBBog
208 | 127/F2qYQZ7Q68s
209 | 127/g6eV_7U5HX8
210 | 127/3WXM2FAueb8
211 | 127/KTQeLdmlzBo
212 | 203/XAHNVoKV1Bc
213 | 203/bmZB3aszZlA
214 | 203/QKjmdrMA2t8
215 | 203/XbTA0SGOdwk
216 | 203/aCvIo-M06xI
217 | 101/SVo2W3ux1pU
218 | 101/uOXlG8Tglc8
219 | 101/10dZTHlkb8w
220 | 101/4eWzsx1vAi8
221 | 101/c9eELn4axpg
222 | 406/BAoQWVV-bh4
223 | 406/8CaadFo3sw0
224 | 406/Pf4UNA-izQo
225 | 406/T_o_T3LEYLY
226 | 406/ysUibvVCpP8
227 | 406/6H8tPeQGhMY
228 | 412/mV3m2svj3XE
229 | 412/eWBSMD3BiHM
230 | 412/sBJJ0Cj0GG4
231 | 412/U_2DFd2ZMfs
232 | 116/eMsfAhVj2e4
233 | 116/cMzyB4m3VHY
234 | 116/5cn9KJfaQXk
235 | 116/kWLYcM3uVVc
236 | 116/9iH8GK1pcEM
237 | 116/MPCU71Hg-i4
238 | 421/ekgZfuxsz_4
239 | 421/4SnAlRlxlFk
240 | 421/c00gy-NVzaw
241 | 421/bxgdUWKOwtQ
242 | 421/s4CktGpWaZE
243 | 422/rf_mGLJPnDk
244 | 422/lBguj96fa5w
245 | 422/3dUm-m3iFaI
246 | 422/N1-rqFfCm9M
247 | 214/mR0inCVvBzY
248 | 214/ZQGfcC62Pys
249 | 214/RHddz6qeJKk
250 | 214/luDzsPatsGw
251 | 214/aCkbw-aI4xU
252 | 103/gZuDMKXWU_E
253 | 103/o42iehActZo
254 | 103/JxCBGlPgr5o
255 | 103/y4y22RQH05c
256 | 103/rwYaDqXFH88
257 | 103/v174YTbr2N8
258 | 103/zLBRrWd4DTo
259 | 207/nz_LHDf0uqE
260 | 207/wHWDBQ9_7FU
261 | 207/4h33GFHLPNg
262 | 207/sjh57ujp52M
263 | 207/J5Tw7KRnSyc
264 | 324/a4RwXrA1hiE
265 | 324/hAzH-GS4cvc
266 | 324/7r6JQycloEs
267 | 324/Pk88LQ7hxbg
268 | 324/4apR0YypAGc
269 | 324/o9kndEZvsnY
270 | 121/p-PFp1c0FKs
271 | 121/-GlSSp5ZOCQ
272 | 121/2-mxsib6pJo
273 | 121/7R5MVNE-ePU
274 | 121/_Vzpj0cXoSM
275 | 121/hs2h7nb5PHQ
276 | 318/yTPJ_u_qxDU
277 | 318/aYjy__xnegM
278 | 318/2Zr72r4OCe8
279 | 318/EnP2j1caRVs
280 | 318/3z_QhNnSFtM
281 | 409/oC5OvA4BK-E
282 | 409/NXnQys_ejeg
283 | 409/DBgap0YANhs
284 | 409/HdVETeyupXE
285 | 409/so-RuJQY1d0
286 | 112/7NptUiW8hJw
287 | 112/jEo9VXYVrxs
288 | 112/5VnaolWGIy4
289 | 112/UkqQAynrM2g
290 | 112/TgttBprZXDY
291 | 112/e1gtgMczUwE
292 | 304/wii9jNiNl9Y
293 | 304/u95xkc4DfAs
294 | 304/cQ8mt5ACO0A
295 | 304/sSO2wO-yaHw
296 | 404/im-aWyUQGrg
297 | 404/wW_kszdGIJw
298 | 404/1rMT2uMF78E
299 | 404/RFE7qdhjgXc
300 | 404/HJHV2nYz1L8
301 | 111/TAXAVvroOgk
302 | 111/xw9aAfqanDo
303 | 111/wQc0xmPurDc
304 | 111/gXINt_KMK3M
305 | 111/83uz_q4_nyk
306 | 111/JWcAs8biQFU
307 | 111/wokMK-w7XiA
308 | 301/Y2HYSmo4KaI
309 | 301/ntiGX3X-spA
310 | 301/4nxbRG6-sfw
311 | 109/5Oq5giRXtag
312 | 109/vLcBGs389k4
313 | 109/A8eDWlCYaq8
314 | 109/RznLeKVI3yo
315 | 109/NYhsc9ikk4I
316 | 109/WqfselLH4MQ
317 | 104/M8SHMUBnm4A
318 | 104/Nbh64ntT3EM
319 | 104/NZtwPf32YN4
320 | 323/8QblSYQpAoM
321 | 323/mUk0FmDrBb8
322 | 323/tKsGWxiWWCg
323 | 308/jnewhlK2USg
324 | 308/WYAFPvlDB_A
325 | 308/zF3TOfktwd4
326 | 308/Ws7JgPJsVjs
327 | 308/hkVfzjA1HA0
328 | 308/wR8Ybxpnbwc
329 | 224/yWEq4_EG1us
330 | 224/02nUKT0A7uE
331 | 224/nuwCjQVlBrg
332 | 224/E9O9-6TQUw0
333 | 305/PNlctwVmbLY
334 | 305/lpWOv7Y3JHM
335 | 305/_mL1gihKDw0
336 | 305/7-FatJyHj_g
337 | 305/x3if1znl5Fg
338 | 416/nVERaEFJWLQ
339 | 416/oJZUxU9szWA
340 | 416/_xIIpW8iMps
341 | 416/fnbXolhuE7k
342 | 416/a5FoLWnEiAI
343 | 416/DVW7nZeeVlk
344 | 114/IDu5czNIM1w
345 | 114/dMbb10O9hGs
346 | 114/84i8Qdnyd0k
347 | 202/iuQjb1-WAzs
348 | 202/2IcWR76i1bo
349 | 202/oR2QDpoatcQ
350 | 202/eHk6NSLGAkc
351 | 202/awQYyYgulLw
352 | 113/YNpVeU1pVZA
353 | 113/S07Fr83GcBI
354 | 113/RllWJUvrxEY
355 | 113/7jO6rYyhuJk
356 | 113/oDsUh1es_lo
357 | 310/gEYyWqs1oL0
358 | 310/7E8Lj_Ktfok
359 | 310/fpPQcbr5VC0
360 | 310/0EuykeOvGg4
361 | 230/3meb_5kcPFg
362 | 230/vWrOd9Ur0po
363 | 230/x6noOknBPDI
364 | 230/3V4MxH2GuIU
365 | 230/6XBocXgvfTs
366 | 230/3aFiXsrKSoQ
367 | 208/bmxWJNbqCk8
368 | 208/We2CzpjPD3k
369 | 208/LQDP3xm8aRk
370 | 208/186EQzPPHW8
371 | 208/T_fPNAK5Ecg
372 | 208/ffyHeyRpYvo
373 | 401/lRwMt_eHjxU
374 | 401/ikmPrpgWQ5M
375 | 401/6seOEuK0ojg
376 | 401/YhevdroG7a4
377 | 401/mNhj7SA7c4g
378 | 401/NRovp9c9e-4
379 | 401/m9gNbLw0Dcg
380 | 226/lC8B_Yx6Qzg
381 | 226/VH0SmCfAov4
382 | 226/xHr8X2Wpmno
383 | 226/Ew5YKc6xmLE
384 | 226/ffhliBglDhY
385 | 226/05ZSU-5UkXw
386 | 311/BRqTCiAc7uk
387 | 311/B1YQYS9BMdk
388 | 311/lkmVVQIsdEE
389 | 311/vq8C5DTfOKc
390 | 311/W2gnFLOi_AQ
391 | 229/XEifm-iXMvs
392 | 229/HdQzPLk_KiA
393 | 229/cMMoRNhHJrI
394 | 229/cDYCtBwin5g
395 | 229/rKtI8FQGhHo
396 | 229/YA6lhxwLrUI
397 | 413/paiJGvLILKE
398 | 413/r9AtdDfDVmo
399 | 413/m3kFrdCHitg
400 | 102/TF1iWaX2-DM
401 | 102/O7ONcb3qhMU
402 | 102/J7gBorrGvDU
403 | 102/pNAwkqm4t3A
404 | 410/29Wkj1LqaK8
405 | 410/oYLrSflCI2g
406 | 410/86Mb6cYFJig
407 | 222/_ilIn1kmNSA
408 | 222/D7K6_0gtpHQ
409 | 222/RubyHelAHBE
410 | 115/F564e476ULM
411 | 115/igF8D7iE46o
412 | 115/DHpQOhQhW3A
413 | 115/qAoqhmjk3iY
414 | 115/kchoaU2HL-o
415 | 115/9ekEjxd-A_Y
416 | 115/xkKuIlYSMMU
417 | 115/yxjnWx6TaQ8
418 | 223/SkawoKeyNoQ
419 | 223/uYBTguvz4tc
420 | 316/lH7pgsnyGrI
421 | 316/kGxmudExRVk
422 | 211/VwtkHIturro
423 | 211/ZjKY9v48fTc
424 | 211/DeiX_otgD1Q
425 | 211/WlkaUxBwURQ
426 | 211/qkluMpILLdQ
427 | 105/PTUxCvCz8Bc
428 | 105/V53XmPeyjIU
429 | 105/jbjg6w5taGU
430 | 105/FrzEHqqi1RY
431 | 105/wlq30WwXwSM
432 | 105/Dao0vasGPMQ
433 | 126/iq7aiv9MPvA
434 | 126/PHpk4ITk-SE
435 | 126/mhEVgpfF-IU
436 | 126/NK2tAXi3cT4
437 | 126/tGaAAI3aAUs
438 | 205/zqTXQ-YqrgQ
439 | 205/RWtVm_5_D2s
440 | 205/zPCtV7YcmkA
441 | 205/efnHOsT7k9s
442 | 205/InDwfZmSikI
443 | 206/LKrI9pGpM78
444 | 206/cdsDsUcLJZM
445 | 206/UIElE5H_iHc
446 | 206/rWdhkAXfEAY
447 | 206/9Y9_OBnJub0
448 | 206/vSRZRp2Ovqc
449 | 227/TfITvKr5M3k
450 | 227/4ZbNtfqKkiI
451 | 227/sv8jRCmi3Ro
452 | 227/eYOn2ZVB4nc
453 | 403/JqjwJIV6pI0
454 | 404/XXUmUPDosYQ
455 | 410/p-gN4cbmunQ
456 | 422/Ky0zf0v2F5A
457 | 401/sGzBQrg1adY
458 |
--------------------------------------------------------------------------------
/figures/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/figures/network.png
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/src/__init__.py
--------------------------------------------------------------------------------
/src/build_vocab.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import nltk
4 | import numpy as np
5 | from tqdm import tqdm
6 | import torch
7 | from src.utils import load_json, flat_list_of_lists, save_json
8 | from rtransformer.recursive_caption_dataset import RecursiveCaptionDataset as RCDataset
9 |
10 |
11 | def build_vocab_idx(word_insts, min_word_count):
12 | full_vocab = set(w for sent in word_insts for w in sent)
13 | print("[Info] Original Vocabulary size =", len(full_vocab))
14 |
15 | word2idx = {
16 | RCDataset.PAD_TOKEN: RCDataset.PAD,
17 | RCDataset.CLS_TOKEN: RCDataset.CLS,
18 | RCDataset.SEP_TOKEN: RCDataset.SEP,
19 | RCDataset.VID_TOKEN: RCDataset.VID,
20 | RCDataset.BOS_TOKEN: RCDataset.BOS,
21 | RCDataset.EOS_TOKEN: RCDataset.EOS,
22 | RCDataset.UNK_TOKEN: RCDataset.UNK,
23 | }
24 |
25 | word_count = {w: 0 for w in full_vocab}
26 |
27 | for sent in word_insts:
28 | for word in sent:
29 | word_count[word] += 1
30 |
31 | ignored_word_count = 0
32 | for word, count in word_count.items():
33 | if word not in word2idx:
34 | if count > min_word_count:
35 | word2idx[word] = len(word2idx)
36 | else:
37 | ignored_word_count += 1
38 |
39 | print("[Info] Trimmed vocabulary size = {},".format(len(word2idx)),
40 | "each with minimum occurrence = {}".format(min_word_count))
41 | print("[Info] Ignored word count = {}".format(ignored_word_count))
42 | return word2idx
43 |
44 |
45 | def load_transform_data(data_path):
46 | data = load_json(data_path)
47 | transformed_data = []
48 | for v_id, cap in data.items():
49 | cap["v_id"] = v_id
50 | transformed_data.append(cap)
51 | return transformed_data
52 |
53 |
54 | def load_glove(filename):
55 | """ returns { word (str) : vector_embedding (torch.FloatTensor) }
56 | """
57 | glove = {}
58 | with open(filename,encoding='utf-8') as f:
59 | for line in f.readlines():
60 | values = line.strip("\n").split(" ") # space separator
61 | word = values[0]
62 | vector = np.asarray([float(e) for e in values[1:]])
63 | glove[word] = vector
64 | return glove
65 |
66 |
67 | def extract_glove(word2idx, raw_glove_path, vocab_glove_path, glove_dim=300):
68 | # Make glove embedding.
69 | print("Loading glove embedding at path : {}.\n".format(raw_glove_path))
70 | glove_full = load_glove(raw_glove_path)
71 | print("Glove Loaded, building word2idx, idx2word mapping.\n")
72 | idx2word = {v: k for k, v in word2idx.items()}
73 |
74 | glove_matrix = np.zeros([len(word2idx), glove_dim])
75 | glove_keys = glove_full.keys()
76 | for i in tqdm(range(len(idx2word))):
77 | w = idx2word[i]
78 | w_embed = glove_full[w] if w in glove_keys else np.random.randn(glove_dim) * 0.4
79 | glove_matrix[i, :] = w_embed
80 | print("vocab embedding size is :", glove_matrix.shape)
81 | torch.save(glove_matrix, vocab_glove_path)
82 |
83 |
84 | def main():
85 | parser = argparse.ArgumentParser()
86 | parser.add_argument("--train_path", type=str, default='../densevid_eval/yc2_data/yc2_train_anet_format.json')
87 | parser.add_argument("--dset_name", type=str, default="yc2", choices=["anet", "yc2"])
88 | parser.add_argument("--cache", type=str, default="./cache")
89 | parser.add_argument("--min_word_count", type=int, default=3)
90 | parser.add_argument("--raw_glove_path", type=str, default='/mnt/Video_feature/path/glove.6B.300d.txt')
91 |
92 | opt = parser.parse_args()
93 | if not os.path.exists(opt.cache):
94 | os.makedirs(opt.cache)
95 |
96 | # load, merge, clean, split data
97 | train_data = load_json(opt.train_path)
98 | all_sentences = flat_list_of_lists([v["sentences"] for k, v in train_data.items()])
99 | all_sentences = [nltk.tokenize.word_tokenize(sen.lower())
100 | for sen in all_sentences]
101 | word2idx = build_vocab_idx(all_sentences, opt.min_word_count)
102 | print("[Info] Dumping the processed data to json file", opt.cache)
103 | word2idx_path = os.path.join(opt.cache, "{}_word2idx_3.json".format(opt.dset_name))
104 | save_json(word2idx, word2idx_path, save_pretty=True)
105 | print("[Info] Finish.")
106 |
107 | vocab_glove_path = os.path.join(opt.cache, "{}_vocab_glove_3.pt".format(opt.dset_name))
108 | extract_glove(word2idx, opt.raw_glove_path, vocab_glove_path)
109 |
110 |
111 | if __name__ == "__main__":
112 | main()
113 |
--------------------------------------------------------------------------------
/src/caption_eval.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/21 18:44
3 | # @Author : Yaojie Shen
4 | # @Project : MM-Video
5 | # @File : caption_test.py
6 | import os
7 | import logging
8 | from collections import defaultdict
9 | from tqdm import tqdm
10 |
11 | import torch.distributed as dist
12 | import torch.utils.data
13 |
14 | from utils.train_utils import gather_object_multiple_gpu, get_timestamp, CudaPreFetcher
15 | from utils.json import save_json, load_json
16 | from utils.train_utils import Timer
17 |
18 | from pycocoevalcap.tokenizer.ptbtokenizer import PTBTokenizer
19 | from pycocoevalcap.bleu.bleu import Bleu
20 | from pycocoevalcap.meteor.meteor import Meteor
21 | from pycocoevalcap.rouge.rouge import Rouge
22 | from pycocoevalcap.cider.cider import Cider
23 | from pycocoevalcap.spice.spice import Spice
24 |
25 | logger = logging.getLogger(__name__)
26 |
27 |
28 | class Translator(object):
29 | """Load with trained model and handle the beam search"""
30 |
31 | def __init__(self, checkpoint, model=None):
32 | self.max_v_len = checkpoint['opt'].max_v_len
33 | self.max_t_len = checkpoint['opt'].max_t_len
34 | self.PAD = 0
35 | self.BOS = 4
36 |
37 | self.model = model
38 | self.model.eval()
39 |
40 | self.timer = Timer(synchronize=True, history_size=500, precision=3)
41 |
42 | def translate_batch_single_sentence_greedy(self, inputs, model):
43 | inputs_ids = inputs["input_ids"]
44 | input_masks = inputs["input_mask"]
45 | # max_t_len = self.max_t_len
46 | max_t_len = 21
47 | inputs_ids[:, :] = 0.
48 | input_masks[:, :] = 0.
49 | assert torch.sum(input_masks[:, :]) == 0, "Initially, all text tokens should be masked"
50 | bsz = len(inputs_ids)
51 | next_symbols = torch.IntTensor([self.BOS] * bsz) # (N, )
52 |
53 | self.timer.reset()
54 | for dec_idx in range(max_t_len):
55 | inputs_ids[:, dec_idx] = next_symbols.clone()
56 | input_masks[:, dec_idx] = 1
57 | outputs = model(inputs)
58 | pred_scores = outputs["prediction_scores"]
59 | # pred_scores[:, :, 49406] = -1e10
60 | next_words = pred_scores[:, dec_idx].max(1)[1] # TODO / NOTE changed
61 | next_symbols = next_words.cpu()
62 | if "visual_output" in outputs:
63 | inputs["visual_output"] = outputs["visual_output"]
64 | else:
65 | logger.debug("visual_output is not in the output of model, this may slow down the caption test")
66 | self.timer(stage_name="inference")
67 | logger.debug(f"inference toke {self.timer.get_info()['average']['inference']} ms")
68 | return inputs_ids
69 |
70 | def translate_batch(self, model_inputs):
71 | """while we used *_list as the input names, they could be non-list for single sentence decoding case"""
72 | return self.translate_batch_single_sentence_greedy(model_inputs, self.model)
73 |
74 |
75 | def convert_ids_to_sentence(tokens):
76 | from .clip.clip import _tokenizer
77 | text = _tokenizer.decode(tokens)
78 | text_list = text.split(" ")
79 | new = []
80 | for i in range(len(text_list)):
81 | if i == 0:
82 | new.append(text_list[i].split(">")[-1])
83 | elif "<|endoftext|>" in text_list[i]:
84 | break
85 | else:
86 | new.append(text_list[i])
87 | return " ".join(new)
88 |
89 |
90 | def run_translate(data_loader, translator, epoch, opt):
91 | # submission template
92 | batch_res = {"version": "VERSION 1.0",
93 | "results": defaultdict(list),
94 | "external_data": {"used": "true", "details": "ay"}}
95 | for bid, batch in enumerate(tqdm(data_loader,
96 | dynamic_ncols=True,
97 | disable=dist.is_initialized() and dist.get_rank() != 0)):
98 | if torch.cuda.is_available():
99 | batch = CudaPreFetcher.cuda(batch)
100 | dec_seq = translator.translate_batch(batch)
101 |
102 | # example_idx indicates which example is in the batch
103 | for example_idx, (cur_gen_sen, cur_meta) in enumerate(zip(dec_seq, batch['metadata'][1])):
104 | cur_data = {
105 | "sentence": convert_ids_to_sentence(cur_gen_sen.tolist()),
106 | "gt_sentence": cur_meta
107 | }
108 | # print(cur_data)
109 | batch_res["results"][batch['metadata'][0][example_idx].split("video")[-1]].append(cur_data)
110 | translator.timer.print()
111 | return batch_res
112 |
113 |
114 | class EvalCap:
115 | def __init__(self, annos, rests, cls_tokenizer=PTBTokenizer,
116 | use_scorers=['Bleu', 'METEOR', 'ROUGE_L', 'CIDEr']): # ,'SPICE']):
117 | self.evalImgs = []
118 | self.eval = {}
119 | self.imgToEval = {}
120 | self.annos = annos
121 | self.rests = rests
122 | self.Tokenizer = cls_tokenizer
123 | self.use_scorers = use_scorers
124 |
125 | def evaluate(self):
126 | res = {}
127 | for r in self.rests:
128 | res[str(r['image_id'])] = [{'caption': r['caption']}]
129 |
130 | gts = {}
131 | for imgId in self.annos:
132 | gts[str(imgId)] = [{'caption': self.annos[imgId]}]
133 |
134 | # =================================================
135 | # Set up scorers
136 | # =================================================
137 | # print('tokenization...')
138 | tokenizer = self.Tokenizer()
139 | gts = tokenizer.tokenize(gts)
140 | res = tokenizer.tokenize(res)
141 | # gts = {k: v for k, v in gts.items() if k in res.keys()}
142 | # =================================================
143 | # Set up scorers
144 | # =================================================
145 | # print('setting up scorers...')
146 | use_scorers = self.use_scorers
147 | scorers = []
148 | if 'Bleu' in use_scorers:
149 | scorers.append((Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]))
150 | if 'METEOR' in use_scorers:
151 | scorers.append((Meteor(), "METEOR"))
152 | if 'ROUGE_L' in use_scorers:
153 | scorers.append((Rouge(), "ROUGE_L"))
154 | if 'CIDEr' in use_scorers:
155 | scorers.append((Cider(), "CIDEr"))
156 | if 'SPICE' in use_scorers:
157 | scorers.append((Spice(), "SPICE"))
158 |
159 | # =================================================
160 | # Compute scores
161 | # =================================================
162 | for scorer, method in scorers:
163 | score, scores = scorer.compute_score(gts, res)
164 | if type(method) == list:
165 | for sc, scs, m in zip(score, scores, method):
166 | self.setEval(sc, m)
167 | self.setImgToEvalImgs(scs, gts.keys(), m)
168 | # print("%s: %0.1f" % (m, sc*100))
169 | else:
170 | self.setEval(score, method)
171 | self.setImgToEvalImgs(scores, gts.keys(), method)
172 | # print("%s: %0.1f" % (method, score*100))
173 | self.setEvalImgs()
174 |
175 | def setEval(self, score, method):
176 | self.eval[method] = score
177 |
178 | def setImgToEvalImgs(self, scores, imgIds, method):
179 | for imgId, score in zip(imgIds, scores):
180 | if not imgId in self.imgToEval:
181 | self.imgToEval[imgId] = {}
182 | self.imgToEval[imgId]["image_id"] = imgId
183 | self.imgToEval[imgId][method] = score
184 |
185 | def setEvalImgs(self):
186 | self.evalImgs = [eval for imgId, eval in self.imgToEval.items()]
187 |
188 |
189 | def evaluate(submission, reference):
190 | tokenizer = PTBTokenizer # for English
191 | annos = reference
192 | data = submission['results']
193 | rests = []
194 | for name, value in data.items():
195 | rests.append({'image_id': str(name), 'caption': value[0]['sentence']})
196 | eval_cap = EvalCap(annos, rests, tokenizer)
197 |
198 | eval_cap.evaluate()
199 |
200 | all_score = {}
201 | for metric, score in eval_cap.eval.items():
202 | all_score[metric] = score
203 | return all_score
204 |
205 | if __name__ == "__main__":
206 | ours = load_json("")
207 | gt = load_json("")
208 | evaluate(ours, gt)
--------------------------------------------------------------------------------
/src/rtransformer/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/GX77/TextKG/2dec9157c1ebcdfecd0e3d1a1b769c195c4d9245/src/rtransformer/__init__.py
--------------------------------------------------------------------------------
/src/rtransformer/beam_search.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/beam_search.py
3 | """
4 | import torch
5 |
6 | from rtransformer.decode_strategy import DecodeStrategy, length_penalty_builder
7 | from rtransformer.recursive_caption_dataset import RecursiveCaptionDataset as RCDataset
8 |
9 | import logging
10 | logger = logging.getLogger(__name__)
11 |
12 |
13 | class BeamSearch(DecodeStrategy):
14 | """Generation beam search.
15 |
16 | Note that the attributes list is not exhaustive. Rather, it highlights
17 | tensors to document their shape. (Since the state variables' "batch"
18 | size decreases as beams finish, we denote this axis with a B rather than
19 | ``batch_size``).
20 |
21 | Args:
22 | beam_size (int): Number of beams to use (see base ``parallel_paths``).
23 | batch_size (int): See base.
24 | pad (int): See base.
25 | bos (int): See base.
26 | eos (int): See base.
27 | n_best (int): Don't stop until at least this many beams have
28 | reached EOS.
29 | mb_device (torch.device or str): See base ``device``.
30 | min_length (int): See base.
31 | max_length (int): See base.
32 | block_ngram_repeat (int): See base.
33 | exclusion_tokens (set[int]): See base.
34 |
35 | Attributes:
36 | top_beam_finished (ByteTensor): Shape ``(B,)``.
37 | _batch_offset (LongTensor): Shape ``(B,)``.
38 | _beam_offset (LongTensor): Shape ``(batch_size x beam_size,)``.
39 | alive_seq (LongTensor): See base.
40 | topk_log_probs (FloatTensor): Shape ``(B x beam_size,)``. These
41 | are the scores used for the topk operation.
42 | select_indices (LongTensor or NoneType): Shape
43 | ``(B x beam_size,)``. This is just a flat view of the
44 | ``_batch_index``.
45 | topk_scores (FloatTensor): Shape
46 | ``(B, beam_size)``. These are the
47 | scores a sequence will receive if it finishes.
48 | topk_ids (LongTensor): Shape ``(B, beam_size)``. These are the
49 | word indices of the topk predictions.
50 | _batch_index (LongTensor): Shape ``(B, beam_size)``.
51 | _prev_penalty (FloatTensor or NoneType): Shape
52 | ``(B, beam_size)``. Initialized to ``None``.
53 | _coverage (FloatTensor or NoneType): Shape
54 | ``(1, B x beam_size, inp_seq_len)``.
55 | hypotheses (list[list[Tuple[Tensor]]]): Contains a tuple
56 | of score (float), sequence (long), and attention (float or None).
57 | """
58 |
59 | def __init__(self, beam_size, batch_size, pad, bos, eos, n_best, mb_device,
60 | min_length, max_length, block_ngram_repeat, exclusion_tokens,
61 | length_penalty_name=None, length_penalty_alpha=0.):
62 | super(BeamSearch, self).__init__(
63 | pad, bos, eos, batch_size, mb_device, beam_size, min_length,
64 | block_ngram_repeat, exclusion_tokens, max_length)
65 | # beam parameters
66 | self.beam_size = beam_size
67 | self.n_best = n_best
68 | self.batch_size = batch_size
69 | self.length_penalty_name = length_penalty_name
70 | self.length_penalty_func = length_penalty_builder(length_penalty_name)
71 | self.length_penalty_alpha = length_penalty_alpha
72 |
73 | # result caching
74 | self.hypotheses = [[] for _ in range(batch_size)]
75 |
76 | # beam state
77 | self.top_beam_finished = torch.zeros([batch_size], dtype=torch.uint8)
78 | self.best_scores = torch.full([batch_size], -1e10, dtype=torch.float,
79 | device=mb_device) # (N, )
80 |
81 | self._batch_offset = torch.arange(batch_size, dtype=torch.long) # (N, )
82 | self._beam_offset = torch.arange(
83 | 0, batch_size * beam_size, step=beam_size, dtype=torch.long,
84 | device=mb_device) # (N, )
85 | self.topk_log_probs = torch.tensor(
86 | [0.0] + [float("-inf")] * (beam_size - 1), device=mb_device
87 | ).repeat(batch_size) # (B*N), guess: store the current beam probabilities
88 | self.select_indices = None
89 |
90 | # buffers for the topk scores and 'backpointer'
91 | self.topk_scores = torch.empty((batch_size, beam_size),
92 | dtype=torch.float, device=mb_device) # (N, B)
93 | self.topk_ids = torch.empty((batch_size, beam_size), dtype=torch.long,
94 | device=mb_device) # (N, B)
95 | self._batch_index = torch.empty([batch_size, beam_size],
96 | dtype=torch.long, device=mb_device) # (N, B)
97 | self.done = False
98 | # "global state" of the old beam
99 | self._prev_penalty = None
100 | self._coverage = None
101 |
102 | @property
103 | def current_predictions(self):
104 | return self.alive_seq[:, -1]
105 |
106 | @property
107 | def current_origin(self):
108 | return self.select_indices
109 |
110 | @property
111 | def current_backptr(self):
112 | # for testing
113 | return self.select_indices.view(self.batch_size, self.beam_size)\
114 | .fmod(self.beam_size)
115 |
116 | def advance(self, log_probs):
117 | """ current step log_probs (N * B, vocab_size), attn (1, N * B, L)
118 | Which attention is this??? Guess: the one with the encoder outputs
119 | """
120 | vocab_size = log_probs.size(-1)
121 |
122 | # using integer division to get an integer _B without casting
123 | _B = log_probs.shape[0] // self.beam_size # batch_size
124 |
125 | # force the output to be longer than self.min_length,
126 | # by setting prob(EOS) to be a very small number when < min_length
127 | self.ensure_min_length(log_probs)
128 |
129 | # Multiply probs by the beam probability.
130 | # logger.info("after log_probs {} {}".format(log_probs.shape, log_probs))
131 | log_probs += self.topk_log_probs.view(_B * self.beam_size, 1)
132 | # logger.info("after log_probs {} {}".format(log_probs.shape, log_probs))
133 |
134 | self.block_ngram_repeats(log_probs)
135 |
136 | # if the sequence ends now, then the penalty is the current
137 | # length + 1, to include the EOS token, length_penalty is a float number
138 | step = len(self)
139 | length_penalty = self.length_penalty_func(step+1, self.length_penalty_alpha)
140 |
141 | # Flatten probs into a list of possibilities.
142 | # pick topk in all the paths
143 | curr_scores = log_probs / length_penalty
144 | curr_scores = curr_scores.reshape(_B, self.beam_size * vocab_size)
145 | # self.topk_scores and self.topk_ids => (N, B)
146 | torch.topk(curr_scores, self.beam_size, dim=-1,
147 | out=(self.topk_scores, self.topk_ids))
148 |
149 | # Recover log probs.
150 | # Length penalty is just a scalar. It doesn't matter if it's applied
151 | # before or after the topk.
152 | torch.mul(self.topk_scores, length_penalty, out=self.topk_log_probs)
153 |
154 | # Resolve beam origin and map to batch index flat representation.
155 | torch.div(self.topk_ids, vocab_size, out=self._batch_index) # _batch_index (N * B)
156 | self._batch_index += self._beam_offset[:_B].unsqueeze(1)
157 | self.select_indices = self._batch_index.view(_B * self.beam_size)
158 |
159 | self.topk_ids.fmod_(vocab_size) # resolve true word ids
160 |
161 | # Append last prediction.
162 | self.alive_seq = torch.cat(
163 | [self.alive_seq.index_select(0, self.select_indices),
164 | self.topk_ids.view(_B * self.beam_size, 1)], -1) # (N * B, step_size)
165 |
166 | self.is_finished = self.topk_ids.eq(self.eos) # (N, B)
167 | self.ensure_max_length()
168 |
169 | def update_finished(self):
170 | # Penalize beams that finished.
171 | _B_old = self.topk_log_probs.shape[0] # batch_size might be changing??? as the beams finished
172 | step = self.alive_seq.shape[-1] # 1 greater than the step in advance, as we advanced 1 step
173 | self.topk_log_probs.masked_fill_(self.is_finished, -1e10)
174 | # on real data (newstest2017) with the pretrained transformer,
175 | # it's faster to not move this back to the original device
176 | self.is_finished = self.is_finished.to('cpu') # (N, B)
177 | self.top_beam_finished |= self.is_finished[:, 0].eq(1) # (N, ) initialized as zeros
178 | predictions = self.alive_seq.view(_B_old, self.beam_size, step)
179 | non_finished_batch = []
180 | for i in range(self.is_finished.size(0)): # (N, )
181 | b = self._batch_offset[i] # (0, ..., N-1)
182 | finished_hyp = self.is_finished[i].nonzero().view(-1)
183 | # Store finished hypotheses for this batch.
184 | for j in finished_hyp:
185 | self.hypotheses[b].append([self.topk_scores[i, j],
186 | predictions[i, j, 1:]])
187 | # End condition is the top beam finished and we can return
188 | # n_best hypotheses.
189 | finish_flag = self.top_beam_finished[i] != 0
190 | if finish_flag and len(self.hypotheses[b]) >= self.n_best:
191 | best_hyp = sorted(
192 | self.hypotheses[b], key=lambda x: x[0], reverse=True) # sort by scores
193 | for n, (score, pred) in enumerate(best_hyp):
194 | if n >= self.n_best:
195 | break
196 | self.scores[b].append(score)
197 | self.predictions[b].append(pred)
198 | else:
199 | non_finished_batch.append(i)
200 | non_finished = torch.tensor(non_finished_batch)
201 | # If all sentences are translated, no need to go further.
202 | if len(non_finished) == 0:
203 | self.done = True
204 | return
205 |
206 | _B_new = non_finished.shape[0]
207 | # Remove finished batches for the next step. (Not finished beam!!!)
208 | self.top_beam_finished = self.top_beam_finished.index_select(
209 | 0, non_finished)
210 | self._batch_offset = self._batch_offset.index_select(0, non_finished)
211 | non_finished = non_finished.to(self.topk_ids.device)
212 | self.topk_log_probs = self.topk_log_probs.index_select(0,
213 | non_finished)
214 | self._batch_index = self._batch_index.index_select(0, non_finished)
215 | self.select_indices = self._batch_index.view(_B_new * self.beam_size)
216 | self.alive_seq = predictions.index_select(0, non_finished) \
217 | .view(-1, self.alive_seq.size(-1))
218 | self.topk_scores = self.topk_scores.index_select(0, non_finished)
219 | self.topk_ids = self.topk_ids.index_select(0, non_finished)
220 |
--------------------------------------------------------------------------------
/src/rtransformer/decode_strategy.py:
--------------------------------------------------------------------------------
1 | """
2 | https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/translate/decode_strategy.py
3 | """
4 | import torch
5 | import logging
6 | logger = logging.getLogger(__name__)
7 |
8 |
9 | class DecodeStrategy(object):
10 | """Base class for generation strategies.
11 |
12 | Args:
13 | pad (int): Magic integer in output vocab.
14 | bos (int): Magic integer in output vocab.
15 | eos (int): Magic integer in output vocab.
16 | batch_size (int): Current batch size.
17 | device (torch.device or str): Device for memory bank (encoder).
18 | parallel_paths (int): Decoding strategies like beam search
19 | use parallel paths. Each batch is repeated ``parallel_paths``
20 | times in relevant state tensors.
21 | min_length (int): Shortest acceptable generation, not counting
22 | begin-of-sentence or end-of-sentence.
23 | max_length (int): Longest acceptable sequence, not counting
24 | begin-of-sentence (presumably there has been no EOS
25 | yet if max_length is used as a cutoff).
26 | block_ngram_repeat (int): Block beams where
27 | ``block_ngram_repeat``-grams repeat.
28 | exclusion_tokens (set[int]): If a gram contains any of these
29 | tokens, it may repeat.
30 |
31 | Attributes:
32 | pad (int): See above.
33 | bos (int): See above.
34 | eos (int): See above.
35 | predictions (list[list[LongTensor]]): For each batch, holds a
36 | list of beam prediction sequences.
37 | scores (list[list[FloatTensor]]): For each batch, holds a
38 | list of scores.
39 | attention (list[list[FloatTensor or list[]]]): For each
40 | batch, holds a list of attention sequence tensors
41 | (or empty lists) having shape ``(step, inp_seq_len)`` where
42 | ``inp_seq_len`` is the length of the sample (not the max
43 | length of all inp seqs).
44 | alive_seq (LongTensor): Shape ``(B x parallel_paths, step)``.
45 | This sequence grows in the ``step`` axis on each call to
46 | :func:`advance()`.
47 | is_finished (ByteTensor or NoneType): Shape
48 | ``(B, parallel_paths)``. Initialized to ``None``.
49 | alive_attn (FloatTensor or NoneType): If tensor, shape is
50 | ``(step, B x parallel_paths, inp_seq_len)``, where ``inp_seq_len``
51 | is the (max) length of the input sequence.
52 | min_length (int): See above.
53 | max_length (int): See above.
54 | block_ngram_repeat (int): See above.
55 | exclusion_tokens (set[int]): See above.
56 | done (bool): See above.
57 | """
58 |
59 | def __init__(self, pad, bos, eos, batch_size, device, parallel_paths,
60 | min_length, block_ngram_repeat, exclusion_tokens, max_length):
61 |
62 | # magic indices
63 | self.pad = pad
64 | self.bos = bos
65 | self.eos = eos
66 |
67 | # result caching
68 | self.predictions = [[] for _ in range(batch_size)]
69 | self.scores = [[] for _ in range(batch_size)]
70 | self.attention = [[] for _ in range(batch_size)]
71 |
72 | self.alive_seq = torch.full(
73 | [batch_size * parallel_paths, 1], self.bos,
74 | dtype=torch.long, device=device) # (N * B, step_size=1)
75 | self.is_finished = torch.zeros(
76 | [batch_size, parallel_paths],
77 | dtype=torch.uint8, device=device)
78 | self.alive_attn = None
79 |
80 | self.min_length = min_length
81 | self.max_length = max_length
82 | self.block_ngram_repeat = block_ngram_repeat
83 | self.exclusion_tokens = exclusion_tokens
84 |
85 | self.done = False
86 |
87 | def __len__(self):
88 | return self.alive_seq.shape[1] # steps length
89 |
90 | def ensure_min_length(self, log_probs):
91 | if len(self) <= self.min_length:
92 | log_probs[:, self.eos] = -1e20
93 |
94 | def ensure_max_length(self):
95 | # add one to account for BOS. Don't account for EOS because hitting
96 | # this implies it hasn't been found.
97 | if len(self) == self.max_length + 1:
98 | self.is_finished.fill_(1)
99 |
100 | def block_ngram_repeats(self, log_probs):
101 | # log_probs (N * B, vocab_size)
102 | cur_len = len(self)
103 | if self.block_ngram_repeat > 0 and cur_len > 1:
104 | for path_idx in range(self.alive_seq.shape[0]): # N * B
105 | # skip BOS
106 | hyp = self.alive_seq[path_idx, 1:]
107 | ngrams = set()
108 | fail = False
109 | gram = []
110 | for i in range(cur_len - 1):
111 | # Last n tokens, n = block_ngram_repeat
112 | gram = (gram + [hyp[i].item()])[-self.block_ngram_repeat:]
113 | # skip the blocking if any token in gram is excluded
114 | if set(gram) & self.exclusion_tokens:
115 | continue
116 | if tuple(gram) in ngrams:
117 | fail = True
118 | ngrams.add(tuple(gram))
119 | if fail:
120 | log_probs[path_idx] = -10e20 # all the words in this path
121 |
122 | def advance(self, log_probs):
123 | """DecodeStrategy subclasses should override :func:`advance()`.
124 |
125 | Advance is used to update ``self.alive_seq``, ``self.is_finished``,
126 | and, when appropriate, ``self.alive_attn``.
127 | """
128 |
129 | raise NotImplementedError()
130 |
131 | def update_finished(self):
132 | """DecodeStrategy subclasses should override :func:`update_finished()`.
133 |
134 | ``update_finished`` is used to update ``self.predictions``,
135 | ``self.scores``, and other "output" attributes.
136 | """
137 |
138 | raise NotImplementedError()
139 |
140 |
141 | def length_penalty_builder(length_penalty_name="none"):
142 | """implement length penalty"""
143 | def length_wu(cur_len, alpha=0.):
144 | """GNMT length re-ranking score.
145 | See "Google's Neural Machine Translation System" :cite:`wu2016google`.
146 | """
147 | return ((5 + cur_len) / 6.0) ** alpha
148 |
149 | def length_average(cur_len, alpha=0.):
150 | """Returns the current sequence length."""
151 | return cur_len
152 |
153 | def length_none(cur_len, alpha=0.):
154 | """Returns unmodified scores."""
155 | return 1.0
156 |
157 | if length_penalty_name == "none":
158 | return length_none
159 | elif length_penalty_name == "wu":
160 | return length_wu
161 | elif length_penalty_name == "avg":
162 | return length_average
163 | else:
164 | raise NotImplementedError
165 |
--------------------------------------------------------------------------------
/src/rtransformer/masked_transformer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (c) 2018, salesforce.com, inc.
3 | All rights reserved.
4 | SPDX-License-Identifier: BSD-3-Clause
5 | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
6 |
7 | References:
8 | https://github.com/salesforce/densecap/blob/master/model/transformer.py
9 |
10 | Modified by Jie Lei
11 | """
12 |
13 | import torch
14 | from torch import nn
15 | from torch.nn import functional as F
16 |
17 | import math
18 | import numpy as np
19 | from rtransformer.model import LabelSmoothingLoss
20 |
21 |
22 | INF = 1e10
23 |
24 |
25 | def positional_encodings_like(x, t=None):
26 | if t is None:
27 | positions = torch.arange(0, x.size(1)).float()
28 | if x.is_cuda:
29 | positions = positions.cuda(x.get_device())
30 | else:
31 | positions = t
32 | encodings = torch.zeros(*x.size()[1:])
33 | if x.is_cuda:
34 | encodings = encodings.cuda(x.get_device())
35 |
36 | for channel in range(x.size(-1)):
37 | if channel % 2 == 0:
38 | encodings[:, channel] = torch.sin(positions / 10000 ** (channel / x.size(2)))
39 | else:
40 | encodings[:, channel] = torch.cos(positions / 10000 ** ((channel - 1) / x.size(2)))
41 | return encodings
42 |
43 |
44 | class LayerNorm(nn.Module):
45 | def __init__(self, d_model, eps=1e-6):
46 | super(LayerNorm, self).__init__()
47 | self.gamma = nn.Parameter(torch.ones(d_model))
48 | self.beta = nn.Parameter(torch.zeros(d_model))
49 | self.eps = eps
50 |
51 | def forward(self, x):
52 | mean = x.mean(-1, keepdim=True)
53 | std = x.std(-1, keepdim=True)
54 | return self.gamma * (x - mean) / (std + self.eps) + self.beta
55 |
56 |
57 | class ResidualBlock(nn.Module):
58 | def __init__(self, layer, d_model, drop_ratio):
59 | super(ResidualBlock, self).__init__()
60 | self.layer = layer
61 | self.dropout = nn.Dropout(drop_ratio)
62 | self.layernorm = LayerNorm(d_model)
63 |
64 | def forward(self, *x):
65 | return self.layernorm(x[0] + self.dropout(self.layer(*x)))
66 |
67 |
68 | class Attention(nn.Module):
69 |
70 | def __init__(self, d_key, drop_ratio, causal):
71 | super(Attention, self).__init__()
72 | self.scale = math.sqrt(d_key)
73 | self.dropout = nn.Dropout(drop_ratio)
74 | self.causal = causal
75 |
76 | def forward(self, query, key, value):
77 | dot_products = torch.bmm(query, key.transpose(1, 2))
78 | if query.dim() == 3 and (self is None or self.causal):
79 | tri = torch.ones(key.size(1), key.size(1)).triu(1) * INF
80 | if key.is_cuda:
81 | tri = tri.cuda(key.get_device())
82 | dot_products.data.sub_(tri.unsqueeze(0))
83 | return torch.bmm(self.dropout(F.softmax(dot_products / self.scale, dim=-1)), value)
84 |
85 |
86 | class MultiHead(nn.Module):
87 | def __init__(self, d_key, d_value, n_heads, drop_ratio, causal=False):
88 | super(MultiHead, self).__init__()
89 | self.attention = Attention(d_key, drop_ratio, causal=causal)
90 | self.wq = nn.Linear(d_key, d_key, bias=False)
91 | self.wk = nn.Linear(d_key, d_key, bias=False)
92 | self.wv = nn.Linear(d_value, d_value, bias=False)
93 | self.wo = nn.Linear(d_value, d_key, bias=False)
94 | self.n_heads = n_heads
95 |
96 | def forward(self, query, key, value):
97 | query, key, value = self.wq(query), self.wk(key), self.wv(value)
98 | query, key, value = (
99 | x.chunk(self.n_heads, -1) for x in (query, key, value))
100 | return self.wo(torch.cat([self.attention(q, k, v)
101 | for q, k, v in zip(query, key, value)], -1))
102 |
103 |
104 | class FeedForward(nn.Module):
105 | def __init__(self, d_model, d_hidden):
106 | super(FeedForward, self).__init__()
107 | self.linear1 = nn.Linear(d_model, d_hidden)
108 | self.linear2 = nn.Linear(d_hidden, d_model)
109 |
110 | def forward(self, x):
111 | return self.linear2(F.relu(self.linear1(x)))
112 |
113 |
114 | class EncoderLayer(nn.Module):
115 | def __init__(self, d_model, d_hidden, n_heads, drop_ratio):
116 | super(EncoderLayer, self).__init__()
117 | self.selfattn = ResidualBlock(
118 | MultiHead(d_model, d_model, n_heads, drop_ratio, causal=False),
119 | d_model, drop_ratio)
120 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden),
121 | d_model, drop_ratio)
122 |
123 | def forward(self, x):
124 | return self.feedforward(self.selfattn(x, x, x))
125 |
126 |
127 | class DecoderLayer(nn.Module):
128 | def __init__(self, d_model, d_hidden, n_heads, drop_ratio):
129 | super(DecoderLayer, self).__init__()
130 | self.selfattn = ResidualBlock(
131 | MultiHead(d_model, d_model, n_heads, drop_ratio, causal=True),
132 | d_model, drop_ratio)
133 | self.attention = ResidualBlock(
134 | MultiHead(d_model, d_model, n_heads, drop_ratio),
135 | d_model, drop_ratio)
136 | self.feedforward = ResidualBlock(FeedForward(d_model, d_hidden),
137 | d_model, drop_ratio)
138 |
139 | def forward(self, x, encoding):
140 | """
141 | Args:
142 | x: (N, Lt, D)
143 | encoding: (N, Lv, D)
144 | """
145 | x = self.selfattn(x, x, x) # (N, Lt, D)
146 | return self.feedforward(self.attention(x, encoding, encoding)) # (N, Lt, D)
147 |
148 |
149 | class Encoder(nn.Module):
150 | def __init__(self, vfeat_size, d_model, d_hidden, n_layers, n_heads, drop_ratio):
151 | super(Encoder, self).__init__()
152 | self.video_embeddings = nn.Sequential(
153 | LayerNorm(vfeat_size),
154 | nn.Dropout(drop_ratio),
155 | nn.Linear(vfeat_size, d_model)
156 | )
157 | self.layers = nn.ModuleList(
158 | [EncoderLayer(d_model, d_hidden, n_heads, drop_ratio)
159 | for i in range(n_layers)])
160 | self.dropout = nn.Dropout(drop_ratio)
161 |
162 | def forward(self, x, mask=None):
163 | """
164 |
165 | Args:
166 | x: (N, Lv, Dv)
167 | mask: (N, Lv)
168 |
169 | Returns:
170 |
171 | """
172 | x = self.video_embeddings(x) # (N, Lv, D)
173 | x = x + positional_encodings_like(x)
174 | x = self.dropout(x)
175 | mask.unsqueeze_(-1)
176 | if mask is not None:
177 | x = x*mask
178 | encoding = []
179 | for layer in self.layers:
180 | x = layer(x)
181 | if mask is not None:
182 | x = x*mask
183 | encoding.append(x)
184 | return encoding
185 |
186 |
187 | class Decoder(nn.Module):
188 | def __init__(self, d_model, d_hidden, vocab_size, n_layers, n_heads,
189 | drop_ratio):
190 | super(Decoder, self).__init__()
191 | self.layers = nn.ModuleList(
192 | [DecoderLayer(d_model, d_hidden, n_heads, drop_ratio)
193 | for i in range(n_layers)])
194 | self.out = nn.Linear(d_model, vocab_size)
195 | self.dropout = nn.Dropout(drop_ratio)
196 | self.d_model = d_model
197 | self.d_out = vocab_size
198 |
199 | def forward(self, x, encoding):
200 | """
201 | Args:
202 | x: (N, Lt)
203 | encoding: [(N, Lv, D), ] * num_hidden_layers
204 |
205 | """
206 | x = F.embedding(x, self.out.weight * math.sqrt(self.d_model)) # (N, Lt, D)
207 | x = x + positional_encodings_like(x) # (N, Lt, D)
208 | x = self.dropout(x) # (N, Lt, D)
209 | for layer, enc in zip(self.layers, encoding):
210 | x = layer(x, enc) # (N, Lt, D)
211 | return x # (N, Lt, D) at last layer
212 |
213 |
214 | class MTransformer(nn.Module):
215 | def __init__(self, config):
216 | super(MTransformer, self).__init__()
217 | self.config = config
218 | vfeat_size = config.video_feature_size
219 | d_model = config.hidden_size # 1024
220 | d_hidden = config.intermediate_size # 2048
221 | n_layers = config.num_hidden_layers # 6
222 | n_heads = config.num_attention_heads # 8
223 | drop_ratio = config.hidden_dropout_prob # 0.1
224 | self.vocab_size = config.vocab_size
225 | self.encoder = Encoder(vfeat_size, d_model, d_hidden, n_layers,
226 | n_heads, drop_ratio)
227 | self.decoder = Decoder(d_model, d_hidden, self.vocab_size,
228 | n_layers, n_heads, drop_ratio)
229 | self.loss_func = LabelSmoothingLoss(config.label_smoothing, config.vocab_size, ignore_index=-1) \
230 | if "label_smoothing" in config and config.label_smoothing > 0 else nn.CrossEntropyLoss(ignore_index=-1)
231 |
232 | def encode(self, video_features, video_masks):
233 | """
234 | Args:
235 | video_features: (N, Lv, Dv)
236 | video_masks: (N, Lv) with 1 indicates valid bits
237 | """
238 | return self.encoder(video_features, video_masks)
239 |
240 | def decode(self, text_input_ids, text_masks, text_input_labels, encoder_outputs, video_masks):
241 | """
242 | Args:
243 | text_input_ids: (N, Lt)
244 | text_masks: (N, Lt) with 1 indicates valid bits,
245 | text_input_labels: (N, Lt) with `-1` on ignored positions
246 | encoder_outputs: (N, Lv, D)
247 | video_masks: not used, leave here to maintain a common API with untied model
248 | """
249 | # the triangular mask is generated and applied inside the attention module
250 | h = self.decoder(text_input_ids, encoder_outputs) # (N, Lt, D)
251 | prediction_scores = self.decoder.out(h) # (N, Lt, vocab_size)
252 | caption_loss = self.loss_func(prediction_scores.view(-1, self.config.vocab_size),
253 | text_input_labels.view(-1)) # float
254 | return caption_loss, prediction_scores
255 |
256 | def forward(self, video_features, video_masks, text_input_ids, text_masks, text_input_labels):
257 | """
258 | Args:
259 | video_features: (N, Lv, Dv)
260 | video_masks: (N, Lv) with 1 indicates valid bits
261 | text_input_ids: (N, Lt)
262 | text_masks: (N, Lt) with 1 indicates valid bits
263 | text_input_labels: (N, Lt) with `-1` on ignored positions (in some sense duplicate with text_masks)
264 | """
265 | encoder_layer_outputs = self.encode(video_features, video_masks) # [(N, Lv, D), ] * num_hidden_layers
266 | caption_loss, prediction_scores = self.decode(
267 | text_input_ids, text_masks, text_input_labels, encoder_layer_outputs, None) # float, (N, Lt, vocab_size)
268 | return caption_loss, prediction_scores
269 |
--------------------------------------------------------------------------------
/src/rtransformer/optimization.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 | # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3 | #
4 | # Licensed under the Apache License, Version 2.0 (the "License");
5 | # you may not use this file except in compliance with the License.
6 | # You may obtain a copy of the License at
7 | #
8 | # http://www.apache.org/licenses/LICENSE-2.0
9 | #
10 | # Unless required by applicable law or agreed to in writing, software
11 | # distributed under the License is distributed on an "AS IS" BASIS,
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 | # See the License for the specific language governing permissions and
14 | # limitations under the License.
15 | """PyTorch optimization for BERT model."""
16 |
17 | import math
18 | import torch
19 | from torch.optim import Optimizer
20 | from torch.optim.optimizer import required
21 | from torch.nn.utils import clip_grad_norm_
22 | import logging
23 | import abc
24 | import sys
25 |
26 | logger = logging.getLogger(__name__)
27 |
28 |
29 | if sys.version_info >= (3, 4):
30 | ABC = abc.ABC
31 | else:
32 | ABC = abc.ABCMeta('ABC', (), {})
33 |
34 |
35 | class _LRSchedule(ABC):
36 | """ Parent of all LRSchedules here. """
37 | warn_t_total = False # is set to True for schedules where progressing beyond t_total steps doesn't make sense
38 | def __init__(self, warmup=0.002, t_total=-1, **kw):
39 | """
40 | :param warmup: what fraction of t_total steps will be used for linear warmup
41 | :param t_total: how many training steps (updates) are planned
42 | :param kw:
43 | """
44 | super(_LRSchedule, self).__init__(**kw)
45 | if t_total < 0:
46 | logger.warning("t_total value of {} results in schedule not being applied".format(t_total))
47 | if not 0.0 <= warmup < 1.0 and not warmup == -1:
48 | raise ValueError("Invalid warmup: {} - should be in [0.0, 1.0[ or -1".format(warmup))
49 | warmup = max(warmup, 0.)
50 | self.warmup, self.t_total = float(warmup), float(t_total)
51 | self.warned_for_t_total_at_progress = -1
52 |
53 | def get_lr(self, step, nowarn=False):
54 | """
55 | :param step: which of t_total steps we're on
56 | :param nowarn: set to True to suppress warning regarding training beyond specified 't_total' steps
57 | :return: learning rate multiplier for current update
58 | """
59 | if self.t_total < 0:
60 | return 1.
61 | progress = float(step) / self.t_total
62 | ret = self.get_lr_(progress)
63 | # warning for exceeding t_total (only active with warmup_linear
64 | if not nowarn and self.warn_t_total and progress > 1. and progress > self.warned_for_t_total_at_progress:
65 | logger.warning(
66 | "Training beyond specified 't_total'. Learning rate multiplier set to {}. Please set 't_total' of {} correctly."
67 | .format(ret, self.__class__.__name__))
68 | self.warned_for_t_total_at_progress = progress
69 | # end warning
70 | return ret
71 |
72 | @abc.abstractmethod
73 | def get_lr_(self, progress):
74 | """
75 | :param progress: value between 0 and 1 (unless going beyond t_total steps) specifying training progress
76 | :return: learning rate multiplier for current update
77 | """
78 | return 1.
79 |
80 |
81 | class ConstantLR(_LRSchedule):
82 | def get_lr_(self, progress):
83 | return 1.
84 |
85 |
86 | class WarmupCosineSchedule(_LRSchedule):
87 | """
88 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
89 | Decreases learning rate from 1. to 0. over remaining `1 - warmup` steps following a cosine curve.
90 | If `cycles` (default=0.5) is different from default, learning rate follows cosine function after warmup.
91 | """
92 | warn_t_total = True
93 | def __init__(self, warmup=0.002, t_total=-1, cycles=.5, **kw):
94 | """
95 | :param warmup: see LRSchedule
96 | :param t_total: see LRSchedule
97 | :param cycles: number of cycles. Default: 0.5, corresponding to cosine decay from 1. at progress==warmup and 0 at progress==1.
98 | :param kw:
99 | """
100 | super(WarmupCosineSchedule, self).__init__(warmup=warmup, t_total=t_total, **kw)
101 | self.cycles = cycles
102 |
103 | def get_lr_(self, progress):
104 | if progress < self.warmup:
105 | return progress / self.warmup
106 | else:
107 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
108 | return 0.5 * (1. + math.cos(math.pi * self.cycles * 2 * progress))
109 |
110 |
111 | class WarmupCosineWithHardRestartsSchedule(WarmupCosineSchedule):
112 | """
113 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
114 | If `cycles` (default=1.) is different from default, learning rate follows `cycles` times a cosine decaying
115 | learning rate (with hard restarts).
116 | """
117 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
118 | super(WarmupCosineWithHardRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
119 | assert(cycles >= 1.)
120 |
121 | def get_lr_(self, progress):
122 | if progress < self.warmup:
123 | return progress / self.warmup
124 | else:
125 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
126 | ret = 0.5 * (1. + math.cos(math.pi * ((self.cycles * progress) % 1)))
127 | return ret
128 |
129 |
130 | class WarmupCosineWithWarmupRestartsSchedule(WarmupCosineWithHardRestartsSchedule):
131 | """
132 | All training progress is divided in `cycles` (default=1.) parts of equal length.
133 | Every part follows a schedule with the first `warmup` fraction of the training steps linearly increasing from 0. to 1.,
134 | followed by a learning rate decreasing from 1. to 0. following a cosine curve.
135 | """
136 | def __init__(self, warmup=0.002, t_total=-1, cycles=1., **kw):
137 | assert(warmup * cycles < 1.)
138 | warmup = warmup * cycles if warmup >= 0 else warmup
139 | super(WarmupCosineWithWarmupRestartsSchedule, self).__init__(warmup=warmup, t_total=t_total, cycles=cycles, **kw)
140 |
141 | def get_lr_(self, progress):
142 | progress = progress * self.cycles % 1.
143 | if progress < self.warmup:
144 | return progress / self.warmup
145 | else:
146 | progress = (progress - self.warmup) / (1 - self.warmup) # progress after warmup
147 | ret = 0.5 * (1. + math.cos(math.pi * progress))
148 | return ret
149 |
150 |
151 | class WarmupConstantSchedule(_LRSchedule):
152 | """
153 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
154 | Keeps learning rate equal to 1. after warmup.
155 | """
156 | def get_lr_(self, progress):
157 | if progress < self.warmup:
158 | return progress / self.warmup
159 | return 1.
160 |
161 |
162 | class WarmupLinearSchedule(_LRSchedule):
163 | """
164 | Linearly increases learning rate from 0 to 1 over `warmup` fraction of training steps.
165 | Linearly decreases learning rate from 1. to 0. over remaining `1 - warmup` steps.
166 | """
167 | warn_t_total = True
168 | def get_lr_(self, progress):
169 | if progress < self.warmup:
170 | return progress / self.warmup
171 | return max((progress - 1.) / (self.warmup - 1.), 0.)
172 |
173 |
174 | SCHEDULES = {
175 | None: ConstantLR,
176 | "none": ConstantLR,
177 | "warmup_cosine": WarmupCosineSchedule,
178 | "warmup_constant": WarmupConstantSchedule,
179 | "warmup_linear": WarmupLinearSchedule
180 | }
181 |
182 |
183 | class EMA(object):
184 | """ Exponential Moving Average for model parameters.
185 | references:
186 | [1] https://github.com/BangLiu/QANet-PyTorch/blob/master/model/modules/ema.py
187 | [2] https://github.com/hengruo/QANet-pytorch/blob/e2de07cd2c711d525f5ffee35c3764335d4b501d/main.py"""
188 | def __init__(self, decay):
189 | self.decay = decay
190 | self.shadow = {}
191 | self.original = {}
192 |
193 | def register(self, name, val):
194 | self.shadow[name] = val.clone()
195 |
196 | def __call__(self, model, step):
197 | decay = min(self.decay, (1 + step) / (10.0 + step))
198 | for name, param in model.named_parameters():
199 | if param.requires_grad:
200 | assert name in self.shadow
201 | new_average = \
202 | (1.0 - decay) * param.data + decay * self.shadow[name]
203 | self.shadow[name] = new_average.clone()
204 |
205 | def assign(self, model):
206 | for name, param in model.named_parameters():
207 | if param.requires_grad:
208 | assert name in self.shadow
209 | self.original[name] = param.data.clone()
210 | param.data = self.shadow[name]
211 |
212 | def resume(self, model):
213 | for name, param in model.named_parameters():
214 | if param.requires_grad:
215 | assert name in self.shadow
216 | param.data = self.original[name]
217 |
218 |
219 | class BertAdam(Optimizer):
220 | """Implements BERT version of Adam algorithm with weight decay fix.
221 | Params:
222 | lr: learning rate
223 | warmup: portion of t_total for the warmup, -1 means no warmup. Default: -1
224 | t_total: total number of training steps for the learning
225 | rate schedule, -1 means constant learning rate of 1. (no warmup regardless of warmup setting). Default: -1
226 | schedule: schedule to use for the warmup (see above).
227 | Can be `'warmup_linear'`, `'warmup_constant'`, `'warmup_cosine'`, `'none'`, `None` or a `_LRSchedule` object (see below).
228 | If `None` or `'none'`, learning rate is always kept constant.
229 | Default : `'warmup_linear'`
230 | b1: Adams b1. Default: 0.9
231 | b2: Adams b2. Default: 0.999
232 | e: Adams epsilon. Default: 1e-6
233 | weight_decay: Weight decay. Default: 0.01
234 | max_grad_norm: Maximum norm for the gradients (-1 means no clipping). Default: 1.0
235 | """
236 | def __init__(self, params, lr=required, warmup=-1, t_total=-1, schedule='warmup_linear',
237 | b1=0.9, b2=0.999, e=1e-6, weight_decay=0.01, max_grad_norm=1.0, **kwargs):
238 | if lr is not required and lr < 0.0:
239 | raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
240 | if not isinstance(schedule, _LRSchedule) and schedule not in SCHEDULES:
241 | raise ValueError("Invalid schedule parameter: {}".format(schedule))
242 | if not 0.0 <= b1 < 1.0:
243 | raise ValueError("Invalid b1 parameter: {} - should be in [0.0, 1.0[".format(b1))
244 | if not 0.0 <= b2 < 1.0:
245 | raise ValueError("Invalid b2 parameter: {} - should be in [0.0, 1.0[".format(b2))
246 | if not e >= 0.0:
247 | raise ValueError("Invalid epsilon value: {} - should be >= 0.0".format(e))
248 | # initialize schedule object
249 | if not isinstance(schedule, _LRSchedule):
250 | schedule_type = SCHEDULES[schedule]
251 | schedule = schedule_type(warmup=warmup, t_total=t_total)
252 | else:
253 | if warmup != -1 or t_total != -1:
254 | logger.warning("warmup and t_total on the optimizer are ineffective when _LRSchedule object is provided as schedule. "
255 | "Please specify custom warmup and t_total in _LRSchedule object.")
256 | defaults = dict(lr=lr, schedule=schedule,
257 | b1=b1, b2=b2, e=e, weight_decay=weight_decay,
258 | max_grad_norm=max_grad_norm)
259 | super(BertAdam, self).__init__(params, defaults)
260 |
261 | def get_lr(self):
262 | lr = []
263 | for group in self.param_groups:
264 | for p in group['params']:
265 | state = self.state[p]
266 | if len(state) == 0:
267 | return [0]
268 | lr_scheduled = group['lr']
269 | lr_scheduled *= group['schedule'].get_lr(state['step'])
270 | lr.append(lr_scheduled)
271 | return lr
272 |
273 | def step(self, closure=None):
274 | """Performs a single optimization step.
275 |
276 | Arguments:
277 | closure (callable, optional): A closure that reevaluates the model
278 | and returns the loss.
279 | """
280 | loss = None
281 | if closure is not None:
282 | loss = closure()
283 |
284 | for group in self.param_groups:
285 | for p in group['params']:
286 | if p.grad is None:
287 | continue
288 | grad = p.grad.data
289 | if grad.is_sparse:
290 | raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
291 |
292 | state = self.state[p]
293 |
294 | # State initialization
295 | if len(state) == 0:
296 | state['step'] = 0
297 | # Exponential moving average of gradient values
298 | state['next_m'] = torch.zeros_like(p.data)
299 | # Exponential moving average of squared gradient values
300 | state['next_v'] = torch.zeros_like(p.data)
301 |
302 | next_m, next_v = state['next_m'], state['next_v']
303 | beta1, beta2 = group['b1'], group['b2']
304 |
305 | # Add grad clipping
306 | if group['max_grad_norm'] > 0:
307 | clip_grad_norm_(p, group['max_grad_norm'])
308 |
309 | # Decay the first and second moment running average coefficient
310 | # In-place operations to update the averages at the same time
311 | #next_m.mul_(beta1).add_(1 - beta1, grad)
312 | #next_v.mul_(beta2).addcmul_(1 - beta2, grad, grad)
313 | next_m.mul_(beta1).add_(grad, alpha=1 - beta1)
314 | next_v.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
315 |
316 | update = next_m / (next_v.sqrt() + group['e'])
317 |
318 | # Just adding the square of the weights to the loss function is *not*
319 | # the correct way of using L2 regularization/weight decay with Adam,
320 | # since that will interact with the m and v parameters in strange ways.
321 | #
322 | # Instead we want to decay the weights in a manner that doesn't interact
323 | # with the m/v parameters. This is equivalent to adding the square
324 | # of the weights to the loss with plain (non-momentum) SGD.
325 | if group['weight_decay'] > 0.0:
326 | update += group['weight_decay'] * p.data
327 |
328 | lr_scheduled = group['lr']
329 | lr_scheduled *= group['schedule'].get_lr(state['step'])
330 |
331 | update_with_lr = lr_scheduled * update
332 | p.data.add_(-update_with_lr)
333 |
334 | state['step'] += 1
335 |
336 | # step_size = lr_scheduled * math.sqrt(bias_correction2) / bias_correction1
337 | # No bias correction
338 | # bias_correction1 = 1 - beta1 ** state['step']
339 | # bias_correction2 = 1 - beta2 ** state['step']
340 |
341 | return loss
342 |
--------------------------------------------------------------------------------
/src/translate.py:
--------------------------------------------------------------------------------
1 | """ Translate input text with trained model. """
2 |
3 | import os
4 | import torch
5 | from torch.utils.data import DataLoader
6 | import argparse
7 | from tqdm import tqdm
8 | import random
9 | import numpy as np
10 | import subprocess
11 | from collections import defaultdict
12 |
13 | from translator import Translator
14 | from rtransformer.recursive_caption_dataset import \
15 | single_sentence_collate, prepare_batch_inputs
16 | from rtransformer.recursive_caption_dataset import RecursiveCaptionDataset as RCDataset
17 | from utils_func import load_json, merge_dicts, save_json
18 |
19 |
20 | def sort_res(res_dict):
21 | """res_dict: the submission json entry `results`"""
22 | final_res_dict = {}
23 | for k, v in res_dict.items():
24 | final_res_dict[k] = sorted(v, key=lambda x: float(x["timestamp"][0]))
25 | return final_res_dict
26 |
27 | def run_translate(eval_data_loader, translator, opt):
28 | # submission template
29 | batch_res = {"version": "VERSION 1.0",
30 | "results": defaultdict(list),
31 | "external_data": {"used": "true", "details": "ay"}}
32 |
33 | for raw_batch in tqdm(eval_data_loader, mininterval=2, desc=" - (Translate)"):
34 | meta = raw_batch[2] # list(dict), len == bsz
35 | batched_data = prepare_batch_inputs(raw_batch[0], device=translator.device)
36 | model_inputs = [
37 | batched_data["input_ids"],
38 | batched_data["input_ids2"],
39 | batched_data["video_feature"],
40 | batched_data['region_feature'],
41 | batched_data["input_mask"],
42 | batched_data["input_mask2"],
43 | batched_data["token_type_ids"],
44 | batched_data["token_type_ids2"],
45 | batched_data['kg_mask'],
46 | batched_data['l_mask'],
47 | ]
48 |
49 | dec_seq = translator.translate_batch(
50 | model_inputs, use_beam=opt.use_beam, recurrent=False, untied=opt.untied or opt.mtrans)
51 |
52 | # example_idx indicates which example is in the batch
53 | for example_idx, (cur_gen_sen, cur_meta) in enumerate(zip(dec_seq, meta)):
54 | cur_data = {
55 | "sentence": eval_data_loader.dataset.convert_ids_to_sentence(
56 | cur_gen_sen.cpu().tolist()),
57 | "timestamp": cur_meta["timestamp"],
58 | "gt_sentence": cur_meta["gt_sentence"]
59 | }
60 | batch_res["results"][cur_meta["name"]].append(cur_data)
61 |
62 | batch_res["results"] = sort_res(batch_res["results"])
63 | return batch_res
64 |
65 |
66 | def get_data_loader(opt, eval_mode="val"):
67 | eval_dataset = RCDataset(
68 | dset_name=opt.dset_name,
69 | data_dir=opt.data_dir, video_feature_dir=opt.video_feature_dir,
70 | duration_file=opt.v_duration_file,
71 | word2idx_path=opt.word2idx_path, max_t_len=opt.max_t_len,
72 | max_v_len=opt.max_v_len, max_n_sen=opt.max_n_sen + 10, mode=eval_mode,
73 | recurrent=opt.recurrent, untied=opt.untied or opt.mtrans)
74 |
75 | collate_fn = single_sentence_collate
76 | eval_data_loader = DataLoader(eval_dataset, collate_fn=collate_fn,
77 | batch_size=opt.batch_size, shuffle=False, num_workers=8)
78 | return eval_data_loader
79 |
80 |
81 | def main():
82 | parser = argparse.ArgumentParser(description="translate.py")
83 |
84 | parser.add_argument("--eval_splits", type=str, nargs="+", default=["val", ],
85 | choices=["val", "test"], help="evaluate on val/test set, yc2 only has val")
86 | parser.add_argument("--res_dir", required=True, help="path to dir containing model .pt file")
87 | parser.add_argument("--batch_size", type=int, default=100, help="batch size")
88 |
89 | # beam search configs
90 | parser.add_argument("--use_beam", action="store_true", help="use beam search, otherwise greedy search")
91 | parser.add_argument("--beam_size", type=int, default=2, help="beam size")
92 | parser.add_argument("--n_best", type=int, default=1, help="stop searching when get n_best from beam search")
93 | parser.add_argument("--min_sen_len", type=int, default=5, help="minimum length of the decoded sentences")
94 | parser.add_argument("--max_sen_len", type=int, default=30, help="maximum length of the decoded sentences")
95 | parser.add_argument("--block_ngram_repeat", type=int, default=0, help="block repetition of ngrams during decoding.")
96 | parser.add_argument("--length_penalty_name", default="none",
97 | choices=["none", "wu", "avg"], help="length penalty to use.")
98 | parser.add_argument("--length_penalty_alpha", type=float, default=0.,
99 | help="Google NMT length penalty parameter (higher = longer generation)")
100 | parser.add_argument("--eval_tool_dir", type=str, default="/mnt/Pycharm_Remote/recurrent_transformer/densevid_eval")
101 |
102 | parser.add_argument("--no_cuda", action="store_true")
103 | parser.add_argument("--seed", default=2019, type=int)
104 | parser.add_argument("--debug", action="store_true")
105 |
106 | opt = parser.parse_args()
107 | opt.cuda = not opt.no_cuda
108 |
109 | # random seed
110 | random.seed(opt.seed)
111 | np.random.seed(opt.seed)
112 | torch.manual_seed(opt.seed)
113 |
114 | checkpoint = torch.load(os.path.join(opt.res_dir, "model.chkpt"))
115 |
116 | # add some of the train configs
117 | train_opt = checkpoint["opt"] # EDict(load_json(os.path.join(opt.res_dir, "model.cfg.json")))
118 | for k in train_opt.__dict__:
119 | if k not in opt.__dict__:
120 | setattr(opt, k, getattr(train_opt, k))
121 | print("train_opt", train_opt)
122 |
123 | decoding_strategy = "beam{}_lp_{}_la_{}".format(
124 | opt.beam_size, opt.length_penalty_name, opt.length_penalty_alpha) if opt.use_beam else "greedy"
125 | save_json(vars(opt),
126 | os.path.join(opt.res_dir, "{}_eval_cfg.json".format(decoding_strategy)),
127 | save_pretty=True)
128 |
129 | if opt.dset_name == "anet":
130 | reference_files_map = {
131 | "val": [os.path.join(opt.data_dir, e) for e in
132 | ["anet_entities_val_1_para.json", "anet_entities_val_2_para.json"]],
133 | "test": [os.path.join(opt.data_dir, e) for e in
134 | ["anet_entities_test_1_para.json", "anet_entities_test_2_para.json"]]}
135 | else: # yc2
136 | reference_files_map = {"val": [os.path.join(opt.data_dir, "yc2_val_anet_format_para.json")]}
137 | for eval_mode in opt.eval_splits:
138 | print("Start evaluating {}".format(eval_mode))
139 | # add 10 at max_n_sen to make the inference stage use all the segments
140 | eval_data_loader = get_data_loader(opt, eval_mode=eval_mode)
141 | eval_references = reference_files_map[eval_mode]
142 |
143 | # setup model
144 | translator = Translator(opt, checkpoint)
145 |
146 | pred_file = os.path.join(opt.res_dir, "{}_pred_{}.json".format(decoding_strategy, eval_mode))
147 | pred_file = os.path.abspath(pred_file)
148 | if not os.path.exists(pred_file):
149 | json_res = run_translate(eval_data_loader, translator, opt=opt)
150 | save_json(json_res, pred_file, save_pretty=True)
151 | else:
152 | print("Using existing prediction file at {}".format(pred_file))
153 |
154 | # COCO language evaluation
155 | lang_file = pred_file.replace(".json", "_lang.json")
156 | eval_command = ["python", "para-evaluate.py", "-s", pred_file, "-o", lang_file,
157 | "-v", "-r"] + eval_references
158 | subprocess.call(eval_command, cwd=opt.eval_tool_dir)
159 |
160 | # basic stats
161 | stat_filepath = pred_file.replace(".json", "_stat.json")
162 | eval_stat_cmd = ["python", "get_caption_stat.py", "-s", pred_file, "-r", eval_references[0],
163 | "-o", stat_filepath, "-v"]
164 | subprocess.call(eval_stat_cmd, cwd=opt.eval_tool_dir)
165 |
166 | # repetition evaluation
167 | rep_filepath = pred_file.replace(".json", "_rep.json")
168 | eval_rep_cmd = ["python", "evaluateRepetition.py", "-s", pred_file,
169 | "-r", eval_references[0], "-o", rep_filepath]
170 | subprocess.call(eval_rep_cmd, cwd=opt.eval_tool_dir)
171 |
172 | metric_filepaths = [lang_file, stat_filepath, rep_filepath]
173 | all_metrics = merge_dicts([load_json(e) for e in metric_filepaths])
174 | all_metrics_filepath = pred_file.replace(".json", "_all_metrics.json")
175 | save_json(all_metrics, all_metrics_filepath, save_pretty=True)
176 |
177 | print("pred_file {} lang_file {}".format(pred_file, lang_file))
178 | print("[Info] Finished {}.".format(eval_mode))
179 |
180 |
181 | if __name__ == "__main__":
182 | main()
183 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/12 22:30
3 | # @Author : Yaojie Shen
4 | # @Project : MM-Video
5 | # @File : __init__.py
6 |
--------------------------------------------------------------------------------
/src/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/13 00:25
3 | # @Author : Yaojie Shen
4 | # @Project : MM-Video
5 | # @File : checkpoint.py
6 |
7 |
8 | import os
9 | import torch
10 | import logging
11 | from typing import *
12 | from collections import OrderedDict
13 |
14 | logger = logging.getLogger(__name__)
15 |
16 |
17 | def auto_resume(ckpt_folder):
18 | try:
19 | ckpt_files = [ckpt for ckpt in os.listdir(ckpt_folder) if ckpt.endswith(".pth")]
20 | except FileNotFoundError:
21 | ckpt_files = []
22 | if len(ckpt_files) > 0:
23 | return max([os.path.join(ckpt_folder, file) for file in ckpt_files], key=os.path.getmtime)
24 | else:
25 | return None
26 |
27 |
28 | def save_checkpoint(ckpt_folder, epoch, model, optimizer, scheduler, config, prefix="checkpoint"):
29 | if hasattr(model, 'module'):
30 | model = model.module
31 | stat_dict = {
32 | "epoch": epoch,
33 | "model": model.state_dict(),
34 | "optimizer": optimizer.state_dict(),
35 | "scheduler": scheduler.state_dict() if scheduler is not None else "",
36 | "config": config
37 | }
38 | ckpt_path = os.path.join(ckpt_folder, f"{prefix}_{epoch}.pth")
39 | os.makedirs(ckpt_folder, exist_ok=True)
40 | torch.save(stat_dict, ckpt_path)
41 | return ckpt_path
42 |
43 |
44 | def load_checkpoint(ckpt_file, model: torch.nn.Module, optimizer: Union[torch.optim.Optimizer, None], scheduler: Any,
45 | restart_train=False, rewrite: Tuple[str, str] = None):
46 | if hasattr(model, 'module'):
47 | model = model.module
48 | state_dict = torch.load(ckpt_file, map_location="cpu")
49 | if rewrite is not None:
50 | logger.info("rewrite model checkpoint prefix: %s->%s", *rewrite)
51 | state_dict["model"] = {k.replace(*rewrite) if k.startswith(rewrite[0]) else k: v
52 | for k, v in state_dict["model"].items()}
53 | try:
54 | missing = model.load_state_dict(state_dict["model"], strict=False)
55 | logger.debug(f"checkpoint key missing: {missing}")
56 | except RuntimeError:
57 | print("fail to directly recover from checkpoint, try to match each layers...")
58 | net_dict = model.state_dict()
59 | print("find %s layers", len(state_dict["model"].items()))
60 | missing_keys = [k for k, v in state_dict["model"].items() if k not in net_dict or net_dict[k].shape != v.shape]
61 | print("missing key: %s", missing_keys)
62 | state_dict["model"] = {k: v for k, v in state_dict["model"].items() if
63 | (k in net_dict and net_dict[k].shape == v.shape)}
64 | print("resume %s layers from checkpoint", len(state_dict["model"].items()))
65 | net_dict.update(state_dict["model"])
66 | model.load_state_dict(OrderedDict(net_dict))
67 |
68 | if not restart_train:
69 | if optimizer is not None and state_dict["optimizer"]:
70 | optimizer.load_state_dict(state_dict["optimizer"])
71 | if scheduler is not None and state_dict["scheduler"]:
72 | scheduler.load_state_dict(state_dict["scheduler"])
73 | epoch = state_dict["epoch"]
74 | else:
75 | logger.info("restart train, optimizer and scheduler will not be resumed")
76 | epoch = 0
77 |
78 | del state_dict
79 | torch.cuda.empty_cache()
80 | return epoch # start epoch
81 |
82 |
83 | def save_model(model_file: str, model: torch.nn.Module):
84 | if hasattr(model, "module"):
85 | model = model.module
86 | torch.save(model.state_dict(), model_file)
87 |
88 |
89 | def load_model(model_file: str, model: torch.nn.Module, strict=True):
90 | if hasattr(model, "module"):
91 | model = model.module
92 | state_dict = torch.load(model_file, map_location="cpu")
93 |
94 | missing_keys: List[str] = []
95 | unexpected_keys: List[str] = []
96 | error_msgs: List[str] = []
97 |
98 | # copy state_dict so _load_from_state_dict can modify it
99 | metadata = getattr(state_dict, '_metadata', None)
100 | state_dict = state_dict.copy()
101 | if metadata is not None:
102 | # mypy isn't aware that "_metadata" exists in state_dict
103 | state_dict._metadata = metadata # type: ignore[attr-defined]
104 |
105 | def load(module, prefix=''):
106 | local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
107 | module._load_from_state_dict(
108 | state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
109 | for name, child in module._modules.items():
110 | if child is not None:
111 | load(child, prefix + name + '.')
112 |
113 | load(model)
114 | del load
115 |
116 | if len(missing_keys) > 0:
117 | logger.info("Weights of {} not initialized from pretrained model: {}"
118 | .format(model.__class__.__name__, "\n " + "\n ".join(missing_keys)))
119 | if len(unexpected_keys) > 0:
120 | logger.info("Weights from pretrained model not used in {}: {}"
121 | .format(model.__class__.__name__, "\n " + "\n ".join(unexpected_keys)))
122 | if len(error_msgs) > 0:
123 | logger.info("Weights from pretrained model cause errors in {}: {}"
124 | .format(model.__class__.__name__, "\n " + "\n ".join(error_msgs)))
125 |
126 | if len(missing_keys) == 0 and len(unexpected_keys) == 0 and len(error_msgs) == 0:
127 | logger.info("All keys loaded successfully for {}".format(model.__class__.__name__))
128 |
129 | if strict and len(error_msgs) > 0:
130 | raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
131 | model.__class__.__name__, "\n\t".join(error_msgs)))
132 |
--------------------------------------------------------------------------------
/src/utils/json.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/2/18 06:03
3 | # @Author : Yaojie Shen
4 | # @Project : MM-Video
5 | # @File : json.py
6 |
7 | import json
8 |
9 |
10 | def load_json(file_path):
11 | with open(file_path, "r") as f:
12 | return json.load(f)
13 |
14 |
15 | def save_json(data, filename, save_pretty=False, sort_keys=False):
16 | class MyEncoder(json.JSONEncoder):
17 |
18 | def default(self, obj):
19 | if isinstance(obj, bytes): # bytes->str
20 | return str(obj, encoding='utf-8')
21 | return json.JSONEncoder.default(self, obj)
22 |
23 | with open(filename, "w") as f:
24 | if save_pretty:
25 | f.write(json.dumps(data, cls=MyEncoder, indent=4, sort_keys=sort_keys))
26 | else:
27 | json.dump(data, f)
28 |
--------------------------------------------------------------------------------
/src/utils/logging.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/12 22:32
3 | # @Author : Yaojie Shen
4 | # @Project : MM-Video
5 | # @File : logging.py
6 |
7 |
8 | import os
9 | import logging
10 | import colorlog
11 | import torch.distributed as dist
12 |
13 | level_dict = {
14 | "critical": logging.CRITICAL,
15 | "error": logging.ERROR,
16 | "warning": logging.WARNING,
17 | "info": logging.INFO,
18 | "debug": logging.DEBUG,
19 | "notset": logging.NOTSET
20 | }
21 |
22 |
23 | # noinspection SpellCheckingInspection
24 | def setup_logging(cfg):
25 | # log file
26 | if len(str(cfg.LOG.LOGGER_FILE).split(".")) == 2:
27 | file_name, extension = str(cfg.LOG.LOGGER_FILE).split(".")
28 | log_file_debug = os.path.join(cfg.LOG.DIR, f"{file_name}_debug.{extension}")
29 | log_file_info = os.path.join(cfg.LOG.DIR, f"{file_name}_info.{extension}")
30 | elif len(str(cfg.LOG.LOGGER_FILE).split(".")) == 1:
31 | file_name = cfg.LOG.LOGGER_FILE
32 | log_file_debug = os.path.join(cfg.LOG.DIR, f"{file_name}_debug")
33 | log_file_info = os.path.join(cfg.LOG.DIR, f"{file_name}_info")
34 | else:
35 | raise ValueError("cfg.LOG.LOGGER_FILE is invalid: %s", cfg.LOG.LOGGER_FILE)
36 | logger = logging.getLogger(__name__.split(".")[0])
37 | logger.setLevel(logging.DEBUG)
38 | logger.handlers.clear()
39 | formatter = logging.Formatter(
40 | f"[%(asctime)s][%(levelname)s]{f'[Rank {dist.get_rank()}]' if dist.is_initialized() else ''} "
41 | "%(filename)s: %(lineno)3d: %(message)s",
42 | datefmt="%m/%d %H:%M:%S",
43 | )
44 | color_formatter = colorlog.ColoredFormatter(
45 | f"%(log_color)s%(bold)s%(levelname)-8s%(reset)s"
46 | f"%(log_color)s[%(asctime)s]"
47 | f"{f'[Rank {dist.get_rank()}]' if dist.is_initialized() else ''}"
48 | "[%(filename)s: %(lineno)3d]:%(reset)s "
49 | "%(message)s",
50 | datefmt="%m/%d %H:%M:%S",
51 | )
52 | # log file
53 | if os.path.dirname(log_file_debug): # dir name is not empty
54 | os.makedirs(os.path.dirname(log_file_debug), exist_ok=True)
55 | # console
56 | handler_console = logging.StreamHandler()
57 | assert cfg.LOG.LOGGER_CONSOLE_LEVEL.lower() in level_dict, \
58 | f"Log level {cfg.LOG.LOGGER_CONSOLE_LEVEL} is invalid"
59 | handler_console.setLevel(level_dict[cfg.LOG.LOGGER_CONSOLE_LEVEL.lower()])
60 | handler_console.setFormatter(color_formatter if cfg.LOG.LOGGER_CONSOLE_COLORFUL else formatter)
61 | logger.addHandler(handler_console)
62 | # debug level
63 | handler_debug = logging.FileHandler(log_file_debug, mode="a")
64 | handler_debug.setLevel(logging.DEBUG)
65 | handler_debug.setFormatter(formatter)
66 | logger.addHandler(handler_debug)
67 | # info level
68 | handler_info = logging.FileHandler(log_file_info, mode="a")
69 | handler_info.setLevel(logging.INFO)
70 | handler_info.setFormatter(formatter)
71 | logger.addHandler(handler_info)
72 |
73 | logger.propagate = False
74 |
75 |
76 | def show_registry():
77 | from mm_video.data.build import DATASET_REGISTRY, COLLATE_FN_REGISTER
78 | from mm_video.modeling.model import MODEL_REGISTRY
79 | from mm_video.modeling.optimizer import OPTIMIZER_REGISTRY
80 | from mm_video.modeling.loss import LOSS_REGISTRY
81 | from mm_video.modeling.meter import METER_REGISTRY
82 |
83 | logger = logging.getLogger(__name__)
84 | logger.debug(DATASET_REGISTRY)
85 | logger.debug(COLLATE_FN_REGISTER)
86 | logger.debug(MODEL_REGISTRY)
87 | logger.debug(OPTIMIZER_REGISTRY)
88 | logger.debug(LOSS_REGISTRY)
89 | logger.debug(METER_REGISTRY)
90 |
--------------------------------------------------------------------------------
/src/utils/register.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2023/2/19 23:23
3 | # @Author : Yaojie Shen
4 | # @Project : MM-Video
5 | # @File : register.py
6 |
7 |
8 |
9 | class LooseRegister:
10 | def __init__(self, init_dict=None):
11 | self._dict = init_dict if init_dict is not None else {}
12 |
13 | def register(self, name, target):
14 | self._dict[name] = target
15 |
16 | def __getitem__(self, item: str):
17 | for k in self._dict.keys():
18 | if item.startswith(k):
19 | return self._dict[k]
20 | raise KeyError(f"Key {item} not found in {list(self._dict.keys())}")
21 |
22 | def __contains__(self, key):
23 | return key in self._dict
24 |
25 |
26 | if __name__ == '__main__':
27 | LOOSE_REGISTER = LooseRegister()
28 |
29 | @LOOSE_REGISTER.register()
30 | def func_a():
31 | print("a")
32 |
33 |
34 | def show_b():
35 | print("b")
36 |
37 |
--------------------------------------------------------------------------------
/src/utils/train_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/12 22:31
3 | # @Author : Yaojie Shen
4 | # @Project : MM-Video
5 | # @File : train_utils.py
6 |
7 | import datetime
8 | import os
9 | import typing
10 | import torch
11 | import time
12 | import random
13 | import itertools
14 | import numpy as np
15 | import logging
16 | from tabulate import tabulate
17 | from collections import defaultdict
18 | from typing import *
19 |
20 | import torch.distributed as dist
21 | from fvcore.common.config import CfgNode
22 |
23 | logger = logging.getLogger(__name__)
24 |
25 |
26 | class Timer:
27 | def __init__(self, synchronize=False, history_size=1000, precision=3):
28 | self._precision = precision
29 | self._stage_index = 0
30 | self._time_info = {}
31 | self._time_history = defaultdict(list)
32 | self._history_size = history_size
33 | if synchronize:
34 | assert torch.cuda.is_available(), "cuda is not available for synchronize"
35 | self._synchronize = synchronize
36 | self._time = self._get_time()
37 |
38 | def _get_time(self):
39 | return round(time.time() * 1000, self._precision)
40 |
41 | def __call__(self, stage_name=None, reset=True):
42 | if self._synchronize:
43 | torch.cuda.synchronize(torch.cuda.current_device())
44 |
45 | current_time = self._get_time()
46 | duration = (current_time - self._time)
47 | if reset:
48 | self._time = current_time
49 |
50 | if stage_name is None:
51 | self._time_info[self._stage_index] = duration
52 | else:
53 | self._time_info[stage_name] = duration
54 | self._time_history[stage_name] = self._time_history[stage_name][-self._history_size:]
55 | self._time_history[stage_name].append(duration)
56 |
57 | return duration
58 |
59 | def reset(self):
60 | if self._synchronize:
61 | torch.cuda.synchronize(torch.cuda.current_device())
62 | self._time = self._get_time()
63 |
64 | def __str__(self):
65 | return str(self.get_info())
66 |
67 | def get_info(self):
68 | info = {
69 | "current": {k: round(v, self._precision) for k, v in self._time_info.items()},
70 | "average": {k: round(sum(v) / len(v), self._precision) for k, v in self._time_history.items()}
71 | }
72 | return info
73 |
74 | def print(self):
75 | data = [[k, round(sum(v) / len(v), self._precision)] for k, v in self._time_history.items()]
76 | print(tabulate(data, headers=["Stage", "Time (ms)"], tablefmt="simple"))
77 |
78 |
79 | class CudaPreFetcher:
80 | def __init__(self, data_loader):
81 | self.dl = data_loader
82 | self.loader = iter(data_loader)
83 | self.stream = torch.cuda.Stream()
84 | self.batch = None
85 |
86 | def preload(self):
87 | try:
88 | self.batch = next(self.loader)
89 | except StopIteration:
90 | self.batch = None
91 | return
92 | with torch.cuda.stream(self.stream):
93 | self.batch = self.cuda(self.batch)
94 |
95 | @staticmethod
96 | def cuda(x: typing.Any):
97 | if isinstance(x, list) or isinstance(x, tuple):
98 | return [CudaPreFetcher.cuda(i) for i in x]
99 | elif isinstance(x, dict):
100 | return {k: CudaPreFetcher.cuda(v) for k, v in x.items()}
101 | elif isinstance(x, torch.Tensor):
102 | return x.cuda(non_blocking=True)
103 | else:
104 | return x
105 |
106 | def __next__(self):
107 | torch.cuda.current_stream().wait_stream(self.stream)
108 | batch = self.batch
109 | if batch is None:
110 | raise StopIteration
111 | self.preload()
112 | return batch
113 |
114 | def __iter__(self):
115 | self.preload()
116 | return self
117 |
118 | def __len__(self):
119 | return len(self.dl)
120 |
121 |
122 | def manual_seed(cfg: CfgNode):
123 | if cfg.SYS.DETERMINISTIC:
124 | torch.manual_seed(cfg.SYS.SEED)
125 | random.seed(cfg.SYS.SEED)
126 | np.random.seed(cfg.SYS.SEED)
127 | torch.cuda.manual_seed(cfg.SYS.SEED)
128 | torch.backends.cudnn.deterministic = True
129 | torch.backends.cudnn.benchmark = True
130 | logger.debug("Manual seed is set")
131 | else:
132 | logger.warning("Manual seed is not used")
133 |
134 |
135 | def init_distributed(proc: int, cfg: CfgNode):
136 | if cfg.SYS.MULTIPROCESS: # initialize multiprocess
137 | word_size = cfg.SYS.NUM_GPU * cfg.SYS.NUM_SHARDS
138 | rank = cfg.SYS.NUM_GPU * cfg.SYS.SHARD_ID + proc
139 | dist.init_process_group(backend="nccl", init_method=cfg.SYS.INIT_METHOD, world_size=word_size, rank=rank)
140 | torch.cuda.set_device(cfg.SYS.GPU_DEVICES[proc])
141 |
142 |
143 | def save_config(cfg: CfgNode):
144 | if not dist.is_initialized() or dist.get_rank() == 0:
145 | config_file = os.path.join(cfg.LOG.DIR, f"config_{get_timestamp()}.yaml")
146 | with open(config_file, "w") as f:
147 | f.write(cfg.dump())
148 | logger.debug("config is saved to %s", config_file)
149 |
150 |
151 | def get_timestamp():
152 | return datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S')
153 |
154 |
155 | def gather_object_multiple_gpu(list_object: List[Any]):
156 | """
157 | gather a list of something from multiple GPU
158 | :param list_object:
159 | """
160 | gathered_objects = [None for _ in range(dist.get_world_size())]
161 | dist.all_gather_object(gathered_objects, list_object)
162 | return list(itertools.chain(*gathered_objects))
163 |
--------------------------------------------------------------------------------
/src/utils/writer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2022/11/13 00:41
3 | # @Author : Yaojie Shen
4 | # @Project : MM-Video
5 | # @File : writer.py
6 |
7 | from torch.utils.tensorboard import SummaryWriter
8 | import torch.distributed as dist
9 |
10 |
11 | class DummySummaryWriter:
12 | """
13 | Issue: https://github.com/pytorch/pytorch/issues/24236
14 | """
15 | def __init__(*args, **kwargs):
16 | pass
17 |
18 | def __call__(self, *args, **kwargs):
19 | return self
20 |
21 | def __getattr__(self, *args, **kwargs):
22 | return self
23 |
24 |
25 | def get_writer(*args, **kwargs):
26 | if not dist.is_initialized() or dist.get_rank() == 0:
27 | return SummaryWriter(*args, **kwargs)
28 | else:
29 | return DummySummaryWriter()
30 |
--------------------------------------------------------------------------------
/src/utils_func.py:
--------------------------------------------------------------------------------
1 | import json
2 |
3 | # !/usr/bin/env python3
4 | # -*- coding: utf-8 -*-
5 |
6 | import json
7 |
8 |
9 | class MyEncoder(json.JSONEncoder):
10 |
11 | def default(self, obj):
12 | """
13 | 只要检查到了是bytes类型的数据就把它转为str类型
14 | :param obj:
15 | :return:
16 | """
17 | if isinstance(obj, bytes):
18 | return str(obj, encoding='utf-8')
19 | return json.JSONEncoder.default(self, obj)
20 |
21 |
22 | def save_json(data, filename, save_pretty=False, sort_keys=False):
23 | with open(filename, "w") as f:
24 | if save_pretty:
25 | f.write(json.dumps(data, cls=MyEncoder, indent=4, sort_keys=sort_keys))
26 | else:
27 | json.dump(data, f)
28 |
29 |
30 | def save_parsed_args_to_json(parsed_args, file_path, pretty=True):
31 | args_dict = vars(parsed_args)
32 | save_json(args_dict, file_path, save_pretty=pretty)
33 |
34 |
35 | def load_json(file_path):
36 | with open(file_path, "r") as f:
37 | return json.load(f)
38 |
39 |
40 | def set_lr(optimizer, decay_factor):
41 | for group in optimizer.param_groups:
42 | group["lr"] = group["lr"] * decay_factor
43 |
44 |
45 | def flat_list_of_lists(l):
46 | """flatten a list of lists [[1,2], [3,4]] to [1,2,3,4]"""
47 | return [item for sublist in l for item in sublist]
48 |
49 |
50 | def count_parameters(model, verbose=True):
51 | """Count number of parameters in PyTorch model,
52 | References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7.
53 |
54 | from utils.utils import count_parameters
55 | count_parameters(model)
56 | import sys
57 | sys.exit(1)
58 | """
59 | n_all = sum(p.numel() for p in model.parameters())
60 | n_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
61 | #if verbose:
62 | # print("Parameter Count: all {:,d}; trainable {:,d}".format(n_all, n_trainable))
63 | return n_all, n_trainable
64 |
65 |
66 | def sum_parameters(model, verbose=True):
67 | """Count number of parameters in PyTorch model,
68 | References: https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/7.
69 |
70 | from utils.utils import count_parameters
71 | count_parameters(model)
72 | import sys
73 | sys.exit(1)
74 | """
75 | p_sum = sum(p.sum().item() for p in model.parameters())
76 | if verbose:
77 | print("Parameter sum {}".format(p_sum))
78 | return p_sum
79 |
80 |
81 | def merge_dicts(list_dicts):
82 | merged_dict = list_dicts[0].copy()
83 | for i in range(1, len(list_dicts)):
84 | merged_dict.update(list_dicts[i])
85 | return merged_dict
86 |
87 |
88 | def merge_json_files(paths, merged_path):
89 | merged_dict = merge_dicts([load_json(e) for e in paths])
90 | save_json(merged_dict, merged_path)
91 |
92 |
93 |
--------------------------------------------------------------------------------