├── .env.example ├── .github └── workflows │ └── continuous-integration-pip.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .project-root ├── .vscode ├── launch.json └── settings.json ├── LICENSE ├── README.md ├── config ├── data │ ├── flive.yaml │ ├── kadid10k.yaml │ ├── koniq10k.yaml │ ├── live.yaml │ ├── livec.yaml │ ├── spaq.yaml │ └── tid2013.yaml ├── default.yaml ├── dist │ ├── 2gpus.yaml │ ├── 4gpus.yaml │ ├── cpu_only.yaml │ └── single_gpu.yaml ├── hydra │ ├── default.yaml │ └── job_logging │ │ └── custom.yaml ├── job │ ├── debug.yaml │ ├── eval.yaml │ ├── train.yaml │ ├── train_loda_flive.yaml │ ├── train_loda_kadid10k.yaml │ ├── train_loda_koniq10k.yaml │ ├── train_loda_live.yaml │ ├── train_loda_livec.yaml │ ├── train_loda_spaq.yaml │ └── train_loda_tid2013.yaml ├── load │ ├── eval.yaml │ └── scratch.yaml ├── log │ ├── debug.yaml │ ├── eval.yaml │ └── train.yaml ├── loss │ └── default.yaml ├── model │ └── loda.yaml ├── optimizer │ ├── adam.yaml │ └── adamW.yaml └── scheduler │ └── cosineAnnealingLR.yaml ├── environment.yaml ├── requirements-dev.txt ├── requirements.txt ├── scripts ├── benchmark │ ├── benchmark_loda_all.sh │ ├── benchmark_loda_eval_all.sh │ ├── benchmark_loda_flive.sh │ ├── benchmark_loda_kadid10k.sh │ ├── benchmark_loda_koniq10k.sh │ ├── benchmark_loda_live.sh │ ├── benchmark_loda_livec.sh │ ├── benchmark_loda_spaq.sh │ └── benchmark_loda_tid2013.sh ├── process_flive.py ├── process_kadid10k.py ├── process_koniq10k.py ├── process_live.py ├── process_livechallenge.py ├── process_spaq.py └── process_tid2013.py ├── setup.cfg ├── src ├── __init__.py ├── dataset │ ├── __init__.py │ ├── dataloader.py │ ├── dataloader_mode.py │ ├── flive_dataset.py │ ├── kadid10k_dataset.py │ ├── koniq10k_dataset.py │ ├── live_dataset.py │ ├── livechallenge_dataset.py │ ├── spaq_dataset.py │ └── tid2013_dataset.py ├── eval.py ├── model │ ├── __init__.py │ ├── loda.py │ ├── model.py │ ├── model_dispatcher.py │ └── patch_embed.py ├── tools │ ├── __init__.py │ ├── test_model.py │ └── train_model.py ├── trainer.py └── utils │ ├── __init__.py │ ├── dataset_utils.py │ ├── loss.py │ ├── metrics.py │ ├── utils.py │ └── writer.py └── tests ├── __init__.py ├── model ├── __init__.py ├── model_test.py └── net_arch_test.py └── test_case.py /.env.example: -------------------------------------------------------------------------------- 1 | # example of file for storing private and user specific environment variables, like keys or system paths 2 | # rename it to ".env" (excluded from version control by default) 3 | # .env is loaded by train.py automatically 4 | # hydra allows you to reference variables in .yaml configs with special syntax: ${oc.env:MY_VAR} 5 | 6 | # The project use wandb by default for logging 7 | # You should set WANDB_INIT_ENTITY to use wandb 8 | # or comment it out to disable wandb in config files 9 | # WANDB_INIT_ENTITY=username 10 | 11 | MY_VAR="/home/user/my/system/path" 12 | -------------------------------------------------------------------------------- /.github/workflows/continuous-integration-pip.yml: -------------------------------------------------------------------------------- 1 | name: CI (pip) 2 | on: [push] 3 | 4 | jobs: 5 | build: 6 | strategy: 7 | matrix: 8 | python-version: [3.10.14] 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout 12 | uses: actions/checkout@v2 13 | - name: Set up Python ${{ matrix.python-version }} 14 | uses: actions/setup-python@v1 15 | with: 16 | python-version: ${{ matrix.python-version }} 17 | - name: Install dependencies 18 | run: | 19 | python -m pip install --upgrade pip 20 | pip install -r requirements.txt 21 | pip install -r requirements-dev.txt 22 | # install black if available (Python 3.6 and above), and autopep8 for testing the pipe mode 23 | pip install black || true 24 | pip install autopep8 || true 25 | # install sphinx_gallery and matplotlib if available (may not work on pypy) 26 | pip install sphinx_gallery || true 27 | - name: Lint with flake8 28 | run: | 29 | # stop the build if there are Python syntax errors or undefined names 30 | flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics 31 | # exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide 32 | flake8 . --count --exit-zero --max-complexity=10 --statistics 33 | - name: Lint with isort 34 | run: | 35 | isort --profile black . -c -v 36 | - name: Test with pytest 37 | run: pytest -q tests 38 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # logs, checkpoints 2 | chkpt/ 3 | logs/ 4 | dataset/meta/ 5 | outputs/ 6 | wandb/ 7 | data/ 8 | runs/ 9 | 10 | # just temporary folder... 11 | temp/ 12 | 13 | # Byte-compiled / optimized / DLL files 14 | __pycache__/ 15 | *.py[cod] 16 | *$py.class 17 | 18 | # C extensions 19 | *.so 20 | 21 | # Distribution / packaging 22 | .Python 23 | build/ 24 | develop-eggs/ 25 | dist/ 26 | downloads/ 27 | eggs/ 28 | .eggs/ 29 | lib/ 30 | lib64/ 31 | parts/ 32 | sdist/ 33 | var/ 34 | wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # pyenv 88 | .python-version 89 | 90 | # celery beat schedule file 91 | celerybeat-schedule 92 | 93 | # SageMath parsed files 94 | *.sage.py 95 | 96 | # Environments 97 | .env 98 | .venv 99 | env/ 100 | venv/ 101 | ENV/ 102 | env.bak/ 103 | venv.bak/ 104 | 105 | # Spyder project settings 106 | .spyderproject 107 | .spyproject 108 | 109 | # Rope project settings 110 | .ropeproject 111 | 112 | # mkdocs documentation 113 | /site 114 | 115 | # mypy 116 | .mypy_cache/ 117 | 118 | # Created by https://www.toptal.com/developers/gitignore/api/python,pycharm,pycharm+all,vscode,jupyternotebooks 119 | # Edit at https://www.toptal.com/developers/gitignore?templates=python,pycharm,pycharm+all,vscode,jupyternotebooks 120 | 121 | ### JupyterNotebooks ### 122 | # gitignore template for Jupyter Notebooks 123 | # website: http://jupyter.org/ 124 | 125 | .ipynb_checkpoints 126 | */.ipynb_checkpoints/* 127 | 128 | # IPython 129 | profile_default/ 130 | ipython_config.py 131 | 132 | # Remove previous ipynb_checkpoints 133 | # git rm -r .ipynb_checkpoints/ 134 | 135 | ### PyCharm ### 136 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 137 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 138 | 139 | # User-specific stuff 140 | .idea/**/workspace.xml 141 | .idea/**/tasks.xml 142 | .idea/**/usage.statistics.xml 143 | .idea/**/dictionaries 144 | .idea/**/shelf 145 | 146 | # Generated files 147 | .idea/**/contentModel.xml 148 | 149 | # Sensitive or high-churn files 150 | .idea/**/dataSources/ 151 | .idea/**/dataSources.ids 152 | .idea/**/dataSources.local.xml 153 | .idea/**/sqlDataSources.xml 154 | .idea/**/dynamic.xml 155 | .idea/**/uiDesigner.xml 156 | .idea/**/dbnavigator.xml 157 | 158 | # Gradle 159 | .idea/**/gradle.xml 160 | .idea/**/libraries 161 | 162 | # Gradle and Maven with auto-import 163 | # When using Gradle or Maven with auto-import, you should exclude module files, 164 | # since they will be recreated, and may cause churn. Uncomment if using 165 | # auto-import. 166 | # .idea/artifacts 167 | # .idea/compiler.xml 168 | # .idea/jarRepositories.xml 169 | # .idea/modules.xml 170 | # .idea/*.iml 171 | # .idea/modules 172 | # *.iml 173 | # *.ipr 174 | 175 | # CMake 176 | cmake-build-*/ 177 | 178 | # Mongo Explorer plugin 179 | .idea/**/mongoSettings.xml 180 | 181 | # File-based project format 182 | *.iws 183 | 184 | # IntelliJ 185 | out/ 186 | 187 | # mpeltonen/sbt-idea plugin 188 | .idea_modules/ 189 | 190 | # JIRA plugin 191 | atlassian-ide-plugin.xml 192 | 193 | # Cursive Clojure plugin 194 | .idea/replstate.xml 195 | 196 | # Crashlytics plugin (for Android Studio and IntelliJ) 197 | com_crashlytics_export_strings.xml 198 | crashlytics.properties 199 | crashlytics-build.properties 200 | fabric.properties 201 | 202 | # Editor-based Rest Client 203 | .idea/httpRequests 204 | 205 | # Android studio 3.1+ serialized cache file 206 | .idea/caches/build_file_checksums.ser 207 | 208 | ### PyCharm Patch ### 209 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 210 | 211 | # *.iml 212 | # modules.xml 213 | # .idea/misc.xml 214 | # *.ipr 215 | 216 | # Sonarlint plugin 217 | # https://plugins.jetbrains.com/plugin/7973-sonarlint 218 | .idea/**/sonarlint/ 219 | 220 | # SonarQube Plugin 221 | # https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin 222 | .idea/**/sonarIssues.xml 223 | 224 | # Markdown Navigator plugin 225 | # https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced 226 | .idea/**/markdown-navigator.xml 227 | .idea/**/markdown-navigator-enh.xml 228 | .idea/**/markdown-navigator/ 229 | 230 | # Cache file creation bug 231 | # See https://youtrack.jetbrains.com/issue/JBR-2257 232 | .idea/$CACHE_FILE$ 233 | 234 | # CodeStream plugin 235 | # https://plugins.jetbrains.com/plugin/12206-codestream 236 | .idea/codestream.xml 237 | 238 | ### PyCharm+all ### 239 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider 240 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 241 | 242 | # User-specific stuff 243 | 244 | # Generated files 245 | 246 | # Sensitive or high-churn files 247 | 248 | # Gradle 249 | 250 | # Gradle and Maven with auto-import 251 | # When using Gradle or Maven with auto-import, you should exclude module files, 252 | # since they will be recreated, and may cause churn. Uncomment if using 253 | # auto-import. 254 | # .idea/artifacts 255 | # .idea/compiler.xml 256 | # .idea/jarRepositories.xml 257 | # .idea/modules.xml 258 | # .idea/*.iml 259 | # .idea/modules 260 | # *.iml 261 | # *.ipr 262 | 263 | # CMake 264 | 265 | # Mongo Explorer plugin 266 | 267 | # File-based project format 268 | 269 | # IntelliJ 270 | 271 | # mpeltonen/sbt-idea plugin 272 | 273 | # JIRA plugin 274 | 275 | # Cursive Clojure plugin 276 | 277 | # Crashlytics plugin (for Android Studio and IntelliJ) 278 | 279 | # Editor-based Rest Client 280 | 281 | # Android studio 3.1+ serialized cache file 282 | 283 | ### PyCharm+all Patch ### 284 | # Ignores the whole .idea folder and all .iml files 285 | # See https://github.com/joeblau/gitignore.io/issues/186 and https://github.com/joeblau/gitignore.io/issues/360 286 | 287 | .idea/ 288 | 289 | # Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-249601023 290 | 291 | *.iml 292 | modules.xml 293 | .idea/misc.xml 294 | *.ipr 295 | 296 | # Sonarlint plugin 297 | .idea/sonarlint 298 | 299 | ### Python ### 300 | # Byte-compiled / optimized / DLL files 301 | __pycache__/ 302 | *.py[cod] 303 | *$py.class 304 | 305 | # C extensions 306 | *.so 307 | 308 | # Distribution / packaging 309 | .Python 310 | build/ 311 | develop-eggs/ 312 | dist/ 313 | downloads/ 314 | eggs/ 315 | .eggs/ 316 | lib/ 317 | lib64/ 318 | parts/ 319 | sdist/ 320 | var/ 321 | wheels/ 322 | pip-wheel-metadata/ 323 | share/python-wheels/ 324 | *.egg-info/ 325 | .installed.cfg 326 | *.egg 327 | MANIFEST 328 | 329 | # PyInstaller 330 | # Usually these files are written by a python script from a template 331 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 332 | *.manifest 333 | *.spec 334 | 335 | # Installer logs 336 | pip-log.txt 337 | pip-delete-this-directory.txt 338 | 339 | # Unit test / coverage reports 340 | htmlcov/ 341 | .tox/ 342 | .nox/ 343 | .coverage 344 | .coverage.* 345 | .cache 346 | nosetests.xml 347 | coverage.xml 348 | *.cover 349 | *.py,cover 350 | .hypothesis/ 351 | .pytest_cache/ 352 | pytestdebug.log 353 | 354 | # Translations 355 | *.mo 356 | *.pot 357 | 358 | # Django stuff: 359 | *.log 360 | local_settings.py 361 | db.sqlite3 362 | db.sqlite3-journal 363 | 364 | # Flask stuff: 365 | instance/ 366 | .webassets-cache 367 | 368 | # Scrapy stuff: 369 | .scrapy 370 | 371 | # Sphinx documentation 372 | docs/_build/ 373 | doc/_build/ 374 | 375 | # PyBuilder 376 | target/ 377 | 378 | # Jupyter Notebook 379 | 380 | # IPython 381 | 382 | # pyenv 383 | .python-version 384 | 385 | # pipenv 386 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 387 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 388 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 389 | # install all needed dependencies. 390 | #Pipfile.lock 391 | 392 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 393 | __pypackages__/ 394 | 395 | # Celery stuff 396 | celerybeat-schedule 397 | celerybeat.pid 398 | 399 | # SageMath parsed files 400 | *.sage.py 401 | 402 | # Environments 403 | .env 404 | .venv 405 | env/ 406 | venv/ 407 | ENV/ 408 | env.bak/ 409 | venv.bak/ 410 | pythonenv* 411 | 412 | # Spyder project settings 413 | .spyderproject 414 | .spyproject 415 | 416 | # Rope project settings 417 | .ropeproject 418 | 419 | # mkdocs documentation 420 | /site 421 | 422 | # mypy 423 | .mypy_cache/ 424 | .dmypy.json 425 | dmypy.json 426 | 427 | # Pyre type checker 428 | .pyre/ 429 | 430 | # pytype static type analyzer 431 | .pytype/ 432 | 433 | # profiling data 434 | .prof 435 | 436 | ### vscode ### 437 | .vscode/* 438 | !.vscode/settings.json 439 | !.vscode/tasks.json 440 | !.vscode/launch.json 441 | !.vscode/extensions.json 442 | *.code-workspace 443 | 444 | # include specific subfolder 445 | !config/data/ 446 | !config/dist/ 447 | 448 | # End of https://www.toptal.com/developers/gitignore/api/python,pycharm,pycharm+all,vscode,jupyternotebooks 449 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com:/pre-commit/pre-commit-hooks 3 | # if there is network issue 4 | # - repo: git@github.com:/pre-commit/pre-commit-hooks.git 5 | # if there is no github ssh 6 | # - repo: https://github.com/pre-commit/pre-commit-hooks 7 | rev: v4.4.0 8 | hooks: 9 | - id: check-yaml 10 | - id: debug-statements 11 | - id: end-of-file-fixer 12 | - id: trailing-whitespace 13 | 14 | - repo: https://github.com:/PyCQA/autoflake 15 | rev: v2.2.0 16 | hooks: 17 | - id: autoflake 18 | 19 | - repo: https://github.com:/psf/black 20 | rev: 23.7.0 21 | hooks: 22 | - id: black 23 | 24 | - repo: https://github.com:/PyCQA/flake8 25 | rev: 6.0.0 26 | hooks: 27 | - id: flake8 28 | 29 | - repo: https://github.com:/pycqa/isort 30 | rev: 5.12.0 31 | hooks: 32 | - id: isort 33 | args: ["--profile", "black"] 34 | 35 | # shell scripts linter 36 | - repo: https://github.com:/shellcheck-py/shellcheck-py 37 | rev: v0.9.0.5 38 | hooks: 39 | - id: shellcheck 40 | 41 | # jupyter notebook cell output clearing 42 | - repo: https://github.com:/kynan/nbstripout 43 | rev: 0.6.1 44 | hooks: 45 | - id: nbstripout 46 | 47 | # jupyter notebook linting 48 | - repo: https://github.com:/nbQA-dev/nbQA 49 | rev: 1.7.0 50 | hooks: 51 | - id: nbqa-black 52 | args: ["--line-length=99"] 53 | - id: nbqa-isort 54 | args: ["--profile=black"] 55 | - id: nbqa-flake8 56 | args: 57 | [ 58 | "--extend-ignore=E203,E402,E501,F401,F841", 59 | "--exclude=logs/*,data/*", 60 | ] 61 | -------------------------------------------------------------------------------- /.project-root: -------------------------------------------------------------------------------- 1 | # this file is required for inferring the project root directory 2 | # do not delete 3 | -------------------------------------------------------------------------------- /.vscode/launch.json: -------------------------------------------------------------------------------- 1 | { 2 | // Use IntelliSense to learn about possible attributes. 3 | // Hover to view descriptions of existing attributes. 4 | // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 5 | "version": "0.2.0", 6 | "configurations": [ 7 | { 8 | "name": "Python: Current File", 9 | "type": "python", 10 | "request": "launch", 11 | "program": "${file}", 12 | "console": "integratedTerminal", 13 | "justMyCode": true 14 | }, 15 | { 16 | "name": "trainer", 17 | "type": "python", 18 | "request": "launch", 19 | "program": "${workspaceFolder}/src/trainer.py", 20 | "console": "integratedTerminal", 21 | "justMyCode": false, 22 | "args": [ 23 | "job=debug", 24 | "hydra.job_logging.handlers.file.filename=logs/debug.log", 25 | "hydra.verbose=true", 26 | ] 27 | }, 28 | { 29 | "name": "eval", 30 | "type": "python", 31 | "request": "launch", 32 | "program": "${workspaceFolder}/src/eval.py", 33 | "console": "integratedTerminal", 34 | "justMyCode": false, 35 | "args": [ 36 | "job=eval", 37 | "hydra.job_logging.handlers.file.filename=logs/eval.log", 38 | "hydra.verbose=true", 39 | "data=koniq10k", 40 | "load.network_chkpt_path=runs/debug/loda_koniq10k_debug/chkpt_dir/loda_koniq10k_debug_1510.pt", 41 | "split_index=0", 42 | ] 43 | }, 44 | ] 45 | } 46 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "files.watcherExclude": { 3 | "**/data/**": true, 4 | "**/.pytest_cache/**": true, 5 | "**/chkpt/**": true, 6 | "**/runs/**": true, 7 | "**/wandb/**": true 8 | } 9 | } 10 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # LoDa (CVPR 2024) 2 | 3 | This is the **official repository** of the [**paper**](https://openaccess.thecvf.com/content/CVPR2024/html/Xu_Boosting_Image_Quality_Assessment_through_Efficient_Transformer_Adaptation_with_Local_CVPR_2024_paper.html) "*Boosting Image Quality Assessment through Efficient Transformer Adaptation with Local Feature Enhancement*". 4 | 5 | ## Updates 6 | 7 | * [06/2024] We released the source code of 'LoDa', check the code on [GitHub](https://github.com/NeosXu/LoDa) 8 | 9 | ## To-Dos 10 | 11 | * [ ] Checkpoints & Logs 12 | * [x] Initialization 13 | 14 | ## Usage 15 | 16 | ### Pre-requisition 17 | 18 | #### Installation 19 | 20 | We recommend using the **conda** package manager to avoid dependency problems. 21 | 22 | 1. Clone the repository 23 | 24 | ```sh 25 | git clone https://github.com/NeosXu/LoDa 26 | ``` 27 | 28 | 2. Install Python dependencies 29 | 30 | ```sh 31 | # Using conda (Recommend) 32 | conda env create -f environment.yaml 33 | conda activate loda 34 | 35 | # Using pip 36 | pip install -r requirements.txt 37 | pip install -r requirements-dev.txt # Optional, for code formatting 38 | 39 | pre-commit install # Optional, for code formatting 40 | ``` 41 | 42 | #### Data Preparation 43 | You need to download the corresponding datasets in the paper and place them under the same directory ```data```. 44 | 45 | For each dataset, run the corresponding preprocess script to process the image, metadata and train/test split of the datasets. 46 | 47 | ```sh 48 | dataset_names=("live" "tid2013" "kadid10k" "livechallenge" "koniq10k" "spaq" "flive") 49 | for dn in "${dataset_names[@]}" 50 | do 51 | python scripts/process_"$dn".py 52 | done 53 | ``` 54 | 55 | At the end, the directory structure should look like this: 56 | 57 | ``` 58 | ├── data 59 | | ├── flive 60 | | ├── kadid10k 61 | | ├── koniq10k 62 | | ├── live_iqa 63 | | ├── LIVEC 64 | | ├── spaq 65 | | ├── tid2013 66 | | ├── meta_info 67 | | | ├── meta_info_FLIVEDataset.csv 68 | | | ├── meta_info_KADID10kDataset.csv 69 | | | ├── meta_info_KonIQ10kDataset.csv 70 | | | ├── ... 71 | | ├── train_split_info 72 | | | ├── flive_82_seed3407.pkl 73 | | | ├── kadid10k_82_seed3407.pkl 74 | | | ├── koniq10k_82_seed3407.pkl 75 | | | ├── ... 76 | ``` 77 | 78 | Or you can simply download the `meta_info` and `train_split_info` from [Google Drive](https://drive.google.com/drive/folders/1LiOQ2dvdssnUoVnIsB97Z21g6g_cmrbw?usp=sharing). 79 | 80 | ### Training 81 | 82 | ```bash 83 | mkdir logs 84 | # all datasets 85 | bash scripts/benchmark/benchmark_loda_all.sh 0 86 | # a single dataset 87 | bash scripts/benchmark/benchmark_loda_koniq10k.sh 0 88 | ``` 89 | 90 | ### Evaluation 91 | 92 | ```bash 93 | mkdir logs 94 | # all datasets 95 | bash scripts/benchmark/benchmark_loda_eval_all.sh 0 96 | ``` 97 | 98 | ## Citing LoDa 99 | 100 | If you find this project helpful in your research, please consider citing our papers: 101 | 102 | ```text 103 | @InProceedings{Xu_2024_CVPR, 104 | author = {Xu, Kangmin and Liao, Liang and Xiao, Jing and Chen, Chaofeng and Wu, Haoning and Yan, Qiong and Lin, Weisi}, 105 | title = {Boosting Image Quality Assessment through Efficient Transformer Adaptation with Local Feature Enhancement}, 106 | booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 107 | month = {June}, 108 | year = {2024}, 109 | pages = {2662-2672} 110 | } 111 | ``` 112 | 113 | ## Acknowledgement 114 | 115 | We borrowed some parts from the following open-source projects: 116 | 117 | * [IQA-PyTorch](https://github.com/chaofengc/IQA-PyTorch) 118 | * [pytorch-project-template](https://github.com/ryul99/pytorch-project-template) 119 | 120 | Many thanks to them. 121 | -------------------------------------------------------------------------------- /config/data/flive.yaml: -------------------------------------------------------------------------------- 1 | name: flive 2 | root: data/flive 3 | meta_info_file: data/meta_info/meta_info_FLIVEDataset.csv 4 | train_test_split_file: data/train_split_info/flive_82_seed3407.pkl 5 | divide_dataset_per_gpu: true 6 | data_num: 39810 7 | train_data_num: 31848 8 | test_data_num: 7962 9 | image_size: 384 10 | patch_size: 224 11 | -------------------------------------------------------------------------------- /config/data/kadid10k.yaml: -------------------------------------------------------------------------------- 1 | name: kadid10k 2 | root: data/kadid10k 3 | meta_info_file: data/meta_info/meta_info_KADID10kDataset.csv 4 | train_test_split_file: data/train_split_info/kadid10k_82_seed3407.pkl 5 | divide_dataset_per_gpu: true 6 | data_num: 10125 7 | train_data_num: 8100 8 | test_data_num: 2025 9 | image_size: 384 10 | patch_size: 224 11 | -------------------------------------------------------------------------------- /config/data/koniq10k.yaml: -------------------------------------------------------------------------------- 1 | name: koniq10k 2 | root: data/koniq10k 3 | meta_info_file: data/meta_info/meta_info_KonIQ10kDataset.csv 4 | train_test_split_file: data/train_split_info/koniq10k_82_seed3407.pkl 5 | divide_dataset_per_gpu: true 6 | data_num: 10073 7 | train_data_num: 8058 8 | test_data_num: 2015 9 | image_size: 384 10 | patch_size: 224 11 | -------------------------------------------------------------------------------- /config/data/live.yaml: -------------------------------------------------------------------------------- 1 | name: live 2 | root: data/live_iqa 3 | meta_info_file: data/meta_info/meta_info_LIVEIQADataset.csv 4 | train_test_split_file: data/train_split_info/live_82_seed3407.pkl 5 | divide_dataset_per_gpu: true 6 | data_num: 779 7 | train_data_num: 619 8 | test_data_num: 160 9 | image_size: 384 10 | patch_size: 224 11 | -------------------------------------------------------------------------------- /config/data/livec.yaml: -------------------------------------------------------------------------------- 1 | name: livec 2 | root: data/LIVEC 3 | meta_info_file: data/meta_info/meta_info_LIVEChallengeDataset.csv 4 | train_test_split_file: data/train_split_info/livechallenge_82_seed3407.pkl 5 | divide_dataset_per_gpu: true 6 | data_num: 1162 7 | train_data_num: 930 8 | test_data_num: 232 9 | image_size: 384 10 | patch_size: 224 11 | -------------------------------------------------------------------------------- /config/data/spaq.yaml: -------------------------------------------------------------------------------- 1 | name: spaq 2 | root: data/spaq 3 | meta_info_file: data/meta_info/meta_info_SPAQDataset.csv 4 | train_test_split_file: data/train_split_info/spaq_82_seed3407.pkl 5 | divide_dataset_per_gpu: true 6 | data_num: 11125 7 | train_data_num: 8900 8 | test_data_num: 2225 9 | image_size: 384 10 | patch_size: 224 11 | -------------------------------------------------------------------------------- /config/data/tid2013.yaml: -------------------------------------------------------------------------------- 1 | name: tid2013 2 | root: data/tid2013 3 | meta_info_file: data/meta_info/meta_info_TID2013Dataset.csv 4 | train_test_split_file: data/train_split_info/tid2013_82_seed3407.pkl 5 | divide_dataset_per_gpu: true 6 | data_num: 3000 7 | train_data_num: 2400 8 | test_data_num: 600 9 | image_size: 384 10 | patch_size: 224 11 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - job: train 5 | - hydra: default 6 | -------------------------------------------------------------------------------- /config/dist/2gpus.yaml: -------------------------------------------------------------------------------- 1 | device: cuda 2 | mode: nccl 3 | master_addr: 127.0.0.1 4 | master_port: '23456' 5 | timeout: null 6 | # gpus is the number that you want to use with DDP (gpus value is used at world_size in DDP). 7 | # Not using DDP when gpus is 0, using all gpus when gpus is -1. 8 | gpus: 2 9 | # device num is the number of gpu used 10 | # the purpose of using it is gpus will sometimes be 0 or -1, it is not the exact number of device 11 | device_num: 2 12 | -------------------------------------------------------------------------------- /config/dist/4gpus.yaml: -------------------------------------------------------------------------------- 1 | device: cuda 2 | mode: nccl 3 | master_addr: 127.0.0.1 4 | master_port: '23456' 5 | timeout: null 6 | # gpus is the number that you want to use with DDP (gpus value is used at world_size in DDP). 7 | # Not using DDP when gpus is 0, using all gpus when gpus is -1. 8 | gpus: 4 9 | # device num is the number of gpu used 10 | # the purpose of using it is gpus will sometimes be 0 or -1, it is not the exact number of device 11 | device_num: 4 12 | -------------------------------------------------------------------------------- /config/dist/cpu_only.yaml: -------------------------------------------------------------------------------- 1 | device: cpu 2 | # if device is cpu, device num can only set to 1 now 3 | device_num: 1 4 | -------------------------------------------------------------------------------- /config/dist/single_gpu.yaml: -------------------------------------------------------------------------------- 1 | device: cuda 2 | # gpus is the number that you want to use with DDP (gpus value is used at world_size in DDP). 3 | # Not using DDP when gpus is 0, using all gpus when gpus is -1. 4 | gpus: 0 5 | # device num is the number of gpu used 6 | # the purpose of using it is gpus will sometimes be 0 or -1, it is not the exact number of device 7 | device_num: 1 8 | -------------------------------------------------------------------------------- /config/hydra/default.yaml: -------------------------------------------------------------------------------- 1 | # https://hydra.cc/docs/configure_hydra/intro/ 2 | 3 | # enable color logging 4 | defaults: 5 | - override job_logging: custom 6 | - override hydra_logging: colorlog 7 | 8 | # output directory, generated dynamically on each run 9 | run: 10 | dir: ${working_dir}/outputs/${name}-${now:%Y-%m-%d_%H-%M-%S} 11 | job: 12 | chdir: false 13 | -------------------------------------------------------------------------------- /config/hydra/job_logging/custom.yaml: -------------------------------------------------------------------------------- 1 | # @package hydra.job_logging 2 | # python logging configuration for tasks 3 | version: 1 4 | formatters: 5 | simple: 6 | format: '%(message)s' 7 | detailed: 8 | format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s' 9 | handlers: 10 | console: 11 | class: logging.StreamHandler 12 | formatter: simple 13 | stream: ext://sys.stdout 14 | file: 15 | class: logging.FileHandler 16 | formatter: detailed 17 | filename: logs/trainer.log 18 | root: 19 | level: INFO 20 | handlers: [console, file] 21 | 22 | disable_existing_loggers: False 23 | -------------------------------------------------------------------------------- /config/job/debug.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: koniq10k 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: debug 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_koniq10k_debug 17 | run_group: debug 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 3 27 | batch_size: 16 28 | num_workers: 2 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 16 34 | num_workers: 2 35 | -------------------------------------------------------------------------------- /config/job/eval.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: koniq10k 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: eval 12 | - /load: eval 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_koniq10k_eval 17 | run_group: eval 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 3 27 | batch_size: 200 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/job/train.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: koniq10k 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: train 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_koniq10k_train 17 | run_group: train 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 3 27 | batch_size: 200 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/job/train_loda_flive.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: flive 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: train 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_flive_train_split${split_index} 17 | run_group: loda_benchmark_flive 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 1 27 | batch_size: 128 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/job/train_loda_kadid10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: kadid10k 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: train 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_kadid10k_train_split${split_index} 17 | run_group: loda_benchmark_kadid10k 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 3 27 | batch_size: 128 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/job/train_loda_koniq10k.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: koniq10k 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: train 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_koniq10k_train_split${split_index} 17 | run_group: loda_benchmark_koniq10k 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 3 27 | batch_size: 128 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/job/train_loda_live.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: live 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: train 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_live_train_split${split_index} 17 | run_group: loda_benchmark_live 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 5 27 | batch_size: 128 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/job/train_loda_livec.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: livec 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: train 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_livec_train_split${split_index} 17 | run_group: loda_benchmark_livec 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 5 27 | batch_size: 128 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/job/train_loda_spaq.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: spaq 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: train 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_spaq_train_split${split_index} 17 | run_group: loda_benchmark_spaq 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 3 27 | batch_size: 128 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/job/train_loda_tid2013.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | 3 | defaults: 4 | - _self_ 5 | - /dist: single_gpu 6 | - /model: loda 7 | - /data: tid2013 8 | - /optimizer: adamW 9 | - /scheduler: cosineAnnealingLR 10 | - /loss: default 11 | - /log: train 12 | - /load: scratch 13 | 14 | # job general configs 15 | project_name: loda 16 | name: loda_tid2013_train_split${split_index} 17 | run_group: loda_benchmark_tid2013 18 | working_dir: runs/${run_group}/${name} 19 | random_seed: 3407 20 | train_test_num: 1 21 | num_epoch: 10 22 | split_index: 0 23 | 24 | # training configs 25 | train: 26 | patch_num: 10 27 | batch_size: 128 28 | num_workers: 10 29 | 30 | # test configs 31 | test: 32 | patch_num: 15 33 | batch_size: 512 34 | num_workers: 10 35 | -------------------------------------------------------------------------------- /config/load/eval.yaml: -------------------------------------------------------------------------------- 1 | resume_state_path: null 2 | network_chkpt_path: ??? 3 | wandb_load_path: null 4 | strict_load: false 5 | -------------------------------------------------------------------------------- /config/load/scratch.yaml: -------------------------------------------------------------------------------- 1 | resume_state_path: null 2 | network_chkpt_path: null 3 | wandb_load_path: null 4 | strict_load: false 5 | -------------------------------------------------------------------------------- /config/log/debug.yaml: -------------------------------------------------------------------------------- 1 | chkpt_dir: ${working_dir}/chkpt_dir 2 | summary_interval: 1 3 | net_chkpt_interval: 1 4 | train_chkpt_interval: 1 5 | 6 | # wandb settings 7 | use_wandb: true 8 | wandb_save_model: false 9 | wandb_init_conf: 10 | project: ${project_name} 11 | name: ${name} 12 | entity: null 13 | mode: disabled 14 | save_code: false 15 | 16 | # tensorboard settings 17 | use_tensorboard: false 18 | -------------------------------------------------------------------------------- /config/log/eval.yaml: -------------------------------------------------------------------------------- 1 | chkpt_dir: ${working_dir}/chkpt_dir 2 | summary_interval: 1 3 | net_chkpt_interval: 1 4 | train_chkpt_interval: 1 5 | 6 | # wandb settings 7 | use_wandb: false 8 | 9 | # tensorboard settings 10 | use_tensorboard: false 11 | -------------------------------------------------------------------------------- /config/log/train.yaml: -------------------------------------------------------------------------------- 1 | chkpt_dir: ${working_dir}/chkpt_dir 2 | summary_interval: 1 3 | net_chkpt_interval: 1 4 | train_chkpt_interval: 1 5 | 6 | # wandb settings 7 | use_wandb: true 8 | wandb_save_model: false 9 | wandb_init_conf: 10 | project: ${project_name} 11 | name: ${name} 12 | entity: ${oc.env:WANDB_INIT_ENTITY} 13 | mode: offline 14 | save_code: true 15 | 16 | # tensorboard settings 17 | use_tensorboard: false 18 | -------------------------------------------------------------------------------- /config/loss/default.yaml: -------------------------------------------------------------------------------- 1 | fn: 2 | - [plcc_loss, 1] 3 | -------------------------------------------------------------------------------- /config/model/loda.yaml: -------------------------------------------------------------------------------- 1 | model_name: loda 2 | basic_model_name: vit_base_patch16_224.augreg2_in21k_ft_in1k 3 | basic_model_pretrained: true 4 | hyper_vit: 5 | dropout_rate: 0.1 6 | vit_param: 7 | img_size: 224 8 | patch_size: 16 9 | embed_dim: 768 10 | depth: 12 11 | qkv_bias: true 12 | num_heads: 12 13 | num_classes: 1000 14 | learner_param: 15 | num_classes: 1 16 | embed_dim: ${model.vit_param.embed_dim} 17 | feature_channels: [256, 512, 1024, 2048] 18 | cnn_feature_num: 4 19 | interaction_block_num: ${model.vit_param.depth} 20 | latent_dim: 64 21 | grid_size: 7 22 | cross_attn_num_heads: 4 23 | feature_model: 24 | name: resnet50 25 | load_timm_model: true 26 | out_indices: [1, 2, 3, 4] 27 | -------------------------------------------------------------------------------- /config/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | name: adam 2 | param: 3 | lr: 3e-4 4 | weight_decay: 1e-2 5 | -------------------------------------------------------------------------------- /config/optimizer/adamW.yaml: -------------------------------------------------------------------------------- 1 | name: adamW 2 | param: 3 | lr: 3e-4 4 | weight_decay: 1e-2 5 | -------------------------------------------------------------------------------- /config/scheduler/cosineAnnealingLR.yaml: -------------------------------------------------------------------------------- 1 | name: CosineAnnealingLR 2 | param: 3 | T_max: ${eval:'${data.train_data_num} * ${train.patch_num} * ${num_epoch} // (${train.batch_size} * ${dist.device_num})'} 4 | eta_min: 0 5 | -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: loda 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - defaults 6 | - conda-forge 7 | dependencies: 8 | - _libgcc_mutex=0.1=conda_forge 9 | - _openmp_mutex=4.5=2_gnu 10 | - abseil-cpp=20211102.0=h27087fc_1 11 | - absl-py=1.4.0=pyhd8ed1ab_0 12 | - aiohttp=3.8.1=py310h5764c6d_1 13 | - aiosignal=1.3.1=pyhd8ed1ab_0 14 | - antlr-python-runtime=4.9.3=pyhd8ed1ab_1 15 | - appdirs=1.4.4=pyh9f0ad1d_0 16 | - async-timeout=4.0.2=pyhd8ed1ab_0 17 | - attrs=22.2.0=pyh71513ae_0 18 | - autopep8=2.0.2=pyhd8ed1ab_0 19 | - black=23.3.0=py310hff52083_0 20 | - blas=1.0=mkl 21 | - blinker=1.6.2=pyhd8ed1ab_0 22 | - bottleneck=1.3.5=py310ha9d4c09_0 23 | - brotlipy=0.7.0=py310h7f8727e_1002 24 | - bzip2=1.0.8=h7b6447c_0 25 | - c-ares=1.19.0=h5eee18b_0 26 | - ca-certificates=2023.5.7=hbcca054_0 27 | - cachetools=5.3.0=pyhd8ed1ab_0 28 | - certifi=2023.5.7=pyhd8ed1ab_0 29 | - cffi=1.15.1=py310h5eee18b_3 30 | - cfgv=3.3.1=pyhd8ed1ab_0 31 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 32 | - click=8.1.3=unix_pyhd8ed1ab_2 33 | - colorama=0.4.6=pyhd8ed1ab_0 34 | - cryptography=41.0.1=py310h75e40e8_0 35 | - cuda-cudart=11.7.99=0 36 | - cuda-cupti=11.7.101=0 37 | - cuda-libraries=11.7.1=0 38 | - cuda-nvrtc=11.7.99=0 39 | - cuda-nvtx=11.7.91=0 40 | - cuda-runtime=11.7.1=0 41 | - distlib=0.3.6=pyhd8ed1ab_0 42 | - docker-pycreds=0.4.0=py_0 43 | - et_xmlfile=1.1.0=pyhd8ed1ab_0 44 | - exceptiongroup=1.1.1=pyhd8ed1ab_0 45 | - ffmpeg=4.3=hf484d3e_0 46 | - filelock=3.9.0=py310h06a4308_0 47 | - flake8=6.0.0=pyhd8ed1ab_0 48 | - freetype=2.12.1=h4a9f257_0 49 | - frozenlist=1.3.3=py310h5eee18b_0 50 | - fsspec=2023.4.0=pyh1a96a4e_0 51 | - giflib=5.2.1=h5eee18b_3 52 | - gitdb=4.0.10=pyhd8ed1ab_0 53 | - gitpython=3.1.31=pyhd8ed1ab_0 54 | - gmp=6.2.1=h295c915_3 55 | - gmpy2=2.1.2=py310heeb90bb_0 56 | - gnutls=3.6.15=he1e5248_0 57 | - google-auth=2.17.3=pyh1a96a4e_0 58 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 59 | - grpc-cpp=1.48.0=h00ec82a_0 60 | - grpcio=1.48.0=py310h70d52b7_0 61 | - huggingface_hub=0.14.1=pyhd8ed1ab_0 62 | - hydra-core=1.3.2=pyhd8ed1ab_0 63 | - identify=2.5.22=pyhd8ed1ab_0 64 | - idna=3.4=py310h06a4308_0 65 | - importlib-metadata=6.6.0=pyha770c72_0 66 | - importlib_resources=5.12.0=pyhd8ed1ab_0 67 | - iniconfig=2.0.0=pyhd8ed1ab_0 68 | - intel-openmp=2021.4.0=h06a4308_3561 69 | - isort=5.12.0=pyhd8ed1ab_1 70 | - jinja2=3.1.2=py310h06a4308_0 71 | - joblib=1.2.0=pyhd8ed1ab_0 72 | - jpeg=9e=h5eee18b_1 73 | - lame=3.100=h7b6447c_0 74 | - lcms2=2.15=hfd0df8a_0 75 | - ld_impl_linux-64=2.38=h1181459_1 76 | - lerc=3.0=h295c915_0 77 | - libblas=3.9.0=12_linux64_mkl 78 | - libcblas=3.9.0=12_linux64_mkl 79 | - libcublas=11.10.3.66=0 80 | - libcufft=10.7.2.124=h4fbf590_0 81 | - libcufile=1.6.1.9=0 82 | - libcurand=10.3.2.106=0 83 | - libcusolver=11.4.0.1=0 84 | - libcusparse=11.7.4.91=0 85 | - libdeflate=1.17=h5eee18b_0 86 | - libffi=3.4.2=h6a678d5_6 87 | - libgcc-ng=13.1.0=he5830b7_0 88 | - libgfortran-ng=12.2.0=h69a702a_19 89 | - libgfortran5=12.2.0=h337968e_19 90 | - libgomp=13.1.0=he5830b7_0 91 | - libiconv=1.16=h7f8727e_2 92 | - libidn2=2.3.2=h7f8727e_0 93 | - liblapack=3.9.0=12_linux64_mkl 94 | - libnpp=11.7.4.75=0 95 | - libnsl=2.0.0=h7f98852_0 96 | - libnvjpeg=11.8.0.2=0 97 | - libpng=1.6.39=h5eee18b_0 98 | - libprotobuf=3.20.3=he621ea3_0 99 | - libsqlite=3.42.0=h2797004_0 100 | - libstdcxx-ng=13.1.0=hfd8a6a1_0 101 | - libtasn1=4.19.0=h5eee18b_0 102 | - libtiff=4.5.0=h6a678d5_2 103 | - libunistring=0.9.10=h27cfd23_0 104 | - libuuid=2.38.1=h0b41bf4_0 105 | - libwebp=1.2.4=h11a3e52_1 106 | - libwebp-base=1.2.4=h5eee18b_1 107 | - libxcb=1.13=h7f98852_1004 108 | - libzlib=1.2.13=hd590300_5 109 | - lz4-c=1.9.4=h6a678d5_0 110 | - markdown=3.4.3=pyhd8ed1ab_0 111 | - markupsafe=2.1.1=py310h7f8727e_0 112 | - mccabe=0.7.0=pyhd8ed1ab_0 113 | - mkl=2021.4.0=h06a4308_640 114 | - mkl-service=2.4.0=py310h7f8727e_0 115 | - mkl_fft=1.3.1=py310hd6ae3a3_0 116 | - mkl_random=1.2.2=py310h00e6091_0 117 | - mpc=1.1.0=h10f8cd9_1 118 | - mpfr=4.0.2=hb69a4c5_1 119 | - multidict=6.0.2=py310h5eee18b_0 120 | - mypy_extensions=1.0.0=pyha770c72_0 121 | - ncurses=6.4=h6a678d5_0 122 | - nettle=3.7.3=hbbd107a_1 123 | - networkx=2.8.4=py310h06a4308_1 124 | - nodeenv=1.7.0=pyhd8ed1ab_0 125 | - numexpr=2.8.4=py310h8879344_0 126 | - numpy=1.23.5=py310hd5efca6_0 127 | - numpy-base=1.23.5=py310h8e6c178_0 128 | - oauthlib=3.2.2=pyhd8ed1ab_0 129 | - omegaconf=2.3.0=pyhd8ed1ab_0 130 | - openh264=2.1.1=h4ff587b_0 131 | - openjpeg=2.5.0=hfec8fc6_2 132 | - openpyxl=3.1.2=py310h2372a71_0 133 | - openssl=3.1.1=hd590300_1 134 | - packaging=23.1=pyhd8ed1ab_0 135 | - pandas=1.5.3=py310h1128e8f_0 136 | - pathspec=0.11.1=pyhd8ed1ab_0 137 | - pathtools=0.1.2=py_1 138 | - pillow=9.4.0=py310h023d228_1 139 | - pip=23.0.1=py310h06a4308_0 140 | - platformdirs=2.5.2=py310h06a4308_0 141 | - pluggy=1.0.0=pyhd8ed1ab_5 142 | - pre-commit=3.2.2=pyha770c72_0 143 | - protobuf=3.20.3=py310h6a678d5_0 144 | - psutil=5.9.0=py310h5eee18b_0 145 | - pthread-stubs=0.4=h36c2ea0_1001 146 | - pyasn1=0.4.8=py_0 147 | - pyasn1-modules=0.2.7=py_0 148 | - pycodestyle=2.10.0=pyhd8ed1ab_0 149 | - pycparser=2.21=pyhd3eb1b0_0 150 | - pyflakes=3.0.1=pyhd8ed1ab_0 151 | - pyjwt=2.6.0=pyhd8ed1ab_0 152 | - pyopenssl=23.2.0=pyhd8ed1ab_1 153 | - pysocks=1.7.1=py310h06a4308_0 154 | - pytest=7.3.1=pyhd8ed1ab_0 155 | - python=3.10.12=hd12c33a_0_cpython 156 | - python-dateutil=2.8.2=pyhd8ed1ab_0 157 | - python-dotenv=1.0.0=pyhd8ed1ab_0 158 | - python_abi=3.10=2_cp310 159 | - pytorch=2.0.0=py3.10_cuda11.7_cudnn8.5.0_0 160 | - pytorch-cuda=11.7=h778d358_3 161 | - pytorch-mutex=1.0=cuda 162 | - pytz=2023.3=pyhd8ed1ab_0 163 | - pyu2f=0.1.5=pyhd8ed1ab_0 164 | - pyyaml=6.0=py310h5764c6d_4 165 | - re2=2022.06.01=h27087fc_1 166 | - readline=8.2=h5eee18b_0 167 | - requests=2.28.1=py310h06a4308_1 168 | - requests-oauthlib=1.3.1=pyhd8ed1ab_0 169 | - rsa=4.9=pyhd8ed1ab_0 170 | - safetensors=0.3.1=py310hcb5633a_0 171 | - scikit-learn=1.0.2=py310h1246948_0 172 | - scipy=1.8.1=py310h7612f91_0 173 | - screen=4.8.0=he28a2e2_0 174 | - sentry-sdk=1.20.0=pyhd8ed1ab_0 175 | - setproctitle=1.2.2=py310h5764c6d_2 176 | - setuptools=66.0.0=py310h06a4308_0 177 | - six=1.16.0=pyhd3eb1b0_1 178 | - smmap=3.0.5=pyh44b312d_0 179 | - sqlite=3.41.2=h5eee18b_0 180 | - sympy=1.11.1=pyh04b8f61_3 181 | - tensorboard=2.11.2=pyhd8ed1ab_0 182 | - tensorboard-data-server=0.6.1=py310h600f1e7_4 183 | - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0 184 | - threadpoolctl=3.1.0=pyh8a188c0_0 185 | - timm=0.9.2=pyhd8ed1ab_0 186 | - tk=8.6.12=h1ccaba5_0 187 | - tomli=2.0.1=pyhd8ed1ab_0 188 | - torchaudio=2.0.0=py310_cu117 189 | - torchtriton=2.0.0=py310 190 | - torchvision=0.15.0=py310_cu117 191 | - tqdm=4.65.0=pyhd8ed1ab_1 192 | - typing-extensions=4.5.0=py310h06a4308_0 193 | - typing_extensions=4.5.0=py310h06a4308_0 194 | - tzdata=2023c=h04d1e81_0 195 | - ukkonen=1.0.1=py310hbf28c38_2 196 | - urllib3=1.26.15=py310h06a4308_0 197 | - virtualenv=20.17.1=py310hff52083_0 198 | - wandb=0.15.0=pyhd8ed1ab_0 199 | - werkzeug=2.3.1=pyhd8ed1ab_0 200 | - wheel=0.38.4=py310h06a4308_0 201 | - xorg-libxau=1.0.11=hd590300_0 202 | - xorg-libxdmcp=1.1.3=h7f98852_0 203 | - xz=5.2.10=h5eee18b_1 204 | - yaml=0.2.5=h7f98852_2 205 | - yarl=1.7.2=py310h5764c6d_2 206 | - zipp=3.15.0=pyhd8ed1ab_0 207 | - zlib=1.2.13=hd590300_5 208 | - zstd=1.5.5=hc292b87_0 209 | - pip: 210 | - colorlog==6.7.0 211 | - hydra-colorlog==1.2.0 212 | - mpmath==1.2.1 213 | - pyrootutils==1.0.4 214 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | autopep8 2 | flake8 3 | pre-commit 4 | black 5 | isort 6 | pytest -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | # --------- pytorch --------- # 2 | torch>=2.0.0 3 | torchvision>=0.15.0 4 | 5 | # --------- hydra --------- # 6 | hydra-core==1.3.2 7 | hydra-colorlog==1.2.0 8 | 9 | # --------- loggers --------- # 10 | wandb==0.15.0 11 | 12 | # --------- huggingface --------- # 13 | timm==0.9.2 14 | 15 | # --------- pillow --------- # 16 | pillow==9.4.0 17 | 18 | # --------- others --------- # 19 | scipy==1.8.1 20 | pandas==1.5.3 21 | tensorboard==2.11.2 22 | pyrootutils==1.0.4 23 | tqdm==4.65.0 24 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on koniq10k 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 6 | job=train_loda_koniq10k \ 7 | split_index="${split_idx}" 8 | done >> logs/benchmark_loda_koniq10k.log 2>&1 9 | 10 | # loda on kadid10k 11 | for split_idx in {0..9}; 12 | do 13 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 14 | job=train_loda_kadid10k \ 15 | split_index="${split_idx}" 16 | done >> logs/benchmark_loda_kadid10k.log 2>&1 17 | 18 | # loda on spaq 19 | for split_idx in {0..9}; 20 | do 21 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 22 | job=train_loda_spaq \ 23 | split_index="${split_idx}" 24 | done >> logs/benchmark_loda_spaq.log 2>&1 25 | 26 | # loda on livec 27 | for split_idx in {0..9}; 28 | do 29 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 30 | job=train_loda_livec \ 31 | split_index="${split_idx}" 32 | done >> logs/benchmark_loda_livec.log 2>&1 33 | 34 | # loda on live 35 | for split_idx in {0..9}; 36 | do 37 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 38 | job=train_loda_live \ 39 | split_index="${split_idx}" 40 | done >> logs/benchmark_loda_live.log 2>&1 41 | 42 | # loda on tid2013 43 | for split_idx in {0..9}; 44 | do 45 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 46 | job=train_loda_tid2013 \ 47 | split_index="${split_idx}" 48 | done >> logs/benchmark_loda_tid2013.log 2>&1 49 | 50 | # loda on flive 51 | for split_idx in {0..9}; 52 | do 53 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 54 | job=train_loda_flive \ 55 | split_index="${split_idx}" 56 | done >> logs/benchmark_loda_flive.log 2>&1 57 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_eval_all.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on koniq10k 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/eval.py \ 6 | job=eval \ 7 | run_group=loda_koniq10k_eval \ 8 | name=loda_koniq10k_eval_split"${split_index}" \ 9 | split_index="${split_idx}" \ 10 | data=koniq10k \ 11 | load.network_chkpt_path=chkpt/rep/koniq10k/loda_koniq10k_split"${split_idx}".pt 12 | done >> logs/loda_koniq10k_eval.log 2>&1 13 | 14 | # loda on kadid10k 15 | for split_idx in {0..9}; 16 | do 17 | CUDA_VISIBLE_DEVICES=$1 python src/eval.py \ 18 | job=eval \ 19 | run_group=loda_kadid10k_eval \ 20 | name=loda_kadid10k_eval_split"${split_index}" \ 21 | split_index="${split_idx}" \ 22 | data=kadid10k \ 23 | load.network_chkpt_path=chkpt/rep/kadid10k/loda_kadid10k_split"${split_idx}".pt 24 | done >> logs/loda_kadid10k_eval.log 2>&1 25 | 26 | # loda on livec 27 | for split_idx in {0..9}; 28 | do 29 | CUDA_VISIBLE_DEVICES=$1 python src/eval.py \ 30 | job=eval \ 31 | run_group=loda_livec_eval \ 32 | name=loda_livec_eval_split"${split_index}" \ 33 | split_index="${split_idx}" \ 34 | data=livec \ 35 | load.network_chkpt_path=chkpt/rep/livec/loda_livec_split"${split_idx}".pt 36 | done >> logs/loda_livec_eval.log 2>&1 37 | 38 | # loda on live 39 | for split_idx in {0..9}; 40 | do 41 | CUDA_VISIBLE_DEVICES=$1 python src/eval.py \ 42 | job=eval \ 43 | run_group=loda_live_eval \ 44 | name=loda_live_eval_split"${split_index}" \ 45 | split_index="${split_idx}" \ 46 | data=live \ 47 | load.network_chkpt_path=chkpt/rep/live/loda_live_split"${split_idx}".pt 48 | done >> logs/loda_live_eval.log 2>&1 49 | 50 | # loda on spaq 51 | for split_idx in {0..9}; 52 | do 53 | CUDA_VISIBLE_DEVICES=$1 python src/eval.py \ 54 | job=eval \ 55 | run_group=loda_spaq_eval \ 56 | name=loda_spaq_eval_split"${split_index}" \ 57 | split_index="${split_idx}" \ 58 | data=spaq \ 59 | load.network_chkpt_path=chkpt/rep/spaq/loda_spaq_split"${split_idx}".pt 60 | done >> logs/loda_spaq_eval.log 2>&1 61 | 62 | # loda on tid2013 63 | for split_idx in {0..9}; 64 | do 65 | CUDA_VISIBLE_DEVICES=$1 python src/eval.py \ 66 | job=eval \ 67 | run_group=loda_tid2013_eval \ 68 | name=loda_tid2013_eval_split"${split_index}" \ 69 | split_index="${split_idx}" \ 70 | data=tid2013 \ 71 | load.network_chkpt_path=chkpt/rep/tid2013/loda_tid2013_split"${split_idx}".pt 72 | done >> logs/loda_tid2013_eval.log 2>&1 73 | 74 | # loda on flive 75 | for split_idx in {0..9}; 76 | do 77 | CUDA_VISIBLE_DEVICES=$1 python src/eval.py \ 78 | job=eval \ 79 | run_group=loda_flive_eval \ 80 | name=loda_flive_eval_split"${split_index}" \ 81 | split_index="${split_idx}" \ 82 | data=flive \ 83 | load.network_chkpt_path=chkpt/rep/flive/loda_flive_split"${split_idx}".pt 84 | done >> logs/loda_flive_eval.log 2>&1 85 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_flive.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on flive 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 6 | job=train_loda_flive \ 7 | split_index="${split_idx}" 8 | done >> logs/benchmark_loda_flive.log 2>&1 9 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_kadid10k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on kadid10k 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 6 | job=train_loda_kadid10k \ 7 | split_index="${split_idx}" 8 | done >> logs/benchmark_loda_kadid10k.log 2>&1 9 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_koniq10k.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on koniq10k 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 6 | job=train_loda_koniq10k \ 7 | split_index="${split_idx}" 8 | done >> logs/benchmark_loda_koniq10k.log 2>&1 9 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_live.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on live 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 6 | job=train_loda_live \ 7 | split_index="${split_idx}" 8 | done >> logs/benchmark_loda_live.log 2>&1 9 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_livec.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on livec 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 6 | job=train_loda_livec \ 7 | split_index="${split_idx}" 8 | done >> logs/benchmark_loda_livec.log 2>&1 9 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_spaq.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on spaq 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 6 | job=train_loda_spaq \ 7 | split_index="${split_idx}" 8 | done >> logs/benchmark_loda_spaq.log 2>&1 9 | -------------------------------------------------------------------------------- /scripts/benchmark/benchmark_loda_tid2013.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # loda on tid2013 3 | for split_idx in {0..9}; 4 | do 5 | CUDA_VISIBLE_DEVICES=$1 python src/trainer.py \ 6 | job=train_loda_tid2013 \ 7 | split_index="${split_idx}" 8 | done >> logs/benchmark_loda_tid2013.log 2>&1 9 | -------------------------------------------------------------------------------- /scripts/process_flive.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on IQA-PyTorch: https://github.com/chaofengc/IQA-PyTorch 3 | """ 4 | import csv 5 | import os 6 | import pickle 7 | import random 8 | 9 | import pandas as pd 10 | import pyrootutils 11 | import torchvision 12 | from tqdm import tqdm 13 | 14 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 15 | 16 | from src.utils.dataset_utils import pil_loader 17 | 18 | """ 19 | The FLIVE dataset introduced by: 20 | 21 | Zhenqiang Ying, Haoran Niu, Praful Gupta, Dhruv Mahajan, Deepti Ghadiyaram, Alan Bovik. 22 | "From Patches to Pictures (PaQ-2-PiQ): Mapping the Perceptual Space of Picture Quality."" 23 | CVPR2020. 24 | 25 | Reference github: 26 | [1] https://github.com/niu-haoran/FLIVE_Database 27 | [2] https://github.com/baidut/PaQ-2-PiQ 28 | 29 | Image/patch labels are in [1], please copy the following prepare script from [2] to [1] 30 | because there are bugs in the script of [1] 31 | 32 | https://github.com/baidut/PaQ-2-PiQ/blob/master/database_prep.ipynb 33 | 34 | Besides, the patch labels in [1] are not complete. 9 patches from EE371R are in 35 | 36 | https://github.com/baidut/PaQ-2-PiQ/tree/master/database/patches 37 | """ 38 | 39 | 40 | def get_meta_info(): 41 | """ 42 | Only use whole image. 43 | """ 44 | img_label_file = "data/flive/labels_image.csv" 45 | save_meta_path = "data/meta_info/meta_info_FLIVEDataset.csv" 46 | 47 | all_img_label = pd.read_csv(img_label_file) 48 | 49 | with open(save_meta_path, "w") as sf: 50 | csvwriter = csv.writer(sf) 51 | head = ["img_name", "mos"] 52 | csvwriter.writerow(head) 53 | 54 | # get image info 55 | for i in tqdm(range(all_img_label.shape[0])): 56 | name = all_img_label.loc[i]["name"] 57 | mos = all_img_label.loc[i]["mos"] 58 | 59 | row = [name, mos] 60 | csvwriter.writerow(row) 61 | 62 | 63 | def get_random_splits(seed=3407): 64 | random.seed(seed) 65 | total_num = 39810 66 | all_img_index = list(range(total_num)) 67 | num_splits = 10 68 | save_path = f"data/train_split_info/flive_82_seed{seed}.pkl" 69 | 70 | # ratio = [0.8, 0.2] # train/val/test 71 | sep_index = int(round(0.8 * total_num)) 72 | 73 | split_info = {} 74 | for i in range(num_splits): 75 | random.shuffle(all_img_index) 76 | split_info[i] = { 77 | "train": all_img_index[:sep_index], 78 | "val": [], 79 | "test": all_img_index[sep_index:], 80 | } 81 | print( 82 | "train num: {} | val num: {} | test num: {}".format( 83 | len(split_info[i]["train"]), 84 | len(split_info[i]["val"]), 85 | len(split_info[i]["test"]), 86 | ) 87 | ) 88 | with open(save_path, "wb") as sf: 89 | pickle.dump(split_info, sf) 90 | 91 | 92 | def downsample_images(down_size=384): 93 | image_path = "data/flive" 94 | meta_info_file = "data/meta_info/meta_info_FLIVEDataset.csv" 95 | save_image_path = "data/flive/flive_384" 96 | 97 | if not os.path.exists(save_image_path): 98 | os.makedirs(save_image_path, exist_ok=True) 99 | 100 | meta_info = pd.read_csv(meta_info_file) 101 | 102 | preprocess = torchvision.transforms.Compose( 103 | [ 104 | torchvision.transforms.Resize( 105 | size=down_size, 106 | interpolation=torchvision.transforms.InterpolationMode.BICUBIC, 107 | ) 108 | ] 109 | ) 110 | 111 | for i in tqdm(range(meta_info.shape[0])): 112 | img_name = meta_info.loc[i]["img_name"] 113 | img = pil_loader(os.path.join(image_path, img_name)) 114 | resized_img = preprocess(img) 115 | img_name = os.path.splitext(img_name)[0] + ".png" 116 | img_folder = os.path.split(img_name)[0] 117 | img_folder = os.path.join(save_image_path, img_folder) 118 | if not os.path.exists(img_folder): 119 | os.makedirs(img_folder, exist_ok=True) 120 | resized_img.save(os.path.join(save_image_path, img_name)) 121 | 122 | 123 | if __name__ == "__main__": 124 | get_meta_info() 125 | get_random_splits() 126 | downsample_images() 127 | -------------------------------------------------------------------------------- /scripts/process_kadid10k.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on IQA-PyTorch: https://github.com/chaofengc/IQA-PyTorch 3 | """ 4 | import csv 5 | import os 6 | import pickle 7 | import random 8 | 9 | import pandas as pd 10 | import pyrootutils 11 | import torchvision 12 | from tqdm import tqdm 13 | 14 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 15 | 16 | from src.utils.dataset_utils import pil_loader 17 | 18 | 19 | def get_meta_info(): 20 | info_file = "data/kadid10k/dmos.csv" 21 | 22 | save_meta_path = "data/meta_info/meta_info_KADID10kDataset.csv" 23 | with open(info_file, "r") as f, open(save_meta_path, "w+") as sf: 24 | csvreader = csv.reader(f) 25 | head = next(csvreader) 26 | print(head) 27 | 28 | new_head = ["ref_name", "dist_name", "dmos", "std"] 29 | csvwriter = csv.writer(sf) 30 | csvwriter.writerow(new_head) 31 | for _, row in enumerate(csvreader): 32 | dis_name = row[0] 33 | ref_name = row[1] 34 | dmos = row[2] 35 | std = row[3] 36 | csvwriter.writerow([ref_name, dis_name, dmos, std]) 37 | 38 | 39 | def get_random_splits(seed=3407): 40 | random.seed(seed) 41 | meta_info_file = "data/meta_info/meta_info_KADID10kDataset.csv" 42 | save_path = f"data/train_split_info/kadid10k_82_seed{seed}.pkl" 43 | ratio = 0.8 44 | 45 | meta_info = pd.read_csv(meta_info_file) 46 | 47 | ref_img_list = sorted( 48 | list(set(meta_info["ref_name"].tolist())) 49 | ) # sequence initialized without sorted() is random 50 | ref_img_num = len(ref_img_list) 51 | num_splits = 10 52 | train_num = int(round(ratio * ref_img_num)) 53 | 54 | split_info = {} 55 | for i in range(num_splits): 56 | split_info[i] = {"train": [], "val": [], "test": []} 57 | 58 | for i in range(num_splits): 59 | random.shuffle(ref_img_list) 60 | train_ref_img_names = ref_img_list[:train_num] 61 | for j in range(meta_info.shape[0]): 62 | tmp_ref_name = meta_info.loc[j]["ref_name"] 63 | if tmp_ref_name in train_ref_img_names: 64 | split_info[i]["train"].append(j) 65 | else: 66 | split_info[i]["test"].append(j) 67 | print( 68 | meta_info.shape[0], 69 | len(split_info[i]["train"]), 70 | len(split_info[i]["test"]), 71 | ) 72 | with open(save_path, "wb") as sf: 73 | pickle.dump(split_info, sf) 74 | 75 | 76 | def downsample_images(down_size=384): 77 | image_path = "data/kadid10k/images" 78 | meta_info_file = "data/meta_info/meta_info_KADID10kDataset.csv" 79 | save_image_path = "data/kadid10k/images_384" 80 | 81 | if not os.path.exists(save_image_path): 82 | os.makedirs(save_image_path, exist_ok=True) 83 | 84 | meta_info = pd.read_csv(meta_info_file) 85 | 86 | preprocess = torchvision.transforms.Compose( 87 | [ 88 | torchvision.transforms.Resize( 89 | size=down_size, 90 | interpolation=torchvision.transforms.InterpolationMode.BICUBIC, 91 | ) 92 | ] 93 | ) 94 | 95 | for i in tqdm(range(meta_info.shape[0])): 96 | img_name = meta_info.loc[i]["dist_name"] 97 | img = pil_loader(os.path.join(image_path, img_name)) 98 | resized_img = preprocess(img) 99 | img_name = os.path.splitext(img_name)[0] + ".png" 100 | resized_img.save(os.path.join(save_image_path, img_name)) 101 | 102 | 103 | if __name__ == "__main__": 104 | get_meta_info() 105 | get_random_splits() 106 | downsample_images() 107 | -------------------------------------------------------------------------------- /scripts/process_koniq10k.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on IQA-PyTorch: https://github.com/chaofengc/IQA-PyTorch 3 | """ 4 | import csv 5 | import os 6 | import pickle 7 | import random 8 | 9 | import pandas as pd 10 | import pyrootutils 11 | import torchvision 12 | from tqdm import tqdm 13 | 14 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 15 | 16 | from src.utils.dataset_utils import pil_loader 17 | 18 | 19 | def get_meta_info(): 20 | """ 21 | Train/Val/Test split file from official github: 22 | https://github.com/subpic/koniq/blob/master/metadata/koniq10k_distributions_sets.csv 23 | """ 24 | info_file = "data/koniq10k/koniq10k_scores_and_distributions.csv" 25 | 26 | save_meta_path = "data/meta_info/meta_info_KonIQ10kDataset.csv" 27 | with open(info_file, "r") as f, open(save_meta_path, "w+") as sf: 28 | csvreader = csv.reader(f) 29 | _ = next(csvreader) 30 | 31 | csvwriter = csv.writer(sf) 32 | new_head = [ 33 | "img_name", 34 | "mos", 35 | "std", 36 | "mos_zscore", 37 | "c1", 38 | "c2", 39 | "c3", 40 | "c4", 41 | "c5", 42 | "c_total", 43 | ] 44 | csvwriter.writerow(new_head) 45 | for _, row in enumerate(csvreader): 46 | new_row = [row[0]] + row[7:10] + row[1:7] 47 | csvwriter.writerow(new_row) 48 | 49 | 50 | def get_random_splits(seed=3407): 51 | """ 52 | Use 10 splits as most paper 53 | """ 54 | random.seed(seed) 55 | total_num = 10073 56 | all_img_index = list(range(total_num)) 57 | num_splits = 10 58 | 59 | # ratio = [0.8, 0.2] # train/test 60 | train_index = int(round(0.8 * total_num)) 61 | 62 | save_path = f"./data/train_split_info/koniq10k_82_seed{seed}.pkl" 63 | split_info = {} 64 | for i in range(num_splits): 65 | random.shuffle(all_img_index) 66 | split_info[i] = { 67 | "train": all_img_index[:train_index], 68 | "val": [], 69 | "test": all_img_index[train_index:], 70 | } 71 | print( 72 | "train num: {} | val num: {} | test num: {}".format( 73 | len(split_info[i]["train"]), 74 | len(split_info[i]["val"]), 75 | len(split_info[i]["test"]), 76 | ) 77 | ) 78 | with open(save_path, "wb") as sf: 79 | pickle.dump(split_info, sf) 80 | 81 | 82 | def downsample_images(down_size=384): 83 | image_path = "data/koniq10k/1024x768" 84 | meta_info_file = "data/meta_info/meta_info_KonIQ10kDataset.csv" 85 | save_image_path = "data/koniq10k/512x384" 86 | 87 | if not os.path.exists(save_image_path): 88 | os.makedirs(save_image_path, exist_ok=True) 89 | 90 | meta_info = pd.read_csv(meta_info_file) 91 | 92 | preprocess = torchvision.transforms.Compose( 93 | [torchvision.transforms.Resize(size=down_size)] 94 | ) 95 | 96 | for i in tqdm(range(meta_info.shape[0])): 97 | img_name = meta_info.loc[i]["img_name"] 98 | img = pil_loader(os.path.join(image_path, img_name)) 99 | resized_img = preprocess(img) 100 | img_name = os.path.splitext(img_name)[0] + ".png" 101 | resized_img.save(os.path.join(save_image_path, img_name)) 102 | 103 | 104 | if __name__ == "__main__": 105 | get_meta_info() 106 | get_random_splits() 107 | downsample_images() 108 | -------------------------------------------------------------------------------- /scripts/process_live.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on IQA-PyTorch: https://github.com/chaofengc/IQA-PyTorch 3 | """ 4 | import csv 5 | import os 6 | import pickle 7 | import random 8 | 9 | import pandas as pd 10 | import pyrootutils 11 | import scipy.io as sio 12 | import torchvision 13 | from tqdm import tqdm 14 | 15 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 16 | 17 | from src.utils.dataset_utils import pil_loader 18 | 19 | 20 | def get_meta_info(): 21 | root_dir = "data/live_iqa" 22 | 23 | dmos = sio.loadmat( 24 | os.path.join(root_dir, "dmos_realigned.mat") 25 | ) # difference of mos: test - ref. lower is better 26 | mos = dmos["dmos_new"][0] 27 | org_flag = dmos["orgs"][0] 28 | 29 | refnames = sio.loadmat(os.path.join(root_dir, "refnames_all.mat")) 30 | refnames = refnames["refnames_all"][0] 31 | 32 | sub_folders = ( 33 | ["jp2k"] * 227 34 | + ["jpeg"] * 233 35 | + ["wn"] * 174 36 | + ["gblur"] * 174 37 | + ["fastfading"] * 174 38 | ) 39 | sub_indexes = list(range(1, 228)) + list(range(1, 234)) + list(range(1, 175)) * 3 40 | 41 | save_meta_path = "data/meta_info/meta_info_LIVEIQADataset.csv" 42 | with open(save_meta_path, "w") as f: 43 | csvwriter = csv.writer(f) 44 | header = ["ref_name", "dist_name", "mos"] 45 | csvwriter.writerow(header) 46 | for i in range(len(sub_folders)): 47 | ref_name = f"refimgs/{refnames[i][0]}" 48 | dis_name = f"{sub_folders[i]}/img{sub_indexes[i]}.bmp" 49 | tmpmos = mos[i] 50 | if org_flag[i] != 1: 51 | csvwriter.writerow([ref_name, dis_name, tmpmos]) 52 | 53 | 54 | def get_random_splits(seed=3407): 55 | random.seed(seed) 56 | meta_info_file = "data/meta_info/meta_info_LIVEIQADataset.csv" 57 | save_path = f"data/train_split_info/live_82_seed{seed}.pkl" 58 | ratio = 0.8 59 | 60 | meta_info = pd.read_csv(meta_info_file) 61 | 62 | ref_img_list = list(set(meta_info["ref_name"].tolist())) 63 | ref_img_num = len(ref_img_list) 64 | num_splits = 10 65 | train_num = int(ratio * ref_img_num) 66 | 67 | split_info = {} 68 | for i in range(num_splits): 69 | split_info[i] = {"train": [], "val": [], "test": []} 70 | 71 | for i in range(num_splits): 72 | random.shuffle(ref_img_list) 73 | train_ref_img_names = ref_img_list[:train_num] 74 | for j in range(meta_info.shape[0]): 75 | tmp_ref_name = meta_info.loc[j]["ref_name"] 76 | if tmp_ref_name in train_ref_img_names: 77 | split_info[i]["train"].append(j) 78 | else: 79 | split_info[i]["test"].append(j) 80 | print( 81 | meta_info.shape[0], len(split_info[i]["train"]), len(split_info[i]["test"]) 82 | ) 83 | with open(save_path, "wb") as sf: 84 | pickle.dump(split_info, sf) 85 | 86 | 87 | def downsample_images(down_size=384): 88 | image_path = "data/live_iqa" 89 | meta_info_file = "data/meta_info/meta_info_LIVEIQADataset.csv" 90 | save_image_path = "data/live_iqa/images_384" 91 | 92 | if not os.path.exists(save_image_path): 93 | os.makedirs(save_image_path, exist_ok=True) 94 | 95 | meta_info = pd.read_csv(meta_info_file) 96 | 97 | preprocess = torchvision.transforms.Compose( 98 | [ 99 | torchvision.transforms.Resize( 100 | size=down_size, 101 | interpolation=torchvision.transforms.InterpolationMode.BICUBIC, 102 | ) 103 | ] 104 | ) 105 | 106 | for i in tqdm(range(meta_info.shape[0])): 107 | img_name = meta_info.loc[i]["dist_name"] 108 | img = pil_loader(os.path.join(image_path, img_name)) 109 | resized_img = preprocess(img) 110 | img_name = os.path.splitext(img_name)[0] + ".bmp" 111 | img_folder = os.path.split(img_name)[0] 112 | img_folder = os.path.join(save_image_path, img_folder) 113 | if not os.path.exists(img_folder): 114 | os.makedirs(img_folder, exist_ok=True) 115 | resized_img.save(os.path.join(save_image_path, img_name)) 116 | 117 | 118 | if __name__ == "__main__": 119 | get_meta_info() 120 | get_random_splits() 121 | downsample_images() 122 | -------------------------------------------------------------------------------- /scripts/process_livechallenge.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on IQA-PyTorch: https://github.com/chaofengc/IQA-PyTorch 3 | """ 4 | import csv 5 | import os 6 | import pickle 7 | import random 8 | 9 | import pandas as pd 10 | import pyrootutils 11 | import scipy.io as sio 12 | import torchvision 13 | from tqdm import tqdm 14 | 15 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 16 | 17 | from src.utils.dataset_utils import pil_loader 18 | 19 | 20 | def get_meta_info(): 21 | root_dir = "data/LIVEC" 22 | names = sio.loadmat(os.path.join(root_dir, "Data", "AllImages_release.mat")) 23 | mos_labels = sio.loadmat(os.path.join(root_dir, "Data", "AllMOS_release.mat")) 24 | mos_std = sio.loadmat(os.path.join(root_dir, "Data", "AllStdDev_release.mat")) 25 | 26 | img_names = names["AllImages_release"] 27 | mos_labels = mos_labels["AllMOS_release"][0] 28 | mos_std = mos_std["AllStdDev_release"][0] 29 | 30 | save_meta_path = "data/meta_info/meta_info_LIVEChallengeDataset.csv" 31 | with open(save_meta_path, "w") as f: 32 | csvwriter = csv.writer(f) 33 | header = ["img_name", "mos", "std"] 34 | csvwriter.writerow(header) 35 | for idx, name_item in enumerate(img_names): 36 | img_name = name_item[0][0] 37 | mos = mos_labels[idx] 38 | std = mos_std[idx] 39 | csvwriter.writerow([img_name, mos, std]) 40 | 41 | 42 | def get_random_splits(seed=3407): 43 | random.seed(seed) 44 | total_num = 1162 45 | all_img_index = list(range(total_num)) 46 | num_splits = 10 47 | 48 | # ratio = [0.8, 0.2] # train/test 49 | train_index = int(round(0.8 * total_num)) 50 | 51 | save_path = f"data/train_split_info/livechallenge_82_seed{seed}.pkl" 52 | split_info = {} 53 | for i in range(num_splits): 54 | random.shuffle(all_img_index) 55 | split_info[i] = { 56 | "train": all_img_index[:train_index], 57 | "val": [], 58 | "test": all_img_index[train_index:], 59 | } 60 | print( 61 | "train num: {} | val num: {} | test num: {}".format( 62 | len(split_info[i]["train"]), 63 | len(split_info[i]["val"]), 64 | len(split_info[i]["test"]), 65 | ) 66 | ) 67 | with open(save_path, "wb") as sf: 68 | pickle.dump(split_info, sf) 69 | 70 | 71 | def downsample_images(down_size=384): 72 | image_path = "data/LIVEC/Images" 73 | meta_info_file = "data/meta_info/meta_info_LIVEChallengeDataset.csv" 74 | save_image_path = "data/LIVEC/Images_384" 75 | 76 | if not os.path.exists(save_image_path): 77 | os.makedirs(save_image_path, exist_ok=True) 78 | 79 | meta_info = pd.read_csv(meta_info_file) 80 | # remove first 7 training images as previous works 81 | # https://github.com/chaofengc/IQA-PyTorch/blob/fe95923f9c48188c65666930048597b45c9046de/pyiqa/data/livechallenge_dataset.py#L38 82 | meta_info = meta_info[7:].reset_index() 83 | 84 | preprocess = torchvision.transforms.Compose( 85 | [ 86 | torchvision.transforms.Resize( 87 | size=down_size, 88 | interpolation=torchvision.transforms.InterpolationMode.BICUBIC, 89 | ) 90 | ] 91 | ) 92 | 93 | for i in tqdm(range(meta_info.shape[0])): 94 | img_name = meta_info.loc[i]["img_name"] 95 | img = pil_loader(os.path.join(image_path, img_name)) 96 | resized_img = preprocess(img) 97 | img_name = os.path.splitext(img_name)[0] + ".bmp" 98 | resized_img.save(os.path.join(save_image_path, img_name)) 99 | 100 | 101 | if __name__ == "__main__": 102 | get_meta_info() 103 | get_random_splits() 104 | downsample_images() 105 | -------------------------------------------------------------------------------- /scripts/process_spaq.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on IQA-PyTorch: https://github.com/chaofengc/IQA-PyTorch 3 | """ 4 | import csv 5 | import os 6 | import pickle 7 | import random 8 | 9 | import pandas as pd 10 | import torchvision 11 | from PIL import Image, ImageOps 12 | from tqdm import tqdm 13 | 14 | 15 | def get_meta_info(): 16 | mos_label_file = "data/spaq/Annotations/MOS and Image attribute scores.xlsx" 17 | scene_label_file = "data/spaq/Annotations/Scene category labels.xlsx" 18 | exif_label_file = "data/spaq/Annotations/EXIF_tags.xlsx" 19 | 20 | mos_label = pd.read_excel(mos_label_file) 21 | scene_label = pd.read_excel(scene_label_file) 22 | exif_label = pd.read_excel(exif_label_file) 23 | 24 | new_head = ( 25 | mos_label.keys().tolist() 26 | + scene_label.keys().tolist()[1:] 27 | + exif_label.keys().tolist()[1:] 28 | ) 29 | new_head[0] = "img_name" 30 | new_head[1] = "mos" 31 | new_head[-2] = "Time0" 32 | new_head[-1] = "Time1" 33 | 34 | save_meta_path = "data/meta_info/meta_info_SPAQDataset.csv" 35 | with open(save_meta_path, "w+") as sf: 36 | csvwriter = csv.writer(sf) 37 | csvwriter.writerow(new_head) 38 | for ridx in range(mos_label.shape[0]): 39 | mos_row = mos_label.loc[ridx].tolist() 40 | scene_row = scene_label.loc[ridx].tolist() 41 | exif_row = exif_label.loc[ridx].tolist() 42 | # print(mos_row, scene_row, exif_row) 43 | assert mos_row[0] == scene_row[0] == exif_row[0] 44 | row_label = mos_row + scene_row[1:] + exif_row[1:] 45 | csvwriter.writerow(row_label) 46 | 47 | 48 | def get_random_splits(seed=3407): 49 | random.seed(seed) 50 | total_num = 11125 51 | all_img_index = list(range(total_num)) 52 | num_splits = 10 53 | save_path = f"data/train_split_info/spaq_82_seed{seed}.pkl" 54 | 55 | # ratio = [0.8, 0.2] # train/val/test 56 | sep_index = int(round(0.8 * total_num)) 57 | 58 | split_info = {} 59 | for i in range(num_splits): 60 | random.shuffle(all_img_index) 61 | split_info[i] = { 62 | "train": all_img_index[:sep_index], 63 | "val": [], 64 | "test": all_img_index[sep_index:], 65 | } 66 | print( 67 | "train num: {} | val num: {} | test num: {}".format( 68 | len(split_info[i]["train"]), 69 | len(split_info[i]["val"]), 70 | len(split_info[i]["test"]), 71 | ) 72 | ) 73 | with open(save_path, "wb") as sf: 74 | pickle.dump(split_info, sf) 75 | 76 | 77 | def downsample_images(down_size=384): 78 | spaq_image_path = "data/spaq/TestImage" 79 | meta_info_file = "data/meta_info/meta_info_SPAQDataset.csv" 80 | save_image_path = "data/spaq/TestImage_384" 81 | 82 | if not os.path.exists(save_image_path): 83 | os.makedirs(save_image_path, exist_ok=True) 84 | 85 | meta_info = pd.read_csv(meta_info_file) 86 | 87 | preprocess = torchvision.transforms.Compose( 88 | [torchvision.transforms.Resize(size=down_size)] 89 | ) 90 | 91 | for i in tqdm(range(meta_info.shape[0])): 92 | img_name = meta_info.loc[i]["img_name"] 93 | # PIL.Image.open will rotate jpeg images 94 | # https://github.com/python-pillow/Pillow/issues/4703 95 | img = Image.open(os.path.join(spaq_image_path, img_name)) 96 | # There seems to be a bug in this function in pillow 9.4.0 97 | # but it has been fixed in pillow 9.5.0 98 | # https://github.com/python-pillow/Pillow/pull/6890 99 | # Unfortunately pillow 9.5.0 is not compatible with current env 100 | # Till I write this code, I cannot create an env with compatible 101 | # pillow 9.5.0 and torchvision by conda, so if this issue still 102 | # remains, you can use the below code (but there are still some 103 | # images being rotated) or create a new virtual env by pip then 104 | # remove try...except 105 | try: 106 | img_t = ImageOps.exif_transpose(img) 107 | except Exception: 108 | img_t = img 109 | resized_img = preprocess(img_t) 110 | img_name = os.path.splitext(img_name)[0] + ".png" 111 | resized_img.save(os.path.join(save_image_path, img_name)) 112 | 113 | 114 | if __name__ == "__main__": 115 | get_meta_info() 116 | get_random_splits() 117 | downsample_images() 118 | -------------------------------------------------------------------------------- /scripts/process_tid2013.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on IQA-PyTorch: https://github.com/chaofengc/IQA-PyTorch 3 | """ 4 | import csv 5 | import os 6 | import pickle 7 | import random 8 | 9 | import pandas as pd 10 | import pyrootutils 11 | import torchvision 12 | from tqdm import tqdm 13 | 14 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 15 | 16 | from src.utils.dataset_utils import pil_loader 17 | 18 | 19 | def get_meta_info(): 20 | root_dir = "data/tid2013" 21 | save_meta_path = "data/meta_info/meta_info_TID2013Dataset.csv" 22 | 23 | mos_file = os.path.join(root_dir, "mos_with_names.txt") 24 | std_file = os.path.join(root_dir, "mos_std.txt") 25 | 26 | mos_names = [x.strip().split() for x in open(mos_file).readlines()] 27 | std = [x.strip() for x in open(std_file).readlines()] 28 | 29 | with open(save_meta_path, "w") as f: 30 | csvwriter = csv.writer(f) 31 | header = ["ref_name", "dist_name", "mos", "std"] 32 | csvwriter.writerow(header) 33 | for idx, ((mos, name), std) in enumerate(zip(mos_names, std)): 34 | ref_name = f"I{name[1:3]}.BMP" 35 | ref_name = ref_name.replace("I25.BMP", "i25.bmp") 36 | img_path = os.path.join(root_dir, "distorted_images", name) 37 | if not os.path.exists(img_path): 38 | name = name.replace("i", "I") 39 | csvwriter.writerow([ref_name, name, mos, std]) 40 | 41 | 42 | def get_random_splits(seed=3407): 43 | random.seed(seed) 44 | meta_info_file = "data/meta_info/meta_info_TID2013Dataset.csv" 45 | save_path = f"data/train_split_info/tid2013_82_seed{seed}.pkl" 46 | ratio = 0.8 47 | 48 | meta_info = pd.read_csv(meta_info_file) 49 | 50 | ref_img_list = list(set(meta_info["ref_name"].tolist())) 51 | ref_img_num = len(ref_img_list) 52 | num_splits = 10 53 | train_num = int(ratio * ref_img_num) 54 | 55 | split_info = {} 56 | for i in range(num_splits): 57 | split_info[i] = {"train": [], "val": [], "test": []} 58 | 59 | for i in range(num_splits): 60 | random.shuffle(ref_img_list) 61 | train_ref_img_names = ref_img_list[:train_num] 62 | for j in range(meta_info.shape[0]): 63 | tmp_ref_name = meta_info.loc[j]["ref_name"] 64 | if tmp_ref_name in train_ref_img_names: 65 | split_info[i]["train"].append(j) 66 | else: 67 | split_info[i]["test"].append(j) 68 | print( 69 | meta_info.shape[0], len(split_info[i]["train"]), len(split_info[i]["test"]) 70 | ) 71 | with open(save_path, "wb") as sf: 72 | pickle.dump(split_info, sf) 73 | 74 | 75 | def downsample_images(down_size=384): 76 | image_path = "data/tid2013/distorted_images" 77 | meta_info_file = "data/meta_info/meta_info_TID2013Dataset.csv" 78 | save_image_path = "data/tid2013/distorted_images_384" 79 | 80 | if not os.path.exists(save_image_path): 81 | os.makedirs(save_image_path, exist_ok=True) 82 | 83 | meta_info = pd.read_csv(meta_info_file) 84 | 85 | preprocess = torchvision.transforms.Compose( 86 | [ 87 | torchvision.transforms.Resize( 88 | size=down_size, 89 | interpolation=torchvision.transforms.InterpolationMode.BICUBIC, 90 | ) 91 | ] 92 | ) 93 | 94 | for i in tqdm(range(meta_info.shape[0])): 95 | img_name = meta_info.loc[i]["dist_name"] 96 | img = pil_loader(os.path.join(image_path, img_name)) 97 | resized_img = preprocess(img) 98 | img_name = os.path.splitext(img_name)[0] + ".bmp" 99 | resized_img.save(os.path.join(save_image_path, img_name)) 100 | 101 | 102 | if __name__ == "__main__": 103 | get_meta_info() 104 | get_random_splits() 105 | downsample_images() 106 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length=127 3 | exclude= 4 | __pycache__ 5 | .venv 6 | .pytest_cache 7 | .github 8 | .vscode 9 | .idea 10 | ignore= 11 | # Too many leading '#' for block comment (E266) 12 | # module level import not at top of file (E402) 13 | E266,W503,E402,E203 14 | 15 | [isort] 16 | known_third_party=wandb 17 | -------------------------------------------------------------------------------- /src/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeosXu/LoDa/82304c20c34c1b5bd45f27bd7ab6e9104a285152/src/__init__.py -------------------------------------------------------------------------------- /src/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader import create_dataloader 2 | from .dataloader_mode import DataloaderMode 3 | 4 | __all__ = ["DataloaderMode", "create_dataloader"] 5 | -------------------------------------------------------------------------------- /src/dataset/dataloader.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torchvision 4 | from torch.utils.data import DataLoader 5 | from torch.utils.data.distributed import DistributedSampler 6 | 7 | from .dataloader_mode import DataloaderMode 8 | from .flive_dataset import FLIVE_Dataset 9 | from .kadid10k_dataset import KADID10k_Dataset 10 | from .koniq10k_dataset import KonIQ10k_Dataset 11 | from .live_dataset import LIVEDataset 12 | from .livechallenge_dataset import LIVEChallengeDataset 13 | from .spaq_dataset import SPAQ_Dataset 14 | from .tid2013_dataset import TID2013Dataset 15 | 16 | 17 | def get_transforms(cfg, mode): 18 | if ( 19 | cfg.data.name == "livec" 20 | or cfg.data.name == "koniq10k" 21 | or cfg.data.name == "spaq" 22 | or cfg.data.name == "flive" 23 | or cfg.data.name == "kadid10k" 24 | or cfg.data.name == "live" 25 | or cfg.data.name == "tid2013" 26 | ): 27 | if mode is DataloaderMode.train: 28 | transforms = torchvision.transforms.Compose( 29 | [ 30 | torchvision.transforms.RandomHorizontalFlip(), 31 | torchvision.transforms.RandomVerticalFlip(), 32 | torchvision.transforms.RandomCrop(size=cfg.data.patch_size), 33 | torchvision.transforms.ToTensor(), 34 | torchvision.transforms.Normalize( 35 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 36 | ), 37 | ] 38 | ) 39 | elif mode is DataloaderMode.val or mode is DataloaderMode.test: 40 | transforms = torchvision.transforms.Compose( 41 | [ 42 | torchvision.transforms.RandomCrop(size=cfg.data.patch_size), 43 | torchvision.transforms.ToTensor(), 44 | torchvision.transforms.Normalize( 45 | mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225) 46 | ), 47 | ] 48 | ) 49 | else: 50 | raise ValueError(f"invalid dataloader mode {mode}") 51 | else: 52 | raise ValueError(f"invalid dataset name {cfg.data.name}") 53 | 54 | return transforms 55 | 56 | 57 | def get_dataset(cfg, mode, split_index=0): 58 | transforms = get_transforms(cfg=cfg, mode=mode) 59 | 60 | # prepare data index 61 | split_index = cfg.split_index 62 | with open(cfg.data.train_test_split_file, "rb") as f: 63 | split_idx = pickle.load(f) 64 | train_idx = split_idx[split_index]["train"] 65 | val_idx = split_idx[split_index]["val"] 66 | test_idx = split_idx[split_index]["test"] 67 | 68 | if mode is DataloaderMode.train: 69 | img_idx = train_idx 70 | elif mode is DataloaderMode.val: 71 | img_idx = val_idx 72 | elif mode is DataloaderMode.test: 73 | img_idx = test_idx 74 | else: 75 | raise ValueError(f"invalid dataloader mode {mode}") 76 | 77 | if cfg.data.name == "koniq10k": 78 | dataset = KonIQ10k_Dataset( 79 | cfg=cfg, index=img_idx, transform=transforms, mode=mode 80 | ) 81 | elif cfg.data.name == "livec": 82 | dataset = LIVEChallengeDataset( 83 | cfg=cfg, index=img_idx, transform=transforms, mode=mode 84 | ) 85 | elif cfg.data.name == "spaq": 86 | dataset = SPAQ_Dataset(cfg=cfg, index=img_idx, transform=transforms, mode=mode) 87 | elif cfg.data.name == "flive": 88 | dataset = FLIVE_Dataset(cfg=cfg, index=img_idx, transform=transforms, mode=mode) 89 | elif cfg.data.name == "kadid10k": 90 | dataset = KADID10k_Dataset( 91 | cfg=cfg, index=img_idx, transform=transforms, mode=mode 92 | ) 93 | elif cfg.data.name == "live": 94 | dataset = LIVEDataset(cfg=cfg, index=img_idx, transform=transforms, mode=mode) 95 | elif cfg.data.name == "tid2013": 96 | dataset = TID2013Dataset( 97 | cfg=cfg, index=img_idx, transform=transforms, mode=mode 98 | ) 99 | else: 100 | raise ValueError(f"invalid dataset name {cfg.data.name}") 101 | 102 | return dataset 103 | 104 | 105 | def create_dataloader(cfg, mode, rank, split_index=0): 106 | data_loader = DataLoader 107 | dataset = get_dataset(cfg=cfg, mode=mode, split_index=split_index) 108 | train_use_shuffle = True 109 | sampler = None 110 | if ( 111 | cfg.dist.device == "cuda" 112 | and cfg.dist.gpus != 0 113 | and cfg.data.divide_dataset_per_gpu 114 | ): 115 | sampler = DistributedSampler(dataset, cfg.dist.gpus, rank) 116 | train_use_shuffle = False 117 | if mode is DataloaderMode.train: 118 | return ( 119 | data_loader( 120 | dataset=dataset, 121 | batch_size=cfg.train.batch_size, 122 | shuffle=train_use_shuffle, 123 | sampler=sampler, 124 | num_workers=cfg.train.num_workers, 125 | pin_memory=True, 126 | drop_last=True, 127 | ), 128 | sampler, 129 | ) 130 | elif mode is DataloaderMode.test: 131 | return ( 132 | data_loader( 133 | dataset=dataset, 134 | batch_size=cfg.test.batch_size, 135 | shuffle=False, 136 | sampler=sampler, 137 | num_workers=cfg.test.num_workers, 138 | pin_memory=True, 139 | drop_last=False, 140 | ), 141 | sampler, 142 | ) 143 | else: 144 | raise ValueError(f"invalid dataloader mode {mode}") 145 | -------------------------------------------------------------------------------- /src/dataset/dataloader_mode.py: -------------------------------------------------------------------------------- 1 | from enum import Enum, auto 2 | 3 | 4 | class DataloaderMode(Enum): 5 | train = auto() 6 | val = auto() 7 | test = auto() 8 | inference = auto() 9 | -------------------------------------------------------------------------------- /src/dataset/flive_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch.utils.data as data 5 | 6 | from src.utils.dataset_utils import pil_loader 7 | 8 | from .dataloader_mode import DataloaderMode 9 | 10 | 11 | class FLIVE_Dataset(data.Dataset): 12 | def __init__(self, cfg, index, transform, mode): 13 | self.root = cfg.data.root 14 | meta_info = pd.read_csv(cfg.data.meta_info_file) 15 | 16 | if mode is DataloaderMode.train: 17 | patch_num = cfg.train.patch_num 18 | elif mode is DataloaderMode.val: 19 | patch_num = cfg.val.patch_num 20 | elif mode is DataloaderMode.test: 21 | patch_num = cfg.test.patch_num 22 | else: 23 | raise ValueError(f"invalid dataloader mode {mode}") 24 | 25 | sample = [] 26 | for idx in index: 27 | img_name = meta_info.loc[idx]["img_name"] 28 | img_name = os.path.splitext(img_name)[0] + ".png" 29 | img_path = os.path.join("flive_384", img_name) 30 | label = meta_info.loc[idx]["mos"] 31 | for _ in range(patch_num): 32 | sample.append((img_path, label)) 33 | 34 | self.samples = sample 35 | self.transform = transform 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | 42 | Returns: 43 | tuple: (sample, target) where target is class_index of the target class. 44 | """ 45 | path, target = self.samples[index] 46 | img = pil_loader(os.path.join(self.root, path)) 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | # if there are more than one image or more than one target 51 | # can organize it as 52 | # return [img1, img2], [targe1, target2] 53 | return img, target 54 | 55 | def __len__(self): 56 | length = len(self.samples) 57 | return length 58 | -------------------------------------------------------------------------------- /src/dataset/kadid10k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch.utils.data as data 5 | 6 | from src.utils.dataset_utils import pil_loader 7 | 8 | from .dataloader_mode import DataloaderMode 9 | 10 | 11 | class KADID10k_Dataset(data.Dataset): 12 | def __init__(self, cfg, index, transform, mode): 13 | self.root = cfg.data.root 14 | meta_info = pd.read_csv(cfg.data.meta_info_file) 15 | 16 | if mode is DataloaderMode.train: 17 | patch_num = cfg.train.patch_num 18 | elif mode is DataloaderMode.val: 19 | patch_num = cfg.val.patch_num 20 | elif mode is DataloaderMode.test: 21 | patch_num = cfg.test.patch_num 22 | else: 23 | raise ValueError(f"invalid dataloader mode {mode}") 24 | 25 | sample = [] 26 | for idx in index: 27 | img_name = meta_info.loc[idx]["dist_name"] 28 | img_name = os.path.splitext(img_name)[0] + ".png" 29 | img_path = os.path.join("images_384", img_name) 30 | label = meta_info.loc[idx]["dmos"] 31 | for _ in range(patch_num): 32 | sample.append((img_path, label)) 33 | 34 | self.samples = sample 35 | self.transform = transform 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | 42 | Returns: 43 | tuple: (sample, target) where target is class_index of the target class. 44 | """ 45 | path, target = self.samples[index] 46 | img = pil_loader(os.path.join(self.root, path)) 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | # if there are more than one image or more than one target 51 | # can organize it as 52 | # return [img1, img2], [targe1, target2] 53 | return img, target 54 | 55 | def __len__(self): 56 | length = len(self.samples) 57 | return length 58 | -------------------------------------------------------------------------------- /src/dataset/koniq10k_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch.utils.data as data 5 | 6 | from src.utils.dataset_utils import pil_loader 7 | 8 | from .dataloader_mode import DataloaderMode 9 | 10 | 11 | class KonIQ10k_Dataset(data.Dataset): 12 | def __init__(self, cfg, index, transform, mode): 13 | self.root = cfg.data.root 14 | meta_info = pd.read_csv(cfg.data.meta_info_file) 15 | 16 | if mode is DataloaderMode.train: 17 | patch_num = cfg.train.patch_num 18 | elif mode is DataloaderMode.val: 19 | patch_num = cfg.val.patch_num 20 | elif mode is DataloaderMode.test: 21 | patch_num = cfg.test.patch_num 22 | else: 23 | raise ValueError(f"invalid dataloader mode {mode}") 24 | 25 | sample = [] 26 | for idx in index: 27 | img_name = meta_info.loc[idx]["img_name"] 28 | img_name = os.path.splitext(img_name)[0] + ".png" 29 | img_path = os.path.join("512x384", img_name) 30 | label = meta_info.loc[idx]["mos"] 31 | for _ in range(patch_num): 32 | sample.append((img_path, label)) 33 | 34 | self.samples = sample 35 | self.transform = transform 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | 42 | Returns: 43 | tuple: (sample, target) where target is class_index of the target class. 44 | """ 45 | path, target = self.samples[index] 46 | img = pil_loader(os.path.join(self.root, path)) 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | # if there are more than one image or more than one target 51 | # can organize it as 52 | # return [img1, img2], [targe1, target2] 53 | return img, target 54 | 55 | def __len__(self): 56 | length = len(self.samples) 57 | return length 58 | -------------------------------------------------------------------------------- /src/dataset/live_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch.utils.data as data 5 | 6 | from src.utils.dataset_utils import pil_loader 7 | 8 | from .dataloader_mode import DataloaderMode 9 | 10 | 11 | class LIVEDataset(data.Dataset): 12 | def __init__(self, cfg, index, transform, mode): 13 | self.root = cfg.data.root 14 | meta_info = pd.read_csv(cfg.data.meta_info_file) 15 | 16 | if mode is DataloaderMode.train: 17 | patch_num = cfg.train.patch_num 18 | elif mode is DataloaderMode.val: 19 | patch_num = cfg.val.patch_num 20 | elif mode is DataloaderMode.test: 21 | patch_num = cfg.test.patch_num 22 | else: 23 | raise ValueError(f"invalid dataloader mode {mode}") 24 | 25 | sample = [] 26 | for idx in index: 27 | img_name = meta_info.loc[idx]["dist_name"] 28 | img_name = os.path.splitext(img_name)[0] + ".bmp" 29 | img_path = os.path.join("images_384", img_name) 30 | label = meta_info.loc[idx]["mos"] 31 | for _ in range(patch_num): 32 | sample.append((img_path, label)) 33 | 34 | self.samples = sample 35 | self.transform = transform 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | 42 | Returns: 43 | tuple: (sample, target) where target is class_index of the target class. 44 | """ 45 | path, target = self.samples[index] 46 | img = pil_loader(os.path.join(self.root, path)) 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | # if there are more than one image or more than one target 51 | # can organize it as 52 | # return [img1, img2], [targe1, target2] 53 | return img, target 54 | 55 | def __len__(self): 56 | length = len(self.samples) 57 | return length 58 | -------------------------------------------------------------------------------- /src/dataset/livechallenge_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch.utils.data as data 5 | 6 | from src.utils.dataset_utils import pil_loader 7 | 8 | from .dataloader import DataloaderMode 9 | 10 | 11 | class LIVEChallengeDataset(data.Dataset): 12 | def __init__(self, cfg, index, transform, mode): 13 | self.root = cfg.data.root 14 | meta_info = pd.read_csv(cfg.data.meta_info_file) 15 | # remove first 7 training images as previous works 16 | # https://github.com/chaofengc/IQA-PyTorch/blob/fe95923f9c48188c65666930048597b45c9046de/pyiqa/data/livechallenge_dataset.py#L38 17 | meta_info = meta_info[7:].reset_index() 18 | 19 | if mode is DataloaderMode.train: 20 | patch_num = cfg.train.patch_num 21 | elif mode is DataloaderMode.val: 22 | patch_num = cfg.val.patch_num 23 | elif mode is DataloaderMode.test: 24 | patch_num = cfg.test.patch_num 25 | else: 26 | raise ValueError(f"invalid dataloader mode {mode}") 27 | 28 | sample = [] 29 | for idx in index: 30 | img_name = meta_info.loc[idx]["img_name"] 31 | img_path = os.path.join("Images", img_name) 32 | label = meta_info.loc[idx]["mos"] 33 | for _ in range(patch_num): 34 | sample.append((img_path, label)) 35 | 36 | self.samples = sample 37 | self.transform = transform 38 | 39 | def __getitem__(self, index): 40 | """ 41 | Args: 42 | index (int): Index 43 | 44 | Returns: 45 | tuple: (sample, target) where target is class_index of the target class. 46 | """ 47 | path, target = self.samples[index] 48 | img = pil_loader(os.path.join(self.root, path)) 49 | if self.transform is not None: 50 | img = self.transform(img) 51 | 52 | return img, target 53 | 54 | def __len__(self): 55 | length = len(self.samples) 56 | return length 57 | -------------------------------------------------------------------------------- /src/dataset/spaq_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch.utils.data as data 5 | 6 | from src.utils.dataset_utils import pil_loader 7 | 8 | from .dataloader_mode import DataloaderMode 9 | 10 | 11 | class SPAQ_Dataset(data.Dataset): 12 | def __init__(self, cfg, index, transform, mode): 13 | self.root = cfg.data.root 14 | meta_info = pd.read_csv(cfg.data.meta_info_file) 15 | 16 | if mode is DataloaderMode.train: 17 | patch_num = cfg.train.patch_num 18 | elif mode is DataloaderMode.val: 19 | patch_num = cfg.val.patch_num 20 | elif mode is DataloaderMode.test: 21 | patch_num = cfg.test.patch_num 22 | else: 23 | raise ValueError(f"invalid dataloader mode {mode}") 24 | 25 | sample = [] 26 | for idx in index: 27 | img_name = meta_info.loc[idx]["img_name"] 28 | img_name = os.path.splitext(img_name)[0] + ".png" 29 | img_path = os.path.join("TestImage_384", img_name) 30 | label = meta_info.loc[idx]["mos"] 31 | for _ in range(patch_num): 32 | sample.append((img_path, label)) 33 | 34 | self.samples = sample 35 | self.transform = transform 36 | 37 | def __getitem__(self, index): 38 | """ 39 | Args: 40 | index (int): Index 41 | 42 | Returns: 43 | tuple: (sample, target) where target is class_index of the target class. 44 | """ 45 | path, target = self.samples[index] 46 | img = pil_loader(os.path.join(self.root, path)) 47 | if self.transform is not None: 48 | img = self.transform(img) 49 | 50 | # if there are more than one image or more than one target 51 | # can organize it as 52 | # return [img1, img2], [targe1, target2] 53 | return img, target 54 | 55 | def __len__(self): 56 | length = len(self.samples) 57 | return length 58 | -------------------------------------------------------------------------------- /src/dataset/tid2013_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import pandas as pd 4 | import torch.utils.data as data 5 | 6 | from src.utils.dataset_utils import pil_loader 7 | 8 | from .dataloader_mode import DataloaderMode 9 | 10 | 11 | class TID2013Dataset(data.Dataset): 12 | def __init__(self, cfg, index, transform, mode): 13 | self.root = cfg.data.root 14 | meta_info = pd.read_csv(cfg.data.meta_info_file) 15 | 16 | if mode is DataloaderMode.train: 17 | patch_num = cfg.train.patch_num 18 | elif mode is DataloaderMode.val: 19 | patch_num = cfg.val.patch_num 20 | elif mode is DataloaderMode.test: 21 | patch_num = cfg.test.patch_num 22 | else: 23 | raise ValueError(f"invalid dataloader mode {mode}") 24 | 25 | sample = [] 26 | for idx in index: 27 | img_name = meta_info.loc[idx]["dist_name"] 28 | img_path = os.path.join("distorted_images", img_name) 29 | label = meta_info.loc[idx]["mos"] 30 | for _ in range(patch_num): 31 | sample.append((img_path, label)) 32 | 33 | self.samples = sample 34 | self.transform = transform 35 | 36 | def __getitem__(self, index): 37 | """ 38 | Args: 39 | index (int): Index 40 | 41 | Returns: 42 | tuple: (sample, target) where target is class_index of the target class. 43 | """ 44 | path, target = self.samples[index] 45 | img = pil_loader(os.path.join(self.root, path)) 46 | if self.transform is not None: 47 | img = self.transform(img) 48 | 49 | # if there are more than one image or more than one target 50 | # can organize it as 51 | # return [img1, img2], [targe1, target2] 52 | return img, target 53 | 54 | def __len__(self): 55 | length = len(self.samples) 56 | return length 57 | -------------------------------------------------------------------------------- /src/eval.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import os 3 | import random 4 | import traceback 5 | 6 | import hydra 7 | import pyrootutils 8 | import torch 9 | import torch.distributed as dist 10 | import torch.multiprocessing as mp 11 | from hydra.core.hydra_config import HydraConfig 12 | from omegaconf import DictConfig, OmegaConf, open_dict 13 | 14 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 15 | 16 | from src.dataset import DataloaderMode, create_dataloader 17 | from src.model import Model, create_model 18 | from src.tools.test_model import test_model 19 | from src.utils.loss import get_loss 20 | from src.utils.utils import get_logger, is_logging_process, set_random_seed 21 | from src.utils.writer import Writer 22 | 23 | 24 | def setup(cfg, rank): 25 | os.environ["MASTER_ADDR"] = cfg.dist.master_addr 26 | os.environ["MASTER_PORT"] = cfg.dist.master_port 27 | timeout_sec = 1800 28 | if cfg.dist.timeout is not None: 29 | os.environ["NCCL_BLOCKING_WAIT"] = "1" 30 | timeout_sec = cfg.dist.timeout 31 | timeout = datetime.timedelta(seconds=timeout_sec) 32 | 33 | # initialize the process group 34 | dist.init_process_group( 35 | cfg.dist.mode, 36 | rank=rank, 37 | world_size=cfg.dist.gpus, 38 | timeout=timeout, 39 | ) 40 | 41 | 42 | def cleanup(): 43 | dist.destroy_process_group() 44 | 45 | 46 | def distributed_run(fn, cfg): 47 | mp.spawn(fn, args=(cfg,), nprocs=cfg.dist.gpus, join=True) 48 | 49 | 50 | def test_loop(rank, cfg): 51 | logger = get_logger(cfg, os.path.basename(__file__)) 52 | if cfg.dist.device == "cuda" and cfg.dist.gpus != 0: 53 | cfg.dist.device = rank 54 | setup(cfg, rank) 55 | torch.cuda.set_device(cfg.dist.device) 56 | 57 | if not OmegaConf.has_resolver("eval"): 58 | OmegaConf.register_new_resolver("eval", eval) 59 | 60 | # setup writer 61 | if is_logging_process(): 62 | # set log/checkpoint dir 63 | os.makedirs(cfg.log.chkpt_dir, exist_ok=True) 64 | # set writer (tensorboard / wandb) 65 | writer = Writer(cfg, "tensorboard") 66 | cfg_str = OmegaConf.to_yaml(cfg) 67 | logger.info("Config:\n" + cfg_str) 68 | if cfg.data.root == "": 69 | logger.error("test data directory cannot be empty.") 70 | raise Exception("Please specify directories of data") 71 | logger.info("Set up test process") 72 | else: 73 | writer = None 74 | 75 | # make dataloader 76 | if is_logging_process(): 77 | logger.info("Making test dataloader...") 78 | test_loader, _ = create_dataloader(cfg, DataloaderMode.test, rank) 79 | 80 | # init Model 81 | net_arch = create_model(cfg=cfg) 82 | loss_f = get_loss(cfg=cfg) 83 | model = Model(cfg, net_arch, loss_f, rank) 84 | 85 | # load training state / network checkpoint 86 | assert cfg.load.network_chkpt_path is not None 87 | model.load_network() 88 | 89 | try: 90 | test_model(cfg, model, test_loader, writer) 91 | if is_logging_process(): 92 | logger.info("End of Test") 93 | except Exception: 94 | if is_logging_process(): 95 | logger.error(traceback.format_exc()) 96 | else: 97 | traceback.print_exc() 98 | finally: 99 | if cfg.dist.device == "cuda" and cfg.dist.gpus != 0: 100 | cleanup() 101 | 102 | 103 | @hydra.main(version_base=None, config_path="../config", config_name="default") 104 | def main(hydra_cfg: DictConfig): 105 | hydra_cfg.dist.device = hydra_cfg.dist.device.lower() 106 | with open_dict(hydra_cfg): 107 | hydra_cfg.job_logging_cfg = HydraConfig.get().job_logging 108 | hydra_cfg.hydra_output_dir = HydraConfig.get().run.dir 109 | # random seed 110 | if hydra_cfg.random_seed is None: 111 | hydra_cfg.random_seed = random.randint(1, 10000) 112 | set_random_seed(hydra_cfg.random_seed) 113 | 114 | if hydra_cfg.dist.device == "cuda" and hydra_cfg.dist.gpus < 0: 115 | hydra_cfg.dist.gpus = torch.cuda.device_count() 116 | if hydra_cfg.dist.device == "cpu" or hydra_cfg.dist.gpus == 0: 117 | hydra_cfg.dist.gpus = 0 118 | test_loop(0, hydra_cfg) 119 | else: 120 | distributed_run(test_loop, hydra_cfg) 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /src/model/__init__.py: -------------------------------------------------------------------------------- 1 | from .model import Model 2 | from .model_dispatcher import create_model 3 | 4 | __all__ = ["Model", "create_model"] 5 | -------------------------------------------------------------------------------- /src/model/loda.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on 3 | timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm 4 | """ 5 | 6 | from collections import OrderedDict 7 | 8 | import timm 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | from timm.models.layers import trunc_normal_ 13 | from timm.models.vision_transformer import Mlp, VisionTransformer 14 | 15 | from .patch_embed import PatchEmbed 16 | 17 | 18 | class NormedLinear(nn.Module): 19 | def __init__(self, in_features, out_features): 20 | super(NormedLinear, self).__init__() 21 | self.weight = nn.Parameter(torch.Tensor(in_features, out_features)) 22 | self.weight.data.uniform_(-1, 1).renorm_(2, 1, 1e-5).mul_(1e5) 23 | 24 | def forward(self, x): 25 | cosine = F.normalize(x, dim=1).mm(F.normalize(self.weight, dim=0)) 26 | return cosine 27 | 28 | 29 | class CrossAttention(nn.Module): 30 | def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0.0, proj_drop=0.0): 31 | super().__init__() 32 | assert dim % num_heads == 0, "dim should be divisible by num_heads" 33 | self.num_heads = num_heads 34 | head_dim = dim // num_heads 35 | self.scale = head_dim**-0.5 36 | 37 | self.q = nn.Linear(dim, dim, bias=qkv_bias) 38 | self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias) 39 | self.attn_drop = nn.Dropout(attn_drop) 40 | self.proj = nn.Linear(dim, dim) 41 | self.proj_drop = nn.Dropout(proj_drop) 42 | 43 | def forward(self, query, image_token): 44 | B, N, C = image_token.shape 45 | kv = ( 46 | self.kv(image_token) 47 | .reshape(B, N, 2, self.num_heads, C // self.num_heads) 48 | .permute(2, 0, 3, 1, 4) 49 | ) 50 | k, v = kv.unbind(0) # make torchscript happy (cannot use tensor as tuple) 51 | 52 | B, N, C = query.shape 53 | q = ( 54 | self.q(query) 55 | .reshape(B, N, 1, self.num_heads, C // self.num_heads) 56 | .permute(2, 0, 3, 1, 4) 57 | ) 58 | q = q[0] 59 | 60 | attn = (q @ k.transpose(-2, -1)) * self.scale 61 | attn = attn.softmax(dim=-1) 62 | attn = self.attn_drop(attn) 63 | 64 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 65 | x = self.proj(x) 66 | x = self.proj_drop(x) 67 | return x 68 | 69 | 70 | class Learner(nn.Module): 71 | def __init__(self, cfg): 72 | super().__init__() 73 | 74 | num_classes = cfg.model.learner_param.num_classes 75 | embed_dim = cfg.model.learner_param.embed_dim 76 | feature_channels = cfg.model.learner_param.feature_channels 77 | cnn_feature_num = cfg.model.learner_param.cnn_feature_num 78 | interaction_block_num = cfg.model.learner_param.interaction_block_num 79 | latent_dim = cfg.model.learner_param.latent_dim 80 | grid_size = cfg.model.learner_param.grid_size 81 | cross_attn_num_heads = cfg.model.learner_param.cross_attn_num_heads 82 | 83 | # hyper net 84 | self.conv = nn.ModuleList( 85 | [ 86 | nn.Sequential( 87 | *[ 88 | nn.Conv2d( 89 | feature_channels[i], 90 | latent_dim, 91 | kernel_size=1, 92 | stride=1, 93 | padding=0, 94 | bias=True, 95 | ), 96 | nn.GELU(), 97 | nn.Conv2d( 98 | latent_dim, 99 | latent_dim, 100 | kernel_size=3, 101 | stride=1, 102 | padding=1, 103 | bias=False, 104 | ), 105 | nn.AdaptiveAvgPool2d(grid_size), 106 | ] 107 | ) 108 | for i in range(cnn_feature_num) 109 | ] 110 | ) 111 | self.down_proj = nn.ModuleList( 112 | [ 113 | Mlp( 114 | in_features=embed_dim, 115 | hidden_features=latent_dim, 116 | out_features=latent_dim, 117 | ) 118 | for _ in range(interaction_block_num) 119 | ] 120 | ) 121 | self.cross_attn = nn.ModuleList( 122 | [ 123 | CrossAttention(dim=latent_dim, num_heads=cross_attn_num_heads) 124 | for _ in range(interaction_block_num) 125 | ] 126 | ) 127 | self.up_proj = nn.ModuleList( 128 | [ 129 | Mlp( 130 | in_features=latent_dim, 131 | hidden_features=embed_dim, 132 | out_features=embed_dim, 133 | ) 134 | for _ in range(interaction_block_num) 135 | ] 136 | ) 137 | self.scale_factor = nn.Parameter( 138 | torch.randn(interaction_block_num, embed_dim) * 0.02 139 | ) 140 | 141 | # new head 142 | self.head = NormedLinear(embed_dim, num_classes) 143 | 144 | self._init_parameters() 145 | 146 | def _init_parameters(self): 147 | trunc_normal_(self.scale_factor, std=0.02) 148 | 149 | def forward(self, x): 150 | return self.head(x) 151 | 152 | 153 | class LoDa(VisionTransformer): 154 | def __init__( 155 | self, 156 | cfg=None, 157 | embed_layer=PatchEmbed, 158 | basic_state_dict=None, 159 | *argv, 160 | **karg, 161 | ): 162 | # Recreate ViT 163 | super().__init__( 164 | embed_layer=embed_layer, 165 | *argv, 166 | **karg, 167 | **(cfg.model.vit_param), 168 | ) 169 | 170 | # load basic state_dict 171 | if basic_state_dict is not None: 172 | self.load_state_dict(basic_state_dict, False) 173 | 174 | self.learner = Learner(cfg) 175 | self.dropout = nn.Dropout(cfg.model.hyper_vit.dropout_rate) 176 | self.head = nn.Identity() 177 | 178 | # feature_extraction model 179 | self.feature_model = timm.create_model( 180 | cfg.model.feature_model.name, 181 | pretrained=cfg.model.feature_model.load_timm_model, 182 | features_only=True, 183 | out_indices=cfg.model.feature_model.out_indices, 184 | ) 185 | 186 | def freeze(self): 187 | for param in self.parameters(): 188 | param.requires_grad = False 189 | 190 | for param in self.learner.parameters(): 191 | param.requires_grad = True 192 | 193 | def un_freeze(self): 194 | for param in self.parameters(): 195 | param.requires_grad = True 196 | 197 | def obtain_state_to_save(self): 198 | feature_model_state_dict = self.feature_model.state_dict() 199 | bn_buffer = OrderedDict() 200 | for key, value in feature_model_state_dict.items(): 201 | if ( 202 | "running_mean" in key 203 | or "running_var" in key 204 | or "num_batches_tracked" in key 205 | ): 206 | bn_buffer[key] = value 207 | 208 | state_dict_to_save = { 209 | "learner": self.learner.state_dict(), 210 | "bn_buffer": bn_buffer, 211 | } 212 | return state_dict_to_save 213 | 214 | def load_saved_state(self, saved_state_dict, strict=False): 215 | self.learner.load_state_dict(saved_state_dict["learner"], strict) 216 | self.feature_model.load_state_dict(saved_state_dict["bn_buffer"], strict) 217 | 218 | def forward_hyper_net(self, x): 219 | batch_size = x.shape[0] 220 | features_list = self.feature_model(x) 221 | 222 | cnn_token_list = [] 223 | for i in range(len(features_list)): 224 | cnn_image_token = self.learner.conv[i](features_list[i]) 225 | latent_dim = cnn_image_token.shape[1] 226 | cnn_image_token = cnn_image_token.permute(0, 2, 3, 1).reshape( 227 | batch_size, -1, latent_dim 228 | ) 229 | cnn_token_list.append(cnn_image_token) 230 | 231 | return torch.cat(cnn_token_list, dim=1) 232 | 233 | def forward_features(self, x, cnn_tokens): 234 | x = self.patch_embed(x) 235 | cls_token = self.cls_token.expand(x.shape[0], -1, -1) 236 | 237 | # concatenate CLS token 238 | x = torch.cat((cls_token, x), dim=1) 239 | x = self.pos_drop(x + self.pos_embed) 240 | 241 | for i in range(len(self.blocks)): 242 | x_down = self.learner.down_proj[i](x) 243 | x_down = x_down + self.learner.cross_attn[i](x_down, cnn_tokens) 244 | x_up = self.learner.up_proj[i](x_down) 245 | x = x + x_up * self.learner.scale_factor[i] 246 | x = self.blocks[i](x) 247 | 248 | x = self.norm(x) 249 | return x[:, 0, :] 250 | 251 | def forward(self, x): 252 | cnn_tokens = self.forward_hyper_net(x) 253 | x = self.forward_features(x, cnn_tokens) 254 | 255 | x = self.dropout(x) 256 | x = self.learner(x) 257 | x = self.head(x) 258 | return x 259 | -------------------------------------------------------------------------------- /src/model/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from collections import OrderedDict 4 | 5 | import torch 6 | import torch.nn 7 | import wandb 8 | from omegaconf import OmegaConf 9 | from torch.nn.parallel import DistributedDataParallel as DDP 10 | 11 | from src.utils.utils import get_logger, is_logging_process 12 | 13 | 14 | class Model: 15 | def __init__(self, cfg, net_arch, loss_f, rank=0): 16 | # prepare var 17 | self.cfg = cfg 18 | self.device = self.cfg.dist.device 19 | self.net = net_arch.to(self.device) 20 | self.rank = rank 21 | if self.device != "cpu" and self.cfg.dist.gpus != 0: 22 | self.net = DDP(self.net, device_ids=[self.rank]) 23 | self.step = 0 24 | self.epoch = -1 25 | self._logger = get_logger(cfg, os.path.basename(__file__)) 26 | 27 | # init optimizer 28 | optimizer_mode = self.cfg.optimizer.name 29 | if optimizer_mode == "adam": 30 | self.optimizer = torch.optim.Adam( 31 | self.net.parameters(), **(self.cfg.optimizer.param) 32 | ) 33 | elif optimizer_mode == "adamW": 34 | self.optimizer = torch.optim.AdamW( 35 | self.net.parameters(), **(self.cfg.optimizer.param) 36 | ) 37 | else: 38 | raise Exception("%s optimizer not supported" % optimizer_mode) 39 | 40 | # init scheduler 41 | scheduler_mode = self.cfg.scheduler.name 42 | if scheduler_mode == "CosineAnnealingLR": 43 | self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( 44 | self.optimizer, **(self.cfg.scheduler.param) 45 | ) 46 | else: 47 | raise Exception("%s scheduler not supported" % scheduler_mode) 48 | 49 | # init loss 50 | self.loss_f = loss_f 51 | self.log = OmegaConf.create() 52 | self.log.loss_v = 0 53 | 54 | def optimize_parameters(self, model_input, model_target): 55 | self.net.train() 56 | self.optimizer.zero_grad() 57 | output = self.run_network(model_input) 58 | loss_v = self.compute_loss(output, model_target) 59 | loss_v.backward() 60 | self.optimizer.step() 61 | self.scheduler.step() 62 | # set log 63 | self.log.loss_v = loss_v.detach().item() 64 | 65 | def compute_loss(self, model_output, model_target): 66 | loss_v = 0.0 67 | for fn, weight in self.loss_f: 68 | loss_v += weight * fn( 69 | model_output, model_target.unsqueeze(1).to(self.device) 70 | ) 71 | return loss_v 72 | 73 | def inference(self, model_input): 74 | self.net.eval() 75 | output = self.run_network(model_input) 76 | return output 77 | 78 | def run_network(self, model_input): 79 | model_input = model_input.to(self.device) 80 | output = self.net(model_input) 81 | return output 82 | 83 | def save_network(self, save_file=True): 84 | if is_logging_process(): 85 | net = self.net.module if isinstance(self.net, DDP) else self.net 86 | state_dict = net.obtain_state_to_save() 87 | for module_name, module_param in state_dict.items(): 88 | if isinstance(module_param, torch.Tensor): 89 | state_dict[module_name] = module_param.to("cpu") 90 | else: 91 | for key, param in module_param.items(): 92 | state_dict[module_name][key] = param.to("cpu") 93 | if save_file: 94 | save_filename = "%s_%d.pt" % (self.cfg.name, self.step) 95 | save_path = osp.join(self.cfg.log.chkpt_dir, save_filename) 96 | torch.save(state_dict, save_path) 97 | if self.cfg.log.use_wandb and self.cfg.log.wandb_save_model: 98 | wandb.save(save_path) 99 | if is_logging_process(): 100 | self._logger.info("Saved network checkpoint to: %s" % save_path) 101 | return state_dict 102 | 103 | def load_network(self, loaded_net=None): 104 | add_log = False 105 | if loaded_net is None: 106 | add_log = True 107 | if self.cfg.load.wandb_load_path is not None: 108 | self.cfg.load.network_chkpt_path = wandb.restore( 109 | self.cfg.load.network_chkpt_path, 110 | run_path=self.cfg.load.wandb_load_path, 111 | ).name 112 | loaded_net = torch.load( 113 | self.cfg.load.network_chkpt_path, 114 | map_location=torch.device(self.device), 115 | ) 116 | loaded_clean_net = OrderedDict() # remove unnecessary 'module.' 117 | for k, v in loaded_net.items(): 118 | if k.startswith("module."): 119 | loaded_clean_net[k[7:]] = v 120 | else: 121 | loaded_clean_net[k] = v 122 | 123 | if isinstance(self.net, DDP): 124 | self.net.module.load_saved_state( 125 | loaded_clean_net, strict=self.cfg.load.strict_load 126 | ) 127 | else: 128 | self.net.load_saved_state( 129 | loaded_clean_net, strict=self.cfg.load.strict_load 130 | ) 131 | if is_logging_process() and add_log: 132 | self._logger.info( 133 | "Checkpoint %s is loaded" % self.cfg.load.network_chkpt_path 134 | ) 135 | 136 | def save_training_state(self): 137 | if is_logging_process(): 138 | save_filename = "%s_%d.state" % (self.cfg.name, self.step) 139 | save_path = osp.join(self.cfg.log.chkpt_dir, save_filename) 140 | net_state_dict = self.save_network(False) 141 | state = { 142 | "model": net_state_dict, 143 | "optimizer": self.optimizer.state_dict(), 144 | "scheduler": self.scheduler.state_dict(), 145 | "step": self.step, 146 | "epoch": self.epoch, 147 | } 148 | torch.save(state, save_path) 149 | if self.cfg.log.use_wandb and self.cfg.log.wandb_save_model: 150 | wandb.save(save_path) 151 | if is_logging_process(): 152 | self._logger.info("Saved training state to: %s" % save_path) 153 | 154 | def load_training_state(self): 155 | if self.cfg.load.wandb_load_path is not None: 156 | self.cfg.load.resume_state_path = wandb.restore( 157 | self.cfg.load.resume_state_path, 158 | run_path=self.cfg.load.wandb_load_path, 159 | ).name 160 | resume_state = torch.load( 161 | self.cfg.load.resume_state_path, 162 | map_location=torch.device(self.device), 163 | ) 164 | 165 | self.load_network(loaded_net=resume_state["model"]) 166 | self.optimizer.load_state_dict(resume_state["optimizer"]) 167 | self.scheduler.load_state_dict(resume_state["scheduler"]) 168 | self.step = resume_state["step"] 169 | self.epoch = resume_state["epoch"] 170 | if is_logging_process(): 171 | self._logger.info( 172 | "Resuming from training state: %s" % self.cfg.load.resume_state_path 173 | ) 174 | -------------------------------------------------------------------------------- /src/model/model_dispatcher.py: -------------------------------------------------------------------------------- 1 | import timm 2 | 3 | from .loda import LoDa 4 | 5 | 6 | def create_model(cfg): 7 | if cfg.model.model_name == "loda": 8 | basic_model = timm.create_model( 9 | cfg.model.basic_model_name, 10 | img_size=cfg.model.vit_param.img_size, 11 | pretrained=cfg.model.basic_model_pretrained, 12 | num_classes=cfg.model.vit_param.num_classes, 13 | ) 14 | 15 | net_arch = LoDa(cfg=cfg, basic_state_dict=basic_model.state_dict()) 16 | net_arch.freeze() 17 | else: 18 | raise Exception("%s model not supported" % cfg.model.model_name) 19 | 20 | return net_arch 21 | -------------------------------------------------------------------------------- /src/model/patch_embed.py: -------------------------------------------------------------------------------- 1 | """ Image to Patch Embedding using Conv2d 2 | 3 | A convolution based approach to patchifying a 2D image w/ embedding projection. 4 | 5 | Based on the impl in https://github.com/google-research/vision_transformer 6 | 7 | Hacked together by / Copyright 2020 Ross Wightman 8 | """ 9 | from torch import nn as nn 10 | 11 | from src.utils.utils import to_2tuple 12 | 13 | 14 | class PatchEmbed(nn.Module): 15 | """2D Image to Patch Embedding""" 16 | 17 | def __init__( 18 | self, 19 | img_size=224, 20 | patch_size=16, 21 | in_chans=3, 22 | embed_dim=768, 23 | norm_layer=None, 24 | flatten=True, 25 | bias=True, 26 | ): 27 | super().__init__() 28 | img_size = to_2tuple(img_size) 29 | patch_size = to_2tuple(patch_size) 30 | self.img_size = img_size 31 | self.patch_size = patch_size 32 | self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) 33 | self.num_patches = self.grid_size[0] * self.grid_size[1] 34 | self.flatten = flatten 35 | 36 | self.proj = nn.Conv2d( 37 | in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias 38 | ) 39 | self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() 40 | 41 | def forward(self, x): 42 | B, C, H, W = x.shape 43 | x = self.proj(x) 44 | if self.flatten: 45 | x = x.flatten(2).transpose(1, 2) # BCHW -> BNC 46 | x = self.norm(x) 47 | return x 48 | -------------------------------------------------------------------------------- /src/tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .test_model import test_model 2 | from .train_model import train_model 3 | 4 | __all__ = ["test_model", "train_model"] 5 | -------------------------------------------------------------------------------- /src/tools/test_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torch 5 | from tqdm import tqdm 6 | 7 | from src.utils.metrics import ( 8 | calculate_plcc, 9 | calculate_rmse, 10 | calculate_srcc, 11 | logistic_regression, 12 | ) 13 | from src.utils.utils import get_logger, is_logging_process 14 | 15 | 16 | def test_model(cfg, model, test_loader, writer): 17 | logger = get_logger(cfg, os.path.basename(__file__)) 18 | 19 | model.net.eval() 20 | 21 | if is_logging_process(): 22 | pbar = tqdm(test_loader) 23 | else: 24 | pbar = test_loader 25 | 26 | total_test_loss = 0 27 | pred_scores = [] 28 | gt_scores = [] 29 | test_loop_len = 0 30 | with torch.no_grad(): 31 | for model_input, model_target in pbar: 32 | target = model_target.to(cfg.dist.device) 33 | 34 | output = model.inference(model_input) 35 | loss_v = model.compute_loss(output, target) 36 | 37 | if cfg.dist.gpus > 0: 38 | # Aggregate loss_v from all GPUs. loss_v is set as the sum of all GPUs' loss_v. 39 | torch.distributed.all_reduce(loss_v) 40 | loss_v /= torch.tensor(float(cfg.dist.gpus)) 41 | 42 | # gather scores from all GPUs. 43 | # TODO: is torch.distributed.all_reduce sequential? 44 | output_list = [ 45 | torch.zeros(output.shape, dtype=output.dtype, device=output.device) 46 | for _ in range(cfg.dist.device_num) 47 | ] 48 | target_list = [ 49 | torch.zeros(target.shape, dtype=target.dtype, device=target.device) 50 | for _ in range(cfg.dist.device_num) 51 | ] 52 | torch.distributed.all_gather(output_list, output) 53 | torch.distributed.all_gather(target_list, target) 54 | output = torch.concat(output_list) 55 | target = torch.concat(target_list) 56 | 57 | total_test_loss += loss_v.to("cpu").item() 58 | pred_scores = pred_scores + output.squeeze(1).cpu().tolist() 59 | gt_scores = gt_scores + target.cpu().tolist() 60 | test_loop_len += 1 61 | 62 | total_test_loss /= test_loop_len 63 | 64 | # compute metrics related to Image Quality Assessment task 65 | pred_scores = np.mean( 66 | np.reshape(np.array(pred_scores).squeeze(), (-1, cfg.test.patch_num)), 67 | axis=1, 68 | ) 69 | gt_scores = np.mean( 70 | np.reshape(np.array(gt_scores).squeeze(), (-1, cfg.test.patch_num)), axis=1 71 | ) 72 | 73 | test_srcc = calculate_srcc(pred_scores, gt_scores) 74 | # Fit the scale of predict scores to MOS scores using logistic regression suggested by VQEG. 75 | pred_scores = logistic_regression(gt_scores, pred_scores) 76 | test_plcc = calculate_plcc(pred_scores, gt_scores) 77 | test_rmse = calculate_rmse(pred_scores, gt_scores) 78 | 79 | if writer is not None: 80 | writer.logging_with_step(test_plcc, model.step, "test_plcc") 81 | writer.logging_with_step(test_srcc, model.step, "test_srcc") 82 | writer.logging_with_step(test_rmse, model.step, "test_rmse") 83 | writer.logging_with_step(total_test_loss, model.step, "test_loss") 84 | if is_logging_process(): 85 | logger.info( 86 | "Test PLCC %.04f at (epoch: %d / step: %d)" 87 | % (test_plcc, model.epoch + 1, model.step) 88 | ) 89 | logger.info( 90 | "Test SRCC %.04f at (epoch: %d / step: %d)" 91 | % (test_srcc, model.epoch + 1, model.step) 92 | ) 93 | logger.info( 94 | "Test RMSE %.04f at (epoch: %d / step: %d)" 95 | % (test_rmse, model.epoch + 1, model.step) 96 | ) 97 | logger.info( 98 | "Test Loss %.04f at (epoch: %d / step: %d)" 99 | % (total_test_loss, model.epoch + 1, model.step) 100 | ) 101 | -------------------------------------------------------------------------------- /src/tools/train_model.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | 4 | from tqdm import tqdm 5 | 6 | from src.utils.utils import get_logger, is_logging_process 7 | 8 | 9 | def train_model(cfg, model, train_loader, writer): 10 | logger = get_logger(cfg, os.path.basename(__file__)) 11 | model.net.train() 12 | 13 | if is_logging_process(): 14 | pbar = tqdm(train_loader, postfix=f"loss: {model.log.loss_v:.04f}") 15 | else: 16 | pbar = train_loader 17 | 18 | for model_input, model_target in pbar: 19 | model.optimize_parameters(model_input, model_target) 20 | loss = model.log.loss_v 21 | 22 | if is_logging_process(): 23 | pbar.postfix = f"loss: {model.log.loss_v:.04f}" 24 | 25 | model.step += 1 26 | 27 | if is_logging_process() and (loss > 1e8 or math.isnan(loss)): 28 | logger.error("Loss exploded to %.02f at step %d!" % (loss, model.step)) 29 | raise Exception("Loss exploded") 30 | 31 | if model.step % cfg.log.summary_interval == 0: 32 | if writer is not None: 33 | writer.logging_with_step(loss, model.step, "train_loss") 34 | -------------------------------------------------------------------------------- /src/trainer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import itertools 3 | import os 4 | import random 5 | import traceback 6 | 7 | import hydra 8 | import pyrootutils 9 | import torch 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | from hydra.core.hydra_config import HydraConfig 13 | from omegaconf import DictConfig, OmegaConf, open_dict 14 | 15 | pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) 16 | 17 | from src.dataset import DataloaderMode, create_dataloader 18 | from src.model import Model, create_model 19 | from src.tools import test_model, train_model 20 | from src.utils.loss import get_loss 21 | from src.utils.utils import get_logger, is_logging_process, set_random_seed 22 | from src.utils.writer import Writer 23 | 24 | 25 | def setup(cfg, rank): 26 | os.environ["MASTER_ADDR"] = cfg.dist.master_addr 27 | os.environ["MASTER_PORT"] = cfg.dist.master_port 28 | timeout_sec = 1800 29 | if cfg.dist.timeout is not None: 30 | os.environ["NCCL_BLOCKING_WAIT"] = "1" 31 | timeout_sec = cfg.dist.timeout 32 | timeout = datetime.timedelta(seconds=timeout_sec) 33 | 34 | # initialize the process group 35 | dist.init_process_group( 36 | cfg.dist.mode, 37 | rank=rank, 38 | world_size=cfg.dist.gpus, 39 | timeout=timeout, 40 | ) 41 | 42 | 43 | def cleanup(): 44 | dist.destroy_process_group() 45 | 46 | 47 | def distributed_run(fn, cfg): 48 | mp.spawn(fn, args=(cfg,), nprocs=cfg.dist.gpus, join=True) 49 | 50 | 51 | def train_loop(rank, cfg): 52 | logger = get_logger(cfg, os.path.basename(__file__)) 53 | if cfg.dist.device == "cuda" and cfg.dist.gpus != 0: 54 | cfg.dist.device = rank 55 | setup(cfg, rank) 56 | torch.cuda.set_device(cfg.dist.device) 57 | 58 | if not OmegaConf.has_resolver("eval"): 59 | OmegaConf.register_new_resolver("eval", eval) 60 | 61 | # setup writer 62 | if is_logging_process(): 63 | # set log/checkpoint dir 64 | os.makedirs(cfg.log.chkpt_dir, exist_ok=True) 65 | # set writer (tensorboard / wandb) 66 | writer = Writer(cfg, "tensorboard") 67 | cfg_str = OmegaConf.to_yaml(cfg) 68 | logger.info("Config:\n" + cfg_str) 69 | if cfg.data.root == "": 70 | logger.error("train or test data directory cannot be empty.") 71 | raise Exception("Please specify directories of data") 72 | logger.info("Set up train process") 73 | else: 74 | writer = None 75 | 76 | # make dataloader 77 | if is_logging_process(): 78 | logger.info("Making train dataloader...") 79 | train_loader, train_sampler = create_dataloader(cfg, DataloaderMode.train, rank) 80 | if is_logging_process(): 81 | logger.info("Making test dataloader...") 82 | test_loader, _ = create_dataloader(cfg, DataloaderMode.test, rank) 83 | 84 | # init Model 85 | net_arch = create_model(cfg=cfg) 86 | loss_f = get_loss(cfg=cfg) 87 | model = Model(cfg, net_arch, loss_f, rank) 88 | 89 | # load training state / network checkpoint 90 | if cfg.load.resume_state_path is not None: 91 | model.load_training_state() 92 | elif cfg.load.network_chkpt_path is not None: 93 | model.load_network() 94 | else: 95 | if is_logging_process(): 96 | logger.info("Starting new training run.") 97 | 98 | try: 99 | if ( 100 | cfg.dist.device == "cpu" 101 | or cfg.dist.gpus == 0 102 | or cfg.data.divide_dataset_per_gpu 103 | ): 104 | epoch_step = 1 105 | else: 106 | epoch_step = cfg.dist.gpus 107 | for model.epoch in itertools.count(model.epoch + 1, epoch_step): 108 | if model.epoch >= cfg.num_epoch: 109 | break 110 | if train_sampler is not None: 111 | train_sampler.set_epoch(model.epoch) 112 | train_model(cfg, model, train_loader, writer) 113 | if model.epoch % cfg.log.net_chkpt_interval == 0: 114 | model.save_network() 115 | if model.epoch % cfg.log.train_chkpt_interval == 0: 116 | model.save_training_state() 117 | test_model(cfg, model, test_loader, writer) 118 | if is_logging_process(): 119 | logger.info("End of Train") 120 | except Exception: 121 | if is_logging_process(): 122 | logger.error(traceback.format_exc()) 123 | else: 124 | traceback.print_exc() 125 | finally: 126 | if cfg.dist.device == "cuda" and cfg.dist.gpus != 0: 127 | cleanup() 128 | 129 | 130 | @hydra.main(version_base=None, config_path="../config", config_name="default") 131 | def main(hydra_cfg: DictConfig): 132 | hydra_cfg.dist.device = hydra_cfg.dist.device.lower() 133 | with open_dict(hydra_cfg): 134 | hydra_cfg.job_logging_cfg = HydraConfig.get().job_logging 135 | hydra_cfg.hydra_output_dir = HydraConfig.get().run.dir 136 | # random seed 137 | if hydra_cfg.random_seed is None: 138 | hydra_cfg.random_seed = random.randint(1, 10000) 139 | set_random_seed(hydra_cfg.random_seed) 140 | 141 | if hydra_cfg.dist.device == "cuda" and hydra_cfg.dist.gpus < 0: 142 | hydra_cfg.dist.gpus = torch.cuda.device_count() 143 | if hydra_cfg.dist.device == "cpu" or hydra_cfg.dist.gpus == 0: 144 | hydra_cfg.dist.gpus = 0 145 | train_loop(0, hydra_cfg) 146 | else: 147 | distributed_run(train_loop, hydra_cfg) 148 | 149 | 150 | if __name__ == "__main__": 151 | main() 152 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeosXu/LoDa/82304c20c34c1b5bd45f27bd7ab6e9104a285152/src/utils/__init__.py -------------------------------------------------------------------------------- /src/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | based on IQA-PyTorch: https://github.com/chaofengc/IQA-PyTorch 3 | """ 4 | import collections 5 | import os 6 | import random 7 | from collections.abc import Sequence 8 | from itertools import repeat 9 | 10 | import torch 11 | import torchvision.transforms as tf 12 | import torchvision.transforms.functional as F 13 | from PIL import Image 14 | 15 | 16 | def getFileName(path, suffix): 17 | filename = [] 18 | f_list = os.listdir(path) 19 | for i in f_list: 20 | if os.path.splitext(i)[1] == suffix: 21 | filename.append(i) 22 | return filename 23 | 24 | 25 | def getTIDFileName(path, suffix): 26 | filename = [] 27 | f_list = os.listdir(path) 28 | for i in f_list: 29 | if suffix.find(os.path.splitext(i)[1]) != -1: 30 | filename.append(i[1:3]) 31 | return filename 32 | 33 | 34 | def pil_loader(path): 35 | return Image.open(path).convert("RGB") 36 | 37 | 38 | def transform_mapping(key, args): 39 | if key == "hflip" and args: 40 | return [PairedRandomHorizontalFlip()] 41 | if key == "vflip" and args: 42 | return [PairedRandomHorizontalFlip()] 43 | elif key == "random_crop": 44 | return [PairedRandomCrop(args)] 45 | elif key == "center_crop": 46 | return [PairedCenterCrop(args)] 47 | elif key == "resize": 48 | return [PairedResize(args)] 49 | elif key == "adaptive_resize": 50 | return [PairedAdaptiveResize(args)] 51 | elif key == "random_square_resize": 52 | return [PairedRandomSquareResize(args)] 53 | elif key == "random_arp_resize": 54 | return [PairedRandomARPResize(args)] 55 | elif key == "ada_pad": 56 | return [PairedAdaptivePadding(args)] 57 | elif key == "rot90" and args: 58 | return [PairedRandomRot90(args)] 59 | elif key == "randomerase": 60 | return [PairedRandomErasing(**args)] 61 | elif key == "totensor" and args: 62 | return [PairedToTensor()] 63 | elif key == "normalize" and args: 64 | return [PairedNormalize(args)] 65 | else: 66 | return [] 67 | 68 | 69 | def _check_pair(x): 70 | if isinstance(x, (tuple, list)) and len(x) >= 2: 71 | return True 72 | 73 | 74 | class PairedToTensor(tf.ToTensor): 75 | """Pair version of center crop""" 76 | 77 | def to_tensor(self, x): 78 | if isinstance(x, torch.Tensor): 79 | return x 80 | else: 81 | return F.to_tensor(x) 82 | 83 | def __call__(self, imgs): 84 | if _check_pair(imgs): 85 | for i in range(len(imgs)): 86 | imgs[i] = self.to_tensor(imgs[i]) 87 | return imgs 88 | else: 89 | return self.to_tensor(imgs) 90 | 91 | 92 | class PairedNormalize(tf.Normalize): 93 | """Pair version of normalize""" 94 | 95 | def __call__(self, imgs): 96 | if _check_pair(imgs): 97 | for i in range(len(imgs)): 98 | imgs[i] = super().forward(imgs[i]) 99 | return imgs 100 | else: 101 | return super().forward(imgs) 102 | 103 | 104 | class PairedCenterCrop(tf.CenterCrop): 105 | """Pair version of center crop""" 106 | 107 | def forward(self, imgs): 108 | if _check_pair(imgs): 109 | for i in range(len(imgs)): 110 | imgs[i] = super().forward(imgs[i]) 111 | return imgs 112 | elif isinstance(imgs, Image.Image): 113 | return super().forward(imgs) 114 | 115 | 116 | class PairedRandomCrop(tf.RandomCrop): 117 | """Pair version of random crop""" 118 | 119 | def _pad(self, img): 120 | if self.padding is not None: 121 | img = F.pad(img, self.padding, self.fill, self.padding_mode) 122 | 123 | width, height = img.size 124 | # pad the width if needed 125 | if self.pad_if_needed and width < self.size[1]: 126 | padding = [self.size[1] - width, 0] 127 | img = F.pad(img, padding, self.fill, self.padding_mode) 128 | # pad the height if needed 129 | if self.pad_if_needed and height < self.size[0]: 130 | padding = [0, self.size[0] - height] 131 | img = F.pad(img, padding, self.fill, self.padding_mode) 132 | return img 133 | 134 | def forward(self, imgs): 135 | if _check_pair(imgs): 136 | i, j, h, w = self.get_params(imgs[0], self.size) 137 | for i in range(len(imgs)): 138 | img = self._pad(imgs[i]) 139 | img = F.crop(img, i, j, h, w) 140 | imgs[i] = img 141 | return imgs 142 | elif isinstance(imgs, Image.Image): 143 | return super().forward(imgs) 144 | 145 | 146 | class PairedRandomErasing(tf.RandomErasing): 147 | """Pair version of random erasing""" 148 | 149 | def forward(self, imgs): 150 | if _check_pair(imgs): 151 | if torch.rand(1) < self.p: 152 | # cast self.value to script acceptable type 153 | if isinstance(self.value, (int, float)): 154 | value = [self.value] 155 | elif isinstance(self.value, str): 156 | value = None 157 | elif isinstance(self.value, tuple): 158 | value = list(self.value) 159 | else: 160 | value = self.value 161 | 162 | if value is not None and not (len(value) in (1, imgs[0].shape[-3])): 163 | raise ValueError( 164 | "If value is a sequence, it should have either a single value or " 165 | f"{imgs[0].shape[-3]} (number of input channels)" 166 | ) 167 | 168 | x, y, h, w, v = self.get_params( 169 | imgs[0], scale=self.scale, ratio=self.ratio, value=value 170 | ) 171 | for i in range(len(imgs)): 172 | imgs[i] = F.erase(imgs[i], x, y, h, w, v, self.inplace) 173 | return imgs 174 | elif isinstance(imgs, Image.Image): 175 | return super().forward(imgs) 176 | 177 | 178 | class PairedRandomHorizontalFlip(tf.RandomHorizontalFlip): 179 | """Pair version of random hflip""" 180 | 181 | def forward(self, imgs): 182 | if _check_pair(imgs): 183 | if torch.rand(1) < self.p: 184 | for i in range(len(imgs)): 185 | imgs[i] = F.hflip(imgs[i]) 186 | return imgs 187 | elif isinstance(imgs, Image.Image): 188 | return super().forward(imgs) 189 | 190 | 191 | class PairedRandomVerticalFlip(tf.RandomVerticalFlip): 192 | """Pair version of random hflip""" 193 | 194 | def forward(self, imgs): 195 | if _check_pair(imgs): 196 | if torch.rand(1) < self.p: 197 | for i in range(len(imgs)): 198 | imgs[i] = F.vflip(imgs[i]) 199 | return imgs 200 | elif isinstance(imgs, Image.Image): 201 | return super().forward(imgs) 202 | 203 | 204 | class PairedRandomRot90(torch.nn.Module): 205 | """Pair version of random hflip""" 206 | 207 | def __init__(self, p=0.5): 208 | super().__init__() 209 | self.p = p 210 | 211 | def forward(self, imgs): 212 | if _check_pair(imgs): 213 | if torch.rand(1) < self.p: 214 | for i in range(len(imgs)): 215 | imgs[i] = F.rotate(imgs[i], 90) 216 | return imgs 217 | elif isinstance(imgs, Image.Image): 218 | if torch.rand(1) < self.p: 219 | imgs = F.rotate(imgs, 90) 220 | return imgs 221 | 222 | 223 | class PairedResize(tf.Resize): 224 | """Pair version of resize""" 225 | 226 | def forward(self, imgs): 227 | if _check_pair(imgs): 228 | for i in range(len(imgs)): 229 | imgs[i] = super().forward(imgs[i]) 230 | return imgs 231 | elif isinstance(imgs, Image.Image): 232 | return super().forward(imgs) 233 | 234 | 235 | class PairedAdaptiveResize(tf.Resize): 236 | """ARP preserved resize when necessary""" 237 | 238 | def forward(self, imgs): 239 | if _check_pair(imgs): 240 | for i in range(len(imgs)): 241 | tmpimg = imgs[i] 242 | min_size = min(tmpimg.size) 243 | if min_size < self.size: 244 | tmpimg = super().forward(tmpimg) 245 | imgs[i] = tmpimg 246 | return imgs 247 | elif isinstance(imgs, Image.Image): 248 | tmpimg = imgs 249 | min_size = min(tmpimg.size) 250 | if min_size < self.size: 251 | tmpimg = super().forward(tmpimg) 252 | return tmpimg 253 | 254 | 255 | class PairedRandomARPResize(torch.nn.Module): 256 | """Pair version of resize""" 257 | 258 | def __init__( 259 | self, size_range, interpolation=tf.InterpolationMode.BILINEAR, antialias=None 260 | ): 261 | super().__init__() 262 | self.interpolation = interpolation 263 | self.antialias = antialias 264 | self.size_range = size_range 265 | if not (isinstance(size_range, Sequence) and len(size_range) == 2): 266 | raise TypeError( 267 | f"size_range should be sequence with 2 int. Got {size_range} with {type(size_range)}" 268 | ) 269 | 270 | def forward(self, imgs): 271 | min_size, max_size = sorted(self.size_range) 272 | target_size = random.randint(min_size, max_size) 273 | if _check_pair(imgs): 274 | for i in range(len(imgs)): 275 | imgs[i] = F.resize(imgs[i], target_size, self.interpolation) 276 | return imgs 277 | elif isinstance(imgs, Image.Image): 278 | return F.resize(imgs, target_size, self.interpolation) 279 | 280 | 281 | class PairedRandomSquareResize(torch.nn.Module): 282 | """Pair version of resize""" 283 | 284 | def __init__( 285 | self, size_range, interpolation=tf.InterpolationMode.BILINEAR, antialias=None 286 | ): 287 | super().__init__() 288 | self.interpolation = interpolation 289 | self.antialias = antialias 290 | self.size_range = size_range 291 | if not (isinstance(size_range, Sequence) and len(size_range) == 2): 292 | raise TypeError( 293 | f"size_range should be sequence with 2 int. Got {size_range} with {type(size_range)}" 294 | ) 295 | 296 | def forward(self, imgs): 297 | min_size, max_size = sorted(self.size_range) 298 | target_size = random.randint(min_size, max_size) 299 | target_size = (target_size, target_size) 300 | if _check_pair(imgs): 301 | for i in range(len(imgs)): 302 | imgs[i] = F.resize(imgs[i], target_size, self.interpolation) 303 | return imgs 304 | elif isinstance(imgs, Image.Image): 305 | return F.resize(imgs, target_size, self.interpolation) 306 | 307 | 308 | class PairedAdaptivePadding(torch.nn.Module): 309 | """Pair version of resize""" 310 | 311 | def __init__(self, target_size, fill=0, padding_mode="constant"): 312 | super().__init__() 313 | self.target_size = self.to_2tuple(target_size) 314 | self.fill = fill 315 | self.padding_mode = padding_mode 316 | 317 | def get_padding(self, x): 318 | w, h = x.size 319 | th, tw = self.target_size 320 | assert ( 321 | th >= h and tw >= w 322 | ), f"Target size {self.target_size} should be larger than image size ({h}, {w})" 323 | pad_row = th - h 324 | pad_col = tw - w 325 | pad_l, pad_r, pad_t, pad_b = ( 326 | pad_col // 2, 327 | pad_col - pad_col // 2, 328 | pad_row // 2, 329 | pad_row - pad_row // 2, 330 | ) 331 | return (pad_l, pad_t, pad_r, pad_b) 332 | 333 | def forward(self, imgs): 334 | if _check_pair(imgs): 335 | for i in range(len(imgs)): 336 | padding = self.get_padding(imgs[i]) 337 | imgs[i] = F.pad(imgs[i], padding, self.fill, self.padding_mode) 338 | return imgs 339 | elif isinstance(imgs, Image.Image): 340 | padding = self.get_padding(imgs) 341 | imgs = F.pad(imgs, padding, self.fill, self.padding_mode) 342 | return imgs 343 | 344 | def to_2tuple(x): 345 | if isinstance(x, collections.abc.Iterable): 346 | return x 347 | return tuple(repeat(x, 2)) 348 | -------------------------------------------------------------------------------- /src/utils/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def plcc_loss(y_pred, y): 5 | y = y.detach().float() 6 | sigma_hat, m_hat = torch.std_mean(y_pred, unbiased=False) 7 | y_pred = (y_pred - m_hat) / (sigma_hat + 1e-8) 8 | sigma, m = torch.std_mean(y, unbiased=False) 9 | y = (y - m) / (sigma + 1e-8) 10 | loss0 = torch.nn.functional.mse_loss(y_pred, y) / 4 11 | rho = torch.mean(y_pred * y) 12 | loss1 = torch.nn.functional.mse_loss(rho * y_pred, y) / 4 13 | return ((loss0 + loss1) / 2).float() 14 | 15 | 16 | def get_loss(cfg): 17 | loss_fn = [] 18 | for name, weight in cfg.loss.fn: 19 | if name == "plcc_loss": 20 | fn = plcc_loss 21 | else: 22 | raise Exception("%s loss not supported" % name) 23 | 24 | loss_fn.append((fn, weight)) 25 | 26 | return loss_fn 27 | -------------------------------------------------------------------------------- /src/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | from scipy.optimize import curve_fit 4 | 5 | 6 | def logistic_func(X, bayta1, bayta2, bayta3, bayta4): 7 | # 4-parameter logistic function 8 | logisticPart = 1 + np.exp(np.negative(np.divide(X - bayta3, np.abs(bayta4)))) 9 | yhat = bayta2 + np.divide(bayta1 - bayta2, logisticPart) 10 | return yhat 11 | 12 | 13 | def logistic_regression(y, y_pred): 14 | # Fit the scale of predict scores to MOS scores using logistic regression suggested by VQEG. 15 | beta_init = [np.max(y), np.min(y), np.mean(y_pred), 0.5] 16 | popt, _ = curve_fit(logistic_func, y_pred, y, p0=beta_init, maxfev=int(1e8)) 17 | y_pred_logistic = logistic_func(y_pred, *popt) 18 | return y_pred_logistic 19 | 20 | 21 | def calculate_rmse(y_pred, y, fit_scale=None, eps=1e-8): 22 | if fit_scale is not None: 23 | y_pred = logistic_regression(y, y_pred) 24 | return np.sqrt(np.mean((y_pred - y) ** 2) + eps) 25 | 26 | 27 | def calculate_plcc(y_pred, y, fit_scale=None): 28 | if fit_scale is not None: 29 | y_pred = logistic_regression(y, y_pred) 30 | return stats.pearsonr(y_pred, y)[0] 31 | 32 | 33 | def calculate_srcc(y_pred, y): 34 | return stats.spearmanr(y_pred, y)[0] 35 | 36 | 37 | def calculate_krcc(y_pred, y): 38 | return stats.kendalltau(y_pred, y)[0] 39 | -------------------------------------------------------------------------------- /src/utils/utils.py: -------------------------------------------------------------------------------- 1 | import collections.abc 2 | import logging 3 | import os 4 | import random 5 | from itertools import repeat 6 | 7 | import numpy as np 8 | import torch 9 | import torch.distributed as dist 10 | from omegaconf import OmegaConf 11 | 12 | 13 | def set_random_seed(seed): 14 | # setting up the random seed 15 | os.environ["PYTHONHASHSEED"] = str(seed) 16 | random.seed(seed) 17 | np.random.seed(seed) 18 | torch.manual_seed(seed) 19 | torch.cuda.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | torch.backends.cudnn.deterministic = True 22 | 23 | 24 | def is_logging_process(): 25 | return not dist.is_initialized() or dist.get_rank() == 0 26 | 27 | 28 | def get_logger(cfg, name=None): 29 | # log_file_path is used when unit testing 30 | if is_logging_process(): 31 | logging.config.dictConfig( 32 | OmegaConf.to_container(cfg.job_logging_cfg, resolve=True) 33 | ) 34 | return logging.getLogger(name) 35 | 36 | 37 | # From PyTorch internals 38 | def _ntuple(n): 39 | def parse(x): 40 | if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): 41 | return tuple(x) 42 | return tuple(repeat(x, n)) 43 | 44 | return parse 45 | 46 | 47 | to_1tuple = _ntuple(1) 48 | to_2tuple = _ntuple(2) 49 | to_3tuple = _ntuple(3) 50 | to_4tuple = _ntuple(4) 51 | to_ntuple = _ntuple 52 | -------------------------------------------------------------------------------- /src/utils/writer.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import wandb 4 | from torch.utils.tensorboard import SummaryWriter 5 | 6 | 7 | class Writer(SummaryWriter): 8 | def __init__(self, cfg, logdir): 9 | self.cfg = cfg 10 | if cfg.log.use_tensorboard: 11 | self.tensorboard = SummaryWriter(logdir) 12 | if cfg.log.use_wandb: 13 | wandb_init_conf = cfg.log.wandb_init_conf 14 | wandb.init(config=cfg, **wandb_init_conf) 15 | if cfg.log.wandb_init_conf.save_code: 16 | wandb.run.log_code("./src/") 17 | wandb.save(os.path.join(cfg.hydra_output_dir, ".hydra/*.yaml")) 18 | 19 | def logging_with_step(self, value, step, logging_name): 20 | if self.cfg.log.use_tensorboard: 21 | self.tensorboard.add_scalar(logging_name, value, step) 22 | if self.cfg.log.use_wandb: 23 | wandb.log({logging_name: value}, step=step) 24 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeosXu/LoDa/82304c20c34c1b5bd45f27bd7ab6e9104a285152/tests/__init__.py -------------------------------------------------------------------------------- /tests/model/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NeosXu/LoDa/82304c20c34c1b5bd45f27bd7ab6e9104a285152/tests/model/__init__.py -------------------------------------------------------------------------------- /tests/model/model_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | 5 | from src.model import Model, create_model 6 | from src.utils.loss import get_loss 7 | from tests.test_case import ProjectTestCase 8 | 9 | 10 | class TestModel(ProjectTestCase): 11 | @classmethod 12 | def setup_class(cls): 13 | cls.model_input = torch.rand(4, 3, 224, 224) 14 | cls.model_target = torch.randn(4) * 4 + 1 15 | 16 | def setup_method(self, method): 17 | super(TestModel, self).setup_method() 18 | self.net = create_model(cfg=self.cfg) 19 | self.loss_f = get_loss(cfg=self.cfg) 20 | self.model = Model(self.cfg, self.net, self.loss_f) 21 | 22 | def test_model(self): 23 | assert self.model.cfg == self.cfg 24 | assert self.model.net == self.net 25 | assert self.model.loss_f == self.loss_f 26 | 27 | def test_run_network(self): 28 | output = self.model.run_network(self.model_input.to(self.cfg.dist.device)) 29 | assert output.shape == self.model_target.unsqueeze(1).shape 30 | 31 | def test_optimize_parameters(self): 32 | self.model.optimize_parameters( 33 | self.model_input.to(self.cfg.dist.device), 34 | self.model_target.to(self.cfg.dist.device), 35 | ) 36 | assert self.model.log.loss_v is not None 37 | 38 | def test_inference(self): 39 | output = self.model.inference(self.model_input.to(self.cfg.dist.device)) 40 | assert output.shape == self.model_target.unsqueeze(1).shape 41 | 42 | def test_save_load_network(self): 43 | local_net = create_model(cfg=self.cfg) 44 | self.loss_f = get_loss(cfg=self.cfg) 45 | local_model = Model(self.cfg, local_net, self.loss_f) 46 | 47 | self.model.save_network() 48 | save_filename = "%s_%d.pt" % (self.cfg.name, self.model.step) 49 | save_path = os.path.join(self.cfg.log.chkpt_dir, save_filename) 50 | self.cfg.load.network_chkpt_path = save_path 51 | 52 | assert os.path.exists(save_path) and os.path.isfile(save_path) 53 | 54 | local_model.load_network() 55 | parameters = zip( 56 | list(local_model.net.parameters()), list(self.model.net.parameters()) 57 | ) 58 | for load, origin in parameters: 59 | assert (load == origin).all() 60 | 61 | def test_save_load_state(self): 62 | local_net = create_model(cfg=self.cfg) 63 | self.loss_f = get_loss(cfg=self.cfg) 64 | local_model = Model(self.cfg, local_net, self.loss_f) 65 | 66 | self.model.save_training_state() 67 | save_filename = "%s_%d.state" % (self.cfg.name, self.model.step) 68 | save_path = os.path.join(self.cfg.log.chkpt_dir, save_filename) 69 | self.cfg.load.resume_state_path = save_path 70 | 71 | assert os.path.exists(save_path) and os.path.isfile(save_path) 72 | 73 | local_model.load_training_state() 74 | parameters = zip( 75 | list(local_model.net.parameters()), list(self.model.net.parameters()) 76 | ) 77 | for load, origin in parameters: 78 | assert (load == origin).all() 79 | assert local_model.epoch == self.model.epoch 80 | assert local_model.step == self.model.step 81 | -------------------------------------------------------------------------------- /tests/model/net_arch_test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tempfile 3 | 4 | import torch 5 | from hydra import compose, initialize 6 | 7 | from src.model import create_model 8 | 9 | TEST_DIR = tempfile.mkdtemp(prefix="project_tests") 10 | 11 | 12 | def test_net_arch(): 13 | os.makedirs(TEST_DIR, exist_ok=True) 14 | with initialize(version_base=None, config_path="../../config"): 15 | cfg = compose(config_name="default", overrides=[f"working_dir={TEST_DIR}"]) 16 | 17 | net = create_model(cfg) 18 | 19 | model_input = torch.rand(4, 3, 224, 224) 20 | model_target = torch.randn(4) * 4 + 1 21 | 22 | output = net(model_input) 23 | assert output.shape == model_target.unsqueeze(1).shape 24 | -------------------------------------------------------------------------------- /tests/test_case.py: -------------------------------------------------------------------------------- 1 | # ref: https://github.com/allenai/allennlp/blob/9c51d6c89875b3a3a50cac165d6f3188d9941c5b/allennlp/common/testing/test_case.py 2 | 3 | import os 4 | import pathlib 5 | import shutil 6 | import tempfile 7 | 8 | from hydra import compose, initialize 9 | from omegaconf import OmegaConf, open_dict 10 | 11 | from src.utils.utils import get_logger 12 | 13 | TEST_DIR = tempfile.mkdtemp(prefix="project_tests") 14 | 15 | 16 | class ProjectTestCase: 17 | def setup_method(self): 18 | # set log/checkpoint dir 19 | self.TEST_DIR = pathlib.Path(TEST_DIR) 20 | self.working_dir = self.TEST_DIR 21 | chkpt_dir = (self.TEST_DIR / "chkpt").resolve() 22 | os.makedirs(self.TEST_DIR, exist_ok=True) 23 | os.makedirs(chkpt_dir, exist_ok=True) 24 | 25 | # set cfg 26 | OmegaConf.register_new_resolver("eval", eval, replace=True) 27 | with initialize(version_base=None, config_path="../config"): 28 | self.cfg = compose( 29 | config_name="default", overrides=[f"working_dir={self.working_dir}"] 30 | ) 31 | self.cfg.dist.device = "cpu" 32 | self.cfg.log.chkpt_dir = str(chkpt_dir) 33 | self.cfg.log.use_wandb = False 34 | self.cfg.log.use_tensorboard = False 35 | 36 | # load job_logging_cfg 37 | project_root_path = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 38 | 39 | _ = OmegaConf.load(os.path.join(project_root_path, "config/default.yaml")) 40 | 41 | # extract logging config of hydra 42 | logging_cfg_path = os.path.join( 43 | project_root_path, "config/hydra/job_logging/custom.yaml" 44 | ) 45 | if os.path.exists(logging_cfg_path): 46 | logging_cfg = OmegaConf.load(logging_cfg_path) 47 | else: 48 | logging_cfg = dict() 49 | with open_dict(self.cfg): 50 | self.cfg.job_logging_cfg = logging_cfg 51 | 52 | # set log file to dummy file 53 | self.cfg.job_logging_cfg.handlers.file.filename = str( 54 | (self.working_dir / "trainer.log").resolve() 55 | ) 56 | 57 | # set logger 58 | self.logger = get_logger(self.cfg, os.path.basename(__file__)) 59 | 60 | def teardown_method(self): 61 | shutil.rmtree(self.TEST_DIR) 62 | --------------------------------------------------------------------------------