├── .github ├── dependabot.yml └── workflows │ ├── ci.yml │ ├── front-build.yml │ ├── front-publish.yml │ └── python-publish.yml ├── .gitignore ├── .gitpod.DockerFile ├── .gitpod.yml ├── .pylintrc ├── HISTORY.md ├── LICENSE ├── Makefile ├── README.md ├── clip_retrieval ├── __init__.py ├── cli.py ├── clip_back.py ├── clip_back_prepro │ ├── README.md │ ├── __init__.py │ ├── index_combiner.py │ └── parquet_to_arrow.py ├── clip_client.py ├── clip_end2end.py ├── clip_filter.py ├── clip_front.py ├── clip_index.py ├── clip_inference │ ├── __init__.py │ ├── distributor.py │ ├── logger.py │ ├── main.py │ ├── mapper.py │ ├── reader.py │ ├── runner.py │ ├── slurm_distributor.py │ ├── slurm_worker.py │ ├── worker.py │ └── writer.py ├── h14_nsfw_model.py └── ivf_metadata_ordering.py ├── doc_assets ├── clip-back-grafana.png ├── clip-front-pic.png └── grafana_dashboard.json ├── docs ├── distributed_clip_inference.md ├── laion5B_back.md └── laion5B_h14_back.md ├── front ├── .gitignore ├── .npmignore ├── .npmrc ├── README.md ├── config.json ├── package.json ├── server.js ├── src │ ├── assets │ │ ├── download.png │ │ ├── image-search.png │ │ └── search.png │ ├── clip-front.js │ ├── clip-service.js │ └── index.html └── webpack.config.js ├── mypy.ini ├── notebook ├── clip-client-query-api.ipynb ├── clip-retrieval-getting-started.ipynb ├── retrieval_example.ipynb └── simple_filter.ipynb ├── pytest.ini ├── requirements-test.txt ├── requirements.txt ├── setup.py └── tests ├── test_back.sh ├── test_clip_client.py ├── test_clip_inference ├── playground.ipynb ├── test_distributor.py ├── test_embeddings │ ├── 0.pkl │ ├── 1.pkl │ ├── 2.pkl │ └── 3.pkl ├── test_get_tasks.py ├── test_images │ ├── 123_456.jpg │ ├── 208_495.jpg │ ├── 321_421.jpg │ ├── 389_535.jpg │ ├── 416_264.jpg │ ├── 456_123.jpg │ └── 524_316.jpg ├── test_main.py ├── test_mapper.py ├── test_reader.py ├── test_runner.py ├── test_tars │ ├── image1.tar │ └── image2.tar ├── test_tensors │ ├── 0.pkl │ ├── 1.pkl │ ├── 2.pkl │ └── 3.pkl ├── test_worker.py └── test_writer.py ├── test_end2end.py ├── test_filter.sh ├── test_index.sh ├── test_inference.sh └── test_simple.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | lint: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 3.8 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.8 20 | - name: Install 21 | run: | 22 | python3 -m venv .env 23 | source .env/bin/activate 24 | python -m pip install -U pip 25 | make install-dev 26 | - name: Lint 27 | run: | 28 | source .env/bin/activate 29 | make lint 30 | tests: 31 | runs-on: ubuntu-latest 32 | strategy: 33 | matrix: 34 | python-version: ["3.8", "3.10"] 35 | 36 | steps: 37 | - uses: actions/checkout@v2 38 | - name: Set up Python ${{ matrix.python-version }} 39 | uses: actions/setup-python@v2 40 | with: 41 | python-version: ${{ matrix.python-version }} 42 | - name: Install 43 | run: | 44 | python3 -m venv .env 45 | source .env/bin/activate 46 | make install 47 | make install-dev 48 | - name: Unit tests 49 | run: | 50 | source .env/bin/activate 51 | make test 52 | -------------------------------------------------------------------------------- /.github/workflows/front-build.yml: -------------------------------------------------------------------------------- 1 | name: front build 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | build: 11 | 12 | runs-on: ubuntu-latest 13 | 14 | strategy: 15 | matrix: 16 | node-version: [18.x] 17 | 18 | steps: 19 | - uses: actions/checkout@v2 20 | - name: Use Node.js ${{ matrix.node-version }} 21 | uses: actions/setup-node@v1 22 | with: 23 | node-version: ${{ matrix.node-version }} 24 | - run: cd front && npm install 25 | - run: cd front && npm run build 26 | - run: cd front && npm test 27 | DeployPages: 28 | runs-on: ubuntu-latest 29 | if: ${{ github.event_name == 'push' }} 30 | steps: 31 | - name: Checkout 🛎️ 32 | uses: actions/checkout@v2.3.1 # If you're using actions/checkout@v2 you must set persist-credentials to false in most cases for the deployment to work correctly. 33 | with: 34 | persist-credentials: false 35 | fetch-depth: 0 36 | - name: Build 37 | run: | 38 | cd front 39 | npm install 40 | npm run build 41 | cp -R build/ ../../ 42 | cd .. 43 | rm -Rf ./* 44 | git checkout gh-pages 45 | rm -Rf ./* 46 | rm -Rf .github .gitignore .gitpod .gitpod.DockerFile .npmignore .npmrc 47 | cp -R ../build/* ./ 48 | - name: Create commits 49 | run: | 50 | git config user.name 'rom1504bot' 51 | git config user.email 'rom1504bot@users.noreply.github.com' 52 | git add --all 53 | git commit --amend -m "Update gh-pages" 54 | - name: Deploy 🚀 55 | uses: ad-m/github-push-action@master 56 | with: 57 | github_token: ${{ secrets.GITHUB_TOKEN }} 58 | branch: gh-pages 59 | force: true -------------------------------------------------------------------------------- /.github/workflows/front-publish.yml: -------------------------------------------------------------------------------- 1 | name: npm-publish 2 | on: 3 | push: 4 | branches: 5 | - main # Change this to your default branch 6 | jobs: 7 | npm-publish: 8 | name: npm-publish 9 | runs-on: ubuntu-latest 10 | steps: 11 | - name: Checkout repository 12 | uses: actions/checkout@master 13 | - name: Set up Node.js 14 | uses: actions/setup-node@master 15 | with: 16 | node-version: 18.0.0 17 | - run: cd front && npm install 18 | - run: cd front && npm run build 19 | - run: cd front && npm test 20 | - id: publish 21 | uses: JS-DevTools/npm-publish@v1 22 | with: 23 | token: ${{ secrets.NPM_AUTH_TOKEN }} 24 | package: front/package.json 25 | - name: Create Release 26 | if: steps.publish.outputs.type != 'none' 27 | id: create_release 28 | uses: actions/create-release@v1 29 | env: 30 | GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} 31 | with: 32 | tag_name: ${{ steps.publish.outputs.version }} 33 | release_name: Release ${{ steps.publish.outputs.version }} 34 | body: ${{ steps.publish.outputs.version }} 35 | draft: false 36 | prerelease: false -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | jobs: 11 | deploy: 12 | runs-on: ubuntu-latest 13 | steps: 14 | - uses: actions/checkout@v2 15 | - uses: actions-ecosystem/action-regex-match@v2 16 | id: regex-match 17 | with: 18 | text: ${{ github.event.head_commit.message }} 19 | regex: '^Release ([^ ]+)' 20 | - name: Use Node.js 14.x 21 | uses: actions/setup-node@v1 22 | with: 23 | node-version: 14.x 24 | - run: cd front && npm install 25 | - run: cd front && npm run build 26 | - name: Set up Python 27 | uses: actions/setup-python@v2 28 | with: 29 | python-version: '3.8' 30 | - name: Install dependencies 31 | run: | 32 | python -m pip install --upgrade pip 33 | pip install setuptools wheel twine pex 34 | - name: Build and publish 35 | if: ${{ steps.regex-match.outputs.match != '' && github.event_name != 'pull_request' }} 36 | env: 37 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 38 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 39 | run: | 40 | python setup.py sdist bdist_wheel 41 | twine upload dist/* 42 | - name: Build pex 43 | run: | 44 | make build-pex 45 | - name: Release 46 | if: ${{ steps.regex-match.outputs.match != '' && github.event_name != 'pull_request' }} 47 | uses: softprops/action-gh-release@v1 48 | with: 49 | files: | 50 | clip_retrieval_torch.tgz 51 | clip_retrieval.tgz 52 | tag_name: ${{ steps.regex-match.outputs.group1 }} 53 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .env 2 | *.egg-info 3 | .vscode 4 | .env 5 | __pycache__ 6 | myimglist.txt 7 | .ipynb_checkpoints 8 | output_folder 9 | indice_folder 10 | image_folder 11 | cat 12 | embedding_folder 13 | index_folder 14 | indices_paths.json 15 | .coverage* 16 | test_folder 17 | build 18 | dist 19 | wandb 20 | .pexing 21 | *.tgz 22 | *.pex -------------------------------------------------------------------------------- /.gitpod.DockerFile: -------------------------------------------------------------------------------- 1 | FROM gitpod/workspace-full:latest 2 | 3 | RUN apt-get update && apt-get install -y python3-opencv -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | image: 2 | file: .gitpod.DockerFile -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | #init-hook= 9 | 10 | # Add files or directories to the blacklist. They should be base names, not 11 | # paths. 12 | ignore=CVS 13 | 14 | # Pickle collected data for later comparisons. 15 | persistent=yes 16 | 17 | # List of plugins (as comma separated values of python modules names) to load, 18 | # usually to register additional checkers. 19 | load-plugins= 20 | 21 | 22 | [MESSAGES CONTROL] 23 | 24 | # Enable the message, report, category or checker with the given id(s). You can 25 | # either give multiple identifier separated by comma (,) or put this option 26 | # multiple time. See also the "--disable" option for examples. 27 | enable=indexing-exception,old-raise-syntax 28 | 29 | # Disable the message, report, category or checker with the given id(s). You 30 | # can either give multiple identifiers separated by comma (,) or put this 31 | # option multiple times (only on the command line, not in the configuration 32 | # file where it should appear only once).You can also use "--disable=all" to 33 | # disable everything first and then reenable specific checks. For example, if 34 | # you want to run only the similarities checker, you can use "--disable=all 35 | # --enable=similarities". If you want to run only the classes checker, but have 36 | # no Warning level messages displayed, use"--disable=all --enable=classes 37 | # --disable=W" 38 | disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,no-else-return,wrong-import-order,unnecessary-pass,logging-fstring-interpolation,logging-format-interpolation,C0330 39 | 40 | 41 | [REPORTS] 42 | 43 | # Set the output format. Available formats are text, parseable, colorized, msvs 44 | # (visual studio) and html. You can also give a reporter class, eg 45 | # mypackage.mymodule.MyReporterClass. 46 | output-format=text 47 | 48 | # Tells whether to display a full report or only the messages 49 | reports=no 50 | 51 | # Python expression which should return a note less than 10 (10 is the highest 52 | # note). You have access to the variables errors warning, statement which 53 | # respectively contain the number of errors / warnings messages and the total 54 | # number of statements analyzed. This is used by the global evaluation report 55 | # (RP0004). 56 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 57 | 58 | # Template used to display messages. This is a python new-style format string 59 | # used to format the message information. See doc for all details 60 | #msg-template= 61 | 62 | 63 | [TYPECHECK] 64 | 65 | # Tells whether missing members accessed in mixin class should be ignored. A 66 | # mixin class is detected if its name ends with "mixin" (case insensitive). 67 | ignore-mixin-members=yes 68 | 69 | # List of classes names for which member attributes should not be checked 70 | # (useful for classes with attributes dynamically set). 71 | ignored-classes=SQLObject 72 | 73 | # List of members which are set dynamically and missed by pylint inference 74 | # system, and so shouldn't trigger E0201 when accessed. Python regular 75 | # expressions are accepted. 76 | generated-members=REQUEST,acl_users,aq_parent 77 | 78 | # List of decorators that create context managers from functions, such as 79 | # contextlib.contextmanager. 80 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 81 | 82 | 83 | [VARIABLES] 84 | 85 | # Tells whether we should check for unused import in __init__ files. 86 | init-import=no 87 | 88 | # A regular expression matching the beginning of the name of dummy variables 89 | # (i.e. not used). 90 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 91 | 92 | # List of additional names supposed to be defined in builtins. Remember that 93 | # you should avoid to define new builtins when possible. 94 | additional-builtins= 95 | 96 | 97 | [BASIC] 98 | 99 | # Regular expression which should only match correct module names 100 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 101 | 102 | # Regular expression which should only match correct module level names 103 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 104 | 105 | # Regular expression which should only match correct class names 106 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 107 | 108 | # Regular expression which should only match correct function names 109 | function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 110 | 111 | # Regular expression which should only match correct method names 112 | method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 113 | 114 | # Regular expression which should only match correct instance attribute names 115 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 116 | 117 | # Regular expression which should only match correct argument names 118 | argument-rgx=^[a-z][a-z0-9_]*$ 119 | 120 | # Regular expression which should only match correct variable names 121 | variable-rgx=^[a-z][a-z0-9_]*$ 122 | 123 | # Regular expression which should only match correct attribute names in class 124 | # bodies 125 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 126 | 127 | # Regular expression which should only match correct list comprehension / 128 | # generator expression variable names 129 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 130 | 131 | # Good variable names which should always be accepted, separated by a comma 132 | good-names=main,_ 133 | 134 | # Bad variable names which should always be refused, separated by a comma 135 | bad-names= 136 | 137 | # Regular expression which should only match function or class names that do 138 | # not require a docstring. 139 | no-docstring-rgx=(__.*__|main) 140 | 141 | # Minimum line length for functions/classes that require docstrings, shorter 142 | # ones are exempt. 143 | docstring-min-length=10 144 | 145 | 146 | [FORMAT] 147 | 148 | # Maximum number of characters on a single line. 149 | max-line-length=120 150 | 151 | # Regexp for a line that is allowed to be longer than the limit. 152 | ignore-long-lines=(?x) 153 | (^\s*(import|from)\s 154 | |\$Id:\s\/\/depot\/.+#\d+\s\$ 155 | |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') 156 | |^\s*\#\ LINT\.ThenChange 157 | |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ 158 | |pylint 159 | |""" 160 | |\# 161 | |lambda 162 | |(https?|ftp):) 163 | 164 | # Allow the body of an if to be on the same line as the test if there is no 165 | # else. 166 | single-line-if-stmt=y 167 | 168 | # Maximum number of lines in a module 169 | max-module-lines=99999 170 | 171 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 172 | # tab). 173 | indent-string=' ' 174 | 175 | 176 | [SIMILARITIES] 177 | 178 | # Minimum lines number of a similarity. 179 | min-similarity-lines=4 180 | 181 | # Ignore comments when computing similarities. 182 | ignore-comments=yes 183 | 184 | # Ignore docstrings when computing similarities. 185 | ignore-docstrings=yes 186 | 187 | # Ignore imports when computing similarities. 188 | ignore-imports=no 189 | 190 | 191 | [MISCELLANEOUS] 192 | 193 | # List of note tags to take in consideration, separated by a comma. 194 | notes= 195 | 196 | 197 | [IMPORTS] 198 | 199 | # Deprecated modules which should not be used, separated by a comma 200 | deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets 201 | 202 | # Create a graph of every (i.e. internal and external) dependencies in the 203 | # given file (report RP0402 must not be disabled) 204 | import-graph= 205 | 206 | # Create a graph of external dependencies in the given file (report RP0402 must 207 | # not be disabled) 208 | ext-import-graph= 209 | 210 | # Create a graph of internal dependencies in the given file (report RP0402 must 211 | # not be disabled) 212 | int-import-graph= 213 | 214 | extension-pkg-whitelist=_jsonnet 215 | 216 | 217 | [CLASSES] 218 | 219 | # List of method names used to declare (i.e. assign) instance attributes. 220 | defining-attr-methods=__init__,__new__,setUp 221 | 222 | # List of valid names for the first argument in a class method. 223 | valid-classmethod-first-arg=cls,class_ 224 | 225 | # List of valid names for the first argument in a metaclass class method. 226 | valid-metaclass-classmethod-first-arg=mcs 227 | 228 | 229 | [DESIGN] 230 | 231 | # Maximum number of arguments for function / method 232 | max-args=5 233 | 234 | # Argument names that match this expression will be ignored. Default to name 235 | # with leading underscore 236 | ignored-argument-names=_.* 237 | 238 | # Maximum number of locals for function / method body 239 | max-locals=15 240 | 241 | # Maximum number of return / yield for function / method body 242 | max-returns=6 243 | 244 | # Maximum number of branch for function / method body 245 | max-branches=12 246 | 247 | # Maximum number of statements in function / method body 248 | max-statements=50 249 | 250 | # Maximum number of parents for a class (see R0901). 251 | max-parents=7 252 | 253 | # Maximum number of attributes for a class (see R0902). 254 | max-attributes=7 255 | 256 | # Minimum number of public methods for a class (see R0903). 257 | min-public-methods=2 258 | 259 | # Maximum number of public methods for a class (see R0904). 260 | max-public-methods=20 261 | 262 | 263 | 264 | [TOKENS] 265 | 266 | # Number of spaces of indent required when the last token on the preceding line 267 | # is an open (, [, or {. 268 | indent-after-paren=4 -------------------------------------------------------------------------------- /HISTORY.md: -------------------------------------------------------------------------------- 1 | ## 2.44.0 2 | 3 | * Support get_tokenizer in clip back inf 4 | 5 | ## 2.43.0 6 | 7 | * Update more deps (fire, pyarrow, pandas, torch) 8 | 9 | ## 2.42.0 10 | 11 | * Update deps 12 | 13 | ## 2.41.0 14 | 15 | * Update scipy requirement from <1.9.2 to <1.11.5 16 | * catch and skip images that fail to load (thanks @heyalexchoi) 17 | * Handle images in multiple folder for files reader and handle uppercase extension (thanks @BIGBALLON) 18 | 19 | ## 2.40.0 20 | 21 | * Add support for the full open clip model name format : ViT-B-32/laion2b_s34b_b79k (thanks @mehdidc @barinov274) 22 | 23 | ## 2.39.0 24 | 25 | * Add DeepSparse backend for CLIP inference (thanks @mgoin) 26 | * fix parquet to arrow script failed when number of samples is small (thanks @luke-han) 27 | * Integration with hugging face ClipModel (thanks @Sofianel5) 28 | 29 | ## 2.38.0 30 | 31 | * Add webp to list of supported files in reader. 32 | * Remove version constraint of fsspec. 33 | 34 | ## 2.37.0 35 | 36 | * Update versions to fix pex and npm build 37 | * Improve errors for empty input folders. 38 | * Default context to fix bug with some requests returning 404 39 | 40 | ## 2.36.1 41 | 42 | * Fix truncate 43 | 44 | ## 2.36.0 45 | 46 | * Make jit=False the default in clip inference 47 | * update webdataset and fsspec 48 | * Add H14 NSFW detector 49 | * Support get tokenizer in clip back (thanks @nousr) 50 | * enable filtering by image with clip-retrieval filter 51 | 52 | ## 2.35.1 53 | 54 | * update key toggles in inf.main (thanks @nousr) 55 | 56 | ## 2.35.0 57 | 58 | * Slurm distributor (thanks @nousr) 59 | * Autocast for openclip 60 | * support openclip in clip back 61 | 62 | ## 2.34.2 63 | 64 | * Read image data from path in case "image_path" is present 65 | 66 | ## 2.34.1 67 | 68 | * Makes file image reader in clip inference fast 69 | 70 | ## 2.34.0 71 | 72 | * Make it possible to use an embedding as query of the back 73 | 74 | ## 2.33.0 75 | 76 | * add clip-client module for querying backend remotely (thanks @afiaka87 ) 77 | 78 | ## 2.32.0 79 | 80 | * use better mclip from https://github.com/FreddeFrallan/Multilingual-CLIP 81 | 82 | ## 2.31.1 83 | 84 | * add clearer way to disable aesthetic scoring in front 85 | 86 | ## 2.31.0 87 | 88 | * aesthetic option 89 | 90 | ## 2.30.0 91 | 92 | * Log error for unsupported input_format (thanks @dmvaldman) 93 | * Add open_clip support (thanks @cat-state) 94 | 95 | ## 2.29.1 96 | 97 | * fix mclip in clip back 98 | 99 | ## 2.29.0 100 | 101 | * add violence detector to clip back 102 | 103 | ## 2.28.0 104 | 105 | * add feature to pass options in config file 106 | 107 | ## 2.27.0 108 | 109 | * safety model for ViT-B/32 110 | 111 | ## 2.26.0 112 | 113 | * replace safety heuristic by safety model 114 | 115 | ## 2.25.4 116 | 117 | * enable back dedup of images 118 | 119 | ## 2.25.3 120 | 121 | * turn off image dedup by default temporarily 122 | 123 | ## 2.25.2 124 | 125 | * fix range search use 126 | 127 | ## 2.25.1 128 | 129 | * add back node build in publish 130 | 131 | ## 2.25.0 132 | 133 | * new arrow provider in clip back 134 | * index combiner script 135 | * parquet to arrow script 136 | * deduplication of results feature 137 | 138 | ## 2.24.10 139 | 140 | * one more fix for text only 141 | 142 | ## 2.24.9 143 | 144 | * fix image_tensor_count vs text_counter count in runner 145 | 146 | ## 2.24.8 147 | 148 | * fix file count check for input format files 149 | 150 | ## 2.24.7 151 | 152 | * going back to autofaiss main 153 | 154 | ## 2.24.6 155 | 156 | * switch to fork of autofaiss 157 | 158 | ## 2.24.5 159 | 160 | * properly close the wandb run at the end 161 | 162 | ## 2.24.4 163 | 164 | * fix pex building 165 | 166 | ## 2.24.3 167 | 168 | * fix version ranges 169 | 170 | ## 2.24.2 171 | 172 | * fix sample_count == 0 issue in logger and handle no text sample properly in main 173 | 174 | ## 2.24.1 175 | 176 | * improve logger by checking the file exists before reading 177 | 178 | ## 2.24.0 179 | 180 | * use zero padding for output file names 181 | * add proper multi gpu support in pyspark distributor 182 | * improve printing of error in logger 183 | 184 | ## 2.23.3 185 | 186 | * fix another small issue with logger reporting 187 | 188 | ## 2.23.2 189 | 190 | * small fix in logger computation 191 | 192 | ## 2.23.1 193 | 194 | * Fix race condition when using mkdir in writer 195 | 196 | ## 2.23.0 197 | 198 | * Refactor clip inference, make it support distributed inference 199 | 200 | ## 2.22.0 201 | 202 | * add use_jit option to back and inference, now True by default, add clip_model option to back 203 | 204 | ## 2.21.0 205 | 206 | * mclip support in clip back and front 207 | 208 | ## 2.20.0 209 | 210 | * replace null bytes while transforming parquet to hdf5 211 | * Use collate_fn to skip corrupt images without using recursion (thanks @afiaka87) 212 | * truncate text inputs in clip back 213 | 214 | ## 2.19.1 215 | 216 | * fix url column option bug 217 | 218 | ## 2.19.0 219 | 220 | * add url column option 221 | * use torch no grad to fix a memleak in clip back 222 | 223 | ## 2.18.0 224 | 225 | * add default backend url in clip back 226 | 227 | ## 2.17.0 228 | 229 | * add option in clip end 2 end to avoid running the back 230 | 231 | ## 2.16.2 232 | 233 | * update for autofaiss 234 | 235 | ## 2.16.1 236 | 237 | * add missing front building in python publish 238 | 239 | ## 2.16.0 240 | 241 | * clip retrieval end2end 242 | 243 | ## 2.15.1 244 | 245 | * minor bug fix about missing .npy extension in output of clip inference 246 | 247 | ## 2.15.0 248 | 249 | * mclip support 250 | * use fsspec to make it possible to output to any fs 251 | 252 | ## 2.14.3 253 | 254 | * add indice deduplication in the output of clip back 255 | 256 | ## 2.14.2 257 | 258 | * use the npy mapping in all cases for ivf reordering since it's fast enough 259 | 260 | ## 2.14.1 261 | 262 | * save ivf_old_to_new_mapping for the text index to use 263 | 264 | ## 2.14.0 265 | 266 | * implement ivf re-ordering for much faster metadata fetching 267 | * add download button in front 268 | 269 | ## 2.13.1 270 | 271 | * fix filterDuplicateUrls issue when there is no url, only images 272 | * fix default columns_to_return 273 | 274 | ## 2.13.0 275 | 276 | * add a simple filter ipynb notebook 277 | 278 | ## 2.12.0 279 | 280 | * implement infinite scroll feature 281 | 282 | ## 2.11.2 283 | 284 | * fix limiting of results in clip back 285 | * fix absence of caption in clip front 286 | 287 | ## 2.11.1 288 | 289 | * fix an issue in clip front handling of default 290 | * limit the number of results to the number available in clip back 291 | 292 | ## 2.11.0 293 | 294 | * add compression by default when creating the hdf5 cache file 295 | 296 | ## 2.10.0 297 | 298 | * add columns_to_return in clip back 299 | * safe mode in front 300 | 301 | ## 2.9.2 302 | 303 | * fix metrics sorting in metrics summary 304 | 305 | ## 2.9.1 306 | 307 | * add download url time and descriptions in metrics summary endpoint 308 | 309 | ## 2.9.0 310 | 311 | * add prometheus endpoint in clip back 312 | 313 | ## 2.8.1 314 | 315 | * properly display errors in clip index 316 | 317 | ## 2.8.0 318 | 319 | * add nb cores option in clip index 320 | 321 | ## 2.7.1 322 | 323 | * add folder name option and catch errors in clip index 324 | 325 | ## 2.7.0 326 | 327 | * package front in npm 328 | 329 | ## 2.6.0 330 | 331 | * implement image url search in clip back 332 | 333 | ## 2.5.0 334 | 335 | * add memory mapping option in clip back : 0 memory usage to load an index! 336 | 337 | ## 2.4.0 338 | 339 | * add copy metadata option to clip index 340 | 341 | ## 2.3.0 342 | 343 | * allows controlling the amount of ram used during the creation process of the index 344 | * add logs in clip back to inform when each thing is loaded 345 | * fix PIL call (thanks @pvl) 346 | 347 | ## 2.2.0 348 | 349 | * expose max_index_memory_usage 350 | 351 | ## 2.1.0 352 | 353 | * --wds_image_key, --wds_caption_key options (thanks @afiaka87) 354 | * implement h5py caching in clip back 355 | 356 | ## 2.0.4 357 | 358 | * fix clip back and filter to use sorted metadatas 359 | 360 | ## 2.0.3 361 | 362 | * fix finding the last batch number (continuing output) 363 | 364 | ## 2.0.2 365 | 366 | * add warn and continue handler to avoid crashing 367 | 368 | ## 2.0.1 369 | 370 | * add missing webdataset dep 371 | 372 | ## 2.0.0 373 | 374 | * webdataset input format 375 | * save in batch 376 | * test files in tests folder 377 | * save metadata as parquet 378 | * use autofaiss in a new clip index 379 | * remove indexing from clip batch and rename to clip inference 380 | 381 | ## 1.0.1 382 | 383 | * fixes 384 | 385 | ## 1.0.0 386 | 387 | * it works 388 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Romain Beaumont 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-dev: ## [Local development] Install test requirements 6 | python -m pip install -r requirements-test.txt 7 | 8 | lint: ## [Local development] Run mypy, pylint and black 9 | python -m mypy clip_retrieval 10 | python -m pylint clip_retrieval 11 | python -m black --check -l 120 . 12 | 13 | black: ## [Local development] Auto-format python code using black 14 | python -m black -l 120 . 15 | 16 | build-pex: 17 | python3 -m venv .pexing 18 | . .pexing/bin/activate && python -m pip install -U pip && python -m pip install pex 19 | . .pexing/bin/activate && python -m pex --layout packed setuptools gcsfs charset-normalizer s3fs pyspark torch torchvision . -o clip_retrieval.pex -v 20 | rm -rf .pexing 21 | tar czf clip_retrieval_torch.tgz clip_retrieval.pex/.deps/torch-* 22 | tar czf clip_retrieval.tgz --exclude clip_retrieval.pex/.deps/torch-* clip_retrieval.pex 23 | 24 | venv-lint-test: ## [Continuous integration] 25 | python3 -m venv .env && . .env/bin/activate && make install install-dev lint test && rm -rf .env 26 | 27 | test: ## [Local development] Run unit tests 28 | rm -rf tests/test_folder/ 29 | python -m pytest -x -s -v tests 30 | 31 | .PHONY: help 32 | 33 | help: # Run `make help` to get help on the make commands 34 | @grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 35 | -------------------------------------------------------------------------------- /clip_retrieval/__init__.py: -------------------------------------------------------------------------------- 1 | """clip retrieval""" 2 | 3 | from .clip_back import clip_back 4 | from .clip_filter import clip_filter 5 | from .clip_index import clip_index 6 | from .clip_inference.main import main as clip_inference 7 | 8 | # from .clip_inference import clip_inference 9 | from .clip_end2end import clip_end2end 10 | from .clip_front import clip_front 11 | -------------------------------------------------------------------------------- /clip_retrieval/cli.py: -------------------------------------------------------------------------------- 1 | """cli entry point""" 2 | 3 | from clip_retrieval.clip_back_prepro.parquet_to_arrow import parquet_to_arrow 4 | from clip_retrieval.clip_back import clip_back 5 | from clip_retrieval.clip_inference import clip_inference 6 | from clip_retrieval.clip_inference.worker import worker 7 | from clip_retrieval.clip_inference.slurm_worker import slurm_worker 8 | from clip_retrieval.clip_filter import clip_filter 9 | from clip_retrieval.clip_index import clip_index 10 | from clip_retrieval.clip_end2end import clip_end2end 11 | from clip_retrieval.clip_front import clip_front 12 | from clip_retrieval.clip_back_prepro.index_combiner import index_combiner 13 | import fire 14 | 15 | 16 | def main(): 17 | """Main entry point""" 18 | fire.Fire( 19 | { 20 | "back": clip_back, 21 | "index": clip_index, 22 | "filter": clip_filter, 23 | "end2end": clip_end2end, 24 | "front": clip_front, 25 | "index_combiner": index_combiner, 26 | "parquet_to_arrow": parquet_to_arrow, 27 | "inference": clip_inference, 28 | "inference.worker": worker, 29 | "inference.slurm_worker": slurm_worker, 30 | } 31 | ) 32 | 33 | 34 | if __name__ == "__main__": 35 | main() 36 | -------------------------------------------------------------------------------- /clip_retrieval/clip_back_prepro/README.md: -------------------------------------------------------------------------------- 1 | ## clip back prepro 2 | 3 | Clip back preprocessing jobs transform metadata and indices into a form that is easier to use for clip back. 4 | 5 | This is helpful to load large datasets in clip back, for example laion5B that has a 800GB index and 900GB of metadata. 6 | 7 | ### Parquet to arrow 8 | 9 | The parquet to arrow script converts many parquet files into a few arrow files. 10 | 11 | The benefit of arrow format compared to parquet is that it's possible to memmap it, allowing to use large amount of metadata at no memory cost. 12 | 13 | Usage example 14 | 15 | ```bash 16 | clip-retrieval parquet_to_arrow --parquet_folder "/media/hd2/allmeta/2Ben"\ 17 | --output_arrow_folder "/media/nvme/large_index/metadata/2B-en"\ 18 | --columns_to_return='["url", "caption"]' 19 | ``` 20 | 21 | ### Index combiner 22 | 23 | The indexer combiner script converts many indices into a single index file, without using memory. 24 | 25 | This makes it possible to use a large index at low memory cost (<500MB) 26 | 27 | Usage example 28 | 29 | ```bash 30 | clip_retrieval index_combiner --input_folder "the/indices"\ 31 | --output_folder "output" 32 | ``` -------------------------------------------------------------------------------- /clip_retrieval/clip_back_prepro/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/clip_retrieval/clip_back_prepro/__init__.py -------------------------------------------------------------------------------- /clip_retrieval/clip_back_prepro/index_combiner.py: -------------------------------------------------------------------------------- 1 | """the index combiner module is used to combine the index files into a single index file""" 2 | 3 | from pathlib import Path 4 | from faiss.contrib.ondisk import merge_ondisk 5 | import faiss 6 | import fire 7 | import os 8 | 9 | 10 | def index_combiner(input_folder, output_folder): 11 | """combine the index files into a single index file""" 12 | index_dir = Path(input_folder) 13 | block_fnames = sorted([str(a) for a in index_dir.glob("*") if "index" in str(a)]) 14 | empty_index = faiss.read_index(block_fnames[0], faiss.IO_FLAG_MMAP) 15 | empty_index.ntotal = 0 16 | 17 | if not os.path.exists(output_folder): 18 | os.makedirs(output_folder) 19 | 20 | merge_ondisk(empty_index, block_fnames, output_folder + "/merged_index.ivfdata") 21 | 22 | faiss.write_index(empty_index, output_folder + "/populated.index") 23 | 24 | 25 | if __name__ == "__main__": 26 | fire.Fire(index_combiner) 27 | -------------------------------------------------------------------------------- /clip_retrieval/clip_back_prepro/parquet_to_arrow.py: -------------------------------------------------------------------------------- 1 | """the parquet to arrow module is used to convert the parquet files into arrow files""" 2 | 3 | from multiprocessing.pool import ThreadPool 4 | import os 5 | from pathlib import Path 6 | import pyarrow.parquet as pq 7 | import pyarrow as pa 8 | from tqdm import tqdm 9 | import fire 10 | import math 11 | 12 | 13 | def file_to_count(filename): 14 | with open(filename, "rb") as f: 15 | parquet_file = pq.ParquetFile(f, memory_map=True) 16 | return parquet_file.metadata.num_rows 17 | 18 | 19 | def count_samples(files): 20 | total_count = 0 21 | with ThreadPool(10) as p: 22 | for c in tqdm(p.imap(file_to_count, files), total=len(files)): 23 | total_count += c 24 | return total_count 25 | 26 | 27 | def parquet_to_arrow(parquet_folder, output_arrow_folder, columns_to_return): 28 | """convert the parquet files into arrow files""" 29 | os.makedirs(output_arrow_folder, exist_ok=True) 30 | data_dir = Path(parquet_folder) 31 | files = sorted(data_dir.glob("*.parquet")) 32 | number_samples = count_samples(files) 33 | print("There are {} samples in the dataset".format(number_samples)) # pylint: disable=consider-using-f-string 34 | 35 | schema = pq.read_table(files[0], columns=columns_to_return).schema 36 | sink = None 37 | current_batch_count = 0 38 | batch_counter = 0 39 | key_format = max(0, int(math.log10(number_samples / 10**10))) + 1 40 | for parquet_files in tqdm(files): 41 | if sink is None or current_batch_count > 10**10: 42 | if sink is not None: 43 | writer.close() 44 | sink.close() 45 | file_key = "{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string 46 | key_format=key_format, true_key=batch_counter 47 | ) 48 | file_name = f"{output_arrow_folder}/{file_key}.arrow" 49 | print(f"Writing to {file_name}") 50 | sink = pa.OSFile(file_name, "wb") 51 | writer = pa.ipc.new_file(sink, schema) 52 | current_batch_count = 0 53 | batch_counter += 1 54 | 55 | print("going to read parquet file: ", parquet_files) 56 | for i in range(2): 57 | try: 58 | table = pq.read_table(parquet_files, columns=columns_to_return, use_threads=False) 59 | except Exception as e: # pylint: disable=broad-except 60 | if i == 1: 61 | raise e 62 | print("Error reading parquet file: ", e) 63 | print("Retrying once...") 64 | continue 65 | writer.write_table(table) 66 | current_batch_count += table.num_rows 67 | if sink is not None: 68 | writer.close() 69 | sink.close() 70 | 71 | 72 | if __name__ == "__main__": 73 | fire.Fire(parquet_to_arrow) 74 | -------------------------------------------------------------------------------- /clip_retrieval/clip_client.py: -------------------------------------------------------------------------------- 1 | """Clip client is a simple python module that allows you to query the backend remotely.""" 2 | 3 | import base64 4 | import enum 5 | import json 6 | from pathlib import Path 7 | from typing import Dict, List, Optional 8 | 9 | import requests 10 | 11 | 12 | class Modality(enum.Enum): 13 | IMAGE = "image" 14 | TEXT = "text" 15 | 16 | 17 | class ClipClient: 18 | """Remotely query the CLIP backend via REST""" 19 | 20 | def __init__( 21 | self, 22 | url: str, 23 | indice_name: str, 24 | use_mclip: bool = False, 25 | aesthetic_score: int = 9, 26 | aesthetic_weight: float = 0.5, 27 | modality: Modality = Modality.IMAGE, 28 | num_images: int = 40, 29 | deduplicate: bool = True, 30 | use_safety_model: bool = True, 31 | use_violence_detector: bool = True, 32 | ): 33 | """ 34 | url: (required) URL of the backend. 35 | indice_name: (required) which indice to search over e.g. "laion5B" or "laion_400m". 36 | use_mclip: (optional) whether to use mclip, a multilingual version of clip. Default is False. 37 | aesthetic_score: (optional) ranking score for aesthetic, higher is prettier. Default is 9. 38 | aesthetic_weight: (optional) weight of the aesthetic score, between 0 and 1. Default is 0.5. 39 | modality: (optional) Search modality. One of Modality.IMAGE or Modality.TEXT. Default is Modality.IMAGE. 40 | num_images: (optional) Number of images to return. Default is 40. 41 | deduplicate: (optional) Whether to deduplicate the result by image embedding. Default is true. 42 | use_safety_model: (optional) Whether to remove unsafe images. Default is true. 43 | use_violence_detector: (optional) Whether to remove images with violence. Default is true. 44 | """ 45 | self.url = url 46 | self.indice_name = indice_name 47 | self.use_mclip = use_mclip 48 | self.aesthetic_score = aesthetic_score 49 | self.aesthetic_weight = aesthetic_weight 50 | self.modality = modality.value 51 | self.num_images = num_images 52 | self.deduplicate = deduplicate 53 | self.use_safety_model = use_safety_model 54 | self.use_violence_detector = use_violence_detector 55 | 56 | def query( 57 | self, 58 | text: Optional[str] = None, 59 | image: Optional[str] = None, 60 | embedding_input: Optional[list] = None, 61 | ) -> List[Dict]: 62 | """ 63 | Given text or image/s, search for other captions/images that are semantically similar. 64 | 65 | Args: 66 | text: text to be searched semantically. 67 | image: base64 string of image to be searched semantically 68 | 69 | Returns: 70 | List of dictionaries containing the results in the form of: 71 | [ 72 | { 73 | "id": 42, 74 | "similarity": 0.323424523424, 75 | "url": "https://example.com/image.jpg", 76 | "caption": "This is a caption", 77 | }, 78 | ... 79 | ] 80 | """ 81 | if text and image: 82 | raise ValueError("Only one of text or image can be provided.") 83 | if text: 84 | return self.__search_knn_api__(text=text) 85 | elif image: 86 | if image.startswith("http"): 87 | return self.__search_knn_api__(image_url=image) 88 | else: 89 | assert Path(image).exists(), f"{image} does not exist." 90 | return self.__search_knn_api__(image=image) 91 | elif embedding_input: 92 | return self.__search_knn_api__(embedding_input=embedding_input) 93 | else: 94 | raise ValueError("Either text or image must be provided.") 95 | 96 | def __search_knn_api__( 97 | self, 98 | text: Optional[str] = None, 99 | image: Optional[str] = None, 100 | image_url: Optional[str] = None, 101 | embedding_input: Optional[list] = None, 102 | ) -> List: 103 | """ 104 | This function is used to send the request to the knn service. 105 | It represents a direct API call and should not be called directly outside the package. 106 | 107 | Args: 108 | text: text to be searched semantically. 109 | image: base64 string of image to be searched semantically. 110 | image_url: url of the image to be searched semantically. 111 | embedding_input: embedding input 112 | 113 | Returns: 114 | List of dictionaries containing the results in the form of: 115 | [ 116 | { 117 | "id": 42, 118 | "similarity": 0.323424523424, 119 | "url": "https://example.com/image.jpg", 120 | "caption": "This is a caption", 121 | }, 122 | ... 123 | ] 124 | 125 | """ 126 | if image: 127 | # Convert image to base64 string 128 | with open(image, "rb") as image_file: 129 | encoded_string = base64.b64encode(image_file.read()) 130 | image = str(encoded_string.decode("utf-8")) 131 | return requests.post( 132 | self.url, 133 | data=json.dumps( 134 | { 135 | "text": text, 136 | "image": image, 137 | "image_url": image_url, 138 | "embedding_input": embedding_input, 139 | "deduplicate": self.deduplicate, 140 | "use_safety_model": self.use_safety_model, 141 | "use_violence_detector": self.use_violence_detector, 142 | "indice_name": self.indice_name, 143 | "use_mclip": self.use_mclip, 144 | "aesthetic_score": self.aesthetic_score, 145 | "aesthetic_weight": self.aesthetic_weight, 146 | "modality": self.modality, 147 | "num_images": self.num_images, 148 | # num_results_ids is hardcoded to the num_images parameter. 149 | "num_result_ids": self.num_images, 150 | } 151 | ), 152 | timeout=3600, 153 | ).json() 154 | -------------------------------------------------------------------------------- /clip_retrieval/clip_end2end.py: -------------------------------------------------------------------------------- 1 | """clip end2end combines img2dataset, inference, index, back and front to produce a retrieval system in one command""" 2 | 3 | import fire 4 | 5 | 6 | def clip_end2end(url_list, output_folder, run_back=True): 7 | """main entry point of clip end2end""" 8 | 9 | import os # pylint: disable=import-outside-toplevel 10 | from img2dataset import download # pylint: disable=import-outside-toplevel 11 | from clip_retrieval import clip_inference # pylint: disable=import-outside-toplevel 12 | from clip_retrieval import clip_index # pylint: disable=import-outside-toplevel 13 | from clip_retrieval import clip_back # pylint: disable=import-outside-toplevel 14 | import fsspec # pylint: disable=import-outside-toplevel 15 | 16 | fs, output_folder_in_fs = fsspec.core.url_to_fs(output_folder) 17 | print(output_folder_in_fs) 18 | if not fs.exists(output_folder_in_fs): 19 | fs.mkdir(output_folder_in_fs) 20 | image_folder_name = os.path.join(output_folder, "images") 21 | embeddings_folder = os.path.join(output_folder, "embeddings") 22 | index_folder = os.path.join(output_folder, "index") 23 | # img2dataset 24 | download( 25 | url_list, 26 | image_size=256, 27 | output_folder=image_folder_name, 28 | thread_count=128, 29 | processes_count=4, 30 | input_format="parquet", 31 | output_format="webdataset", 32 | url_col="URL", 33 | caption_col="TEXT", 34 | ) 35 | # Clip inference 36 | input_files = [image_folder_name + "/" + p for p in next(fs.walk(image_folder_name))[2] if p.endswith(".tar")] 37 | clip_inference( 38 | input_dataset=input_files, 39 | output_folder=embeddings_folder, 40 | input_format="webdataset", 41 | enable_metadata=True, 42 | write_batch_size=100000, 43 | batch_size=512, 44 | cache_path=None, 45 | ) 46 | # Clip index 47 | os.mkdir(index_folder) 48 | clip_index(embeddings_folder, index_folder=index_folder) 49 | 50 | # Clip back 51 | indice_path = os.path.join(output_folder, "indices_paths.json") 52 | with fsspec.open(indice_path, "w") as f: 53 | f.write('{"example_index": "' + index_folder + '"}') 54 | if run_back: 55 | clip_back(port=1234, indices_paths=indice_path) 56 | 57 | 58 | if __name__ == "__main__": 59 | fire.Fire(clip_end2end) 60 | -------------------------------------------------------------------------------- /clip_retrieval/clip_filter.py: -------------------------------------------------------------------------------- 1 | """clip filter is a tool to use a knn index and a image/text collection to extract interesting subsets""" 2 | 3 | 4 | import fire 5 | 6 | 7 | def clip_filter(query, output_folder, indice_folder, num_results=100, threshold=None): 8 | """Entry point of clip filter""" 9 | 10 | import faiss # pylint: disable=import-outside-toplevel 11 | import torch # pylint: disable=import-outside-toplevel 12 | import os # pylint: disable=import-outside-toplevel 13 | import shutil # pylint: disable=import-outside-toplevel 14 | from pathlib import Path # pylint: disable=import-outside-toplevel 15 | import pandas as pd # pylint: disable=import-outside-toplevel 16 | import clip # pylint: disable=import-outside-toplevel 17 | from PIL import Image # pylint: disable=import-outside-toplevel 18 | 19 | device = "cuda" if torch.cuda.is_available() else "cpu" 20 | model, preprocess = clip.load("ViT-B/32", device=device, jit=False) 21 | 22 | data_dir = Path(indice_folder + "/metadata") 23 | df = pd.concat(pd.read_parquet(parquet_file) for parquet_file in sorted(data_dir.glob("*.parquet"))) 24 | 25 | url_list = None 26 | if "url" in df: 27 | url_list = df["url"].tolist() 28 | 29 | image_list = df["image_path"].tolist() 30 | image_index = faiss.read_index(indice_folder + "/image.index") 31 | indices_loaded = { 32 | "image_list": image_list, 33 | "image_index": image_index, 34 | } 35 | 36 | image_index = indices_loaded["image_index"] 37 | image_list = indices_loaded["image_list"] 38 | if not os.path.exists(output_folder): 39 | os.mkdir(output_folder) 40 | if query.endswith((".png", ".jpg", ".jpeg", ".bmp")) and os.path.isfile(query): 41 | im = Image.open(query) 42 | query_features = model.encode_image(preprocess(im).unsqueeze(0).to(device)) 43 | else: 44 | text = clip.tokenize([query]).to(device) 45 | query_features = model.encode_text(text) 46 | query_features /= query_features.norm(dim=-1, keepdim=True) 47 | query_features = query_features.cpu().detach().numpy().astype("float32") 48 | 49 | index = image_index 50 | 51 | if threshold is not None: 52 | _, d, i = index.range_search(query_features, threshold) 53 | print(f"Found {i.shape} items with query '{query}' and threshold {threshold}") 54 | else: 55 | d, i = index.search(query_features, num_results) 56 | print(f"Found {num_results} items with query '{query}'") 57 | i = i[0] 58 | d = d[0] 59 | 60 | min_d = min(d) 61 | max_d = max(d) 62 | print(f"The minimum distance is {min_d:.2f} and the maximum is {max_d:.2f}") 63 | print( 64 | "You may want to use these numbers to increase your --num_results parameter. Or use the --threshold parameter." 65 | ) 66 | 67 | print(f"Copying the images in {output_folder}") 68 | 69 | for _, ei in zip(d, i): 70 | path = image_list[ei] 71 | if os.path.exists(path): 72 | shutil.copy(path, output_folder) 73 | if url_list is not None: 74 | print(url_list[ei]) 75 | 76 | 77 | if __name__ == "__main__": 78 | fire.Fire(clip_filter) 79 | -------------------------------------------------------------------------------- /clip_retrieval/clip_front.py: -------------------------------------------------------------------------------- 1 | """clip front""" 2 | 3 | from flask import Flask, send_from_directory, request 4 | import json 5 | import fire 6 | 7 | 8 | def add_static_endpoints(app, default_backend=None, default_index=None, url_column="url"): 9 | """add static endpoints to the flask app""" 10 | import pkg_resources # pylint: disable=import-outside-toplevel 11 | 12 | front_path = pkg_resources.resource_filename("clip_retrieval", "../front/build") 13 | 14 | def static_dir_(): 15 | return send_from_directory(front_path, "index.html") 16 | 17 | app.route("/")(static_dir_) 18 | 19 | def config_json(): 20 | back = default_backend if default_backend is not None else request.host_url 21 | index = default_index if default_index is not None else "" 22 | config = {"defaultBackend": back, "defaultIndex": index, "urlColumn": url_column} 23 | return json.dumps(config) 24 | 25 | app.route("/config.json")(config_json) 26 | 27 | def static_dir(path): 28 | return send_from_directory(front_path, path) 29 | 30 | app.route("/")(static_dir) 31 | 32 | 33 | def clip_front(default_backend=None, default_index=None, url_column="url"): 34 | app = Flask(__name__) 35 | add_static_endpoints(app, default_backend, default_index, url_column) 36 | app.run(host="0.0.0.0", port=1235, debug=False) 37 | 38 | 39 | if __name__ == "__main__": 40 | fire.Fire(clip_front) 41 | -------------------------------------------------------------------------------- /clip_retrieval/clip_index.py: -------------------------------------------------------------------------------- 1 | """Clip index is a tool to index clip embeddings using autofaiss""" 2 | 3 | import fire 4 | import os 5 | from shutil import copytree 6 | import logging 7 | 8 | 9 | LOGGER = logging.getLogger(__name__) 10 | 11 | 12 | def quantize(emb_folder, index_folder, index_name, max_index_memory_usage, current_memory_available, nb_cores): 13 | """calls autofaiss to build an index""" 14 | 15 | from autofaiss import build_index # pylint: disable=import-outside-toplevel 16 | 17 | try: 18 | LOGGER.debug(f"starting index {index_name}") 19 | if os.path.exists(emb_folder): 20 | LOGGER.debug( 21 | f"embedding path exist, building index {index_name}" 22 | f"using embeddings {emb_folder} ; saving in {index_folder}" 23 | ) 24 | build_index( 25 | embeddings=emb_folder, 26 | index_path=index_folder + "/" + index_name + ".index", 27 | index_infos_path=index_folder + "/" + index_name + ".json", 28 | max_index_memory_usage=max_index_memory_usage, 29 | current_memory_available=current_memory_available, 30 | nb_cores=nb_cores, 31 | ) 32 | LOGGER.debug(f"index {index_name} done") 33 | except Exception as e: # pylint: disable=broad-except 34 | LOGGER.exception(f"index {index_name} failed") 35 | raise e 36 | 37 | 38 | def clip_index( 39 | embeddings_folder, 40 | index_folder, 41 | max_index_memory_usage="4G", 42 | current_memory_available="16G", 43 | copy_metadata=True, 44 | image_subfolder="img_emb", 45 | text_subfolder="text_emb", 46 | nb_cores=None, 47 | ): 48 | """indexes clip embeddings using autofaiss""" 49 | quantize( 50 | embeddings_folder + "/" + image_subfolder, 51 | index_folder, 52 | "image", 53 | max_index_memory_usage, 54 | current_memory_available, 55 | nb_cores, 56 | ) 57 | quantize( 58 | embeddings_folder + "/" + text_subfolder, 59 | index_folder, 60 | "text", 61 | max_index_memory_usage, 62 | current_memory_available, 63 | nb_cores, 64 | ) 65 | if copy_metadata: 66 | copytree(embeddings_folder + "/metadata", index_folder + "/metadata") 67 | 68 | 69 | if __name__ == "__main__": 70 | fire.Fire(clip_index) 71 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/__init__.py: -------------------------------------------------------------------------------- 1 | """clip inference""" 2 | 3 | from .main import main as clip_inference 4 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/distributor.py: -------------------------------------------------------------------------------- 1 | """distributors provide way to compute using several gpus and several machines""" 2 | 3 | import os 4 | 5 | from .worker import worker 6 | 7 | 8 | class SequentialDistributor: 9 | def __init__(self, tasks, worker_args): 10 | self.tasks = tasks 11 | self.worker_args = worker_args 12 | 13 | def __call__(self): 14 | """ 15 | call a single `worker(...)` and pass it everything. 16 | """ 17 | worker( 18 | tasks=self.tasks, 19 | **self.worker_args, 20 | ) 21 | 22 | 23 | class PysparkDistributor: 24 | """the pyspark distributor uses pyspark for distribution""" 25 | 26 | def __init__(self, tasks, worker_args): 27 | self.tasks = tasks 28 | self.worker_args = worker_args 29 | 30 | def __call__(self): 31 | """ 32 | Parallelize work and call `worker(...)` 33 | """ 34 | 35 | import pyspark # pylint: disable=import-outside-toplevel 36 | from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel 37 | 38 | spark = SparkSession.getActiveSession() 39 | 40 | if spark is None: 41 | print("No pyspark session found, creating a new one!") 42 | spark = ( 43 | SparkSession.builder.config("spark.driver.memory", "16G") 44 | .master("local[" + str(2) + "]") 45 | .appName("spark-stats") 46 | .getOrCreate() 47 | ) 48 | 49 | rdd = spark.sparkContext.parallelize(c=self.tasks, numSlices=len(self.tasks)) 50 | 51 | def run(partition_id): 52 | context = pyspark.TaskContext.get() 53 | if "gpu" in context.resources(): 54 | gpu = context.resources()["gpu"].addresses[0] 55 | os.environ["CUDA_VISIBLE_DEVICES"] = gpu 56 | 57 | worker(tasks=[partition_id], **self.worker_args) 58 | 59 | rdd.foreach(run) 60 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/logger.py: -------------------------------------------------------------------------------- 1 | """The logger module allows logging to stdout and wandb""" 2 | 3 | from collections import defaultdict 4 | import fsspec 5 | import multiprocessing 6 | import time 7 | import json 8 | import wandb 9 | import queue 10 | import traceback 11 | 12 | 13 | class LoggerWriter: 14 | """the logger writer write stats to json file, for each worker""" 15 | 16 | def __init__(self, partition_id, stats_folder): 17 | self.partition_id = partition_id 18 | self.stats_folder = stats_folder 19 | 20 | def start(self): 21 | ctx = multiprocessing.get_context("spawn") 22 | self.queue = ctx.Queue() 23 | self.updater_process = ctx.Process(target=self.updater) 24 | self.updater_process.start() 25 | 26 | def end(self): 27 | self.queue.put(None) 28 | self.updater_process.join() 29 | self.queue.close() 30 | 31 | def __call__(self, stats): 32 | self.queue.put(stats) 33 | 34 | def updater(self): 35 | """updater process that writes stats to file from the queue""" 36 | stats = defaultdict(lambda: 0) 37 | fs, relative_path = fsspec.core.url_to_fs(self.stats_folder) 38 | last_write = None 39 | while True: 40 | item = self.queue.get() 41 | if item is None: 42 | self.write_stats(stats, fs, relative_path, False) 43 | return 44 | for k in item: 45 | stats[k] += item[k] 46 | if last_write is None or time.time() - last_write > 5: 47 | self.write_stats(stats, fs, relative_path, True) 48 | last_write = time.time() 49 | 50 | def sum(self, stats, new_stats): 51 | for k in stats.keys(): 52 | stats[k] += new_stats[k] 53 | return stats 54 | 55 | def write_stats(self, stats, fs, relative_path, wip): 56 | fs.makedirs(relative_path, exist_ok=True) 57 | if not wip and fs.exists(relative_path + f"/wip_{self.partition_id}.json"): 58 | fs.rm(relative_path + f"/wip_{self.partition_id}.json") 59 | prefix = "wip_" if wip else "" 60 | with fs.open(relative_path + f"/{prefix}{self.partition_id}.json", "w") as f: 61 | f.write(json.dumps(stats)) 62 | 63 | 64 | class LoggerReader: 65 | """the logger reader read stats of all json files and aggregate them""" 66 | 67 | def __init__(self, stats_folder, wandb_project="clip_retrieval", enable_wandb=False): 68 | self.stats_folder = stats_folder 69 | self.enable_wandb = enable_wandb 70 | self.wandb_project = wandb_project 71 | self.log_interval = 5 72 | 73 | def start(self): 74 | ctx = multiprocessing.get_context("spawn") 75 | self.queue = ctx.Queue() 76 | self.start_time = time.perf_counter() 77 | self.reader_process = ctx.Process(target=self.reader) 78 | self.reader_process.start() 79 | 80 | def end(self): 81 | self.queue.put("end") 82 | self.reader_process.join() 83 | self.queue.close() 84 | 85 | def reader(self): 86 | """reader process that reads stats from files and aggregates them""" 87 | try: # pylint: disable=too-many-nested-blocks 88 | if self.enable_wandb: 89 | self.current_run = wandb.init(project=self.wandb_project) 90 | else: 91 | self.current_run = None 92 | 93 | last_check = 0 94 | stats = {} 95 | start_time_no_initial_load = float("inf") 96 | fs, relative_path = fsspec.core.url_to_fs(self.stats_folder, use_listings_cache=False) 97 | 98 | fs.makedirs(relative_path, exist_ok=True) 99 | 100 | while True: # pylint: disable=too-many-nested-blocks 101 | time.sleep(0.1) 102 | try: 103 | self.queue.get(False) 104 | last_one = True 105 | except queue.Empty as _: 106 | last_one = False 107 | if not last_one and time.perf_counter() - last_check < self.log_interval: 108 | continue 109 | 110 | last_check = time.perf_counter() 111 | 112 | stats_files = fs.glob(relative_path + "/*.json") 113 | for k in stats_files: 114 | filename = k.split("/")[-1] 115 | if filename[:4] == "wip_" or filename not in stats: 116 | for i in range(5): # pylint: disable=unused-variable 117 | try: 118 | fs.invalidate_cache() 119 | if not fs.exists(k): 120 | continue 121 | with fs.open(k, "r") as f: 122 | stats[filename] = json.loads(f.read()) 123 | if filename[:4] != "wip_" and "wip_" + filename in stats: 124 | del stats["wip_" + filename] 125 | break 126 | except Exception as e: # pylint: disable=broad-except 127 | if i == 4: 128 | print(f"failed to read {k} error : {e}") 129 | time.sleep(1) 130 | 131 | stats_aggregated = defaultdict(lambda: 0) 132 | for k, v in stats.items(): 133 | for k2 in v: 134 | stats_aggregated[k2] += v[k2] 135 | 136 | for v in stats.values(): 137 | start_time_no_initial_load = min(start_time_no_initial_load, v["start_time"]) 138 | 139 | current_time = time.perf_counter() 140 | current_real_time = time.time() 141 | total_duration = current_time - self.start_time 142 | 143 | if stats_aggregated["sample_count"] == 0: 144 | if last_one: 145 | self._finish() 146 | break 147 | continue 148 | total_duration_no_initial_load = current_real_time - start_time_no_initial_load 149 | 150 | stats_aggregated["average_read_duration_per_sample"] = ( 151 | stats_aggregated["read_duration"] / stats_aggregated["sample_count"] 152 | ) 153 | stats_aggregated["average_inference_duration_per_sample"] = ( 154 | stats_aggregated["inference_duration"] / stats_aggregated["sample_count"] 155 | ) 156 | stats_aggregated["average_write_duration_per_sample"] = ( 157 | stats_aggregated["write_duration"] / stats_aggregated["sample_count"] 158 | ) 159 | stats_aggregated["average_total_duration_per_sample"] = ( 160 | stats_aggregated["total_duration"] / stats_aggregated["sample_count"] 161 | ) 162 | stats_aggregated["sample_per_sec"] = stats_aggregated["sample_count"] / total_duration 163 | stats_aggregated["total_job_duration"] = total_duration 164 | stats_aggregated["total_duration_no_initial_load"] = total_duration_no_initial_load 165 | stats_aggregated["sample_per_sec_no_initial_load"] = ( 166 | stats_aggregated["sample_count"] / total_duration_no_initial_load 167 | ) 168 | 169 | to_log = [ 170 | "sample_count", 171 | "sample_per_sec", 172 | "sample_per_sec_no_initial_load", 173 | "total_job_duration", 174 | "average_read_duration_per_sample", 175 | "average_inference_duration_per_sample", 176 | "average_write_duration_per_sample", 177 | "average_total_duration_per_sample", 178 | ] 179 | stats_for_logging = {} 180 | for k in to_log: 181 | stats_for_logging[k] = stats_aggregated[k] 182 | 183 | print( 184 | "\r", 185 | "sample_per_sec " 186 | + str(int(stats_for_logging["sample_per_sec_no_initial_load"])) 187 | + " ; sample_count " 188 | + str(stats_for_logging["sample_count"]) 189 | + " ", 190 | end="", 191 | ) 192 | if self.enable_wandb: 193 | wandb.log(stats_for_logging) 194 | 195 | if last_one: 196 | self._finish() 197 | break 198 | except Exception as e: # pylint: disable=broad-except 199 | traceback.print_exc() 200 | print("logger error", e) 201 | self._finish() 202 | return 203 | 204 | def _finish(self): 205 | if self.current_run is not None: 206 | self.current_run.finish() 207 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/main.py: -------------------------------------------------------------------------------- 1 | """main module combines distributor, runner, reader, mapper, writer to produce clip embeddings""" 2 | 3 | import fire 4 | import math 5 | from braceexpand import braceexpand 6 | 7 | from clip_retrieval.clip_inference.logger import LoggerReader 8 | from clip_retrieval.clip_inference.reader import folder_to_keys 9 | from clip_retrieval.clip_inference.slurm_distributor import SlurmDistributor 10 | from clip_retrieval.clip_inference.distributor import PysparkDistributor, SequentialDistributor 11 | 12 | 13 | def calculate_partition_count( 14 | input_format, 15 | input_dataset, 16 | enable_image, 17 | enable_text, 18 | enable_metadata, 19 | write_batch_size, 20 | wds_number_file_per_input_file, 21 | ): 22 | """ 23 | Calculate the partition count needed to store the resulting embeddings. 24 | 25 | Return: 26 | - the output partition count and the updated toggles for image, text and metadata. 27 | """ 28 | 29 | sample_count = 0 30 | 31 | if input_format == "files": 32 | keys, text_files, image_files, metadata_files = folder_to_keys( 33 | input_dataset, 34 | enable_text=enable_text, 35 | enable_image=enable_image, 36 | enable_metadata=enable_metadata, 37 | ) 38 | if text_files is None or len(text_files) == 0: 39 | enable_text = False 40 | if image_files is None or len(image_files) == 0: 41 | enable_image = False 42 | if metadata_files is None or len(metadata_files) == 0: 43 | enable_metadata = False 44 | if not enable_text and not enable_image and not enable_metadata: 45 | raise ValueError("no sample found") 46 | keys, text_files, image_files, metadata_files = folder_to_keys( 47 | input_dataset, 48 | enable_text=enable_text, 49 | enable_image=enable_image, 50 | enable_metadata=enable_metadata, 51 | ) 52 | sample_count = len(keys) 53 | elif input_format == "webdataset": 54 | sample_count = len(input_dataset) * wds_number_file_per_input_file 55 | else: 56 | raise ValueError(f"Unsupported input_format {input_format}") 57 | 58 | if sample_count == 0: 59 | raise ValueError("no sample found") 60 | 61 | print(f"The number of samples has been estimated to be {sample_count}") 62 | 63 | output_partition_count = math.ceil(sample_count / write_batch_size) 64 | 65 | return output_partition_count, enable_text, enable_image, enable_metadata 66 | 67 | 68 | # pylint: disable=unused-argument 69 | def main( 70 | input_dataset, 71 | output_folder, 72 | input_format="files", 73 | cache_path=None, 74 | batch_size=256, 75 | num_prepro_workers=4, 76 | enable_text=True, 77 | enable_image=True, 78 | enable_metadata=False, 79 | write_batch_size=10**6, 80 | wds_image_key="jpg", 81 | wds_caption_key="txt", 82 | clip_model="ViT-B/32", 83 | mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", 84 | use_mclip=False, 85 | use_jit=False, 86 | distribution_strategy="sequential", 87 | wds_number_file_per_input_file=10000, 88 | output_partition_count=None, 89 | wandb_project="clip_retrieval", 90 | enable_wandb=False, 91 | clip_cache_path=None, 92 | slurm_job_name=None, 93 | slurm_partition=None, 94 | slurm_nodes=None, 95 | slurm_job_comment=None, 96 | slurm_nodelist=None, 97 | slurm_exclude=None, 98 | slurm_job_timeout=None, 99 | slurm_cache_path=None, 100 | slurm_verbose_wait=False, 101 | ): 102 | # package arguments to pass on to the distributor 103 | local_args = dict(locals()) 104 | 105 | expanded_dataset = list(braceexpand(input_dataset)) if input_format == "webdataset" else input_dataset 106 | 107 | # compute this now for the distributors to use 108 | if output_partition_count is None: 109 | output_partition_count, enable_text, enable_image, enable_metadata = calculate_partition_count( 110 | input_format=input_format, 111 | input_dataset=expanded_dataset, 112 | enable_image=enable_image, 113 | enable_text=enable_text, 114 | enable_metadata=enable_metadata, 115 | write_batch_size=write_batch_size, 116 | wds_number_file_per_input_file=wds_number_file_per_input_file, 117 | ) 118 | 119 | # update the local args to match the computed values 120 | local_args["output_partition_count"] = output_partition_count 121 | local_args["enable_text"] = enable_text 122 | local_args["enable_image"] = enable_image 123 | local_args["enable_metadata"] = enable_metadata 124 | 125 | local_args.pop("wds_number_file_per_input_file") 126 | local_args.pop("write_batch_size") 127 | local_args.pop("distribution_strategy") 128 | local_args.pop("wandb_project") 129 | local_args.pop("enable_wandb") 130 | 131 | tasks = list(range(output_partition_count)) 132 | worker_args = {k: v for k, v in local_args.items() if not k.startswith("slurm_")} 133 | 134 | if distribution_strategy == "sequential": 135 | distributor = SequentialDistributor(tasks=tasks, worker_args=worker_args) 136 | elif distribution_strategy == "pyspark": 137 | distributor = PysparkDistributor(tasks=tasks, worker_args=worker_args) 138 | elif distribution_strategy == "slurm": 139 | slurm_args = {k.lstrip("slurm_"): v for k, v in local_args.items() if k.startswith("slurm_")} 140 | distributor = SlurmDistributor(tasks=tasks, worker_args=worker_args, slurm_args=slurm_args) 141 | else: 142 | print( 143 | f"The {distribution_strategy} strategy is not implemented. Please choose from: [sequential, pyspark, slurm]" 144 | ) 145 | 146 | logger_reader = LoggerReader( 147 | stats_folder=output_folder + "/stats", 148 | wandb_project=wandb_project, 149 | enable_wandb=enable_wandb, 150 | ) 151 | 152 | logger_reader.start() 153 | 154 | distributor() 155 | 156 | logger_reader.end() 157 | 158 | 159 | if __name__ == "__main__": 160 | fire.Fire(main) 161 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/mapper.py: -------------------------------------------------------------------------------- 1 | """mapper module transform images and text to embeddings""" 2 | 3 | import torch 4 | from all_clip import load_clip 5 | from sentence_transformers import SentenceTransformer 6 | 7 | 8 | def normalized(a, axis=-1, order=2): 9 | import numpy as np # pylint: disable=import-outside-toplevel 10 | 11 | l2 = np.atleast_1d(np.linalg.norm(a, order, axis)) 12 | l2[l2 == 0] = 1 13 | return a / np.expand_dims(l2, axis) 14 | 15 | 16 | class ClipMapper: 17 | """transforms images and texts into clip embeddings""" 18 | 19 | def __init__( 20 | self, 21 | enable_image, 22 | enable_text, 23 | enable_metadata, 24 | use_mclip, 25 | clip_model, 26 | use_jit, 27 | mclip_model, 28 | warmup_batch_size=1, 29 | clip_cache_path=None, 30 | ): 31 | self.enable_image = enable_image 32 | self.enable_text = enable_text 33 | self.enable_metadata = enable_metadata 34 | self.use_mclip = use_mclip 35 | self.device = "cuda" if torch.cuda.is_available() else "cpu" 36 | model, _, _ = load_clip( 37 | clip_model=clip_model, 38 | use_jit=use_jit, 39 | warmup_batch_size=warmup_batch_size, 40 | clip_cache_path=clip_cache_path, 41 | ) 42 | self.model_img = model.encode_image 43 | self.model_txt = model.encode_text 44 | if use_mclip: 45 | print("\nLoading MCLIP model for text embedding\n") 46 | mclip = SentenceTransformer(mclip_model) 47 | self.model_txt = mclip.encode 48 | 49 | def __call__(self, item): 50 | with torch.no_grad(): 51 | image_embs = None 52 | text_embs = None 53 | image_filename = None 54 | text = None 55 | metadata = None 56 | if self.enable_image: 57 | image_features = self.model_img(item["image_tensor"].to(self.device)) 58 | image_features /= image_features.norm(dim=-1, keepdim=True) 59 | image_embs = image_features.cpu().to(torch.float16).numpy() 60 | image_filename = item["image_filename"] 61 | if self.enable_text: 62 | if self.use_mclip: 63 | text_embs = normalized(self.model_txt(item["text"])) 64 | else: 65 | text_features = self.model_txt(item["text_tokens"].to(self.device)) 66 | text_features /= text_features.norm(dim=-1, keepdim=True) 67 | text_embs = text_features.cpu().to(torch.float16).numpy() 68 | text = item["text"] 69 | if self.enable_metadata: 70 | metadata = item["metadata"] 71 | 72 | return { 73 | "image_embs": image_embs, 74 | "text_embs": text_embs, 75 | "image_filename": image_filename, 76 | "text": text, 77 | "metadata": metadata, 78 | } 79 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/reader.py: -------------------------------------------------------------------------------- 1 | """Reader module provides files and webdataset readers""" 2 | 3 | from pathlib import Path 4 | from PIL import Image, UnidentifiedImageError 5 | from torch.utils.data import DataLoader 6 | from torch.utils.data.dataloader import default_collate 7 | import io 8 | 9 | 10 | def folder_to_keys(folder, enable_text=True, enable_image=True, enable_metadata=False): 11 | """returns a list of keys from a folder of images and text""" 12 | path = Path(folder) 13 | text_files = None 14 | metadata_files = None 15 | image_files = None 16 | if enable_text: 17 | text_files = [*path.glob("**/*.txt")] 18 | text_files = {text_file.relative_to(path).as_posix(): text_file for text_file in text_files} 19 | if enable_image: 20 | image_files = [ 21 | *path.glob("**/*.png"), 22 | *path.glob("**/*.jpg"), 23 | *path.glob("**/*.jpeg"), 24 | *path.glob("**/*.bmp"), 25 | *path.glob("**/*.webp"), 26 | *path.glob("**/*.PNG"), 27 | *path.glob("**/*.JPG"), 28 | *path.glob("**/*.JPEG"), 29 | *path.glob("**/*.BMP"), 30 | *path.glob("**/*.WEBP"), 31 | ] 32 | image_files = {image_file.relative_to(path).as_posix(): image_file for image_file in image_files} 33 | if enable_metadata: 34 | metadata_files = [*path.glob("**/*.json")] 35 | metadata_files = {metadata_file.relative_to(path).as_posix(): metadata_file for metadata_file in metadata_files} 36 | 37 | keys = None 38 | 39 | def join(new_set): 40 | return new_set & keys if keys is not None else new_set 41 | 42 | if enable_text: 43 | keys = join(text_files.keys()) 44 | elif enable_image: 45 | keys = join(image_files.keys()) 46 | elif enable_metadata: 47 | keys = join(metadata_files.keys()) 48 | 49 | keys = list(sorted(keys)) 50 | 51 | return keys, text_files, image_files, metadata_files 52 | 53 | 54 | def get_image_dataset(): 55 | """retrieve image dataset module without importing torch at the top level""" 56 | 57 | from torch.utils.data import Dataset # pylint: disable=import-outside-toplevel 58 | 59 | class ImageDataset(Dataset): 60 | """ImageDataset is a pytorch Dataset exposing image and text tensors from a folder of image and text""" 61 | 62 | def __init__( 63 | self, 64 | preprocess, 65 | tokenizer, 66 | folder, 67 | enable_text=True, 68 | enable_image=True, 69 | enable_metadata=False, 70 | input_sampler=lambda a: a, 71 | ): 72 | super().__init__() 73 | 74 | self.keys, text_files, image_files, metadata_files = folder_to_keys( 75 | folder, enable_text, enable_image, enable_metadata 76 | ) 77 | self.keys = input_sampler(self.keys) 78 | self.enable_text = enable_text 79 | self.enable_image = enable_image 80 | self.enable_metadata = enable_metadata 81 | keys_set = set(self.keys) 82 | if self.enable_text: 83 | self.tokenizer = lambda text: tokenizer([text])[0] 84 | self.text_files = {k: v for k, v in text_files.items() if k in keys_set} 85 | if self.enable_image: 86 | self.image_files = {k: v for k, v in image_files.items() if k in keys_set} 87 | self.image_transform = preprocess 88 | if self.enable_metadata: 89 | self.metadata_files = {k: v for k, v in metadata_files.items() if k in keys_set} 90 | 91 | def __len__(self): 92 | return len(self.keys) 93 | 94 | def __getitem__(self, ind): 95 | key = self.keys[ind] 96 | output = {} 97 | 98 | if self.enable_image: 99 | image_file = self.image_files[key] 100 | try: 101 | image_tensor = self.image_transform(Image.open(image_file)) 102 | except (UnidentifiedImageError, OSError) as e: 103 | print(f"Failed to load image {image_file}. Error: {e}. Skipping.") 104 | return None # return None to be filtered in the batch collate_fn 105 | output["image_filename"] = str(image_file) 106 | output["image_tensor"] = image_tensor 107 | 108 | if self.enable_text: 109 | text_file = self.text_files[key] 110 | caption = text_file.read_text() 111 | tokenized_text = self.tokenizer(caption) 112 | output["text_tokens"] = tokenized_text 113 | output["text"] = caption 114 | 115 | if self.enable_metadata: 116 | metadata_file = self.metadata_files[key] 117 | metadata = metadata_file.read_text() 118 | output["metadata"] = metadata 119 | 120 | return output 121 | 122 | return ImageDataset 123 | 124 | 125 | def create_webdataset( 126 | urls, 127 | image_transform, 128 | tokenizer, 129 | enable_text=True, 130 | enable_image=True, 131 | image_key="jpg", 132 | caption_key="txt", 133 | enable_metadata=False, 134 | cache_path=None, 135 | input_sampler=lambda a: a, 136 | ): 137 | """Create a WebDataset reader, it can read a webdataset of image, text and json""" 138 | import webdataset as wds # pylint: disable=import-outside-toplevel 139 | 140 | urls = input_sampler(urls) 141 | 142 | dataset = wds.WebDataset(urls, cache_dir=cache_path, cache_size=10**10, handler=wds.handlers.warn_and_continue) 143 | 144 | def _tokenizer(text): 145 | return tokenizer([text])[0] 146 | 147 | def filter_dataset(item): 148 | if enable_text and caption_key not in item: 149 | return False 150 | if enable_image and image_key not in item: 151 | return False 152 | if enable_metadata and "json" not in item: 153 | return False 154 | return True 155 | 156 | filtered_dataset = dataset.select(filter_dataset) 157 | 158 | def preprocess_dataset(item): 159 | output = {} 160 | if enable_image: 161 | image_data = item[image_key] 162 | image = Image.open(io.BytesIO(image_data)) 163 | image_tensor = image_transform(image) 164 | output["image_filename"] = item["__key__"] 165 | output["image_tensor"] = image_tensor 166 | 167 | if enable_text: 168 | text = item[caption_key] 169 | caption = text.decode("utf-8") 170 | tokenized_text = _tokenizer(caption) 171 | output["text_tokens"] = tokenized_text 172 | output["text"] = caption 173 | 174 | if enable_metadata: 175 | metadata_file = item["json"] 176 | metadata = metadata_file.decode("utf-8") 177 | output["metadata"] = metadata 178 | return output 179 | 180 | transformed_dataset = filtered_dataset.map(preprocess_dataset, handler=wds.handlers.warn_and_continue) 181 | return transformed_dataset 182 | 183 | 184 | def dataset_to_dataloader(dataset, batch_size, num_prepro_workers, input_format): 185 | """Create a pytorch dataloader from a dataset""" 186 | 187 | def collate_fn(batch): 188 | batch = list(filter(lambda x: x is not None, batch)) 189 | return default_collate(batch) 190 | 191 | data = DataLoader( 192 | dataset, 193 | batch_size=batch_size, 194 | shuffle=False, 195 | num_workers=num_prepro_workers, 196 | pin_memory=True, 197 | prefetch_factor=2, 198 | collate_fn=collate_fn if input_format == "files" else None, 199 | ) 200 | return data 201 | 202 | 203 | class FilesReader: 204 | """FilesReader is a reader that reads files from a folder""" 205 | 206 | def __init__( 207 | self, 208 | sampler, 209 | preprocess, 210 | tokenizer, 211 | input_dataset, 212 | batch_size, 213 | num_prepro_workers, 214 | enable_text=True, 215 | enable_image=True, 216 | enable_metadata=False, 217 | ) -> None: 218 | super().__init__() 219 | dataset = get_image_dataset()( 220 | preprocess, tokenizer, input_dataset, enable_text, enable_image, enable_metadata, sampler 221 | ) 222 | self.dataloader = dataset_to_dataloader(dataset, batch_size, num_prepro_workers, "files") 223 | 224 | def __iter__(self): 225 | for batch in self.dataloader: 226 | yield batch 227 | 228 | 229 | class WebdatasetReader: 230 | """WebdatasetReader is a reader that reads samples from a webdataset""" 231 | 232 | def __init__( 233 | self, 234 | sampler, 235 | preprocess, 236 | tokenizer, 237 | input_dataset, 238 | batch_size, 239 | num_prepro_workers, 240 | enable_text=True, 241 | enable_image=True, 242 | enable_metadata=False, 243 | wds_image_key="jpg", 244 | wds_caption_key="txt", 245 | cache_path=None, 246 | ): 247 | self.batch_size = batch_size 248 | dataset = create_webdataset( 249 | input_dataset, 250 | preprocess, 251 | tokenizer, 252 | enable_text=enable_text, 253 | enable_image=enable_image, 254 | image_key=wds_image_key, 255 | caption_key=wds_caption_key, 256 | enable_metadata=enable_metadata, 257 | cache_path=cache_path, 258 | input_sampler=sampler, 259 | ) 260 | self.dataloader = dataset_to_dataloader(dataset, batch_size, num_prepro_workers, "webdataset") 261 | 262 | def __iter__(self): 263 | for batch in self.dataloader: 264 | yield batch 265 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/runner.py: -------------------------------------------------------------------------------- 1 | """The runner combine reader, mapper and writer to produce clip embeddings""" 2 | 3 | import time 4 | 5 | 6 | class Sampler: 7 | """Sampler""" 8 | 9 | def __init__(self, output_partition_id, output_partition_count): 10 | self.output_partition_id = output_partition_id 11 | self.output_partition_count = output_partition_count 12 | 13 | def __call__(self, l): 14 | return [e for i, e in enumerate(l) if i % self.output_partition_count == self.output_partition_id] 15 | 16 | 17 | class Runner: 18 | """Runner class""" 19 | 20 | def __init__(self, reader_builder, mapper_builder, writer_builder, logger_builder, output_partition_count): 21 | self.reader_builder = reader_builder 22 | self.mapper_builder = mapper_builder 23 | self.writer_builder = writer_builder 24 | self.logger_builder = logger_builder 25 | self.output_partition_count = output_partition_count 26 | 27 | def __call__(self, i): 28 | sampler = Sampler(i, self.output_partition_count) 29 | reader = self.reader_builder(sampler) 30 | writer = self.writer_builder(i) 31 | mapper = self.mapper_builder() 32 | logger = self.logger_builder(i) 33 | logger.start() 34 | iterator = reader.__iter__() 35 | while True: 36 | begin_time = time.time() 37 | start_time = time.perf_counter() 38 | try: 39 | batch = iterator.__next__() 40 | except StopIteration: 41 | break 42 | read_duration = time.perf_counter() - start_time 43 | start_time = time.perf_counter() 44 | embeddings = mapper(batch) 45 | inference_duration = time.perf_counter() - start_time 46 | start_time = time.perf_counter() 47 | writer(embeddings) 48 | write_duration = time.perf_counter() - start_time 49 | end_time = time.time() 50 | logger( 51 | { 52 | "start_time": begin_time, 53 | "end_time": end_time, 54 | "read_duration": read_duration, 55 | "inference_duration": inference_duration, 56 | "write_duration": write_duration, 57 | "total_duration": end_time - begin_time, 58 | "sample_count": batch["image_tensor"].shape[0] 59 | if "image_tensor" in batch 60 | else batch["text_tokens"].shape[0], 61 | } 62 | ) 63 | logger.end() 64 | writer.flush() 65 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/slurm_distributor.py: -------------------------------------------------------------------------------- 1 | """Distribute work using SLURM""" 2 | 3 | import os 4 | import time 5 | import json 6 | import subprocess 7 | from datetime import datetime 8 | 9 | TIMESTAMP = datetime.now().timestamp() 10 | 11 | 12 | class SlurmDistributor: 13 | """distribute work across a collection of slurm jobs""" 14 | 15 | def __init__(self, tasks, worker_args, slurm_args): 16 | self.num_tasks = len(tasks) 17 | self.worker_args = worker_args 18 | self.slurm_args = slurm_args 19 | 20 | self.job_timeout = slurm_args.pop("job_timeout") 21 | self.verbose_wait = slurm_args.pop("verbose_wait") 22 | 23 | def __call__(self): 24 | """ 25 | Create a sbatch file, submit it to slurm, and wait for it to finish. 26 | """ 27 | # pop the cache path from the slurm args to remove it 28 | cache_path = self.slurm_args.pop("cache_path") 29 | 30 | # create the cache path if it doesn't exist 31 | if cache_path is None: 32 | cache_path = os.path.expanduser("~/.cache") 33 | 34 | os.makedirs(cache_path, exist_ok=True) 35 | 36 | # make the filenames unique using the current timestamp 37 | sbatch_script_path = os.path.join(cache_path, f"sbatch_script_{TIMESTAMP}.sh") 38 | 39 | # save the file to the cache path 40 | with open(sbatch_script_path, "w", encoding="utf-8") as sbatch_file: 41 | sbatch_file.write( 42 | self._generate_sbatch(cache_path=cache_path, slurm_args=self.slurm_args, worker_args=self.worker_args) 43 | ) 44 | 45 | # now we need to run the job 46 | status = self._run_job(sbatch_script_path) 47 | 48 | # interpret the results 49 | if status == "success": 50 | print("job succeeded") 51 | return True 52 | elif status == "failed": 53 | print("job failed") 54 | return False 55 | else: 56 | print("exception occurred") 57 | return False 58 | 59 | def _run_job(self, sbatch_file): 60 | """ 61 | Run a job and wait for it to finish. 62 | """ 63 | try: 64 | job_id = self._start_job(sbatch_file) 65 | 66 | print(f"waiting for job {job_id}") 67 | 68 | timeout = self.job_timeout 69 | 70 | if timeout is None: 71 | print("You have not specified a timeout, defaulting to 2 weeks.") 72 | timeout = 1.21e6 73 | 74 | status = self._wait_for_job_to_finish(job_id=job_id, timeout=timeout) 75 | 76 | if not status: 77 | print(f"canceling {job_id}") 78 | subprocess.check_output(["scancel", job_id]).decode("utf8") 79 | status = self._wait_for_job_to_finish(job_id) 80 | print("job cancelled") 81 | return "failed" 82 | else: 83 | print("job succeeded") 84 | return "success" 85 | except Exception as e: # pylint: disable=broad-except 86 | print(e) 87 | return "exception occurred" 88 | 89 | def _wait_for_job_to_finish(self, job_id, timeout=30): 90 | t = time.time() 91 | while 1: 92 | if time.time() - t > timeout: 93 | return False 94 | time.sleep(1) 95 | if self._is_job_finished(job_id): 96 | return True 97 | 98 | def _is_job_finished(self, job_id): 99 | status = subprocess.check_output(["squeue", "-j", job_id]).decode("utf8") 100 | 101 | if self.verbose_wait: 102 | print(f"job status is {status}") 103 | 104 | return status == "slurm_load_jobs error: Invalid job id specified" or len(status.split("\n")) == 2 105 | 106 | def _start_job(self, sbatch_file): 107 | """start job""" 108 | args = ["sbatch"] 109 | args.append(sbatch_file) 110 | sbatch_output = subprocess.check_output(args).decode("utf8") 111 | lines = sbatch_output.split("\n") 112 | 113 | lines = [line for line in lines if "Submitted" in line] 114 | if len(lines) == 0: 115 | raise ValueError(f"slurm sbatch failed: {sbatch_output}") 116 | 117 | parsed_sbatch = lines[0].split(" ") 118 | job_id = parsed_sbatch[3].strip() 119 | return job_id 120 | 121 | def _write_json_worker_args(self, worker_args, cache_path): 122 | """write the worker args to a json file""" 123 | worker_args_path = os.path.join(cache_path, f"worker_args_{TIMESTAMP}.json") 124 | with open(worker_args_path, "w", encoding="utf-8") as worker_args_file: 125 | json.dump(worker_args, worker_args_file, indent=4) 126 | return worker_args_path 127 | 128 | def _generate_sbatch(self, cache_path, slurm_args, worker_args): 129 | """ 130 | Generate sbatch for a worker. 131 | 132 | sbatch: allows you to specify a configuration and task in a file 133 | - https://slurm.schedmd.com/sbatch.html 134 | """ 135 | # write the worker args to a file 136 | worker_args_path = self._write_json_worker_args(worker_args, cache_path) 137 | 138 | venv = os.environ["VIRTUAL_ENV"] 139 | scomment = ("--comment " + slurm_args["job_comment"]) if ["job_comment"] is not None else "" 140 | sbatch_scomment = ( 141 | ("#SBATCH --comment " + slurm_args["job_comment"]) if slurm_args["job_comment"] is not None else "" 142 | ) 143 | nodelist = ("#SBATCH --nodelist " + slurm_args["nodelist"]) if slurm_args["nodelist"] is not None else "" 144 | exclude = ("#SBATCH --exclude " + slurm_args["exclude"]) if slurm_args["exclude"] is not None else "" 145 | 146 | return f"""#!/bin/bash 147 | # Define sbatch config, use exclusive to capture all resources in each node 148 | #SBATCH --partition={slurm_args["partition"]} 149 | #SBATCH --job-name={slurm_args["job_name"]} 150 | #SBATCH --output={cache_path}/slurm-%x_%j.out 151 | #SBATCH --nodes={slurm_args["nodes"]} 152 | #SBATCH --ntasks-per-node=8 153 | #SBATCH --cpus-per-gpu=6 154 | #SBATCH --gres=gpu:8 155 | #SBATCH --exclusive 156 | 157 | {sbatch_scomment} 158 | {nodelist} 159 | {exclude} 160 | 161 | # Environment variables for the inner script 162 | export NUM_TASKS={self.num_tasks} 163 | export WORLD_SIZE={slurm_args["nodes"] * 8} # 8 gpus per node 164 | export WORKER_ARGS_PATH={worker_args_path} 165 | 166 | # Run the internal script 167 | source {venv}/bin/activate 168 | srun --cpu_bind=v --accel-bind=gn {scomment} clip-retrieval inference.slurm_worker 169 | """ 170 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/slurm_worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Read environment variables and distribute work work to this rank. 3 | This script is called by the slurm distributor. 4 | 5 | Will launch a worker. 6 | """ 7 | 8 | import os 9 | import json 10 | 11 | from torch.cuda import set_device 12 | 13 | from clip_retrieval.clip_inference.worker import worker 14 | 15 | 16 | def get_task_list(num_tasks, world_size, global_rank, local_rank): 17 | """Get the list of tasks to process.""" 18 | tasks_per_worker = num_tasks // world_size 19 | 20 | # Assign a subset of tasks to each worker 21 | start = global_rank * tasks_per_worker 22 | end = start + tasks_per_worker 23 | 24 | # If the tasks don't divide evenly 25 | # then we should redistribute the remainder 26 | if global_rank < num_tasks % world_size: 27 | start += global_rank 28 | end += global_rank + 1 29 | else: 30 | start += num_tasks % world_size 31 | end += num_tasks % world_size 32 | 33 | tasks = list(range(start, end)) 34 | 35 | print(f"worker global rank:{global_rank}\tlocal rank: {local_rank}\tprocessing tasks {tasks}") 36 | 37 | return tasks 38 | 39 | 40 | def slurm_worker(): 41 | """Distribute work to this job and launch a job.""" 42 | 43 | # Read environment variables 44 | # These are set by slurm_distributor or SLURM itself 45 | num_tasks = int(os.environ["NUM_TASKS"]) 46 | global_rank = int(os.environ["SLURM_PROCID"]) 47 | world_size = int(os.environ["WORLD_SIZE"]) 48 | local_rank = int(os.environ["SLURM_LOCALID"]) 49 | 50 | # Read the worker args from the file 51 | with open(os.environ["WORKER_ARGS_PATH"], "r", encoding="utf-8") as worker_args_file: 52 | worker_args = json.load(worker_args_file) 53 | 54 | # Find the range of tasks to process 55 | tasks = get_task_list(num_tasks, world_size, global_rank, local_rank) 56 | 57 | # set device 58 | set_device(local_rank) 59 | 60 | # Launch the worker 61 | worker(tasks, **worker_args) 62 | 63 | 64 | if __name__ == "__main__": 65 | slurm_worker() 66 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/worker.py: -------------------------------------------------------------------------------- 1 | """ 2 | Inference Worker: 3 | 4 | A completely independent process that will be started once for each GPU node. 5 | Distributors will call this either through the CLI or directly. 6 | 7 | The worker sequentially process the tasks passed to it. 8 | Tasks are lists of partition_id's that this worker will be responsible for. 9 | """ 10 | 11 | import fire 12 | from braceexpand import braceexpand 13 | 14 | from clip_retrieval.clip_inference.runner import Runner 15 | from clip_retrieval.clip_inference.mapper import ClipMapper 16 | from clip_retrieval.clip_inference.writer import NumpyWriter 17 | from clip_retrieval.clip_inference.logger import LoggerWriter 18 | from clip_retrieval.clip_inference.reader import FilesReader, WebdatasetReader 19 | from all_clip import load_clip 20 | 21 | 22 | def worker( 23 | tasks, 24 | input_dataset, 25 | output_folder, 26 | output_partition_count, 27 | input_format="files", 28 | cache_path=None, 29 | batch_size=256, 30 | num_prepro_workers=4, 31 | enable_text=True, 32 | enable_image=True, 33 | enable_metadata=False, 34 | wds_image_key="jpg", 35 | wds_caption_key="txt", 36 | clip_model="ViT-B/32", 37 | mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", 38 | use_mclip=False, 39 | use_jit=True, 40 | clip_cache_path=None, 41 | ): 42 | """Start a worker""" 43 | print("Starting the worker", flush=True) 44 | 45 | # check for brace expansion 46 | if input_format == "webdataset" and not isinstance(input_dataset, list): 47 | input_dataset = list(braceexpand(input_dataset)) 48 | 49 | print(f"dataset is {len(input_dataset)}", flush=True) 50 | 51 | def reader_builder(sampler): 52 | _, preprocess, tokenizer = load_clip( 53 | clip_model=clip_model, 54 | use_jit=use_jit, 55 | warmup_batch_size=batch_size, 56 | clip_cache_path=clip_cache_path, 57 | ) 58 | if input_format == "files": 59 | return FilesReader( 60 | sampler, 61 | preprocess, 62 | tokenizer, 63 | input_dataset, 64 | batch_size, 65 | num_prepro_workers, 66 | enable_text=enable_text, 67 | enable_image=enable_image, 68 | enable_metadata=enable_metadata, 69 | ) 70 | elif input_format == "webdataset": 71 | return WebdatasetReader( 72 | sampler, 73 | preprocess, 74 | tokenizer, 75 | input_dataset, 76 | batch_size, 77 | num_prepro_workers, 78 | enable_text=enable_text, 79 | enable_image=enable_image, 80 | enable_metadata=enable_metadata, 81 | wds_image_key=wds_image_key, 82 | wds_caption_key=wds_caption_key, 83 | cache_path=cache_path, 84 | ) 85 | else: 86 | raise ValueError(f"Unknown input_format: {input_format}") 87 | 88 | def mapper_builder(): 89 | return ClipMapper( 90 | enable_image=enable_image, 91 | enable_text=enable_text, 92 | enable_metadata=enable_metadata, 93 | use_mclip=use_mclip, 94 | clip_model=clip_model, 95 | use_jit=use_jit, 96 | mclip_model=mclip_model, 97 | clip_cache_path=clip_cache_path, 98 | warmup_batch_size=batch_size, 99 | ) 100 | 101 | def writer_builder(i): 102 | return NumpyWriter( 103 | partition_id=i, 104 | output_folder=output_folder, 105 | enable_text=enable_text, 106 | enable_image=enable_image, 107 | enable_metadata=enable_metadata, 108 | output_partition_count=output_partition_count, 109 | ) 110 | 111 | def logger_builder(i): 112 | return LoggerWriter( 113 | partition_id=i, 114 | stats_folder=output_folder + "/stats", 115 | ) 116 | 117 | runner = Runner( 118 | reader_builder=reader_builder, 119 | mapper_builder=mapper_builder, 120 | writer_builder=writer_builder, 121 | logger_builder=logger_builder, 122 | output_partition_count=output_partition_count, 123 | ) 124 | 125 | for task in tasks: 126 | print(f"Starting work on task {task}", flush=True) 127 | runner(task) 128 | 129 | 130 | if __name__ == "__main__": 131 | fire.Fire(worker) 132 | -------------------------------------------------------------------------------- /clip_retrieval/clip_inference/writer.py: -------------------------------------------------------------------------------- 1 | """writer module saves embeddings""" 2 | 3 | import fsspec 4 | from io import BytesIO 5 | import json 6 | import math 7 | 8 | 9 | class OutputSink: 10 | """This output sink can save image, text embeddings as npy and metadata as parquet""" 11 | 12 | def __init__(self, output_folder, enable_text, enable_image, enable_metadata, partition_id, output_partition_count): 13 | self.enable_text = enable_text 14 | self.enable_image = enable_image 15 | self.enable_metadata = enable_metadata 16 | self.fs, output_folder = fsspec.core.url_to_fs(output_folder) 17 | self.output_folder = output_folder 18 | self.img_emb_folder = output_folder + "/img_emb" 19 | self.text_emb_folder = output_folder + "/text_emb" 20 | self.metadata_folder = output_folder + "/metadata" 21 | self.batch_num = partition_id 22 | self.oom_partition_count = int(math.log10(output_partition_count)) + 1 23 | 24 | if enable_image: 25 | self.fs.makedirs(self.img_emb_folder, exist_ok=True) 26 | 27 | if enable_text: 28 | self.fs.makedirs(self.text_emb_folder, exist_ok=True) 29 | 30 | self.fs.makedirs(self.metadata_folder, exist_ok=True) 31 | 32 | self.batch_count = 0 33 | self.__init_batch() 34 | 35 | def __init_batch(self): 36 | self.image_embeddings = [] 37 | self.text_embeddings = [] 38 | self.image_names = [] 39 | self.captions = [] 40 | self.metadata = [] 41 | self.batch_count = 0 42 | 43 | def add(self, sample): 44 | """ 45 | add to buffers the image embeddings, text embeddings, and meta 46 | """ 47 | 48 | self.batch_count += sample["image_embs"].shape[0] if self.enable_image else sample["text_embs"].shape[0] 49 | if self.enable_image: 50 | self.image_embeddings.append(sample["image_embs"]) 51 | self.image_names.extend(sample["image_filename"]) 52 | if self.enable_text: 53 | self.captions.extend(sample["text"]) 54 | self.text_embeddings.append(sample["text_embs"]) 55 | if self.enable_metadata: 56 | self.metadata.extend(sample["metadata"]) 57 | 58 | def __write_batch(self): 59 | """ 60 | write a batch of embeddings and meta to npy and parquet 61 | """ 62 | import numpy as np # pylint: disable=import-outside-toplevel 63 | import pandas as pd # pylint: disable=import-outside-toplevel 64 | 65 | data_lists = [] 66 | data_columns = [] 67 | batch_num_str = str(self.batch_num).zfill(self.oom_partition_count) 68 | if self.enable_image: 69 | img_emb_mat = np.concatenate(self.image_embeddings) 70 | output_path_img = self.img_emb_folder + "/img_emb_" + batch_num_str 71 | 72 | with self.fs.open(output_path_img + ".npy", "wb") as f: 73 | npb = BytesIO() 74 | np.save(npb, img_emb_mat) 75 | f.write(npb.getbuffer()) 76 | 77 | data_lists.append(self.image_names) 78 | data_columns.append("image_path") 79 | 80 | if self.enable_text: 81 | text_emb_mat = np.concatenate(self.text_embeddings) 82 | output_path_text = self.text_emb_folder + "/text_emb_" + batch_num_str 83 | 84 | with self.fs.open(output_path_text + ".npy", "wb") as f: 85 | npb = BytesIO() 86 | np.save(npb, text_emb_mat) 87 | f.write(npb.getbuffer()) 88 | 89 | data_lists.append(self.captions) 90 | data_columns.append("caption") 91 | 92 | if self.enable_metadata: 93 | data_lists.append(self.metadata) 94 | data_columns.append("metadata") 95 | 96 | df = pd.DataFrame(data=list(zip(*data_lists)), columns=data_columns) 97 | if self.enable_metadata: 98 | parsed_metadata = pd.json_normalize(df["metadata"].apply(json.loads)) 99 | without_existing_columns = parsed_metadata.drop( 100 | columns=set(["caption", "metadata", "image_path"]) & set(parsed_metadata.keys()) 101 | ) 102 | df = df.join(without_existing_columns).drop(columns=["metadata"]) 103 | 104 | output_path_metadata = self.metadata_folder + "/metadata_" + batch_num_str + ".parquet" 105 | with self.fs.open(output_path_metadata, "wb") as f: 106 | df.to_parquet(f) 107 | 108 | def flush(self): 109 | if self.batch_count == 0: 110 | return 111 | self.__write_batch() 112 | self.__init_batch() 113 | 114 | 115 | class NumpyWriter: 116 | """the numpy writer writes embeddings to folders img_emb, text_emb, and metadata""" 117 | 118 | def __init__(self, partition_id, output_folder, enable_text, enable_image, enable_metadata, output_partition_count): 119 | self.sink = OutputSink( 120 | output_folder, enable_text, enable_image, enable_metadata, partition_id, output_partition_count 121 | ) 122 | 123 | def __call__(self, batch): 124 | self.sink.add(batch) 125 | 126 | def flush(self): 127 | self.sink.flush() 128 | -------------------------------------------------------------------------------- /clip_retrieval/h14_nsfw_model.py: -------------------------------------------------------------------------------- 1 | """Modeling & Loading code for H14 NSFW Detector""" 2 | 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | # pylint: disable=invalid-name 10 | class H14_NSFW_Detector(nn.Module): 11 | """An NSFW detector for H14 CLIP embeds""" 12 | 13 | def __init__(self, input_size=1024, cache_folder=os.path.expanduser("~/.cache/clip_retrieval")): 14 | super().__init__() 15 | self.input_size = input_size 16 | self.layers = nn.Sequential( 17 | nn.Linear(self.input_size, 1024), 18 | nn.ReLU(), 19 | nn.Dropout(0.2), 20 | nn.Linear(1024, 2048), 21 | nn.ReLU(), 22 | nn.Dropout(0.2), 23 | nn.Linear(2048, 1024), 24 | nn.ReLU(), 25 | nn.Dropout(0.2), 26 | nn.Linear(1024, 256), 27 | nn.ReLU(), 28 | nn.Dropout(0.2), 29 | nn.Linear(256, 128), 30 | nn.ReLU(), 31 | nn.Dropout(0.2), 32 | nn.Linear(128, 16), 33 | nn.Linear(16, 1), 34 | ) 35 | 36 | # Load the model from the cache folder 37 | self.load_state_dict(self.load_state(cache_folder)) 38 | self.eval() 39 | 40 | def forward(self, x): 41 | """Forward pass of the model""" 42 | return self.layers(x) 43 | 44 | # pylint: disable=unused-argument 45 | def predict(self, x, batch_size): 46 | """autokeras interface""" 47 | with torch.no_grad(): 48 | x = torch.from_numpy(x) 49 | y = self.layers(x) 50 | return y.detach().cpu().numpy() 51 | 52 | def load_state(self, cache_folder: str): 53 | """ 54 | Load the model from the cache folder 55 | If it does not exist, create it 56 | """ 57 | 58 | cache_subfolder = os.path.join(cache_folder, "h14_nsfw_model") 59 | if not os.path.exists(cache_subfolder): 60 | os.makedirs(cache_subfolder) 61 | 62 | model_path = os.path.join(cache_subfolder, "model.pt") 63 | if not os.path.exists(model_path): 64 | print("Downloading model...") 65 | import urllib.request # pylint: disable=import-outside-toplevel 66 | 67 | urllib.request.urlretrieve( 68 | "https://github.com/LAION-AI/CLIP-based-NSFW-Detector/raw/main/h14_nsfw.pth", model_path 69 | ) 70 | print("Downloaded model H14 NSFW model to:", model_path) 71 | 72 | return torch.load(model_path, map_location="cpu") 73 | -------------------------------------------------------------------------------- /clip_retrieval/ivf_metadata_ordering.py: -------------------------------------------------------------------------------- 1 | """ivf metadata ordering is a module to reorder a metadata collection by ivf clusters""" 2 | 3 | import os 4 | from tqdm import tqdm 5 | from pathlib import Path 6 | import numpy as np 7 | from collections import defaultdict 8 | import heapq 9 | import time 10 | import pandas as pd 11 | 12 | import pyarrow.parquet as pq 13 | import h5py 14 | import faiss 15 | 16 | 17 | def search_to_new_ids(index, query, k): 18 | """ 19 | this function maps the result ids to the ones ordered by the ivf clusters 20 | to be used along with a re-ordered metadata 21 | """ 22 | distances, indices = index.search(query, k) 23 | opq2 = faiss.downcast_VectorTransform(index.chain.at(0)) 24 | xq = opq2.apply(query) 25 | _, l = faiss.extract_index_ivf(index).quantizer.search(xq, faiss.extract_index_ivf(index).nprobe) 26 | il = faiss.extract_index_ivf(index).invlists 27 | list_sizes = [il.list_size(i) for i in range(il.nlist)] 28 | starting_offset = [] 29 | c = 0 30 | for i in list_sizes: 31 | starting_offset.append(c) 32 | c += i 33 | old_id_to_new_id = {} 34 | for i in l[0]: 35 | i = int(i) 36 | ids = il.get_ids(i) 37 | list_size = il.list_size(int(i)) 38 | items = faiss.rev_swig_ptr(ids, list_size) 39 | for nit, it in enumerate(items): 40 | old_id_to_new_id[it] = starting_offset[i] + nit 41 | il.release_ids(ids=ids, list_no=i) 42 | ids = np.array([old_id_to_new_id[i] if i != -1 else -1 for i in indices[0]]) 43 | return distances, ids 44 | 45 | 46 | def get_old_to_new_mapping(index): 47 | """ 48 | use an ivf index to compute a mapping from initial ids to ids ordered by clusters 49 | """ 50 | il = faiss.extract_index_ivf(index).invlists 51 | d = np.ones((index.ntotal,), "int64") 52 | begin_list = [] 53 | current_begin = 0 54 | for i in tqdm(range(il.nlist)): 55 | begin_list.append(current_begin) 56 | ids = il.get_ids(i) 57 | list_size = il.list_size(int(i)) 58 | items = faiss.rev_swig_ptr(ids, list_size) 59 | new_ids = range(current_begin, current_begin + list_size) 60 | d.put(np.array(items, "int"), np.array(new_ids, "int")) 61 | il.release_ids(ids=ids, list_no=i) 62 | current_begin += list_size 63 | 64 | return d 65 | 66 | 67 | def re_order_parquet(index, input_path, output_path, columns_to_return): 68 | """ 69 | use external sort to reorder parquet files 70 | """ 71 | d = get_old_to_new_mapping(index) 72 | data_dir = Path(input_path) 73 | if not os.path.exists(output_path): 74 | os.mkdir(output_path) 75 | current_offset = 0 76 | current_id = 0 77 | for parquet_files in tqdm(sorted(data_dir.glob("*.parquet"))): 78 | df = pd.read_parquet(parquet_files) 79 | df["new_id"] = d[current_offset : current_offset + len(df)] 80 | saved_df = df[columns_to_return + ["new_id"]] 81 | saved_df = saved_df.sort_values("new_id") 82 | saved_df.to_parquet(output_path + "/meta_" + str(current_id) + ".parquet") 83 | current_id += 1 84 | current_offset += len(df) 85 | 86 | 87 | class Hdf5Sink: 88 | """ 89 | A hdf5 sink: take as input rows and write them to hdf5 regularly 90 | """ 91 | 92 | def __init__(self, output_hdf5_file, keys): 93 | self.f = h5py.File(output_hdf5_file, "w") 94 | self.ds = self.f.create_group("dataset") 95 | self.buffer = [] 96 | self.keys = keys 97 | 98 | def write(self, sample): 99 | self.buffer.append(sample) 100 | if len(self.buffer) == 10**6: 101 | self._write_buffer() 102 | 103 | def end(self): 104 | self._write_buffer() 105 | self.f.close() 106 | 107 | def _write_buffer(self): 108 | """ 109 | Write a list of rows to hdf5 110 | """ 111 | if len(self.buffer) == 0: 112 | return 113 | df = pd.DataFrame(self.buffer, columns=self.keys) 114 | for k, v in df.items(): 115 | if k not in self.keys: 116 | continue 117 | col = v 118 | if col.dtype in ("float64", "float32"): 119 | col = col.fillna(0.0) 120 | if col.dtype in ("int64", "int32"): 121 | col = col.fillna(0) 122 | if col.dtype == "object": 123 | col = col.fillna("") 124 | z = col.to_numpy() 125 | if k not in self.ds: 126 | self.ds.create_dataset(k, data=z, maxshape=(None,), compression="gzip") 127 | else: 128 | prevlen = len(self.ds[k]) 129 | self.ds[k].resize((prevlen + len(z),)) 130 | self.ds[k][prevlen:] = z 131 | self.buffer = [] 132 | 133 | 134 | class DummySink: 135 | def __init__(self): 136 | pass 137 | 138 | def write(self, sample): 139 | pass 140 | 141 | def end(self): 142 | pass 143 | 144 | 145 | def external_sort_parquet(output_sink, input_path): 146 | """ 147 | create heap 148 | add to heap 1 batch of each file 149 | store in dict nb of item in heap for each file 150 | start getting from heap and pushing to sink 151 | when nb_item[last_retrieved] == 0 and there is some item left in this file, add a new batch of that file in heap 152 | """ 153 | 154 | h = [] 155 | data_dir = Path(input_path) 156 | files = [pq.ParquetFile(filename, memory_map=True) for filename in sorted(data_dir.glob("*.parquet"))] 157 | batches_list = [ffile.iter_batches(batch_size=10**4) for ffile in files] 158 | index_to_value = {} 159 | counts = [ffile.metadata.num_rows for ffile in files] 160 | current_count_per_file = defaultdict(lambda: 0) 161 | 162 | def read_batch(i): 163 | batch = next(batches_list[i]) 164 | current_count_per_file[i] += batch.num_rows 165 | df = batch.to_pandas() 166 | data = zip(df["new_id"], *[df[c] for c in [c for c in df.columns if c != "new_id"]]) 167 | for e in data: 168 | heapq.heappush(h, (e[0], i)) 169 | index_to_value[e[0]] = e[1:] 170 | 171 | for i in range(len(batches_list)): 172 | read_batch(i) 173 | 174 | done_count_per_file = defaultdict(lambda: 0) 175 | c = 0 176 | begin = time.time() 177 | while h: 178 | c += 1 179 | e, i = heapq.heappop(h) 180 | v = index_to_value[e] 181 | del index_to_value[e] 182 | output_sink.write(v) 183 | current_count_per_file[i] -= 1 184 | done_count_per_file[i] += 1 185 | if current_count_per_file[i] == 0 and done_count_per_file[i] < counts[i]: 186 | read_batch(i) 187 | if c % 100000 == 0: 188 | print(e, c, time.time() - begin, "s") 189 | 190 | output_sink.end() 191 | -------------------------------------------------------------------------------- /doc_assets/clip-back-grafana.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/doc_assets/clip-back-grafana.png -------------------------------------------------------------------------------- /doc_assets/clip-front-pic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/doc_assets/clip-front-pic.png -------------------------------------------------------------------------------- /doc_assets/grafana_dashboard.json: -------------------------------------------------------------------------------- 1 | { 2 | "__inputs": [], 3 | "__requires": [ 4 | { 5 | "type": "grafana", 6 | "id": "grafana", 7 | "name": "Grafana", 8 | "version": "8.0.6" 9 | }, 10 | { 11 | "type": "panel", 12 | "id": "timeseries", 13 | "name": "Time series", 14 | "version": "" 15 | } 16 | ], 17 | "annotations": { 18 | "list": [ 19 | { 20 | "builtIn": 1, 21 | "datasource": "-- Grafana --", 22 | "enable": true, 23 | "hide": true, 24 | "iconColor": "rgba(0, 211, 255, 1)", 25 | "name": "Annotations & Alerts", 26 | "type": "dashboard" 27 | } 28 | ] 29 | }, 30 | "editable": true, 31 | "gnetId": null, 32 | "graphTooltip": 0, 33 | "id": null, 34 | "links": [], 35 | "panels": [ 36 | { 37 | "datasource": null, 38 | "fieldConfig": { 39 | "defaults": { 40 | "color": { 41 | "mode": "palette-classic" 42 | }, 43 | "custom": { 44 | "axisLabel": "", 45 | "axisPlacement": "auto", 46 | "barAlignment": 0, 47 | "drawStyle": "line", 48 | "fillOpacity": 0, 49 | "gradientMode": "none", 50 | "hideFrom": { 51 | "legend": false, 52 | "tooltip": false, 53 | "viz": false 54 | }, 55 | "lineInterpolation": "linear", 56 | "lineWidth": 1, 57 | "pointSize": 5, 58 | "scaleDistribution": { 59 | "type": "linear" 60 | }, 61 | "showPoints": "auto", 62 | "spanNulls": false, 63 | "stacking": { 64 | "group": "A", 65 | "mode": "none" 66 | }, 67 | "thresholdsStyle": { 68 | "mode": "off" 69 | } 70 | }, 71 | "mappings": [], 72 | "thresholds": { 73 | "mode": "absolute", 74 | "steps": [ 75 | { 76 | "color": "green", 77 | "value": null 78 | }, 79 | { 80 | "color": "red", 81 | "value": 80 82 | } 83 | ] 84 | }, 85 | "unit": "s" 86 | }, 87 | "overrides": [] 88 | }, 89 | "gridPos": { 90 | "h": 9, 91 | "w": 12, 92 | "x": 0, 93 | "y": 0 94 | }, 95 | "id": 2, 96 | "options": { 97 | "legend": { 98 | "calcs": [], 99 | "displayMode": "list", 100 | "placement": "bottom" 101 | }, 102 | "tooltip": { 103 | "mode": "single" 104 | } 105 | }, 106 | "targets": [ 107 | { 108 | "exemplar": true, 109 | "expr": "full_knn_request_time_sum / full_knn_request_time_count", 110 | "interval": "", 111 | "legendFormat": "full", 112 | "refId": "A" 113 | }, 114 | { 115 | "exemplar": true, 116 | "expr": "metadata_get_time_sum / metadata_get_time_count", 117 | "hide": false, 118 | "interval": "", 119 | "legendFormat": "metadata", 120 | "refId": "B" 121 | }, 122 | { 123 | "exemplar": true, 124 | "expr": "image_clip_inference_time_sum / image_clip_inference_time_count", 125 | "hide": false, 126 | "interval": "", 127 | "legendFormat": "image clip inference", 128 | "refId": "C" 129 | }, 130 | { 131 | "exemplar": true, 132 | "expr": "text_clip_inference_time_sum / text_clip_inference_time_count", 133 | "hide": false, 134 | "interval": "", 135 | "legendFormat": "text clip inference", 136 | "refId": "D" 137 | }, 138 | { 139 | "exemplar": true, 140 | "expr": "download_time_sum / download_time_count", 141 | "hide": false, 142 | "interval": "", 143 | "legendFormat": "download time", 144 | "refId": "E" 145 | }, 146 | { 147 | "exemplar": true, 148 | "expr": "knn_index_time_sum / knn_index_time_count", 149 | "hide": false, 150 | "interval": "", 151 | "legendFormat": "knn index time", 152 | "refId": "F" 153 | }, 154 | { 155 | "exemplar": true, 156 | "expr": "image_prepro_time_sum / image_prepro_time_count", 157 | "hide": false, 158 | "interval": "", 159 | "legendFormat": "image prepro", 160 | "refId": "G" 161 | }, 162 | { 163 | "exemplar": true, 164 | "expr": "text_prepro_time_sum / text_prepro_time_count", 165 | "hide": false, 166 | "interval": "", 167 | "legendFormat": "text prepro", 168 | "refId": "H" 169 | } 170 | ], 171 | "title": "Average latencies", 172 | "type": "timeseries" 173 | }, 174 | { 175 | "datasource": null, 176 | "fieldConfig": { 177 | "defaults": { 178 | "color": { 179 | "mode": "palette-classic" 180 | }, 181 | "custom": { 182 | "axisLabel": "", 183 | "axisPlacement": "auto", 184 | "barAlignment": 0, 185 | "drawStyle": "line", 186 | "fillOpacity": 0, 187 | "gradientMode": "none", 188 | "hideFrom": { 189 | "legend": false, 190 | "tooltip": false, 191 | "viz": false 192 | }, 193 | "lineInterpolation": "linear", 194 | "lineWidth": 1, 195 | "pointSize": 5, 196 | "scaleDistribution": { 197 | "type": "linear" 198 | }, 199 | "showPoints": "auto", 200 | "spanNulls": false, 201 | "stacking": { 202 | "group": "A", 203 | "mode": "none" 204 | }, 205 | "thresholdsStyle": { 206 | "mode": "off" 207 | } 208 | }, 209 | "mappings": [], 210 | "thresholds": { 211 | "mode": "absolute", 212 | "steps": [ 213 | { 214 | "color": "green", 215 | "value": null 216 | }, 217 | { 218 | "color": "red", 219 | "value": 80 220 | } 221 | ] 222 | }, 223 | "unit": "none" 224 | }, 225 | "overrides": [] 226 | }, 227 | "gridPos": { 228 | "h": 9, 229 | "w": 12, 230 | "x": 12, 231 | "y": 0 232 | }, 233 | "id": 3, 234 | "options": { 235 | "legend": { 236 | "calcs": [], 237 | "displayMode": "list", 238 | "placement": "bottom" 239 | }, 240 | "tooltip": { 241 | "mode": "single" 242 | } 243 | }, 244 | "targets": [ 245 | { 246 | "exemplar": true, 247 | "expr": "increase(full_knn_request_time_count[$__range])\n", 248 | "interval": "", 249 | "legendFormat": "full", 250 | "refId": "A" 251 | }, 252 | { 253 | "exemplar": true, 254 | "expr": "increase(metadata_get_time_count[$__range])\n", 255 | "hide": false, 256 | "interval": "", 257 | "legendFormat": "metadata", 258 | "refId": "B" 259 | }, 260 | { 261 | "exemplar": true, 262 | "expr": "increase(image_clip_inference_time_count[$__range])", 263 | "hide": false, 264 | "interval": "", 265 | "legendFormat": "image clip inference", 266 | "refId": "C" 267 | }, 268 | { 269 | "exemplar": true, 270 | "expr": "increase(text_clip_inference_time_count[$__range])", 271 | "hide": false, 272 | "interval": "", 273 | "legendFormat": "text clip inference", 274 | "refId": "D" 275 | }, 276 | { 277 | "exemplar": true, 278 | "expr": "increase(download_time_count[$__range])", 279 | "hide": false, 280 | "interval": "", 281 | "legendFormat": "download time", 282 | "refId": "E" 283 | }, 284 | { 285 | "exemplar": true, 286 | "expr": "increase(knn_index_time_count[$__range])", 287 | "hide": false, 288 | "interval": "", 289 | "legendFormat": "knn index time", 290 | "refId": "F" 291 | }, 292 | { 293 | "exemplar": true, 294 | "expr": "increase(image_prepro_time_count[$__range])", 295 | "hide": false, 296 | "interval": "", 297 | "legendFormat": "image prepro", 298 | "refId": "G" 299 | }, 300 | { 301 | "exemplar": true, 302 | "expr": "increase(text_prepro_time_count[$__range])", 303 | "hide": false, 304 | "interval": "", 305 | "legendFormat": "text prepro", 306 | "refId": "H" 307 | } 308 | ], 309 | "title": "Request count", 310 | "type": "timeseries" 311 | } 312 | ], 313 | "schemaVersion": 30, 314 | "style": "dark", 315 | "tags": [], 316 | "templating": { 317 | "list": [] 318 | }, 319 | "time": { 320 | "from": "now-1h", 321 | "to": "now" 322 | }, 323 | "timepicker": {}, 324 | "timezone": "", 325 | "title": "Clip", 326 | "uid": "zF8DzpI7z", 327 | "version": 2 328 | } 329 | -------------------------------------------------------------------------------- /docs/distributed_clip_inference.md: -------------------------------------------------------------------------------- 1 | # distributed clip inference 2 | 3 | If you want to generate billion of clip embeddings, read this. 4 | 5 | This guide is about using pyspark to run clip inference in multiple node and using multiple gpus. 6 | 7 | you may also be interested by [distributed img2dataset](https://github.com/rom1504/img2dataset/blob/main/examples/distributed_img2dataset_tutorial.md) 8 | 9 | We will be assuming ubuntu 20.04. 10 | 11 | ## Setup the master node 12 | 13 | On the master node: 14 | 15 | First download spark: 16 | ```bash 17 | wget https://archive.apache.org/dist/spark/spark-3.2.0/spark-3.2.0-bin-hadoop3.2.tgz 18 | tar xf spark-3.2.0-bin-hadoop3.2.tgz 19 | ``` 20 | 21 | Then download clip inference: 22 | ```bash 23 | rm -rf clip_retrieval.pex 24 | wget https://github.com/rom1504/clip-retrieval/releases/latest/download/clip_retrieval.tgz -O clip_retrieval.tgz 25 | wget https://github.com/rom1504/clip-retrieval/releases/latest/download/clip_retrieval_torch.tgz -O clip_retrieval_torch.tgz 26 | tar xf clip_retrieval.tgz 27 | tar xf clip_retrieval_torch.tgz 28 | ``` 29 | 30 | If the master node cannot open ports that are visible from your local machine, you can do a tunnel between your local machine and the master node to be able to see the spark ui (at http://localhost:8080) 31 | ```bash 32 | ssh -L 8080:localhost:8080 -L 4040:localhost:4040 master_node 33 | ``` 34 | 35 | 36 | ## Setup the worker nodes 37 | 38 | ### ssh basic setup 39 | 40 | Still in the master node, create a ips.txt with the ips of all the nodes 41 | 42 | ```bash 43 | ssh-keyscan `cat ips.txt` >> ~/.ssh/known_hosts 44 | ``` 45 | 46 | You may use a script like this to fill your .ssh/config file 47 | ``` 48 | def generate(ip): 49 | print( 50 | f"Host {ip}\n" 51 | f" HostName {ip}\n" 52 | " User ubuntu\n" 53 | " IdentityFile ~/yourkey.pem" 54 | ) 55 | 56 | with open("ips.txt") as f: 57 | lines = f.readlines() 58 | for line in lines: 59 | generate(line.strip()) 60 | ``` 61 | python3 generate.py >> ~/.ssh/config 62 | 63 | Install pssh with `sudo apt install pssh` 64 | 65 | Pick the right username (MASTER_USER) for the master node, and (USER) for the worker nodes, then run this to check your parallel ssh setup: 66 | ```bash 67 | USER=rom1504 68 | ``` 69 | 70 | Optionally, if another node than the current one has access to the worker nodes, you may need to add a ssh key to all the nodes with: 71 | ``` 72 | for IP in `cat ips.txt` 73 | do 74 | ssh-copy-id -i the_new_id_rsa $USER@$IP 75 | done 76 | ``` 77 | 78 | Check you can connect to all the nodes with: 79 | ``` 80 | parallel-ssh -l $USER -i -h ips.txt uname -a 81 | ``` 82 | 83 | ##### Install some packages 84 | 85 | ```bash 86 | parallel-ssh -l $USER -i -h ips.txt "sudo apt update" 87 | parallel-ssh -l $USER -i -h ips.txt "sudo apt install openjdk-11-jre-headless libgl1 htop tmux bwm-ng sshfs python3-distutils python3-apt python3.8 -y" 88 | ``` 89 | 90 | 91 | #### [Optional] Network setting on aws 92 | 93 | put in same VPC and security group and allow inbound 94 | 95 | ##### Download clip retrieval on all nodes 96 | 97 | Download clip retrieval on all node by retrying this N times until parallel ssh says success for all: 98 | ```bash 99 | 100 | parallel-ssh -i -h ips.txt "rm -rf clip_retrieval.pex" 101 | parallel-ssh -i -h ips.txt "wget https://github.com/rom1504/clip-retrieval/releases/latest/download/clip_retrieval.tgz -O clip_retrieval.tgz" 102 | parallel-ssh -i -h ips.txt "wget https://github.com/rom1504/clip-retrieval/releases/latest/download/clip_retrieval_torch.tgz -O clip_retrieval_torch.tgz" 103 | parallel-ssh -i -h ips.txt "tar xf clip_retrieval.tgz" 104 | parallel-ssh -i -h ips.txt "tar xf clip_retrieval_torch.tgz" 105 | ``` 106 | 107 | ##### Download spark on workers 108 | 109 | parallel-ssh -l $USER -i -h ips.txt "wget https://archive.apache.org/dist/spark/spark-3.2.0/spark-3.2.0-bin-hadoop3.2.tgz" 110 | parallel-ssh -l $USER -i -h ips.txt "tar xf spark-3.2.0-bin-hadoop3.2.tgz" 111 | 112 | echo '[{"id":{"componentName": "spark.worker","resourceName":"gpu"},"addresses":["0","1","2","3","4","5","6","7"]}]' > gpufile 113 | parallel-scp -h ips.txt gpufile /home/ubuntu/gpufile 114 | 115 | #### Start the master node 116 | 117 | When you're ready, you can start the master node with: 118 | 119 | ```bash 120 | ./spark-3.2.0-bin-hadoop3.2/sbin/start-master.sh -p 7077 121 | ``` 122 | 123 | 124 | #### Start the worker nodes 125 | 126 | When you're ready, you can start the worker nodes with: 127 | 128 | ```bash 129 | parallel-ssh -l $USER -i -h ips.txt 'SPARK_WORKER_OPTS="-Dspark.worker.resource.gpu.amount=8 -Dspark.worker.resourcesFile=/home/ubuntu/gpufile" ./spark-3.2.0-bin-hadoop3.2/sbin/start-worker.sh -c 16 -m 24G "spark://172.31.44.42:7077"' 130 | ``` 131 | 132 | Replace 172.31.44.42 by the master node ip. 133 | 134 | 135 | #### Stop the worker nodes 136 | 137 | When you're done, you can stop the worker nodes with: 138 | 139 | ```bash 140 | parallel-ssh -l $USER -i -h ips.txt "rm -rf ~/spark-3.2.0-bin-hadoop3.2/work/*" 141 | parallel-ssh -l $USER -i -h ips.txt "pkill java" 142 | ``` 143 | 144 | #### Stop the master node 145 | 146 | When you're done, you can stop the master node with: 147 | 148 | ```bash 149 | pkill java 150 | ``` 151 | 152 | 153 | ### Running clip inference on it 154 | 155 | Once your spark cluster is setup, you're ready to start clip inference in distributed mode. 156 | Make sure to open your spark UI, at http://localhost:8080 (or the ip where the master node is running) 157 | 158 | Save this script to inference.py. 159 | 160 | Then run `./clip_retrieval.pex/__main__.py inference.py` 161 | 162 | ```python 163 | from clip_retrieval import clip_inference 164 | import shutil 165 | import os 166 | from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel 167 | 168 | from pyspark import SparkConf, SparkContext 169 | 170 | def create_spark_session(): 171 | # this must be a path that is available on all worker nodes 172 | 173 | os.environ['PYSPARK_PYTHON'] = "/home/ubuntu/clip_retrieval.pex/__main__.py" 174 | spark = ( 175 | SparkSession.builder 176 | .config("spark.submit.deployMode", "client") \ 177 | .config("spark.executorEnv.PEX_ROOT", "./.pex") 178 | .config("spark.task.resource.gpu.amount", "1") 179 | .config("spark.executor.resource.gpu.amount", "8") 180 | #.config("spark.executor.cores", "16") 181 | #.config("spark.cores.max", "48") # you can reduce this number if you want to use only some cores ; if you're using yarn the option name is different, check spark doc 182 | .config("spark.driver.port", "5678") 183 | .config("spark.driver.blockManager.port", "6678") 184 | .config("spark.driver.host", "172.31.44.42") 185 | .config("spark.driver.bindAddress", "172.31.44.42") 186 | .config("spark.executor.memory", "16G") # make sure to increase this if you're using more cores per executor 187 | .config("spark.executor.memoryOverhead", "8G") 188 | .config("spark.task.maxFailures", "100") 189 | .master("spark://172.31.44.42:7077") # this should point to your master node, if using the tunnelling version, keep this to localhost 190 | .appName("spark-stats") 191 | .getOrCreate() 192 | ) 193 | return spark 194 | 195 | spark = create_spark_session() 196 | 197 | clip_inference(input_dataset="pipe:aws s3 cp --quiet s3://laion-us-east-1/laion-data/laion2B-data/{000000..231349}.tar -", output_folder="s3://laion-us-east-1/my_test_embedding2", input_format="webdataset", enable_metadata=True, write_batch_size=1000000, num_prepro_workers=8, batch_size=512, cache_path=None, enable_wandb=True, distribution_strategy="pyspark", clip_model="ViT-B/14") 198 | ``` 199 | 200 | ## Some benchmarks 201 | 202 | Using 1 node with 8 a100 on aws, using s3 as input and output: 203 | * 7000 sample/s on 8 a100 on vit-b / 32 : 2500 for one gpu so it's resizing bottlenecked 204 | * 7000 sample/s on 8 a100 on vit-b / 16 : 1100 sample/s for one gpu so it's still bottlenecked by resizing but much better 205 | * 2500 sample/s on 8 a100 on vit-l / 14 : 312 sample/s for one gpu so it's optimal 206 | 207 | on 4 such nodes, the speed are multiplied by 4 which is optimal. 208 | -------------------------------------------------------------------------------- /docs/laion5B_back.md: -------------------------------------------------------------------------------- 1 | # Run clip-retrieval back with laion5B index 2 | 3 | First step, download all the files from https://huggingface.co/datasets/laion/laion5B-index 4 | 5 | Second step, 6 | ```bash 7 | python3 -m venv .env 8 | source .env/bin/activate 9 | pip install clip-retrieval autokeras==1.0.18 keras==2.8.0 Keras-Preprocessing==1.1.2 tensorflow==2.8.0` 10 | ``` 11 | (autokeras is optional and needed only for safety filtering) 12 | 13 | Then put this 14 | ```json 15 | { 16 | "laion5B": { 17 | "indice_folder": "laion5B-index", 18 | "provide_safety_model": true, 19 | "enable_faiss_memory_mapping": true, 20 | "use_arrow": true, 21 | "enable_hdf5": false, 22 | "reorder_metadata_by_ivf_index": false, 23 | "columns_to_return": ["url", "caption"], 24 | "clip_model": "ViT-L/14", 25 | "enable_mclip_option": true 26 | } 27 | } 28 | ``` 29 | in indices.json file 30 | 31 | Finally run this 32 | ```python 33 | export CUDA_VISIBLE_DEVICES= 34 | clip-retrieval back --provide_violence_detector True --provide_safety_model True --clip_model="ViT-L/14" --default_backend="http://localhost:1234/" --port 1234 --indices-paths indices.json --use_arrow True --enable_faiss_memory_mapping True --columns_to_return='["url", "caption", "md5"]' 35 | ``` 36 | -------------------------------------------------------------------------------- /docs/laion5B_h14_back.md: -------------------------------------------------------------------------------- 1 | # How to setup clip-back with an H/14 index of Laion5B 2 | 3 | 1. Create a python virtual environment & install huggingface_hub & clip-retrieval 4 | - `pip install huggingface_hub clip-retrieval` 5 | 2. Install `aria2` on your system 6 | https://github.com/aria2/aria2 7 | 3. Navigate to your large storage 8 | - `cd /somehwere/with/lots/of/space` 9 | 4. Download the index parts from the hugging-face repository 10 | - `mkdir index-parts && cd index-parts` 11 | - `for i in {00..79}; do aria2c -x 16 https://huggingface.co/datasets/laion/laion5b-h14-index/resolve/main/index-parts/$i.index -o $i.index; done` 12 | - `cd ..` 13 | 5. Combine the index parts using the following command 14 | - `clip-retrieval index_combiner --input_folder "index-parts" --output_folder "combined-indices"` 15 | 6. Now download the metadata parts from the following metadata repos 16 | 17 | - ***multi embeddings*** 18 | - `mkdir multi-embeddings && cd multi-embeddings` 19 | - `for i in {0000..2268}; do aria2c -x 16 https://huggingface.co/datasets/laion/laion2b-multi-vit-h-14-embeddings/resolve/main/metadata/metadata_$i.parquet -o metadata_$i.parquet; done` 20 | - `cd ..` 21 | - ***english embeddings*** 22 | - `mkdir en-embeddings && cd en-embeddings` 23 | - `for i in {0000..2313}; do aria2c -x 16 https://huggingface.co/datasets/laion/laion2b-en-vit-h-14-embeddings/resolve/main/metadata/metadata_$i.parquet -o metadata_$i.parquet; done` 24 | - `cd ..` 25 | - ***nolang embeddings*** 26 | - `mkdir nolang-embeddings && nolang en-embeddings` 27 | - `for i in {0000..1273}; do aria2c -x 16 https://huggingface.co/datasets/laion/laion1b-nolang-vit-h-14-embeddings/resolve/main/metadata/metadata_$i.parquet -o metadata_$i.parquet; done` 28 | - `cd ..` 29 | 30 | 7. Now run the metadata combiner for each of the metadata folders (Warning: ensure all metadata parquet files are present before combining them, or the combined arrow file may be misaligned with the index) 31 | 32 | - ***multi embeddings*** 33 | - `clip-retrieval parquet_to_arrow --parquet_folder="multi-embeddings" --output_arrow_folder="multi-combined" --columns_to_return='["url", "caption"]'` 34 | - ***english embeddings*** 35 | - `clip-retrieval parquet_to_arrow --parquet_folder="en-embeddings" --output_arrow_folder="en-combined" --columns_to_return='["url", "caption"]'` 36 | - ***nolang embeddings*** 37 | - `clip-retrieval parquet_to_arrow --parquet_folder="nolang-embeddings" --output_arrow_folder="nolang-combined" --columns_to_return='["url", "caption"]'` 38 | 39 | 8. Create a parent directory to hold all of the index information 40 | - `mkdir Laion5B_H14 && mkdir Laion5B_H14/metadata && mkdir Laion5B_H14/image.index` 41 | 9. Move all of the metadata `arrow files` to the metadata subfolder of our new parent folder 42 | > **NOTE: in order to maintain the proper ordering, it is important to use the following file names** 43 | - `mv en-combined/0.arrow Laion5B_H14/metadata/0_en.arrow` 44 | - `mv multi-combined/0.arrow Laion5B_H14/metadata/1_multi.arrow` 45 | - `mv nolang-combined/0.arrow Laion5B_H14/metadata/2_nolang.arrow` 46 | 10. Move the files generated from the index combination step into the `image.index` subfolder 47 | - `mv combined-indices/* Laion5B_H14/image.index/` 48 | 11. Create an indices.json file with the following (edit as necessary, more info on parameters in the [Main README](https://github.com/rom1504/clip-retrieval#clip-back)) 49 | 50 | ``` 51 | { 52 | "laion5B-H-14": { 53 | "indice_folder": "Laion5B_H14", 54 | "provide_safety_model": true, 55 | "enable_faiss_memory_mapping": true, 56 | "use_arrow": true, 57 | "enable_hdf5": false, 58 | "reorder_metadata_by_ivf_index": false, 59 | "columns_to_return": ["url", "caption"], 60 | "clip_model": "open_clip:ViT-H-14", 61 | "enable_mclip_option": false 62 | } 63 | } 64 | ``` 65 | -------------------------------------------------------------------------------- /front/.gitignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | package-lock.json 3 | build -------------------------------------------------------------------------------- /front/.npmignore: -------------------------------------------------------------------------------- 1 | node_modules 2 | package-lock.json -------------------------------------------------------------------------------- /front/.npmrc: -------------------------------------------------------------------------------- 1 | package-lock=false -------------------------------------------------------------------------------- /front/README.md: -------------------------------------------------------------------------------- 1 | ## clip-retrieval-front 2 | [![NPM version](https://badge.fury.io/js/clip-retrieval-front.svg)](http://badge.fury.io/js/clip-retrieval-front) 3 | 4 | Easily compute clip embeddings and build a clip retrieval system with them. 100M text+image embeddings can be processed in 20h using a 3080. 5 | 6 | This is the front end for the [clip-retrieval](https://github.com/rom1504/clip-retrieval/) package. 7 | If you arrived straight here, first check [clip-retrieval README](https://github.com/rom1504/clip-retrieval/) 8 | 9 | You can use it at [clip-retrieval ui](https://rom1504.github.io/clip-retrieval/) 10 | 11 | Or you can run it yourself with: 12 | ``` 13 | npm install -g clip-retrieval-front 14 | clip-retrieval-front 3005 15 | ``` 16 | 17 | -------------------------------------------------------------------------------- /front/config.json: -------------------------------------------------------------------------------- 1 | { 2 | "defaultBackend": "https://knn.laion.ai", 3 | "defaultIndex": "laion5B-H-14" 4 | } 5 | -------------------------------------------------------------------------------- /front/package.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "clip-retrieval-front", 3 | "version": "2.31.1", 4 | "description": "Easily compute clip embeddings and build a clip retrieval system with them. 100M text+image embeddings can be processed in 20h using a 3080.", 5 | "main": "server.js", 6 | "dependencies": { 7 | "compression": "^1.7.4", 8 | "express": "^4.18.2" 9 | }, 10 | "repository": { 11 | "type": "git", 12 | "url": "https://github.com/rom1504/clip-retrieval.git" 13 | }, 14 | "bin": { 15 | "clip-retrieval-front": "./server.js" 16 | }, 17 | "devDependencies": { 18 | "@vaadin/vaadin-button": "^23.3.14", 19 | "@vaadin/vaadin-combo-box": "^23.3.14", 20 | "@vaadin/vaadin-select": "^23.3.14", 21 | "@webcomponents/webcomponentsjs": "^2.8.0", 22 | "bootstrap": "^5.2.3", 23 | "clean-webpack-plugin": "^4.0.0", 24 | "copy-webpack-plugin": "^11.0.0", 25 | "css-loader": "^6.8.1", 26 | "dateformat": "^5.0.3", 27 | "file-loader": "^6.2.0", 28 | "html-webpack-plugin": "^5.5.1", 29 | "http-server": "^14.1.1", 30 | "json-bigint": "^1.0.0", 31 | "lit-element": "^3.3.2", 32 | "lit-html": "^2.7.4", 33 | "raw-loader": "^4.0.2", 34 | "standard": "^17.0.0", 35 | "style-loader": "^3.3.3", 36 | "to-string-loader": "^1.2.0", 37 | "webpack": "^5.84.1", 38 | "webpack-cli": "^5.1.1", 39 | "webpack-dev-server": "^4.15.0", 40 | "webpack-merge": "^5.9.0" 41 | }, 42 | "scripts": { 43 | "build": "webpack --env production", 44 | "lint": "standard", 45 | "fix": "standard --fix", 46 | "prod-start": "npm run build && node server.js 3005", 47 | "start": "webpack-dev-server --env development", 48 | "test": "npm run lint" 49 | }, 50 | "author": "Romain Beaumont", 51 | "license": "MIT" 52 | } 53 | -------------------------------------------------------------------------------- /front/server.js: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env node 2 | 3 | const express = require('express') 4 | const compression = require('compression') 5 | const path = require('path') 6 | 7 | // Create our app 8 | const app = express() 9 | 10 | app.use(compression()) 11 | app.use(express.static(path.join(__dirname, './build'))) 12 | 13 | // Start the server 14 | const server = app.listen(process.argv[2] === undefined ? 8080 : process.argv[2], function () { 15 | console.log('Server listening on port ' + server.address().port) 16 | }) 17 | -------------------------------------------------------------------------------- /front/src/assets/download.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/front/src/assets/download.png -------------------------------------------------------------------------------- /front/src/assets/image-search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/front/src/assets/image-search.png -------------------------------------------------------------------------------- /front/src/assets/search.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/front/src/assets/search.png -------------------------------------------------------------------------------- /front/src/clip-front.js: -------------------------------------------------------------------------------- 1 | /* globals customElements, FileReader */ 2 | import { LitElement, html, css } from 'lit-element' 3 | import ClipService from './clip-service' 4 | 5 | class ClipFront extends LitElement { 6 | constructor () { 7 | super() 8 | window.fetch('config.json').then(res => res.json()).then(config => { 9 | this.defaultIndex = config.defaultIndex 10 | this.defaultBackend = config.defaultBackend 11 | this.urlColumn = config.urlColumn || 'url' 12 | this.init() 13 | }) 14 | } 15 | 16 | async init () { 17 | const urlParams = new URLSearchParams(window.location.search) 18 | const back = urlParams.get('back') 19 | const index = urlParams.get('index') 20 | const query = urlParams.get('query') 21 | const useMclip = urlParams.get('useMclip') 22 | const imageUrl = urlParams.get('imageUrl') 23 | if (index != null) { 24 | this.currentIndex = index 25 | } else { 26 | this.currentIndex = back === null || back === this.defaultBackend ? this.defaultIndex : '' 27 | } 28 | if (back != null) { 29 | this.backendHost = back 30 | } else { 31 | this.backendHost = this.defaultBackend 32 | } 33 | if (query != null) { 34 | this.text = query 35 | } else { 36 | this.text = '' 37 | } 38 | if (useMclip != null) { 39 | this.useMclip = useMclip === 'true' 40 | } else { 41 | this.useMclip = false 42 | } 43 | this.service = new ClipService(this.backendHost) 44 | this.numImages = 40 45 | this.numResultIds = 3000 46 | this.lastMetadataId = null 47 | this.onGoingMetadataFetch = false 48 | this.indices = [] 49 | this.images = [] 50 | this.modality = 'image' 51 | this.blacklist = {} 52 | this.lastSearch = 'text' 53 | this.displayCaptions = true 54 | this.displaySimilarities = false 55 | this.displayFullCaptions = false 56 | this.safeMode = true 57 | this.removeViolence = true 58 | this.firstLoad = true 59 | this.imageUrl = imageUrl === null ? undefined : imageUrl 60 | this.hideDuplicateUrls = true 61 | this.hideDuplicateImages = true 62 | this.aestheticScore = '' 63 | this.aestheticWeight = '0.5' 64 | await this.initIndices() 65 | this.postInit() 66 | } 67 | 68 | setBackendToDefault () { 69 | this.backendHost = this.defaultBackend 70 | this.initIndices(true) 71 | } 72 | 73 | async initIndices (forceChange) { 74 | await this.service.getIndices().then(l => { 75 | this.indices = l 76 | if (forceChange || this.currentIndex === '') { 77 | this.currentIndex = this.indices[0] 78 | } 79 | }).catch(e => { 80 | console.error(e) 81 | if (!forceChange) { 82 | this.setBackendToDefault() 83 | } 84 | }) 85 | } 86 | 87 | static get properties () { 88 | return { 89 | service: { type: Object }, 90 | images: { type: Array }, 91 | image: { type: String }, 92 | imageUrl: { type: String }, 93 | text: { type: String }, 94 | numImages: { type: Number }, 95 | modality: { type: String }, 96 | indices: { type: Array }, 97 | currentIndex: { type: String }, 98 | backendHost: { type: String }, 99 | blacklist: { type: Object }, 100 | displaySimilarities: { type: Boolean }, 101 | displayCaptions: { type: Boolean }, 102 | displayFullCaptions: { type: Boolean }, 103 | safeMode: { type: Boolean }, 104 | removeViolence: { type: Boolean }, 105 | hideDuplicateUrls: { type: Boolean }, 106 | hideDuplicateImages: { type: Boolean }, 107 | useMclip: { type: Boolean }, 108 | aestheticWeight: { type: String }, 109 | aestheticScore: { type: String } 110 | } 111 | } 112 | 113 | postInit () { 114 | const searchElem = this.shadowRoot.getElementById('searchBar') 115 | searchElem.addEventListener('keyup', e => { if (e.keyCode === 13) { this.textSearch() } }) 116 | const productsElement = this.shadowRoot.getElementById('products') 117 | window.onscroll = () => { 118 | if ((window.innerHeight + window.pageYOffset) >= productsElement.offsetHeight) { 119 | this.fetchMoreMetadata() 120 | } 121 | } 122 | } 123 | 124 | async initialScroll () { 125 | const productsElement = this.shadowRoot.getElementById('products') 126 | let i = 0 127 | while ((window.innerHeight + window.pageYOffset) >= productsElement.offsetHeight) { 128 | await this.fetchMoreMetadata() 129 | i += 1 130 | if (i > 5) { 131 | break 132 | } 133 | } 134 | } 135 | 136 | updated (_changedProperties) { 137 | if (_changedProperties.has('backendHost')) { 138 | this.service.backend = this.backendHost 139 | this.initIndices(!this.firstLoad) 140 | this.firstLoad = false 141 | this.setUrlParams() 142 | } 143 | if (_changedProperties.has('currentIndex')) { 144 | this.setUrlParams() 145 | } 146 | if (_changedProperties.has('image')) { 147 | if (this.image !== undefined) { 148 | this.imageSearch() 149 | return 150 | } 151 | } 152 | if (_changedProperties.has('imageUrl')) { 153 | if (this.imageUrl !== undefined) { 154 | this.imageUrlSearch() 155 | return 156 | } 157 | } 158 | if (_changedProperties.has('useMclip') || _changedProperties.has('modality') || _changedProperties.has('currentIndex') || 159 | _changedProperties.has('hideDuplicateUrls') || _changedProperties.has('hideDuplicateImages') || _changedProperties.has('safeMode') || 160 | _changedProperties.has('removeViolence') || _changedProperties.has('aestheticScore') || _changedProperties.has('aestheticWeight')) { 161 | if (this.image !== undefined || this.text !== '' || this.imageUrl !== undefined) { 162 | this.redoSearch() 163 | } 164 | } 165 | } 166 | 167 | async redoSearch () { 168 | if (this.lastSearch === 'text') { 169 | this.textSearch() 170 | } else if (this.lastSearch === 'image') { 171 | this.imageSearch() 172 | } else if (this.lastSearch === 'imageUrl') { 173 | this.imageUrlSearch() 174 | } 175 | } 176 | 177 | setUrlParams () { 178 | const urlParams = new URLSearchParams(window.location.search) 179 | if (this.text !== '') { 180 | urlParams.set('query', this.text) 181 | } else { 182 | urlParams.delete('query') 183 | } 184 | if (this.imageUrl !== undefined) { 185 | urlParams.set('imageUrl', this.imageUrl) 186 | } else { 187 | urlParams.delete('imageUrl') 188 | } 189 | urlParams.set('back', this.backendHost) 190 | urlParams.set('index', this.currentIndex) 191 | urlParams.set('useMclip', this.useMclip) 192 | window.history.pushState({}, '', '?' + urlParams.toString()) 193 | } 194 | 195 | async fetchMoreMetadata (amount = 40) { 196 | if (this.onGoingMetadataFetch) { 197 | return 198 | } 199 | this.onGoingMetadataFetch = true 200 | console.log('fetching more metadata starting from position', this.lastMetadataId) 201 | if (this.lastMetadataId === null) { 202 | this.onGoingMetadataFetch = false 203 | return 204 | } 205 | amount = Math.min(amount, this.numResultIds - this.lastMetadataId - 1) 206 | if (amount <= 0) { 207 | this.onGoingMetadataFetch = false 208 | return 209 | } 210 | const ids = this.images.slice(this.lastMetadataId + 1, this.lastMetadataId + amount + 1).map(i => i.id) 211 | try { 212 | const metasWithIds = Object.fromEntries((await this.service.getMetadata(ids, this.currentIndex)).map(({ id, metadata }) => [id, metadata])) 213 | this.images = this.images.map(image => { 214 | if (metasWithIds[image.id] !== undefined) { 215 | image = { ...metasWithIds[image.id], ...image } 216 | } 217 | return image 218 | }) 219 | this.lastMetadataId += amount 220 | } catch (e) { 221 | console.log(e) 222 | } 223 | this.onGoingMetadataFetch = false 224 | } 225 | 226 | callClip (overrideCount = null) { 227 | const text = this.text === undefined ? null : this.text 228 | const image = this.image === undefined ? null : this.image 229 | const imageUrl = this.imageUrl === undefined ? null : this.imageUrl 230 | const numImages = overrideCount === null ? this.numImages : overrideCount 231 | const numResultIds = overrideCount === null ? this.numResultIds : overrideCount 232 | return this.service.callClipService(text, image, imageUrl, null, this.modality, numImages, 233 | this.currentIndex, numResultIds, this.useMclip, this.hideDuplicateImages, this.safeMode, this.removeViolence, this.aestheticScore, this.aestheticWeight) 234 | } 235 | 236 | async download () { 237 | function downloadFile (filename, text) { 238 | const element = document.createElement('a') 239 | element.setAttribute('href', 'data:application/json;charset=utf-8,' + encodeURIComponent(text)) 240 | element.setAttribute('download', filename) 241 | 242 | element.style.display = 'none' 243 | document.body.appendChild(element) 244 | 245 | element.click() 246 | 247 | document.body.removeChild(element) 248 | } 249 | const count = this.modality === 'image' && this.currentIndex === this.indices[0] ? 10000 : 100 250 | const results = await this.callClip(count) 251 | downloadFile('clipsubset.json', JSON.stringify(results, null, 2)) 252 | } 253 | 254 | async textSearch () { 255 | if (this.text === '') { 256 | return 257 | } 258 | this.image = undefined 259 | this.imageUrl = undefined 260 | const results = await this.callClip() 261 | console.log(results) 262 | this.images = results 263 | this.lastMetadataId = Math.min(this.numImages, results.length) - 1 264 | this.lastSearch = 'text' 265 | this.setUrlParams() 266 | setTimeout(() => this.initialScroll(), 0) 267 | } 268 | 269 | async imageSearch () { 270 | this.text = '' 271 | this.imageUrl = undefined 272 | const results = await this.callClip() 273 | console.log(results) 274 | this.images = results 275 | this.lastMetadataId = Math.min(this.numImages, results.length) - 1 276 | this.lastSearch = 'image' 277 | this.setUrlParams() 278 | setTimeout(() => this.initialScroll(), 0) 279 | } 280 | 281 | async imageUrlSearch () { 282 | this.text = '' 283 | this.image = undefined 284 | const results = await this.callClip() 285 | console.log(results) 286 | this.images = results 287 | this.lastMetadataId = Math.min(this.numImages, results.length) - 1 288 | this.lastSearch = 'imageUrl' 289 | this.setUrlParams() 290 | setTimeout(() => this.initialScroll(), 0) 291 | } 292 | 293 | static get styles () { 294 | return css` 295 | input:-webkit-autofill, 296 | input:-webkit-autofill:hover, 297 | input:-webkit-autofill:focus, 298 | input:-webkit-autofill:active { 299 | -webkit-transition: "color 9999s ease-out, background-color 9999s ease-out"; 300 | -webkit-transition-delay: 9999s; 301 | } 302 | 303 | figcaption { 304 | display: table-caption; 305 | caption-side: bottom; 306 | background: #fff; 307 | padding: 0 0px 0px; 308 | } 309 | 310 | #searchBar, #searchBar:hover, #searchBar:focus, #searchBar:valid { 311 | border-radius: 25px; 312 | border-color: #ddd; 313 | background-color:white; 314 | border-width:1px; 315 | width:85%; 316 | padding:15px; 317 | outline: none; 318 | border-style: solid; 319 | font-size:16px; 320 | font-family:arial, sans-serif; 321 | } 322 | #searchBar:hover, #searchBar:focus { 323 | box-shadow: 0px 0px 7px #ccc; 324 | } 325 | #all { 326 | margin-left:2%; 327 | margin-right:2%; 328 | margin-top:2%; 329 | } 330 | #inputSearchBar:hover > #searchBar { 331 | box-shadow: 0px 0px 7px #ccc !important; 332 | } 333 | #download { 334 | width: 22px; 335 | margin-left:0.5%; 336 | vertical-align:middle; 337 | cursor:pointer; 338 | } 339 | #imageSearch { 340 | width: 22px; 341 | margin-left:0.5%; 342 | vertical-align:middle; 343 | cursor:pointer; 344 | } 345 | #textSearch { 346 | width: 22px; 347 | margin-left:1.5%; 348 | vertical-align:middle; 349 | cursor:pointer; 350 | } 351 | .subImageSearch { 352 | width: 22px; 353 | height: 22px: 354 | cursor:pointer; 355 | float:right; 356 | z-index:90; 357 | display:None; 358 | } 359 | .subTextSearch { 360 | width: 22px; 361 | height: 22px: 362 | cursor:pointer; 363 | margin-left:5%; 364 | margin-right:5%; 365 | float:right; 366 | z-index:90; 367 | display:None; 368 | } 369 | figure:hover > .subImageSearch { 370 | display:inline; 371 | cursor:pointer; 372 | } 373 | figure:hover > .subTextSearch { 374 | display:inline; 375 | cursor:pointer; 376 | } 377 | #products { 378 | margin-top:50px; 379 | width:85%; 380 | float:right; 381 | display: inline-grid; 382 | } 383 | @media (min-width: 500px) { 384 | #products { 385 | grid-template-columns: repeat(2, 1fr); 386 | } 387 | } 388 | 389 | @media (min-width: 700px) { 390 | #products{ 391 | grid-template-columns: repeat(4, 1fr); 392 | } 393 | } 394 | 395 | @media (min-width: 1000px) { 396 | #products { 397 | grid-template-columns: repeat(5, 1fr); 398 | } 399 | } 400 | 401 | @media (min-width: 1300px) { 402 | #products { 403 | grid-template-columns: repeat(7, 1fr); 404 | } 405 | } 406 | 407 | @media (min-width: 1600px) { 408 | #products{ 409 | grid-template-columns: repeat(8, 1fr); 410 | } 411 | } 412 | #filter { 413 | position:absolute; 414 | top:20px; 415 | width:12%; 416 | float:left; 417 | } 418 | #searchLine { 419 | margin-left:15%; 420 | } 421 | 422 | figcaption { 423 | font-size:16px; 424 | } 425 | 426 | figure,img.pic,figcaption { 427 | width:150px; 428 | } 429 | 430 | @media (max-width: 500px) { 431 | 432 | #searchBar, #searchBar:hover, #searchBar:focus, #searchBar:valid { 433 | width:60%; 434 | } 435 | #filter { 436 | font-size:14px; 437 | width:100px; 438 | } 439 | 440 | #products { 441 | grid-template-columns: repeat(3, 1fr); 442 | } 443 | figure,img.pic,figcaption { 444 | width:70px; 445 | } 446 | #searchLine { 447 | margin-left:100px; 448 | } 449 | 450 | figcaption { 451 | font-size:12px; 452 | } 453 | 454 | #products { 455 | width:60%; 456 | } 457 | } 458 | 459 | ` 460 | } 461 | 462 | updateImage (file) { 463 | const reader = new FileReader() 464 | reader.readAsDataURL(file) 465 | reader.onload = () => { 466 | this.image = reader.result.split(',')[1] 467 | } 468 | reader.onerror = (error) => { 469 | console.log('Error: ', error) 470 | } 471 | } 472 | 473 | renderImage (image) { 474 | let src 475 | if (image.image !== undefined) { 476 | src = `data:image/png;base64, ${image.image}` 477 | } 478 | if (image[this.urlColumn] !== undefined) { 479 | src = image[this.urlColumn] 480 | } 481 | /* 482 | // useful for testing broken images 483 | const hashCode = s => s.split('').reduce((a,b)=>{a=((a<<5)-a)+b.charCodeAt(0);return a&a},0) 484 | 485 | const disp = hashCode(src) % 2 == 0 486 | src = (disp ? "" : "sss") +src 487 | */ 488 | return html` 489 |
491 | ${this.displaySimilarities ? html`

${(image.similarity).toFixed(4)}

` : ''} 492 | ${image.caption !== undefined 493 | ? html` { this.text = image.caption; this.textSearch() }} />` 494 | : ''} 495 | 496 | { 497 | if (image.image !== undefined) { 498 | this.image = image.image 499 | } else if (image[this.urlColumn] !== undefined) { 500 | this.imageUrl = image[this.urlColumn] 501 | } 502 | }} /> 503 | ${image.caption !== undefined ? image.caption : ''} { this.blacklist = { ...this.blacklist, ...{ [src]: true } } }} /> 506 | 507 | ${this.displayCaptions 508 | ? html`
509 | ${image.caption !== undefined && image.caption.length > 50 && 510 | !this.displayFullCaptions 511 | ? image.caption.substr(0, 50) + '...' 512 | : image.caption}
` 513 | : ''} 514 | 515 | 516 |
517 | ` 518 | } 519 | 520 | filterDuplicateUrls (images) { 521 | const urls = {} 522 | return images.filter(image => { 523 | if (image[this.urlColumn] !== undefined) { 524 | if (urls[image[this.urlColumn]] === undefined) { 525 | urls[image[this.urlColumn]] = true 526 | return true 527 | } 528 | return false 529 | } 530 | return true 531 | }) 532 | } 533 | 534 | render () { 535 | if (this.images === undefined) { 536 | return html`
` 537 | } 538 | const preFiltered = this.images 539 | .filter(image => image.caption !== undefined || image[this.urlColumn] !== undefined || image.image !== undefined) 540 | const filteredImages = this.hideDuplicateUrls ? this.filterDuplicateUrls(preFiltered) : preFiltered 541 | 542 | return html` 543 |
544 |
545 | 546 | { this.text = e.target.value }}/> 547 | { this.textSearch() }} /> 548 | { this.shadowRoot.getElementById('filechooser').click() }} /> 549 | { this.download() }} /> 550 | 551 | this.updateImage(this.shadowRoot.getElementById('filechooser').files[0])}> 552 | 553 | 554 |
555 |
556 | Backend url:
{ this.backendHost = e.target.value }}/>
557 | Index:

559 | ${this.image !== undefined ? html`
` : ''} 560 | ${this.imageUrl !== undefined ? html`
` : ''} 561 | Clip retrieval works by converting the text query to a CLIP embedding 562 | , then using that embedding to query a knn index of clip image embedddings

563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
573 |
574 |
580 | 581 |
582 | ${filteredImages.map(image => this.renderImage(image))} 583 | ${this.safeMode && this.images.length !== 0 && filteredImages.length === 0 ? 'Displaying only nice pictures in safe mode!' : ''} 584 |
585 |
586 | ` 587 | } 588 | } 589 | 590 | customElements.define('clip-front', ClipFront) 591 | -------------------------------------------------------------------------------- /front/src/clip-service.js: -------------------------------------------------------------------------------- 1 | /* globals fetch */ 2 | 3 | import JsonBigint from 'json-bigint' 4 | 5 | export default class ClipService { 6 | constructor (backend) { 7 | this.backend = backend 8 | } 9 | 10 | async getIndices () { 11 | const result = JsonBigint.parse(await (await fetch(this.backend + '/indices-list', { 12 | })).text()) 13 | 14 | return result 15 | } 16 | 17 | async callClipService (text, image, imageUrl, embeddingInput, modality, numImages, indexName, numResultIds, useMclip, hideDuplicateImages, useSafetyModel, useViolenceDetector, aestheticScore, aestheticWeight) { 18 | console.log('calling', text, numImages) 19 | const result = JsonBigint.parse(await (await fetch(this.backend + '/knn-service', { 20 | method: 'POST', 21 | body: JSON.stringify({ 22 | text, 23 | image, 24 | image_url: imageUrl, 25 | embedding_input: embeddingInput, 26 | modality, 27 | num_images: numImages, 28 | indice_name: indexName, 29 | num_result_ids: numResultIds, 30 | use_mclip: useMclip, 31 | deduplicate: hideDuplicateImages, 32 | use_safety_model: useSafetyModel, 33 | use_violence_detector: useViolenceDetector, 34 | aesthetic_score: aestheticScore, 35 | aesthetic_weight: aestheticWeight 36 | }) 37 | })).text()) 38 | 39 | return result 40 | } 41 | 42 | async getMetadata (ids, indexName) { 43 | const result = JsonBigint.parse(await (await fetch(this.backend + '/metadata', { 44 | method: 'POST', 45 | body: JSON.stringify({ 46 | ids, 47 | indice_name: indexName 48 | }) 49 | })).text()) 50 | 51 | return result 52 | } 53 | } 54 | -------------------------------------------------------------------------------- /front/src/index.html: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | Clip front 9 | 10 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | -------------------------------------------------------------------------------- /front/webpack.config.js: -------------------------------------------------------------------------------- 1 | 'use strict' 2 | 3 | const { resolve, join } = require('path') 4 | const { merge } = require('webpack-merge') 5 | const CopyWebpackPlugin = require('copy-webpack-plugin') 6 | const HtmlWebpackPlugin = require('html-webpack-plugin') 7 | const { CleanWebpackPlugin } = require('clean-webpack-plugin') 8 | 9 | const ENV = process.argv.find(arg => arg.includes('production')) 10 | ? 'production' 11 | : 'development' 12 | const OUTPUT_PATH = ENV === 'production' ? resolve('build') : resolve('src') 13 | const INDEX_TEMPLATE = resolve('./src/index.html') 14 | 15 | const webcomponentsjs = './node_modules/@webcomponents/webcomponentsjs' 16 | 17 | const assets = [ 18 | { 19 | from: resolve('./src/assets'), 20 | to: join('assets') 21 | }, 22 | { 23 | from: resolve('./config.json'), 24 | to: join('config.json') 25 | } 26 | ] 27 | 28 | const polyfills = [ 29 | { 30 | from: resolve(`${webcomponentsjs}/webcomponents-*.js`), 31 | to: join(OUTPUT_PATH, 'vendor') 32 | }, 33 | { 34 | from: resolve(`${webcomponentsjs}/bundles/*.js`), 35 | to: join(OUTPUT_PATH, 'vendor', 'bundles') 36 | }, 37 | { 38 | from: resolve(`${webcomponentsjs}/custom-elements-es5-adapter.js`), 39 | to: join(OUTPUT_PATH, 'vendor') 40 | } 41 | ] 42 | 43 | const commonConfig = merge([ 44 | { 45 | entry: './src/clip-front.js', 46 | output: { 47 | path: OUTPUT_PATH, 48 | filename: '[name].[chunkhash:8].js' 49 | }, 50 | module: { 51 | rules: [ 52 | { 53 | test: /\.css$/, 54 | use: [ 55 | 'to-string-loader', 56 | 'css-loader' 57 | ] 58 | }, 59 | { 60 | test: /\.png|\.gif|\.txt$/, 61 | use: [ 62 | { 63 | loader: 'file-loader' 64 | } 65 | ] 66 | } 67 | ] 68 | }, 69 | resolve: { 70 | extensions: ['.js', '.jsx'] 71 | } 72 | } 73 | ]) 74 | 75 | const developmentConfig = merge([ 76 | { 77 | devtool: 'cheap-module-source-map', 78 | plugins: [ 79 | new CopyWebpackPlugin({ patterns: [...polyfills, ...assets] }), 80 | new HtmlWebpackPlugin({ 81 | template: INDEX_TEMPLATE 82 | }) 83 | ], 84 | 85 | devServer: { 86 | compress: true, 87 | port: 3005, 88 | historyApiFallback: true, 89 | host: '0.0.0.0' 90 | } 91 | } 92 | ]) 93 | 94 | const productionConfig = merge([ 95 | { 96 | devtool: 'nosources-source-map', 97 | plugins: [ 98 | new CleanWebpackPlugin(), 99 | new CopyWebpackPlugin({ patterns: [...polyfills, ...assets] }), 100 | new HtmlWebpackPlugin({ 101 | template: INDEX_TEMPLATE, 102 | filename: 'index.html', 103 | minify: { 104 | collapseWhitespace: true, 105 | removeComments: true, 106 | minifyCSS: true, 107 | minifyJS: true 108 | } 109 | }) 110 | ] 111 | } 112 | ]) 113 | 114 | module.exports = mode => { 115 | if (mode.production) { 116 | return merge(commonConfig, productionConfig, { mode: 'production' }) 117 | } 118 | const config = merge(commonConfig, developmentConfig, { mode: 'development' }) 119 | 120 | return config 121 | } 122 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | # Global options: 2 | 3 | [mypy] 4 | python_version = 3.8 5 | ignore_missing_imports = True -------------------------------------------------------------------------------- /notebook/clip-retrieval-getting-started.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 2, 4 | "metadata": { 5 | "language_info": { 6 | "codemirror_mode": { 7 | "name": "ipython", 8 | "version": 3 9 | }, 10 | "file_extension": ".py", 11 | "mimetype": "text/x-python", 12 | "name": "python", 13 | "nbconvert_exporter": "python", 14 | "pygments_lexer": "ipython3", 15 | "version": "3.8.10" 16 | }, 17 | "orig_nbformat": 2, 18 | "kernelspec": { 19 | "name": "python3", 20 | "display_name": "Python 3.8.10 64-bit ('.env': venv)" 21 | }, 22 | "colab": { 23 | "name": "clip-retrieval-getting-started.ipynb", 24 | "provenance": [] 25 | }, 26 | "interpreter": { 27 | "hash": "843de08df30066b821f0437d83317f7e657c9d58c210bb967a72474dd7dcb832" 28 | } 29 | }, 30 | "cells": [ 31 | { 32 | "cell_type": "markdown", 33 | "source": [ 34 | "## Install" 35 | ], 36 | "metadata": { 37 | "id": "uT9FwUjk_lRD" 38 | } 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "source": [ 44 | "!pip install clip-retrieval img2dataset" 45 | ], 46 | "outputs": [], 47 | "metadata": { 48 | "id": "LIJwsAPIjvnX", 49 | "outputId": "a8ea45f0-0cf7-4b5c-836a-c7c4f5bbed1f", 50 | "colab": { 51 | "base_uri": "https://localhost:8080/" 52 | } 53 | } 54 | }, 55 | { 56 | "cell_type": "markdown", 57 | "source": [ 58 | "## Get some image urls" 59 | ], 60 | "metadata": { 61 | "id": "q5-9yk7y_qlW" 62 | } 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "source": [ 68 | "!echo 'https://placekitten.com/200/305' >> myimglist.txt\n", 69 | "!echo 'https://placekitten.com/200/304' >> myimglist.txt\n", 70 | "!echo 'https://placekitten.com/200/303' >> myimglist.txt" 71 | ], 72 | "outputs": [], 73 | "metadata": { 74 | "id": "SA89YmKtjvnX" 75 | } 76 | }, 77 | { 78 | "cell_type": "markdown", 79 | "source": [ 80 | "## Download the image urls" 81 | ], 82 | "metadata": { 83 | "id": "N8Tbn2Kl_t1N" 84 | } 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": null, 89 | "source": [ 90 | "!img2dataset --url_list=myimglist.txt --output_folder=image_folder --thread_count=64 --image_size=256" 91 | ], 92 | "outputs": [], 93 | "metadata": { 94 | "id": "BVZW6noqjvnY", 95 | "colab": { 96 | "base_uri": "https://localhost:8080/" 97 | }, 98 | "outputId": "484db3d4-249a-4d61-f2d8-5a0f1817a1b4" 99 | } 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "source": [ 104 | "## Produce embeddings" 105 | ], 106 | "metadata": { 107 | "id": "FMW4ncir_1Jt" 108 | } 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": null, 113 | "source": [ 114 | "!clip-retrieval inference --input_dataset image_folder --output_folder embedding_folder" 115 | ], 116 | "outputs": [], 117 | "metadata": { 118 | "id": "MvNr8NJRjvnZ", 119 | "colab": { 120 | "base_uri": "https://localhost:8080/" 121 | }, 122 | "outputId": "1484e573-c84d-4593-c7f6-a461b1d516ca" 123 | } 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "source": [ 129 | "!ls -R embedding_folder" 130 | ], 131 | "outputs": [], 132 | "metadata": { 133 | "id": "aBcl0APqjvnZ", 134 | "colab": { 135 | "base_uri": "https://localhost:8080/" 136 | }, 137 | "outputId": "14bda9f6-e687-43e7-d7ee-754048cc0c2e" 138 | } 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "source": [ 143 | "## Produce knn indices" 144 | ], 145 | "metadata": { 146 | "id": "Am62ARgs_3_e" 147 | } 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": null, 152 | "source": [ 153 | "!clip-retrieval index --embeddings_folder=embedding_folder --index_folder=index_folder" 154 | ], 155 | "outputs": [], 156 | "metadata": { 157 | "id": "xwzha2vY6OnP", 158 | "outputId": "63e1d45a-1c13-4f84-8a45-1c7196fb7eb6", 159 | "colab": { 160 | "base_uri": "https://localhost:8080/" 161 | } 162 | } 163 | }, 164 | { 165 | "cell_type": "markdown", 166 | "source": [ 167 | "## Use the index to get a subset of files" 168 | ], 169 | "metadata": { 170 | "id": "gL_X73OY_6TW" 171 | } 172 | }, 173 | { 174 | "cell_type": "code", 175 | "execution_count": null, 176 | "source": [ 177 | "!clip-retrieval filter --query \"cat\" --output_folder \"cat/\" --indice_folder \"index_folder\"" 178 | ], 179 | "outputs": [], 180 | "metadata": { 181 | "id": "COVo6tHQjvnZ", 182 | "colab": { 183 | "base_uri": "https://localhost:8080/" 184 | }, 185 | "outputId": "57c02131-5f3a-417b-fd53-36ef5fef1061" 186 | } 187 | }, 188 | { 189 | "cell_type": "code", 190 | "execution_count": null, 191 | "source": [ 192 | "!ls" 193 | ], 194 | "outputs": [], 195 | "metadata": { 196 | "id": "wmVuLCKmubsI", 197 | "colab": { 198 | "base_uri": "https://localhost:8080/" 199 | }, 200 | "outputId": "64547d42-80cc-45a0-d8c6-320f007a77c8" 201 | } 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "source": [ 207 | "ls -R cat" 208 | ], 209 | "outputs": [], 210 | "metadata": { 211 | "id": "KOdR2ybtjvna", 212 | "colab": { 213 | "base_uri": "https://localhost:8080/" 214 | }, 215 | "outputId": "219d38d5-f4b1-46b6-b178-c5da880431d0" 216 | } 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "source": [ 222 | "from IPython.display import Image\n", 223 | "Image(filename='cat/000000000.jpg') " 224 | ], 225 | "outputs": [], 226 | "metadata": { 227 | "id": "GHtA2Jlajvna", 228 | "colab": { 229 | "base_uri": "https://localhost:8080/", 230 | "height": 273 231 | }, 232 | "outputId": "3448f92e-d3de-48ae-8ba6-807d414d45fb" 233 | } 234 | }, 235 | { 236 | "cell_type": "markdown", 237 | "source": [ 238 | "## Run the knn service backend" 239 | ], 240 | "metadata": { 241 | "id": "tcvl9hog_-Lg" 242 | } 243 | }, 244 | { 245 | "cell_type": "code", 246 | "execution_count": null, 247 | "source": [ 248 | "%%bash\n", 249 | "echo '{\"example_index\": \"index_folder\"}' > indices_paths.json\n", 250 | "npm install -g localtunnel" 251 | ], 252 | "outputs": [], 253 | "metadata": { 254 | "id": "8mKtpVPi6jiZ", 255 | "outputId": "a7382021-c907-44ba-ecf1-4fd864c63089", 256 | "colab": { 257 | "base_uri": "https://localhost:8080/" 258 | } 259 | } 260 | }, 261 | { 262 | "cell_type": "code", 263 | "execution_count": null, 264 | "source": [ 265 | "# after running the next cell, visit the localtunnel url once then go to\n", 266 | "# https://rom1504.github.io/clip-retrieval/?back=" 267 | ], 268 | "outputs": [], 269 | "metadata": { 270 | "id": "cUCDh4cq7RgW" 271 | } 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": null, 276 | "source": [ 277 | "from threading import Thread\n", 278 | "\n", 279 | "def app():\n", 280 | " !clip-retrieval back --port 1234 --indices-paths indices_paths.json\n", 281 | "\n", 282 | "if __name__ == '__main__':\n", 283 | " t1 = Thread(target = app)\n", 284 | " a = t1.start()\n", 285 | " !lt --port 1234" 286 | ], 287 | "outputs": [], 288 | "metadata": { 289 | "id": "q6SaDruy6SOJ", 290 | "outputId": "a544d2c1-d3d0-4267-a12b-c41c3d599e1e", 291 | "colab": { 292 | "base_uri": "https://localhost:8080/" 293 | } 294 | } 295 | } 296 | ] 297 | } 298 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | 2 | [pytest] 3 | log_cli = 1 4 | log_cli_level = DEBUG -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | img2dataset 2 | black==23.12.1 3 | mypy==1.8.0 4 | pylint==3.0.3 5 | pytest-cov==4.1.0 6 | pytest-xdist==3.5.0 7 | pytest==7.4.4 8 | types-setuptools 9 | types-requests 10 | types-certifi 11 | pyspark 12 | deepsparse-nightly[clip] -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | img2dataset>=1.25.5,<2 2 | clip-anytorch>=2.5.0,<3 3 | tqdm>=4.62.3,<5 4 | fire>=0.4.0,<0.6.0 5 | torch>=1.7.1,<3 6 | torchvision>=0.10.1,<2 7 | numpy>=1.19.5,<2 8 | faiss-cpu>=1.7.2,<2 9 | flask>=3.0.0,<4 10 | flask_restful>=0.3.9,<1 11 | flask_cors>=4.0.0,<5 12 | pandas>=1.1.5,<3 13 | pyarrow>=6.0.1,<16 14 | autofaiss>=2.9.6,<3 15 | webdataset>=0.2,<0.3 16 | h5py>=3.1.0,<4 17 | prometheus-client>=0.13.1,<1 18 | fsspec 19 | sentence-transformers>=2.2.0,<3 20 | wandb>=0.12.0,<0.17 21 | open-clip-torch>=2.0.0,<3.0.0 22 | requests>=2.27.1,<3 23 | aiohttp>=3.8.1,<4 24 | multilingual-clip>=1.0.10,<2 25 | transformers 26 | urllib3<2 27 | scipy<1.13 28 | all_clip<2 29 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from pathlib import Path 3 | import os 4 | 5 | if __name__ == "__main__": 6 | with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file: 7 | long_description = file.read() 8 | 9 | import os 10 | 11 | def package_files(directory): 12 | paths = [] 13 | for path, _, filenames in os.walk(directory): 14 | for filename in filenames: 15 | paths.append(os.path.join("..", path, filename)) 16 | return paths 17 | 18 | extra_files = package_files("front/build") 19 | 20 | def _read_reqs(relpath): 21 | fullpath = os.path.join(os.path.dirname(__file__), relpath) 22 | with open(fullpath) as f: 23 | return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))] 24 | 25 | REQUIREMENTS = _read_reqs("requirements.txt") 26 | 27 | setup( 28 | name="clip_retrieval", 29 | packages=find_packages(), 30 | package_data={"": extra_files}, 31 | include_package_data=True, 32 | version="2.44.0", 33 | license="MIT", 34 | description="Easily computing clip embeddings and building a clip retrieval system with them", 35 | long_description=long_description, 36 | long_description_content_type="text/markdown", 37 | entry_points={"console_scripts": ["clip-retrieval = clip_retrieval.cli:main"]}, 38 | author="Romain Beaumont", 39 | author_email="romain.rom1@gmail.com", 40 | url="https://github.com/rom1504/clip-retrieval", 41 | data_files=[ 42 | (".", ["README.md"]), 43 | ], 44 | keywords=["machine learning", "computer vision", "download", "image", "dataset"], 45 | install_requires=REQUIREMENTS, 46 | classifiers=[ 47 | "Development Status :: 4 - Beta", 48 | "Intended Audience :: Developers", 49 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 50 | "License :: OSI Approved :: MIT License", 51 | "Programming Language :: Python :: 3.8", 52 | ], 53 | ) 54 | -------------------------------------------------------------------------------- /tests/test_back.sh: -------------------------------------------------------------------------------- 1 | echo '{"example_index": "/tmp/my_index"}' > indices_paths.json 2 | clip-retrieval back --port 1234 --indices-paths indices_paths.json & 3 | FOO_PID=$! 4 | sleep 10 5 | curl -d '{"text":"cat", "modality":"image", "num_images": 10, "indice_name": "example_index"}' -H "Content-Type: application/json" -X POST http://localhost:1234/knn-service 6 | kill $FOO_PID 7 | -------------------------------------------------------------------------------- /tests/test_clip_client.py: -------------------------------------------------------------------------------- 1 | """Test the ClipClient class.""" 2 | import logging 3 | import pytest 4 | 5 | LOGGER = logging.getLogger(__name__) 6 | LOGGER.info("Test ClipClient.query()") 7 | from clip_retrieval.clip_client import ClipClient, Modality 8 | 9 | test_url = "https://placekitten.com/400/600" 10 | test_caption = "an image of a cat" 11 | test_image_1 = "tests/test_clip_inference/test_images/123_456.jpg" 12 | test_image_2 = "tests/test_clip_inference/test_images/416_264.jpg" 13 | 14 | knn_service_url = "https://knn.laion.ai/knn-service" 15 | 16 | 17 | # NOTE: This test may fail if the backend is down. 18 | @pytest.mark.skip(reason="temporarily skipping this test while laion knn is down") 19 | def test_query(): 20 | """ 21 | Test the ClipClient.query() method. 22 | """ 23 | # Create a client 24 | client = ClipClient( 25 | url=knn_service_url, 26 | indice_name="laion5B-L-14", 27 | use_mclip=False, 28 | aesthetic_score=9, 29 | aesthetic_weight=0.5, 30 | modality=Modality.IMAGE, 31 | num_images=40, 32 | ) 33 | 34 | # test search via text 35 | text_search_results = client.query(text=test_caption) 36 | assert len(text_search_results) > 0, "No results found" 37 | assert "url" in text_search_results[0], "url not found in search results" 38 | assert "caption" in text_search_results[0], "caption not found in search results" 39 | assert "similarity" in text_search_results[0], "similarity not found in search results" 40 | assert "id" in text_search_results[0], "id not found in search results" 41 | LOGGER.info(f"{len(text_search_results)} results found") 42 | LOGGER.info(text_search_results[0]) 43 | 44 | # test search via image 45 | image_search_results = client.query(image=test_image_1) 46 | assert len(image_search_results) > 0, "No results found" 47 | assert "url" in image_search_results[0], "url not found in search results" 48 | assert "caption" in image_search_results[0], "caption not found in search results" 49 | assert "similarity" in image_search_results[0], "similarity not found in search results" 50 | assert "id" in image_search_results[0], "id not found in search results" 51 | LOGGER.info(f"{len(image_search_results)} results found") 52 | LOGGER.info(image_search_results[0]) 53 | 54 | # test search via url of image 55 | image_url_search_results = client.query(image=test_url) 56 | assert len(image_url_search_results) > 0, "No results found" 57 | assert "url" in image_url_search_results[0], "url not found in search results" 58 | assert "caption" in image_url_search_results[0], "caption not found in search results" 59 | assert "similarity" in image_url_search_results[0], "similarity not found in search results" 60 | assert "id" in image_url_search_results[0], "id not found in search results" 61 | LOGGER.info(f"{len(image_url_search_results)} results found") 62 | LOGGER.info(image_url_search_results[0]) 63 | -------------------------------------------------------------------------------- /tests/test_clip_inference/test_distributor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import tempfile 4 | import pytest 5 | 6 | from clip_retrieval.clip_inference.distributor import SequentialDistributor, PysparkDistributor 7 | 8 | 9 | @pytest.mark.parametrize("distributor_kind", ["sequential", "pyspark"]) 10 | def test_distributor(distributor_kind): 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 12 | 13 | with tempfile.TemporaryDirectory() as tmpdir: 14 | current_folder = os.path.dirname(__file__) 15 | input_dataset = os.path.join(current_folder, "test_images") 16 | 17 | worker_args = { 18 | "input_dataset": input_dataset, 19 | "output_folder": tmpdir, 20 | "output_partition_count": 2, 21 | "num_prepro_workers": 6, 22 | "batch_size": 2, 23 | "enable_text": False, 24 | "enable_image": True, 25 | "enable_metadata": False, 26 | } 27 | 28 | tasks = [0, 1] 29 | 30 | if distributor_kind == "sequential": 31 | distributor = SequentialDistributor(tasks=tasks, worker_args=worker_args) 32 | elif distributor_kind == "pyspark": 33 | from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel 34 | 35 | spark = ( 36 | SparkSession.builder.config("spark.driver.memory", "16G") 37 | .master("local[" + str(2) + "]") 38 | .appName("spark-stats") 39 | .getOrCreate() 40 | ) 41 | 42 | distributor = PysparkDistributor(tasks=tasks, worker_args=worker_args) 43 | 44 | distributor() 45 | 46 | with open(os.path.join(tmpdir, "img_emb/img_emb_0.npy"), "rb") as f: 47 | image_embs = np.load(f) 48 | assert image_embs.shape[0] == 4 49 | 50 | with open(os.path.join(tmpdir, "img_emb/img_emb_1.npy"), "rb") as f: 51 | image_embs = np.load(f) 52 | assert image_embs.shape[0] == 3 53 | -------------------------------------------------------------------------------- /tests/test_clip_inference/test_embeddings/0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_embeddings/0.pkl -------------------------------------------------------------------------------- /tests/test_clip_inference/test_embeddings/1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_embeddings/1.pkl -------------------------------------------------------------------------------- /tests/test_clip_inference/test_embeddings/2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_embeddings/2.pkl -------------------------------------------------------------------------------- /tests/test_clip_inference/test_embeddings/3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_embeddings/3.pkl -------------------------------------------------------------------------------- /tests/test_clip_inference/test_get_tasks.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | from clip_retrieval.clip_inference.slurm_worker import get_task_list 4 | 5 | 6 | def test_uneven_tasks(): 7 | """Test task distribution for an uneven distribution of tasks/workers.""" 8 | 9 | world_size = 3 10 | num_tasks = 11 11 | 12 | SOLUTION = { 13 | 0: [0, 1, 2, 3], 14 | 1: [4, 5, 6, 7], 15 | 2: [8, 9, 10], 16 | } 17 | 18 | # Test that the tasks are distributed as evenly as possible 19 | for global_rank in range(world_size): 20 | tasks = get_task_list(num_tasks=num_tasks, world_size=world_size, global_rank=global_rank, local_rank=-1) 21 | assert tasks == SOLUTION[global_rank] 22 | 23 | 24 | def test_even_tasks(): 25 | """Test task distribution for an even distribution of tasks/workers.""" 26 | 27 | world_size = 3 28 | num_tasks = 9 29 | 30 | SOLUTION = { 31 | 0: [0, 1, 2], 32 | 1: [3, 4, 5], 33 | 2: [6, 7, 8], 34 | } 35 | 36 | # Test that the tasks are distributed as evenly as possible 37 | for global_rank in range(world_size): 38 | tasks = get_task_list(num_tasks=num_tasks, world_size=world_size, global_rank=global_rank, local_rank=-1) 39 | assert tasks == SOLUTION[global_rank] 40 | -------------------------------------------------------------------------------- /tests/test_clip_inference/test_images/123_456.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_images/123_456.jpg -------------------------------------------------------------------------------- /tests/test_clip_inference/test_images/208_495.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_images/208_495.jpg -------------------------------------------------------------------------------- /tests/test_clip_inference/test_images/321_421.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_images/321_421.jpg -------------------------------------------------------------------------------- /tests/test_clip_inference/test_images/389_535.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_images/389_535.jpg -------------------------------------------------------------------------------- /tests/test_clip_inference/test_images/416_264.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_images/416_264.jpg -------------------------------------------------------------------------------- /tests/test_clip_inference/test_images/456_123.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_images/456_123.jpg -------------------------------------------------------------------------------- /tests/test_clip_inference/test_images/524_316.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_images/524_316.jpg -------------------------------------------------------------------------------- /tests/test_clip_inference/test_main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import pytest 4 | 5 | import tempfile 6 | from clip_retrieval.clip_inference.main import main 7 | 8 | 9 | def test_main(): 10 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 11 | current_folder = os.path.dirname(__file__) 12 | input_dataset = os.path.join(current_folder, "test_images") 13 | 14 | with tempfile.TemporaryDirectory() as tmpdir: 15 | from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel 16 | 17 | spark = ( 18 | SparkSession.builder.config("spark.driver.memory", "16G") 19 | .master("local[" + str(2) + "]") 20 | .appName("spark-stats") 21 | .getOrCreate() 22 | ) 23 | 24 | main( 25 | input_dataset, 26 | output_folder=tmpdir, 27 | input_format="files", 28 | cache_path=None, 29 | batch_size=8, 30 | num_prepro_workers=8, 31 | enable_text=False, 32 | enable_image=True, 33 | enable_metadata=False, 34 | write_batch_size=4, 35 | wds_image_key="jpg", 36 | wds_caption_key="txt", 37 | clip_model="ViT-B/32", 38 | mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", 39 | use_mclip=False, 40 | use_jit=True, 41 | distribution_strategy="pyspark", 42 | wds_number_file_per_input_file=10000, 43 | output_partition_count=None, 44 | ) 45 | 46 | with open(tmpdir + "/img_emb/img_emb_0.npy", "rb") as f: 47 | image_embs = np.load(f) 48 | assert image_embs.shape[0] == 4 49 | 50 | with open(tmpdir + "/img_emb/img_emb_1.npy", "rb") as f: 51 | image_embs = np.load(f) 52 | assert image_embs.shape[0] == 3 53 | 54 | 55 | # python -m pytest -x -s -v tests -k "test_main_empty_input" 56 | def test_main_empty_input(): 57 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 58 | current_folder = os.path.dirname(__file__) 59 | input_dataset = os.path.join(current_folder, "test_images_empty") 60 | 61 | with tempfile.TemporaryDirectory() as tmpdir, pytest.raises(Exception) as exc_info: 62 | main( 63 | input_dataset, 64 | output_folder=tmpdir, 65 | input_format="files", 66 | cache_path=None, 67 | batch_size=8, 68 | num_prepro_workers=8, 69 | enable_text=False, 70 | enable_image=True, 71 | enable_metadata=False, 72 | write_batch_size=4, 73 | wds_image_key="jpg", 74 | wds_caption_key="txt", 75 | clip_model="ViT-B/32", 76 | mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", 77 | use_mclip=False, 78 | use_jit=True, 79 | distribution_strategy="sequential", 80 | wds_number_file_per_input_file=10000, 81 | output_partition_count=None, 82 | ) 83 | 84 | print(exc_info.traceback) 85 | 86 | assert exc_info.value.args[0] == "no sample found" 87 | assert str(exc_info.value) == "no sample found" 88 | -------------------------------------------------------------------------------- /tests/test_clip_inference/test_mapper.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import pickle 3 | import os 4 | import numpy as np 5 | 6 | from clip_retrieval.clip_inference.mapper import ClipMapper 7 | 8 | 9 | @pytest.mark.parametrize( 10 | "model", 11 | [ 12 | "ViT-B/32", 13 | "open_clip:ViT-B-32/laion2b_s34b_b79k", 14 | "hf_clip:patrickjohncyh/fashion-clip", 15 | "nm:mgoin/CLIP-ViT-B-32-laion2b_s34b_b79k-ds", 16 | ], 17 | ) 18 | def test_mapper(model): 19 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 20 | 21 | mapper = ClipMapper( 22 | enable_image=True, 23 | enable_text=False, 24 | enable_metadata=False, 25 | use_mclip=False, 26 | clip_model=model, 27 | use_jit=True, 28 | mclip_model="", 29 | ) 30 | 31 | current_dir = os.path.dirname(os.path.abspath(__file__)) 32 | tensor_files = [i for i in os.listdir(current_dir + "/test_tensors")] 33 | 34 | for tensor_file in tensor_files: 35 | with open(current_dir + "/test_tensors/{}".format(tensor_file), "rb") as f: 36 | tensor = pickle.load(f) 37 | sample = mapper(tensor) 38 | assert sample["image_embs"].shape[0] == tensor["image_tensor"].shape[0] 39 | assert sample["image_embs"].dtype == np.dtype("float16") 40 | pass 41 | -------------------------------------------------------------------------------- /tests/test_clip_inference/test_reader.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | from clip_retrieval.clip_inference.reader import FilesReader, WebdatasetReader 3 | from clip_retrieval.clip_inference.runner import Sampler 4 | import os 5 | 6 | from all_clip import load_clip 7 | 8 | 9 | @pytest.mark.parametrize("file_format", ["files", "webdataset"]) 10 | def test_reader(file_format): 11 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 12 | current_folder = os.path.dirname(__file__) 13 | if file_format == "files": 14 | input_dataset = current_folder + "/test_images" 15 | else: 16 | tar_folder = current_folder + "/test_tars" 17 | input_dataset = [tar_folder + "/image1.tar", tar_folder + "/image2.tar"] 18 | batch_size = 2 19 | num_prepro_workers = 2 20 | _, preprocess, tokenizer = load_clip(warmup_batch_size=batch_size) 21 | 22 | output_partition_count = 2 23 | actual_values = [] 24 | for output_partition_id in range(output_partition_count): 25 | sampler = Sampler(output_partition_id=output_partition_id, output_partition_count=output_partition_count) 26 | if file_format == "files": 27 | reader = FilesReader( 28 | sampler, 29 | preprocess, 30 | tokenizer, 31 | input_dataset, 32 | batch_size, 33 | num_prepro_workers, 34 | enable_text=False, 35 | enable_image=True, 36 | enable_metadata=False, 37 | ) 38 | elif file_format == "webdataset": 39 | reader = WebdatasetReader( 40 | sampler, 41 | preprocess, 42 | tokenizer, 43 | input_dataset, 44 | batch_size, 45 | num_prepro_workers, 46 | enable_text=False, 47 | enable_image=True, 48 | enable_metadata=False, 49 | ) 50 | vals = [i["image_tensor"].shape[0] for i in reader] 51 | actual_values.append(vals) 52 | 53 | assert actual_values == [[2, 2], [2, 1]] 54 | -------------------------------------------------------------------------------- /tests/test_clip_inference/test_runner.py: -------------------------------------------------------------------------------- 1 | from clip_retrieval.clip_inference.logger import LoggerWriter 2 | from clip_retrieval.clip_inference.runner import Runner 3 | from clip_retrieval.clip_inference.reader import FilesReader 4 | from clip_retrieval.clip_inference.mapper import ClipMapper 5 | from clip_retrieval.clip_inference.writer import NumpyWriter 6 | from all_clip import load_clip 7 | import os 8 | import numpy as np 9 | import tempfile 10 | 11 | 12 | def test_runner(): 13 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 14 | 15 | output_partition_count = 2 16 | num_prepro_workers = 8 17 | batch_size = 8 18 | current_folder = os.path.dirname(__file__) 19 | folder = current_folder + "/test_images" 20 | 21 | with tempfile.TemporaryDirectory() as tmpdir: 22 | 23 | def reader_builder(sampler): 24 | _, preprocess, tokenizer = load_clip(warmup_batch_size=batch_size) 25 | return FilesReader( 26 | sampler, 27 | preprocess, 28 | tokenizer, 29 | folder, 30 | batch_size, 31 | num_prepro_workers, 32 | enable_text=False, 33 | enable_image=True, 34 | enable_metadata=False, 35 | ) 36 | 37 | def mapper_builder(): 38 | return ClipMapper( 39 | enable_image=True, 40 | enable_text=False, 41 | enable_metadata=False, 42 | use_mclip=False, 43 | clip_model="ViT-B/32", 44 | use_jit=True, 45 | mclip_model="", 46 | warmup_batch_size=batch_size, 47 | ) 48 | 49 | def logger_builder(i): 50 | return LoggerWriter( 51 | partition_id=i, 52 | stats_folder=tmpdir + "/stats", 53 | ) 54 | 55 | def writer_builder(i): 56 | return NumpyWriter( 57 | partition_id=i, 58 | output_folder=tmpdir, 59 | enable_text=False, 60 | enable_image=True, 61 | enable_metadata=False, 62 | output_partition_count=output_partition_count, 63 | ) 64 | 65 | runner = Runner( 66 | reader_builder=reader_builder, 67 | mapper_builder=mapper_builder, 68 | writer_builder=writer_builder, 69 | logger_builder=logger_builder, 70 | output_partition_count=output_partition_count, 71 | ) 72 | 73 | runner(0) 74 | 75 | with open(tmpdir + "/img_emb/img_emb_0.npy", "rb") as f: 76 | image_embs = np.load(f) 77 | assert image_embs.shape[0] == 4 78 | -------------------------------------------------------------------------------- /tests/test_clip_inference/test_tars/image1.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_tars/image1.tar -------------------------------------------------------------------------------- /tests/test_clip_inference/test_tars/image2.tar: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_tars/image2.tar -------------------------------------------------------------------------------- /tests/test_clip_inference/test_tensors/0.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_tensors/0.pkl -------------------------------------------------------------------------------- /tests/test_clip_inference/test_tensors/1.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_tensors/1.pkl -------------------------------------------------------------------------------- /tests/test_clip_inference/test_tensors/2.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_tensors/2.pkl -------------------------------------------------------------------------------- /tests/test_clip_inference/test_tensors/3.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/clip-retrieval/ee0931f89c69cf2e39b5187d50a40873b7999d2b/tests/test_clip_inference/test_tensors/3.pkl -------------------------------------------------------------------------------- /tests/test_clip_inference/test_worker.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | 4 | import tempfile 5 | from clip_retrieval.clip_inference.worker import worker 6 | 7 | 8 | def test_worker(): 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 10 | current_folder = os.path.dirname(__file__) 11 | input_dataset = os.path.join(current_folder, "test_images") 12 | 13 | with tempfile.TemporaryDirectory() as tmpdir: 14 | worker( 15 | tasks=[0, 1], 16 | input_dataset=input_dataset, 17 | output_folder=tmpdir, 18 | input_format="files", 19 | output_partition_count=2, 20 | cache_path=None, 21 | batch_size=2, 22 | num_prepro_workers=6, 23 | enable_text=False, 24 | enable_image=True, 25 | enable_metadata=False, 26 | wds_image_key="jpg", 27 | wds_caption_key="txt", 28 | clip_model="ViT-B/32", 29 | mclip_model="sentence-transformers/clip-ViT-B-32-multilingual-v1", 30 | use_mclip=False, 31 | use_jit=True, 32 | clip_cache_path=None, 33 | ) 34 | 35 | with open(tmpdir + "/img_emb/img_emb_0.npy", "rb") as f: 36 | image_embs = np.load(f) 37 | assert image_embs.shape[0] == 4 38 | 39 | with open(tmpdir + "/img_emb/img_emb_1.npy", "rb") as f: 40 | image_embs = np.load(f) 41 | assert image_embs.shape[0] == 3 42 | -------------------------------------------------------------------------------- /tests/test_clip_inference/test_writer.py: -------------------------------------------------------------------------------- 1 | from clip_retrieval.clip_inference.writer import NumpyWriter 2 | import numpy as np 3 | import pickle 4 | import tempfile 5 | import os 6 | 7 | 8 | def test_writer(): 9 | with tempfile.TemporaryDirectory() as tmpdir: 10 | writer = NumpyWriter( 11 | partition_id=0, 12 | output_folder=tmpdir, 13 | enable_text=False, 14 | enable_image=True, 15 | enable_metadata=False, 16 | output_partition_count=1, 17 | ) 18 | current_folder = os.path.dirname(__file__) 19 | embedding_files = [i for i in os.listdir(current_folder + "/test_embeddings")] 20 | expected_shape = 0 21 | for embedding_file in embedding_files: 22 | with open(current_folder + "/test_embeddings/{}".format(embedding_file), "rb") as f: 23 | embedding = pickle.load(f) 24 | expected_shape += embedding["image_embs"].shape[0] 25 | writer(embedding) 26 | writer.flush() 27 | 28 | with open(tmpdir + "/img_emb/img_emb_0.npy", "rb") as f: 29 | image_embs = np.load(f) 30 | assert image_embs.shape[0] == expected_shape 31 | -------------------------------------------------------------------------------- /tests/test_end2end.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "" 4 | 5 | from img2dataset import download 6 | from clip_retrieval import clip_inference 7 | from clip_retrieval import clip_index 8 | import pandas as pd 9 | import shutil 10 | import subprocess 11 | import time 12 | import requests 13 | import logging 14 | 15 | 16 | LOGGER = logging.getLogger(__name__) 17 | 18 | test_list = [ 19 | ["first", "https://placekitten.com/400/600"], 20 | ["second", "https://placekitten.com/200/300"], 21 | ["third", "https://placekitten.com/300/200"], 22 | ["fourth", "https://placekitten.com/400/400"], 23 | ["fifth", "https://placekitten.com/200/200"], 24 | [None, "https://placekitten.com/200/200"], 25 | ] 26 | 27 | 28 | def generate_parquet(output_file): 29 | df = pd.DataFrame(test_list, columns=["caption", "url"]) 30 | df.to_parquet(output_file) 31 | 32 | 33 | def test_end2end(): 34 | current_folder = os.path.dirname(__file__) 35 | test_folder = current_folder + "/" + "test_folder" 36 | if os.path.exists(test_folder): 37 | shutil.rmtree(test_folder) 38 | if not os.path.exists(test_folder): 39 | os.mkdir(test_folder) 40 | url_list_name = os.path.join(test_folder, "url_list") 41 | image_folder_name = os.path.join(test_folder, "images") 42 | 43 | url_list_name += ".parquet" 44 | generate_parquet(url_list_name) 45 | 46 | download( 47 | url_list_name, 48 | image_size=256, 49 | output_folder=image_folder_name, 50 | thread_count=32, 51 | input_format="parquet", 52 | output_format="webdataset", 53 | url_col="url", 54 | caption_col="caption", 55 | ) 56 | 57 | assert os.path.exists(image_folder_name) 58 | 59 | embeddings_folder = os.path.join(test_folder, "embeddings") 60 | 61 | clip_inference( 62 | input_dataset=f"{image_folder_name}/00000.tar", 63 | output_folder=embeddings_folder, 64 | input_format="webdataset", 65 | enable_metadata=True, 66 | write_batch_size=100000, 67 | batch_size=8, 68 | cache_path=None, 69 | ) 70 | 71 | assert os.path.exists(embeddings_folder) 72 | 73 | index_folder = os.path.join(test_folder, "index") 74 | 75 | os.mkdir(index_folder) 76 | 77 | clip_index(embeddings_folder, index_folder=index_folder) 78 | 79 | assert os.path.exists(index_folder + "/image.index") 80 | assert os.path.exists(index_folder + "/text.index") 81 | 82 | indice_path = os.path.join(test_folder, "indices_paths.json") 83 | with open(indice_path, "w") as f: 84 | f.write('{"example_index": "' + index_folder + '"}') 85 | 86 | p = subprocess.Popen( 87 | f"clip-retrieval back --port=1239 --indices_paths='{indice_path}' --enable_mclip_option=False", 88 | shell=True, 89 | stdout=subprocess.PIPE, 90 | ) 91 | for i in range(8): 92 | try: 93 | time.sleep(10) 94 | r = requests.post( 95 | "http://localhost:1239/knn-service", 96 | json={"text": "cat", "modality": "image", "num_images": 10, "indice_name": "example_index"}, 97 | ) 98 | _ = r.json() 99 | assert r.status_code == 200 100 | break 101 | except Exception as e: 102 | if i == 7: 103 | raise e 104 | -------------------------------------------------------------------------------- /tests/test_filter.sh: -------------------------------------------------------------------------------- 1 | clip-retrieval filter --query "cat" --output_folder "cat/" --indice_folder "/tmp/my_index" -------------------------------------------------------------------------------- /tests/test_index.sh: -------------------------------------------------------------------------------- 1 | time clip-retrieval index --embeddings_folder=/tmp/folder/ --index_folder=/tmp/my_index -------------------------------------------------------------------------------- /tests/test_inference.sh: -------------------------------------------------------------------------------- 1 | rm -rf /tmp/folder 2 | time clip-retrieval inference --input_dataset="http://the-eye.eu/eleuther_staging/cah/releases/laion400m/{00000..01000}.tar" --output_folder="/tmp/folder" \ 3 | --input_format "webdataset" --subset_size=100000 --enable_metadata=True --write_batch_size=100000 --batch_size=512 --cache_path=None -------------------------------------------------------------------------------- /tests/test_simple.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | 3 | 4 | def test_basic(): 5 | print("it works !") 6 | --------------------------------------------------------------------------------