├── .gitignore ├── LICENSE ├── README.md ├── a photo of a brown giraffe and a white stop sign.png ├── app_deqa.py ├── app_geneval.py ├── gunicorn.conf.py ├── requirements.txt ├── reward_server ├── deqa.py ├── gen_eval.py └── object_names.txt └── test ├── test_deqa.py └── test_geneval.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/visualstudiocode,pycharm+all,python,vim,macos,linux 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=visualstudiocode,pycharm+all,python,vim,macos,linux 3 | 4 | ### Linux ### 5 | *~ 6 | 7 | # temporary files which can be created if a process still has a handle open of a deleted file 8 | .fuse_hidden* 9 | 10 | # KDE directory preferences 11 | .directory 12 | 13 | # Linux trash folder which might appear on any partition or disk 14 | .Trash-* 15 | 16 | # .nfs files are created when an open file is removed but is still being accessed 17 | .nfs* 18 | 19 | ### macOS ### 20 | # General 21 | .DS_Store 22 | .AppleDouble 23 | .LSOverride 24 | 25 | # Icon must end with two \r 26 | Icon 27 | 28 | 29 | # Thumbnails 30 | ._* 31 | 32 | # Files that might appear in the root of a volume 33 | .DocumentRevisions-V100 34 | .fseventsd 35 | .Spotlight-V100 36 | .TemporaryItems 37 | .Trashes 38 | .VolumeIcon.icns 39 | .com.apple.timemachine.donotpresent 40 | 41 | # Directories potentially created on remote AFP share 42 | .AppleDB 43 | .AppleDesktop 44 | Network Trash Folder 45 | Temporary Items 46 | .apdisk 47 | 48 | ### macOS Patch ### 49 | # iCloud generated files 50 | *.icloud 51 | 52 | ### PyCharm+all ### 53 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 54 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 55 | 56 | # User-specific stuff 57 | .idea/**/workspace.xml 58 | .idea/**/tasks.xml 59 | .idea/**/usage.statistics.xml 60 | .idea/**/dictionaries 61 | .idea/**/shelf 62 | 63 | # AWS User-specific 64 | .idea/**/aws.xml 65 | 66 | # Generated files 67 | .idea/**/contentModel.xml 68 | 69 | # Sensitive or high-churn files 70 | .idea/**/dataSources/ 71 | .idea/**/dataSources.ids 72 | .idea/**/dataSources.local.xml 73 | .idea/**/sqlDataSources.xml 74 | .idea/**/dynamic.xml 75 | .idea/**/uiDesigner.xml 76 | .idea/**/dbnavigator.xml 77 | 78 | # Gradle 79 | .idea/**/gradle.xml 80 | .idea/**/libraries 81 | 82 | # Gradle and Maven with auto-import 83 | # When using Gradle or Maven with auto-import, you should exclude module files, 84 | # since they will be recreated, and may cause churn. Uncomment if using 85 | # auto-import. 86 | # .idea/artifacts 87 | # .idea/compiler.xml 88 | # .idea/jarRepositories.xml 89 | # .idea/modules.xml 90 | # .idea/*.iml 91 | # .idea/modules 92 | # *.iml 93 | # *.ipr 94 | 95 | # CMake 96 | cmake-build-*/ 97 | 98 | # Mongo Explorer plugin 99 | .idea/**/mongoSettings.xml 100 | 101 | # File-based project format 102 | *.iws 103 | 104 | # IntelliJ 105 | out/ 106 | 107 | # mpeltonen/sbt-idea plugin 108 | .idea_modules/ 109 | 110 | # JIRA plugin 111 | atlassian-ide-plugin.xml 112 | 113 | # Cursive Clojure plugin 114 | .idea/replstate.xml 115 | 116 | # SonarLint plugin 117 | .idea/sonarlint/ 118 | 119 | # Crashlytics plugin (for Android Studio and IntelliJ) 120 | com_crashlytics_export_strings.xml 121 | crashlytics.properties 122 | crashlytics-build.properties 123 | fabric.properties 124 | 125 | # Editor-based Rest Client 126 | .idea/httpRequests 127 | 128 | # Android studio 3.1+ serialized cache file 129 | .idea/caches/build_file_checksums.ser 130 | 131 | ### PyCharm+all Patch ### 132 | # Ignore everything but code style settings and run configurations 133 | # that are supposed to be shared within teams. 134 | 135 | .idea/* 136 | 137 | !.idea/codeStyles 138 | !.idea/runConfigurations 139 | 140 | ### Python ### 141 | # Byte-compiled / optimized / DLL files 142 | __pycache__/ 143 | *.py[cod] 144 | *$py.class 145 | 146 | # C extensions 147 | *.so 148 | 149 | # Distribution / packaging 150 | .Python 151 | build/ 152 | develop-eggs/ 153 | dist/ 154 | downloads/ 155 | eggs/ 156 | .eggs/ 157 | lib/ 158 | lib64/ 159 | parts/ 160 | sdist/ 161 | var/ 162 | wheels/ 163 | share/python-wheels/ 164 | *.egg-info/ 165 | .installed.cfg 166 | *.egg 167 | MANIFEST 168 | 169 | # PyInstaller 170 | # Usually these files are written by a python script from a template 171 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 172 | *.manifest 173 | *.spec 174 | 175 | # Installer logs 176 | pip-log.txt 177 | pip-delete-this-directory.txt 178 | 179 | # Unit test / coverage reports 180 | htmlcov/ 181 | .tox/ 182 | .nox/ 183 | .coverage 184 | .coverage.* 185 | .cache 186 | nosetests.xml 187 | coverage.xml 188 | *.cover 189 | *.py,cover 190 | .hypothesis/ 191 | .pytest_cache/ 192 | cover/ 193 | 194 | # Translations 195 | *.mo 196 | *.pot 197 | 198 | # Django stuff: 199 | *.log 200 | local_settings.py 201 | db.sqlite3 202 | db.sqlite3-journal 203 | 204 | # Flask stuff: 205 | instance/ 206 | .webassets-cache 207 | 208 | # Scrapy stuff: 209 | .scrapy 210 | 211 | # Sphinx documentation 212 | docs/_build/ 213 | 214 | # PyBuilder 215 | .pybuilder/ 216 | target/ 217 | 218 | # Jupyter Notebook 219 | .ipynb_checkpoints 220 | 221 | # IPython 222 | profile_default/ 223 | ipython_config.py 224 | 225 | # pyenv 226 | # For a library or package, you might want to ignore these files since the code is 227 | # intended to run in multiple environments; otherwise, check them in: 228 | # .python-version 229 | 230 | # pipenv 231 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 232 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 233 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 234 | # install all needed dependencies. 235 | #Pipfile.lock 236 | 237 | # poetry 238 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 239 | # This is especially recommended for binary packages to ensure reproducibility, and is more 240 | # commonly ignored for libraries. 241 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 242 | #poetry.lock 243 | 244 | # pdm 245 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 246 | #pdm.lock 247 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 248 | # in version control. 249 | # https://pdm.fming.dev/#use-with-ide 250 | .pdm.toml 251 | 252 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 253 | __pypackages__/ 254 | 255 | # Celery stuff 256 | celerybeat-schedule 257 | celerybeat.pid 258 | 259 | # SageMath parsed files 260 | *.sage.py 261 | 262 | # Environments 263 | .env 264 | .venv 265 | env/ 266 | venv/ 267 | ENV/ 268 | env.bak/ 269 | venv.bak/ 270 | 271 | # Spyder project settings 272 | .spyderproject 273 | .spyproject 274 | 275 | # Rope project settings 276 | .ropeproject 277 | 278 | # mkdocs documentation 279 | /site 280 | 281 | # mypy 282 | .mypy_cache/ 283 | .dmypy.json 284 | dmypy.json 285 | 286 | # Pyre type checker 287 | .pyre/ 288 | 289 | # pytype static type analyzer 290 | .pytype/ 291 | 292 | # Cython debug symbols 293 | cython_debug/ 294 | 295 | # PyCharm 296 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 297 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 298 | # and can be added to the global gitignore or merged into this file. For a more nuclear 299 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 300 | #.idea/ 301 | 302 | ### Python Patch ### 303 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 304 | poetry.toml 305 | 306 | # ruff 307 | .ruff_cache/ 308 | 309 | # LSP config files 310 | pyrightconfig.json 311 | 312 | ### Vim ### 313 | # Swap 314 | [._]*.s[a-v][a-z] 315 | !*.svg # comment out if you don't need vector files 316 | [._]*.sw[a-p] 317 | [._]s[a-rt-v][a-z] 318 | [._]ss[a-gi-z] 319 | [._]sw[a-p] 320 | 321 | # Session 322 | Session.vim 323 | Sessionx.vim 324 | 325 | # Temporary 326 | .netrwhist 327 | # Auto-generated tag files 328 | tags 329 | # Persistent undo 330 | [._]*.un~ 331 | 332 | ### VisualStudioCode ### 333 | .vscode/* 334 | !.vscode/settings.json 335 | !.vscode/tasks.json 336 | !.vscode/launch.json 337 | !.vscode/extensions.json 338 | !.vscode/*.code-snippets 339 | 340 | # Local History for Visual Studio Code 341 | .history/ 342 | 343 | # Built Visual Studio Code Extensions 344 | *.vsix 345 | 346 | ### VisualStudioCode Patch ### 347 | # Ignore all local history of files 348 | .history 349 | .ionide 350 | 351 | # End of https://www.toptal.com/developers/gitignore/api/visualstudiocode,pycharm+all,python,vim,macos,linux 352 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Jie Liu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # reward-server 2 | 3 | Serves reward inference using an HTTP server. 4 | 5 | ## Install 6 | 7 | ### GenEval 8 | 9 | ```bash 10 | # First 11 | conda create -n reward_server python=3.10.16 12 | # Then 13 | pip install torch==2.1.2+cu121 torchvision==0.16.2+cu121 torchaudio==2.1.2+cu121 --index-url https://download.pytorch.org/whl/cu121 14 | # Then 15 | pip install -r requirements.txt 16 | ``` 17 | 18 | Then install mmdet: 19 | 20 | ```bash 21 | mim install mmcv-full mmengine 22 | git clone https://github.com/open-mmlab/mmdetection.git 23 | cd mmdetection; git checkout 2.x 24 | # Modify mmdet/__init__.py: set mmcv_maximum_version = '2.3.0' 25 | pip install -e . 26 | ``` 27 | 28 | Then download mask2former: 29 | 30 | ```bash 31 | wget https://download.openmmlab.com/mmdetection/v2.0/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco_20220504_001756-743b7d99.pth -O "$1/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.pth" 32 | ``` 33 | 34 | Modify `MY_CONFIG_PATH` and `MY_CKPT_PATH` in `reward-server/reward_server/gen_eval.py` to your own paths. 35 | 36 | ## Usage 37 | 38 | ### GenEval 39 | 40 | Start the server side: 41 | 42 | ```bash 43 | cd reward-server/ 44 | conda deactivate 45 | conda activate reward_server 46 | gunicorn "app_geneval:create_app()" 47 | ``` 48 | 49 | You must modify `gunicorn.conf.py` to change the number of GPUs. 50 | 51 | After starting, you can run the client for testing: 52 | 53 | ```bash 54 | python test/test_geneval.py 55 | ``` 56 | 57 | ### DeQA 58 | If there's an error, please refer to [DeQA](https://github.com/zhiyuanyou/DeQA-Score ) to install DeQA's dependencies. 59 | Start the server side: 60 | 61 | ```bash 62 | cd reward-server/ 63 | conda deactivate 64 | conda activate reward_server 65 | gunicorn "app_deqa:create_app()" 66 | ``` 67 | 68 | After starting, you can run the client for testing: 69 | 70 | ```bash 71 | python test/test_deqa.py 72 | ``` 73 | -------------------------------------------------------------------------------- /a photo of a brown giraffe and a white stop sign.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifan123/reward-server/92ac29fb322c028f8840db8f93c950422931710d/a photo of a brown giraffe and a white stop sign.png -------------------------------------------------------------------------------- /app_deqa.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import pickle 4 | import traceback 5 | from reward_server.deqa import load_deqascore 6 | import numpy as np 7 | import os 8 | 9 | from flask import Flask, request, Blueprint 10 | 11 | root = Blueprint("root", __name__) 12 | 13 | def create_app(): 14 | global INFERENCE_FN 15 | INFERENCE_FN = load_deqascore() 16 | 17 | app = Flask(__name__) 18 | app.register_blueprint(root) 19 | return app 20 | 21 | @root.route("/", methods=["POST"]) 22 | def inference(): 23 | print(f"received POST request from {request.remote_addr}") 24 | data = request.get_data() 25 | 26 | try: 27 | # expects a dict with "images", "queries", and optionally "answers" 28 | # images: (batch_size,) of JPEG bytes 29 | # queries: (batch_size, num_queries_per_image) of strings 30 | # answers: (batch_size, num_queries_per_image) of strings 31 | data = pickle.loads(data) 32 | 33 | images = [Image.open(BytesIO(d), formats=["jpeg"]) for d in data["images"]] 34 | 35 | print(f"Got {len(images)} images") 36 | 37 | outputs = INFERENCE_FN(images) 38 | 39 | response = {"outputs": outputs} 40 | 41 | # returns: a dict with "outputs" and optionally "scores" 42 | # outputs: (batch_size, num_queries_per_image) of strings 43 | # precision: (batch_size, num_queries_per_image) of floats 44 | # recall: (batch_size, num_queries_per_image) of floats 45 | # f1: (batch_size, num_queries_per_image) of floats 46 | response = pickle.dumps(response) 47 | 48 | returncode = 200 49 | except Exception as e: 50 | response = traceback.format_exc() 51 | print(response) 52 | response = response.encode("utf-8") 53 | returncode = 500 54 | 55 | return response, returncode 56 | 57 | 58 | HOST = "127.0.0.1" 59 | PORT = 8085 60 | 61 | if __name__ == "__main__": 62 | create_app().run(HOST, PORT) -------------------------------------------------------------------------------- /app_geneval.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | from io import BytesIO 3 | import pickle 4 | import traceback 5 | from reward_server.gen_eval import load_geneval 6 | import numpy as np 7 | import os 8 | 9 | from flask import Flask, request, Blueprint 10 | 11 | root = Blueprint("root", __name__) 12 | 13 | def create_app(): 14 | global INFERENCE_FN 15 | INFERENCE_FN = load_geneval() 16 | 17 | app = Flask(__name__) 18 | app.register_blueprint(root) 19 | return app 20 | 21 | @root.route("/", methods=["POST"]) 22 | def inference(): 23 | print(f"received POST request from {request.remote_addr}") 24 | data = request.get_data() 25 | 26 | try: 27 | # expects a dict with "images", "queries", and optionally "answers" 28 | # images: (batch_size,) of JPEG bytes 29 | # queries: (batch_size, num_queries_per_image) of strings 30 | # answers: (batch_size, num_queries_per_image) of strings 31 | data = pickle.loads(data) 32 | 33 | images = [Image.open(BytesIO(d), formats=["jpeg"]) for d in data["images"]] 34 | meta_datas = data["meta_datas"] 35 | only_strict = data["only_strict"] 36 | 37 | print(f"Got {len(images)} images") 38 | 39 | scores, rewards, strict_rewards, group_rewards, group_strict_rewards = INFERENCE_FN(images, meta_datas, only_strict) 40 | 41 | response = {"scores": scores, "rewards": rewards, "strict_rewards": strict_rewards, "group_rewards": group_rewards, "group_strict_rewards": group_strict_rewards} 42 | 43 | # returns: a dict with "outputs" and optionally "scores" 44 | # outputs: (batch_size, num_queries_per_image) of strings 45 | # precision: (batch_size, num_queries_per_image) of floats 46 | # recall: (batch_size, num_queries_per_image) of floats 47 | # f1: (batch_size, num_queries_per_image) of floats 48 | response = pickle.dumps(response) 49 | 50 | returncode = 200 51 | except Exception as e: 52 | response = traceback.format_exc() 53 | print(response) 54 | response = response.encode("utf-8") 55 | returncode = 500 56 | 57 | return response, returncode 58 | 59 | 60 | HOST = "127.0.0.1" 61 | PORT = 8085 62 | 63 | if __name__ == "__main__": 64 | create_app().run(HOST, PORT) -------------------------------------------------------------------------------- /gunicorn.conf.py: -------------------------------------------------------------------------------- 1 | import os 2 | NUM_DEVICES = 8 3 | USED_DEVICES = set() 4 | 5 | # reward='deqa' 6 | reward='geneval' 7 | if reward=='deqa': 8 | port=18086 9 | if reward=='geneval': 10 | port=18085 11 | 12 | def pre_fork(server, worker): 13 | # runs on server 14 | global USED_DEVICES 15 | worker.device_id = next(i for i in range(NUM_DEVICES) if i not in USED_DEVICES) 16 | USED_DEVICES.add(worker.device_id) 17 | 18 | def post_fork(server, worker): 19 | # runs on worker 20 | os.environ["CUDA_VISIBLE_DEVICES"] = str(worker.device_id) 21 | 22 | def child_exit(server, worker): 23 | # runs on server 24 | global USED_DEVICES 25 | USED_DEVICES.remove(worker.device_id) 26 | 27 | # Gunicorn Configuration 28 | bind = f"127.0.0.1:{port}" 29 | # for cross node access 30 | # bind = "0.0.0.0:18085" 31 | workers = NUM_DEVICES 32 | worker_class = "sync" 33 | timeout = 120 34 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==2.1.0 2 | addict==2.4.0 3 | aiofiles==23.2.1 4 | aiohappyeyeballs==2.5.0 5 | aiohttp==3.11.13 6 | aiohttp-cors==0.8.0 7 | aiosignal==1.3.2 8 | airportsdata==20250224 9 | albucore==0.0.13 10 | albumentations==1.4.10 11 | aliyun-python-sdk-core==2.16.0 12 | aliyun-python-sdk-kms==2.16.5 13 | altair==5.5.0 14 | annotated-types==0.7.0 15 | anthropic==0.49.0 16 | anyio==4.8.0 17 | astor==0.8.1 18 | asttokens==3.0.0 19 | async-timeout==5.0.1 20 | attrs==25.1.0 21 | av==14.2.0 22 | beautifulsoup4==4.13.3 23 | bitsandbytes==0.45.3 24 | blake3==1.0.4 25 | blinker==1.9.0 26 | braceexpand==0.1.7 27 | cachetools==5.5.2 28 | charset-normalizer==3.4.1 29 | click==8.1.8 30 | clip 31 | clip-benchmark==1.6.1 32 | cloudpickle==3.1.1 33 | colorama==0.4.6 34 | colorful==0.5.6 35 | compressed-tensors==0.9.1 36 | contourpy==1.3.1 37 | crcmod==1.7 38 | cryptography==44.0.2 39 | cuda-bindings==12.8.0 40 | cuda-python==12.8.0 41 | cycler==0.12.1 42 | Cython==3.0.12 43 | datasets 44 | decorator==5.1.1 45 | decord==0.6.0 46 | deepspeed==0.16.4 47 | depyf==0.18.0 48 | diffusers==0.33.1 49 | dill==0.3.8 50 | diskcache==5.6.3 51 | distlib==0.3.9 52 | distro==1.9.0 53 | docker-pycreds==0.4.0 54 | docstring-parser==0.16 55 | einops==0.8.1 56 | einops-exts==0.0.4 57 | exceptiongroup==1.2.2 58 | executing==2.2.0 59 | fairscale==0.4.13 60 | fastapi==0.115.11 61 | ffmpy==0.5.0 62 | filelock==3.14.0 63 | fire==0.7.0 64 | flash-attn==2.7.4.post1 65 | flashinfer-python==0.2.3 66 | flask==3.1.0 67 | fonttools==4.56.0 68 | frozenlist==1.5.0 69 | fsspec==2024.12.0 70 | ftfy==6.3.1 71 | gguf==0.10.0 72 | gitdb==4.0.12 73 | gitpython==3.1.44 74 | google-api-core==2.24.2 75 | google-auth==2.38.0 76 | googleapis-common-protos==1.69.2 77 | grpcio==1.71.0 78 | gunicorn==23.0.0 79 | h11==0.14.0 80 | hf-transfer==0.1.9 81 | hjson==3.1.0 82 | httpcore==1.0.7 83 | httptools==0.6.4 84 | httpx==0.28.1 85 | huggingface-hub==0.29.1 86 | icecream==2.1.4 87 | idna==3.10 88 | imageio==2.37.0 89 | imageio-ffmpeg==0.6.0 90 | imgaug==0.4.0 91 | importlib-metadata==8.6.1 92 | importlib-resources==6.5.2 93 | inflect==6.0.4 94 | intel-openmp==2023.1.0 95 | interegular==0.3.3 96 | ipython==8.34.0 97 | itsdangerous==2.2.0 98 | jedi==0.19.2 99 | jinja2==3.1.5 100 | jiter==0.8.2 101 | jmespath==0.10.0 102 | joblib==1.4.2 103 | kiwisolver==1.4.8 104 | lark==1.2.2 105 | latex2mathml==3.77.0 106 | lazy-loader==0.4 107 | Levenshtein==0.26.1 108 | litellm==1.63.12 109 | llguidance==0.7.4 110 | lm-format-enforcer==0.10.11 111 | lmdb==1.6.2 112 | lxml==5.3.1 113 | markdown==3.7 114 | markdown-it-py==3.0.0 115 | markdown2==2.5.3 116 | markupsafe==2.1.5 117 | matplotlib==3.10.0 118 | matplotlib-inline==0.1.7 119 | mdurl==0.1.2 120 | mistral-common==1.5.4 121 | ml-collections==1.0.0 122 | model-index==0.1.11 123 | modelscope==1.24.0 124 | mpmath==1.3.0 125 | msgpack==1.1.0 126 | msgspec==0.19.0 127 | multidict==6.1.0 128 | multiprocess==0.70.16 129 | narwhals==1.30.0 130 | nest-asyncio==1.6.0 131 | networkx==3.4.2 132 | ninja==1.11.1.3 133 | numpy==1.26.0 134 | nvidia-cublas-cu12==12.4.5.8 135 | nvidia-cuda-cupti-cu12==12.4.127 136 | nvidia-cuda-nvrtc-cu12==12.4.127 137 | nvidia-cuda-runtime-cu12==12.4.127 138 | nvidia-cudnn-cu12==9.1.0.70 139 | nvidia-cufft-cu12==11.2.1.3 140 | nvidia-curand-cu12==10.3.5.147 141 | nvidia-cusolver-cu12==11.6.1.9 142 | nvidia-cusparse-cu12==12.3.1.170 143 | nvidia-cusparselt-cu12==0.6.2 144 | nvidia-ml-py==12.570.86 145 | nvidia-nccl-cu12==2.21.5 146 | nvidia-nvjitlink-cu12==12.4.127 147 | nvidia-nvtx-cu12==12.4.127 148 | open-clip-torch==2.31.0 149 | openai 150 | opencensus==0.11.4 151 | opencensus-context==0.1.3 152 | opencv-contrib-python==4.11.0.86 153 | opencv-python==4.11.0.86 154 | opencv-python-headless==4.11.0.86 155 | opendatalab==0.0.10 156 | openmim==0.3.9 157 | openxlab==0.1.2 158 | opt-einsum==3.3.0 159 | ordered-set==4.1.0 160 | orjson==3.10.15 161 | oss2==2.17.0 162 | outlines==0.1.11 163 | outlines-core==0.1.26 164 | packaging==24.2 165 | paddleocr==2.9.1 166 | paddlepaddle-gpu==2.6.2 167 | pandas==2.2.3 168 | parso==0.8.4 169 | partial-json-parser==0.2.1.1.post5 170 | peft==0.10.0 171 | pexpect==4.9.0 172 | Pillow==10.4.0 173 | platformdirs==4.3.6 174 | prometheus-client==0.21.1 175 | prometheus-fastapi-instrumentator==7.1.0 176 | prompt-toolkit==3.0.50 177 | propcache==0.3.0 178 | proto-plus==1.26.1 179 | protobuf 180 | psutil==7.0.0 181 | ptyprocess==0.7.0 182 | pure-eval==0.2.3 183 | py-cpuinfo==9.0.0 184 | py-spy==0.4.0 185 | pyarrow==19.0.1 186 | pyasn1==0.6.1 187 | pyasn1-modules==0.4.1 188 | pyclipper==1.3.0.post6 189 | pycocoevalcap==1.2 190 | pycocotools==2.0.8 191 | pycountry==24.6.1 192 | pycparser==2.22 193 | pycryptodome==3.22.0 194 | pydantic==2.10.6 195 | pydantic-core==2.27.2 196 | pydub==0.25.1 197 | pygments==2.19.1 198 | pyparsing==3.2.1 199 | python-dateutil==2.9.0.post0 200 | python-docx==1.1.2 201 | python-dotenv==1.0.1 202 | python-Levenshtein==0.26.1 203 | python-multipart==0.0.20 204 | pytorch-lightning==2.5.1 205 | pytz==2023.4 206 | pyzmq==26.3.0 207 | qwen-vl-utils==0.0.10 208 | rapidfuzz==3.12.1 209 | ray==2.43.0 210 | referencing==0.36.2 211 | regex==2024.11.6 212 | requests==2.28.2 213 | rich==13.4.2 214 | rpds-py==0.23.1 215 | rsa==4.9 216 | safetensors==0.5.2 217 | scikit-image==0.25.2 218 | scikit-learn==1.6.1 219 | scipy==1.15.2 220 | semantic-version==2.10.0 221 | sentencepiece==0.2.0 222 | sentry-sdk==2.22.0 223 | setproctitle==1.3.4 224 | setuptools==60.2.0 225 | sgl-kernel==0.0.5 226 | sglang==0.4.4.post1 227 | shapely==2.0.7 228 | shortuuid==1.0.13 229 | shtab==1.7.1 230 | six==1.17.0 231 | smart-open==7.1.0 232 | smmap==5.0.2 233 | sniffio==1.3.1 234 | soupsieve==2.6 235 | stack_data==0.6.3 236 | starlette==0.46.1 237 | svgwrite==1.4.3 238 | sympy==1.13.1 239 | tabulate==0.9.0 240 | termcolor==2.5.0 241 | terminaltables==3.1.10 242 | threadpoolctl==3.5.0 243 | tifffile==2025.2.18 244 | tiktoken==0.9.0 245 | timm==1.0.15 246 | tokenizers==0.13.3 247 | tomli==2.2.1 248 | torchao==0.9.0 249 | torchmetrics==1.7.1 250 | tqdm 251 | traitlets==5.14.3 252 | transformers==4.28.0 253 | triton==2.1.0 254 | typeguard==4.4.2 255 | typing_extensions==4.12.2 256 | tyro==0.9.16 257 | tzdata==2025.1 258 | urllib3==1.26.20 259 | uvicorn==0.34.0 260 | uvloop==0.21.0 261 | virtualenv==20.29.3 262 | wandb==0.18.7 263 | watchfiles==1.0.4 264 | wcwidth==0.2.13 265 | webdataset==0.2.111 266 | websockets==11.0.3 267 | werkzeug==3.1.3 268 | wrapt==1.17.2 269 | xformers 270 | xgrammar 271 | xxhash==3.5.0 272 | yapf==0.43.0 273 | yarl==1.18.3 274 | zipp==3.21.0 -------------------------------------------------------------------------------- /reward_server/deqa.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoModelForCausalLM 3 | 4 | def load_deqascore(): 5 | model = AutoModelForCausalLM.from_pretrained( 6 | "zhiyuanyou/DeQA-Score-Mix3", 7 | trust_remote_code=True, 8 | attn_implementation="eager", 9 | torch_dtype=torch.float16, 10 | device_map=None, 11 | ).cuda() 12 | model.requires_grad_(False) 13 | 14 | @torch.no_grad() 15 | def compute_deqascore(images): 16 | score = model.score(images) 17 | score = score / 5 18 | score = [sc.item() for sc in score] 19 | return score 20 | 21 | return compute_deqascore -------------------------------------------------------------------------------- /reward_server/gen_eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | import re 5 | import sys 6 | import time 7 | 8 | import warnings 9 | warnings.filterwarnings("ignore") 10 | 11 | import numpy as np 12 | import pandas as pd 13 | from collections import defaultdict 14 | from PIL import Image, ImageOps 15 | import torch 16 | import mmdet 17 | from mmdet.apis import inference_detector, init_detector 18 | 19 | import open_clip 20 | from clip_benchmark.metrics import zeroshot_classification as zsc 21 | zsc.tqdm = lambda it, *args, **kwargs: it 22 | 23 | DEVICE = "cuda" 24 | MY_CONFIG_PATH="/m2v_intern/liujie/research/mmdetection/configs/mask2former/mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco.py" 25 | MY_CKPT_PATH="/m2v_intern/liujie/research/geneval/model/mask2former2" 26 | 27 | def load_geneval(): 28 | def timed(fn): 29 | def wrapper(*args, **kwargs): 30 | startt = time.time() 31 | result = fn(*args, **kwargs) 32 | endt = time.time() 33 | print(f'Function {fn.__name__!r} executed in {endt - startt:.3f}s', file=sys.stderr) 34 | return result 35 | return wrapper 36 | 37 | # Load models 38 | 39 | @timed 40 | def load_models(): 41 | CONFIG_PATH = os.path.join( 42 | os.path.dirname(mmdet.__file__), 43 | MY_CONFIG_PATH 44 | ) 45 | OBJECT_DETECTOR = "mask2former_swin-s-p4-w7-224_lsj_8x2_50e_coco" 46 | CKPT_PATH = os.path.join(MY_CKPT_PATH, f"{OBJECT_DETECTOR}.pth") 47 | object_detector = init_detector(CONFIG_PATH, CKPT_PATH, device=DEVICE) 48 | 49 | clip_arch = "ViT-L-14" 50 | clip_model, _, transform = open_clip.create_model_and_transforms(clip_arch, pretrained="openai", device=DEVICE) 51 | tokenizer = open_clip.get_tokenizer(clip_arch) 52 | 53 | with open(os.path.join(os.getcwd(), "reward_server/object_names.txt")) as cls_file: 54 | classnames = [line.strip() for line in cls_file] 55 | 56 | return object_detector, (clip_model, transform, tokenizer), classnames 57 | 58 | 59 | COLORS = ["red", "orange", "yellow", "green", "blue", "purple", "pink", "brown", "black", "white"] 60 | COLOR_CLASSIFIERS = {} 61 | 62 | # Evaluation parts 63 | 64 | class ImageCrops(torch.utils.data.Dataset): 65 | def __init__(self, image: Image.Image, objects): 66 | self._image = image.convert("RGB") 67 | bgcolor = "#999" 68 | if bgcolor == "original": 69 | self._blank = self._image.copy() 70 | else: 71 | self._blank = Image.new("RGB", image.size, color=bgcolor) 72 | self._objects = objects 73 | 74 | def __len__(self): 75 | return len(self._objects) 76 | 77 | def __getitem__(self, index): 78 | box, mask = self._objects[index] 79 | if mask is not None: 80 | assert tuple(self._image.size[::-1]) == tuple(mask.shape), (index, self._image.size[::-1], mask.shape) 81 | image = Image.composite(self._image, self._blank, Image.fromarray(mask)) 82 | else: 83 | image = self._image 84 | image = image.crop(box[:4]) 85 | return (transform(image), 0) 86 | 87 | 88 | def color_classification(image, bboxes, classname): 89 | if classname not in COLOR_CLASSIFIERS: 90 | COLOR_CLASSIFIERS[classname] = zsc.zero_shot_classifier( 91 | clip_model, tokenizer, COLORS, 92 | [ 93 | f"a photo of a {{c}} {classname}", 94 | f"a photo of a {{c}}-colored {classname}", 95 | f"a photo of a {{c}} object" 96 | ], 97 | DEVICE 98 | ) 99 | clf = COLOR_CLASSIFIERS[classname] 100 | dataloader = torch.utils.data.DataLoader( 101 | ImageCrops(image, bboxes), 102 | batch_size=16, num_workers=4 103 | ) 104 | with torch.no_grad(): 105 | pred, _ = zsc.run_classification(clip_model, clf, dataloader, DEVICE) 106 | return [COLORS[index.item()] for index in pred.argmax(1)] 107 | 108 | 109 | def compute_iou(box_a, box_b): 110 | area_fn = lambda box: max(box[2] - box[0] + 1, 0) * max(box[3] - box[1] + 1, 0) 111 | i_area = area_fn([ 112 | max(box_a[0], box_b[0]), max(box_a[1], box_b[1]), 113 | min(box_a[2], box_b[2]), min(box_a[3], box_b[3]) 114 | ]) 115 | u_area = area_fn(box_a) + area_fn(box_b) - i_area 116 | return i_area / u_area if u_area else 0 117 | 118 | 119 | def relative_position(obj_a, obj_b): 120 | """Give position of A relative to B, factoring in object dimensions""" 121 | boxes = np.array([obj_a[0], obj_b[0]])[:, :4].reshape(2, 2, 2) 122 | center_a, center_b = boxes.mean(axis=-2) 123 | dim_a, dim_b = np.abs(np.diff(boxes, axis=-2))[..., 0, :] 124 | offset = center_a - center_b 125 | # 126 | revised_offset = np.maximum(np.abs(offset) - POSITION_THRESHOLD * (dim_a + dim_b), 0) * np.sign(offset) 127 | if np.all(np.abs(revised_offset) < 1e-3): 128 | return set() 129 | # 130 | dx, dy = revised_offset / np.linalg.norm(offset) 131 | relations = set() 132 | if dx < -0.5: relations.add("left of") 133 | if dx > 0.5: relations.add("right of") 134 | if dy < -0.5: relations.add("above") 135 | if dy > 0.5: relations.add("below") 136 | return relations 137 | 138 | 139 | def evaluate(image, objects, metadata): 140 | """ 141 | Evaluate given image using detected objects on the global metadata specifications. 142 | Assumptions: 143 | * Metadata combines 'include' clauses with AND, and 'exclude' clauses with OR 144 | * All clauses are independent, i.e., duplicating a clause has no effect on the correctness 145 | * CHANGED: Color and position will only be evaluated on the most confidently predicted objects; 146 | therefore, objects are expected to appear in sorted order 147 | """ 148 | correct = True 149 | reason = [] 150 | matched_groups = [] 151 | # Check for expected objects 152 | for req in metadata.get('include', []): 153 | classname = req['class'] 154 | matched = True 155 | found_objects = objects.get(classname, [])[:req['count']] 156 | if len(found_objects) < req['count']: 157 | correct = matched = False 158 | reason.append(f"expected {classname}>={req['count']}, found {len(found_objects)}") 159 | else: 160 | if 'color' in req: 161 | # Color check 162 | colors = color_classification(image, found_objects, classname) 163 | if colors.count(req['color']) < req['count']: 164 | correct = matched = False 165 | reason.append( 166 | f"expected {req['color']} {classname}>={req['count']}, found " + 167 | f"{colors.count(req['color'])} {req['color']}; and " + 168 | ", ".join(f"{colors.count(c)} {c}" for c in COLORS if c in colors) 169 | ) 170 | if 'position' in req and matched: 171 | # Relative position check 172 | expected_rel, target_group = req['position'] 173 | if matched_groups[target_group] is None: 174 | correct = matched = False 175 | reason.append(f"no target for {classname} to be {expected_rel}") 176 | else: 177 | for obj in found_objects: 178 | for target_obj in matched_groups[target_group]: 179 | true_rels = relative_position(obj, target_obj) 180 | if expected_rel not in true_rels: 181 | correct = matched = False 182 | reason.append( 183 | f"expected {classname} {expected_rel} target, found " + 184 | f"{' and '.join(true_rels)} target" 185 | ) 186 | break 187 | if not matched: 188 | break 189 | if matched: 190 | matched_groups.append(found_objects) 191 | else: 192 | matched_groups.append(None) 193 | # Check for non-expected objects 194 | for req in metadata.get('exclude', []): 195 | classname = req['class'] 196 | if len(objects.get(classname, [])) >= req['count']: 197 | correct = False 198 | reason.append(f"expected {classname}<{req['count']}, found {len(objects[classname])}") 199 | return correct, "\n".join(reason) 200 | 201 | def evaluate_reward(image, objects, metadata): 202 | """ 203 | Evaluate given image using detected objects on the global metadata specifications. 204 | Assumptions: 205 | * Metadata combines 'include' clauses with AND, and 'exclude' clauses with OR 206 | * All clauses are independent, i.e., duplicating a clause has no effect on the correctness 207 | * CHANGED: Color and position will only be evaluated on the most confidently predicted objects; 208 | therefore, objects are expected to appear in sorted order 209 | """ 210 | correct = True 211 | reason = [] 212 | rewards = [] 213 | matched_groups = [] 214 | # Check for expected objects 215 | for req in metadata.get('include', []): 216 | classname = req['class'] 217 | matched = True 218 | found_objects = objects.get(classname, []) 219 | rewards.append(1-abs(req['count'] - len(found_objects))/req['count']) 220 | if len(found_objects) != req['count']: 221 | correct = matched = False 222 | reason.append(f"expected {classname}=={req['count']}, found {len(found_objects)}") 223 | if 'color' in req or 'position' in req: 224 | rewards.append(0.0) 225 | else: 226 | if 'color' in req: 227 | # Color check 228 | colors = color_classification(image, found_objects, classname) 229 | rewards.append(1-abs(req['count'] - colors.count(req['color']))/req['count']) 230 | if colors.count(req['color']) != req['count']: 231 | correct = matched = False 232 | reason.append( 233 | f"expected {req['color']} {classname}>={req['count']}, found " + 234 | f"{colors.count(req['color'])} {req['color']}; and " + 235 | ", ".join(f"{colors.count(c)} {c}" for c in COLORS if c in colors) 236 | ) 237 | if 'position' in req and matched: 238 | # Relative position check 239 | expected_rel, target_group = req['position'] 240 | if matched_groups[target_group] is None: 241 | correct = matched = False 242 | reason.append(f"no target for {classname} to be {expected_rel}") 243 | rewards.append(0.0) 244 | else: 245 | for obj in found_objects: 246 | for target_obj in matched_groups[target_group]: 247 | true_rels = relative_position(obj, target_obj) 248 | if expected_rel not in true_rels: 249 | correct = matched = False 250 | reason.append( 251 | f"expected {classname} {expected_rel} target, found " + 252 | f"{' and '.join(true_rels)} target" 253 | ) 254 | rewards.append(0.0) 255 | break 256 | if not matched: 257 | break 258 | rewards.append(1.0) 259 | if matched: 260 | matched_groups.append(found_objects) 261 | else: 262 | matched_groups.append(None) 263 | reward = sum(rewards) / len(rewards) if rewards else 0 264 | return correct, reward, "\n".join(reason) 265 | 266 | def evaluate_image(image_pils, metadatas, only_strict): 267 | results = inference_detector(object_detector, [np.array(image_pil) for image_pil in image_pils]) 268 | ret = [] 269 | for result, image_pil, metadata in zip(results, image_pils, metadatas): 270 | bbox = result[0] if isinstance(result, tuple) else result 271 | segm = result[1] if isinstance(result, tuple) and len(result) > 1 else None 272 | image = ImageOps.exif_transpose(image_pil) 273 | detected = {} 274 | # Determine bounding boxes to keep 275 | confidence_threshold = THRESHOLD if metadata['tag'] != "counting" else COUNTING_THRESHOLD 276 | for index, classname in enumerate(classnames): 277 | ordering = np.argsort(bbox[index][:, 4])[::-1] 278 | ordering = ordering[bbox[index][ordering, 4] > confidence_threshold] # Threshold 279 | ordering = ordering[:MAX_OBJECTS].tolist() # Limit number of detected objects per class 280 | detected[classname] = [] 281 | while ordering: 282 | max_obj = ordering.pop(0) 283 | detected[classname].append((bbox[index][max_obj], None if segm is None else segm[index][max_obj])) 284 | ordering = [ 285 | obj for obj in ordering 286 | if NMS_THRESHOLD == 1 or compute_iou(bbox[index][max_obj], bbox[index][obj]) < NMS_THRESHOLD 287 | ] 288 | if not detected[classname]: 289 | del detected[classname] 290 | # Evaluate 291 | is_strict_correct, score, reason = evaluate_reward(image, detected, metadata) 292 | if only_strict: 293 | is_correct = False 294 | else: 295 | is_correct, _ = evaluate(image, detected, metadata) 296 | ret.append({ 297 | 'tag': metadata['tag'], 298 | 'prompt': metadata['prompt'], 299 | 'correct': is_correct, 300 | 'strict_correct': is_strict_correct, 301 | 'score': score, 302 | 'reason': reason, 303 | 'metadata': json.dumps(metadata), 304 | 'details': json.dumps({ 305 | key: [box.tolist() for box, _ in value] 306 | for key, value in detected.items() 307 | }) 308 | }) 309 | return ret 310 | 311 | object_detector, (clip_model, transform, tokenizer), classnames = load_models() 312 | THRESHOLD = 0.3 313 | COUNTING_THRESHOLD = 0.9 314 | MAX_OBJECTS = 16 315 | NMS_THRESHOLD = 1.0 316 | POSITION_THRESHOLD = 0.1 317 | 318 | 319 | @torch.no_grad() 320 | def compute_geneval(images, metadatas, only_strict=False): 321 | required_keys = ['single_object', 'two_object', 'counting', 'colors', 'position', 'color_attr'] 322 | scores = [] 323 | strict_rewards = [] 324 | grouped_strict_rewards = defaultdict(list) 325 | rewards = [] 326 | grouped_rewards = defaultdict(list) 327 | results = evaluate_image(images, metadatas, only_strict=only_strict) 328 | for result in results: 329 | strict_rewards.append(1.0 if result["strict_correct"] else 0.0) 330 | scores.append(result["score"]) 331 | rewards.append(1.0 if result["correct"] else 0.0) 332 | tag = result["tag"] 333 | for key in required_keys: 334 | if key != tag: 335 | grouped_strict_rewards[key].append(-10.0) 336 | grouped_rewards[key].append(-10.0) 337 | else: 338 | grouped_strict_rewards[tag].append(1.0 if result["strict_correct"] else 0.0) 339 | grouped_rewards[tag].append(1.0 if result["correct"] else 0.0) 340 | return scores, rewards, strict_rewards, dict(grouped_rewards), dict(grouped_strict_rewards) 341 | 342 | return compute_geneval -------------------------------------------------------------------------------- /reward_server/object_names.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorcycle 5 | airplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | couch 59 | potted plant 60 | bed 61 | dining table 62 | toilet 63 | tv 64 | laptop 65 | computer mouse 66 | tv remote 67 | computer keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush 81 | -------------------------------------------------------------------------------- /test/test_deqa.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | import io 4 | import pickle 5 | import glob 6 | import tqdm 7 | from concurrent.futures import ThreadPoolExecutor 8 | import os 9 | 10 | BATCH_SIZE = 8 11 | 12 | paths = [os.path.join(os.getcwd(), "a photo of a brown giraffe and a white stop sign.png")] 13 | 14 | def f(_): 15 | for i in tqdm.tqdm(range(0, len(paths), BATCH_SIZE)): 16 | batch_paths = paths[i : i + BATCH_SIZE] 17 | 18 | jpeg_data = [] 19 | queries = [] 20 | answers = [] 21 | for path in batch_paths: 22 | image = Image.open(path) 23 | 24 | # Compress the images using JPEG 25 | buffer = io.BytesIO() 26 | image.save(buffer, format="JPEG") 27 | jpeg_data.append(buffer.getvalue()) 28 | 29 | data = {"images": jpeg_data} 30 | data_bytes = pickle.dumps(data) 31 | 32 | # Send the JPEG data in an HTTP POST request to the server 33 | url = "http://10.82.140.11:18085" 34 | response = requests.post(url, data=data_bytes) 35 | 36 | # Print the response from the server 37 | print(response) 38 | print(response.content) 39 | response_data = pickle.loads(response.content) 40 | 41 | for output in response_data["outputs"]: 42 | print(output) 43 | print("--") 44 | 45 | # with ThreadPoolExecutor(max_workers=8) as executor: 46 | # for _ in executor.map(f, range(8)): 47 | # pass 48 | f(1) -------------------------------------------------------------------------------- /test/test_geneval.py: -------------------------------------------------------------------------------- 1 | import requests 2 | from PIL import Image 3 | import io 4 | import pickle 5 | import glob 6 | import tqdm 7 | from concurrent.futures import ThreadPoolExecutor 8 | import os 9 | 10 | BATCH_SIZE = 24 11 | 12 | paths = [os.path.join(os.getcwd(), "a photo of a brown giraffe and a white stop sign.png")] 13 | 14 | def f(_): 15 | for i in tqdm.tqdm(range(0, len(paths), BATCH_SIZE), desc="Processing batches", unit="batch", ncols=100, leave=True, dynamic_ncols=True): 16 | batch_paths = paths[i : i + BATCH_SIZE] 17 | 18 | jpeg_data = [] 19 | queries = [] 20 | answers = [] 21 | for path in batch_paths: 22 | image = Image.open(path) 23 | 24 | # Compress the images using JPEG 25 | buffer = io.BytesIO() 26 | image.save(buffer, format="JPEG") 27 | jpeg_data.append(buffer.getvalue()) 28 | 29 | data = { 30 | "images": jpeg_data, 31 | "meta_datas": [{"tag": "color_attr", "include": [{"class": "giraffe", "count": 1, "color": "brown"}, {"class": "stop sign", "count": 1, "color": "white"}], "prompt": "a photo of a brown giraffe and a white stop sign"}], 32 | "only_strict": False, 33 | } 34 | 35 | data_bytes = pickle.dumps(data) 36 | 37 | # Send the JPEG data in an HTTP POST request to the server 38 | url = "http://127.0.0.1:18085" 39 | response = requests.post(url, data=data_bytes) 40 | 41 | # Print the response from the server 42 | response_data = pickle.loads(response.content) 43 | # for output in response_data["outputs"]: 44 | # print(output) 45 | # print("--") 46 | 47 | # with ThreadPoolExecutor(max_workers=8) as executor: 48 | # for _ in executor.map(f, range(8)): 49 | # pass 50 | f(1) --------------------------------------------------------------------------------