├── .gitignore ├── MANIFEST.in ├── README.md ├── RELEASE_GUIDE.md ├── data ├── factual_mr │ ├── factual_mr.csv │ └── meta.json ├── factual_sg │ ├── factual_sg.csv │ ├── length │ │ ├── dev.csv │ │ ├── test.csv │ │ └── train.csv │ └── random │ │ ├── dev.csv │ │ ├── test.csv │ │ └── train.csv └── factual_sg_id │ ├── factual_sg_id.csv │ └── random │ ├── dev.csv │ ├── test.csv │ └── train.csv ├── interface └── sg_collect.html ├── logo ├── adobe_logo.png ├── monash_logo.png └── wuhan_logo.png ├── release.sh ├── requirements.txt ├── setup.py ├── src └── factual_scene_graph │ ├── __init__.py │ ├── evaluation │ ├── __init__.py │ ├── evaluator.py │ ├── resources │ │ ├── __init__.py │ │ ├── english.exceptions │ │ └── english.synsets │ ├── set_match_evaluation.py │ ├── soft_spice_evaluation.py │ ├── spice_evaluation.py │ └── synonym_dictionary.py │ ├── parser │ ├── __init__.py │ └── scene_graph_parser.py │ └── utils.py └── tests ├── test_data ├── SPICE_parsing_outputs.txt ├── crowdflower_flickr8k.json └── flickr8k.json ├── test_eval_spice.py ├── test_evaluator.py ├── test_metric_human_correlation.py ├── test_parser.py ├── test_spice_parser.py ├── test_synonym_dictionary.py └── train_parser.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # pycharm 7 | .idea/ 8 | .idea/* 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | 31 | # PyInstaller 32 | # Usually these files are written by a python script from a template 33 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 34 | *.manifest 35 | *.spec 36 | 37 | # Installer logs 38 | pip-log.txt 39 | pip-delete-this-directory.txt 40 | 41 | # Unit test / coverage reports 42 | htmlcov/ 43 | .tox/ 44 | .nox/ 45 | .coverage 46 | .coverage.* 47 | .cache 48 | nosetests.xml 49 | coverage.xml 50 | *.cover 51 | *.py,cover 52 | .hypothesis/ 53 | .pytest_cache/ 54 | cover/ 55 | 56 | 57 | *.xml 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | db.sqlite3-journal 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | 121 | # pytype static type analyzer 122 | .pytype/ 123 | 124 | # Cython debug symbols 125 | cython_debug/ 126 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include src/factual_scene_graph/evaluation/resources/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # [ACL 2023 Findings] FACTUAL: A Benchmark for Faithful and Consistent Textual Scene Graph Parsing 2 | 3 | Welcome to the official repository for the ACL 2023 Findings paper: 4 | [**FACTUAL: A Benchmark for Faithful and Consistent Textual Scene Graph Parsing**](https://arxiv.org/pdf/2305.17497.pdf). Here, you'll find both the code and dataset associated with our research. 5 | 6 |
7 | Monash University Logo 8 | Adobe Logo 9 | Wuhan University Logo 10 |
11 | 12 |

13 | 14 | 15 | 16 | 17 |

18 | 19 | ## 🆕 New Feature: Multi-Sentence Scene Graph Parsing 20 | 21 | > **✨ Now supports parsing complex, multi-sentence descriptions!** 22 | > Use `parser_type='sentence_merge'` to automatically split text into sentences, parse each individually, and merge the results into a unified scene graph. 23 | 24 | **Key Benefits:** 25 | - 🔄 **Automatic sentence segmentation** using NLTK 26 | - ⚡ **Efficient batch processing** for optimal performance 27 | - 🧹 **Smart deduplication** of entities and relations 28 | - 🔗 **Maintains relationships** across sentence boundaries 29 | - 📝 **Perfect for complex descriptions** like image captions, stories, or detailed scene descriptions 30 | 31 | **Quick Example:** 32 | ```python 33 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 34 | 35 | # Multi-sentence parser 36 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', parser_type='sentence_merge') 37 | 38 | # Parse complex description 39 | result = parser.parse(["""The cat sits on a mat. The mat is red and soft. 40 | A dog runs nearby."""]) 41 | ``` 42 | 43 | [👀 See full documentation below](#multi-sentence-parsing-with-sentence-merge) 44 | 45 | --- 46 | 47 | ## Installation 48 | 49 | ```sh 50 | pip install FactualSceneGraph 51 | ``` 52 | 53 | ## Dataset 54 | 55 | The FACTUAL Scene Graph dataset includes 40,369 instances with lemmatized predicates/relations. 56 | 57 | ### FACTUAL Scene Graph dataset: 58 | 59 | - **Storage**: `data/factual_sg/factual_sg.csv` 60 | - **From Huggingface**: `load_dataset('lizhuang144/FACTUAL_Scene_Graph')` 61 | 62 | **Splits**: 63 | - Random Split: 64 | - Train: `data/factual_sg/random/train.csv` 65 | - Test: `data/factual_sg/random/test.csv` 66 | - Dev: `data/factual_sg/random/dev.csv` 67 | - Length Split: 68 | - Train: `data/factual_sg/length/train.csv` 69 | - Test: `data/factual_sg/length/test.csv` 70 | - Dev: `data/factual_sg/length/dev.csv` 71 | 72 | **Data Fields**: 73 | 74 | - `image_id`: The ID of the image in Visual Genome. 75 | - `region_id`: The ID of the region in Visual Genome. 76 | - `caption`: The caption of the image region. 77 | - `scene_graph`: The scene graph of the image region and caption. 78 | 79 | **Related Resources**: Please find the details of images and regions from [Visual Genome](https://huggingface.co/datasets/visual_genome) given their corresponding IDs. 80 | 81 | ### FACTUAL-MR dataset: 82 | 83 | - `data/factual_mr/factual_mr.csv` 84 | - `data/factual_mr/meta.json`: the metadata for mapping the abbreviations of quantifiers in `factual_mr.csv` to their complete names. 85 | 86 | ### VG Scene Graph dataset: 87 | 88 | - **From Huggingface**: `load_dataset('lizhuang144/VG_scene_graph_clean')` 89 | - **Details**: Cleaned to exclude empty instances; includes 2.9 million instances. 90 | 91 | ### FACTUAL Scene Graph dataset with identifiers: 92 | 93 | - **From Huggingface**: `load_dataset('lizhuang144/FACTUAL_Scene_Graph_ID')` 94 | - **Enhancements**: Contains verb identifiers, passive voice indicators, and node indexes. 95 | 96 | ## Scene Graph Parsing Models 97 | 98 | ### Simplified Model Training Without Node Indexes and Passive Identifiers 99 | 100 | The following table shows the performance comparison of various scene graph parsing models. Notably, the original SPICE parser performs worse than our more recent models. 101 | 102 | #### Performance Metrics Explained: 103 | - **SPICE F-score**: A metric that measures the similarity between candidate and reference scene graph representations derived from captions. It assesses the quality of scene graph parsing by evaluating how well the parser's output matches the ground truth graph in terms of propositional content. 104 | - **Exact Set Match**: Adapted from the methodology described by [Yu et al., 2019](https://aclanthology.org/P19-1443/), this metric evaluates the parser's accuracy by verifying whether the strings of parsed facts match the ground truth facts, without considering the ordering of those facts. This adaptation is a stringent accuracy measure, necessitating an exact correspondence between the candidate and ground truth facts. 105 | 106 | > **Note**: It is important to note that in the original work of Yu et al., 2019, the metric was applied to SQL clauses, whereas in our context, it has been tailored to assess scene graph facts. 107 | 108 | 109 | | Model | Set Match | SPICE | Soft-SPICE | Model Weight | 110 | |-------|-----------|-------|------------|--------------| 111 | | SPICE/Stanford Parser | 19.30 | 64.77 | 92.60 | [modified-SPICE-score](https://github.com/yychai74/modified-SPICE-score) | 112 | | (pre) Flan-T5-large | 81.63 | 93.20 | 98.75 | [flan-t5-large-VG-factual-sg](https://huggingface.co/lizhuang144/flan-t5-large-VG-factual-sg) | 113 | | (pre) Flan-T5-base | 81.37 | 93.27 | 98.83 | [flan-t5-base-VG-factual-sg](https://huggingface.co/lizhuang144/flan-t5-base-VG-factual-sg) | 114 | | (pre) Flan-T5-small | 78.18 | 92.26 | 98.67 | [flan-t5-small-VG-factual-sg](https://huggingface.co/lizhuang144/flan-t5-small-VG-factual-sg) | 115 | 116 | The prefix "(pre)" indicates models that were pre-trained on the VG scene graph dataset before being fine-tuned on the FACTUAL dataset. The outdated SPICE parser, despite its historical significance, shows a Set Match rate of only 19.30% and a SPICE score of 64.77, which is significantly lower than the more recent Flan-T5 models fine-tuned on FACTUAL data. 117 | 118 | > **Note**: 119 | > 1. **Model Training Adjustments**: In training these models, the node index has been removed. This means that different nodes with identical names are not distinguished by their indexes. Additionally, passive identifiers such as 'p:' are excluded, and verbs and prepositions have been merged. While this format loses some information from the FACTUAL-MR dataset, it remains compatible with the Visual Genome scene graphs and is effectively usable for downstream scene graph tasks. 120 | > 2. **SPICE Parser Performance**: The performance of the SPICE Parser in the table above differs significantly from the original results reported in our paper. This is because the parser is based on dependency parsing. To ensure a fair comparison, we have aligned its parsing outputs with the ground truth generated by research on dependency parsing-based scene graph parsing (See [Scene Graph Parsing as Dependency Parsing](https://arxiv.org/abs/1803.09189)). As a result, our comparison in our paper was more aligned with their findings. However, in the table above, we recompare the SPICE Parser outputs with the ground truth from our dataset and show a new result. Please see ``tests/test_spice_parser.py`` to replicate the SPICE results. 121 | 122 | ### Enhanced Scene Graph Parsing with Node Indexes and Verb Identifiers 123 | 124 | Enhanced scene graph parsing includes detailed annotations such as verb identifiers and node indexes, which offer a more nuanced understanding of the relationships within the input text. For example: 125 | 126 | - The sentence "A monkey is sitting next to another monkey" is parsed as: 127 | `( monkey, v:sit next to, monkey:1 )` 128 | Here, "v:" indicates a verb, and ":1" differentiates the second "monkey" as a unique entity. 129 | 130 | - For "A car is parked on the ground", the scene graph is: 131 | `( car, pv:park on, ground )` 132 | The "pv:" prefix highlights "park" as a passive verb, underscoring the significance of node order in the graph. 133 | 134 | This advanced parsing technique offers substantial enhancements over the original Visual Genome (VG) scene graphs by: 135 | 136 | - **Uniquely Identifying Similar Entities**: Assigning indexes to nodes with the same name allows for clear differentiation between identical entities. 137 | - **Detailing Predicates**: Annotating each predicate with the specific verb and its tense provides richer contextual information. 138 | 139 | Such improvements are invaluable for complex downstream tasks, as they facilitate a deeper semantic understanding of the scenes. 140 | 141 | #### Model Performance with Advanced Parsing: 142 | 143 | | Model | Set Match | SPICE | Soft-SPICE |Model Weight | 144 | |-------|-----------|-------|--------------|--------------| 145 | | (pre) Flan-T5-large | 81.03 | 93.00 | 98.66 |[flan-t5-large-VG-factual-sg-id](https://huggingface.co/lizhuang144/flan-t5-large-VG-factual-sg-id) | 146 | | (pre) Flan-T5-base | 81.37 | 93.29 | 98.76 |[flan-t5-base-VG-factual-sg-id](https://huggingface.co/lizhuang144/flan-t5-base-VG-factual-sg-id) | 147 | | (pre) Flan-T5-small | 79.64 | 92.40 | 98.53 |[flan-t5-small-VG-factual-sg-id](https://huggingface.co/lizhuang144/flan-t5-small-VG-factual-sg-id) | 148 | 149 | The acronym (pre) stands for models that were pre-trained on VG and then fine-tuned on FACTUAL, indicating a two-phase learning process that enhances model performance. 150 | 151 | ## Usage Examples 152 | 153 | This section demonstrates how to use our models for scene graph parsing. We provide examples for basic usage, advanced usage with the `SceneGraphParser` class, and the **new sentence merge functionality** for multi-sentence text processing. 154 | 155 | ### Basic Usage 156 | 157 | ```python 158 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 159 | 160 | tokenizer = AutoTokenizer.from_pretrained("lizhuang144/flan-t5-base-VG-factual-sg") 161 | model = AutoModelForSeq2SeqLM.from_pretrained("lizhuang144/flan-t5-base-VG-factual-sg") 162 | 163 | text = tokenizer( 164 | "Generate Scene Graph: 2 pigs are flying on the sky with 2 bags on their backs", 165 | max_length=200, 166 | return_tensors="pt", 167 | truncation=True 168 | ) 169 | 170 | generated_ids = model.generate( 171 | text["input_ids"], 172 | attention_mask=text["attention_mask"], 173 | use_cache=True, 174 | decoder_start_token_id=tokenizer.pad_token_id, 175 | num_beams=1, 176 | max_length=200, 177 | early_stopping=True 178 | ) 179 | 180 | print(tokenizer.decode(generated_ids[0], skip_special_tokens=True, clean_up_tokenization_spaces=True)) 181 | # Output: `( pigs , is , 2 ) , ( bags , on back of , pigs ), ( bags , is , 2 ) , ( pigs , fly on , sky )` 182 | ``` 183 | Note: In this example, the predicate 'is' is referred to as 'has_attribute'. 184 | 185 | ### Advanced Usage with SceneGraphParser 186 | 187 | For more advanced parsing, utilize the `SceneGraphParser` class: 188 | 189 | ```python 190 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 191 | 192 | # Default parser for single sentences or simple text 193 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device='cpu') 194 | text_graph = parser.parse(["2 beautiful pigs are flying on the sky with 2 bags on their backs"], 195 | beam_size=1, return_text=True) 196 | graph_obj = parser.parse(["2 beautiful and strong pigs are flying on the sky with 2 bags on their backs"], 197 | beam_size=5, return_text=False, max_output_len=128) 198 | 199 | print(text_graph[0]) 200 | # Output: ( pigs , is , 2 ) , ( pigs , is , beautiful ) , ( bags , on back of , pigs ) , ( pigs , fly on , sky ) , ( bags , is , 2 ) 201 | 202 | from factual_scene_graph.utils import tprint 203 | tprint(graph_obj[0]) 204 | ``` 205 | This will produce a formatted scene graph output: 206 | ``` 207 | Entities: 208 | +----------+------------+------------------+ 209 | | Entity | Quantity | Attributes | 210 | |----------+------------+------------------| 211 | | pigs | 2 | beautiful,strong | 212 | | bags | 2 | | 213 | | sky | | | 214 | +----------+------------+------------------+ 215 | Relations: 216 | +-----------+------------+----------+ 217 | | Subject | Relation | Object | 218 | |-----------+------------+----------| 219 | | pigs | fly on | sky | 220 | | bags | on back of | pigs | 221 | +-----------+------------+----------+ 222 | ``` 223 | 224 | ### 🌟 Multi-Sentence Parsing with Sentence Merge 225 | 226 | **NEW:** For parsing multi-sentence descriptions, use the sentence merge functionality: 227 | 228 | ```python 229 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 230 | 231 | # Sentence merge parser for multi-sentence text 232 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device='cpu', parser_type='sentence_merge') 233 | 234 | # Multi-sentence description 235 | multi_sentence_text = """The image captures a serene scene in a park. A gravel path, dappled with sunlight 236 | filtering through the tall trees on either side, winds its way towards a white bridge. The bridge arches 237 | over a small body of water, possibly a stream or a pond. The sky above is a clear blue, with a few clouds 238 | scattered across it.""" 239 | 240 | text_graph = parser.parse([multi_sentence_text], beam_size=1, return_text=True) 241 | graph_obj = parser.parse([multi_sentence_text], beam_size=5, return_text=False, max_output_len=128) 242 | 243 | print("Parsed scene graph:") 244 | print(text_graph[0]) 245 | 246 | print("\nFormatted output:") 247 | from factual_scene_graph.utils import tprint 248 | tprint(graph_obj[0]) 249 | ``` 250 | 251 | **How Sentence Merge Works:** 252 | 1. **Sentence Tokenization**: Uses NLTK to split the input text into individual sentences 253 | 2. **Individual Parsing**: Each sentence is parsed separately using the base model 254 | 3. **Graph Merging**: The resulting scene graphs are merged and deduplicated 255 | 4. **Efficient Processing**: All sentences are processed in batches for optimal performance 256 | 257 | **Benefits:** 258 | - Handles complex, multi-sentence descriptions 259 | - Maintains relationships across sentence boundaries 260 | - Automatic deduplication of repeated entities and relations 261 | - Efficient batch processing for better GPU utilization 262 | 263 | ## A Comprehensive Toolkit for Scene Graph Parsing Evaluation 264 | 265 | This package provides implementations for evaluating scene graphs using SPICE, SoftSPICE, and Set Match metrics. These evaluations can be performed on various inputs, including captions and scene graphs in both list and nested list formats. 266 | 267 | ### Supported Input Formats 268 | 269 | - `(list of candidate_captions, list of list reference_captions)` 270 | - `(list of candidate_captions, list of list reference_graphs)` 271 | - `(list of candidate_graphs, list of list reference_graphs)` 272 | 273 | ### Usage 274 | 275 | Below are examples demonstrating how to use the evaluation methods provided in this package. 276 | 277 | #### Example 1: Testing Scene Graph Parsing 278 | 279 | This example demonstrates evaluating a single scene graph using the SPICE method. 280 | 281 | ```python 282 | import pandas as pd 283 | import torch 284 | from factual_scene_graph.evaluation.evaluator import Evaluator 285 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 286 | 287 | def test_scene_graph_parsing(): 288 | device = "cuda" if torch.cuda.is_available() else "cpu" 289 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device=device) 290 | evaluator = Evaluator(parser=parser, device=device) 291 | 292 | scores = evaluator.evaluate( 293 | ["2 beautiful pigs are flying on the sky with 2 bags on their backs"], 294 | [['( pigs , is , beautiful ) , ( bags , on back of , pigs ) , ( bags , is , 2 ) , ( pigs , is , 2 ) , ( pigs , fly on , sky )']], 295 | method='spice', 296 | beam_size=1, 297 | max_output_len=128 298 | ) 299 | print(scores) 300 | 301 | # Uncomment to run the example 302 | # test_scene_graph_parsing() 303 | ``` 304 | 305 | #### Example 2: Testing Scene Graph Parsing on the Test Set of FACTUAL Random Split 306 | 307 | This example demonstrates evaluating a dataset of scene graphs using SPICE, Set Match, and SoftSPICE methods. 308 | 309 | ```python 310 | import pandas as pd 311 | import torch 312 | from factual_scene_graph.evaluation.evaluator import Evaluator 313 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 314 | 315 | def test_scene_graph_parsing_on_random(): 316 | device = "cuda" if torch.cuda.is_available() else "cpu" 317 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device=device, lemmatize=False) 318 | evaluator = Evaluator(parser=parser, text_encoder_checkpoint='all-MiniLM-L6-v2', device=device, lemmatize=True) 319 | 320 | random_data_pd = pd.read_csv('data/factual_sg/random/test.csv') 321 | random_data_captions = random_data_pd['caption'].tolist() 322 | random_data_graphs = [[scene] for scene in random_data_pd['scene_graph'].tolist()] 323 | 324 | # Evaluating using SPICE 325 | spice_scores, cand_graphs, ref_graphs = evaluator.evaluate( 326 | random_data_captions, 327 | random_data_graphs, 328 | method='spice', 329 | beam_size=1, 330 | batch_size=128, 331 | max_input_len=256, 332 | max_output_len=256, 333 | return_graphs=True 334 | ) 335 | print('SPICE scores for random test set:', sum(spice_scores)/len(spice_scores)) 336 | 337 | # Evaluating using Set Match 338 | set_match_scores = evaluator.evaluate(cand_graphs, ref_graphs, method='set_match', beam_size=1) 339 | print('Set Match scores for random test set:', sum(set_match_scores)/len(set_match_scores)) 340 | 341 | # Evaluating using Soft-SPICE 342 | soft_spice_scores = evaluator.evaluate(cand_graphs, ref_graphs, method='soft_spice', beam_size=1) 343 | print('Soft-SPICE scores for random test set:', sum(soft_spice_scores)/len(soft_spice_scores)) 344 | 345 | # Uncomment to run the example 346 | # test_scene_graph_parsing_on_random() 347 | ``` 348 | 349 | ### Human Correlation Performance on the Flickr8k Dataset 350 | 351 | In our study, we evaluated the correlation between various metrics and human judgment in image caption generation on the Flickr8k dataset using Kendall's tau. This comparison helps in understanding how well each metric aligns with human perception. 352 | 353 | #### Results 354 | 355 | Below is a table showing the Tau-c correlation values for different models: 356 | 357 | | Model | Tau-c | 358 | |------------------|-------| 359 | | SPICE(Official-Original) | 44.77 | 360 | | SPICE(Official-Factual) | 45.13 | 361 | | SPICE(Ours-Factual) | 45.25 | 362 | | Soft-SPICE | 54.20 | 363 | | RefCLIPScore | 53.00 | 364 | | BERTScore | 36.71 | 365 | 366 | #### SPICE Implementations 367 | 368 | This section provides an overview of the different SPICE implementations used in our project. 369 | 370 | - **1. SPICE(Official-Original)**: 371 | 372 | - Uses the original parser from the [Modified SPICE Score](https://github.com/yychai74/modified-SPICE-score) repository. 373 | - Follows the official SPICE implementation as provided in the repository. 374 | - Employs the original parser to process the input and generate the SPICE score. 375 | 376 | - **2. SPICE(Official-Factual)**: 377 | 378 | - Follows the official SPICE implementation from the [Modified SPICE Score](https://github.com/yychai74/modified-SPICE-score) repository. 379 | - Uses the `lizhuang144/flan-t5-base-VG-factual-sg` checkpoint as the parser instead of the original parser. 380 | 381 | - **3. SPICE(Ours-Factual)**: 382 | 383 | - Our own SPICE implementation, denoted by the "Ours" prefix. 384 | - Utilizes the `lizhuang144/flan-t5-base-VG-factual-sg` checkpoint as the parser. 385 | - Updated with an improved synonym-matching dictionary, resulting in closer alignment with the official SPICE synonym-matching version. 386 | - The update, now the default setting in SPICE(Ours-Factual), shows a stronger correlation with human judgment than the official SPICE version. 387 | - Recommended for better performance in relevant applications. 388 | 389 | - **4. Soft-SPICE**: 390 | 391 | - A variant of the SPICE score that incorporates a soft matching mechanism. 392 | - Uses the `lizhuang144/flan-t5-base-VG-factual-sg` checkpoint as the parser. 393 | - The default text encoder is `all-MiniLM-L6-v2` from the `SentenceTransformer` library. 394 | - Aims to provide a more flexible and nuanced evaluation of the generated text by considering soft matches between the reference and the generated content. 395 | 396 | These SPICE implementations offer various options for evaluating the quality of the generated text, each with its own characteristics and parser choices. The "Official" implementations follow the original SPICE repository, while our implementation (SPICE(Ours-Factual)) introduces improvements and updates for enhanced performance. 397 | 398 | #### Replicating the Results 399 | 400 | To replicate the human correlation results for Our SPICE and Soft-SPICE, please refer to the script located at `tests/test_metric_human_correlation.py`. This script provides a straightforward way to validate our findings. 401 | 402 | ## Citation 403 | 404 | If you find the paper or the accompanying code beneficial, please acknowledge our work in your own research. Please use the following BibTeX entry for citation: 405 | 406 | ``` 407 | @inproceedings{li-etal-2023-factual, 408 | title = "{FACTUAL}: A Benchmark for Faithful and Consistent Textual Scene Graph Parsing", 409 | author = "Li, Zhuang and 410 | Chai, Yuyang and 411 | Zhuo, Terry Yue and 412 | Qu, Lizhen and 413 | Haffari, Gholamreza and 414 | Li, Fei and 415 | Ji, Donghong and 416 | Tran, Quan Hung", 417 | booktitle = "Findings of the Association for Computational Linguistics: ACL 2023", 418 | month = jul, 419 | year = "2023", 420 | address = "Toronto, Canada", 421 | publisher = "Association for Computational Linguistics", 422 | url = "https://aclanthology.org/2023.findings-acl.398", 423 | pages = "6377--6390", 424 | } 425 | ``` 426 | 427 | ## Acknowledgments 428 | 429 | This project has been developed with the use of code from the [SceneGraphParser](https://github.com/vacancy/SceneGraphParser) repository by [Jiayuan Mao](https://github.com/vacancy). We gratefully acknowledge their pioneering work and contributions to the open-source community. 430 | 431 | 432 | 433 | -------------------------------------------------------------------------------- /RELEASE_GUIDE.md: -------------------------------------------------------------------------------- 1 | # 📦 Release Guide for FactualSceneGraph 2 | 3 | ## 🚀 Quick Release (Automated) 4 | 5 | Use the provided release script: 6 | 7 | ```bash 8 | # Default release 9 | ./release.sh 10 | 11 | # Custom version/message 12 | ./release.sh "0.6.1" "Add multi-sentence scene graph parsing" 13 | ``` 14 | 15 | ## 📝 Manual Release Process 16 | 17 | ### Prerequisites 18 | ```bash 19 | pip install twine build wheel 20 | ``` 21 | 22 | ### Step-by-Step Commands 23 | 24 | ```bash 25 | # 1. Ensure you're on main branch and up to date 26 | git checkout main 27 | git pull origin main 28 | 29 | # 2. Clean previous builds 30 | rm -rf build/ dist/ *.egg-info/ 31 | 32 | # 3. Build the package 33 | python setup.py sdist bdist_wheel 34 | 35 | # 4. Check the build quality 36 | python -m twine check dist/* 37 | 38 | # 5. Commit and push changes 39 | git add . 40 | git commit -m "Release v0.6.1: Add multi-sentence scene graph parsing" 41 | git push origin main 42 | 43 | # 6. Create and push git tag 44 | git tag -a "v0.6.1" -m "Version 0.6.1: Add multi-sentence scene graph parsing" 45 | git push origin v0.6.1 46 | 47 | # 7. Upload to PyPI 48 | python -m twine upload dist/* 49 | 50 | # 8. Verify the release 51 | pip install --upgrade FactualSceneGraph 52 | ``` 53 | 54 | ## 🔧 PyPI Configuration 55 | 56 | Create `~/.pypirc` for authentication: 57 | 58 | ```ini 59 | [distutils] 60 | index-servers = pypi 61 | 62 | [pypi] 63 | username = __token__ 64 | password = pypi-your-api-token-here 65 | ``` 66 | 67 | ## ✅ Post-Release Checklist 68 | 69 | - [ ] Check [PyPI package page](https://pypi.org/project/FactualSceneGraph/) 70 | - [ ] Verify installation: `pip install --upgrade FactualSceneGraph` 71 | - [ ] Test the new features work as expected 72 | - [ ] Create GitHub Release with release notes 73 | - [ ] Update documentation if needed 74 | - [ ] Announce the release (if applicable) 75 | 76 | ## 🐛 Troubleshooting 77 | 78 | ### Common Issues: 79 | 80 | 1. **Authentication errors**: Check your PyPI token in `~/.pypirc` 81 | 2. **Version conflicts**: Ensure version number is incremented 82 | 3. **Build errors**: Check for syntax errors in setup.py 83 | 4. **Git push errors**: Ensure you have push permissions to the repository 84 | 85 | ### Testing Before Release: 86 | 87 | ```bash 88 | # Test local installation 89 | pip install -e . 90 | 91 | # Test import 92 | python -c "from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser; print('Import successful')" 93 | 94 | # Test new feature 95 | python -c " 96 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', parser_type='sentence_merge') 97 | print('Sentence merge parser created successfully') 98 | " 99 | ``` -------------------------------------------------------------------------------- /data/factual_mr/meta.json: -------------------------------------------------------------------------------- 1 | { 2 | "pr":"pair", 3 | "gr":"group", 4 | "pa":"part", 5 | "sl": "slice", 6 | } 7 | -------------------------------------------------------------------------------- /interface/sg_collect.html: -------------------------------------------------------------------------------- 1 | 2 |
3 |
4 |
5 |
6 |

Tagging Instructions (Click to expand)

7 | 8 |
9 |

Please annotate the scene graph in the image caption.

10 |   11 | 12 |
13 |

Image Caption:

14 | 15 |

White antenna on top of the boat

16 |
17 | 18 |
19 |

Image:

20 |
21 |   22 | 23 |
Please annotate the scene graph in the image caption. The scene graph describes the text nodes in the image caption and the relations between those nodes. 24 |

A complete example

25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 |
Sub Node IdSubject NodeVerbPrepositionObj Node IdObject Node
Node0antennahas_attributeno_prepositionNode1White
Node0antennano_verbon top ofNode2boat
56 | 57 |

What is a node?

58 | A text node is a phrase in the caption either indicating an object grounded in the image or the attribute of an object. For example, in the caption "White antenna on top of a boat", "antenna" and "boat" are the object nodes. "White" is an attribute node that describes the object "antenna". 59 | 60 |
    61 |
  • All attributes should be adjective. The "stone" in "stone wall" is not an attribute since here "stone" is a noun. So you should consider the whole phrase, "stone wall", as one object node which is a multi-word expression. Tip: if the attribute is an adjective, the phrase can be rephrased as "[object] is [attribute]". For example, "White antenna" can be rephrased as "[antenna] is [White]". But "[wall] is [stone]" sounds odd.
  • 62 |
  • Please note the node text should exactly match the text in the caption, otherwise you can not submit the HIT. For example, "white" does not match the text "White" in the caption.
  • 63 |
  • Please note that we also annotate quantity. For example, the quantity of "a shoe" is "one". The quantity of "a pair of shoes" is "a pair of". The quantity of "two baskets of apples" is "two groups of". And note in this phrase, there is also an implicit relation here, "apples in basket".
  • 64 |
  • When two nodes have the same text, differentiate them with the node ids. For example, in caption "a red wall on the back of a black wall". You could annotate the red "wall" as "Node0" and black "wall" as "Node1". Of course, the attributes "black" here is "Node2" and "red" is "Node3".
  • 65 |
66 | 67 |

What is a relation?

68 | A relation includes two parts, verb and preposition, which describe the relationship between the nodes. For simplicity, we converted all the verbs into their original forms. For example, "is surrounded" or "surrounding" are all converted into "surround". 69 | 70 |
    71 |
  • We provide a list of candidate verbs and a list of candidate of prepositions to annotate the captions. If you can not find the verbs or prepositions which exactly match the verbs or prepositions in the caption, select the ones with their closest semantic meanings. For example, for the caption, "the man is leaning on the back of a bike", if we don't provide the preposition "on the back of" while we provide the preposition "on the backwards", please select this one as the answer.
  • 72 |
  • There are two default verbs "no_verb" and "has_attribute". If you can not find any verbs describing the relations or the verb is "is", "are", etc., just select "no_verb". We consider "has_attribute" as a special verb which only connects objects with attributes.
  • 73 |
  • There is one default preposition "no_preposition". If you can not find any preposition describing the relations, just select "no_preposition". Notice that if you select "has_attribute" as the verb, the preposition would be selected as "no_preposition" by default.
  • 74 |
  • Important Note: Please try to always annotate the relation in the active voice. For example, in the caption "the gate is surrounded by fence". The relation between [fence] and [gate] should be [fence] "surround" "no_preposition" [gate] but not [gate] "surround" "by" [fence]". Only in the cases when it is not possible to annotate the relation with the active voice, you should annotate the relation with passive voice. For example, in the caption "the mug is filled with coffee". The object [coffee] can not perform the action "fill". Therefore, the relation can only be, [mug] "fill" "with" [coffee]". Although here "fill" is the original form of "is filled", you could select "passive voice" before this "verb", which could indicate that the verb "fill" is actually in its passive voice. If caption is "the man fill the mug with coffee", [man] could perform action "fill". Therefore, the graph includes two triples, [man] "fill" "no_preposition" [mug], [mug] "fill" "with" [coffee]. We also provide an option to select whether the verb is in active voice or passive voice. The order of objects is very important in our study, please be careful about the order!!!
  • 75 |
76 | 77 |

Some Corner Cases (read only when you don't know how to annotate)

78 | Below are some corner cases we annotated. If you find the caption difficult to annotate, please check the corner cases for reference. If after observation with the corner cases, you still don't know how to annotate. Please see the box at the bottom and write down why you find it difficult. 79 | 80 |
    81 |
  • Sometimes, there is no relation in the caption. For example, for the caption "olive topping", it has only one object without any relation. Then the whole graph includes only one node "olive topping" without any relations.
  • 82 |
  • We find the relation "of" is usually difficult to annotate. For example, the caption "head of a person" or "a person's head" should all be annotated as [person] "have" "no_preposition" [head].
  • 83 |
  • Sometimes, there are two relations between a pair of objects. For example, the caption "street beneath and around the skater" should be annotated as two triples, [street] "no_verb" "beneath" [skater], [street] "no_verb" "around" [skater].
  • 84 |
  • The co-reference should be resolved first before the graph annotation. For example, the caption "the apples have light spots on them" should be annotated as [light spots] "no_verb" "on" [apples]. The anaphora, "them", points to the object [apples].
  • 85 |
  • Sometimes, there are grammar errors in the caption. Please fix the grammar errors given your experience and annotate the caption. For example, "the man is swinging racket a ball" should be corrected as "the man is swinging racket and a ball". Then the graph is as two triples, [man] "swing" "no_preposition" [racket], [man] "swing" "no_preposition" [ball].
  • 86 |
  • There are self-loops in the graph. For example, caption "two boys sitting side by side" should be annotated as, [two boys] "sit" "side by side" [two boys]. The object, [two boys], connects to itself [two boys] through the relation ,"sit" "side by side".
  • 87 |
  • Sometimes, verb can be considered as the attribute. In caption, "the man is reading". It should be annotated as, [man] "has_attribute" "no_preposition" [reading]. Here "reading" is considered as an attribute.
  • 88 |
89 |
90 |
91 |
92 |
93 |
94 | 95 |
96 |
97 |
98 |

Image-Caption

99 | 100 |
101 |
102 |

Caption:

103 | 104 |

{{caption}}

105 |
106 | 107 |
108 |

Image Id : {{image_id}}

109 |
110 |
111 |
112 |
113 |
114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 |
Node IdsQuantityNode NameNode Position in the CaptionRemove Button
128 | 129 |
130 |
131 |
132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 |
Sub Node IdSubject NodeVerb VoiceVerbPrepositionObj Node IdObject NodeRemove Button
149 | 150 |
151 |
152 |
153 | 154 |
 
155 | 164 | 165 | 173 | 174 | 175 |
176 |
177 |
178 | Submit
192 | 193 | 194 | 195 | 211 | -------------------------------------------------------------------------------- /logo/adobe_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuang-li/FactualSceneGraph/944b284e2b983113e05ee3f36da30707db622b7e/logo/adobe_logo.png -------------------------------------------------------------------------------- /logo/monash_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuang-li/FactualSceneGraph/944b284e2b983113e05ee3f36da30707db622b7e/logo/monash_logo.png -------------------------------------------------------------------------------- /logo/wuhan_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuang-li/FactualSceneGraph/944b284e2b983113e05ee3f36da30707db622b7e/logo/wuhan_logo.png -------------------------------------------------------------------------------- /release.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Release script for FactualSceneGraph 4 | # Usage: ./release.sh [version] [message] 5 | 6 | set -e # Exit on any error 7 | 8 | VERSION=${1:-"0.6.1"} 9 | MESSAGE=${2:-"Add multi-sentence scene graph parsing with sentence merge functionality"} 10 | 11 | echo "🚀 Starting release process for version $VERSION" 12 | 13 | # 1. Ensure we're on main branch and up to date 14 | echo "📥 Ensuring we're on main branch and up to date..." 15 | git checkout main 16 | git pull origin main 17 | 18 | # 2. Run tests (optional - uncomment if you have tests) 19 | # echo "🧪 Running tests..." 20 | # python -m pytest tests/ 21 | 22 | # 3. Clean previous builds 23 | echo "🧹 Cleaning previous builds..." 24 | rm -rf build/ dist/ *.egg-info/ 25 | 26 | # 4. Build the package 27 | echo "📦 Building package..." 28 | python setup.py sdist bdist_wheel 29 | 30 | # 5. Check the build 31 | echo "🔍 Checking build..." 32 | python -m twine check dist/* 33 | 34 | # 6. Commit changes 35 | echo "💾 Committing changes..." 36 | git add . 37 | git commit -m "Release v$VERSION: $MESSAGE" || echo "No changes to commit" 38 | 39 | # 7. Create and push tag 40 | echo "🏷️ Creating and pushing tag..." 41 | git tag -a "v$VERSION" -m "Version $VERSION: $MESSAGE" 42 | git push origin main 43 | git push origin "v$VERSION" 44 | 45 | # 8. Upload to PyPI (will prompt for credentials) 46 | echo "📤 Uploading to PyPI..." 47 | echo "⚠️ About to upload to PyPI. Make sure you have your credentials ready!" 48 | read -p "Continue? (y/N): " -n 1 -r 49 | echo 50 | if [[ $REPLY =~ ^[Yy]$ ]]; then 51 | python -m twine upload dist/* 52 | echo "✅ Successfully uploaded to PyPI!" 53 | else 54 | echo "❌ Upload cancelled. You can manually upload later with:" 55 | echo " python -m twine upload dist/*" 56 | fi 57 | 58 | echo "🎉 Release process completed!" 59 | echo "📋 Summary:" 60 | echo " - Version: $VERSION" 61 | echo " - Git tag: v$VERSION pushed to origin" 62 | echo " - PyPI package: uploaded (if confirmed)" 63 | echo "" 64 | echo "🔗 Check your release at:" 65 | echo " - GitHub: https://github.com/zhuang-li/FACTUAL/releases/tag/v$VERSION" 66 | echo " - PyPI: https://pypi.org/project/FactualSceneGraph/$VERSION/" -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | git+https://github.com/openai/CLIP.git 2 | scipy==1.11.1 3 | sentence_transformers==2.2.2 4 | spacy 5 | nltk -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='FactualSceneGraph', 5 | version='0.6.1', 6 | author='Zhuang Li', 7 | author_email='lizhuang144@gmail.com', 8 | description='A package for scene graph parsing and evaluation with multi-sentence support', 9 | long_description=open('README.md').read(), 10 | long_description_content_type='text/markdown', 11 | url='https://github.com/zhuang-li/FACTUAL', 12 | package_dir={'': "src"}, 13 | packages=find_packages(where='src'), 14 | include_package_data=True, 15 | package_data={ 16 | 'factual_scene_graph.evaluation': ['resources/*'], 17 | }, 18 | install_requires=[ 19 | 'torch', 20 | 'transformers', 21 | 'tqdm', 22 | 'nltk', 23 | 'spacy', 24 | 'sentence-transformers', 25 | 'pandas', 26 | 'numpy', 27 | 'tabulate' 28 | # Add other dependencies needed for your package 29 | ], 30 | classifiers=[ 31 | 'Development Status :: 3 - Alpha', 32 | # Add additional classifiers as appropriate for your project 33 | 'Intended Audience :: Developers', 34 | 'License :: OSI Approved :: MIT License', 35 | 'Programming Language :: Python :: 3', 36 | 'Programming Language :: Python :: 3.7', 37 | 'Programming Language :: Python :: 3.8', 38 | 'Operating System :: OS Independent', 39 | ], 40 | ) 41 | -------------------------------------------------------------------------------- /src/factual_scene_graph/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuang-li/FactualSceneGraph/944b284e2b983113e05ee3f36da30707db622b7e/src/factual_scene_graph/__init__.py -------------------------------------------------------------------------------- /src/factual_scene_graph/evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuang-li/FactualSceneGraph/944b284e2b983113e05ee3f36da30707db622b7e/src/factual_scene_graph/evaluation/__init__.py -------------------------------------------------------------------------------- /src/factual_scene_graph/evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import random 2 | import re 3 | 4 | import nltk 5 | import spacy 6 | from nltk import WordNetLemmatizer 7 | from sentence_transformers import SentenceTransformer 8 | from tqdm import tqdm 9 | 10 | from .set_match_evaluation import eval_set_match 11 | from .spice_evaluation import eval_spice 12 | from .soft_spice_evaluation import * 13 | from ..utils import is_graph_format, space_out_symbols_in_graph 14 | import logging 15 | 16 | # Set up logging configuration (adjust the level and format as needed) 17 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 18 | 19 | 20 | class Evaluator: 21 | def __init__(self, parser=None, text_encoder_checkpoint=None, device='cuda:0', lemmatize=False): 22 | 23 | self.parser = parser 24 | if text_encoder_checkpoint: 25 | self.text_encoder = SentenceTransformer(text_encoder_checkpoint).to(device).eval() 26 | self.lemmatize = lemmatize 27 | 28 | if lemmatize: 29 | # # Load spacy model for lemmatization 30 | # self.nlp = spacy.load('en_core_web_sm', disable=['parser', 'ner']) 31 | # 32 | # # Define a custom token match function 33 | # def custom_token_match(text): 34 | # # Match patterns like 'v:', 'pv:', ':1', ':2', etc. 35 | # return re.match(r'(v:|pv:|:\d+)', text) 36 | # 37 | # # Update the tokenizer with the custom token match 38 | # self.nlp.tokenizer = Tokenizer(self.nlp.vocab, token_match=custom_token_match) 39 | self.lemmatizer = WordNetLemmatizer() 40 | 41 | def _process_graphs(self, graph_string): 42 | """ 43 | Perform text processing: lemmatization. 44 | 45 | :param text: A string containing the text to be processed. 46 | :return: Processed text as a string. 47 | """ 48 | 49 | 50 | 51 | if self.lemmatize: 52 | # Tokenize the text using spaCy 53 | # doc = self.nlp(graph_string) 54 | # 55 | # # Define a set of words to exclude from lemmatization 56 | # # Lemmatize tokens that are not in exclude_words and are not punctuations 57 | # lemmatized_tokens = [token.lemma_ if token.text not in exclude_words and not token.is_punct else token.text for token in doc] 58 | # 59 | # # Join the lemmatized tokens back into a string 60 | # graph_string = ' '.join(lemmatized_tokens) 61 | # Lemmatize each word in the text 62 | exclude_words = {'is', ',', '(', ')'} 63 | tokens = graph_string.split(' ') 64 | graph_string = ' '.join([self.lemmatizer.lemmatize(token) if not token in exclude_words else token for token in tokens]) 65 | 66 | return graph_string 67 | 68 | def evaluate(self, candidates, references, method='spice', batch_size=4, return_graphs=False, **kwargs): 69 | """ 70 | Evaluate scene graphs or text captions. 71 | 72 | :param candidates: List of candidate scene graphs or captions. 73 | :param references: List of List of reference scene graphs or captions. 74 | :param method: Evaluation method ('spice', 'soft_spice', or 'set_match'). 75 | :param batch_size: Batch size for processing. 76 | :param kwargs: Additional arguments for evaluation metrics. 77 | :return: Evaluation scores, and optionally the processed candidates and references. 78 | """ 79 | logging.info("Starting evaluation...") 80 | 81 | if not all(isinstance(candidate, str) for candidate in candidates): 82 | raise ValueError("All candidates must be strings.") 83 | 84 | # Ensure references is a list of lists of strings 85 | if not all(isinstance(ref_list, list) and all(isinstance(ref, str) for ref in ref_list) for ref_list in 86 | references): 87 | raise ValueError("References must be a list of lists of strings.") 88 | 89 | # Determine input formats and parse if necessary 90 | candidates, references = self._parse_inputs(candidates, references, batch_size, **kwargs) 91 | 92 | # Choose the evaluation method 93 | method_function = { 94 | 'set_match': self._set_match_score, 95 | 'spice': self._spice_score, 96 | 'soft_spice': self._soft_spice_score 97 | }.get(method) 98 | 99 | logging.info(f"Evaluating using method: {method}") 100 | 101 | if method_function is None: 102 | raise ValueError(f"Unknown evaluation method: {method}") 103 | 104 | # Filter kwargs based on the method 105 | if method == 'soft_spice': 106 | method_kwargs = {k: kwargs[k] for k in kwargs if k in ['bidirectional']} 107 | scores = method_function(candidates, references, batch_size, **method_kwargs) 108 | elif method == 'spice': 109 | method_kwargs = {k: kwargs[k] for k in kwargs if k in ['merge_tuples_synonyms', 'synonym_match']} 110 | scores = method_function(candidates, references, **method_kwargs) 111 | else: 112 | scores = method_function(candidates, references) 113 | 114 | logging.info("Evaluation completed.") 115 | 116 | # multiply with 100 117 | scores = [100 * score for score in scores] 118 | 119 | # Return results 120 | return (scores, candidates, references) if return_graphs else scores 121 | 122 | def _parse_inputs(self, candidates, references, batch_size, **kwargs): 123 | """ 124 | Parse inputs if they are not in graph format. 125 | 126 | :param candidates: List of candidate scene graphs or captions. 127 | :param references: List of List of reference scene graphs or captions. 128 | :param batch_size: Batch size for processing. 129 | :param kwargs: Additional arguments for parsing. 130 | :return: Parsed candidates and references. 131 | """ 132 | # Check for parser availability for non-graph formats 133 | if not self._all_items_are_graphs(candidates) and self.parser is None: 134 | raise ValueError("Parser is required for non-graph candidate inputs.") 135 | 136 | # Ensure the structure of references is correct 137 | assert all(isinstance(ref_list, list) for ref_list in 138 | references), "Each reference should be a list of scene graphs or captions." 139 | 140 | # Parse candidates and references if they are not in graph format 141 | parsed_candidates = self._parse_if_needed(candidates, batch_size, is_nested=False, **kwargs) 142 | parsed_references = self._parse_if_needed(references, batch_size, is_nested=True, **kwargs) 143 | 144 | return parsed_candidates, parsed_references 145 | 146 | def _all_items_are_graphs(self, items): 147 | """ 148 | Check if all items in a list are in graph format. 149 | 150 | :param items: List of items (candidates or references). 151 | :return: Boolean indicating if all items are in graph format. 152 | """ 153 | return all(is_graph_format(item) for item in items) 154 | 155 | def _needs_parsing(self, items, is_nested, sample_size=5): 156 | """ 157 | Determine if parsing is needed based on a random sample from a list of items. 158 | 159 | :param items: List of items or nested list of items. 160 | :param is_nested: Boolean indicating if 'items' is a nested list. 161 | :param sample_size: The number of items to sample for checking. 162 | :return: Boolean indicating if parsing is needed. 163 | """ 164 | 165 | # Flatten the list if it is nested 166 | if is_nested: 167 | flattened_items = [item for sublist in items for item in sublist] 168 | else: 169 | flattened_items = items 170 | 171 | # Sample a few items from the list 172 | sampled_items = random.sample(flattened_items, min(sample_size, len(flattened_items))) 173 | 174 | return not all(is_graph_format(item) for item in sampled_items) 175 | 176 | def _parse_if_needed(self, items, batch_size, is_nested, **kwargs): 177 | """ 178 | Parse items if they are not in graph format. Handles both nested and non-nested lists. 179 | Applies lemmatization to parsed graphs if enabled. 180 | 181 | :param items: List or list of lists of items (candidates or references). 182 | :param batch_size: Batch size for processing. 183 | :param is_nested: Boolean indicating if the items list is nested. 184 | :param kwargs: Additional arguments for parsing. 185 | :return: Parsed items, maintaining the original structure. 186 | """ 187 | # Determine whether parsing is needed 188 | logging.info("Determine whether parsing is needed...") 189 | needs_parsing = self._needs_parsing(items, is_nested) 190 | logging.info(f"Parsing is needed: {needs_parsing}") 191 | if needs_parsing: 192 | if is_nested: 193 | logging.info("Parsing references...") 194 | else: 195 | logging.info("Parsing candidates...") 196 | 197 | # Flatten nested list if necessary and parse 198 | flat_list, structure = (self._flatten_nested_list(items) if is_nested else (items, None)) 199 | parsed_flat_list = self.parser.parse(flat_list, batch_size=batch_size, return_text=True, 200 | **kwargs) if needs_parsing else flat_list 201 | 202 | # Apply lemmatization post-processing if enabled 203 | if self.lemmatize: 204 | parsed_flat_list = [self._process_graphs(graph_str) for graph_str in parsed_flat_list] 205 | 206 | parsed_flat_list = [space_out_symbols_in_graph(graph_str) for graph_str in parsed_flat_list] 207 | 208 | # Recover the nested list structure if it was flattened 209 | return self._recover_nested_list_structure(parsed_flat_list, structure) if is_nested else parsed_flat_list 210 | 211 | def _flatten_nested_list(self, nested_list): 212 | """ 213 | Flatten a nested list while keeping track of the original structure. 214 | 215 | :param nested_list: A list of lists to be flattened. 216 | :return: A tuple of the flattened list and a list of lengths of the original sublists. 217 | """ 218 | flat_list = [] 219 | structure = [] 220 | for sublist in nested_list: 221 | flat_list.extend(sublist) 222 | structure.append(len(sublist)) 223 | return flat_list, structure 224 | 225 | def _recover_nested_list_structure(self, flat_list, structure): 226 | """ 227 | Recover the structure of a nested list from the flattened version and the original structure information. 228 | 229 | :param flat_list: Flattened list of items. 230 | :param structure: List of lengths of the original sublists. 231 | :return: Nested list reconstructed from the flat list. 232 | """ 233 | nested_list, index = [], 0 234 | for length in structure: 235 | nested_list.append(flat_list[index:index + length]) 236 | index += length 237 | return nested_list 238 | 239 | def _set_match_score(self, candidates, references): 240 | """ 241 | Set the match score for each candidate and reference pair. 242 | 243 | :param candidates: Candidate scene graphs. 244 | :param references: Reference scene graphs. 245 | :return: List of match scores. 246 | """ 247 | scores = [] 248 | for cand, refs in zip(candidates, references): 249 | score = eval_set_match(cand, refs) 250 | scores.append(score) 251 | return scores 252 | 253 | def _spice_score(self, candidates, references, merge_tuples_synonyms=False, synonym_match=True): 254 | """ 255 | Compute SPICE score. 256 | 257 | :param candidates: List of Candidate scene graphs. 258 | :param references: List of List of Reference scene graphs. 259 | :return: List of SPICE scores. 260 | """ 261 | scores = [] 262 | for cand, refs in tqdm(zip(candidates, references), total=len(candidates)): 263 | score = eval_spice(cand, refs, merge_tuples_synonyms, synonym_match) 264 | scores.append(score) 265 | return scores 266 | 267 | def _soft_spice_score(self, candidates, references, batch_size, bidirectional=False): 268 | """ 269 | Compute Soft SPICE scores for a batch of candidates and references. 270 | 271 | :param candidates: A list of candidate scene graphs. 272 | :param references: A list of reference scene graphs corresponding to the candidates. 273 | :param batch_size: Batch size to be used for encoding. 274 | :param bidirectional: Whether to use bidirectional encoding. 275 | :return: A list of Soft SPICE scores for each candidate. 276 | """ 277 | all_cand_phrases, all_ref_phrases, cand_lengths, ref_lengths = accumulate_phrases(candidates, references) 278 | encoded_cands, encoded_refs = encode_phrases(self.text_encoder, all_cand_phrases, all_ref_phrases, batch_size) 279 | scores = compute_scores(encoded_cands, encoded_refs, cand_lengths, ref_lengths, bidirectional) 280 | return scores 281 | 282 | 283 | -------------------------------------------------------------------------------- /src/factual_scene_graph/evaluation/resources/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuang-li/FactualSceneGraph/944b284e2b983113e05ee3f36da30707db622b7e/src/factual_scene_graph/evaluation/resources/__init__.py -------------------------------------------------------------------------------- /src/factual_scene_graph/evaluation/set_match_evaluation.py: -------------------------------------------------------------------------------- 1 | from ..utils import get_seg_list, space_out_symbols_in_graph 2 | 3 | 4 | def eval_set_match(cand, refs): 5 | """ 6 | Evaluate the set match score between source and reference phrases. 7 | 8 | :param src_phrases: Source phrases. 9 | :param ref_phrases: Reference phrases. 10 | :return: Calculated set match score. 11 | """ 12 | 13 | cand_phrases = get_seg_list(cand) 14 | ref_phrases = get_seg_list(refs) 15 | 16 | return len(cand_phrases) == len(ref_phrases) and (sorted(cand_phrases) == sorted(ref_phrases)) 17 | -------------------------------------------------------------------------------- /src/factual_scene_graph/evaluation/soft_spice_evaluation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from packaging import version 3 | from ..utils import get_seg_list 4 | 5 | def _get_graph_phrases(graph_str_list, type_dict): 6 | """ 7 | Extract phrases from graph strings and classify their types. 8 | 9 | :param graph_str_list: List of graph strings. 10 | :param type_dict: Dictionary to store type information. 11 | :return: A list of unique phrases. 12 | """ 13 | seg_list = get_seg_list(graph_str_list) 14 | new_pairs = set() 15 | 16 | for seg in seg_list: 17 | new_seg = [item.strip() for item in seg.split(',')] 18 | if len(new_seg) == 1: 19 | _handle_single_element(new_seg, new_pairs, type_dict) 20 | elif len(new_seg) == 2: 21 | _handle_two_elements(new_seg, new_pairs, type_dict) 22 | elif len(new_seg) >= 3: 23 | _handle_three_or_more_elements(new_seg, new_pairs, type_dict) 24 | 25 | return list(new_pairs) 26 | 27 | 28 | def _handle_single_element(seg, pairs, type_dict): 29 | pairs.add(seg[0]) 30 | type_dict[seg[0]] = "entity" 31 | 32 | 33 | def _handle_two_elements(seg, pairs, type_dict): 34 | pairs.add(f"{seg[1]} {seg[0]}") 35 | type_dict[f"{seg[1]} {seg[0]}"] = "attribute" 36 | pairs.add(seg[0]) 37 | type_dict[seg[0]] = "entity" 38 | 39 | 40 | def _handle_three_or_more_elements(seg, pairs, type_dict): 41 | if seg[1] == 'is': 42 | pairs.add(f"{seg[2]} {seg[0]}") 43 | type_dict[f"{seg[2]} {seg[0]}"] = "attribute" 44 | else: 45 | phrase = f"{seg[0]} {' '.join(seg[1:-1])} {seg[-1]}" 46 | pairs.add(phrase) 47 | type_dict[phrase] = "relation" 48 | pairs.add(seg[-1]) 49 | type_dict[seg[-1]] = "entity" 50 | 51 | pairs.add(seg[0]) 52 | type_dict[seg[0]] = "entity" 53 | 54 | def encode_phrases(text_encoder, all_cand_phrases, all_ref_phrases, batch_size): 55 | all_encoded_phrases = _normalize_features( 56 | text_encoder.encode(all_cand_phrases + all_ref_phrases, batch_size=batch_size) 57 | ) 58 | encoded_cands, encoded_refs = np.split(all_encoded_phrases, [len(all_cand_phrases)]) 59 | return encoded_cands, encoded_refs 60 | 61 | 62 | def accumulate_phrases(candidates, references): 63 | all_cand_phrases = [] 64 | all_ref_phrases = [] 65 | cand_lengths = [] 66 | ref_lengths = [] 67 | type_dict = {} 68 | 69 | for cand, refs in zip(candidates, references): 70 | cand_phrases = _get_graph_phrases(cand, type_dict) 71 | ref_phrases = _get_graph_phrases(refs, type_dict) 72 | 73 | all_cand_phrases.extend(cand_phrases) 74 | all_ref_phrases.extend(ref_phrases) 75 | 76 | cand_lengths.append(len(cand_phrases)) 77 | ref_lengths.append(len(ref_phrases)) 78 | 79 | return all_cand_phrases, all_ref_phrases, cand_lengths, ref_lengths 80 | 81 | 82 | def compute_scores(encoded_cands, encoded_refs, cand_lengths, ref_lengths, bidirectional=False): 83 | scores = [] 84 | ref_start_idx = 0 85 | 86 | for cand_len, ref_len in zip(cand_lengths, ref_lengths): 87 | cand_feats = encoded_cands[:cand_len] 88 | encoded_cands = encoded_cands[cand_len:] # Update the remaining candidate features 89 | 90 | ref_feats = encoded_refs[ref_start_idx:ref_start_idx + ref_len] 91 | ref_start_idx += ref_len 92 | 93 | all_sims = cand_feats.dot(ref_feats.T) 94 | score_per_phrase = np.max(all_sims, axis=1) 95 | 96 | if bidirectional: 97 | score_per_phrase_ = np.max(all_sims, axis=0) # recall-like 98 | # Normalize each score from [-1,1] to [0,1] before averaging 99 | norm_scores_precision = (score_per_phrase + 1) / 2 100 | norm_scores_recall = (score_per_phrase_ + 1) / 2 101 | 102 | precision = np.mean(norm_scores_precision) 103 | recall = np.mean(norm_scores_recall) 104 | 105 | # Use harmonic mean (F1-like) 106 | if precision + recall > 0: 107 | softspice_bi = 2 * (precision * recall) / (precision + recall) 108 | else: 109 | softspice_bi = 0.0 110 | scores.append(softspice_bi) 111 | else: 112 | scores.append(np.mean(score_per_phrase)) 113 | 114 | return scores 115 | 116 | 117 | 118 | def _normalize_features(features): 119 | """ 120 | Normalize feature vectors. 121 | 122 | :param features: Feature vectors. 123 | :return: Normalized feature vectors. 124 | """ 125 | return features / np.sqrt(np.sum(features ** 2, axis=1, keepdims=True)) 126 | -------------------------------------------------------------------------------- /src/factual_scene_graph/evaluation/spice_evaluation.py: -------------------------------------------------------------------------------- 1 | from nltk.corpus import wordnet 2 | 3 | from .synonym_dictionary import synonym_dictionary 4 | from ..utils import get_seg_list 5 | 6 | def eval_spice(cand, refs, merge_tuples_synonyms=True, synonym_match=True): 7 | """ 8 | Evaluate the SPICE metric. 9 | 10 | :param cand_tuples: Candidate tuples from scene graph. 11 | :param ref_tuples: Reference tuples for comparison. 12 | :return: Calculated SPICE score. 13 | """ 14 | 15 | cand_tuples = get_graph_tuples(cand, merge_tuples_synonyms) 16 | # print(refs) 17 | ref_tuples = get_graph_tuples(refs, merge_tuples_synonyms) 18 | # print(ref_tuples) 19 | # breakpoint() 20 | 21 | # print(cand_tuples) 22 | # print(ref_tuples) 23 | 24 | return calculate_spice_score(cand_tuples, ref_tuples, synonym_match) 25 | 26 | def calculate_spice_score(cand_tuples, ref_tuples, synonym_match): 27 | matched_cand_indices = set() 28 | matched_ref_indices = set() 29 | total_matches = 0 30 | 31 | # First pass: Exact matches 32 | for i, cand in enumerate(cand_tuples): 33 | for j, ref in enumerate(ref_tuples): 34 | if are_tuples_match(cand, ref) and j not in matched_ref_indices: 35 | matched_cand_indices.add(i) 36 | matched_ref_indices.add(j) 37 | total_matches += 1 38 | break 39 | 40 | if synonym_match: 41 | # Second pass: WordNet-based similar matches (for unmatched candidates) 42 | for i, cand in enumerate(cand_tuples): 43 | if i not in matched_cand_indices: 44 | for j, ref in enumerate(ref_tuples): 45 | if j not in matched_ref_indices and similar_to_any(cand, [ref]): 46 | # print("Synonym match") 47 | # print(cand) 48 | # print(ref) 49 | matched_ref_indices.add(j) 50 | total_matches += 1 51 | break 52 | # Calculate precision, recall, and F1 score 53 | precision = calculate_score(total_matches, len(cand_tuples)) 54 | recall = calculate_score(total_matches, len(ref_tuples)) 55 | 56 | assert precision <= 1 and recall <= 1, "Precision or recall is greater than 1, total_matches {0}, len(cand_tuples) {1}, len(ref_tuples) {2}".format(total_matches, len(cand_tuples), len(ref_tuples)) 57 | 58 | return calculate_f1(precision, recall) 59 | 60 | def similar_to_any(candidate, references): 61 | """ 62 | Check if a candidate is similar to any reference tuples. 63 | 64 | :param candidate: The candidate tuple to compare. 65 | :param references: A list of reference tuples. 66 | :return: True if similar to any reference, False otherwise. 67 | """ 68 | candidate_synsets = get_synsets(candidate) 69 | 70 | return any(are_tuples_match(candidate_synsets, get_synsets(ref)) for ref in references) 71 | 72 | def get_synsets_for_word_set(word_set): 73 | return set().union(*[word_to_synset(word) for word in word_set]) 74 | 75 | def get_synsets(words): 76 | """ 77 | Get synsets for a list of words. 78 | 79 | :param words: A list of words. 80 | :return: A set of synsets for the words. 81 | """ 82 | return [get_synsets_for_word_set(word_set) for word_set in words] 83 | 84 | 85 | def word_to_synset(word): 86 | """ 87 | Process a word into its synsets. 88 | 89 | :param word: The word to process. 90 | :return: A set of synsets for the word. 91 | """ 92 | word = ' '.join(word.strip().lower().split()) 93 | lemma_synset = set() 94 | 95 | # If the word consists of multiple parts, join them with an underscore 96 | word_split = word.split() 97 | if len(word_split) >= 2: 98 | word = "_".join(word_split) 99 | 100 | # # Add the word itself to the synset set 101 | # lemma_synset.add(word) 102 | # 103 | # # Add all synsets of the word to the set 104 | # for sys in wordnet.synsets(word): 105 | # for lemma in sys.lemmas(): 106 | # lemma_synset.add(lemma.name()) 107 | 108 | lemma_synset.update(synonym_dictionary.get_synsets(word)) 109 | lemma_synset.update(synonym_dictionary.get_stem_synsets(word)) 110 | 111 | return lemma_synset 112 | 113 | def are_tuples_match(synsets1, synsets2): 114 | """ 115 | Determine if two lists of synsets have non-empty intersections for corresponding elements. 116 | 117 | :param synsets1: First list of synsets. 118 | :param synsets2: Second list of synsets. 119 | :return: True if all corresponding synsets have a non-empty intersection, False otherwise. 120 | """ 121 | 122 | return len(synsets1) == len(synsets2) and all(s1.intersection(s2) for s1, s2 in zip(synsets1, synsets2)) 123 | 124 | def calculate_score(match_count, total_count): 125 | """ 126 | Calculate precision or recall. 127 | 128 | :param match_count: The count of matched tuples. 129 | :param total_count: The total count of tuples. 130 | :return: The calculated score. 131 | """ 132 | return match_count / total_count if total_count > 0 else 0 133 | 134 | def calculate_f1(precision, recall): 135 | """ 136 | Calculate the F1 score from precision and recall. 137 | 138 | :param precision: Precision value. 139 | :param recall: Recall value. 140 | :return: The calculated F1 score. 141 | """ 142 | return 2 * precision * recall / (precision + recall) if precision + recall > 0 else 0 143 | 144 | def get_graph_tuples(graph_str_list, merge_tuples_synonyms=True): 145 | """ 146 | Get tuples from a scene graph. 147 | """ 148 | seg_list = get_seg_list(graph_str_list) 149 | selected_obj_set = set() 150 | tuples = [] 151 | 152 | for hyp in seg_list: 153 | lf_seg = [token.strip() for token in hyp.split(',')] 154 | seg_len = len(lf_seg) 155 | 156 | # Handle segments based on their length 157 | if seg_len == 1: 158 | add_unique_tuple(lf_seg[0], tuples, selected_obj_set) 159 | elif seg_len >= 2: 160 | process_lf_segment(lf_seg, tuples, selected_obj_set, seg_len) 161 | 162 | if merge_tuples_synonyms: 163 | return merge_tuples_based_on_synonyms(sorted(tuples, key=tuple_sort_key)) 164 | else: 165 | return sorted(tuples, key=tuple_sort_key) 166 | 167 | def tuple_sort_key(t): 168 | """ 169 | Generate a sort key for a tuple of sets of strings. 170 | 171 | :param t: A tuple, each element of which is a set of strings. 172 | :return: A string that represents the sorted contents of the sets. 173 | """ 174 | sorted_sets = [' '.join(sorted(s)) for s in t] # Sort each set and join into strings 175 | return ' '.join(sorted_sets) # Join the sorted strings from each set 176 | 177 | 178 | def merge_elements_by_synsets(tuples, position, check_length, unique_sets=None): 179 | """ 180 | Generalized function to merge elements within tuples based on synsets. 181 | """ 182 | is_shared_set = unique_sets is not None 183 | unique_sets = unique_sets or [] 184 | 185 | for t in tuples: 186 | if len(t) == check_length: 187 | merge_found = False 188 | for i, merged_set in enumerate(unique_sets): 189 | if len(get_synsets_for_word_set(t[position]).intersection(get_synsets_for_word_set(merged_set))) > 0: 190 | merge_found = True 191 | unique_sets[i] = merged_set.union(t[position]) 192 | break 193 | if not merge_found and not is_shared_set: 194 | unique_sets.append(t[position]) 195 | 196 | for t in tuples: 197 | if len(t) == check_length: 198 | for merged_set in unique_sets: 199 | if len(get_synsets_for_word_set(t[position]).intersection(get_synsets_for_word_set(merged_set))) > 0: 200 | t[position].update(merged_set) 201 | break 202 | 203 | return unique_sets if not is_shared_set else None 204 | 205 | 206 | def merge_tuples_based_on_synonyms(tuples): 207 | """ 208 | Merge nodes, attributes, and relations in tuples based on synonyms. 209 | """ 210 | # Create a shared unique set for specific merge operations 211 | shared_unique_set = merge_elements_by_synsets(tuples, 0, 1) # Merging nodes and initializing shared set 212 | 213 | # Use the shared set for merging nodes in three-element tuples and at the end of three-element tuples 214 | merge_elements_by_synsets(tuples, 0, 3, shared_unique_set) 215 | merge_elements_by_synsets(tuples, 2, 3, shared_unique_set) 216 | 217 | # Merging attributes in two-element tuples 218 | merge_elements_by_synsets(tuples, 1, 2) 219 | 220 | # Merging relations in three-element tuples 221 | merge_elements_by_synsets(tuples, 1, 3) 222 | 223 | merged_tuples = merge_tuples(tuples) 224 | 225 | return merged_tuples 226 | 227 | def merge_tuples(tuples): 228 | """ 229 | Merge tuples if they have similar elements. 230 | """ 231 | merged_tuples = [] 232 | for t in tuples: 233 | merge_found = False 234 | for i, mt in enumerate(merged_tuples): 235 | if similar_to_any(t, [mt]): 236 | merged_tuples[i] = merge_two_tuples(t, mt) 237 | merge_found = True 238 | break 239 | if not merge_found: 240 | merged_tuples.append(t) 241 | 242 | return merged_tuples 243 | 244 | def merge_two_tuples(tuple1, tuple2): 245 | """ 246 | Merge two tuples that have synonyms in common. 247 | 248 | :param tuple1: A tuple. 249 | :param tuple2: Another tuple. 250 | :return: A merged tuple. 251 | """ 252 | # print(tuple1) 253 | # print(tuple2) 254 | return [t1.union(t2) for t1, t2 in zip(tuple1, tuple2)] 255 | 256 | def add_unique_tuple(item, tuples, selected_obj_set): 257 | """ 258 | Adds a unique tuple from an item. 259 | """ 260 | if item not in selected_obj_set: 261 | tuples.append([{item}]) 262 | selected_obj_set.add(item) 263 | 264 | 265 | def process_lf_segment(lf_seg, tuples, selected_obj_set, seg_len): 266 | """ 267 | Processes a segment of length 2 or more and adds appropriate tuples. 268 | """ 269 | # Construct the tuple string based on segment length 270 | if seg_len == 2 or (seg_len == 3 and lf_seg[1] == 'is'): 271 | tuple_str = lf_seg[0] + ' ' + lf_seg[-1] 272 | if tuple_str not in selected_obj_set: 273 | tuples.append(({lf_seg[0]}, {lf_seg[-1]})) 274 | selected_obj_set.add(tuple_str) 275 | add_unique_tuple(lf_seg[0], tuples, selected_obj_set) 276 | 277 | elif seg_len == 3: 278 | tuple_str = ' '.join(lf_seg) 279 | if tuple_str not in selected_obj_set: 280 | tuples.append(({lf_seg[0]}, {lf_seg[1]}, {lf_seg[2]})) 281 | selected_obj_set.add(tuple_str) 282 | add_unique_tuple(lf_seg[0], tuples, selected_obj_set) 283 | add_unique_tuple(lf_seg[2], tuples, selected_obj_set) 284 | 285 | elif seg_len > 3: 286 | tuple_str = lf_seg[0] + ' ' + ' '.join(lf_seg[1:-1]) + ' ' + lf_seg[-1] 287 | if tuple_str not in selected_obj_set: 288 | tuples.append(({lf_seg[0]}, {" ".join(lf_seg[1:-1])}, {lf_seg[-1]})) 289 | selected_obj_set.add(tuple_str) 290 | add_unique_tuple(lf_seg[0], tuples, selected_obj_set) 291 | add_unique_tuple(lf_seg[-1], tuples, selected_obj_set) 292 | 293 | 294 | 295 | -------------------------------------------------------------------------------- /src/factual_scene_graph/evaluation/synonym_dictionary.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from importlib import resources 3 | 4 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 5 | 6 | class SynonymDictionary: 7 | def __init__(self, exc_file_path, syn_file_path): 8 | self.word_to_bases = {} 9 | self.word_to_synsets = {} 10 | self.set_to_relations = {} 11 | 12 | # logging.info("Reading synonym dictionary..") 13 | # Reading exception file 14 | with open(exc_file_path, 'r', encoding='utf-8') as file: 15 | for base in file: 16 | base = base.strip() 17 | forms = next(file).strip().split() 18 | for form in forms: 19 | self.word_to_bases.setdefault(form, []).append(base) 20 | 21 | # Reading synset file 22 | with open(syn_file_path, 'r', encoding='utf-8') as file: 23 | for word in file: 24 | word = word.strip() 25 | sets = set(map(int, next(file).strip().split())) 26 | self.word_to_synsets[word] = sets 27 | 28 | # logging.info("Reading synonym dictionary.. Done") 29 | def get_synsets(self, word): 30 | return self.word_to_synsets.get(word, set()) 31 | 32 | def get_stem_synsets(self, word): 33 | bases = self.word_to_bases.get(word) 34 | if bases: 35 | sets = set() 36 | for base in bases: 37 | sets.update(self.get_synsets(base)) 38 | return sets 39 | return self.get_synsets(self.morph(word)) 40 | 41 | 42 | def morph(self, word): 43 | sufx = [ 44 | "s", "ses", "xes", "zes", "ches", "shes", "men", "ies", # Noun suffixes 45 | "s", "ies", "es", "es", "ed", "ed", "ing", "ing", # Verb suffixes 46 | "er", "est", "er", "est" # Adjective suffixes 47 | ] 48 | 49 | addr = [ 50 | "", "s", "x", "z", "ch", "sh", "man", "y", # Noun endings 51 | "", "y", "e", "", "e", "", "e", "", # Verb endings 52 | "", "", "e", "e" # Adjective endings 53 | ] 54 | 55 | if word.endswith("ful"): 56 | base = word[:-3] # Remove 'ful' 57 | if base in self.word_to_synsets: 58 | return base 59 | return word 60 | 61 | if word.endswith("ss") or len(word) <= 2: 62 | return word 63 | 64 | for sufx, addr in zip(sufx, addr): 65 | if word.endswith(sufx): 66 | base = word[:-len(sufx)] + addr 67 | if base in self.word_to_synsets: 68 | return base 69 | 70 | return word 71 | 72 | 73 | with resources.open_text('factual_scene_graph.evaluation.resources', 'english.exceptions') as f_exc, \ 74 | resources.open_text('factual_scene_graph.evaluation.resources', 'english.synsets') as f_syn: 75 | synonym_dictionary = SynonymDictionary(f_exc.name, 76 | f_syn.name) 77 | 78 | 79 | -------------------------------------------------------------------------------- /src/factual_scene_graph/parser/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhuang-li/FactualSceneGraph/944b284e2b983113e05ee3f36da30707db622b7e/src/factual_scene_graph/parser/__init__.py -------------------------------------------------------------------------------- /src/factual_scene_graph/parser/scene_graph_parser.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from nltk import WordNetLemmatizer 3 | import nltk 4 | from tqdm import tqdm 5 | from transformers import AutoTokenizer, AutoModelForSeq2SeqLM 6 | 7 | from ..utils import space_out_symbols_in_graph, clean_graph_string, remove_factual_chars 8 | 9 | 10 | class SceneGraphParser: 11 | def __init__(self, checkpoint_path, device='cuda:0', lemmatize=False, lowercase=False, parser_type='default'): 12 | self.device = device 13 | self.tokenizer = AutoTokenizer.from_pretrained(checkpoint_path) 14 | self.model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint_path).to(device).eval() 15 | self.lemmatize = lemmatize 16 | self.lowercase = lowercase 17 | self.parser_type = parser_type 18 | 19 | if lemmatize: 20 | self.lemmatizer = WordNetLemmatizer() 21 | 22 | if parser_type == 'sentence_merge': 23 | # Download required NLTK data if not already downloaded 24 | try: 25 | nltk.data.find('tokenizers/punkt_tab') 26 | except LookupError: 27 | try: 28 | nltk.download('punkt_tab') 29 | except: 30 | # Fallback to older punkt for older NLTK versions 31 | nltk.download('punkt') 32 | 33 | def _process_text(self, text): 34 | """ 35 | Perform text processing: lemmatization and optionally converting to lowercase. 36 | 37 | :param text: A string containing the text to be processed. 38 | :return: Processed text as a string. 39 | """ 40 | 41 | if self.lemmatize: 42 | # Lemmatize each word in the text 43 | tokens = text.split(' ') 44 | text = ' '.join([self.lemmatizer.lemmatize(token) for token in tokens]) 45 | 46 | 47 | 48 | if self.lowercase: 49 | text = text.lower() 50 | 51 | return text 52 | 53 | def _merge_graphs(self, graph_texts): 54 | """ 55 | Merge multiple graph texts and remove duplicates. 56 | 57 | :param graph_texts: List of graph texts to merge 58 | :return: Merged and deduplicated graph text 59 | """ 60 | # Combine all graph texts 61 | combined_text = ' , '.join(graph_texts) 62 | # Clean and deduplicate 63 | return clean_graph_string(combined_text) 64 | 65 | def parse(self, descriptions, max_input_len=64, max_output_len=128, beam_size=5, return_text=False, filter_factual_chars=False, batch_size=32): 66 | if isinstance(descriptions, str): 67 | descriptions = [descriptions] 68 | 69 | if self.parser_type == 'sentence_merge': 70 | # Process all sentences from all descriptions in one batch 71 | all_sentences = [] 72 | sentence_to_desc_map = [] # Maps each sentence back to its original description 73 | 74 | for desc in descriptions: 75 | sentences = nltk.sent_tokenize(desc) 76 | for sentence in sentences: 77 | sentence = sentence.strip() 78 | if sentence: # Skip empty sentences 79 | all_sentences.append(sentence) 80 | sentence_to_desc_map.append(len(descriptions) - 1) # Store index of original description 81 | 82 | if all_sentences: 83 | # Process all sentences in batches 84 | processed_sentences = [self._process_text(sent) for sent in all_sentences] 85 | sentence_graphs = self._parse_batch(processed_sentences, max_input_len, max_output_len, 86 | beam_size, filter_factual_chars, batch_size) 87 | 88 | # Merge graphs for each original description 89 | processed_descriptions = [""] * len(descriptions) 90 | for i, graph in enumerate(sentence_graphs): 91 | if graph: # Only merge non-empty graphs 92 | desc_idx = sentence_to_desc_map[i] 93 | if processed_descriptions[desc_idx]: 94 | processed_descriptions[desc_idx] = self._merge_graphs([processed_descriptions[desc_idx], graph]) 95 | else: 96 | processed_descriptions[desc_idx] = graph 97 | else: 98 | processed_descriptions = [""] * len(descriptions) 99 | else: # default parser 100 | processed_descriptions = [self._process_text(desc) for desc in descriptions] 101 | processed_descriptions = self._parse_batch(processed_descriptions, max_input_len, max_output_len, 102 | beam_size, filter_factual_chars, batch_size) 103 | 104 | if return_text: 105 | return processed_descriptions 106 | else: 107 | return [self.graph_string_to_object(text) for text in processed_descriptions] 108 | 109 | def _parse_batch(self, descriptions, max_input_len, max_output_len, beam_size, filter_factual_chars, batch_size): 110 | """Helper method to parse a batch of descriptions""" 111 | all_formatted_texts = [] 112 | 113 | # Process descriptions in batches 114 | for i in tqdm(range(0, len(descriptions), batch_size)): 115 | batch_descriptions = descriptions[i:i + batch_size] 116 | prompt_texts = ['Generate Scene Graph: ' + desc.strip() for desc in batch_descriptions] 117 | with torch.no_grad(): 118 | encoded_inputs = self.tokenizer( 119 | prompt_texts, 120 | max_length=max_input_len, 121 | truncation=True, 122 | padding=True, 123 | return_tensors='pt' 124 | ) 125 | tokens = encoded_inputs['input_ids'].to(self.device) 126 | attention_masks = encoded_inputs['attention_mask'].to(self.device) 127 | 128 | early_stopping = beam_size > 1 129 | 130 | generated_ids = self.model.generate( 131 | tokens, 132 | attention_mask=attention_masks, 133 | use_cache=True, 134 | decoder_start_token_id=self.tokenizer.pad_token_id, 135 | num_beams=beam_size, 136 | max_length=max_output_len, 137 | early_stopping=early_stopping, 138 | num_return_sequences=1, 139 | ) 140 | 141 | # Decoding the output 142 | generated_texts = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True, 143 | clean_up_tokenization_spaces=True) 144 | 145 | if filter_factual_chars: 146 | generated_texts = [remove_factual_chars(text) for text in generated_texts] 147 | 148 | formatted_texts = [clean_graph_string( 149 | space_out_symbols_in_graph(text.replace('Generate Scene Graph:', '').strip())) for text in 150 | generated_texts] 151 | all_formatted_texts.extend(formatted_texts) 152 | 153 | return all_formatted_texts 154 | 155 | def graph_string_to_object(self, graph_text): 156 | graph = {'entities': [], 'relations': []} 157 | entity_map = {} # Entity name to index mapping 158 | 159 | # Process each relation in the description 160 | relation_strs = graph_text.strip().split(') ,') 161 | for relation_str in relation_strs: 162 | relation_str = relation_str.strip().strip('()') 163 | parts = [part.strip() for part in relation_str.split(',')] 164 | 165 | if len(parts) != 3 and len(relation_strs) > 1: 166 | continue # Skip malformed relations 167 | elif len(parts) != 3 and len(relation_strs) == 1: 168 | self._get_or_create_entity_index(parts[0], graph, entity_map) 169 | else: 170 | subject, relationship, object_ = parts 171 | 172 | subject_index = self._get_or_create_entity_index(subject, graph, entity_map) 173 | 174 | if relationship == 'is': 175 | if object_.isdigit(): # Quantity 176 | graph['entities'][subject_index]['quantity'] = object_ 177 | else: # Attribute 178 | graph['entities'][subject_index]['attributes'].add(object_) 179 | else: 180 | object_index = self._get_or_create_entity_index(object_, graph, entity_map) 181 | # Add relation 182 | graph['relations'].append({'subject': subject_index, 'relation': relationship, 'object': object_index}) 183 | 184 | return graph 185 | 186 | def _get_or_create_entity_index(self, entity_name, graph, entity_map): 187 | if entity_name not in entity_map: 188 | new_index = len(graph['entities']) 189 | graph['entities'].append({'head': entity_name, 'quantity': '', 'attributes': set()}) 190 | entity_map[entity_name] = new_index 191 | else: 192 | new_index = entity_map[entity_name] 193 | 194 | return new_index 195 | 196 | 197 | # unit test main 198 | 199 | 200 | -------------------------------------------------------------------------------- /src/factual_scene_graph/utils.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import re 3 | 4 | import tabulate 5 | 6 | 7 | def get_seg_list(graphs): 8 | """ 9 | Extracts segments from a single graph or a list of graphs. 10 | 11 | :param graphs: A single graph string or a list of graph strings. 12 | :return: A list of unique segments from the graph(s). 13 | """ 14 | 15 | def extract_segments(graph): 16 | """ 17 | Extract segments from an individual graph string. 18 | 19 | :param graph: A single graph string. 20 | :return: A list of segments from the graph. 21 | """ 22 | formatted_graph = space_out_symbols_in_graph(graph) 23 | return [segment.strip('()').strip() for segment in formatted_graph.split(') , (')] 24 | 25 | if isinstance(graphs, str): 26 | segments = extract_segments(graphs) 27 | elif isinstance(graphs, list): 28 | segments = [seg for graph in graphs for seg in extract_segments(graph)] 29 | else: 30 | raise TypeError('Input must be a string or a list of strings') 31 | 32 | # remove the duplicates, note this might influence the evaluation performance 33 | return list(set(segments)) 34 | 35 | 36 | def space_out_symbols_in_graph(graph_str): 37 | # Add spaces around parentheses and commas, then split into words 38 | formatted_str = graph_str.replace('(', ' ( ').replace(')', ' ) ').replace(',', ' , ') 39 | 40 | # Use strip to remove leading/trailing whitespace and join the words back into a string 41 | return ' '.join(word.strip() for word in formatted_str.split()) 42 | 43 | 44 | def is_graph_format(input_string): 45 | """ 46 | Check if the input string follows the graph format. 47 | 48 | :param input_string: A string to check. 49 | :return: True if the string contains elements in parentheses, False otherwise. 50 | """ 51 | # Pattern to match any content within parentheses 52 | graph_pattern = r"\(.*?\)" 53 | 54 | return bool(re.search(graph_pattern, input_string)) 55 | 56 | def remove_factual_chars(text): 57 | """ 58 | Remove specific substrings from the text, including patterns like ':digit'. 59 | 60 | :param text: Input text from which substrings will be removed. 61 | :return: Text after removing specific substrings. 62 | """ 63 | # Direct string replacements 64 | replacements = ['v:', 'pv:'] 65 | for replacement in replacements: 66 | text = text.replace(replacement, '') 67 | 68 | # Using regular expression to remove patterns like ':digit' 69 | text = re.sub(r':\d+', '', text) 70 | 71 | return text 72 | 73 | def clean_graph_string(fact_str): 74 | # Split the string into individual facts 75 | facts = fact_str.strip().split(') ,') 76 | # remove truncated parentheses 77 | if not fact_str.endswith(')') and len(facts) > 1: 78 | facts = facts[:-1] 79 | elif not fact_str.endswith(')') and len(facts) == 1: 80 | facts = [facts[0].split(',')[0]] 81 | # Use a set to filter out duplicate facts 82 | unique_facts = set() 83 | for fact in facts: 84 | fact = fact.strip().strip('()').strip() 85 | if fact: 86 | unique_facts.add(fact) 87 | 88 | # sort unique_facts 89 | 90 | unique_facts = sorted(unique_facts) 91 | 92 | # Reconstruct the string with unique facts 93 | unique_fact_str = ' , '.join([f'( {fact} )' for fact in unique_facts]) 94 | return unique_fact_str 95 | 96 | 97 | def tprint(graph, file=None): 98 | """ 99 | Print a scene graph as a table. 100 | The printed strings contain essential information about the parsed scene graph. 101 | """ 102 | assert isinstance(graph, dict), 'Input must be a dictionary' 103 | _print = functools.partial(print, file=file) 104 | 105 | _print('Entities:') 106 | entities_data = [ 107 | [e['head'].lower(), e.get('quantity', ''), ','.join(e.get('attributes', set()))] 108 | for e in graph['entities'] 109 | ] 110 | _print(tabulate.tabulate(entities_data, headers=['Entity', 'Quantity', 'Attributes'], tablefmt=_tabulate_format)) 111 | 112 | _print('Relations:') 113 | relations_data = [ 114 | [ 115 | graph['entities'][rel['subject']]['head'].lower(), 116 | rel['relation'].lower(), 117 | graph['entities'][rel['object']]['head'].lower() 118 | ] 119 | for rel in graph['relations'] 120 | ] 121 | _print(tabulate.tabulate(relations_data, headers=['Subject', 'Relation', 'Object'], tablefmt=_tabulate_format)) 122 | 123 | 124 | _tabulate_format = tabulate.TableFormat( 125 | lineabove=tabulate.Line("+", "-", "+", "+"), 126 | linebelowheader=tabulate.Line("|", "-", "+", "|"), 127 | linebetweenrows=None, 128 | linebelow=tabulate.Line("+", "-", "+", "+"), 129 | headerrow=tabulate.DataRow("|", "|", "|"), 130 | datarow=tabulate.DataRow("|", "|", "|"), 131 | padding=1, with_header_hide=None 132 | ) 133 | 134 | 135 | 136 | -------------------------------------------------------------------------------- /tests/test_eval_spice.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from factual_scene_graph.evaluation.spice_evaluation import calculate_spice_score, eval_spice 4 | logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 5 | 6 | if __name__ == "__main__": 7 | cand = [ 8 | ({"jean"}, {"be"}, {"blue"}), 9 | [{"jean"}], 10 | [{"blue"}], 11 | ({"man"}, {"wear"}, {"jeans"}), 12 | [{"man"}], 13 | [{"jeans"}] 14 | ] 15 | 16 | refs = [ 17 | ({"man"}, {"wear"}, {"jean"}), 18 | [{"man"}], 19 | [{"jean"}], 20 | ({"jean"}, {"be"}, {"blue"}), 21 | [{"blue"}] 22 | ] 23 | print(calculate_spice_score(cand, refs, synonym_match=True)) 24 | 25 | cand_graph = '( jean , is , blue ) , ( jean , is , blue )' 26 | 27 | ref_graphs = ['( jean , is , blue ) , ( jeans , is , blue )', '( jean , is , blue ) , ( sea , is , blue ) , ( ocean , is , blue )'] 28 | 29 | print(eval_spice(cand_graph, ref_graphs, merge_tuples_synonyms=True, synonym_match=True)) 30 | 31 | -------------------------------------------------------------------------------- /tests/test_evaluator.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | 4 | from factual_scene_graph.evaluation.evaluator import Evaluator 5 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 6 | 7 | device = "cuda" if torch.cuda.is_available() else "cpu" 8 | def test_scene_graph_parsing(): 9 | 10 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device=device) 11 | evaluator = Evaluator(parser=parser, device='cuda:0') 12 | scores = evaluator.evaluate(["2 beautiful pigs are flying on the sky with 2 bags on their backs"],[['( pigs , is , beautiful ) , ( bags , on back of , pigs ) , ( bags , is , 2 ) , ( pigs , is , 2 ) , ( pigs , fly on , sky )']],method='spice', beam_size=1, max_output_len=128) 13 | print(scores) 14 | 15 | def test_scene_graph_parsing_on_random(): 16 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device=device,lemmatize=False, lowercase=True) 17 | evaluator = Evaluator(parser=parser,text_encoder_checkpoint='all-MiniLM-L6-v2', device=device,lemmatize=True) 18 | 19 | random_data_pd = pd.read_csv('data/factual_sg/random/test.csv') 20 | random_data_captions = random_data_pd['caption'].tolist() 21 | random_data_graphs = [[scene] for scene in random_data_pd['scene_graph'].tolist()] 22 | spice_scores, cand_graphs, ref_graphs = evaluator.evaluate(random_data_captions, random_data_graphs, method='spice', beam_size=1, batch_size=128, max_input_len=256, max_output_len=256, return_graphs=True) 23 | 24 | print('SPICE scores for random test set:') 25 | print(sum(spice_scores)/len(spice_scores)) 26 | 27 | set_match_scores = evaluator.evaluate(cand_graphs, ref_graphs,method='set_match', beam_size=1) 28 | 29 | print('Set Match scores for random test set:') 30 | print(sum(set_match_scores)/len(set_match_scores)) 31 | 32 | soft_spice_scores = evaluator.evaluate(cand_graphs, ref_graphs,method='soft_spice', beam_size=1) 33 | 34 | print('Soft-SPICE scores for random test set:') 35 | print(sum(soft_spice_scores)/len(soft_spice_scores)) 36 | 37 | 38 | 39 | if __name__ == "__main__": 40 | #test_scene_graph_parsing() 41 | test_scene_graph_parsing_on_random() -------------------------------------------------------------------------------- /tests/test_metric_human_correlation.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | import sys 5 | 6 | import numpy as np 7 | import scipy 8 | import torch 9 | from factual_scene_graph.evaluation.evaluator import Evaluator 10 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 11 | 12 | 13 | def collect_unique_captions(candidates, refs): 14 | """Collect unique captions from candidates and references.""" 15 | caption_set = set(candidates) # Add all candidates to the set 16 | for ref_list in refs: 17 | caption_set.update(ref_list) # Add all elements from each list in refs 18 | return list(caption_set) 19 | 20 | def parse_captions(captions, parser): 21 | """Parse captions and return a dictionary of results.""" 22 | parse_results = parser.parse(captions, batch_size=64, max_input_len=256, 23 | max_output_len=256, beam_size=1, return_text=True) 24 | return {caption: result 25 | for caption, result in zip(captions, parse_results)} 26 | 27 | def evaluate_graphs(candidates, refs, parse_dict, evaluator, return_graphs): 28 | """Evaluate the graphs and return the results.""" 29 | cand_graphs = [parse_dict[cand] for cand in candidates] 30 | ref_graphs = [[parse_dict[ref_i] for ref_i in ref] for ref in refs] 31 | return evaluator.evaluate(cand_graphs, ref_graphs, method='spice', 32 | beam_size=1, batch_size=64, max_input_len=256, 33 | max_output_len=256, return_graphs=return_graphs) 34 | 35 | def compute_correlation(input_json, tauvariant='c'): 36 | data = {} 37 | with open(input_json) as f: 38 | data.update(json.load(f)) 39 | print('Loaded {} images'.format(len(data))) 40 | 41 | refs = [] 42 | candidates = [] 43 | human_scores = [] 44 | for k, v in list(data.items()): 45 | for human_judgement in v['human_judgement']: 46 | if np.isnan(human_judgement['rating']): 47 | print('NaN') 48 | continue 49 | 50 | candidate = ' '.join(human_judgement['caption'].split()) 51 | candidates.append(candidate) 52 | 53 | ref = [' '.join(gt.split()) for gt in v['ground_truth']] 54 | refs.append(ref) 55 | human_scores.append(human_judgement['rating']) 56 | print('Loaded {} references and {} candidates'.format(len(refs), len(candidates))) 57 | assert len(candidates) == len(refs) 58 | 59 | device = "cuda" if torch.cuda.is_available() else "cpu" 60 | 61 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg', device=device, lemmatize=False, lowercase=True) 62 | evaluator = Evaluator(parser=parser, text_encoder_checkpoint='all-MiniLM-L6-v2', device=device, lemmatize=True) 63 | 64 | 65 | caption_list = collect_unique_captions(candidates, refs) 66 | parse_dict = parse_captions(caption_list, parser) 67 | 68 | # Evaluate with return_graphs=True 69 | spice_scores, cand_graphs, ref_graphs = evaluate_graphs(candidates, refs, parse_dict, evaluator, True) 70 | 71 | 72 | assert len(spice_scores) == len(human_scores) 73 | print('SPICE score: ', sum(spice_scores) / len(spice_scores)) 74 | print('{} Tau-{}: {:.3f}'.format(tauvariant, tauvariant, 100*scipy.stats.kendalltau(spice_scores, human_scores, variant=tauvariant)[0])) 75 | 76 | 77 | soft_spice_scores = evaluator.evaluate(cand_graphs, ref_graphs, method='soft_spice', batch_size=128) 78 | assert len(soft_spice_scores) == len(human_scores) 79 | print('Soft-SPICE score: ', sum(soft_spice_scores) / len(soft_spice_scores)) 80 | print('{} Tau-{}: {:.3f}'.format(tauvariant, tauvariant, 100*scipy.stats.kendalltau(soft_spice_scores, human_scores, variant=tauvariant)[0])) 81 | 82 | 83 | 84 | if __name__ == '__main__': 85 | compute_correlation('tests/test_data/flickr8k.json', tauvariant='c') -------------------------------------------------------------------------------- /tests/test_parser.py: -------------------------------------------------------------------------------- 1 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 2 | from factual_scene_graph.utils import space_out_symbols_in_graph, tprint, clean_graph_string 3 | 4 | 5 | def test_clean_graph_string(): 6 | # Test normal case 7 | assert clean_graph_string("( bench , is , woo") == "( bench )" 8 | # Test with extra spaces 9 | assert clean_graph_string("( bench , is , wooden ) ,") == "( bench , is , wooden )" 10 | # Test with multiple relations 11 | assert clean_graph_string("( bench , is , wooden ) , ( bench , v:faces , ") == "( bench , is , wooden )" 12 | # Test empty string 13 | assert clean_graph_string("") == "" 14 | print("All tests for clean_graph_string passed!") 15 | 16 | def test_space_out_symbols_in_graph(): 17 | # Test normal case 18 | assert space_out_symbols_in_graph("(bench,is,wooden)") == "( bench , is , wooden )" 19 | # Test with extra spaces 20 | assert space_out_symbols_in_graph(" (bench ,is,wooden) ") == "( bench , is , wooden )" 21 | # Test with multiple relations 22 | assert space_out_symbols_in_graph("(bench,is,wooden),(bench,v:faces,sea)") == "( bench , is , wooden ) , ( bench , v:faces , sea )" 23 | # Test empty string 24 | assert space_out_symbols_in_graph("") == "" 25 | print("All tests for space_out_symbols_in_graph passed!") 26 | 27 | def test_scene_graph_parser(): 28 | # Assuming SceneGraphParser is correctly instantiated as `parser` 29 | # Test normal input 30 | parser = SceneGraphParser('lizhuang144/flan-t5-base-VG-factual-sg-id', device='cuda:0') 31 | text_graph = parser.parse(["2 beautiful pigs are flying on the sky with 2 bags on their backs"],return_text=True) 32 | 33 | print(text_graph[0]) 34 | text_graph = parser.parse(["2 beautiful and strong pigs are flying on the sky with 2 bags on their backs", 35 | "a blue sky"], max_output_len=128, return_text=True) 36 | print(text_graph[0]) 37 | 38 | graph_obj = parser.parse(["2 beautiful and strong pigs are flying on the sky with 2 bags on their backs", 39 | "a blue sky"], max_output_len=128, return_text=False) 40 | print(graph_obj[0]) 41 | 42 | tprint(graph_obj[0]) 43 | 44 | graph_obj = parser.parse(["boy"], max_output_len=16, return_text=False) 45 | print(graph_obj[0]) 46 | 47 | tprint(graph_obj[0]) 48 | 49 | text_graph = parser.parse(["a logo is written on another logo"],return_text=True,filter_factual_chars=True) 50 | 51 | print(text_graph[0]) 52 | 53 | 54 | if __name__ == "__main__": 55 | test_space_out_symbols_in_graph() 56 | test_clean_graph_string() 57 | test_scene_graph_parser() -------------------------------------------------------------------------------- /tests/test_spice_parser.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import torch 3 | 4 | from factual_scene_graph.evaluation.evaluator import Evaluator 5 | from factual_scene_graph.parser.scene_graph_parser import SceneGraphParser 6 | 7 | device = "cuda" if torch.cuda.is_available() else "cpu" 8 | 9 | def test_scene_graph_parsing_on_random(): 10 | evaluator = Evaluator(parser=None,text_encoder_checkpoint='all-MiniLM-L6-v2', device=device,lemmatize=True) 11 | 12 | random_data_pd = pd.read_csv('data/factual_sg_id/random/test.csv') 13 | random_cand_captions = random_data_pd['caption'].tolist() 14 | 15 | 16 | caption_scene_dict = {} 17 | 18 | for line in open(file='tests/test_data/SPICE_parsing_outputs.txt', mode='r').readlines(): 19 | caption_scene_dict[line.split('\t')[0].strip()] = line.split('\t')[1].strip() 20 | # print(caption_scene_dict) 21 | random_cand_graphs = [] 22 | for caption in random_cand_captions: 23 | random_cand_graphs.append(caption_scene_dict[caption]) 24 | 25 | random_data_graphs = [[scene] for scene in random_data_pd['scene_graph'].tolist()] 26 | spice_scores, cand_graphs, ref_graphs = evaluator.evaluate(random_cand_graphs, random_data_graphs, method='spice', beam_size=1, batch_size=128, max_input_len=256, max_output_len=256, return_graphs=True) 27 | 28 | print('SPICE scores of SPICE Parser for random test set:') 29 | print(sum(spice_scores)/len(spice_scores)) 30 | 31 | set_match_scores = evaluator.evaluate(cand_graphs, ref_graphs,method='set_match', beam_size=1) 32 | 33 | print('Set Match scores of SPICE Parser for random test set:') 34 | print(sum(set_match_scores)/len(set_match_scores)) 35 | 36 | soft_spice_scores = evaluator.evaluate(cand_graphs, ref_graphs,method='soft_spice', beam_size=1) 37 | 38 | print('Soft-SPICE scores of SPICE Parser for random test set:') 39 | print(sum(soft_spice_scores)/len(soft_spice_scores)) 40 | 41 | 42 | 43 | if __name__ == "__main__": 44 | #test_scene_graph_parsing() 45 | test_scene_graph_parsing_on_random() -------------------------------------------------------------------------------- /tests/test_synonym_dictionary.py: -------------------------------------------------------------------------------- 1 | from importlib import resources 2 | 3 | from factual_scene_graph.evaluation.synonym_dictionary import SynonymDictionary 4 | 5 | 6 | with resources.open_text('factual_scene_graph.evaluation.resources', 'english.exceptions') as f_exc, \ 7 | resources.open_text('factual_scene_graph.evaluation.resources', 'english.synsets') as f_syn: 8 | synonym_dictionary = SynonymDictionary(f_exc.name, f_syn.name) 9 | print(synonym_dictionary.get_stem_synsets('write_down')) -------------------------------------------------------------------------------- /tests/train_parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import wandb 3 | from datasets import load_dataset 4 | from transformers import TrainerCallback, AutoTokenizer, AutoModelForSeq2SeqLM 5 | from transformers import ( 6 | DataCollatorForSeq2Seq, 7 | Seq2SeqTrainingArguments, 8 | Seq2SeqTrainer, 9 | ) 10 | import torch 11 | 12 | class LogLearningRateCallback(TrainerCallback): 13 | """ 14 | Callback that logs the learning rate at the end of each step. 15 | """ 16 | def on_step_end(self, args, state, control, **kwargs): 17 | optimizer = kwargs.get('optimizer') 18 | if optimizer: 19 | current_lr = optimizer.param_groups[0]['lr'] 20 | wandb.log({"learning_rate": current_lr}, step=state.global_step) 21 | 22 | def parse_args(): 23 | """ 24 | Parse command-line arguments. 25 | """ 26 | parser = argparse.ArgumentParser(description="Train a scene graph generation model") 27 | parser.add_argument( 28 | "--dataset", 29 | type=str, 30 | default="lizhuang144/FACTUAL_Scene_Graph", 31 | help="Dataset identifier from Hugging Face hub", 32 | ) 33 | parser.add_argument( 34 | "--checkpoint", 35 | type=str, 36 | default="google/flan-t5-base", 37 | help="Model checkpoint to use", 38 | ) 39 | parser.add_argument( 40 | "--output_dir", 41 | type=str, 42 | default="saved_models/factual_sg/", 43 | help="Directory to save model checkpoints", 44 | ) 45 | parser.add_argument("--num_epochs", type=int, default=10, help="Number of training epochs") 46 | parser.add_argument("--batch_size", type=int, default=64, help="Training batch size per device") 47 | parser.add_argument("--learning_rate", type=float, default=2e-5, help="Learning rate") 48 | parser.add_argument("--eval_steps", type=int, default=200, help="Number of steps between evaluations") 49 | parser.add_argument("--generation_max_length", type=int, default=512, help="Maximum length of generated sequences") 50 | parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Number of gradient accumulation steps") 51 | parser.add_argument("--seed", type=int, default=11, help="Random seed") 52 | return parser.parse_args() 53 | 54 | def prepare_examples(example): 55 | """ 56 | Prepares each example by creating an 'input_text' with a task prompt and setting the 'target_text'. 57 | """ 58 | example["input_text"] = "Generate Scene Graph: " + example["caption"] 59 | example["target_text"] = example["scene_graph"] 60 | return example 61 | 62 | def preprocess_function(examples, tokenizer): 63 | """ 64 | Tokenizes the input and target texts without fixed padding. 65 | Dynamic padding will be applied at the batch level by the data collator. 66 | """ 67 | inputs = examples["input_text"] 68 | targets = examples["target_text"] 69 | 70 | # Tokenize without fixed padding 71 | model_inputs = tokenizer( 72 | inputs, max_length=512, truncation=True, padding=False 73 | ) 74 | with tokenizer.as_target_tokenizer(): 75 | labels = tokenizer( 76 | targets, max_length=512, truncation=True, padding=False 77 | ) 78 | # Replace all pad tokens in the labels with -100 so they are ignored in the loss 79 | labels["input_ids"] = [ 80 | [(l if l != tokenizer.pad_token_id else -100) for l in label] 81 | for label in labels["input_ids"] 82 | ] 83 | model_inputs["labels"] = labels["input_ids"] 84 | 85 | # Optional: print lengths for debugging (comment out in production) 86 | print(len(model_inputs["input_ids"][0]), len(labels["input_ids"][0])) 87 | return model_inputs 88 | 89 | def main(): 90 | args = parse_args() 91 | 92 | # Initialize Weights & Biases 93 | project_name = "FACTUAL_Scene_Graph" 94 | run_name = ( 95 | f"dataset={args.dataset}-checkpoint={args.checkpoint}-num_epochs={args.num_epochs}-" 96 | f"batch_size={args.batch_size}-learning_rate={args.learning_rate}-seed={args.seed}-eval_steps={args.eval_steps}" 97 | ) 98 | wandb.init( 99 | project=project_name, 100 | name=run_name, 101 | config={ 102 | "dataset": args.dataset, 103 | "checkpoint": args.checkpoint, 104 | "num_epochs": args.num_epochs, 105 | "batch_size": args.batch_size, 106 | "learning_rate": args.learning_rate, 107 | "seed": args.seed, 108 | "eval_steps": args.eval_steps, 109 | }, 110 | ) 111 | 112 | # Use CUDA if available, else CPU 113 | device = "cuda" if torch.cuda.is_available() else "cpu" 114 | tokenizer = AutoTokenizer.from_pretrained(args.checkpoint) 115 | model = AutoModelForSeq2SeqLM.from_pretrained( 116 | args.checkpoint, trust_remote_code=True 117 | ).to(device) 118 | 119 | # Load and prepare the dataset 120 | dataset = load_dataset(args.dataset) 121 | dataset = dataset.map(prepare_examples) 122 | tokenized_dataset = dataset.map( 123 | lambda x: preprocess_function(x, tokenizer), batched=True 124 | ) 125 | # Create a 90/10 train-test split from the available 'train' split 126 | tokenized_dataset = tokenized_dataset["train"].train_test_split(test_size=0.1, seed=args.seed) 127 | 128 | # Use a data collator that pads to the longest sequence in the batch 129 | data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model, padding="longest") 130 | 131 | training_args = Seq2SeqTrainingArguments( 132 | output_dir=args.output_dir, 133 | eval_steps=args.eval_steps, 134 | eval_strategy="steps", 135 | save_strategy="steps", 136 | save_steps=args.eval_steps, 137 | generation_max_length=args.generation_max_length, 138 | num_train_epochs=args.num_epochs, 139 | predict_with_generate=True, 140 | seed=args.seed, 141 | overwrite_output_dir=True, 142 | save_total_limit=1, 143 | report_to="wandb", 144 | per_device_train_batch_size=args.batch_size, 145 | gradient_accumulation_steps=args.gradient_accumulation_steps, 146 | load_best_model_at_end=True, 147 | learning_rate=args.learning_rate, 148 | lr_scheduler_type="cosine", 149 | warmup_steps=500, 150 | metric_for_best_model="eval_loss", 151 | greater_is_better=False, 152 | ) 153 | 154 | trainer = Seq2SeqTrainer( 155 | model=model, 156 | args=training_args, 157 | train_dataset=tokenized_dataset["train"], 158 | eval_dataset=tokenized_dataset["test"], 159 | tokenizer=tokenizer, 160 | data_collator=data_collator, 161 | callbacks=[LogLearningRateCallback()], 162 | ) 163 | 164 | trainer.train() 165 | wandb.finish() 166 | 167 | if __name__ == "__main__": 168 | main() 169 | --------------------------------------------------------------------------------