├── .gitignore ├── LICENSE ├── README.md ├── assets ├── .DS_Store ├── trex2 │ ├── countanything.jpg │ ├── demo.jpg │ ├── generic.jpg │ ├── gradio.jpg │ ├── head.jpg │ ├── interactive_0.jpg │ ├── interactive_1.jpg │ ├── method.jpg │ ├── trexlabel.jpg │ ├── video_cover.jpg │ └── video_cover2.png └── trex2_api_examples │ ├── generic_prompt1.jpg │ ├── generic_prompt2.jpg │ ├── generic_target.jpg │ └── interactive1.jpeg ├── demo_examples ├── customize_embedding.py ├── embedding_inference.py ├── football_player_embedding.txt ├── generic_inference.py └── interactive_inference.py ├── gradio_demo.py ├── requirements.txt ├── setup.py └── trex ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── model_wrapper.cpython-38.pyc └── visualize.cpython-38.pyc ├── model_wrapper.py ├── version.py └── visualize.py /.gitignore: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/.gitignore -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | IDEA License 1.0 2 | 3 | This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and the International Digital Economy Academy (“IDEA” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by IDEA under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by IDEA related to the Software (“Documentation”). 4 | 5 | By downloading the Software or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to IDEA that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity. 6 | 7 | 1. LICENSE GRANT 8 | 9 | a. You are granted a non-exclusive, worldwide, transferable, sublicensable, irrevocable, royalty free and limited license under IDEA’s copyright interests to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Software solely for your non-commercial research purposes. 10 | 11 | b. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. IDEA and its licensors reserve all rights not expressly granted by this License. 12 | 13 | c. If you intend to use the Software Products for any commercial purposes, you must request a license from IDEA, which IDEA may grant to you in its sole discretion. 14 | 15 | 2. REDISTRIBUTION AND USE 16 | 17 | a. If you distribute or make the Software Products, or any derivative works thereof, available to a third party, you shall provide a copy of this Agreement to such third party. 18 | 19 | b. You must retain in all copies of the Software Products that you distribute the following attribution notice: "T is licensed under the IDEA License 1.0, Copyright (c) IDEA. All Rights Reserved." 20 | 21 | d. Your use of the Software Products must comply with applicable laws and regulations (including trade compliance laws and regulations). 22 | 23 | e. You will not, and will not permit, assist or cause any third party to use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for in any manner that infringes, misappropriates, or otherwise violates any third-party rights. 24 | 25 | 3. DISCLAIMER OF WARRANTY 26 | 27 | UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. 28 | 29 | 4. LIMITATION OF LIABILITY 30 | 31 | IN NO EVENT WILL IDEA OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF IDEA OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. 32 | 33 | 5. INDEMNIFICATION 34 | 35 | You will indemnify, defend and hold harmless IDEA and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “IDEA Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any IDEA Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the IDEA Parties of any such Claims, and cooperate with IDEA Parties in defending such Claims. You will also grant the IDEA Parties sole control of the defense or settlement, at IDEA’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and IDEA or the other IDEA Parties. 36 | 37 | 6. TERMINATION; SURVIVAL 38 | 39 | a. This License will automatically terminate upon any breach by you of the terms of this License. 40 | 41 | b. If you institute litigation or other proceedings against IDEA or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. 42 | 43 | c. The following sections survive termination of this License: 2 (Redistribution and use), 3 (Disclaimers of Warranty), 4 (Limitation of Liability), 5 (Indemnification), 6 (Termination; Survival), 7 (Trademarks) and 8 (Applicable Law; Dispute Resolution). 44 | 45 | 7. TRADEMARKS 46 | 47 | Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with IDEA without the prior written permission of IDEA, except to the extent necessary to make the reference required by the attribution notice of this Agreement. 48 | 49 | 8. APPLICABLE LAW; DISPUTE RESOLUTION 50 | 51 | This License will be governed and construed under the laws of the People’s Republic of China without regard to conflicts of law provisions. The parties expressly agree that the United Nations Convention on Contracts for the International Sale of Goods will not apply. Any suit or proceeding arising out of or relating to this License will be brought in the courts, as applicable, in Shenzhen, Guangdong, and each party irrevocably submits to the jurisdiction and venue of such courts. 52 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 |
3 | 4 |
5 | 6 |
7 |

A picture speaks volumes, as do the words that frame it.

8 |
9 | 10 |
11 | 12 | ![Static Badge](https://img.shields.io/badge/T--Rex-2-2) [![arXiv preprint](https://img.shields.io/badge/arxiv_2403.14610-blue%3Flog%3Darxiv)](https://arxiv.org/pdf/2403.14610.pdf) [![Homepage](https://img.shields.io/badge/homepage-visit-blue)](https://deepdataspace.com/home) [![Hits](https://hits.seeyoufarm.com/api/count/incr/badge.svg?url=https%3A%2F%2Fgithub.com%2FMountchicken%2FT-Rex&count_bg=%2379C83D&title_bg=%23DF9B9B&icon=iconify.svg&icon_color=%23FFF9F9&title=VISITORS&edge_flat=false)](https://hits.seeyoufarm.com) [![Static Badge](https://img.shields.io/badge/Try_Demo!-blue?logo=chainguard&logoColor=green)](https://deepdataspace.com/playground/ivp) 13 |
14 | 15 | ---- 16 | 🎉 **[T-Rex Label](https://trexlabel.com?source=gt) surpasses 2,000 users!** 17 | Just three months after launch, [T-Rex Label](https://trexlabel.com?source=gt) has grown to over 2,000 users. In our latest release, we’ve improved the annotation tool's user experience and expanded the YOLO format export, making it easier for researchers to quickly build datasets. For any feedback, feel free to reach out at [trexlabel_dm@idea.edu.cn](mailto:trexlabel_dm@idea.edu.cn). 18 | 19 | ---- 20 | 📌 If you find our project helpful and need more API token quotas, you can request additional tokens by [filling out this form](https://docs.google.com/forms/d/e/1FAIpQLSfjogAtkgoVyFX9wvCAE15mD7QtHdKdKOrVmcE5GT1xu-03Aw/viewform?usp=sf_link). Our team will review your request and allocate more tokens for your use in one or two days. You can also apply for more tokens by sending us an email. 21 | 22 | ---- 23 | 24 | # Introduction Video 🎥 25 | Turn on the music if possible 🎧 26 | 27 | [![Video Name](assets/trex2/video_cover.jpg)](https://github.com/Mountchicken/Union14M/assets/65173622/60be19f5-88e4-478e-b1a3-af62b8d6d177) 28 | 29 | # News 📰 30 | - **2024-06-24**: We have introduced two new free products based on T-Rex2: 31 | - [**Count Anything APP**](https://apps.apple.com/app/id6502489882): CountAnything is a versatile, efficient, and cost-effective counting tool that utilizes advanced computer vision algorithms, specifically T-Rex, for automated counting. It is applicable across various industries, including manufacturing, agriculture, and aquaculture. 32 | 33 | [![Video Name](assets/trex2/countanything.jpg)](https://github.com/Mountchicken/Mountchicken/assets/65173622/1cffc04a-d9be-46ec-b87e-f754b71d6e21) 34 | 35 | - [**T-Rex Label**](https://www.trexlabel.com/?source=gh): T-Rex Label is an advanced annotation tool powerd by T-Rex2, specifically designed to handle the complexities of various industries and scenarios. It is the ideal choice for those aiming to streamline their workflows and effortlessly create high-quality datasets. 36 | 37 | [![Video Name](assets/trex2/trexlabel.jpg)](https://github.com/Mountchicken/CodeCookbook/assets/65173622/58129775-533d-4aad-88f4-e1992546f9ba) 38 | 39 | - **2024-05-17**: [Grounding DINO 1.5](https://github.com/IDEA-Research/Grounding-DINO-1.5-API) is released. This is IDEA Research's Most Capable Open-World Object Detection Model Series. It can detect any object throught text prompts! 40 | 41 | # Contents 📜 42 | - [Introduction Video 🎥](#introduction-video-) 43 | - [News 📰](#news-) 44 | - [Contents 📜](#contents-) 45 | - [1. Introduction 📚](#1-introduction-) 46 | - [What Can T-Rex Do 📝](#what-can-t-rex-do-) 47 | - [2. Try Demo 🎮](#2-try-demo-) 48 | - [3. API Usage Examples📚](#3-api-usage-examples) 49 | - [Setup](#setup) 50 | - [Interactive Visual Prompt API](#interactive-visual-prompt-api) 51 | - [Generic Visual Prompt API](#generic-visual-prompt-api) 52 | - [Customize Visual Prompt Embedding API](#customize-visual-prompt-embedding-api) 53 | - [Embedding Inference API](#embedding-inference-api) 54 | - [4. Local Gradio Demo with API🎨](#4-local-gradio-demo-with-api) 55 | - [4.1. Setup](#41-setup) 56 | - [4.2. Run the Gradio Demo](#42-run-the-gradio-demo) 57 | - [4.3. Basic Operations](#43-basic-operations) 58 | - [5. Related Works](#5-related-works) 59 | - [BibTeX 📚](#bibtex-) 60 | 61 | # 1. Introduction 📚 62 | Object detection, the ability to locate and identify objects within an image, is a cornerstone of computer vision, pivotal to applications ranging from autonomous driving to content moderation. A notable limitation of traditional object detection models is their closed-set nature. These models are trained on a predetermined set of categories, confining their ability to recognize only those specific categories. The training process itself is arduous, demanding expert knowledge, extensive datasets, and intricate model tuning to achieve desirable accuracy. Moreover, the introduction of a novel object category, exacerbates these challenges, necessitating the entire process to be repeated. 63 | 64 | T-Rex2 addresses these limitations by integrating both text and visual prompts in one model, thereby harnessing the strengths of both modalities. The synergy of text and visual prompts equips T-Rex2 with robust zero-shot capabilities, making it a versatile tool in the ever-changing landscape of object detection. 65 | 66 |
67 | 68 |
69 | 70 | ## What Can T-Rex Do 📝 71 | T-Rex2 is well-suited for a variety of real-world applications, including but not limited to: agriculture, industry, livstock and wild animals monitoring, biology, medicine, OCR, retail, electronics, transportation, logistics, and more. T-Rex2 mainly supports three major workflows including interactive visual prompt workflow, generic visual prompt workflow and text prompt workflow. It can cover most of the application scenarios that require object detection 72 | 73 | [![Video Name](assets/trex2/video_cover2.png)](https://github.com/Mountchicken/Union14M/assets/65173622/c3585d49-208c-4ba4-9954-fd1572d299dc) 74 | 75 | # 2. Try Demo 🎮 76 | We are now opening online demo for T-Rex2. [Check our demo here](https://deepdataspace.com/playground/ivp) 77 | 78 |
79 | 80 |
81 | 82 | 83 | # 3. API Usage Examples📚 84 | We are now opening free API access to T-Rex2. For educators, students, and researchers, we offer an API with extensive usage times to support your educational and research endeavors. You can get API at here [request API](https://cloud.deepdataspace.com/apply-token?from=github). 85 | - [Full API documentation can be found here](https://cloudapi-sdk.deepdataspace.com/dds_cloudapi_sdk/tasks/trex_interactive.html). 86 | 87 | 88 | ## Setup 89 | Install the API package and acquire the API token from the email. 90 | ```bash 91 | git clone https://github.com/IDEA-Research/T-Rex.git 92 | cd T-Rex 93 | pip install dds-cloudapi-sdk==0.1.1 94 | pip install -v -e . 95 | ``` 96 | 97 | 98 | 99 | ## Interactive Visual Prompt API 100 | - In interactive visual prompt workflow, users can provide visual prompts in boxes or points format on a given image to specify the object to be detected. 101 | 102 | ```python 103 | python demo_examples/interactive_inference.py --token 104 | ``` 105 | - You are supposed get the following visualization results at `demo_vis/` 106 |
107 | 108 | 109 |
110 | 111 | ## Generic Visual Prompt API 112 | - In generic visual prompt workflow, users can provide visual prompts on one reference image 113 | and detect on the other image. 114 | 115 | ```python 116 | python demo_examples/generic_inference.py --token 117 | ``` 118 | - You are supposed get the following visualization results at `demo_vis/` 119 |
120 | + 121 | = 122 | 123 |
124 | 125 | ## Customize Visual Prompt Embedding API 126 | In this workflow, you can customize a visual embedding for a object category using multiple images. With this embedding, you can detect on any images. 127 | 128 | ```python 129 | python demo_examples/customize_embedding.py --token 130 | ``` 131 | - You are supposed to get a download link for this visual prompt embedding in `safetensors` format. Save it and let's use it for `embedding_inference`. 132 | 133 | ## Embedding Inference API 134 | With the visual prompt embeddings generated from the previous API. You can use it detect on any images. 135 | ```python 136 | python demo_examples/embedding_inference.py --token 137 | ``` 138 | 139 | # 4. Local Gradio Demo with API🎨 140 |
141 | 142 |
143 | 144 | ## 4.1. Setup 145 | - Install T-Rex2 API if you haven't done so 146 | ```bash 147 | - install gradio and other dependencies 148 | ```bash 149 | # install gradio and other dependencies 150 | pip install gradio-image-prompter 151 | ``` 152 | 153 | ## 4.2. Run the Gradio Demo 154 | ```bash 155 | python gradio_demo.py --trex2_api_token 156 | ``` 157 | 158 | ## 4.3. Basic Operations 159 | - **Draw Box**: Draw a box on the image to specify the object to be detected. Drag the left mouse button to draw a box. 160 | - **Draw Point**: Draw a point on the image to specify the object to be detected. Click the left mouse button to draw a point. 161 | - **Interactive Visual Prompt**: Provide visual prompts in boxes or points format on a given image to specify the object to be detected. The Input Target Image and Interactive Visual Prompt Image should be the same 162 | - **Generic Visual Prompt**: Provide visual prompts on multiple reference images and detect on the other image. 163 | 164 | # 5. Related Works 165 | :fire: We release the [training and inference code](https://github.com/UX-Decoder/DINOv) and [demo link](http://semantic-sam.xyzou.net:6099/) of [DINOv](https://arxiv.org/pdf/2311.13601.pdf), which can handle in-context **visual prompts** for open-set and referring detection & segmentation. Check it out! 166 | 167 | # 6. LICENSE 168 | We use [IDEA License 1.0](LICENSE) 169 | 170 | # BibTeX 📚 171 | ``` 172 | @misc{jiang2024trex2, 173 | title={T-Rex2: Towards Generic Object Detection via Text-Visual Prompt Synergy}, 174 | author={Qing Jiang and Feng Li and Zhaoyang Zeng and Tianhe Ren and Shilong Liu and Lei Zhang}, 175 | year={2024}, 176 | eprint={2403.14610}, 177 | archivePrefix={arXiv}, 178 | primaryClass={cs.CV} 179 | } 180 | ``` 181 | -------------------------------------------------------------------------------- /assets/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/.DS_Store -------------------------------------------------------------------------------- /assets/trex2/countanything.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/countanything.jpg -------------------------------------------------------------------------------- /assets/trex2/demo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/demo.jpg -------------------------------------------------------------------------------- /assets/trex2/generic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/generic.jpg -------------------------------------------------------------------------------- /assets/trex2/gradio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/gradio.jpg -------------------------------------------------------------------------------- /assets/trex2/head.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/head.jpg -------------------------------------------------------------------------------- /assets/trex2/interactive_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/interactive_0.jpg -------------------------------------------------------------------------------- /assets/trex2/interactive_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/interactive_1.jpg -------------------------------------------------------------------------------- /assets/trex2/method.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/method.jpg -------------------------------------------------------------------------------- /assets/trex2/trexlabel.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/trexlabel.jpg -------------------------------------------------------------------------------- /assets/trex2/video_cover.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/video_cover.jpg -------------------------------------------------------------------------------- /assets/trex2/video_cover2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2/video_cover2.png -------------------------------------------------------------------------------- /assets/trex2_api_examples/generic_prompt1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2_api_examples/generic_prompt1.jpg -------------------------------------------------------------------------------- /assets/trex2_api_examples/generic_prompt2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2_api_examples/generic_prompt2.jpg -------------------------------------------------------------------------------- /assets/trex2_api_examples/generic_target.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2_api_examples/generic_target.jpg -------------------------------------------------------------------------------- /assets/trex2_api_examples/interactive1.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/assets/trex2_api_examples/interactive1.jpeg -------------------------------------------------------------------------------- /demo_examples/customize_embedding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from trex import TRex2APIWrapper, visualize 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser(description="Interactive Inference") 12 | parser.add_argument( 13 | "--token", 14 | type=str, 15 | help="The token for T-Rex2 API. We are now opening free API access to T-Rex2", 16 | ) 17 | parser.add_argument( 18 | "--box_threshold", type=float, default=0.3, help="The threshold for box score" 19 | ) 20 | return parser.parse_args() 21 | 22 | 23 | if __name__ == "__main__": 24 | args = get_args() 25 | trex2 = TRex2APIWrapper(args.token) 26 | 27 | target_image = "assets/trex2_api_examples/generic_target.jpg" 28 | prompts = [ 29 | dict( 30 | image="assets/trex2_api_examples/generic_prompt1.jpg", 31 | interactions=[ 32 | { 33 | "type": "rect", 34 | "category_id": 1, 35 | "rect": [692, 338, 725, 459], 36 | }, 37 | { 38 | "type": "rect", 39 | "category_id": 1, 40 | "rect": [561, 231, 634, 351], 41 | }, 42 | ], 43 | ), 44 | dict( 45 | image="assets/trex2_api_examples/generic_prompt2.jpg", 46 | interactions=[ 47 | { 48 | "type": "rect", 49 | "category_id": 1, 50 | "rect": [561, 231, 634, 351], 51 | }, 52 | ], 53 | ), 54 | ] 55 | result = trex2.visual_prompt_inference( 56 | target_image, prompts, return_type=["embedding"] 57 | )[1] 58 | # save this base64 result to a file 59 | with open("demo_examples/football_player_embedding.txt", "w") as f: 60 | f.write(result) 61 | -------------------------------------------------------------------------------- /demo_examples/embedding_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from trex import TRex2APIWrapper, visualize 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser(description="Interactive Inference") 12 | parser.add_argument( 13 | "--token", 14 | type=str, 15 | help="The token for T-Rex2 API. We are now opening free API access to T-Rex2", 16 | ) 17 | parser.add_argument( 18 | "--box_threshold", type=float, default=0.3, help="The threshold for box score" 19 | ) 20 | parser.add_argument( 21 | "--vis_dir", 22 | type=str, 23 | default="demo_vis/", 24 | help="The directory for visualization", 25 | ) 26 | return parser.parse_args() 27 | 28 | 29 | if __name__ == "__main__": 30 | args = get_args() 31 | trex2 = TRex2APIWrapper(args.token) 32 | target_image = "assets/trex2_api_examples/generic_target.jpg" 33 | embedding = "demo_examples/football_player_embedding.txt" 34 | with open(embedding, "r") as f: 35 | embedding = f.read() 36 | result = trex2.embedding_inference(target_image, embedding) 37 | # filter out the boxes with low score 38 | scores = np.array(result["scores"]) 39 | labels = np.array(result["labels"]) 40 | boxes = np.array(result["boxes"]) 41 | filter_mask = scores > args.box_threshold 42 | filtered_result = { 43 | "scores": scores[filter_mask], 44 | "labels": labels[filter_mask], 45 | "boxes": boxes[filter_mask], 46 | } 47 | # visualize the results 48 | if not os.path.exists(args.vis_dir): 49 | os.makedirs(args.vis_dir) 50 | 51 | image = Image.open(target_image) 52 | image = visualize(image, filtered_result, draw_score=True) 53 | image.save(os.path.join(args.vis_dir, f"embedding.jpg")) 54 | print(f"Visualized image saved to {args.vis_dir}/embedding.jpg") 55 | -------------------------------------------------------------------------------- /demo_examples/football_player_embedding.txt: -------------------------------------------------------------------------------- 1 | W3siY2F0ZWdvcnlfaWQiOiAxLCAiZW1iZWRkaW5nIjogIjd0Yk12UlpXcGIwRkF3UkFPQnFEdmtBS29iMEVYbHcrNGtFYnZ2M1l3TDM1SkplK2E3S1B2dWdReUwwOFJUUytNcEMwUHYvREVqNnY3UUEvMWJHYlBVaTFqNzBFOVhZK0toVW12ZzdZSGo5Qnp3Rys2aHdLdmhCSEtMMTlhTW0rNytxT1BOSEhIcjI4dFRvK2Jkbkd2VWhmaERvNGZFbytZaU9aUHNZanU3M0FoeHcrODRpalBiTEpYNzNvaEhzK1RWNFZQby9DQmIzQ3BiWS8xY3lGdmFheHlyd0Vnd2krdkFJN3Z3ekR2cjRINTQyK2JCdUpQVnJGRkQ3d3hNZStUMDd0dmVyT2tiMERXcnkrZGpBaFBwRmh2RDBPaUY2K2FjSXJQc2owbnoxQVJxZzdZdTB6djFVR2xUNjZPQUE4c3FLbFB0NzhyTDY0cXA0K0l4c0VQNTFoQzc1NVBvdTk0WFhnTy9HS21iMk9ZWFk5NFZtSVBtcjVINytyY3hDK3htVmJ2aEIvSWo5a3pFUTliUUREdmZZRVc3MnA2S0kremhTYnZXS3NhajY0MFpjK3hGRFp2SHlsWlRzSU9lWTk0SnlWUG53ZU1iM01wOVE5dVFtT3ZuQkxFVDR2RFFLOTFoMEdQdlVzN3IxQkNKMDlvRDBSUHg0eC83MnZnVDQ5ZkRheVBONUpPYjNra215OVBFTEt2cXdPaTc3c211OCsrdmFMdmJZOWhUNGIrQjY4MlZYanZRQWFwendlYlJBL09LUVh2bldocXI3R0tvcTdRanhadk9aUks3NHUvNE85NGlIYXZmbTVDVC9NdDFvOUpNWGZ2SzVjMEwybnI3czhjbXI1dk5BYkRMMUhGWk8rRGtVNHZ6NVlDYjdLOVlnK1NDQ2pQVVhFcFQ2WVJYZStmZ1BYUGFMVC9yMk03L1M4aHJGOHZoSTJpajN5ZkxZLzkvdjZQVGlwbHIwK2NJSytqVEgxUHBUTkN6OUdiWGErcjdFT1BtamcyYjBIbDltOGxxU1h2dFhFN0VBRXFrWEJ4U1dRdlJvMU5iNjJ4VjYrQ0tOOXZlTEFGejBhbU9xOUdBWDRQbVFZeVQ3dzhZazlDSndGdmNMaWNiME1EMU85ZnJvY3ZiRTZiYnlNTUJ1K3NDOFl2VHpneHIyeXk5UytZRmNPdmg2d0U3Nmg2RUcvVktHQXZpNW9pN3hJMU1vOTdvdDZ2WXdoaUx6TVBJNjlTRm93UGw5V3dEMkdBRU0reVhzSVBGTHlrcjR3bjZJOStpUFpQSi9XQnovakFmVThEUWhUTzNndytiN3k3WEMrMWVJZnY2YldTTDJQZFI0L0lYUWRQcHpHdnI3MjFzMDlTUCswUG0vQ2lEM0dITis4dmlaZnZoWitTYjRHaGNzOVlXVWpQRUxaRkQyUGxCUzk4TFVOdkhDa2ZidHlSSm05ckxEVFBKU0g5NzNBMVdnKzVUQWlQb3Mzc0x3SDE4czhjSXprdnNkdTRMeThMbHEvUHJWa3Zyd3pKejUya1IrOUtXUFFPd1NxMWoxT1hHKytZcU01dnNRVEJENUs4L2U5ZHRZM1BsbDNVYitmVzFTK3RiT0VQZ0dTK2J2ZldTZythTi9tdkd0S2d6MWtweks5cFNtSFBxMEtUTHc5MVdJOU1vMSt2WUR2Y0QydWdMSSt4R0tudm52aHNEMGMwOWE3OHRjaHZwSnhyTDVmQlQrK2FINGJ2NWYwNWp6b2FtaStLQ1FadW82YVZyN094SzY5TmtYQ1BYQ3JLTDdHMHVlOXdjZXl2TklOdDc1ME81VzliUlFKUGc9PSJ9XQ== -------------------------------------------------------------------------------- /demo_examples/generic_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from trex import TRex2APIWrapper, visualize 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser(description="Interactive Inference") 12 | parser.add_argument( 13 | "--token", 14 | type=str, 15 | help="The token for T-Rex2 API. We are now opening free API access to T-Rex2", 16 | ) 17 | parser.add_argument( 18 | "--box_threshold", type=float, default=0.3, help="The threshold for box score" 19 | ) 20 | parser.add_argument( 21 | "--vis_dir", 22 | type=str, 23 | default="demo_vis/", 24 | help="The directory for visualization", 25 | ) 26 | return parser.parse_args() 27 | 28 | 29 | if __name__ == "__main__": 30 | args = get_args() 31 | trex2 = TRex2APIWrapper(args.token) 32 | 33 | target_image = "assets/trex2_api_examples/generic_target.jpg" 34 | prompts = [ 35 | dict( 36 | image="assets/trex2_api_examples/generic_prompt1.jpg", 37 | interactions=[ 38 | { 39 | "type": "rect", 40 | "category_id": 1, 41 | "rect": [692, 338, 725, 459], 42 | }, 43 | { 44 | "type": "rect", 45 | "category_id": 1, 46 | "rect": [561, 231, 634, 351], 47 | }, 48 | ], 49 | ), 50 | dict( 51 | image="assets/trex2_api_examples/generic_prompt2.jpg", 52 | interactions=[ 53 | { 54 | "type": "rect", 55 | "category_id": 1, 56 | "rect": [561, 231, 634, 351], 57 | }, 58 | ], 59 | ), 60 | ] 61 | result = trex2.visual_prompt_inference(target_image, prompts)[0] 62 | # filter out the boxes with low score 63 | 64 | scores = np.array(result["scores"]) 65 | labels = np.array(result["labels"]) 66 | boxes = np.array(result["boxes"]) 67 | filter_mask = scores > args.box_threshold 68 | filtered_result = { 69 | "scores": scores[filter_mask], 70 | "labels": labels[filter_mask], 71 | "boxes": boxes[filter_mask], 72 | } 73 | # visualize the results 74 | if not os.path.exists(args.vis_dir): 75 | os.makedirs(args.vis_dir) 76 | 77 | image = Image.open(target_image) 78 | image = visualize(image, filtered_result, draw_score=True) 79 | image.save(os.path.join(args.vis_dir, f"generic.jpg")) 80 | print(f"Visualized image saved to {args.vis_dir}/generic.jpg") 81 | -------------------------------------------------------------------------------- /demo_examples/interactive_inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from trex import TRex2APIWrapper, visualize 8 | 9 | 10 | def get_args(): 11 | parser = argparse.ArgumentParser(description="Interactive Inference") 12 | parser.add_argument( 13 | "--token", 14 | type=str, 15 | help="The token for T-Rex2 API. We are now opening free API access to T-Rex2", 16 | ) 17 | parser.add_argument( 18 | "--box_threshold", type=float, default=0.3, help="The threshold for box score" 19 | ) 20 | parser.add_argument( 21 | "--vis_dir", 22 | type=str, 23 | default="demo_vis/", 24 | help="The directory for visualization", 25 | ) 26 | return parser.parse_args() 27 | 28 | 29 | if __name__ == "__main__": 30 | args = get_args() 31 | trex2 = TRex2APIWrapper(args.token) 32 | # demo for box input 33 | target_image = "assets/trex2_api_examples/interactive1.jpeg" 34 | prompts = [ 35 | dict( 36 | image="assets/trex2_api_examples/interactive1.jpeg", 37 | interactions=[ 38 | { 39 | "type": "rect", 40 | "category_id": 1, 41 | "rect": [347, 1259, 600, 1437], 42 | }, 43 | { 44 | "type": "rect", 45 | "category_id": 1, 46 | "rect": [1085, 1154, 1154, 1246], 47 | }, 48 | { 49 | "type": "rect", 50 | "category_id": 2, 51 | "rect": [1465, 787, 1497, 877], 52 | }, 53 | ], 54 | ) 55 | ] 56 | 57 | result = trex2.visual_prompt_inference(target_image, prompts)[0] 58 | # filter out the boxes with low score 59 | 60 | scores = np.array(result["scores"]) 61 | labels = np.array(result["labels"]) 62 | boxes = np.array(result["boxes"]) 63 | filter_mask = scores > args.box_threshold 64 | filtered_result = { 65 | "scores": scores[filter_mask], 66 | "labels": labels[filter_mask], 67 | "boxes": boxes[filter_mask], 68 | } 69 | # visualize the results 70 | if not os.path.exists(args.vis_dir): 71 | os.makedirs(args.vis_dir) 72 | 73 | image = Image.open(target_image) 74 | image = visualize(image, filtered_result, draw_score=True) 75 | image.save(os.path.join(args.vis_dir, f"interactive.jpg")) 76 | print(f"Visualized image saved to {args.vis_dir}/interactive.jpg") 77 | -------------------------------------------------------------------------------- /gradio_demo.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import random 4 | from typing import Dict, List 5 | 6 | import gradio as gr 7 | import numpy as np 8 | from gradio_image_prompter import ImagePrompter 9 | from PIL import Image, ImageDraw, ImageFont 10 | 11 | from trex import TRex2APIWrapper 12 | 13 | 14 | def arg_parse(): 15 | parser = argparse.ArgumentParser(description="Gradio Demo for T-Rex2") 16 | parser.add_argument( 17 | "--trex2_api_token", 18 | type=str, 19 | help="API token for T-Rex2", 20 | ) 21 | parser.add_argument("--sam_type", type=str, default="vit_l", help="SAM model type") 22 | parser.add_argument( 23 | "--sam_checkpoint_path", type=str, help="path to checkpoint file" 24 | ) 25 | args = parser.parse_args() 26 | return args 27 | 28 | 29 | def plot_boxes_to_image( 30 | image_pil: Image, 31 | tgt: Dict, 32 | return_point: bool = False, 33 | point_width: float = 1.0, 34 | return_score=True, 35 | ) -> Image: 36 | """Plot bounding boxes and labels on an image. 37 | 38 | Args: 39 | image_pil (PIL.Image): The input image as a PIL Image object. 40 | tgt (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing 41 | the bounding boxes and labels. The keys are: 42 | - scores: A tuple containing the height and width of the image. 43 | - boxes: A list of normalized bounding boxes as a list of shape (N, 4), in 44 | (x_center, y_center, width, height) format. 45 | - labels: A list of string labels for each bounding box. 46 | return_point (bool): Draw center point instead of bounding box. Defaults to False. 47 | 48 | Returns: 49 | Union[PIL.Image, PIL.Image]: A tuple containing the input image and ploted image. 50 | """ 51 | # Get the bounding boxes and labels from the target dictionary 52 | boxes = tgt["boxes"] 53 | scores = tgt["scores"] 54 | 55 | # Create a PIL ImageDraw object to draw on the input image 56 | draw = ImageDraw.Draw(image_pil) 57 | # Create a new binary mask image with the same size as the input image 58 | mask = Image.new("L", image_pil.size, 0) 59 | # Create a PIL ImageDraw object to draw on the mask image 60 | mask_draw = ImageDraw.Draw(mask) 61 | 62 | # Draw boxes and masks for each box and label in the target dictionary 63 | for box, score in zip(boxes, scores): 64 | # Convert the box coordinates from 0..1 to 0..W, 0..H 65 | color = tuple(np.random.randint(0, 255, size=3).tolist()) 66 | # Extract the box coordinates 67 | x0, y0, x1, y1 = box 68 | x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) 69 | if return_point: 70 | ceter_x = int((x0 + x1) / 2) 71 | ceter_y = int((y0 + y1) / 2) 72 | # Draw the center point on the input image 73 | draw.ellipse( 74 | ( 75 | ceter_x - point_width, 76 | ceter_y - point_width, 77 | ceter_x + point_width, 78 | ceter_y + point_width, 79 | ), 80 | fill=color, 81 | width=point_width, 82 | ) 83 | else: 84 | # Draw the box outline on the input image 85 | draw.rectangle([x0, y0, x1, y1], outline=color, width=int(point_width)) 86 | 87 | # Draw the label text on the input image 88 | if return_score: 89 | text = f"{score:.2f}" 90 | else: 91 | text = f"" 92 | font = ImageFont.load_default() 93 | if hasattr(font, "getbbox"): 94 | bbox = draw.textbbox((x0, y0), text, font) 95 | else: 96 | w, h = draw.textsize(text, font) 97 | bbox = (x0, y0, w + x0, y0 + h) 98 | if not return_point: 99 | draw.rectangle(bbox, fill=color) 100 | draw.text((x0, y0), text, fill="white") 101 | 102 | # Draw the box on the mask image 103 | mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) 104 | return image_pil, mask 105 | 106 | 107 | def multi_mask2one_mask(masks): 108 | _, _, h, w = masks.shape 109 | for i, mask in enumerate(masks): 110 | mask_image = mask.reshape(h, w, 1) 111 | whole_mask = mask_image if i == 0 else whole_mask + mask_image 112 | whole_mask = np.where(whole_mask == False, 0, 255) 113 | return whole_mask 114 | 115 | 116 | def numpy2PIL(numpy_image): 117 | out = Image.fromarray(numpy_image.astype(np.uint8)) 118 | return out 119 | 120 | 121 | def draw_mask(mask, draw, random_color=True): 122 | if random_color: 123 | color = ( 124 | random.randint(0, 255), 125 | random.randint(0, 255), 126 | random.randint(0, 255), 127 | 153, 128 | ) 129 | else: 130 | color = (30, 144, 255, 153) 131 | 132 | nonzero_coords = np.transpose(np.nonzero(mask)) 133 | 134 | for coord in nonzero_coords: 135 | draw.point(coord[::-1], fill=color) 136 | 137 | 138 | def build_annotation(boxes, mask): 139 | annotations = [] 140 | mask_coor = np.transpose(np.nonzero(mask)).astype(np.int32).tolist() 141 | for i, box in enumerate(boxes): 142 | # convert box from xyxy to xywh 143 | box = box.tolist() 144 | box[2] -= box[0] 145 | box[3] -= box[1] 146 | box = np.array(box).astype(np.int32).tolist() 147 | area = box[2] * box[3] 148 | annotation = { 149 | "id": i, 150 | "image_id": 0, 151 | "category_id": 0, 152 | "segmentation": [], 153 | "mask": mask_coor, 154 | "area": area, 155 | "bbox": box, 156 | "iscrowd": 0, 157 | } 158 | annotations.append(annotation) 159 | return json.dumps(dict(annotation=annotations)) 160 | 161 | 162 | def clean_input(): 163 | return [None] * 9 164 | 165 | 166 | def parse_visual_prompt(points: List): 167 | boxes = [] 168 | pos_points = [] 169 | neg_points = [] 170 | for point in points: 171 | if point[2] == 2 and point[-1] == 3: 172 | x1, y1, _, x2, y2, _ = point 173 | boxes.append([x1, y1, x2, y2]) 174 | elif point[2] == 1 and point[-1] == 4: 175 | x, y, _, _, _, _ = point 176 | pos_points.append([x, y]) 177 | elif point[2] == 0 and point[-1] == 4: 178 | x, y, _, _, _, _ = point 179 | neg_points.append([x, y]) 180 | return boxes, pos_points, neg_points 181 | 182 | 183 | def pack_model_input_interactive(interactive_input): 184 | ref_image = interactive_input["image"] 185 | ref_image = Image.fromarray(ref_image) 186 | ref_visual_prompt = interactive_input["points"] 187 | boxes, pos_points, neg_points = parse_visual_prompt(ref_visual_prompt) 188 | # boxes and points can not show at the same time 189 | if not len(boxes) > 0: 190 | raise gr.Error("You can't draw box. We do not support point prompt for now") 191 | if len(boxes) > 0: 192 | prompts = { 193 | "prompt_image": ref_image, 194 | "type": "rect", 195 | "prompts": [{"category_id": 1, "rects": boxes}], 196 | } 197 | interactions = [dict(type="rect", category_id=1, rect=box) for box in boxes] 198 | prompts = [ 199 | dict( 200 | image=ref_image, 201 | interactions=interactions, 202 | ) 203 | ] 204 | return prompts 205 | 206 | 207 | def pack_model_input_generic(generic_vp_dict): 208 | prompts = [] 209 | for k, v in generic_vp_dict.items(): 210 | if v is None: 211 | continue 212 | ref_image = v["image"] 213 | ref_image = Image.fromarray(ref_image) 214 | ref_visual_prompt = v["points"] 215 | boxes, pos_points, _ = parse_visual_prompt(ref_visual_prompt) 216 | # boxes and points can not show at the same time 217 | if not len(boxes) > 0: 218 | raise gr.Error("You can't draw box. We do not support point prompt for now") 219 | if len(boxes) > 0: 220 | interactions = [dict(type="rect", category_id=1, rect=box) for box in boxes] 221 | prompts.append(dict(image=ref_image, interactions=interactions)) 222 | return prompts 223 | 224 | 225 | def trex2_postprocess( 226 | target_image, 227 | trex2_results, 228 | visual_threshold, 229 | return_point, 230 | point_width, 231 | return_score, 232 | ): 233 | if isinstance(trex2_results, dict): 234 | trex2_results = [trex2_results] 235 | # filter based on visual threshold 236 | scores = np.array(trex2_results[0]["scores"]) 237 | boxes = np.array(trex2_results[0]["boxes"]) 238 | labels = np.array(trex2_results[0]["labels"]) 239 | filter_mask = scores > float(visual_threshold) 240 | boxes = boxes[filter_mask] 241 | labels = labels[filter_mask] 242 | scores = scores[filter_mask] 243 | trex2_results[0]["boxes"] = boxes 244 | trex2_results[0]["labels"] = labels 245 | trex2_results[0]["scores"] = scores 246 | image_with_box = plot_boxes_to_image( 247 | target_image, trex2_results[0], return_point, point_width, return_score 248 | )[0] 249 | visualization = np.array(image_with_box) 250 | mask = None 251 | return visualization, len(boxes), build_annotation(boxes, mask) 252 | 253 | 254 | def inference( 255 | target_image, 256 | interactive_input, 257 | generic_vp1, 258 | generic_vp2, 259 | generic_vp3, 260 | generic_vp4, 261 | generic_vp5, 262 | generic_vp6, 263 | generic_vp7, 264 | generic_vp8, 265 | visual_threshold, 266 | return_point, 267 | point_width, 268 | return_score, 269 | ): 270 | 271 | generic_vp_dict = { 272 | "1": generic_vp1, 273 | "2": generic_vp2, 274 | "3": generic_vp3, 275 | "4": generic_vp4, 276 | "5": generic_vp5, 277 | "6": generic_vp6, 278 | "7": generic_vp7, 279 | "8": generic_vp8, 280 | } 281 | if target_image is None: 282 | gr.Error("Please provide a target image") 283 | # tell if generic visual prompt is empty 284 | target_image = Image.fromarray(target_image) 285 | generic_is_empty = True 286 | for _, v in generic_vp_dict.items(): 287 | if v is not None: 288 | generic_is_empty = False 289 | break 290 | # We support: 291 | # 1. interactive visual prompt 292 | # 2. generic visual prompt 293 | if interactive_input is not None and generic_is_empty: 294 | prompts = pack_model_input_interactive(interactive_input) 295 | trex2_results = trex2.visual_prompt_inference(target_image, prompts)[0] 296 | elif interactive_input is None and not generic_is_empty: 297 | prompts = pack_model_input_generic(generic_vp_dict) 298 | trex2_results = trex2.visual_prompt_inference(target_image, prompts)[0] 299 | else: 300 | raise gr.Error( 301 | "You should provide either interactive visual prompt or generic visual prompt" 302 | ) 303 | visualization, num_count, coco_anno = trex2_postprocess( 304 | target_image, 305 | trex2_results, 306 | visual_threshold, 307 | return_point, 308 | point_width, 309 | return_score, 310 | ) 311 | # interactive only inference 312 | return visualization, num_count, coco_anno 313 | 314 | 315 | args = arg_parse() 316 | trex2 = TRex2APIWrapper(args.trex2_api_token) 317 | # args.device = 'cuda' if torch.cuda.is_available() else 'cpu' 318 | # sam = sam_model_registry['vit_l'](checkpoint=args.sam_checkpoint_path) 319 | # sam.to(device=args.device) 320 | # sam_predictor = SamPredictor(sam) 321 | 322 | if __name__ == "__main__": 323 | interactive_1 = ImagePrompter(label="1", scale=1) 324 | generic_vp1 = ImagePrompter(label="Generic Visual Prompt 1", scale=1) 325 | generic_vp2 = ImagePrompter(label="Generic Visual Prompt 2", scale=1) 326 | generic_vp3 = ImagePrompter(label="Generic Visual Prompt 3", scale=1) 327 | generic_vp4 = ImagePrompter(label="Generic Visual Prompt 4", scale=1) 328 | generic_vp5 = ImagePrompter(label="Generic Visual Prompt 5", scale=1) 329 | generic_vp6 = ImagePrompter(label="Generic Visual Prompt 6", scale=1) 330 | generic_vp7 = ImagePrompter(label="Generic Visual Prompt 7", scale=1) 331 | generic_vp8 = ImagePrompter(label="Generic Visual Prompt 8", scale=1) 332 | with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: 333 | with gr.Row(): 334 | with gr.Column(): 335 | with gr.Row(): 336 | with gr.Column(): 337 | target_image = gr.Image(label="Input Target Image", width=300) 338 | with gr.Column(): 339 | with gr.Row(): 340 | return_point = gr.Checkbox(label="Return Point Anno") 341 | with gr.Row(): 342 | return_score = gr.Checkbox(label="Return Score") 343 | with gr.Row(): 344 | point_width = gr.Slider( 345 | label="Line/Point Width", 346 | value=5.0, 347 | minimum=0.0, 348 | maximum=20.0, 349 | step=0.01, 350 | ) 351 | with gr.Row(): 352 | output_image = gr.Image(label="Output Image", width=300) 353 | with gr.Row(): 354 | num_count = gr.Textbox( 355 | label="Counting Results", lines=1, show_copy_button=True 356 | ) 357 | with gr.Row(): 358 | coco_anno = gr.Textbox( 359 | label="COCO Results", 360 | lines=1, 361 | max_lines=4, 362 | show_copy_button=True, 363 | ) 364 | 365 | with gr.Column(): 366 | with gr.Row(): 367 | interactions = "LeftClick (Point Prompt) | PressMove (Box Prompt)" 368 | gr.Markdown( 369 | "

This is for interactive visual prompt

" 370 | ) 371 | gr.Markdown( 372 | "

[🖱️ | 🖐️]: 🌟🌟 {} 🌟🌟

".format( 373 | interactions 374 | ) 375 | ) 376 | with gr.Row(): 377 | interactive = gr.TabbedInterface( 378 | [interactive_1], ["Interactive Visual Prompt"] 379 | ) 380 | with gr.Row(): 381 | interactions = "LeftClick (Point Prompt) | PressMove (Box Prompt)" 382 | gr.Markdown( 383 | "

This is for generic visual prompt

" 384 | ) 385 | gr.Markdown( 386 | "

[🖱️ | 🖐️]: 🌟🌟 {} 🌟🌟

".format( 387 | interactions 388 | ) 389 | ) 390 | with gr.Row(): 391 | generic = gr.TabbedInterface( 392 | [ 393 | generic_vp1, 394 | generic_vp2, 395 | generic_vp3, 396 | generic_vp4, 397 | generic_vp5, 398 | generic_vp6, 399 | generic_vp7, 400 | generic_vp8, 401 | ], 402 | ["1", "2", "3", "4", "5", "6", "7", "8"], 403 | ) 404 | with gr.Row(): 405 | visual_threshold = gr.Slider( 406 | label="Visual Prompt Threshold", 407 | value=0.3, 408 | minimum=0.0, 409 | maximum=1.0, 410 | step=0.01, 411 | ) 412 | 413 | with gr.Row(): 414 | clean = gr.Button("Clean Inputs") 415 | infer = gr.Button("Run T-Rex2🦖🦖🦖") 416 | clean.click( 417 | fn=clean_input, 418 | outputs=[ 419 | interactive_1, 420 | generic_vp1, 421 | generic_vp2, 422 | generic_vp3, 423 | generic_vp4, 424 | generic_vp5, 425 | generic_vp6, 426 | generic_vp7, 427 | generic_vp8, 428 | ], 429 | ) 430 | infer.click( 431 | fn=inference, 432 | inputs=[ 433 | target_image, 434 | interactive_1, 435 | generic_vp1, 436 | generic_vp2, 437 | generic_vp3, 438 | generic_vp4, 439 | generic_vp5, 440 | generic_vp6, 441 | generic_vp7, 442 | generic_vp8, 443 | visual_threshold, 444 | return_point, 445 | point_width, 446 | return_score, 447 | ], 448 | outputs=[output_image, num_count, coco_anno], 449 | ) 450 | demo.launch() 451 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pydantic==2.10.6 2 | gradio==4.44.1 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import subprocess 4 | 5 | import torch 6 | from setuptools import find_packages, setup 7 | from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension 8 | 9 | version = "v1.0" 10 | package_name = "trex" 11 | cwd = os.path.dirname(os.path.abspath(__file__)) 12 | 13 | sha = "Unknown" 14 | try: 15 | sha = subprocess.check_output(["git", "rev-parse", "HEAD"], 16 | cwd=cwd).decode("ascii").strip() 17 | except Exception: 18 | pass 19 | 20 | 21 | def write_version_file(): 22 | version_path = os.path.join(cwd, "trex/", "version.py") 23 | with open(version_path, "w") as f: 24 | f.write(f"__version__ = '{version}'\n") 25 | # f.write(f"git_version = {repr(sha)}\n") 26 | 27 | 28 | def parse_requirements(fname="requirements.txt", with_version=True): 29 | """Parse the package dependencies listed in a requirements file but strips 30 | specific versioning information. 31 | 32 | Args: 33 | fname (str): path to requirements file 34 | with_version (bool, default=False): if True include version specs 35 | 36 | Returns: 37 | List[str]: list of requirements items 38 | 39 | CommandLine: 40 | python -c "import setup; print(setup.parse_requirements())" 41 | """ 42 | import re 43 | import sys 44 | from os.path import exists 45 | 46 | require_fpath = fname 47 | 48 | def parse_line(line): 49 | """Parse information from a line in a requirements text file.""" 50 | if line.startswith("-r "): 51 | # Allow specifying requirements in other files 52 | target = line.split(" ")[1] 53 | for info in parse_require_file(target): 54 | yield info 55 | else: 56 | info = {"line": line} 57 | if line.startswith("-e "): 58 | info["package"] = line.split("#egg=")[1] 59 | elif "@git+" in line: 60 | info["package"] = line 61 | else: 62 | # Remove versioning from the package 63 | pat = "(" + "|".join([">=", "==", ">"]) + ")" 64 | parts = re.split(pat, line, maxsplit=1) 65 | parts = [p.strip() for p in parts] 66 | 67 | info["package"] = parts[0] 68 | if len(parts) > 1: 69 | op, rest = parts[1:] 70 | if ";" in rest: 71 | # Handle platform specific dependencies 72 | # http://setuptools.readthedocs.io/en/latest/setuptools.html#declaring-platform-specific-dependencies 73 | version, platform_deps = map(str.strip, 74 | rest.split(";")) 75 | info["platform_deps"] = platform_deps 76 | else: 77 | version = rest # NOQA 78 | info["version"] = (op, version) 79 | yield info 80 | 81 | def parse_require_file(fpath): 82 | with open(fpath, "r") as f: 83 | for line in f.readlines(): 84 | line = line.strip() 85 | if line and not line.startswith("#"): 86 | for info in parse_line(line): 87 | yield info 88 | 89 | def gen_packages_items(): 90 | if exists(require_fpath): 91 | for info in parse_require_file(require_fpath): 92 | parts = [info["package"]] 93 | if with_version and "version" in info: 94 | parts.extend(info["version"]) 95 | if not sys.version.startswith("3.4"): 96 | # apparently package_deps are broken in 3.4 97 | platform_deps = info.get("platform_deps") 98 | if platform_deps is not None: 99 | parts.append(";" + platform_deps) 100 | item = "".join(parts) 101 | yield item 102 | 103 | packages = list(gen_packages_items()) 104 | return packages 105 | 106 | 107 | if __name__ == "__main__": 108 | print(f"Building wheel {package_name}-{version}") 109 | 110 | with open("LICENSE", "r", encoding="utf-8") as f: 111 | license = f.read() 112 | 113 | write_version_file() 114 | 115 | setup( 116 | name="trex", 117 | version="v1.0", 118 | author="International Digital Economy Academy, Qing Jiang", 119 | url="https://github.com/IDEA-Research/T-Rex", 120 | description="T-Rex2 API wrapper.", 121 | license=license, 122 | install_requires=parse_requirements("requirements.txt"), 123 | packages=find_packages(exclude=("tests", )), 124 | ext_modules=None, 125 | cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, 126 | ) 127 | -------------------------------------------------------------------------------- /trex/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_wrapper import TRex2APIWrapper 2 | from .visualize import visualize 3 | 4 | __all__ = ["TRex2APIWrapper", "visualize"] 5 | -------------------------------------------------------------------------------- /trex/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/trex/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /trex/__pycache__/model_wrapper.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/trex/__pycache__/model_wrapper.cpython-38.pyc -------------------------------------------------------------------------------- /trex/__pycache__/visualize.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/IDEA-Research/T-Rex/a1a8f8f8d5b694802efd721118b1f227c569ecd9/trex/__pycache__/visualize.cpython-38.pyc -------------------------------------------------------------------------------- /trex/model_wrapper.py: -------------------------------------------------------------------------------- 1 | import base64 2 | import tempfile 3 | import time 4 | from io import BytesIO 5 | from typing import Dict, List, Union 6 | 7 | import numpy as np 8 | import requests 9 | from PIL import Image 10 | 11 | 12 | def encode_image(image): 13 | """ 14 | Encodes an image to a base64 string. 15 | 16 | Args: 17 | image (str or PIL.Image.Image): The image to encode. 18 | - If str: should be a valid image file path. 19 | - If PIL.Image.Image: image will be encoded from memory. 20 | 21 | Returns: 22 | str: Base64-encoded image string. 23 | """ 24 | if isinstance(image, str): 25 | # Treat as file path 26 | with open(image, "rb") as image_file: 27 | return base64.b64encode(image_file.read()).decode("utf-8") 28 | 29 | elif isinstance(image, Image.Image): 30 | # Encode from in-memory PIL image 31 | buffer = BytesIO() 32 | image.save(buffer, format="JPEG") # You can change to PNG if needed 33 | return base64.b64encode(buffer.getvalue()).decode("utf-8") 34 | 35 | else: 36 | raise TypeError("Input must be a file path (str) or PIL.Image.Image object.") 37 | 38 | 39 | class TRex2APIWrapper: 40 | """API wrapper for T-Rex2 41 | 42 | Args: 43 | token (str): The token for T-Rex2 API. We are now opening free API access to T-Rex2. For 44 | educators, students, and researchers, we offer an API with extensive usage times to 45 | support your educational and research endeavors. Please send a request to this email 46 | address (weiliu@idea.edu.cn) and attach your usage purpose as well as your institution. 47 | """ 48 | 49 | def __init__(self, token: str): 50 | self.headers = {"Content-Type": "application/json", "Token": token} 51 | 52 | def call_api(self, task_dict): 53 | resp = requests.post( 54 | url="https://api.deepdataspace.com/v2/task/trex/detection", 55 | json=task_dict, 56 | headers=self.headers, 57 | ) 58 | json_resp = resp.json() 59 | if json_resp["msg"] != "ok": 60 | raise RuntimeError(f"API call failed with error: {json_resp}") 61 | task_uuid = json_resp["data"]["task_uuid"] 62 | 63 | while True: 64 | resp = requests.get( 65 | f"https://api.deepdataspace.com/v2/task_status/{task_uuid}", 66 | headers=self.headers, 67 | ) 68 | json_resp = resp.json() 69 | if json_resp["data"]["status"] not in ["waiting", "running"]: 70 | break 71 | time.sleep(1) 72 | 73 | if json_resp["data"]["status"] == "failed": 74 | raise RuntimeError(f"API call failed with error: {json_resp['msg']}") 75 | elif json_resp["data"]["status"] == "success": 76 | return json_resp 77 | 78 | def convert_embedding_prompt( 79 | self, target_image: Union[str, Image.Image], base64_embedding: str 80 | ): 81 | """Convert the prompt to the format required by the API""" 82 | target_image_base64 = encode_image(target_image) 83 | prompt = { 84 | "model": "T-Rex-2.0", 85 | "image": f"data:image/jpg;base64,{target_image_base64}", 86 | "targets": ["bbox"], 87 | "prompt": {"type": "embedding", "embedding": base64_embedding}, 88 | } 89 | return prompt 90 | 91 | def convert_visual_prompt( 92 | self, 93 | target_image: Union[str, Image.Image], 94 | prompts: List[Dict], 95 | return_type: List[str] = ["bbox"], 96 | ): 97 | """Convert the prompt to the format required by the API""" 98 | target_image_base64 = encode_image(target_image) 99 | 100 | for prompt in prompts: 101 | prompt["image"] = f"data:image/jpg;base64,{encode_image(prompt['image'])}" 102 | 103 | prompt = { 104 | "model": "T-Rex-2.0", 105 | "image": f"data:image/jpg;base64,{target_image_base64}", 106 | "targets": return_type, 107 | "prompt": {"type": "visual_images", "visual_images": prompts}, 108 | } 109 | 110 | return prompt 111 | 112 | def visual_prompt_inference( 113 | self, 114 | target_image: Union[str, Image.Image], 115 | prompt: List[Dict], 116 | return_type: List[str] = ["bbox"], 117 | ): 118 | """Visual prompt inference for both interactive and generic workflow. 119 | 120 | Args: 121 | target_image (Union[str, Image.Image]): The image to upload. Can be a file path or PIL.Image 122 | prompts (List[dict]): A list of prompt dict. Each dict is for one prompt image: 123 | # Box prompt 124 | [ 125 | { 126 | "image": (str or Image.Image): Prompt Image 1, 127 | "interactions": [ 128 | { 129 | "type": "rect", 130 | "category_id": 12, 131 | "rect": [159.78119507908616, 186.52658172231986, 337.2996485061512, 309.2963532513181], 132 | }, 133 | { 134 | "type": "rect", 135 | "category_id": 1, 136 | "rect": [159.78119507908616, 186.52658172231986, 337.2996485061512, 309.2963532513181], 137 | } 138 | ... # more prompt on current image 139 | ] 140 | }, 141 | { 142 | "image": (str or Image.Image): Prompt Image 2, 143 | "interactions": [ 144 | { 145 | "type": "rect", 146 | "category_id": 12, 147 | "rect": [159.78119507908616, 186.52658172231986, 337.2996485061512, 309.2963532513181], 148 | }, 149 | { 150 | "type": "rect", 151 | "category_id": 1, 152 | "rect": [159.78119507908616, 186.52658172231986, 337.2996485061512, 309.2963532513181], 153 | } 154 | ... # more prompt on current image 155 | ] 156 | } 157 | ... # more prompt image. 158 | ] 159 | # Point prompt 160 | [ 161 | { 162 | "image": (str or Image.Image): Prompt Image 1, 163 | "interactions": [ 164 | { 165 | "type": "point", 166 | "category_id": 12, 167 | "point": [159.78119507908616, 186.52658172231986], 168 | }, 169 | { 170 | "type": "point", 171 | "category_id": 1, 172 | "point": [159.78119507908616, 186.52658172231986], 173 | } 174 | ... # more prompt on current image 175 | ] 176 | }, 177 | { 178 | "image": (str or Image.Image): Prompt Image 2, 179 | "interactions": [ 180 | { 181 | "type": "point", 182 | "category_id": 12, 183 | "point": [159.78119507908616, 186.52658172231986], 184 | }, 185 | { 186 | "type": "point", 187 | "category_id": 1, 188 | "point": [159.78119507908616, 186.52658172231986], 189 | } 190 | ... # more prompt on current image 191 | ] 192 | } 193 | ... # more prompt image. 194 | ] 195 | return_type (List[str]): The type of return value. Currently only support "bbox" and "embedding". 196 | 197 | Returns: 198 | detection_result (Dict): Detection result in format: 199 | { 200 | "scores": (List[float]): A list of scores for each object in the batch 201 | "labels": (List[int]): A list of labels for each object in the batch 202 | "boxes": (List[List[int]]): A list of boxes for each object in the batch, 203 | in format [xmin, ymin, xmax, ymax] 204 | } 205 | base64_embedding (str): The base64 encoding of the embedding. Only available when 206 | "embedding" is in return_type, else None 207 | 208 | """ 209 | # Convert the interactive prompt to the format required by the API 210 | prompt = self.convert_visual_prompt(target_image, prompt, return_type) 211 | # call the API 212 | result = self.call_api(prompt) 213 | detection_result = self.postprocess(result["data"]["result"]["objects"]) 214 | if "embedding" in return_type: 215 | base64_embedding = result["data"]["result"]["embedding"] 216 | else: 217 | base64_embedding = None 218 | return detection_result, base64_embedding 219 | 220 | def embedding_inference( 221 | self, target_image: Union[str, Image.Image], base64_embedding: str 222 | ): 223 | """Prompt inference workflow. 224 | Args: 225 | target_image (Union[str, Image.Image]): The image to upload. Can be a file path or PIL.Image 226 | base64_embedding (str): The base64 encoding of the embedding. 227 | 228 | Returns: 229 | Dict: Return dict in format: 230 | { 231 | "scores": (torch.Tensor): Sigmoid logits in shape (batch_size, 900, num_classes), 232 | class order is the same as the order in the prompt 233 | "labels": (List[List[int]]): A list of list of labels for each batch image. 234 | "boxes": (torch.Tensor): Normalized prediction boxes in shape (batch_size, 900, 4), 235 | format is (xmin, ymin, ymin, ymax) 236 | } 237 | """ 238 | prompt = self.convert_embedding_prompt(target_image, base64_embedding) 239 | result = self.call_api(prompt) 240 | detection_result = self.postprocess(result["data"]["result"]["objects"]) 241 | return detection_result 242 | 243 | def postprocess(self, object_batches): 244 | """Postprocess the result from the API 245 | 246 | Args: 247 | object_batches (List[Dict]): List of Dicts. Each dict contains the prediction 248 | on each image. Each TRexObject contains the following keys: 249 | - category_id (int): The category id of the object 250 | - score (float): The score of the object 251 | - bbox (List[int]): The bounding box of the object in format [xmin, ymin, xmax, ymax] 252 | 253 | Returns: 254 | List[Dict]: Return a list of dict in format: 255 | { 256 | "scores": (List[float]): A list of scores for each object in the batch 257 | "labels": (List[int]): A list of labels for each object in the batch 258 | "boxes": (List[List[int]]): A list of boxes for each object in the batch 259 | } 260 | 261 | """ 262 | scores = [] 263 | labels = [] 264 | boxes = [] 265 | for obj in object_batches: 266 | score = obj["score"] 267 | category_id = obj["category_id"] 268 | bbox = obj["bbox"] 269 | scores.append(score) 270 | labels.append(category_id) 271 | boxes.append(bbox) 272 | return {"scores": scores, "labels": labels, "boxes": boxes} 273 | -------------------------------------------------------------------------------- /trex/version.py: -------------------------------------------------------------------------------- 1 | __version__ = 'v1.0' 2 | -------------------------------------------------------------------------------- /trex/visualize.py: -------------------------------------------------------------------------------- 1 | from typing import Dict 2 | 3 | import numpy as np 4 | from PIL import Image, ImageDraw, ImageFont 5 | 6 | 7 | def visualize(image_pil: Image, 8 | target: Dict, 9 | return_point: bool = False, 10 | draw_width: float = 6.0, 11 | random_color: bool = True, 12 | overwrite_color: Dict = None, 13 | agnostic_random_color: bool = False, 14 | draw_score=False, 15 | draw_label=True) -> Image: 16 | """Plot bounding boxes and labels on an image. 17 | 18 | Args: 19 | image_pil (PIL.Image): The input image as a PIL Image object. 20 | model_targetoutput (Dict[str, Union[torch.Tensor, List[torch.Tensor]]]): The target dictionary containing 21 | the bounding boxes and labels. The keys are: 22 | - boxes (List[int]): A list of bounding boxes in shape (N, 4), [x1, y1, x2, y2] format. 23 | - scores (List[float]): A list of scores for each bounding box. shape (N) 24 | - labels (List[str]): A list of string labels for each bounding box. shape (N) 25 | return_point (bool): Draw center point instead of bounding box. Defaults to False. 26 | draw_width (float): The width of the drawn bounding box or point. Defaults to 1.0. 27 | random_color (bool): Use random color for each category. Defaults to True. 28 | overwrite_color (Dict): Overwrite color for each category. Defaults to None. 29 | agnostic_random_color (bool): If True, we will use random color for all boxes. 30 | draw_score (bool): Draw score on the image. Defaults to False. 31 | 32 | Returns: 33 | Union[PIL.Image, PIL.Image]: A tuple containing the input image and ploted image. 34 | """ 35 | # Get the bounding boxes and labels from the target dictionary 36 | boxes = target["boxes"] 37 | scores = target["scores"] 38 | labels = target["labels"] 39 | 40 | label2color = {} 41 | if overwrite_color: 42 | label2color = overwrite_color 43 | else: 44 | # find all unique labels 45 | unique_labels = set(labels) 46 | # generate random color for each label 47 | for label in unique_labels: 48 | if random_color: 49 | label2color[str(label)] = tuple( 50 | np.random.randint(0, 255, size=3).tolist()) 51 | else: 52 | label2color[str(label)] = (255, 255, 255) 53 | 54 | # Create a PIL ImageDraw object to draw on the input image 55 | draw = ImageDraw.Draw(image_pil) 56 | # Create a new binary mask image with the same size as the input image 57 | mask = Image.new("L", image_pil.size, 0) 58 | # Create a PIL ImageDraw object to draw on the mask image 59 | mask_draw = ImageDraw.Draw(mask) 60 | 61 | # Draw boxes and masks for each box and label in the target dictionary 62 | for box, score, label in zip(boxes, scores, labels): 63 | # Convert the box coordinates from 0..1 to 0..W, 0..H 64 | score = score.item() 65 | if not agnostic_random_color: 66 | color = label2color[str(label)] 67 | else: 68 | color = tuple(np.random.randint(0, 255, size=3).tolist()) 69 | # Extract the box coordinates 70 | x0, y0, x1, y1 = box 71 | x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) 72 | if return_point: 73 | ceter_x = int((x0 + x1) / 2) 74 | ceter_y = int((y0 + y1) / 2) 75 | # Draw the center point on the input image 76 | draw.ellipse((ceter_x - draw_width, ceter_y - draw_width, 77 | ceter_x + draw_width, ceter_y + draw_width), 78 | fill=color, 79 | width=draw_width) 80 | else: 81 | # Draw the box outline on the input image 82 | draw.rectangle([x0, y0, x1, y1], 83 | outline=color, 84 | width=int(draw_width)) 85 | 86 | # Draw the label text on the input image 87 | if not draw_label: 88 | label = "" 89 | if draw_score: 90 | text = f"{label} {score:.2f}" 91 | else: 92 | text = f"{label}" 93 | font = ImageFont.load_default() 94 | if hasattr(font, "getbbox"): 95 | bbox = draw.textbbox((x0, y0), text, font) 96 | else: 97 | w, h = draw.textsize(text, font) 98 | bbox = (x0, y0, w + x0, y0 + h) 99 | if not return_point: 100 | draw.rectangle(bbox, fill=color) 101 | draw.text((x0, y0), text, fill="white") 102 | 103 | # Draw the box on the mask image 104 | mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) 105 | return image_pil 106 | --------------------------------------------------------------------------------