├── .gitattributes ├── .github ├── scripts │ ├── make_executable.sh │ └── update_readme.py └── workflows │ ├── install-and-test-unibench.yml │ ├── publish.yml │ └── update-readme.yml ├── .gitignore ├── .pyre_configuration ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── assets ├── UPDATES.md └── header.png ├── benchmark_builder ├── README.md ├── __init__.py ├── benchmarks │ ├── __init__.py │ └── sundataset.py ├── class_names.py ├── main.py ├── templates.py └── utils.py ├── environment.yml ├── figures ├── example_figure1.ipynb ├── example_figure2.ipynb └── siglip2_results.ipynb ├── setup.py └── unibench ├── __init__.py ├── benchmarks_zoo ├── README.md ├── __init__.py ├── benchmarks.py ├── registry.py └── wrappers │ ├── __init__.py │ ├── bechmark_handler.py │ └── huggingface.py ├── common_utils ├── __init__.py ├── constants.py └── utils.py ├── evaluator.py ├── models_zoo ├── README.md ├── __init__.py ├── models.py ├── registry.py └── wrappers │ ├── __init__.py │ ├── base.py │ ├── blip.py │ ├── clip.py │ ├── huggingface.py │ ├── lit.py │ ├── transformations │ ├── __init__.py │ ├── faceblur.py │ └── grayscale2rgb.py │ └── xvlm_util │ ├── __init__.py │ ├── box_ops.py │ ├── clip_vit.py │ ├── swin_transformer.py │ ├── tokenization_bert.py │ ├── tokenization_roberta.py │ ├── utils.py │ ├── vit.py │ ├── xbert.py │ ├── xroberta.py │ └── xvlm.py └── output.py /.gitattributes: -------------------------------------------------------------------------------- 1 | *.f filter=lfs diff=lfs merge=lfs -text -------------------------------------------------------------------------------- /.github/scripts/make_executable.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | chmod +x .github/scripts/update_readme.py 3 | -------------------------------------------------------------------------------- /.github/scripts/update_readme.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import re 4 | import os 5 | 6 | def get_latest_update(): 7 | """Extract the latest update section from UPDATES.md""" 8 | with open("assets/UPDATES.md", "r") as file: 9 | content = file.read() 10 | 11 | # Find all update sections (starting with ## and followed by date/version) 12 | update_sections = re.findall(r'(## .+?(?=## |\Z))', content, re.DOTALL) 13 | 14 | if not update_sections: 15 | return None 16 | 17 | # Return the first (latest) update section 18 | latest_update = update_sections[0].strip() 19 | 20 | # Convert the section to a list of lines for easier processing 21 | update_lines = latest_update.split('\n') 22 | 23 | # Get the heading (first line) and change ## to #### 24 | heading = update_lines[0].replace('##', '####') 25 | 26 | # Get the content (everything after the heading) 27 | content_lines = update_lines[1:] 28 | content = '\n'.join([line.strip() for line in content_lines if line.strip()]) 29 | 30 | return { 31 | 'heading': heading, 32 | 'content': content 33 | } 34 | 35 | def update_readme(latest_update): 36 | """Update the News and Updates section in README.md""" 37 | with open("README.md", "r") as file: 38 | readme_content = file.read() 39 | 40 | # Define patterns to find the start of News section and the next section after it 41 | news_section_start = r'## News and Updates\s+For the latest news and updates, see the snippet below.\s+' 42 | news_start_match = re.search(news_section_start, readme_content) 43 | 44 | if not news_start_match: 45 | print("News and Updates section not found in README.md") 46 | return False 47 | 48 | # Find the next section header after News and Updates 49 | next_section_pattern = r'(?m)^## [^#]' 50 | next_sections = list(re.finditer(next_section_pattern, readme_content)) 51 | 52 | # Find the position of the section after News and Updates 53 | news_section_pos = news_start_match.start() 54 | next_section_pos = None 55 | 56 | for section in next_sections: 57 | if section.start() > news_section_pos: 58 | next_section_pos = section.start() 59 | break 60 | 61 | if not next_section_pos: 62 | print("Could not find the section after News and Updates") 63 | return False 64 | 65 | # Extract the parts before News section, the News section intro, and after the News section 66 | before_news = readme_content[:news_start_match.end()] 67 | after_news = readme_content[next_section_pos:] 68 | 69 | # Create the new README content with a single update section 70 | new_update_section = f"{latest_update['heading']}\n{latest_update['content']}\n \nFor full details, refer to the [UPDATES.md](./assets/UPDATES.md) file.\n\n" 71 | updated_readme = before_news + new_update_section + after_news 72 | 73 | # Write the updated content back to README.md 74 | with open("README.md", "w") as file: 75 | file.write(updated_readme) 76 | 77 | return True 78 | 79 | def main(): 80 | latest_update = get_latest_update() 81 | if latest_update: 82 | success = update_readme(latest_update) 83 | if success: 84 | print(f"Successfully updated README.md with latest update: {latest_update['heading']}") 85 | else: 86 | print("Failed to update README.md") 87 | else: 88 | print("No updates found in UPDATES.md") 89 | 90 | if __name__ == "__main__": 91 | main() 92 | -------------------------------------------------------------------------------- /.github/workflows/install-and-test-unibench.yml: -------------------------------------------------------------------------------- 1 | name: Python application 2 | 3 | on: 4 | push: 5 | branches: ["main"] 6 | pull_request: 7 | branches: ["main"] 8 | 9 | permissions: 10 | contents: read 11 | 12 | jobs: 13 | build: 14 | runs-on: ubuntu-latest 15 | 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 3.10 19 | uses: actions/setup-python@v3 20 | with: 21 | python-version: "3.10" 22 | - name: Install dependencies 23 | run: | 24 | pip install -e . 25 | pip install flake8 pytest 26 | - name: Test with pytest 27 | run: | 28 | python -m pytest tests/ 29 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Upload Python Package to PyPI when a Release is Created 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | pypi-publish: 9 | name: Publish release to PyPI 10 | runs-on: ubuntu-latest 11 | environment: 12 | name: pypi 13 | url: https://pypi.org/p/unibench 14 | permissions: 15 | id-token: write 16 | steps: 17 | - uses: actions/checkout@v4 18 | - name: Set up Python 19 | uses: actions/setup-python@v4 20 | with: 21 | python-version: "3.x" 22 | - name: Install dependencies 23 | run: | 24 | python -m pip install --upgrade pip 25 | pip install setuptools wheel 26 | - name: Build package 27 | run: | 28 | python setup.py sdist bdist_wheel # Could also be python -m build 29 | - name: Publish package distributions to PyPI 30 | uses: pypa/gh-action-pypi-publish@release/v1 31 | -------------------------------------------------------------------------------- /.github/workflows/update-readme.yml: -------------------------------------------------------------------------------- 1 | name: Update README News Section 2 | 3 | on: 4 | push: 5 | paths: 6 | - 'assets/UPDATES.md' 7 | 8 | jobs: 9 | update-readme: 10 | runs-on: ubuntu-latest 11 | permissions: 12 | contents: write 13 | steps: 14 | - name: Check out repository 15 | uses: actions/checkout@v4 16 | 17 | - name: Set up Python 18 | uses: actions/setup-python@v4 19 | with: 20 | python-version: '3.10' 21 | 22 | - name: Update README 23 | run: python .github/scripts/update_readme.py 24 | 25 | - name: Commit changes 26 | run: | 27 | git config --local user.email "action@github.com" 28 | git config --local user.name "GitHub Action" 29 | git add README.md 30 | git diff --staged --quiet || git commit -m "Update News section in README [skip ci]" 31 | git push 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by https://www.toptal.com/developers/gitignore/api/python 2 | # Edit at https://www.toptal.com/developers/gitignore?templates=python 3 | 4 | basic_test.py 5 | dataset_builder/ 6 | test_* 7 | 8 | ### Python ### 9 | .vscode/ 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | share/python-wheels/ 34 | *.egg-info/ 35 | .installed.cfg 36 | *.egg 37 | MANIFEST 38 | 39 | # PyInstaller 40 | # Usually these files are written by a python script from a template 41 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 42 | *.manifest 43 | *.spec 44 | 45 | # Installer logs 46 | pip-log.txt 47 | pip-delete-this-directory.txt 48 | 49 | # Unit test / coverage reports 50 | htmlcov/ 51 | .tox/ 52 | .nox/ 53 | .coverage 54 | .coverage.* 55 | .cache 56 | nosetests.xml 57 | coverage.xml 58 | *.cover 59 | *.py,cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | cover/ 63 | 64 | # Translations 65 | *.mo 66 | *.pot 67 | 68 | # Django stuff: 69 | *.log 70 | local_settings.py 71 | db.sqlite3 72 | db.sqlite3-journal 73 | 74 | # Flask stuff: 75 | instance/ 76 | .webassets-cache 77 | 78 | # Scrapy stuff: 79 | .scrapy 80 | 81 | # Sphinx documentation 82 | docs/_build/ 83 | 84 | # PyBuilder 85 | .pybuilder/ 86 | target/ 87 | 88 | # Jupyter Notebook 89 | .ipynb_checkpoints 90 | 91 | # IPython 92 | profile_default/ 93 | ipython_config.py 94 | 95 | # pyenv 96 | # For a library or package, you might want to ignore these files since the code is 97 | # intended to run in multiple environments; otherwise, check them in: 98 | # .python-version 99 | 100 | # pipenv 101 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 102 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 103 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 104 | # install all needed dependencies. 105 | #Pipfile.lock 106 | 107 | # poetry 108 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 109 | # This is especially recommended for binary packages to ensure reproducibility, and is more 110 | # commonly ignored for libraries. 111 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 112 | #poetry.lock 113 | 114 | # pdm 115 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 116 | #pdm.lock 117 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 118 | # in version control. 119 | # https://pdm.fming.dev/#use-with-ide 120 | .pdm.toml 121 | 122 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 123 | __pypackages__/ 124 | 125 | # Celery stuff 126 | celerybeat-schedule 127 | celerybeat.pid 128 | 129 | # SageMath parsed files 130 | *.sage.py 131 | 132 | # Environments 133 | .env 134 | .venv 135 | env/ 136 | venv/ 137 | ENV/ 138 | env.bak/ 139 | venv.bak/ 140 | 141 | # Spyder project settings 142 | .spyderproject 143 | .spyproject 144 | 145 | # Rope project settings 146 | .ropeproject 147 | 148 | # mkdocs documentation 149 | /site 150 | 151 | # mypy 152 | .mypy_cache/ 153 | .dmypy.json 154 | dmypy.json 155 | 156 | # Pyre type checker 157 | .pyre/ 158 | 159 | # pytype static type analyzer 160 | .pytype/ 161 | 162 | # Cython debug symbols 163 | cython_debug/ 164 | 165 | # PyCharm 166 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 167 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 168 | # and can be added to the global gitignore or merged into this file. For a more nuclear 169 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 170 | #.idea/ 171 | 172 | ### Python Patch ### 173 | # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration 174 | poetry.toml 175 | 176 | # ruff 177 | .ruff_cache/ 178 | 179 | # LSP config files 180 | pyrightconfig.json 181 | 182 | # Remove *.out files 183 | *.out 184 | 185 | # End of https://www.toptal.com/developers/gitignore/api/python -------------------------------------------------------------------------------- /.pyre_configuration: -------------------------------------------------------------------------------- 1 | { 2 | "site_package_search_strategy": "pep561", 3 | "source_directories": [ 4 | "." 5 | ], 6 | "exclude": [ 7 | ".*build/.*" 8 | ] 9 | } 10 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to UniBench 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Meta's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Meta has a [bounty program](https://bugbounty.meta.com/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to UniBench, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Attribution-NonCommercial 4.0 International 2 | 3 | ======================================================================= 4 | 5 | Creative Commons Corporation ("Creative Commons") is not a law firm and 6 | does not provide legal services or legal advice. Distribution of 7 | Creative Commons public licenses does not create a lawyer-client or 8 | other relationship. Creative Commons makes its licenses and related 9 | information available on an "as-is" basis. Creative Commons gives no 10 | warranties regarding its licenses, any material licensed under their 11 | terms and conditions, or any related information. Creative Commons 12 | disclaims all liability for damages resulting from their use to the 13 | fullest extent possible. 14 | 15 | Using Creative Commons Public Licenses 16 | 17 | Creative Commons public licenses provide a standard set of terms and 18 | conditions that creators and other rights holders may use to share 19 | original works of authorship and other material subject to copyright 20 | and certain other rights specified in the public license below. The 21 | following considerations are for informational purposes only, are not 22 | exhaustive, and do not form part of our licenses. 23 | 24 | Considerations for licensors: Our public licenses are 25 | intended for use by those authorized to give the public 26 | permission to use material in ways otherwise restricted by 27 | copyright and certain other rights. Our licenses are 28 | irrevocable. Licensors should read and understand the terms 29 | and conditions of the license they choose before applying it. 30 | Licensors should also secure all rights necessary before 31 | applying our licenses so that the public can reuse the 32 | material as expected. Licensors should clearly mark any 33 | material not subject to the license. This includes other CC- 34 | licensed material, or material used under an exception or 35 | limitation to copyright. More considerations for licensors: 36 | wiki.creativecommons.org/Considerations_for_licensors 37 | 38 | Considerations for the public: By using one of our public 39 | licenses, a licensor grants the public permission to use the 40 | licensed material under specified terms and conditions. If 41 | the licensor's permission is not necessary for any reason--for 42 | example, because of any applicable exception or limitation to 43 | copyright--then that use is not regulated by the license. Our 44 | licenses grant only permissions under copyright and certain 45 | other rights that a licensor has authority to grant. Use of 46 | the licensed material may still be restricted for other 47 | reasons, including because others have copyright or other 48 | rights in the material. A licensor may make special requests, 49 | such as asking that all changes be marked or described. 50 | Although not required by our licenses, you are encouraged to 51 | respect those requests where reasonable. More considerations 52 | for the public: 53 | wiki.creativecommons.org/Considerations_for_licensees 54 | 55 | ======================================================================= 56 | 57 | Creative Commons Attribution-NonCommercial 4.0 International Public 58 | License 59 | 60 | By exercising the Licensed Rights (defined below), You accept and agree 61 | to be bound by the terms and conditions of this Creative Commons 62 | Attribution-NonCommercial 4.0 International Public License ("Public 63 | License"). To the extent this Public License may be interpreted as a 64 | contract, You are granted the Licensed Rights in consideration of Your 65 | acceptance of these terms and conditions, and the Licensor grants You 66 | such rights in consideration of benefits the Licensor receives from 67 | making the Licensed Material available under these terms and 68 | conditions. 69 | 70 | 71 | Section 1 -- Definitions. 72 | 73 | a. Adapted Material means material subject to Copyright and Similar 74 | Rights that is derived from or based upon the Licensed Material 75 | and in which the Licensed Material is translated, altered, 76 | arranged, transformed, or otherwise modified in a manner requiring 77 | permission under the Copyright and Similar Rights held by the 78 | Licensor. For purposes of this Public License, where the Licensed 79 | Material is a musical work, performance, or sound recording, 80 | Adapted Material is always produced where the Licensed Material is 81 | synched in timed relation with a moving image. 82 | 83 | b. Adapter's License means the license You apply to Your Copyright 84 | and Similar Rights in Your contributions to Adapted Material in 85 | accordance with the terms and conditions of this Public License. 86 | 87 | c. Copyright and Similar Rights means copyright and/or similar rights 88 | closely related to copyright including, without limitation, 89 | performance, broadcast, sound recording, and Sui Generis Database 90 | Rights, without regard to how the rights are labeled or 91 | categorized. For purposes of this Public License, the rights 92 | specified in Section 2(b)(1)-(2) are not Copyright and Similar 93 | Rights. 94 | d. Effective Technological Measures means those measures that, in the 95 | absence of proper authority, may not be circumvented under laws 96 | fulfilling obligations under Article 11 of the WIPO Copyright 97 | Treaty adopted on December 20, 1996, and/or similar international 98 | agreements. 99 | 100 | e. Exceptions and Limitations means fair use, fair dealing, and/or 101 | any other exception or limitation to Copyright and Similar Rights 102 | that applies to Your use of the Licensed Material. 103 | 104 | f. Licensed Material means the artistic or literary work, database, 105 | or other material to which the Licensor applied this Public 106 | License. 107 | 108 | g. Licensed Rights means the rights granted to You subject to the 109 | terms and conditions of this Public License, which are limited to 110 | all Copyright and Similar Rights that apply to Your use of the 111 | Licensed Material and that the Licensor has authority to license. 112 | 113 | h. Licensor means the individual(s) or entity(ies) granting rights 114 | under this Public License. 115 | 116 | i. NonCommercial means not primarily intended for or directed towards 117 | commercial advantage or monetary compensation. For purposes of 118 | this Public License, the exchange of the Licensed Material for 119 | other material subject to Copyright and Similar Rights by digital 120 | file-sharing or similar means is NonCommercial provided there is 121 | no payment of monetary compensation in connection with the 122 | exchange. 123 | 124 | j. Share means to provide material to the public by any means or 125 | process that requires permission under the Licensed Rights, such 126 | as reproduction, public display, public performance, distribution, 127 | dissemination, communication, or importation, and to make material 128 | available to the public including in ways that members of the 129 | public may access the material from a place and at a time 130 | individually chosen by them. 131 | 132 | k. Sui Generis Database Rights means rights other than copyright 133 | resulting from Directive 96/9/EC of the European Parliament and of 134 | the Council of 11 March 1996 on the legal protection of databases, 135 | as amended and/or succeeded, as well as other essentially 136 | equivalent rights anywhere in the world. 137 | 138 | l. You means the individual or entity exercising the Licensed Rights 139 | under this Public License. Your has a corresponding meaning. 140 | 141 | 142 | Section 2 -- Scope. 143 | 144 | a. License grant. 145 | 146 | 1. Subject to the terms and conditions of this Public License, 147 | the Licensor hereby grants You a worldwide, royalty-free, 148 | non-sublicensable, non-exclusive, irrevocable license to 149 | exercise the Licensed Rights in the Licensed Material to: 150 | 151 | a. reproduce and Share the Licensed Material, in whole or 152 | in part, for NonCommercial purposes only; and 153 | 154 | b. produce, reproduce, and Share Adapted Material for 155 | NonCommercial purposes only. 156 | 157 | 2. Exceptions and Limitations. For the avoidance of doubt, where 158 | Exceptions and Limitations apply to Your use, this Public 159 | License does not apply, and You do not need to comply with 160 | its terms and conditions. 161 | 162 | 3. Term. The term of this Public License is specified in Section 163 | 6(a). 164 | 165 | 4. Media and formats; technical modifications allowed. The 166 | Licensor authorizes You to exercise the Licensed Rights in 167 | all media and formats whether now known or hereafter created, 168 | and to make technical modifications necessary to do so. The 169 | Licensor waives and/or agrees not to assert any right or 170 | authority to forbid You from making technical modifications 171 | necessary to exercise the Licensed Rights, including 172 | technical modifications necessary to circumvent Effective 173 | Technological Measures. For purposes of this Public License, 174 | simply making modifications authorized by this Section 2(a) 175 | (4) never produces Adapted Material. 176 | 177 | 5. Downstream recipients. 178 | 179 | a. Offer from the Licensor -- Licensed Material. Every 180 | recipient of the Licensed Material automatically 181 | receives an offer from the Licensor to exercise the 182 | Licensed Rights under the terms and conditions of this 183 | Public License. 184 | 185 | b. No downstream restrictions. You may not offer or impose 186 | any additional or different terms or conditions on, or 187 | apply any Effective Technological Measures to, the 188 | Licensed Material if doing so restricts exercise of the 189 | Licensed Rights by any recipient of the Licensed 190 | Material. 191 | 192 | 6. No endorsement. Nothing in this Public License constitutes or 193 | may be construed as permission to assert or imply that You 194 | are, or that Your use of the Licensed Material is, connected 195 | with, or sponsored, endorsed, or granted official status by, 196 | the Licensor or others designated to receive attribution as 197 | provided in Section 3(a)(1)(A)(i). 198 | 199 | b. Other rights. 200 | 201 | 1. Moral rights, such as the right of integrity, are not 202 | licensed under this Public License, nor are publicity, 203 | privacy, and/or other similar personality rights; however, to 204 | the extent possible, the Licensor waives and/or agrees not to 205 | assert any such rights held by the Licensor to the limited 206 | extent necessary to allow You to exercise the Licensed 207 | Rights, but not otherwise. 208 | 209 | 2. Patent and trademark rights are not licensed under this 210 | Public License. 211 | 212 | 3. To the extent possible, the Licensor waives any right to 213 | collect royalties from You for the exercise of the Licensed 214 | Rights, whether directly or through a collecting society 215 | under any voluntary or waivable statutory or compulsory 216 | licensing scheme. In all other cases the Licensor expressly 217 | reserves any right to collect such royalties, including when 218 | the Licensed Material is used other than for NonCommercial 219 | purposes. 220 | 221 | 222 | Section 3 -- License Conditions. 223 | 224 | Your exercise of the Licensed Rights is expressly made subject to the 225 | following conditions. 226 | 227 | a. Attribution. 228 | 229 | 1. If You Share the Licensed Material (including in modified 230 | form), You must: 231 | 232 | a. retain the following if it is supplied by the Licensor 233 | with the Licensed Material: 234 | 235 | i. identification of the creator(s) of the Licensed 236 | Material and any others designated to receive 237 | attribution, in any reasonable manner requested by 238 | the Licensor (including by pseudonym if 239 | designated); 240 | 241 | ii. a copyright notice; 242 | 243 | iii. a notice that refers to this Public License; 244 | 245 | iv. a notice that refers to the disclaimer of 246 | warranties; 247 | 248 | v. a URI or hyperlink to the Licensed Material to the 249 | extent reasonably practicable; 250 | 251 | b. indicate if You modified the Licensed Material and 252 | retain an indication of any previous modifications; and 253 | 254 | c. indicate the Licensed Material is licensed under this 255 | Public License, and include the text of, or the URI or 256 | hyperlink to, this Public License. 257 | 258 | 2. You may satisfy the conditions in Section 3(a)(1) in any 259 | reasonable manner based on the medium, means, and context in 260 | which You Share the Licensed Material. For example, it may be 261 | reasonable to satisfy the conditions by providing a URI or 262 | hyperlink to a resource that includes the required 263 | information. 264 | 265 | 3. If requested by the Licensor, You must remove any of the 266 | information required by Section 3(a)(1)(A) to the extent 267 | reasonably practicable. 268 | 269 | 4. If You Share Adapted Material You produce, the Adapter's 270 | License You apply must not prevent recipients of the Adapted 271 | Material from complying with this Public License. 272 | 273 | 274 | Section 4 -- Sui Generis Database Rights. 275 | 276 | Where the Licensed Rights include Sui Generis Database Rights that 277 | apply to Your use of the Licensed Material: 278 | 279 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right 280 | to extract, reuse, reproduce, and Share all or a substantial 281 | portion of the contents of the database for NonCommercial purposes 282 | only; 283 | 284 | b. if You include all or a substantial portion of the database 285 | contents in a database in which You have Sui Generis Database 286 | Rights, then the database in which You have Sui Generis Database 287 | Rights (but not its individual contents) is Adapted Material; and 288 | 289 | c. You must comply with the conditions in Section 3(a) if You Share 290 | all or a substantial portion of the contents of the database. 291 | 292 | For the avoidance of doubt, this Section 4 supplements and does not 293 | replace Your obligations under this Public License where the Licensed 294 | Rights include other Copyright and Similar Rights. 295 | 296 | 297 | Section 5 -- Disclaimer of Warranties and Limitation of Liability. 298 | 299 | a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE 300 | EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS 301 | AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF 302 | ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, 303 | IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, 304 | WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR 305 | PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, 306 | ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT 307 | KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT 308 | ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. 309 | 310 | b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE 311 | TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, 312 | NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, 313 | INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, 314 | COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR 315 | USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN 316 | ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR 317 | DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR 318 | IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. 319 | 320 | c. The disclaimer of warranties and limitation of liability provided 321 | above shall be interpreted in a manner that, to the extent 322 | possible, most closely approximates an absolute disclaimer and 323 | waiver of all liability. 324 | 325 | 326 | Section 6 -- Term and Termination. 327 | 328 | a. This Public License applies for the term of the Copyright and 329 | Similar Rights licensed here. However, if You fail to comply with 330 | this Public License, then Your rights under this Public License 331 | terminate automatically. 332 | 333 | b. Where Your right to use the Licensed Material has terminated under 334 | Section 6(a), it reinstates: 335 | 336 | 1. automatically as of the date the violation is cured, provided 337 | it is cured within 30 days of Your discovery of the 338 | violation; or 339 | 340 | 2. upon express reinstatement by the Licensor. 341 | 342 | For the avoidance of doubt, this Section 6(b) does not affect any 343 | right the Licensor may have to seek remedies for Your violations 344 | of this Public License. 345 | 346 | c. For the avoidance of doubt, the Licensor may also offer the 347 | Licensed Material under separate terms or conditions or stop 348 | distributing the Licensed Material at any time; however, doing so 349 | will not terminate this Public License. 350 | 351 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public 352 | License. 353 | 354 | 355 | Section 7 -- Other Terms and Conditions. 356 | 357 | a. The Licensor shall not be bound by any additional or different 358 | terms or conditions communicated by You unless expressly agreed. 359 | 360 | b. Any arrangements, understandings, or agreements regarding the 361 | Licensed Material not stated herein are separate from and 362 | independent of the terms and conditions of this Public License. 363 | 364 | 365 | Section 8 -- Interpretation. 366 | 367 | a. For the avoidance of doubt, this Public License does not, and 368 | shall not be interpreted to, reduce, limit, restrict, or impose 369 | conditions on any use of the Licensed Material that could lawfully 370 | be made without permission under this Public License. 371 | 372 | b. To the extent possible, if any provision of this Public License is 373 | deemed unenforceable, it shall be automatically reformed to the 374 | minimum extent necessary to make it enforceable. If the provision 375 | cannot be reformed, it shall be severed from this Public License 376 | without affecting the enforceability of the remaining terms and 377 | conditions. 378 | 379 | c. No term or condition of this Public License will be waived and no 380 | failure to comply consented to unless expressly agreed to by the 381 | Licensor. 382 | 383 | d. Nothing in this Public License constitutes or may be interpreted 384 | as a limitation upon, or waiver of, any privileges and immunities 385 | that apply to the Licensor or You, including from the legal 386 | processes of any jurisdiction or authority. 387 | 388 | ======================================================================= 389 | 390 | Creative Commons is not a party to its public 391 | licenses. Notwithstanding, Creative Commons may elect to apply one of 392 | its public licenses to material it publishes and in those instances 393 | will be considered the “Licensor.” The text of the Creative Commons 394 | public licenses is dedicated to the public domain under the CC0 Public 395 | Domain Dedication. Except for the limited purpose of indicating that 396 | material is shared under a Creative Commons public license or as 397 | otherwise permitted by the Creative Commons policies published at 398 | creativecommons.org/policies, Creative Commons does not authorize the 399 | use of the trademark "Creative Commons" or any other trademark or logo 400 | of Creative Commons without its prior written consent including, 401 | without limitation, in connection with any unauthorized modifications 402 | to any of its public licenses or any other arrangements, 403 | understandings, or agreements concerning use of licensed material. For 404 | the avoidance of doubt, this paragraph does not form part of the 405 | public licenses. 406 | 407 | Creative Commons may be contacted at creativecommons.org. 408 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | | ![header](./assets/header.png "header") | 3 | | :------------------------------------------------: | 4 | | *[[Arxiv link](https://arxiv.org/abs/2408.04810)]* | 5 | 6 |

7 | Getting Started • 8 | Usage • 9 | Benchmarks & Models • 10 | Credit & Citation 11 |

12 | 13 | # Vision-Language Model Evaluation Repository 14 | 15 | This repository is designed to simplify the evaluation process of vision-language models. It provides a comprehensive set of tools and scripts for evaluating VLM models and benchmarks. We offer 60+ VLMs, inclusive of recent large-scale models like EVACLIP, with scales reaching up to 4.3B parameters and 12.8B training samples. Additionally, we provide implementations for 40+ evaluation benchmarks. 16 | 17 | ## News and Updates 18 | 19 | For the latest news and updates, see the snippet below. 20 | 21 | #### April 15, 2025 - v0.4.0 22 | - Removed FaceNet from required libraries. 23 | - Added SigLIP2 models 24 | - Added bivlc benchmark 25 | - Created benchmark_builder for future benchmark implementations 26 | - Added News & Updates section in README 27 | - Fixed Sun397 benchmark 28 | 29 | For full details, refer to the [UPDATES.md](./assets/UPDATES.md) file. 30 | 31 | ## Coming Soon 32 | 33 | - [ ] L-VLM (e.g. PaliGemma, LlavaNext) 34 | 35 | ## Getting Started 36 | 37 | Install the package: 38 | ``` 39 | pip install unibench -U 40 | ``` 41 | 42 |
43 | [option 2] Install Dependencies 44 | 45 | 46 | 1. Install the necessary dependencies by: 47 | - Option 1, creating a new conda env: `conda env create -f environment.yml` 48 | - Option 2, updating your conda env with required libraries: `conda env update --file environment.yml --prune` 49 | 2. Activate the environment: `conda activate unibench` 50 | 3. Install Spacy english language model: `python -m spacy download en_core_web_sm` 51 | 4. Install the package: `pip install git+https://github.com/facebookresearch/unibench` 52 | 53 |
54 | 55 | ## Usage 56 | 57 | ### Print out Results from Evaluated Models 58 | 59 | The following command will print the results of the evaluations on all benchmarks and models: 60 | 61 | ```console 62 | unibench show_results 63 | ``` 64 | 65 | ### Run Evaluation using Command Line 66 | 67 | The following command will run the evaluation on all benchmarks and models: 68 | 69 | ```console 70 | unibench evaluate 71 | ``` 72 | 73 | ### Run Evaluation using Custom Script 74 | 75 | The following command will run the evaluation on all benchmarks and models: 76 | 77 | ```python 78 | import unibench as vlm 79 | 80 | evaluator = vlm.Evaluator() 81 | evaluator.evaluate() 82 | ``` 83 | 84 | ### Arguments for Evaluation 85 | 86 | `evaluate` function takes the following arguments: 87 | 88 | ```console 89 | Args: 90 | save_freq (int): The frequency at which to save results. Defaults to 1000. 91 | face_blur (bool): Whether to use face blurring during evaluation. Defaults to False. 92 | device (str): The device to use for evaluation. Defaults to "cuda" if available otherwise "cpu". 93 | batch_per_gpu (int): Evaluation batch size per GPU. Defaults to 32. 94 | ``` 95 | 96 | 97 | The `Evaluator` class takes the following arguments: 98 | 99 | ```console 100 | Args: 101 | seed (int): Random seed for reproducibility. 102 | num_workers (int): Number of workers for data loading. 103 | models (Union[List[str], str]): List of models to evaluate or "all" to evaluate all available models. 104 | benchmarks (Union[List[str], str]): List of benchmarks to evaluate or "all" to evaluate all available benchmarks. 105 | model_id (Union[int, None]): Specific model ID to evaluate. 106 | benchmark_id (Union[int, None]): Specific benchmark ID to evaluate. 107 | output_dir (str): Directory to save evaluation results. 108 | benchmarks_dir (str): Directory containing benchmark data. 109 | download_aggregate_precomputed (bool): Whether to download aggregate precomputed results. 110 | download_all_precomputed (bool): Whether to download all precomputed results. 111 | ``` 112 | 113 | ### Example 114 | 115 | The following command will run the evaluation for openclip_vitB32 trained on metaclip400m and CLIP ResNet50 on vg_relation,clevr_distance,pcam,imageneta benchmarks: 116 | 117 | ```console 118 | unibench evaluate --models=[openclip_vitB32_metaclip_400m,clip_resnet50] --benchmarks=[vg_relation,clevr_distance,pcam,imageneta] 119 | ``` 120 | 121 | In addition to saving the results in `~/.cache/unibench`, the output would be a summary of the evaluation results: 122 | 123 | ```console 124 | model_name non-natural images reasoning relation robustness 125 | ──────────────────────────────────────────────────────────────────────────────────────── 126 | clip_resnet50 63.95 14.89 54.13 23.27 127 | openclip_vitB32_metaclip_400m 63.87 19.46 51.54 28.71 128 | ``` 129 | 130 | ## Supported Models and benchmarks 131 | Full list of models and benchmarks are available in the [models_zoo](unibench/models_zoo/README.md) and [benchmarks_zoo](unibench/benchmarks_zoo/README.md). You are also able to run the following commands: 132 | 133 | ```console 134 | unibench list_models 135 | # or 136 | unibench list_benchmarks 137 | ``` 138 | 139 | ### Sample Models 140 | 141 | 142 | | | Dataset Size (Million) | Number of Parameters (Million) | Learning Objective | Architecture | Model Name | 143 | | :----------------- | ---------------------: | -----------------------------: | :----------------- | :----------- | :------------ | 144 | | blip_vitB16_14m | 14 | 86 | BLIP | vit | BLIP ViT B 16 | 145 | | blip_vitL16_129m | 129 | 307 | BLIP | vit | BLIP ViT L 16 | 146 | | blip_vitB16_129m | 129 | 86 | BLIP | vit | BLIP ViT B 16 | 147 | | blip_vitB16_coco | 129 | 86 | BLIP | vit | BLIP ViT B 16 | 148 | | blip_vitB16_flickr | 129 | 86 | BLIP | vit | BLIP ViT B 16 | 149 | 150 | 151 | ### Sample benchmarks 152 | | | benchmark | benchmark_type | 153 | | :------------- | :-------- | :------------- | 154 | | clevr_distance | zero-shot | vtab | 155 | | fgvc_aircraft | zero-shot | transfer | 156 | | objectnet | zero-shot | robustness | 157 | | winoground | relation | relation | 158 | | imagenetc | zero-shot | corruption | 159 | 160 | ### benchmarks Overview 161 | 162 | | benchmark type | number of benchmarks | 163 | | :------------- | :------------------: | 164 | | ImageNet | 1 | 165 | | vtab | 18 | 166 | | transfer | 7 | 167 | | robustness | 6 | 168 | | relation | 6 | 169 | | corruption | 1 | 170 | 171 | 191 | 192 | 193 | 194 | ### How results are saved 195 | 196 | For each model, the results are saved in the output directory defined in constants: `~./.cache/unibench/outputs`. 197 | 198 | 199 | ### Add new Benchmark 200 | 201 | To add new benchmark, you can simply inherit from the `torch.utils.data.Dataset` class and implement the `__getitem__`, and `__len__` methods. For example, here is how to add ImageNetA as a new benchmark: 202 | 203 | ```python 204 | from functools import partial 205 | from unibench import Evaluator 206 | from unibench.benchmarks_zoo import ZeroShotBenchmarkHandler 207 | from torchvision.datasets import FashionMNIST 208 | 209 | class_names = [ 210 | "T-shirt/top", 211 | "Trouser", 212 | "Pullover", 213 | "Dress", 214 | "Coat", 215 | "Sandal", 216 | "Shirt", 217 | "Sneaker", 218 | "Bag", 219 | "Ankle boot", 220 | ] 221 | 222 | templates = ["an image of {}"] 223 | 224 | benchmark = partial( 225 | FashionMNIST, root="/fsx-robust/haideraltahan", train=False, download=True 226 | ) 227 | handler = partial( 228 | ZeroShotBenchmarkHandler, 229 | benchmark_name="fashion_mnist_new", 230 | classes=class_names, 231 | templates=templates, 232 | ) 233 | 234 | 235 | eval = Evaluator() 236 | 237 | eval.add_benchmark( 238 | benchmark, 239 | handler, 240 | meta_data={ 241 | "benchmark_type": "object recognition", 242 | }, 243 | ) 244 | eval.update_benchmark_list(["fashion_mnist_new"]) 245 | eval.update_model_list(["blip_vitB16_129m"]) 246 | eval.evaluate() 247 | ``` 248 | 249 | ### Add new Model 250 | 251 | The most important compontent of adding a new model is creating or using pre-existing `AbstractModel` and implementing `compute_zeroshot_weights`, `get_image_embeddings`, and `get_text_embeddings`, similar to how `ClipModel` works: 252 | 253 | ```python 254 | class ClipModel(AbstractModel): 255 | def __init__( 256 | self, 257 | model, 258 | model_name, 259 | **kwargs, 260 | ): 261 | super(ClipModel, self).__init__(model, model_name, **kwargs) 262 | 263 | def compute_zeroshot_weights(self): 264 | zeroshot_weights = [] 265 | for class_name in self.classes: 266 | texts = [template.format(class_name) for template in self.templates] 267 | 268 | class_embedding = self.get_text_embeddings(texts) 269 | 270 | class_embedding = class_embedding.mean(dim=0) 271 | class_embedding /= class_embedding.norm(dim=-1, keepdim=True) 272 | 273 | zeroshot_weights.append(class_embedding) 274 | self.zeroshot_weights = torch.stack(zeroshot_weights).T 275 | 276 | @torch.no_grad() 277 | def get_image_embeddings(self, images): 278 | image_features = self.model.encode_image(images.to(self.device)) 279 | image_features /= image_features.norm(dim=1, keepdim=True) 280 | return image_features.unsqueeze(1) 281 | 282 | @torch.no_grad() 283 | def get_text_embeddings(self, captions): 284 | if ( 285 | "truncate" in inspect.getfullargspec(self.tokenizer.__call__)[0] 286 | or "truncate" in inspect.getfullargspec(self.tokenizer)[0] 287 | ): 288 | caption_tokens = self.tokenizer( 289 | captions, context_length=self.context_length, truncate=True 290 | ).to(self.device) 291 | else: 292 | caption_tokens = self.tokenizer( 293 | captions, context_length=self.context_length 294 | ).to(self.device) 295 | 296 | caption_embeddings = self.model.encode_text(caption_tokens) 297 | caption_embeddings /= caption_embeddings.norm(dim=-1, keepdim=True) 298 | 299 | return caption_embeddings 300 | 301 | ``` 302 | 303 | Using the following class, we can then add models to the list of models. Here we have an example of adding and evaluating `ViTamin-L`. 304 | 305 | ```python 306 | from functools import partial 307 | from io import open_code 308 | from unibench import Evaluator 309 | from unibench.models_zoo.wrappers.clip import ClipModel 310 | import open_clip 311 | 312 | model, _, _ = open_clip.create_model_and_transforms( 313 | "ViTamin-L", pretrained="datacomp1b" 314 | ) 315 | 316 | tokenizer = open_clip.get_tokenizer("ViTamin-L") 317 | 318 | model = partial( 319 | ClipModel, 320 | model=model, 321 | model_name="vitamin_l_comp1b", 322 | tokenizer=tokenizer, 323 | input_resolution=model.visual.image_size[0], 324 | logit_scale=model.logit_scale, 325 | ) 326 | 327 | 328 | eval = Evaluator(benchmarks_dir="/fsx-checkpoints/haideraltahan/.cache/unibench/data") 329 | 330 | eval.add_model(model=model) 331 | eval.update_benchmark_list(["imagenet1k"]) 332 | eval.update_model_list(["vitamin_l_comp1b"]) 333 | eval.evaluate() 334 | 335 | ``` 336 | 337 | 338 | ## Contributing 339 | 340 | [Contributions](CONTRIBUTING.md) (e.g. adding new benchmarks/models), issues, and feature requests are welcome! For any changes, please open an issue first to discuss what you would like to change or improve. 341 | 342 | When contributing please ensure tests are passing: 343 | 344 | ```python 345 | # if need be, pip install pytest 346 | python -m pytest tests/ 347 | ``` 348 | 349 | ## License 350 | 351 | The majority of UniBench is licensed under [CC-BY-NC](LICENSE), however portions of the project are available under separate license terms: 352 | 353 | | License | Libraries | 354 | | :----------------- | :-------------------------------------------------------------------------------------------------------: | 355 | | MIT license | zipp, tabulate, rich, openai-clip, latextable, gdown | 356 | | Apache 2.0 license | transformers, timm, opencv-python, open-clip-torch, ftfy, fire, debtcollector, datasets, oslo.concurrency | 357 | | BSD license | torchvision, torch, seaborn, scipy, scikit-learn, fairscale, cycler, contourpy, click, GitPython | 358 | 359 | ## Citation 360 | 361 | If you use this repository in your research, please cite it as follows: 362 | 363 | ```bibtex 364 | @inproceedings{altahan2024unibenchvisualreasoningrequires, 365 | title={UniBench: Visual Reasoning Requires Rethinking Vision-Language Beyond Scaling}, 366 | author={Haider Al-Tahan and Quentin Garrido and Randall Balestriero and Diane Bouchacourt and Caner Hazirbas and Mark Ibrahim}, 367 | year={2024}, 368 | eprint={2408.04810}, 369 | archivePrefix={arXiv}, 370 | primaryClass={cs.CV}, 371 | url={https://arxiv.org/abs/2408.04810}, 372 | } 373 | ``` 374 | 375 | ## Recognition 376 | 377 | Library structure was inspired by [Robert Geirhos](https://github.com/rgeirhos)'s work https://github.com/bethgelab/model-vs-human 378 | -------------------------------------------------------------------------------- /assets/UPDATES.md: -------------------------------------------------------------------------------- 1 | # Updates 2 | 3 | ## April 15, 2025 - v0.4.0 4 | - Removed FaceNet from required libraries. 5 | - Added SigLIP2 models 6 | - Added bivlc benchmark 7 | - Created benchmark_builder for future benchmark implementations 8 | - Added News & Updates section in README 9 | - Fixed Sun397 benchmark 10 | 11 | ## September 15, 2024 - v0.3.1 12 | - Fixes type expectation in utils to print models by @marksibrahim 13 | 14 | ## September 3, 2024 - v0.3.0 15 | - Add long descripton to setup.py for pypi by @hazirbas 16 | - Fix type error in reading data frame by @hazirbas 17 | - Update 0.3.0 18 | 19 | ## August 15, 2024 - v0.2.0 20 | - Initial release 21 | -------------------------------------------------------------------------------- /assets/header.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/unibench/b32b4018f4b5f5daf7fd6771db09898ae4eff351/assets/header.png -------------------------------------------------------------------------------- /benchmark_builder/README.md: -------------------------------------------------------------------------------- 1 | ## Here we have documentations on how we generate teh huggingface cdatasets -------------------------------------------------------------------------------- /benchmark_builder/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ -------------------------------------------------------------------------------- /benchmark_builder/benchmarks/__init__.py: -------------------------------------------------------------------------------- 1 | from .sundataset import SUN397 -------------------------------------------------------------------------------- /benchmark_builder/benchmarks/sundataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any, Callable, Optional, Tuple, Union 3 | 4 | import PIL.Image 5 | 6 | from torchvision.datasets import VisionDataset 7 | from torchvision.datasets.utils import download_and_extract_archive 8 | 9 | class SUN397(VisionDataset): 10 | """`The SUN397 Data Set `_. 11 | 12 | The SUN397 or Scene UNderstanding (SUN) is a dataset for scene recognition consisting of 13 | 397 categories with 108'754 images. 14 | 15 | Args: 16 | root (str or ``pathlib.Path``): Root directory of the dataset. 17 | transform (callable, optional): A function/transform that takes in a PIL image and returns a transformed 18 | version. E.g, ``transforms.RandomCrop``. 19 | target_transform (callable, optional): A function/transform that takes in the target and transforms it. 20 | download (bool, optional): If true, downloads the dataset from the internet and 21 | puts it in root directory. If dataset is already downloaded, it is not 22 | downloaded again. 23 | """ 24 | 25 | _DATASET_URL = "http://vision.princeton.edu/projects/2010/SUN/SUN397.tar.gz" 26 | _DATASET_MD5 = "8ca2778205c41d23104230ba66911c7a" 27 | _PARTITION_INFO_URL = "https://vision.princeton.edu/projects/2010/SUN/download/Partitions.zip" 28 | 29 | 30 | def __init__( 31 | self, 32 | root: Union[str, Path], 33 | transform: Optional[Callable] = None, 34 | target_transform: Optional[Callable] = None, 35 | download: bool = False, 36 | partition_idx: int = 1, 37 | ) -> None: 38 | super().__init__(root, transform=transform, target_transform=target_transform) 39 | assert partition_idx > 0 and partition_idx <= 10, "partition_idx should be between 1 and 10" 40 | self._data_dir = Path(self.root) / "SUN397" 41 | 42 | if download: 43 | self._download() 44 | 45 | if not self._check_exists(): 46 | raise RuntimeError("Dataset not found. You can use download=True to download it") 47 | 48 | with open(self._data_dir / "ClassName.txt") as f: 49 | self.classes = [c[3:].strip() for c in f] 50 | 51 | self.class_to_idx = dict(zip(self.classes, range(len(self.classes)))) 52 | 53 | self.partition = self._data_dir / f"Testing_{partition_idx:02}.txt" 54 | with open(self.partition) as f: 55 | self.partition_list = [line.strip() for line in f] 56 | 57 | self._image_files = [] 58 | for entry in self.partition_list: 59 | image_path = self._data_dir / entry.lstrip('/') 60 | if image_path.exists(): 61 | self._image_files.append(image_path) 62 | else: 63 | print(f"Image not found: {image_path}") 64 | self._labels = [ 65 | self.class_to_idx["/".join(path.relative_to(self._data_dir).parts[1:-1])] for path in self._image_files 66 | ] 67 | self.classes = [cl.replace("_", " ").replace("/", " ") for cl in self.classes] 68 | 69 | def __len__(self) -> int: 70 | return len(self._image_files) 71 | 72 | def __getitem__(self, idx: int) -> Tuple[Any, Any]: 73 | image_file, label = self._image_files[idx], self._labels[idx] 74 | image = PIL.Image.open(image_file).convert("RGB") 75 | 76 | if self.transform: 77 | image = self.transform(image) 78 | 79 | if self.target_transform: 80 | label = self.target_transform(label) 81 | 82 | return image, label 83 | 84 | def _check_exists(self) -> bool: 85 | return self._data_dir.is_dir() 86 | 87 | def _download(self) -> None: 88 | if self._check_exists(): 89 | return 90 | download_and_extract_archive(self._DATASET_URL, download_root=self.root, md5=self._DATASET_MD5) 91 | download_and_extract_archive(self._PARTITION_INFO_URL, download_root=self._data_dir) -------------------------------------------------------------------------------- /benchmark_builder/main.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | ###### CREDIT GOES TO https://github.com/LAION-AI/CLIP_benchmark/blob/main/clip_benchmark/datasets ###### 9 | 10 | from pathlib import Path 11 | import fire 12 | from tqdm import tqdm 13 | import shutil 14 | import webdataset 15 | from class_names import CLASS_NAMES 16 | from templates import TEMPLATES 17 | from utils import * 18 | import os 19 | from benchmarks import SUN397 20 | from datasets import load_dataset 21 | import numpy as np 22 | 23 | def main( 24 | dataset_name, 25 | root_dir="/research/haider/Datasets", 26 | language="en", 27 | upload2huggingface=True, 28 | image_format="webp", 29 | max_size=500_000_000, 30 | num_workers=64, 31 | ): 32 | transform = PIL_to_bytes(image_format) 33 | classnames = ( 34 | CLASS_NAMES[language][dataset_name] if dataset_name in CLASS_NAMES[language] else None 35 | ) 36 | templates = ( 37 | TEMPLATES[language][dataset_name] 38 | if dataset_name in TEMPLATES[language] 39 | else TEMPLATES[language]["imagenet1k"] 40 | ) 41 | 42 | if dataset_name == "sun397": 43 | ds = SUN397(root=root_dir, transform=transform, download=True, partition_idx=1) 44 | ds.templates = templates 45 | elif dataset_name == "bivlc": 46 | ds = load_dataset("imirandam/BiVLC", split="test", cache_dir=root_dir) 47 | ds = ds.map(lambda example: {"image": [transform(example["image"]), transform(example["negative_image"])], "captions": [example['caption'], example['negative_caption']], "split": f"{example['type']}\n{example['subtype']}"}, num_proc=num_workers) 48 | ds = [(example['image'], example['captions'], example['split']) for example in ds] 49 | ds = ListDataset(ds) 50 | else: 51 | raise ValueError(f"Unknown dataset: {dataset_name}") 52 | 53 | output_dir = Path(root_dir) / dataset_name 54 | split_dir = output_dir / "test" 55 | if output_dir.exists(): 56 | shutil.rmtree(output_dir) 57 | 58 | split_dir.mkdir(parents=True, exist_ok=True) 59 | 60 | dataloader = torch.utils.data.DataLoader( 61 | ds, 62 | batch_size=1, 63 | num_workers=num_workers, 64 | collate_fn=lambda batch: batch[0] # No collate, only for multiprocessing 65 | ) 66 | 67 | 68 | if hasattr(ds, "classes") and ds.classes: 69 | classnames_fname = output_dir / "classnames.txt" 70 | with open(classnames_fname, "w") as classnames_file: 71 | print(*ds.classes, sep="\n", end="\n", file=classnames_file) 72 | print("Saved class names to '%s'" % classnames_fname) 73 | else: 74 | print("WARNING: No class names found") 75 | 76 | if hasattr(ds, "templates") and ds.templates: 77 | templates_fname = output_dir / "zeroshot_classification_templates.txt" 78 | with open(templates_fname, "w") as templates_file: 79 | print(*ds.templates, sep="\n", end="\n", file=templates_file) 80 | print("Saved class names to '%s'" % templates_fname) 81 | else: 82 | print("WARNING: No zeroshot classification templates found") 83 | 84 | data_fname = os.path.join(split_dir, r"%d.tar") 85 | sink = webdataset.ShardWriter( 86 | data_fname, 87 | maxsize=max_size 88 | ) 89 | nsamples = 0 90 | for index, batch in enumerate(tqdm(dataloader, desc="Converting")): 91 | if len(batch) == 2: 92 | input, output = batch 93 | elif len(batch) == 3: 94 | input, output, split = batch 95 | else: 96 | raise ValueError(f"Unknown batch size: {len(batch)}") 97 | 98 | nsamples += 1 99 | 100 | if isinstance(input, bytes): 101 | input = {f"0.{image_format}": input} 102 | elif isinstance(input, list) or isinstance(input, torch.Tensor): 103 | input = {f"{i}.{image_format}": img for i, img in enumerate(input)} 104 | 105 | if isinstance(output, int): 106 | output = {'cls': output} 107 | elif isinstance(output, list): 108 | output = {'npy': np.array(output)} 109 | 110 | if split is not None: 111 | output['split.txt'] = split 112 | 113 | sink.write({ 114 | "__key__": "s%07d" % index, 115 | **input, 116 | **output 117 | }) 118 | num_shards = sink.shard 119 | sink.close() 120 | print("Saved dataset to '%s'" % data_fname.replace(r"%d", "{0..%d}" % (num_shards - 1))) 121 | nshards_fname = split_dir / "nshards.txt" 122 | with open(nshards_fname, "w") as nshards_file: 123 | print(num_shards, end="\n", file=nshards_file) 124 | print("Saved number of shards = %d to '%s'" % (num_shards, nshards_fname)) 125 | print("Final dataset size:", nsamples) 126 | if upload2huggingface: 127 | os.system(f"huggingface-cli upload wds_{dataset_name} {str(output_dir)} {str(output_dir)} --repo-type dataset") 128 | 129 | 130 | if __name__ == "__main__": 131 | fire.Fire(main) 132 | -------------------------------------------------------------------------------- /benchmark_builder/templates.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | ###### CREDIT GOES TO https://github.com/LAION-AI/CLIP_benchmark/blob/main/clip_benchmark/datasets ###### 9 | 10 | TEMPLATES = { 11 | "en": { 12 | "cifar10": [ 13 | "a photo of a {c}.", 14 | "a blurry photo of a {c}.", 15 | "a black and white photo of a {c}.", 16 | "a low contrast photo of a {c}.", 17 | "a high contrast photo of a {c}.", 18 | "a bad photo of a {c}.", 19 | "a good photo of a {c}.", 20 | "a photo of a small {c}.", 21 | "a photo of a big {c}.", 22 | "a photo of the {c}.", 23 | "a blurry photo of the {c}.", 24 | "a black and white photo of the {c}.", 25 | "a low contrast photo of the {c}.", 26 | "a high contrast photo of the {c}.", 27 | "a bad photo of the {c}.", 28 | "a good photo of the {c}.", 29 | "a photo of the small {c}.", 30 | "a photo of the big {c}.", 31 | ], 32 | "cifar100": [ 33 | "a photo of a {c}.", 34 | "a blurry photo of a {c}.", 35 | "a black and white photo of a {c}.", 36 | "a low contrast photo of a {c}.", 37 | "a high contrast photo of a {c}.", 38 | "a bad photo of a {c}.", 39 | "a good photo of a {c}.", 40 | "a photo of a small {c}.", 41 | "a photo of a big {c}.", 42 | "a photo of the {c}.", 43 | "a blurry photo of the {c}.", 44 | "a black and white photo of the {c}.", 45 | "a low contrast photo of the {c}.", 46 | "a high contrast photo of the {c}.", 47 | "a bad photo of the {c}.", 48 | "a good photo of the {c}.", 49 | "a photo of the small {c}.", 50 | "a photo of the big {c}.", 51 | ], 52 | "imagenet1k": [ 53 | "a bad photo of a {c}.", 54 | "a photo of many {c}.", 55 | "a sculpture of a {c}.", 56 | "a photo of the hard to see {c}.", 57 | "a low resolution photo of the {c}.", 58 | "a rendering of a {c}.", 59 | "graffiti of a {c}.", 60 | "a bad photo of the {c}.", 61 | "a cropped photo of the {c}.", 62 | "a tattoo of a {c}.", 63 | "the embroidered {c}.", 64 | "a photo of a hard to see {c}.", 65 | "a bright photo of a {c}.", 66 | "a photo of a clean {c}.", 67 | "a photo of a dirty {c}.", 68 | "a dark photo of the {c}.", 69 | "a drawing of a {c}.", 70 | "a photo of my {c}.", 71 | "the plastic {c}.", 72 | "a photo of the cool {c}.", 73 | "a close-up photo of a {c}.", 74 | "a black and white photo of the {c}.", 75 | "a painting of the {c}.", 76 | "a painting of a {c}.", 77 | "a pixelated photo of the {c}.", 78 | "a sculpture of the {c}.", 79 | "a bright photo of the {c}.", 80 | "a cropped photo of a {c}.", 81 | "a plastic {c}.", 82 | "a photo of the dirty {c}.", 83 | "a jpeg corrupted photo of a {c}.", 84 | "a blurry photo of the {c}.", 85 | "a photo of the {c}.", 86 | "a good photo of the {c}.", 87 | "a rendering of the {c}.", 88 | "a {c} in a video game.", 89 | "a photo of one {c}.", 90 | "a doodle of a {c}.", 91 | "a close-up photo of the {c}.", 92 | "a photo of a {c}.", 93 | "the origami {c}.", 94 | "the {c} in a video game.", 95 | "a sketch of a {c}.", 96 | "a doodle of the {c}.", 97 | "a origami {c}.", 98 | "a low resolution photo of a {c}.", 99 | "the toy {c}.", 100 | "a rendition of the {c}.", 101 | "a photo of the clean {c}.", 102 | "a photo of a large {c}.", 103 | "a rendition of a {c}.", 104 | "a photo of a nice {c}.", 105 | "a photo of a weird {c}.", 106 | "a blurry photo of a {c}.", 107 | "a cartoon {c}.", 108 | "art of a {c}.", 109 | "a sketch of the {c}.", 110 | "a embroidered {c}.", 111 | "a pixelated photo of a {c}.", 112 | "itap of the {c}.", 113 | "a jpeg corrupted photo of the {c}.", 114 | "a good photo of a {c}.", 115 | "a plushie {c}.", 116 | "a photo of the nice {c}.", 117 | "a photo of the small {c}.", 118 | "a photo of the weird {c}.", 119 | "the cartoon {c}.", 120 | "art of the {c}.", 121 | "a drawing of the {c}.", 122 | "a photo of the large {c}.", 123 | "a black and white photo of a {c}.", 124 | "the plushie {c}.", 125 | "a dark photo of a {c}.", 126 | "itap of a {c}.", 127 | "graffiti of the {c}.", 128 | "a toy {c}.", 129 | "itap of my {c}.", 130 | "a photo of a cool {c}.", 131 | "a photo of a small {c}.", 132 | "a tattoo of the {c}.", 133 | ], 134 | "food101": ["a photo of {c}, a type of food."], 135 | "sun397": ["a photo of a {c}.", "a photo of the {c}."], 136 | "cars": [ 137 | "a photo of a {c}.", 138 | "a photo of the {c}.", 139 | "a photo of my {c}.", 140 | "i love my {c}!", 141 | "a photo of my dirty {c}.", 142 | "a photo of my clean {c}.", 143 | "a photo of my new {c}.", 144 | "a photo of my old {c}.", 145 | ], 146 | "fgvc_aircraft": [ 147 | "a photo of a {c}, a type of aircraft.", 148 | "a photo of the {c}, a type of aircraft.", 149 | ], 150 | "dtd": [ 151 | "a photo of a {c} texture.", 152 | "a photo of a {c} pattern.", 153 | "a photo of a {c} thing.", 154 | "a photo of a {c} object.", 155 | "a photo of the {c} texture.", 156 | "a photo of the {c} pattern.", 157 | "a photo of the {c} thing.", 158 | "a photo of the {c} object.", 159 | ], 160 | "pets": ["a photo of a {c}, a type of pet."], 161 | "caltech101": [ 162 | "a photo of a {c}.", 163 | "a painting of a {c}.", 164 | "a plastic {c}.", 165 | "a sculpture of a {c}.", 166 | "a sketch of a {c}.", 167 | "a tattoo of a {c}.", 168 | "a toy {c}.", 169 | "a rendition of a {c}.", 170 | "a embroidered {c}.", 171 | "a cartoon {c}.", 172 | "a {c} in a video game.", 173 | "a plushie {c}.", 174 | "a origami {c}.", 175 | "art of a {c}.", 176 | "graffiti of a {c}.", 177 | "a drawing of a {c}.", 178 | "a doodle of a {c}.", 179 | "a photo of the {c}.", 180 | "a painting of the {c}.", 181 | "the plastic {c}.", 182 | "a sculpture of the {c}.", 183 | "a sketch of the {c}.", 184 | "a tattoo of the {c}.", 185 | "the toy {c}.", 186 | "a rendition of the {c}.", 187 | "the embroidered {c}.", 188 | "the cartoon {c}.", 189 | "the {c} in a video game.", 190 | "the plushie {c}.", 191 | "the origami {c}.", 192 | "art of the {c}.", 193 | "graffiti of the {c}.", 194 | "a drawing of the {c}.", 195 | "a doodle of the {c}.", 196 | ], 197 | "flowers": ["a photo of a {c}, a type of flower."], 198 | "mnist": ['a photo of the number: "{c}".'], 199 | "stl10": ["a photo of a {c}.", "a photo of the {c}."], 200 | "eurosat": [ 201 | "a centered satellite photo of {c}.", 202 | "a centered satellite photo of a {c}.", 203 | "a centered satellite photo of the {c}.", 204 | ], 205 | "gtsrb": [ 206 | 'a zoomed in photo of a "{c}" traffic sign.', 207 | 'a centered photo of a "{c}" traffic sign.', 208 | 'a close up photo of a "{c}" traffic sign.', 209 | ], 210 | "country211": [ 211 | "a photo i took in {c}.", 212 | "a photo i took while visiting {c}.", 213 | "a photo from my home country of {c}.", 214 | "a photo from my visit to {c}.", 215 | "a photo showing the country of {c}.", 216 | ], 217 | "renderedsst2": ["a {c} review of a movie."], 218 | "voc2007": ["a photo of a {c}."], 219 | "voc2007_multilabel": ["a photo of a {c}."], 220 | "fer2013": [ 221 | "a photo of a {c} looking face.", 222 | "a photo of a face showing the emotion: {c}.", 223 | "a photo of a face looking {c}.", 224 | "a face that looks {c}.", 225 | "they look {c}.", 226 | "look at how {c} they are.", 227 | ], 228 | "clevr_count_all": ["a picture of {c} objects"], 229 | "clevr_closest_object_distance": ["{c} shapes."], 230 | "pcam": ["a histopathology slide showing {c}", "histopathology image of {c}"], 231 | "svhn": [ 232 | "a photo of the number {c} written on a sign", 233 | "an outdoor house number {c}", 234 | "the number {c} in the center of the image", 235 | "an outdoor number {c} writte on a sign", 236 | "an outdoor number {c}", 237 | "a centered image of the number {c}", 238 | ], 239 | "resisc45": [ 240 | "a sattelite image of {c}", 241 | "an aerial view of {c}", 242 | "a sattelite photo of {c}", 243 | "{c} from above", 244 | ], 245 | "kitti_closest_vehicle_distance": ["{c}"], 246 | "smallnorb_label_azimuth": [ 247 | "an object rotated at {c}", 248 | "something rotated at {c}", 249 | "{c} rotation", 250 | "something at a {c} angle", 251 | ], 252 | "smallnorb_label_elevation": [ 253 | "an object rotated at {c}", 254 | "something rotated at {c}", 255 | "{c} rotation", 256 | "something at a {c} angle", 257 | ], 258 | "dsprites_label_x_position": [ 259 | "an object located at position {c}% on the horizontal axis" 260 | ], 261 | "dsprites_label_orientation": [ 262 | "an object rotated at {c}", 263 | "something rotated at {c}", 264 | "{c} rotation", 265 | "something at a {c} angle", 266 | ], 267 | "dmlab": ["{c}"], 268 | "diabetic_retinopathy": ["a retinal image with {c}"], 269 | "dummy": ["a photo of a {c}"], 270 | } 271 | } 272 | -------------------------------------------------------------------------------- /benchmark_builder/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | ###### CREDIT GOES TO https://github.com/LAION-AI/CLIP_benchmark/blob/main/clip_benchmark/datasets ###### 9 | 10 | import io 11 | import torch 12 | 13 | 14 | class ListDataset(torch.utils.data.Dataset): 15 | def __init__(self, data): 16 | self.data = data 17 | 18 | def __getitem__(self, idx): 19 | return self.data[idx] 20 | 21 | def __len__(self): 22 | return len(self.data) 23 | 24 | 25 | def PIL_to_bytes(image_format): 26 | OPTIONS = { 27 | "webp": dict(format="webp", lossless=True), 28 | "png": dict(format="png"), 29 | "jpg": dict(format="jpeg"), 30 | } 31 | 32 | def transform(image): 33 | bytestream = io.BytesIO() 34 | image.save(bytestream, **OPTIONS[image_format]) 35 | return bytestream.getvalue() 36 | 37 | return transform 38 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: base 2 | channels: 3 | - pytorch 4 | - rapidsai 5 | - nvidia 6 | - conda-forge 7 | - defaults 8 | dependencies: 9 | - _libgcc_mutex=0.1=conda_forge 10 | - _openmp_mutex=4.5=2_kmp_llvm 11 | - absl-py=2.1.0=pyhd8ed1ab_0 12 | - aiohttp=3.9.3=py310h2372a71_1 13 | - aiosignal=1.3.1=pyhd8ed1ab_0 14 | - archspec=0.2.3=pyhd8ed1ab_0 15 | - astunparse=1.6.3=pyhd8ed1ab_0 16 | - async-timeout=4.0.3=pyhd8ed1ab_0 17 | - attrs=23.2.0=pyh71513ae_0 18 | - aws-c-auth=0.7.11=h0b4cabd_1 19 | - aws-c-cal=0.6.9=h14ec70c_3 20 | - aws-c-common=0.9.12=hd590300_0 21 | - aws-c-compression=0.2.17=h572eabf_8 22 | - aws-c-event-stream=0.4.1=h97bb272_2 23 | - aws-c-http=0.8.0=h9129f04_2 24 | - aws-c-io=0.14.0=hf8f278a_1 25 | - aws-c-mqtt=0.10.1=h2b97f5f_0 26 | - aws-c-s3=0.4.9=hca09fc5_0 27 | - aws-c-sdkutils=0.1.13=h572eabf_1 28 | - aws-checksums=0.1.17=h572eabf_7 29 | - aws-crt-cpp=0.26.0=h04327c0_8 30 | - aws-sdk-cpp=1.11.210=hba3e011_10 31 | - blinker=1.7.0=pyhd8ed1ab_0 32 | - bokeh=3.4.0=pyhd8ed1ab_0 33 | - boltons=24.0.0=pyhd8ed1ab_0 34 | - brotli=1.0.9=he6710b0_2 35 | - brotli-python=1.1.0=py310hc6cd4ac_1 36 | - bzip2=1.0.8=hd590300_5 37 | - c-ares=1.28.1=hd590300_0 38 | - ca-certificates=2024.3.11=h06a4308_0 39 | - cached-property=1.5.2=hd8ed1ab_1 40 | - cached_property=1.5.2=pyha770c72_1 41 | - cachetools=5.3.3=pyhd8ed1ab_0 42 | - certifi=2024.2.2=pyhd8ed1ab_0 43 | - cffi=1.16.0=py310h2fee648_0 44 | - charset-normalizer=3.3.2=pyhd8ed1ab_0 45 | - click=8.1.7=unix_pyh707e725_0 46 | - cloudpickle=3.0.0=pyhd8ed1ab_0 47 | - colorama=0.4.6=pyhd8ed1ab_0 48 | - conda=24.3.0=py310hff52083_0 49 | - conda-libmamba-solver=24.1.0=pyhd8ed1ab_0 50 | - conda-package-handling=2.2.0=pyh38be061_0 51 | - conda-package-streaming=0.9.0=pyhd8ed1ab_0 52 | - contourpy=1.2.1=py310hd41b1e2_0 53 | - cryptography=42.0.5=py310h75e40e8_0 54 | - cuda-cccl=12.4.127=0 55 | - cuda-cccl_linux-64=12.0.90=ha770c72_1 56 | - cuda-cudart=12.1.105=0 57 | - cuda-cudart-dev=12.1.105=0 58 | - cuda-cudart-dev_linux-64=12.0.107=h59595ed_8 59 | - cuda-cudart-static=12.0.107=hd3aeb46_8 60 | - cuda-cudart-static_linux-64=12.0.107=h59595ed_8 61 | - cuda-cudart_linux-64=12.0.107=h59595ed_8 62 | - cuda-cupti=12.1.105=0 63 | - cuda-libraries=12.1.0=0 64 | - cuda-nvcc-dev_linux-64=12.0.76=ha770c72_1 65 | - cuda-nvcc-impl=12.0.76=h59595ed_1 66 | - cuda-nvcc-tools=12.0.76=h59595ed_1 67 | - cuda-nvrtc=12.1.105=0 68 | - cuda-nvtx=12.1.105=0 69 | - cuda-opencl=12.4.127=0 70 | - cuda-profiler-api=12.0.76=ha770c72_0 71 | - cuda-python=12.4.0=py310h9f9f131_1 72 | - cuda-runtime=12.1.0=0 73 | - cuda-version=12.0=hffde075_3 74 | - cudf=24.02.02=cuda12_py310_240227_gdd34fdbe35_0 75 | - cudnn=8.9.7.29=h092f7fd_3 76 | - cuml=24.02.00=cuda12_py310_240213_geb50e481d_0 77 | - cupy=13.0.0=py310h7aad9d2_3 78 | - cupy-core=13.0.0=py310had4011e_3 79 | - cycler=0.11.0=pyhd3eb1b0_0 80 | - cytoolz=0.12.3=py310h2372a71_0 81 | - dask=2024.1.1=pyhd8ed1ab_0 82 | - dask-core=2024.1.1=pyhd8ed1ab_0 83 | - dask-cuda=24.02.00=py310_240212_g96bedbc_0 84 | - dask-cudf=24.02.02=cuda12_py310_240227_gdd34fdbe35_0 85 | - distributed=2024.1.1=pyhd8ed1ab_0 86 | - distro=1.9.0=pyhd8ed1ab_0 87 | - dlpack=0.5=h9c3ff4c_0 88 | - fastrlock=0.8.2=py310hc6cd4ac_2 89 | - filelock=3.13.3=pyhd8ed1ab_0 90 | - flatbuffers=23.5.26=h59595ed_1 91 | - fmt=10.2.1=h00ab1b0_0 92 | - fonttools=4.25.0=pyhd3eb1b0_0 93 | - freetype=2.12.1=h267a509_2 94 | - frozenlist=1.4.1=py310h2372a71_0 95 | - fsspec=2024.3.1=pyhca7485f_0 96 | - gast=0.5.4=pyhd8ed1ab_0 97 | - gflags=2.2.2=he1b5a44_1004 98 | - giflib=5.2.1=h0b41bf4_3 99 | - glog=0.6.0=h6f12383_0 100 | - gmock=1.14.0=ha770c72_1 101 | - gmp=6.3.0=h59595ed_1 102 | - gmpy2=2.1.2=py310h3ec546c_1 103 | - google-auth=2.29.0=pyhca7485f_0 104 | - google-auth-oauthlib=1.2.0=pyhd8ed1ab_0 105 | - google-pasta=0.2.0=pyh8c360ce_0 106 | - grpcio=1.59.3=py310h1b8f574_0 107 | - gtest=1.14.0=h00ab1b0_1 108 | - h5py=3.10.0=nompi_py310h65828d5_101 109 | - hdf5=1.14.3=nompi_h4f84152_100 110 | - icu=73.2=h59595ed_0 111 | - idna=3.6=pyhd8ed1ab_0 112 | - importlib-metadata=7.1.0=pyha770c72_0 113 | - importlib_metadata=7.1.0=hd8ed1ab_0 114 | - jinja2=3.1.3=pyhd8ed1ab_0 115 | - joblib=1.4.0=pyhd8ed1ab_0 116 | - jsonpatch=1.33=pyhd8ed1ab_0 117 | - jsonpointer=2.4=py310hff52083_3 118 | - keras=2.15.0=pyhd8ed1ab_0 119 | - keyutils=1.6.1=h166bdaf_0 120 | - kiwisolver=1.4.4=py310h6a678d5_0 121 | - krb5=1.21.2=h659d440_0 122 | - lcms2=2.16=hb7c19ff_0 123 | - ld_impl_linux-64=2.40=h41732ed_0 124 | - lerc=4.0.0=h27087fc_0 125 | - libabseil=20230802.1=cxx17_h59595ed_0 126 | - libaec=1.1.3=h59595ed_0 127 | - libarchive=3.7.2=h2aa1ff5_1 128 | - libarrow=14.0.2=h84dd17c_3_cpu 129 | - libarrow-acero=14.0.2=h59595ed_3_cpu 130 | - libarrow-dataset=14.0.2=h59595ed_3_cpu 131 | - libarrow-flight=14.0.2=h120cb0d_3_cpu 132 | - libarrow-flight-sql=14.0.2=h61ff412_3_cpu 133 | - libarrow-gandiva=14.0.2=hacb8726_3_cpu 134 | - libarrow-substrait=14.0.2=h61ff412_3_cpu 135 | - libblas=3.9.0=22_linux64_openblas 136 | - libbrotlicommon=1.1.0=hd590300_1 137 | - libbrotlidec=1.1.0=hd590300_1 138 | - libbrotlienc=1.1.0=hd590300_1 139 | - libcblas=3.9.0=22_linux64_openblas 140 | - libcrc32c=1.1.2=h9c3ff4c_0 141 | - libcublas=12.1.0.26=0 142 | - libcublas-dev=12.1.0.26=0 143 | - libcudf=24.02.02=cuda12_240227_gdd34fdbe35_0 144 | - libcufft=11.0.2.4=0 145 | - libcufile=1.9.1.3=0 146 | - libcufile-dev=1.9.1.3=0 147 | - libcuml=24.02.00=cuda12_240213_geb50e481d_0 148 | - libcumlprims=24.02.00=cuda12_240213_g0e32024_0 149 | - libcurand=10.3.5.147=0 150 | - libcurand-dev=10.3.5.147=0 151 | - libcurl=8.7.1=hca28451_0 152 | - libcusolver=11.4.4.55=0 153 | - libcusolver-dev=11.4.4.55=0 154 | - libcusparse=12.0.2.55=0 155 | - libcusparse-dev=12.0.2.55=0 156 | - libdeflate=1.20=hd590300_0 157 | - libedit=3.1.20191231=he28a2e2_2 158 | - libev=4.33=hd590300_2 159 | - libevent=2.1.12=hf998b51_1 160 | - libffi=3.4.2=h7f98852_5 161 | - libgcc-ng=13.2.0=h807b86a_5 162 | - libgfortran-ng=13.2.0=h69a702a_5 163 | - libgfortran5=13.2.0=ha4646dd_5 164 | - libgoogle-cloud=2.12.0=h5206363_4 165 | - libgrpc=1.59.3=hd6c4280_0 166 | - libhwloc=2.9.3=default_h554bfaf_1009 167 | - libiconv=1.17=hd590300_2 168 | - libjpeg-turbo=3.0.0=hd590300_1 169 | - libkvikio=24.02.01=cuda12_240226_gfe01c15_0 170 | - liblapack=3.9.0=22_linux64_openblas 171 | - libllvm14=14.0.6=hcd5def8_4 172 | - libllvm15=15.0.7=hb3ce162_4 173 | - libmagma=2.7.2=h173bb3b_2 174 | - libmagma_sparse=2.7.2=h173bb3b_3 175 | - libmamba=1.5.8=had39da4_0 176 | - libmambapy=1.5.8=py310h39ff949_0 177 | - libnghttp2=1.58.0=h47da74e_1 178 | - libnl=3.9.0=hd590300_0 179 | - libnpp=12.0.2.50=0 180 | - libnsl=2.0.1=hd590300_0 181 | - libnvjitlink=12.1.105=0 182 | - libnvjpeg=12.1.1.14=0 183 | - libopenblas=0.3.27=pthreads_h413a1c8_0 184 | - libparquet=14.0.2=h352af49_3_cpu 185 | - libpng=1.6.43=h2797004_0 186 | - libprotobuf=4.24.4=hf27288f_0 187 | - libraft=24.02.00=cuda12_240212_g698d6c7b_0 188 | - libraft-headers=24.02.00=cuda12_240212_g698d6c7b_0 189 | - libraft-headers-only=24.02.00=cuda12_240212_g698d6c7b_0 190 | - libre2-11=2023.09.01=h7a70373_1 191 | - librmm=24.02.00=cuda12_240212_g09b406c1_0 192 | - libsolv=0.7.28=hfc55251_2 193 | - libsqlite=3.45.2=h2797004_0 194 | - libssh2=1.11.0=h0841786_0 195 | - libstdcxx-ng=13.2.0=h7e041cc_5 196 | - libthrift=0.19.0=hb90f79a_1 197 | - libtiff=4.6.0=h1dd3fc0_3 198 | - libtorch=2.1.2=cuda120_h86db2e7_300 199 | - libutf8proc=2.8.0=h166bdaf_0 200 | - libuuid=2.38.1=h0b41bf4_0 201 | - libuv=1.48.0=hd590300_0 202 | - libwebp-base=1.3.2=hd590300_0 203 | - libxcb=1.15=h0b41bf4_0 204 | - libxcrypt=4.4.36=hd590300_1 205 | - libxml2=2.12.6=h232c23b_1 206 | - libzlib=1.2.13=hd590300_5 207 | - llvm-openmp=18.1.2=h4dfa4b3_0 208 | - llvmlite=0.42.0=py310h1b8f574_1 209 | - locket=1.0.0=pyhd8ed1ab_0 210 | - lz4=4.3.3=py310h350c4a5_0 211 | - lz4-c=1.9.4=hcb278e6_0 212 | - lzo=2.10=h516909a_1000 213 | - magma=2.7.2=h51420fd_3 214 | - mamba=1.5.8=py310h51d5547_0 215 | - markdown=3.6=pyhd8ed1ab_0 216 | - markdown-it-py=3.0.0=pyhd8ed1ab_0 217 | - markupsafe=2.1.5=py310h2372a71_0 218 | - matplotlib-base=3.8.0=py310h1128e8f_0 219 | - mdurl=0.1.2=pyhd8ed1ab_0 220 | - menuinst=2.0.2=py310hff52083_0 221 | - mkl=2023.2.0=h84fe81f_50496 222 | - ml_dtypes=0.2.0=py310hcc13569_2 223 | - mpc=1.3.1=hfe3b2da_0 224 | - mpfr=4.2.1=h9458935_1 225 | - mpmath=1.3.0=pyhd8ed1ab_0 226 | - msgpack-python=1.0.7=py310hd41b1e2_0 227 | - multidict=6.0.5=py310h2372a71_0 228 | - munkres=1.1.4=py_0 229 | - nccl=2.21.5.1=h3a97aeb_0 230 | - ncurses=6.4.20240210=h59595ed_0 231 | - networkx=3.3=pyhd8ed1ab_1 232 | - numba=0.59.1=py310h7dc5dd1_0 233 | - numpy=1.24.4=py310ha4c1d20_0 234 | - nvcomp=3.0.6=h10b603f_0 235 | - nvtx=0.2.10=py310h2372a71_0 236 | - oauthlib=3.2.2=pyhd8ed1ab_0 237 | - openjpeg=2.5.2=h488ebb8_0 238 | - openssl=3.2.1=hd590300_1 239 | - opt_einsum=3.3.0=pyhc1e730c_2 240 | - orc=1.9.2=h4b38347_0 241 | - packaging=24.0=pyhd8ed1ab_0 242 | - pandas=1.5.3=py310h9b08913_1 243 | - partd=1.4.1=pyhd8ed1ab_0 244 | - pillow=10.3.0=py310hf73ecf8_0 245 | - pip=24.0=pyhd8ed1ab_0 246 | - platformdirs=4.2.0=pyhd8ed1ab_0 247 | - pluggy=1.4.0=pyhd8ed1ab_0 248 | - protobuf=4.24.4=py310h620c231_0 249 | - psutil=5.9.8=py310h2372a71_0 250 | - pthread-stubs=0.4=h36c2ea0_1001 251 | - pyarrow=14.0.2=py310hf9e7431_3_cpu 252 | - pyarrow-hotfix=0.6=pyhd8ed1ab_0 253 | - pyasn1=0.5.1=pyhd8ed1ab_0 254 | - pyasn1-modules=0.3.0=pyhd8ed1ab_0 255 | - pybind11-abi=4=hd8ed1ab_3 256 | - pycosat=0.6.6=py310h2372a71_0 257 | - pycparser=2.22=pyhd8ed1ab_0 258 | - pygments=2.17.2=pyhd8ed1ab_0 259 | - pyjwt=2.8.0=pyhd8ed1ab_1 260 | - pylibraft=24.02.00=cuda12_py310_240212_g698d6c7b_0 261 | - pynvjitlink=0.1.14=py310hdaa3023_0 262 | - pynvml=11.4.1=pyhd8ed1ab_0 263 | - pyopenssl=24.0.0=pyhd8ed1ab_0 264 | - pyparsing=3.0.9=py310h06a4308_0 265 | - pysocks=1.7.1=pyha2e5f31_6 266 | - python=3.10.14=hd12c33a_0_cpython 267 | - python-dateutil=2.9.0=pyhd8ed1ab_0 268 | - python-flatbuffers=24.3.25=pyh59ac667_0 269 | - python_abi=3.10=4_cp310 270 | - pytorch=2.1.2=cuda120_py310h8a81058_300 271 | - pytorch-cuda=12.1=ha16c6d3_5 272 | - pytz=2024.1=pyhd8ed1ab_0 273 | - pyu2f=0.1.5=pyhd8ed1ab_0 274 | - pyyaml=6.0.1=py310h2372a71_1 275 | - raft-dask=24.02.00=cuda12_py310_240212_g698d6c7b_0 276 | - rapids-dask-dependency=24.02.00=0 277 | - rdma-core=51.0=hd3aeb46_0 278 | - re2=2023.09.01=h7f4b329_1 279 | - readline=8.2=h8228510_1 280 | - reproc=14.2.4.post0=hd590300_1 281 | - reproc-cpp=14.2.4.post0=h59595ed_1 282 | - requests=2.31.0=pyhd8ed1ab_0 283 | - requests-oauthlib=2.0.0=pyhd8ed1ab_0 284 | - rich=13.7.1=pyhd8ed1ab_0 285 | - rmm=24.02.00=cuda12_py310_240212_g09b406c1_0 286 | - rsa=4.9=pyhd8ed1ab_0 287 | - ruamel.yaml=0.18.6=py310h2372a71_0 288 | - ruamel.yaml.clib=0.2.8=py310h2372a71_0 289 | - s2n=1.4.1=h06160fa_0 290 | - scikit-learn=1.3.0=py310h1128e8f_1 291 | - scipy=1.13.0=py310hb13e2d6_0 292 | - seaborn=0.12.2=py310h06a4308_0 293 | - setuptools=69.2.0=pyhd8ed1ab_0 294 | - six=1.16.0=pyh6c4a22f_0 295 | - sleef=3.5.1=h9b69904_2 296 | - snappy=1.1.10=hdb0a2a9_1 297 | - sortedcontainers=2.4.0=pyhd8ed1ab_0 298 | - spdlog=1.12.0=hd2e6256_2 299 | - sympy=1.12=pypyh9d50eac_103 300 | - tbb=2021.11.0=h00ab1b0_1 301 | - tblib=3.0.0=pyhd8ed1ab_0 302 | - tensorboard=2.15.2=pyhd8ed1ab_0 303 | - tensorboard-data-server=0.7.0=py310h75e40e8_1 304 | - tensorflow=2.15.0=cuda120py310h9360858_3 305 | - tensorflow-base=2.15.0=cuda120py310heceb7ac_3 306 | - tensorflow-estimator=2.15.0=cuda120py310h549c77d_3 307 | - termcolor=2.4.0=pyhd8ed1ab_0 308 | - threadpoolctl=2.2.0=pyh0d69192_0 309 | - tk=8.6.13=noxft_h4845f30_101 310 | - toolz=0.12.1=pyhd8ed1ab_0 311 | - tornado=6.4=py310h2372a71_0 312 | - tqdm=4.66.2=pyhd8ed1ab_0 313 | - treelite=4.0.0=py310h4a6579d_0 314 | - truststore=0.8.0=pyhd8ed1ab_0 315 | - typing-extensions=4.11.0=hd8ed1ab_0 316 | - typing_extensions=4.11.0=pyha770c72_0 317 | - tzdata=2024a=h0c530f3_0 318 | - ucx=1.15.0=hda83522_8 319 | - ucx-proc=1.0.0=gpu 320 | - ucx-py=0.36.00=py310_240212_g1266b48_0 321 | - urllib3=2.2.1=pyhd8ed1ab_0 322 | - werkzeug=3.0.2=pyhd8ed1ab_0 323 | - wheel=0.43.0=pyhd8ed1ab_1 324 | - wrapt=1.14.1=py310h5764c6d_1 325 | - xorg-libxau=1.0.11=hd590300_0 326 | - xorg-libxdmcp=1.1.3=h7f98852_0 327 | - xyzservices=2024.4.0=pyhd8ed1ab_0 328 | - xz=5.2.6=h166bdaf_0 329 | - yaml=0.2.5=h7f98852_2 330 | - yaml-cpp=0.8.0=h59595ed_0 331 | - yarl=1.9.4=py310h2372a71_0 332 | - zict=3.0.0=pyhd8ed1ab_0 333 | - zipp=3.17.0=pyhd8ed1ab_0 334 | - zstandard=0.22.0=py310h1275a96_0 335 | - zstd=1.5.5=hfc55251_0 336 | - pip: 337 | - facenet-pytorch==2.5.3 338 | - latextable==1.0.1 339 | - netaddr==1.2.1 340 | - netifaces==0.11.0 341 | - texttable==1.7.0 342 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import setuptools 8 | 9 | # read the contents of your README file 10 | from pathlib import Path 11 | this_directory = Path(__file__).parent 12 | long_description = (this_directory / "README.md").read_text() 13 | 14 | minimal_requirements = [ 15 | "fire", 16 | "pandas", 17 | "rich", 18 | "huggingface_hub", 19 | "oslo.concurrency", 20 | "pyarrow", 21 | ] 22 | 23 | new_model_requirements = [ 24 | "datasets", 25 | "torch", 26 | "torchvision", 27 | ] 28 | 29 | new_benchmark_requirements = [ 30 | "torch", 31 | "torchvision", 32 | "open_clip_torch", 33 | "openai-clip", 34 | "timm", 35 | "opencv-python", 36 | "transformers", 37 | 'GitPython', 38 | 'fairscale', 39 | 'gdown', 40 | 'scipy', 41 | ] 42 | 43 | setuptools.setup( 44 | name="unibench", 45 | version="0.4.0", 46 | author="Haider Al-Tahan", 47 | author_email="haideraltahan@meta.com", 48 | description="This repository is designed to simplify the evaluation process of vision-language models. It provides a comprehensive set of tools and scripts for evaluating VLM models and benchmarks.", 49 | long_description=long_description, 50 | long_description_content_type="text/markdown", 51 | url="https://github.com/facebookresearch/unibench", 52 | project_urls={ 53 | "Bug Tracker": "https://github.com/facebookresearch/unibench/issues", 54 | }, 55 | classifiers=[ 56 | "Programming Language :: Python :: 3", 57 | "License :: OSI Approved :: MIT License", 58 | "Operating System :: OS Independent", 59 | ], 60 | packages=setuptools.find_packages(), 61 | python_requires=">=3.8", 62 | install_requires=minimal_requirements, 63 | extras_require={ 64 | "new_benchmark": minimal_requirements + new_benchmark_requirements, 65 | "new_model": minimal_requirements + new_model_requirements, 66 | "all": list(set(minimal_requirements + new_model_requirements + new_benchmark_requirements)), 67 | }, 68 | setup_requires=["pytest-runner"], 69 | tests_require=["pytest"], 70 | test_suite="tests", 71 | entry_points={"console_scripts": ["unibench = unibench.evaluator:run"]}, 72 | dependency_links=["https://pypi.nvidia.com"], 73 | ) 74 | -------------------------------------------------------------------------------- /unibench/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | __version__ = "0.4.0" 8 | __author__ = "Haider Al-Tahan" 9 | 10 | from .evaluator import Evaluator 11 | -------------------------------------------------------------------------------- /unibench/benchmarks_zoo/README.md: -------------------------------------------------------------------------------- 1 | # List of Implemented Benchmarks 2 | 3 | | | benchmark | benchmark_type | 4 | |:--------------------|:------------|:---------------| 5 | | clevr_distance | zero-shot | vtab | 6 | | dspr_x_position | zero-shot | vtab | 7 | | dspr_y_position | zero-shot | vtab | 8 | | dspr_orientation | zero-shot | vtab | 9 | | sun397 | zero-shot | vtab | 10 | | retinopathy | zero-shot | vtab | 11 | | resisc45 | zero-shot | vtab | 12 | | svhn | zero-shot | vtab | 13 | | pets | zero-shot | vtab | 14 | | eurosat | zero-shot | vtab | 15 | | dtd | zero-shot | vtab | 16 | | dmlab | zero-shot | vtab | 17 | | clevr_count | zero-shot | vtab | 18 | | cifar100 | zero-shot | vtab | 19 | | caltech101 | zero-shot | vtab | 20 | | smallnorb_elevation | zero-shot | vtab | 21 | | pcam | zero-shot | vtab | 22 | | smallnorb_azimuth | zero-shot | vtab | 23 | | fer2013 | zero-shot | transfer | 24 | | voc2007 | zero-shot | transfer | 25 | | mnist | zero-shot | transfer | 26 | | country211 | zero-shot | transfer | 27 | | fgvc_aircraft | zero-shot | transfer | 28 | | cars | zero-shot | transfer | 29 | | cifar10 | zero-shot | transfer | 30 | | imageneta | zero-shot | robustness | 31 | | imagenetr | zero-shot | robustness | 32 | | imagenete | zero-shot | robustness | 33 | | objectnet | zero-shot | robustness | 34 | | imagenet9 | zero-shot | robustness | 35 | | imagenetv2 | zero-shot | robustness | 36 | | flickr30k_order | relation | relation | 37 | | sugarcrepe | relation | relation | 38 | | winoground | relation | relation | 39 | | vg_attribution | relation | relation | 40 | | vg_relation | relation | relation | 41 | | coco_order | relation | relation | 42 | | imagenet | zero-shot | imagenet | 43 | | imagenetc | zero-shot | corruption | -------------------------------------------------------------------------------- /unibench/benchmarks_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from .registry import * 8 | from .benchmarks import * 9 | -------------------------------------------------------------------------------- /unibench/benchmarks_zoo/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from collections import defaultdict 9 | 10 | _benchmark_registry = defaultdict(list) # mapping of benchmark names to entrypoint fns 11 | _benchmark_info_registry = defaultdict(defaultdict) 12 | 13 | 14 | def register_benchmark(names, info=None): 15 | def inner_decorator(fn): 16 | # add entries to registry dict/sets 17 | benchmark_name = fn.__name__ 18 | if isinstance(names, str): 19 | _benchmark_registry[names].append(benchmark_name) 20 | else: 21 | for name in names: 22 | _benchmark_registry[name].append(benchmark_name) 23 | 24 | if info is not None: 25 | _benchmark_info_registry[benchmark_name] = info 26 | 27 | # register_benchmark_handler(benchmark_name, fn) 28 | 29 | return fn 30 | 31 | return inner_decorator 32 | 33 | 34 | # def register_benchmark_handler(benchmark_name, benchmark_handler): 35 | # _benchmark_handlers[benchmark_name] = benchmark_handler 36 | 37 | 38 | # def get_benchmark_handler(benchmark_name): 39 | # return _benchmark_handlers[benchmark_name] 40 | 41 | def load_benchmark(benchmark_name, **kwargs): 42 | import unibench.benchmarks_zoo.benchmarks as benchmarks 43 | supported_benchmarks = list_benchmarks("all") 44 | module_name = supported_benchmarks.index(benchmark_name) 45 | if module_name is None: 46 | raise NameError( 47 | f"Benchmark {benchmark_name} is not supported, " 48 | f"please select from {list(supported_benchmarks.keys())}" 49 | ) 50 | 51 | return eval(f"benchmarks.{benchmark_name}")(benchmark_name, **kwargs) 52 | 53 | 54 | def get_benchmark_info(benchmark_name): 55 | return _benchmark_info_registry[benchmark_name] 56 | 57 | 58 | def get_benchmark_types(): 59 | return list(_benchmark_registry.keys()) 60 | 61 | 62 | def list_benchmarks(framework="all"): 63 | """Return list of available benchmark names, sorted alphabetically""" 64 | r = [] 65 | if framework == "all": 66 | for _, v in _benchmark_registry.items(): 67 | r.extend(v) 68 | else: 69 | r = _benchmark_registry[framework] 70 | if isinstance(r, str): 71 | return [r] 72 | return sorted(list(r)) 73 | -------------------------------------------------------------------------------- /unibench/benchmarks_zoo/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from .huggingface import HuggingFaceDataset 8 | from .bechmark_handler import * 9 | -------------------------------------------------------------------------------- /unibench/benchmarks_zoo/wrappers/bechmark_handler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from abc import abstractmethod 9 | import itertools 10 | from torch.distributions import Categorical 11 | from torch import softmax 12 | import torch 13 | 14 | 15 | class BenchmarkHandler: 16 | def __init__(self, benchmark_name, benchmark): 17 | self.benchmark_name = benchmark_name 18 | self.benchmark = benchmark 19 | 20 | @abstractmethod 21 | def eval_batch(self, model, batch): 22 | raise NotImplementedError 23 | 24 | @abstractmethod 25 | def on_validation_start(self, model): 26 | pass 27 | 28 | 29 | class ZeroShotBenchmarkHandler(BenchmarkHandler): 30 | def __init__(self, benchmark_name, benchmark, classes, templates, topx=1): 31 | BenchmarkHandler.__init__(self, benchmark_name, benchmark) 32 | assert classes is not None, "Classes must be provided for zero shot benchmarks" 33 | assert ( 34 | templates is not None 35 | ), "Templates must be provided for zero shot benchmarks" 36 | self.classes = classes 37 | self.templates = templates 38 | self.topx = topx 39 | 40 | def on_validation_start(self, model): 41 | model.set_classes(self.classes) 42 | model.set_templates(self.templates) 43 | model.compute_zeroshot_weights() 44 | 45 | def get_zeroshot_predictions(self, model, images): 46 | logit_scale = ( 47 | model.logit_scale.exp() 48 | if model.logit_scale is not None 49 | else torch.tensor(100.0) 50 | ) 51 | 52 | return ( 53 | (logit_scale * model.get_image_embeddings(images) @ model.zeroshot_weights) 54 | .squeeze() 55 | .float() 56 | ) 57 | 58 | def eval_batch(self, model, batch): 59 | split = "" 60 | if len(batch) == 4: 61 | images, targets, sample_id, split = batch 62 | elif len(batch) == 3: 63 | images, targets, sample_id = batch 64 | else: 65 | images, targets = batch 66 | 67 | logits = self.get_zeroshot_predictions(model, images) 68 | 69 | if len(targets.shape) > 1: 70 | pred = softmax(logits, dim=-1).topk(1)[1].squeeze() 71 | entropy = Categorical(probs=softmax(logits, dim=-1)).entropy() 72 | correct = targets[range(len(targets)), pred.squeeze()].clamp(0, 1) 73 | confidence = softmax(logits, dim=-1).topk(1)[0].squeeze() 74 | top5 = softmax(logits, dim=-1).topk(5)[1] 75 | correct_top5 = ( 76 | torch.bitwise_and( 77 | torch.nn.functional.one_hot(top5, len(self.classes)).sum(1), 78 | targets, 79 | ) 80 | .sum(1) 81 | .int() 82 | .clamp(0, 1) 83 | ) 84 | targets = targets.topk(1)[1].squeeze() 85 | top5 = top5.tolist() 86 | 87 | else: 88 | pred = softmax(logits, dim=-1) 89 | confidence = pred.max(1)[0].squeeze() 90 | entropy = Categorical(probs=pred).entropy() 91 | _, pred = pred.topk(self.topx, 1, True, True) 92 | pred = pred.t() 93 | correct = pred.eq(targets.view(1, -1).expand_as(pred)).int().sum(0) 94 | 95 | if len(self.classes) < 5: 96 | top5 = targets 97 | correct_top5 = [1] * len(targets) 98 | else: 99 | pred = softmax(logits, dim=-1) 100 | _, top5 = pred.topk(5, 1, True, True) 101 | correct_top5 = ( 102 | torch.bitwise_and( 103 | torch.nn.functional.one_hot(top5, len(self.classes)).sum(1), 104 | torch.nn.functional.one_hot(targets, len(self.classes)), 105 | ) 106 | .sum(1) 107 | .int() 108 | ) 109 | pred = pred.topk(1, 1, True, True)[1].squeeze() 110 | 111 | res = { 112 | "entropy": entropy, 113 | "image_class": targets, 114 | "split": split, 115 | "benchmark_name": self.benchmark_name, 116 | "correctness": correct, 117 | "correctness_top5": correct_top5, 118 | "predictions": pred, 119 | "predictions_top5": top5, 120 | "confidence": confidence, 121 | } 122 | 123 | if len(batch) > 2: 124 | res["image_name"] = sample_id 125 | 126 | return res 127 | 128 | 129 | class RelationBenchmarkHandler(BenchmarkHandler): 130 | def __init__(self, benchmark_name, benchmark): 131 | BenchmarkHandler.__init__(self, benchmark_name, benchmark) 132 | 133 | def get_similarity(self, model, images, captions): 134 | image_features = model.get_image_embeddings(images) 135 | num_captions = len(captions) 136 | batch_size = len(captions[0]) 137 | 138 | caption_features = ( 139 | model.get_text_embeddings(list(itertools.chain.from_iterable(captions))) 140 | .reshape(num_captions, batch_size, -1) 141 | .permute(1, 0, 2) 142 | ) 143 | 144 | scores = torch.einsum("nkd,nld->nkl", image_features, caption_features) 145 | 146 | if model.use_itm_head: 147 | scores = model.use_mlp_head( 148 | scores, 149 | model.model.visual_encoder(images.to(model.device)).unsqueeze(1), 150 | captions, 151 | ) 152 | 153 | return scores 154 | 155 | def eval_batch(self, model, batch): 156 | attribute = None 157 | if len(batch) == 4: 158 | images, captions, sample_id, attribute = batch 159 | else: 160 | images, captions, sample_id = batch 161 | 162 | if self.benchmark_name == "bivlc": 163 | sim_C0_I0 = self.get_similarity(model, images[0], [captions[0]]).squeeze() 164 | sim_C0_I1 = self.get_similarity(model, images[1], [captions[0]]).squeeze() 165 | sim_C1_I0 = self.get_similarity(model, images[0], [captions[1]]).squeeze() 166 | sim_C1_I1 = self.get_similarity(model, images[1], [captions[1]]).squeeze() 167 | 168 | Ipos_2T = sim_C0_I0 > sim_C1_I0 169 | Ineg_2T = sim_C1_I1 > sim_C0_I1 170 | Tpos_2I = sim_C0_I0 > sim_C0_I1 171 | Tneg_2I = sim_C1_I1 > sim_C1_I0 172 | 173 | I2T = torch.logical_and(Ipos_2T, Ineg_2T) 174 | T2I = torch.logical_and(Tpos_2I, Tneg_2I) 175 | group_score = torch.logical_and(I2T, T2I) 176 | 177 | res = { 178 | "image_name": sample_id, 179 | "benchmark_name": self.benchmark_name, 180 | "correctness": group_score.int(), 181 | 'I2T': I2T.int(), 182 | 'T2I': T2I.int(), 183 | } 184 | 185 | elif isinstance(images, list): 186 | c_i0 = self.get_similarity(model, images[0], captions).squeeze() 187 | c_i1 = self.get_similarity(model, images[1], captions).squeeze() 188 | text_correct = torch.logical_and( 189 | c_i0[:, 0] > c_i0[:, 1], c_i1[:, 1] > c_i1[:, 0] 190 | ).int() 191 | image_correct = torch.logical_and( 192 | c_i0[:, 0] > c_i1[:, 0], c_i1[:, 1] > c_i0[:, 1] 193 | ).int() 194 | correct = torch.logical_and(text_correct, image_correct).int() 195 | 196 | res = { 197 | "image_name": sample_id, 198 | "benchmark_name": self.benchmark_name, 199 | "correctness": correct, 200 | "text_correctness": text_correct, 201 | "image_correctness": image_correct, 202 | } 203 | else: 204 | scores = self.get_similarity(model, images, captions) 205 | preds = torch.argmax(scores.squeeze(), axis=-1) 206 | correct = (preds == 0).int() 207 | 208 | res = { 209 | "image_name": sample_id, 210 | "benchmark_name": self.benchmark_name, 211 | "correctness": correct, 212 | "confidence": scores.squeeze(1).max(1)[0], 213 | "entropy": Categorical( 214 | probs=softmax(scores.squeeze(1), dim=-1) 215 | ).entropy(), 216 | } 217 | 218 | if attribute is not None: 219 | if "\n" in attribute[0]: 220 | attribute = [x.split("\n") for x in attribute] 221 | res["attribute"] = attribute 222 | 223 | return res 224 | -------------------------------------------------------------------------------- /unibench/benchmarks_zoo/wrappers/huggingface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from datasets import load_dataset, load_from_disk 8 | from huggingface_hub import hf_hub_download 9 | import torch 10 | from torch.utils.data import Dataset 11 | from ...common_utils.constants import DATA_DIR 12 | from pathlib import Path 13 | 14 | 15 | class HuggingFaceDataset(Dataset): 16 | def load_txt_file(self, dataset_url, filename, dir): 17 | file = hf_hub_download( 18 | repo_id=dataset_url, 19 | filename=filename, 20 | repo_type="dataset", 21 | local_dir=dir, 22 | local_dir_use_symlinks=False, 23 | ) 24 | 25 | res = [] 26 | with open(file) as f: 27 | for line in f: 28 | res.append( 29 | line.replace("_", " ") 30 | .replace("\n", "") 31 | .replace("{c}", "{}") 32 | .replace(" ", " ") 33 | .lower() 34 | ) 35 | return res 36 | 37 | def __init__( 38 | self, 39 | dataset_url, 40 | root: str = DATA_DIR, 41 | transform=None, 42 | target_transform=None, 43 | download_num_workers=60, 44 | image_extension="webp", 45 | classes=None, 46 | templates=None, 47 | *args, 48 | **kwargs 49 | ): 50 | Dataset.__init__(self, *args, **kwargs) 51 | assert dataset_url != "", "Please provide a dataset url" 52 | 53 | self.dataset_name = dataset_url.split("/")[-1] 54 | self.root_dir = root 55 | self.dataset_dir = Path(self.root_dir) / self.dataset_name 56 | self.dataset_url = dataset_url 57 | self.image_extension = image_extension 58 | self.transform = transform 59 | self.download_num_workers = download_num_workers 60 | self.target_transform = target_transform 61 | 62 | self.classes = classes 63 | self.templates = templates 64 | 65 | if not self.dataset_dir.exists(): 66 | self.download_dataset() 67 | 68 | self.dataset = load_from_disk(str(self.dataset_dir)) 69 | 70 | try: 71 | if classes is None: 72 | self.classes = self.load_txt_file( 73 | dataset_url, "classnames.txt", str(self.dataset_dir) 74 | ) 75 | if templates is None: 76 | self.templates = self.load_txt_file( 77 | dataset_url, 78 | "zeroshot_classification_templates.txt", 79 | str(self.dataset_dir), 80 | ) 81 | except: 82 | pass 83 | 84 | def __len__(self): 85 | return len(self.dataset) 86 | 87 | def download_dataset(self): 88 | try: 89 | self.dataset = load_dataset( 90 | self.dataset_url, 91 | trust_remote_code=True, 92 | split="test", 93 | num_proc=self.download_num_workers, 94 | ) 95 | except: 96 | self.dataset = load_dataset( 97 | self.dataset_url, 98 | trust_remote_code=True, 99 | split="test", 100 | ) 101 | 102 | self.dataset.save_to_disk(str(self.dataset_dir)) 103 | 104 | def __getitem__(self, index): 105 | item = self.dataset[index] 106 | 107 | # Loading Images 108 | samples = [] 109 | for k in item.keys(): 110 | if self.image_extension in k: 111 | img = item[k].convert("RGB") 112 | if self.transform is not None: 113 | img = self.transform(img) 114 | samples.append(img) 115 | 116 | if len(samples) == 1: 117 | samples = samples[0] 118 | 119 | # Loading Labels 120 | if "cls" in item.keys(): 121 | target = item["cls"] 122 | if self.target_transform is not None: 123 | target = self.target_transform(target) 124 | else: 125 | target = item["npy"] 126 | if all(isinstance(t, int) for t in target): 127 | target = torch.nn.functional.one_hot( 128 | torch.tensor(target), len(self.classes) 129 | ).sum(0) 130 | 131 | for t in target: 132 | if self.target_transform is not None: 133 | t = self.target_transform(t) 134 | 135 | if "split.txt" in item.keys(): 136 | return ( 137 | samples, 138 | target, 139 | str(item["__key__"]), 140 | item["split.txt"], 141 | ) 142 | 143 | return samples, target, str(item["__key__"]) 144 | -------------------------------------------------------------------------------- /unibench/common_utils/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /unibench/common_utils/constants.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from pathlib import Path 8 | import os 9 | 10 | ################################################################## 11 | # DIRECTORIES 12 | ################################################################## 13 | PROJ_DIR = Path(__file__).parent.parent.absolute() 14 | CURRENT_DIR = Path(os.getcwd()) 15 | HUB_CACHE_DIR = Path(os.getenv("TORCH_HOME", Path.home().joinpath(".cache").joinpath("torch"))).joinpath("hub") 16 | CACHE_DIR = Path(os.getenv("UNIBENCH_HUB", Path.home().joinpath(".cache").joinpath("unibench"))) 17 | 18 | DATA_DIR = CACHE_DIR.joinpath("data") 19 | OUTPUT_DIR = CACHE_DIR.joinpath("outputs") 20 | LOCK_DIR = CACHE_DIR.joinpath("locks") 21 | 22 | ################################################################## 23 | # MEAN AND STD 24 | ################################################################## 25 | 26 | DEFAULT_CROP_PCT = 0.875 27 | DEFAULT_CROP_MODE = 'center' 28 | IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406) 29 | IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225) 30 | IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5) 31 | IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5) 32 | IMAGENET_DPN_MEAN = (124 / 255, 117 / 255, 104 / 255) 33 | IMAGENET_DPN_STD = tuple([1 / (.0167 * 255)] * 3) 34 | OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073) 35 | OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711) 36 | -------------------------------------------------------------------------------- /unibench/common_utils/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import random 9 | from typing import Optional 10 | import numpy as np 11 | from huggingface_hub import hf_hub_download, snapshot_download 12 | 13 | import pandas as pd 14 | from rich.table import Table 15 | from rich import box 16 | 17 | 18 | def df_to_table( 19 | pandas_dataframe: pd.DataFrame, 20 | rich_table: Table, 21 | show_index: bool = True, 22 | index_name: Optional[str] = None, 23 | ) -> Table: 24 | """Convert a pandas.DataFrame obj into a rich.Table obj. 25 | Args: 26 | pandas_dataframe (DataFrame): A Pandas DataFrame to be converted to a rich Table. 27 | rich_table (Table): A rich Table that should be populated by the DataFrame values. 28 | show_index (bool): Add a column with a row count to the table. Defaults to True. 29 | index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value. 30 | Returns: 31 | Table: The rich Table instance passed, populated with the DataFrame values.""" 32 | 33 | if show_index: 34 | index_name = str(index_name) if index_name else "" 35 | rich_table.add_column(index_name) 36 | 37 | for column in pandas_dataframe.columns: 38 | rich_table.add_column(str(column)) 39 | 40 | for index, value_list in enumerate(pandas_dataframe.values.tolist()): 41 | row = [str(pandas_dataframe.index.to_list()[index])] if show_index else [] 42 | if isinstance(value_list[0], float): 43 | row += [str(round(float(x), 2)) for x in value_list] 44 | else: 45 | row += [str(x) for x in value_list] 46 | rich_table.add_row(*row) 47 | 48 | rich_table.row_styles = ["none", "dim"] 49 | rich_table.box = box.SIMPLE_HEAD 50 | 51 | return rich_table 52 | 53 | 54 | def seed_everything(seed: int): 55 | import torch 56 | random.seed(seed) 57 | np.random.seed(seed) 58 | torch.manual_seed(seed) 59 | torch.cuda.manual_seed_all(seed) 60 | 61 | 62 | def get_benchmark_mappings(axis, benchmarks=None): 63 | from ..benchmarks_zoo.registry import get_benchmark_info 64 | from ..benchmarks_zoo import list_benchmarks 65 | if benchmarks is None: 66 | benchmarks = list_benchmarks() 67 | benchmark_mappings = {} 68 | for benchmark in benchmarks: 69 | if axis is None: 70 | benchmark_mappings[benchmark] = get_benchmark_info(benchmark) 71 | else: 72 | benchmark_mappings[benchmark] = get_benchmark_info(benchmark)[axis] 73 | return benchmark_mappings 74 | 75 | 76 | def get_model_mappings(axis, models=None): 77 | from ..models_zoo.registry import get_model_info 78 | from ..models_zoo import list_models 79 | if models is None: 80 | models = list_models() 81 | model_mappings = {} 82 | for model in models: 83 | if axis is None: 84 | model_mappings[model] = get_model_info(model) 85 | else: 86 | model_mappings[model] = get_model_info(model)[axis] 87 | return model_mappings 88 | 89 | 90 | def download_only_aggregate(output_dir): 91 | print(f"Downloading only aggregate results...{output_dir}") 92 | hf_hub_download( 93 | repo_id="haideraltahan/unibench", 94 | cache_dir=output_dir, 95 | local_dir=output_dir, 96 | local_dir_use_symlinks=False, 97 | repo_type="dataset", 98 | filename="aggregate.f", 99 | ) 100 | 101 | def download_all_results(output_dir): 102 | print(f"Downloading all results...{output_dir}") 103 | snapshot_download( 104 | repo_id="haideraltahan/unibench", 105 | cache_dir=output_dir, 106 | local_dir=output_dir, 107 | local_dir_use_symlinks=False, 108 | repo_type="dataset", 109 | ) 110 | -------------------------------------------------------------------------------- /unibench/evaluator.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | from typing import List, Union 9 | import os 10 | 11 | import fire 12 | import pandas as pd 13 | from rich.progress import Progress 14 | 15 | 16 | from unibench.output import OutputHandler 17 | from unibench.common_utils.utils import ( 18 | seed_everything, 19 | get_model_mappings, 20 | get_benchmark_mappings, 21 | df_to_table, 22 | ) 23 | from unibench.common_utils.constants import OUTPUT_DIR, DATA_DIR 24 | 25 | from rich.console import Console 26 | from rich.table import Table 27 | 28 | 29 | class Evaluator(object): 30 | """ 31 | The Evaluator class is responsible for evaluating machine learning models on various benchmarks. 32 | It provides methods to update the list of models and benchmarks, download benchmarks, add new models and benchmarks, 33 | generate aggregate results, and evaluate models. 34 | 35 | Attributes: 36 | seed (int): Random seed for reproducibility. 37 | num_workers (int): Number of workers for data loading. 38 | models (Union[List[str], str]): List of models to evaluate or "all" to evaluate all available models. 39 | benchmarks (Union[List[str], str]): List of benchmarks to evaluate or "all" to evaluate all available benchmarks. 40 | model_id (Union[int, None]): Specific model ID to evaluate. 41 | benchmark_id (Union[int, None]): Specific benchmark ID to evaluate. 42 | output_dir (str): Directory to save evaluation results. 43 | benchmarks_dir (str): Directory containing benchmark data. 44 | download_aggregate_precomputed (bool): Whether to download aggregate precomputed results. Used for minor analysis and fast loading. 45 | download_all_precomputed (bool): Whether to download all precomputed results. Used for slow loading and comprehensive analysis. 46 | 47 | Methods: 48 | update_benchmark_list(benchmarks, benchmark_id=None): 49 | Updates the list of benchmarks to evaluate. 50 | 51 | update_model_list(models, model_id=None): 52 | Updates the list of models to evaluate. 53 | 54 | download_benchmarks(): 55 | Downloads the specified benchmarks. 56 | 57 | list_models(): 58 | Lists all available models. 59 | 60 | add_benchmark(benchmark, handler, meta_data={}): 61 | Adds a new benchmark to the list of benchmarks. 62 | 63 | generate_aggregate_results(): 64 | Generates aggregate results from the evaluation. 65 | 66 | list_benchmarks(): 67 | Lists all available benchmarks. 68 | 69 | add_model(model, meta_data={}): 70 | Adds a new model to the list of models. 71 | 72 | show_results(): 73 | Displays the evaluation results. 74 | 75 | evaluate(save_freq=1000, face_blur=False, device="cuda" if torch.cuda.is_available() else "cpu", batch_per_gpu=32): 76 | Evaluates the models on the benchmarks and saves the results. 77 | """ 78 | 79 | def __init__( 80 | self, 81 | seed: int = 1337, 82 | num_workers=int(os.environ.get("SLURM_CPUS_PER_TASK") or 96), 83 | models: Union[List[str], str] = "all", 84 | benchmarks: Union[List[str], str] = "all", 85 | model_id: Union[int, None] = None, 86 | benchmark_id: Union[int, None] = None, 87 | output_dir: str = OUTPUT_DIR, 88 | benchmarks_dir: str = DATA_DIR, 89 | download_aggregate_precomputed: bool = True, 90 | download_all_precomputed: bool = False, 91 | ): 92 | self.seed = seed 93 | self.num_workers = num_workers 94 | self.benchmarks_dir = benchmarks_dir 95 | self.output_dir = output_dir 96 | 97 | self.update_model_list(models, model_id) 98 | self.update_benchmark_list(benchmarks, benchmark_id) 99 | 100 | self.outputhandler = OutputHandler( 101 | output_dir=output_dir, 102 | download_aggregate_precomputed=download_aggregate_precomputed, 103 | download_all_precomputed=download_all_precomputed, 104 | ) 105 | 106 | def update_benchmark_list( 107 | self, benchmarks: Union[List[str], str], benchmark_id: Union[int, None] = None 108 | ): 109 | from unibench.benchmarks_zoo import list_benchmarks 110 | 111 | if isinstance(benchmarks, str): 112 | self.benchmarks = list_benchmarks(benchmarks) 113 | elif isinstance(benchmarks, list): 114 | self.benchmarks = benchmarks 115 | 116 | if benchmark_id is not None: 117 | self.benchmarks = [self.benchmarks[int(benchmark_id)]] 118 | print("Evaluating only benchmark {}".format(self.benchmarks[0])) 119 | 120 | assert ( 121 | isinstance(self.benchmarks, list) and len(self.benchmarks) > 0 122 | ), "Please provide benchmarks to evaluate!" 123 | print("There are {} benchmarks to evaluate".format(len(self.benchmarks))) 124 | 125 | def update_model_list( 126 | self, models: Union[List[str], str], model_id: Union[int, None] = None 127 | ): 128 | from unibench.models_zoo import list_models 129 | 130 | if isinstance(models, str): 131 | self.models = list_models(models) 132 | elif isinstance(models, list): 133 | self.models = models 134 | 135 | if model_id is not None: 136 | self.models = [self.models[int(model_id)]] 137 | print("Evaluating only model {}".format(self.models[0])) 138 | 139 | assert ( 140 | isinstance(self.models, list) and len(self.models) > 0 141 | ), "Please provide models to evaluate!" 142 | 143 | print("There are {} models to evaluate".format(len(self.models))) 144 | 145 | def download_benchmarks(self): 146 | from unibench.benchmarks_zoo.registry import load_benchmark 147 | 148 | for benchmark in self.benchmarks: 149 | print(f"Loading {benchmark}") 150 | load_benchmark(benchmark, root=self.benchmarks_dir) 151 | print(f"Done Loading {benchmark}") 152 | 153 | def list_models(self) -> dict: 154 | model_mappings = get_model_mappings(None) 155 | # print(pd.DataFrame(model_mappings).transpose().to_markdown()) 156 | Console().print( 157 | df_to_table( 158 | pd.DataFrame(model_mappings).transpose(), 159 | Table(show_header=True, header_style="bold magenta"), 160 | index_name="model_name", 161 | ) 162 | ) 163 | return model_mappings 164 | 165 | def add_benchmark(self, benchmark, handler, meta_data={}): 166 | from unibench.benchmarks_zoo.registry import register_benchmark 167 | import unibench.benchmarks_zoo.benchmarks as benchmarks_module 168 | 169 | def temp_func(benchmark_name, transform=None, **kwargs): 170 | bm = benchmark(transform=transform, **kwargs) 171 | 172 | return handler(benchmark=bm) 173 | 174 | benchmark_name = handler.keywords["benchmark_name"] 175 | 176 | temp_func.__name__ = benchmark_name 177 | register_benchmark("new_benchmark", meta_data)(temp_func) 178 | 179 | setattr(benchmarks_module, benchmark_name, temp_func) 180 | self.benchmarks.append(benchmark_name) 181 | 182 | def generate_aggregate_results(self): 183 | self.outputhandler.load_all_csv(self.models, self.benchmarks) 184 | self.outputhandler.generate_aggregate_results() 185 | print("Aggregate results generated") 186 | 187 | def list_benchmarks(self) -> dict: 188 | benchmark_mappings = get_benchmark_mappings(None) 189 | Console().print( 190 | df_to_table( 191 | pd.DataFrame(benchmark_mappings).transpose(), 192 | Table(show_header=True, header_style="bold magenta"), 193 | index_name="benchmark_name", 194 | ) 195 | ) 196 | return benchmark_mappings 197 | 198 | def add_model( 199 | self, 200 | model, 201 | meta_data: dict = {}, 202 | ): 203 | import unibench.models_zoo.models as models_module 204 | from unibench.models_zoo.registry import register_model 205 | 206 | def temp_func(model_name, **kwargs): 207 | return model(**kwargs) 208 | 209 | model_name = model.keywords["model_name"] 210 | 211 | temp_func.__name__ = model_name 212 | register_model("new_model", meta_data)(temp_func) 213 | 214 | setattr(models_module, model_name, temp_func) 215 | self.models.append(model_name) 216 | 217 | def show_results(self): 218 | Console().print( 219 | df_to_table( 220 | self.outputhandler.print_dataframe( 221 | **{"benchmark_name": self.benchmarks, "model_name": self.models} 222 | ).round(4) 223 | * 100, 224 | Table(show_header=True, header_style="bold magenta"), 225 | index_name="model_name", 226 | ) 227 | ) 228 | 229 | def evaluate( 230 | self, 231 | save_freq: int = 1000, 232 | face_blur: bool = False, 233 | device="cpu", 234 | batch_per_gpu: int = 32, 235 | ): 236 | """ 237 | Evaluate models on benchmarks and return and saving the results. 238 | 239 | Args: 240 | save_freq (int): The frequency at which to save results. Defaults to 1000. 241 | face_blur (bool): Whether to use face blurring during evaluation. Defaults to False. 242 | device (str): The device to use for evaluation. Defaults to "cuda" if available otherwise "cpu". 243 | batch_per_gpu (int): The batch size per GPU. Defaults to 32. 244 | 245 | Returns: 246 | query results: The results of the query for the specified benchmarks and models. 247 | """ 248 | import torch 249 | from unibench.benchmarks_zoo.registry import load_benchmark 250 | from unibench.models_zoo.registry import load_model 251 | 252 | device = "cuda" if torch.cuda.is_available() else "cpu" 253 | seed_everything(self.seed) 254 | 255 | with Progress(transient=True) as progress: 256 | pg_models = progress.add_task( 257 | "[green]Processing...", total=len(self.models) 258 | ) 259 | pg_benchmarks = progress.add_task( 260 | "[green]Processing...", total=len(self.benchmarks) 261 | ) 262 | pg_benchmark = progress.add_task( 263 | "[green]Processing...", total=len(self.benchmarks), visible=False 264 | ) 265 | for model_name in self.models: 266 | progress.update( 267 | pg_models, description=f"[green]Processing {model_name}..." 268 | ) 269 | 270 | model = None 271 | 272 | for benchmark_name in self.benchmarks: 273 | progress.update( 274 | pg_benchmarks, 275 | description=f"[green]Processing {benchmark_name}...", 276 | ) 277 | 278 | number_entries = self.outputhandler.check_if_computed( 279 | model_name=model_name, benchmark_name=benchmark_name 280 | ) 281 | 282 | if number_entries == True: 283 | progress.update(pg_benchmarks, advance=1, refresh=True) 284 | continue 285 | 286 | if model is None: 287 | model = load_model( 288 | model_name=model_name, 289 | batch_per_gpu=batch_per_gpu, 290 | face_blur=face_blur, 291 | device=device, 292 | ) 293 | 294 | if model is None: 295 | raise ValueError( 296 | f"{model_name} does not exist in the currently supported models" 297 | ) 298 | 299 | dh = load_benchmark( 300 | benchmark_name, 301 | transform=model.get_preprocess_transforms(), 302 | root=self.benchmarks_dir, 303 | ) 304 | 305 | ds = dh.benchmark 306 | 307 | dh.on_validation_start(model) 308 | 309 | dl = torch.utils.data.DataLoader( 310 | ds, 311 | batch_size=model.get_batch_size(), 312 | shuffle=False, 313 | num_workers=self.num_workers, 314 | pin_memory=True, 315 | ) 316 | 317 | if number_entries == len(ds): 318 | progress.update(pg_benchmarks, advance=1, refresh=True) 319 | continue 320 | elif number_entries > len(ds) or (0 < number_entries < len(ds)): 321 | print(f"Reseting results for {model_name}") 322 | self.outputhandler.delete_rows( 323 | model_name=model_name, benchmark_name=benchmark_name 324 | ) 325 | 326 | progress.update( 327 | pg_benchmark, total=len(dl), completed=0, visible=True 328 | ) 329 | for idx, batch in enumerate(dl): 330 | progress.update( 331 | pg_benchmark, 332 | description=f"[green]Processing Batch #{idx}...", 333 | visible=True, 334 | ) 335 | 336 | for i, sample in enumerate(batch): 337 | if isinstance(sample, torch.Tensor) and device == "cuda": 338 | batch[i] = batch[i].to(device) 339 | 340 | with torch.no_grad(), torch.amp.autocast("cuda"): 341 | values_to_save = dh.eval_batch(model, batch) 342 | 343 | self.outputhandler.add_values( 344 | model_name=model_name, **values_to_save 345 | ) 346 | 347 | if idx % save_freq == 0 and idx > 0: 348 | self.outputhandler.save_csv(model_name, benchmark_name) 349 | progress.update(pg_benchmark, advance=1) 350 | progress.update(pg_benchmark, visible=False) 351 | self.outputhandler.save_csv(model_name, benchmark_name) 352 | self.outputhandler.save_aggregate_results( 353 | model_name, benchmark_name 354 | ) 355 | progress.update(pg_benchmarks, advance=1) 356 | progress.update(pg_models, advance=1) 357 | 358 | Console().print( 359 | df_to_table( 360 | self.outputhandler.print_dataframe( 361 | **{"benchmark_name": self.benchmarks, "model_name": self.models} 362 | ).round(4) 363 | * 100, 364 | Table(show_header=True, header_style="bold magenta"), 365 | index_name="model_name", 366 | ) 367 | ) 368 | Console().print(f"The results are saved in {self.output_dir}") 369 | 370 | 371 | def run(): 372 | fire.Fire(Evaluator) 373 | -------------------------------------------------------------------------------- /unibench/models_zoo/README.md: -------------------------------------------------------------------------------- 1 | # List of Implemented VLMs 2 | 3 | | | Dataset Size (Million) | Number of Parameters (Million) | Learning Objective | Architecture | Model Name | 4 | |:--------------------------------|---------------:|-------------:|:----------------------------|:---------------|:-------------------| 5 | | blip_vitB16_14m | 14 | 86 | BLIP | vit | BLIP ViT B 16 | 6 | | blip_vitL16_129m | 129 | 307 | BLIP | vit | BLIP ViT L 16 | 7 | | blip_vitB16_129m | 129 | 86 | BLIP | vit | BLIP ViT B 16 | 8 | | blip_vitB16_coco | 129 | 86 | BLIP | vit | BLIP ViT B 16 | 9 | | blip_vitB16_flickr | 129 | 86 | BLIP | vit | BLIP ViT B 16 | 10 | | blip_vitL16_coco | 129 | 307 | BLIP | vit | BLIP ViT L 16 | 11 | | blip_vitL16_flickr | 129 | 307 | BLIP | vit | BLIP ViT L 16 | 12 | | eva02_vitE14_plus_2b | 2000 | 4350 | EVA02 | vit | EVA02 ViT E 14 | 13 | | eva02_vitE14_2b | 2000 | 4350 | EVA02 | vit | EVA02 ViT E 14 | 14 | | eva02_vitL14_2b | 2000 | 307 | EVA02 | vit | EVA02 ViT L 14 | 15 | | eva02_vitB16_2b | 2000 | 86 | EVA02 | vit | EVA02 ViT B 16 | 16 | | eva01_vitG14_plus_2b | 2000 | 1011 | EVA01 | vit | EVA01 ViT g 14 | 17 | | eva01_vitG14_400m | 400 | 1011 | EVA01 | vit | EVA01 ViT g 14 | 18 | | clipa_vitbigG14 | 1280 | 1843 | CLIPA | vit | CLIPA ViT G 14 | 19 | | clipa_vitH14 | 1280 | 633 | CLIPA | vit | CLIPA ViT H 14 | 20 | | clipa_vitL14 | 1280 | 307 | CLIPA | vit | CLIPA ViT L 14 | 21 | | siglip_vitL16 | 10000 | 307 | Contrastive (sigmoid-based) | vit | SigLIP ViT L 16 | 22 | | siglip_vitB16 | 10000 | 86 | Contrastive (sigmoid-based) | vit | SigLIP ViT B 16 | 23 | | openclip_vitB32_metaclip_fullcc | 2500 | 86 | Contrastive | vit | MetaCLIP ViT B 32 | 24 | | openclip_vitB16_metaclip_400m | 400 | 86 | Contrastive | vit | MetaCLIP ViT B 16 | 25 | | openclip_vitB32_metaclip_400m | 400 | 86 | Contrastive | vit | MetaCLIP ViT B 32 | 26 | | openclip_vitB16_metaclip_fullcc | 2500 | 86 | Contrastive | vit | MetaCLIP ViT B 16 | 27 | | openclip_vitL14_dfn2b | 2000 | 307 | Contrastive | vit | OpenCLIP ViT L 14 | 28 | | openclip_vitL14_metaclip_400 | 400 | 307 | Contrastive | vit | MetaCLIP ViT L 14 | 29 | | openclip_vitL14_metaclip_fullcc | 2500 | 307 | Contrastive | vit | MetaCLIP ViT L 14 | 30 | | openclip_vitH14_metaclip_fullcc | 2500 | 633 | Contrastive | vit | MetaCLIP ViT H 14 | 31 | | openclip_vitH14_dfn5b | 5000 | 633 | Contrastive | vit | OpenCLIP ViT H 14 | 32 | | openclip_convnext_base | 400 | 88 | Contrastive | conv | OpenCLIP ConvNext | 33 | | openclip_vitB32_datacomp_s | 13 | 86 | Contrastive | vit | DataComp ViT B 32 | 34 | | openclip_vitB32_datacomp_m | 128 | 86 | Contrastive | vit | DataComp ViT B 32 | 35 | | openclip_vitB32_datacomp_xl | 12800 | 86 | Contrastive | vit | DataComp ViT B 32 | 36 | | openclip_vitB16_datacomp_xl | 12800 | 86 | Contrastive | vit | DataComp ViT B 16 | 37 | | openclip_vitB16_datacomp_l | 1280 | 86 | Contrastive | vit | DataComp ViT B 16 | 38 | | openclip_vitH14 | 2000 | 633 | Contrastive | vit | OpenCLIP ViT H 14 | 39 | | xvlm_flickr | 16 | 86 | XVLM | Swin | XVLM Swin B | 40 | | flava_full | 70 | 86 | Other | vit | FLAVA ViT B 32 | 41 | | openclip_vitL14_400m | 400 | 307 | Contrastive | vit | OpenCLIP ViT L 14 | 42 | | openclip_vitL14_datacomp_xl | 12800 | 307 | Contrastive | vit | DataComp ViT L 14 | 43 | | openclip_vitL14_2b | 2000 | 307 | Contrastive | vit | OpenCLIP ViT L 14 | 44 | | clip_vitL14 | 400 | 307 | Contrastive | vit | CLIP ViT L 14 | 45 | | xvlm_coco | 16 | 86 | XVLM | Swin | XVLM Swin B | 46 | | openclip_vitB32_400m | 400 | 86 | Contrastive | vit | OpenCLIP ViT B 32 | 47 | | openclip_vitB32_2b | 2000 | 86 | Contrastive | vit | OpenCLIP ViT B 32 | 48 | | openclip_vitG14_2b | 2000 | 1011 | Contrastive | vit | OpenCLIP ViT g 14 | 49 | | openclip_vitbigG14_2b | 2000 | 1843 | Contrastive | vit | OpenCLIP ViT G 14 | 50 | | openclip_vitB16_2b | 2000 | 86 | Contrastive | vit | OpenCLIP ViT B 16 | 51 | | openclip_vitB16_400m | 400 | 86 | Contrastive | vit | OpenCLIP ViT B 16 | 52 | | opencoca_vitL14_2b | 2000 | 307 | Other | vit | OpenCOCA ViT L 14 | 53 | | opencoca_vitB32_2b | 2000 | 86 | Other | vit | OpenCOCA ViT B 32 | 54 | | negclip_vitB32 | 400 | 86 | Negative CLIP | vit | NegCLIP ViT B 32 | 55 | | clip_vitB16 | 400 | 86 | Contrastive | vit | CLIP ViT B 16 | 56 | | clip_resnet50 | 400 | 38 | Contrastive | conv | CLIP ResNet50 | 57 | | openclip_resnet101_yfcc | 15 | 56 | Contrastive | conv | OpenCLIP ResNet101 | 58 | | openclip_resnet50_yfcc | 15 | 38 | Contrastive | conv | OpenCLIP ResNet50 | 59 | | openclip_resnet50_cc | 12 | 38 | Contrastive | conv | OpenCLIP ResNet50 | 60 | | clip_resnet101 | 400 | 56 | Contrastive | conv | CLIP ResNet101 | 61 | | clip_resnet50x4 | 400 | 87 | Contrastive | conv | CLIP ResNet50x4 | 62 | | clip_resnet50x16 | 400 | 167 | Contrastive | conv | CLIP ResNet50x16 | 63 | | clip_resnet50x64 | 400 | 420 | Contrastive | conv | CLIP ResNet50x64 | 64 | | clip_vitB32 | 400 | 86 | Contrastive | vit | CLIP ViT B 32 | -------------------------------------------------------------------------------- /unibench/models_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from .registry import * 8 | from .models import * 9 | -------------------------------------------------------------------------------- /unibench/models_zoo/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from collections import defaultdict 8 | 9 | _model_registry = defaultdict(list) # mapping of model names to entrypoint fns 10 | _model_info_registry = defaultdict(defaultdict) 11 | 12 | 13 | def register_model(frameworks, info): 14 | def inner_decorator(fn): 15 | # add entries to registry dict/sets 16 | model_name = fn.__name__ 17 | if isinstance(frameworks, str): 18 | _model_registry[frameworks].append(model_name) 19 | else: 20 | for framework in frameworks: 21 | _model_registry[framework].append(model_name) 22 | 23 | if info is not None: 24 | _model_info_registry[model_name] = info 25 | 26 | return fn 27 | 28 | return inner_decorator 29 | 30 | 31 | def get_model_info(model_name): 32 | return _model_info_registry[model_name] 33 | 34 | 35 | def load_model(model_name, **kwargs): 36 | import unibench.models_zoo.models as models 37 | if model_name in list_models("all"): 38 | model = eval(f"models.{model_name}")(model_name, **kwargs) 39 | else: 40 | return None 41 | return model 42 | 43 | 44 | def list_models(framework="all"): 45 | """Return list of available model names, sorted alphabetically""" 46 | r = [] 47 | if framework == "all": 48 | for _, v in _model_registry.items(): 49 | r.extend(v) 50 | else: 51 | r = _model_registry[framework] 52 | return sorted(list(r)) 53 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from .clip import ClipModel 8 | from .huggingface import FlavaModel, XVLMModel 9 | from .blip import BlipModel 10 | from .base import AbstractModel 11 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/base.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from abc import ABC, abstractmethod 8 | from typing import List, Union 9 | 10 | import torch 11 | import math 12 | import timm.data 13 | import torch 14 | from torchvision.transforms import ( 15 | Compose, 16 | Resize, 17 | CenterCrop, 18 | Normalize, 19 | ToTensor, 20 | InterpolationMode, 21 | ) 22 | 23 | from .transformations.grayscale2rgb import GrayScale2RGB 24 | 25 | 26 | class AbstractModel(ABC): 27 | def __init__( 28 | self, 29 | model, 30 | model_name, 31 | input_resolution: int = 224, 32 | crop_pct: float = 1.0, 33 | norm_mean=timm.data.constants.OPENAI_CLIP_MEAN, 34 | norm_std=timm.data.constants.OPENAI_CLIP_STD, 35 | use_norm: bool = True, 36 | batch_per_gpu: int = 32, 37 | logit_scale=None, 38 | context_length=77, 39 | tokenizer=None, 40 | interpolation=InterpolationMode.BICUBIC, 41 | use_itm_head: bool = False, 42 | face_blur: bool = False, 43 | device: str = "cuda", 44 | ) -> None: 45 | super(AbstractModel, self).__init__() 46 | assert device in ["cpu", "cuda"], "device must be 'cpu' or 'cuda'" 47 | 48 | self.model = model 49 | self.use_itm_head = use_itm_head 50 | self.batch_per_gpu = batch_per_gpu 51 | self.logit_scale = logit_scale 52 | self.context_length = context_length 53 | self.model_name = model_name 54 | self.crop_pct = crop_pct 55 | self.input_resolution = input_resolution 56 | self.norm_mean = norm_mean 57 | self.norm_std = norm_std 58 | self.interpolation = interpolation 59 | self.use_norm = use_norm 60 | self.face_blur = face_blur 61 | self.device = device 62 | self.tokenizer = tokenizer 63 | 64 | self.model = self.model.to(device) 65 | self.model.eval() 66 | 67 | self.zeroshot_weights = None 68 | self.classes = None 69 | self.templates = None 70 | 71 | @abstractmethod 72 | def get_image_embeddings(self, images: torch.Tensor) -> torch.Tensor: 73 | raise NotImplementedError 74 | 75 | @abstractmethod 76 | def get_text_embeddings(self, texts: Union[list, torch.Tensor]) -> torch.Tensor: 77 | raise NotImplementedError 78 | 79 | def set_templates(self, templates: List[str]) -> None: 80 | if self.templates != templates: 81 | self.templates = templates 82 | 83 | def set_classes(self, classes: List[str]) -> None: 84 | if self.classes != classes: 85 | self.classes = classes 86 | 87 | def get_zeroshot_predictions(self, images, zeroshot_weights): 88 | return ( 89 | ( 90 | self.logit_scale.exp() 91 | if self.logit_scale is not None 92 | else torch.tensor(100.0) 93 | ) 94 | * self.get_image_embeddings(images) 95 | @ zeroshot_weights 96 | ).squeeze() 97 | 98 | @abstractmethod 99 | def compute_zeroshot_weights(self) -> None: 100 | pass 101 | 102 | def get_batch_size(self) -> int: 103 | return self.batch_per_gpu * ( 104 | 1 if self.device == "cpu" else torch.cuda.device_count() 105 | ) 106 | 107 | def get_preprocess_transforms(self): 108 | scale_size = int(math.floor(self.input_resolution / self.crop_pct)) 109 | transforms = [ 110 | Resize(scale_size, interpolation=self.interpolation), 111 | CenterCrop(self.input_resolution), 112 | ] 113 | 114 | if self.face_blur: 115 | from .transformations.faceblur import FaceBlur 116 | transforms.append(FaceBlur(input_resolution=self.input_resolution)) 117 | 118 | transforms.append(ToTensor()) 119 | transforms.append(GrayScale2RGB()) 120 | 121 | if self.use_norm: 122 | transforms.append( 123 | Normalize( 124 | mean=self.norm_mean, 125 | std=self.norm_std, 126 | ) 127 | ) 128 | 129 | return Compose(transforms) 130 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/blip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import itertools 8 | import torch 9 | 10 | from .clip import ClipModel 11 | 12 | 13 | class BlipModel(ClipModel): 14 | def __init__(self, **kwargs): 15 | super(BlipModel, self).__init__(**kwargs) 16 | 17 | @torch.no_grad() 18 | def get_image_embeddings(self, images): 19 | image_features = self.model.visual_encoder(images.to(self.device)) 20 | image_features = self.model.vision_proj(image_features[:, 0, :]).float() 21 | image_features /= image_features.norm(dim=-1, keepdim=True) 22 | return image_features.unsqueeze(1) 23 | 24 | @torch.no_grad() 25 | def use_mlp_head(self, sim_scores, image_features, captions): 26 | num_captions = len(captions) 27 | batch_size = len(captions[0]) 28 | captions = self.tokenizer( 29 | list(itertools.chain.from_iterable(captions)), 30 | padding="max_length", 31 | truncation=True, 32 | max_length=self.context_length, 33 | return_tensors="pt", 34 | ) 35 | 36 | text_token = ( 37 | captions["input_ids"] 38 | .to(self.device) 39 | .reshape(num_captions, batch_size, -1) 40 | .permute(1, 0, 2) 41 | ) 42 | text_attention = ( 43 | captions["attention_mask"] 44 | .to(self.device) 45 | .reshape(num_captions, batch_size, -1) 46 | .permute(1, 0, 2) 47 | ) 48 | 49 | # AUGMENT TEXT to IMG scores by parsing the text conditioned on the image 50 | new_scores = torch.full(sim_scores.size(), -100.0).to( 51 | self.device 52 | ) # batch x n_image_options x n_text_options) 53 | n_text_options = new_scores.size(2) 54 | for i in range(sim_scores.size(0)): 55 | text_candidates_i = text_token[i] 56 | text_attention_i = text_attention[i] 57 | encoder_att = torch.ones( 58 | (n_text_options, image_features.size(2)), dtype=torch.long 59 | ).to(self.device) 60 | for j in range(sim_scores.size(1)): # loop over image options 61 | encoder_output = image_features[i, j] # size n hidden states x dim d 62 | encoder_output = encoder_output.repeat(n_text_options, 1, 1) 63 | output = self.model.text_encoder( 64 | text_candidates_i, 65 | attention_mask=text_attention_i, 66 | encoder_hidden_states=encoder_output, 67 | encoder_attention_mask=encoder_att, 68 | return_dict=False, 69 | )[0] 70 | score = self.model.itm_head(output[:, 0, :])[ 71 | :, 1 72 | ] # logits that the text is relevant to the image 73 | new_scores[i, j, :] = score + sim_scores[i, j, :] 74 | return new_scores 75 | 76 | @torch.no_grad() 77 | def get_text_embeddings(self, captions): 78 | captions = self.tokenizer( 79 | captions, 80 | padding="max_length", 81 | truncation=True, 82 | max_length=self.context_length, 83 | return_tensors="pt", 84 | ) 85 | 86 | text_features = self.model.text_encoder( 87 | captions["input_ids"].to(self.device), 88 | attention_mask=captions["attention_mask"].to(self.device), 89 | return_dict=True, 90 | mode="text", 91 | ) 92 | 93 | text_features = self.model.text_proj( 94 | text_features.last_hidden_state[:, 0, :] 95 | ).float() 96 | 97 | text_features /= text_features.norm(dim=-1, keepdim=True) 98 | 99 | return text_features 100 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/clip.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import torch 9 | from .base import AbstractModel 10 | import inspect 11 | 12 | 13 | class ClipModel(AbstractModel): 14 | def __init__( 15 | self, 16 | model, 17 | model_name, 18 | **kwargs, 19 | ): 20 | super(ClipModel, self).__init__(model, model_name, **kwargs) 21 | 22 | def compute_zeroshot_weights(self): 23 | zeroshot_weights = [] 24 | for class_name in self.classes: 25 | texts = [template.format(class_name) for template in self.templates] 26 | 27 | class_embedding = self.get_text_embeddings(texts) 28 | 29 | class_embedding = class_embedding.mean(dim=0) 30 | class_embedding /= class_embedding.norm(dim=-1, keepdim=True) 31 | 32 | zeroshot_weights.append(class_embedding) 33 | self.zeroshot_weights = torch.stack(zeroshot_weights).T 34 | 35 | @torch.no_grad() 36 | def get_image_embeddings(self, images): 37 | image_features = self.model.encode_image(images.to(self.device)) 38 | image_features /= image_features.norm(dim=1, keepdim=True) 39 | return image_features.unsqueeze(1) 40 | 41 | @torch.no_grad() 42 | def get_text_embeddings(self, captions): 43 | if ( 44 | "truncate" in inspect.getfullargspec(self.tokenizer.__call__)[0] 45 | or "truncate" in inspect.getfullargspec(self.tokenizer)[0] 46 | ): 47 | caption_tokens = self.tokenizer( 48 | captions, context_length=self.context_length, truncate=True 49 | ).to(self.device) 50 | else: 51 | caption_tokens = self.tokenizer( 52 | captions, context_length=self.context_length 53 | ).to(self.device) 54 | 55 | caption_embeddings = self.model.encode_text(caption_tokens) 56 | caption_embeddings /= caption_embeddings.norm(dim=-1, keepdim=True) 57 | 58 | return caption_embeddings 59 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/huggingface.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import torch 8 | 9 | from .blip import BlipModel 10 | from .clip import ClipModel 11 | 12 | 13 | class FlavaModel(ClipModel): 14 | def __init__(self, *args, **kwargs): 15 | super(FlavaModel, self).__init__(*args, **kwargs) 16 | 17 | @torch.no_grad() 18 | def get_image_embeddings(self, images): 19 | image_features = self.model.flava.get_image_features( 20 | images.to(self.device), 21 | )[:, 0, :] 22 | image_features /= image_features.norm(dim=-1, keepdim=True) 23 | return image_features.unsqueeze(1) 24 | 25 | @torch.no_grad() 26 | def get_text_embeddings(self, texts): 27 | captions = self.tokenizer( 28 | text=texts, 29 | return_tensors="pt", 30 | padding="max_length", 31 | max_length=self.context_length, 32 | truncation=True 33 | ) 34 | 35 | text_features = self.model.flava.get_text_features( 36 | **{k: v.cuda() for k, v in captions.items()} 37 | )[:, 0, :] 38 | text_features /= text_features.norm(dim=-1, keepdim=True) 39 | 40 | return text_features 41 | 42 | 43 | class XVLMModel(BlipModel): 44 | def __init__(self, *args, **kwargs): 45 | super(XVLMModel, self).__init__(*args, **kwargs) 46 | 47 | @torch.no_grad() 48 | def get_image_embeddings(self, images): 49 | image_features_out = self.model.vision_encoder( 50 | images.to(self.device), 51 | output_attentions=None, 52 | output_hidden_states=None, 53 | return_dict=False, 54 | ) 55 | image_features = self.model.vision_proj(image_features_out[:, 0, :]) 56 | image_features = image_features.float() 57 | image_features /= image_features.norm(dim=-1, keepdim=True) 58 | return image_features.unsqueeze(1) 59 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/lit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import os 8 | import jax 9 | import jax.numpy as jnp 10 | from matplotlib import pyplot as plt 11 | import numpy as np 12 | import pandas as pd 13 | from common_utils.constants import CACHE_DIR 14 | from .base import AbstractModel 15 | import tensorflow as tf 16 | import tqdm 17 | 18 | 19 | class LiTModel(AbstractModel): 20 | def __init__(self, model, model_name, *args, **kwargs): 21 | super(LiTModel, self).__init__(model, model_name, *args, **kwargs) 22 | self.tokenizer = model.get_tokenizer() 23 | 24 | def _get_zeroshot_weights(self, class_names, templates): 25 | zeroshot_weights = [] 26 | for class_name in tqdm(class_names): 27 | texts = [ 28 | template.format(class_name) for template in self.templates 29 | ] # format with class 30 | texts = self.tokenize(texts) # tokenize 31 | _, class_embeddings, _ = self.model.apply( 32 | self.lit_variables, tokens=texts 33 | ) # embed with text encoder 34 | class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True) 35 | class_embedding = class_embeddings.mean(dim=0) 36 | class_embedding /= class_embedding.norm() 37 | zeroshot_weights.append(class_embedding) 38 | zeroshot_weights = torch.stack(zeroshot_weights, dim=1).cuda() 39 | 40 | return zeroshot_weights 41 | 42 | def get_zeroshot_predictions(self, images, zeroshot_weights): 43 | if self.zeroshot_weights is None: 44 | self.zeroshot_weights = self._get_zeroshot_weights( 45 | imagenet_classes, imagenet_templates 46 | ) 47 | return self.get_image_embeddings(images) @ zeroshot_weights 48 | 49 | def get_image_embeddings(self, images): 50 | image_features = self.model.module.encode_image(images) 51 | image_features /= image_features.norm(dim=-1, keepdim=True) 52 | return image_features 53 | 54 | def get_text_embeddings(self, text): 55 | caption_options = [] 56 | for c_option in text: 57 | caption_tokenized = torch.cat([clip.tokenize(c) for c in c_option]) 58 | caption_embeddings = self.model.module.encode_text( 59 | caption_tokenized.to("cuda") 60 | ) # B x D 61 | caption_embeddings /= caption_embeddings.norm(dim=-1, keepdim=True) # B x D 62 | caption_options.append(caption_embeddings.unsqueeze(1)) # B x 1 x D 63 | 64 | return torch.stack(caption_options, axis=1) 65 | 66 | def forward_batch(self, images, output="zero_shot", text=None): 67 | assert output in self.OUTPUT_TYPES 68 | 69 | if output == "zero_shot": 70 | return self.get_zeroshot_predictions(images, self.zeroshot_weights) 71 | elif output == "relations": 72 | image_options = self.get_image_embeddings(images) # B x L x D 73 | caption_options = self.get_text_embeddings(text) # B x K x D 74 | return np.einsum( 75 | "nkd,nld->nkl", image_options, caption_options 76 | ) # B x K x L 77 | raise NotImplementedError(f"Not implemented for {self.model_name}") 78 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/transformations/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/transformations/faceblur.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | from facenet_pytorch import MTCNN 8 | from PIL import Image 9 | import numpy as np 10 | import torch 11 | import torchvision 12 | 13 | 14 | class FaceBlur(torch.nn.Module): 15 | def __init__(self, input_resolution=224, kernel_size=21, sigma=10.0): 16 | super(FaceBlur, self).__init__() 17 | self.mtcnn = MTCNN(keep_all=True, image_size=input_resolution) 18 | self.kernel_size = kernel_size 19 | self.sigma = sigma 20 | 21 | def forward(self, image): 22 | boxes, _ = self.mtcnn.detect(image) 23 | image_c = np.array(image.copy()) 24 | if boxes is not None: 25 | for x, y, w, h in boxes: 26 | x, y, w, h = int(x), int(y), int(w), int(h) 27 | 28 | roi = image_c[y:h, x:w] 29 | 30 | if all([x >= 10 for x in roi.shape]): 31 | continue 32 | 33 | roi = torchvision.transforms.functional.gaussian_blur( 34 | Image.fromarray(roi), kernel_size=self.kernel_size, sigma=self.sigma 35 | ) 36 | image_c[y:h, x:w] = roi 37 | return Image.fromarray(image_c) -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/transformations/grayscale2rgb.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import torch 8 | 9 | class GrayScale2RGB(torch.nn.Module): 10 | def __init__(self): 11 | super(GrayScale2RGB, self).__init__() 12 | 13 | def forward(self, image): 14 | if image.shape[0] == 1: 15 | image = image.repeat(3, 1, 1) 16 | return image -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/xvlm_util/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/xvlm_util/box_ops.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 8 | """ 9 | Utilities for bounding box manipulation and GIoU. 10 | """ 11 | import torch 12 | from torchvision.ops.boxes import box_area 13 | 14 | 15 | def box_cxcywh_to_xyxy(x): # 这个用了 16 | x_c, y_c, w, h = x.unbind(-1) 17 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), 18 | (x_c + 0.5 * w), (y_c + 0.5 * h)] 19 | return torch.stack(b, dim=-1) 20 | 21 | 22 | def box_xyxy_to_cxcywh(x): 23 | x0, y0, x1, y1 = x.unbind(-1) 24 | b = [(x0 + x1) / 2, (y0 + y1) / 2, 25 | (x1 - x0), (y1 - y0)] 26 | return torch.stack(b, dim=-1) 27 | 28 | 29 | # modified from torchvision to also return the union 30 | def box_iou(boxes1, boxes2): 31 | area1 = box_area(boxes1) 32 | area2 = box_area(boxes2) 33 | 34 | lt = torch.max(boxes1[:, None, :2], boxes2[:, :2]) # [N,M,2] 35 | rb = torch.min(boxes1[:, None, 2:], boxes2[:, 2:]) # [N,M,2] 36 | 37 | wh = (rb - lt).clamp(min=0) # [N,M,2] 38 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M] 39 | 40 | union = area1[:, None] + area2 - inter 41 | 42 | iou = inter / union 43 | return iou, union 44 | 45 | 46 | def generalized_box_iou(boxes1, boxes2): 47 | """ 48 | Generalized IoU from https://giou.stanford.edu/ 49 | 50 | The boxes should be in [x0, y0, x1, y1] format 51 | 52 | Returns a [N, M] pairwise matrix, where N = len(boxes1) 53 | and M = len(boxes2) 54 | """ 55 | iou, union = box_iou(boxes1, boxes2) 56 | 57 | lt = torch.min(boxes1[:, None, :2], boxes2[:, :2]) 58 | rb = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) 59 | 60 | wh = (rb - lt).clamp(min=0) # [N,M,2] 61 | area = wh[:, :, 0] * wh[:, :, 1] 62 | 63 | return iou - (area - union) / area 64 | 65 | 66 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/xvlm_util/clip_vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | # Copyright 2021 The OpenAI Team Authors and The HuggingFace Team. All rights reserved. 8 | # 9 | # Licensed under the Apache License, Version 2.0 (the "License"); 10 | # you may not use this file except in compliance with the License. 11 | # You may obtain a copy of the License at 12 | # 13 | # http://www.apache.org/licenses/LICENSE-2.0 14 | # 15 | # Unless required by applicable law or agreed to in writing, software 16 | # distributed under the License is distributed on an "AS IS" BASIS, 17 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 18 | # See the License for the specific language governing permissions and 19 | # limitations under the License. 20 | 21 | from dataclasses import dataclass 22 | from typing import Any, Optional, Tuple 23 | 24 | import torch 25 | import torch.utils.checkpoint 26 | from torch import nn 27 | 28 | from transformers.activations import ACT2FN 29 | from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling 30 | from transformers.utils import logging 31 | 32 | 33 | logger = logging.get_logger(__name__) 34 | 35 | 36 | # Copied from transformers.models.bart.modeling_bart._expand_mask 37 | def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): 38 | """ 39 | Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. 40 | """ 41 | bsz, src_len = mask.size() 42 | tgt_len = tgt_len if tgt_len is not None else src_len 43 | 44 | expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) 45 | 46 | inverted_mask = 1.0 - expanded_mask 47 | 48 | return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) 49 | 50 | 51 | # contrastive loss function, adapted from 52 | # https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html 53 | def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: 54 | return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) 55 | 56 | 57 | def clip_loss(similarity: torch.Tensor) -> torch.Tensor: 58 | caption_loss = contrastive_loss(similarity) 59 | image_loss = contrastive_loss(similarity.T) 60 | return (caption_loss + image_loss) / 2.0 61 | 62 | 63 | class CLIPVisionEmbeddings(nn.Module): 64 | def __init__(self, image_size, patch_size, hidden_size): 65 | super().__init__() 66 | self.embed_dim = hidden_size 67 | self.image_size = image_size 68 | self.patch_size = patch_size 69 | 70 | 71 | 72 | def forward(self, pixel_values): 73 | batch_size = pixel_values.shape[0] 74 | patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] 75 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2) 76 | 77 | class_embeds = self.class_embedding.expand(batch_size, 1, -1) 78 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1) 79 | embeddings = embeddings + self.position_embedding(self.position_ids) 80 | return embeddings 81 | 82 | 83 | class CLIPAttention(nn.Module): 84 | """Multi-headed attention from 'Attention Is All You Need' paper""" 85 | 86 | def __init__(self, hidden_size, num_attention_heads, attention_dropout): 87 | super().__init__() 88 | self.embed_dim = hidden_size 89 | self.num_heads = num_attention_heads 90 | self.head_dim = self.embed_dim // self.num_heads 91 | assert ( 92 | self.head_dim * self.num_heads == self.embed_dim 93 | ), f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`: {self.num_heads})." 94 | self.scale = self.head_dim ** -0.5 95 | self.dropout = attention_dropout 96 | 97 | self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) 98 | self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) 99 | self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) 100 | self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) 101 | 102 | def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): 103 | return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() 104 | 105 | def forward( 106 | self, 107 | hidden_states: torch.Tensor, 108 | attention_mask: Optional[torch.Tensor] = None, 109 | causal_attention_mask: Optional[torch.Tensor] = None, 110 | output_attentions: bool = False, 111 | ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: 112 | """Input shape: Batch x Time x Channel""" 113 | 114 | bsz, tgt_len, embed_dim = hidden_states.size() 115 | 116 | # get query proj 117 | query_states = self.q_proj(hidden_states) * self.scale 118 | key_states = self._shape(self.k_proj(hidden_states), -1, bsz) 119 | value_states = self._shape(self.v_proj(hidden_states), -1, bsz) 120 | 121 | proj_shape = (bsz * self.num_heads, -1, self.head_dim) 122 | query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) 123 | key_states = key_states.view(*proj_shape) 124 | value_states = value_states.view(*proj_shape) 125 | 126 | src_len = key_states.size(1) 127 | attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) 128 | 129 | if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): 130 | raise ValueError( 131 | f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is {attn_weights.size()}" 132 | ) 133 | 134 | # apply the causal_attention_mask first 135 | if causal_attention_mask is not None: 136 | if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): 137 | raise ValueError( 138 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {causal_attention_mask.size()}" 139 | ) 140 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask 141 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 142 | 143 | if attention_mask is not None: 144 | if attention_mask.size() != (bsz, 1, tgt_len, src_len): 145 | raise ValueError( 146 | f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" 147 | ) 148 | attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask 149 | attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) 150 | 151 | attn_weights = nn.functional.softmax(attn_weights, dim=-1) 152 | 153 | if output_attentions: 154 | # this operation is a bit akward, but it's required to 155 | # make sure that attn_weights keeps its gradient. 156 | # In order to do so, attn_weights have to reshaped 157 | # twice and have to be reused in the following 158 | attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) 159 | attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) 160 | else: 161 | attn_weights_reshaped = None 162 | 163 | attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) 164 | 165 | attn_output = torch.bmm(attn_probs, value_states) 166 | 167 | if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): 168 | raise ValueError( 169 | f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is {attn_output.size()}" 170 | ) 171 | 172 | attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) 173 | attn_output = attn_output.transpose(1, 2) 174 | attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) 175 | 176 | attn_output = self.out_proj(attn_output) 177 | 178 | return attn_output, attn_weights_reshaped 179 | 180 | 181 | class CLIPMLP(nn.Module): 182 | def __init__(self, hidden_act, hidden_size, intermediate_size): 183 | super().__init__() 184 | self.activation_fn = ACT2FN[hidden_act] 185 | self.fc1 = nn.Linear(hidden_size, intermediate_size) 186 | self.fc2 = nn.Linear(intermediate_size, hidden_size) 187 | 188 | def forward(self, hidden_states): 189 | hidden_states = self.fc1(hidden_states) 190 | hidden_states = self.activation_fn(hidden_states) 191 | hidden_states = self.fc2(hidden_states) 192 | return hidden_states 193 | 194 | 195 | class CLIPEncoderLayer(nn.Module): 196 | def __init__(self, hidden_size, hidden_act, num_attention_heads, attention_dropout, intermediate_size): 197 | super().__init__() 198 | self.self_attn = CLIPAttention(hidden_size, num_attention_heads, attention_dropout) 199 | self.layer_norm1 = nn.LayerNorm(hidden_size) 200 | self.mlp = CLIPMLP(hidden_act, hidden_size, intermediate_size) 201 | self.layer_norm2 = nn.LayerNorm(hidden_size) 202 | 203 | def forward( 204 | self, 205 | hidden_states: torch.Tensor, 206 | attention_mask: None, 207 | ): 208 | """ 209 | Args: 210 | hidden_states (:obj:`torch.FloatTensor`): input to the layer of shape :obj:`(batch, seq_len, embed_dim)` 211 | attention_mask (:obj:`torch.FloatTensor`): attention mask of size 212 | :obj:`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. 213 | :obj:`(config.encoder_attention_heads,)`. 214 | output_attentions (:obj:`bool`, `optional`): 215 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 216 | returned tensors for more detail. 217 | """ 218 | residual = hidden_states 219 | 220 | hidden_states = self.layer_norm1(hidden_states) 221 | hidden_states, attn_weights = self.self_attn( 222 | hidden_states=hidden_states, 223 | attention_mask=attention_mask, 224 | causal_attention_mask=None, 225 | output_attentions=False, 226 | ) 227 | hidden_states = residual + hidden_states 228 | 229 | residual = hidden_states 230 | hidden_states = self.layer_norm2(hidden_states) 231 | hidden_states = self.mlp(hidden_states) 232 | hidden_states = residual + hidden_states 233 | 234 | return hidden_states 235 | 236 | 237 | class CLIPEncoder(nn.Module): 238 | """ 239 | Transformer encoder consisting of :obj:`config.num_hidden_layers` self attention layers. Each layer is a 240 | :class:`~transformers.CLIPEncoderLayer`. 241 | 242 | Args: 243 | config: CLIPConfig 244 | """ 245 | 246 | def __init__(self, hidden_size, hidden_act, num_attention_heads, attention_dropout, intermediate_size, num_hidden_layers, local_attn_depth): 247 | super().__init__() 248 | self.depth = num_hidden_layers 249 | self.local_attn_depth = local_attn_depth 250 | self.layers = nn.ModuleList([CLIPEncoderLayer(hidden_size, hidden_act, num_attention_heads, attention_dropout, intermediate_size) for _ in range(num_hidden_layers)]) 251 | 252 | def forward( 253 | self, 254 | inputs_embeds, 255 | idx_to_group_img=None, 256 | image_atts=None 257 | ): 258 | r""" 259 | Args: 260 | inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`): 261 | Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded 262 | representation. This is useful if you want more control over how to convert :obj:`input_ids` indices 263 | into associated vectors than the model's internal embedding lookup matrix. 264 | attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 265 | Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``: 266 | 267 | - 1 for tokens that are **not masked**, 268 | - 0 for tokens that are **masked**. 269 | 270 | `What are attention masks? <../glossary.html#attention-mask>`__ 271 | causal_attention_mask (:obj:`torch.Tensor` of shape :obj:`(batch_size, sequence_length)`, `optional`): 272 | Causal mask for the text model. Mask values selected in ``[0, 1]``: 273 | 274 | - 1 for tokens that are **not masked**, 275 | - 0 for tokens that are **masked**. 276 | 277 | `What are attention masks? <../glossary.html#attention-mask>`__ 278 | output_attentions (:obj:`bool`, `optional`): 279 | Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under 280 | returned tensors for more detail. 281 | output_hidden_states (:obj:`bool`, `optional`): 282 | Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors 283 | for more detail. 284 | return_dict (:obj:`bool`, `optional`): 285 | Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple. 286 | """ 287 | 288 | do_gather = True if idx_to_group_img is not None else False 289 | 290 | if do_gather and (image_atts is not None): 291 | full_atts = torch.ones(inputs_embeds.shape[:2], dtype=inputs_embeds.dtype).to(inputs_embeds.device) 292 | image_atts_blk = torch.cat([image_atts, full_atts], dim=0) 293 | 294 | image_atts_blk = image_atts_blk.unsqueeze(1).unsqueeze(2) 295 | image_atts_blk = (1.0 - image_atts_blk) * -10000.0 296 | # (bs, 1, 1, num_patches) 297 | image_atts_blk = image_atts_blk.expand(-1, -1, image_atts_blk.size(-1), -1) 298 | else: 299 | image_atts_blk = None 300 | 301 | hidden_states = inputs_embeds 302 | for idx, encoder_layer in enumerate(self.layers): 303 | if (self.local_attn_depth > 0) and (idx >= self.depth - self.local_attn_depth): 304 | if do_gather: 305 | do_gather = False 306 | hidden_states_bs = torch.gather(hidden_states, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, hidden_states.shape[1], hidden_states.shape[2])) 307 | hidden_states = torch.cat([hidden_states_bs, hidden_states], dim=0) 308 | 309 | hidden_states = encoder_layer(hidden_states, attention_mask=image_atts_blk) 310 | else: 311 | hidden_states = encoder_layer(hidden_states, attention_mask=None) 312 | 313 | return hidden_states 314 | 315 | 316 | class CLIPVisionTransformer(nn.Module): 317 | def __init__(self, image_size, patch_size, hidden_size, hidden_act, num_attention_heads, attention_dropout, intermediate_size, num_hidden_layers, local_attn_depth=0): 318 | super().__init__() 319 | 320 | self.image_size = image_size 321 | self.patch_size = patch_size 322 | 323 | self.num_patch_embed = (self.image_size // self.patch_size) ** 2 324 | self.patch_embed = nn.Conv2d( 325 | in_channels=3, out_channels=hidden_size, kernel_size=self.patch_size, stride=self.patch_size, bias=False 326 | ) 327 | self.class_embedding = nn.Parameter(torch.randn(hidden_size)) 328 | self.num_pos_embed = self.num_patch_embed + 1 329 | self.pos_embed = nn.Embedding(self.num_pos_embed, hidden_size) 330 | self.register_buffer("position_ids", torch.arange(self.num_pos_embed).expand((1, -1))) 331 | 332 | self.pre_layrnorm = nn.LayerNorm(hidden_size) 333 | self.encoder = CLIPEncoder(hidden_size, hidden_act, num_attention_heads, attention_dropout, intermediate_size, 334 | num_hidden_layers, local_attn_depth=local_attn_depth) 335 | self.post_layernorm = nn.LayerNorm(hidden_size) 336 | 337 | def forward( 338 | self, 339 | x, 340 | idx_to_group_img=None, 341 | image_atts=None 342 | ): 343 | 344 | batch_size = x.shape[0] 345 | patch_embeds = self.patch_embed(x) # shape = [*, width, grid, grid] 346 | patch_embeds = patch_embeds.flatten(2).transpose(1, 2) 347 | 348 | class_embeds = self.class_embedding.expand(batch_size, 1, -1) 349 | embeddings = torch.cat([class_embeds, patch_embeds], dim=1) 350 | hidden_states = embeddings + self.pos_embed(self.position_ids) 351 | 352 | hidden_states = self.pre_layrnorm(hidden_states) 353 | 354 | outputs = self.encoder( 355 | inputs_embeds=hidden_states, 356 | idx_to_group_img=idx_to_group_img, 357 | image_atts=image_atts) 358 | 359 | outputs = self.post_layernorm(outputs) 360 | 361 | if idx_to_group_img is not None: 362 | bs = len(idx_to_group_img) 363 | outputs, outputs_fullatts = torch.split(outputs, [bs, outputs.size(0)-bs]) 364 | return outputs, outputs_fullatts 365 | 366 | return outputs 367 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/xvlm_util/tokenization_roberta.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | # coding=utf-8 8 | # Copyright 2018 The Open AI Team Authors and The HuggingFace Inc. team. 9 | # 10 | # Licensed under the Apache License, Version 2.0 (the "License"); 11 | # you may not use this file except in compliance with the License. 12 | # You may obtain a copy of the License at 13 | # 14 | # http://www.apache.org/licenses/LICENSE-2.0 15 | # 16 | # Unless required by applicable law or agreed to in writing, software 17 | # distributed under the License is distributed on an "AS IS" BASIS, 18 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 19 | # See the License for the specific language governing permissions and 20 | # limitations under the License. 21 | """Tokenization classes for RoBERTa.""" 22 | 23 | from typing import List, Optional 24 | 25 | from transformers.tokenization_utils import AddedToken 26 | from transformers.utils import logging 27 | from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer 28 | 29 | 30 | logger = logging.get_logger(__name__) 31 | 32 | VOCAB_FILES_NAMES = { 33 | "vocab_file": "vocab.json", 34 | "merges_file": "merges.txt", 35 | } 36 | 37 | PRETRAINED_VOCAB_FILES_MAP = { 38 | "vocab_file": { 39 | "roberta-base": "https://huggingface.co/roberta-base/resolve/main/vocab.json", 40 | "roberta-large": "https://huggingface.co/roberta-large/resolve/main/vocab.json", 41 | "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/vocab.json", 42 | "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/vocab.json", 43 | "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/vocab.json", 44 | "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/vocab.json", 45 | }, 46 | "merges_file": { 47 | "roberta-base": "https://huggingface.co/roberta-base/resolve/main/merges.txt", 48 | "roberta-large": "https://huggingface.co/roberta-large/resolve/main/merges.txt", 49 | "roberta-large-mnli": "https://huggingface.co/roberta-large-mnli/resolve/main/merges.txt", 50 | "distilroberta-base": "https://huggingface.co/distilroberta-base/resolve/main/merges.txt", 51 | "roberta-base-openai-detector": "https://huggingface.co/roberta-base-openai-detector/resolve/main/merges.txt", 52 | "roberta-large-openai-detector": "https://huggingface.co/roberta-large-openai-detector/resolve/main/merges.txt", 53 | }, 54 | } 55 | 56 | PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { 57 | "roberta-base": 512, 58 | "roberta-large": 512, 59 | "roberta-large-mnli": 512, 60 | "distilroberta-base": 512, 61 | "roberta-base-openai-detector": 512, 62 | "roberta-large-openai-detector": 512, 63 | } 64 | 65 | 66 | class RobertaTokenizer(GPT2Tokenizer): 67 | """ 68 | Constructs a RoBERTa tokenizer, derived from the GPT-2 tokenizer, using byte-level Byte-Pair-Encoding. 69 | 70 | This tokenizer has been trained to treat spaces like parts of the tokens (a bit like sentencepiece) so a word will 71 | be encoded differently whether it is at the beginning of the sentence (without space) or not: 72 | 73 | :: 74 | 75 | >>> from transformers import RobertaTokenizer 76 | >>> tokenizer = RobertaTokenizer.from_pretrained("roberta-base") 77 | >>> tokenizer("Hello world")['input_ids'] 78 | [0, 31414, 232, 328, 2] 79 | >>> tokenizer(" Hello world")['input_ids'] 80 | [0, 20920, 232, 2] 81 | 82 | You can get around that behavior by passing ``add_prefix_space=True`` when instantiating this tokenizer or when you 83 | call it on some text, but since the model was not pretrained this way, it might yield a decrease in performance. 84 | 85 | .. note:: 86 | 87 | When used with ``is_split_into_words=True``, this tokenizer will add a space before each word (even the first 88 | one). 89 | 90 | This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the main methods. 91 | Users should refer to this superclass for more information regarding those methods. 92 | 93 | Args: 94 | vocab_file (:obj:`str`): 95 | Path to the vocabulary file. 96 | merges_file (:obj:`str`): 97 | Path to the merges file. 98 | errors (:obj:`str`, `optional`, defaults to :obj:`"replace"`): 99 | Paradigm to follow when decoding bytes to UTF-8. See `bytes.decode 100 | `__ for more information. 101 | bos_token (:obj:`str`, `optional`, defaults to :obj:`""`): 102 | The beginning of sequence token that was used during pretraining. Can be used a sequence classifier token. 103 | 104 | .. note:: 105 | 106 | When building a sequence using special tokens, this is not the token that is used for the beginning of 107 | sequence. The token used is the :obj:`cls_token`. 108 | eos_token (:obj:`str`, `optional`, defaults to :obj:`""`): 109 | The end of sequence token. 110 | 111 | .. note:: 112 | 113 | When building a sequence using special tokens, this is not the token that is used for the end of 114 | sequence. The token used is the :obj:`sep_token`. 115 | sep_token (:obj:`str`, `optional`, defaults to :obj:`""`): 116 | The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences for 117 | sequence classification or for a text and a question for question answering. It is also used as the last 118 | token of a sequence built with special tokens. 119 | cls_token (:obj:`str`, `optional`, defaults to :obj:`""`): 120 | The classifier token which is used when doing sequence classification (classification of the whole sequence 121 | instead of per-token classification). It is the first token of the sequence when built with special tokens. 122 | unk_token (:obj:`str`, `optional`, defaults to :obj:`""`): 123 | The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this 124 | token instead. 125 | pad_token (:obj:`str`, `optional`, defaults to :obj:`""`): 126 | The token used for padding, for example when batching sequences of different lengths. 127 | mask_token (:obj:`str`, `optional`, defaults to :obj:`""`): 128 | The token used for masking values. This is the token used when training this model with masked language 129 | modeling. This is the token which the model will try to predict. 130 | add_prefix_space (:obj:`bool`, `optional`, defaults to :obj:`False`): 131 | Whether or not to add an initial space to the input. This allows to treat the leading word just as any 132 | other word. (RoBERTa tokenizer detect beginning of words by the preceding space). 133 | """ 134 | 135 | vocab_files_names = VOCAB_FILES_NAMES 136 | pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP 137 | max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES 138 | model_input_names = ["input_ids", "attention_mask"] 139 | 140 | def __init__( 141 | self, 142 | vocab_file, 143 | merges_file, 144 | errors="replace", 145 | bos_token="", 146 | eos_token="", 147 | sep_token="", 148 | cls_token="", 149 | unk_token="", 150 | pad_token="", 151 | mask_token="", 152 | add_prefix_space=False, 153 | **kwargs 154 | ): 155 | bos_token = AddedToken(bos_token, lstrip=False, rstrip=False) if isinstance(bos_token, str) else bos_token 156 | eos_token = AddedToken(eos_token, lstrip=False, rstrip=False) if isinstance(eos_token, str) else eos_token 157 | sep_token = AddedToken(sep_token, lstrip=False, rstrip=False) if isinstance(sep_token, str) else sep_token 158 | cls_token = AddedToken(cls_token, lstrip=False, rstrip=False) if isinstance(cls_token, str) else cls_token 159 | unk_token = AddedToken(unk_token, lstrip=False, rstrip=False) if isinstance(unk_token, str) else unk_token 160 | pad_token = AddedToken(pad_token, lstrip=False, rstrip=False) if isinstance(pad_token, str) else pad_token 161 | 162 | # Mask token behave like a normal word, i.e. include the space before it 163 | mask_token = AddedToken(mask_token, lstrip=True, rstrip=False) if isinstance(mask_token, str) else mask_token 164 | 165 | super().__init__( 166 | vocab_file=vocab_file, 167 | merges_file=merges_file, 168 | errors=errors, 169 | bos_token=bos_token, 170 | eos_token=eos_token, 171 | unk_token=unk_token, 172 | sep_token=sep_token, 173 | cls_token=cls_token, 174 | pad_token=pad_token, 175 | mask_token=mask_token, 176 | add_prefix_space=add_prefix_space, 177 | **kwargs, 178 | ) 179 | 180 | def build_inputs_with_special_tokens( 181 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 182 | ) -> List[int]: 183 | """ 184 | Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and 185 | adding special tokens. A RoBERTa sequence has the following format: 186 | 187 | - single sequence: `` X `` 188 | - pair of sequences: `` A B `` 189 | 190 | Args: 191 | token_ids_0 (:obj:`List[int]`): 192 | List of IDs to which the special tokens will be added. 193 | token_ids_1 (:obj:`List[int]`, `optional`): 194 | Optional second list of IDs for sequence pairs. 195 | 196 | Returns: 197 | :obj:`List[int]`: List of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens. 198 | """ 199 | if token_ids_1 is None: 200 | # return [self.cls_token_id] + token_ids_0 + [self.sep_token_id] 201 | return [self.cls_token_id] + token_ids_0 202 | 203 | cls = [self.cls_token_id] 204 | sep = [self.sep_token_id] 205 | return cls + token_ids_0 + sep + sep + token_ids_1 + sep 206 | 207 | def get_special_tokens_mask( 208 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False 209 | ) -> List[int]: 210 | """ 211 | Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding 212 | special tokens using the tokenizer ``prepare_for_model`` method. 213 | 214 | Args: 215 | token_ids_0 (:obj:`List[int]`): 216 | List of IDs. 217 | token_ids_1 (:obj:`List[int]`, `optional`): 218 | Optional second list of IDs for sequence pairs. 219 | already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`): 220 | Whether or not the token list is already formatted with special tokens for the model. 221 | 222 | Returns: 223 | :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token. 224 | """ 225 | if already_has_special_tokens: 226 | return super().get_special_tokens_mask( 227 | token_ids_0=token_ids_0, token_ids_1=token_ids_1, already_has_special_tokens=True 228 | ) 229 | 230 | if token_ids_1 is None: 231 | return [1] + ([0] * len(token_ids_0)) + [1] 232 | return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1] 233 | 234 | def create_token_type_ids_from_sequences( 235 | self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None 236 | ) -> List[int]: 237 | """ 238 | Create a mask from the two sequences passed to be used in a sequence-pair classification task. RoBERTa does not 239 | make use of token type ids, therefore a list of zeros is returned. 240 | 241 | Args: 242 | token_ids_0 (:obj:`List[int]`): 243 | List of IDs. 244 | token_ids_1 (:obj:`List[int]`, `optional`): 245 | Optional second list of IDs for sequence pairs. 246 | 247 | Returns: 248 | :obj:`List[int]`: List of zeros. 249 | """ 250 | sep = [self.sep_token_id] 251 | cls = [self.cls_token_id] 252 | 253 | if token_ids_1 is None: 254 | return len(cls + token_ids_0 + sep) * [0] 255 | return len(cls + token_ids_0 + sep + sep + token_ids_1 + sep) * [0] 256 | 257 | def prepare_for_tokenization(self, text, is_split_into_words=False, **kwargs): 258 | add_prefix_space = kwargs.pop("add_prefix_space", self.add_prefix_space) 259 | if (is_split_into_words or add_prefix_space) and (len(text) > 0 and not text[0].isspace()): 260 | text = " " + text 261 | return (text, kwargs) 262 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/xvlm_util/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | import os 8 | import subprocess 9 | import yaml 10 | 11 | from unibench.common_utils.constants import HUB_CACHE_DIR 12 | import gdown 13 | 14 | download_urls = { 15 | "xvlm-flickr": { 16 | "model_url": "1vhdtH3iFaoZuMqOGm-8YM-diPWVfRJzv", 17 | "vision_config_url": "https://github.com/zengyan-97/X-VLM/raw/e7b960256d194952321b5adad39770c03e6ce9c2/configs/config_swinB_384.json", 18 | "config_url": "13-GCckeAh7QUeFVGwye7qLJamwl_hXdf", 19 | "bert_config_url": "https://github.com/zengyan-97/X-VLM/raw/e7b960256d194952321b5adad39770c03e6ce9c2/configs/config_bert.json", 20 | }, 21 | "xvlm-coco": { 22 | "model_url": "1bv6_pZOsXW53EhlwU0ZgSk03uzFI61pN", 23 | "vision_config_url": "https://github.com/zengyan-97/X-VLM/raw/e7b960256d194952321b5adad39770c03e6ce9c2/configs/config_swinB_384.json", 24 | "config_url": "11pdOukGXZzmPubvjLhJ2Sr1BIBRTEM-P", 25 | "bert_config_url": "https://github.com/zengyan-97/X-VLM/raw/e7b960256d194952321b5adad39770c03e6ce9c2/configs/config_bert.json", 26 | }, 27 | } 28 | 29 | 30 | def get_config(version, root_dir=HUB_CACHE_DIR): 31 | config_path = os.path.join(root_dir, f"{version}-config") 32 | model_path = os.path.join(root_dir, f"{version}.pth") 33 | bert_config_path = os.path.join( 34 | root_dir, 35 | "configs", 36 | download_urls[version]["bert_config_url"].split("/")[-1], 37 | ) 38 | vision_config_path = os.path.join( 39 | root_dir, 40 | "configs", 41 | download_urls[version]["vision_config_url"].split("/")[-1], 42 | ) 43 | 44 | if not ( 45 | os.path.exists(config_path) 46 | and os.path.exists(model_path) 47 | and os.path.exists(bert_config_path) 48 | and os.path.exists(vision_config_path) 49 | ): 50 | print(f"Downloading XVLM model to {root_dir}...") 51 | model_url = download_urls[version]["model_url"] 52 | config_url = download_urls[version]["config_url"] 53 | bert_config_url = download_urls[version]["bert_config_url"] 54 | vision_config_url = download_urls[version]["vision_config_url"] 55 | os.makedirs(os.path.join(root_dir, "configs"), exist_ok=True) 56 | gdown.download(id=model_url, output=model_path, quiet=False) 57 | gdown.download(id=config_url, output=config_path, quiet=False) 58 | subprocess.call(["wget", "-c", bert_config_url, "-O", bert_config_path]) 59 | subprocess.call(["wget", "-c", vision_config_url, "-O", vision_config_path]) 60 | 61 | config = yaml.load(open(config_path, "r"), Loader=yaml.Loader) 62 | config["vision_config"] = vision_config_path 63 | config["text_config"] = bert_config_path 64 | 65 | return config, model_path 66 | -------------------------------------------------------------------------------- /unibench/models_zoo/wrappers/xvlm_util/vit.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | 9 | # Multi-Grained Vision Language Pre-Training: Aligning Texts with Visual Concepts (https://arxiv.org/abs/2111.08276) 10 | # Github: https://github.com/zengyan-97/X-VLM 11 | # Copyright (c) 2022, ByteDance Inc. 12 | # All rights reserved. 13 | 14 | import sys 15 | 16 | import torch 17 | import torch.nn as nn 18 | import torch.nn.functional as F 19 | from functools import partial 20 | 21 | from timm.models.vision_transformer import _cfg, PatchEmbed 22 | from timm.models.registry import register_model 23 | from timm.models.layers import trunc_normal_, DropPath 24 | 25 | 26 | class Mlp(nn.Module): 27 | """ MLP as used in Vision Transformer, MLP-Mixer and related networks 28 | """ 29 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 30 | super().__init__() 31 | out_features = out_features or in_features 32 | hidden_features = hidden_features or in_features 33 | self.fc1 = nn.Linear(in_features, hidden_features) 34 | self.act = act_layer() 35 | self.fc2 = nn.Linear(hidden_features, out_features) 36 | self.drop = nn.Dropout(drop) 37 | 38 | def forward(self, x): 39 | x = self.fc1(x) 40 | x = self.act(x) 41 | x = self.drop(x) 42 | x = self.fc2(x) 43 | x = self.drop(x) 44 | return x 45 | 46 | 47 | class Attention(nn.Module): 48 | def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.): 49 | super().__init__() 50 | self.num_heads = num_heads 51 | head_dim = dim // num_heads 52 | # NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights 53 | self.scale = qk_scale or head_dim ** -0.5 54 | self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) 55 | self.attn_drop = nn.Dropout(attn_drop) 56 | self.proj = nn.Linear(dim, dim) 57 | self.proj_drop = nn.Dropout(proj_drop) 58 | self.attn_gradients = None 59 | self.attention_map = None 60 | 61 | def save_attn_gradients(self, attn_gradients): 62 | self.attn_gradients = attn_gradients 63 | 64 | def get_attn_gradients(self): 65 | return self.attn_gradients 66 | 67 | def save_attention_map(self, attention_map): 68 | self.attention_map = attention_map 69 | 70 | def get_attention_map(self): 71 | return self.attention_map 72 | 73 | def forward(self, x, register_hook=False, image_atts=None): 74 | B, N, C = x.shape 75 | qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) 76 | q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple) 77 | 78 | attn = (q @ k.transpose(-2, -1)) * self.scale 79 | 80 | if image_atts is not None: 81 | attn += image_atts 82 | 83 | attn = attn.softmax(dim=-1) 84 | attn = self.attn_drop(attn) 85 | 86 | if register_hook: 87 | self.save_attention_map(attn) 88 | attn.register_hook(self.save_attn_gradients) 89 | 90 | # attn: (bs, num_heads, num_patches, num_patches) 91 | # v: (bs, num_heads, num_patches, d) 92 | # attn @ v: (bs, num_heads, num_patches, d) 93 | x = (attn @ v).transpose(1, 2).reshape(B, N, C) 94 | x = self.proj(x) 95 | x = self.proj_drop(x) 96 | return x 97 | 98 | 99 | class Block(nn.Module): 100 | 101 | def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., 102 | drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm): 103 | super().__init__() 104 | self.norm1 = norm_layer(dim) 105 | self.attn = Attention( 106 | dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop) 107 | # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here 108 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 109 | self.norm2 = norm_layer(dim) 110 | mlp_hidden_dim = int(dim * mlp_ratio) 111 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) 112 | 113 | def forward(self, x, register_hook=False, image_atts=None): 114 | x = x + self.drop_path(self.attn(self.norm1(x), register_hook=register_hook, image_atts=image_atts)) 115 | x = x + self.drop_path(self.mlp(self.norm2(x))) 116 | return x 117 | 118 | 119 | class VisionTransformer(nn.Module): 120 | """ Vision Transformer 121 | A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale` - 122 | https://arxiv.org/abs/2010.11929 123 | """ 124 | def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12, 125 | num_heads=12, mlp_ratio=4., qkv_bias=True, qk_scale=None, representation_size=None, 126 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0., norm_layer=None, local_attn_depth=0): 127 | """ 128 | Args: 129 | img_size (int, tuple): input image size 130 | patch_size (int, tuple): patch size 131 | in_chans (int): number of input channels 132 | num_classes (int): number of classes for classification head 133 | embed_dim (int): embedding dimension 134 | depth (int): depth of transformer 135 | num_heads (int): number of attention heads 136 | mlp_ratio (int): ratio of mlp hidden dim to embedding dim 137 | qkv_bias (bool): enable bias for qkv if True 138 | qk_scale (float): override default qk scale of head_dim ** -0.5 if set 139 | representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set 140 | drop_rate (float): dropout rate 141 | attn_drop_rate (float): attention dropout rate 142 | drop_path_rate (float): stochastic depth rate 143 | norm_layer: (nn.Module): normalization layer 144 | """ 145 | super().__init__() 146 | self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models 147 | norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6) 148 | 149 | self.patch_embed = PatchEmbed( 150 | img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) 151 | 152 | self.num_patch_embed = self.patch_embed.num_patches 153 | 154 | self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) 155 | 156 | self.num_pos_embed = self.num_patch_embed + 1 157 | self.pos_embed = nn.Parameter(torch.zeros(1, self.num_pos_embed, embed_dim)) 158 | 159 | self.pos_drop = nn.Dropout(p=drop_rate) 160 | 161 | dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule 162 | self.blocks = nn.ModuleList([ 163 | Block( 164 | dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, 165 | drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer) 166 | for i in range(depth)]) 167 | 168 | self.depth = depth 169 | self.local_attn_depth = local_attn_depth # do local attn from index=(depth - local_attn_depth) 170 | 171 | self.norm = norm_layer(embed_dim) 172 | 173 | trunc_normal_(self.pos_embed, std=.02) 174 | trunc_normal_(self.cls_token, std=.02) 175 | self.apply(self._init_weights) 176 | 177 | def _init_weights(self, m): 178 | if isinstance(m, nn.Linear): 179 | trunc_normal_(m.weight, std=.02) 180 | if isinstance(m, nn.Linear) and m.bias is not None: 181 | nn.init.constant_(m.bias, 0) 182 | elif isinstance(m, nn.LayerNorm): 183 | nn.init.constant_(m.bias, 0) 184 | nn.init.constant_(m.weight, 1.0) 185 | 186 | @torch.jit.ignore 187 | def no_weight_decay(self): 188 | return {'pos_embed', 'cls_token'} 189 | 190 | def forward(self, x, register_blk=-1, idx_to_group_img=None, image_atts=None): 191 | 192 | B = x.shape[0] 193 | x = self.patch_embed(x) 194 | 195 | cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks 196 | x = torch.cat((cls_tokens, x), dim=1) 197 | 198 | x = x + self.pos_embed[:,:x.size(1),:] 199 | x = self.pos_drop(x) 200 | 201 | do_gather = True if idx_to_group_img is not None else False 202 | 203 | if do_gather and (image_atts is not None): 204 | full_atts = torch.ones(x.shape[:2], dtype=x.dtype).to(x.device) 205 | image_atts_blk = torch.cat([image_atts, full_atts], dim=0) 206 | 207 | image_atts_blk = image_atts_blk.unsqueeze(1).unsqueeze(2) 208 | image_atts_blk = (1.0 - image_atts_blk) * -10000.0 209 | else: 210 | image_atts_blk = None 211 | 212 | for i, blk in enumerate(self.blocks): 213 | if (self.local_attn_depth > 0) and (i >= self.depth-self.local_attn_depth): 214 | if do_gather: 215 | do_gather = False 216 | 217 | x_bs = torch.gather(x, dim=0, index=idx_to_group_img.view(-1, 1, 1).expand(-1, x.shape[1], x.shape[2])) 218 | x = torch.cat([x_bs, x], dim=0) 219 | 220 | x = blk(x, register_blk == i, image_atts=image_atts_blk) 221 | 222 | else: 223 | x = blk(x, register_blk==i, image_atts=None) 224 | 225 | x = self.norm(x) 226 | 227 | if idx_to_group_img is not None: 228 | bs = len(idx_to_group_img) 229 | x_bs, x_fullatts = torch.split(x, [bs, x.size(0)-bs]) 230 | return x_bs, x_fullatts 231 | 232 | return x 233 | 234 | 235 | def interpolate_pos_embed(pos_embed_checkpoint, num_patches, num_extra_tokens=1): 236 | # num_patches = visual_encoder.num_patch_embed 237 | # num_extra_tokens = visual_encoder.num_pos_embed - visual_encoder.num_patch_embed 238 | 239 | # interpolate position embedding 240 | embedding_size = pos_embed_checkpoint.shape[-1] 241 | # height (== width) for the checkpoint position embedding 242 | orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5) 243 | # height (== width) for the new position embedding 244 | new_size = int(num_patches ** 0.5) 245 | 246 | if orig_size != new_size: 247 | # class_token and dist_token are kept unchanged 248 | extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] 249 | # only the position tokens are interpolated 250 | pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] 251 | pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) 252 | pos_tokens = torch.nn.functional.interpolate( 253 | pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False) 254 | pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) 255 | new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) 256 | print('reshape position embedding from %d to %d' % (orig_size ** 2, new_size ** 2)) 257 | 258 | return new_pos_embed 259 | else: 260 | return pos_embed_checkpoint 261 | -------------------------------------------------------------------------------- /unibench/output.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (c) Meta Platforms, Inc. and affiliates. 3 | All rights reserved. 4 | This source code is licensed under the license found in the 5 | LICENSE file in the root directory of this source tree. 6 | """ 7 | 8 | import os.path 9 | from pathlib import Path 10 | 11 | import pandas as pd 12 | 13 | from unibench.common_utils.utils import download_all_results, download_only_aggregate 14 | from oslo_concurrency import lockutils 15 | 16 | from .common_utils.constants import OUTPUT_DIR, LOCK_DIR 17 | 18 | 19 | class OutputHandler(object): 20 | def __init__( 21 | self, 22 | load_all_csv=False, 23 | round_values=4, 24 | output_dir=OUTPUT_DIR, 25 | download_all_precomputed=False, 26 | download_aggregate_precomputed=False, 27 | ): 28 | self.output_dir = Path(output_dir) 29 | if download_all_precomputed: 30 | download_all_results(self.output_dir) 31 | elif download_aggregate_precomputed: 32 | download_only_aggregate(self.output_dir) 33 | self.round_values = round_values 34 | self.reset_local_csv() 35 | lockutils.set_defaults(lock_path=LOCK_DIR) 36 | self.load_aggregate_results() 37 | if load_all_csv: 38 | self.load_all_csv() 39 | 40 | def reset_local_csv(self): 41 | self._local_csv = pd.DataFrame() 42 | 43 | def check_if_computed(self, model_name, benchmark_name, **kwargs): 44 | self.load_aggregate_results() 45 | res = self.query( 46 | df=self._aggregate, 47 | **{"model_name": model_name, "benchmark_name": benchmark_name} 48 | ) 49 | if len(res) >= 1: 50 | return True 51 | 52 | self.load_csv(model_name, benchmark_name) 53 | return len(self.query(**kwargs)) 54 | 55 | def load_all_csvs(self, model_names): 56 | self._model_csv = pd.DataFrame() 57 | dfs = [] 58 | for model in model_names: 59 | model_folder = self.output_dir.joinpath(model) 60 | for benchmark_file in os.listdir(model_folder): 61 | file = model_folder.joinpath(benchmark_file) 62 | if ".f" in file.suffix and file.exists(): 63 | try: 64 | dfs.append(pd.read_feather(file)) 65 | except: 66 | print("Error reading file: ", file) 67 | else: 68 | print("File not found: ", file) 69 | 70 | self._model_csv = pd.concat(dfs).reset_index(drop=True).round(self.round_values) 71 | 72 | def load_all_csv(self, model_name, benchmark_name): 73 | self._model_csv = pd.DataFrame() 74 | dfs = [] 75 | for model in model_name: 76 | model_folder = self.output_dir.joinpath(model) 77 | for benchmark in benchmark_name: 78 | file = model_folder.joinpath(benchmark + ".f") 79 | if file.exists(): 80 | try: 81 | dfs.append(pd.read_feather(file)) 82 | except: 83 | print("Error reading file: ", file) 84 | else: 85 | print("File not found: ", file) 86 | 87 | self._model_csv = pd.concat(dfs).reset_index(drop=True).round(self.round_values) 88 | 89 | def load_csv(self, model_name, benchmark_name): 90 | file_name = str( 91 | self.output_dir.joinpath(model_name).joinpath(benchmark_name + ".f") 92 | ) 93 | 94 | # Load the csv if it exists 95 | if os.path.exists(file_name): 96 | self._model_csv = pd.read_feather(file_name) 97 | else: 98 | self._model_csv = pd.DataFrame() 99 | 100 | def load_model_csvs(self, model_name, use_cols=None): 101 | model_folder = self.output_dir.joinpath(model_name) 102 | 103 | self._model_csv = pd.DataFrame() 104 | dfs = [] 105 | for file in os.listdir(model_folder): 106 | if file.endswith(".f"): 107 | dfs.append( 108 | pd.read_feather(model_folder.joinpath(file), columns=use_cols) 109 | ) 110 | 111 | self._model_csv = pd.concat(dfs).reset_index(drop=True).round(self.round_values) 112 | 113 | def get_csv(self): 114 | return pd.concat([self._local_csv, self._model_csv]) 115 | 116 | def add_values(self, **kwargs): 117 | import torch 118 | for k in kwargs.keys(): 119 | if isinstance(kwargs[k], torch.Tensor): 120 | kwargs[k] = kwargs[k].cpu().squeeze().tolist() 121 | self._local_csv = pd.concat([self._local_csv, pd.DataFrame(kwargs)]) 122 | 123 | def query(self, df=None, **kwargs): 124 | if df is None: 125 | df = self._model_csv 126 | if len(kwargs) == 0: 127 | return df 128 | 129 | mask = pd.Series([True] * len(df)) 130 | 131 | for k, v in kwargs.items(): 132 | if isinstance(v, list): 133 | mask &= df[k].isin(v) 134 | else: 135 | mask &= (df[k] == v) 136 | 137 | return df[mask] 138 | 139 | def delete_rows(self, model_name, benchmark_name, **kwargs): 140 | # file_name = str(OUTPUT_DIR.joinpath(model_name + ".f")) 141 | self.output_dir.joinpath(model_name).mkdir(parents=True, exist_ok=True) 142 | file_name = str( 143 | self.output_dir.joinpath(model_name).joinpath(benchmark_name + ".f") 144 | ) 145 | 146 | # Load the csv if it exists 147 | if os.path.exists(file_name): 148 | self._model_csv = pd.read_feather(file_name) 149 | else: 150 | pass 151 | 152 | self._model_csv.drop(self.query(**kwargs).index, inplace=True) 153 | self._model_csv = self._model_csv.reset_index(drop=True) 154 | 155 | Path(self.output_dir).mkdir(parents=True, exist_ok=True) 156 | self._model_csv.to_feather(file_name) 157 | 158 | def _get_benchmark_mappings(self, axis): 159 | from .benchmarks_zoo.registry import get_benchmark_info, list_benchmarks 160 | benchmark_mappings = {} 161 | for benchmark in list_benchmarks(): 162 | if axis is None: 163 | benchmark_mappings[benchmark] = get_benchmark_info(benchmark) 164 | else: 165 | benchmark_mappings[benchmark] = get_benchmark_info(benchmark)[axis] 166 | return benchmark_mappings 167 | 168 | def get_aggregate_results(self): 169 | if not hasattr(self, "_aggregate"): 170 | self.load_aggregate_results() 171 | return self._aggregate 172 | 173 | @lockutils.synchronized(name="aggregate", external=True, fair=True) 174 | def load_aggregate_results(self): 175 | file = self.output_dir.joinpath("aggregate.f") 176 | if file.exists(): 177 | self._aggregate = pd.read_feather(file) 178 | 179 | @lockutils.synchronized(name="aggregate", external=True, fair=True) 180 | def save_aggregate_results(self, model_name, benchmark_name): 181 | file_dir = self.output_dir.joinpath("aggregate.f") 182 | if file_dir.exists(): 183 | self._aggregate = pd.read_feather(file_dir) 184 | 185 | df = self.query( 186 | self._model_csv, 187 | **{"model_name": [model_name], "benchmark_name": [benchmark_name]} 188 | ) 189 | 190 | df = ( 191 | df.groupby(["model_name", "benchmark_name"])["correctness"] 192 | .mean() 193 | .reset_index() 194 | ) 195 | 196 | df = ( 197 | pd.concat([self._aggregate, df]) 198 | .drop_duplicates(subset=["model_name", "benchmark_name"], keep="last") 199 | .reset_index(drop=True) 200 | ) 201 | 202 | df.to_feather(file_dir) 203 | 204 | def print_dataframe(self, **kwargs): 205 | self.load_aggregate_results() 206 | df = self.query(df=self._aggregate, **kwargs) 207 | benchmark_mappings = self._get_benchmark_mappings("benchmark_type") 208 | df["benchmark_type"] = df["benchmark_name"].map(benchmark_mappings) 209 | df = ( 210 | df.groupby(["model_name", "benchmark_name", "benchmark_type"])[ 211 | "correctness" 212 | ] 213 | .mean() 214 | .reset_index() 215 | ) 216 | 217 | df = ( 218 | df.groupby(["model_name", "benchmark_type"])["correctness"] 219 | .mean() 220 | .reset_index() 221 | ) 222 | return df.pivot( 223 | index="model_name", columns="benchmark_type", values="correctness" 224 | ) 225 | 226 | def save_csv(self, model_name, benchmark_name): 227 | self.output_dir.joinpath(model_name).mkdir(parents=True, exist_ok=True) 228 | file_name = str( 229 | self.output_dir.joinpath(model_name).joinpath(benchmark_name + ".f") 230 | ) 231 | 232 | # Load the csv if it exists 233 | if os.path.exists(file_name): 234 | self._model_csv = pd.read_feather(file_name) 235 | else: 236 | self._model_csv = pd.DataFrame() 237 | 238 | # Add the local csv to the model csv 239 | self._model_csv = ( 240 | pd.concat( 241 | [self._model_csv, self._local_csv.reset_index(drop=True)], 242 | axis=0, 243 | ignore_index=True, 244 | ) 245 | .round(self.round_values) 246 | .reset_index(drop=True) 247 | ) 248 | 249 | # Save the model csv 250 | self._model_csv.to_feather(file_name) 251 | self.reset_local_csv() --------------------------------------------------------------------------------