├── .flake8 ├── .gitignore ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── GBC_data_construction.png ├── GBC_illustration.png ├── GBC_sdxl.png └── GBC_viewer.png ├── configs ├── captioning │ ├── default.yaml │ └── llava_yoloworld.yaml ├── generation │ ├── gbc2i │ │ ├── sampling_base.yaml │ │ ├── sampling_gbc_encode_with_context.yaml │ │ ├── sampling_gbc_encode_with_context_ipa.yaml │ │ ├── sampling_gbc_encode_without_context.yaml │ │ ├── sampling_region_base.yaml │ │ ├── sampling_region_base_ipa.yaml │ │ ├── sampling_region_gbc_encode_with_context.yaml │ │ ├── sampling_region_gbc_encode_with_context_ipa.yaml │ │ ├── sampling_region_gbc_encode_without_context.yaml │ │ ├── sampling_region_gbc_encode_without_context_ipa.yaml │ │ └── secondary │ │ │ ├── sampling_layer.yaml │ │ │ ├── sampling_layer_region.yaml │ │ │ └── sampling_layer_region_ipa.yaml │ ├── graph_transform_ex.yaml │ └── t2gbc_default.yaml └── processing │ ├── clip_filtering.yaml │ ├── compute_all_scores.yaml │ ├── compute_clip_scores.yaml │ ├── compute_toxicity_scores.yaml │ ├── relation_composition_filtering.yaml │ └── to_structured_text.yaml ├── data ├── gbc │ ├── prompt_gen │ │ ├── library_turtle_frog_steamponk.json │ │ └── library_turtle_frog_steamponk.parquet │ ├── wiki │ │ ├── wiki_gbc_graphs.jsonl │ │ └── wiki_gbc_graphs.parquet │ ├── wiki_pixtral_gdino │ │ ├── wiki_gbc_graphs.jsonl │ │ └── wiki_gbc_graphs.parquet │ └── wiki_with_clip_scores │ │ ├── wiki_gbc_graphs_with_clip.jsonl │ │ └── wiki_gbc_graphs_with_clip.parquet └── images │ ├── t2gbc2i │ ├── frog.png │ ├── library.png │ ├── mechanical_cat.png │ └── turtle.png │ └── wiki │ ├── Cesenatico.jpg │ ├── Eiffel_tower_0.jpg │ ├── Eiffel_tower_1.jpg │ ├── Eiffel_tower_2.jpg │ ├── Household_utensils.jpg │ ├── Ilish_Bhaat.jpg │ ├── Jean-Francois_Portaels.jpg │ ├── Kinkakuji.jpg │ ├── Kompoziciya.jpg │ ├── Launch_of_Shenzhou.jpg │ ├── Magdeburg.jpg │ ├── Maid_of_the_Mist.jpg │ ├── Mumbai_Flora_Fountain.jpg │ ├── Painting.jpg │ ├── Petani_padi.jpg │ ├── README.md │ ├── Regalia.jpg │ ├── Tartu_raudteejaama_veetorn.jpg │ ├── Vat_Pa_Phai_temple.jpg │ ├── Wild_horses.jpg │ ├── corgi.jpg │ └── orange_cat.jpg ├── prompts ├── captioning │ ├── query_composition.txt │ ├── query_entity.txt │ ├── query_image.txt │ ├── query_relation.txt │ ├── system_composition.txt │ ├── system_entity.txt │ ├── system_image.txt │ └── system_relation.txt └── t2i │ ├── banana_apple.yaml │ ├── banana_apple_graph_only.yaml │ ├── dog_cat_ref_image.yaml │ ├── dog_cat_ref_image_graph_only.yaml │ ├── living_room.yaml │ ├── living_room_graph_only.yaml │ ├── neg_default.yaml │ ├── person_graph_only.yaml │ ├── t2gbc_seed.txt │ ├── t2gbc_seed.yaml │ └── t2gbc_seed_with_entity_specification.yaml ├── pyproject.toml ├── scripts ├── captioning │ └── run_gbc_captioning.py ├── generation │ ├── gbc2i.py │ ├── t2gbc.py │ └── t2gbc2i.py ├── processing │ └── process_gbc.py └── setup │ ├── download_llava_models.py │ ├── download_yolo_world_models.py │ ├── setup_llava_query.sh │ └── setup_yolo_world_detection.sh ├── src └── gbc │ ├── __init__.py │ ├── captioning │ ├── __init__.py │ ├── auto_actions.py │ ├── conversion │ │ ├── __init__.py │ │ ├── parse_node_infos.py │ │ └── to_gbc_graph.py │ ├── detection │ │ ├── __init__.py │ │ ├── detection.py │ │ ├── detection_action.py │ │ ├── grounding_dino.py │ │ ├── hack_mmengine_registry.py │ │ └── yolo_world.py │ ├── mllm │ │ ├── __init__.py │ │ ├── composition_mst.py │ │ ├── llava │ │ │ ├── __init__.py │ │ │ ├── llava_base.py │ │ │ ├── llava_chat_handler.py │ │ │ └── llava_queries.py │ │ ├── parse_mllm_output.py │ │ ├── pixtral │ │ │ ├── __init__.py │ │ │ ├── pixtral_base.py │ │ │ └── pixtral_queries.py │ │ ├── query_base.py │ │ └── query_prototype.py │ ├── pipeline │ │ ├── __init__.py │ │ ├── get_relational_queries.py │ │ ├── pipeline.py │ │ └── pipeline_functional.py │ └── primitives │ │ ├── __init__.py │ │ ├── action.py │ │ ├── action_io.py │ │ └── io_unit.py │ ├── data │ ├── __init__.py │ ├── bbox │ │ ├── __init__.py │ │ ├── annotate.py │ │ └── bbox.py │ ├── caption │ │ ├── __init__.py │ │ └── caption.py │ └── graph │ │ ├── __init__.py │ │ ├── gbc_graph.py │ │ └── gbc_graph_full.py │ ├── processing │ ├── __init__.py │ ├── data_transforms │ │ ├── __init__.py │ │ ├── basic_transforms.py │ │ ├── clip_filtering.py │ │ ├── clip_scoring.py │ │ ├── function_transforms.py │ │ └── toxicity_scoring.py │ ├── local_process.py │ └── meta_process.py │ ├── t2i │ ├── __init__.py │ ├── gbc2i │ │ ├── cfg.py │ │ ├── gbc_sampling.py │ │ ├── get_sigmas.py │ │ ├── k_diffusion │ │ │ ├── NOTICE.md │ │ │ ├── __init__.py │ │ │ ├── euler.py │ │ │ └── wrapper.py │ │ ├── mask_from_xattn_score.py │ │ └── sampling.py │ ├── modules │ │ ├── attention.py │ │ ├── attn_masks.py │ │ ├── graph_attn.py │ │ ├── text_encoders.py │ │ └── unet_patch.py │ ├── prompt.py │ ├── t2gbc │ │ ├── gbc_prompt_gen.py │ │ ├── generation.py │ │ ├── graph_parse.py │ │ └── sampling_constraint.py │ └── utils │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── loader.py │ │ └── segmentation.py │ ├── texts │ ├── __init__.py │ ├── basics.py │ ├── classifiers.py │ └── text_helpers.py │ └── utils.py ├── tests ├── test_captioning_unit │ ├── test_auto_image_entity_actions.py │ ├── test_captioning_pipeline.py │ ├── test_captioning_pipeline_functional.py │ ├── test_detection_action.py │ ├── test_detection_grounding_dino.py │ ├── test_detection_yolo_world.py │ ├── test_image_query.py │ ├── test_io_unit.py │ └── test_pixtral_query.py ├── test_data_unit │ ├── test_basic_filter.py │ ├── test_clip_filter.py │ ├── test_gbc_to_text_and_image.py │ └── test_graph_basics.py └── test_integral │ ├── test_gbc2i_sampling.sh │ ├── test_gbc_captioning.sh │ ├── test_gbc_captioning_batch_single_image.sh │ ├── test_gbc_captioning_single_image.sh │ ├── test_gbc_captioning_single_image_llava_yoloworld.sh │ └── test_gbc_processing.sh └── viewer ├── .gitignore ├── README.md ├── config.js ├── data └── .gitkeep ├── index.html ├── main.js ├── package-lock.json ├── package.json ├── server ├── api.py ├── headers.py └── requirements.txt ├── style.css ├── utils.js ├── vis-settings.js └── vite.config.js /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 88 3 | ignore = E401, E402, E731, E203, W503, W504 4 | 5 | per-file-ignores = 6 | setup/download_models.py: E501 7 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | docs/build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # pipenv 89 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 90 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 91 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 92 | # install all needed dependencies. 93 | #Pipfile.lock 94 | 95 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 96 | __pypackages__/ 97 | 98 | # Celery stuff 99 | celerybeat-schedule 100 | celerybeat.pid 101 | 102 | # SageMath parsed files 103 | *.sage.py 104 | 105 | # Environments 106 | .env 107 | .venv 108 | env/ 109 | venv/ 110 | ENV/ 111 | env.bak/ 112 | venv.bak/ 113 | 114 | # Spyder project settings 115 | .spyderproject 116 | .spyproject 117 | 118 | # Rope project settings 119 | .ropeproject 120 | 121 | # mypy 122 | .mypy_cache/ 123 | .dmypy.json 124 | dmypy.json 125 | 126 | # Pyre type checker 127 | .pyre/ 128 | 129 | .DS_Store 130 | **/.DS_Store 131 | 132 | tests/outputs/ 133 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to making participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the open source team at [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com). All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant](https://www.contributor-covenant.org), version 1.4, 71 | available at [https://www.contributor-covenant.org/version/1/4/code-of-conduct.html](https://www.contributor-covenant.org/version/1/4/code-of-conduct.html) -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contribution Guide 2 | 3 | Thanks for your interest in contributing. This project was released to accompany a research paper for purposes of reproducibility, and beyond its publication there are limited plans for future development of the repository. 4 | 5 | While we welcome new pull requests and issues please note that our response may be limited. Forks and out-of-tree improvements are strongly encouraged. 6 | 7 | ## Before you get started 8 | 9 | By submitting a pull request, you represent that you have the right to license your contribution to Apple and the community, and agree by submitting the patch that your contributions are licensed under the [LICENSE](LICENSE). 10 | 11 | We ask that all community members read and observe our [Code of Conduct](CODE_OF_CONDUCT.md). 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (C) 2024 Apple Inc. All Rights Reserved. 2 | 3 | IMPORTANT: This Apple software is supplied to you by Apple 4 | Inc. ("Apple") in consideration of your agreement to the following 5 | terms, and your use, installation, modification or redistribution of 6 | this Apple software constitutes acceptance of these terms. If you do 7 | not agree with these terms, please do not use, install, modify or 8 | redistribute this Apple software. 9 | 10 | In consideration of your agreement to abide by the following terms, and 11 | subject to these terms, Apple grants you a personal, non-exclusive 12 | license, under Apple's copyrights in this original Apple software (the 13 | "Apple Software"), to use, reproduce, modify and redistribute the Apple 14 | Software, with or without modifications, in source and/or binary forms; 15 | provided that if you redistribute the Apple Software in its entirety and 16 | without modifications, you must retain this notice and the following 17 | text and disclaimers in all such redistributions of the Apple Software. 18 | Neither the name, trademarks, service marks or logos of Apple Inc. may 19 | be used to endorse or promote products derived from the Apple Software 20 | without specific prior written permission from Apple. Except as 21 | expressly stated in this notice, no other rights or licenses, express or 22 | implied, are granted by Apple herein, including but not limited to any 23 | patent rights that may be infringed by your derivative works or by other 24 | works in which the Apple Software may be incorporated. 25 | 26 | The Apple Software is provided by Apple on an "AS IS" basis. APPLE 27 | MAKES NO WARRANTIES, EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION 28 | THE IMPLIED WARRANTIES OF NON-INFRINGEMENT, MERCHANTABILITY AND FITNESS 29 | FOR A PARTICULAR PURPOSE, REGARDING THE APPLE SOFTWARE OR ITS USE AND 30 | OPERATION ALONE OR IN COMBINATION WITH YOUR PRODUCTS. 31 | 32 | IN NO EVENT SHALL APPLE BE LIABLE FOR ANY SPECIAL, INDIRECT, INCIDENTAL 33 | OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF 34 | SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS 35 | INTERRUPTION) ARISING IN ANY WAY OUT OF THE USE, REPRODUCTION, 36 | MODIFICATION AND/OR DISTRIBUTION OF THE APPLE SOFTWARE, HOWEVER CAUSED 37 | AND WHETHER UNDER THEORY OF CONTRACT, TORT (INCLUDING NEGLIGENCE), 38 | STRICT LIABILITY OR OTHERWISE, EVEN IF APPLE HAS BEEN ADVISED OF THE 39 | POSSIBILITY OF SUCH DAMAGE. 40 | -------------------------------------------------------------------------------- /assets/GBC_data_construction.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/assets/GBC_data_construction.png -------------------------------------------------------------------------------- /assets/GBC_illustration.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/assets/GBC_illustration.png -------------------------------------------------------------------------------- /assets/GBC_sdxl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/assets/GBC_sdxl.png -------------------------------------------------------------------------------- /assets/GBC_viewer.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/assets/GBC_viewer.png -------------------------------------------------------------------------------- /configs/captioning/default.yaml: -------------------------------------------------------------------------------- 1 | pipeline_config: 2 | 3 | artifact_format: .jsonl 4 | save_frequency: 100 5 | save_dir: tests/outputs/captioning/gbc_wiki_images 6 | save_callback: null 7 | save_images: False 8 | 9 | batch_query: False 10 | batch_size: 32 # ignored if batch_query is False 11 | 12 | include_entity_query: True 13 | include_composition_query: True 14 | include_relation_query: True 15 | 16 | mask_inside_threshold: 0.85 17 | 18 | shared: 19 | 20 | suitable_for_detection_func: 21 | _target_: gbc.texts.text_helpers.suitable_for_detection 22 | _partial_: true 23 | 24 | potential_same_object_func: 25 | _target_: gbc.texts.text_helpers.potential_same_object 26 | _partial_: true 27 | 28 | detection_model: 29 | _target_: gbc.captioning.detection.grounding_dino.GroundingDinoDetection 30 | model_name: IDEA-Research/grounding-dino-tiny 31 | 32 | pixtral: 33 | model_name: nm-testing/pixtral-12b-FP8-dynamic 34 | gpu_memory_utilization: 0.8 # can be set to 0.3 if you have 80GB GPU memory 35 | 36 | queries: 37 | 38 | image_query: 39 | _target_: gbc.captioning.mllm.pixtral.PixtralImageQuery 40 | model_kwargs: ${...shared.pixtral} 41 | system_file: prompts/captioning/system_image.txt 42 | query_file: prompts/captioning/query_image.txt 43 | suitable_for_detection_func: ${...shared.suitable_for_detection_func} 44 | 45 | entity_query: 46 | _target_: gbc.captioning.mllm.pixtral.PixtralEntityQuery 47 | model_kwargs: ${...shared.pixtral} 48 | system_file: prompts/captioning/system_entity.txt 49 | query_file: prompts/captioning/query_entity.txt 50 | suitable_for_detection_func: ${...shared.suitable_for_detection_func} 51 | potential_same_object_func: ${...shared.potential_same_object_func} 52 | query_kwargs: 53 | max_tokens: 512 54 | 55 | relation_query: 56 | _target_: gbc.captioning.mllm.pixtral.PixtralRelationQuery 57 | model_kwargs: ${...shared.pixtral} 58 | system_file: prompts/captioning/system_relation.txt 59 | query_file: prompts/captioning/query_relation.txt 60 | 61 | composition_query: 62 | _target_: gbc.captioning.mllm.pixtral.PixtralCompositionQuery 63 | model_kwargs: ${...shared.pixtral} 64 | system_file: prompts/captioning/system_composition.txt 65 | query_file: prompts/captioning/query_composition.txt 66 | 67 | detection_from_image: 68 | _target_: gbc.captioning.detection.detection_action.DetectionAction 69 | detection_model: ${...shared.detection_model} 70 | score_threshold: 0.35 # 0.05 71 | nms_single_threshold: 0.05 72 | nms_multiple_threshold: 0.2 73 | select_high_conf_tolerance: 0.05 74 | topk: 6 75 | min_abs_area: 5000 76 | max_rel_area: null 77 | 78 | detection_from_entity: 79 | _target_: gbc.captioning.detection.detection_action.DetectionAction 80 | detection_model: ${...shared.detection_model} 81 | score_threshold: 0.35 # 0.05 82 | nms_single_threshold: 0.05 83 | nms_multiple_threshold: 0.2 84 | select_high_conf_tolerance: 0.05 85 | topk: 6 86 | min_abs_area: 5000 87 | max_rel_area: 0.8 88 | -------------------------------------------------------------------------------- /configs/captioning/llava_yoloworld.yaml: -------------------------------------------------------------------------------- 1 | pipeline_config: 2 | 3 | artifact_format: .jsonl 4 | save_frequency: 100 5 | save_dir: tests/outputs/captioning/gbc_wiki_images 6 | save_callback: null 7 | save_images: False 8 | 9 | batch_query: False 10 | batch_size: 32 # ignored if batch_query is False 11 | 12 | include_entity_query: True 13 | include_composition_query: True 14 | include_relation_query: True 15 | 16 | mask_inside_threshold: 0.85 17 | 18 | shared: 19 | 20 | suitable_for_detection_func: 21 | _target_: gbc.texts.text_helpers.suitable_for_detection 22 | _partial_: true 23 | 24 | potential_same_object_func: 25 | _target_: gbc.texts.text_helpers.potential_same_object 26 | _partial_: true 27 | 28 | detection_model: 29 | _target_: gbc.captioning.detection.yolo_world.YoloWorldDetection 30 | work_dir: logs 31 | model_version: x_v2 32 | 33 | llava_7B: 34 | gpu_id: 0 35 | version: Mistral-7B 36 | 37 | llava_34B: 38 | gpu_id: 0 39 | version: Yi-34B 40 | 41 | queries: 42 | 43 | image_query: 44 | _target_: gbc.captioning.mllm.llava.LlavaImageQuery 45 | model_kwargs: ${...shared.llava_34B} 46 | system_file: prompts/captioning/system_image.txt 47 | query_file: prompts/captioning/query_image.txt 48 | suitable_for_detection_func: ${...shared.suitable_for_detection_func} 49 | 50 | entity_query: 51 | _target_: gbc.captioning.mllm.llava.LlavaEntityQuery 52 | model_kwargs: ${...shared.llava_7B} 53 | system_file: prompts/captioning/system_entity.txt 54 | query_file: prompts/captioning/query_entity.txt 55 | suitable_for_detection_func: ${...shared.suitable_for_detection_func} 56 | potential_same_object_func: ${...shared.potential_same_object_func} 57 | 58 | relation_query: 59 | _target_: gbc.captioning.mllm.llava.LlavaRelationQuery 60 | model_kwargs: ${...shared.llava_7B} 61 | system_file: prompts/captioning/system_relation.txt 62 | query_file: prompts/captioning/query_relation.txt 63 | 64 | composition_query: 65 | _target_: gbc.captioning.mllm.llava.LlavaCompositionQuery 66 | model_kwargs: ${...shared.llava_34B} 67 | system_file: prompts/captioning/system_composition.txt 68 | query_file: prompts/captioning/query_composition.txt 69 | 70 | detection_from_image: 71 | _target_: gbc.captioning.detection.detection_action.DetectionAction 72 | detection_model: ${...shared.detection_model} 73 | score_threshold: 0.35 # 0.05 74 | nms_single_threshold: 0.05 75 | nms_multiple_threshold: 0.2 76 | select_high_conf_tolerance: 0.05 77 | topk: 6 78 | min_abs_area: 5000 79 | max_rel_area: null 80 | 81 | detection_from_entity: 82 | _target_: gbc.captioning.detection.detection_action.DetectionAction 83 | detection_model: ${...shared.detection_model} 84 | score_threshold: 0.35 # 0.05 85 | nms_single_threshold: 0.05 86 | nms_multiple_threshold: 0.2 87 | select_high_conf_tolerance: 0.05 88 | topk: 6 89 | min_abs_area: 5000 90 | max_rel_area: 0.8 91 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_base.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-base/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.diffusion_sampling 6 | _partial_: true 7 | 8 | internal_sampling_func: 9 | _target_: gbc.t2i.sample_euler_ancestral 10 | _partial_: true 11 | eta: 0.0 12 | 13 | num_samples: 8 14 | padding_mode: cycling 15 | cfg_scale: 6 16 | seed: 1315 17 | num_steps: 24 18 | width: 1024 19 | height: 1024 20 | 21 | train_scheduler: 22 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 23 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 24 | subfolder: scheduler 25 | 26 | model_config: 27 | 28 | unet: 29 | _target_: diffusers.UNet2DConditionModel.from_pretrained 30 | _load_config_: 31 | device: cuda 32 | precision: torch.float16 33 | to_freeze: true 34 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 35 | subfolder: unet 36 | 37 | te: 38 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 39 | _load_config_: 40 | device: cuda 41 | precision: torch.float16 42 | to_freeze: true 43 | tokenizers: 44 | - openai/clip-vit-large-patch14 45 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 46 | text_model_and_configs: 47 | - 48 | - _target_: transformers.CLIPTextModel.from_pretrained 49 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 50 | subfolder: text_encoder 51 | - disable_autocast: false 52 | concat_bucket: 0 53 | use_pooled: false 54 | need_mask: false 55 | layer_idx: -2 56 | - 57 | - _target_: transformers.CLIPTextModel.from_pretrained 58 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 59 | subfolder: text_encoder_2 60 | - disable_autocast: false 61 | concat_bucket: 0 62 | use_pooled: true 63 | need_mask: false 64 | layer_idx: -2 65 | zero_for_padding: false 66 | 67 | vae: 68 | _target_: diffusers.AutoencoderKL.from_pretrained 69 | _load_config_: 70 | device: cuda 71 | precision: torch.float16 72 | to_freeze: true 73 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 74 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_gbc_encode_with_context.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-gbc-with-context/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: false 10 | use_region_attention: true 11 | use_layer_attention: false # note that layer attention does not work with get_mask_from_xattn_scores 12 | exclusive_region_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: false 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | get_mask_from_xattn_scores: true 21 | n_first_phase_steps: 12 22 | first_phase_start_compute_ema_step: 6 23 | region_mask_ema: 0.9 24 | 25 | num_samples: 8 26 | padding_mode: cycling 27 | cfg_scale: 6 28 | seed: 1315 29 | num_steps: 24 30 | width: 1024 31 | height: 1024 32 | 33 | train_scheduler: 34 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 35 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 36 | subfolder: scheduler 37 | 38 | model_config: 39 | 40 | unet: 41 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 42 | _load_config_: 43 | device: cuda 44 | precision: torch.float16 45 | to_freeze: true 46 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 47 | subfolder: unet 48 | use_region_attention: true 49 | use_caption_mask: true 50 | use_layer_attention: false 51 | use_flex_attention_for_region: false 52 | 53 | te: 54 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 55 | _load_config_: 56 | device: cuda 57 | precision: torch.float16 58 | to_freeze: true 59 | tokenizers: 60 | - openai/clip-vit-large-patch14 61 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 62 | text_model_and_configs: 63 | - 64 | - _target_: transformers.CLIPTextModel.from_pretrained 65 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 66 | subfolder: text_encoder 67 | - disable_autocast: false 68 | concat_bucket: 0 69 | use_pooled: false 70 | need_mask: false 71 | layer_idx: -2 72 | - 73 | - _target_: transformers.CLIPTextModel.from_pretrained 74 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 75 | subfolder: text_encoder_2 76 | - disable_autocast: false 77 | concat_bucket: 0 78 | use_pooled: true 79 | need_mask: false 80 | layer_idx: -2 81 | zero_for_padding: false 82 | 83 | vae: 84 | _target_: diffusers.AutoencoderKL.from_pretrained 85 | _load_config_: 86 | device: cuda 87 | precision: torch.float16 88 | to_freeze: true 89 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 90 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_gbc_encode_with_context_ipa.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-ipa-gbc-with-context/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: false 10 | use_region_attention: true 11 | use_layer_attention: false # note that layer attention does not work with get_mask_from_xattn_scores 12 | exclusive_region_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: true 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | get_mask_from_xattn_scores: true 21 | n_first_phase_steps: 12 22 | first_phase_start_compute_ema_step: 6 23 | region_mask_ema: 0.9 24 | 25 | num_samples: 8 26 | padding_mode: uniform_expansion 27 | cfg_scale: 6 28 | seed: 1315 29 | num_steps: 24 30 | width: 1024 31 | height: 1024 32 | 33 | train_scheduler: 34 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 35 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 36 | subfolder: scheduler 37 | 38 | model_config: 39 | 40 | unet: 41 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 42 | _load_config_: 43 | device: cuda 44 | precision: torch.float16 45 | to_freeze: true 46 | load_ip_adapter_kwargs: 47 | pretrained_model_name_or_path: h94/IP-Adapter 48 | subfolder: sdxl_models 49 | weight_name: ip-adapter_sdxl.bin 50 | ip_adapter_scale: 1.0 51 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 52 | subfolder: unet 53 | use_region_attention: true 54 | use_caption_mask: true 55 | use_layer_attention: false 56 | use_flex_attention_for_region: false 57 | 58 | te: 59 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 60 | _load_config_: 61 | device: cuda 62 | precision: torch.float16 63 | to_freeze: true 64 | tokenizers: 65 | - openai/clip-vit-large-patch14 66 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 67 | text_model_and_configs: 68 | - 69 | - _target_: transformers.CLIPTextModel.from_pretrained 70 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 71 | subfolder: text_encoder 72 | - disable_autocast: false 73 | concat_bucket: 0 74 | use_pooled: false 75 | need_mask: false 76 | layer_idx: -2 77 | - 78 | - _target_: transformers.CLIPTextModel.from_pretrained 79 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 80 | subfolder: text_encoder_2 81 | - disable_autocast: false 82 | concat_bucket: 0 83 | use_pooled: true 84 | need_mask: false 85 | layer_idx: -2 86 | zero_for_padding: false 87 | 88 | vae: 89 | _target_: diffusers.AutoencoderKL.from_pretrained 90 | _load_config_: 91 | device: cuda 92 | precision: torch.float16 93 | to_freeze: true 94 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 95 | 96 | image_encoder: 97 | _target_: transformers.CLIPVisionModelWithProjection.from_pretrained 98 | _load_config_: 99 | device: cuda 100 | precision: torch.float16 101 | to_freeze: true 102 | pretrained_model_name_or_path: h94/IP-Adapter 103 | subfolder: sdxl_models/image_encoder 104 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_gbc_encode_without_context.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-gbc-without-context/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: false 10 | use_region_attention: true 11 | use_layer_attention: false # note that layer attention does not work with get_mask_from_xattn_scores 12 | exclusive_region_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: false 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | get_mask_from_xattn_scores: true 21 | n_first_phase_steps: 12 22 | first_phase_start_compute_ema_step: 6 23 | region_mask_ema: 0.9 24 | 25 | num_samples: 8 26 | padding_mode: cycling 27 | cfg_scale: 6 28 | seed: 1315 29 | num_steps: 24 30 | width: 1024 31 | height: 1024 32 | 33 | train_scheduler: 34 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 35 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 36 | subfolder: scheduler 37 | 38 | model_config: 39 | 40 | unet: 41 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 42 | _load_config_: 43 | device: cuda 44 | precision: torch.float16 45 | to_freeze: true 46 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 47 | subfolder: unet 48 | use_region_attention: true 49 | use_caption_mask: true 50 | use_layer_attention: false 51 | use_flex_attention_for_region: false 52 | 53 | te: 54 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 55 | _load_config_: 56 | device: cuda 57 | precision: torch.float16 58 | to_freeze: true 59 | tokenizers: 60 | - openai/clip-vit-large-patch14 61 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 62 | text_model_and_configs: 63 | - 64 | - _target_: transformers.CLIPTextModel.from_pretrained 65 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 66 | subfolder: text_encoder 67 | - disable_autocast: false 68 | concat_bucket: 0 69 | use_pooled: false 70 | need_mask: false 71 | layer_idx: -2 72 | - 73 | - _target_: transformers.CLIPTextModel.from_pretrained 74 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 75 | subfolder: text_encoder_2 76 | - disable_autocast: false 77 | concat_bucket: 0 78 | use_pooled: true 79 | need_mask: false 80 | layer_idx: -2 81 | zero_for_padding: false 82 | 83 | vae: 84 | _target_: diffusers.AutoencoderKL.from_pretrained 85 | _load_config_: 86 | device: cuda 87 | precision: torch.float16 88 | to_freeze: true 89 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 90 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_region_base.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-region-base/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: true 10 | use_region_attention: true 11 | use_layer_attention: false 12 | exclusive_region_attention: false 13 | labels_in_neg: false 14 | concat_ancestor_prompts: false 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | 21 | num_samples: 8 22 | padding_mode: cycling 23 | cfg_scale: 6 24 | seed: 1315 25 | num_steps: 24 26 | width: 1024 27 | height: 1024 28 | 29 | train_scheduler: 30 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 31 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 32 | subfolder: scheduler 33 | 34 | model_config: 35 | 36 | unet: 37 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 38 | _load_config_: 39 | device: cuda 40 | precision: torch.float16 41 | to_freeze: true 42 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 43 | subfolder: unet 44 | use_region_attention: true 45 | use_caption_mask: false 46 | use_layer_attention: false 47 | use_flex_attention_for_region: false 48 | 49 | te: 50 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 51 | _load_config_: 52 | device: cuda 53 | precision: torch.float16 54 | to_freeze: true 55 | tokenizers: 56 | - openai/clip-vit-large-patch14 57 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 58 | text_model_and_configs: 59 | - 60 | - _target_: transformers.CLIPTextModel.from_pretrained 61 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 62 | subfolder: text_encoder 63 | - disable_autocast: false 64 | concat_bucket: 0 65 | use_pooled: false 66 | need_mask: false 67 | layer_idx: -2 68 | - 69 | - _target_: transformers.CLIPTextModel.from_pretrained 70 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 71 | subfolder: text_encoder_2 72 | - disable_autocast: false 73 | concat_bucket: 0 74 | use_pooled: true 75 | need_mask: false 76 | layer_idx: -2 77 | zero_for_padding: false 78 | 79 | vae: 80 | _target_: diffusers.AutoencoderKL.from_pretrained 81 | _load_config_: 82 | device: cuda 83 | precision: torch.float16 84 | to_freeze: true 85 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 86 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_region_base_ipa.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-ipa-region-base/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: true 10 | use_region_attention: true 11 | use_layer_attention: false 12 | exclusive_region_attention: false 13 | labels_in_neg: false 14 | concat_ancestor_prompts: false 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | 21 | num_samples: 8 22 | padding_mode: cycling 23 | cfg_scale: 6 24 | seed: 1315 25 | num_steps: 24 26 | width: 1024 27 | height: 1024 28 | 29 | train_scheduler: 30 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 31 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 32 | subfolder: scheduler 33 | 34 | model_config: 35 | 36 | unet: 37 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 38 | _load_config_: 39 | device: cuda 40 | precision: torch.float16 41 | to_freeze: true 42 | load_ip_adapter_kwargs: 43 | pretrained_model_name_or_path: h94/IP-Adapter 44 | subfolder: sdxl_models 45 | weight_name: ip-adapter_sdxl.bin 46 | ip_adapter_scale: 1.0 47 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 48 | subfolder: unet 49 | use_region_attention: true 50 | use_caption_mask: false 51 | use_layer_attention: false 52 | use_flex_attention_for_region: false 53 | 54 | te: 55 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 56 | _load_config_: 57 | device: cuda 58 | precision: torch.float16 59 | to_freeze: true 60 | tokenizers: 61 | - openai/clip-vit-large-patch14 62 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 63 | text_model_and_configs: 64 | - 65 | - _target_: transformers.CLIPTextModel.from_pretrained 66 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 67 | subfolder: text_encoder 68 | - disable_autocast: false 69 | concat_bucket: 0 70 | use_pooled: false 71 | need_mask: false 72 | layer_idx: -2 73 | - 74 | - _target_: transformers.CLIPTextModel.from_pretrained 75 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 76 | subfolder: text_encoder_2 77 | - disable_autocast: false 78 | concat_bucket: 0 79 | use_pooled: true 80 | need_mask: false 81 | layer_idx: -2 82 | zero_for_padding: false 83 | 84 | vae: 85 | _target_: diffusers.AutoencoderKL.from_pretrained 86 | _load_config_: 87 | device: cuda 88 | precision: torch.float16 89 | to_freeze: true 90 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 91 | 92 | image_encoder: 93 | _target_: transformers.CLIPVisionModelWithProjection.from_pretrained 94 | _load_config_: 95 | device: cuda 96 | precision: torch.float16 97 | to_freeze: true 98 | pretrained_model_name_or_path: h94/IP-Adapter 99 | subfolder: sdxl_models/image_encoder 100 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_region_gbc_encode_with_context.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-region-gbc-with-context/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: true 10 | use_region_attention: true 11 | use_layer_attention: false 12 | exclusive_region_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: true 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | 21 | num_samples: 8 22 | padding_mode: uniform_expansion 23 | cfg_scale: 6 24 | seed: 1215 25 | num_steps: 24 26 | width: 1024 27 | height: 1024 28 | 29 | train_scheduler: 30 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 31 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 32 | subfolder: scheduler 33 | 34 | model_config: 35 | 36 | unet: 37 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 38 | _load_config_: 39 | device: cuda 40 | precision: torch.float16 41 | to_freeze: true 42 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 43 | subfolder: unet 44 | use_region_attention: true 45 | use_caption_mask: true 46 | use_layer_attention: false 47 | use_flex_attention_for_region: false 48 | 49 | te: 50 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 51 | _load_config_: 52 | device: cuda 53 | precision: torch.float16 54 | to_freeze: true 55 | tokenizers: 56 | - openai/clip-vit-large-patch14 57 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 58 | text_model_and_configs: 59 | - 60 | - _target_: transformers.CLIPTextModel.from_pretrained 61 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 62 | subfolder: text_encoder 63 | - disable_autocast: false 64 | concat_bucket: 0 65 | use_pooled: false 66 | need_mask: false 67 | layer_idx: -2 68 | - 69 | - _target_: transformers.CLIPTextModel.from_pretrained 70 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 71 | subfolder: text_encoder_2 72 | - disable_autocast: false 73 | concat_bucket: 0 74 | use_pooled: true 75 | need_mask: false 76 | layer_idx: -2 77 | zero_for_padding: false 78 | 79 | vae: 80 | _target_: diffusers.AutoencoderKL.from_pretrained 81 | _load_config_: 82 | device: cuda 83 | precision: torch.float16 84 | to_freeze: true 85 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 86 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_region_gbc_encode_with_context_ipa.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-ipa-region-gbc-with-context/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: true 10 | use_region_attention: true 11 | use_layer_attention: false 12 | exclusive_region_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: true 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | 21 | num_samples: 8 22 | padding_mode: cycling 23 | cfg_scale: 6 24 | seed: 1315 25 | num_steps: 24 26 | width: 1024 27 | height: 1024 28 | 29 | train_scheduler: 30 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 31 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 32 | subfolder: scheduler 33 | 34 | model_config: 35 | 36 | unet: 37 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 38 | _load_config_: 39 | device: cuda 40 | precision: torch.float16 41 | to_freeze: true 42 | load_ip_adapter_kwargs: 43 | pretrained_model_name_or_path: h94/IP-Adapter 44 | subfolder: sdxl_models 45 | weight_name: ip-adapter_sdxl.bin 46 | ip_adapter_scale: 1.0 47 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 48 | subfolder: unet 49 | use_region_attention: true 50 | use_caption_mask: true 51 | use_layer_attention: false 52 | use_flex_attention_for_region: false 53 | 54 | te: 55 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 56 | _load_config_: 57 | device: cuda 58 | precision: torch.float16 59 | to_freeze: true 60 | tokenizers: 61 | - openai/clip-vit-large-patch14 62 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 63 | text_model_and_configs: 64 | - 65 | - _target_: transformers.CLIPTextModel.from_pretrained 66 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 67 | subfolder: text_encoder 68 | - disable_autocast: false 69 | concat_bucket: 0 70 | use_pooled: false 71 | need_mask: false 72 | layer_idx: -2 73 | - 74 | - _target_: transformers.CLIPTextModel.from_pretrained 75 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 76 | subfolder: text_encoder_2 77 | - disable_autocast: false 78 | concat_bucket: 0 79 | use_pooled: true 80 | need_mask: false 81 | layer_idx: -2 82 | zero_for_padding: false 83 | 84 | vae: 85 | _target_: diffusers.AutoencoderKL.from_pretrained 86 | _load_config_: 87 | device: cuda 88 | precision: torch.float16 89 | to_freeze: true 90 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 91 | 92 | image_encoder: 93 | _target_: transformers.CLIPVisionModelWithProjection.from_pretrained 94 | _load_config_: 95 | device: cuda 96 | precision: torch.float16 97 | to_freeze: true 98 | pretrained_model_name_or_path: h94/IP-Adapter 99 | subfolder: sdxl_models/image_encoder 100 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_region_gbc_encode_without_context.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-region-gbc-without-context/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: true 10 | use_region_attention: true 11 | use_layer_attention: false 12 | exclusive_region_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: false 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | 21 | num_samples: 8 22 | padding_mode: cycling 23 | cfg_scale: 6 24 | seed: 1315 25 | num_steps: 24 26 | width: 1024 27 | height: 1024 28 | 29 | train_scheduler: 30 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 31 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 32 | subfolder: scheduler 33 | 34 | model_config: 35 | 36 | unet: 37 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 38 | _load_config_: 39 | device: cuda 40 | precision: torch.float16 41 | to_freeze: true 42 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 43 | subfolder: unet 44 | use_region_attention: true 45 | use_caption_mask: true 46 | use_layer_attention: false 47 | use_flex_attention_for_region: false 48 | 49 | te: 50 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 51 | _load_config_: 52 | device: cuda 53 | precision: torch.float16 54 | to_freeze: true 55 | tokenizers: 56 | - openai/clip-vit-large-patch14 57 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 58 | text_model_and_configs: 59 | - 60 | - _target_: transformers.CLIPTextModel.from_pretrained 61 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 62 | subfolder: text_encoder 63 | - disable_autocast: false 64 | concat_bucket: 0 65 | use_pooled: false 66 | need_mask: false 67 | layer_idx: -2 68 | - 69 | - _target_: transformers.CLIPTextModel.from_pretrained 70 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 71 | subfolder: text_encoder_2 72 | - disable_autocast: false 73 | concat_bucket: 0 74 | use_pooled: true 75 | need_mask: false 76 | layer_idx: -2 77 | zero_for_padding: false 78 | 79 | vae: 80 | _target_: diffusers.AutoencoderKL.from_pretrained 81 | _load_config_: 82 | device: cuda 83 | precision: torch.float16 84 | to_freeze: true 85 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 86 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/sampling_region_gbc_encode_without_context_ipa.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-ipa-region-gbc-without-context/ 2 | 3 | image_sampling_func: 4 | 5 | _target_: gbc.t2i.gbc_diffusion_sampling 6 | _partial_: true 7 | 8 | return_flattened: true 9 | return_with_bbox: true 10 | use_region_attention: true 11 | use_layer_attention: false 12 | exclusive_region_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: false 15 | 16 | internal_sampling_func: 17 | _target_: gbc.t2i.sample_euler_ancestral 18 | _partial_: true 19 | eta: 0.0 20 | 21 | num_samples: 8 22 | padding_mode: cycling 23 | cfg_scale: 6 24 | seed: 1315 25 | num_steps: 24 26 | width: 1024 27 | height: 1024 28 | 29 | train_scheduler: 30 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 31 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 32 | subfolder: scheduler 33 | 34 | model_config: 35 | 36 | unet: 37 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 38 | _load_config_: 39 | device: cuda 40 | precision: torch.float16 41 | to_freeze: true 42 | load_ip_adapter_kwargs: 43 | pretrained_model_name_or_path: h94/IP-Adapter 44 | subfolder: sdxl_models 45 | weight_name: ip-adapter_sdxl.bin 46 | ip_adapter_scale: 1.0 47 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 48 | subfolder: unet 49 | use_region_attention: true 50 | use_caption_mask: true 51 | use_layer_attention: false 52 | use_flex_attention_for_region: false 53 | 54 | te: 55 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 56 | _load_config_: 57 | device: cuda 58 | precision: torch.float16 59 | to_freeze: true 60 | tokenizers: 61 | - openai/clip-vit-large-patch14 62 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 63 | text_model_and_configs: 64 | - 65 | - _target_: transformers.CLIPTextModel.from_pretrained 66 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 67 | subfolder: text_encoder 68 | - disable_autocast: false 69 | concat_bucket: 0 70 | use_pooled: false 71 | need_mask: false 72 | layer_idx: -2 73 | - 74 | - _target_: transformers.CLIPTextModel.from_pretrained 75 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 76 | subfolder: text_encoder_2 77 | - disable_autocast: false 78 | concat_bucket: 0 79 | use_pooled: true 80 | need_mask: false 81 | layer_idx: -2 82 | zero_for_padding: false 83 | 84 | vae: 85 | _target_: diffusers.AutoencoderKL.from_pretrained 86 | _load_config_: 87 | device: cuda 88 | precision: torch.float16 89 | to_freeze: true 90 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 91 | 92 | image_encoder: 93 | _target_: transformers.CLIPVisionModelWithProjection.from_pretrained 94 | _load_config_: 95 | device: cuda 96 | precision: torch.float16 97 | to_freeze: true 98 | pretrained_model_name_or_path: h94/IP-Adapter 99 | subfolder: sdxl_models/image_encoder 100 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/secondary/sampling_layer.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-layer-gbc-without-context/ 2 | 3 | 4 | sampling_func: 5 | 6 | _target_: gbc.t2i.gbc_diffusion_sampling 7 | _partial_: true 8 | 9 | return_flattened: true 10 | return_with_bbox: true 11 | use_region_attention: false 12 | use_layer_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: false 15 | 16 | layer_mask_config: 17 | use_edge: true 18 | # the following would be ignored if use_edge is false 19 | attend_from_bbox: true 20 | include_child_to_parent: true 21 | attend_to_bbox: true 22 | 23 | prompt_neg_prop: 0.5 24 | adj_neg_prop: 0 25 | all_neg_prop: 0.5 26 | 27 | internal_sampling_func: 28 | _target_: gbc.t2i.sample_euler_ancestral 29 | _partial_: true 30 | eta: 0.0 31 | 32 | num_samples: 6 33 | padding_mode: cycling 34 | cfg_scale: 6 35 | seed: 1315 36 | num_steps: 24 37 | width: 1024 38 | height: 1024 39 | 40 | train_scheduler: 41 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 42 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 43 | subfolder: scheduler 44 | 45 | model_config: 46 | 47 | unet: 48 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 49 | _load_config_: 50 | device: cuda 51 | precision: torch.float16 52 | to_freeze: true 53 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 54 | subfolder: unet 55 | use_region_attention: false 56 | use_caption_mask: false 57 | use_layer_attention: true 58 | use_flex_attention_for_layer: true 59 | 60 | te: 61 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 62 | _load_config_: 63 | device: cuda 64 | precision: torch.float16 65 | to_freeze: true 66 | tokenizers: 67 | - openai/clip-vit-large-patch14 68 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 69 | text_model_and_configs: 70 | - 71 | - _target_: transformers.CLIPTextModel.from_pretrained 72 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 73 | subfolder: text_encoder 74 | - disable_autocast: false 75 | concat_bucket: 0 76 | use_pooled: false 77 | need_mask: false 78 | layer_idx: -2 79 | - 80 | - _target_: transformers.CLIPTextModel.from_pretrained 81 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 82 | subfolder: text_encoder_2 83 | - disable_autocast: false 84 | concat_bucket: 0 85 | use_pooled: true 86 | need_mask: false 87 | layer_idx: -2 88 | zero_for_padding: false 89 | 90 | vae: 91 | _target_: diffusers.AutoencoderKL.from_pretrained 92 | _load_config_: 93 | device: cuda 94 | precision: torch.float16 95 | to_freeze: true 96 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 97 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/secondary/sampling_layer_region.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-layer-region-gbc-with-context/ 2 | 3 | 4 | sampling_func: 5 | 6 | _target_: gbc.t2i.gbc_diffusion_sampling 7 | _partial_: true 8 | 9 | return_flattened: true 10 | return_with_bbox: true 11 | use_region_attention: true 12 | use_layer_attention: true 13 | labels_in_neg: true 14 | concat_ancestor_prompts: true 15 | 16 | layer_mask_config: 17 | use_edge: true 18 | # the following would be ignored if use_edge is false 19 | attend_from_bbox: true 20 | include_child_to_parent: true 21 | attend_to_bbox: true 22 | 23 | prompt_neg_prop: 0.5 24 | adj_neg_prop: 0 25 | all_neg_prop: 0.5 26 | 27 | internal_sampling_func: 28 | _target_: gbc.t2i.sample_euler_ancestral 29 | _partial_: true 30 | eta: 0.0 31 | 32 | num_samples: 6 33 | padding_mode: cycling 34 | cfg_scale: 6 35 | seed: 1315 36 | num_steps: 24 37 | width: 1024 38 | height: 1024 39 | 40 | train_scheduler: 41 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 42 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 43 | subfolder: scheduler 44 | 45 | model_config: 46 | 47 | unet: 48 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 49 | _load_config_: 50 | device: cuda 51 | precision: torch.float16 52 | to_freeze: true 53 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 54 | subfolder: unet 55 | use_region_attention: true 56 | use_caption_mask: true 57 | use_layer_attention: true 58 | use_flex_attention_for_region: false 59 | use_flex_attention_for_layer: true 60 | 61 | te: 62 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 63 | _load_config_: 64 | device: cuda 65 | precision: torch.float16 66 | to_freeze: true 67 | tokenizers: 68 | - openai/clip-vit-large-patch14 69 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 70 | text_model_and_configs: 71 | - 72 | - _target_: transformers.CLIPTextModel.from_pretrained 73 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 74 | subfolder: text_encoder 75 | - disable_autocast: false 76 | concat_bucket: 0 77 | use_pooled: false 78 | need_mask: false 79 | layer_idx: -2 80 | - 81 | - _target_: transformers.CLIPTextModel.from_pretrained 82 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 83 | subfolder: text_encoder_2 84 | - disable_autocast: false 85 | concat_bucket: 0 86 | use_pooled: true 87 | need_mask: false 88 | layer_idx: -2 89 | zero_for_padding: false 90 | 91 | vae: 92 | _target_: diffusers.AutoencoderKL.from_pretrained 93 | _load_config_: 94 | device: cuda 95 | precision: torch.float16 96 | to_freeze: true 97 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 98 | -------------------------------------------------------------------------------- /configs/generation/gbc2i/secondary/sampling_layer_region_ipa.yaml: -------------------------------------------------------------------------------- 1 | save_dir: ./tests/outputs/generation/gbc2i/sdxl-ipa-with-context/ 2 | 3 | 4 | sampling_func: 5 | 6 | _target_: gbc.t2i.gbc_diffusion_sampling 7 | _partial_: true 8 | 9 | return_flattened: true 10 | return_with_bbox: true 11 | use_region_attention: true 12 | use_layer_attention: true 13 | exclusive_region_attention: true 14 | labels_in_neg: true 15 | concat_ancestor_prompts: true 16 | 17 | internal_sampling_func: 18 | _target_: gbc.t2i.sample_euler_ancestral 19 | _partial_: true 20 | eta: 0.0 21 | 22 | layer_mask_config: 23 | use_edge: true 24 | # the following would be ignored if use_edge is false 25 | attend_from_bbox: true 26 | include_child_to_parent: false 27 | attend_to_bbox: true 28 | 29 | prompt_neg_prop: 1 30 | adj_neg_prop: 0 31 | all_neg_prop: 0 32 | 33 | num_samples: 6 34 | padding_mode: cycling 35 | cfg_scale: 6 36 | seed: 1315 37 | num_steps: 24 38 | width: 1024 39 | height: 1024 40 | 41 | train_scheduler: 42 | _target_: diffusers.EulerDiscreteScheduler.from_pretrained 43 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 44 | subfolder: scheduler 45 | 46 | model_config: 47 | 48 | unet: 49 | _target_: gbc.t2i.modules.unet_patch.GbcUNet2DConditionModel.from_pretrained 50 | _load_config_: 51 | device: cuda 52 | precision: torch.float16 53 | to_freeze: true 54 | load_ip_adapter_kwargs: 55 | pretrained_model_name_or_path: h94/IP-Adapter 56 | subfolder: sdxl_models 57 | weight_name: ip-adapter_sdxl.bin 58 | ip_adapter_scale: 1.0 59 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 60 | subfolder: unet 61 | use_region_attention: true 62 | use_caption_mask: true 63 | use_layer_attention: true 64 | use_flex_attention_for_region: false 65 | use_flex_attention_for_layer: true 66 | 67 | te: 68 | _target_: gbc.t2i.modules.text_encoders.ConcatTextEncoders 69 | _load_config_: 70 | device: cuda 71 | precision: torch.float16 72 | to_freeze: true 73 | tokenizers: 74 | - openai/clip-vit-large-patch14 75 | - laion/CLIP-ViT-bigG-14-laion2B-39B-b160k 76 | text_model_and_configs: 77 | - 78 | - _target_: transformers.CLIPTextModel.from_pretrained 79 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 80 | subfolder: text_encoder 81 | - disable_autocast: false 82 | concat_bucket: 0 83 | use_pooled: false 84 | need_mask: false 85 | layer_idx: -2 86 | - 87 | - _target_: transformers.CLIPTextModel.from_pretrained 88 | pretrained_model_name_or_path: stabilityai/stable-diffusion-xl-base-1.0 89 | subfolder: text_encoder_2 90 | - disable_autocast: false 91 | concat_bucket: 0 92 | use_pooled: true 93 | need_mask: false 94 | layer_idx: -2 95 | zero_for_padding: false 96 | 97 | vae: 98 | _target_: diffusers.AutoencoderKL.from_pretrained 99 | _load_config_: 100 | device: cuda 101 | precision: torch.float16 102 | to_freeze: true 103 | pretrained_model_name_or_path: madebyollin/sdxl-vae-fp16-fix 104 | 105 | image_encoder: 106 | _target_: transformers.CLIPVisionModelWithProjection.from_pretrained 107 | _load_config_: 108 | device: cuda 109 | precision: torch.float16 110 | to_freeze: true 111 | pretrained_model_name_or_path: h94/IP-Adapter 112 | subfolder: sdxl_models/image_encoder 113 | -------------------------------------------------------------------------------- /configs/generation/graph_transform_ex.yaml: -------------------------------------------------------------------------------- 1 | graph_transform: 2 | _target_: gbc.processing.data_transforms.basic_filter_and_extract 3 | _partial_: true 4 | drop_vertex_size_kwargs: 5 | min_rel_width: 0.1 6 | min_rel_height: 0.1 7 | same_level_max_bbox_overlap_ratio: 0.25 8 | max_n_vertices: 10 9 | keep_in_edges: false 10 | -------------------------------------------------------------------------------- /configs/generation/t2gbc_default.yaml: -------------------------------------------------------------------------------- 1 | save_file: tests/outputs/generation/t2gbc/gbc_prompt_gen.json 2 | 3 | prompt_gen_config: 4 | allow_composition: true 5 | star_graph: false 6 | verbose: true 7 | seed: 1015 8 | temperature: 1 9 | top_p: 0.95 10 | top_k: 60 11 | repetition_penalty: 1 12 | max_new_tokens: 4096 13 | torch_dtype: torch.float32 14 | device: cpu 15 | attn_implementation: sdpa 16 | 17 | model_config: 18 | prompt_gen_model_name_or_path: graph-based-captions/GBC10M-PromptGen-200M 19 | -------------------------------------------------------------------------------- /configs/processing/clip_filtering.yaml: -------------------------------------------------------------------------------- 1 | processing_config: 2 | input_paths: [data/gbc/wiki_with_clip_scores/] 3 | input_formats: [.jsonl] 4 | data_transform: 5 | _target_: gbc.processing.data_transforms.gbc_clip_filter 6 | _partial_: true 7 | clip_names: 8 | - dfn5b-h-patch14-378 9 | min_clip_scores: 10 | original: 0.24970313230613642 11 | short-image: 0.23114516488176592 12 | detail-image: 0.19171233770958623 13 | relation: 0.18524515134437888 14 | composition: 0.11810152712453109 15 | short-composition: 0.1397709942144013 16 | detail-entity: 0.15333012089283748 17 | name_transform: 18 | _target_: gbc.processing.data_transforms.string_replace 19 | _partial_: true 20 | old_str: _with_clip 21 | new_str: _clip_filtered 22 | save_format: .parquet 23 | save_dir: tests/outputs/processing/clip_filtering/ 24 | -------------------------------------------------------------------------------- /configs/processing/compute_all_scores.yaml: -------------------------------------------------------------------------------- 1 | processing_config: 2 | input_paths: [data/gbc/wiki/] 3 | input_formats: [.jsonl] 4 | 5 | data_transform: 6 | _target_: gbc.processing.data_transforms.chain_transforms 7 | _args_: 8 | - _target_: gbc.processing.data_transforms.create_list_transform 9 | transform_function: 10 | _target_: gbc.processing.data_transforms.compute_caption_statistics 11 | _partial_: true 12 | - _target_: gbc.processing.data_transforms.compute_toxicity_scores 13 | _partial_: true 14 | model_name: original 15 | device: cuda 16 | - _target_: gbc.processing.data_transforms.compute_clip_scores 17 | _partial_: true 18 | clip_models: 19 | openai-l-patch14-336: 20 | _target_: gbc.processing.data_transforms.clip_scoring.HfClipScoreModel 21 | device: "cuda:0" 22 | clip_model_path: openai/clip-vit-large-patch14-336 23 | dfn5b-h-patch14-378: 24 | _target_: gbc.processing.data_transforms.clip_scoring.OpenClipScoreModel 25 | device: "cuda:0" 26 | clip_model_path: hf-hub:apple/DFN5B-CLIP-ViT-H-14-384 27 | tokenizer_name: ViT-H-14 28 | batch_size: 256 29 | 30 | name_transform: 31 | _target_: gbc.processing.data_transforms.append_string_to_filename 32 | _partial_: true 33 | append_text: _with_toxicity 34 | 35 | save_format: .json 36 | save_dir: tests/outputs/processing/all_scores/ 37 | -------------------------------------------------------------------------------- /configs/processing/compute_clip_scores.yaml: -------------------------------------------------------------------------------- 1 | processing_config: 2 | input_paths: [data/gbc/wiki/] 3 | input_formats: [.jsonl] 4 | 5 | data_transform: 6 | _target_: gbc.processing.data_transforms.compute_clip_scores 7 | _partial_: true 8 | clip_models: 9 | openai-l-patch14-336: 10 | _target_: gbc.processing.data_transforms.clip_scoring.HfClipScoreModel 11 | device: "cuda:0" 12 | clip_model_path: openai/clip-vit-large-patch14-336 13 | dfn5b-h-patch14-378: 14 | _target_: gbc.processing.data_transforms.clip_scoring.OpenClipScoreModel 15 | device: "cuda:0" 16 | clip_model_path: hf-hub:apple/DFN5B-CLIP-ViT-H-14-384 17 | tokenizer_name: ViT-H-14 18 | batch_size: 256 19 | 20 | name_transform: 21 | _target_: gbc.processing.data_transforms.append_string_to_filename 22 | _partial_: true 23 | append_text: _with_clip 24 | 25 | save_format: .jsonl 26 | save_dir: tests/outputs/processing/clip_scoring/ 27 | -------------------------------------------------------------------------------- /configs/processing/compute_toxicity_scores.yaml: -------------------------------------------------------------------------------- 1 | processing_config: 2 | input_paths: [data/gbc/wiki/] 3 | input_formats: [.jsonl] 4 | 5 | data_transform: 6 | _target_: gbc.processing.data_transforms.compute_toxicity_scores 7 | _partial_: true 8 | model_name: original 9 | device: cuda 10 | 11 | name_transform: 12 | _target_: gbc.processing.data_transforms.append_string_to_filename 13 | _partial_: true 14 | append_text: _with_toxicity 15 | 16 | save_format: .json 17 | save_dir: tests/outputs/processing/toxicity_scoring/ 18 | -------------------------------------------------------------------------------- /configs/processing/relation_composition_filtering.yaml: -------------------------------------------------------------------------------- 1 | processing_config: 2 | input_paths: [data/gbc/wiki/] 3 | input_formats: [.jsonl] 4 | data_transform: 5 | _target_: gbc.processing.data_transforms.create_list_transform 6 | transform_function: 7 | _target_: gbc.processing.data_transforms.basic_filter_and_extract 8 | _partial_: true 9 | drop_vertex_size_kwargs: 10 | min_rel_width: 0.1 11 | min_rel_height: 0.1 12 | drop_vertex_types: 13 | - relation 14 | drop_caption_types: 15 | - hardcode 16 | - composition 17 | max_n_vertices: 10 18 | name_transform: 19 | _target_: gbc.processing.data_transforms.append_string_to_filename 20 | _partial_: true 21 | append_text: _relation_filtered_max_10_vertices 22 | save_format: .parquet 23 | save_dir: tests/outputs/processing/filtering/ 24 | -------------------------------------------------------------------------------- /configs/processing/to_structured_text.yaml: -------------------------------------------------------------------------------- 1 | processing_config: 2 | input_paths: [data/gbc/wiki/] 3 | input_formats: [.jsonl] 4 | data_transform: 5 | _target_: gbc.processing.data_transforms.create_list_transform 6 | transform_function: 7 | _target_: gbc.processing.data_transforms.gbc_graph_to_text_and_image 8 | _partial_: true 9 | text_format: structured 10 | graph_traversal_mode: topological 11 | caption_agg_mode_for_structured: fisrt 12 | read_image: false # Object of type Image can not be stored in json etc. 13 | name_transform: 14 | _target_: gbc.processing.data_transforms.append_string_to_filename 15 | _partial_: true 16 | append_text: _structured_text 17 | save_format: .json 18 | save_dir: tests/outputs/processing/filtering/ 19 | -------------------------------------------------------------------------------- /data/gbc/prompt_gen/library_turtle_frog_steamponk.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/gbc/prompt_gen/library_turtle_frog_steamponk.parquet -------------------------------------------------------------------------------- /data/gbc/wiki/wiki_gbc_graphs.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/gbc/wiki/wiki_gbc_graphs.parquet -------------------------------------------------------------------------------- /data/gbc/wiki_pixtral_gdino/wiki_gbc_graphs.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/gbc/wiki_pixtral_gdino/wiki_gbc_graphs.parquet -------------------------------------------------------------------------------- /data/gbc/wiki_with_clip_scores/wiki_gbc_graphs_with_clip.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/gbc/wiki_with_clip_scores/wiki_gbc_graphs_with_clip.parquet -------------------------------------------------------------------------------- /data/images/t2gbc2i/frog.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/t2gbc2i/frog.png -------------------------------------------------------------------------------- /data/images/t2gbc2i/library.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/t2gbc2i/library.png -------------------------------------------------------------------------------- /data/images/t2gbc2i/mechanical_cat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/t2gbc2i/mechanical_cat.png -------------------------------------------------------------------------------- /data/images/t2gbc2i/turtle.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/t2gbc2i/turtle.png -------------------------------------------------------------------------------- /data/images/wiki/Cesenatico.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Cesenatico.jpg -------------------------------------------------------------------------------- /data/images/wiki/Eiffel_tower_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Eiffel_tower_0.jpg -------------------------------------------------------------------------------- /data/images/wiki/Eiffel_tower_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Eiffel_tower_1.jpg -------------------------------------------------------------------------------- /data/images/wiki/Eiffel_tower_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Eiffel_tower_2.jpg -------------------------------------------------------------------------------- /data/images/wiki/Household_utensils.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Household_utensils.jpg -------------------------------------------------------------------------------- /data/images/wiki/Ilish_Bhaat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Ilish_Bhaat.jpg -------------------------------------------------------------------------------- /data/images/wiki/Jean-Francois_Portaels.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Jean-Francois_Portaels.jpg -------------------------------------------------------------------------------- /data/images/wiki/Kinkakuji.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Kinkakuji.jpg -------------------------------------------------------------------------------- /data/images/wiki/Kompoziciya.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Kompoziciya.jpg -------------------------------------------------------------------------------- /data/images/wiki/Launch_of_Shenzhou.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Launch_of_Shenzhou.jpg -------------------------------------------------------------------------------- /data/images/wiki/Magdeburg.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Magdeburg.jpg -------------------------------------------------------------------------------- /data/images/wiki/Maid_of_the_Mist.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Maid_of_the_Mist.jpg -------------------------------------------------------------------------------- /data/images/wiki/Mumbai_Flora_Fountain.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Mumbai_Flora_Fountain.jpg -------------------------------------------------------------------------------- /data/images/wiki/Painting.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Painting.jpg -------------------------------------------------------------------------------- /data/images/wiki/Petani_padi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Petani_padi.jpg -------------------------------------------------------------------------------- /data/images/wiki/README.md: -------------------------------------------------------------------------------- 1 | ## Attribution 2 | 3 | All images in this folder are from [Wikimedia Commons](https://commons.wikimedia.org/wiki/Accueil) and are under their respective licenses. 4 | 5 | - Cesenatico.jpg : https://commons.wikimedia.org/wiki/File:Cesenatico_-_Porto_Canale_(2023).jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 6 | - Eiffel_tower_0.jpg : https://fr.m.wikipedia.org/wiki/Fichier:Eiffel_Tower_from_north_Avenue_de_New_York,_Aug_2010.jpg / [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0) 7 | - Eiffel_tower_1.jpg : https://commons.wikimedia.org/wiki/File:Tour_Eiffel_Wikimedia_Commons.jpg / Public domain 8 | - Eiffel_tower_2.jpg : https://commons.wikimedia.org/wiki/File:Paris_-_The_Eiffel_Tower_in_spring_-_2307.jpg / [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/) 9 | - Household_utensils.jpg : https://commons.wikimedia.org/wiki/File:Household_utensils_-_artwork_-_Malsch_01.jpg / [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0) 10 | - Ilish_Bhaat.jpg : https://commons.wikimedia.org/wiki/File:Ilish_Bhaat.jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 11 | - Jean-Francois_Portaels.jpg : https://commons.wikimedia.org/wiki/File:Jean-Francois_Portaels_-_The_Procession.jpg / Public domain 12 | - Kinkakuji.jpg : https://commons.wikimedia.org/wiki/File:Kinkaku-ji_in_November_2016_-02.jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 13 | - Kompoziciya.jpg : https://commons.wikimedia.org/wiki/File:143266165_1913_Kompoziciya_s_chernuym_treugolnikom_H_m_525_h_525cm_MacDougalls_2013.jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 14 | - Launch_of_Shenzhou.jpg : https://commons.wikimedia.org/wiki/File:Launch_of_Shenzhou_13.jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 15 | - Magdeburg.jpg : https://commons.wikimedia.org/wiki/File:Magdeburg_asv2022-08_img28_F%C3%BCrstenwallpark_monument.jpg / [Free Art License](https://en.wikipedia.org/wiki/Free_Art_License) 16 | - Maid_of_the_Mist.jpg : https://commons.wikimedia.org/wiki/File:Maid_of_the_Mist_-_pot-o-gold.jpg / [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/) 17 | - Mumbai_Flora_Fountain.jpg : https://commons.wikimedia.org/wiki/File:Mumbai_03-2016_72_Flora_Fountain.jpg / [Free Art License](https://en.wikipedia.org/wiki/Free_Art_License) 18 | - Painting.jpg : https://commons.wikimedia.org/wiki/File:39.%D0%9A%D1%80%D0%B0%D1%81%D0%BA%D0%B8_%D0%BE%D1%81%D0%B5%D0%BD%D0%B8_70%D1%8580_%D1%85.%D0%BC..jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 19 | - Petani_padi.jpg : https://commons.wikimedia.org/wiki/File:Petani_padi.jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 20 | - Regalia.jpg : https://commons.wikimedia.org/wiki/File:Crown,_sceptre,_orb_%26_key_of_the_King_of_Sweden_2014.jpg / Public domain 21 | - Tartu_raudteejaama_veetorn.jpg : https://commons.wikimedia.org/wiki/File:Tartu_raudteejaama_veetorn,_2010.JPG / [CC BY-SA 3.0](https://creativecommons.org/licenses/by-sa/3.0/) 22 | - Vat_Pa_Phai_temple.jpg : https://commons.wikimedia.org/wiki/File:Vat_Pa_Phai_temple_with_a_Buddhist_monk,_orange_marigold,_clouds_and_blue_sky,_in_Luang_Prabang.jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 23 | - Wild_horses.jpg : https://commons.wikimedia.org/wiki/File:Wild_horses,_%C5%A0ar_Mountains.jpg / [CC BY-SA 4.0](https://creativecommons.org/licenses/by-sa/4.0) 24 | - corgi.jpg : https://commons.wikimedia.org/wiki/File:Fawn_and_white_Welsh_Corgi_puppy_standing_on_rear_legs_and_sticking_out_the_tongue.jpg / [CC0 1.0 Universal](https://creativecommons.org/publicdomain/zero/1.0/) 25 | - orange_cat.jpg : https://commons.wikimedia.org/wiki/File:Orange_tabby_cat_sitting_on_fallen_leaves-Hisashi-01A.jpg / [CC BY-SA 2.0](https://creativecommons.org/licenses/by-sa/2.0) 26 | -------------------------------------------------------------------------------- /data/images/wiki/Regalia.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Regalia.jpg -------------------------------------------------------------------------------- /data/images/wiki/Tartu_raudteejaama_veetorn.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Tartu_raudteejaama_veetorn.jpg -------------------------------------------------------------------------------- /data/images/wiki/Vat_Pa_Phai_temple.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Vat_Pa_Phai_temple.jpg -------------------------------------------------------------------------------- /data/images/wiki/Wild_horses.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/Wild_horses.jpg -------------------------------------------------------------------------------- /data/images/wiki/corgi.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/corgi.jpg -------------------------------------------------------------------------------- /data/images/wiki/orange_cat.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/data/images/wiki/orange_cat.jpg -------------------------------------------------------------------------------- /prompts/captioning/query_composition.txt: -------------------------------------------------------------------------------- 1 | Please describe the composition of the {} in the bounding boxes, followed by some general descriptions that apply to all {}. The composition should include {} and be based on the following hints (do not mention hints or bounding boxes in the response). 2 | {} -------------------------------------------------------------------------------- /prompts/captioning/query_entity.txt: -------------------------------------------------------------------------------- 1 | Please assess the image focusing on '{}'. Start by confirming its presence with 'Object Present: Yes' or 'Object Present: No'. If present, describe its key features in a detailed caption with at most 50 words. Then, evaluate if any aspects stand out for further emphasis, stating 'Prominent Features: Yes' or 'No' while preferring "Yes". If yes, list a few notable features in brackets, applying [single] or [multiple] as appropriate. Importantly, do not include '{}' in features. Instead, you should break it down. -------------------------------------------------------------------------------- /prompts/captioning/query_image.txt: -------------------------------------------------------------------------------- 1 | Following the instruction, please provide a detailed caption and a concise formatted caption for the given image. Note that it is crucial for you to put elements that can be detected and should be further described in picture in brackets as [object_name][node_type] in the concise formatted caption. -------------------------------------------------------------------------------- /prompts/captioning/query_relation.txt: -------------------------------------------------------------------------------- 1 | Please infer and list of at most {} relations between {} in this images. -------------------------------------------------------------------------------- /prompts/captioning/system_composition.txt: -------------------------------------------------------------------------------- 1 | Your role is to analyze images containing objects within pre-labeled bounding boxes and describe the compositional arrangement of these objects based on provided hints. You will then provide general descriptions that apply to all the objects collectively. 2 | 3 | Input Image Format Explanation: 4 | - The image will feature objects of interest, each enclosed within a bounding box. 5 | - Each bounding box will be numbered centrally to uniquely identify it. 6 | - The objects will be similar in nature (e.g., all dogs) and positioned within a scene. 7 | 8 | Utilizing Hints for Analyzing Composition: 9 | - Begin by reviewing the hints provided regarding the spatial arrangement of the objects. 10 | - These hints may specify the relative positions of objects (e.g., "Object 3 is in the top right corner"). 11 | - Use the hints to guide your description of how the objects relate to each other within their bounding boxes. 12 | 13 | Output Format: 14 | - Composition Description:Start with "Composition:" followed by a description informed by the hints and using the bounding box numbers. This description should elucidate the spatial arrangement of the objects as per the hints. 15 | - General Descriptions: Provide observations that apply to all objects within the specified group, excluding unrelated elements or background details. Preface this section with "General descriptions:". 16 | 17 | Additional Guidelines: 18 | - Describe the spatial arrangement of objects without inferring spatial relations from the sequence of numbers. 19 | - Utilize clear spatial language to articulate the composition. 20 | - The description should reflect the actual visual composition, not the order of numbers in the bounding boxes. 21 | 22 | 23 | Examples: 24 | 25 | Example for 3 Dogs in Bounding Boxes: 26 | 27 | Query Prompt: "Please describe the composition of the 3 dogs in the bounding boxes, followed by some general descriptions that apply to all dogs." 28 | 29 | System Response: 30 | 31 | Composition: Dog 3 is in front, with dog 2 to the left and dog 1 to the right. 32 | General descriptions: 33 | - The three dogs are aligned in a row on the grass. 34 | - They share similar sizes and features, suggesting they may be from the same breed. 35 | 36 | Additional Examples: 37 | 38 | For 5 Flowers in a Garden Bed in Bounding Boxes: 39 | Composition: Flower 4 takes a central position, flanked by flower 2 and flower 3 on either side, while flower 1 and flower 5 bookend the arrangement at the outer edges. 40 | General descriptions: 41 | - Each flower is in full bloom, indicating a peak growing season. 42 | 43 | For 2 Cats in a Window in Bounding Boxes: 44 | Composition: Cat 1 is positioned on the left side of the window sill, while cat 2 is curled up on the right. 45 | General descriptions: 46 | - Both cats are basking in the sunlight coming through the window. 47 | - Their relaxed postures suggest a shared sense of comfort and tranquility. 48 | -------------------------------------------------------------------------------- /prompts/captioning/system_entity.txt: -------------------------------------------------------------------------------- 1 | Your task is to perform an in-depth analysis of a cropped image focusing on a requested object, like a "house". The process involves a step-by-step evaluation to identify the object's presence, describe its features, craft concise captions, and assess any prominent objects. 2 | 3 | Process Overview: 4 | 5 | Verify Object Presence: 6 | - Examine the image to determine if the specified object, or any instance of it, is present. 7 | - State the presence with "Object Present: Yes" or "Object Present: No". 8 | 9 | Provide Appropriate Caption (If Object Is Present): 10 | - Provide a detailed description of the object, focusing solely on its features without reference to other elements in the image. 11 | - The description should contain at most 50 words. 12 | 13 | Assessment of Prominent Objects: 14 | - Evaluate the described features to determine if any stand out for further description and are detectable by an object detection model. This is crucial for complex objects such as 'man', 'woman', 'family', 'couple', 'cat', or 'house', where components or distinctive attributes are significant. For example, when analyzing 'woman', consider elements like dress [single], skirt [single], or hair [single] as prominent features. For simpler objects like 'cup' or 'chair', additional descriptions may not be needed. 15 | 16 | Identification of Prominent Features (If Applicable): 17 | - If there are prominent features identified, list and format these features for potential detection by an object detection model. 18 | - Ensure these features are parts or components of the main object and not the entire object itself. 19 | - Use [single] for unique, standalone items, and [multiple] for features present more than once, such as roof [single] or windows [multiple]. 20 | - Group similar items under a single [multiple] label rather than describing them separately, even if individual descriptions were provided in the detailed caption. For example, multiple distinct windows in a house should be labeled as windows [multiple] rather than individually enumerated. 21 | - For groups like families or couples, identify members separately (e.g., man [single], woman [single]) rather than as a collective unit. This contrasts with grouping similar inanimate objects (e.g., windows [multiple]), where individual distinction isn't necessary. 22 | - Consistency with the caption: Ensure that the features identified as [single] or [multiple] are also mentioned in the caption. 23 | 24 | Example Responses: 25 | 26 | Example 1: Object Not Present 27 | 28 | Object Presence: No 29 | Detailed Caption: N/A 30 | Prominent Features: N/A 31 | Identification of Prominent Features: N/A 32 | 33 | Example 2: Object Present Without Prominent Features (requested object: "cup") 34 | 35 | Object Presence: Yes 36 | Detailed Caption: A simple ceramic cup on a wooden table. The cup has a smooth, unadorned surface and a standard curved handle on one side. 37 | Prominent Features: No 38 | Identification of Prominent Features: N/A 39 | 40 | Example 3: Object Present With Prominent Features (requested object: "family") 41 | 42 | Object Presence: Yes 43 | Detailed Caption: A family of four is captured enjoying a sunny day in the park. The father, in casual attire, is engrossed in setting up a picnic spot, while the mother, donned in a summer dress, is laying out a feast on a blanket. Nearby, two children, a boy and a girl, engage in playful antics; the boy is kicking a football with fervor, and the girl, adorned in a light frock, is gleefully chasing bubbles. 44 | Prominent Features: Yes 45 | Identification of Prominent Features: 46 | - Father: [single] 47 | - Mother: [single] 48 | - Boy: [single] 49 | - Girl: [single] 50 | 51 | Example 4: Object Present With Prominent Features (requested object: "car") 52 | 53 | Object Presence: Yes 54 | Detailed Caption: A vintage car in pristine condition, with shiny chrome bumpers and classic spoke wheels. The car's body is painted in a vibrant red, and the leather interior is visible through the clear windows. A unique hood ornament adorns the front, adding to the car's elegance. 55 | Prominent Features: Yes 56 | Identification of Prominent Features: 57 | - Chrome bumpers: [single] 58 | - Wheels: [multiple] 59 | - Hood ornament: [single] -------------------------------------------------------------------------------- /prompts/captioning/system_relation.txt: -------------------------------------------------------------------------------- 1 | Your role involves analyzing the spatial and direct interactions between pre-identified elements within an image, described through annotations like [beach], [turquoise waters], [people], [shoreline], [lush vegetation]. Your task is to objectively describe how these elements are related or positioned relative to each other within the scene. 2 | 3 | Steps for Identifying Objective Relations: 4 | 1. Review Annotated Elements: Start by examining the list of annotated elements. Understand the nature of each element as it is presented in the image. 5 | 2. Identify Spatial Positions: Determine the spatial positioning of these elements in relation to each other. Focus on direct relationships such as touching, overlapping, or proximity without implying any further interpretation. 6 | 3. Describe Direct Interactions: Look for and describe any direct interactions between elements, such as one element supporting another, blocking, or leading into another. 7 | 4. Format Relations List: Provide your findings as a list of bullet points. Each bullet should detail a direct and observable relationship between two or more elements, using their annotated identifiers for reference. 8 | 9 | Example Relations Based on Annotated Elements: 10 | 11 | For elements: [beach], [turquoise waters], [people], [shoreline], [lush vegetation], you might reply: 12 | 13 | - The [people] are standing on the [beach], with the [lush vegetation] to their left. 14 | - [Turquoise waters] lap against the [beach] at the [shoreline], with [people] scattered along its edge. 15 | - [Lush vegetation] flanks the left side of the [beach], providing a natural border. 16 | - The [shoreline] separates the [beach] from the [turquoise waters]. 17 | - To the right of the [lush vegetation], the [beach] stretches towards the [turquoise waters]. 18 | 19 | For another set of elements: [eagle], [snake], [wings], you might reply: 20 | 21 | - The [eagle] has its [wings] spread above the [snake]. 22 | - The [snake] is positioned below the [eagle]. 23 | - The [eagle]'s claws are near or in contact with the [snake]. 24 | 25 | Guidelines for Reporting Relations: 26 | 1. Ensure descriptions are based solely on visible or directly inferable relationships without adding interpretations or assumptions. 27 | 2. Maintain clarity and precision in articulating the spatial and interactional dynamics between elements. 28 | 3. Stick to objective descriptions that reflect the physical and observable aspects of the elements' relationships. 29 | 4. Only answer the list of bullet points without anything before or after. 30 | 5. Do not include any bullet point with 1 or even 0 elements. 31 | 32 | - Visible Relationships Only: Report relationships that are clearly depicted in the image. If no clear relationships are visible, state "No visible relationships." 33 | - Objective Descriptions: Keep descriptions factual and based solely on what can be seen in the image. 34 | - Avoid Assumptions: Do not infer or assume any relationships that aren't clearly shown in the image. 35 | - Bullet Point Format: Present each observable relationship as a separate bullet point, avoiding any descriptive text not related to the direct relationships. 36 | - No Relation Inference: Refrain from implying relationships or positions that are not explicitly shown. If elements are simply present without any discernible interaction, it is acceptable to say "Elements are present without visible interaction." 37 | - Avoid Single Element Points: Do not include bullet points that mention only one element or have no elements at all. Each bullet point must reference the relationship between two or more elements. -------------------------------------------------------------------------------- /prompts/t2i/banana_apple.yaml: -------------------------------------------------------------------------------- 1 | - prompts: 2 | # Ensure that only first prompt is used in first phase 3 | - Banana and apple arranged in a plate on a rustic wooden table bathed in warm sunlight. 4 | - Red banana on table. 5 | - Yellow apple on table. 6 | bboxes: 7 | - [0, 0, 1, 1] 8 | - [0.3, 0.45, 0.9, 0.9] 9 | - [0.1, 0.3, 0.5, 0.7] 10 | adjacency: 11 | - [1, 2] 12 | - [] 13 | - [] 14 | labeled_adjacency: 15 | - [[1, ["Banana"]], [2, ["apple"]]] 16 | - [] 17 | - [] 18 | -------------------------------------------------------------------------------- /prompts/t2i/banana_apple_graph_only.yaml: -------------------------------------------------------------------------------- 1 | - prompts: 2 | # Ensure that only first prompt is used in first phase 3 | - Banana and apple arranged in a plate on a rustic wooden table bathed in warm sunlight. 4 | - Red banana on table. 5 | - Yellow apple on table. 6 | bboxes: 7 | - [0, 0, 1, 1] 8 | - [0, 0, 0, 0] 9 | - [0, 0, 0, 0] 10 | adjacency: 11 | - [1, 2] 12 | - [] 13 | - [] 14 | labeled_adjacency: 15 | - [[1, ["Banana"]], [2, ["apple"]]] 16 | - [] 17 | - [] 18 | -------------------------------------------------------------------------------- /prompts/t2i/dog_cat_ref_image.yaml: -------------------------------------------------------------------------------- 1 | - prompts: 2 | - A cat and a dog sitting near a lamppost, depicted in a comic style. 3 | - Dog sitting on the ground. 4 | - Cat sitting on the ground. 5 | bboxes: 6 | - [0, 0, 1, 1] 7 | - [0.1, 0.4, 0.45, 0.9] 8 | - [0.55, 0.5, 0.9, 0.9] 9 | adjacency: 10 | - [1, 2] 11 | - [] 12 | - [] 13 | labeled_adjacency: 14 | - [[1, ["dog"]], [2, ["cat"]]] 15 | - [] 16 | - [] 17 | ref_images: 18 | - data/images/wiki/corgi.jpg 19 | - data/images/wiki/orange_cat.jpg 20 | ref_image_idxs: [1, 2] 21 | -------------------------------------------------------------------------------- /prompts/t2i/dog_cat_ref_image_graph_only.yaml: -------------------------------------------------------------------------------- 1 | - prompts: 2 | - A cat and a dog sitting near a lamppost, depicted in a comic style. 3 | - Dog sitting on the ground. 4 | - Cat sitting on the ground. 5 | bboxes: 6 | # Ensure that only first prompt is used in first phase 7 | - [0, 0, 1, 1] 8 | - [0, 0, 0, 0] 9 | - [0, 0, 0, 0] 10 | adjacency: 11 | - [1, 2] 12 | - [] 13 | - [] 14 | labeled_adjacency: 15 | - [[1, ["dog"]], [2, ["cat"]]] 16 | - [] 17 | - [] 18 | ref_images: 19 | - data/images/wiki/corgi.jpg 20 | - data/images/wiki/orange_cat.jpg 21 | ref_image_idxs: [1, 2] 22 | -------------------------------------------------------------------------------- /prompts/t2i/living_room.yaml: -------------------------------------------------------------------------------- 1 | - prompts: 2 | - A cozy living room with a large sofa and a painting hanging on the wall above it. 3 | - The sofa is a plush, deep blue with soft cushions and a textured throw draped over one side. 4 | - The painting depicts a banana and a apple on a wooden table. 5 | - Red banana on table. 6 | - Yellow apple on table. 7 | - A white cushion featuring floral pattern. 8 | bboxes: 9 | - [0, 0, 1, 1] 10 | - [0.15, 0.6, 0.9, 0.9] 11 | - [0.3, 0.1, 0.8, 0.45] 12 | - [0.35, 0.15, 0.6, 0.35] 13 | - [0.55, 0.25, 0.7, 0.4] 14 | - [0.3, 0.55, 0.6, 0.8] 15 | adjacency: 16 | - [1, 2] 17 | - [5] 18 | - [3, 4] 19 | - [] 20 | - [] 21 | - [] 22 | labeled_adjacency: 23 | - [[1, ["sofa"]], [2, ["painting"]]] 24 | - [[5, ["cushions"]]] 25 | - [[3, ["banana"]], [4, ["apple"]]] 26 | - [] 27 | - [] 28 | - [] 29 | -------------------------------------------------------------------------------- /prompts/t2i/living_room_graph_only.yaml: -------------------------------------------------------------------------------- 1 | - prompts: 2 | - A cozy living room with a large sofa and a painting hanging on the wall above it. 3 | - The sofa is a plush, deep blue with soft cushions and a textured throw draped over one side. 4 | - The painting is an abstract piece with vibrant colors, including splashes of red, yellow, and blue. 5 | - The living room has warm, ambient lighting from a nearby floor lamp, casting a soft glow on the sofa and painting. 6 | bboxes: 7 | # Ensure that only first prompt is used in first phase 8 | - [0, 0, 1, 1] 9 | - [0, 0, 0, 0] 10 | - [0, 0, 0, 0] 11 | - [0, 0, 0, 0] 12 | adjacency: 13 | - [1, 2, 3] 14 | - [] 15 | - [] 16 | - [] 17 | labeled_adjacency: 18 | - [[1, ["sofa"]], [2, ["painting"]], [3, ["living room"]]] 19 | - [] 20 | - [] 21 | - [] 22 | -------------------------------------------------------------------------------- /prompts/t2i/neg_default.yaml: -------------------------------------------------------------------------------- 1 | - low quality, worst quality 2 | -------------------------------------------------------------------------------- /prompts/t2i/person_graph_only.yaml: -------------------------------------------------------------------------------- 1 | - prompts: 2 | - A person wearing a t-shirt and shorts on street looking at viewer. 3 | - The t-shirt is white with a small logo on the chest. 4 | - The shorts are navy-blue, knee-length with side pockets and a subtle checkered pattern. 5 | bboxes: 6 | # Ensure that only first prompt is used in first phase 7 | - [0, 0, 1, 1] 8 | - [0, 0, 0, 0] 9 | - [0, 0, 0, 0] 10 | adjacency: 11 | - [1, 2] 12 | - [] 13 | - [] 14 | labeled_adjacency: 15 | - [[1, ["t-shirt"]], [2, ["shorts"]]] 16 | - [] 17 | - [] 18 | -------------------------------------------------------------------------------- /prompts/t2i/t2gbc_seed.txt: -------------------------------------------------------------------------------- 1 | A cozy library room with a large wooden bookshelf, a leather armchair, and a small reading table with an old lamp. 2 | A turtle sunbathing on a log in a quiet pond, with lily pads floating on the water. 3 | A frog in a mystical forest filled with oversized mushrooms. 4 | A steampunk-inspired workshop with gears on the walls and a mechanical cat. 5 | -------------------------------------------------------------------------------- /prompts/t2i/t2gbc_seed.yaml: -------------------------------------------------------------------------------- 1 | - A cozy library room with a large wooden bookshelf, a leather armchair, and a small reading table with an old lamp. 2 | - A turtle sunbathing on a log in a quiet pond, with lily pads floating on the water. 3 | - A frog in a mystical forest filled with oversized mushrooms. 4 | - A steampunk-inspired workshop with gears on the walls and a mechanical cat. 5 | -------------------------------------------------------------------------------- /prompts/t2i/t2gbc_seed_with_entity_specification.yaml: -------------------------------------------------------------------------------- 1 | - A cozy library room with a large wooden [bookshelf], a [leather armchair], and a small reading table with an old [lamp]. 2 | - A [turtle] sunbathing on a [log] in a quiet pond, with lily pads floating on the water. 3 | - A [frog] in a mystical forest filled with oversized mushrooms. 4 | - A steampunk-inspired workshop with gears on the walls and a [mechanical cat]. 5 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [project] 6 | name = "gbc" 7 | version = "0.1.0" 8 | description = "Captioning pipeline and data utilities for graph-based captioning." 9 | readme = "README.md" 10 | license = { file="LICENSE" } 11 | classifiers = [ 12 | "Programming Language :: Python :: 3", 13 | "License :: OSI Approved :: Apple Sample Code License", 14 | ] 15 | requires-python = ">=3.10" 16 | dependencies = [ 17 | "pandas", 18 | "pyarrow", 19 | "pydantic>=2.6.4", 20 | "tqdm>=4.66.2", 21 | "numpy<2.0.0", 22 | "transformers>=4.46.3", 23 | "Pillow>=10.2.0", 24 | "omegaconf>=2.3.0", 25 | "hydra-core>=1.3.2", 26 | ] 27 | 28 | [project.optional-dependencies] 29 | tests = [ 30 | "matplotlib", 31 | "objprint>=0.2.3", 32 | "opencv_python>=4.8.0.76", 33 | "opencv_python_headless>=4.9.0.80", 34 | ] 35 | processing = [ 36 | "torch>=2.2.2", 37 | "nltk", 38 | "detoxify>=0.5.2", 39 | "open_clip_torch>=2.24.0", 40 | ] 41 | t2i = [ 42 | "torch>=2.5.0", # >= 2.5.0 for flex attention 43 | "torchvision", 44 | "lightning", 45 | "einops", 46 | "diffusers>=0.30.0", 47 | "supervision", # For bbox annotation 48 | "scikit-image", # for segmentation 49 | "sentencepiece", # for tokenizer 50 | "accelerate", 51 | ] 52 | captioning = [ 53 | "torch>=2.2.2", 54 | "torchvision>=0.17.2", 55 | "lightning>=2.2.2", 56 | "vllm", 57 | "hbutils>=0.9.3", 58 | "sentence_transformers>=2.5.1", 59 | "scipy>=1.11.4", 60 | "supervision", 61 | ] 62 | 63 | [tool.setuptools] 64 | package-dir = {"" = "src"} 65 | include-package-data = true 66 | -------------------------------------------------------------------------------- /scripts/generation/gbc2i.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | import os 6 | from copy import deepcopy 7 | from omegaconf import OmegaConf 8 | from hydra.utils import instantiate 9 | 10 | from gbc.utils import setup_gbc_logger, load_list_from_file, save_list_to_file 11 | from gbc.data import GbcGraphFull 12 | from gbc.t2i import load_any, GbcPrompt 13 | from gbc.t2i.utils import truncate_or_pad_to_length 14 | 15 | 16 | def read_prompt_file(prompt_file, config): 17 | if prompt_file.endswith(".yaml"): 18 | prompts = OmegaConf.load(prompt_file) 19 | return prompts, [None] * len(prompts) 20 | gbc_graphs = load_list_from_file(prompt_file, GbcGraphFull) 21 | if "graph_transform" in config: 22 | graph_transform = instantiate(config.graph_transform) 23 | gbc_graphs = [graph_transform(gbc_graph) for gbc_graph in gbc_graphs] 24 | prompts = [GbcPrompt.from_gbc_graph(gbc_graph) for gbc_graph in gbc_graphs] 25 | return prompts, gbc_graphs 26 | 27 | 28 | def read_prompt_files(prompt_files, config): 29 | all_prompts = [] 30 | gbc_graphs = [] 31 | for prompt_file in prompt_files: 32 | prompts, graphs = read_prompt_file(prompt_file, config) 33 | all_prompts.extend(prompts) 34 | gbc_graphs.extend(graphs) 35 | return all_prompts, gbc_graphs 36 | 37 | 38 | if __name__ == "__main__": 39 | 40 | parser = argparse.ArgumentParser() 41 | parser.add_argument( 42 | "--configs", 43 | type=str, 44 | nargs="+", 45 | default=["configs/generation/gbc2i/sampling_base.yaml"], 46 | ) 47 | parser.add_argument( 48 | "--prompt_files", 49 | type=str, 50 | nargs="+", 51 | default=["prompts/t2i/t2gbc_seed.yaml"], 52 | ) 53 | parser.add_argument( 54 | "--neg_prompt_files", 55 | type=str, 56 | nargs="+", 57 | default=["prompts/t2i/neg_default.yaml"], 58 | ) 59 | parser.add_argument( 60 | "--save_dir", 61 | type=str, 62 | default=None, 63 | ) 64 | args = parser.parse_args() 65 | 66 | configs = [] 67 | for config in args.configs: 68 | conf = OmegaConf.load(config) 69 | configs.append(conf) 70 | config = OmegaConf.merge(*configs) 71 | 72 | if args.save_dir is None: 73 | assert "save_dir" in config, ( 74 | "'save_dir' must be specified either as a " 75 | "command line argument or in the config file" 76 | ) 77 | args.save_dir = config.save_dir 78 | 79 | pos_prompts, gbc_graphs = read_prompt_files(args.prompt_files, config) 80 | neg_prompts, _ = read_prompt_files(args.neg_prompt_files, config) 81 | 82 | logger = setup_gbc_logger() 83 | unet = load_any(config.model_config.unet) 84 | te = load_any(config.model_config.te) 85 | vae = load_any(config.model_config.vae) 86 | 87 | if "image_encoder" in config.model_config: 88 | image_encoder = load_any(config.model_config.image_encoder) 89 | add_kwargs = {"image_encoder": image_encoder} 90 | else: 91 | add_kwargs = {} 92 | 93 | sampling_func = instantiate(config.image_sampling_func) 94 | images = sampling_func( 95 | unet=unet, 96 | te=te, 97 | vae=vae, 98 | prompts=pos_prompts, 99 | neg_prompts=neg_prompts, 100 | **add_kwargs, 101 | ) 102 | 103 | os.makedirs(args.save_dir, exist_ok=True) 104 | for i, image in enumerate(images): 105 | image.save(os.path.join(args.save_dir, f"{i}.png")) 106 | logger.info(f"Saved {len(images)} images to {args.save_dir}") 107 | 108 | num_samples = config.image_sampling_func.get("num_samples", len(pos_prompts)) 109 | padding_mode = config.image_sampling_func.get("padding_mode", "cycling") 110 | img_idxs = truncate_or_pad_to_length( 111 | list(range(len(gbc_graphs))), num_samples, padding_mode 112 | ) 113 | gbc_graphs_generated = [] 114 | for idx in range(num_samples): 115 | img_path = os.path.join(args.save_dir, f"{idx}.png") 116 | img_url = None 117 | gbc_graph = gbc_graphs[img_idxs[idx]] 118 | if gbc_graph is not None: 119 | gbc_graph = deepcopy(gbc_graph).to_gbc_graph() 120 | gbc_graph.img_path = img_path 121 | gbc_graphs_generated.append(gbc_graph) 122 | if gbc_graphs_generated: 123 | save_list_to_file( 124 | gbc_graphs_generated, 125 | os.path.join(args.save_dir, "gbc_graphs_generated.parquet"), 126 | ) 127 | logger.info("Gbc graphs with paths to generated images saved") 128 | -------------------------------------------------------------------------------- /scripts/processing/process_gbc.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import argparse 5 | import os 6 | from omegaconf import OmegaConf 7 | from hydra.utils import instantiate 8 | 9 | from gbc.utils import setup_gbc_logger 10 | from gbc.data import GbcGraphFull 11 | from gbc.processing import local_process_data 12 | 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser( 16 | description="Convert JSONL/JSON/Parquet files to JSON/JSONL/Parquet." 17 | ) 18 | parser.add_argument( 19 | "--input_paths", 20 | nargs="+", 21 | default=None, 22 | help="List of input files or directories.", 23 | ) 24 | parser.add_argument( 25 | "--input_formats", 26 | nargs="+", 27 | default=None, 28 | help="List of input formats to look for (e.g., .json, .jsonl, .parquet).", 29 | ) 30 | parser.add_argument( 31 | "--save_format", 32 | default=None, 33 | help="Desired output format (e.g., .json, .jsonl, .parquet).", 34 | ) 35 | parser.add_argument( 36 | "--save_dir", 37 | default=None, 38 | help="Directory to save the converted files.", 39 | ) 40 | parser.add_argument( 41 | "--configs", 42 | type=str, 43 | nargs="*", 44 | default=None, 45 | help="List of configs to be used. Latter ones override former ones.", 46 | ) 47 | args = parser.parse_args() 48 | 49 | if args.configs is not None: 50 | configs = [] 51 | for config in args.configs: 52 | conf = OmegaConf.load(config) 53 | configs.append(conf) 54 | config = OmegaConf.merge(*configs) 55 | else: 56 | config = OmegaConf.create() 57 | if "processing_config" in config: 58 | config = config.processing_config 59 | config = instantiate(config) 60 | 61 | for key, value in vars(args).items(): 62 | if key == "configs": 63 | continue 64 | if value is not None: 65 | config[key] = value 66 | assert key in config, f"{key} not found in neither args nor config" 67 | 68 | setup_gbc_logger() 69 | os.makedirs(config.save_dir, exist_ok=True) 70 | 71 | data_transform = config.get("data_transform", None) 72 | name_transform = config.get("name_transform", None) 73 | 74 | local_process_data( 75 | config.input_paths, 76 | save_dir=config.save_dir, 77 | save_format=config.save_format, 78 | input_formats=config.input_formats, 79 | data_class=GbcGraphFull, 80 | data_transform=data_transform, 81 | name_transform=name_transform, 82 | ) 83 | -------------------------------------------------------------------------------- /scripts/setup/download_llava_models.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from huggingface_hub import hf_hub_download 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | hf_hub_download( 10 | repo_id="cmp-nct/llava-1.6-gguf", 11 | filename="mmproj-mistral7b-f16-q6_k.gguf", 12 | local_dir="models/LLaVA-1.6/", 13 | ) 14 | hf_hub_download( 15 | repo_id="cmp-nct/llava-1.6-gguf", 16 | filename="ggml-mistral-q_4_k.gguf", 17 | local_dir="models/LLaVA-1.6/", 18 | ) 19 | hf_hub_download( 20 | repo_id="cmp-nct/llava-1.6-gguf", 21 | filename="mmproj-llava-34b-f16-q6_k.gguf", 22 | local_dir="models/LLaVA-1.6/", 23 | ) 24 | hf_hub_download( 25 | repo_id="cmp-nct/llava-1.6-gguf", 26 | filename="ggml-yi-34b-f16-q_3_k.gguf", 27 | local_dir="models/LLaVA-1.6/", 28 | ) 29 | -------------------------------------------------------------------------------- /scripts/setup/download_yolo_world_models.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from huggingface_hub import hf_hub_download 5 | 6 | 7 | if __name__ == "__main__": 8 | 9 | hf_hub_download( 10 | repo_id="wondervictor/YOLO-World", 11 | filename="yolo_world_v2_x_obj365v1_goldg_cc3mlite_pretrain-8698fbfa.pth", 12 | local_dir="models/yolo_world/", 13 | ) 14 | hf_hub_download( 15 | repo_id="wondervictor/YOLO-World", 16 | filename="yolo_world_v2_l_obj365v1_goldg_cc3mv2_pretrain-2f3a4a22.pth", 17 | local_dir="models/yolo_world/", 18 | ) 19 | -------------------------------------------------------------------------------- /scripts/setup/setup_llava_query.sh: -------------------------------------------------------------------------------- 1 | CMAKE_ARGS="-DLLAMA_CUDA=on" python -m pip install llama-cpp-python==0.2.79 2 | python -m pip install huggingface_hub 3 | python scripts/setup/download_llava_models.py 4 | -------------------------------------------------------------------------------- /scripts/setup/setup_yolo_world_detection.sh: -------------------------------------------------------------------------------- 1 | # Install YOLO-World and get configs 2 | git clone --recursive https://github.com/AILab-CVC/YOLO-World 3 | cd YOLO-World 4 | git checkout b449b98202e931590513c16e4830318be2dde946 5 | python -m pip install . 6 | cd .. 7 | 8 | # https://github.com/open-mmlab/mmdetection/issues/12008 9 | # python -m pip install torch==2.4.0 torchvision==0.19.0 ## Needed if no hacking_mmengine_history 10 | 11 | # There is some issue in mmcv and mmdet version, so use mim for installation 12 | # We get warning "yolo-world 0.1.0 requires mmdet==3.0.0, but you have mmdet 3.3.0 which is incompatible" but this should be fine 13 | # see https://github.com/AILab-CVC/YOLO-World/issues/364 and https://github.com/AILab-CVC/YOLO-World/issues/279 14 | python -m pip install -U openmim 15 | mim install mmcv==2.1.0 16 | mim install mmdet==3.3.0 17 | 18 | python -m pip install huggingface_hub 19 | python scripts/setup/download_yolo_world_models.py 20 | -------------------------------------------------------------------------------- /src/gbc/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/src/gbc/__init__.py -------------------------------------------------------------------------------- /src/gbc/captioning/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .pipeline import ( 5 | GbcPipeline, 6 | run_queries, 7 | resume_captioning, 8 | run_gbc_captioning, 9 | ) 10 | from .auto_actions import ( 11 | AutoImageQuery, 12 | AutoEntityQuery, 13 | AutoRelationQuery, 14 | AutoCompositionQuery, 15 | AutoDetectionActionFromImage, 16 | AutoDetectionActionFromEntity, 17 | ) 18 | from .primitives import get_action_input_from_img_path 19 | -------------------------------------------------------------------------------- /src/gbc/captioning/conversion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/src/gbc/captioning/conversion/__init__.py -------------------------------------------------------------------------------- /src/gbc/captioning/detection/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/src/gbc/captioning/detection/__init__.py -------------------------------------------------------------------------------- /src/gbc/captioning/detection/grounding_dino.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from typing import Optional 5 | from PIL import Image 6 | from functools import cache 7 | 8 | import numpy as np 9 | import torch 10 | from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 11 | 12 | from gbc.utils import get_gbc_logger 13 | from .detection import Detection 14 | 15 | 16 | @cache 17 | def load_grounding_dino( 18 | model_name: str = "IDEA-Research/grounding-dino-tiny", device: Optional[str] = None 19 | ): 20 | logger = get_gbc_logger() 21 | logger.info("Load GroundingDINO model...") 22 | if device is None: 23 | device = "cuda" if torch.cuda.is_available() else "cpu" 24 | processor = AutoProcessor.from_pretrained(model_name) 25 | model = AutoModelForZeroShotObjectDetection.from_pretrained(model_name).to(device) 26 | return processor, model 27 | 28 | 29 | class GroundingDinoDetection(Detection): 30 | """ 31 | Detection wrapper for 32 | `GroudingDINO `_ model. 33 | 34 | .. note:: 35 | The loaded model is cached in memory so that repeated instantiations of 36 | the class with the same parameters would reuse the same model. 37 | 38 | Attributes 39 | ---------- 40 | model_name: str, default="IDEA-Research/grounding-dino-tiny" 41 | The name of the GroundingDINO model to use. 42 | device: Optional[str], default=None 43 | The device to use for inference. 44 | box_threshold: float, default=0.25 45 | The bbox threshold for ``processor.post_process_grounded_object_detection``. 46 | """ 47 | 48 | def __init__( 49 | self, 50 | model_name: str = "IDEA-Research/grounding-dino-tiny", 51 | device: Optional[str] = None, 52 | box_threshold: float = 0.25, 53 | ): 54 | if device is None: 55 | device = "cuda" if torch.cuda.is_available() else "cpu" 56 | self.processor, self.model = load_grounding_dino(model_name, device) 57 | self.device = device 58 | self.box_threshold = box_threshold 59 | 60 | def detect_core( 61 | self, image: np.ndarray, texts: list[str] 62 | ) -> tuple[list[tuple[int, int, int, int]], list[float], list[int]]: 63 | 64 | image = Image.fromarray(image) 65 | texts = [text.strip(".") + "." for text in texts] 66 | inputs = self.processor( 67 | images=[image for _ in range(len(texts))], 68 | text=texts, 69 | return_tensors="pt", 70 | padding=True, 71 | ).to(self.device) 72 | with torch.no_grad(): 73 | outputs = self.model(**inputs) 74 | # The problem of GroundingDINO is that it returns probability token by token 75 | # We encode and compute for each text separately and take the maximum token 76 | # score for each piece of text 77 | results = self.processor.post_process_grounded_object_detection( 78 | outputs, 79 | inputs.input_ids, 80 | box_threshold=self.box_threshold, 81 | # This does not matter as we do not use returned label 82 | text_threshold=self.box_threshold, 83 | target_sizes=[image.size[::-1] for _ in range(len(texts))], 84 | ) 85 | scores = [] 86 | bboxes = [] 87 | labels = [] 88 | for label, result in enumerate(results): 89 | scores.extend(result["scores"].tolist()) 90 | bboxes.extend(result["boxes"].tolist()) 91 | labels.extend([label] * len(result["scores"])) 92 | return bboxes, scores, labels 93 | -------------------------------------------------------------------------------- /src/gbc/captioning/detection/hack_mmengine_registry.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | """ 5 | Workaround to make mmengine work with torch >= 2.5.0 6 | Taken from https://github.com/open-mmlab/mmdetection/issues/12008 7 | """ 8 | 9 | import logging 10 | 11 | from mmengine.registry import Registry 12 | from mmengine.logging import print_log 13 | from typing import Type, Optional, Union, List 14 | 15 | 16 | def _register_module( 17 | self, 18 | module: Type, 19 | module_name: Optional[Union[str, List[str]]] = None, 20 | force: bool = False, 21 | ) -> None: 22 | """Register a module. 23 | 24 | Args: 25 | module (type): Module to be registered. Typically a class or a 26 | function, but generally all ``Callable`` are acceptable. 27 | module_name (str or list of str, optional): The module name to be 28 | registered. If not specified, the class name will be used. 29 | Defaults to None. 30 | force (bool): Whether to override an existing class with the same 31 | name. Defaults to False. 32 | """ 33 | if not callable(module): 34 | raise TypeError(f"module must be Callable, but got {type(module)}") 35 | 36 | if module_name is None: 37 | module_name = module.__name__ 38 | if isinstance(module_name, str): 39 | module_name = [module_name] 40 | for name in module_name: 41 | if not force and name in self._module_dict: 42 | existed_module = self.module_dict[name] 43 | # raise KeyError(f'{name} is already registered in {self.name} ' 44 | # f'at {existed_module.__module__}') 45 | print_log( 46 | f"{name} is already registered in {self.name} " 47 | f"at {existed_module.__module__}. Registration ignored.", 48 | logger="current", 49 | level=logging.INFO, 50 | ) 51 | self._module_dict[name] = module 52 | 53 | 54 | Registry._register_module = _register_module 55 | -------------------------------------------------------------------------------- /src/gbc/captioning/mllm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/src/gbc/captioning/mllm/__init__.py -------------------------------------------------------------------------------- /src/gbc/captioning/mllm/llava/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .llava_base import load_llava_model, llava_query_single 5 | from .llava_queries import * 6 | -------------------------------------------------------------------------------- /src/gbc/captioning/mllm/llava/llava_queries.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import traceback 5 | from typing import Optional, Union 6 | from PIL import Image 7 | 8 | from gbc.utils import get_gbc_logger 9 | 10 | from .llava_base import load_llava_model, llava_query_single 11 | from ..query_prototype import ( 12 | MllmQueryPrototype, 13 | ImageQuery, 14 | EntityQuery, 15 | RelationQuery, 16 | CompositionQuery, 17 | ) 18 | from ...primitives import QueryResult 19 | 20 | 21 | __all__ = [ 22 | "LlavaQueryAction", 23 | "LlavaImageQuery", 24 | "LlavaEntityQuery", 25 | "LlavaCompositionQuery", 26 | "LlavaRelationQuery", 27 | ] 28 | 29 | 30 | class LlavaQueryAction(MllmQueryPrototype): 31 | """ 32 | Base class for all LLaVA query actions. 33 | """ 34 | 35 | def load_model(self, **kwargs): 36 | return load_llava_model(**kwargs) 37 | 38 | def query_prelim( 39 | self, 40 | image: Image.Image, 41 | filled_in_query: Optional[Union[str, list[str], tuple[str]]] = None, 42 | ) -> QueryResult: 43 | try: 44 | query_output = llava_query_single( 45 | self.query_model, 46 | image, 47 | self.query_message, 48 | filled_in_query=filled_in_query, 49 | system_message=self.system_message, 50 | temperature=self.query_kwargs.pop("temperature", 0.1), 51 | **self.query_kwargs, 52 | ) 53 | except Exception as e: 54 | logger = get_gbc_logger() 55 | logger.warning(f"Failed to query: {e}") 56 | traceback.print_exc() 57 | query_output = "" 58 | return query_output 59 | 60 | 61 | class LlavaImageQuery(ImageQuery, LlavaQueryAction): 62 | pass 63 | 64 | 65 | class LlavaEntityQuery(EntityQuery, LlavaQueryAction): 66 | pass 67 | 68 | 69 | class LlavaRelationQuery(RelationQuery, LlavaQueryAction): 70 | pass 71 | 72 | 73 | class LlavaCompositionQuery(CompositionQuery, LlavaQueryAction): 74 | pass 75 | -------------------------------------------------------------------------------- /src/gbc/captioning/mllm/pixtral/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .pixtral_base import load_pixtral_model, pixtral_query_single 5 | from .pixtral_queries import * 6 | -------------------------------------------------------------------------------- /src/gbc/captioning/mllm/pixtral/pixtral_base.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from functools import cache 5 | from typing import Optional, Union 6 | from PIL import Image 7 | 8 | from vllm import LLM, SamplingParams 9 | 10 | from gbc.utils import get_gbc_logger 11 | 12 | 13 | @cache 14 | def load_pixtral_model( 15 | model_name: str = "nm-testing/pixtral-12b-FP8-dynamic", 16 | max_num_seqs: int = 1, 17 | enforce_eager: bool = True, 18 | max_model_len: int = 8192, 19 | **kwargs, 20 | ): 21 | logger = get_gbc_logger() 22 | logger.info(f"Load Pixtral model from {model_name} ...") 23 | llm = LLM( 24 | model=model_name, 25 | max_num_seqs=max_num_seqs, 26 | enforce_eager=enforce_eager, 27 | max_model_len=max_model_len, 28 | **kwargs, 29 | ) 30 | return llm 31 | 32 | 33 | def pixtral_query_single( 34 | model: LLM, 35 | image: Image.Image, 36 | query: str, 37 | filled_in_query: Optional[Union[str, list[str], tuple[str]]] = None, 38 | system_message: Optional[str] = None, 39 | temperature: float = 0.1, 40 | max_tokens: Optional[int] = None, 41 | verbose: bool = False, 42 | **kwargs, 43 | ): 44 | if filled_in_query: 45 | if isinstance(filled_in_query, tuple) or isinstance(filled_in_query, list): 46 | query = query.format(*filled_in_query) 47 | else: 48 | query = query.format(filled_in_query) 49 | if system_message is not None: 50 | query = system_message + "\n" + query 51 | 52 | if verbose: 53 | logger = get_gbc_logger() 54 | logger.debug(f"Filled in query: {filled_in_query}") 55 | logger.debug(f"Query: {query}") 56 | 57 | inputs = { 58 | "prompt": f"[INST]{query}.\n[IMG][/INST]", 59 | "multi_modal_data": {"image": [image]}, 60 | } 61 | sampling_params = SamplingParams( 62 | temperature=temperature, max_tokens=max_tokens, **kwargs 63 | ) 64 | outputs = model.generate(inputs, sampling_params=sampling_params, use_tqdm=False) 65 | output = outputs[0].outputs[0].text 66 | 67 | return output 68 | -------------------------------------------------------------------------------- /src/gbc/captioning/mllm/pixtral/pixtral_queries.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import traceback 5 | from typing import Optional, Union 6 | from PIL import Image 7 | 8 | from gbc.utils import get_gbc_logger 9 | 10 | from .pixtral_base import load_pixtral_model, pixtral_query_single 11 | from ..query_prototype import ( 12 | MllmQueryPrototype, 13 | ImageQuery, 14 | EntityQuery, 15 | RelationQuery, 16 | CompositionQuery, 17 | ) 18 | from ...primitives import QueryResult 19 | 20 | 21 | __all__ = [ 22 | "PixtralQueryAction", 23 | "PixtralImageQuery", 24 | "PixtralEntityQuery", 25 | "PixtralCompositionQuery", 26 | "PixtralRelationQuery", 27 | ] 28 | 29 | 30 | class PixtralQueryAction(MllmQueryPrototype): 31 | """ 32 | Base class for all Pixtral query actions. 33 | """ 34 | 35 | def load_model(self, **kwargs): 36 | return load_pixtral_model(**kwargs) 37 | 38 | def query_prelim( 39 | self, 40 | image: Image.Image, 41 | filled_in_query: Optional[Union[str, list[str], tuple[str]]] = None, 42 | ) -> QueryResult: 43 | try: 44 | query_output = pixtral_query_single( 45 | self.query_model, 46 | image, 47 | self.query_message, 48 | filled_in_query=filled_in_query, 49 | system_message=self.system_message, 50 | temperature=self.query_kwargs.pop("temperature", 0.1), 51 | **self.query_kwargs, 52 | ) 53 | except Exception as e: 54 | logger = get_gbc_logger() 55 | logger.warning(f"Failed to query: {e}") 56 | traceback.print_exc() 57 | query_output = "" 58 | return query_output 59 | 60 | 61 | class PixtralImageQuery(ImageQuery, PixtralQueryAction): 62 | pass 63 | 64 | 65 | class PixtralEntityQuery(EntityQuery, PixtralQueryAction): 66 | pass 67 | 68 | 69 | class PixtralRelationQuery(RelationQuery, PixtralQueryAction): 70 | pass 71 | 72 | 73 | class PixtralCompositionQuery(CompositionQuery, PixtralQueryAction): 74 | pass 75 | -------------------------------------------------------------------------------- /src/gbc/captioning/pipeline/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .pipeline import GbcPipeline 5 | from .pipeline_functional import * 6 | -------------------------------------------------------------------------------- /src/gbc/captioning/pipeline/pipeline_functional.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | """ 5 | This module implements the functional interface of the GBC pipeline. 6 | The arguments of :class:`~gbc.captioning.pipeline.pipeline.GbcPipeline` can be 7 | passed as keyword arguments to the functions defined in this module. 8 | """ 9 | 10 | from typing import Optional, Union 11 | from omegaconf import DictConfig 12 | 13 | from .pipeline import GbcPipeline 14 | from ..primitives import Action, ActionInputPair, NodeInfo 15 | 16 | 17 | def run_gbc_captioning( 18 | img_files_or_folders: Union[str, list[str]], 19 | captioning_cfg: DictConfig, 20 | *, 21 | attempt_resume: bool = True, 22 | return_raw_results: bool = False, 23 | **kwargs, 24 | ): 25 | """ 26 | Functional wrapper for 27 | :meth:`gbc.captioning.pipeline.pipeline.GbcPipeline.run_gbc_captioning`. 28 | """ 29 | gbc_pipeline = GbcPipeline.from_config(captioning_cfg, **kwargs) 30 | return gbc_pipeline.run_gbc_captioning( 31 | img_files_or_folders, 32 | attempt_resume=attempt_resume, 33 | return_raw_results=return_raw_results, 34 | ) 35 | 36 | 37 | def run_image_entity_captioning( 38 | img_files_or_folders: Union[str, list[str]], 39 | captioning_cfg: DictConfig, 40 | *, 41 | node_infos: list[NodeInfo] = None, 42 | completed_actions: list[ActionInputPair] = None, 43 | tqdm_desc: Optional[str] = None, 44 | return_raw_results: bool = False, 45 | **kwargs, 46 | ): 47 | """ 48 | Functional wrapper for 49 | :meth:`gbc.captioning.pipeline.pipeline.GbcPipeline.run_image_entity_captioning`. 50 | """ 51 | gbc_pipeline = GbcPipeline.from_config(captioning_cfg, **kwargs) 52 | return gbc_pipeline.run_image_entity_captioning( 53 | img_files_or_folders, 54 | node_infos=node_infos, 55 | completed_actions=completed_actions, 56 | tqdm_desc=tqdm_desc, 57 | return_raw_results=return_raw_results, 58 | ) 59 | 60 | 61 | def run_relational_captioning( 62 | node_infos: list[NodeInfo], 63 | captioning_cfg: DictConfig, 64 | *, 65 | completed_actions: list[Action] = None, 66 | tqdm_desc: Optional[str] = None, 67 | return_raw_results: bool = False, 68 | **kwargs, 69 | ): 70 | """ 71 | Functional wrapper for 72 | :meth:`gbc.captioning.pipeline.pipeline.GbcPipeline.run_relational_captioning`. 73 | """ 74 | gbc_pipeline = GbcPipeline.from_config(captioning_cfg, **kwargs) 75 | return gbc_pipeline.run_relational_captioning( 76 | node_infos=node_infos, 77 | completed_actions=completed_actions, 78 | tqdm_desc=tqdm_desc, 79 | return_raw_results=return_raw_results, 80 | ) 81 | 82 | 83 | def resume_captioning( 84 | save_dir: str, 85 | captioning_cfg: DictConfig, 86 | *, 87 | recursive: bool = True, 88 | return_raw_results: bool = False, 89 | **kwargs, 90 | ): 91 | """ 92 | Functional wrapper for 93 | :meth:`gbc.captioning.pipeline.pipeline.GbcPipeline.resume_captioning`. 94 | """ 95 | gbc_pipeline = GbcPipeline.from_config(captioning_cfg, save_dir=save_dir, **kwargs) 96 | return gbc_pipeline.resume_captioning( 97 | recursive=recursive, 98 | return_raw_results=return_raw_results, 99 | ) 100 | 101 | 102 | def run_queries( 103 | action_input_pairs: list[ActionInputPair], 104 | captioning_cfg: DictConfig, 105 | *, 106 | node_infos: Optional[list[NodeInfo]] = None, 107 | completed_actions: Optional[list[ActionInputPair]] = None, 108 | recursive: bool = True, 109 | init_queried_nodes_from_node_infos: bool = True, 110 | tqdm_desc: Optional[str] = None, 111 | return_raw_results: bool = False, 112 | **kwargs, 113 | ): 114 | """ 115 | Functional wrapper for 116 | :meth:`gbc.captioning.pipeline.pipeline.GbcPipeline.run_queries`. 117 | """ 118 | gbc_pipeline = GbcPipeline.from_config(captioning_cfg, **kwargs) 119 | return gbc_pipeline.run_queries( 120 | action_input_pairs, 121 | node_infos=node_infos, 122 | completed_actions=completed_actions, 123 | recursive=recursive, 124 | init_queried_nodes_from_node_infos=init_queried_nodes_from_node_infos, 125 | tqdm_desc=tqdm_desc, 126 | return_raw_results=return_raw_results, 127 | ) 128 | -------------------------------------------------------------------------------- /src/gbc/captioning/primitives/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .io_unit import * 5 | from .action import * 6 | from .action_io import * 7 | -------------------------------------------------------------------------------- /src/gbc/captioning/primitives/io_unit.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import re 5 | from typing import Optional, Literal 6 | from pydantic import BaseModel 7 | 8 | 9 | from gbc.data import Description 10 | 11 | 12 | __all__ = ["EntityInfo", "QueryResult", "RefPosition", "find_ref_poss"] 13 | 14 | 15 | class RefPosition(BaseModel): 16 | """ 17 | Represents a reference to a specific segment within a list of text descriptions. 18 | 19 | Attributes 20 | ---------- 21 | desc_index : int 22 | Index of the target description within the list. 23 | start : int 24 | Start position of the text segment within the target description. 25 | end : int 26 | End position of the text segment within the target description. 27 | """ 28 | 29 | desc_index: int 30 | start: int 31 | end: int 32 | 33 | 34 | def find_ref_poss(desc_texts: list[str], entity_text: str) -> list[RefPosition]: 35 | """ 36 | Find all positions of the entity text within the provided description texts. 37 | 38 | Parameters 39 | ---------- 40 | desc_texts 41 | List of description texts to search within. 42 | entity_text 43 | The entity text to find within the description texts. 44 | 45 | Returns 46 | ------- 47 | list[RefPosition] 48 | List of reference positions where the entity text occurs in the 49 | description texts. 50 | """ 51 | ref_poss = [] 52 | # Create a pattern that matches the entity text only 53 | # if it's surrounded by word boundaries 54 | pattern = r"\b" + re.escape(entity_text) + r"\b" 55 | 56 | for idx, desc_text in enumerate(desc_texts): 57 | # Find all occurrences of the entity text in the description text, 58 | # considering word boundaries 59 | for match in re.finditer(pattern, desc_text, re.IGNORECASE): 60 | start = match.start() 61 | end = match.end() 62 | ref_pos = RefPosition(desc_index=idx, start=start, end=end) 63 | ref_poss.append(ref_pos) 64 | 65 | return ref_poss 66 | 67 | 68 | class EntityInfo(BaseModel): 69 | """ 70 | Stores information about entities including their label and text. 71 | 72 | Attributes 73 | ---------- 74 | label : Literal["image", "entity", "single", "multiple", "relation", "composition"] 75 | Label indicating the type of entity. 76 | 77 | - In ``action_input.entity_info``, this translates to node label. 78 | - In ``action_input.entities``, the label ``single`` and ``multiple`` affect 79 | how detected bounding boxes are post processed. 80 | text : str | None 81 | Text associated with the entity. 82 | The texts from ``entities`` in 83 | :class:`~gbc.captioning.primitives.io_unit.QueryResult` 84 | are later parsed as edge labels. 85 | entity_id : str 86 | Identifier for the entity. 87 | """ 88 | 89 | label: Literal["image", "entity", "single", "multiple", "relation", "composition"] 90 | text: Optional[str] = None 91 | entity_id: str = "" 92 | 93 | 94 | class QueryResult(BaseModel): 95 | """ 96 | Represents the result of a query, including descriptions and entities. 97 | 98 | Attributes 99 | ---------- 100 | descs : list of tuple of (Description, list of str) 101 | List of descriptions and their associated entity ids. 102 | This second part is only useful for relation queries, in which case 103 | it is used to indicate which entities are involved in the relation. 104 | entities : list of tuple of (EntityInfo, list of RefPosition) 105 | List of entities and their reference positions within the descriptions. 106 | raw : str | None 107 | Raw query result data. 108 | """ 109 | 110 | descs: list[tuple[Description, list[str]]] 111 | entities: list[tuple[EntityInfo, list[RefPosition]]] 112 | raw: Optional[str] = None 113 | -------------------------------------------------------------------------------- /src/gbc/data/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .bbox import Bbox 5 | from .caption import Description, Caption 6 | from .graph import GbcGraph, GbcGraphFull 7 | -------------------------------------------------------------------------------- /src/gbc/data/bbox/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .bbox import * 5 | -------------------------------------------------------------------------------- /src/gbc/data/bbox/annotate.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import numpy as np 5 | import supervision as sv 6 | 7 | from .bbox import Bbox 8 | 9 | 10 | def annotate_bboxes( 11 | image: np.ndarray, 12 | bboxes: list[Bbox], 13 | include_labels: bool = True, 14 | ): 15 | image_width, image_height = image.shape[1], image.shape[0] 16 | ref_length = (image_width + image_height) / 2 17 | thickness = int(ref_length / 100) 18 | text_thickness = int(thickness / 2) 19 | text_scale = ref_length / 500 20 | bounding_box_annotator = sv.BoxAnnotator(thickness=thickness, color=sv.Color.GREEN) 21 | label_annotator = sv.LabelAnnotator( 22 | text_thickness=text_thickness, 23 | text_position=sv.Position.CENTER, 24 | color=sv.Color.GREEN, 25 | text_scale=text_scale, 26 | text_padding=thickness, 27 | text_color=sv.Color.BLACK, 28 | ) 29 | xyxy = [] 30 | confidence = [] 31 | for bbox in bboxes: 32 | if not isinstance(bbox, Bbox): 33 | bbox = Bbox.model_validate(bbox) 34 | xyxy.append( 35 | ( 36 | bbox.left * image_width, 37 | bbox.top * image_height, 38 | bbox.right * image_width, 39 | bbox.bottom * image_height, 40 | ) 41 | ) 42 | confidence.append(bbox.confidence) 43 | detections = sv.Detections( 44 | xyxy=np.array(xyxy), 45 | confidence=np.array(confidence), 46 | class_id=np.zeros(len(bboxes), dtype=int), 47 | ) 48 | annotated_image = image.copy() 49 | annotated_image = bounding_box_annotator.annotate(annotated_image, detections) 50 | if include_labels: 51 | annotated_image = label_annotator.annotate( 52 | annotated_image, 53 | detections, 54 | labels=[str(i) for i in range(1, len(detections) + 1)], 55 | ) 56 | return annotated_image 57 | 58 | 59 | def annotate_all_labels(image: np.ndarray, labeled_bboxes: list[tuple[str, Bbox]]): 60 | if len(labeled_bboxes) == 0: 61 | return image 62 | image_width, image_height = image.shape[1], image.shape[0] 63 | ref_length = min(image_width, image_height) 64 | thickness = int(ref_length / 100) 65 | text_thickness = int(thickness / 2) 66 | text_scale = ref_length / 800 67 | bounding_box_annotator = sv.BoxAnnotator(thickness=thickness) 68 | label_annotator = sv.LabelAnnotator( 69 | text_thickness=text_thickness, 70 | text_scale=text_scale, 71 | text_padding=thickness, 72 | text_color=sv.Color.BLACK, 73 | text_position=sv.Position.TOP_RIGHT, 74 | ) 75 | xyxy = [] 76 | labels = [] 77 | confidence = [] 78 | class_id = [] 79 | text_to_id = dict() 80 | current_id = 0 81 | for text, bbox in labeled_bboxes: 82 | if not isinstance(bbox, Bbox): 83 | bbox = Bbox.model_validate(bbox) 84 | xyxy.append( 85 | ( 86 | bbox.left * image_width, 87 | bbox.top * image_height, 88 | bbox.right * image_width, 89 | bbox.bottom * image_height, 90 | ) 91 | ) 92 | confidence.append(bbox.confidence) 93 | if text not in text_to_id: 94 | text_to_id[text] = current_id 95 | current_id += 1 96 | labels.append(text) 97 | class_id.append(text_to_id[text]) 98 | detections = sv.Detections( 99 | xyxy=np.array(xyxy), 100 | confidence=np.array(confidence), 101 | class_id=np.array(class_id), 102 | ) 103 | annotated_image = image.copy() 104 | annotated_image = bounding_box_annotator.annotate(annotated_image, detections) 105 | annotated_image = label_annotator.annotate( 106 | annotated_image, detections, labels=labels 107 | ) 108 | return annotated_image 109 | -------------------------------------------------------------------------------- /src/gbc/data/caption/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .caption import * 5 | -------------------------------------------------------------------------------- /src/gbc/data/graph/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .gbc_graph import GbcGraph, GbcEdge, GbcVertex 5 | from .gbc_graph_full import GbcGraphFull, GbcVertexFull 6 | -------------------------------------------------------------------------------- /src/gbc/data/graph/gbc_graph.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import os 5 | from PIL import Image 6 | from typing import Literal, Optional 7 | from pydantic import BaseModel 8 | 9 | from gbc.utils import ImageCache 10 | from ..bbox import Bbox 11 | from ..caption import Description 12 | 13 | 14 | class GbcEdge(BaseModel): 15 | # Source and targets are vertices id 16 | source: str 17 | text: str 18 | target: str 19 | 20 | 21 | class GbcVertex(BaseModel): 22 | vertex_id: str 23 | bbox: Bbox 24 | label: Literal["image", "entity", "composition", "relation"] 25 | descs: list[Description] 26 | in_edges: list[GbcEdge] = [] 27 | out_edges: list[GbcEdge] = [] 28 | 29 | 30 | class GbcGraph(BaseModel): 31 | 32 | vertices: list[GbcVertex] 33 | img_url: Optional[str] = None 34 | img_path: Optional[str] = None 35 | original_caption: Optional[str] = None 36 | short_caption: Optional[str] = None 37 | detail_caption: Optional[str] = None 38 | 39 | _image_cache: Optional[ImageCache] = None 40 | 41 | def model_post_init(self, context): 42 | for vertex in self.vertices: 43 | if vertex.label == "image": 44 | for desc in vertex.descs: 45 | if desc.label == "original" and self.original_caption is None: 46 | self.original_caption = desc.text 47 | elif desc.label == "short" and self.short_caption is None: 48 | self.short_caption = desc.text 49 | elif desc.label == "detail" and self.detail_caption is None: 50 | self.detail_caption = desc.text 51 | break 52 | 53 | def get_image(self, img_root_dir: str = "") -> Optional[Image.Image]: 54 | if self._image_cache is None: 55 | self._image_cache = ImageCache( 56 | img_path=os.path.join(img_root_dir, self.img_path), 57 | img_url=self.img_url, 58 | ) 59 | return self._image_cache.get_image() 60 | -------------------------------------------------------------------------------- /src/gbc/processing/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .local_process import local_process_data 5 | from .meta_process import meta_process_data 6 | -------------------------------------------------------------------------------- /src/gbc/processing/data_transforms/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .basic_transforms import * 5 | from .function_transforms import * 6 | 7 | try: 8 | from .clip_scoring import compute_clip_scores 9 | from .clip_filtering import gbc_clip_filter 10 | from .toxicity_scoring import compute_toxicity_scores 11 | except ModuleNotFoundError: 12 | pass 13 | -------------------------------------------------------------------------------- /src/gbc/processing/data_transforms/function_transforms.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from collections.abc import Callable 5 | from functools import reduce 6 | 7 | 8 | def create_list_transform(transform_function: Callable) -> Callable: 9 | """ 10 | Returns a function that applies a given transformation function 11 | to a list of objects. 12 | 13 | Parameters 14 | ---------- 15 | transform_function 16 | A function to apply to each element of a list. 17 | 18 | Returns 19 | ------- 20 | Callable 21 | A function that takes a list of objects and returns a new list of 22 | transformed objects. 23 | """ 24 | 25 | def list_transform(obj_list): 26 | return list(map(transform_function, obj_list)) 27 | 28 | return list_transform 29 | 30 | 31 | def chain_transforms(*transform_functions: Callable) -> Callable: 32 | """ 33 | Returns a function that chains multiple transformation functions together, 34 | applying them sequentially to an input. 35 | 36 | Parameters 37 | ---------- 38 | *transform_functions : Callable 39 | A variable number of functions to apply sequentially. 40 | 41 | Returns 42 | ------- 43 | Callable 44 | A function that takes an object and applies the chain of transformations 45 | to it in order. 46 | """ 47 | 48 | def chained_transform(input_obj): 49 | return reduce(lambda obj, func: func(obj), transform_functions, input_obj) 50 | 51 | return chained_transform 52 | -------------------------------------------------------------------------------- /src/gbc/processing/data_transforms/toxicity_scoring.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | from tqdm import tqdm 6 | from functools import cache 7 | from detoxify import Detoxify 8 | 9 | from gbc.data.graph import GbcGraph, GbcGraphFull 10 | 11 | 12 | def compute_toxicity_scores( 13 | gbc_graphs: list[GbcGraph], model_name: str = "original", device: str | None = None 14 | ) -> list[GbcGraphFull]: 15 | """ 16 | Compute toxicity scores for GBC graphs. 17 | 18 | This function calculates toxicity scores for all captions associated with 19 | vertices in the provided GBC graphs. 20 | The scores are computed using the specified toxicity model. 21 | 22 | Parameters 23 | ---------- 24 | gbc_graphs 25 | A list of GBC graphs for which toxicity scores need to be computed. 26 | If a :class:`~GbcGraph` is passed, 27 | a new :class:`~GbcGraphFull` instance is created. 28 | model_name 29 | The name of the toxicity model to use for computing scores. 30 | Default is ``original``. 31 | device 32 | The device to run the toxicity model on (e.g., ``cpu`` or ``cuda``). 33 | 34 | Returns 35 | ------- 36 | list of GbcGraphFull 37 | A list of :class:`~GbcGraphFull` objects with updated toxicity scores 38 | for captions. 39 | 40 | Notes 41 | ----- 42 | - If an input ``gbc_graph`` is already in :class:`~GbcGraphFull`, 43 | the modifications will be in-place. 44 | - If a caption already has toxicity scores computed, they are retained. 45 | """ 46 | 47 | gbc_graphs_with_toxicity_scores = [] 48 | 49 | for gbc_graph in tqdm(gbc_graphs, desc="Computing toxicity scores"): 50 | if not isinstance(gbc_graph, GbcGraphFull): 51 | gbc_graph = GbcGraphFull.from_gbc_graph(gbc_graph) 52 | for vertex in gbc_graph.vertices: 53 | for caption in vertex.descs: 54 | if caption.toxicity_scores is None or len(caption.toxicity_scores) == 0: 55 | caption.toxicity_scores = compute_toxicity( 56 | caption.text, model_name, device 57 | ) 58 | gbc_graphs_with_toxicity_scores.append(gbc_graph) 59 | 60 | return gbc_graphs_with_toxicity_scores 61 | 62 | 63 | @cache 64 | def get_toxicity_model(model_name="original", device=None): 65 | if device is None: 66 | device = "cuda" if torch.cuda.is_available() else "cpu" 67 | model = Detoxify(model_name, device=device) 68 | return model 69 | 70 | 71 | def compute_toxicity(caption: str, model_name="original", device=None): 72 | model = get_toxicity_model(model_name, device) 73 | scores = model.predict(caption) 74 | # from float32 to float 75 | scores = {key: float(value) for key, value in scores.items()} 76 | return scores 77 | -------------------------------------------------------------------------------- /src/gbc/processing/local_process.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import os 5 | from pathlib import Path 6 | from typing import Callable, Optional, Type 7 | 8 | from gbc.utils import load_list_from_file, save_list_to_file, get_files_recursively 9 | from .meta_process import meta_process_data 10 | 11 | 12 | def local_process_data( 13 | inputs: list[str], 14 | save_dir: str, 15 | save_format: str, 16 | *, 17 | input_formats: Optional[list[str]] = None, 18 | data_class: Optional[Type] = None, 19 | data_transform: Optional[Callable] = None, 20 | name_transform: Optional[Callable[[str], str]] = None, 21 | ): 22 | """ 23 | Processes local data files, optionally transforming and combining them, 24 | and saves the results. 25 | 26 | This function handles data processing for a list of input paths, where 27 | each path can be a file or a directory. 28 | It supports loading, transforming, and combining data, then saving 29 | the results to a specified directory in a specified format. 30 | Both the loaded data and the saved data list of dictionaries. 31 | 32 | Parameters 33 | ---------- 34 | inputs 35 | List of file or directory paths to be processed. 36 | save_dir 37 | Directory where the processed files will be saved. 38 | save_format 39 | Format in which to save the processed files (".json", ".jsonl", or ".parquet"). 40 | input_formats 41 | List of acceptable input file formats (e.g., [".json", ".jsonl"]). 42 | Defaults to ``None`` which means all formats are accepted. 43 | data_class 44 | Class type to validate and load each data item. 45 | Defaults to ``None`` which means the raw dictionary is used. 46 | data_transform 47 | Function to transform the loaded data. 48 | If ``None``, no transformation is applied. 49 | name_transform 50 | Function to transform the name from the input file to the name 51 | of the output file. 52 | 53 | Raises 54 | ------ 55 | ValueError 56 | If the input file path does not exist or is not a file or directory. 57 | """ 58 | 59 | def save_callback(data_list: list, input_file: str, input: str): 60 | # The case when input is a file 61 | if input_file == input: 62 | input_rel_path = os.path.basename(input_file) 63 | # The case when input is a folder 64 | else: 65 | input_rel_path = os.path.relpath(input_file, input) 66 | if name_transform is not None: 67 | input_rel_path = name_transform(input_rel_path) 68 | save_path = str(Path(save_dir) / Path(input_rel_path).with_suffix(save_format)) 69 | save_list_to_file(data_list, save_path, exclude_none=False) 70 | 71 | def get_input_files(input_path: str) -> list[str]: 72 | return get_files_recursively(input_path, input_formats=input_formats) 73 | 74 | def load_callback(input_path: str) -> list: 75 | return load_list_from_file(input_path, data_class) 76 | 77 | meta_process_data( 78 | inputs, 79 | save_callback, 80 | get_input_files=get_input_files, 81 | load_callback=load_callback, 82 | data_transform=data_transform, 83 | data_combine=None, 84 | ) 85 | -------------------------------------------------------------------------------- /src/gbc/processing/meta_process.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from typing import Callable, Optional, TypeVar 5 | 6 | from gbc.utils import load_list_from_file, get_gbc_logger 7 | 8 | 9 | InputType = TypeVar("InputType") 10 | LoadedDataType = TypeVar("LoadedDataType") 11 | TransformedDataType = TypeVar("TransformedDataType") 12 | 13 | 14 | def meta_process_data( 15 | inputs: list[InputType], 16 | save_callback: Callable[ 17 | [TransformedDataType, Optional[str], Optional[InputType]], None 18 | ], 19 | *, 20 | get_input_files: Optional[Callable[[InputType], list[str]]] = None, 21 | load_callback: Optional[Callable[[str], LoadedDataType]] = None, 22 | data_transform: Optional[Callable[[LoadedDataType], TransformedDataType]] = None, 23 | data_combine: Optional[ 24 | Callable[[TransformedDataType, TransformedDataType], TransformedDataType] 25 | ] = None, 26 | ) -> None: 27 | """ 28 | Processes and combines data from multiple input sources. 29 | 30 | This function provides a flexible way to handle data processing by allowing 31 | the user to specify custom functions for loading, transforming, combining, 32 | and saving data. The processing workflow is as follows: 33 | 34 | 1. Extract file paths from each input item using ``get_input_files``. 35 | 2. Load data from each file using ``load_callback``. 36 | 3. Optionally transform the loaded data using ``data_transform``. 37 | 4. Optionally combine transformed data using ``data_combine``. 38 | 5. Save the final result using ``save_callback``. 39 | 40 | Parameters 41 | ---------- 42 | inputs 43 | List of input items to be processed. 44 | save_callback 45 | Function to save the processed data. If ``data_combine`` is not provided, 46 | ``save_callback`` is called for each data file. 47 | Otherwise, ``save_callback`` is called for the combined result. 48 | It takes input file path and input as arguments. 49 | get_input_files 50 | Function to extract file paths from each input item. 51 | If ``None``, the input items are used directly as the file paths. 52 | load_callback 53 | Function to load data from a file path. 54 | If ``None``, :func:`~gbc.utils.load_list_from_file` is used. 55 | data_transform 56 | Function to transform the loaded data. 57 | If ``None``, no transformation is applied. 58 | data_combine 59 | Function to combine transformed data from two files 60 | If ``None``, the data is saved individually using ``save_callback``. 61 | """ 62 | logger = get_gbc_logger() 63 | combined_result = None 64 | all_input_files = [] 65 | all_inputs = [] 66 | for input in inputs: 67 | if get_input_files is not None: 68 | input_files = get_input_files(input) 69 | else: 70 | input_files = [input] 71 | for input_file in input_files: 72 | if load_callback is not None: 73 | data = load_callback(input_file) 74 | else: 75 | data = load_list_from_file(input_file) 76 | if data_transform is not None: 77 | logger.info(f"Processing {input_file}...") 78 | data = data_transform(data) 79 | if data_combine is not None: 80 | if combined_result is None: 81 | combined_result = data 82 | else: 83 | combined_result = data_combine(combined_result, data) 84 | all_input_files.append(input_file) 85 | all_inputs.append(data) 86 | else: 87 | save_callback(data, input_file=input_file, input=input) 88 | if combined_result is not None and save_callback is not None: 89 | save_callback(combined_result, input_file=all_input_files, input=all_inputs) 90 | -------------------------------------------------------------------------------- /src/gbc/t2i/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .utils import load_any 5 | from .prompt import Prompt, GbcPrompt 6 | from .gbc2i.sampling import diffusion_sampling 7 | from .gbc2i.gbc_sampling import gbc_diffusion_sampling 8 | from .gbc2i.k_diffusion import sample_euler_ancestral 9 | from .t2gbc.gbc_prompt_gen import gbc_prompt_gen 10 | -------------------------------------------------------------------------------- /src/gbc/t2i/gbc2i/get_sigmas.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from collections.abc import Callable 5 | 6 | import numpy as np 7 | import torch 8 | from diffusers import EulerDiscreteScheduler 9 | 10 | 11 | def get_sigmas_from_diffusers_scheduler( 12 | num_steps: int, 13 | scheduler: EulerDiscreteScheduler, 14 | omit_last_timestep: bool = False, 15 | ): 16 | # Either is fine as long as we do not have two consecutive 0s at the end 17 | # but the first one is much worse when we have few steps 18 | if omit_last_timestep: 19 | sigmas = scheduler.sigmas[ 20 | torch.linspace( 21 | 0, scheduler.config.num_train_timesteps - 1, num_steps 22 | ).long() 23 | ] 24 | sigmas = torch.concat([sigmas, sigmas.new_zeros(1)], 0) 25 | else: 26 | sigmas = scheduler.sigmas[ 27 | torch.linspace( 28 | 0, scheduler.config.num_train_timesteps, num_steps + 1 29 | ).long() 30 | ] 31 | return sigmas 32 | 33 | 34 | def get_sigmas_for_rf( 35 | num_steps, max_sigma, min_sigma=0, time_disc_func: Callable | None = None 36 | ): 37 | max_time = max_sigma / (1 + max_sigma) 38 | min_time = min_sigma / (1 + min_sigma) 39 | time_disc_func = time_disc_func or uniform_time 40 | time = np.flip(time_disc_func(min_time, max_time, num_steps)) 41 | sigmas = time / (1 - time) 42 | return sigmas 43 | 44 | 45 | def uniform_time(min_time, max_time, num_steps): 46 | return np.linspace(min_time, max_time, num_steps + 1) 47 | 48 | 49 | def sigmoid_time(min_time, max_time, num_steps, rho=10): 50 | # independent of rho 51 | min_time = max(min_time, 1e-5) 52 | min_time_logit = np.log(min_time / (1 - min_time)) 53 | max_time_logit = np.log(max_time / (1 - max_time)) 54 | min_time_rt = min_time_logit / rho + 0.5 55 | max_time_rt = max_time_logit / rho + 0.5 56 | time_rt = np.linspace(min_time_rt, max_time_rt, num_steps + 1) 57 | time = 1 / (1 + np.exp(-rho * (time_rt - 0.5))) 58 | time[0] = min_time 59 | return time 60 | 61 | 62 | def sigmoid_time_scale(min_time, max_time, num_steps, rho=10): 63 | time_rt = np.linspace(-0.5, 0.5, num_steps + 1) 64 | time = 1 / (1 + np.exp(-rho * time_rt)) 65 | # scale to [0, 1] 66 | time = (time - time[0]) / (time[-1] - time[0]) 67 | # scale to [min_time, max_time] 68 | time = time * (max_time - min_time) + min_time 69 | return time 70 | -------------------------------------------------------------------------------- /src/gbc/t2i/gbc2i/k_diffusion/NOTICE.md: -------------------------------------------------------------------------------- 1 | Code in this folder is adapted from [k-diffusion](https://github.com/crowsonkb/k-diffusion). 2 | -------------------------------------------------------------------------------- /src/gbc/t2i/gbc2i/k_diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .euler import sample_euler_ancestral 5 | from .wrapper import DiscreteSchedule, DiscreteEpsDDPMDenoiser 6 | -------------------------------------------------------------------------------- /src/gbc/t2i/t2gbc/gbc_prompt_gen.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from functools import cache 5 | from typing import Literal 6 | 7 | import torch 8 | import lightning.pytorch as pl 9 | from transformers import AutoModelForCausalLM, AutoTokenizer 10 | 11 | from ..utils import truncate_or_pad_to_length 12 | from .generation import generate_text 13 | from .sampling_constraint import create_prefix_allowed_tokens_fn 14 | from .graph_parse import parse_gbc_graph 15 | 16 | 17 | @cache 18 | def load_gbc_prompt_gen( 19 | pretrained_model_name_or_path: str, 20 | torch_dtype: str | torch.dtype = torch.float32, 21 | device: str = "cpu", 22 | attn_implementation: str = "sdpa", 23 | ): 24 | tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path) 25 | model = ( 26 | AutoModelForCausalLM.from_pretrained( 27 | pretrained_model_name_or_path, 28 | torch_dtype=torch_dtype, 29 | attn_implementation=attn_implementation, 30 | ) 31 | .eval() 32 | .to(device) 33 | ) 34 | return tokenizer, model 35 | 36 | 37 | @torch.no_grad() 38 | @torch.inference_mode() 39 | def gbc_prompt_gen( 40 | pretrained_model_name_or_path: str, 41 | prompts: list[str], 42 | num_samples: int | None = None, 43 | padding_mode: Literal["repeat_last", "cycling", "uniform_expansion"] = "cycling", 44 | allow_composition: bool = True, 45 | star_graph: bool = False, 46 | entity_lists: list[list[str]] | None = None, 47 | verbose: bool = False, 48 | seed: int | None = None, 49 | # Generation config 50 | temperature: float = 1, 51 | top_p: float = 0.95, 52 | top_k: int = 60, 53 | repetition_penalty: float = 1, 54 | max_new_tokens: int = 4096, 55 | generation_kwargs: dict = {}, 56 | # For model loading 57 | torch_dtype: str | torch.dtype = torch.float32, 58 | device: str = "cpu", 59 | attn_implementation: str = "sdpa", 60 | ): 61 | 62 | if seed is not None: 63 | pl.seed_everything(seed) 64 | 65 | if num_samples is not None: 66 | prompts = truncate_or_pad_to_length( 67 | prompts, num_samples, padding_mode=padding_mode 68 | ) 69 | if entity_lists is not None: 70 | entity_lists = truncate_or_pad_to_length( 71 | entity_lists, num_samples, padding_mode=padding_mode 72 | ) 73 | 74 | tokenizer, model = load_gbc_prompt_gen( 75 | pretrained_model_name_or_path=pretrained_model_name_or_path, 76 | torch_dtype=torch_dtype, 77 | device=device, 78 | attn_implementation=attn_implementation, 79 | ) 80 | 81 | data = [] 82 | 83 | for idx, prompt in enumerate(prompts): 84 | 85 | input = f"""Node #0 image 86 | type: image 87 | is_leave: False 88 | desc: {prompt} 89 | parents: 90 | bbox: left: 0.0000000, top: 0.0000000, right: 1.0000000, bottom: 1.0000000 91 | """ 92 | input = "\n".join(line.strip() for line in input.splitlines()) 93 | 94 | prev = input 95 | 96 | entity_lists_i = [entity_lists[idx]] if entity_lists is not None else [] 97 | 98 | prefix_allowed_tokens_fn = create_prefix_allowed_tokens_fn( 99 | tokenizer, 100 | allow_composition=allow_composition, 101 | star_graph=star_graph, 102 | entity_lists=entity_lists_i, 103 | ) 104 | model_kwargs = { 105 | "prefix_allowed_tokens_fn": prefix_allowed_tokens_fn, 106 | } 107 | 108 | result = generate_text( 109 | model, 110 | tokenizer, 111 | input, 112 | temperature=temperature, 113 | top_p=top_p, 114 | top_k=top_k, 115 | repetition_penalty=repetition_penalty, 116 | max_new_tokens=max_new_tokens, 117 | stream_output=verbose, 118 | model_kwargs=model_kwargs, 119 | generation_kwargs=generation_kwargs, 120 | ) 121 | 122 | if verbose: 123 | print() 124 | prev = "" 125 | for k in result: 126 | if len(k) > len(prev): 127 | if verbose: 128 | print(k[len(prev) :], end="", flush=True) 129 | prev = k 130 | result = prev 131 | if verbose: 132 | print() 133 | gbc_graph = parse_gbc_graph(result, prompt, verbose=verbose) 134 | data.append(gbc_graph) 135 | 136 | return data 137 | -------------------------------------------------------------------------------- /src/gbc/t2i/t2gbc/generation.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import traceback 5 | from queue import Queue 6 | from threading import Thread 7 | 8 | import torch 9 | import transformers 10 | from transformers import GenerationConfig, PreTrainedModel, PreTrainedTokenizerBase 11 | 12 | 13 | def generate_text( 14 | model: PreTrainedModel, 15 | tokenizer: PreTrainedTokenizerBase, 16 | prompt="", 17 | temperature=0.5, 18 | top_p=0.95, 19 | top_k=45, 20 | repetition_penalty=1.17, 21 | max_new_tokens=128, 22 | stream_output=False, 23 | autocast_gen=lambda: torch.autocast("cpu", enabled=False), 24 | model_kwargs={}, 25 | generation_kwargs={}, 26 | ): 27 | inputs = tokenizer(prompt, return_tensors="pt") 28 | input_ids = inputs["input_ids"].to(next(model.parameters()).device) 29 | generation_config = GenerationConfig( 30 | temperature=temperature, 31 | top_p=top_p, 32 | top_k=top_k, 33 | repetition_penalty=repetition_penalty, 34 | do_sample=True, 35 | **generation_kwargs, 36 | ) 37 | model_generate_params = { 38 | "input_ids": input_ids, 39 | "generation_config": generation_config, 40 | "return_dict_in_generate": True, 41 | "output_scores": True, 42 | "max_new_tokens": max_new_tokens, 43 | **model_kwargs, 44 | } 45 | 46 | if stream_output: 47 | 48 | def generate_with_callback(callback=None, **kwargs): 49 | kwargs.setdefault("stopping_criteria", transformers.StoppingCriteriaList()) 50 | kwargs["stopping_criteria"].append(Stream(callback_func=callback)) 51 | with torch.no_grad(), autocast_gen(): 52 | model.generate(**kwargs) 53 | 54 | def generate_with_streaming(**kwargs): 55 | return Iteratorize(generate_with_callback, kwargs, callback=None) 56 | 57 | with generate_with_streaming(**model_generate_params) as generator: 58 | for output in generator: 59 | decoded_output = tokenizer.decode(output) 60 | if output[-1] == tokenizer.eos_token_id: 61 | break 62 | yield decoded_output 63 | return # early return for stream_output 64 | 65 | with torch.no_grad(), autocast_gen(): 66 | generation_output = model.generate(**model_generate_params) 67 | s = generation_output.sequences[0] 68 | output = tokenizer.decode(s) 69 | yield output 70 | 71 | 72 | class Stream(transformers.StoppingCriteria): 73 | def __init__(self, callback_func=None): 74 | self.callback_func = callback_func 75 | 76 | def __call__(self, input_ids, scores) -> bool: 77 | if self.callback_func is not None: 78 | self.callback_func(input_ids[0]) 79 | return False 80 | 81 | 82 | class Iteratorize: 83 | """ 84 | Transforms a function that takes a callback 85 | into a lazy iterator (generator). 86 | """ 87 | 88 | def __init__(self, func, kwargs=None, callback=None): 89 | self.mfunc = func 90 | self.c_callback = callback 91 | self.q = Queue() 92 | self.sentinel = object() 93 | self.kwargs = kwargs or {} 94 | self.stop_now = False 95 | 96 | def _callback(val): 97 | if self.stop_now: 98 | raise ValueError 99 | self.q.put(val) 100 | 101 | def gentask(): 102 | try: 103 | ret = self.mfunc(callback=_callback, **self.kwargs) 104 | except ValueError: 105 | pass 106 | except Exception: 107 | traceback.print_exc() 108 | 109 | self.q.put(self.sentinel) 110 | if self.c_callback: 111 | self.c_callback(ret) 112 | 113 | self.thread = Thread(target=gentask) 114 | self.thread.start() 115 | 116 | def __iter__(self): 117 | return self 118 | 119 | def __next__(self): 120 | obj = self.q.get(True, None) 121 | if obj is self.sentinel: 122 | self.thread.join() 123 | raise StopIteration 124 | else: 125 | return obj 126 | 127 | def __enter__(self): 128 | return self 129 | 130 | def __exit__(self, exc_type, exc_val, exc_tb): 131 | self.stop_now = True 132 | -------------------------------------------------------------------------------- /src/gbc/t2i/t2gbc/graph_parse.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import re 5 | 6 | from gbc.data.bbox import Bbox 7 | from gbc.data.caption import Description 8 | from gbc.data.graph import GbcVertex, GbcEdge, GbcGraph 9 | 10 | 11 | def parse_gbc_graph(result: str, prompt: str, verbose: bool = False): 12 | 13 | node_parser: re.Pattern = re.compile( 14 | r"Node #(\d+)(.*)\n" 15 | r"type: (.+)\n" 16 | r"is_leave: (True|False)\n" 17 | r"desc: (.+)\n" 18 | r"parents:((?:\s*#\d+\((?:(?!#).)*\))*)\n" 19 | r"bbox: (.*)" 20 | ) 21 | parents_parser: re.Pattern = re.compile( 22 | r"\s*#(\d)+\(((?:(?!#).)*)\: ((?:(?!#).)*)\)" 23 | ) 24 | bbox_parser: re.Pattern = re.compile(r"((?:(?!\:).)*): ([\d\.]+)(?:, )*") 25 | 26 | vertices = {} 27 | for match in node_parser.finditer(result): 28 | id, vertex_id, vertex_label, is_leave, description, parents, bbox = ( 29 | match.groups() 30 | ) 31 | id = id.strip() 32 | vertex_id = vertex_id.strip() 33 | vertex_label = vertex_label.strip() 34 | is_leave = {"True": True, "False": False}[is_leave.strip()] 35 | description = description.strip() 36 | parents = parents_parser.findall(parents) 37 | try: 38 | bbox = bbox_parser.findall(bbox) 39 | bbox = {k: float(v) for k, v in bbox} 40 | except Exception: 41 | bbox = {} 42 | if bbox == {}: 43 | continue 44 | vertices[vertex_id] = GbcVertex( 45 | vertex_id=vertex_id, 46 | label=vertex_label, 47 | bbox=Bbox(**bbox), 48 | descs=[Description(text=description, label="short")], 49 | in_edges=[], 50 | out_edges=[], 51 | ) 52 | in_edges = [] 53 | for pid, parent_vid, text in parents: 54 | if parent_vid in vertices: 55 | edge = GbcEdge( 56 | source=parent_vid, target=vertex_id, text=text, label="short" 57 | ) 58 | vertices[parent_vid].out_edges.append(edge) 59 | in_edges.append(edge) 60 | vertices[vertex_id].in_edges = in_edges 61 | if verbose: 62 | print("=" * 60) 63 | print( 64 | f"{vertex_id} | {vertex_label} | {is_leave}\n" 65 | f"{parents} | {bbox}\n" 66 | f"{description}" 67 | ) 68 | 69 | gbc_graph = GbcGraph( 70 | vertices=list(vertices.values()), 71 | short_caption=prompt, 72 | ) 73 | return gbc_graph 74 | -------------------------------------------------------------------------------- /src/gbc/t2i/utils/loader.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from typing import Any 5 | import omegaconf 6 | from hydra.utils import instantiate 7 | from dataclasses import dataclass 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | @dataclass 14 | class ModelLoadingConfig: 15 | ckpt_path: str | None = None 16 | state_dict_key: str | None = None 17 | state_dict_prefix: str | None = None 18 | precision: str | None = None 19 | device: str | None = None 20 | to_compile: bool = False 21 | to_freeze: bool = False 22 | 23 | 24 | def load_torch_model_from_path(path: str): 25 | if path.endswith(".safetensors"): 26 | import safetensors 27 | 28 | return safetensors.torch.load_file(path, device="cpu") 29 | return torch.load(path, map_location=lambda storage, loc: storage) 30 | 31 | 32 | def extract_state_dict(state_dict: dict[str, Any], key: str | None, prefix: str | None): 33 | if key is not None: 34 | state_dict = state_dict[key] 35 | if prefix is None: 36 | return state_dict 37 | extracted_state_dict = {} 38 | for key, params in state_dict.items(): 39 | if key.startswith(prefix): 40 | extracted_state_dict[key[len(prefix) :]] = params 41 | return extracted_state_dict 42 | 43 | 44 | def prepare_model(model: nn.Module, model_loading_config: ModelLoadingConfig): 45 | if model_loading_config.ckpt_path is not None: 46 | state_dict_all = load_torch_model_from_path(model_loading_config.ckpt_path) 47 | state_dict = extract_state_dict( 48 | state_dict_all, 49 | model_loading_config.state_dict_key, 50 | model_loading_config.state_dict_prefix, 51 | ) 52 | model.load_state_dict(state_dict) 53 | if model_loading_config.precision is not None: 54 | model = model.to(eval(model_loading_config.precision)) 55 | if model_loading_config.device is not None: 56 | model = model.to(model_loading_config.device) 57 | if model_loading_config.to_compile: 58 | model = torch.compile(model) 59 | if model_loading_config.to_freeze: 60 | model.requires_grad_(False).eval() 61 | return model 62 | 63 | 64 | def load_any(obj): 65 | if isinstance(obj, dict) or isinstance(obj, omegaconf.DictConfig): 66 | if "_load_config_" in obj: 67 | load_config = obj.pop("_load_config_") 68 | load_config = ModelLoadingConfig(**load_config) 69 | else: 70 | load_config = None 71 | obj = instantiate(obj) 72 | if load_config is not None: 73 | obj = prepare_model(obj, load_config) 74 | return obj 75 | -------------------------------------------------------------------------------- /src/gbc/texts/__init__.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from .basics import * 5 | -------------------------------------------------------------------------------- /src/gbc/texts/basics.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | 5 | def plural_to_singular(string): 6 | # So that we do not need to install it unless needed 7 | from hbutils.string.plural import singular_form 8 | 9 | quantifiers = [ 10 | "a", 11 | "an", 12 | "some", 13 | "many", 14 | "several", 15 | "few", 16 | "one", 17 | "two", 18 | "three", 19 | "four", 20 | "five", 21 | "six", 22 | "seven", 23 | "eight", 24 | "nine", 25 | "ten", 26 | "eleven", 27 | "twelve", 28 | "thirteen", 29 | "fourteen", 30 | "fifteen", 31 | "sixteen", 32 | "seventeen", 33 | "eighteen", 34 | "nineteen", 35 | "twenty", 36 | ] 37 | if " " in string: 38 | string_parts = string.split() 39 | if string_parts[0].lower() in quantifiers: 40 | string = " ".join(string_parts[1:]) 41 | return singular_form(string).lower() 42 | 43 | 44 | def remove_repeated_suffix(s: str) -> str: 45 | """ 46 | Removes the repeated suffix from the string efficiently using Rolling Hash. 47 | """ 48 | if not s: 49 | return s 50 | 51 | n = len(s) 52 | base = 257 # A prime number base for hashing 53 | mod = 10**9 + 7 # A large prime modulus to prevent overflow 54 | 55 | # Precompute prefix hashes and powers of the base 56 | prefix_hash = [0] * (n + 1) 57 | power = [1] * (n + 1) 58 | 59 | for i in range(n): 60 | prefix_hash[i + 1] = (prefix_hash[i] * base + ord(s[i])) % mod 61 | power[i + 1] = (power[i] * base) % mod 62 | 63 | def get_hash(left, right): 64 | return (prefix_hash[right] - prefix_hash[left] * power[right - left]) % mod 65 | 66 | max_k = 0 # To store the maximum k where suffix is repeated 67 | 68 | # Iterate over possible suffix lengths from 1 to n//2 69 | for k in range(1, n // 2 + 1): 70 | # Compare the last k characters with the k characters before them 71 | if get_hash(n - 2 * k, n - k) == get_hash(n - k, n): 72 | max_k = k # Update max_k if a repeated suffix is found 73 | 74 | if max_k > 0: 75 | # Remove the extra occurrences of the suffix 76 | # Calculate how many times the suffix is repeated consecutively 77 | m = 2 78 | while max_k * (m + 1) <= n and get_hash( 79 | n - (m + 1) * max_k, n - m * max_k 80 | ) == get_hash(n - m * max_k, n - (m - 1) * max_k): 81 | m += 1 82 | # Remove (m-1) copies of the suffix 83 | s = s[: n - (m - 1) * max_k] 84 | 85 | return s 86 | -------------------------------------------------------------------------------- /src/gbc/texts/classifiers.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | from typing import List, Tuple, Union 5 | from functools import cache 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import lightning.pytorch as pl 11 | from sentence_transformers import SentenceTransformer 12 | 13 | from gbc.utils import get_gbc_logger 14 | 15 | 16 | @cache 17 | def load_emb_model(): 18 | model = SentenceTransformer( 19 | "jinaai/jina-embeddings-v2-small-en", trust_remote_code=True 20 | ) 21 | return model 22 | 23 | 24 | @cache 25 | def load_emb_classifier(model_path, gpu_id=0): 26 | logger = get_gbc_logger() 27 | logger.info(f"Loading text classifier from {model_path} on gpu {gpu_id}...") 28 | if torch.cuda.is_available(): 29 | map_location = f"cuda:{gpu_id}" 30 | else: 31 | map_location = "cpu" 32 | model = EmbClassfier.load_from_checkpoint( 33 | model_path, 34 | emb_model=load_emb_model(), 35 | map_location=map_location, 36 | ) 37 | model = model.eval() 38 | return model 39 | 40 | 41 | @cache 42 | def load_emb_pair_classifier(model_path, gpu_id=0): 43 | logger = get_gbc_logger() 44 | logger.info(f"Loading text pair classifier from {model_path} on gpu {gpu_id}...") 45 | if torch.cuda.is_available(): 46 | map_location = f"cuda:{gpu_id}" 47 | else: 48 | map_location = "cpu" 49 | model = EmbPairClassfier.load_from_checkpoint( 50 | model_path, 51 | emb_model=load_emb_model(), 52 | map_location=map_location, 53 | ) 54 | model = model.eval() 55 | return model 56 | 57 | 58 | class EmbClassfier(pl.LightningModule): 59 | def __init__( 60 | self, 61 | emb_model: SentenceTransformer = None, 62 | ): 63 | assert emb_model is not None 64 | super(EmbClassfier, self).__init__() 65 | self.save_hyperparameters(ignore=["emb_model"]) 66 | self.text_model = emb_model.eval().requires_grad_(False) 67 | test_output = self.text_model.encode(["hello", "world"]) 68 | 69 | self.head = BinaryClassifierHead(test_output.shape[1]) 70 | self.train_params = self.head.parameters() 71 | 72 | def forward(self, texts: str | List[str]): 73 | if isinstance(texts, str): 74 | texts = [texts] 75 | return self.head(self.text_model.encode(texts, convert_to_tensor=True)) 76 | 77 | 78 | class EmbPairClassfier(pl.LightningModule): 79 | def __init__( 80 | self, 81 | emb_model: SentenceTransformer = None, 82 | ): 83 | assert emb_model is not None 84 | super(EmbPairClassfier, self).__init__() 85 | self.save_hyperparameters(ignore=["emb_model"]) 86 | self.text_model = emb_model.eval().requires_grad_(False) 87 | test_output = self.text_model.encode(["hello", "world"]) 88 | 89 | self.head = PairClassifierHead(test_output.shape[1]) 90 | self.train_params = self.head.parameters() 91 | 92 | def forward(self, text_pairs: Union[Tuple[str, str], Tuple[List[str], List[str]]]): 93 | texts1, texts2 = text_pairs 94 | if isinstance(texts1, str): 95 | texts1 = [texts1] 96 | if isinstance(texts2, str): 97 | texts2 = [texts2] 98 | return self.head( 99 | self.text_model.encode(texts1, convert_to_tensor=True), 100 | self.text_model.encode(texts2, convert_to_tensor=True), 101 | ) 102 | 103 | 104 | class BinaryClassifierHead(nn.Module): 105 | def __init__(self, in_dim): 106 | super().__init__() 107 | self.fc = nn.Linear(in_dim, 1) 108 | 109 | def forward(self, x): 110 | return self.fc(x) 111 | 112 | 113 | class PairClassifierHead(nn.Module): 114 | def __init__(self, in_dim): 115 | super().__init__() 116 | self.fc = nn.Sequential( 117 | nn.LayerNorm(in_dim * 2), 118 | nn.Linear(in_dim * 2, in_dim * 8), 119 | nn.SiLU(), 120 | nn.Linear(in_dim * 8, 1), 121 | ) 122 | 123 | def forward(self, x, y): 124 | return self.fc(torch.concat((x, y), dim=1)) 125 | -------------------------------------------------------------------------------- /src/gbc/texts/text_helpers.py: -------------------------------------------------------------------------------- 1 | # For licensing see accompanying LICENSE file. 2 | # Copyright (C) 2024 Apple Inc. All Rights Reserved. 3 | 4 | import torch 5 | 6 | from .basics import plural_to_singular 7 | from .classifiers import load_emb_classifier, load_emb_pair_classifier 8 | 9 | 10 | def potential_same_object(text1: str, text2: str, config=None) -> bool: 11 | singular_text1 = plural_to_singular(text1) 12 | singular_text2 = plural_to_singular(text2) 13 | if singular_text1 == singular_text2: 14 | return True 15 | if config is None: 16 | return False 17 | model = load_emb_pair_classifier( 18 | config.text_pair_model_path, gpu_id=getattr(config, "gpu_id", 0) 19 | ) 20 | with torch.no_grad(): 21 | output1 = model((text1, text2)).squeeze().cpu().numpy() 22 | output2 = model((singular_text1, singular_text2)).squeeze().cpu().numpy() 23 | return (output1 > 0) or (output2 > 0) 24 | 25 | 26 | def suitable_for_detection(text: str, config=None) -> bool: 27 | if text in ["right", "left", "top", "bottom", "front", "back", "side"]: 28 | return False 29 | # Detection with digit would cause collision between entity and composition nodes 30 | if text.isdigit(): 31 | return False 32 | # This will break the naming convention of the node ids 33 | # As we split each detection result by '_' 34 | if "_" in text: 35 | return False 36 | # Otherwise we return True if no model is loaded 37 | if config is None: 38 | return True 39 | model = load_emb_classifier( 40 | config.text_binary_model_path, gpu_id=getattr(config, "gpu_id", 0) 41 | ) 42 | with torch.no_grad(): 43 | output = model(text).squeeze().cpu().numpy() 44 | return output > 0 45 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_auto_image_entity_actions.py: -------------------------------------------------------------------------------- 1 | from objprint import op 2 | from omegaconf import OmegaConf 3 | 4 | from gbc.utils import setup_gbc_logger 5 | from gbc.captioning import ( 6 | AutoImageQuery, 7 | AutoEntityQuery, 8 | get_action_input_from_img_path, 9 | ) 10 | 11 | 12 | img_path = "data/images/wiki/Eiffel_tower_0.jpg" 13 | action_input = get_action_input_from_img_path(img_path) 14 | config = OmegaConf.load("configs/captioning/default.yaml") 15 | 16 | setup_gbc_logger() 17 | 18 | 19 | ### 20 | 21 | print("Testing AutoImageQuery with config...") 22 | image_query = AutoImageQuery(config) 23 | 24 | print("Testing caching") 25 | image_query = AutoImageQuery(config) 26 | queries, result, image = image_query.query(action_input) 27 | 28 | print("Queries to complete:") 29 | op([query.model_dump() for query in queries]) 30 | 31 | print("Result:") 32 | op(result.model_dump()) 33 | 34 | 35 | ### 36 | 37 | print("----------------------------------------------") 38 | 39 | print("Testing AutoEntityQuery with config...") 40 | entity_query = AutoEntityQuery(config) 41 | 42 | print("Testing caching") 43 | entity_query = AutoEntityQuery(config) 44 | action_input.entity_info.text = "tower" 45 | queries, result, image = entity_query.query(action_input) 46 | 47 | print("Queries to complete:") 48 | op([query.model_dump() for query in queries]) 49 | 50 | print("Result:") 51 | op(result.model_dump()) 52 | 53 | 54 | ### 55 | 56 | print("----------------------------------------------") 57 | 58 | print("Testing AutoImageQuery...") 59 | 60 | image_query = AutoImageQuery() 61 | 62 | print("Testing AutoEntityQuery...") 63 | 64 | entity_query = AutoEntityQuery() 65 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_captioning_pipeline.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from objprint import op 3 | 4 | from gbc.utils import save_list_to_file, setup_gbc_logger 5 | from gbc.captioning import GbcPipeline 6 | 7 | setup_gbc_logger() 8 | 9 | config = OmegaConf.load("configs/captioning/default.yaml") 10 | gbc_pipeline = GbcPipeline.from_config(config) 11 | 12 | img_file_1 = "data/images/wiki/Eiffel_tower_0.jpg" 13 | img_file_2 = "data/images/wiki/Eiffel_tower_1.jpg" 14 | 15 | # Perform captioning on a single image 16 | gbc = gbc_pipeline.run_gbc_captioning(img_file_1) 17 | # Pretty print the GBC graph 18 | op(gbc[0].model_dump()) 19 | 20 | # Perform captioning on multiple images 21 | gbcs = gbc_pipeline.run_gbc_captioning([img_file_1, img_file_2]) 22 | # Save the GBC graphs, can save as json, jsonl, or parquet 23 | save_list_to_file(gbcs, "tests/outputs/captioning/gbc_eiffel_tower.json") 24 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_captioning_pipeline_functional.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from objprint import op 3 | 4 | from gbc.utils import save_list_to_file, setup_gbc_logger 5 | from gbc.captioning import run_gbc_captioning 6 | 7 | setup_gbc_logger() 8 | 9 | config = OmegaConf.load("configs/captioning/default.yaml") 10 | 11 | img_file_1 = "data/images/wiki/Eiffel_tower_0.jpg" 12 | img_file_2 = "data/images/wiki/Eiffel_tower_1.jpg" 13 | 14 | # Perform captioning on a single image 15 | gbc = run_gbc_captioning(img_file_1, config, include_relation_query=False) 16 | # Pretty print the GBC graph 17 | op(gbc[0].model_dump()) 18 | 19 | # Perform captioning on multiple images 20 | gbcs = run_gbc_captioning( 21 | [img_file_1, img_file_2], config, batch_query=True, batch_size=8 22 | ) 23 | # Save the GBC graphs, can save as json, jsonl, or parquet 24 | save_list_to_file(gbcs, "tests/outputs/captioning/gbc_batch_eiffel_tower.json") 25 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_detection_action.py: -------------------------------------------------------------------------------- 1 | from objprint import op 2 | from omegaconf import OmegaConf 3 | 4 | import cv2 5 | 6 | from gbc.utils import setup_gbc_logger 7 | from gbc.data.bbox.annotate import annotate_all_labels 8 | from gbc.captioning import ( 9 | AutoDetectionActionFromImage, 10 | AutoDetectionActionFromEntity, 11 | ) 12 | from gbc.captioning.primitives import ( 13 | EntityInfo, 14 | ActionInputWithEntities, 15 | get_action_input_from_img_path, 16 | ) 17 | from gbc.captioning.detection.detection_action import DetectionAction 18 | 19 | 20 | ### 21 | 22 | setup_gbc_logger() 23 | 24 | img_path = "data/images/wiki/Eiffel_tower_0.jpg" 25 | action_input = get_action_input_from_img_path(img_path) 26 | 27 | entities = [ 28 | EntityInfo(label="multiple", text="boats", entity_id="boats"), 29 | EntityInfo(label="single", text="tower", entity_id="tower"), 30 | EntityInfo(label="single", text="train", entity_id="train"), 31 | ] 32 | 33 | action_input_with_entities = ActionInputWithEntities( 34 | image=action_input.image, 35 | entity_info=action_input.entity_info, 36 | entities=entities, 37 | ) 38 | 39 | config = OmegaConf.load("configs/captioning/default.yaml") 40 | 41 | 42 | ### 43 | 44 | # The first time detection model is loaded 45 | print("Testing DetectionAction...") 46 | 47 | detection_action = DetectionAction() 48 | queries, _, _ = detection_action.query(action_input_with_entities) 49 | 50 | op([query.model_dump() for query in queries]) 51 | 52 | 53 | ### 54 | 55 | print("-------------------------------------------------------------") 56 | print("Testing AutoDetectionActionFromImage with config...") 57 | 58 | auto_detection_action = AutoDetectionActionFromImage(config) 59 | 60 | print("-------------------------------------------------------------") 61 | print("Testing caching...") 62 | auto_detection_action = AutoDetectionActionFromImage(config) 63 | queries, _, _ = auto_detection_action.query(action_input_with_entities) 64 | 65 | op([query.model_dump() for query in queries]) 66 | 67 | labeled_bboxes = [] 68 | for query in queries: 69 | labeled_bboxes.append( 70 | ( 71 | query.action_input.first_entity_id, 72 | query.action_input.bbox, 73 | ) 74 | ) 75 | # Note that cv2 is in BGR instead of RGB 76 | img_annotated = annotate_all_labels( 77 | action_input.get_image(return_pil=False)[:, :, ::-1], labeled_bboxes 78 | ) 79 | save_img_path = "tests/outputs/detection/annotated_eiffel_tower.jpg" 80 | cv2.imwrite(save_img_path, img_annotated) 81 | 82 | 83 | ### 84 | 85 | print("-------------------------------------------------------------") 86 | print("Testing AutoDetectionActionFromImage without config...") 87 | auto_detection_action = AutoDetectionActionFromImage() 88 | queries, _, _ = auto_detection_action.query(action_input_with_entities) 89 | 90 | op([query.model_dump() for query in queries]) 91 | 92 | 93 | ### 94 | 95 | print("-------------------------------------------------------------") 96 | print("Testing AutoDetectionActionFromEntity with config...") 97 | 98 | auto_detection_action = AutoDetectionActionFromEntity(config) 99 | queries, _, _ = auto_detection_action.query(action_input_with_entities) 100 | 101 | op([query.model_dump() for query in queries]) 102 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_detection_grounding_dino.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from objprint import op 3 | 4 | from gbc.utils import setup_gbc_logger 5 | from gbc.captioning.detection.grounding_dino import GroundingDinoDetection 6 | 7 | 8 | setup_gbc_logger() 9 | 10 | print("Testing GroundingDinoDetection...") 11 | 12 | detection_model = GroundingDinoDetection() 13 | 14 | img_path = "data/images/wiki/Eiffel_tower_0.jpg" 15 | texts = ["boats", "tower", "train"] 16 | image = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB) 17 | 18 | print("Running detection...") 19 | # UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. # noqa 20 | bboxes = detection_model.detect(image, texts) 21 | 22 | print("----------------------------------------------") 23 | for text, bboxes_per_text in bboxes.items(): 24 | print(f"Detected bounding boxes for '{text}':") 25 | op([bbox.model_dump() for bbox in bboxes_per_text]) 26 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_detection_yolo_world.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from objprint import op 3 | 4 | from gbc.utils import setup_gbc_logger 5 | from gbc.captioning.detection.yolo_world import YoloWorldDetection 6 | 7 | 8 | setup_gbc_logger() 9 | 10 | print("Testing YoloWorldDetection X...") 11 | 12 | detection_model = YoloWorldDetection(model_version="x_v2") 13 | 14 | img_path = "data/images/wiki/Eiffel_tower_0.jpg" 15 | texts = ["boats", "tower", "train"] 16 | image = cv2.imread(img_path) 17 | 18 | print("Running detection...") 19 | bboxes = detection_model.detect(image, texts) 20 | 21 | print("----------------------------------------------") 22 | for text, bboxes_per_text in bboxes.items(): 23 | print(f"Detected bounding boxes for '{text}':") 24 | op([bbox.model_dump() for bbox in bboxes_per_text]) 25 | 26 | 27 | print("----------------------------------------------") 28 | 29 | print("Testing YoloWorldDetection L...") 30 | 31 | detection_model = YoloWorldDetection(model_version="l_v2") 32 | 33 | print("Running detection...") 34 | bboxes = detection_model.detect(image, texts) 35 | 36 | print("----------------------------------------------") 37 | for text, bboxes_per_text in bboxes.items(): 38 | print(f"Detected bounding boxes for '{text}':") 39 | op([bbox.model_dump() for bbox in bboxes_per_text]) 40 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_image_query.py: -------------------------------------------------------------------------------- 1 | from objprint import op 2 | 3 | from gbc.utils import setup_gbc_logger 4 | from gbc.captioning import ( 5 | get_action_input_from_img_path, 6 | ) 7 | from gbc.captioning.mllm.llava import LlavaImageQuery 8 | 9 | setup_gbc_logger() 10 | 11 | img_path = "data/images/wiki/Eiffel_tower_0.jpg" 12 | action_input = get_action_input_from_img_path(img_path) 13 | 14 | system_file_image = "prompts/captioning/system_image.txt" 15 | query_file_image = "prompts/captioning/query_image.txt" 16 | 17 | llava_image_query = LlavaImageQuery( 18 | query_file=query_file_image, 19 | system_file=system_file_image, 20 | ) 21 | 22 | queries, result, image = llava_image_query.query(action_input) 23 | 24 | print("Queries to complete:") 25 | op([query.model_dump() for query in queries]) 26 | 27 | print("----------------------------------------------") 28 | print("Result:") 29 | op(result.model_dump()) 30 | 31 | print("----------------------------------------------") 32 | print("Queries to complete:") 33 | op([query.model_dump() for query in queries]) 34 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_io_unit.py: -------------------------------------------------------------------------------- 1 | from objprint import op 2 | 3 | from gbc.captioning import get_action_input_from_img_path 4 | from gbc.captioning.auto_actions import AutoImageQuery 5 | from gbc.captioning.primitives import ( 6 | NodeInfo, 7 | QueryResult, 8 | ActionInputPair, 9 | ) 10 | 11 | 12 | img_path = "data/images/wiki/Eiffel_tower_0.jpg" 13 | action_input = get_action_input_from_img_path(img_path) 14 | 15 | 16 | print("------ Test NodeInfo -------") 17 | 18 | print("Define NodeInfo") 19 | query_result = QueryResult(descs=[], entities=[], raw="") 20 | node_info = NodeInfo( 21 | action_input=action_input, 22 | query_result=query_result, 23 | ) 24 | print(node_info) 25 | 26 | print("Convert to dict") 27 | node_info_dict = node_info.model_dump() 28 | op(node_info_dict) 29 | 30 | print("Convert from dict") 31 | node_info = NodeInfo.model_validate(node_info_dict) 32 | print(node_info) 33 | 34 | print("Load image") 35 | image = node_info.action_input.get_image() 36 | print(node_info) 37 | 38 | 39 | print() 40 | print("------ Test ActionInputPair -------") 41 | 42 | print("Define ActionInputPair") 43 | action_input_pair = ActionInputPair( 44 | action_class=AutoImageQuery, 45 | action_input=action_input, 46 | ) 47 | print(action_input_pair) 48 | 49 | print("Convert to dict") 50 | action_input_pair_dict = action_input_pair.model_dump() 51 | op(action_input_pair_dict) 52 | 53 | print("Convert from dict") 54 | action_input_pair = ActionInputPair.model_validate(action_input_pair_dict) 55 | print(action_input_pair) 56 | -------------------------------------------------------------------------------- /tests/test_captioning_unit/test_pixtral_query.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from gbc.utils import setup_gbc_logger 4 | from gbc.captioning.mllm.pixtral import load_pixtral_model, pixtral_query_single 5 | 6 | 7 | setup_gbc_logger() 8 | 9 | 10 | with open("prompts/captioning/query_image.txt", "r") as f: 11 | query = f.read() 12 | 13 | with open("prompts/captioning/system_image.txt", "r") as f: 14 | system_message = f.read() 15 | 16 | 17 | image = Image.open("data/images/wiki/Eiffel_tower_0.jpg").convert("RGB") 18 | 19 | 20 | print("Testing pixtral_query_single with image query...") 21 | 22 | model = load_pixtral_model("nm-testing/pixtral-12b-FP8-dynamic") 23 | result = pixtral_query_single(model, image, query, system_message=system_message) 24 | print(result) 25 | -------------------------------------------------------------------------------- /tests/test_data_unit/test_basic_filter.py: -------------------------------------------------------------------------------- 1 | # from objprint import op 2 | from tqdm import tqdm 3 | from copy import deepcopy 4 | from itertools import product 5 | 6 | from gbc.data import GbcGraphFull 7 | from gbc.utils import setup_gbc_logger, load_list_from_file 8 | from gbc.processing.data_transforms import basic_filter_and_extract 9 | 10 | 11 | setup_gbc_logger() 12 | 13 | gbc_graphs = load_list_from_file( 14 | "data/gbc/wiki/wiki_gbc_graphs.jsonl", class_type=GbcGraphFull 15 | ) 16 | 17 | 18 | drop_composition_descendants_list = [False, True] 19 | drop_vertex_size_kwargs_list = [ 20 | None, 21 | {"min_rel_width": 0.3, "min_rel_height": 0.3}, 22 | {"max_rel_width": 0.7, "max_rel_size": 0.8, "min_rel_size": 0.2}, 23 | ] 24 | drop_vertex_types_list = [ 25 | None, 26 | ["relation"], 27 | ["composition", "entity"], 28 | ["composition", "relation"], 29 | ["composition", "relation", "entity"], 30 | ] 31 | drop_caption_types_list = [ 32 | None, 33 | ["composition"], 34 | ["detail-image"], 35 | ["relation", "original-image"], 36 | ["detail-entity", "hardcode", "bagofwords"], 37 | ] 38 | same_level_max_bbox_overlap_ratio_list = [None, 0.1, 0.5, 1] 39 | max_n_vertices_list = [None, 1, 5, 10] 40 | max_depth_list = [None, 1, 3] 41 | subgraph_extraction_mode_list = ["bfs", "dfs"] 42 | subgraph_edge_shuffling_list = [False, True] 43 | keep_in_edges_list = [False, True] 44 | keep_out_edges_list = [False, True] 45 | 46 | 47 | # Cycle through graphs for testing 48 | graph_count = len(gbc_graphs) 49 | graph_index = 0 # Start index for cycling through graphs 50 | 51 | # Iterate over all parameter combinations 52 | parameter_combinations = product( 53 | drop_composition_descendants_list, 54 | drop_vertex_size_kwargs_list, 55 | drop_vertex_types_list, 56 | drop_caption_types_list, 57 | same_level_max_bbox_overlap_ratio_list, 58 | max_n_vertices_list, 59 | max_depth_list, 60 | subgraph_extraction_mode_list, 61 | subgraph_edge_shuffling_list, 62 | keep_in_edges_list, 63 | keep_out_edges_list, 64 | ) 65 | 66 | # Run tests, 115200 iterations, could take several minutes 67 | for combination in tqdm(parameter_combinations): 68 | # Cycle through graphs 69 | original_graph = gbc_graphs[graph_index] 70 | graph_index = (graph_index + 1) % graph_count # Cycle index 71 | 72 | # Create a deepcopy of the graph for this test 73 | test_graph = deepcopy(original_graph) 74 | 75 | # Unpack parameters 76 | ( 77 | drop_composition_descendants, 78 | drop_vertex_size_kwargs, 79 | drop_vertex_types, 80 | drop_caption_types, 81 | same_level_max_bbox_overlap_ratio, 82 | max_n_vertices, 83 | max_depth, 84 | subgraph_extraction_mode, 85 | subgraph_edge_shuffling, 86 | keep_in_edges, 87 | keep_out_edges, 88 | ) = combination 89 | 90 | # # Print parameters 91 | # print("Running test with parameters:") 92 | # print(f" drop_composition_descendants: {drop_composition_descendants}") 93 | # print(f" drop_vertex_size_kwargs: {drop_vertex_size_kwargs}") 94 | # print(f" drop_vertex_types: {drop_vertex_types}") 95 | # print(f" drop_caption_types: {drop_caption_types}") 96 | # print(f" same_level_max_bbox_overlap_ratio: {same_level_max_bbox_overlap_ratio}") 97 | # print(f" max_n_vertices: {max_n_vertices}") 98 | # print(f" max_depth: {max_depth}") 99 | # print(f" subgraph_extraction_mode: {subgraph_extraction_mode}") 100 | # print(f" subgraph_edge_shuffling: {subgraph_edge_shuffling}") 101 | # print(f" keep_in_edges: {keep_in_edges}") 102 | # print(f" keep_out_edges: {keep_out_edges}") 103 | 104 | # Apply the function 105 | result_graph = basic_filter_and_extract( 106 | gbc_graph=test_graph, 107 | drop_composition_descendants=drop_composition_descendants, 108 | drop_vertex_size_kwargs=drop_vertex_size_kwargs, 109 | drop_vertex_types=drop_vertex_types, 110 | drop_caption_types=drop_caption_types, 111 | same_level_max_bbox_overlap_ratio=same_level_max_bbox_overlap_ratio, 112 | max_n_vertices=max_n_vertices, 113 | max_depth=max_depth, 114 | subgraph_extraction_mode=subgraph_extraction_mode, 115 | subraph_edge_shuffling=subgraph_edge_shuffling, 116 | keep_in_edges=keep_in_edges, 117 | keep_out_edges=keep_out_edges, 118 | ) 119 | 120 | # # Log the result 121 | # op(result_graph.model_dump(), indent=2) # Pretty-print the result graph 122 | -------------------------------------------------------------------------------- /tests/test_data_unit/test_clip_filter.py: -------------------------------------------------------------------------------- 1 | # from objprint import op 2 | from tqdm import tqdm 3 | from copy import deepcopy 4 | from itertools import product 5 | 6 | from gbc.data import GbcGraphFull 7 | from gbc.utils import setup_gbc_logger, load_list_from_file 8 | from gbc.processing.data_transforms import gbc_clip_filter 9 | 10 | 11 | setup_gbc_logger() 12 | 13 | gbc_graphs = load_list_from_file( 14 | "data/gbc/wiki/with_clip_scores/wiki_gbc_graphs_with_clip.jsonl", 15 | class_type=GbcGraphFull, 16 | ) 17 | 18 | 19 | split_rather_than_filter_list = [False, True] 20 | exclude_labels_list = [ 21 | ["hardcode"], 22 | ["relation", "image-detail"], 23 | ["hardcode", "composition"], 24 | ] 25 | random_filtering_probs_list = [dict(), {"short": 0.3, "detail": 0.2}] 26 | add_bag_of_words_list = [False, True] 27 | remove_edge_with_no_corr_list = [False, True] 28 | filter_image_using_short_clip_scores_list = [False, True] 29 | filter_after_selection_list = [False, True] 30 | max_n_vertices_per_graph_list = [None, 10] 31 | node_selection_mode_list = ["bfs", "dfs", "random"] 32 | max_n_vertices_list = [None, 100] 33 | 34 | # Cycle through graphs for testing 35 | graph_count = len(gbc_graphs) 36 | graph_index = 0 # Start index for cycling through graphs 37 | 38 | # Iterate over all parameter combinations 39 | parameter_combinations = product( 40 | split_rather_than_filter_list, 41 | exclude_labels_list, 42 | random_filtering_probs_list, 43 | add_bag_of_words_list, 44 | remove_edge_with_no_corr_list, 45 | filter_image_using_short_clip_scores_list, 46 | filter_after_selection_list, 47 | max_n_vertices_per_graph_list, 48 | node_selection_mode_list, 49 | max_n_vertices_list, 50 | ) 51 | 52 | # Run tests, 115200 iterations, could take several minutes 53 | for combination in tqdm(parameter_combinations): 54 | # Cycle through graphs 55 | original_graph = gbc_graphs[graph_index] 56 | graph_index = (graph_index + 1) % graph_count # Cycle index 57 | 58 | # Create a deepcopy of the graph for this test 59 | test_graph = deepcopy(original_graph) 60 | 61 | # Unpack parameters 62 | ( 63 | split_rather_than_filter, 64 | exclude_labels, 65 | random_filtering_probs, 66 | add_bag_of_words, 67 | remove_edge_with_no_corr, 68 | filter_image_using_short_clip_scores, 69 | filter_after_selection, 70 | max_n_vertices_per_graph, 71 | node_selection_mode, 72 | max_n_vertices, 73 | ) = combination 74 | 75 | # Apply the function 76 | result_graph = gbc_clip_filter( 77 | test_graph, 78 | split_rather_than_filter=split_rather_than_filter, 79 | exclude_labels=exclude_labels, 80 | random_filtering_probs=random_filtering_probs, 81 | add_bag_of_words=add_bag_of_words, 82 | remove_edge_with_no_corr=remove_edge_with_no_corr, 83 | filter_image_using_short_clip_scores=filter_image_using_short_clip_scores, 84 | filter_after_selection=filter_after_selection, 85 | max_n_vertices_per_graph=max_n_vertices_per_graph, 86 | node_selection_mode=node_selection_mode, 87 | max_n_vertices=max_n_vertices, 88 | ) 89 | 90 | # # Log the result 91 | # op(result_graph.model_dump(), indent=2) # Pretty-print the result graph 92 | -------------------------------------------------------------------------------- /tests/test_data_unit/test_gbc_to_text_and_image.py: -------------------------------------------------------------------------------- 1 | from objprint import op 2 | from copy import deepcopy 3 | from itertools import product 4 | 5 | from gbc.data import GbcGraphFull 6 | from gbc.utils import setup_gbc_logger, load_list_from_file 7 | from gbc.processing.data_transforms import gbc_graph_to_text_and_image 8 | 9 | 10 | setup_gbc_logger() 11 | 12 | gbc_graphs = load_list_from_file( 13 | "data/gbc/wiki/wiki_gbc_graphs.jsonl", class_type=GbcGraphFull 14 | ) 15 | 16 | 17 | text_format_list = ["set", "set_with_bbox", "concat", "structured"] 18 | graph_traversal_mode_list = ["bfs", "dfs", "random", "topological"] 19 | read_image_list = [False, True] 20 | remove_repeated_suffix_list = [False, True] 21 | 22 | 23 | # Cycle through graphs for testing 24 | graph_count = len(gbc_graphs) 25 | graph_index = 0 # Start index for cycling through graphs 26 | 27 | # Iterate over all parameter combinations 28 | parameter_combinations = product( 29 | text_format_list, 30 | graph_traversal_mode_list, 31 | read_image_list, 32 | remove_repeated_suffix_list, 33 | ) 34 | 35 | # Run tests, 115200 iterations, could take several minutes 36 | for combination in parameter_combinations: 37 | # Cycle through graphs 38 | original_graph = gbc_graphs[graph_index] 39 | graph_index = (graph_index + 1) % graph_count # Cycle index 40 | 41 | # Create a deepcopy of the graph for this test 42 | test_graph = deepcopy(original_graph) 43 | 44 | # Unpack parameters 45 | text_format, graph_traversal_mode, read_image, remove_repeated_suffix = combination 46 | 47 | # Print parameters 48 | print("-------------------------------------------------------") 49 | print("Running test with parameters:") 50 | print(f" text_format: {text_format}") 51 | print(f" graph_traversal_mode: {graph_traversal_mode}") 52 | print(f" read_image: {read_image}") 53 | print(f" remove_repeated_suffix: {remove_repeated_suffix}") 54 | 55 | print("") 56 | 57 | if text_format == "structured": 58 | caption_agg_mode_for_structured_list = ["first", "concat"] 59 | 60 | for caption_agg_mode_for_structured in caption_agg_mode_for_structured_list: 61 | print( 62 | " caption_agg_mode_for_structured: ", 63 | caption_agg_mode_for_structured, 64 | ) 65 | print("") 66 | result = gbc_graph_to_text_and_image( 67 | gbc_graph=test_graph, 68 | graph_traversal_mode=graph_traversal_mode, 69 | text_format=text_format, 70 | read_image=read_image, 71 | remove_repeated_suffix=remove_repeated_suffix, 72 | caption_agg_mode_for_structured=caption_agg_mode_for_structured, 73 | ) 74 | op(result, indent=2) 75 | 76 | else: 77 | result = gbc_graph_to_text_and_image( 78 | gbc_graph=test_graph, 79 | graph_traversal_mode=graph_traversal_mode, 80 | text_format=text_format, 81 | read_image=read_image, 82 | remove_repeated_suffix=remove_repeated_suffix, 83 | ) 84 | op(result, indent=2) 85 | -------------------------------------------------------------------------------- /tests/test_data_unit/test_graph_basics.py: -------------------------------------------------------------------------------- 1 | from objprint import op 2 | from copy import deepcopy 3 | 4 | from gbc.data import GbcGraphFull 5 | from gbc.utils import setup_gbc_logger, load_list_from_file 6 | 7 | 8 | setup_gbc_logger() 9 | 10 | gbc_graphs = load_list_from_file( 11 | "data/gbc/wiki/wiki_gbc_graphs.jsonl", class_type=GbcGraphFull 12 | ) 13 | 14 | print("The first gbc graph is:") 15 | op(gbc_graphs[0].model_dump()) 16 | 17 | 18 | for mode in ["dfs", "bfs", "random"]: 19 | print("--------------------------------------------------------------") 20 | print(f"Extracted gbc subgraph with 5 vertices in {mode} order:") 21 | op(gbc_graphs[0].get_subgraph(5, mode=mode).model_dump()) 22 | 23 | gbc_graph = deepcopy(gbc_graphs[0]) 24 | 25 | gbc_graph = gbc_graph.drop_vertices_by_size(min_rel_width=0.2, min_rel_height=0.2) 26 | print("--------------------------------------------------------------") 27 | print("Dropping vertices with width < 0.2 or height < 0.2:") 28 | op(gbc_graph.model_dump()) 29 | 30 | 31 | gbc_graph = gbc_graph.drop_captions_by_type(["hardcode", "composition"]) 32 | print("--------------------------------------------------------------") 33 | print("Dropping captions of type 'hardcode' and 'composition':") 34 | op(gbc_graph.model_dump()) 35 | 36 | gbc_graph = gbc_graph.drop_vertices_by_type(["relation"]) 37 | print("--------------------------------------------------------------") 38 | print("Dropping vertices of type 'relation':") 39 | op(gbc_graph.model_dump()) 40 | 41 | gbc_graph = gbc_graph.drop_vertices_by_type(["composition"]) 42 | print("--------------------------------------------------------------") 43 | print("Dropping vertices of type 'composition':") 44 | op(gbc_graph.model_dump()) 45 | 46 | 47 | gbc_graph = deepcopy(gbc_graphs[0]) 48 | 49 | gbc_graph = gbc_graph.drop_composition_descendants() 50 | print("--------------------------------------------------------------") 51 | print("Dropping composition descendants:") 52 | op(gbc_graph.model_dump()) 53 | 54 | gbc_graph = gbc_graph.drop_vertices_by_overlap_area( 55 | max_overlap_ratio=0.5, keep_in_edges=False 56 | ) 57 | print("--------------------------------------------------------------") 58 | print("Dropping vertices so that overlap ratio < 0.5:") 59 | op(gbc_graph.model_dump()) 60 | -------------------------------------------------------------------------------- /tests/test_integral/test_gbc2i_sampling.sh: -------------------------------------------------------------------------------- 1 | python scripts/generation/gbc2i.py \ 2 | --configs configs/generation/gbc2i/sampling_base.yaml \ 3 | --prompt_file prompts/t2i/t2gbc_seed.yaml 4 | python scripts/generation/gbc2i.py \ 5 | --configs configs/generation/gbc2i/sampling_region_gbc_encode_without_context_ipa.yaml \ 6 | --prompt_files prompts/t2i/dog_cat_ref_image.yaml 7 | python scripts/generation/gbc2i.py \ 8 | --configs configs/generation/gbc2i/sampling_gbc_encode_without_context.yaml \ 9 | --prompt_files prompts/t2i/banana_apple_graph_only.yaml prompts/t2i/living_room_graph_only.yaml 10 | -------------------------------------------------------------------------------- /tests/test_integral/test_gbc_captioning.sh: -------------------------------------------------------------------------------- 1 | python scripts/captioning/run_gbc_captioning.py \ 2 | --img_paths data/images/wiki/ \ 3 | --save_dir tests/outputs/captioning/gbc_wiki/ \ 4 | -------------------------------------------------------------------------------- /tests/test_integral/test_gbc_captioning_batch_single_image.sh: -------------------------------------------------------------------------------- 1 | python scripts/captioning/run_gbc_captioning.py \ 2 | --img_paths data/images/wiki/Mumbai_Flora_Fountain.jpg \ 3 | --save_dir tests/outputs/captioning/gbc_batch_fountain/ \ 4 | --save_images --batch_query --batch_size 4 5 | -------------------------------------------------------------------------------- /tests/test_integral/test_gbc_captioning_single_image.sh: -------------------------------------------------------------------------------- 1 | python scripts/captioning/run_gbc_captioning.py \ 2 | --img_paths data/images/wiki/Kinkakuji.jpg \ 3 | --save_dir tests/outputs/captioning/gbc_kinkakuji/ \ 4 | --save_images 5 | -------------------------------------------------------------------------------- /tests/test_integral/test_gbc_captioning_single_image_llava_yoloworld.sh: -------------------------------------------------------------------------------- 1 | python scripts/captioning/run_gbc_captioning.py \ 2 | --config_file configs/captioning/llava_yoloworld.yaml \ 3 | --img_paths data/images/wiki/Wild_horses.jpg \ 4 | --save_dir tests/outputs/captioning/gbc_wild_horses/ \ 5 | --save_images 6 | -------------------------------------------------------------------------------- /tests/test_integral/test_gbc_processing.sh: -------------------------------------------------------------------------------- 1 | python scripts/processing/process_gbc.py --configs configs/processing/to_structured_text.yaml 2 | python scripts/processing/process_gbc.py --configs configs/processing/compute_clip_scores.yaml 3 | python scripts/processing/process_gbc.py --configs configs/processing/compute_toxicity_scores.yaml 4 | python scripts/processing/process_gbc.py --configs configs/processing/compute_all_scores.yaml 5 | python scripts/processing/process_gbc.py --configs configs/processing/relation_composition_filtering.yaml 6 | python scripts/processing/process_gbc.py --configs configs/processing/clip_filtering.yaml 7 | -------------------------------------------------------------------------------- /viewer/.gitignore: -------------------------------------------------------------------------------- 1 | # data 2 | temp.duckdb 3 | **/*.parquet 4 | 5 | # Logs 6 | logs 7 | *.log 8 | npm-debug.log* 9 | yarn-debug.log* 10 | yarn-error.log* 11 | pnpm-debug.log* 12 | lerna-debug.log* 13 | 14 | node_modules 15 | dist 16 | dist-ssr 17 | *.local 18 | 19 | # Editor directories and files 20 | .vscode/* 21 | !.vscode/extensions.json 22 | .idea 23 | .DS_Store 24 | *.suo 25 | *.ntvs* 26 | *.njsproj 27 | *.sln 28 | *.sw? 29 | -------------------------------------------------------------------------------- /viewer/README.md: -------------------------------------------------------------------------------- 1 | # GBC Viewer 2 | 3 | The GBC Viewer is an interactive tool for exploring GBC-annotated data. It supports both reading images locally and from the internet. 4 | 5 |

6 | 7 |

8 | 9 | 10 | Download the released datasets [GBC1M](https://huggingface.co/datasets/graph-based-captions/GBC1M/tree/main/data/parquet)/[GBC10M](https://huggingface.co/datasets/graph-based-captions/GBC10M/tree/main/data/parquet) into the [data](data) folder and explore them using the viewer! 11 | 12 | ## Requirements 13 | 14 | - **Node.js**: `>= 20` 15 | - **Python**: `>= 3.10` 16 | 17 | ## Installation 18 | 19 | Install the required dependencies: 20 | 21 | ```bash 22 | npm install 23 | python3 -m pip install -r ./server/requirements.txt 24 | ``` 25 | 26 | ## Usage 27 | 28 | ### Build and Run the Viewer 29 | 30 | To build the website and start the Python server with both frontend and backend running together: 31 | 32 | ```bash 33 | npm run build 34 | python ./server/api.py --path ../data/gbc/wiki --img_root_dir .. --frontend_path dist --port 5050 35 | ``` 36 | 37 | In this example: 38 | - The server reads data from [wiki_gbc_graphs.parquet](../data/gbc/wiki/wiki_gbc_graphs.parquet) in the `../data/gbc/wiki` directory. 39 | - The website will be available at [http://localhost:5050/](http://localhost:5050/). 40 | 41 | **Command-Line Arguments** 42 | 43 | The `server/api.py` script supports the following arguments: 44 | 45 | - `--path`: Path to the directory containing parquet files. Only immediate parquet files are used (non-recursive), and other formats like json or jsonl are ignored. 46 | - `--img_root_dir`: Root directory for resolving relative paths in the `img_path` field. 47 | - `--frontend_path`: Path to the folder containing the built frontend (e.g., `dist`). 48 | - `--port`: Port where the server will run (defaults to `5050`). 49 | - `--low_ram`: Enables low RAM mode, which reduces memory usage but may be slower. Recommended for large datasets like GBC1M or GBC10M. 50 | 51 | ### Image Preview 52 | 53 | - **Local Images**: The viewer looks for images locally using the sample's `img_path` field. 54 | - **Fallback to Internet**: If `img_path` is `None` or the local image is missing, the viewer fetches the image using `img_url`. 55 | 56 | 57 | ## Development 58 | 59 | For development, you can run the viewer in different modes: 60 | 61 | - **Frontend Only**: Start the Vite dev server for the frontend. 62 | ```bash 63 | npm run dev:vite 64 | ``` 65 | 66 | - **Backend Only**: Start the API dev server. The frontend will be available at [http://localhost:5173/](http://localhost:5173/). 67 | ```bash 68 | npm run dev:api 69 | ``` 70 | 71 | - **Frontend and Backend Together**: Start both servers simultaneously. 72 | ```bash 73 | npm run dev 74 | ``` 75 | 76 | By default, the development server reads data from the [data](data) folder. You can modify the arguments of `npm run dev:api` in [package.json](package.json) to change this behavior. 77 | -------------------------------------------------------------------------------- /viewer/config.js: -------------------------------------------------------------------------------- 1 | const config = { 2 | "api_url": "/api", 3 | } 4 | 5 | export default config 6 | -------------------------------------------------------------------------------- /viewer/data/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/apple/ml-gbc/59220d21168251ee4ef0332a4e7b0ebfdc909d32/viewer/data/.gitkeep -------------------------------------------------------------------------------- /viewer/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | GBC Viewer 9 | 10 | 12 | 15 | 16 | 17 | 18 | 19 |
20 |
21 |

GBC Viewer

22 |
23 | 30 |
31 |
32 |
33 |
34 | 44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 | Bounding Box and Caption 52 |
53 |
54 |
55 |
56 | 57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |

Image URL not available

66 |
67 |
68 |
Vertex Information
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 | Graph Viewer 84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /viewer/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "gbc-viewtool", 3 | "private": true, 4 | "version": "1.0.0", 5 | "type": "module", 6 | "scripts": { 7 | "dev:vite": "vite", 8 | "dev:api": "python ./server/api.py --path ./data --low-ram", 9 | "dev": "npm-run-all --parallel dev:vite dev:api --", 10 | "build": "vite build", 11 | "preview": "vite preview" 12 | }, 13 | "devDependencies": { 14 | "npm-run-all": "^4.1.5", 15 | "vite": "^5.3.1" 16 | }, 17 | "dependencies": { 18 | "axios": "^1.7.2", 19 | "gbc-viewtool": "^1.0.0", 20 | "showdown": "^2.1.0" 21 | } 22 | } 23 | -------------------------------------------------------------------------------- /viewer/server/headers.py: -------------------------------------------------------------------------------- 1 | headers = [ 2 | ( 3 | "accept", 4 | "text/html,application/xhtml+xml,application/xml;q=0.9," 5 | "image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.9", 6 | ), 7 | ("accept-encoding", "gzip, deflate, br"), 8 | ( 9 | "accept-language", 10 | "zh-TW,zh;q=0.9,en-US;q=0.8,en;q=0.7,af;q=0.6,ja;q=0.5,zh-CN;q=0.4", 11 | ), 12 | ("cache-control", "max-age=0"), 13 | ("if-modified-since", "Mon, 04 May 2020 12:41:48 GMT"), 14 | ("referer", "https://www.pixiv.net/artworks/81295155"), 15 | ("sec-fetch-dest", "document"), 16 | ("sec-fetch-mode", "navigate"), 17 | ("sec-fetch-site", "none"), 18 | ("sec-fetch-user", "?1"), 19 | ("upgrade-insecure-requests", "1"), 20 | ( 21 | "user-agent", 22 | "Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36 (KHTML, like Gecko)" 23 | " Chrome/81.0.4044.129 Safari/537.36", 24 | ), 25 | ] 26 | 27 | headers_dict = {i.strip(): j.strip() for (i, j) in headers} 28 | -------------------------------------------------------------------------------- /viewer/server/requirements.txt: -------------------------------------------------------------------------------- 1 | fastapi 2 | httpx 3 | async_lru 4 | h2 5 | duckdb 6 | uvicorn 7 | Pillow 8 | -------------------------------------------------------------------------------- /viewer/style.css: -------------------------------------------------------------------------------- 1 | html, body, .container-fluid, .row { 2 | height: 100%; 3 | } 4 | #sidebar { 5 | width: 15%; 6 | } 7 | #mynetwork { 8 | width: 100%; 9 | height: 100%; 10 | } 11 | #imageviewer { 12 | display: flex; 13 | justify-content: center; 14 | align-items: center; 15 | width: 100%; 16 | height: 100%; /* Adjust this value as needed */ 17 | background-color: gray; 18 | } 19 | #imgwrapper { 20 | position: relative; 21 | display: inline-block; 22 | line-height: 0; 23 | } 24 | #imgwrapper img { 25 | max-width: 100%; 26 | max-height: 40vh; 27 | width: auto; 28 | height: auto; 29 | } 30 | .bbox { 31 | position: absolute; 32 | border: .2rem solid; 33 | pointer-events: none; 34 | } 35 | #caption { 36 | height: 0; 37 | overflow-y: auto; 38 | } 39 | #captionViewer { 40 | height: 100%; 41 | } 42 | 43 | .outer-container { 44 | width: 100%; 45 | height: 100%; 46 | box-sizing: border-box; 47 | } 48 | 49 | .inner-container { 50 | width: 100%; 51 | height: 100%; 52 | position: relative; 53 | } 54 | 55 | .object { 56 | width: 100%; 57 | height: 100%; 58 | position: absolute; 59 | top: 0; 60 | left: 0; 61 | box-sizing: border-box; 62 | } 63 | .vis-tooltip{ 64 | visibility: hidden !important; 65 | } 66 | .nav li { 67 | white-space: nowrap; 68 | overflow: hidden; 69 | text-overflow: ellipsis; 70 | max-width: 100%; /* Adjust this value as needed */ 71 | color: black; 72 | } 73 | 74 | :root{ 75 | --global-color: rgba(233,43,124,1); 76 | --image-color: rgba(233,43,124,1); 77 | --composition-color: green; 78 | --entity-color: blue; 79 | --relation-color: orange; 80 | } 81 | .node { 82 | text-decoration: none; 83 | } 84 | .node:hover { 85 | text-decoration: underline; 86 | cursor: pointer; 87 | } 88 | .node-image { 89 | color: rgba(233,43,124,1); 90 | } 91 | .node-global { 92 | color: rgba(233,43,124,1); 93 | } 94 | .node-composition { 95 | color: green; 96 | } 97 | .node-relation { 98 | color: orange; 99 | } 100 | .node-entity { 101 | color: blue; 102 | } -------------------------------------------------------------------------------- /viewer/utils.js: -------------------------------------------------------------------------------- 1 | function parseName(label) { 2 | if(label==''){ 3 | return 'empty(root)'; 4 | } 5 | return label.replace(/_/g, ' ').replace(/\|/g, ' / ').replace(/[^a-zA-Z0-9_/ ]/g, ''); 6 | } 7 | 8 | // function classname(string) { 9 | // return string.replace(/[.*+?^${}()|[\]\\]/g, '_-_'); 10 | // } 11 | function classname(str) { 12 | return str.replace(/[^a-zA-Z0-9_-]/g, '_'); 13 | } 14 | 15 | function nodeJumper(nodes, id, text) { 16 | let node = nodes.get(id); 17 | text = text == undefined ? node.label : text; 18 | return `${text}` 19 | } 20 | 21 | function rewriteCaption(description, texts, nodes) { 22 | var history = [""]; 23 | texts = texts.filter((value, index) => { 24 | let result = !history.includes(value[0]); 25 | history.push(value[0]); 26 | return result 27 | }); 28 | console.log("texts", texts) 29 | for (let [text, id] of texts) { 30 | // Use 'i' so that it's case-insensitive 31 | var re = new RegExp( 32 | `([^a-zA-Z])(${text.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')})([^a-zA-Z])`, 33 | "gi" 34 | ); 35 | // Create the regular expression with case insensitivity and word boundary handling 36 | var re = new RegExp( 37 | `(^|[^a-zA-Z])(${text.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')})([^a-zA-Z]|$)`, 38 | "gi" 39 | ); 40 | description = description.replace(re, (match, p1, p2, p3) => { 41 | // console.log(`Original text: ${text}`); 42 | // console.log(`Found in description: ${p2}`); 43 | return `${p1}${nodeJumper(nodes, id, p2)}${p3}`; 44 | }); 45 | } 46 | return description 47 | } 48 | 49 | function printVertex(vertex, nodes, edges, options) { 50 | let result = ""; 51 | result += "**Vertex ID: " + parseName(vertex.vertex_id) + "‌**\n"; 52 | result += "\n**Label: " + vertex.label + "**\n ---\n"; 53 | 54 | if (vertex.in_edges.length > 0) { 55 | result += "\nIn Edges: \n"; 56 | for (let in_edge of vertex.in_edges) { 57 | result += `- ${in_edge.text} (${nodeJumper(nodes, in_edge.source)} -> this)\n`; 58 | } 59 | result += "\n
"; 60 | } 61 | 62 | let out_edge_texts = []; 63 | if (vertex.out_edges.length > 0) { 64 | result += "\nOut Edges: \n"; 65 | for (let out_edge of vertex.out_edges) { 66 | result += `- ${out_edge.text} (this -> ${nodeJumper(nodes, out_edge.target)})\n`; 67 | out_edge_texts.push([out_edge.text, out_edge.target]); 68 | } 69 | result += "\n
"; 70 | } 71 | 72 | result += "\nCaptions:\n"; 73 | for (let { text, label } of vertex.descs) { 74 | text = rewriteCaption(text, out_edge_texts, nodes); 75 | result += `- (**${label}**) ${text}\n`; 76 | } 77 | return result 78 | } 79 | 80 | function bbox(node, width, offset, transparency) { 81 | var style = ""; 82 | offset = offset || 0; 83 | style += `top:${node.bbox.top * 100 + offset}%;`; 84 | style += `left:${node.bbox.left * 100 + offset}%;`; 85 | style += `bottom:${(1 - node.bbox.bottom) * 100 + offset}%;`; 86 | style += `right:${(1 - node.bbox.right) * 100 + offset}%;`; 87 | style += `border: ${width ? width : '.2'}rem solid var(--${node.group}-color);`; 88 | style += `opacity: ${transparency ? transparency : 1};`; 89 | return `
`; 90 | } 91 | 92 | export { parseName, printVertex, classname, nodeJumper, bbox } 93 | -------------------------------------------------------------------------------- /viewer/vis-settings.js: -------------------------------------------------------------------------------- 1 | var options = { 2 | "groups": { 3 | "image": { 4 | "color": { 5 | "background": "rgba(255,150,150,1)", 6 | "border": "red", 7 | "highlight": { 8 | "background": "rgba(255,200,200,1)", 9 | "border": "rgba(233,43,124,1)" 10 | } 11 | } 12 | }, 13 | "global": { 14 | "color": { 15 | "background": "rgba(255,150,150,1)", 16 | "border": "red", 17 | "highlight": { 18 | "background": "rgba(255,200,200,1)", 19 | "border": "rgba(233,43,124,1)" 20 | } 21 | } 22 | }, 23 | "composition": { 24 | "color": { 25 | "background": "rgba(150,255,150,1)", 26 | "border": "green", "highlight": { 27 | "background": "rgba(180,255,180,1)", "border": "rgba(43,124,43,1)" 28 | } 29 | } 30 | }, 31 | "entity": { 32 | "color": { 33 | "background": "rgba(150,150,255,1)", "border": "blue", "highlight": { 34 | "background": "rgba(210,229,255,1)", "border": "rgba(43,124,233,1)" 35 | } 36 | } 37 | }, 38 | "relation": { 39 | "color": { 40 | "background": "rgba(255,255,150,1)", "border": "orange", "highlight": { 41 | "background": "rgba(255,255,200,1)", "border": "rgba(200,200,43,1)" 42 | } 43 | } 44 | } 45 | } 46 | }; 47 | export { options }; -------------------------------------------------------------------------------- /viewer/vite.config.js: -------------------------------------------------------------------------------- 1 | export default { 2 | esbuild: { 3 | drop: ['console', 'debugger'], 4 | }, 5 | server: { 6 | port: process.env.VITE_PORT || 5173, 7 | proxy: { 8 | '/api': { 9 | // Replace target with your API server address 10 | // Only meaningful for dev server 11 | target: 'http://localhost:5050', 12 | changeOrigin: true 13 | } 14 | } 15 | } 16 | } --------------------------------------------------------------------------------