├── .github ├── dependabot.yml └── workflows │ ├── ci.yml │ └── python-publish.yml ├── .gitignore ├── .gitpod.DockerFile ├── .gitpod.yml ├── .pylintrc ├── HISTORY.md ├── LICENSE ├── Makefile ├── README.md ├── dataset_examples ├── SBUcaptions.md ├── cc12m.md ├── cc3m.md ├── common_pool.md ├── coyo-700m.md ├── datacomp.md ├── laion-aesthetic.md ├── laion-art.md ├── laion-coco.md ├── laion-face.md ├── laion-high-resolution.md ├── laion400m.md ├── laion5B.md └── mscoco.md ├── doc_assets ├── wandb_metrics.png └── wandb_table.png ├── examples ├── distributed_img2dataset_tutorial.md ├── pyspark_example.py ├── ray_example │ ├── README.md │ ├── cluster_minimal.yaml │ └── ray_example.py └── simple_example.py ├── img2dataset ├── __init__.py ├── architecture.md ├── blurrer.py ├── distributor.py ├── downloader.py ├── logger.py ├── main.py ├── reader.py ├── resizer.py └── writer.py ├── mypy.ini ├── notebook └── img2dataset_getting_started.ipynb ├── requirements-test.txt ├── requirements.txt ├── setup.py └── tests ├── blur_test_files ├── bbox.npy ├── blurred.png ├── original.png ├── resize_border.jpg ├── resize_center_crop.jpg ├── resize_keep_ratio.jpg ├── resize_keep_ratio_largest.jpg ├── resize_no.jpg └── test_bbox.parquet ├── conftest.py ├── fixtures.py ├── http_server.py ├── resize_test_image ├── 123_456.jpg ├── 208_495.jpg ├── 321_421.jpg ├── 389_535.jpg ├── 416_264.jpg ├── 456_123.jpg └── 524_316.jpg ├── test_blurrer.py ├── test_downloader.py ├── test_files ├── benchmark.sh ├── hashes.json ├── large_bench.sh ├── large_bench_tf.sh ├── s3_bench.sh ├── sample_image.parquet ├── sample_image.txt ├── test_1000.parquet ├── test_1000.txt ├── test_10000.parquet ├── test_10000.txt └── unique_images.txt ├── test_main.py ├── test_reader.py ├── test_resizer.py └── test_writer.py /.github/dependabot.yml: -------------------------------------------------------------------------------- 1 | version: 2 2 | updates: 3 | - package-ecosystem: "pip" 4 | directory: "/" 5 | schedule: 6 | interval: "daily" 7 | -------------------------------------------------------------------------------- /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: Continuous integration 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | pull_request: 8 | branches: 9 | - main 10 | 11 | jobs: 12 | lint: 13 | runs-on: ubuntu-latest 14 | steps: 15 | - uses: actions/checkout@v2 16 | - name: Set up Python 3.8 17 | uses: actions/setup-python@v2 18 | with: 19 | python-version: 3.8 20 | - name: Install 21 | run: | 22 | python3 -m venv .env 23 | source .env/bin/activate 24 | python -m pip install -U pip 25 | make install-dev 26 | - name: Lint 27 | run: | 28 | source .env/bin/activate 29 | make lint 30 | pex: 31 | runs-on: ubuntu-latest 32 | steps: 33 | - uses: actions/checkout@v2 34 | - name: Set up Python 35 | uses: actions/setup-python@v2 36 | with: 37 | python-version: '3.8' 38 | - name: Install dependencies 39 | run: | 40 | python -m pip install --upgrade pip 41 | pip install setuptools wheel twine pex 42 | - name: Build pex 43 | run: | 44 | make build-pex 45 | tests: 46 | runs-on: ubuntu-latest 47 | strategy: 48 | matrix: 49 | python-version: [3.8, '3.10'] 50 | 51 | steps: 52 | - uses: actions/checkout@v2 53 | - name: Set up Python ${{ matrix.python-version }} 54 | uses: actions/setup-python@v2 55 | with: 56 | python-version: ${{ matrix.python-version }} 57 | - name: Install 58 | run: | 59 | python3 -m venv .env 60 | source .env/bin/activate 61 | make install 62 | make install-dev 63 | - name: Unit tests 64 | run: | 65 | source .env/bin/activate 66 | ray start --head --disable-usage-stats 67 | ray start --address='127.0.0.1:6379' 68 | make test 69 | 70 | -------------------------------------------------------------------------------- /.github/workflows/python-publish.yml: -------------------------------------------------------------------------------- 1 | name: Release 2 | 3 | on: 4 | push: 5 | branches: 6 | - main 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v2 12 | - uses: actions-ecosystem/action-regex-match@v2 13 | id: regex-match 14 | with: 15 | text: ${{ github.event.head_commit.message }} 16 | regex: '^Release ([^ ]+)' 17 | - name: Set up Python 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: '3.8' 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install setuptools wheel twine pex 25 | - name: Build pex 26 | run: | 27 | make build-pex 28 | - name: Release 29 | if: ${{ steps.regex-match.outputs.match != '' }} 30 | uses: softprops/action-gh-release@v1 31 | with: 32 | files: img2dataset.pex 33 | tag_name: ${{ steps.regex-match.outputs.group1 }} 34 | - name: Build and publish 35 | if: ${{ steps.regex-match.outputs.match != '' }} 36 | env: 37 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 38 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 39 | run: | 40 | python setup.py sdist bdist_wheel 41 | twine upload dist/* 42 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.egg-info 2 | .vscode 3 | .env 4 | output_folder 5 | myimglist.txt 6 | __pycache__ 7 | .envtest 8 | bench 9 | test_folder 10 | .coverage* 11 | .env* 12 | wandb 13 | *.pex 14 | .pexing 15 | build 16 | .hypothesis 17 | -------------------------------------------------------------------------------- /.gitpod.DockerFile: -------------------------------------------------------------------------------- 1 | FROM gitpod/workspace-full:latest 2 | 3 | RUN sudo apt-get update && sudo apt-get install -y python3-opencv 4 | -------------------------------------------------------------------------------- /.gitpod.yml: -------------------------------------------------------------------------------- 1 | image: 2 | file: .gitpod.DockerFile -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | # Specify a configuration file. 4 | #rcfile= 5 | 6 | # Python code to execute, usually for sys.path manipulation such as 7 | # pygtk.require(). 8 | #init-hook= 9 | 10 | # Add files or directories to the blacklist. They should be base names, not 11 | # paths. 12 | ignore=CVS 13 | 14 | # Pickle collected data for later comparisons. 15 | persistent=yes 16 | 17 | # List of plugins (as comma separated values of python modules names) to load, 18 | # usually to register additional checkers. 19 | load-plugins= 20 | 21 | 22 | [MESSAGES CONTROL] 23 | 24 | # Enable the message, report, category or checker with the given id(s). You can 25 | # either give multiple identifier separated by comma (,) or put this option 26 | # multiple time. See also the "--disable" option for examples. 27 | enable=indexing-exception,old-raise-syntax 28 | 29 | # Disable the message, report, category or checker with the given id(s). You 30 | # can either give multiple identifiers separated by comma (,) or put this 31 | # option multiple times (only on the command line, not in the configuration 32 | # file where it should appear only once).You can also use "--disable=all" to 33 | # disable everything first and then reenable specific checks. For example, if 34 | # you want to run only the similarities checker, you can use "--disable=all 35 | # --enable=similarities". If you want to run only the classes checker, but have 36 | # no Warning level messages displayed, use"--disable=all --enable=classes 37 | # --disable=W" 38 | disable=design,similarities,no-self-use,attribute-defined-outside-init,locally-disabled,star-args,pointless-except,bad-option-value,global-statement,fixme,suppressed-message,useless-suppression,locally-enabled,no-member,no-name-in-module,import-error,unsubscriptable-object,unbalanced-tuple-unpacking,undefined-variable,not-context-manager,no-else-return,wrong-import-order,unnecessary-pass,logging-fstring-interpolation,logging-format-interpolation,C0330 39 | 40 | 41 | [REPORTS] 42 | 43 | # Set the output format. Available formats are text, parseable, colorized, msvs 44 | # (visual studio) and html. You can also give a reporter class, eg 45 | # mypackage.mymodule.MyReporterClass. 46 | output-format=text 47 | 48 | # Tells whether to display a full report or only the messages 49 | reports=no 50 | 51 | # Python expression which should return a note less than 10 (10 is the highest 52 | # note). You have access to the variables errors warning, statement which 53 | # respectively contain the number of errors / warnings messages and the total 54 | # number of statements analyzed. This is used by the global evaluation report 55 | # (RP0004). 56 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 57 | 58 | # Template used to display messages. This is a python new-style format string 59 | # used to format the message information. See doc for all details 60 | #msg-template= 61 | 62 | 63 | [TYPECHECK] 64 | 65 | # Tells whether missing members accessed in mixin class should be ignored. A 66 | # mixin class is detected if its name ends with "mixin" (case insensitive). 67 | ignore-mixin-members=yes 68 | 69 | # List of classes names for which member attributes should not be checked 70 | # (useful for classes with attributes dynamically set). 71 | ignored-classes=SQLObject 72 | 73 | # List of members which are set dynamically and missed by pylint inference 74 | # system, and so shouldn't trigger E0201 when accessed. Python regular 75 | # expressions are accepted. 76 | generated-members=REQUEST,acl_users,aq_parent 77 | 78 | # List of decorators that create context managers from functions, such as 79 | # contextlib.contextmanager. 80 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 81 | 82 | 83 | [VARIABLES] 84 | 85 | # Tells whether we should check for unused import in __init__ files. 86 | init-import=no 87 | 88 | # A regular expression matching the beginning of the name of dummy variables 89 | # (i.e. not used). 90 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 91 | 92 | # List of additional names supposed to be defined in builtins. Remember that 93 | # you should avoid to define new builtins when possible. 94 | additional-builtins= 95 | 96 | 97 | [BASIC] 98 | 99 | # Regular expression which should only match correct module names 100 | module-rgx=(([a-z_][a-z0-9_]*)|([A-Z][a-zA-Z0-9]+))$ 101 | 102 | # Regular expression which should only match correct module level names 103 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 104 | 105 | # Regular expression which should only match correct class names 106 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 107 | 108 | # Regular expression which should only match correct function names 109 | function-rgx=^(?:(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 110 | 111 | # Regular expression which should only match correct method names 112 | method-rgx=^(?:(?P__[a-z0-9_]+__|next)|(?P_{0,2}[A-Z][a-zA-Z0-9]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 113 | 114 | # Regular expression which should only match correct instance attribute names 115 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 116 | 117 | # Regular expression which should only match correct argument names 118 | argument-rgx=^[a-z][a-z0-9_]*$ 119 | 120 | # Regular expression which should only match correct variable names 121 | variable-rgx=^[a-z][a-z0-9_]*$ 122 | 123 | # Regular expression which should only match correct attribute names in class 124 | # bodies 125 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 126 | 127 | # Regular expression which should only match correct list comprehension / 128 | # generator expression variable names 129 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 130 | 131 | # Good variable names which should always be accepted, separated by a comma 132 | good-names=main,_ 133 | 134 | # Bad variable names which should always be refused, separated by a comma 135 | bad-names= 136 | 137 | # Regular expression which should only match function or class names that do 138 | # not require a docstring. 139 | no-docstring-rgx=(__.*__|main) 140 | 141 | # Minimum line length for functions/classes that require docstrings, shorter 142 | # ones are exempt. 143 | docstring-min-length=10 144 | 145 | 146 | [FORMAT] 147 | 148 | # Maximum number of characters on a single line. 149 | max-line-length=120 150 | 151 | # Regexp for a line that is allowed to be longer than the limit. 152 | ignore-long-lines=(?x) 153 | (^\s*(import|from)\s 154 | |\$Id:\s\/\/depot\/.+#\d+\s\$ 155 | |^[a-zA-Z_][a-zA-Z0-9_]*\s*=\s*("[^"]\S+"|'[^']\S+') 156 | |^\s*\#\ LINT\.ThenChange 157 | |^[^#]*\#\ type:\ [a-zA-Z_][a-zA-Z0-9_.,[\] ]*$ 158 | |pylint 159 | |""" 160 | |\# 161 | |lambda 162 | |(https?|ftp):) 163 | 164 | # Allow the body of an if to be on the same line as the test if there is no 165 | # else. 166 | single-line-if-stmt=y 167 | 168 | # Maximum number of lines in a module 169 | max-module-lines=99999 170 | 171 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 172 | # tab). 173 | indent-string=' ' 174 | 175 | 176 | [SIMILARITIES] 177 | 178 | # Minimum lines number of a similarity. 179 | min-similarity-lines=4 180 | 181 | # Ignore comments when computing similarities. 182 | ignore-comments=yes 183 | 184 | # Ignore docstrings when computing similarities. 185 | ignore-docstrings=yes 186 | 187 | # Ignore imports when computing similarities. 188 | ignore-imports=no 189 | 190 | 191 | [MISCELLANEOUS] 192 | 193 | # List of note tags to take in consideration, separated by a comma. 194 | notes= 195 | 196 | 197 | [IMPORTS] 198 | 199 | # Deprecated modules which should not be used, separated by a comma 200 | deprecated-modules=regsub,TERMIOS,Bastion,rexec,sets 201 | 202 | # Create a graph of every (i.e. internal and external) dependencies in the 203 | # given file (report RP0402 must not be disabled) 204 | import-graph= 205 | 206 | # Create a graph of external dependencies in the given file (report RP0402 must 207 | # not be disabled) 208 | ext-import-graph= 209 | 210 | # Create a graph of internal dependencies in the given file (report RP0402 must 211 | # not be disabled) 212 | int-import-graph= 213 | 214 | extension-pkg-whitelist=_jsonnet 215 | 216 | 217 | [CLASSES] 218 | 219 | # List of method names used to declare (i.e. assign) instance attributes. 220 | defining-attr-methods=__init__,__new__,setUp 221 | 222 | # List of valid names for the first argument in a class method. 223 | valid-classmethod-first-arg=cls,class_ 224 | 225 | # List of valid names for the first argument in a metaclass class method. 226 | valid-metaclass-classmethod-first-arg=mcs 227 | 228 | 229 | [DESIGN] 230 | 231 | # Maximum number of arguments for function / method 232 | max-args=5 233 | 234 | # Argument names that match this expression will be ignored. Default to name 235 | # with leading underscore 236 | ignored-argument-names=_.* 237 | 238 | # Maximum number of locals for function / method body 239 | max-locals=15 240 | 241 | # Maximum number of return / yield for function / method body 242 | max-returns=6 243 | 244 | # Maximum number of branch for function / method body 245 | max-branches=12 246 | 247 | # Maximum number of statements in function / method body 248 | max-statements=50 249 | 250 | # Maximum number of parents for a class (see R0901). 251 | max-parents=7 252 | 253 | # Maximum number of attributes for a class (see R0902). 254 | max-attributes=7 255 | 256 | # Minimum number of public methods for a class (see R0903). 257 | min-public-methods=2 258 | 259 | # Maximum number of public methods for a class (see R0904). 260 | max-public-methods=20 261 | 262 | 263 | 264 | [TOKENS] 265 | 266 | # Number of spaces of indent required when the last token on the preceding line 267 | # is an open (, [, or {. 268 | indent-after-paren=4 -------------------------------------------------------------------------------- /HISTORY.md: -------------------------------------------------------------------------------- 1 | ## 1.45.0 2 | 3 | * update pyarrow 4 | * add incremental model extend (thanks @edwardguil) 5 | 6 | ## 1.44.1 7 | 8 | * extend fire dep range 9 | 10 | ## 1.44.0 11 | 12 | * Deps update 13 | 14 | ## 1.43.0 15 | 16 | * Remove version restriction for fsspec 17 | 18 | ## 1.42.0 19 | 20 | * ray distibutor (thanks @Vaishaal) 21 | * Remove tmp_dir only if the output dir is not in s3 (thanks @ezzarum) 22 | * support more input formats (thanks @ldfandian) 23 | 24 | ## 1.41.0 25 | 26 | * Verify hashes during download. (thanks @GeorgiosSmyrnis and @carlini) 27 | * opencv-python => opencv-python-headless (thanks @shionhonda) 28 | 29 | ## 1.40.0 30 | 31 | * Add SBU captions benchmark 32 | * Bump ffspec version 33 | * Fix face blurring when padding/cropping 34 | * Add support for other hash functions 35 | 36 | ## 1.39.0 37 | 38 | * Make opt out the default, add warning about ethical issues with slowing down democratization of skills and art. 39 | 40 | ## 1.38.0 41 | 42 | * Incorporate face blurring with bounding boxes. (thanks @GeorgiosSmyrnis) 43 | 44 | ## 1.37.0 45 | 46 | * Add support for resizing with fixed aspect ratio while fixing the largest image dimension (thanks @gabrielilharco) 47 | 48 | ## 1.36.0 49 | 50 | * bumping webdataset version to 0.2.5+ 51 | 52 | ## 1.35.0 53 | 54 | * added max_image_area flag (thanks @sagadre) 55 | 56 | ## 1.34.0 57 | 58 | * Add argument validator in main. 59 | * Respect noai and noimageai directives when downloading image files (thanks @raincoastchris) 60 | * add list of int, float feature in TFRecordSampleWriter (thanks @justHungryMan) 61 | 62 | ## 1.33.0 63 | 64 | * feat: support pyspark < 3 when distributing image-to-dataset job (thanks @nateagr) 65 | 66 | ## 1.32.0 67 | 68 | * feat: support min image size + max aspect ratio (@borisdayma) 69 | 70 | ## 1.31.0 71 | 72 | * feat: allow encoding in different formats (thanks @borisdayma) 73 | 74 | ## 1.30.2 75 | 76 | * Fix error message for incorrect input format 77 | 78 | ## 1.30.1 79 | 80 | * Bug fix: shard id was incorrect when resuming (thanks @lxj616) 81 | 82 | ## 1.30.0 83 | 84 | * Implement shard retrying 85 | 86 | ## 1.29.0 87 | 88 | * Validate input and output format 89 | * Implement incremental mode 90 | 91 | ## 1.28.0 92 | 93 | * use pyarrow in the reader to make it much faster 94 | 95 | ## 1.27.4 96 | 97 | * use 2022.1.0 of fsspec for python3.6 98 | 99 | ## 1.27.3 100 | 101 | * fix fsspec version 102 | 103 | ## 1.27.2 104 | 105 | * fix fsspec version 106 | 107 | ## 1.27.1 108 | 109 | * add gcsfs to pex 110 | 111 | ## 1.27.0 112 | 113 | * buffered writer fix: release ram more often 114 | * feat: accept numpy arrays (thanks @borisdayma) 115 | 116 | ## 1.26.0 117 | 118 | * add tfrecord output format (thanks @borisdayma) 119 | 120 | ## 1.25.6 121 | 122 | * fix an interaction between md5 and exif option 123 | 124 | ## 1.25.5 125 | 126 | * fix dependency ranges 127 | 128 | ## 1.25.4 129 | 130 | * use exifread-nocycle to avoid cycle in exifread 131 | 132 | ## 1.25.3 133 | 134 | * retry whole sharding if it fails 135 | 136 | ## 1.25.2 137 | 138 | * retry writing the shard in reader in case of error 139 | 140 | ## 1.25.1 141 | 142 | * small fix for logger and continuing 143 | * use time instead of perf_counter to measure shard duration 144 | 145 | ## 1.25.0 146 | 147 | * make metadata writer much faster by building the schema in the downloader instead of guessing it 148 | * add new option allowing to disable reencoding 149 | 150 | ## 1.24.1 151 | 152 | * hide opencv warning 153 | 154 | ## 1.24.0 155 | 156 | * force one thread for opencv 157 | * make total logger start time the minimum of workers start time 158 | * add s3fs into the released pex for convenience 159 | * make sharding faster on high latency fs by using a thread pool 160 | 161 | ## 1.23.1 162 | 163 | * fix logger on s3: do not use listing caching in logger 164 | 165 | ## 1.23.0 166 | 167 | * add tutorial on how to setup a spark cluster and use it for distributed img2dataset 168 | better aws s3 support: 169 | * initialize logger fs in subprocess to avoid moving fs over a fork() 170 | * use spawn instead of fork method 171 | 172 | * make total logging more intuitive and convenient by logging every worker return 173 | 174 | ## 1.22.3 175 | 176 | * fix release regex 177 | 178 | ## 1.22.2 179 | 180 | * fix fsspec support by using tmp_dir in main.py 181 | 182 | ## 1.22.1 183 | 184 | * fix pex creation 185 | 186 | ## 1.22.0 187 | 188 | * add option not to write 189 | 190 | ## 1.21.2 191 | 192 | * try catch in the logger for json.load 193 | * prevent error if logger sync is called when no call has been done 194 | * Add a build-pex target in Makefile and CI 195 | 196 | ## 1.21.1 197 | 198 | * decrease default log interval to 5s 199 | 200 | ## 1.21.0 201 | 202 | * add option to retry http download 203 | 204 | ## 1.20.2 205 | 206 | * add original_width by default for a consistent schema 207 | 208 | ## 1.20.1 209 | 210 | * fix relative path handling 211 | 212 | ## 1.20.0 213 | 214 | * Add multi distributor support : multiprocessing and pyspark 215 | 216 | ## 1.19.0 217 | 218 | * make the reader emits file paths instead of samples 219 | 220 | ## 1.18.0 221 | 222 | * use a logger process to make logging distribution friendly, also save json stat files next to folder/tar files 223 | 224 | ## 1.17.0 225 | 226 | * Use fsspec to support all filesystems 227 | 228 | ## 1.16.0 229 | 230 | * implement md5 of images feature 231 | 232 | ## 1.15.1 233 | 234 | * fix null convert in writer 235 | 236 | ## 1.15.0 237 | 238 | * add parquet writer 239 | 240 | ## 1.14.0 241 | 242 | * make reader memory efficient by using feather files 243 | 244 | ## 1.13.0 245 | 246 | * large refactoring of the whole code in submodules 247 | * Enhance image resize processing (esp re downscale) (@rwightman) 248 | 249 | ## 1.12.0 250 | 251 | * handle transparency (thanks @borisdayma) 252 | * add json input file support 253 | 254 | 255 | ## 1.11.0 256 | 257 | * Add support for .tsv.gz files (thanks @robvanvolt) 258 | 259 | ## 1.10.1 260 | 261 | * raise clean exception on image decoding error 262 | * remove the \n in urls for txt inputs 263 | * save the error message when resizing fails in metadata 264 | * add type hints to download function 265 | 266 | ## 1.10.0 267 | 268 | * use semaphores to decrease memory usage 269 | 270 | ## 1.9.9 271 | 272 | * fix an issue with resize_mode "no" 273 | 274 | ## 1.9.8 275 | 276 | * optimize listing files is back, sorted is eager so the iterator returned by iglob is ok 277 | 278 | ## 1.9.7 279 | 280 | * revert last commit, it could cause double iteration on an iterator which can cause surprising behaviors 281 | 282 | ## 1.9.6 283 | 284 | * optimize listing files (thanks @Skylion) 285 | 286 | ## 1.9.5 287 | 288 | * fix a bug affecting downloading multiple files 289 | 290 | ## 1.9.4 291 | 292 | * ensure sharded_images_to_dl is removed from memory at the end of downloading a file 293 | 294 | ## 1.9.3 295 | 296 | * solve the stemming issue: make keys uniques 297 | 298 | ## 1.9.2 299 | 300 | * Save empty caption if caption are none instead of not having the caption file 301 | 302 | ## 1.9.1 303 | 304 | * fix for the new logging feature when cleaning the status dict 305 | 306 | ## 1.9.0 307 | 308 | * wandb support is back 309 | 310 | ## 1.8.5 311 | 312 | * support for python 3.6 313 | 314 | ## 1.8.4 315 | 316 | * convert caption to str before writing 317 | 318 | ## 1.8.3 319 | 320 | * add back timeout properly 321 | 322 | ## 1.8.2 323 | 324 | * fixes 325 | 326 | ## 1.8.1 327 | 328 | * revert wandb for now, code is too complex and there are issues 329 | 330 | ## 1.8.0 331 | 332 | * feat: custom timeout (thanks @borisdayma) 333 | * feat: support wandb (thanks @borisdayma) 334 | 335 | ## 1.7.0 336 | 337 | * use albumentations for resizing (thanks @borisdayma) 338 | 339 | ## 1.6.1 340 | 341 | * depend on pyyaml to be able to use the last webdataset 342 | 343 | ## 1.6.0 344 | 345 | * feat: handle tsv + center crop (thanks @borisdayma) 346 | 347 | ## 1.5.3 348 | 349 | * increase stability by closing the pool and tarwriter explicitly 350 | 351 | ## 1.5.2 352 | 353 | * improve memory usage 354 | 355 | ## 1.5.1 356 | 357 | * glob only input files of the right ext 358 | 359 | ## 1.5.0 360 | 361 | * add a save_additional_columns option 362 | 363 | ## 1.4.0 364 | 365 | * Multiple file support 366 | * Status dataframe 367 | 368 | ## 1.3.0 369 | 370 | * Uses a resizing method less prone to aliasing (thanks @skylion) 371 | * multi processing + multi threading 372 | 373 | ## 1.2.0 374 | 375 | * add webdataset support and benchmarks 376 | * supports reading as parquet and csv 377 | 378 | ## 1.1.1 379 | 380 | * fix cli 381 | 382 | ## 1.1.0 383 | 384 | * add image resizing mode 385 | 386 | ## 1.0.1 387 | 388 | * fixes 389 | 390 | ## 1.0.0 391 | 392 | * it works 393 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Romain Beaumont 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | install: ## [Local development] Upgrade pip, install requirements, install package. 2 | python -m pip install -U pip 3 | python -m pip install -e . 4 | 5 | install-dev: ## [Local development] Install test requirements 6 | python -m pip install -r requirements-test.txt 7 | 8 | lint: ## [Local development] Run mypy, pylint and black 9 | python -m mypy img2dataset 10 | python -m pylint img2dataset 11 | python -m black --check -l 120 . 12 | 13 | black: ## [Local development] Auto-format python code using black 14 | python -m black -l 120 . 15 | 16 | build-pex: 17 | python3 -m venv .pexing 18 | . .pexing/bin/activate && python -m pip install -U pip && python -m pip install pex 19 | . .pexing/bin/activate && python -m pex setuptools scipy==1.9.0 gcsfs s3fs pyspark==3.2.0 requests==2.27.1 . -o img2dataset.pex -v 20 | rm -rf .pexing 21 | 22 | test: ## [Local development] Run unit tests 23 | python -m pytest -x -s -v tests 24 | 25 | .PHONY: help 26 | 27 | help: # Run `make help` to get help on the make commands 28 | @grep -E '^[0-9a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}' 29 | -------------------------------------------------------------------------------- /dataset_examples/SBUcaptions.md: -------------------------------------------------------------------------------- 1 | ## SBU Captions 2 | 3 | [SBU Captions]([https://www.cs.rice.edu/~vo9/sbucaptions/sbu-captions]) is a large-scale dataset that contains 860K image-text pairs as well as many other meta-attributes to increase the usability to train various models. This dataset is one of the key benchmark datasets. 4 | 5 | ### Download the metadata 6 | 7 | 8 | ``` 9 | wget https://www.cs.rice.edu/~vo9/sbucaptions/sbu-captions-all.tar.gz 10 | tar -xvzf sbu-captions-all.tar.gz 11 | 12 | ``` 13 | 14 | ### Download the images with img2dataset 15 | 16 | ``` 17 | img2dataset --url_list sbu-captions-all.json --input_format "json" --url_col "image_urls" --caption_col "captions" --output_format webdataset --output_folder sbucaptions --processes_count 16 --thread_count 64 --image_size 256 18 | ``` 19 | 20 | ### Benchmark 21 | 22 | https://wandb.ai/rom1504/img2dataset/runs/2nhepsmf 23 | 24 | 1000 sample/s using 16 cores 25 | Average bandwidth 500Mb/s ; cpu usage 100% on all cores 26 | Write speed on disk : about 20MB/s average 27 | 28 | -------------------------------------------------------------------------------- /dataset_examples/cc12m.md: -------------------------------------------------------------------------------- 1 | ## CC12M 2 | 3 | [CC12M](https://github.com/google-research-datasets/conceptual-12m) is a dataset of 12 million image and caption. 4 | 5 | 6 | ### Download the metadata 7 | 8 | `wget https://storage.googleapis.com/conceptual_12m/cc12m.tsv` 9 | That's a 2.6GB file 10 | 11 | Add the column names at the top of the file with `sed -i '1s/^/url\tcaption\n/' cc12m.tsv` 12 | 13 | ### Download the images with img2dataset 14 | 15 | Run this command. It will download the cc12m dataset as resized images in the webdataset format. 16 | 17 | ``` 18 | img2dataset --url_list cc12m.tsv --input_format "tsv"\ 19 | --url_col "url" --caption_col "caption" --output_format webdataset\ 20 | --output_folder cc12m --processes_count 16 --thread_count 64 --image_size 256\ 21 | --enable_wandb True 22 | ``` 23 | 24 | ### Benchmark 25 | 26 | https://wandb.ai/rom1504/img2dataset/reports/Download-cc12m-with-img2dataset--VmlldzoxMjIxMTY0 27 | * 630 sample/s : cc12m has a lot of large images so resizing makes cpu the bottleneck 28 | * total: 5h 29 | * output: 331GB 30 | 31 | -------------------------------------------------------------------------------- /dataset_examples/cc3m.md: -------------------------------------------------------------------------------- 1 | ## CC3M 2 | 3 | CC3M is a dataset of 3 million image and caption. 4 | 5 | ### Download the metadata 6 | 7 | Go to https://ai.google.com/research/ConceptualCaptions/download and press download 8 | That's a 500MB tsv file 9 | 10 | Add the column names at the top of the file with `sed -i '1s/^/caption\turl\n/' cc3m.tsv` 11 | 12 | ### Download the images with img2dataset 13 | 14 | Run this command. It will download the cc3m dataset as resized images in the webdataset format. 15 | 16 | ``` 17 | img2dataset --url_list cc3m.tsv --input_format "tsv"\ 18 | --url_col "url" --caption_col "caption" --output_format webdataset\ 19 | --output_folder cc3m --processes_count 16 --thread_count 64 --image_size 256\ 20 | --enable_wandb True 21 | ``` 22 | 23 | ### Benchmark 24 | 25 | https://wandb.ai/rom1504/img2dataset/reports/Download-cc3m-with-img2dataset--VmlldzoxMjE5MTE4 26 | 27 | This dataset has a lot of high resolution images, so this results in about 850 image downloader per second. Overall this takes about one hour. Using a computer with 16 cores, and 2Gbps of bandwidth. 28 | -------------------------------------------------------------------------------- /dataset_examples/common_pool.md: -------------------------------------------------------------------------------- 1 | ## CommonPool 2 | 3 | CommonPool is a dataset with 12.8 billion image-text pairs collected from Common Crawl, and is part of [DataComp](https://github.com/mlfoundations/datacomp), a benchmark for designing multimodal datasets. 4 | See http://datacomp.ai/ and https://arxiv.org/abs/2304.14108 for details. 5 | 6 | Along with the largest pool with 12.8B samples, CommonPool also comes in three smaller versions, containing 12.8M, 128M, and 1.28B samples. 7 | 8 | 9 | ### Downloading CommonPool 10 | 11 | CommonPool can be downloaded using img2dataset by following the instructions on https://github.com/mlfoundations/datacomp/blob/main/download_upstream.py -------------------------------------------------------------------------------- /dataset_examples/coyo-700m.md: -------------------------------------------------------------------------------- 1 | ## COYO-700M 2 | 3 | [COYO-700M](https://github.com/kakaobrain/coyo-dataset) is a large-scale dataset that contains 747M image-text pairs as well as many other meta-attributes to increase the usability to train various models. Our dataset follows a similar strategy to previous vision-and-language datasets, collecting many informative pairs of alt-text and its associated image in HTML documents. We expect COYO to be used to train popular large-scale foundation models complementary to other similar datasets. 4 | 5 | ### Download the metadata 6 | 7 | Download from https://huggingface.co/datasets/kakaobrain/coyo-700m 8 | We are providing a [download guide](https://github.com/kakaobrain/coyo-dataset/tree/main/download) 9 | 10 | ``` 11 | mkdir coyo-700m && cd coyo-700m 12 | for i in {00000..00127}; do wget https://huggingface.co/datasets/kakaobrain/coyo-700m/resolve/main/data/part-$i-17da4908-939c-46e5-91d0-15f256041956-c000.snappy.parquet; done 13 | cd .. 14 | ``` 15 | 16 | ### Download the images with img2dataset 17 | 18 | ``` 19 | img2dataset --url_list coyo-700m --input_format "parquet"\ 20 | --url_col "url" --caption_col "text" --output_format webdataset\ 21 | --output_folder coyo-700m-webdataset --processes_count 16 --thread_count 64 --image_size 384\ 22 | --resize_only_if_bigger=True --resize_mode="keep_ratio" --skip_reencode=True \ 23 | --save_additional_columns '["clip_similarity_vitb32","clip_similarity_vitl14","nsfw_score_opennsfw2","nsfw_score_gantman","watermark_score","aesthetic_score_laion_v2"]' --enable_wandb False 24 | ``` 25 | 26 | ### Benchmark 27 | 28 | -------------------------------------------------------------------------------- /dataset_examples/datacomp.md: -------------------------------------------------------------------------------- 1 | ## DataComp-1B 2 | 3 | DataComp-1B is a dataset with 1.4 billion image-text pairs collected from Common Crawl and subsequently filtered. DataComp-1B is derived from CommonPool, as part of [DataComp](https://github.com/mlfoundations/datacomp), a benchmark for designing multimodal datasets. 4 | DataComp-1B comprises the best performing subset of the `xlarge` version of CommonPool found by [Gadre et al., 2023](https://arxiv.org/abs/2304.14108). 5 | See http://datacomp.ai/ and https://arxiv.org/abs/2304.14108 for details. 6 | 7 | 8 | ### Downloading DataComp-1B 9 | 10 | CommonPool can be downloaded using img2dataset by following the instructions on https://github.com/mlfoundations/datacomp/tree/main#downloading-datacomp-1b -------------------------------------------------------------------------------- /dataset_examples/laion-aesthetic.md: -------------------------------------------------------------------------------- 1 | ## Laion-aesthetic 2 | 3 | Laion aesthetic is a laion5B subset with aesthetic > 7 pwatermark < 0.8 punsafe < 0.5 4 | See [full description](https://github.com/LAION-AI/laion-datasets/blob/main/laion-aesthetic.md) 5 | 6 | It is available at https://huggingface.co/datasets/laion/laion1B-nolang-aesthetic 7 | https://huggingface.co/datasets/laion/laion2B-en-aesthetic 8 | https://huggingface.co/datasets/laion/laion2B-multi-aesthetic 9 | 10 | It has 52M + 51M + 17M samples 11 | 12 | A good use case is to train an image generation model. 13 | 14 | ### Download the metadata 15 | 16 | Download from https://huggingface.co/datasets/laion/laion1B-nolang-aesthetic 17 | https://huggingface.co/datasets/laion/laion2B-en-aesthetic 18 | https://huggingface.co/datasets/laion/laion2B-multi-aesthetic 19 | 20 | ``` 21 | mkdir laion2B-en-aesthetic && cd laion2B-en-aesthetic 22 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-en-aesthetic/resolve/main/part-$i-9230b837-b1e0-4254-8b88-ed2976e9cee9-c000.snappy.parquet; done 23 | cd .. 24 | ``` 25 | 26 | Very similar for laion2B-multi and laion1B-nolang 27 | 28 | Example of copy to s3: 29 | ``` 30 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-en-aesthetic/resolve/main/part-$i-9230b837-b1e0-4254-8b88-ed2976e9cee9-c000.snappy.parquet -O - | aws s3 cp - s3://s-laion/laion-aesthetic/metadata/laion2B-en-aesthetic/part-$i-9230b837-b1e0-4254-8b88-ed2976e9cee9-c000.snappy.parquet; done 31 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-multi-aesthetic/resolve/main/part-$i-41ee6475-31c6-4d39-960e-7dbbe96bc95b-c000.snappy.parquet -O - | aws s3 cp - s3://s-laion/laion-aesthetic/metadata/laion2B-multi-aesthetic/part-$i-41ee6475-31c6-4d39-960e-7dbbe96bc95b-c000.snappy.parquet; done 32 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion1B-nolang-aesthetic/resolve/main/part-$i-604e83c4-a4f2-460a-8aae-1c0fa1d4f6d5-c000.snappy.parquet -O - | aws s3 cp - s3://s-laion/laion-aesthetic/metadata/laion1B-nolang-aesthetic/part-$i-604e83c4-a4f2-460a-8aae-1c0fa1d4f6d5-c000.snappy.parquet; done 33 | ``` 34 | 35 | ### Download the images with img2dataset 36 | 37 | ``` 38 | img2dataset --url_list laion2B-en-aesthetic --input_format "parquet"\ 39 | --url_col "URL" --caption_col "TEXT" --output_format webdataset\ 40 | --output_folder laion2B-en-aesthetic-data --processes_count 16 --thread_count 64 --image_size 384\ 41 | --resize_only_if_bigger=True --resize_mode="keep_ratio" --skip_reencode=True \ 42 | --save_additional_columns '["similarity","hash","punsafe","pwatermark","aesthetic"]' --enable_wandb True 43 | ``` 44 | 45 | ### Benchmark 46 | 47 | -------------------------------------------------------------------------------- /dataset_examples/laion-art.md: -------------------------------------------------------------------------------- 1 | ## Laion-art 2 | 3 | Laion art is a 8M samples laion5B subset with aesthetic > 8 pwatermark < 0.8 punsafe < 0.5 4 | See [full description](https://github.com/LAION-AI/laion-datasets/blob/main/laion-aesthetic.md) 5 | 6 | It is available at https://huggingface.co/datasets/laion/laion-art 7 | 8 | A good use case is to train an image generation model. 9 | 10 | ### Download the metadata 11 | 12 | Download from [https://huggingface.co/datasets/laion/laion1B-nolang-aesthetic 13 | https://huggingface.co/datasets/laion/laion2B-en-aesthetic 14 | https://huggingface.co/datasets/laion/laion2B-multi-aesthetic](https://huggingface.co/datasets/laion/laion-art) 15 | 16 | ``` 17 | wget https://huggingface.co/datasets/laion/laion-art/resolve/main/laion-art.parquet 18 | ``` 19 | 20 | ### Download the images with img2dataset 21 | 22 | ``` 23 | img2dataset --url_list laion-art --input_format "parquet"\ 24 | --url_col "URL" --caption_col "TEXT" --output_format webdataset\ 25 | --output_folder laion-high-resolution --processes_count 16 --thread_count 64 --image_size 384\ 26 | --resize_only_if_bigger=True --resize_mode="keep_ratio" --skip_reencode=True \ 27 | --save_additional_columns '["similarity","hash","punsafe","pwatermark","aesthetic","LANGUAGE"]' --enable_wandb True 28 | ``` 29 | 30 | ### Benchmark 31 | -------------------------------------------------------------------------------- /dataset_examples/laion-coco.md: -------------------------------------------------------------------------------- 1 | ## LAION-COCO 2 | 3 | LAION-COCO is a 600M subset of LAION2B-EN, captioned with an ensemble of BLIP L/14 and 2 CLIP versions (L/14 and RN50x64). 4 | It is available at https://huggingface.co/datasets/laion/laion-coco 5 | 6 | ### Download the metadata 7 | 8 | Download from https://huggingface.co/datasets/laion/laion-coco 9 | 10 | ```bash 11 | mkdir -p laion-coco && cd laion-coco/ 12 | 13 | for i in {0..127}; do 14 | wget "https://huggingface.co/datasets/laion/laion-coco/resolve/main/part-$(printf "%05d" $i)-2256f782-126f-4dc6-b9c6-e6757637749d-c000.snappy.parquet" 15 | done 16 | 17 | cd .. 18 | ``` 19 | 20 | ### Download the images with img2dataset 21 | 22 | ```bash 23 | img2dataset --url_list laion-coco --input_format "parquet"\ 24 | --url_col "URL" --caption_col "TEXT" --output_format webdataset\ 25 | --output_folder laion-coco-output --processes_count 16 --thread_count 64 --image_size 256\ 26 | --resize_only_if_bigger=True --resize_mode="keep_ratio" --skip_reencode=True \ 27 | --save_additional_columns '["similarity","hash","punsafe","pwatermark","top_caption","all_captions","all_similarities"]' --enable_wandb True 28 | ``` 29 | -------------------------------------------------------------------------------- /dataset_examples/laion-face.md: -------------------------------------------------------------------------------- 1 | # Laion-Face 2 | 3 | [LAION-Face](https://github.com/FacePerceiver/LAION-Face) is the human face subset of [LAION-400M](https://laion.ai/laion-400-open-dataset/), it consists of 50 million image-text pairs. Face detection is conducted to find images with faces. Apart from the 50 million full-set(LAION-Face 50M), there is a 20 million sub-set(LAION-Face 20M) for fast evaluation. 4 | 5 | LAION-Face is first used as the training set of [FaRL](https://github.com/FacePerceiver/FaRL), which provides powerful pre-training transformer backbones for face analysis tasks. 6 | 7 | For more details, please check the official repo at https://github.com/FacePerceiver/LAION-Face . 8 | 9 | ## Download and convert metadata 10 | ```bash 11 | wget -l1 -r --no-parent https://the-eye.eu/public/AI/cah/laion400m-met-release/laion400m-meta/ 12 | mv the-eye.eu/public/AI/cah/laion400m-met-release/laion400m-meta/ . 13 | wget https://huggingface.co/datasets/FacePerceiver/laion-face/resolve/main/laion_face_ids.pth 14 | wget https://raw.githubusercontent.com/FacePerceiver/LAION-Face/master/convert_parquet.py 15 | python convert_parquet.py ./laion_face_ids.pth ./laion400m-meta ./laion_face_meta 16 | ``` 17 | 18 | ## Download the images with img2dataset 19 | When metadata is ready, you can start download the images. 20 | 21 | ```bash 22 | wget https://raw.githubusercontent.com/FacePerceiver/LAION-Face/master/download.sh 23 | bash download.sh ./laion_face_meta ./laion_face_data 24 | ``` 25 | 26 | Please be patient, this command might run over days, and cost about 2T disk space, and it will download 50 million image-text pairs as 32 parts. 27 | 28 | - To use the **LAION-Face 50M**, you should use all the 32 parts. 29 | - To use the **LAION-Face 20M**, you should use these parts. 30 | ``` 31 | 0,2,5,8,13,15,17,18,21,22,24,25,28 32 | ``` 33 | 34 | checkout `download.sh` and [img2dataset](https://github.com/rom1504/img2dataset) for more details and parameter setting. 35 | 36 | 37 | 38 | -------------------------------------------------------------------------------- /dataset_examples/laion-high-resolution.md: -------------------------------------------------------------------------------- 1 | ## Laion-high-resolution 2 | 3 | Laion high resolution is a >= 1024x1024 subset of laion5B. 4 | It is available at https://huggingface.co/datasets/laion/laion-high-resolution 5 | It has 170M samples 6 | 7 | A good use case is to train a superresolution model. 8 | 9 | ### Download the metadata 10 | 11 | Download from https://huggingface.co/datasets/laion/laion-high-resolution 12 | 13 | ``` 14 | mkdir -p laion-high-resolution && cd laion-high-resolution 15 | 16 | for i in {0..127}; do 17 | wget "https://huggingface.co/datasets/laion/laion-high-resolution/resolve/main/part-$(printf "%05d" $i)-5d6701c4-b238-4c0a-84e4-fe8e9daea963-c000.snappy.parquet" 18 | done 19 | 20 | cd .. 21 | ``` 22 | 23 | ### Download the images with img2dataset 24 | 25 | ``` 26 | img2dataset --url_list laion-high-resolution --input_format "parquet"\ 27 | --url_col "URL" --caption_col "TEXT" --output_format webdataset\ 28 | --output_folder laion-high-resolution-output --processes_count 16 --thread_count 64 --image_size 1024\ 29 | --resize_only_if_bigger=True --resize_mode="keep_ratio" --skip_reencode=True \ 30 | --save_additional_columns '["similarity","hash","punsafe","pwatermark","LANGUAGE"]' --enable_wandb True 31 | ``` 32 | 33 | ### Benchmark 34 | 35 | https://wandb.ai/rom1504/img2dataset/reports/laion-high-resolution--VmlldzoxOTY0MzA4 36 | 37 | This can be downloaded at 280 sample/s so it takes 7 days to download with one 32 cores 2Gbps machine. 38 | The result is 50TB (high resolution images are big and slow to download!) 39 | -------------------------------------------------------------------------------- /dataset_examples/laion400m.md: -------------------------------------------------------------------------------- 1 | ## laion-400m 2 | 3 | [laion-400m](https://laion.ai/laion-400-open-dataset/) is a 400M image text dataset 4 | 5 | ### Download the metadata 6 | 7 | ``` 8 | wget -l1 -r --no-parent https://the-eye.eu/public/AI/cah/laion400m-met-release/laion400m-meta/ 9 | mv the-eye.eu/public/AI/cah/laion400m-met-release/laion400m-meta/ . 10 | ``` 11 | 12 | ### Download the images with img2dataset 13 | 14 | ``` 15 | img2dataset --url_list laion400m-meta --input_format "parquet"\ 16 | --url_col "URL" --caption_col "TEXT" --output_format webdataset\ 17 | --output_folder laion400m-data --processes_count 16 --thread_count 128 --image_size 256\ 18 | --save_additional_columns '["NSFW","similarity","LICENSE"]' --enable_wandb True 19 | ``` 20 | 21 | ### Benchmark 22 | 23 | This can be downloaded at 1300 sample/s so it takes 3.5 days to download with one 16 cores 2Gbps machine. 24 | The result is 10TB 25 | -------------------------------------------------------------------------------- /dataset_examples/laion5B.md: -------------------------------------------------------------------------------- 1 | ## Laion5B 2 | 3 | Laion5B has 5.86B samples 4 | See https://laion.ai/laion-5b-a-new-era-of-open-large-scale-multi-modal-datasets/ and https://rom1504.medium.com/semantic-search-at-billions-scale-95f21695689a for details. 5 | 6 | ### Download the metadata 7 | 8 | #### Normal 9 | 10 | Download from https://huggingface.co/datasets/laion/laion2B-en https://huggingface.co/datasets/laion/laion2B-multi https://huggingface.co/datasets/laion/laion1B-nolang 11 | 12 | ``` 13 | mkdir laion2B-en && cd laion2B-en 14 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-en/resolve/main/part-$i-5114fd87-297e-42b0-9d11-50f1df323dfa-c000.snappy.parquet; done 15 | cd .. 16 | ``` 17 | 18 | ``` 19 | mkdir laion2B-multi && cd laion2B-multi 20 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-multi/resolve/main/part-$i-fc82da14-99c9-4ff6-ab6a-ac853ac82819-c000.snappy.parquet; done 21 | cd .. 22 | ``` 23 | 24 | ``` 25 | mkdir laion1B-nolang && cd laion1B-nolang 26 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion1B-nolang/resolve/main/part-$i-d6a94da9-d368-4d5b-9ab7-3f6d3c7abdb3-c000.snappy.parquet; done 27 | ``` 28 | 29 | 30 | #### Joined: with punsafe and pwatermark 31 | 32 | You may also choose to download the joined collection that include the pwatermark and punsafe fields: 33 | https://huggingface.co/datasets/laion/laion2B-en-joined https://huggingface.co/datasets/laion/laion2B-multi-joined https://huggingface.co/datasets/laion/laion1B-nolang-joined 34 | 35 | ``` 36 | mkdir laion2B-en && cd laion2B-en 37 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-en-joined/resolve/main/part-$i-4cfd6e30-f032-46ee-9105-8696034a8373-c000.snappy.parquet; done 38 | cd .. 39 | ``` 40 | 41 | ``` 42 | mkdir laion2B-multi && cd laion2B-multi 43 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-multi-joined/resolve/main/part-$i-fcd86c9b-36f4-49ff-bea1-8c9a0e029fb7-c000.snappy.parquet; done 44 | cd .. 45 | ``` 46 | 47 | ``` 48 | mkdir laion1B-nolang && cd laion1B-nolang 49 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion1B-nolang-joined/resolve/main/part-$i-4852663c-9585-44b0-9a45-f95c2b89c792-c000.snappy.parquet; done 50 | ``` 51 | 52 | #### With md5 hashes in addition 53 | 54 | If you want to be extra safe, you may use the collection that contain the md5 hash of the images from a download in may 2022. 55 | 56 | In that case you can use `--compute_hash "md5" --verify_hash '["md5","md5"]'` to automatically drop out the images that do not match theses hashes. 57 | 58 | As of january 2023, that means dropping about 15% of the dataset. Some of those images are actually still good but have been slightly changed by the websites. 59 | 60 | https://huggingface.co/datasets/laion/laion2B-en-md5 61 | 62 | https://huggingface.co/datasets/laion/laion2B-multi-md5 63 | 64 | https://huggingface.co/datasets/laion/laion1B-nolang-md5 65 | 66 | 67 | #### Saving to aws 68 | 69 | If instead of saving to a local folder you prefer saving on eg aws, you may use commands like this: 70 | ``` 71 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-en-joined/resolve/main/part-$i-4cfd6e30-f032-46ee-9105-8696034a8373-c000.snappy.parquet -O - | aws s3 cp - s3://laion5b/metadata/laion2B-en-joined/part-$i-4cfd6e30-f032-46ee-9105-8696034a8373-c000.snappy.parquet; done 72 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion2B-multi-joined/resolve/main/part-$i-fcd86c9b-36f4-49ff-bea1-8c9a0e029fb7-c000.snappy.parquet -O - | aws s3 cp - s3://laion5b/metadata/laion2B-multi-joined/art-$i-fcd86c9b-36f4-49ff-bea1-8c9a0e029fb7-c000.snappy.parquet; done 73 | for i in {00000..00127}; do wget https://huggingface.co/datasets/laion/laion1B-nolang-joined/resolve/main/part-$i-4852663c-9585-44b0-9a45-f95c2b89c792-c000.snappy.parquet -O - | aws s3 cp - s3://laion5b/metadata/laion1B-nolang-joined/part-$i-4852663c-9585-44b0-9a45-f95c2b89c792-c000.snappy.parquet; done 74 | ``` 75 | 76 | You may also decide to put a `&` at the end of the line to download all files in parallel. (in this case expand the for loop in multiple lines for bash syntax reasons) 77 | 78 | ### Download the images 79 | 80 | This one is big so I advise doing it in distributed mode. I followed distributed_img2dataset_tutorial.md. 81 | Note some aws specifics in that guide (in particular regarding VPC and security group configs to allow worker and master to talk together) 82 | Below is some specifics. 83 | 84 | #### What infra 85 | 86 | In practice I advise to rent 1 master node and 10 worker nodes with the instance type c6i.4xlarge (16 intel cores). 87 | That makes it possible to download laion5B in a week. 88 | 89 | Each instance downloads at around 1000 sample/s. 90 | The below config produces a dataset of size 220TB. You can choose to resize to 256 instead to get a 50TB dataset. 91 | 92 | #### Script 93 | 94 | Example of master config: 95 | ``` 96 | ./spark-3.2.0-bin-hadoop3.2/sbin/start-master.sh -p 7077 97 | ``` 98 | 99 | Example of worker config: 100 | ``` 101 | parallel-ssh -l $USER -i -h ips.txt './spark-3.2.0-bin-hadoop3.2/sbin/start-worker.sh -c 16 -m 24G "spark://172.31.46.59:7077"' 102 | ``` 103 | 104 | bash 105 | ``` 106 | aws s3 rm --recursive s3://laion-us-east-1/test_output/ 107 | ./img2dataset.pex download.py 108 | ``` 109 | 110 | 111 | ```python 112 | from img2dataset import download 113 | import shutil 114 | import os 115 | from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel 116 | 117 | from pyspark import SparkConf, SparkContext 118 | 119 | def create_spark_session(): 120 | # this must be a path that is available on all worker nodes 121 | pex_file = "/home/ubuntu/img2dataset.pex" 122 | 123 | os.environ['PYSPARK_PYTHON'] = pex_file 124 | spark = ( 125 | SparkSession.builder 126 | .config("spark.submit.deployMode", "client") \ 127 | #.config("spark.files", pex_file) \ # you may choose to uncomment this option if you want spark to automatically download the pex file, but it may be slow 128 | .config("spark.executorEnv.PEX_ROOT", "./.pex") 129 | #.config("spark.executor.cores", "16") 130 | #.config("spark.cores.max", "48") # you can reduce this number if you want to use only some cores ; if you're using yarn the option name is different, check spark doc 131 | .config("spark.driver.port", "5678") 132 | .config("spark.driver.blockManager.port", "6678") 133 | .config("spark.driver.host", "172.31.44.42") 134 | .config("spark.driver.bindAddress", "172.31.44.42") 135 | .config("spark.executor.memory", "16G") # make sure to increase this if you're using more cores per executor 136 | .config("spark.executor.memoryOverhead", "8G") 137 | .config("spark.task.maxFailures", "100") 138 | .master("spark://172.31.44.42:7077") # this should point to your master node, if using the tunnelling version, keep this to localhost 139 | .appName("spark-stats") 140 | .getOrCreate() 141 | ) 142 | return spark 143 | 144 | spark = create_spark_session() 145 | 146 | url_list = "s3://laion-us-east-1/laion-metadata/laion2B-en/" 147 | output_dir = "s3://laion-us-east-1/laion-data/laion2B-data" 148 | 149 | download( 150 | processes_count=1, 151 | thread_count=64, 152 | url_list = url_list, 153 | image_size=384, 154 | resize_only_if_bigger=True, 155 | resize_mode="keep_ratio", 156 | skip_reencode=True, 157 | output_folder=output_dir, 158 | output_format="webdataset", 159 | input_format="parquet", 160 | url_col="URL", 161 | caption_col="TEXT", 162 | enable_wandb=True, 163 | number_sample_per_shard=10000, 164 | distributor="pyspark", 165 | save_additional_columns=["NSFW","similarity","LICENSE"], 166 | oom_shard_count=6, 167 | ) 168 | ``` 169 | 170 | Will result in : 171 | ``` 172 | Total Objects: 694047 173 | Total Size: 84.8 TiB 174 | ``` 175 | 176 | Same config for laion2B-multi and laion1B-nolang 177 | -------------------------------------------------------------------------------- /dataset_examples/mscoco.md: -------------------------------------------------------------------------------- 1 | ## mscoco 2 | 3 | [mscoco](https://academictorrents.com/details/74dec1dd21ae4994dfd9069f9cb0443eb960c962) train split is a dataset of 600 thousands image and caption. 4 | 5 | 6 | ### Download the metadata 7 | 8 | `wget https://huggingface.co/datasets/ChristophSchuhmann/MS_COCO_2017_URL_TEXT/resolve/main/mscoco.parquet` 9 | That's a 18M file. It contains the train split from [mscoco](https://academictorrents.com/details/74dec1dd21ae4994dfd9069f9cb0443eb960c962) 10 | 11 | 12 | ### Download the images with img2dataset 13 | 14 | Run this command. It will download the mscoco dataset as resized images in the webdataset format. 15 | 16 | ``` 17 | img2dataset --url_list mscoco.parquet --input_format "parquet"\ 18 | --url_col "URL" --caption_col "TEXT" --output_format webdataset\ 19 | --output_folder mscoco --processes_count 16 --thread_count 64 --image_size 256\ 20 | --enable_wandb True 21 | ``` 22 | 23 | ### Benchmark 24 | 25 | https://wandb.ai/rom1504/img2dataset/reports/MSCOCO--VmlldzoxMjczMTkz 26 | * 800 sample/s 27 | * total: 10min 28 | * output: 20GB 29 | 30 | -------------------------------------------------------------------------------- /doc_assets/wandb_metrics.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/doc_assets/wandb_metrics.png -------------------------------------------------------------------------------- /doc_assets/wandb_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/doc_assets/wandb_table.png -------------------------------------------------------------------------------- /examples/distributed_img2dataset_tutorial.md: -------------------------------------------------------------------------------- 1 | # Distributed img2dataset tutorial 2 | 3 | Img2dataset can be used on a single machine to download and resize at around 100 sample/s/core. 4 | For large node, that has been measure to go up to 4000 samples/s (with 40 cores). 5 | 6 | However, what if you have billion of samples and you don't want to wait weeks ? 7 | 8 | To support that use case, img2dataset proposes to use multiple machines by setting up a pyspark cluster. 9 | This document will help you setup such a cluster and run img2dataset on it. 10 | 11 | ## Where to get a cluster, what machines to use? 12 | 13 | These providers have been tested to work well with img2dataset: 14 | * aliyun small 2 cores nodes ($4.5/month for 40 sample/s) 15 | * aws c6i.4xlarge nodes ($0.68/h for 1000 sample/s) 16 | * Databricks AWS r5.2xlarge nodes ($0.504/h for 1000 sample/s) 17 | 18 | Ubuntu 20.04 works well with img2dataset. Centos7 also works. 19 | Other providers probably work too but haven't been tested. 20 | 21 | ## Setting up a pyspark cluster 22 | 23 | ### You already got a cluster 24 | 25 | That option is of course the best. If you have an existing on-premise cluster, or you're using a cloud cluster like amazon emr, then you're all set, go directly to the use img2dataset section. 26 | You may want to put https://github.com/rom1504/img2dataset/releases/latest/download/img2dataset.pex in a place that is available to all your nodes. 27 | 28 | ### You don't have a cluster, but you have access to N machines over ssh 29 | 30 | That's a common case, you have access to N machines, and you have a place to store the data. 31 | This is actually fairly easy to use this to setup a pyspark cluster. Let's see how to do it. 32 | 33 | Tools: 34 | * spark and pyspark 35 | * parallel ssh 36 | * pex 37 | 38 | We will be assuming ubuntu 20.04. 39 | 40 | 41 | #### Setup the master node 42 | 43 | On the master node: 44 | 45 | First download spark: 46 | ```bash 47 | wget https://archive.apache.org/dist/spark/spark-3.2.0/spark-3.2.0-bin-hadoop3.2.tgz 48 | tar xf spark-3.2.0-bin-hadoop3.2.tgz 49 | ``` 50 | 51 | Then download img2dataset: 52 | ```bash 53 | wget https://github.com/rom1504/img2dataset/releases/latest/download/img2dataset.pex -O img2dataset.pex 54 | ``` 55 | 56 | If the master node cannot open ports that are visible from your local machine, you can do a tunnel between your local machine and the master node to be able to see the spark ui (at http://localhost:8080) 57 | ```bash 58 | ssh -L 8080:localhost:8080 -L 4040:localhost:4040 master_node 59 | ``` 60 | 61 | 62 | #### Setup the worker nodes 63 | 64 | ##### ssh basic setup 65 | 66 | Still in the master node, create a ips.txt with the ips of all the nodes 67 | 68 | ```bash 69 | ssh-keyscan `cat ips.txt` >> ~/.ssh/known_hosts 70 | ``` 71 | 72 | You may use a script like this to fill your .ssh/config file 73 | ``` 74 | def generate(ip): 75 | print( 76 | f"Host {ip}\n" 77 | f" HostName {ip}\n" 78 | " User ubuntu\n" 79 | " IdentityFile ~/yourkey.pem" 80 | ) 81 | 82 | with open("ips.txt") as f: 83 | lines = f.readlines() 84 | for line in lines: 85 | generate(line.strip()) 86 | ``` 87 | python3 generate.py >> ~/.ssh/config 88 | 89 | Install pssh with `sudo apt install pssh` 90 | 91 | Pick the right username (MASTER_USER) for the master node, and (USER) for the worker nodes, then run this to check your parallel ssh setup: 92 | ```bash 93 | MASTER_USER=rom1504 94 | USER=rom1504 95 | ``` 96 | 97 | Optionally, if another node than the current one has access to the worker nodes, you may need to add a ssh key to all the nodes with: 98 | ``` 99 | for IP in `cat ips.txt` 100 | do 101 | ssh-copy-id -i the_new_id_rsa $USER@$IP 102 | done 103 | ``` 104 | 105 | Check you can connect to all the nodes with: 106 | ``` 107 | parallel-ssh -l $USER -i -h ips.txt uname -a 108 | ``` 109 | 110 | ##### Install some packages 111 | 112 | ```bash 113 | sudo apt update 114 | sudo apt install openjdk-11-jre-headless libgl1 htop tmux bwm-ng sshfs -y 115 | ``` 116 | 117 | ```bash 118 | parallel-ssh -l $USER -i -h ips.txt "sudo apt update" 119 | parallel-ssh -l $USER -i -h ips.txt "sudo apt install openjdk-11-jre-headless libgl1 htop tmux bwm-ng sshfs -y" 120 | ``` 121 | 122 | 123 | #### Network setting 124 | 125 | on master: 126 | ```bash 127 | sudo sh -c 'echo `hostname -I` `hostname` >> /etc/hosts' 128 | ``` 129 | 130 | on workers 131 | ```bash 132 | parallel-ssh -l $USER -i -h ips.txt "sudo sh -c 'echo \`hostname -I\` \`hostname\` >> /etc/hosts'" 133 | ``` 134 | 135 | 136 | ### Install knot resolver 137 | 138 | ```bash 139 | parallel-ssh -l $USER -i -h ips.txt "sudo apt update && sudo apt install libgl1 htop tmux bwm-ng python3.8-venv awscli -y" 140 | parallel-ssh -l $USER -i -h ips.txt "wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb && sudo dpkg -i knot-resolver-release.deb && sudo apt update && sudo apt install -y knot-resolver" 141 | ``` 142 | 143 | ```bash 144 | parallel-ssh -l $USER -i -h ips.txt "sudo systemctl stop systemd-resolved" 145 | parallel-ssh -l $USER -i -h ips.txt "sudo systemctl start kresd@{1..4}.service" 146 | parallel-ssh -l $USER -i -h ips.txt 'sudo sh -c "echo nameserver 127.0.0.1 > /etc/resolv.conf"' 147 | parallel-ssh -l $USER -i -h ips.txt 'dig @localhost google.com' 148 | ``` 149 | 150 | 151 | ##### Download img2dataset on all nodes 152 | 153 | Download img2dataset on all node by retrying this N times until parallel ssh says success for all: 154 | ```bash 155 | parallel-ssh -i -h ips.txt "wget -c https://github.com/rom1504/img2dataset/releases/latest/download/img2dataset.pex -O img2dataset_new.pex" 156 | ``` 157 | Then: 158 | ```bash 159 | parallel-ssh -l $USER -i -h ips.txt "mv img2dataset_new.pex img2dataset.pex" 160 | parallel-ssh -l $USER -i -h ips.txt "chmod +x img2dataset.pex" 161 | ``` 162 | 163 | ##### Download spark on workers 164 | 165 | ```bash 166 | parallel-ssh -l $USER -i -h ips.txt "wget https://archive.apache.org/dist/spark/spark-3.2.0/spark-3.2.0-bin-hadoop3.2.tgz" 167 | parallel-ssh -l $USER -i -h ips.txt "tar xf spark-3.2.0-bin-hadoop3.2.tgz" 168 | ``` 169 | 170 | #### Start the master node 171 | 172 | When you're ready, you can start the master node with: 173 | 174 | ```bash 175 | ./spark-3.2.0-bin-hadoop3.2/sbin/start-master.sh -h master_node -p 7077 176 | ``` 177 | 178 | Replace master_node by the master node ip. 179 | 180 | 181 | #### Start the worker nodes 182 | 183 | When you're ready, you can start the worker nodes with: 184 | 185 | ```bash 186 | parallel-ssh -l $USER -i -h ips.txt "./spark-3.2.0-bin-hadoop3.2/sbin/start-worker.sh -c 16 -m 16G spark://master_node:7077" 187 | ``` 188 | 189 | Replace master_node by the master node ip. 190 | Replace -c 16 -m 16g but the number of cores and ram you want to use on each worker. 191 | 192 | 193 | #### Stop the worker nodes 194 | 195 | When you're done, you can stop the worker nodes with: 196 | 197 | ```bash 198 | parallel-ssh -l $USER -i -h ips.txt "rm -rf ~/spark-3.2.0-bin-hadoop3.2/work/*" 199 | pkill -f "ssh -R" 200 | parallel-ssh -l $USER -i -h ips.txt "pkill java" 201 | ``` 202 | 203 | 204 | #### Stop the master node 205 | 206 | When you're done, you can stop the master node with: 207 | 208 | ```bash 209 | pkill java 210 | ``` 211 | 212 | 213 | ### Running img2dataset on it 214 | 215 | Once your spark cluster is setup, you're ready to start img2dataset in distributed mode. 216 | Make sure to open your spark UI, at http://master_node:8080 217 | 218 | Save this script to download.py. 219 | 220 | Then run ./img2dataset.pex download.py 221 | 222 | Replace master_node by the master node ip. 223 | 224 | ```python 225 | from img2dataset import download 226 | import shutil 227 | import os 228 | from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel 229 | 230 | from pyspark import SparkConf, SparkContext 231 | 232 | def create_spark_session(): 233 | # this must be a path that is available on all worker nodes 234 | pex_file = "/home/rom1504/img2dataset.pex" 235 | 236 | os.environ['PYSPARK_PYTHON'] = pex_file 237 | spark = ( 238 | SparkSession.builder 239 | .config("spark.submit.deployMode", "client") \ 240 | #.config("spark.files", pex_file) \ # you may choose to uncomment this option if you want spark to automatically download the pex file, but it may be slow 241 | .config("spark.executorEnv.PEX_ROOT", "./.pex") 242 | #.config("spark.executor.cores", "2") # this can be set to the number of cores of the machine 243 | #.config("spark.cores.max", "200") # total number of cores to use over the whole spark cluster 244 | .config("spark.driver.port", "5678") 245 | .config("spark.driver.blockManager.port", "6678") 246 | .config("spark.driver.host", "master_node") 247 | .config("spark.driver.bindAddress", "master_node") 248 | .config("spark.executor.memory", "16GB") # make sure to increase this if you're using more cores per executor 249 | .config("spark.executor.memoryOverhead", "8GB") 250 | .config("spark.task.maxFailures", "100") 251 | .master("spark://master_node:7077") # this should point to your master node, if using the tunnelling version, keep this to localhost 252 | .appName("spark-stats") 253 | .getOrCreate() 254 | ) 255 | return spark 256 | 257 | output_dir = "/tmp/bench" 258 | 259 | 260 | spark = create_spark_session() 261 | 262 | url_list = "some_file.parquet" 263 | 264 | download( 265 | processes_count=1, # this is not used with spark, instead one task for each core will be started (nb executor * nb core per executor) 266 | thread_count=32, 267 | retries=0, 268 | url_list = url_list, 269 | image_size=384, 270 | resize_only_if_bigger=True, 271 | resize_mode="keep_ratio", 272 | skip_reencode=True, 273 | output_folder=output_dir, 274 | output_format="webdataset", 275 | input_format="parquet", 276 | url_col="URL", 277 | caption_col="TEXT", 278 | enable_wandb=False, 279 | number_sample_per_shard=10000, 280 | distributor="pyspark", 281 | save_additional_columns=["NSFW","similarity","LICENSE"] 282 | ) 283 | ``` 284 | 285 | ### You have Databricks access 286 | 287 | This [notebook](https://smellslike.ml/extras/Download_LAION_with_Databricks.html) by [@smellslikeml](https://github.com/smellslikeml/) shows how to use a Databricks's managed spark cluster. It includes the network optimizations suggested [here](https://github.com/rom1504/img2dataset#setting-up-a-high-performance-dns-resolver). 288 | -------------------------------------------------------------------------------- /examples/pyspark_example.py: -------------------------------------------------------------------------------- 1 | from img2dataset import download 2 | import shutil 3 | import os 4 | from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel 5 | 6 | output_dir = os.path.abspath("bench") 7 | 8 | if os.path.exists(output_dir): 9 | shutil.rmtree(output_dir) 10 | 11 | spark = ( 12 | SparkSession.builder.config("spark.driver.memory", "16G").master("local[16]").appName("spark-stats").getOrCreate() 13 | ) 14 | 15 | download( 16 | processes_count=16, 17 | thread_count=32, 18 | url_list="../tests/test_files/test_10000.parquet", 19 | image_size=256, 20 | output_folder=output_dir, 21 | output_format="webdataset", 22 | input_format="parquet", 23 | url_col="URL", 24 | caption_col="TEXT", 25 | enable_wandb=True, 26 | number_sample_per_shard=1000, 27 | distributor="pyspark", 28 | ) 29 | 30 | # rm -rf bench 31 | -------------------------------------------------------------------------------- /examples/ray_example/README.md: -------------------------------------------------------------------------------- 1 | ## Instructions for running a large img2dataset job on a ray cluster on AWS 2 | First install ray: 3 | ``` pip install ray ``` 4 | 5 | If you are on AWS you can spin up a ray cluster this way: 6 | 7 | ``` ray up cluster_minimal.yaml ``` 8 | 9 | Then you can run your job: 10 | ```ray submit cluster_minimal.yaml ray_example.py -- --url_list --out_folder ``` 11 | 12 | You may also setup a ray cluster by following https://docs.ray.io/en/latest/cluster/getting-started.html 13 | 14 | -------------------------------------------------------------------------------- /examples/ray_example/cluster_minimal.yaml: -------------------------------------------------------------------------------- 1 | # An unique identifier for the head node and workers of this cluster. 2 | cluster_name: minimal 3 | min_workers: 0 4 | max_workers: 10 5 | upscaling_speed: 1.0 6 | available_node_types: 7 | ray.head.default: 8 | resources: {} 9 | node_config: 10 | ImageId: ami-0ea1c7db66fee3098 11 | InstanceType: m5.24xlarge 12 | # if you have an IamInstanceProfile fill it out here... 13 | #IamInstanceProfile: 14 | # Arn: 15 | ray.worker.default: 16 | min_workers: 0 17 | max_workers: 500 18 | node_config: 19 | ImageId: ami-0ea1c7db66fee3098 20 | InstanceType: m5.24xlarge 21 | InstanceMarketOptions: 22 | MarketType: spot 23 | # if you have an IamInstanceProfile fill it out here... 24 | #IamInstanceProfile: 25 | # Arn: 26 | 27 | # Cloud-provider specific configuration. 28 | provider: 29 | type: aws 30 | region: us-east-1 31 | 32 | initialization_commands: 33 | - wget https://secure.nic.cz/files/knot-resolver/knot-resolver-release.deb 34 | - sudo dpkg -i knot-resolver-release.deb 35 | - sudo apt update 36 | - sudo apt install -y knot-resolver 37 | - sudo sh -c 'echo `hostname -I` `hostname` >> /etc/hosts' 38 | - sudo sh -c 'echo nameserver 127.0.0.1 > /etc/resolv.conf' 39 | - sudo systemctl stop systemd-resolved 40 | - sudo systemctl start kresd@1.service 41 | - sudo systemctl start kresd@2.service 42 | - sudo systemctl start kresd@3.service 43 | - sudo systemctl start kresd@4.service 44 | - sudo systemctl start kresd@5.service 45 | - sudo systemctl start kresd@6.service 46 | - sudo systemctl start kresd@7.service 47 | - sudo systemctl start kresd@8.service 48 | - sudo apt-get install ffmpeg libsm6 libxext6 -y 49 | 50 | setup_commands: 51 | - wget https://repo.anaconda.com/miniconda/Miniconda3-py39_22.11.1-1-Linux-x86_64.sh -O miniconda.sh 52 | - bash ~/miniconda.sh -f -b -p miniconda3/ 53 | - echo 'export PATH="$HOME/miniconda3/bin/:$PATH"' >> ~/.bashrc 54 | # if you have AWS CREDS fill them out here 55 | #- echo 'export AWS_ACCESS_KEY_ID=' >> ~/.bashrc 56 | #- echo 'export AWS_SECRET_ACCESS_KEY=' >> ~/.bashrc 57 | - pip install --upgrade pip setuptools wheel 58 | - pip install ray 59 | - pip install img2dataset 60 | - pip install opencv-python --upgrade 61 | - wandb login KEY 62 | - pip install s3fs==2022.11.0 63 | - pip install botocore==1.27.59 64 | 65 | head_setup_commands: [] 66 | 67 | -------------------------------------------------------------------------------- /examples/ray_example/ray_example.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import time 3 | from collections import Counter 4 | 5 | import ray 6 | from img2dataset import download 7 | 8 | import argparse 9 | 10 | 11 | @ray.remote 12 | def main(args): 13 | download( 14 | processes_count=1, 15 | thread_count=32, 16 | retries=0, 17 | timeout=10, 18 | url_list=args.url_list, 19 | image_size=512, 20 | resize_only_if_bigger=True, 21 | resize_mode="keep_ratio_largest", 22 | skip_reencode=True, 23 | output_folder=args.out_folder, 24 | output_format="webdataset", 25 | input_format="parquet", 26 | url_col="url", 27 | caption_col="alt", 28 | enable_wandb=True, 29 | subjob_size=48 * 120 * 2, 30 | number_sample_per_shard=10000, 31 | distributor="ray", 32 | oom_shard_count=8, 33 | compute_hash="sha256", 34 | save_additional_columns=["uid"], 35 | ) 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser() 40 | parser.add_argument("--url_list") 41 | parser.add_argument("--out_folder") 42 | args = parser.parse_args() 43 | ray.init(address="localhost:6379") 44 | main(args) 45 | -------------------------------------------------------------------------------- /examples/simple_example.py: -------------------------------------------------------------------------------- 1 | from img2dataset import download 2 | import shutil 3 | import os 4 | 5 | if __name__ == "__main__": 6 | output_dir = os.path.abspath("bench") 7 | 8 | if os.path.exists(output_dir): 9 | shutil.rmtree(output_dir) 10 | 11 | download( 12 | processes_count=16, 13 | thread_count=32, 14 | url_list="../tests/test_files/test_10000.parquet", 15 | image_size=256, 16 | output_folder=output_dir, 17 | output_format="files", 18 | input_format="parquet", 19 | url_col="URL", 20 | caption_col="TEXT", 21 | enable_wandb=True, 22 | number_sample_per_shard=1000, 23 | distributor="multiprocessing", 24 | ) 25 | 26 | # rm -rf bench 27 | -------------------------------------------------------------------------------- /img2dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """Img2dataset""" 2 | 3 | from img2dataset.main import main 4 | from img2dataset.main import download 5 | -------------------------------------------------------------------------------- /img2dataset/architecture.md: -------------------------------------------------------------------------------- 1 | Img2dataset is split in these modules: 2 | 3 | * reader: read the url data and yield it as file shards (list of arrow files) 4 | * writer: write the image data 5 | * resizer: take as input images, and return resized images 6 | * downloader: takes one shard, read it to memory, resize it using the resizer, write it using the writer 7 | * main: takes a collection of files, reads them as shards using the reader, spawn N processes and in each use a downloader to process shards 8 | 9 | Main is the only one that is exposed to the user 10 | 11 | The objective of this split in modules is to make it easier to expand the functionalities (new input and output format, new resizing, new ways to distribute) 12 | 13 | -------------------------------------------------------------------------------- /img2dataset/blurrer.py: -------------------------------------------------------------------------------- 1 | """blurrer module to blur parts of the image""" 2 | 3 | import numpy as np 4 | 5 | import albumentations as A 6 | 7 | 8 | class BoundingBoxBlurrer: 9 | """blur images based on a bounding box. 10 | 11 | The bounding box used is assumed to have format [x_min, y_min, x_max, y_max] 12 | (with elements being floats in [0,1], relative to the original shape of the 13 | image). 14 | """ 15 | 16 | def __init__(self) -> None: 17 | pass 18 | 19 | def __call__(self, img, bbox_list): 20 | """Apply blurring to bboxes of an image. 21 | 22 | Args: 23 | img: The image to blur. 24 | bbox_list: The list of bboxes to blur. 25 | 26 | Returns: 27 | The image with bboxes blurred. 28 | """ 29 | 30 | # Skip if there are no boxes to blur. 31 | if len(bbox_list) == 0: 32 | return img 33 | 34 | height, width = img.shape[:2] 35 | 36 | # Convert to float temporarily 37 | img = img.astype(np.float32) / 255.0 38 | 39 | mask = np.zeros_like(img) 40 | 41 | # Incorporate max diagonal from ImageNet code. 42 | max_diagonal = 0 43 | 44 | for bbox in bbox_list: 45 | adjusted_bbox = [ 46 | int(bbox[0] * width), 47 | int(bbox[1] * height), 48 | int(bbox[2] * width), 49 | int(bbox[3] * height), 50 | ] 51 | 52 | diagonal = max(adjusted_bbox[2] - adjusted_bbox[0], adjusted_bbox[3] - adjusted_bbox[1]) 53 | max_diagonal = max(max_diagonal, diagonal) 54 | 55 | # Adjusting bbox as in: 56 | # https://github.com/princetonvisualai/imagenet-face-obfuscation 57 | adjusted_bbox[0] = int(adjusted_bbox[0] - 0.1 * diagonal) 58 | adjusted_bbox[1] = int(adjusted_bbox[1] - 0.1 * diagonal) 59 | adjusted_bbox[2] = int(adjusted_bbox[2] + 0.1 * diagonal) 60 | adjusted_bbox[3] = int(adjusted_bbox[3] + 0.1 * diagonal) 61 | 62 | # Clipping for indexing. 63 | adjusted_bbox[0] = np.clip(adjusted_bbox[0], 0, width - 1) 64 | adjusted_bbox[1] = np.clip(adjusted_bbox[1], 0, height - 1) 65 | adjusted_bbox[2] = np.clip(adjusted_bbox[2], 0, width - 1) 66 | adjusted_bbox[3] = np.clip(adjusted_bbox[3], 0, height - 1) 67 | 68 | mask[adjusted_bbox[1] : adjusted_bbox[3], adjusted_bbox[0] : adjusted_bbox[2], ...] = 1 69 | 70 | sigma = 0.1 * max_diagonal 71 | ksize = int(2 * np.ceil(4 * sigma)) + 1 72 | blurred_img = A.augmentations.gaussian_blur(img, ksize=ksize, sigma=sigma) 73 | blurred_mask = A.augmentations.gaussian_blur(mask, ksize=ksize, sigma=sigma) 74 | 75 | result = img * (1 - blurred_mask) + blurred_img * blurred_mask 76 | 77 | # Convert back to uint8 78 | result = (result * 255.0).astype(np.uint8) 79 | 80 | return result 81 | -------------------------------------------------------------------------------- /img2dataset/distributor.py: -------------------------------------------------------------------------------- 1 | """distributor defines the distribution strategies for img2dataset""" 2 | 3 | from contextlib import contextmanager 4 | from multiprocessing import get_context 5 | from itertools import islice, chain 6 | 7 | from tqdm import tqdm 8 | 9 | 10 | def retrier(runf, failed_shards, max_shard_retry): 11 | # retry failed shards max_shard_retry times 12 | for i in range(max_shard_retry): 13 | if len(failed_shards) == 0: 14 | break 15 | print(f"Retrying {len(failed_shards)} shards, try {i+1}") 16 | failed_shards = runf(failed_shards) 17 | if len(failed_shards) != 0: 18 | print( 19 | f"Retried {max_shard_retry} times, but {len(failed_shards)} shards " 20 | "still failed. You may restart the same command to retry again." 21 | ) 22 | 23 | 24 | def multiprocessing_distributor(processes_count, downloader, reader, _, max_shard_retry): 25 | """Distribute the work to the processes using multiprocessing""" 26 | ctx = get_context("spawn") 27 | with ctx.Pool(processes_count, maxtasksperchild=5) as process_pool: 28 | 29 | def run(gen): 30 | failed_shards = [] 31 | for status, row in tqdm(process_pool.imap_unordered(downloader, gen)): 32 | if status is False: 33 | failed_shards.append(row) 34 | return failed_shards 35 | 36 | failed_shards = run(reader) 37 | 38 | retrier(run, failed_shards, max_shard_retry) 39 | 40 | process_pool.terminate() 41 | process_pool.join() 42 | del process_pool 43 | 44 | 45 | def pyspark_distributor(processes_count, downloader, reader, subjob_size, max_shard_retry): 46 | """Distribute the work to the processes using pyspark""" 47 | 48 | with _spark_session(processes_count) as spark: 49 | 50 | def batcher(iterable, batch_size): 51 | iterator = iter(iterable) 52 | for first in iterator: 53 | yield list(chain([first], islice(iterator, batch_size - 1))) 54 | 55 | def run(gen): 56 | failed_shards = [] 57 | for batch in batcher(gen, subjob_size): 58 | rdd = spark.sparkContext.parallelize(batch, len(batch)) 59 | for status, row in rdd.map(downloader).collect(): 60 | if status is False: 61 | failed_shards.append(row) 62 | return failed_shards 63 | 64 | failed_shards = run(reader) 65 | 66 | retrier(run, failed_shards, max_shard_retry) 67 | 68 | 69 | try: 70 | import ray # pylint: disable=import-outside-toplevel 71 | 72 | @ray.remote 73 | def ray_download(downloader, shards): 74 | status, row = downloader(shards) 75 | return status, row 76 | 77 | def ray_distributor(processes_count, downloader, reader, _, max_shard_retry): # type: ignore 78 | # pylint: disable=unused-argument 79 | ret = [] 80 | count = 0 81 | for task in reader: 82 | count += 1 83 | ret.append(ray_download.remote(downloader, task)) 84 | ray.get(ret) 85 | 86 | except ModuleNotFoundError as e: 87 | 88 | def ray_distributor(processes_count, downloader, reader, subjob_size, max_shard_retry): # type: ignore # pylint: disable=unused-argument 89 | return None 90 | 91 | 92 | @contextmanager 93 | def _spark_session(processes_count: int): 94 | """Create and close a spark session if none exist""" 95 | 96 | from pyspark.sql import SparkSession # pylint: disable=import-outside-toplevel 97 | import pyspark # pylint: disable=import-outside-toplevel 98 | 99 | spark_major_version = int(pyspark.version.__version__[0]) 100 | if spark_major_version >= 3: 101 | spark = SparkSession.getActiveSession() 102 | else: 103 | spark = pyspark.sql.SparkSession._instantiatedSession # type: ignore # pylint: disable=protected-access 104 | 105 | if spark is None: 106 | print("No pyspark session found, creating a new one!") 107 | owned = True 108 | spark = ( 109 | SparkSession.builder.config("spark.driver.memory", "16G") 110 | .master("local[" + str(processes_count) + "]") 111 | .appName("spark-stats") 112 | .getOrCreate() 113 | ) 114 | else: 115 | owned = False 116 | 117 | try: 118 | yield spark 119 | finally: 120 | if owned: 121 | spark.stop() 122 | -------------------------------------------------------------------------------- /img2dataset/downloader.py: -------------------------------------------------------------------------------- 1 | """the downloader module handles the downloading""" 2 | 3 | from multiprocessing.pool import ThreadPool 4 | from threading import Semaphore 5 | import urllib.request 6 | import io 7 | import math 8 | import exifread 9 | import json 10 | import time 11 | import hashlib 12 | import pyarrow as pa 13 | import traceback 14 | 15 | import fsspec 16 | from .logger import CappedCounter 17 | from .logger import write_stats 18 | 19 | 20 | def is_disallowed(headers, user_agent_token, disallowed_header_directives): 21 | """Check if HTTP headers contain an X-Robots-Tag directive disallowing usage""" 22 | for values in headers.get_all("X-Robots-Tag", []): 23 | try: 24 | uatoken_directives = values.split(":", 1) 25 | directives = [x.strip().lower() for x in uatoken_directives[-1].split(",")] 26 | ua_token = uatoken_directives[0].lower() if len(uatoken_directives) == 2 else None 27 | if (ua_token is None or ua_token == user_agent_token) and any( 28 | x in disallowed_header_directives for x in directives 29 | ): 30 | return True 31 | except Exception as err: # pylint: disable=broad-except 32 | traceback.print_exc() 33 | print(f"Failed to parse X-Robots-Tag: {values}: {err}") 34 | return False 35 | 36 | 37 | def download_image(row, timeout, user_agent_token, disallowed_header_directives): 38 | """Download an image with urllib""" 39 | key, url = row 40 | img_stream = None 41 | user_agent_string = "Mozilla/5.0 (X11; Ubuntu; Linux x86_64; rv:72.0) Gecko/20100101 Firefox/72.0" 42 | if user_agent_token: 43 | user_agent_string += f" (compatible; {user_agent_token}; +https://github.com/rom1504/img2dataset)" 44 | try: 45 | request = urllib.request.Request(url, data=None, headers={"User-Agent": user_agent_string}) 46 | with urllib.request.urlopen(request, timeout=timeout) as r: 47 | if disallowed_header_directives and is_disallowed( 48 | r.headers, 49 | user_agent_token, 50 | disallowed_header_directives, 51 | ): 52 | return key, None, "Use of image disallowed by X-Robots-Tag directive" 53 | img_stream = io.BytesIO(r.read()) 54 | return key, img_stream, None 55 | except Exception as err: # pylint: disable=broad-except 56 | if img_stream is not None: 57 | img_stream.close() 58 | return key, None, str(err) 59 | 60 | 61 | def download_image_with_retry(row, timeout, retries, user_agent_token, disallowed_header_directives): 62 | for _ in range(retries + 1): 63 | key, img_stream, err = download_image(row, timeout, user_agent_token, disallowed_header_directives) 64 | if img_stream is not None: 65 | return key, img_stream, err 66 | return key, None, err 67 | 68 | 69 | def compute_key(key, shard_id, oom_sample_per_shard, oom_shard_count): 70 | true_key = (10**oom_sample_per_shard) * shard_id + key 71 | key_format = oom_sample_per_shard + oom_shard_count 72 | str_key = "{true_key:0{key_format}d}".format( # pylint: disable=consider-using-f-string 73 | key_format=key_format, true_key=true_key 74 | ) 75 | return str_key 76 | 77 | 78 | class Downloader: 79 | """The downloader class gets calls with shards, download them then call the writer to write them down""" 80 | 81 | def __init__( 82 | self, 83 | sample_writer_class, 84 | resizer, 85 | thread_count, 86 | save_caption, 87 | extract_exif, 88 | output_folder, 89 | column_list, 90 | timeout, 91 | number_sample_per_shard, 92 | oom_shard_count, 93 | compute_hash, 94 | verify_hash_type, 95 | encode_format, 96 | retries, 97 | user_agent_token, 98 | disallowed_header_directives, 99 | blurring_bbox_col=None, 100 | ) -> None: 101 | self.sample_writer_class = sample_writer_class 102 | self.resizer = resizer 103 | self.thread_count = thread_count 104 | self.save_caption = save_caption 105 | self.extract_exif = extract_exif 106 | self.output_folder = output_folder 107 | self.column_list = column_list 108 | self.timeout = timeout 109 | self.number_sample_per_shard = number_sample_per_shard 110 | self.oom_shard_count = oom_shard_count 111 | self.compute_hash = compute_hash 112 | self.verify_hash_type = verify_hash_type 113 | self.encode_format = encode_format 114 | self.retries = retries 115 | self.user_agent_token = None if user_agent_token is None else user_agent_token.strip().lower() 116 | self.disallowed_header_directives = ( 117 | None 118 | if disallowed_header_directives is None 119 | else {directive.strip().lower() for directive in disallowed_header_directives} 120 | ) 121 | self.blurring_bbox_col = blurring_bbox_col 122 | 123 | def __call__( 124 | self, 125 | row, 126 | ): 127 | try: 128 | self.download_shard(row) 129 | return (True, row) 130 | except Exception as err: # pylint: disable=broad-except 131 | traceback.print_exc() 132 | print(f"shard {row[0]} failed with error {err}") 133 | return (False, row) 134 | 135 | def download_shard( 136 | self, 137 | row, 138 | ): 139 | """Function to start an image downloading in one process""" 140 | 141 | shard_id, shard_file = row 142 | start_time = time.time() 143 | 144 | fs, shard_path = fsspec.core.url_to_fs(shard_file) 145 | with fs.open(shard_path, "rb") as f: 146 | df = pa.ipc.open_file(f).read_all() 147 | schema = df.schema 148 | schema = ( 149 | schema.append(pa.field("key", pa.string())) 150 | .append(pa.field("status", pa.string())) 151 | .append(pa.field("error_message", pa.string())) 152 | .append(pa.field("width", pa.int32())) 153 | .append(pa.field("height", pa.int32())) 154 | .append(pa.field("original_width", pa.int32())) 155 | .append(pa.field("original_height", pa.int32())) 156 | ) 157 | if self.extract_exif: 158 | schema = schema.append(pa.field("exif", pa.string())) 159 | 160 | if self.compute_hash is not None and self.compute_hash not in schema.names: 161 | schema = schema.append(pa.field(self.compute_hash, pa.string())) 162 | 163 | pydict = df.select(self.column_list).to_pydict() 164 | shard_to_dl = list(enumerate(zip(*(pydict[col] for col in self.column_list)))) 165 | del pydict 166 | del df 167 | 168 | status_dict = CappedCounter() 169 | 170 | count = len(shard_to_dl) 171 | successes = 0 172 | failed_to_download = 0 173 | failed_to_resize = 0 174 | url_indice = self.column_list.index("url") 175 | caption_indice = self.column_list.index("caption") if "caption" in self.column_list else None 176 | hash_indice = ( 177 | self.column_list.index(self.verify_hash_type) if self.verify_hash_type in self.column_list else None 178 | ) 179 | bbox_indice = self.column_list.index(self.blurring_bbox_col) if self.blurring_bbox_col is not None else None 180 | key_url_list = [(key, x[url_indice]) for key, x in shard_to_dl] 181 | 182 | # this prevents an accumulation of more than twice the number of threads in sample ready to resize 183 | # limit the memory usage 184 | semaphore = Semaphore(self.thread_count * 2) 185 | 186 | def data_generator(): 187 | for e in key_url_list: 188 | semaphore.acquire() # pylint: disable=consider-using-with 189 | yield e 190 | 191 | loader = data_generator() 192 | 193 | # give schema to writer 194 | sample_writer = self.sample_writer_class( 195 | shard_id, 196 | self.output_folder, 197 | self.save_caption, 198 | self.oom_shard_count, 199 | schema, 200 | self.encode_format, 201 | ) 202 | oom_sample_per_shard = math.ceil(math.log10(self.number_sample_per_shard)) 203 | with ThreadPool(self.thread_count) as thread_pool: 204 | for key, img_stream, error_message in thread_pool.imap_unordered( 205 | lambda x: download_image_with_retry( 206 | x, 207 | timeout=self.timeout, 208 | retries=self.retries, 209 | user_agent_token=self.user_agent_token, 210 | disallowed_header_directives=self.disallowed_header_directives, 211 | ), 212 | loader, 213 | ): 214 | try: 215 | _, sample_data = shard_to_dl[key] 216 | str_key = compute_key(key, shard_id, oom_sample_per_shard, self.oom_shard_count) 217 | meta = { 218 | # Skip columns containing a the verification hash and only save the compute hash 219 | **{ 220 | self.column_list[i]: sample_data[i] 221 | for i in range(len(self.column_list)) 222 | if (hash_indice is None or i != hash_indice) 223 | }, 224 | "key": str_key, 225 | "status": None, 226 | "error_message": error_message, 227 | "width": None, 228 | "height": None, 229 | "original_width": None, 230 | "original_height": None, 231 | } 232 | if self.extract_exif: 233 | meta["exif"] = None 234 | 235 | if self.compute_hash is not None: 236 | meta[self.compute_hash] = None 237 | 238 | if error_message is not None: 239 | failed_to_download += 1 240 | status = "failed_to_download" 241 | status_dict.increment(error_message) 242 | meta["status"] = status 243 | sample_writer.write( 244 | None, 245 | str_key, 246 | sample_data[caption_indice] if caption_indice is not None else None, 247 | meta, 248 | ) 249 | semaphore.release() 250 | continue 251 | 252 | if hash_indice is not None: 253 | img_stream.seek(0) 254 | test_hash = getattr(hashlib, self.verify_hash_type)(img_stream.read()).hexdigest() 255 | if test_hash != sample_data[hash_indice]: 256 | failed_to_download += 1 257 | status = "failed_to_download" 258 | status_dict.increment("hash mismatch") 259 | meta["status"] = status 260 | meta["error_message"] = "hash mismatch" 261 | sample_writer.write( 262 | None, 263 | str_key, 264 | sample_data[caption_indice] if caption_indice is not None else None, 265 | meta, 266 | ) 267 | img_stream.close() 268 | del img_stream 269 | semaphore.release() 270 | continue 271 | 272 | img_stream.seek(0) 273 | bbox_list = sample_data[bbox_indice] if bbox_indice is not None else None 274 | ( 275 | img, 276 | width, 277 | height, 278 | original_width, 279 | original_height, 280 | error_message, 281 | ) = self.resizer(img_stream, bbox_list) 282 | if error_message is not None: 283 | failed_to_resize += 1 284 | status = "failed_to_resize" 285 | status_dict.increment(error_message) 286 | meta["status"] = status 287 | meta["error_message"] = error_message 288 | sample_writer.write( 289 | None, 290 | str_key, 291 | sample_data[caption_indice] if caption_indice is not None else None, 292 | meta, 293 | ) 294 | img_stream.close() 295 | del img_stream 296 | semaphore.release() 297 | continue 298 | successes += 1 299 | status = "success" 300 | status_dict.increment(status) 301 | 302 | if self.extract_exif: 303 | try: 304 | img_stream.seek(0) 305 | exif = json.dumps( 306 | { 307 | k: str(v).strip() 308 | for k, v in exifread.process_file(img_stream, details=False).items() 309 | if v is not None 310 | } 311 | ) 312 | except Exception as _: # pylint: disable=broad-except 313 | exif = None 314 | meta["exif"] = exif 315 | 316 | if self.compute_hash is not None: 317 | img_stream.seek(0) 318 | meta[self.compute_hash] = getattr(hashlib, self.compute_hash)(img_stream.read()).hexdigest() 319 | 320 | meta["status"] = status 321 | meta["width"] = width 322 | meta["height"] = height 323 | meta["original_width"] = original_width 324 | meta["original_height"] = original_height 325 | img_stream.close() 326 | del img_stream 327 | 328 | sample_writer.write( 329 | img, 330 | str_key, 331 | sample_data[caption_indice] if caption_indice is not None else None, 332 | meta, 333 | ) 334 | except Exception as err: # pylint: disable=broad-except 335 | traceback.print_exc() 336 | print(f"Sample {key} failed to download: {err}") 337 | semaphore.release() 338 | 339 | sample_writer.close() 340 | thread_pool.terminate() 341 | thread_pool.join() 342 | del thread_pool 343 | 344 | end_time = time.time() 345 | write_stats( 346 | self.output_folder, 347 | shard_id, 348 | count, 349 | successes, 350 | failed_to_download, 351 | failed_to_resize, 352 | start_time, 353 | end_time, 354 | status_dict, 355 | self.oom_shard_count, 356 | ) 357 | fs.rm(shard_path) 358 | -------------------------------------------------------------------------------- /img2dataset/logger.py: -------------------------------------------------------------------------------- 1 | """logging utils for the downloader""" 2 | 3 | import wandb 4 | import time 5 | from collections import Counter 6 | import fsspec 7 | import json 8 | import multiprocessing 9 | import queue 10 | import traceback 11 | 12 | 13 | class CappedCounter: 14 | """Maintain a counter with a capping to avoid memory issues""" 15 | 16 | def __init__(self, max_size=10**5): 17 | self.max_size = max_size 18 | self.counter = Counter() 19 | 20 | def increment(self, key): 21 | if len(self.counter) >= self.max_size: 22 | self._keep_most_frequent() 23 | self.counter[key] += 1 24 | 25 | def _keep_most_frequent(self): 26 | self.counter = Counter(dict(self.counter.most_common(int(self.max_size / 2)))) 27 | 28 | def most_common(self, k): 29 | return self.counter.most_common(k) 30 | 31 | def update(self, counter): 32 | self.counter.update(counter.counter) 33 | if len(self.counter) >= self.max_size: 34 | self._keep_most_frequent() 35 | 36 | def dump(self): 37 | return self.counter 38 | 39 | @classmethod 40 | def load(cls, d, max_size=10**5): 41 | c = CappedCounter(max_size) 42 | c.counter = Counter(d) 43 | return c 44 | 45 | 46 | class Logger: 47 | """logger which logs when number of calls reaches a value or a time interval has passed""" 48 | 49 | def __init__(self, min_interval=0): 50 | """Log only every if min_interval (seconds) have elapsed since last log""" 51 | # wait for all processes to return 52 | self.processes_returned = 0 53 | # min time (in seconds) before logging a new table (avoids too many logs) 54 | self.min_interval = min_interval 55 | self.last = time.perf_counter() 56 | # keep track of whether we logged the last call 57 | self.last_call_logged = False 58 | self.last_args = None 59 | self.last_kwargs = None 60 | 61 | def __call__(self, *args, **kwargs): 62 | self.processes_returned += 1 63 | if time.perf_counter() - self.last > self.min_interval: 64 | self.do_log(*args, **kwargs) 65 | self.last = time.perf_counter() 66 | self.last_call_logged = True 67 | else: 68 | self.last_call_logged = False 69 | self.last_args = args 70 | self.last_kwargs = kwargs 71 | 72 | def do_log(self, *args, **kwargs): 73 | raise NotImplementedError() 74 | 75 | def sync(self): 76 | """Ensure last call is logged""" 77 | if not self.last_call_logged and self.last_args is not None: 78 | self.do_log(*self.last_args, **self.last_kwargs) 79 | # reset for next file 80 | self.processes_returned = 0 81 | 82 | 83 | class SpeedLogger(Logger): 84 | """Log performance metrics""" 85 | 86 | def __init__(self, prefix, enable_wandb, **logger_args): 87 | super().__init__(**logger_args) 88 | self.prefix = prefix 89 | self.start_time = float("+inf") 90 | self.end_time = float("-inf") 91 | self.count = 0 92 | self.success = 0 93 | self.failed_to_download = 0 94 | self.failed_to_resize = 0 95 | self.enable_wandb = enable_wandb 96 | 97 | def __call__( 98 | self, count, success, failed_to_download, failed_to_resize, start_time, end_time 99 | ): # pylint: disable=arguments-differ 100 | self.count += count 101 | self.success += success 102 | self.failed_to_download += failed_to_download 103 | self.failed_to_resize += failed_to_resize 104 | self.start_time = min(start_time, self.start_time) 105 | self.end_time = max(end_time, self.end_time) 106 | super().__call__( 107 | self.count, self.success, self.failed_to_download, self.failed_to_resize, self.start_time, self.end_time 108 | ) 109 | 110 | def do_log( 111 | self, count, success, failed_to_download, failed_to_resize, start_time, end_time 112 | ): # pylint: disable=arguments-differ 113 | duration = end_time - start_time 114 | img_per_sec = count / duration 115 | success_ratio = 1.0 * success / count 116 | failed_to_download_ratio = 1.0 * failed_to_download / count 117 | failed_to_resize_ratio = 1.0 * failed_to_resize / count 118 | 119 | print( 120 | " - ".join( 121 | [ 122 | f"{self.prefix:<7}", 123 | f"success: {success_ratio:.3f}", 124 | f"failed to download: {failed_to_download_ratio:.3f}", 125 | f"failed to resize: {failed_to_resize_ratio:.3f}", 126 | f"images per sec: {img_per_sec:.0f}", 127 | f"count: {count}", 128 | ] 129 | ) 130 | ) 131 | 132 | if self.enable_wandb: 133 | wandb.log( 134 | { 135 | f"{self.prefix}/img_per_sec": img_per_sec, 136 | f"{self.prefix}/success": success_ratio, 137 | f"{self.prefix}/failed_to_download": failed_to_download_ratio, 138 | f"{self.prefix}/failed_to_resize": failed_to_resize_ratio, 139 | f"{self.prefix}/count": count, 140 | } 141 | ) 142 | 143 | 144 | class StatusTableLogger(Logger): 145 | """Log status table to W&B, up to `max_status` most frequent items""" 146 | 147 | def __init__(self, max_status=100, min_interval=60, enable_wandb=False, **logger_args): 148 | super().__init__(min_interval=min_interval, **logger_args) 149 | # avoids too many errors unique to a specific website (SSL certificates, etc) 150 | self.max_status = max_status 151 | self.enable_wandb = enable_wandb 152 | 153 | def do_log(self, status_dict, count): # pylint: disable=arguments-differ 154 | if self.enable_wandb: 155 | status_table = wandb.Table( 156 | columns=["status", "frequency", "count"], 157 | data=[[k, 1.0 * v / count, v] for k, v in status_dict.most_common(self.max_status)], 158 | ) 159 | wandb.run.log({"status": status_table}) 160 | 161 | 162 | def write_stats( 163 | output_folder, 164 | shard_id, 165 | count, 166 | successes, 167 | failed_to_download, 168 | failed_to_resize, 169 | start_time, 170 | end_time, 171 | status_dict, 172 | oom_shard_count, 173 | ): 174 | """Write stats to disk""" 175 | stats = { 176 | "count": count, 177 | "successes": successes, 178 | "failed_to_download": failed_to_download, 179 | "failed_to_resize": failed_to_resize, 180 | "duration": end_time - start_time, 181 | "start_time": start_time, 182 | "end_time": end_time, 183 | "status_dict": status_dict.dump(), 184 | } 185 | fs, output_path = fsspec.core.url_to_fs(output_folder) 186 | shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string 187 | shard_id=shard_id, oom_shard_count=oom_shard_count 188 | ) 189 | json_file = f"{output_path}/{shard_name}_stats.json" 190 | with fs.open(json_file, "w") as f: 191 | json.dump(stats, f, indent=4) 192 | 193 | 194 | # https://docs.python.org/3/library/multiprocessing.html 195 | # logger process that reads stats files regularly, aggregates and send to wandb / print to terminal 196 | class LoggerProcess(multiprocessing.context.SpawnProcess): 197 | """Logger process that reads stats files regularly, aggregates and send to wandb / print to terminal""" 198 | 199 | def __init__(self, output_folder, enable_wandb, wandb_project, config_parameters, log_interval=5): 200 | super().__init__() 201 | self.log_interval = log_interval 202 | self.enable_wandb = enable_wandb 203 | self.output_folder = output_folder 204 | self.stats_files = set() 205 | self.wandb_project = wandb_project 206 | self.done_shards = set() 207 | self.config_parameters = config_parameters 208 | ctx = multiprocessing.get_context("spawn") 209 | self.q = ctx.Queue() 210 | 211 | def run(self): 212 | """Run logger process""" 213 | 214 | fs, output_path = fsspec.core.url_to_fs(self.output_folder, use_listings_cache=False) 215 | 216 | if self.enable_wandb: 217 | self.current_run = wandb.init(project=self.wandb_project, config=self.config_parameters, anonymous="allow") 218 | else: 219 | self.current_run = None 220 | self.total_speed_logger = SpeedLogger("total", enable_wandb=self.enable_wandb) 221 | self.status_table_logger = StatusTableLogger(enable_wandb=self.enable_wandb) 222 | last_check = 0 223 | total_status_dict = CappedCounter() 224 | while True: 225 | time.sleep(0.1) 226 | try: 227 | self.q.get(False) 228 | last_one = True 229 | except queue.Empty as _: 230 | last_one = False 231 | if not last_one and time.perf_counter() - last_check < self.log_interval: 232 | continue 233 | 234 | try: 235 | # read stats files 236 | stats_files = fs.glob(output_path + "/*.json") 237 | 238 | # filter out files that have an id smaller that are already done 239 | stats_files = [f for f in stats_files if int(f.split("/")[-1].split("_")[0]) not in self.done_shards] 240 | 241 | # get new stats files 242 | new_stats_files = set(stats_files) - self.stats_files 243 | if len(new_stats_files) == 0: 244 | if last_one: 245 | self.finish() 246 | return 247 | 248 | # read new stats files 249 | for stats_file in new_stats_files: 250 | with fs.open(stats_file, "r") as f: 251 | try: 252 | stats = json.load(f) 253 | SpeedLogger("worker", enable_wandb=self.enable_wandb)( 254 | count=stats["count"], 255 | success=stats["successes"], 256 | failed_to_download=stats["failed_to_download"], 257 | failed_to_resize=stats["failed_to_resize"], 258 | start_time=stats["start_time"], 259 | end_time=stats["end_time"], 260 | ) 261 | self.total_speed_logger( 262 | count=stats["count"], 263 | success=stats["successes"], 264 | failed_to_download=stats["failed_to_download"], 265 | failed_to_resize=stats["failed_to_resize"], 266 | start_time=stats["start_time"], 267 | end_time=stats["end_time"], 268 | ) 269 | status_dict = CappedCounter.load(stats["status_dict"]) 270 | total_status_dict.update(status_dict) 271 | self.status_table_logger(total_status_dict, self.total_speed_logger.count) 272 | except Exception as err: # pylint: disable=broad-except 273 | print(f"failed to parse stats file {stats_file}", err) 274 | 275 | self.stats_files.add(stats_file) 276 | last_check = time.perf_counter() 277 | 278 | if last_one: 279 | self.finish() 280 | return 281 | except Exception as e: # pylint: disable=broad-except 282 | traceback.print_exc() 283 | print("logger error", e) 284 | self.finish() 285 | return 286 | 287 | def finish(self): 288 | """Finish logger process""" 289 | self.total_speed_logger.sync() 290 | self.status_table_logger.sync() 291 | if self.current_run is not None: 292 | self.current_run.finish() 293 | 294 | def join(self, timeout=None): 295 | """Stop logger process""" 296 | self.q.put("stop") 297 | super().join() 298 | self.q.close() 299 | -------------------------------------------------------------------------------- /img2dataset/main.py: -------------------------------------------------------------------------------- 1 | """Img2dataset""" 2 | 3 | from typing import List, Optional 4 | import fire 5 | import logging 6 | from .logger import LoggerProcess 7 | from .resizer import Resizer 8 | from .blurrer import BoundingBoxBlurrer 9 | from .writer import ( 10 | WebDatasetSampleWriter, 11 | FilesSampleWriter, 12 | ParquetSampleWriter, 13 | TFRecordSampleWriter, 14 | DummySampleWriter, 15 | ) 16 | from .reader import Reader 17 | from .downloader import Downloader 18 | from .distributor import ( 19 | multiprocessing_distributor, 20 | pyspark_distributor, 21 | ray_distributor, 22 | ) 23 | import fsspec 24 | import sys 25 | import signal 26 | import os 27 | 28 | logging.getLogger("exifread").setLevel(level=logging.CRITICAL) 29 | 30 | 31 | def arguments_validator(params): 32 | """Validate the arguments""" 33 | if params["compute_hash"] not in [None, "md5", "sha256", "sha512"]: 34 | hash_type = params["compute_hash"] 35 | raise ValueError(f"Unsupported hash to compute: {hash_type}") 36 | 37 | if params["verify_hash"] is not None: 38 | _, verify_hash_type = params["verify_hash"] 39 | if verify_hash_type != params["compute_hash"]: 40 | raise ValueError( 41 | "verify_hash and compute_hash must be the same " 42 | f"but got {verify_hash_type} and {params['compute_hash']}" 43 | ) 44 | 45 | if params["save_additional_columns"] is not None: 46 | save_additional_columns_set = set(params["save_additional_columns"]) 47 | 48 | forbidden_columns = set( 49 | [ 50 | "key", 51 | "caption", 52 | "url", 53 | "width", 54 | "height", 55 | "original_width", 56 | "original_height", 57 | "status", 58 | "error_message", 59 | "exif", 60 | "md5", 61 | "sha256", 62 | "sha512", 63 | ] 64 | ) 65 | intersection = save_additional_columns_set.intersection(forbidden_columns) 66 | if intersection: 67 | raise ValueError( 68 | f"You cannot use in save_additional_columns the following columns: {intersection}." 69 | + "img2dataset reserves these columns for its own use. Please remove them from save_additional_columns." 70 | ) 71 | 72 | 73 | def download( 74 | url_list: str, 75 | image_size: int = 256, 76 | output_folder: str = "images", 77 | processes_count: int = 1, 78 | resize_mode: str = "border", 79 | resize_only_if_bigger: bool = False, 80 | upscale_interpolation: str = "lanczos", 81 | downscale_interpolation: str = "area", 82 | encode_quality: int = 95, 83 | encode_format: str = "jpg", 84 | skip_reencode: bool = False, 85 | output_format: str = "files", 86 | input_format: str = "txt", 87 | url_col: str = "url", 88 | caption_col: Optional[str] = None, 89 | bbox_col: Optional[str] = None, 90 | thread_count: int = 256, 91 | number_sample_per_shard: int = 10000, 92 | extract_exif: bool = True, 93 | save_additional_columns: Optional[List[str]] = None, 94 | timeout: int = 10, 95 | enable_wandb: bool = False, 96 | wandb_project: str = "img2dataset", 97 | oom_shard_count: int = 5, 98 | compute_hash: Optional[str] = "sha256", 99 | verify_hash: Optional[List[str]] = None, 100 | distributor: str = "multiprocessing", 101 | subjob_size: int = 1000, 102 | retries: int = 0, 103 | disable_all_reencoding: bool = False, 104 | min_image_size: int = 0, 105 | max_image_area: float = float("inf"), 106 | max_aspect_ratio: float = float("inf"), 107 | incremental_mode: str = "incremental", 108 | max_shard_retry: int = 1, 109 | user_agent_token: Optional[str] = None, 110 | disallowed_header_directives: Optional[List[str]] = None, 111 | ): 112 | """Download is the main entry point of img2dataset, it uses multiple processes and download multiple files""" 113 | if disallowed_header_directives is None: 114 | disallowed_header_directives = ["noai", "noimageai", "noindex", "noimageindex"] 115 | if len(disallowed_header_directives) == 0: 116 | disallowed_header_directives = None 117 | 118 | config_parameters = dict(locals()) 119 | arguments_validator(config_parameters) 120 | 121 | def make_path_absolute(path): 122 | fs, p = fsspec.core.url_to_fs(path) 123 | if fs.protocol == "file": 124 | return os.path.abspath(p) 125 | return path 126 | 127 | output_folder = make_path_absolute(output_folder) 128 | url_list = make_path_absolute(url_list) 129 | 130 | logger_process = LoggerProcess(output_folder, enable_wandb, wandb_project, config_parameters) 131 | 132 | tmp_path = output_folder + "/_tmp" 133 | fs, tmp_dir = fsspec.core.url_to_fs(tmp_path) 134 | if not fs.exists(tmp_dir): 135 | fs.mkdir(tmp_dir) 136 | 137 | def signal_handler(signal_arg, frame): # pylint: disable=unused-argument 138 | try: 139 | fs.rm(tmp_dir, recursive=True) 140 | except Exception as _: # pylint: disable=broad-except 141 | pass 142 | logger_process.terminate() 143 | sys.exit(0) 144 | 145 | signal.signal(signal.SIGINT, signal_handler) 146 | 147 | save_caption = caption_col is not None 148 | 149 | fs, output_path = fsspec.core.url_to_fs(output_folder) 150 | start_shard_id = 0 151 | 152 | if not fs.exists(output_path): 153 | fs.mkdir(output_path) 154 | done_shards = set() 155 | else: 156 | if incremental_mode == "incremental": 157 | done_shards = set(int(x.split("/")[-1].split("_")[0]) for x in fs.glob(output_path + "/*.json")) 158 | elif incremental_mode == "overwrite": 159 | fs.rm(output_path, recursive=True) 160 | fs.mkdir(output_path) 161 | done_shards = set() 162 | elif incremental_mode == "extend": 163 | existing_shards = [int(x.split("/")[-1].split("_")[0]) for x in fs.glob(output_path + "/*.json")] 164 | start_shard_id = max(existing_shards, default=-1) + 1 165 | done_shards = set() 166 | else: 167 | raise ValueError(f"Unknown incremental mode {incremental_mode}") 168 | 169 | logger_process.done_shards = done_shards 170 | logger_process.start() 171 | 172 | if bbox_col is not None: 173 | if save_additional_columns is None: 174 | save_additional_columns = [bbox_col] 175 | else: 176 | save_additional_columns.append(bbox_col) 177 | 178 | if verify_hash is not None: 179 | verify_hash_col, verify_hash_type = verify_hash 180 | else: 181 | verify_hash_col = None 182 | verify_hash_type = None 183 | 184 | reader = Reader( 185 | url_list, 186 | input_format, 187 | url_col, 188 | caption_col, 189 | verify_hash_col, 190 | verify_hash_type, 191 | save_additional_columns, 192 | number_sample_per_shard, 193 | done_shards, 194 | tmp_path, 195 | start_shard_id, 196 | ) 197 | 198 | if output_format == "webdataset": 199 | sample_writer_class = WebDatasetSampleWriter 200 | elif output_format == "parquet": 201 | sample_writer_class = ParquetSampleWriter # type: ignore 202 | elif output_format == "files": 203 | sample_writer_class = FilesSampleWriter # type: ignore 204 | elif output_format == "tfrecord": 205 | sample_writer_class = TFRecordSampleWriter # type: ignore 206 | elif output_format == "dummy": 207 | sample_writer_class = DummySampleWriter # type: ignore 208 | else: 209 | raise ValueError(f"Invalid output format {output_format}") 210 | 211 | if bbox_col is not None: 212 | blurrer = BoundingBoxBlurrer() 213 | else: 214 | blurrer = None 215 | 216 | resizer = Resizer( 217 | image_size=image_size, 218 | resize_mode=resize_mode, 219 | resize_only_if_bigger=resize_only_if_bigger, 220 | upscale_interpolation=upscale_interpolation, 221 | downscale_interpolation=downscale_interpolation, 222 | encode_quality=encode_quality, 223 | encode_format=encode_format, 224 | skip_reencode=skip_reencode, 225 | disable_all_reencoding=disable_all_reencoding, 226 | min_image_size=min_image_size, 227 | max_image_area=max_image_area, 228 | max_aspect_ratio=max_aspect_ratio, 229 | blurrer=blurrer, 230 | ) 231 | 232 | downloader = Downloader( 233 | sample_writer_class=sample_writer_class, 234 | resizer=resizer, 235 | thread_count=thread_count, 236 | save_caption=save_caption, 237 | extract_exif=extract_exif, 238 | output_folder=output_folder, 239 | column_list=reader.column_list, 240 | timeout=timeout, 241 | number_sample_per_shard=number_sample_per_shard, 242 | oom_shard_count=oom_shard_count, 243 | compute_hash=compute_hash, 244 | verify_hash_type=verify_hash_type, 245 | encode_format=encode_format, 246 | retries=retries, 247 | user_agent_token=user_agent_token, 248 | disallowed_header_directives=disallowed_header_directives, 249 | blurring_bbox_col=bbox_col, 250 | ) 251 | 252 | print("Starting the downloading of this file") 253 | if distributor == "multiprocessing": 254 | distributor_fn = multiprocessing_distributor 255 | elif distributor == "pyspark": 256 | distributor_fn = pyspark_distributor 257 | elif distributor == "ray": 258 | distributor_fn = ray_distributor 259 | else: 260 | raise ValueError(f"Distributor {distributor} not supported") 261 | 262 | distributor_fn( 263 | processes_count, 264 | downloader, 265 | reader, 266 | subjob_size, 267 | max_shard_retry, 268 | ) 269 | logger_process.join() 270 | 271 | if not hasattr(fs, "s3"): 272 | fs.rm(tmp_dir, recursive=True) 273 | 274 | 275 | def main(): 276 | fire.Fire(download) 277 | 278 | 279 | if __name__ == "__main__": 280 | main() 281 | -------------------------------------------------------------------------------- /img2dataset/reader.py: -------------------------------------------------------------------------------- 1 | """Reader is module to read the url list and return shards""" 2 | 3 | from multiprocessing.pool import ThreadPool 4 | import math 5 | import fsspec 6 | import time 7 | import pyarrow.parquet as pq 8 | import pyarrow.csv as csv_pa 9 | import pyarrow.json as json_pa 10 | import pyarrow as pa 11 | import pandas as pd 12 | 13 | 14 | class Reader: 15 | """ 16 | The reader class reads an url list and returns shards 17 | It provides an iter method 18 | It provides attributes: 19 | - column_list: the list of columns to read 20 | - input_format: the format of the input file 21 | - url_col: the column name of the url 22 | - caption_col: the column name of the caption 23 | - verify_hash_col: the column containing the hash to verify. 24 | - verify_hash_type: the type of hash to verify. 25 | - save_additional_columns: the list of additional columns to save 26 | - number_sample_per_shard: the number of samples per shard 27 | - done_shards: a set of already done shards 28 | - start_shard_id: the shard id to begin downloading from 29 | """ 30 | 31 | def __init__( 32 | self, 33 | url_list, 34 | input_format, 35 | url_col, 36 | caption_col, 37 | verify_hash_col, 38 | verify_hash_type, 39 | save_additional_columns, 40 | number_sample_per_shard, 41 | done_shards, 42 | tmp_path, 43 | start_shard_id: int = 0, 44 | ) -> None: 45 | self.input_format = input_format 46 | self.url_col = url_col 47 | self.caption_col = caption_col 48 | self.verify_hash_col = verify_hash_col 49 | self.verify_hash_type = verify_hash_type 50 | self.save_additional_columns = save_additional_columns 51 | self.number_sample_per_shard = number_sample_per_shard 52 | self.done_shards = done_shards 53 | self.start_shard_id = start_shard_id 54 | 55 | fs, url_path = fsspec.core.url_to_fs(url_list) 56 | self.fs = fs 57 | self.tmp_path = tmp_path 58 | 59 | if fs.isdir(url_path): 60 | self.input_files = sorted(fs.glob(url_path.rstrip("/") + "/*." + input_format)) 61 | if len(self.input_files) == 0: 62 | raise ValueError(f"No file found at path {url_path} with extension {input_format}") 63 | else: 64 | self.input_files = [url_path] 65 | 66 | if self.input_format in ["txt", "txt.gz"]: 67 | self.column_list = ["url"] 68 | elif self.input_format in ["json", "json.gz", "jsonl", "jsonl.gz", "csv", "csv.gz", "tsv", "tsv.gz", "parquet"]: 69 | self.column_list = self.save_additional_columns if self.save_additional_columns is not None else [] 70 | if self.caption_col is not None: 71 | self.column_list = self.column_list + ["caption"] 72 | if self.verify_hash_col is not None: 73 | if self.verify_hash_type in ["md5", "sha256", "sha512"]: 74 | self.column_list = self.column_list + [self.verify_hash_type] 75 | else: 76 | raise ValueError(f"Invalid hash type {self.verify_hash_type}") 77 | self.column_list = self.column_list + ["url"] 78 | else: 79 | raise ValueError(f"Invalid input format {self.input_format}") 80 | 81 | def _save_to_arrow(self, input_file, start_shard_id): 82 | """Read the input file and save to arrow files in a temporary directory""" 83 | if self.input_format in [ 84 | "txt", 85 | "txt.gz", 86 | "csv", 87 | "csv.gz", 88 | "tsv", 89 | "tsv.gz", 90 | "json", 91 | "json.gz", 92 | "jsonl", 93 | "jsonl.gz", 94 | ]: 95 | compression = None 96 | if self.input_format.endswith(".gz"): 97 | compression = "gzip" 98 | with self.fs.open(input_file, encoding="utf-8", mode="rb", compression=compression) as file: 99 | if self.input_format in ["txt", "txt.gz"]: 100 | df = csv_pa.read_csv(file, read_options=csv_pa.ReadOptions(column_names=["url"])) 101 | elif self.input_format in ["json", "json.gz"]: 102 | df = pa.Table.from_pandas(pd.read_json(file)) 103 | elif self.input_format in ["csv", "csv.gz"]: 104 | df = csv_pa.read_csv(file) 105 | elif self.input_format in ["tsv", "tsv.gz"]: 106 | df = csv_pa.read_csv(file, parse_options=csv_pa.ParseOptions(delimiter="\t")) 107 | elif self.input_format in ["jsonl", "jsonl.gz"]: 108 | df = json_pa.read_json(file) 109 | else: 110 | raise ValueError(f"Unknown input format {self.input_format}") 111 | elif self.input_format == "parquet": 112 | with self.fs.open(input_file, mode="rb") as file: 113 | columns_to_read = [self.url_col] 114 | if self.caption_col is not None: 115 | columns_to_read += [self.caption_col] 116 | if self.verify_hash_col is not None: 117 | columns_to_read += [self.verify_hash_col] 118 | if self.save_additional_columns is not None: 119 | columns_to_read += self.save_additional_columns 120 | df = pq.read_table(file, columns=columns_to_read) 121 | else: 122 | raise ValueError(f"Unknown input format {self.input_format}") 123 | 124 | column_names = df.column_names 125 | if self.caption_col is not None: 126 | column_names = [c if c != self.caption_col else "caption" for c in column_names] 127 | 128 | if self.verify_hash_col is not None: 129 | column_names = [c if c != self.verify_hash_col else self.verify_hash_type for c in column_names] 130 | 131 | column_names = [c if c != self.url_col else "url" for c in column_names] 132 | 133 | df = df.rename_columns(column_names) 134 | 135 | number_samples = df.num_rows 136 | 137 | number_shards = math.ceil(df.num_rows / self.number_sample_per_shard) 138 | shards_to_write = [ 139 | (start_shard_id + shard_id, shard_id) 140 | for shard_id in range(number_shards) 141 | if start_shard_id + shard_id not in self.done_shards 142 | ] 143 | if len(shards_to_write) == 0: 144 | return [], number_shards 145 | 146 | def write_shard(t): 147 | full_shard_id, shard_id = t 148 | begin_shard = shard_id * self.number_sample_per_shard 149 | end_shard = min(number_samples, (1 + shard_id) * self.number_sample_per_shard) 150 | df_shard = df.slice(begin_shard, end_shard - begin_shard).select(self.column_list) 151 | tmp_file = self.tmp_path + f"/{full_shard_id}.feather" 152 | for i in range(10): 153 | try: 154 | fs, tmp_path = fsspec.core.url_to_fs(tmp_file) 155 | with fs.open(tmp_path, "wb") as file: 156 | with pa.ipc.new_file(file, df_shard.schema) as writer: 157 | writer.write_table(df_shard) 158 | return (full_shard_id, tmp_file) 159 | except Exception as e: # pylint: disable=broad-except 160 | if i != 9: 161 | print("retrying to write to file due to error:", e) 162 | time.sleep(1) 163 | else: 164 | raise e 165 | # can't reach here 166 | raise ValueError("Failed to write to file.") 167 | 168 | for i in range(10): 169 | shards = [] 170 | # thread pool to make it faster to write files to low latency file systems (ie s3, hdfs) 171 | try: 172 | with ThreadPool(32) as thread_pool: 173 | for shard in thread_pool.imap_unordered(write_shard, shards_to_write): 174 | shards.append(shard) 175 | break 176 | except Exception as e: # pylint: disable=broad-except 177 | if i != 9: 178 | print("retrying whole sharding to write to files due to error:", e) 179 | time.sleep(2 * i) 180 | else: 181 | raise e 182 | 183 | shards.sort(key=lambda k: k[0]) 184 | 185 | del df 186 | 187 | return shards, number_shards 188 | 189 | def __iter__(self): 190 | """ 191 | Iterate over shards, yield shards of size number_sample_per_shard or less for the last one 192 | Each shard is a tuple (shard_id, shard) 193 | shard is a tuple (sample id, sample) 194 | sample is a tuple of the columns 195 | """ 196 | start_shard_id = self.start_shard_id 197 | for i, input_file in enumerate(self.input_files): 198 | print("Sharding file number " + str(i + 1) + " of " + str(len(self.input_files)) + " called " + input_file) 199 | 200 | shards, number_shards = self._save_to_arrow(input_file, start_shard_id) 201 | print("File sharded in " + str(len(shards)) + " shards") 202 | print( 203 | "Downloading starting now, check your bandwidth speed (with bwm-ng)" 204 | "your cpu (with htop), and your disk usage (with iotop)!" 205 | ) 206 | 207 | for shard_id, arrow_file in shards: 208 | yield ( 209 | shard_id, 210 | arrow_file, 211 | ) 212 | start_shard_id += number_shards 213 | -------------------------------------------------------------------------------- /img2dataset/resizer.py: -------------------------------------------------------------------------------- 1 | """resizer module handle image resizing""" 2 | 3 | import albumentations as A 4 | import cv2 5 | import numpy as np 6 | from enum import Enum 7 | import imghdr 8 | import os 9 | 10 | _INTER_STR_TO_CV2 = { 11 | "nearest": cv2.INTER_NEAREST, 12 | "linear": cv2.INTER_LINEAR, 13 | "bilinear": cv2.INTER_LINEAR, 14 | "cubic": cv2.INTER_CUBIC, 15 | "bicubic": cv2.INTER_CUBIC, 16 | "area": cv2.INTER_AREA, 17 | "lanczos": cv2.INTER_LANCZOS4, 18 | "lanczos4": cv2.INTER_LANCZOS4, 19 | } 20 | 21 | 22 | class ResizeMode(Enum): 23 | no = 0 # pylint: disable=invalid-name 24 | keep_ratio = 1 # pylint: disable=invalid-name 25 | center_crop = 2 # pylint: disable=invalid-name 26 | border = 3 # pylint: disable=invalid-name 27 | keep_ratio_largest = 4 # pylint: disable=invalid-name 28 | 29 | 30 | # thanks https://stackoverflow.com/questions/11130156/suppress-stdout-stderr-print-from-python-functions 31 | class SuppressStdoutStderr: 32 | """ 33 | A context manager for doing a "deep suppression" of stdout and stderr in 34 | Python, i.e. will suppress all print, even if the print originates in a 35 | compiled C/Fortran sub-function. 36 | This will not suppress raised exceptions, since exceptions are printed 37 | to stderr just before a script exits, and after the context manager has 38 | exited (at least, I think that is why it lets exceptions through). 39 | 40 | """ 41 | 42 | def __init__(self): 43 | # Open a pair of null files 44 | self.null_fds = [os.open(os.devnull, os.O_RDWR) for x in range(2)] 45 | # Save the actual stdout (1) and stderr (2) file descriptors. 46 | self.save_fds = [os.dup(1), os.dup(2)] 47 | 48 | def __enter__(self): 49 | # Assign the null pointers to stdout and stderr. 50 | os.dup2(self.null_fds[0], 1) 51 | os.dup2(self.null_fds[1], 2) 52 | 53 | def __exit__(self, *_): 54 | # Re-assign the real stdout/stderr back to (1) and (2) 55 | os.dup2(self.save_fds[0], 1) 56 | os.dup2(self.save_fds[1], 2) 57 | # Close all file descriptors 58 | for fd in self.null_fds + self.save_fds: 59 | os.close(fd) 60 | 61 | 62 | def inter_str_to_cv2(inter_str): 63 | inter_str = inter_str.lower() 64 | if inter_str not in _INTER_STR_TO_CV2: 65 | raise ValueError(f"Invalid option for interpolation: {inter_str}") 66 | return _INTER_STR_TO_CV2[inter_str] 67 | 68 | 69 | class Resizer: 70 | """ 71 | Resize images 72 | Expose a __call__ method to be used as a callable object 73 | 74 | Should be used to resize one image at a time 75 | 76 | Options: 77 | resize_mode: "no", "keep_ratio", "center_crop", "border" 78 | resize_only_if_bigger: if True, resize only if image is bigger than image_size 79 | image_size: size of the output image to resize 80 | """ 81 | 82 | def __init__( 83 | self, 84 | image_size, 85 | resize_mode, 86 | resize_only_if_bigger, 87 | upscale_interpolation="lanczos", 88 | downscale_interpolation="area", 89 | encode_quality=95, 90 | encode_format="jpg", 91 | skip_reencode=False, 92 | disable_all_reencoding=False, 93 | min_image_size=0, 94 | max_image_area=float("inf"), 95 | max_aspect_ratio=float("inf"), 96 | blurrer=None, 97 | ): 98 | if encode_format not in ["jpg", "png", "webp"]: 99 | raise ValueError(f"Invalid encode format {encode_format}") 100 | if encode_format == "png": 101 | if encode_quality < 0 or encode_quality > 9: 102 | raise ValueError( 103 | "For png, encode quality represents compression which" 104 | f"must be between 0 and 9, got {encode_quality}" 105 | ) 106 | 107 | self.image_size = image_size 108 | if isinstance(resize_mode, str): 109 | if resize_mode not in ResizeMode.__members__: # pylint: disable=unsupported-membership-test 110 | raise ValueError(f"Invalid option for resize_mode: {resize_mode}") 111 | resize_mode = ResizeMode[resize_mode] 112 | self.resize_mode = resize_mode 113 | self.resize_only_if_bigger = resize_only_if_bigger 114 | self.upscale_interpolation = inter_str_to_cv2(upscale_interpolation) 115 | self.downscale_interpolation = inter_str_to_cv2(downscale_interpolation) 116 | self.encode_format = encode_format 117 | cv2_img_quality = None 118 | if encode_format == "jpg": 119 | cv2_img_quality = int(cv2.IMWRITE_JPEG_QUALITY) 120 | self.what_ext = "jpeg" 121 | elif encode_format == "png": 122 | cv2_img_quality = int(cv2.IMWRITE_PNG_COMPRESSION) 123 | self.what_ext = "png" 124 | elif encode_format == "webp": 125 | cv2_img_quality = int(cv2.IMWRITE_WEBP_QUALITY) 126 | self.what_ext = "webp" 127 | if cv2_img_quality is None: 128 | raise ValueError(f"Invalid option for encode_format: {encode_format}") 129 | self.encode_params = [cv2_img_quality, encode_quality] 130 | self.skip_reencode = skip_reencode 131 | self.disable_all_reencoding = disable_all_reencoding 132 | self.min_image_size = min_image_size 133 | self.max_image_area = max_image_area 134 | self.max_aspect_ratio = max_aspect_ratio 135 | self.blurrer = blurrer 136 | 137 | def __call__(self, img_stream, blurring_bbox_list=None): 138 | """ 139 | input: an image stream, optionally a list of bounding boxes to blur. 140 | output: img_str, width, height, original_width, original_height, err 141 | """ 142 | try: 143 | if self.disable_all_reencoding: 144 | return img_stream.read(), None, None, None, None, None 145 | with SuppressStdoutStderr(): 146 | cv2.setNumThreads(1) 147 | img_stream.seek(0) 148 | encode_needed = imghdr.what(img_stream) != self.what_ext if self.skip_reencode else True 149 | img_stream.seek(0) 150 | img_buf = np.frombuffer(img_stream.read(), np.uint8) 151 | img = cv2.imdecode(img_buf, cv2.IMREAD_UNCHANGED) 152 | if img is None: 153 | raise ValueError("Image decoding error") 154 | if len(img.shape) == 3 and img.shape[-1] == 4: 155 | # alpha matting with white background 156 | alpha = img[:, :, 3, np.newaxis] 157 | img = alpha / 255 * img[..., :3] + 255 - alpha 158 | img = np.rint(img.clip(min=0, max=255)).astype(np.uint8) 159 | encode_needed = True 160 | original_height, original_width = img.shape[:2] 161 | # check if image is too small 162 | if min(original_height, original_width) < self.min_image_size: 163 | return None, None, None, None, None, "image too small" 164 | if original_height * original_width > self.max_image_area: 165 | return None, None, None, None, None, "image area too large" 166 | # check if wrong aspect ratio 167 | if max(original_height, original_width) / min(original_height, original_width) > self.max_aspect_ratio: 168 | return None, None, None, None, None, "aspect ratio too large" 169 | 170 | # check if resizer was defined during init if needed 171 | if blurring_bbox_list is not None and self.blurrer is None: 172 | return None, None, None, None, None, "blurrer not defined" 173 | 174 | # Flag to check if blurring is still needed. 175 | maybe_blur_still_needed = True 176 | 177 | # resizing in following conditions 178 | if self.resize_mode in (ResizeMode.keep_ratio, ResizeMode.center_crop): 179 | downscale = min(original_width, original_height) > self.image_size 180 | if not self.resize_only_if_bigger or downscale: 181 | interpolation = self.downscale_interpolation if downscale else self.upscale_interpolation 182 | img = A.smallest_max_size(img, self.image_size, interpolation=interpolation) 183 | if blurring_bbox_list is not None and self.blurrer is not None: 184 | img = self.blurrer(img=img, bbox_list=blurring_bbox_list) 185 | if self.resize_mode == ResizeMode.center_crop: 186 | img = A.center_crop(img, self.image_size, self.image_size) 187 | encode_needed = True 188 | maybe_blur_still_needed = False 189 | elif self.resize_mode in (ResizeMode.border, ResizeMode.keep_ratio_largest): 190 | downscale = max(original_width, original_height) > self.image_size 191 | if not self.resize_only_if_bigger or downscale: 192 | interpolation = self.downscale_interpolation if downscale else self.upscale_interpolation 193 | img = A.longest_max_size(img, self.image_size, interpolation=interpolation) 194 | if blurring_bbox_list is not None and self.blurrer is not None: 195 | img = self.blurrer(img=img, bbox_list=blurring_bbox_list) 196 | if self.resize_mode == ResizeMode.border: 197 | img = A.pad( 198 | img, 199 | self.image_size, 200 | self.image_size, 201 | border_mode=cv2.BORDER_CONSTANT, 202 | value=[255, 255, 255], 203 | ) 204 | encode_needed = True 205 | maybe_blur_still_needed = False 206 | 207 | # blur parts of the image if needed 208 | if maybe_blur_still_needed and blurring_bbox_list is not None and self.blurrer is not None: 209 | img = self.blurrer(img=img, bbox_list=blurring_bbox_list) 210 | 211 | height, width = img.shape[:2] 212 | if encode_needed: 213 | img_str = cv2.imencode(f".{self.encode_format}", img, params=self.encode_params)[1].tobytes() 214 | else: 215 | img_str = img_buf.tobytes() 216 | return img_str, width, height, original_width, original_height, None 217 | 218 | except Exception as err: # pylint: disable=broad-except 219 | return None, None, None, None, None, str(err) 220 | -------------------------------------------------------------------------------- /img2dataset/writer.py: -------------------------------------------------------------------------------- 1 | """"writer module handle writing the images to disk""" 2 | 3 | import json 4 | import os 5 | 6 | import fsspec 7 | import numpy as np 8 | import pyarrow as pa 9 | import pyarrow.parquet as pq 10 | import webdataset as wds 11 | 12 | 13 | class BufferedParquetWriter: 14 | """Write samples to parquet files incrementally with a buffer""" 15 | 16 | def __init__(self, output_file, schema, buffer_size=100): 17 | self.buffer_size = buffer_size 18 | self.schema = schema 19 | self._initiatlize_buffer() 20 | fs, output_path = fsspec.core.url_to_fs(output_file) 21 | self.output_fd = fs.open(output_path, "wb") 22 | self.parquet_writer = pq.ParquetWriter(self.output_fd, schema) 23 | 24 | def _initiatlize_buffer(self): 25 | self.current_buffer_size = 0 26 | self.buffer = {k: [] for k in self.schema.names} 27 | 28 | def _add_sample_to_buffer(self, sample): 29 | for k in self.schema.names: 30 | self.buffer[k].append(sample[k]) 31 | self.current_buffer_size += 1 32 | 33 | def write(self, sample): 34 | if self.current_buffer_size >= self.buffer_size: 35 | self.flush() 36 | self._add_sample_to_buffer(sample) 37 | 38 | def flush(self): 39 | """Write the buffer to disk""" 40 | if self.current_buffer_size == 0: 41 | return 42 | 43 | df = pa.Table.from_pydict(self.buffer, self.schema) 44 | self.parquet_writer.write_table(df) 45 | self._initiatlize_buffer() 46 | 47 | def close(self): 48 | self.flush() 49 | if self.parquet_writer is not None: 50 | self.parquet_writer.close() 51 | self.parquet_writer = None 52 | self.output_fd.close() 53 | 54 | 55 | class ParquetSampleWriter: 56 | """ParquetSampleWriter is a image+caption writer to parquet""" 57 | 58 | def __init__( 59 | self, 60 | shard_id, 61 | output_folder, 62 | save_caption, 63 | oom_shard_count, 64 | schema, 65 | encode_format, 66 | ): 67 | self.oom_shard_count = oom_shard_count 68 | self.encode_format = encode_format 69 | schema = schema.append(pa.field(encode_format, pa.binary())) 70 | shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string 71 | shard_id=shard_id, oom_shard_count=oom_shard_count 72 | ) 73 | output_file = f"{output_folder}/{shard_name}.parquet" 74 | self.buffered_parquet_writer = BufferedParquetWriter(output_file, schema, 100) 75 | self.save_caption = save_caption 76 | 77 | def write(self, img_str, key, caption, meta): 78 | """Keep sample in memory then write to disk when close() is called""" 79 | if img_str is not None: 80 | sample = {"key": key, self.encode_format: img_str} 81 | if self.save_caption: 82 | sample["txt"] = str(caption) if caption is not None else "" 83 | else: 84 | sample = {"key": key, self.encode_format: None} 85 | if self.save_caption: 86 | sample["txt"] = None 87 | sample.update(meta) 88 | self.buffered_parquet_writer.write(sample) 89 | 90 | def close(self): 91 | self.buffered_parquet_writer.close() 92 | 93 | 94 | class WebDatasetSampleWriter: 95 | """WebDatasetSampleWriter is a image+caption writer to webdataset""" 96 | 97 | def __init__( 98 | self, 99 | shard_id, 100 | output_folder, 101 | save_caption, 102 | oom_shard_count, 103 | schema, 104 | encode_format, 105 | ): 106 | self.oom_shard_count = oom_shard_count 107 | shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string 108 | shard_id=shard_id, oom_shard_count=oom_shard_count 109 | ) 110 | self.shard_id = shard_id 111 | fs, output_path = fsspec.core.url_to_fs(output_folder) 112 | self.tar_fd = fs.open(f"{output_path}/{shard_name}.tar", "wb") 113 | self.tarwriter = wds.TarWriter(self.tar_fd) 114 | self.save_caption = save_caption 115 | self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) 116 | self.encode_format = encode_format 117 | 118 | def write(self, img_str, key, caption, meta): 119 | """write sample to tars""" 120 | if img_str is not None: 121 | sample = {"__key__": key, self.encode_format: img_str} 122 | if self.save_caption: 123 | sample["txt"] = str(caption) if caption is not None else "" 124 | # some meta data may not be JSON serializable 125 | for k, v in meta.items(): 126 | if isinstance(v, np.ndarray): 127 | meta[k] = v.tolist() 128 | sample["json"] = json.dumps(meta, indent=4) 129 | self.tarwriter.write(sample) 130 | self.buffered_parquet_writer.write(meta) 131 | 132 | def close(self): 133 | self.buffered_parquet_writer.close() 134 | self.tarwriter.close() 135 | self.tar_fd.close() 136 | 137 | 138 | class TFRecordSampleWriter: 139 | """TFRecordSampleWriter is a image+caption writer to TFRecord""" 140 | 141 | def __init__( 142 | self, 143 | shard_id, 144 | output_folder, 145 | save_caption, 146 | oom_shard_count, 147 | schema, 148 | encode_format, 149 | ): 150 | try: 151 | os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" 152 | import tensorflow_io as _ # pylint: disable=import-outside-toplevel 153 | from tensorflow.python.lib.io.tf_record import TFRecordWriter # pylint: disable=import-outside-toplevel 154 | from tensorflow.python.training.training import ( # pylint: disable=import-outside-toplevel 155 | BytesList, 156 | Example, 157 | Feature, 158 | Features, 159 | FloatList, 160 | Int64List, 161 | ) 162 | 163 | self._BytesList = BytesList # pylint: disable=invalid-name 164 | self._Int64List = Int64List # pylint: disable=invalid-name 165 | self._FloatList = FloatList # pylint: disable=invalid-name 166 | self._Example = Example # pylint: disable=invalid-name 167 | self._Features = Features # pylint: disable=invalid-name 168 | self._Feature = Feature # pylint: disable=invalid-name 169 | except ImportError as e: 170 | raise ModuleNotFoundError( 171 | "tfrecords require tensorflow and tensorflow_io to be installed." 172 | "Run `pip install tensorflow tensorflow_io`." 173 | ) from e 174 | 175 | self.oom_shard_count = oom_shard_count 176 | shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string 177 | shard_id=shard_id, oom_shard_count=oom_shard_count 178 | ) 179 | self.shard_id = shard_id 180 | self.tf_writer = TFRecordWriter(f"{output_folder}/{shard_name}.tfrecord") 181 | self.save_caption = save_caption 182 | self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) 183 | self.encode_format = encode_format 184 | 185 | def write(self, img_str, key, caption, meta): 186 | """Write a sample using tfrecord writer""" 187 | if img_str is not None: 188 | sample = { 189 | "key": self._bytes_feature(key.encode()), 190 | self.encode_format: self._bytes_feature(img_str), 191 | } 192 | if self.save_caption: 193 | sample["txt"] = self._bytes_feature(str(caption) if caption is not None else "") 194 | for k, v in meta.items(): 195 | sample[k] = self._feature(v) 196 | tf_example = self._Example(features=self._Features(feature=sample)) 197 | self.tf_writer.write(tf_example.SerializeToString()) 198 | self.buffered_parquet_writer.write(meta) 199 | 200 | def close(self): 201 | self.buffered_parquet_writer.close() 202 | self.tf_writer.close() 203 | 204 | def _feature(self, value): 205 | """Convert to proper feature type""" 206 | if isinstance(value, list): 207 | return self._list_feature(value) 208 | elif isinstance(value, int): 209 | return self._int64_feature(value) 210 | elif isinstance(value, float): 211 | return self._float_feature(value) 212 | else: 213 | return self._bytes_feature(value) 214 | 215 | def _bytes_feature(self, value): 216 | """Returns a bytes_list from a string / byte.""" 217 | if value is None: 218 | value = "" 219 | if isinstance(value, str): 220 | value = value.encode() 221 | return self._Feature(bytes_list=self._BytesList(value=[value])) 222 | 223 | def _float_feature(self, value): 224 | """Returns a float_list from a float / double.""" 225 | return self._Feature(float_list=self._FloatList(value=[value])) 226 | 227 | def _int64_feature(self, value): 228 | """Returns an int64_list from a bool / enum / int / uint.""" 229 | return self._Feature(int64_list=self._Int64List(value=[value])) 230 | 231 | def _list_feature(self, value): 232 | """Returns an list of int64_list, float_list, bytes_list.""" 233 | if isinstance(value[0], int): 234 | return self._Feature(int64_list=self._Int64List(value=value)) 235 | elif isinstance(value[0], float): 236 | return self._Feature(float_list=self._FloatList(value=value)) 237 | else: 238 | for i, bytes_feature in enumerate(value): 239 | if bytes_feature is None: 240 | value[i] = "" 241 | if isinstance(bytes_feature, str): 242 | value[i] = bytes_feature.encode() 243 | return self._Feature(bytes_list=self._BytesList(value=value)) 244 | 245 | 246 | class FilesSampleWriter: 247 | """FilesSampleWriter is a caption+image writer to files""" 248 | 249 | def __init__( 250 | self, 251 | shard_id, 252 | output_folder, 253 | save_caption, 254 | oom_shard_count, 255 | schema, 256 | encode_format, 257 | ): 258 | self.oom_shard_count = oom_shard_count 259 | shard_name = "{shard_id:0{oom_shard_count}d}".format( # pylint: disable=consider-using-f-string 260 | shard_id=shard_id, oom_shard_count=oom_shard_count 261 | ) 262 | self.shard_id = shard_id 263 | self.fs, self.subfolder = fsspec.core.url_to_fs(f"{output_folder}/{shard_name}") 264 | if not self.fs.exists(self.subfolder): 265 | self.fs.mkdir(self.subfolder) 266 | self.save_caption = save_caption 267 | self.buffered_parquet_writer = BufferedParquetWriter(output_folder + "/" + shard_name + ".parquet", schema, 100) 268 | self.encode_format = encode_format 269 | 270 | def write(self, img_str, key, caption, meta): 271 | """Write sample to disk""" 272 | if img_str is not None: 273 | filename = f"{self.subfolder}/{key}.{self.encode_format}" 274 | with self.fs.open(filename, "wb") as f: 275 | f.write(img_str) 276 | if self.save_caption: 277 | caption = str(caption) if caption is not None else "" 278 | caption_filename = f"{self.subfolder}/{key}.txt" 279 | with self.fs.open(caption_filename, "w") as f: 280 | f.write(str(caption)) 281 | 282 | # some meta data may not be JSON serializable 283 | for k, v in meta.items(): 284 | if isinstance(v, np.ndarray): 285 | meta[k] = v.tolist() 286 | j = json.dumps(meta, indent=4) 287 | meta_filename = f"{self.subfolder}/{key}.json" 288 | with self.fs.open(meta_filename, "w") as f: 289 | f.write(j) 290 | self.buffered_parquet_writer.write(meta) 291 | 292 | def close(self): 293 | self.buffered_parquet_writer.close() 294 | 295 | 296 | class DummySampleWriter: 297 | """Does not write""" 298 | 299 | def __init__(self, shard_id, output_folder, save_caption, oom_shard_count, schema, encode_format): 300 | pass 301 | 302 | def write(self, img_str, key, caption, meta): 303 | pass 304 | 305 | def close(self): 306 | pass 307 | -------------------------------------------------------------------------------- /mypy.ini: -------------------------------------------------------------------------------- 1 | # Global options: 2 | 3 | [mypy] 4 | python_version = 3.8 5 | ignore_missing_imports = True 6 | -------------------------------------------------------------------------------- /requirements-test.txt: -------------------------------------------------------------------------------- 1 | black==24.1.1 2 | mypy==1.8.0 3 | pylint==3.0.3 4 | pytest-cov==4.1.0 5 | pytest-xdist==3.5.0 6 | pytest==8.0.0 7 | psutil 8 | pyspark 9 | uvicorn 10 | fastapi 11 | tensorflow 12 | tensorflow_io 13 | types-requests 14 | types-pkg_resources 15 | ray 16 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm>=4.62.3,<5 2 | opencv-python-headless>=4.5.5.62,<5 3 | fire>=0.4.0,<0.6.0 4 | webdataset>=0.2.5,<0.3 5 | pandas>=1.1.5,<3 6 | pyarrow>=6.0.1,<16 7 | exifread-nocycle>=3.0.1,<4 8 | albumentations>=1.1.0,<2 9 | dataclasses>=0.6,<1.0.0 10 | wandb>=0.16.0,<0.17 11 | fsspec 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | from pathlib import Path 3 | import os 4 | 5 | if __name__ == "__main__": 6 | with Path(Path(__file__).parent, "README.md").open(encoding="utf-8") as file: 7 | long_description = file.read() 8 | 9 | def _read_reqs(relpath): 10 | fullpath = os.path.join(os.path.dirname(__file__), relpath) 11 | with open(fullpath) as f: 12 | return [s.strip() for s in f.readlines() if (s.strip() and not s.startswith("#"))] 13 | 14 | REQUIREMENTS = _read_reqs("requirements.txt") 15 | 16 | setup( 17 | name="img2dataset", 18 | packages=find_packages(), 19 | include_package_data=True, 20 | version="1.45.0", 21 | license="MIT", 22 | description="Easily turn a set of image urls to an image dataset", 23 | long_description=long_description, 24 | long_description_content_type="text/markdown", 25 | entry_points={"console_scripts": ["img2dataset = img2dataset:main"]}, 26 | author="Romain Beaumont", 27 | author_email="romain.rom1@gmail.com", 28 | url="https://github.com/rom1504/img2dataset", 29 | data_files=[(".", ["README.md"])], 30 | keywords=["machine learning", "computer vision", "download", "image", "dataset"], 31 | install_requires=REQUIREMENTS, 32 | classifiers=[ 33 | "Development Status :: 4 - Beta", 34 | "Intended Audience :: Developers", 35 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 36 | "License :: OSI Approved :: MIT License", 37 | "Programming Language :: Python :: 3.6", 38 | ], 39 | ) 40 | -------------------------------------------------------------------------------- /tests/blur_test_files/bbox.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/bbox.npy -------------------------------------------------------------------------------- /tests/blur_test_files/blurred.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/blurred.png -------------------------------------------------------------------------------- /tests/blur_test_files/original.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/original.png -------------------------------------------------------------------------------- /tests/blur_test_files/resize_border.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/resize_border.jpg -------------------------------------------------------------------------------- /tests/blur_test_files/resize_center_crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/resize_center_crop.jpg -------------------------------------------------------------------------------- /tests/blur_test_files/resize_keep_ratio.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/resize_keep_ratio.jpg -------------------------------------------------------------------------------- /tests/blur_test_files/resize_keep_ratio_largest.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/resize_keep_ratio_largest.jpg -------------------------------------------------------------------------------- /tests/blur_test_files/resize_no.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/resize_no.jpg -------------------------------------------------------------------------------- /tests/blur_test_files/test_bbox.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/blur_test_files/test_bbox.parquet -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import urllib.request 3 | import subprocess 4 | import sys 5 | import time 6 | import sys 7 | 8 | 9 | def spawn_and_wait_server(): 10 | port = f"123{sys.version_info.minor}" 11 | process = subprocess.Popen( 12 | [ 13 | sys.executable, 14 | "-m", 15 | "uvicorn", 16 | "tests.http_server:app", 17 | "--port", 18 | str(port), 19 | ] 20 | ) 21 | while True: 22 | try: 23 | urllib.request.urlopen(f"http://localhost:{port}") 24 | except Exception as e: 25 | time.sleep(0.1) 26 | else: 27 | break 28 | return process 29 | 30 | 31 | # credits to pytest-xdist's README 32 | @pytest.fixture(scope="session", autouse=True) 33 | def http_server(tmp_path_factory, worker_id): 34 | if worker_id == "master": 35 | # single worker: just run the HTTP server 36 | process = spawn_and_wait_server() 37 | yield process 38 | process.kill() 39 | process.wait() 40 | return 41 | 42 | # get the temp directory shared by all workers 43 | root_tmp_dir = tmp_path_factory.getbasetemp().parent 44 | 45 | # try to get a lock 46 | lock = root_tmp_dir / "lock" 47 | try: 48 | lock.mkdir(exist_ok=False) 49 | except FileExistsError: 50 | yield # failed, don't run the HTTP server 51 | return 52 | 53 | # got the lock, run the HTTP server 54 | process = spawn_and_wait_server() 55 | yield process 56 | process.kill() 57 | process.wait() 58 | -------------------------------------------------------------------------------- /tests/fixtures.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import cv2 3 | import glob 4 | import random 5 | import os 6 | import sys 7 | import gzip 8 | 9 | 10 | def setup_fixtures(count=5, disallowed=0): 11 | test_list = [] 12 | current_folder = os.path.dirname(__file__) 13 | test_folder = current_folder + "/" + "resize_test_image" 14 | port = f"123{sys.version_info.minor}" 15 | image_paths = glob.glob(test_folder + "/*") 16 | for i in range(count): 17 | item = random.randint(0, len(image_paths) - 1) 18 | test_list.append( 19 | ( 20 | f"caption {i}" if i != 0 else "", 21 | image_paths[item].replace(test_folder, f"http://localhost:{port}/allowed"), 22 | ) 23 | ) 24 | test_list = test_list[:count] 25 | 26 | for i in range(disallowed): 27 | item = random.randint(0, len(image_paths) - 1) 28 | test_list.append( 29 | ( 30 | f"caption {i}" if i != 0 else "", 31 | image_paths[item].replace(test_folder, f"http://localhost:{port}/disallowed"), 32 | ) 33 | ) 34 | test_list = test_list[: count + disallowed] 35 | 36 | return test_list 37 | 38 | 39 | def generate_url_list_txt(output_file, test_list, compression_on=False): 40 | if compression_on: 41 | f = gzip.open(output_file, "wt") 42 | else: 43 | f = open(output_file, "w") 44 | with f: 45 | for _, url in test_list: 46 | f.write(url + "\n") 47 | 48 | 49 | def generate_csv(output_file, test_list, compression=None): 50 | df = pd.DataFrame(test_list, columns=["caption", "url"]) 51 | df.to_csv(output_file, compression=compression) 52 | 53 | 54 | def generate_tsv(output_file, test_list, compression=None): 55 | df = pd.DataFrame(test_list, columns=["caption", "url"]) 56 | df.to_csv(output_file, sep="\t", compression=compression) 57 | 58 | 59 | def generate_json(output_file, test_list, compression=None): 60 | df = pd.DataFrame(test_list, columns=["caption", "url"]) 61 | df.to_json(output_file, compression=compression) 62 | 63 | 64 | def generate_jsonl(output_file, test_list, compression=None): 65 | df = pd.DataFrame(test_list, columns=["caption", "url"]) 66 | df.to_json(output_file, orient="records", lines=True, compression=compression) 67 | 68 | 69 | def generate_parquet(output_file, test_list): 70 | df = pd.DataFrame(test_list, columns=["caption", "url"]) 71 | df.to_parquet(output_file) 72 | 73 | 74 | def generate_input_file(input_format, url_list_name, test_list): 75 | if input_format == "txt": 76 | url_list_name += ".txt" 77 | generate_url_list_txt(url_list_name, test_list) 78 | elif input_format == "txt.gz": 79 | url_list_name += ".txt.gz" 80 | generate_url_list_txt(url_list_name, test_list, True) 81 | elif input_format == "csv": 82 | url_list_name += ".csv" 83 | generate_csv(url_list_name, test_list) 84 | elif input_format == "csv.gz": 85 | url_list_name += ".csv.gz" 86 | generate_csv(url_list_name, test_list, "gzip") 87 | elif input_format == "tsv": 88 | url_list_name += ".tsv" 89 | generate_tsv(url_list_name, test_list) 90 | elif input_format == "tsv.gz": 91 | url_list_name += ".tsv.gz" 92 | generate_tsv(url_list_name, test_list, "gzip") 93 | elif input_format == "json": 94 | url_list_name += ".json" 95 | generate_json(url_list_name, test_list) 96 | elif input_format == "json.gz": 97 | url_list_name += ".json.gz" 98 | generate_json(url_list_name, test_list, "gzip") 99 | elif input_format == "jsonl": 100 | url_list_name += ".jsonl" 101 | generate_jsonl(url_list_name, test_list) 102 | elif input_format == "jsonl.gz": 103 | url_list_name += ".jsonl.gz" 104 | generate_jsonl(url_list_name, test_list, "gzip") 105 | elif input_format == "parquet": 106 | url_list_name += ".parquet" 107 | generate_parquet(url_list_name, test_list) 108 | 109 | return url_list_name 110 | 111 | 112 | def get_all_files(folder, ext): 113 | return sorted(list(glob.glob(folder + "/**/*." + ext, recursive=True))) 114 | 115 | 116 | def check_one_image_size(img, img_unresized, image_size, resize_mode, resize_only_if_bigger): 117 | width = img.shape[1] 118 | height = img.shape[0] 119 | width_unresized = img_unresized.shape[1] 120 | height_unresized = img_unresized.shape[0] 121 | resized = True 122 | if resize_only_if_bigger: 123 | if ( 124 | max(width_unresized, height_unresized) <= image_size 125 | and resize_mode == "border" 126 | or min(width_unresized, height_unresized) <= image_size 127 | and resize_mode in ["keep_ratio", "center_crop"] 128 | ): 129 | if width_unresized != width or height_unresized != height: 130 | raise Exception( 131 | f"Image size is not the same as the original one in resize only if bigger mode," 132 | f"expected={width_unresized}, {height_unresized} found={width}, {height}" 133 | ) 134 | else: 135 | resized = False 136 | 137 | if not resized: 138 | return 139 | 140 | if resize_mode == "border": 141 | if width != image_size or height != image_size: 142 | raise Exception(f"Image size is not 256x256 in border mode found={width}x{height}") 143 | elif resize_mode == "keep_ratio": 144 | ratio = float(image_size) / min(width_unresized, height_unresized) 145 | new_size = tuple([round(x * ratio) for x in [width_unresized, height_unresized]]) 146 | if new_size != (width, height): 147 | raise Exception( 148 | f"Image size is not of the right size in keep ratio mode" 149 | f"expected = {new_size[0]}, {new_size[1]} found = {width}, {height} " 150 | ) 151 | 152 | 153 | def check_image_size(file_list, l_unresized, image_size, resize_mode, resize_only_if_bigger): 154 | for file, file_unresized in zip(file_list, l_unresized): 155 | img = cv2.imread(file) 156 | img_unresized = cv2.imread(file_unresized) 157 | check_one_image_size(img, img_unresized, image_size, resize_mode, resize_only_if_bigger) 158 | -------------------------------------------------------------------------------- /tests/http_server.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from fastapi import FastAPI, Response 4 | from fastapi.staticfiles import StaticFiles 5 | 6 | 7 | class StaticFilesXRobotsTagHeader(StaticFiles): 8 | async def get_response(self, *args, **kwargs) -> Response: 9 | response = await super().get_response(*args, **kwargs) 10 | response.headers["X-Robots-Tag"] = "noai, noimageai, noindex, noimageindex, nofollow" 11 | return response 12 | 13 | 14 | app = FastAPI() 15 | 16 | current_folder = os.path.dirname(__file__) 17 | test_folder = str(current_folder) + "/" + "resize_test_image" 18 | 19 | 20 | @app.get("/") 21 | async def get(): 22 | return "hi" 23 | 24 | 25 | app.mount("/allowed", StaticFiles(directory=test_folder), name="static_allowed") 26 | app.mount("/disallowed", StaticFilesXRobotsTagHeader(directory=test_folder), name="static_disallowed") 27 | -------------------------------------------------------------------------------- /tests/resize_test_image/123_456.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/resize_test_image/123_456.jpg -------------------------------------------------------------------------------- /tests/resize_test_image/208_495.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/resize_test_image/208_495.jpg -------------------------------------------------------------------------------- /tests/resize_test_image/321_421.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/resize_test_image/321_421.jpg -------------------------------------------------------------------------------- /tests/resize_test_image/389_535.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/resize_test_image/389_535.jpg -------------------------------------------------------------------------------- /tests/resize_test_image/416_264.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/resize_test_image/416_264.jpg -------------------------------------------------------------------------------- /tests/resize_test_image/456_123.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/resize_test_image/456_123.jpg -------------------------------------------------------------------------------- /tests/resize_test_image/524_316.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/resize_test_image/524_316.jpg -------------------------------------------------------------------------------- /tests/test_blurrer.py: -------------------------------------------------------------------------------- 1 | """Tests for the bounding box blurring module.""" 2 | 3 | from img2dataset.blurrer import BoundingBoxBlurrer 4 | import os 5 | import pytest 6 | import cv2 7 | import numpy as np 8 | 9 | 10 | def test_blurrer(): 11 | """Test whether blurrer works properly.""" 12 | current_folder = os.path.dirname(__file__) 13 | test_folder = os.path.join(current_folder, "blur_test_files") 14 | orig_image_path = os.path.join(test_folder, "original.png") 15 | blur_image_path = os.path.join(test_folder, "blurred.png") 16 | bbox_path = os.path.join(test_folder, "bbox.npy") 17 | 18 | blurrer = BoundingBoxBlurrer() 19 | orig_image = cv2.imread(orig_image_path) 20 | blur_image = cv2.imread(blur_image_path) 21 | with open(bbox_path, "rb") as f: 22 | bbox = np.load(f) 23 | 24 | blur_image_test = blurrer(orig_image, bbox) 25 | 26 | assert np.array_equal(blur_image, blur_image_test) # Also checks for shape 27 | -------------------------------------------------------------------------------- /tests/test_downloader.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import pytest 3 | import json 4 | from fixtures import setup_fixtures 5 | from img2dataset.resizer import Resizer 6 | from img2dataset.writer import FilesSampleWriter 7 | from img2dataset.downloader import Downloader 8 | 9 | import os 10 | import pandas as pd 11 | 12 | 13 | @pytest.mark.parametrize("compute_hash", ["md5", "sha256", "sha512"]) 14 | def test_valid_hash(compute_hash, tmp_path): 15 | test_folder = str(tmp_path) 16 | current_folder = os.path.dirname(__file__) 17 | input_file = os.path.join(current_folder, "test_files", "sample_image.txt") 18 | with open(input_file, "r") as file: 19 | test_list = pd.DataFrame([(url.rstrip(),) for url in file.readlines()], columns=["url"]) 20 | 21 | image_folder_name = os.path.join(test_folder, "images") 22 | os.mkdir(image_folder_name) 23 | 24 | resizer = Resizer(256, "border", False) 25 | writer = FilesSampleWriter 26 | 27 | downloader = Downloader( 28 | writer, 29 | resizer, 30 | thread_count=32, 31 | save_caption=False, 32 | extract_exif=True, 33 | output_folder=image_folder_name, 34 | column_list=["url"], 35 | timeout=10, 36 | number_sample_per_shard=10, 37 | oom_shard_count=5, 38 | compute_hash=compute_hash, 39 | verify_hash_type=None, 40 | encode_format="jpg", 41 | retries=0, 42 | user_agent_token="img2dataset", 43 | disallowed_header_directives=["noai", "noindex"], 44 | ) 45 | 46 | tmp_file = os.path.join(test_folder, "sample_image.feather") 47 | df = pd.DataFrame(test_list, columns=["url"]) 48 | df.to_feather(tmp_file) 49 | 50 | downloader((0, tmp_file)) 51 | 52 | df = pd.read_parquet(image_folder_name + "/00000.parquet") 53 | 54 | desired_output_file = os.path.join(current_folder, "test_files", "hashes.json") 55 | with open(desired_output_file, "r") as f: 56 | hashes_dict = json.load(f) 57 | 58 | assert df[compute_hash][0] == hashes_dict[compute_hash] 59 | 60 | 61 | @pytest.mark.parametrize("compute_hash", ["md5", "sha256", "sha512"]) 62 | def test_unique_hash(compute_hash, tmp_path): 63 | current_folder = os.path.dirname(__file__) 64 | input_file = os.path.join(current_folder, "test_files", "unique_images.txt") 65 | with open(input_file, "r") as file: 66 | test_list = pd.DataFrame([(url.rstrip(),) for url in file.readlines()], columns=["url"]) 67 | 68 | test_folder = str(tmp_path) 69 | 70 | image_folder_name = os.path.join(test_folder, "images") 71 | os.mkdir(image_folder_name) 72 | 73 | resizer = Resizer(256, "border", False) 74 | writer = FilesSampleWriter 75 | 76 | downloader = Downloader( 77 | writer, 78 | resizer, 79 | thread_count=32, 80 | save_caption=False, 81 | extract_exif=True, 82 | output_folder=image_folder_name, 83 | column_list=["url"], 84 | timeout=10, 85 | number_sample_per_shard=10, 86 | oom_shard_count=5, 87 | compute_hash=compute_hash, 88 | verify_hash_type=None, 89 | encode_format="jpg", 90 | retries=0, 91 | user_agent_token="img2dataset", 92 | disallowed_header_directives=["noai", "noindex"], 93 | ) 94 | 95 | tmp_file = os.path.join(test_folder, "test_list.feather") 96 | df = pd.DataFrame(test_list, columns=["url"]) 97 | df.to_feather(tmp_file) 98 | 99 | downloader((0, tmp_file)) 100 | 101 | assert len(os.listdir(image_folder_name + "/00000")) >= 3 * 10 102 | 103 | df = pd.read_parquet(image_folder_name + "/00000.parquet") 104 | 105 | success = df[df[compute_hash].notnull()] 106 | 107 | assert len(success) > 10 108 | 109 | assert len(success) == len(success.drop_duplicates(compute_hash)) 110 | 111 | 112 | def test_downloader(tmp_path): 113 | test_folder = str(tmp_path) 114 | n_allowed = 5 115 | n_disallowed = 5 116 | test_list = setup_fixtures(count=n_allowed, disallowed=n_disallowed) 117 | 118 | assert len(test_list) == n_allowed + n_disallowed 119 | 120 | image_folder_name = os.path.join(test_folder, "images") 121 | 122 | os.mkdir(image_folder_name) 123 | 124 | resizer = Resizer(256, "border", False) 125 | writer = FilesSampleWriter 126 | 127 | downloader = Downloader( 128 | writer, 129 | resizer, 130 | thread_count=32, 131 | save_caption=True, 132 | extract_exif=True, 133 | output_folder=image_folder_name, 134 | column_list=["caption", "url"], 135 | timeout=10, 136 | number_sample_per_shard=10, 137 | oom_shard_count=5, 138 | compute_hash="md5", 139 | verify_hash_type="None", 140 | encode_format="jpg", 141 | retries=0, 142 | user_agent_token="img2dataset", 143 | disallowed_header_directives=["noai", "noindex"], 144 | ) 145 | 146 | tmp_file = os.path.join(test_folder, "test_list.feather") 147 | df = pd.DataFrame(test_list, columns=["caption", "url"]) 148 | df.to_feather(tmp_file) 149 | 150 | downloader((0, tmp_file)) 151 | 152 | assert len(os.listdir(image_folder_name + "/00000")) == 3 * n_allowed 153 | -------------------------------------------------------------------------------- /tests/test_files/benchmark.sh: -------------------------------------------------------------------------------- 1 | set -e 2 | #rm -rf bench 3 | time img2dataset --processes_count 16 --thread_count 128 --url_list=test_10000.parquet --image_size=256 --output_folder=bench \ 4 | --output_format="files" --input_format "parquet" --url_col "URL" --caption_col "TEXT" --enable_wandb True --number_sample_per_shard 1000 \ 5 | --distributor multiprocessing 6 | #rm -rf bench -------------------------------------------------------------------------------- /tests/test_files/hashes.json: -------------------------------------------------------------------------------- 1 | {"md5": "2b65fd14dde4b7875a6fdf733888316c", "sha256": "08dc1f9bc6a55a04882a2cc3ac792feca030362ddd62990c74b708fe1606ef47", "sha512": "0645ed8341a300400efdf2cb6d12f80617dbb306c88ca4ffff3100a18e0ef374686245dc6146aa30299ac7a09de77c31bd81773363636ff74c75ff24c1a5923a"} -------------------------------------------------------------------------------- /tests/test_files/large_bench.sh: -------------------------------------------------------------------------------- 1 | ## this benchmarks uses parquet files from https://github.com/rom1504/cah-prepro 2 | 3 | rm -rf /media/hd/testing/tmp_test 4 | img2dataset --url_list /media/hd/testing/cah_400M_meta --input_format "parquet"\ 5 | --url_col "URL" --caption_col "TEXT" --output_format webdataset\ 6 | --output_folder /media/hd/testing/tmp_test --processes_count 16 --thread_count 64 --image_size 256\ 7 | --save_additional_columns '["NSFW","similarity","LICENSE"]' --enable_wandb True --distributor multiprocessing -------------------------------------------------------------------------------- /tests/test_files/large_bench_tf.sh: -------------------------------------------------------------------------------- 1 | ## this benchmarks uses parquet files from https://github.com/rom1504/cah-prepro 2 | 3 | rm -rf /media/hd/testing/tmp_test 4 | img2dataset --url_list /media/hd/testing/cah_400M_meta --input_format "parquet"\ 5 | --url_col "URL" --caption_col "TEXT" --output_format tfrecord\ 6 | --output_folder /media/hd/testing/tmp_test --processes_count 16 --thread_count 64 --image_size 256\ 7 | --save_additional_columns '["NSFW","similarity","LICENSE"]' --enable_wandb True --distributor pyspark -------------------------------------------------------------------------------- /tests/test_files/s3_bench.sh: -------------------------------------------------------------------------------- 1 | ## this benchmarks uses parquet files from https://github.com/rom1504/cah-prepro 2 | 3 | #rm -rf /media/nvme/sample_unresized_laion400m 4 | aws s3 rm --recursive s3://laion-watermark/my_img2dataset_test/ 5 | img2dataset --url_list "/media/hd/testing/cah_400M_meta" --input_format "parquet"\ 6 | --url_col "URL" --caption_col "TEXT" --output_format webdataset\ 7 | --output_folder "s3://laion-watermark/my_img2dataset_test" --processes_count 16 --thread_count 64 --image_size 256\ 8 | --save_additional_columns '["NSFW","similarity","LICENSE"]' --enable_wandb True --distributor pyspark 9 | -------------------------------------------------------------------------------- /tests/test_files/sample_image.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/test_files/sample_image.parquet -------------------------------------------------------------------------------- /tests/test_files/sample_image.txt: -------------------------------------------------------------------------------- 1 | https://raw.githubusercontent.com/rom1504/img2dataset/main/tests/blur_test_files/original.png -------------------------------------------------------------------------------- /tests/test_files/test_1000.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/test_files/test_1000.parquet -------------------------------------------------------------------------------- /tests/test_files/test_10000.parquet: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/rom1504/img2dataset/a70e10d352ec11fd611b86ab81a29223a16c841e/tests/test_files/test_10000.parquet -------------------------------------------------------------------------------- /tests/test_files/unique_images.txt: -------------------------------------------------------------------------------- 1 | https://ae01.alicdn.com/kf/H70ee1312e27344219c9933b925984c4ex/SHUOKE-Bi-LED-Projector-headlights-Lens-2-5-Inch-Car-Headlight-Projector-bi-lenses-for-BMW.jpg 2 | https://www.cvs.com/bizcontent/merchandising/productimages/large/305210304956_5.jpg 3 | https://blueseatblogs.com/wp-content/uploads/2016/07/Untitled-copy-4.jpg 4 | https://www.lenovo.com/medias/40AF0135AU-500-1.png?context=bWFzdGVyfHJvb3R8NzA3NDd8aW1hZ2UvcG5nfGhlMS9oN2IvMTEwNjI3ODEwMTgxNDIucG5nfDFkNWQzODkxN2M1NzllYWZkNGUyNTVhZGU0NTI1ZThlM2M3NTBkNTBlZDE4M2I3ZWQ4OTQ5ZDU4YmIyYTI3ZTY 5 | https://cdn-0.brighthubeducation.com/wp-content/uploads/2009/12/band-24298_640-300x272.png?ezimgfmt=rs:0x0/rscb3/ng:webp/ngcb3 6 | https://sc02.alicdn.com/kf/H9474839c52104fbc83f5d7d87450108fl/235122577/H9474839c52104fbc83f5d7d87450108fl.jpg_.webp 7 | https://live-production.wcms.abc-cdn.net.au/bd2f16060e3b71356492fde5c5165a6f?impolicy=wcms_crop_resize&cropH=562&cropW=1000&xPos=0&yPos=0&width=862&height=485 8 | https://i1.wp.com/ae01.alicdn.com/kf/Hfa068558ed824f4ea2b6c1fa24b250f06/Heavy-Duty-Extra-Long-Blade-Camping-Woodworking-Saw-Hand-Saw-Garden-Saw-for-Wood-Dry-Wood.jpg_640x640.jpg?strip=all&quality=70&resize=50,50 9 | https://st.automobilemag.com/uploads/sites/11/2013/04/2013-Chrysler-300-SRT8-left-side-view-3.jpg 10 | https://img.perniaspopupshop.com/catalog/product/s/h/SHLK022113_1.jpg?impolicy=listingimagenew 11 | https://headtopics.com/images/2020/1/10/premiumtimesng/nigeria-immigration-probes-bribery-allegation-at-border-checkpoints-1215575743024988160.webp 12 | https://imgs.michaels.com/MAM/assets/1/5E3C12034D34434F8A9BAAFDDF0F8E1B/img/5CC6BE3C1F7A431BA0D54F2BD7748F44/10171148_20.jpg?fit=inside|1024:1024 13 | https://i1.wp.com/ae01.alicdn.com/kf/H807a27dfda7e4ea59f812622fb483c5cD/New-Arrived-LED-270-Strobe-Light-RGB-And-White-Color-For-Disco-Dj-Pub-Wedding-Party.png 14 | https://d1fdloi71mui9q.cloudfront.net/KA6p6VieSeiU7QagRlTC_40f7689eab0aa363cd2ae22a201a9348 15 | https://i.etsystatic.com/iusa/472168/55565844/iusa_400x400.55565844_40v8.jpg?version=0 16 | https://craft-mart.com/wp-content/uploads/2021/07/227-DEN-house-plans-600.jpg.webp 17 | https://ii1.pepperfry.com/media/catalog/product/b/r/568x284/brisbane-rhs-3-seater-sofa-with-coffee-table-in--grey-and-yellow-colour-by-arra-brisbane-rhs-3-seate-emr0ys.jpg 18 | https://ae01.alicdn.com/kf/HTB172zuOXXXXXXpaXXXq6xXFXXXs/Chashma-Brand-Women-Purple-Glasses-Classic-Design-Female-Optical-Frames.jpg 19 | https://i.imgur.com/aPbtskx_d.webp?maxwidth=640&shape=thumb&fidelity=medium 20 | https://static.super-shop.com/484971-tech-deck-fingerboard-birdhouse-10pack-01-960w.jpg 21 | https://i.etsystatic.com/14636520/r/il/d6d2eb/2666822748/il_fullxfull.2666822748_em0j.jpg 22 | https://ae01.alicdn.com/kf/HTB1fPeMaxD1gK0jSZFKq6AJrVXa5/Fuel-Pump-For-Johnson-Evinrude-100-105-115-125-135-140-HP-438559-385784-433390-1pcs.jpg_Q90.jpg_.webp 23 | https://assets.sunglasshut.com/is/image/LuxotticaRetail/8053672805437__001.png 24 | https://images.news18.com/ibnlive/uploads/2020/10/1603437876_untitled-design-2.png?impolicy=website&width=510&height=356 25 | https://st.motortrend.com/uploads/sites/5/2021/06/di_escort_10_di_driving_2.jpg 26 | https://ae01.alicdn.com/kf/HTB1m2afLXXXXXaiXpXXq6xXFXXXE/High-Quality-Trumpet-Bb-B-Flat-Brass-with-a-Silver-plated-mouthpiece-and-a-Pair-of.jpg 27 | https://thegrio.com/wp-content/uploads/2020/04/Insecure-Issa-Smile.png 28 | https://res.klook.com/images/fl_lossy.progressive,q_65/c_fill,w_1295,h_720/w_80,x_15,y_15,g_south_west,l_klook_water/activities/hvw4jgt5qalzplmtphzy/CasaAmatllerTicketinBarcelona.webp 29 | https://statusmarkets.in/wp-content/uploads/2021/11/5-Sandwich-Maker-Options-To-Make-Crispy-Sandwiches.jpg 30 | https://mobileimages.lowes.com/productimages/eab7cc88-42b3-449c-a2db-299501e01b5d/00814409.jpg?size=pdhi 31 | https://pisces.bbystatic.com/image2/BestBuy_US/images/products/6416/6416554_sd.jpg 32 | http://ae01.alicdn.com/kf/HTB1H7DmLlLoK1RjSZFuq6xn0XXam/27-xenoblade-2-pyra-pvc.jpg 33 | https://my-live-01.slatic.net/p/960a82ffaa6dd744531ffdc3b01815ef.jpg_2200x2200q80.jpg_.webp 34 | https://ae01.alicdn.com/kf/HTB1lxYwaOnrK1RjSsziq6xptpXaW.jpg 35 | https://i.etsystatic.com/24705470/r/il/70aba7/2627122317/il_340x270.2627122317_4jyh.jpg 36 | https://ae01.alicdn.com/kf/Hc93d3f3a14ce4f4e9b7cbaccf888bbeeK/Early-Autumn-New-Chic-Western-Style-Small-Shirt-Retro-French-Design-Sense-Palace-Style-Long-Sleeve.jpg_Q90.jpg_.webp 37 | https://i3.wp.com/ae01.alicdn.com/kf/Hc9c3cb4769c848a08c50029c2456a56bW/Wireless-Touch-Sensor-LED-Under-Cabinet-Light-Kitchen-LED-Battery-Wardrobe-Closet-Puck-Light-with-Controller.jpg 38 | https://www.masslive.com/resizer/oBaVCH9QRcAdGDUGFMCIkivO93o=/1280x0/smart/arc-anglerfish-arc2-prod-advancelocal.s3.amazonaws.com/public/5MZH2D7VINGEFF6FMHQZ7EDBC4.png 39 | https://ae01.alicdn.com/kf/HTB1XePTceGSBuNjSspbq6AiipXa5/Custom-Name-Solid-Silver-Ring-Men-Two-sides-Engraved-Name-Personalized-Engagement-Rings-Wholesale.jpg 40 | https://www.twi-global.com/image-library/Case-Studies/2019-Case-Studies/Carbon-Fibre-Composites/Carbon-Fibre-Fig-1b-v2.x86dd4117.jpg?crop=376,367,32,0 41 | https://i.etsystatic.com/24318141/d/il/57f34b/2640943861/il_340x270.2640943861_k16p.jpg?version=0 42 | https://i.etsystatic.com/18639132/r/il/188e9d/2093816019/il_340x270.2093816019_1hof.jpg 43 | https://i.etsystatic.com/18828870/r/il/ed6622/1692255574/il_fullxfull.1692255574_paeh.jpg 44 | https://ae01.alicdn.com/kf/HTB1SGuSQFXXXXb0XXXXq6xXFXXXX/Sexy-White-Beaded-Barefoot-Sandals-Beach-Foot-Jewelry-With-Starfish-Anklets-For-Women-2019-Summer-Ankle.jpg_Q90.jpg_.webp 45 | https://www.wyndhamhotels.com/content/dam/property-images/en-us/di/us/tn/franklin/02728/02728_exterior_view_1.jpg 46 | https://static.wixstatic.com/media/be5788_fe6a4c37067040388649a5b614942b9f.png/v1/fill/w_520,h_452,al_c,lg_1,q_85/be5788_fe6a4c37067040388649a5b614942b9f.webp 47 | https://ii1.pepperfry.com/media/catalog/product/m/i/568x284/miranda-three-seater-sofa-in-camel-yellow-colour-by-woodsworth-miranda-three-seater-sofa-in-camel-ye-mipvbl.jpg 48 | http://mobileimages.lowes.com/productimages/dfac35bf-7330-4069-ac58-f90dcae35a98/01108373.jpg 49 | https://i0.wp.com/ae01.alicdn.com/kf/HTB1S3_JMNYaK1RjSZFnq6y80pXax/Five-night-at-freddy-s-toys-FNAF-Fazbear-Chica-Bonnie-Mangle-Foxy-Bracelet-keychain-Freddy-Bag.jpg?crop=7,3,927,600&quality=3628 50 | https://i.etsystatic.com/7061050/r/il/c94af2/2416889156/il_fullxfull.2416889156_8jay.jpg 51 | https://ae01.alicdn.com/kf/H52614cd077dc483d8c4ff4dc326cdb96E/Simple-PVC-Cover-Diary-Drawing-Painting-Notebook-Cute-Soft-Cover-White-Paper-Notebook-Memo-Pad-School.jpg_Q90.jpg_.webp 52 | https://s13emagst.akamaized.net/products/18626/18625722/images/res_0caaca8696e9db21a6696182088e0512.jpg 53 | https://www.telegraph.co.uk/multimedia/archive/03381/71169547_Universit_3381660b.jpg 54 | https://vodcdn.abplive.in/2020/02/ca2a91550d7cf8d2576635f87fc71f2b.jpg?impolicy=abp_cdn&imwidth=330 55 | -------------------------------------------------------------------------------- /tests/test_main.py: -------------------------------------------------------------------------------- 1 | from img2dataset import download 2 | import os 3 | import shutil 4 | import pytest 5 | import glob 6 | import numpy as np 7 | import pandas as pd 8 | import cv2 9 | import time 10 | import tarfile 11 | from fixtures import ( 12 | get_all_files, 13 | check_image_size, 14 | generate_input_file, 15 | setup_fixtures, 16 | ) 17 | 18 | testdata = [ 19 | ("border", False, False), 20 | ("border", False, True), 21 | ("border", True, False), 22 | ("keep_ratio", False, False), 23 | ("keep_ratio", True, False), 24 | ("keep_ratio", True, True), 25 | ("center_crop", False, False), 26 | ("center_crop", True, False), 27 | ("no", False, False), 28 | ("no", False, True), 29 | ] 30 | 31 | 32 | @pytest.mark.parametrize("image_size", [256, 512]) 33 | @pytest.mark.parametrize("resize_mode, resize_only_if_bigger, skip_reencode", testdata) 34 | def test_download_resize(image_size, resize_mode, resize_only_if_bigger, skip_reencode, tmp_path): 35 | test_folder = str(tmp_path) 36 | test_list = setup_fixtures() 37 | prefix = resize_mode + "_" + str(resize_only_if_bigger) + "_" 38 | url_list_name = os.path.join(test_folder, prefix + "url_list") 39 | image_folder_name = os.path.join(test_folder, prefix + "images") 40 | unresized_folder = os.path.join(test_folder, prefix + "unresized_images") 41 | 42 | url_list_name = generate_input_file("txt", url_list_name, test_list) 43 | 44 | download( 45 | url_list_name, 46 | image_size=image_size, 47 | output_folder=unresized_folder, 48 | thread_count=32, 49 | resize_mode="no", 50 | resize_only_if_bigger=resize_only_if_bigger, 51 | ) 52 | 53 | download( 54 | url_list_name, 55 | image_size=image_size, 56 | output_folder=image_folder_name, 57 | thread_count=32, 58 | resize_mode=resize_mode, 59 | resize_only_if_bigger=resize_only_if_bigger, 60 | skip_reencode=skip_reencode, 61 | ) 62 | 63 | l = get_all_files(image_folder_name, "jpg") 64 | j = [a for a in get_all_files(image_folder_name, "json") if "stats" not in a] 65 | assert len(j) == len(test_list) 66 | p = get_all_files(image_folder_name, "parquet") 67 | assert len(p) == 1 68 | l_unresized = get_all_files(unresized_folder, "jpg") 69 | assert len(l) == len(test_list) 70 | check_image_size(l, l_unresized, image_size, resize_mode, resize_only_if_bigger) 71 | 72 | 73 | @pytest.mark.parametrize( 74 | "input_format, output_format", 75 | [ 76 | ["txt", "files"], 77 | ["txt", "webdataset"], 78 | ["txt.gz", "files"], 79 | ["txt.gz", "webdataset"], 80 | ["csv", "files"], 81 | ["csv", "webdataset"], 82 | ["csv.gz", "files"], 83 | ["csv.gz", "webdataset"], 84 | ["tsv", "files"], 85 | ["tsv", "webdataset"], 86 | ["tsv.gz", "files"], 87 | ["tsv.gz", "webdataset"], 88 | ["json", "files"], 89 | ["json", "webdataset"], 90 | ["json.gz", "files"], 91 | ["json.gz", "webdataset"], 92 | ["jsonl", "files"], 93 | ["jsonl", "webdataset"], 94 | ["jsonl.gz", "files"], 95 | ["jsonl.gz", "webdataset"], 96 | ["parquet", "files"], 97 | ["parquet", "webdataset"], 98 | ["parquet", "parquet"], 99 | ["parquet", "dummy"], 100 | ["parquet", "tfrecord"], 101 | ], 102 | ) 103 | def test_download_input_format(input_format, output_format, tmp_path): 104 | test_list = setup_fixtures() 105 | test_folder = str(tmp_path) 106 | 107 | prefix = input_format + "_" + output_format + "_" 108 | url_list_name = os.path.join(test_folder, prefix + "url_list") 109 | image_folder_name = os.path.join(test_folder, prefix + "images") 110 | 111 | url_list_name = generate_input_file(input_format, url_list_name, test_list) 112 | 113 | download( 114 | url_list_name, 115 | image_size=256, 116 | output_folder=image_folder_name, 117 | thread_count=32, 118 | input_format=input_format, 119 | output_format=output_format, 120 | url_col="url", 121 | caption_col="caption", 122 | compute_hash="md5", 123 | ) 124 | 125 | if output_format != "dummy": 126 | df = pd.read_parquet(image_folder_name + "/00000.parquet") 127 | 128 | expected_columns = [ 129 | "url", 130 | "key", 131 | "status", 132 | "error_message", 133 | "width", 134 | "height", 135 | "original_width", 136 | "original_height", 137 | "exif", 138 | "md5", 139 | ] 140 | 141 | if input_format not in ["txt", "txt.gz"]: 142 | expected_columns.insert(2, "caption") 143 | 144 | if output_format == "parquet": 145 | expected_columns.append("jpg") 146 | 147 | assert set(df.columns.tolist()) == set(expected_columns) 148 | 149 | expected_file_count = len(test_list) 150 | if output_format == "files": 151 | l = get_all_files(image_folder_name, "jpg") 152 | assert len(l) == expected_file_count 153 | elif output_format == "webdataset": 154 | l = glob.glob(image_folder_name + "/*.tar") 155 | assert len(l) == 1 156 | if l[0] != image_folder_name + "/00000.tar": 157 | raise Exception(l[0] + " is not 00000.tar") 158 | 159 | assert ( 160 | len([x for x in tarfile.open(image_folder_name + "/00000.tar").getnames() if x.endswith(".jpg")]) 161 | == expected_file_count 162 | ) 163 | elif output_format == "parquet": 164 | l = glob.glob(image_folder_name + "/*.parquet") 165 | assert len(l) == 1 166 | if l[0] != image_folder_name + "/00000.parquet": 167 | raise Exception(l[0] + " is not 00000.parquet") 168 | 169 | assert len(pd.read_parquet(image_folder_name + "/00000.parquet").index) == expected_file_count 170 | elif output_format == "dummy": 171 | l = [ 172 | x 173 | for x in glob.glob(image_folder_name + "/*") 174 | if ( 175 | not x.endswith(".json") 176 | and not x.endswith(".jsonl") 177 | and not x.endswith(".json.gz") 178 | and not x.endswith(".jsonl.gz") 179 | ) 180 | ] 181 | assert len(l) == 0 182 | elif output_format == "tfrecord": 183 | l = glob.glob(image_folder_name + "/*.tfrecord") 184 | assert len(l) == 1 185 | if l[0] != image_folder_name + "/00000.tfrecord": 186 | raise Exception(l[0] + " is not 00000.tfrecord") 187 | 188 | 189 | @pytest.mark.parametrize( 190 | "input_format, output_format", 191 | [ 192 | ["txt", "files"], 193 | ["txt", "webdataset"], 194 | ["txt.gz", "files"], 195 | ["txt.gz", "webdataset"], 196 | ["csv", "files"], 197 | ["csv", "webdataset"], 198 | ["csv.gz", "files"], 199 | ["csv.gz", "webdataset"], 200 | ["tsv", "files"], 201 | ["tsv", "webdataset"], 202 | ["tsv.gz", "files"], 203 | ["tsv.gz", "webdataset"], 204 | ["json", "files"], 205 | ["json", "webdataset"], 206 | ["json.gz", "files"], 207 | ["json.gz", "webdataset"], 208 | ["jsonl", "files"], 209 | ["jsonl", "webdataset"], 210 | ["jsonl.gz", "files"], 211 | ["jsonl.gz", "webdataset"], 212 | ["parquet", "files"], 213 | ["parquet", "webdataset"], 214 | ], 215 | ) 216 | def test_download_multiple_input_files(input_format, output_format, tmp_path): 217 | test_list = setup_fixtures() 218 | prefix = input_format + "_" + output_format + "_" 219 | test_folder = str(tmp_path) 220 | 221 | subfolder = test_folder + "/" + prefix + "input_folder" 222 | if not os.path.exists(subfolder): 223 | os.mkdir(subfolder) 224 | url_list_names = [os.path.join(subfolder, prefix + "url_list1"), os.path.join(subfolder, prefix + "url_list2")] 225 | image_folder_name = os.path.join(test_folder, prefix + "images") 226 | 227 | for url_list_name in url_list_names: 228 | url_list_name = generate_input_file(input_format, url_list_name, test_list) 229 | 230 | download( 231 | subfolder, 232 | image_size=256, 233 | output_folder=image_folder_name, 234 | thread_count=32, 235 | input_format=input_format, 236 | output_format=output_format, 237 | url_col="url", 238 | caption_col="caption", 239 | ) 240 | 241 | expected_file_count = len(test_list) 242 | if output_format == "files": 243 | l = get_all_files(image_folder_name, "jpg") 244 | assert len(l) == expected_file_count * 2 245 | elif output_format == "webdataset": 246 | l = sorted(glob.glob(image_folder_name + "/*.tar")) 247 | assert len(l) == 2 248 | if l[0] != image_folder_name + "/00000.tar": 249 | raise Exception(l[0] + " is not 00000.tar") 250 | if l[1] != image_folder_name + "/00001.tar": 251 | raise Exception(l[1] + " is not 00001.tar") 252 | 253 | assert ( 254 | len([x for x in tarfile.open(image_folder_name + "/00000.tar").getnames() if x.endswith(".jpg")]) 255 | == expected_file_count 256 | ) 257 | assert ( 258 | len([x for x in tarfile.open(image_folder_name + "/00001.tar").getnames() if x.endswith(".jpg")]) 259 | == expected_file_count 260 | ) 261 | 262 | 263 | @pytest.mark.parametrize( 264 | "save_caption, output_format", 265 | [ 266 | [True, "files"], 267 | [False, "files"], 268 | [True, "webdataset"], 269 | [False, "webdataset"], 270 | ], 271 | ) 272 | def test_captions_saving(save_caption, output_format, tmp_path): 273 | test_folder = str(tmp_path) 274 | test_list = setup_fixtures() 275 | 276 | input_format = "parquet" 277 | prefix = str(save_caption) + "_" + input_format + "_" + output_format + "_" 278 | url_list_name = os.path.join(test_folder, prefix + "url_list") 279 | image_folder_name = os.path.join(test_folder, prefix + "images") 280 | url_list_name = generate_input_file("parquet", url_list_name, test_list) 281 | download( 282 | url_list_name, 283 | image_size=256, 284 | output_folder=image_folder_name, 285 | thread_count=32, 286 | input_format=input_format, 287 | output_format=output_format, 288 | url_col="url", 289 | caption_col="caption" if save_caption else None, 290 | ) 291 | 292 | expected_file_count = len(test_list) 293 | if output_format == "files": 294 | l = get_all_files(image_folder_name, "jpg") 295 | assert len(l) == expected_file_count 296 | l = get_all_files(image_folder_name, "txt") 297 | if save_caption: 298 | assert len(l) == expected_file_count 299 | for expected, real in zip(test_list, l): 300 | true_real = open(real).read() 301 | true_expected = expected[0] if expected[0] is not None else "" 302 | assert true_expected == true_real 303 | else: 304 | assert len(l) == 0 305 | elif output_format == "webdataset": 306 | l = glob.glob(image_folder_name + "/*.tar") 307 | assert len(l) == 1 308 | if l[0] != image_folder_name + "/00000.tar": 309 | raise Exception(l[0] + " is not 00000.tar") 310 | 311 | with tarfile.open(image_folder_name + "/00000.tar") as f: 312 | assert len([x for x in f.getnames() if x.endswith(".jpg")]) == expected_file_count 313 | txt_files = sorted([x for x in f.getnames() if x.endswith(".txt")]) 314 | if save_caption: 315 | assert len(txt_files) == expected_file_count 316 | for expected, real in zip(test_list, txt_files): 317 | true_expected = expected[0] if expected[0] is not None else "" 318 | true_real = f.extractfile(real).read().decode("utf-8") 319 | assert true_expected == true_real 320 | else: 321 | assert len(txt_files) == 0 322 | 323 | 324 | def test_webdataset(tmp_path): 325 | test_list = setup_fixtures() 326 | test_folder = str(tmp_path) 327 | url_list_name = os.path.join(test_folder, "url_list") 328 | image_folder_name = os.path.join(test_folder, "images") 329 | 330 | url_list_name = generate_input_file("txt", url_list_name, test_list) 331 | 332 | download( 333 | url_list_name, image_size=256, output_folder=image_folder_name, thread_count=32, output_format="webdataset" 334 | ) 335 | 336 | l = glob.glob(image_folder_name + "/*.tar") 337 | assert len(l) == 1 338 | if l[0] != image_folder_name + "/00000.tar": 339 | raise Exception(l[0] + " is not 00000.tar") 340 | 341 | assert len(tarfile.open(image_folder_name + "/00000.tar").getnames()) == len(test_list) * 2 342 | 343 | os.remove(url_list_name) 344 | shutil.rmtree(image_folder_name) 345 | 346 | 347 | def test_relative_path(tmp_path): 348 | test_folder = str(tmp_path) 349 | test_list = setup_fixtures() 350 | 351 | url_list_name = os.path.join(test_folder, "url_list") 352 | image_folder_name = os.path.join(test_folder, "images") 353 | 354 | url_list_name = generate_input_file("txt", url_list_name, test_list) 355 | 356 | url_list_name = os.path.relpath(url_list_name) 357 | image_folder_name = os.path.relpath(image_folder_name) 358 | 359 | download( 360 | url_list_name, image_size=256, output_folder=image_folder_name, thread_count=32, output_format="webdataset" 361 | ) 362 | 363 | l = glob.glob(image_folder_name + "/*.tar") 364 | assert len(l) == 1 365 | if l[0] != image_folder_name + "/00000.tar": 366 | raise Exception(l[0] + " is not 00000.tar") 367 | 368 | assert len(tarfile.open(image_folder_name + "/00000.tar").getnames()) == len(test_list) * 2 369 | 370 | 371 | @pytest.mark.parametrize( 372 | "distributor", 373 | [ 374 | "multiprocessing", 375 | "pyspark", 376 | "ray", 377 | ], 378 | ) 379 | def test_distributors(distributor, tmp_path): 380 | test_folder = str(tmp_path) 381 | test_list = setup_fixtures() 382 | 383 | url_list_name = os.path.join(test_folder, "url_list") 384 | image_folder_name = os.path.join(test_folder, "images") 385 | 386 | url_list_name = generate_input_file("txt", url_list_name, test_list) 387 | 388 | download( 389 | url_list_name, 390 | image_size=256, 391 | output_folder=image_folder_name, 392 | thread_count=32, 393 | output_format="webdataset", 394 | distributor=distributor, 395 | ) 396 | 397 | l = glob.glob(image_folder_name + "/*.tar") 398 | assert len(l) == 1 399 | if l[0] != image_folder_name + "/00000.tar": 400 | raise Exception(l[0] + " is not 00000.tar") 401 | 402 | assert len(tarfile.open(image_folder_name + "/00000.tar").getnames()) == len(test_list) * 2 403 | 404 | 405 | # @pytest.mark.skip(reason="slow") 406 | @pytest.mark.parametrize("output_format", ["webdataset", "files"]) 407 | def test_benchmark(output_format, tmp_path): 408 | test_folder = str(tmp_path) 409 | current_folder = os.path.dirname(__file__) 410 | 411 | prefix = output_format + "_" 412 | url_list_name = os.path.join(current_folder, "test_files/test_1000.parquet") 413 | image_folder_name = os.path.join(test_folder, prefix + "images") 414 | 415 | t = time.time() 416 | 417 | download( 418 | url_list_name, 419 | image_size=256, 420 | output_folder=image_folder_name, 421 | thread_count=32, 422 | output_format=output_format, 423 | input_format="parquet", 424 | url_col="URL", 425 | caption_col="TEXT", 426 | ) 427 | 428 | took = time.time() - t 429 | 430 | print("Took " + str(took) + "s") 431 | 432 | if took > 100: 433 | raise Exception("Very slow, took " + str(took)) 434 | 435 | 436 | @pytest.mark.parametrize( 437 | "resize_mode, resize_only_if_bigger", 438 | [ 439 | ["no", False], 440 | ["border", False], 441 | ["keep_ratio", False], 442 | ["keep_ratio_largest", False], 443 | ["center_crop", False], 444 | ["no", True], 445 | ["border", True], 446 | ["keep_ratio", True], 447 | ["keep_ratio_largest", True], 448 | ["center_crop", True], 449 | ], 450 | ) 451 | def test_blur_and_resize(resize_mode, resize_only_if_bigger, tmp_path): 452 | test_folder = str(tmp_path) 453 | output_folder = os.path.join(test_folder, "images") 454 | 455 | current_folder = os.path.dirname(__file__) 456 | input_parquet = os.path.join(current_folder, "blur_test_files", "test_bbox.parquet") 457 | 458 | download( 459 | input_parquet, 460 | input_format="parquet", 461 | image_size=600, 462 | output_folder=output_folder, 463 | output_format="files", 464 | thread_count=32, 465 | resize_mode=resize_mode, 466 | resize_only_if_bigger=resize_only_if_bigger, 467 | bbox_col="bboxes", 468 | ) 469 | 470 | output_img_path = get_all_files(output_folder, "jpg")[0] 471 | if resize_only_if_bigger: 472 | desired_output_img_path = os.path.join( 473 | current_folder, "blur_test_files", "resize_no.jpg" 474 | ) # Original image is smaller 475 | else: 476 | desired_output_img_path = os.path.join(current_folder, "blur_test_files", f"resize_{resize_mode}.jpg") 477 | 478 | output_img = cv2.imread(output_img_path) 479 | desired_img = cv2.imread(desired_output_img_path) 480 | assert np.array_equal(output_img, desired_img) 481 | 482 | 483 | def test_verify_hash(tmp_path): 484 | test_folder = str(tmp_path) 485 | output_folder = os.path.join(test_folder, "images") 486 | 487 | current_folder = os.path.dirname(__file__) 488 | input_parquet = os.path.join(current_folder, "test_files", "sample_image.parquet") 489 | 490 | download( 491 | input_parquet, 492 | input_format="parquet", 493 | image_size=224, 494 | output_folder=output_folder, 495 | output_format="files", 496 | thread_count=32, 497 | verify_hash=["sha256hash", "sha256"], 498 | ) 499 | 500 | df = pd.read_parquet(os.path.join(output_folder, "00000.parquet")) 501 | 502 | assert df["sha256"].isna().to_numpy().sum() == 1 503 | -------------------------------------------------------------------------------- /tests/test_reader.py: -------------------------------------------------------------------------------- 1 | from img2dataset.reader import Reader 2 | import os 3 | from fixtures import generate_input_file, setup_fixtures 4 | import pytest 5 | import math 6 | import time 7 | import gc 8 | import psutil 9 | import pandas as pd 10 | 11 | 12 | def current_memory_usage(): 13 | return psutil.Process().memory_info().rss / 1024 / 1024 14 | 15 | 16 | @pytest.mark.parametrize( 17 | "input_format", 18 | [ 19 | "txt", 20 | "txt.gz", 21 | "csv", 22 | "csv.gz", 23 | "tsv", 24 | "tsv.gz", 25 | "json", 26 | "json.gz", 27 | "jsonl", 28 | "jsonl.gz", 29 | "parquet", 30 | ], 31 | ) 32 | def test_reader(input_format, tmp_path): 33 | """Tests whether Reader class works as expected.""" 34 | expected_count = 10**5 + 5312 35 | test_folder = str(tmp_path) 36 | test_list = setup_fixtures(count=expected_count) 37 | prefix = input_format + "_" 38 | url_list_name = os.path.join(test_folder, prefix + "url_list") 39 | url_list_name = generate_input_file(input_format, url_list_name, test_list) 40 | 41 | tmp_path = os.path.join(test_folder, prefix + "tmp") 42 | os.mkdir(tmp_path) 43 | 44 | done_shards = [0, 1, 2, 3] 45 | batch_size = 1000 46 | reader = Reader( 47 | url_list=url_list_name, 48 | input_format=input_format, 49 | url_col="url", 50 | caption_col=None if input_format in ["txt", "txt.gz"] else "caption", 51 | verify_hash_col=None, 52 | verify_hash_type=None, 53 | save_additional_columns=None, 54 | number_sample_per_shard=batch_size, 55 | done_shards=done_shards, 56 | tmp_path=test_folder, 57 | ) 58 | 59 | if input_format in ["txt", "txt.gz"]: 60 | assert reader.column_list == ["url"] 61 | else: 62 | assert reader.column_list == ["caption", "url"] 63 | last_shard_num = math.ceil(expected_count / batch_size) - 1 64 | 65 | total_sample_count = 0 66 | start_time = time.time() 67 | initial_memory_usage = current_memory_usage() 68 | for i, (shard_id, shard_path) in enumerate(reader): 69 | incremental_shard_id = len(done_shards) + i 70 | assert incremental_shard_id == shard_id 71 | shard_df = pd.read_feather(shard_path) 72 | shard = list(enumerate(shard_df[reader.column_list].to_records(index=False).tolist())) 73 | total_sample_count += len(shard) 74 | if last_shard_num == incremental_shard_id: 75 | assert len(shard) <= batch_size 76 | else: 77 | assert len(shard) == batch_size 78 | 79 | begin_expected = incremental_shard_id * batch_size 80 | end_expected = (incremental_shard_id + 1) * batch_size 81 | 82 | expected_shard = list(enumerate(test_list[begin_expected:end_expected])) 83 | if input_format in ["txt", "txt.gz"]: 84 | expected_shard = [(i, (url,)) for i, (_, url) in expected_shard] 85 | assert shard == expected_shard 86 | current_usage = current_memory_usage() 87 | assert current_usage - initial_memory_usage < 100 88 | del expected_shard 89 | del shard 90 | 91 | del reader 92 | 93 | assert total_sample_count == expected_count - batch_size * len(done_shards) 94 | 95 | total_time = time.time() - start_time 96 | print("Total time:", total_time) 97 | assert total_time <= 1.0 98 | 99 | gc.collect() 100 | 101 | final_memory_usage = current_memory_usage() 102 | assert final_memory_usage - initial_memory_usage < 100 103 | -------------------------------------------------------------------------------- /tests/test_resizer.py: -------------------------------------------------------------------------------- 1 | from img2dataset.resizer import Resizer 2 | import os 3 | import glob 4 | import pytest 5 | from fixtures import check_one_image_size 6 | import io 7 | import cv2 8 | import numpy as np 9 | 10 | testdata = [ 11 | ("border", False, False), 12 | ("border", False, True), 13 | ("border", True, False), 14 | ("keep_ratio", False, False), 15 | ("keep_ratio", True, False), 16 | ("keep_ratio", True, True), 17 | ("keep_ratio_largest", False, False), 18 | ("keep_ratio_largest", True, False), 19 | ("keep_ratio_largest", True, True), 20 | ("center_crop", False, False), 21 | ("center_crop", True, False), 22 | ("no", False, False), 23 | ("no", False, True), 24 | ] 25 | 26 | testformat = [ 27 | (95, "jpg"), 28 | (95, "webp"), 29 | (9, "png"), 30 | ] 31 | 32 | 33 | @pytest.mark.parametrize("image_size", [256, 512]) 34 | @pytest.mark.parametrize("resize_mode, resize_only_if_bigger, skip_reencode", testdata) 35 | @pytest.mark.parametrize("encode_quality, encode_format", testformat) 36 | def test_resizer(image_size, resize_mode, resize_only_if_bigger, skip_reencode, encode_quality, encode_format): 37 | current_folder = os.path.dirname(__file__) 38 | test_folder = current_folder + "/" + "resize_test_image" 39 | image_paths = glob.glob(test_folder + "/*") 40 | resizer = Resizer( 41 | image_size, 42 | resize_mode, 43 | resize_only_if_bigger, 44 | encode_quality=encode_quality, 45 | encode_format=encode_format, 46 | skip_reencode=skip_reencode, 47 | ) 48 | for image_path in image_paths: 49 | with open(image_path, "rb") as f: 50 | img = f.read() 51 | image_original_stream = io.BytesIO(img) 52 | image_resized_str, width, height, original_width, original_height, err = resizer(image_original_stream) 53 | assert err is None 54 | image_original_stream = io.BytesIO(img) 55 | image_original = cv2.imdecode(np.frombuffer(image_original_stream.read(), np.uint8), cv2.IMREAD_UNCHANGED) 56 | image_resized = cv2.imdecode(np.frombuffer(image_resized_str, np.uint8), cv2.IMREAD_UNCHANGED) 57 | width_resized = image_resized.shape[1] 58 | height_resized = image_resized.shape[0] 59 | width_original = image_original.shape[1] 60 | height_original = image_original.shape[0] 61 | assert width_resized == width 62 | assert height_resized == height 63 | assert width_original == original_width 64 | assert height_original == original_height 65 | check_one_image_size(image_resized, image_original, image_size, resize_mode, resize_only_if_bigger) 66 | 67 | 68 | def test_resizer_filter(): 69 | current_folder = os.path.dirname(__file__) 70 | test_folder = current_folder + "/" + "resize_test_image" 71 | image_paths = glob.glob(test_folder + "/*") 72 | resizer = Resizer( 73 | image_size=256, resize_mode="no", resize_only_if_bigger=True, min_image_size=200, max_aspect_ratio=1.5 74 | ) 75 | errors = [] 76 | for image_path in image_paths: 77 | with open(image_path, "rb") as f: 78 | img = f.read() 79 | image_original_stream = io.BytesIO(img) 80 | _, _, _, _, _, err = resizer(image_original_stream) 81 | errors.append(err) 82 | expected_errors = [(None, 2), ("image too small", 2), ("aspect ratio too large", 3)] 83 | for expected_error, count in expected_errors: 84 | assert count == errors.count(expected_error) 85 | 86 | resizer = Resizer(image_size=256, resize_mode="no", resize_only_if_bigger=True, max_image_area=60000) 87 | errors = [] 88 | for image_path in image_paths: 89 | with open(image_path, "rb") as f: 90 | img = f.read() 91 | image_original_stream = io.BytesIO(img) 92 | _, _, _, _, _, err = resizer(image_original_stream) 93 | errors.append(err) 94 | expected_errors = [(None, 2), ("image area too large", 5)] 95 | for expected_error, count in expected_errors: 96 | assert count == errors.count(expected_error) 97 | -------------------------------------------------------------------------------- /tests/test_writer.py: -------------------------------------------------------------------------------- 1 | from img2dataset.writer import ( 2 | FilesSampleWriter, 3 | WebDatasetSampleWriter, 4 | ParquetSampleWriter, 5 | DummySampleWriter, 6 | TFRecordSampleWriter, 7 | ) 8 | 9 | import os 10 | import glob 11 | import pytest 12 | import tarfile 13 | import pandas as pd 14 | import pyarrow as pa 15 | 16 | 17 | @pytest.mark.parametrize("writer_type", ["files", "webdataset", "parquet", "dummy", "tfrecord"]) 18 | def test_writer(writer_type, tmp_path): 19 | current_folder = os.path.dirname(__file__) 20 | test_folder = str(tmp_path) 21 | input_folder = current_folder + "/" + "resize_test_image" 22 | output_folder = test_folder + "/" + "test_write" 23 | os.mkdir(output_folder) 24 | image_paths = glob.glob(input_folder + "/*") 25 | schema = pa.schema( 26 | [ 27 | pa.field("key", pa.string()), 28 | pa.field("caption", pa.string()), 29 | pa.field("status", pa.string()), 30 | pa.field("error_message", pa.string()), 31 | pa.field("width", pa.int32()), 32 | pa.field("height", pa.int32()), 33 | pa.field("original_width", pa.int32()), 34 | pa.field("original_height", pa.int32()), 35 | pa.field("labels", pa.list_(pa.int32())), 36 | ] 37 | ) 38 | if writer_type == "files": 39 | writer_class = FilesSampleWriter 40 | elif writer_type == "webdataset": 41 | writer_class = WebDatasetSampleWriter 42 | elif writer_type == "parquet": 43 | writer_class = ParquetSampleWriter 44 | elif writer_type == "dummy": 45 | writer_class = DummySampleWriter 46 | elif writer_type == "tfrecord": 47 | writer_class = TFRecordSampleWriter 48 | 49 | writer = writer_class(0, output_folder, True, 5, schema, "jpg") 50 | 51 | for i, image_path in enumerate(image_paths): 52 | with open(image_path, "rb") as f: 53 | img_str = f.read() 54 | writer.write( 55 | img_str=img_str, 56 | key=str(i), 57 | caption=str(i), 58 | meta={ 59 | "key": str(i), 60 | "caption": str(i), 61 | "status": "ok", 62 | "error_message": "", 63 | "width": 100, 64 | "height": 100, 65 | "original_width": 100, 66 | "original_height": 100, 67 | "labels": [0, 100, 200], 68 | }, 69 | ) 70 | writer.close() 71 | 72 | if writer_type != "dummy": 73 | df = pd.read_parquet(output_folder + "/00000.parquet") 74 | 75 | expected_columns = [ 76 | "key", 77 | "caption", 78 | "status", 79 | "error_message", 80 | "width", 81 | "height", 82 | "original_width", 83 | "original_height", 84 | "labels", 85 | ] 86 | 87 | if writer_type == "parquet": 88 | expected_columns.append("jpg") 89 | 90 | assert df.columns.tolist() == expected_columns 91 | 92 | assert df["key"].iloc[0] == "0" 93 | assert df["caption"].iloc[0] == "0" 94 | assert df["status"].iloc[0] == "ok" 95 | assert df["error_message"].iloc[0] == "" 96 | assert df["width"].iloc[0] == 100 97 | assert df["height"].iloc[0] == 100 98 | assert df["original_width"].iloc[0] == 100 99 | assert df["original_height"].iloc[0] == 100 100 | assert (df["labels"].iloc[0] == [0, 100, 200]).all() 101 | 102 | if writer_type == "files": 103 | saved_files = list(glob.glob(output_folder + "/00000/*")) 104 | assert len(saved_files) == 3 * len(image_paths) 105 | elif writer_type == "webdataset": 106 | l = glob.glob(output_folder + "/*.tar") 107 | assert len(l) == 1 108 | if l[0] != output_folder + "/00000.tar": 109 | raise Exception(l[0] + " is not 00000.tar") 110 | 111 | assert len(tarfile.open(output_folder + "/00000.tar").getnames()) == len(image_paths) * 3 112 | elif writer_type == "parquet": 113 | l = glob.glob(output_folder + "/*.parquet") 114 | assert len(l) == 1 115 | if l[0] != output_folder + "/00000.parquet": 116 | raise Exception(l[0] + " is not 00000.parquet") 117 | 118 | assert len(df.index) == len(image_paths) 119 | elif writer_type == "dummy": 120 | l = glob.glob(output_folder + "/*") 121 | assert len(l) == 0 122 | elif writer_type == "tfrecord": 123 | l = glob.glob(output_folder + "/*.tfrecord") 124 | assert len(l) == 1 125 | if l[0] != output_folder + "/00000.tfrecord": 126 | raise Exception(l[0] + " is not 00000.tfrecord") 127 | --------------------------------------------------------------------------------