├── .github └── workflows │ ├── ci.yml │ └── pypi-publish.yml ├── .gitignore ├── .pylintrc ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── dm_pix ├── __init__.py ├── _src │ ├── __init__.py │ ├── api_test.py │ ├── augment.py │ ├── augment_test.py │ ├── color_conversion.py │ ├── color_conversion_test.py │ ├── depth_and_space.py │ ├── depth_and_space_test.py │ ├── interpolation.py │ ├── interpolation_test.py │ ├── metrics.py │ ├── metrics_test.py │ ├── patch.py │ ├── patch_test.py │ └── test_utils.py ├── images │ └── pix_logo.png └── py.typed ├── docs ├── Makefile ├── api.rst ├── conf.py ├── ext │ └── coverage_check.py └── index.rst ├── examples ├── README.md ├── assets │ ├── adjust_brightness_jax_logo.jpg │ ├── adjust_contrast_jax_logo.jpg │ ├── adjust_gamma_jax_logo.jpg │ ├── adjust_hue_jax_logo.jpg │ ├── flip_left_right_jax_logo.jpg │ ├── flip_up_down_jax_logo.jpg │ ├── gaussian_blur_jax_logo.jpg │ └── jax_logo.jpg ├── image_augmentation.ipynb └── image_augmentation.py ├── pyproject.toml ├── readthedocs.yml └── test.sh /.github/workflows/ci.yml: -------------------------------------------------------------------------------- 1 | name: ci 2 | 3 | on: 4 | push: 5 | branches: ["master"] 6 | pull_request: 7 | branches: ["master"] 8 | 9 | jobs: 10 | build-and-test: 11 | name: "Python ${{ matrix.python-version }} on ${{ matrix.os }}" 12 | runs-on: "${{ matrix.os }}" 13 | 14 | strategy: 15 | matrix: 16 | python-version: ["3.8", "3.9", "3.10"] 17 | os: [ubuntu-latest] 18 | 19 | steps: 20 | - uses: "actions/checkout@v2" 21 | - uses: "actions/setup-python@v4" 22 | with: 23 | python-version: "${{ matrix.python-version }}" 24 | cache: "pip" 25 | cache-dependency-path: "pyproject.toml" 26 | - name: Run CI tests 27 | run: bash test.sh 28 | shell: bash 29 | -------------------------------------------------------------------------------- /.github/workflows/pypi-publish.yml: -------------------------------------------------------------------------------- 1 | name: pypi 2 | 3 | on: 4 | release: 5 | types: [created] 6 | 7 | jobs: 8 | deploy: 9 | runs-on: ubuntu-latest 10 | steps: 11 | - uses: actions/checkout@v4 12 | - name: Set up Python 13 | uses: actions/setup-python@v4 14 | with: 15 | python-version: '3.x' 16 | - name: Install dependencies 17 | run: | 18 | python -m pip install --upgrade pip 19 | pip install setuptools wheel twine build 20 | - name: Check consistency between the package version and release tag 21 | run: | 22 | pip install . 23 | RELEASE_VER=${GITHUB_REF#refs/*/} 24 | PACKAGE_VER="v`python -c 'import dm_pix; print(dm_pix.__version__)'`" 25 | if [ $RELEASE_VER != $PACKAGE_VER ] 26 | then 27 | echo "package ver. ($PACKAGE_VER) != release ver. ($RELEASE_VER)"; exit 1 28 | fi 29 | - name: Build and publish 30 | env: 31 | TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} 32 | TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} 33 | run: | 34 | python -m build 35 | twine upload dist/* 36 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Building and releasing library: 2 | *.egg-info 3 | *.pyc 4 | *.so 5 | build/ 6 | dist/ 7 | venv/ 8 | 9 | # Mac OS 10 | .DS_Store 11 | 12 | # Python tools 13 | .mypy_cache/ 14 | .pytype/ 15 | .ipynb_checkpoints 16 | 17 | # Editors 18 | .idea 19 | .vscode 20 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Add files or directories to the blacklist. They should be base names, not 11 | # paths. 12 | ignore=third_party 13 | 14 | # Add files or directories matching the regex patterns to the blacklist. The 15 | # regex matches against base names, not paths. 16 | ignore-patterns= 17 | 18 | # Pickle collected data for later comparisons. 19 | persistent=no 20 | 21 | # List of plugins (as comma separated values of python modules names) to load, 22 | # usually to register additional checkers. 23 | load-plugins= 24 | 25 | # Use multiple processes to speed up Pylint. 26 | jobs=4 27 | 28 | # Allow loading of arbitrary C extensions. Extensions are imported into the 29 | # active Python interpreter and may run arbitrary code. 30 | unsafe-load-any-extension=no 31 | 32 | # A comma-separated list of package or module names from where C extensions may 33 | # be loaded. Extensions are loading into the active Python interpreter and may 34 | # run arbitrary code 35 | extension-pkg-whitelist= 36 | 37 | 38 | [MESSAGES CONTROL] 39 | 40 | # Only show warnings with the listed confidence levels. Leave empty to show 41 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 42 | confidence= 43 | 44 | # Enable the message, report, category or checker with the given id(s). You can 45 | # either give multiple identifier separated by comma (,) or put this option 46 | # multiple time (only on the command line, not in the configuration file where 47 | # it should appear only once). See also the "--disable" option for examples. 48 | #enable= 49 | 50 | # Disable the message, report, category or checker with the given id(s). You 51 | # can either give multiple identifiers separated by comma (,) or put this 52 | # option multiple times (only on the command line, not in the configuration 53 | # file where it should appear only once).You can also use "--disable=all" to 54 | # disable everything first and then reenable specific checks. For example, if 55 | # you want to run only the similarities checker, you can use "--disable=all 56 | # --enable=similarities". If you want to run only the classes checker, but have 57 | # no Warning level messages displayed, use"--disable=all --enable=classes 58 | # --disable=W" 59 | disable=apply-builtin, 60 | attribute-defined-outside-init, 61 | backtick, 62 | bad-option-value, 63 | buffer-builtin, 64 | c-extension-no-member, 65 | cmp-builtin, 66 | cmp-method, 67 | coerce-builtin, 68 | coerce-method, 69 | delslice-method, 70 | div-method, 71 | duplicate-code, 72 | eq-without-hash, 73 | execfile-builtin, 74 | file-builtin, 75 | filter-builtin-not-iterating, 76 | fixme, 77 | getslice-method, 78 | global-statement, 79 | hex-method, 80 | idiv-method, 81 | implicit-str-concat-in-sequence, 82 | import-error, 83 | import-self, 84 | import-star-module-level, 85 | input-builtin, 86 | intern-builtin, 87 | invalid-str-codec, 88 | invalid-unary-operand-type, 89 | locally-disabled, 90 | long-builtin, 91 | long-suffix, 92 | map-builtin-not-iterating, 93 | metaclass-assignment, 94 | next-method-called, 95 | next-method-defined, 96 | no-absolute-import, 97 | no-else-break, 98 | no-else-continue, 99 | no-else-raise, 100 | no-else-return, 101 | no-member, 102 | no-self-use, 103 | nonzero-method, 104 | oct-method, 105 | old-division, 106 | old-ne-operator, 107 | old-octal-literal, 108 | old-raise-syntax, 109 | parameter-unpacking, 110 | print-statement, 111 | raising-string, 112 | range-builtin-not-iterating, 113 | raw_input-builtin, 114 | rdiv-method, 115 | reduce-builtin, 116 | relative-import, 117 | reload-builtin, 118 | round-builtin, 119 | setslice-method, 120 | signature-differs, 121 | standarderror-builtin, 122 | suppressed-message, 123 | sys-max-int, 124 | too-few-public-methods, 125 | too-many-ancestors, 126 | too-many-arguments, 127 | too-many-boolean-expressions, 128 | too-many-branches, 129 | too-many-instance-attributes, 130 | too-many-locals, 131 | too-many-public-methods, 132 | too-many-return-statements, 133 | too-many-statements, 134 | trailing-newlines, 135 | unichr-builtin, 136 | unicode-builtin, 137 | unpacking-in-except, 138 | useless-else-on-loop, 139 | useless-suppression, 140 | using-cmp-argument, 141 | xrange-builtin, 142 | wrong-import-order, 143 | zip-builtin-not-iterating, 144 | 145 | 146 | [REPORTS] 147 | 148 | # Set the output format. Available formats are text, parseable, colorized, msvs 149 | # (visual studio) and html. You can also give a reporter class, eg 150 | # mypackage.mymodule.MyReporterClass. 151 | output-format=text 152 | 153 | # Put messages in a separate file for each module / package specified on the 154 | # command line instead of printing them on stdout. Reports (if any) will be 155 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 156 | # and it will be removed in Pylint 2.0. 157 | files-output=no 158 | 159 | # Tells whether to display a full report or only the messages 160 | reports=no 161 | 162 | # Python expression which should return a note less than 10 (10 is the highest 163 | # note). You have access to the variables errors warning, statement which 164 | # respectively contain the number of errors / warnings messages and the total 165 | # number of statements analyzed. This is used by the global evaluation report 166 | # (RP0004). 167 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 168 | 169 | # Template used to display messages. This is a python new-style format string 170 | # used to format the message information. See doc for all details 171 | #msg-template= 172 | 173 | 174 | [BASIC] 175 | 176 | # Good variable names which should always be accepted, separated by a comma 177 | good-names=main,_ 178 | 179 | # Bad variable names which should always be refused, separated by a comma 180 | bad-names= 181 | 182 | # Colon-delimited sets of names that determine each other's naming style when 183 | # the name regexes allow several styles. 184 | name-group= 185 | 186 | # Include a hint for the correct naming format with invalid-name 187 | include-naming-hint=no 188 | 189 | # List of decorators that produce properties, such as abc.abstractproperty. Add 190 | # to this list to register other decorators that produce valid properties. 191 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 192 | 193 | # Regular expression matching correct function names 194 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 195 | 196 | # Regular expression matching correct variable names 197 | variable-rgx=^[a-z][a-z0-9_]*$ 198 | 199 | # Regular expression matching correct constant names 200 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 201 | 202 | # Regular expression matching correct attribute names 203 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 204 | 205 | # Regular expression matching correct argument names 206 | argument-rgx=^[a-z][a-z0-9_]*$ 207 | 208 | # Regular expression matching correct class attribute names 209 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 210 | 211 | # Regular expression matching correct inline iteration names 212 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 213 | 214 | # Regular expression matching correct class names 215 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 216 | 217 | # Regular expression matching correct module names 218 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 219 | 220 | # Regular expression matching correct method names 221 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 222 | 223 | # Regular expression which should only match function or class names that do 224 | # not require a docstring. 225 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 226 | 227 | # Minimum line length for functions/classes that require docstrings, shorter 228 | # ones are exempt. 229 | docstring-min-length=10 230 | 231 | 232 | [TYPECHECK] 233 | 234 | # List of decorators that produce context managers, such as 235 | # contextlib.contextmanager. Add to this list to register other decorators that 236 | # produce valid context managers. 237 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 238 | 239 | # Tells whether missing members accessed in mixin class should be ignored. A 240 | # mixin class is detected if its name ends with "mixin" (case insensitive). 241 | ignore-mixin-members=yes 242 | 243 | # List of module names for which member attributes should not be checked 244 | # (useful for modules/projects where namespaces are manipulated during runtime 245 | # and thus existing member attributes cannot be deduced by static analysis. It 246 | # supports qualified module names, as well as Unix pattern matching. 247 | ignored-modules= 248 | 249 | # List of class names for which member attributes should not be checked (useful 250 | # for classes with dynamically set attributes). This supports the use of 251 | # qualified names. 252 | ignored-classes=optparse.Values,thread._local,_thread._local 253 | 254 | # List of members which are set dynamically and missed by pylint inference 255 | # system, and so shouldn't trigger E1101 when accessed. Python regular 256 | # expressions are accepted. 257 | generated-members= 258 | 259 | # List of decorators that change the signature of a decorated function. 260 | signature-mutators=toolz.functoolz.curry 261 | 262 | [FORMAT] 263 | 264 | # Maximum number of characters on a single line. 265 | max-line-length=80 266 | 267 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 268 | # lines made too long by directives to pytype. 269 | 270 | # Regexp for a line that is allowed to be longer than the limit. 271 | ignore-long-lines=(?x)( 272 | ^\s*(\#\ )??$| 273 | ^\s*(from\s+\S+\s+)?import\s+.+$) 274 | 275 | # Allow the body of an if to be on the same line as the test if there is no 276 | # else. 277 | single-line-if-stmt=yes 278 | 279 | # List of optional constructs for which whitespace checking is disabled. `dict- 280 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 281 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 282 | # `empty-line` allows space-only lines. 283 | no-space-check= 284 | 285 | # Maximum number of lines in a module 286 | max-module-lines=99999 287 | 288 | # String used as indentation unit. The internal Google style guide mandates 2 289 | # spaces. Google's externally-published style guide says 4, consistent with 290 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 291 | # projects (like TensorFlow). 292 | indent-string=' ' 293 | 294 | # Number of spaces of indent required inside a hanging or continued line. 295 | indent-after-paren=4 296 | 297 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 298 | expected-line-ending-format= 299 | 300 | 301 | [MISCELLANEOUS] 302 | 303 | # List of note tags to take in consideration, separated by a comma. 304 | notes=TODO 305 | 306 | 307 | [VARIABLES] 308 | 309 | # Tells whether we should check for unused import in __init__ files. 310 | init-import=no 311 | 312 | # A regular expression matching the name of dummy variables (i.e. expectedly 313 | # not used). 314 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 315 | 316 | # List of additional names supposed to be defined in builtins. Remember that 317 | # you should avoid to define new builtins when possible. 318 | additional-builtins= 319 | 320 | # List of strings which can identify a callback function by name. A callback 321 | # name must start or end with one of those strings. 322 | callbacks=cb_,_cb 323 | 324 | # List of qualified module names which can have objects that can redefine 325 | # builtins. 326 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 327 | 328 | 329 | [LOGGING] 330 | 331 | # Logging modules to check that the string format arguments are in logging 332 | # function parameter format 333 | logging-modules=logging,absl.logging,tensorflow.google.logging 334 | 335 | 336 | [SIMILARITIES] 337 | 338 | # Minimum lines number of a similarity. 339 | min-similarity-lines=4 340 | 341 | # Ignore comments when computing similarities. 342 | ignore-comments=yes 343 | 344 | # Ignore docstrings when computing similarities. 345 | ignore-docstrings=yes 346 | 347 | # Ignore imports when computing similarities. 348 | ignore-imports=no 349 | 350 | 351 | [SPELLING] 352 | 353 | # Spelling dictionary name. Available dictionaries: none. To make it working 354 | # install python-enchant package. 355 | spelling-dict= 356 | 357 | # List of comma separated words that should not be checked. 358 | spelling-ignore-words= 359 | 360 | # A path to a file that contains private dictionary; one word per line. 361 | spelling-private-dict-file= 362 | 363 | # Tells whether to store unknown words to indicated private dictionary in 364 | # --spelling-private-dict-file option instead of raising a message. 365 | spelling-store-unknown-words=no 366 | 367 | 368 | [IMPORTS] 369 | 370 | # Deprecated modules which should not be used, separated by a comma 371 | deprecated-modules=regsub, 372 | TERMIOS, 373 | Bastion, 374 | rexec, 375 | sets 376 | 377 | # Create a graph of every (i.e. internal and external) dependencies in the 378 | # given file (report RP0402 must not be disabled) 379 | import-graph= 380 | 381 | # Create a graph of external dependencies in the given file (report RP0402 must 382 | # not be disabled) 383 | ext-import-graph= 384 | 385 | # Create a graph of internal dependencies in the given file (report RP0402 must 386 | # not be disabled) 387 | int-import-graph= 388 | 389 | # Force import order to recognize a module as part of the standard 390 | # compatibility libraries. 391 | known-standard-library= 392 | 393 | # Force import order to recognize a module as part of a third party library. 394 | known-third-party=enchant, absl 395 | 396 | # Analyse import fallback blocks. This can be used to support both Python 2 and 397 | # 3 compatible code, which means that the block might have code that exists 398 | # only in one or another interpreter, leading to false positives when analysed. 399 | analyse-fallback-blocks=no 400 | 401 | 402 | [CLASSES] 403 | 404 | # List of method names used to declare (i.e. assign) instance attributes. 405 | defining-attr-methods=__init__, 406 | __new__, 407 | setUp 408 | 409 | # List of member names, which should be excluded from the protected access 410 | # warning. 411 | exclude-protected=_asdict, 412 | _fields, 413 | _replace, 414 | _source, 415 | _make 416 | 417 | # List of valid names for the first argument in a class method. 418 | valid-classmethod-first-arg=cls, 419 | class_ 420 | 421 | # List of valid names for the first argument in a metaclass class method. 422 | valid-metaclass-classmethod-first-arg=mcs 423 | 424 | 425 | [EXCEPTIONS] 426 | 427 | # Exceptions that will emit a warning when being caught. Defaults to 428 | # "Exception" 429 | overgeneral-exceptions=StandardError, 430 | Exception, 431 | BaseException 432 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing guidelines 2 | 3 | ## How to become a contributor and submit your own code 4 | 5 | ### Contributor License Agreements 6 | 7 | We'd love to accept your patches! Before we can take them, we have to jump a 8 | couple of legal hurdles. 9 | 10 | Please fill out either the individual or corporate Contributor License Agreement 11 | (CLA). 12 | 13 | * If you are an individual writing original source code and you're sure you 14 | own the intellectual property, then you'll need to sign an [Individual CLA]. 15 | * If you work for a company that wants to allow you to contribute your work, 16 | then you'll need to sign a [Corporate CLA]. 17 | 18 | Follow either of the two links above to access the appropriate CLA and 19 | instructions for how to sign and return it. Once we receive it, we'll be able to 20 | accept your pull requests. 21 | 22 | ***NOTE***: Only original source code from you and other people that have signed 23 | the CLA can be accepted into the main repository. 24 | 25 | ### Contributing code 26 | 27 | If you have improvements to PIX, send us your pull requests! For those just 28 | getting started, Github has a [HowTo]. 29 | 30 | [Individual CLA]: https://cla.developers.google.com/about/google-individual?csw=1 "Individual CLA" 31 | [Corporate CLA]: https://cla.developers.google.com/about/google-corporate?csw=1 "Corporate CLA" 32 | [HowTo]: https://help.github.com/articles/using-pull-requests/ "GitHub PR HowTo" 33 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PIX 2 | 3 | 4 | 5 | PIX is an image processing library in [JAX], for [JAX]. 6 | 7 | [![GitHub Workflow CI](https://img.shields.io/github/actions/workflow/status/deepmind/dm_pix/ci.yml?label=pytest&logo=python&logoColor=white&style=flat-square)](https://github.com/deepmind/dm_pix/actions/workflows/ci.yml) 8 | [![Read the Docs](https://img.shields.io/readthedocs/dm_pix?label=ReadTheDocs&logo=readthedocs&logoColor=white&style=flat-square)](https://dm-pix.readthedocs.io/en/latest/?badge=latest) 9 | [![PyPI](https://img.shields.io/pypi/v/dm_pix?logo=pypi&logoColor=white&style=flat-square)](https://pypi.org/project/dm-pix/) 10 | 11 | ## Overview 12 | 13 | [JAX] is a library resulting from the union of [Autograd] and [XLA] for 14 | high-performance machine learning research. It provides [NumPy], [SciPy], 15 | automatic differentiation and first-class GPU/TPU support. 16 | 17 | PIX is a library built on top of JAX with the goal of providing image processing 18 | functions and tools to JAX in a way that they can be optimised and parallelised 19 | through [`jax.jit`][jit], [`jax.vmap`][vmap] and [`jax.pmap`][pmap]. 20 | 21 | ## Installation 22 | 23 | PIX is written in pure Python, but depends on C++ code via JAX. 24 | 25 | Because JAX installation is different depending on your CUDA version, PIX does 26 | not list JAX as a dependency in [`pyproject.toml`], although it is technically 27 | listed for reference, but commented. 28 | 29 | First, follow [JAX installation instructions] to install JAX with the relevant 30 | accelerator support. 31 | 32 | Then, install PIX using `pip`: 33 | 34 | ```bash 35 | $ pip install dm-pix 36 | ``` 37 | 38 | ## Quickstart 39 | 40 | To use `PIX`, you just need to `import dm_pix as pix` and use it right away! 41 | 42 | For example, let's assume to have loaded the JAX logo (available in 43 | `examples/assets/jax_logo.jpg`) in a variable called `image` and we want to flip 44 | it left to right. 45 | 46 | ![JAX logo] 47 | 48 | All it's needed is the following code! 49 | 50 | ```python 51 | import dm_pix as pix 52 | 53 | # Load an image into a NumPy array with your preferred library. 54 | image = load_image() 55 | 56 | flip_left_right_image = pix.flip_left_right(image) 57 | ``` 58 | 59 | And here is the result! 60 | 61 | ![JAX logo left-right] 62 | 63 | All the functions in PIX can be [`jax.jit`][jit]ed, [`jax.vmap`][vmap]ed and 64 | [`jax.pmap`][pmap]ed, so all the following functions can take advantage of 65 | optimization and parallelization. 66 | 67 | ```python 68 | import dm_pix as pix 69 | import jax 70 | 71 | # Load an image into a NumPy array with your preferred library. 72 | image = load_image() 73 | 74 | # Vanilla Python function. 75 | flip_left_right_image = pix.flip_left_right(image) 76 | 77 | # `jax.jit`ed function. 78 | flip_left_right_image = jax.jit(pix.flip_left_right)(image) 79 | 80 | # Assuming to have a single device, like a CPU or a single GPU, we add a 81 | # single leading dimension for using `image` with the parallelized or 82 | # the multi-device parallelization version of `pix.flip_left_right`. 83 | # To know more, please refer to JAX documentation of `jax.vmap` and `jax.pmap`. 84 | image = image[np.newaxis, ...] 85 | 86 | # `jax.vmap`ed function. 87 | flip_left_right_image = jax.vmap(pix.flip_left_right)(image) 88 | 89 | # `jax.pmap`ed function. 90 | flip_left_right_image = jax.pmap(pix.flip_left_right)(image) 91 | ``` 92 | 93 | You can check it yourself that the result from the four versions of 94 | `pix.flip_left_right` is the same (up to the accelerator floating point 95 | accuracy)! 96 | 97 | ## Examples 98 | 99 | We have a few examples in the [`examples/`] folder. They are not much 100 | more involved then the previous example, but they may be a good starting point 101 | for you! 102 | 103 | ## Testing 104 | 105 | We provide a suite of tests to help you both testing your development 106 | environment and to know more about the library itself! All test files have 107 | `_test` suffix, and can be executed using `pytest`. 108 | 109 | If you already have PIX installed, you just need to install some extra 110 | dependencies and run `pytest` as follows: 111 | 112 | ```bash 113 | $ pip install -e ".[test]" 114 | $ python -m pytest [-n ] dm_pix 115 | ``` 116 | 117 | If you want an isolated virtual environment, you just need to run our utility 118 | `bash` script as follows: 119 | 120 | ```bash 121 | $ ./test.sh 122 | ``` 123 | 124 | ## Citing PIX 125 | 126 | This repository is part of the [DeepMind JAX Ecosystem], to cite PIX please use 127 | the [DeepMind JAX Ecosystem citation]. 128 | 129 | ## Contribute! 130 | 131 | We are very happy to accept contributions! 132 | 133 | Please read our [contributing guidelines](./CONTRIBUTING.md) and send us PRs! 134 | 135 | [Autograd]: https://github.com/hips/autograd "Autograd on GitHub" 136 | [DeepMind JAX Ecosystem]: https://deepmind.com/blog/article/using-jax-to-accelerate-our-research "DeepMind JAX Ecosystem" 137 | [DeepMind JAX Ecosystem citation]: https://github.com/deepmind/jax/blob/main/deepmind2020jax.txt "Citation" 138 | [JAX]: https://github.com/jax-ml/jax "JAX on GitHub" 139 | [JAX installation instructions]: https://github.com/jax-ml/jax#installation "JAX installation" 140 | [jit]: https://jax.readthedocs.io/en/latest/jax.html#jax.jit "jax.jit documentation" 141 | [NumPy]: https://numpy.org/ "NumPy" 142 | [pmap]: https://jax.readthedocs.io/en/latest/jax.html#jax.pmap "jax.pmap documentation" 143 | [SciPy]: https://www.scipy.org/ "SciPy" 144 | [XLA]: https://www.tensorflow.org/xla "XLA" 145 | [vmap]: https://jax.readthedocs.io/en/latest/jax.html#jax.vmap "jax.vmap documentation" 146 | 147 | [`examples/`]: ./examples/ 148 | [JAX logo]: ./examples/assets/jax_logo.jpg 149 | [JAX logo left-right]: ./examples/assets/flip_left_right_jax_logo.jpg 150 | [`pyproject.toml`]: ./pyproject.toml 151 | -------------------------------------------------------------------------------- /dm_pix/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """PIX public APIs.""" 15 | 16 | from dm_pix._src import augment 17 | from dm_pix._src import color_conversion 18 | from dm_pix._src import depth_and_space 19 | from dm_pix._src import interpolation 20 | from dm_pix._src import metrics 21 | from dm_pix._src import patch 22 | 23 | __version__ = "0.4.4" 24 | 25 | # Augmentations. 26 | adjust_brightness = augment.adjust_brightness 27 | adjust_contrast = augment.adjust_contrast 28 | adjust_gamma = augment.adjust_gamma 29 | adjust_hue = augment.adjust_hue 30 | adjust_saturation = augment.adjust_saturation 31 | affine_transform = augment.affine_transform 32 | center_crop = augment.center_crop 33 | elastic_deformation = augment.elastic_deformation 34 | flip_left_right = augment.flip_left_right 35 | flip_up_down = augment.flip_up_down 36 | gaussian_blur = augment.gaussian_blur 37 | pad_to_size = augment.pad_to_size 38 | random_brightness = augment.random_brightness 39 | random_contrast = augment.random_contrast 40 | random_crop = augment.random_crop 41 | random_flip_left_right = augment.random_flip_left_right 42 | random_flip_up_down = augment.random_flip_up_down 43 | random_gamma = augment.random_gamma 44 | random_hue = augment.random_hue 45 | random_saturation = augment.random_saturation 46 | resize_with_crop_or_pad = augment.resize_with_crop_or_pad 47 | rotate = augment.rotate 48 | rot90 = augment.rot90 49 | solarize = augment.solarize 50 | 51 | # Color conversions. 52 | hsl_to_rgb = color_conversion.hsl_to_rgb 53 | hsv_to_rgb = color_conversion.hsv_to_rgb 54 | rgb_to_hsl = color_conversion.rgb_to_hsl 55 | rgb_to_hsv = color_conversion.rgb_to_hsv 56 | rgb_to_grayscale = color_conversion.rgb_to_grayscale 57 | 58 | # Depth and space transformations. 59 | depth_to_space = depth_and_space.depth_to_space 60 | space_to_depth = depth_and_space.space_to_depth 61 | 62 | # Interpolation functions. 63 | flat_nd_linear_interpolate = interpolation.flat_nd_linear_interpolate 64 | flat_nd_linear_interpolate_constant = ( 65 | interpolation.flat_nd_linear_interpolate_constant) 66 | 67 | # Metrics. 68 | mae = metrics.mae 69 | mse = metrics.mse 70 | psnr = metrics.psnr 71 | rmse = metrics.rmse 72 | simse = metrics.simse 73 | ssim = metrics.ssim 74 | 75 | # Patch extraction functions. 76 | extract_patches = patch.extract_patches 77 | 78 | del augment, color_conversion, depth_and_space, interpolation, metrics, patch 79 | 80 | __all__ = ( 81 | "adjust_brightness", 82 | "adjust_contrast", 83 | "adjust_gamma", 84 | "adjust_hue", 85 | "adjust_saturation", 86 | "affine_transform", 87 | "center_crop", 88 | "depth_to_space", 89 | "elastic_deformation", 90 | "extract_patches", 91 | "flat_nd_linear_interpolate", 92 | "flat_nd_linear_interpolate_constant", 93 | "flip_left_right", 94 | "flip_up_down", 95 | "gaussian_blur", 96 | "hsl_to_rgb", 97 | "hsv_to_rgb", 98 | "mae", 99 | "mse", 100 | "pad_to_size", 101 | "psnr", 102 | "random_brightness", 103 | "random_contrast", 104 | "random_crop", 105 | "random_flip_left_right", 106 | "random_flip_up_down", 107 | "random_gamma", 108 | "random_hue", 109 | "random_saturation", 110 | "resize_with_crop_or_pad", 111 | "rotate", 112 | "rgb_to_hsl", 113 | "rgb_to_hsv", 114 | "rgb_to_grayscale", 115 | "rmse", 116 | "rot90", 117 | "simse", 118 | "ssim", 119 | "solarize", 120 | "space_to_depth", 121 | ) 122 | 123 | # _________________________________________ 124 | # / Please don't use symbols in `_src` they \ 125 | # \ are not part of the PIX public API. / 126 | # ----------------------------------------- 127 | # \ ^__^ 128 | # \ (oo)\_______ 129 | # (__)\ )\/\ 130 | # ||----w | 131 | # || || 132 | # 133 | try: 134 | del _src # pylint: disable=undefined-variable 135 | except NameError: 136 | pass 137 | -------------------------------------------------------------------------------- /dm_pix/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | -------------------------------------------------------------------------------- /dm_pix/_src/api_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dm_pix API.""" 15 | 16 | import inspect 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import dm_pix as pix 21 | from dm_pix._src import test_utils 22 | 23 | 24 | class ApiTest(parameterized.TestCase): 25 | 26 | @parameterized.named_parameters(*test_utils.get_public_functions(pix)) 27 | def test_key_argument(self, f): 28 | sig = inspect.signature(f) 29 | param_names = tuple(sig.parameters) 30 | self.assertNotIn("rng", param_names, 31 | "Prefer `key` to `rng` in PIX (following JAX).") 32 | if "key" in param_names: 33 | self.assertLess( 34 | param_names.index("key"), param_names.index("image"), 35 | "RNG `key` argument should be before `image` in PIX.") 36 | 37 | @parameterized.named_parameters(*test_utils.get_public_functions(pix)) 38 | def test_kwarg_only_defaults(self, f): 39 | argspec = inspect.getfullargspec(f) 40 | if f.__name__ == "rot90": 41 | # Special case for `k` in rot90. 42 | self.assertLen(argspec.defaults, 1) 43 | return 44 | 45 | self.assertEmpty( 46 | argspec.defaults or (), 47 | "Optional keyword arguments in PIX should be keyword " 48 | "only. Prefer `f(x, *, axis=-1)` to `f(x, axis=-1)`.") 49 | 50 | 51 | if __name__ == "__main__": 52 | absltest.main() 53 | -------------------------------------------------------------------------------- /dm_pix/_src/augment.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module provides image augmentation functions. 15 | 16 | All functions expect float-encoded images, with values in [0, 1]. 17 | Do not clip their outputs to this range to allow chaining without losing 18 | information. The outside-of-bounds behavior is (as much as possible) similar to 19 | that of TensorFlow. 20 | """ 21 | 22 | import functools 23 | from typing import Any, Callable, Optional, Sequence, Tuple, Union 24 | 25 | import chex 26 | from dm_pix._src import color_conversion 27 | from dm_pix._src import interpolation 28 | import jax 29 | import jax.numpy as jnp 30 | 31 | # DO NOT REMOVE - Logging lib. 32 | 33 | 34 | def adjust_brightness(image: chex.Array, delta: chex.Numeric) -> chex.Array: 35 | """Shifts the brightness of an RGB image by a given amount. 36 | 37 | This is equivalent to tf.image.adjust_brightness. 38 | 39 | Args: 40 | image: an RGB image, given as a float tensor in [0, 1]. 41 | delta: the (additive) amount to shift each channel by. 42 | 43 | Returns: 44 | The brightness-adjusted image. May be outside of the [0, 1] range. 45 | """ 46 | # DO NOT REMOVE - Logging usage. 47 | 48 | return image + jnp.asarray(delta, image.dtype) 49 | 50 | 51 | def adjust_contrast( 52 | image: chex.Array, 53 | factor: chex.Numeric, 54 | *, 55 | channel_axis: int = -1, 56 | ) -> chex.Array: 57 | """Adjusts the contrast of an RGB image by a given multiplicative amount. 58 | 59 | This is equivalent to `tf.image.adjust_contrast`. 60 | 61 | Args: 62 | image: an RGB image, given as a float tensor in [0, 1]. 63 | factor: the (multiplicative) amount to adjust contrast by. 64 | channel_axis: the index of the channel axis. 65 | 66 | Returns: 67 | The contrast-adjusted image. May be outside of the [0, 1] range. 68 | """ 69 | # DO NOT REMOVE - Logging usage. 70 | 71 | if _channels_last(image, channel_axis): 72 | spatial_axes = (-3, -2) 73 | else: 74 | spatial_axes = (-2, -1) 75 | mean = jnp.mean(image, axis=spatial_axes, keepdims=True) 76 | return jnp.asarray(factor, image.dtype) * (image - mean) + mean 77 | 78 | 79 | def adjust_gamma( 80 | image: chex.Array, 81 | gamma: chex.Numeric, 82 | *, 83 | gain: chex.Numeric = 1., 84 | assume_in_bounds: bool = False, 85 | ) -> chex.Array: 86 | """Adjusts the gamma of an RGB image. 87 | 88 | This is equivalent to `tf.image.adjust_gamma`, i.e. returns 89 | `gain * image ** gamma`. 90 | 91 | Args: 92 | image: an RGB image, given as a [0-1] float tensor. 93 | gamma: the exponent to apply. 94 | gain: the (multiplicative) gain to apply. 95 | assume_in_bounds: whether the input image should be assumed to have all 96 | values within [0, 1]. If False (default), the inputs will be clipped to 97 | that range avoid NaNs. 98 | 99 | Returns: 100 | The gamma-adjusted image. 101 | """ 102 | # DO NOT REMOVE - Logging usage. 103 | 104 | if not assume_in_bounds: 105 | image = jnp.clip(image, 0., 1.) # Clip image for safety. 106 | return jnp.asarray(gain, image.dtype) * ( 107 | image**jnp.asarray(gamma, image.dtype)) 108 | 109 | 110 | def adjust_hue( 111 | image: chex.Array, 112 | delta: chex.Numeric, 113 | *, 114 | channel_axis: int = -1, 115 | ) -> chex.Array: 116 | """Adjusts the hue of an RGB image by a given multiplicative amount. 117 | 118 | This is equivalent to `tf.image.adjust_hue` when TF is running on GPU. When 119 | running on CPU, the results will be different if all RGB values for a pixel 120 | are outside of the [0, 1] range. 121 | 122 | Args: 123 | image: an RGB image, given as a [0-1] float tensor. 124 | delta: the (additive) angle to shift hue by. 125 | channel_axis: the index of the channel axis. 126 | 127 | Returns: 128 | The saturation-adjusted image. 129 | """ 130 | # DO NOT REMOVE - Logging usage. 131 | 132 | rgb = color_conversion.split_channels(image, channel_axis) 133 | hue, saturation, value = color_conversion.rgb_planes_to_hsv_planes(*rgb) 134 | rgb_adjusted = color_conversion.hsv_planes_to_rgb_planes((hue + delta) % 1.0, 135 | saturation, value) 136 | return jnp.stack(rgb_adjusted, axis=channel_axis) 137 | 138 | 139 | def adjust_saturation( 140 | image: chex.Array, 141 | factor: chex.Numeric, 142 | *, 143 | channel_axis: int = -1, 144 | ) -> chex.Array: 145 | """Adjusts the saturation of an RGB image by a given multiplicative amount. 146 | 147 | This is equivalent to `tf.image.adjust_saturation`. 148 | 149 | Args: 150 | image: an RGB image, given as a [0-1] float tensor. 151 | factor: the (multiplicative) amount to adjust saturation by. 152 | channel_axis: the index of the channel axis. 153 | 154 | Returns: 155 | The saturation-adjusted image. 156 | """ 157 | # DO NOT REMOVE - Logging usage. 158 | 159 | rgb = color_conversion.split_channels(image, channel_axis) 160 | hue, saturation, value = color_conversion.rgb_planes_to_hsv_planes(*rgb) 161 | factor = jnp.asarray(factor, image.dtype) 162 | rgb_adjusted = color_conversion.hsv_planes_to_rgb_planes( 163 | hue, jnp.clip(saturation * factor, 0., 1.), value) 164 | return jnp.stack(rgb_adjusted, axis=channel_axis) 165 | 166 | 167 | def elastic_deformation( 168 | key: chex.PRNGKey, 169 | image: chex.Array, 170 | alpha: chex.Numeric, 171 | sigma: chex.Numeric, 172 | *, 173 | order: int = 1, 174 | mode: str = "nearest", 175 | cval: float = 0., 176 | channel_axis: int = -1, 177 | ) -> chex.Array: 178 | """Applies an elastic deformation to the given image. 179 | 180 | Introduced by [Simard, 2003] and popularized by [Ronneberger, 2015]. Deforms 181 | images by moving pixels locally around using displacement fields. 182 | 183 | Small sigma values (< 1.) give pixelated images while higher values result 184 | in water like results. Alpha should be in the between x5 and x10 the value 185 | given for sigma for sensible results. 186 | 187 | Args: 188 | key: key: a JAX RNG key. 189 | image: a JAX array representing an image. Assumes that the image is 190 | either HWC or CHW. 191 | alpha: strength of the distortion field. Higher values mean that pixels are 192 | moved further with respect to the distortion field's direction. 193 | sigma: standard deviation of the gaussian kernel used to smooth the 194 | distortion fields. 195 | order: the order of the spline interpolation, default is 1. The order has 196 | to be in the range [0, 1]. Note that PIX interpolation will only be used 197 | for order=1, for other values we use `jax.scipy.ndimage.map_coordinates`. 198 | mode: the mode parameter determines how the input array is extended beyond 199 | its boundaries. Default is 'nearest'. Modes 'nearest and 'constant' use 200 | PIX interpolation, which is very fast on accelerators (especially on 201 | TPUs). For all other modes, 'wrap', 'mirror' and 'reflect', we rely 202 | on `jax.scipy.ndimage.map_coordinates`, which however is slow on 203 | accelerators, so use it with care. 204 | cval: value to fill past edges of input if mode is 'constant'. Default is 205 | 0.0. 206 | channel_axis: the index of the channel axis. 207 | 208 | Returns: 209 | The transformed image. 210 | """ 211 | # DO NOT REMOVE - Logging usage. 212 | 213 | chex.assert_rank(image, 3) 214 | if channel_axis != -1: 215 | image = jnp.moveaxis(image, source=channel_axis, destination=-1) 216 | single_channel_shape = (*image.shape[:-1], 1) 217 | key_i, key_j = jax.random.split(key) 218 | noise_i = jax.random.uniform(key_i, shape=single_channel_shape) * 2 - 1 219 | noise_j = jax.random.uniform(key_j, shape=single_channel_shape) * 2 - 1 220 | 221 | # ~3 sigma on each side of the kernel's center covers ~99.7% of the 222 | # probability mass. There is some fiddling for smaller values. Source: 223 | # https://docs.opencv.org/3.1.0/d4/d86/group__imgproc__filter.html#gac05a120c1ae92a6060dd0db190a61afa 224 | kernel_size = ((sigma - 0.8) / 0.3 + 1) / 0.5 + 1 225 | shift_map_i = gaussian_blur( 226 | image=noise_i, 227 | sigma=sigma, 228 | kernel_size=kernel_size) * alpha 229 | shift_map_j = gaussian_blur( 230 | image=noise_j, 231 | sigma=sigma, 232 | kernel_size=kernel_size) * alpha 233 | 234 | meshgrid = list( 235 | jnp.meshgrid( 236 | *[jnp.arange(size) for size in single_channel_shape], indexing="ij")) 237 | meshgrid[0] += shift_map_i 238 | meshgrid[1] += shift_map_j 239 | 240 | interpolate_function = _get_interpolate_function( 241 | mode=mode, 242 | order=order, 243 | cval=cval, 244 | ) 245 | transformed_image = jnp.concatenate([ 246 | interpolate_function( 247 | image[..., channel, jnp.newaxis], jnp.asarray(meshgrid)) 248 | for channel in range(image.shape[-1]) 249 | ], axis=-1) 250 | 251 | if channel_axis != -1: # Set channel axis back to original index. 252 | transformed_image = jnp.moveaxis( 253 | transformed_image, source=-1, destination=channel_axis) 254 | return transformed_image 255 | 256 | 257 | def center_crop( 258 | image: chex.Array, 259 | height: chex.Numeric, 260 | width: chex.Numeric, 261 | *, 262 | channel_axis: int = -1, 263 | ) -> chex.Array: 264 | """Crops an image to the given size keeping the same center of the original. 265 | 266 | Target height/width given can be greater than the current size of the image 267 | which results in being a no-op for that dimension. 268 | 269 | In case of odd size along any dimension the bottom/right side gets the extra 270 | pixel. 271 | 272 | Args: 273 | image: a JAX array representing an image. Assumes that the image is either 274 | ...HWC or ...CHW. 275 | height: target height to crop the image to. 276 | width: target width to crop the image to. 277 | channel_axis: the index of the channel axis. 278 | 279 | Returns: 280 | The cropped image(s). 281 | """ 282 | # DO NOT REMOVE - Logging usage. 283 | 284 | chex.assert_rank(image, {3, 4}) 285 | batch, current_height, current_width, channel = _get_dimension_values( 286 | image=image, channel_axis=channel_axis 287 | ) 288 | center_h, center_w = current_height // 2, current_width // 2 289 | 290 | left = max(center_w - (width // 2), 0) 291 | right = min(left + width, current_width) 292 | top = max(center_h - (height // 2), 0) 293 | bottom = min(top + height, current_height) 294 | 295 | if _channels_last(image, channel_axis): 296 | start_indices = (top, left, 0) 297 | limit_indices = (bottom, right, channel) 298 | else: 299 | start_indices = (0, top, left) 300 | limit_indices = (channel, bottom, right) 301 | 302 | if batch is not None: # In case batch of images is given. 303 | start_indices = (0, *start_indices) 304 | limit_indices = (batch, *limit_indices) 305 | 306 | return jax.lax.slice( 307 | image, start_indices=start_indices, limit_indices=limit_indices 308 | ) 309 | 310 | 311 | def pad_to_size( 312 | image: chex.Array, 313 | target_height: int, 314 | target_width: int, 315 | *, 316 | mode: str = "constant", 317 | pad_kwargs: Optional[Any] = None, 318 | channel_axis: int = -1, 319 | ) -> chex.Array: 320 | """Pads an image to the given size keeping the original image centered. 321 | 322 | For different padding methods and kwargs please see: 323 | https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.pad.html 324 | 325 | In case of odd size difference along any dimension the bottom/right side gets 326 | the extra padding pixel. 327 | 328 | Target size can be smaller than original size which results in a no-op for 329 | such dimension. 330 | 331 | Args: 332 | image: a JAX array representing an image. Assumes that the image is either 333 | ...HWC or ...CHW. 334 | target_height: target height to pad the image to. 335 | target_width: target width to pad the image to. 336 | mode: Mode for padding the images, see jax.numpy.pad for details. Default is 337 | `constant`. 338 | pad_kwargs: Keyword arguments to pass jax.numpy.pad, see documentation for 339 | options. 340 | channel_axis: the index of the channel axis. 341 | 342 | Returns: 343 | The padded image(s). 344 | """ 345 | # DO NOT REMOVE - Logging usage. 346 | 347 | chex.assert_rank(image, {3, 4}) 348 | batch, height, width, _ = _get_dimension_values( 349 | image=image, channel_axis=channel_axis 350 | ) 351 | delta_width = max(target_width - width, 0) 352 | delta_height = max(target_height - height, 0) 353 | if delta_width == 0 and delta_height == 0: 354 | return image 355 | 356 | left = delta_width // 2 357 | right = max(target_width - (left + width), 0) 358 | top = delta_height // 2 359 | bottom = max(target_height - (top + height), 0) 360 | 361 | pad_width = ((top, bottom), (left, right), (0, 0)) 362 | if batch: 363 | pad_width = ((0, 0), *pad_width) 364 | 365 | return jnp.pad(image, pad_width=pad_width, mode=mode, **pad_kwargs or {}) 366 | 367 | 368 | def resize_with_crop_or_pad( 369 | image: chex.Array, 370 | target_height: chex.Numeric, 371 | target_width: chex.Numeric, 372 | *, 373 | pad_mode: str = "constant", 374 | pad_kwargs: Optional[Any] = None, 375 | channel_axis: int = -1, 376 | ) -> chex.Array: 377 | """Crops and/or pads an image to a target width and height. 378 | 379 | Equivalent in functionality to tf.image.resize_with_crop_or_pad but allows for 380 | different padding methods as well beyond zero padding. 381 | 382 | Args: 383 | image: a JAX array representing an image. Assumes that the image is either 384 | ...HWC or ...CHW. 385 | target_height: target height to crop or pad the image to. 386 | target_width: target width to crop or pad the image to. 387 | pad_mode: mode for padding the images, see jax.numpy.pad for details. 388 | Default is `constant`. 389 | pad_kwargs: keyword arguments to pass jax.numpy.pad, see documentation for 390 | options. 391 | channel_axis: the index of the channel axis. 392 | 393 | Returns: 394 | The image(s) resized by crop or pad to the desired target size. 395 | """ 396 | # DO NOT REMOVE - Logging usage. 397 | 398 | chex.assert_rank(image, {3, 4}) 399 | image = center_crop( 400 | image, 401 | height=target_height, 402 | width=target_width, 403 | channel_axis=channel_axis, 404 | ) 405 | return pad_to_size( 406 | image, 407 | target_height=target_height, 408 | target_width=target_width, 409 | channel_axis=channel_axis, 410 | mode=pad_mode, 411 | pad_kwargs=pad_kwargs, 412 | ) 413 | 414 | 415 | def flip_left_right( 416 | image: chex.Array, 417 | *, 418 | channel_axis: int = -1, 419 | ) -> chex.Array: 420 | """Flips an image along the horizontal axis. 421 | 422 | Assumes that the image is either ...HWC or ...CHW and flips the W axis. 423 | 424 | Args: 425 | image: a JAX array representing an image. Assumes that the image is either 426 | ...HWC or ...CHW. 427 | channel_axis: the index of the channel axis. 428 | 429 | Returns: 430 | The flipped image. 431 | """ 432 | # DO NOT REMOVE - Logging usage. 433 | 434 | if _channels_last(image, channel_axis): 435 | flip_axis = -2 # Image is ...HWC 436 | else: 437 | flip_axis = -1 # Image is ...CHW 438 | return jnp.flip(image, axis=flip_axis) 439 | 440 | 441 | def flip_up_down( 442 | image: chex.Array, 443 | *, 444 | channel_axis: int = -1, 445 | ) -> chex.Array: 446 | """Flips an image along the vertical axis. 447 | 448 | Assumes that the image is either ...HWC or ...CHW, and flips the H axis. 449 | 450 | Args: 451 | image: a JAX array representing an image. Assumes that the image is either 452 | ...HWC or ...CHW. 453 | channel_axis: the index of the channel axis. 454 | 455 | Returns: 456 | The flipped image. 457 | """ 458 | # DO NOT REMOVE - Logging usage. 459 | 460 | if _channels_last(image, channel_axis): 461 | flip_axis = -3 # Image is ...HWC 462 | else: 463 | flip_axis = -2 # Image is ...CHW 464 | return jnp.flip(image, axis=flip_axis) 465 | 466 | 467 | def gaussian_blur( 468 | image: chex.Array, 469 | sigma: chex.Numeric, 470 | kernel_size: float, 471 | *, 472 | padding: str = "SAME", 473 | channel_axis: int = -1, 474 | ) -> chex.Array: 475 | """Applies gaussian blur (convolution with a Gaussian kernel). 476 | 477 | Args: 478 | image: the input image, as a [0-1] float tensor. Should have 3 or 4 479 | dimensions with two spatial dimensions. 480 | sigma: the standard deviation (in pixels) of the gaussian kernel. 481 | kernel_size: the size (in pixels) of the square gaussian kernel. Will be 482 | "rounded" to the next odd integer. 483 | padding: either "SAME" or "VALID", passed to the underlying convolution. 484 | channel_axis: the index of the channel axis. 485 | 486 | Returns: 487 | The blurred image. 488 | """ 489 | # DO NOT REMOVE - Logging usage. 490 | 491 | chex.assert_rank(image, {3, 4}) 492 | data_format = "NHWC" if _channels_last(image, channel_axis) else "NCHW" 493 | dimension_numbers = (data_format, "HWIO", data_format) 494 | num_channels = image.shape[channel_axis] 495 | radius = int(kernel_size / 2) 496 | kernel_size_ = 2 * radius + 1 497 | x = jnp.arange(-radius, radius + 1).astype(jnp.float32) 498 | blur_filter = jnp.exp(-x**2 / (2. * sigma**2)) 499 | blur_filter = blur_filter / jnp.sum(blur_filter) 500 | blur_v = jnp.reshape(blur_filter, [kernel_size_, 1, 1, 1]) 501 | blur_h = jnp.reshape(blur_filter, [1, kernel_size_, 1, 1]) 502 | blur_h = jnp.tile(blur_h, [1, 1, 1, num_channels]) 503 | blur_v = jnp.tile(blur_v, [1, 1, 1, num_channels]) 504 | 505 | expand_batch_dim = image.ndim == 3 506 | if expand_batch_dim: 507 | image = image[jnp.newaxis, ...] 508 | blurred = _depthwise_conv2d( 509 | image, 510 | kernel=blur_h, 511 | strides=(1, 1), 512 | padding=padding, 513 | channel_axis=channel_axis, 514 | dimension_numbers=dimension_numbers) 515 | blurred = _depthwise_conv2d( 516 | blurred, 517 | kernel=blur_v, 518 | strides=(1, 1), 519 | padding=padding, 520 | channel_axis=channel_axis, 521 | dimension_numbers=dimension_numbers) 522 | if expand_batch_dim: 523 | blurred = jnp.squeeze(blurred, axis=0) 524 | return blurred 525 | 526 | 527 | def rot90( 528 | image: chex.Array, 529 | k: int = 1, 530 | *, 531 | channel_axis: int = -1, 532 | ) -> chex.Array: 533 | """Rotates an image counter-clockwise by 90 degrees. 534 | 535 | This is equivalent to tf.image.rot90. Assumes that the image is either 536 | ...HWC or ...CHW. 537 | 538 | Args: 539 | image: an RGB image, given as a float tensor in [0, 1]. 540 | k: the number of times the rotation is applied. 541 | channel_axis: the index of the channel axis. 542 | 543 | Returns: 544 | The rotated image. 545 | """ 546 | # DO NOT REMOVE - Logging usage. 547 | 548 | if _channels_last(image, channel_axis): 549 | spatial_axes = (-3, -2) # Image is ...HWC 550 | else: 551 | spatial_axes = (-2, -1) # Image is ...CHW 552 | return jnp.rot90(image, k, spatial_axes) 553 | 554 | 555 | def solarize(image: chex.Array, threshold: chex.Numeric) -> chex.Array: 556 | """Applies solarization to an image. 557 | 558 | All values above a given threshold will be inverted. 559 | 560 | Args: 561 | image: an RGB image, given as a [0-1] float tensor. 562 | threshold: the threshold for inversion. 563 | 564 | Returns: 565 | The solarized image. 566 | """ 567 | # DO NOT REMOVE - Logging usage. 568 | 569 | return jnp.where(image < threshold, image, 1. - image) 570 | 571 | 572 | def affine_transform( 573 | image: chex.Array, 574 | matrix: chex.Array, 575 | *, 576 | offset: Union[chex.Array, chex.Numeric] = 0., 577 | order: int = 1, 578 | mode: str = "nearest", 579 | cval: float = 0.0, 580 | ) -> chex.Array: 581 | """Applies an affine transformation given by matrix. 582 | 583 | Given an output image pixel index vector o, the pixel value is determined from 584 | the input image at position jnp.dot(matrix, o) + offset. 585 | 586 | This does 'pull' (or 'backward') resampling, transforming the output space to 587 | the input to locate data. Affine transformations are often described in the 588 | 'push' (or 'forward') direction, transforming input to output. If you have a 589 | matrix for the 'push' transformation, use its inverse (jax.numpy.linalg.inv) 590 | in this function. 591 | 592 | Args: 593 | image: a JAX array representing an image. Assumes that the image is 594 | either HWC or CHW. 595 | matrix: the inverse coordinate transformation matrix, mapping output 596 | coordinates to input coordinates. If ndim is the number of dimensions of 597 | input, the given matrix must have one of the following shapes: 598 | 599 | - (ndim, ndim): the linear transformation matrix for each output 600 | coordinate. 601 | - (ndim,): assume that the 2-D transformation matrix is diagonal, with the 602 | diagonal specified by the given value. 603 | - (ndim + 1, ndim + 1): assume that the transformation is specified using 604 | homogeneous coordinates [1]. In this case, any value passed to offset is 605 | ignored. 606 | - (ndim, ndim + 1): as above, but the bottom row of a homogeneous 607 | transformation matrix is always [0, 0, 0, 1], and may be omitted. 608 | 609 | offset: the offset into the array where the transform is applied. If a 610 | float, offset is the same for each axis. If an array, offset should 611 | contain one value for each axis. 612 | order: the order of the spline interpolation, default is 1. The order has 613 | to be in the range [0-1]. Note that PIX interpolation will only be used 614 | for order=1, for other values we use `jax.scipy.ndimage.map_coordinates`. 615 | mode: the mode parameter determines how the input array is extended beyond 616 | its boundaries. Default is 'nearest'. Modes 'nearest and 'constant' use 617 | PIX interpolation, which is very fast on accelerators (especially on 618 | TPUs). For all other modes, 'wrap', 'mirror' and 'reflect', we rely 619 | on `jax.scipy.ndimage.map_coordinates`, which however is slow on 620 | accelerators, so use it with care. 621 | cval: value to fill past edges of input if mode is 'constant'. Default is 622 | 0.0. 623 | 624 | Returns: 625 | The input image transformed by the given matrix. 626 | 627 | Example transformations: 628 | Rotation: 629 | 630 | >>> angle = jnp.pi / 4 631 | >>> matrix = jnp.array([ 632 | ... [jnp.cos(rotation), -jnp.sin(rotation), 0], 633 | ... [jnp.sin(rotation), jnp.cos(rotation), 0], 634 | ... [0, 0, 1], 635 | ... ]) 636 | >>> result = dm_pix.affine_transform(image=image, matrix=matrix) 637 | 638 | Translation can be expressed through either the matrix itself 639 | or the offset parameter: 640 | 641 | >>> matrix = jnp.array([ 642 | ... [1, 0, 0, 25], 643 | ... [0, 1, 0, 25], 644 | ... [0, 0, 1, 0], 645 | ... ]) 646 | >>> result = dm_pix.affine_transform(image=image, matrix=matrix) 647 | >>> # Or with offset: 648 | >>> matrix = jnp.array([ 649 | ... [1, 0, 0], 650 | ... [0, 1, 0], 651 | ... [0, 0, 1], 652 | ... ]) 653 | >>> offset = jnp.array([25, 25, 0]) 654 | >>> result = dm_pix.affine_transform( 655 | image=image, matrix=matrix, offset=offset) 656 | 657 | Reflection: 658 | 659 | >>> matrix = jnp.array([ 660 | ... [-1, 0, 0], 661 | ... [0, 1, 0], 662 | ... [0, 0, 1], 663 | ... ]) 664 | >>> result = dm_pix.affine_transform(image=image, matrix=matrix) 665 | 666 | Scale: 667 | 668 | >>> matrix = jnp.array([ 669 | ... [2, 0, 0], 670 | ... [0, 1, 0], 671 | ... [0, 0, 1], 672 | ... ]) 673 | >>> result = dm_pix.affine_transform(image=image, matrix=matrix) 674 | 675 | Shear: 676 | 677 | >>> matrix = jnp.array([ 678 | ... [1, 0.5, 0], 679 | ... [0.5, 1, 0], 680 | ... [0, 0, 1], 681 | ... ]) 682 | >>> result = dm_pix.affine_transform(image=image, matrix=matrix) 683 | 684 | One can also combine different transformations matrices: 685 | 686 | >>> matrix = rotation_matrix.dot(translation_matrix) 687 | """ 688 | # DO NOT REMOVE - Logging usage. 689 | 690 | chex.assert_rank(image, 3) 691 | chex.assert_rank(matrix, {1, 2}) 692 | chex.assert_rank(offset, {0, 1}) 693 | 694 | if matrix.ndim == 1: 695 | matrix = jnp.diag(matrix) 696 | 697 | if matrix.shape not in [(3, 3), (4, 4), (3, 4)]: 698 | error_msg = ( 699 | "Expected matrix shape must be one of (ndim, ndim), (ndim,)" 700 | "(ndim + 1, ndim + 1) or (ndim, ndim + 1) being ndim the image.ndim. " 701 | f"The affine matrix provided has shape {matrix.shape}.") 702 | raise ValueError(error_msg) 703 | 704 | meshgrid = jnp.meshgrid(*[jnp.arange(size) for size in image.shape], 705 | indexing="ij") 706 | indices = jnp.concatenate( 707 | [jnp.expand_dims(x, axis=-1) for x in meshgrid], axis=-1) 708 | 709 | if matrix.shape == (4, 4) or matrix.shape == (3, 4): 710 | offset = matrix[:image.ndim, image.ndim] 711 | matrix = matrix[:image.ndim, :image.ndim] 712 | 713 | coordinates = indices @ matrix.T 714 | coordinates = jnp.moveaxis(coordinates, source=-1, destination=0) 715 | 716 | # Alter coordinates to account for offset. 717 | offset = jnp.full((3,), fill_value=offset) 718 | coordinates += jnp.reshape(offset, (*offset.shape, 1, 1, 1)) 719 | 720 | interpolate_function = _get_interpolate_function( 721 | mode=mode, 722 | order=order, 723 | cval=cval, 724 | ) 725 | return interpolate_function(image, coordinates) 726 | 727 | 728 | def rotate( 729 | image: chex.Array, 730 | angle: float, 731 | *, 732 | order: int = 1, 733 | mode: str = "nearest", 734 | cval: float = 0.0, 735 | ) -> chex.Array: 736 | """Rotates an image around its center using interpolation. 737 | 738 | Args: 739 | image: a JAX array representing an image. Assumes that the image is 740 | either HWC or CHW. 741 | angle: the counter-clockwise rotation angle in units of radians. 742 | order: the order of the spline interpolation, default is 1. The order has 743 | to be in the range [0,1]. See `affine_transform` for details. 744 | mode: the mode parameter determines how the input array is extended beyond 745 | its boundaries. Default is 'nearest'. See `affine_transform` for details. 746 | cval: value to fill past edges of input if mode is 'constant'. Default is 747 | 0.0. 748 | 749 | Returns: 750 | The rotated image. 751 | """ 752 | # DO NOT REMOVE - Logging usage. 753 | 754 | # Calculate inverse transform matrix assuming clockwise rotation. 755 | c = jnp.cos(angle) 756 | s = jnp.sin(angle) 757 | matrix = jnp.array([[c, s, 0], [-s, c, 0], [0, 0, 1]]) 758 | 759 | # Use the offset to place the rotation at the image center. 760 | image_center = (jnp.asarray(image.shape) - 1.) / 2. 761 | offset = image_center - matrix @ image_center 762 | 763 | return affine_transform(image, matrix, offset=offset, order=order, mode=mode, 764 | cval=cval) 765 | 766 | 767 | def random_flip_left_right( 768 | key: chex.PRNGKey, 769 | image: chex.Array, 770 | *, 771 | probability: chex.Numeric = 0.5, 772 | ) -> chex.Array: 773 | """Applies `flip_left_right` with a given probability. 774 | 775 | Args: 776 | key: a JAX RNG key. 777 | image: a JAX array representing an image. Assumes that the image is either 778 | ...HWC or ...CHW. 779 | probability: the probability of applying flip_left_right transform. Must be 780 | a value in [0, 1]. 781 | 782 | Returns: 783 | A left-right flipped image if condition is met, otherwise original image. 784 | """ 785 | # DO NOT REMOVE - Logging usage. 786 | 787 | should_transform = jax.random.bernoulli(key=key, p=probability) 788 | return jax.lax.cond(should_transform, flip_left_right, lambda x: x, image) 789 | 790 | 791 | def random_flip_up_down( 792 | key: chex.PRNGKey, 793 | image: chex.Array, 794 | *, 795 | probability: chex.Numeric = 0.5, 796 | ) -> chex.Array: 797 | """Applies `flip_up_down` with a given probability. 798 | 799 | Args: 800 | key: a JAX RNG key. 801 | image: a JAX array representing an image. Assumes that the image is either 802 | ...HWC or ...CHW. 803 | probability: the probability of applying flip_up_down transform. Must be a 804 | value in [0, 1]. 805 | 806 | Returns: 807 | An up-down flipped image if condition is met, otherwise original image. 808 | """ 809 | # DO NOT REMOVE - Logging usage. 810 | 811 | should_transform = jax.random.bernoulli(key=key, p=probability) 812 | return jax.lax.cond(should_transform, flip_up_down, lambda x: x, image) 813 | 814 | 815 | def random_brightness( 816 | key: chex.PRNGKey, 817 | image: chex.Array, 818 | max_delta: chex.Numeric, 819 | ) -> chex.Array: 820 | """`adjust_brightness(...)` with random delta in `[-max_delta, max_delta)`.""" 821 | # DO NOT REMOVE - Logging usage. 822 | 823 | delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta) 824 | return adjust_brightness(image, delta) 825 | 826 | 827 | def random_gamma( 828 | key: chex.PRNGKey, 829 | image: chex.Array, 830 | min_gamma: chex.Numeric, 831 | max_gamma: chex.Numeric, 832 | *, 833 | gain: chex.Numeric = 1, 834 | assume_in_bounds: bool = False, 835 | ) -> chex.Array: 836 | """`adjust_gamma(...)` with random gamma in [min_gamma, max_gamma)`.""" 837 | # DO NOT REMOVE - Logging usage. 838 | 839 | gamma = jax.random.uniform(key, (), minval=min_gamma, maxval=max_gamma) 840 | return adjust_gamma( 841 | image, gamma, gain=gain, assume_in_bounds=assume_in_bounds) 842 | 843 | 844 | def random_hue( 845 | key: chex.PRNGKey, 846 | image: chex.Array, 847 | max_delta: chex.Numeric, 848 | *, 849 | channel_axis: int = -1, 850 | ) -> chex.Array: 851 | """`adjust_hue(...)` with random delta in `[-max_delta, max_delta)`.""" 852 | # DO NOT REMOVE - Logging usage. 853 | 854 | delta = jax.random.uniform(key, (), minval=-max_delta, maxval=max_delta) 855 | return adjust_hue(image, delta, channel_axis=channel_axis) 856 | 857 | 858 | def random_contrast( 859 | key: chex.PRNGKey, 860 | image: chex.Array, 861 | lower: chex.Numeric, 862 | upper: chex.Numeric, 863 | *, 864 | channel_axis: int = -1, 865 | ) -> chex.Array: 866 | """`adjust_contrast(...)` with random factor in `[lower, upper)`.""" 867 | # DO NOT REMOVE - Logging usage. 868 | 869 | factor = jax.random.uniform(key, (), minval=lower, maxval=upper) 870 | return adjust_contrast(image, factor, channel_axis=channel_axis) 871 | 872 | 873 | def random_saturation( 874 | key: chex.PRNGKey, 875 | image: chex.Array, 876 | lower: chex.Numeric, 877 | upper: chex.Numeric, 878 | *, 879 | channel_axis: int = -1, 880 | ) -> chex.Array: 881 | """`adjust_saturation(...)` with random factor in `[lower, upper)`.""" 882 | # DO NOT REMOVE - Logging usage. 883 | 884 | factor = jax.random.uniform(key, (), minval=lower, maxval=upper) 885 | return adjust_saturation(image, factor, channel_axis=channel_axis) 886 | 887 | 888 | def random_crop( 889 | key: chex.PRNGKey, 890 | image: chex.Array, 891 | crop_sizes: Sequence[int], 892 | ) -> chex.Array: 893 | """Crop images randomly to specified sizes. 894 | 895 | Given an input image, it crops the image to the specified `crop_sizes`. If 896 | `crop_sizes` are lesser than the image's sizes, the offset for cropping is 897 | chosen at random. To deterministically crop an image, 898 | please use `jax.lax.dynamic_slice` and specify offsets and crop sizes. 899 | 900 | Args: 901 | key: key for pseudo-random number generator. 902 | image: a JAX array which represents an image. 903 | crop_sizes: a sequence of integers, each of which sequentially specifies the 904 | crop size along the corresponding dimension of the image. Sequence length 905 | must be identical to the rank of the image and the crop size should not be 906 | greater than the corresponding image dimension. 907 | 908 | Returns: 909 | A cropped image, a JAX array whose shape is same as `crop_sizes`. 910 | """ 911 | # DO NOT REMOVE - Logging usage. 912 | 913 | image_shape = image.shape 914 | assert len(image_shape) == len(crop_sizes), ( 915 | f"Number of image dims {len(image_shape)} and number of crop_sizes " 916 | f"{len(crop_sizes)} do not match.") 917 | assert image_shape >= crop_sizes, ( 918 | f"Crop sizes {crop_sizes} should be a subset of image size {image_shape} " 919 | "in each dimension .") 920 | random_keys = jax.random.split(key, len(crop_sizes)) 921 | 922 | slice_starts = [ 923 | jax.random.randint(k, (), 0, img_size - crop_size + 1) 924 | for k, img_size, crop_size in zip(random_keys, image_shape, crop_sizes) 925 | ] 926 | out = jax.lax.dynamic_slice(image, slice_starts, crop_sizes) 927 | 928 | return out 929 | 930 | 931 | def _channels_last(image: chex.Array, channel_axis: int): 932 | last = channel_axis == -1 or channel_axis == (image.ndim - 1) 933 | if not last: 934 | assert channel_axis == -3 or channel_axis == range(image.ndim)[-3] 935 | return last 936 | 937 | 938 | def _depthwise_conv2d( 939 | inputs: chex.Array, 940 | kernel: chex.Array, 941 | *, 942 | strides: Tuple[int, int], 943 | padding: str, 944 | channel_axis: int, 945 | dimension_numbers: Tuple[str, str, str], 946 | ) -> chex.Array: 947 | """Computes a depthwise conv2d in Jax. 948 | 949 | Reference implementation: http://shortn/_oEpb0c2V3l 950 | 951 | Args: 952 | inputs: an NHWC or NCHW tensor (depending on dimension_numbers), with N=1. 953 | kernel: a [H', W', 1, C] tensor. 954 | strides: optional stride for the kernel. 955 | padding: "SAME" or "VALID". 956 | channel_axis: the index of the channel axis. 957 | dimension_numbers: see jax.lax.conv_general_dilated. 958 | 959 | Returns: 960 | The depthwise convolution of inputs with kernel, with the same 961 | dimension_numbers as the input. 962 | """ 963 | return jax.lax.conv_general_dilated( 964 | inputs, 965 | kernel, 966 | strides, 967 | padding, 968 | feature_group_count=inputs.shape[channel_axis], 969 | dimension_numbers=dimension_numbers) 970 | 971 | 972 | def _get_interpolate_function( 973 | mode: str, 974 | order: int, 975 | cval: float = 0., 976 | ) -> Callable[[chex.Array, chex.Array], chex.Array]: 977 | """Selects the interpolation function to use based on the given parameters. 978 | 979 | PIX interpolations are preferred given they are faster on accelerators. For 980 | the cases where such interpolation is not implemented by PIX we really on 981 | jax.scipy.ndimage.map_coordinates. See specifics below. 982 | 983 | Args: 984 | mode: the mode parameter determines how the input array is extended beyond 985 | its boundaries. Modes 'nearest and 'constant' use PIX interpolation, which 986 | is very fast on accelerators (especially on TPUs). For all other modes, 987 | 'wrap', 'mirror' and 'reflect', we rely on 988 | `jax.scipy.ndimage.map_coordinates`, which however is slow on 989 | accelerators, so use it with care. 990 | order: the order of the spline interpolation. The order has to be in the 991 | range [0, 1]. Note that PIX interpolation will only be used for order=1, 992 | for other values we use `jax.scipy.ndimage.map_coordinates`. 993 | cval: value to fill past edges of input if mode is 'constant'. 994 | 995 | Returns: 996 | The selected interpolation function. 997 | """ 998 | if mode == "nearest" and order == 1: 999 | interpolate_function = interpolation.flat_nd_linear_interpolate 1000 | elif mode == "constant" and order == 1: 1001 | interpolate_function = functools.partial( 1002 | interpolation.flat_nd_linear_interpolate_constant, cval=cval) 1003 | else: 1004 | interpolate_function = functools.partial( 1005 | jax.scipy.ndimage.map_coordinates, mode=mode, order=order, cval=cval) 1006 | return interpolate_function 1007 | 1008 | 1009 | def _get_dimension_values( 1010 | image: chex.Array, 1011 | channel_axis: int, 1012 | ) -> Tuple[Optional[int], int, int, int]: 1013 | """Gets shape values in BHWC order. 1014 | 1015 | If single image is given B is None. 1016 | 1017 | Small utility to get dimension values regardless of channel axis and single 1018 | image or batch of images are passed. 1019 | 1020 | Args: 1021 | image: a JAX array representing an image. Assumes that the image is either 1022 | ...HWC or ...CHW. 1023 | channel_axis: channel_axis: the index of the channel axis. 1024 | 1025 | Returns: 1026 | A tuple with the values of each dimension in order BHWC. 1027 | """ 1028 | chex.assert_rank(image, {3, 4}) 1029 | if image.ndim == 4: 1030 | if _channels_last(image=image, channel_axis=channel_axis): 1031 | batch, height, width, channel = image.shape 1032 | else: 1033 | batch, channel, height, width = image.shape 1034 | else: 1035 | if _channels_last(image=image, channel_axis=channel_axis): 1036 | batch, (height, width, channel) = None, image.shape 1037 | else: 1038 | batch, (channel, height, width) = None, image.shape 1039 | return batch, height, width, channel 1040 | -------------------------------------------------------------------------------- /dm_pix/_src/augment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dm_pix._src.augment.""" 15 | 16 | import functools 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | from dm_pix._src import augment 21 | import jax 22 | import jax.numpy as jnp 23 | import numpy as np 24 | import scipy 25 | import tensorflow as tf 26 | 27 | _IMG_SHAPE = (131, 111, 3) 28 | _RAND_FLOATS_IN_RANGE = list( 29 | np.random.uniform(0., 1., size=(10,) + _IMG_SHAPE).astype(np.float32)) 30 | _RAND_FLOATS_OUT_OF_RANGE = list( 31 | np.random.uniform(-0.5, 1.5, size=(10,) + _IMG_SHAPE).astype(np.float32)) 32 | _KERNEL_SIZE = _IMG_SHAPE[0] / 10. 33 | 34 | 35 | class _ImageAugmentationTest(parameterized.TestCase): 36 | """Runs tests for the various augments with the correct arguments.""" 37 | 38 | def _test_fn_with_random_arg(self, images_list, jax_fn, reference_fn, 39 | **kw_range): 40 | pass 41 | 42 | def _test_fn(self, images_list, jax_fn, reference_fn): 43 | pass 44 | 45 | def assertAllCloseTolerant(self, x, y): 46 | # Increase tolerance on TPU due to lower precision. 47 | tol = 1e-2 if jax.local_devices()[0].platform == "tpu" else 1e-4 48 | np.testing.assert_allclose(x, y, rtol=tol, atol=tol) 49 | self.assertEqual(x.dtype, y.dtype) 50 | 51 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 52 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 53 | def test_adjust_brightness(self, images_list): 54 | self._test_fn_with_random_arg( 55 | images_list, 56 | jax_fn=augment.adjust_brightness, 57 | reference_fn=tf.image.adjust_brightness, 58 | delta=(-0.5, 0.5)) 59 | 60 | key = jax.random.PRNGKey(0) 61 | self._test_fn_with_random_arg( 62 | images_list, 63 | jax_fn=functools.partial(augment.random_brightness, key), 64 | reference_fn=None, 65 | max_delta=(0, 0.5)) 66 | 67 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 68 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 69 | def test_adjust_contrast(self, images_list): 70 | self._test_fn_with_random_arg( 71 | images_list, 72 | jax_fn=augment.adjust_contrast, 73 | reference_fn=tf.image.adjust_contrast, 74 | factor=(0.5, 1.5)) 75 | key = jax.random.PRNGKey(0) 76 | self._test_fn_with_random_arg( 77 | images_list, 78 | jax_fn=functools.partial(augment.random_contrast, key, upper=1), 79 | reference_fn=None, 80 | lower=(0, 0.9)) 81 | 82 | # Doesn't make sense outside of [0, 1]. 83 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE)) 84 | def test_adjust_gamma(self, images_list): 85 | self._test_fn_with_random_arg( 86 | images_list, 87 | jax_fn=augment.adjust_gamma, 88 | reference_fn=tf.image.adjust_gamma, 89 | gamma=(0.5, 1.5)) 90 | key = jax.random.PRNGKey(0) 91 | self._test_fn_with_random_arg( 92 | images_list, 93 | jax_fn=functools.partial(augment.random_gamma, key, min_gamma=1), 94 | reference_fn=None, 95 | max_gamma=(1.5, 1.9)) 96 | 97 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 98 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 99 | def test_adjust_saturation(self, images_list): 100 | # tf.image.adjust_saturation has a buggy implementation when the green and 101 | # blue channels have very close values that don't match the red channel. 102 | # This is due to a rounding error in http://shortn/_ETSJsEwUj5 103 | # if (g - b) < 0 but small enough that (hh + 1) == 1. 104 | # Eg: tf.image.adjust_saturation([[[0.75, 0.0369078, 0.0369079]]], 1.0) 105 | # -> [[[0.03690779, 0.03690779, 0.03690779]]] 106 | # Perturb the inputs slightly so that this doesn't happen. 107 | def perturb(rgb): 108 | rgb_new = np.copy(rgb) 109 | rgb_new[..., 1] += 0.001 * (np.abs(rgb[..., 2] - rgb[..., 1]) < 1e-3) 110 | return rgb_new 111 | 112 | images_list = list(map(perturb, images_list)) 113 | self._test_fn_with_random_arg( 114 | images_list, 115 | jax_fn=augment.adjust_saturation, 116 | reference_fn=tf.image.adjust_saturation, 117 | factor=(0.5, 1.5)) 118 | key = jax.random.PRNGKey(0) 119 | self._test_fn_with_random_arg( 120 | images_list, 121 | jax_fn=functools.partial(augment.random_saturation, key, upper=1), 122 | reference_fn=None, 123 | lower=(0, 0.9)) 124 | 125 | # CPU TF uses a different hue adjustment method outside of the [0, 1] range. 126 | # Disable out-of-range tests. 127 | @parameterized.named_parameters( 128 | ("in_range", _RAND_FLOATS_IN_RANGE),) 129 | def test_adjust_hue(self, images_list): 130 | self._test_fn_with_random_arg( 131 | images_list, 132 | jax_fn=augment.adjust_hue, 133 | reference_fn=tf.image.adjust_hue, 134 | delta=(-0.5, 0.5)) 135 | key = jax.random.PRNGKey(0) 136 | self._test_fn_with_random_arg( 137 | images_list, 138 | jax_fn=functools.partial(augment.random_hue, key), 139 | reference_fn=None, 140 | max_delta=(0, 0.5)) 141 | 142 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 143 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 144 | def test_rot90(self, images_list): 145 | self._test_fn( 146 | images_list, 147 | jax_fn=lambda img: augment.rot90(img, k=1), 148 | reference_fn=lambda img: tf.image.rot90(img, k=1)) 149 | self._test_fn( 150 | images_list, 151 | jax_fn=lambda img: augment.rot90(img, k=2), 152 | reference_fn=lambda img: tf.image.rot90(img, k=2)) 153 | self._test_fn( 154 | images_list, 155 | jax_fn=lambda img: augment.rot90(img, k=3), 156 | reference_fn=lambda img: tf.image.rot90(img, k=3)) 157 | 158 | # The functions below don't have a TF equivalent to compare to, we just check 159 | # that they run. 160 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 161 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 162 | def test_flip(self, images_list): 163 | self._test_fn( 164 | images_list, 165 | jax_fn=augment.flip_left_right, 166 | reference_fn=tf.image.flip_left_right) 167 | self._test_fn( 168 | images_list, 169 | jax_fn=augment.flip_up_down, 170 | reference_fn=tf.image.flip_up_down) 171 | key = jax.random.PRNGKey(0) 172 | self._test_fn( 173 | images_list, 174 | jax_fn=functools.partial(augment.random_flip_left_right, key), 175 | reference_fn=None) 176 | self._test_fn( 177 | images_list, 178 | jax_fn=functools.partial(augment.random_flip_up_down, key), 179 | reference_fn=None) 180 | self._test_fn_with_random_arg( 181 | images_list, 182 | jax_fn=functools.partial(augment.random_flip_left_right, key), 183 | reference_fn=None, 184 | probability=(0., 1.)) 185 | self._test_fn_with_random_arg( 186 | images_list, 187 | jax_fn=functools.partial(augment.random_flip_up_down, key), 188 | reference_fn=None, 189 | probability=(0., 1.)) 190 | 191 | # Due to a bug in scipy we cannot test all available modes, refer to these 192 | # issues for more information: https://github.com/jax-ml/jax/issues/11097, 193 | # https://github.com/jax-ml/jax/issues/11097 194 | @parameterized.named_parameters( 195 | ("in_range_nearest_0", _RAND_FLOATS_IN_RANGE, "nearest", 0), 196 | ("in_range_nearest_1", _RAND_FLOATS_IN_RANGE, "nearest", 1), 197 | ("in_range_mirror_1", _RAND_FLOATS_IN_RANGE, "mirror", 1), 198 | ("out_of_range_nearest_0", _RAND_FLOATS_OUT_OF_RANGE, "nearest", 0), 199 | ("out_of_range_nearest_1", _RAND_FLOATS_OUT_OF_RANGE, "nearest", 1), 200 | ("out_of_range_mirror_1", _RAND_FLOATS_OUT_OF_RANGE, "mirror", 1), 201 | ) 202 | def test_affine_transform(self, images_list, mode, order): 203 | # (ndim, ndim) no offset 204 | self._test_fn( 205 | images_list, 206 | jax_fn=functools.partial( 207 | augment.affine_transform, matrix=np.eye(3), mode=mode, order=order), 208 | reference_fn=functools.partial( 209 | scipy.ndimage.affine_transform, 210 | matrix=np.eye(3), 211 | order=order, 212 | mode=mode)) 213 | 214 | # (ndim, ndim) with offset 215 | matrix = jnp.array([[-0.5, 0.2, 0.], [0.8, 0.5, 0.], [0., 0., 1.]]) 216 | offset = jnp.array([40., 32., 0.]) 217 | self._test_fn( 218 | images_list, 219 | jax_fn=functools.partial( 220 | augment.affine_transform, 221 | matrix=matrix, 222 | mode=mode, 223 | offset=offset, 224 | order=order), 225 | reference_fn=functools.partial( 226 | scipy.ndimage.affine_transform, 227 | matrix=matrix, 228 | offset=offset, 229 | order=order, 230 | mode=mode)) 231 | 232 | # (ndim + 1, ndim + 1) 233 | matrix = jnp.array([[0.4, 0.2, 0, -10], [0.2, -0.5, 0, 5], [0, 0, 1, 0], 234 | [0, 0, 0, 1]]) 235 | self._test_fn( 236 | images_list, 237 | jax_fn=functools.partial( 238 | augment.affine_transform, matrix=matrix, mode=mode, order=order), 239 | reference_fn=functools.partial( 240 | scipy.ndimage.affine_transform, 241 | matrix=matrix, 242 | order=order, 243 | mode=mode)) 244 | 245 | # (ndim, ndim + 1) 246 | matrix = jnp.array([[0.4, 0.2, 0, -10], [0.2, -0.5, 0, 5], [0, 0, 1, 0]]) 247 | self._test_fn( 248 | images_list, 249 | jax_fn=functools.partial( 250 | augment.affine_transform, matrix=matrix, mode=mode, order=order), 251 | reference_fn=functools.partial( 252 | scipy.ndimage.affine_transform, 253 | matrix=matrix, 254 | order=order, 255 | mode=mode)) 256 | 257 | # (ndim,) 258 | matrix = jnp.array([0.4, 0.2, 1]) 259 | self._test_fn( 260 | images_list, 261 | jax_fn=functools.partial( 262 | augment.affine_transform, matrix=matrix, mode=mode, order=order), 263 | reference_fn=functools.partial( 264 | scipy.ndimage.affine_transform, 265 | matrix=matrix, 266 | order=order, 267 | mode=mode)) 268 | 269 | @parameterized.product( 270 | parameters_base=[ 271 | ("in_range", "nearest", 0), 272 | ("in_range", "nearest", 1), 273 | ("in_range", "mirror", 1), 274 | ("in_range", "constant", 1), 275 | ("out_of_range", "nearest", 0), 276 | ("out_of_range", "nearest", 1), 277 | ("out_of_range", "mirror", 1), 278 | ("out_of_range", "constant", 1), 279 | ], 280 | cval=(0.0, 1.0, -2.0), 281 | angle=(0.0, np.pi / 4, -np.pi / 4), 282 | ) 283 | def test_rotate(self, parameters_base, cval, angle): 284 | images_list_type, mode, order = parameters_base 285 | if images_list_type == "in_range": 286 | images_list = _RAND_FLOATS_IN_RANGE 287 | elif images_list_type == "out_of_range": 288 | images_list = _RAND_FLOATS_OUT_OF_RANGE 289 | else: 290 | raise ValueError(f"{images_list_type} not a valid image list for tests.") 291 | self._test_fn( 292 | images_list, 293 | jax_fn=functools.partial( 294 | augment.rotate, angle=angle, mode=mode, order=order, cval=cval), 295 | reference_fn=functools.partial( 296 | scipy.ndimage.rotate, 297 | angle=angle * 180 / np.pi, # SciPy uses degrees. 298 | order=order, 299 | mode=mode, 300 | cval=cval, 301 | reshape=False)) 302 | 303 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 304 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 305 | def test_solarize(self, images_list): 306 | self._test_fn_with_random_arg( 307 | images_list, 308 | jax_fn=augment.solarize, 309 | reference_fn=None, 310 | threshold=(0., 1.)) 311 | 312 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 313 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 314 | def test_gaussian_blur(self, images_list): 315 | blur_fn = functools.partial(augment.gaussian_blur, kernel_size=_KERNEL_SIZE) 316 | self._test_fn_with_random_arg( 317 | images_list, jax_fn=blur_fn, reference_fn=None, sigma=(0.1, 2.0)) 318 | 319 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 320 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 321 | def test_random_crop(self, images_list): 322 | key = jax.random.PRNGKey(43) 323 | crop_fn = lambda img: augment.random_crop(key, img, (100, 100, 3)) 324 | self._test_fn(images_list, jax_fn=crop_fn, reference_fn=None) 325 | 326 | @parameterized.named_parameters(("in_range", _RAND_FLOATS_IN_RANGE), 327 | ("out_of_range", _RAND_FLOATS_OUT_OF_RANGE)) 328 | def test_elastic_deformation(self, images_list): 329 | key = jax.random.PRNGKey(43) 330 | alpha = 10. 331 | sigma = 5. 332 | elastic_deformation = functools.partial( 333 | augment.elastic_deformation, 334 | key, 335 | alpha=alpha, 336 | sigma=sigma, 337 | ) 338 | # Due to the difference between random number generation in numpy (in 339 | # reference function) and JAX's for the displacement fields we cannot test 340 | # this against some of the available functions. At the time of writing open 341 | # source options are either unmaintained or are not readily available. 342 | self._test_fn( 343 | images_list, 344 | jax_fn=elastic_deformation, 345 | reference_fn=None, 346 | ) 347 | elastic_deformation = functools.partial( 348 | augment.elastic_deformation, 349 | key, 350 | sigma=sigma) 351 | # Sigma has to be constant for jit since kernel_size is derived from it. 352 | self._test_fn_with_random_arg( 353 | images_list, 354 | jax_fn=elastic_deformation, 355 | reference_fn=None, 356 | alpha=(40, 80)) 357 | 358 | @parameterized.product( 359 | images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE), 360 | height=(131, 111, 1, 88), 361 | width=(111, 105, 1, 40), 362 | ) 363 | def test_center_crop(self, images_list, height, width): 364 | center_crop = functools.partial( 365 | augment.center_crop, 366 | height=height, 367 | width=width, 368 | ) 369 | # Using layer as reference as other tf utility functions are not exactly 370 | # this same center crop: 371 | # - tf.image.crop_and_resize 372 | # - tf.image.central_crop 373 | reference = tf.keras.layers.CenterCrop( 374 | height=height, 375 | width=width, 376 | ) 377 | self._test_fn(images_list, jax_fn=center_crop, reference_fn=reference) 378 | 379 | @parameterized.product( 380 | images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE), 381 | target_height=(156, 131, 200, 251), 382 | target_width=(156, 111, 200, 251), 383 | ) 384 | def test_pad_to_size(self, images_list, target_height, target_width): 385 | pad_fn = functools.partial( 386 | augment.pad_to_size, 387 | target_height=target_height, 388 | target_width=target_width, 389 | mode="constant", 390 | pad_kwargs={"constant_values": 0}, 391 | ) 392 | # We have to rely on `resize_with_crop_or_pad` as there are no pad to size 393 | # equivalents. 394 | reference_fn = functools.partial( 395 | tf.image.resize_with_crop_or_pad, 396 | target_height=target_height, 397 | target_width=target_width, 398 | ) 399 | 400 | self._test_fn(images_list, jax_fn=pad_fn, reference_fn=reference_fn) 401 | 402 | @parameterized.product( 403 | images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE), 404 | target_height=(156, 138, 200, 251), 405 | target_width=(156, 138, 200, 251), 406 | ) 407 | def test_resize_with_crop_or_pad( 408 | self, images_list, target_height, target_width 409 | ): 410 | resize_crop_or_pad = functools.partial( 411 | augment.resize_with_crop_or_pad, 412 | target_height=target_height, 413 | target_width=target_width, 414 | pad_mode="constant", 415 | pad_kwargs={"constant_values": 0}, 416 | ) 417 | reference_fn = functools.partial( 418 | tf.image.resize_with_crop_or_pad, 419 | target_height=target_height, 420 | target_width=target_width, 421 | ) 422 | 423 | self._test_fn( 424 | images_list, jax_fn=resize_crop_or_pad, reference_fn=reference_fn 425 | ) 426 | 427 | 428 | class TestMatchReference(_ImageAugmentationTest): 429 | 430 | def _test_fn_with_random_arg( 431 | self, images_list, jax_fn, reference_fn, **kw_range 432 | ): 433 | if reference_fn is None: 434 | return 435 | assert len(kw_range) == 1 436 | kw_name, (random_min, random_max) = list(kw_range.items())[0] 437 | for image_rgb in images_list: 438 | argument = np.random.uniform(random_min, random_max, size=()) 439 | adjusted_jax = jax_fn(image_rgb, **{kw_name: argument}) 440 | adjusted_reference = reference_fn(image_rgb, argument) 441 | if hasattr(adjusted_reference, "numpy"): 442 | adjusted_reference = adjusted_reference.numpy() 443 | self.assertAllCloseTolerant(adjusted_jax, adjusted_reference) 444 | 445 | def _test_fn(self, images_list, jax_fn, reference_fn): 446 | if reference_fn is None: 447 | return 448 | for image_rgb in images_list: 449 | adjusted_jax = jax_fn(image_rgb) 450 | adjusted_reference = reference_fn(image_rgb) 451 | if hasattr(adjusted_reference, "numpy"): 452 | adjusted_reference = adjusted_reference.numpy() 453 | self.assertAllCloseTolerant(adjusted_jax, adjusted_reference) 454 | 455 | 456 | class TestVmap(_ImageAugmentationTest): 457 | 458 | def _test_fn_with_random_arg(self, images_list, jax_fn, reference_fn, 459 | **kw_range): 460 | del reference_fn # unused. 461 | assert len(kw_range) == 1 462 | kw_name, (random_min, random_max) = list(kw_range.items())[0] 463 | arguments = [ 464 | np.random.uniform(random_min, random_max, size=()) for _ in images_list 465 | ] 466 | fn_vmap = jax.vmap(jax_fn) 467 | outputs_vmaped = list( 468 | fn_vmap( 469 | np.stack(images_list, axis=0), 470 | **{kw_name: np.stack(arguments, axis=0)})) 471 | assert len(images_list) == len(outputs_vmaped) 472 | assert len(images_list) == len(arguments) 473 | for image_rgb, argument, adjusted_vmap in zip(images_list, arguments, 474 | outputs_vmaped): 475 | adjusted_jax = jax_fn(image_rgb, **{kw_name: argument}) 476 | self.assertAllCloseTolerant(adjusted_jax, adjusted_vmap) 477 | 478 | def _test_fn(self, images_list, jax_fn, reference_fn): 479 | del reference_fn # unused. 480 | fn_vmap = jax.vmap(jax_fn) 481 | outputs_vmaped = list(fn_vmap(np.stack(images_list, axis=0))) 482 | assert len(images_list) == len(outputs_vmaped) 483 | for image_rgb, adjusted_vmap in zip(images_list, outputs_vmaped): 484 | adjusted_jax = jax_fn(image_rgb) 485 | self.assertAllCloseTolerant(adjusted_jax, adjusted_vmap) 486 | 487 | 488 | class TestJit(_ImageAugmentationTest): 489 | 490 | def _test_fn_with_random_arg(self, images_list, jax_fn, reference_fn, 491 | **kw_range): 492 | del reference_fn # unused. 493 | assert len(kw_range) == 1 494 | kw_name, (random_min, random_max) = list(kw_range.items())[0] 495 | jax_fn_jitted = jax.jit(jax_fn) 496 | for image_rgb in images_list: 497 | argument = np.random.uniform(random_min, random_max, size=()) 498 | adjusted_jax = jax_fn(image_rgb, **{kw_name: argument}) 499 | adjusted_jit = jax_fn_jitted(image_rgb, **{kw_name: argument}) 500 | self.assertAllCloseTolerant(adjusted_jax, adjusted_jit) 501 | 502 | def _test_fn(self, images_list, jax_fn, reference_fn): 503 | del reference_fn # unused. 504 | jax_fn_jitted = jax.jit(jax_fn) 505 | for image_rgb in images_list: 506 | adjusted_jax = jax_fn(image_rgb) 507 | adjusted_jit = jax_fn_jitted(image_rgb) 508 | self.assertAllCloseTolerant(adjusted_jax, adjusted_jit) 509 | 510 | 511 | class TestCustom(parameterized.TestCase): 512 | """Tests custom logic that is not covered by reference functions.""" 513 | 514 | @parameterized.product( 515 | images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE), 516 | height=(250, 200), 517 | width=(250, 200), 518 | expected_height=(131, 131), 519 | expected_width=(111, 111), 520 | ) 521 | def test_center_crop_size_bigger_than_original( 522 | self, 523 | images_list, 524 | height, 525 | width, 526 | expected_height, 527 | expected_width, 528 | ): 529 | output = augment.center_crop( 530 | image=jnp.array(images_list), 531 | height=height, 532 | width=width, 533 | ) 534 | 535 | self.assertEqual(output.shape[1], expected_height) 536 | self.assertEqual(output.shape[2], expected_width) 537 | 538 | @parameterized.product( 539 | images_list=(_RAND_FLOATS_IN_RANGE, _RAND_FLOATS_OUT_OF_RANGE), 540 | target_height=(55, 84), 541 | target_width=(55, 84), 542 | expected_height=(131, 131), 543 | expected_width=(111, 111), 544 | ) 545 | def test_pad_to_size_when_target_size_smaller_than_original( 546 | self, 547 | images_list, 548 | target_height, 549 | target_width, 550 | expected_height, 551 | expected_width, 552 | ): 553 | output = augment.pad_to_size( 554 | image=jnp.array(images_list), 555 | target_height=target_height, 556 | target_width=target_width, 557 | ) 558 | 559 | self.assertEqual(output.shape[1], expected_height) 560 | self.assertEqual(output.shape[2], expected_width) 561 | 562 | 563 | if __name__ == "__main__": 564 | jax.config.update("jax_default_matmul_precision", "float32") 565 | absltest.main() 566 | -------------------------------------------------------------------------------- /dm_pix/_src/color_conversion.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module provides functions to convert color spaces. 15 | 16 | All functions expect float-encoded images, with values in [0, 1]. 17 | """ 18 | 19 | from typing import Tuple 20 | 21 | import chex 22 | import jax.numpy as jnp 23 | 24 | # DO NOT REMOVE - Logging lib. 25 | 26 | 27 | def split_channels( 28 | image: chex.Array, 29 | channel_axis: int, 30 | ) -> Tuple[chex.Array, chex.Array, chex.Array]: 31 | """Splits an image into its channels. 32 | 33 | Args: 34 | image: an image, with float values in range [0, 1]. Behavior outside of 35 | these bounds is not guaranteed. 36 | channel_axis: the channel axis. image should have 3 layers along this axis. 37 | 38 | Returns: 39 | A tuple of 3 images, with float values in range [0, 1], stacked 40 | along channel_axis. 41 | """ 42 | # DO NOT REMOVE - Logging usage. 43 | 44 | chex.assert_axis_dimension(image, axis=channel_axis, expected=3) 45 | split_axes = jnp.split(image, 3, axis=channel_axis) 46 | return tuple(map(lambda x: jnp.squeeze(x, axis=channel_axis), split_axes)) 47 | 48 | 49 | def rgb_to_hsv( 50 | image_rgb: chex.Array, 51 | *, 52 | channel_axis: int = -1, 53 | ) -> chex.Array: 54 | """Converts an image from RGB to HSV. 55 | 56 | Args: 57 | image_rgb: an RGB image, with float values in range [0, 1]. Behavior outside 58 | of these bounds is not guaranteed. 59 | channel_axis: the channel axis. image_rgb should have 3 layers along this 60 | axis. 61 | 62 | Returns: 63 | An HSV image, with float values in range [0, 1], stacked along channel_axis. 64 | """ 65 | # DO NOT REMOVE - Logging usage. 66 | 67 | eps = jnp.finfo(image_rgb.dtype).eps 68 | image_rgb = jnp.where(jnp.abs(image_rgb) < eps, 0., image_rgb) 69 | red, green, blue = split_channels(image_rgb, channel_axis) 70 | return jnp.stack( 71 | rgb_planes_to_hsv_planes(red, green, blue), axis=channel_axis) 72 | 73 | 74 | def hsv_to_rgb( 75 | image_hsv: chex.Array, 76 | *, 77 | channel_axis: int = -1, 78 | ) -> chex.Array: 79 | """Converts an image from HSV to RGB. 80 | 81 | Args: 82 | image_hsv: an HSV image, with float values in range [0, 1]. Behavior outside 83 | of these bounds is not guaranteed. 84 | channel_axis: the channel axis. image_hsv should have 3 layers along this 85 | axis. 86 | 87 | Returns: 88 | An RGB image, with float values in range [0, 1], stacked along channel_axis. 89 | """ 90 | # DO NOT REMOVE - Logging usage. 91 | hue, saturation, value = split_channels(image_hsv, channel_axis) 92 | return jnp.stack( 93 | hsv_planes_to_rgb_planes(hue, saturation, value), axis=channel_axis) 94 | 95 | 96 | def rgb_planes_to_hsv_planes( 97 | red: chex.Array, 98 | green: chex.Array, 99 | blue: chex.Array, 100 | ) -> Tuple[chex.Array, chex.Array, chex.Array]: 101 | """Converts red, green, blue color planes to hue, saturation, value planes. 102 | 103 | All planes should have the same shape, with float values in range [0, 1]. 104 | Behavior outside of these bounds is not guaranteed. 105 | 106 | Reference implementation: http://shortn/_DjPmiAOWSQ 107 | 108 | Args: 109 | red: the red color plane. 110 | green: the green color plane. 111 | blue: the blue color plane. 112 | 113 | Returns: 114 | A tuple of (hue, saturation, value) planes, as float values in range [0, 1]. 115 | """ 116 | # DO NOT REMOVE - Logging usage. 117 | 118 | value = jnp.maximum(jnp.maximum(red, green), blue) 119 | minimum = jnp.minimum(jnp.minimum(red, green), blue) 120 | range_ = value - minimum 121 | 122 | # Avoid divisions by zeros by using safe values for the division. Even if the 123 | # results are masked by jnp.where and the function would give correct results, 124 | # this would produce NaNs when computing gradients. 125 | safe_value = jnp.where(value > 0, value, 1.) 126 | safe_range = jnp.where(range_ > 0, range_, 1.) 127 | 128 | saturation = jnp.where(value > 0, range_ / safe_value, 0.) 129 | norm = 1. / (6. * safe_range) 130 | 131 | hue = jnp.where(value == green, 132 | norm * (blue - red) + 2. / 6., 133 | norm * (red - green) + 4. / 6.) 134 | hue = jnp.where(value == red, norm * (green - blue), hue) 135 | hue = jnp.where(range_ > 0, hue, 0.) + (hue < 0.) 136 | 137 | return hue, saturation, value 138 | 139 | 140 | def hsv_planes_to_rgb_planes( 141 | hue: chex.Array, 142 | saturation: chex.Array, 143 | value: chex.Array, 144 | ) -> Tuple[chex.Array, chex.Array, chex.Array]: 145 | """Converts hue, saturation, value planes to red, green, blue color planes. 146 | 147 | All planes should have the same shape, with float values in range [0, 1]. 148 | Behavior outside of these bounds is not guaranteed. 149 | 150 | Reference implementation: http://shortn/_NvL2jK8F87 151 | 152 | Args: 153 | hue: the hue plane (wrapping). 154 | saturation: the saturation plane. 155 | value: the value plane. 156 | 157 | Returns: 158 | A tuple of (red, green, blue) planes, as float values in range [0, 1]. 159 | """ 160 | # DO NOT REMOVE - Logging usage. 161 | 162 | dh = (hue % 1.0) * 6. # Wrap when hue >= 360°. 163 | dr = jnp.clip(jnp.abs(dh - 3.) - 1., 0., 1.) 164 | dg = jnp.clip(2. - jnp.abs(dh - 2.), 0., 1.) 165 | db = jnp.clip(2. - jnp.abs(dh - 4.), 0., 1.) 166 | one_minus_s = 1. - saturation 167 | 168 | red = value * (one_minus_s + saturation * dr) 169 | green = value * (one_minus_s + saturation * dg) 170 | blue = value * (one_minus_s + saturation * db) 171 | 172 | return red, green, blue 173 | 174 | 175 | def rgb_to_hsl( 176 | image_rgb: chex.Array, 177 | *, 178 | channel_axis: int = -1, 179 | ) -> chex.Array: 180 | """Converts an image from RGB to HSL. 181 | 182 | Args: 183 | image_rgb: an RGB image, with float values in range [0, 1]. Behavior outside 184 | of these bounds is not guaranteed. 185 | channel_axis: the channel axis. image_rgb should have 3 layers along this 186 | axis. 187 | 188 | Returns: 189 | An HSL image, with float values in range [0, 1], stacked along channel_axis. 190 | """ 191 | # DO NOT REMOVE - Logging usage. 192 | 193 | red, green, blue = split_channels(image_rgb, channel_axis) 194 | 195 | c_max = jnp.maximum(red, jnp.maximum(green, blue)) 196 | c_min = jnp.minimum(red, jnp.minimum(green, blue)) 197 | c_sum = c_max + c_min 198 | c_diff = c_max - c_min 199 | 200 | mask = c_min == c_max 201 | 202 | rc = (c_max - red) / c_diff 203 | gc = (c_max - green) / c_diff 204 | bc = (c_max - blue) / c_diff 205 | 206 | eps = jnp.finfo(jnp.float32).eps 207 | h = jnp.where( 208 | mask, 0, 209 | (jnp.where(red == c_max, bc - gc, 210 | jnp.where(green == c_max, 2 + rc - bc, 4 + gc - rc)) / 6) % 1) 211 | s = jnp.where(mask, 0, (c_diff + eps) / 212 | (2 * eps + jnp.where(c_sum <= 1, c_sum, 2 - c_sum))) 213 | l = c_sum / 2 214 | 215 | return jnp.stack([h, s, l], axis=-1) 216 | 217 | 218 | def hsl_to_rgb( 219 | image_hsl: chex.Array, 220 | *, 221 | channel_axis: int = -1, 222 | ) -> chex.Array: 223 | """Converts an image from HSL to RGB. 224 | 225 | Args: 226 | image_hsl: an HSV image, with float values in range [0, 1]. Behavior outside 227 | of these bounds is not guaranteed. 228 | channel_axis: the channel axis. image_hsv should have 3 layers along this 229 | axis. 230 | 231 | Returns: 232 | An RGB image, with float values in range [0, 1], stacked along channel_axis. 233 | """ 234 | # DO NOT REMOVE - Logging usage. 235 | 236 | h, s, l = split_channels(image_hsl, channel_axis) 237 | 238 | m2 = jnp.where(l <= 0.5, l * (1 + s), l + s - l * s) 239 | m1 = 2 * l - m2 240 | 241 | def _f(hue): 242 | hue = hue % 1.0 243 | return jnp.where( 244 | hue < 1 / 6, m1 + 6 * (m2 - m1) * hue, 245 | jnp.where( 246 | hue < 0.5, m2, 247 | jnp.where(hue < 2 / 3, m1 + 6 * (m2 - m1) * (2 / 3 - hue), m1))) 248 | 249 | image_rgb = jnp.stack([_f(h + 1 / 3), _f(h), _f(h - 1 / 3)], axis=-1) 250 | return jnp.where(s[..., jnp.newaxis] == 0, l[..., jnp.newaxis], image_rgb) 251 | 252 | 253 | def rgb_to_grayscale( 254 | image: chex.Array, 255 | *, 256 | keep_dims: bool = False, 257 | luma_standard="rec601", 258 | channel_axis: int = -1, 259 | ) -> chex.Array: 260 | """Converts an image to a grayscale image using the luma value. 261 | 262 | This is equivalent to `tf.image.rgb_to_grayscale` (when keep_channels=False). 263 | 264 | Args: 265 | image: an RGB image, given as a float tensor in [0, 1]. 266 | keep_dims: if False (default), returns a tensor with a single channel. If 267 | True, will tile the resulting channel. 268 | luma_standard: the luma standard to use, either "rec601", "rec709" or 269 | "bt2001". The default rec601 corresponds to TensorFlow's. 270 | channel_axis: the index of the channel axis. 271 | 272 | Returns: 273 | The grayscale image. 274 | """ 275 | # DO NOT REMOVE - Logging usage. 276 | 277 | assert luma_standard in ["rec601", "rec709", "bt2001"] 278 | if luma_standard == "rec601": 279 | # TensorFlow's default. 280 | rgb_weights = jnp.array([0.2989, 0.5870, 0.1140], dtype=image.dtype) 281 | elif luma_standard == "rec709": 282 | rgb_weights = jnp.array([0.2126, 0.7152, 0.0722], dtype=image.dtype) 283 | else: 284 | rgb_weights = jnp.array([0.2627, 0.6780, 0.0593], dtype=image.dtype) 285 | grayscale = jnp.tensordot(image, rgb_weights, axes=(channel_axis, -1)) 286 | # Add back the channel axis. 287 | grayscale = jnp.expand_dims(grayscale, axis=channel_axis) 288 | if keep_dims: 289 | if channel_axis < 0: 290 | channel_axis += image.ndim 291 | reps = [(1 if axis != channel_axis else 3) for axis in range(image.ndim)] 292 | return jnp.tile(grayscale, reps) # Tile to 3 along the channel axis. 293 | else: 294 | return grayscale 295 | -------------------------------------------------------------------------------- /dm_pix/_src/color_conversion_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dm_pix._src.color_conversion.""" 15 | 16 | import colorsys 17 | import enum 18 | import functools 19 | from typing import Sequence 20 | 21 | from absl.testing import parameterized 22 | import chex 23 | from dm_pix._src import augment 24 | from dm_pix._src import color_conversion 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | import tensorflow as tf 29 | 30 | _NUM_IMAGES = 100 31 | _IMG_SHAPE = (16, 16, 3) 32 | _FLAT_IMG_SHAPE = (_IMG_SHAPE[0] * _IMG_SHAPE[1], _IMG_SHAPE[2]) 33 | _QUANTISATIONS = (None, 16, 2) 34 | 35 | 36 | @enum.unique 37 | class TestImages(enum.Enum): 38 | """Enum classes representing random images with (low, high, num_images).""" 39 | RAND_FLOATS_IN_RANGE = (0., 1., _NUM_IMAGES) 40 | RAND_FLOATS_OUT_OF_RANGE = (-0.5, 1.5, _NUM_IMAGES) 41 | ALL_ONES = (1., 1., 1) 42 | ALL_ZEROS = (0., 0., 1) 43 | 44 | 45 | def generate_test_images( 46 | low: float, 47 | high: float, 48 | num_images: int, 49 | ) -> Sequence[chex.Array]: 50 | images = np.random.uniform( 51 | low=low, 52 | high=high, 53 | size=(num_images,) + _IMG_SHAPE, 54 | ) 55 | return list(images.astype(np.float32)) 56 | 57 | 58 | class ColorConversionTest( 59 | chex.TestCase, 60 | parameterized.TestCase, 61 | ): 62 | 63 | @chex.all_variants 64 | @parameterized.product( 65 | test_images=[ 66 | TestImages.RAND_FLOATS_IN_RANGE, 67 | TestImages.RAND_FLOATS_OUT_OF_RANGE, 68 | TestImages.ALL_ONES, 69 | TestImages.ALL_ZEROS, 70 | ], 71 | channel_last=[True, False], 72 | ) 73 | def test_hsv_to_rgb(self, test_images, channel_last): 74 | channel_axis = -1 if channel_last else -3 75 | hsv_to_rgb = self.variant( 76 | functools.partial( 77 | color_conversion.hsv_to_rgb, channel_axis=channel_axis)) 78 | for hsv in generate_test_images(*test_images.value): 79 | hsv = np.clip(hsv, 0., 1.) 80 | rgb_tf = tf.image.hsv_to_rgb(hsv).numpy() 81 | if not channel_last: 82 | hsv = hsv.swapaxes(-1, -3) 83 | rgb_jax = hsv_to_rgb(hsv) 84 | if not channel_last: 85 | rgb_jax = rgb_jax.swapaxes(-1, -3) 86 | np.testing.assert_allclose(rgb_jax, rgb_tf, rtol=1e-3, atol=1e-3) 87 | 88 | # Check that taking gradients through the conversion function does not 89 | # create NaNs, which might happen due to some division by zero. 90 | # We're only testing the gradients of a single image rather than all test 91 | # images to avoid making the tests run too slow. 92 | jacobian = jax.jacrev(hsv_to_rgb)(hsv) 93 | self.assertFalse(jnp.isnan(jacobian).any(), "NaNs in HSV to RGB gradients") 94 | 95 | @chex.all_variants 96 | @parameterized.product( 97 | test_images=[ 98 | TestImages.RAND_FLOATS_IN_RANGE, 99 | TestImages.RAND_FLOATS_OUT_OF_RANGE, 100 | TestImages.ALL_ONES, 101 | TestImages.ALL_ZEROS, 102 | ], 103 | channel_last=[True, False], 104 | ) 105 | def test_rgb_to_hsv(self, test_images, channel_last): 106 | channel_axis = -1 if channel_last else -3 107 | rgb_to_hsv = self.variant( 108 | functools.partial( 109 | color_conversion.rgb_to_hsv, channel_axis=channel_axis)) 110 | for rgb in generate_test_images(*test_images.value): 111 | hsv_tf = tf.image.rgb_to_hsv(rgb).numpy() 112 | if not channel_last: 113 | rgb = rgb.swapaxes(-1, -3) 114 | hsv_jax = rgb_to_hsv(rgb) 115 | if not channel_last: 116 | hsv_jax = hsv_jax.swapaxes(-1, -3) 117 | np.testing.assert_allclose(hsv_jax, hsv_tf, rtol=1e-3, atol=1e-3) 118 | 119 | # Check that taking gradients through the conversion function does not 120 | # create NaNs, which might happen due to some division by zero. 121 | # We're only testing the gradients of a single image rather than all test 122 | # images to avoid making the tests run too slow. 123 | rgb = generate_test_images(*test_images.value)[0] 124 | if not channel_last: 125 | rgb = rgb.swapaxes(-1, -3) 126 | jacobian = jax.jacrev(rgb_to_hsv)(rgb) 127 | self.assertFalse(jnp.isnan(jacobian).any(), "NaNs in RGB to HSV gradients") 128 | 129 | def test_rgb_to_hsv_subnormals(self): 130 | 131 | # Create a tensor that will contain some subnormal floating points. 132 | img = jnp.zeros((5, 5, 3)) 133 | img = img.at[2, 2, 1].set(1) 134 | blurred_img = augment.gaussian_blur(img, sigma=0.08, kernel_size=5.) 135 | fun = lambda x: color_conversion.rgb_to_hsv(x).sum() 136 | grad_fun = jax.grad(fun) 137 | 138 | grad = grad_fun(blurred_img) 139 | self.assertFalse(jnp.isnan(grad).any(), "NaNs in RGB to HSV gradients") 140 | 141 | @chex.all_variants 142 | def test_vmap_roundtrip(self): 143 | images = generate_test_images(*TestImages.RAND_FLOATS_IN_RANGE.value) 144 | rgb_init = np.stack(images, axis=0) 145 | rgb_to_hsv = self.variant(jax.vmap(color_conversion.rgb_to_hsv)) 146 | hsv_to_rgb = self.variant(jax.vmap(color_conversion.hsv_to_rgb)) 147 | hsv = rgb_to_hsv(rgb_init) 148 | rgb_final = hsv_to_rgb(hsv) 149 | np.testing.assert_allclose(rgb_init, rgb_final, rtol=1e-3, atol=1e-3) 150 | 151 | def test_jit_roundtrip(self): 152 | images = generate_test_images(*TestImages.RAND_FLOATS_IN_RANGE.value) 153 | rgb_init = np.stack(images, axis=0) 154 | hsv = jax.jit(color_conversion.rgb_to_hsv)(rgb_init) 155 | rgb_final = jax.jit(color_conversion.hsv_to_rgb)(hsv) 156 | np.testing.assert_allclose(rgb_init, rgb_final, rtol=1e-3, atol=1e-3) 157 | 158 | @chex.all_variants 159 | @parameterized.named_parameters( 160 | ("black", 0, 0), 161 | ("gray", 0.000001, 0.999999), 162 | ("white", 1, 1), 163 | ) 164 | def test_rgb_to_hsl_golden(self, minval, maxval): 165 | """Compare against colorsys.rgb_to_hls as a golden implementation.""" 166 | key = jax.random.PRNGKey(0) 167 | for quantization in (None, 16, 2): 168 | key_rand_uni, key = jax.random.split(key) 169 | image_rgb = jax.random.uniform( 170 | key=key_rand_uni, 171 | shape=_FLAT_IMG_SHAPE, 172 | dtype=np.float32, 173 | minval=minval, 174 | maxval=maxval, 175 | ) 176 | 177 | # Use quantization to probe the corners of the color cube. 178 | if quantization is not None: 179 | image_rgb = jnp.round(image_rgb * quantization) / quantization 180 | 181 | hsl_true = np.zeros_like(image_rgb) 182 | for i in range(image_rgb.shape[0]): 183 | h, l, s = colorsys.rgb_to_hls(*image_rgb[i, :]) 184 | hsl_true[i, :] = [h, s, l] 185 | 186 | image_rgb = np.reshape(image_rgb, _IMG_SHAPE) 187 | hsl_true = np.reshape(hsl_true, _IMG_SHAPE) 188 | rgb_to_hsl = self.variant(color_conversion.rgb_to_hsl) 189 | np.testing.assert_allclose( 190 | rgb_to_hsl(image_rgb), hsl_true, atol=1E-5, rtol=1E-5) 191 | 192 | @chex.all_variants 193 | @parameterized.named_parameters( 194 | ("black", 0, 0.000001), 195 | ("white", 0.999999, 1), 196 | ) 197 | def test_rgb_to_hsl_stable(self, minval, maxval): 198 | """rgb_to_hsl's output near the black+white corners should be in [0, 1].""" 199 | key_rand_uni = jax.random.PRNGKey(0) 200 | image_rgb = jax.random.uniform( 201 | key=key_rand_uni, 202 | shape=_FLAT_IMG_SHAPE, 203 | dtype=np.float32, 204 | minval=minval, 205 | maxval=maxval, 206 | ) 207 | rgb_to_hsl = self.variant(color_conversion.rgb_to_hsl) 208 | hsl = rgb_to_hsl(image_rgb) 209 | self.assertTrue(jnp.all(jnp.isfinite(hsl))) 210 | self.assertLessEqual(jnp.max(hsl), 1.) 211 | self.assertGreaterEqual(jnp.min(hsl), 0.) 212 | 213 | @chex.all_variants 214 | def test_hsl_to_rgb_golden(self): 215 | """Compare against colorsys.rgb_to_hls as a golden implementation.""" 216 | key = jax.random.PRNGKey(0) 217 | for quantization in _QUANTISATIONS: 218 | key_rand_uni, key = jax.random.split(key) 219 | image_hsl = ( 220 | jax.random.uniform(key_rand_uni, _FLAT_IMG_SHAPE).astype(np.float32)) 221 | 222 | # Use quantization to probe the corners of the color cube. 223 | if quantization is not None: 224 | image_hsl = jnp.round(image_hsl * quantization) / quantization 225 | 226 | rgb_true = np.zeros_like(image_hsl) 227 | for i in range(image_hsl.shape[0]): 228 | h, s, l = image_hsl[i, :] 229 | rgb_true[i, :] = colorsys.hls_to_rgb(h, l, s) 230 | 231 | rgb_true = np.reshape(rgb_true, _IMG_SHAPE) 232 | image_hsl = np.reshape(image_hsl, _IMG_SHAPE) 233 | hsl_to_rgb = self.variant(color_conversion.hsl_to_rgb) 234 | np.testing.assert_allclose( 235 | hsl_to_rgb(image_hsl), rgb_true, atol=1E-5, rtol=1E-5) 236 | 237 | @chex.all_variants 238 | def test_hsl_rgb_roundtrip(self): 239 | key = jax.random.PRNGKey(0) 240 | for quantization in _QUANTISATIONS: 241 | key_rand_uni, key = jax.random.split(key) 242 | image_rgb = jax.random.uniform(key_rand_uni, _IMG_SHAPE) 243 | 244 | # Use quantization to probe the corners of the color cube. 245 | if quantization is not None: 246 | image_rgb = jnp.round(image_rgb * quantization) / quantization 247 | 248 | rgb_to_hsl = self.variant(color_conversion.rgb_to_hsl) 249 | hsl_to_rgb = self.variant(color_conversion.hsl_to_rgb) 250 | np.testing.assert_allclose( 251 | image_rgb, hsl_to_rgb(rgb_to_hsl(image_rgb)), atol=1E-5, rtol=1E-5) 252 | 253 | @chex.all_variants 254 | @parameterized.product( 255 | test_images=[ 256 | TestImages.RAND_FLOATS_IN_RANGE, 257 | TestImages.RAND_FLOATS_OUT_OF_RANGE, 258 | TestImages.ALL_ONES, 259 | TestImages.ALL_ZEROS, 260 | ], 261 | keep_dims=[True, False], 262 | channel_last=[True, False], 263 | ) 264 | def test_grayscale(self, test_images, keep_dims, channel_last): 265 | channel_axis = -1 if channel_last else -3 266 | rgb_to_grayscale = self.variant( 267 | functools.partial( 268 | color_conversion.rgb_to_grayscale, 269 | keep_dims=keep_dims, 270 | channel_axis=channel_axis)) 271 | for rgb in generate_test_images(*test_images.value): 272 | grayscale_tf = tf.image.rgb_to_grayscale(rgb).numpy() 273 | if not channel_last: 274 | rgb = rgb.swapaxes(-1, -3) 275 | grayscale_jax = rgb_to_grayscale(rgb) 276 | if not channel_last: 277 | grayscale_jax = grayscale_jax.swapaxes(-1, -3) 278 | if keep_dims: 279 | for i in range(_IMG_SHAPE[-1]): 280 | np.testing.assert_allclose( 281 | grayscale_jax[..., [i]], grayscale_tf, atol=1E-5, rtol=1E-5) 282 | else: 283 | np.testing.assert_allclose( 284 | grayscale_jax, grayscale_tf, atol=1E-5, rtol=1E-5) 285 | 286 | 287 | if __name__ == "__main__": 288 | tf.test.main() 289 | -------------------------------------------------------------------------------- /dm_pix/_src/depth_and_space.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module provides functions for rearranging blocks of spatial data.""" 15 | 16 | import chex 17 | import jax 18 | import jax.numpy as jnp 19 | 20 | # DO NOT REMOVE - Logging lib. 21 | 22 | 23 | def depth_to_space(inputs: chex.Array, block_size: int) -> chex.Array: 24 | """Rearranges data from depth into blocks of spatial data. 25 | 26 | Args: 27 | inputs: Array of shape [H, W, C] or [N, H, W, C]. The number of channels 28 | (depth dimension) must be divisible by block_size ** 2. 29 | block_size: Size of spatial blocks >= 2. 30 | 31 | Returns: 32 | For inputs of shape [H, W, C] the output is a reshaped array of shape 33 | [H * B, W * B, C / (B ** 2)], where B is `block_size`. If there's a leading 34 | batch dimension, it stays unchanged. 35 | """ 36 | # DO NOT REMOVE - Logging usage. 37 | 38 | chex.assert_rank(inputs, {3, 4}) 39 | if inputs.ndim == 4: # Batched case. 40 | return jax.vmap(depth_to_space, in_axes=(0, None))(inputs, block_size) 41 | 42 | height, width, depth = inputs.shape 43 | if depth % (block_size**2) != 0: 44 | raise ValueError( 45 | f"Number of channels {depth} must be divisible by block_size ** 2" 46 | f" {block_size**2}." 47 | ) 48 | new_depth = depth // (block_size**2) 49 | outputs = jnp.reshape(inputs, 50 | [height, width, block_size, block_size, new_depth]) 51 | outputs = jnp.transpose(outputs, [0, 2, 1, 3, 4]) 52 | outputs = jnp.reshape(outputs, 53 | [height * block_size, width * block_size, new_depth]) 54 | return outputs 55 | 56 | 57 | def space_to_depth(inputs: chex.Array, block_size: int) -> chex.Array: 58 | """Rearranges data from blocks of spatial data into depth. 59 | 60 | This is the reverse of depth_to_space. 61 | 62 | Args: 63 | inputs: Array of shape [H, W, C] or [N, H, W, C]. The height and width must 64 | each be divisible by block_size. 65 | block_size: Size of spatial blocks >= 2. 66 | 67 | Returns: 68 | For inputs of shape [H, W, C] the output is a reshaped array of shape 69 | [H / B, W / B, C * (B ** 2)], where B is `block_size`. If there's a leading 70 | batch dimension, it stays unchanged. 71 | """ 72 | # DO NOT REMOVE - Logging usage. 73 | 74 | chex.assert_rank(inputs, {3, 4}) 75 | if inputs.ndim == 4: # Batched case. 76 | return jax.vmap(space_to_depth, in_axes=(0, None))(inputs, block_size) 77 | 78 | height, width, depth = inputs.shape 79 | if height % block_size != 0: 80 | raise ValueError( 81 | f"Height {height} must be divisible by block size {block_size}.") 82 | if width % block_size != 0: 83 | raise ValueError( 84 | f"Width {width} must be divisible by block size {block_size}.") 85 | new_depth = depth * (block_size**2) 86 | new_height = height // block_size 87 | new_width = width // block_size 88 | outputs = jnp.reshape(inputs, 89 | [new_height, block_size, new_width, block_size, depth]) 90 | outputs = jnp.transpose(outputs, [0, 2, 1, 3, 4]) 91 | outputs = jnp.reshape(outputs, [new_height, new_width, new_depth]) 92 | return outputs 93 | -------------------------------------------------------------------------------- /dm_pix/_src/depth_and_space_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dm_pix._src.depth_and_space.""" 15 | 16 | from absl.testing import parameterized 17 | import chex 18 | from dm_pix._src import depth_and_space 19 | import numpy as np 20 | import tensorflow as tf 21 | 22 | 23 | class DepthAndSpaceTest(chex.TestCase, parameterized.TestCase): 24 | 25 | @chex.all_variants 26 | @parameterized.parameters(([1, 1, 1, 9], 3), ([2, 2, 2, 8], 2)) 27 | def test_depth_to_space(self, input_shape, block_size): 28 | depth_to_space_fn = self.variant( 29 | depth_and_space.depth_to_space, static_argnums=1) 30 | inputs = np.arange(np.prod(input_shape), dtype=np.int32) 31 | inputs = np.reshape(inputs, input_shape) 32 | output_tf = tf.nn.depth_to_space(inputs, block_size).numpy() 33 | output_jax = depth_to_space_fn(inputs, block_size) 34 | np.testing.assert_array_equal(output_tf, output_jax) 35 | 36 | @chex.all_variants 37 | @parameterized.parameters(([1, 3, 3, 1], 3), ([2, 4, 4, 2], 2)) 38 | def test_space_to_depth(self, input_shape, block_size): 39 | space_to_depth_fn = self.variant( 40 | depth_and_space.space_to_depth, static_argnums=1) 41 | inputs = np.arange(np.prod(input_shape), dtype=np.int32) 42 | inputs = np.reshape(inputs, input_shape) 43 | output_tf = tf.nn.space_to_depth(inputs, block_size).numpy() 44 | output_jax = space_to_depth_fn(inputs, block_size) 45 | np.testing.assert_array_equal(output_tf, output_jax) 46 | 47 | 48 | if __name__ == "__main__": 49 | tf.test.main() 50 | -------------------------------------------------------------------------------- /dm_pix/_src/interpolation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module provides functions for interpolating ND images. 15 | 16 | All functions expect float-encoded images, with values in [0, 1]. 17 | """ 18 | 19 | import itertools 20 | from typing import Optional, Sequence, Tuple 21 | 22 | import chex 23 | from jax import lax 24 | import jax.numpy as jnp 25 | 26 | # DO NOT REMOVE - Logging lib. 27 | 28 | 29 | def _round_half_away_from_zero(a: chex.Array) -> chex.Array: 30 | return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a) 31 | 32 | 33 | def _make_linear_interpolation_indices_nd( 34 | coordinates: chex.Array, 35 | shape: chex.Array, 36 | ) -> Tuple[chex.Array, chex.Array, chex.Array]: 37 | """Creates linear interpolation indices and weights for ND coordinates. 38 | 39 | Args: 40 | coordinates: An array of shape (N, M_coordinates). 41 | shape: The shape of the ND volume, e.g. if N=3 shape=(dim_z, dim_y, dim_x). 42 | 43 | Returns: 44 | The lower and upper indices of `coordinates` and their weights. 45 | """ 46 | lower = jnp.floor(coordinates).astype(jnp.int32) 47 | upper = jnp.ceil(coordinates).astype(jnp.int32) 48 | weights = coordinates - lower 49 | 50 | # Expand dimensions for `shape` to allow broadcasting it to every coordinate. 51 | # Expansion size is equal to the number of dimensions of `coordinates` - 1. 52 | shape = shape.reshape(shape.shape + (1,) * (coordinates.ndim - 1)) 53 | 54 | lower = jnp.clip(lower, 0, shape - 1) 55 | upper = jnp.clip(upper, 0, shape - 1) 56 | 57 | return lower, upper, weights 58 | 59 | 60 | def _make_linear_interpolation_indices_flat_nd( 61 | coordinates: chex.Array, 62 | shape: Sequence[int], 63 | ) -> Tuple[chex.Array, chex.Array]: 64 | """Creates flat linear interpolation indices and weights for ND coordinates. 65 | 66 | Args: 67 | coordinates: An array of shape (N, M_coordinates). 68 | shape: The shape of the ND volume, e.g. if N=3 shape=(dim_z, dim_y, dim_x). 69 | 70 | Returns: 71 | The indices into the flattened input and their weights. 72 | """ 73 | coordinates = jnp.asarray(coordinates) 74 | shape = jnp.asarray(shape) 75 | 76 | if shape.shape[0] != coordinates.shape[0]: 77 | raise ValueError( 78 | (f"{coordinates.shape[0]}-dimensional coordinates provided for " 79 | f"{shape.shape[0]}-dimensional input")) 80 | 81 | lower_nd, upper_nd, weights_nd = _make_linear_interpolation_indices_nd( 82 | coordinates, shape) 83 | 84 | # Here we want to translate e.g. a 3D-disposed indices to linear ones, since 85 | # we have to index on the flattened source, so: 86 | # flat_idx = shape[1] * shape[2] * z_idx + shape[2] * y_idx + x_idx 87 | 88 | # The `strides` of a `shape`-sized array tell us how many elements we have to 89 | # skip to move to the next position along a certain axis in that array. 90 | # For example, for a shape=(5,4,2) we have to skip 1 value to move to the next 91 | # column (3rd axis), 2 values to move to get to the same position in the next 92 | # row (2nd axis) and 4*2=8 values to move to get to the same position on the 93 | # 1st axis. 94 | strides = jnp.concatenate([jnp.cumprod(shape[:0:-1])[::-1], jnp.array([1])]) 95 | 96 | # Array of 2^n rows where the ith row is the binary representation of i. 97 | binary_array = jnp.array( 98 | list(itertools.product([0, 1], repeat=shape.shape[0]))) 99 | 100 | # Expand dimensions to allow broadcasting `strides` and `binary_array` to 101 | # every coordinate. 102 | # Expansion size is equal to the number of dimensions of `coordinates` - 1. 103 | strides = strides.reshape(strides.shape + (1,) * (coordinates.ndim - 1)) 104 | binary_array = binary_array.reshape(binary_array.shape + (1,) * 105 | (coordinates.ndim - 1)) 106 | 107 | lower_1d = lower_nd * strides 108 | upper_1d = upper_nd * strides 109 | 110 | point_weights = [] 111 | point_indices = [] 112 | 113 | for r in binary_array: 114 | # `point_indices` is defined as: 115 | # `jnp.matmul(binary_array, upper) + jnp.matmul(1-binary_array, lower)` 116 | # however, to date, that implementation turns out to be slower than the 117 | # equivalent following one. 118 | point_indices.append(jnp.sum(upper_1d * r + lower_1d * (1 - r), axis=0)) 119 | point_weights.append( 120 | jnp.prod(r * weights_nd + (1 - r) * (1 - weights_nd), axis=0)) 121 | return jnp.stack(point_indices, axis=0), jnp.stack(point_weights, axis=0) 122 | 123 | 124 | def _linear_interpolate_using_indices_nd( 125 | volume: chex.Array, 126 | indices: chex.Array, 127 | weights: chex.Array, 128 | ) -> chex.Array: 129 | """Interpolates linearly on `volume` using `indices` and `weights`.""" 130 | target = jnp.sum(weights * volume[indices], axis=0) 131 | if jnp.issubdtype(volume.dtype, jnp.integer): 132 | target = _round_half_away_from_zero(target) 133 | return target.astype(volume.dtype) 134 | 135 | 136 | def flat_nd_linear_interpolate( 137 | volume: chex.Array, 138 | coordinates: chex.Array, 139 | *, 140 | unflattened_vol_shape: Optional[Sequence[int]] = None, 141 | ) -> chex.Array: 142 | """Maps the input ND volume to coordinates by linear interpolation. 143 | 144 | Args: 145 | volume: A volume (flat if `unflattened_vol_shape` is provided) where to 146 | query coordinates. 147 | coordinates: An array of shape (N, M_coordinates). Where M_coordinates can 148 | be M-dimensional. If M_coordinates == 1, then `coordinates.shape` can 149 | simply be (N,), e.g. if N=3 and M_coordinates=1, this has the form (z, y, 150 | x). 151 | unflattened_vol_shape: The shape of the `volume` before flattening. If 152 | provided, then `volume` must be pre-flattened. 153 | 154 | Returns: 155 | The resulting mapped coordinates. The shape of the output is `M_coordinates` 156 | (derived from `coordinates` by dropping the first axis). 157 | """ 158 | # DO NOT REMOVE - Logging usage. 159 | 160 | if unflattened_vol_shape is None: 161 | unflattened_vol_shape = volume.shape 162 | volume = volume.flatten() 163 | 164 | indices, weights = _make_linear_interpolation_indices_flat_nd( 165 | coordinates, shape=unflattened_vol_shape) 166 | return _linear_interpolate_using_indices_nd( 167 | jnp.asarray(volume), indices, weights) 168 | 169 | 170 | def flat_nd_linear_interpolate_constant( 171 | volume: chex.Array, 172 | coordinates: chex.Array, 173 | *, 174 | cval: Optional[float] = 0., 175 | unflattened_vol_shape: Optional[Sequence[int]] = None, 176 | ) -> chex.Array: 177 | """Maps volume by interpolation and returns a constant outside boundaries. 178 | 179 | Maps the input ND volume to coordinates by linear interpolation, but returns 180 | a constant value if the coordinates fall outside the volume boundary. 181 | 182 | Args: 183 | volume: A volume (flat if `unflattened_vol_shape` is provided) where to 184 | query coordinates. 185 | coordinates: An array of shape (N, M_coordinates). Where M_coordinates can 186 | be M-dimensional. If M_coordinates == 1, then `coordinates.shape` can 187 | simply be (N,), e.g. if N=3 and M_coordinates=1, this has the form (z, y, 188 | x). 189 | cval: A constant value to map to for coordinates that fall outside 190 | the volume boundaries. 191 | unflattened_vol_shape: The shape of the `volume` before flattening. If 192 | provided, then `volume` must be pre-flattened. 193 | 194 | Returns: 195 | The resulting mapped coordinates. The shape of the output is `M_coordinates` 196 | (derived from `coordinates` by dropping the first axis). 197 | """ 198 | # DO NOT REMOVE - Logging usage. 199 | 200 | volume_shape = volume.shape 201 | if unflattened_vol_shape is not None: 202 | volume_shape = unflattened_vol_shape 203 | 204 | # Initialize considering all coordinates within the volume and loop through 205 | # boundaries. 206 | is_in_bounds = jnp.full(coordinates.shape[1:], True) 207 | for dim, dim_size in enumerate(volume_shape): 208 | is_in_bounds = jnp.logical_and(is_in_bounds, coordinates[dim] >= 0) 209 | is_in_bounds = jnp.logical_and(is_in_bounds, 210 | coordinates[dim] <= dim_size - 1) 211 | 212 | return flat_nd_linear_interpolate( 213 | volume, 214 | coordinates, 215 | unflattened_vol_shape=unflattened_vol_shape 216 | ) * is_in_bounds + (1. - is_in_bounds) * cval 217 | -------------------------------------------------------------------------------- /dm_pix/_src/interpolation_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dm_pix._src.interpolation.""" 15 | 16 | import itertools 17 | from typing import Sequence, Tuple 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | import chex 22 | from dm_pix._src import interpolation 23 | import jax.numpy as jnp 24 | import numpy as np 25 | 26 | _SHAPE_COORDS = ((1, 1), (1, 3), (3, 2), (4, 4), (4, 1, 4), (4, 2, 2)) 27 | _CVALS = (0.0, 1.0, -2.0) 28 | 29 | 30 | def _prepare_inputs( 31 | shape_output_coordinates: Tuple[int]) -> Tuple[jnp.ndarray, jnp.ndarray]: 32 | """Returns the volume and coordinates to be used in the function under test. 33 | 34 | Args: 35 | shape_output_coordinates: [N, M] shape for the output coordinates, where N 36 | is a scalar that determines also the number of dimensions in the volume 37 | and M is either a scalar (output coordinates will be a 1D array) or a 38 | vector. 39 | """ 40 | template_coords = jnp.array([(2, 0, 3, 2.9), (1, 0, 4.3, 1), (1, 0.5, 8.4, 1), 41 | (21, 0.5, 4, 1)]) 42 | 43 | template_volume = jnp.array( 44 | [[[[1.6583091, 2.0139587, 2.4636955, 0.11345804, 4.044214], 45 | [4.538101, 4.3030543, 0.6967968, 2.0311975, 2.5746036], 46 | [0.52024364, 4.767304, 2.3863382, 2.496363, 3.7334495], 47 | [1.367867, 4.18175, 0.38294435, 3.9395797, 2.6097183]], 48 | [[0.7470304, 3.8882136, 0.42186677, 1.9224191, 2.3947673], 49 | [4.859208, 2.7876246, 0.7796812, 3.234911, 2.0911336], 50 | [3.9205093, 4.027418, 2.9367173, 4.367462, 0.5682403], 51 | [3.32689, 0.5056447, 1.3147497, 3.549356, 0.57163835]]], 52 | [[[4.6056757, 3.0942523, 4.809611, 0.6062186, 4.1184435], 53 | [1.0862654, 1.0130441, 0.24880886, 2.9144812, 2.831624], 54 | [0.8990741, 4.6315174, 3.490876, 3.997823, 3.166548], 55 | [2.2909844, 2.1135485, 0.7603508, 1.7530066, 3.3882804]], 56 | [[2.2388606, 0.62632084, 0.39939642, 1.2361205, 4.4961414], 57 | [1.3705498, 4.6373777, 2.2974424, 2.9484348, 1.8847889], 58 | [4.856637, 3.4407651, 1.5632284, 0.30945182, 4.8406916], 59 | [4.10108, 0.44603765, 3.893259, 2.656221, 4.652004]]], 60 | [[[1.8670297, 4.1097646, 3.9615297, 0.9295058, 3.9903827], 61 | [3.3507752, 1.4316595, 4.0365667, 2.3517795, 2.7806897], 62 | [1.245628, 4.8092294, 3.3148618, 3.6758037, 2.4036856], 63 | [4.2023296, 0.6232512, 2.2606378, 2.1633143, 3.019858]], 64 | [[3.6607206, 0.26809275, 0.43593287, 0.3059131, 0.5254775], 65 | [0.27680695, 0.88441014, 4.8790736, 4.796288, 4.922847], 66 | [3.3822608, 2.5350225, 3.771946, 0.46694577, 4.0173407], 67 | [4.835033, 4.4530325, 1.4543611, 4.67758, 3.4009826]]]]) 68 | 69 | if shape_output_coordinates[0] == 1: 70 | volume = template_volume[0, 0, 0, :] 71 | elif shape_output_coordinates[0] == 3: 72 | volume = template_volume[0, :, :, :] 73 | elif shape_output_coordinates[0] == template_volume.ndim: 74 | volume = template_volume 75 | else: 76 | raise ValueError("Unsupported shape_output_coordinates[0] = " 77 | f"{shape_output_coordinates[0]}") 78 | 79 | if len(shape_output_coordinates) == 2: 80 | if shape_output_coordinates <= template_coords.shape: 81 | # Get a slice of the `template_coords`. 82 | coordinates = template_coords[0:shape_output_coordinates[0], 83 | 0:shape_output_coordinates[1]] 84 | 85 | if shape_output_coordinates[1] == 1: 86 | # Do [[ num ]] -> [ num ] to test special case. 87 | coordinates = coordinates.squeeze(axis=-1) 88 | else: 89 | raise ValueError("Unsupported shape_output_coordinates[1] = " 90 | f"{shape_output_coordinates[1]}") 91 | else: 92 | try: 93 | # In this case, try reshaping the _TEMPLATE_COORDS to the desired shape. 94 | coordinates = jnp.reshape(template_coords, shape_output_coordinates) 95 | except TypeError as e: 96 | raise ValueError(f"Unsupported shape_output_coordinates = " 97 | f"{shape_output_coordinates}") from e 98 | 99 | return volume, coordinates 100 | 101 | 102 | def _prepare_expected(shape_coordinates: Sequence[int]) -> jnp.ndarray: 103 | if len(shape_coordinates) == 2: 104 | if tuple(shape_coordinates) == (1, 1): 105 | out = jnp.array(2.4636955) 106 | elif tuple(shape_coordinates) == (1, 3): 107 | out = jnp.array([2.4636955, 1.6583091, 0.11345804]) 108 | elif tuple(shape_coordinates) == (3, 2): 109 | out = jnp.array([2.7876246, 1.836134]) 110 | elif shape_coordinates[0] == 4: 111 | out = jnp.array([4.922847, 3.128356, 3.4009826, 0.88441014]) 112 | else: 113 | raise ValueError(f"Unsupported shape_coordinates = {shape_coordinates}") 114 | elif shape_coordinates[0] == 4: 115 | try: 116 | out = jnp.array([4.922847, 3.128356, 3.4009826, 0.88441014]) 117 | out = jnp.reshape(out, shape_coordinates[1:]) 118 | except TypeError as e: 119 | raise ValueError( 120 | f"Unsupported shape_coordinates = {shape_coordinates}") from e 121 | else: 122 | raise ValueError(f"Unsupported shape_coordinates = {shape_coordinates}") 123 | 124 | return out 125 | 126 | 127 | def _prepare_expected_const(shape_coordinates: Sequence[int], 128 | cval: float) -> jnp.ndarray: 129 | if len(shape_coordinates) == 2: 130 | if tuple(shape_coordinates) == (3, 2): 131 | out = jnp.array([cval, 1.836134]) 132 | elif shape_coordinates[0] == 4: 133 | out = jnp.array([cval, 3.128356, cval, cval]) 134 | else: 135 | return _prepare_expected(shape_coordinates) 136 | elif shape_coordinates[0] == 4: 137 | try: 138 | out = jnp.array([cval, 3.128356, cval, cval]) 139 | out = jnp.reshape(out, shape_coordinates[1:]) 140 | except TypeError as e: 141 | raise ValueError( 142 | f"Unsupported shape_coordinates = {shape_coordinates}") from e 143 | else: 144 | raise ValueError(f"Unsupported shape_coordinates = {shape_coordinates}") 145 | 146 | return out 147 | 148 | 149 | class InterpolationTest(chex.TestCase, parameterized.TestCase): 150 | 151 | @chex.all_variants 152 | @parameterized.named_parameters([ 153 | dict(testcase_name=f"_{shape}_coords", shape_coordinates=shape) 154 | for shape in _SHAPE_COORDS 155 | ]) 156 | def test_flat_nd_linear_interpolate(self, shape_coordinates): 157 | volume, coords = _prepare_inputs(shape_coordinates) 158 | expected = _prepare_expected(shape_coordinates) 159 | 160 | flat_nd_linear_interpolate = self.variant( 161 | interpolation.flat_nd_linear_interpolate) 162 | np.testing.assert_allclose( 163 | flat_nd_linear_interpolate(volume, coords), expected) 164 | np.testing.assert_allclose( 165 | flat_nd_linear_interpolate( 166 | volume.flatten(), coords, unflattened_vol_shape=volume.shape), 167 | expected) 168 | 169 | @chex.all_variants 170 | @parameterized.named_parameters([ 171 | (f"_{shape}_coords_{cval}_cval", shape, cval) 172 | for cval, shape in itertools.product(_CVALS, _SHAPE_COORDS) 173 | ]) 174 | def test_flat_nd_linear_interpolate_constant(self, shape_coordinates, cval): 175 | volume, coords = _prepare_inputs(shape_coordinates) 176 | expected = _prepare_expected_const(shape_coordinates, cval) 177 | 178 | flat_nd_linear_interpolate_constant = self.variant( 179 | interpolation.flat_nd_linear_interpolate_constant) 180 | np.testing.assert_allclose( 181 | flat_nd_linear_interpolate_constant(volume, coords, cval=cval), 182 | expected) 183 | np.testing.assert_allclose( 184 | flat_nd_linear_interpolate_constant( 185 | volume.flatten(), 186 | coords, 187 | cval=cval, 188 | unflattened_vol_shape=volume.shape), expected) 189 | 190 | 191 | if __name__ == "__main__": 192 | absltest.main() 193 | -------------------------------------------------------------------------------- /dm_pix/_src/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Functions to compare image pairs. 15 | 16 | All functions expect float-encoded images, with values in [0, 1], with NHWC 17 | shapes. Each image metric function returns a scalar for each image pair. 18 | """ 19 | 20 | from typing import Callable, Optional 21 | 22 | import chex 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | # DO NOT REMOVE - Logging lib. 27 | 28 | 29 | def mae( 30 | a: chex.Array, 31 | b: chex.Array, 32 | *, 33 | ignore_nans: bool = False, 34 | ) -> chex.Numeric: 35 | """Returns the Mean Absolute Error between `a` and `b`. 36 | 37 | Args: 38 | a: First image (or set of images). 39 | b: Second image (or set of images). 40 | ignore_nans: If True, will ignore NaNs in the inputs. 41 | 42 | Returns: 43 | MAE between `a` and `b`. 44 | """ 45 | # DO NOT REMOVE - Logging usage. 46 | 47 | chex.assert_rank([a, b], {3, 4}) 48 | chex.assert_type([a, b], float) 49 | chex.assert_equal_shape([a, b]) 50 | mean_fn = jnp.nanmean if ignore_nans else jnp.mean 51 | return mean_fn(jnp.abs(a - b), axis=(-3, -2, -1)) 52 | 53 | 54 | def mse( 55 | a: chex.Array, 56 | b: chex.Array, 57 | *, 58 | ignore_nans: bool = False, 59 | ) -> chex.Numeric: 60 | """Returns the Mean Squared Error between `a` and `b`. 61 | 62 | Args: 63 | a: First image (or set of images). 64 | b: Second image (or set of images). 65 | ignore_nans: If True, will ignore NaNs in the inputs. 66 | 67 | Returns: 68 | MSE between `a` and `b`. 69 | """ 70 | # DO NOT REMOVE - Logging usage. 71 | 72 | chex.assert_rank([a, b], {3, 4}) 73 | chex.assert_type([a, b], float) 74 | chex.assert_equal_shape([a, b]) 75 | mean_fn = jnp.nanmean if ignore_nans else jnp.mean 76 | return mean_fn(jnp.square(a - b), axis=(-3, -2, -1)) 77 | 78 | 79 | def psnr( 80 | a: chex.Array, 81 | b: chex.Array, 82 | *, 83 | ignore_nans: bool = False, 84 | ) -> chex.Numeric: 85 | """Returns the Peak Signal-to-Noise Ratio between `a` and `b`. 86 | 87 | Assumes that the dynamic range of the images (the difference between the 88 | maximum and the minimum allowed values) is 1.0. 89 | 90 | Args: 91 | a: First image (or set of images). 92 | b: Second image (or set of images). 93 | ignore_nans: If True, will ignore NaNs in the inputs. 94 | 95 | Returns: 96 | PSNR in decibels between `a` and `b`. 97 | """ 98 | # DO NOT REMOVE - Logging usage. 99 | 100 | chex.assert_rank([a, b], {3, 4}) 101 | chex.assert_type([a, b], float) 102 | chex.assert_equal_shape([a, b]) 103 | return -10.0 * jnp.log(mse(a, b, ignore_nans=ignore_nans)) / jnp.log(10.0) 104 | 105 | 106 | def rmse( 107 | a: chex.Array, 108 | b: chex.Array, 109 | *, 110 | ignore_nans: bool = False, 111 | ) -> chex.Array: 112 | """Returns the Root Mean Squared Error between `a` and `b`. 113 | 114 | Args: 115 | a: First image (or set of images). 116 | b: Second image (or set of images). 117 | ignore_nans: If True, will ignore NaNs in the inputs. 118 | 119 | Returns: 120 | RMSE between `a` and `b`. 121 | """ 122 | # DO NOT REMOVE - Logging usage. 123 | 124 | chex.assert_rank([a, b], {3, 4}) 125 | chex.assert_type([a, b], float) 126 | chex.assert_equal_shape([a, b]) 127 | return jnp.sqrt(mse(a, b, ignore_nans=ignore_nans)) 128 | 129 | 130 | def simse( 131 | a: chex.Array, 132 | b: chex.Array, 133 | *, 134 | ignore_nans: bool = False, 135 | ) -> chex.Numeric: 136 | """Returns the Scale-Invariant Mean Squared Error between `a` and `b`. 137 | 138 | For each image pair, a scaling factor for `b` is computed as the solution to 139 | the following problem: 140 | 141 | min_alpha || vec(a) - alpha * vec(b) ||_2^2 142 | 143 | where `a` and `b` are flattened, i.e., vec(x) = np.flatten(x). The MSE between 144 | the optimally scaled `b` and `a` is returned: mse(a, alpha*b). 145 | 146 | This is a scale-invariant metric, so for example: simse(x, y) == sims(x, y*5). 147 | 148 | This metric was used in "Shape, Illumination, and Reflectance from Shading" by 149 | Barron and Malik, TPAMI, '15. 150 | 151 | Args: 152 | a: First image (or set of images). 153 | b: Second image (or set of images). 154 | ignore_nans: If True, will ignore NaNs in the inputs. 155 | 156 | Returns: 157 | SIMSE between `a` and `b`. 158 | """ 159 | # DO NOT REMOVE - Logging usage. 160 | 161 | chex.assert_rank([a, b], {3, 4}) 162 | chex.assert_type([a, b], float) 163 | chex.assert_equal_shape([a, b]) 164 | 165 | sum_fn = jnp.nansum if ignore_nans else jnp.sum 166 | a_dot_b = sum_fn((a * b), axis=(-3, -2, -1), keepdims=True) 167 | b_dot_b = sum_fn((b * b), axis=(-3, -2, -1), keepdims=True) 168 | alpha = a_dot_b / b_dot_b 169 | return mse(a, alpha * b, ignore_nans=ignore_nans) 170 | 171 | 172 | def ssim( 173 | a: chex.Array, 174 | b: chex.Array, 175 | *, 176 | max_val: float = 1.0, 177 | filter_size: int = 11, 178 | filter_sigma: float = 1.5, 179 | k1: float = 0.01, 180 | k2: float = 0.03, 181 | return_map: bool = False, 182 | precision=jax.lax.Precision.HIGHEST, 183 | filter_fn: Optional[Callable[[chex.Array], chex.Array]] = None, 184 | ignore_nans: bool = False, 185 | ) -> chex.Numeric: 186 | """Computes the structural similarity index (SSIM) between image pairs. 187 | 188 | This function is based on the standard SSIM implementation from: 189 | Z. Wang, A. C. Bovik, H. R. Sheikh and E. P. Simoncelli, 190 | "Image quality assessment: from error visibility to structural similarity", 191 | in IEEE Transactions on Image Processing, vol. 13, no. 4, pp. 600-612, 2004. 192 | 193 | This function was modeled after tf.image.ssim, and should produce comparable 194 | output. 195 | 196 | Note: the true SSIM is only defined on grayscale. This function does not 197 | perform any colorspace transform. If the input is in a color space, then it 198 | will compute the average SSIM. 199 | 200 | Args: 201 | a: First image (or set of images). 202 | b: Second image (or set of images). 203 | max_val: The maximum magnitude that `a` or `b` can have. 204 | filter_size: Window size (>= 1). Image dims must be at least this small. 205 | filter_sigma: The bandwidth of the Gaussian used for filtering (> 0.). 206 | k1: One of the SSIM dampening parameters (> 0.). 207 | k2: One of the SSIM dampening parameters (> 0.). 208 | return_map: If True, will cause the per-pixel SSIM "map" to be returned. 209 | precision: The numerical precision to use when performing convolution. 210 | filter_fn: An optional argument for overriding the filter function used by 211 | SSIM, which would otherwise be a 2D Gaussian blur specified by filter_size 212 | and filter_sigma. 213 | ignore_nans: If True, will ignore NaNs in the inputs. 214 | 215 | Returns: 216 | Each image's mean SSIM, or a tensor of individual values if `return_map`. 217 | """ 218 | # DO NOT REMOVE - Logging usage. 219 | 220 | chex.assert_rank([a, b], {3, 4}) 221 | chex.assert_type([a, b], float) 222 | chex.assert_equal_shape([a, b]) 223 | 224 | if filter_fn is None: 225 | # Construct a 1D Gaussian blur filter. 226 | hw = filter_size // 2 227 | shift = (2 * hw - filter_size + 1) / 2 228 | f_i = ((jnp.arange(filter_size) - hw + shift) / filter_sigma) ** 2 229 | filt = jnp.exp(-0.5 * f_i) 230 | filt /= jnp.sum(filt) 231 | 232 | # Construct a 1D convolution. 233 | def filter_fn_1(z): 234 | return jnp.convolve(z, filt, mode="valid", precision=precision) 235 | 236 | filter_fn_vmap = jax.vmap(filter_fn_1) 237 | 238 | # Apply the vectorized filter along the y axis. 239 | def filter_fn_y(z): 240 | z_flat = jnp.moveaxis(z, -3, -1).reshape((-1, z.shape[-3])) 241 | z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + ( 242 | z.shape[-2], 243 | z.shape[-1], 244 | -1, 245 | ) 246 | z_filtered = jnp.moveaxis( 247 | filter_fn_vmap(z_flat).reshape(z_filtered_shape), -1, -3 248 | ) 249 | return z_filtered 250 | 251 | # Apply the vectorized filter along the x axis. 252 | def filter_fn_x(z): 253 | z_flat = jnp.moveaxis(z, -2, -1).reshape((-1, z.shape[-2])) 254 | z_filtered_shape = ((z.shape[-4],) if z.ndim == 4 else ()) + ( 255 | z.shape[-3], 256 | z.shape[-1], 257 | -1, 258 | ) 259 | z_filtered = jnp.moveaxis( 260 | filter_fn_vmap(z_flat).reshape(z_filtered_shape), -1, -2 261 | ) 262 | return z_filtered 263 | 264 | # Apply the blur in both x and y. 265 | filter_fn = lambda z: filter_fn_y(filter_fn_x(z)) 266 | 267 | mu0 = filter_fn(a) 268 | mu1 = filter_fn(b) 269 | mu00 = mu0 * mu0 270 | mu11 = mu1 * mu1 271 | mu01 = mu0 * mu1 272 | sigma00 = filter_fn(a**2) - mu00 273 | sigma11 = filter_fn(b**2) - mu11 274 | sigma01 = filter_fn(a * b) - mu01 275 | 276 | # Clip the variances and covariances to valid values. 277 | # Variance must be non-negative: 278 | epsilon = jnp.finfo(jnp.float32).eps ** 2 279 | sigma00 = jnp.maximum(epsilon, sigma00) 280 | sigma11 = jnp.maximum(epsilon, sigma11) 281 | sigma01 = jnp.sign(sigma01) * jnp.minimum( 282 | jnp.sqrt(sigma00 * sigma11), jnp.abs(sigma01) 283 | ) 284 | 285 | c1 = (k1 * max_val) ** 2 286 | c2 = (k2 * max_val) ** 2 287 | numer = (2 * mu01 + c1) * (2 * sigma01 + c2) 288 | denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2) 289 | ssim_map = numer / denom 290 | mean_fn = jnp.nanmean if ignore_nans else jnp.mean 291 | ssim_value = mean_fn(ssim_map, axis=tuple(range(-3, 0))) 292 | return ssim_map if return_map else ssim_value 293 | -------------------------------------------------------------------------------- /dm_pix/_src/metrics_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | import functools 15 | 16 | from absl.testing import absltest 17 | import chex 18 | from dm_pix._src import metrics 19 | import jax 20 | import numpy as np 21 | import tensorflow as tf 22 | 23 | 24 | class MSETest(chex.TestCase, absltest.TestCase): 25 | 26 | def setUp(self): 27 | super().setUp() 28 | key = jax.random.PRNGKey(0) 29 | key1, key2 = jax.random.split(key) 30 | self._img1 = jax.random.uniform( 31 | key1, 32 | shape=(4, 32, 32, 3), 33 | minval=0.0, 34 | maxval=1.0, 35 | ) 36 | self._img2 = jax.random.uniform( 37 | key2, 38 | shape=(4, 32, 32, 3), 39 | minval=0.0, 40 | maxval=1.0, 41 | ) 42 | 43 | @chex.all_variants 44 | def test_psnr_match(self): 45 | psnr = self.variant(metrics.psnr) 46 | values_jax = psnr(self._img1, self._img2) 47 | 48 | values_tf = tf.image.psnr(self._img1, self._img2, max_val=1.0).numpy() 49 | 50 | np.testing.assert_allclose(values_jax, values_tf, rtol=1e-3, atol=1e-3) 51 | 52 | @chex.all_variants 53 | def test_psnr_ignore_nans(self): 54 | psnr = self.variant(functools.partial(metrics.psnr, ignore_nans=True)) 55 | 56 | values_jax_nan = psnr( 57 | self._img1.at[:, 0, 0, 0].set(np.nan), 58 | self._img1.at[:, 0, 0, 0].set(np.nan), 59 | ) 60 | 61 | assert not np.any(np.isnan(values_jax_nan)) 62 | 63 | @chex.all_variants 64 | def test_simse_invariance(self): 65 | simse = self.variant(metrics.simse) 66 | 67 | simse_jax = simse(self._img1, self._img1 * 2.0) 68 | 69 | np.testing.assert_allclose(simse_jax, np.zeros(4), rtol=1e-6, atol=1e-6) 70 | 71 | @chex.all_variants 72 | def test_simse_ignore_nans(self): 73 | simse = self.variant(functools.partial(metrics.simse, ignore_nans=True)) 74 | 75 | simse_jax_nan = simse( 76 | self._img1.at[:, 0, 0, 0].set(np.nan), 77 | self._img1.at[:, 0, 0, 0].set(np.nan), 78 | ) 79 | 80 | assert not np.any(np.isnan(simse_jax_nan)) 81 | 82 | 83 | class SSIMTests(chex.TestCase, absltest.TestCase): 84 | 85 | @chex.all_variants 86 | def test_ssim_golden(self): 87 | """Test that the SSIM implementation matches the Tensorflow version.""" 88 | 89 | key = jax.random.PRNGKey(0) 90 | for shape in ((2, 12, 12, 3), (12, 12, 3), (2, 12, 15, 3), (17, 12, 3)): 91 | for _ in range(4): 92 | ( 93 | max_val_key, 94 | img0_key, 95 | img1_key, 96 | filter_size_key, 97 | filter_sigma_key, 98 | k1_key, 99 | k2_key, 100 | key, 101 | ) = jax.random.split(key, 8) 102 | max_val = jax.random.uniform(max_val_key, minval=0.1, maxval=3.0) 103 | img0 = max_val * jax.random.uniform(img0_key, shape=shape) 104 | img1 = max_val * jax.random.uniform(img1_key, shape=shape) 105 | filter_size = jax.random.randint( 106 | filter_size_key, shape=(), minval=1, maxval=10 107 | ) 108 | filter_sigma = jax.random.uniform( 109 | filter_sigma_key, shape=(), minval=0.1, maxval=10.0 110 | ) 111 | k1 = jax.random.uniform(k1_key, shape=(), minval=0.001, maxval=0.1) 112 | k2 = jax.random.uniform(k2_key, shape=(), minval=0.001, maxval=0.1) 113 | 114 | ssim_gt = tf.image.ssim( 115 | img0, 116 | img1, 117 | max_val, 118 | filter_size=filter_size, 119 | filter_sigma=filter_sigma, 120 | k1=k1, 121 | k2=k2, 122 | ).numpy() 123 | for return_map in [False, True]: 124 | ssim_fn = self.variant( 125 | functools.partial( 126 | metrics.ssim, 127 | max_val=max_val, 128 | filter_size=filter_size, 129 | filter_sigma=filter_sigma, 130 | k1=k1, 131 | k2=k2, 132 | return_map=return_map, 133 | ) 134 | ) 135 | 136 | ssim = ssim_fn(img0, img1) 137 | 138 | if not return_map: 139 | np.testing.assert_allclose(ssim, ssim_gt, atol=1e-5, rtol=1e-5) 140 | else: 141 | np.testing.assert_allclose( 142 | np.mean(ssim, list(range(-3, 0))), ssim_gt, atol=1e-5, rtol=1e-5 143 | ) 144 | self.assertLessEqual(np.max(ssim), 1.0) 145 | self.assertGreaterEqual(np.min(ssim), -1.0) 146 | 147 | @chex.all_variants 148 | def test_ssim_lowerbound(self): 149 | """Test the unusual corner case where SSIM is -1.""" 150 | filter_size = 11 151 | grid_coords = [np.linspace(-1, 1, filter_size)] * 2 152 | img = np.meshgrid(*grid_coords)[0][np.newaxis, ..., np.newaxis] 153 | eps = 1e-5 154 | ssim_fn = self.variant( 155 | functools.partial( 156 | metrics.ssim, 157 | max_val=1.0, 158 | filter_size=filter_size, 159 | filter_sigma=1.5, 160 | k1=eps, 161 | k2=eps, 162 | ) 163 | ) 164 | 165 | ssim = ssim_fn(img, -img) 166 | 167 | np.testing.assert_allclose(ssim, -np.ones_like(ssim), atol=1e-5, rtol=1e-5) 168 | 169 | @chex.all_variants 170 | def test_ssim_finite_grad(self): 171 | """Test that SSIM produces a finite gradient on large flat regions.""" 172 | img = np.zeros((64, 64, 3)) 173 | 174 | grad = self.variant(jax.grad(metrics.ssim))(img, img) 175 | 176 | np.testing.assert_equal(grad, np.zeros_like(grad)) 177 | 178 | @chex.all_variants 179 | def test_ssim_ignore_nans(self): 180 | """Test that SSIM ignores NaNs.""" 181 | ssim_fn = self.variant( 182 | functools.partial( 183 | metrics.ssim, 184 | max_val=1.0, 185 | filter_size=11, 186 | filter_sigma=1.5, 187 | k1=0.01, 188 | k2=0.03, 189 | ignore_nans=True, 190 | ) 191 | ) 192 | key = jax.random.PRNGKey(0) 193 | _, key1 = jax.random.split(key) 194 | img = jax.random.uniform( 195 | key1, 196 | shape=(4, 32, 32, 3), 197 | minval=0.0, 198 | maxval=1.0, 199 | ) 200 | 201 | ssim = ssim_fn( 202 | img.at[:, 0, 0, 0].set(np.nan), img.at[:, 0, 0, 0].set(np.nan) 203 | ) 204 | 205 | assert not np.any(np.isnan(ssim)) 206 | 207 | 208 | if __name__ == "__main__": 209 | absltest.main() 210 | -------------------------------------------------------------------------------- /dm_pix/_src/patch.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """This module provides image patching functionality.""" 15 | 16 | from typing import Sequence 17 | 18 | import chex 19 | import jax 20 | import jax.numpy as jnp 21 | 22 | # DO NOT REMOVE - Logging lib. 23 | 24 | 25 | def extract_patches( 26 | images: chex.Array, 27 | sizes: Sequence[int], 28 | strides: Sequence[int], 29 | rates: Sequence[int], 30 | *, 31 | padding: str = "VALID", 32 | ) -> jnp.ndarray: 33 | """Extract patches from images. 34 | 35 | This function is a wrapper for `jax.lax.conv_general_dilated_patches` 36 | to conform to the same interface as `tf.image.extract_patches`, except for 37 | this function supports arbitrary-dimensional `images`, not only 4D as in 38 | `tf.image.extract_patches`. 39 | 40 | The function extracts patches of shape `sizes` from `images` in the same 41 | manner as a convolution with kernel of shape `sizes`, stride equal to 42 | `strides`, and the given `padding` scheme. The patches are stacked in the 43 | channel dimension. 44 | 45 | Args: 46 | images: input batch of images of shape [B, H, W, ..., C]. 47 | sizes: size of the extracted patches. 48 | Must be [1, size_rows, size_cols, ..., 1]. 49 | strides: how far the centers of two consecutive patches are in the images. 50 | Must be [1, stride_rows, stride_cols, ..., 1]. 51 | rates: sampling rate. Must be [1, rate_rows, rate_cols, ..., 1]. This is the 52 | input stride, specifying how far two consecutive patch samples are in the 53 | input. Equivalent to extracting patches with `patch_sizes_eff = 54 | patch_sizes + (patch_sizes - 1) * (rates - 1)`, followed by subsampling 55 | them spatially by a factor of rates. This is equivalent to rate in dilated 56 | (a.k.a. Atrous) convolutions. 57 | padding: the type of padding algorithm to use. 58 | 59 | Returns: 60 | Tensor of shape 61 | [B, patch_rows, patch_cols, ..., size_rows * size_cols * ... * C]. 62 | """ 63 | # DO NOT REMOVE - Logging usage. 64 | 65 | ndim = images.ndim 66 | 67 | if len(sizes) != ndim or sizes[0] != 1 or sizes[-1] != 1: 68 | raise ValueError("Input `sizes` must be [1, size_rows, size_cols, ..., 1] " 69 | f"and same length as `images.ndim` {ndim}. Got {sizes}.") 70 | if len(strides) != ndim or strides[0] != 1 or strides[-1] != 1: 71 | raise ValueError("Input `strides` must be [1, size_rows, size_cols, ..., 1]" 72 | f"and same length as `images.ndim` {ndim}. Got {strides}.") 73 | if len(rates) != ndim or rates[0] != 1 or rates[-1] != 1: 74 | raise ValueError("Input `rates` must be [1, size_rows, size_cols, ..., 1] " 75 | f"and same length as `images.ndim` {ndim}. Got {rates}.") 76 | 77 | channels = images.shape[-1] 78 | lhs_spec = out_spec = (0, ndim - 1) + tuple(range(1, ndim - 1)) 79 | rhs_spec = tuple(range(ndim)) 80 | patches = jax.lax.conv_general_dilated_patches( 81 | lhs=images, 82 | filter_shape=sizes[1:-1], 83 | window_strides=strides[1:-1], 84 | padding=padding, 85 | rhs_dilation=rates[1:-1], 86 | dimension_numbers=jax.lax.ConvDimensionNumbers( 87 | lhs_spec, rhs_spec, out_spec) 88 | ) 89 | 90 | # `conv_general_dilated_patches` returns `patches` in channel-major order, 91 | # rearrange to match interface of `tf.image.extract_patches`. 92 | patches = jnp.reshape(patches, patches.shape[:-1] + (channels, -1)) 93 | patches = jnp.moveaxis(patches, -2, -1) 94 | patches = jnp.reshape(patches, patches.shape[:-2] + (-1,)) 95 | return patches 96 | -------------------------------------------------------------------------------- /dm_pix/_src/patch_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Tests for dm_pix._src.patch.""" 15 | 16 | import functools 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | import chex 21 | from dm_pix._src import patch 22 | import numpy as np 23 | import tensorflow as tf 24 | 25 | 26 | def _create_test_images(shape): 27 | images = np.arange(np.prod(np.array(shape)), dtype=np.float32) 28 | return np.reshape(images, shape) 29 | 30 | 31 | class PatchTest(chex.TestCase, parameterized.TestCase): 32 | 33 | @chex.all_variants 34 | @parameterized.named_parameters( 35 | ('padding_valid', 'VALID'), 36 | ('padding_same', 'SAME'), 37 | ) 38 | def test_extract_patches(self, padding): 39 | image_shape = (2, 5, 7, 3) 40 | images = _create_test_images(image_shape) 41 | 42 | sizes = (1, 2, 3, 1) 43 | strides = (1, 1, 2, 1) 44 | rates = (1, 2, 1, 1) 45 | 46 | extract_patches = self.variant( 47 | functools.partial(patch.extract_patches, padding=padding), 48 | static_argnums=(1, 2, 3)) 49 | jax_patches = extract_patches( 50 | images, 51 | sizes, 52 | strides, 53 | rates, 54 | ) 55 | tf_patches = tf.image.extract_patches( 56 | images, 57 | sizes=sizes, 58 | strides=strides, 59 | rates=rates, 60 | padding=padding, 61 | ) 62 | np.testing.assert_array_equal(jax_patches, tf_patches.numpy()) 63 | 64 | @chex.all_variants 65 | @parameterized.named_parameters( 66 | ('padding_valid', 'VALID'), 67 | ('padding_same', 'SAME'), 68 | ) 69 | def test_extract_patches_0d(self, padding): 70 | image_shape = (2, 3) 71 | images = _create_test_images(image_shape) 72 | 73 | sizes = (1, 1) 74 | strides = (1, 1) 75 | rates = (1, 1) 76 | 77 | extract_patches = self.variant( 78 | functools.partial(patch.extract_patches, padding=padding), 79 | static_argnums=(1, 2, 3)) 80 | jax_patches = extract_patches( 81 | images, 82 | sizes, 83 | strides, 84 | rates, 85 | ) 86 | # 0D patches is a no-op. 87 | np.testing.assert_array_equal(jax_patches, images) 88 | 89 | @chex.all_variants 90 | @parameterized.named_parameters( 91 | ('padding_valid', 'VALID'), 92 | ('padding_same', 'SAME'), 93 | ) 94 | def test_extract_patches_1d(self, padding): 95 | image_shape = (2, 7, 3) 96 | images = _create_test_images(image_shape) 97 | 98 | sizes = (1, 2, 1) 99 | strides = (1, 1, 1) 100 | rates = (1, 2, 1) 101 | 102 | extract_patches = self.variant( 103 | functools.partial(patch.extract_patches, padding=padding), 104 | static_argnums=(1, 2, 3)) 105 | jax_patches = extract_patches( 106 | images, 107 | sizes, 108 | strides, 109 | rates, 110 | ) 111 | jax_patches = np.expand_dims(jax_patches, -2) 112 | # Reference patches are computed over an image with an extra singleton dim. 113 | tf_patches = tf.image.extract_patches( 114 | np.expand_dims(images, 2), 115 | sizes=sizes + (1,), 116 | strides=strides + (1,), 117 | rates=rates + (1,), 118 | padding=padding, 119 | ) 120 | np.testing.assert_array_equal(jax_patches, tf_patches.numpy()) 121 | 122 | @chex.all_variants 123 | def test_extract_patches_3d(self): 124 | image_shape = (2, 4, 9, 6, 3) 125 | images = _create_test_images(image_shape) 126 | 127 | sizes = (1, 2, 3, 2, 1) 128 | strides = (1, 2, 3, 2, 1) 129 | rates = (1, 1, 1, 1, 1) 130 | 131 | extract_patches = self.variant( 132 | functools.partial(patch.extract_patches, padding='VALID'), 133 | static_argnums=(1, 2, 3)) 134 | jax_patches = extract_patches( 135 | images, 136 | sizes, 137 | strides, 138 | rates, 139 | ) 140 | # Reconstructing the original from non-overlapping patches. 141 | images_reconstructed = np.reshape( 142 | jax_patches, 143 | jax_patches.shape[:-1] + sizes[1:-1] + images.shape[-1:] 144 | ) 145 | images_reconstructed = np.moveaxis(images_reconstructed, 146 | (-4, -3, -2), 147 | (2, 4, 6)) 148 | images_reconstructed = images_reconstructed.reshape(image_shape) 149 | np.testing.assert_allclose(images_reconstructed, images, rtol=5e-3) 150 | 151 | @chex.all_variants 152 | @parameterized.product( 153 | ({ 154 | 'sizes': (1, 2, 3), 155 | 'strides': (1, 1, 2, 1), 156 | 'rates': (1, 2, 1, 1), 157 | }, { 158 | 'sizes': (1, 2, 3, 1), 159 | 'strides': (1, 1, 2), 160 | 'rates': (1, 2, 1, 1), 161 | }, { 162 | 'sizes': (1, 2, 3, 1), 163 | 'strides': (1, 1, 2, 1), 164 | 'rates': (1, 2, 1), 165 | }, { 166 | 'sizes': (1, 2, 1), 167 | 'strides': (1, 2, 1), 168 | 'rates': (1, 1), 169 | }, { 170 | 'sizes': (1, 1), 171 | 'strides': (1, 2), 172 | 'rates': (1, 1), 173 | }, { 174 | 'sizes': (1, 1), 175 | 'strides': (1,), 176 | 'rates': (1, 1), 177 | }, { 178 | 'sizes': (1, 2, 3, 4, 1), 179 | 'strides': (1, 2, 3, 4, 2), 180 | 'rates': (1, 1, 1, 1, 1), 181 | }, { 182 | 'sizes': (1, 2, 3, 1), 183 | 'strides': (1, 2, 3, 4, 1), 184 | 'rates': (1, 1, 1, 1, 1), 185 | }), 186 | padding=('VALID', 'SAME'), 187 | ) 188 | def test_extract_patches_raises(self, sizes, strides, rates, padding): 189 | image_shape = (2, 5, 7, 3) 190 | images = _create_test_images(image_shape) 191 | 192 | extract_patches = self.variant( 193 | functools.partial(patch.extract_patches, padding=padding), 194 | static_argnums=(1, 2, 3)) 195 | with self.assertRaises(ValueError): 196 | extract_patches( 197 | images, 198 | sizes, 199 | strides, 200 | rates, 201 | ) 202 | 203 | 204 | if __name__ == '__main__': 205 | absltest.main() 206 | -------------------------------------------------------------------------------- /dm_pix/_src/test_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Testing utilities for PIX.""" 15 | 16 | import inspect 17 | import types 18 | from typing import Sequence, Tuple 19 | 20 | 21 | def get_public_functions( 22 | root_module: types.ModuleType) -> Sequence[Tuple[str, types.FunctionType]]: 23 | """Returns `(function_name, function)` for all functions of `root_module`.""" 24 | fns = [] 25 | for name in dir(root_module): 26 | o = getattr(root_module, name) 27 | if inspect.isfunction(o): 28 | fns.append((name, o)) 29 | return fns 30 | 31 | 32 | def get_public_symbols( 33 | root_module: types.ModuleType) -> Sequence[Tuple[str, types.FunctionType]]: 34 | """Returns `(symbol_name, symbol)` for all symbols of `root_module`.""" 35 | fns = [] 36 | for name in getattr(root_module, '__all__'): 37 | o = getattr(root_module, name) 38 | fns.append((name, o)) 39 | return fns 40 | -------------------------------------------------------------------------------- /dm_pix/images/pix_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/dm_pix/images/pix_logo.png -------------------------------------------------------------------------------- /dm_pix/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/dm_pix/py.typed -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SOURCEDIR = . 8 | BUILDDIR = _build 9 | 10 | # Put it first so that "make" without argument is like "make help". 11 | help: 12 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 13 | 14 | .PHONY: help Makefile 15 | 16 | # Catch-all target: route all unknown targets to Sphinx using the new 17 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 18 | %: Makefile 19 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 20 | -------------------------------------------------------------------------------- /docs/api.rst: -------------------------------------------------------------------------------- 1 | Augmentations 2 | ============= 3 | 4 | .. currentmodule:: dm_pix 5 | 6 | .. autosummary:: 7 | 8 | adjust_brightness 9 | adjust_contrast 10 | adjust_gamma 11 | adjust_hue 12 | adjust_saturation 13 | affine_transform 14 | center_crop 15 | elastic_deformation 16 | flip_left_right 17 | flip_up_down 18 | gaussian_blur 19 | pad_to_size 20 | random_brightness 21 | random_contrast 22 | random_crop 23 | random_flip_left_right 24 | random_flip_up_down 25 | random_gamma 26 | random_hue 27 | random_saturation 28 | resize_with_crop_or_pad 29 | rotate 30 | rot90 31 | solarize 32 | 33 | adjust_brightness 34 | ~~~~~~~~~~~~~~~~~ 35 | 36 | .. autofunction:: adjust_brightness 37 | 38 | adjust_contrast 39 | ~~~~~~~~~~~~~~~ 40 | 41 | .. autofunction:: adjust_contrast 42 | 43 | adjust_gamma 44 | ~~~~~~~~~~~~ 45 | 46 | .. autofunction:: adjust_gamma 47 | 48 | adjust_hue 49 | ~~~~~~~~~~ 50 | 51 | .. autofunction:: adjust_hue 52 | 53 | adjust_saturation 54 | ~~~~~~~~~~~~~~~~~ 55 | 56 | .. autofunction:: adjust_saturation 57 | 58 | affine_transform 59 | ~~~~~~~~~~~~~~~~~ 60 | 61 | .. autofunction:: affine_transform 62 | 63 | center_crop 64 | ~~~~~~~~~~~ 65 | 66 | .. autofunction:: center_crop 67 | 68 | 69 | elastic_deformation 70 | ~~~~~~~~~~~~~~~~~~~ 71 | 72 | .. autofunction:: elastic_deformation 73 | 74 | flip_left_right 75 | ~~~~~~~~~~~~~~~ 76 | 77 | .. autofunction:: flip_left_right 78 | 79 | flip_up_down 80 | ~~~~~~~~~~~~ 81 | 82 | .. autofunction:: flip_up_down 83 | 84 | gaussian_blur 85 | ~~~~~~~~~~~~~ 86 | 87 | .. autofunction:: gaussian_blur 88 | 89 | pad_to_size 90 | ~~~~~~~~~~~ 91 | 92 | .. autofunction:: pad_to_size 93 | 94 | random_brightness 95 | ~~~~~~~~~~~~~~~~~ 96 | 97 | .. autofunction:: random_brightness 98 | 99 | random_contrast 100 | ~~~~~~~~~~~~~~~ 101 | 102 | .. autofunction:: random_contrast 103 | 104 | random_crop 105 | ~~~~~~~~~~~ 106 | 107 | .. autofunction:: random_crop 108 | 109 | random_flip_left_right 110 | ~~~~~~~~~~~~~~~~~~~~~~ 111 | 112 | .. autofunction:: random_flip_left_right 113 | 114 | random_flip_up_down 115 | ~~~~~~~~~~~~~~~~~~~ 116 | 117 | .. autofunction:: random_flip_up_down 118 | 119 | random_gamma 120 | ~~~~~~~~~~~~ 121 | 122 | .. autofunction:: random_gamma 123 | 124 | random_hue 125 | ~~~~~~~~~~ 126 | 127 | .. autofunction:: random_hue 128 | 129 | random_saturation 130 | ~~~~~~~~~~~~~~~~~ 131 | 132 | .. autofunction:: random_saturation 133 | 134 | resize_with_crop_or_pad 135 | ~~~~~~~~~~~~~~~~~~~~~~~ 136 | 137 | .. autofunction:: resize_with_crop_or_pad 138 | 139 | rotate 140 | ~~~~~~ 141 | 142 | .. autofunction:: rotate 143 | 144 | rot90 145 | ~~~~~ 146 | 147 | .. autofunction:: rot90 148 | 149 | solarize 150 | ~~~~~~~~ 151 | 152 | .. autofunction:: solarize 153 | 154 | 155 | Color conversions 156 | ================= 157 | 158 | .. currentmodule:: dm_pix 159 | 160 | .. autosummary:: 161 | hsl_to_rgb 162 | hsv_to_rgb 163 | rgb_to_hsl 164 | rgb_to_hsv 165 | rgb_to_grayscale 166 | 167 | hsl_to_rgb 168 | ~~~~~~~~~~ 169 | 170 | .. autofunction:: hsl_to_rgb 171 | 172 | hsv_to_rgb 173 | ~~~~~~~~~~ 174 | 175 | .. autofunction:: hsv_to_rgb 176 | 177 | rgb_to_hsl 178 | ~~~~~~~~~~ 179 | 180 | .. autofunction:: rgb_to_hsl 181 | 182 | rgb_to_hsv 183 | ~~~~~~~~~~ 184 | 185 | .. autofunction:: rgb_to_hsv 186 | 187 | rgb_to_grayscale 188 | ~~~~~~~~~~~~~~~~ 189 | 190 | .. autofunction:: rgb_to_grayscale 191 | 192 | 193 | Depth and space transformations 194 | =============================== 195 | 196 | .. currentmodule:: dm_pix 197 | 198 | .. autosummary:: 199 | depth_to_space 200 | space_to_depth 201 | 202 | depth_to_space 203 | ~~~~~~~~~~~~~~ 204 | 205 | .. autofunction:: depth_to_space 206 | 207 | space_to_depth 208 | ~~~~~~~~~~~~~~ 209 | 210 | .. autofunction:: space_to_depth 211 | 212 | 213 | Interpolation functions 214 | ======================= 215 | 216 | .. currentmodule:: dm_pix 217 | 218 | .. autosummary:: 219 | flat_nd_linear_interpolate 220 | flat_nd_linear_interpolate_constant 221 | 222 | flat_nd_linear_interpolate 223 | ~~~~~~~~~~~~~~~~~~~~~~~~~~ 224 | 225 | .. autofunction:: flat_nd_linear_interpolate 226 | 227 | flat_nd_linear_interpolate_constant 228 | ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 229 | 230 | .. autofunction:: flat_nd_linear_interpolate_constant 231 | 232 | 233 | Metrics 234 | ======= 235 | 236 | .. currentmodule:: dm_pix 237 | 238 | .. autosummary:: 239 | mae 240 | mse 241 | psnr 242 | rmse 243 | simse 244 | ssim 245 | 246 | mae 247 | ~~~ 248 | 249 | .. autofunction:: mae 250 | 251 | mse 252 | ~~~ 253 | 254 | .. autofunction:: mse 255 | 256 | psnr 257 | ~~~~ 258 | 259 | .. autofunction:: psnr 260 | 261 | rmse 262 | ~~~~ 263 | 264 | .. autofunction:: rmse 265 | 266 | simse 267 | ~~~~~ 268 | 269 | .. autofunction:: simse 270 | 271 | ssim 272 | ~~~~ 273 | 274 | .. autofunction:: ssim 275 | 276 | 277 | Patch extraction functions 278 | ========================== 279 | 280 | .. currentmodule:: dm_pix 281 | 282 | .. autosummary:: 283 | extract_patches 284 | 285 | extract_patches 286 | ~~~~~~~~~~~~~~~ 287 | 288 | .. autofunction:: extract_patches 289 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Configuration file for the Sphinx documentation builder.""" 15 | 16 | # This file only contains a selection of the most common options. For a full 17 | # list see the documentation: 18 | # http://www.sphinx-doc.org/en/master/config 19 | 20 | # -- Path setup -------------------------------------------------------------- 21 | 22 | # If extensions (or modules to document with autodoc) are in another directory, 23 | # add these directories to sys.path here. If the directory is relative to the 24 | # documentation root, use os.path.abspath to make it absolute, like shown here. 25 | 26 | # pylint: disable=g-bad-import-order 27 | # pylint: disable=g-import-not-at-top 28 | import inspect 29 | import os 30 | import sys 31 | import typing 32 | 33 | 34 | def _add_annotations_import(path): 35 | """Appends a future annotations import to the file at the given path.""" 36 | with open(path) as f: 37 | contents = f.read() 38 | if contents.startswith('from __future__ import annotations'): 39 | # If we run sphinx multiple times then we will append the future import 40 | # multiple times too. 41 | return 42 | 43 | assert contents.startswith('#'), (path, contents.split('\n')[0]) 44 | with open(path, 'w') as f: 45 | # NOTE: This is subtle and not unit tested, we're prefixing the first line 46 | # in each Python file with this future import. It is important to prefix 47 | # not insert a newline such that source code locations are accurate (we link 48 | # to GitHub). The assertion above ensures that the first line in the file is 49 | # a comment so it is safe to prefix it. 50 | f.write('from __future__ import annotations ') 51 | f.write(contents) 52 | 53 | 54 | def _recursive_add_annotations_import(): 55 | for path, _, files in os.walk('../dm_pix/'): 56 | for file in files: 57 | if file.endswith('.py'): 58 | _add_annotations_import(os.path.abspath(os.path.join(path, file))) 59 | 60 | 61 | if 'READTHEDOCS' in os.environ: 62 | _recursive_add_annotations_import() 63 | 64 | # We remove `None` type annotations as this breaks Sphinx under Python 3.7 and 65 | # 3.8 with error `AssertionError: Invalid annotation [...] None is not a class.` 66 | filter_nones = lambda x: dict((k, v) for k, v in x.items() if v is not None) 67 | typing.get_type_hints = lambda obj, *unused: filter_nones(obj.__annotations__) 68 | sys.path.insert(0, os.path.abspath('../')) 69 | sys.path.append(os.path.abspath('ext')) 70 | 71 | import dm_pix as pix 72 | from sphinxcontrib import katex 73 | 74 | # -- Project information ----------------------------------------------------- 75 | 76 | project = 'PIX' 77 | copyright = '2021, DeepMind' # pylint: disable=redefined-builtin 78 | author = 'PIX Contributors' 79 | 80 | # -- General configuration --------------------------------------------------- 81 | 82 | master_doc = 'index' 83 | 84 | # Add any Sphinx extension module names here, as strings. They can be 85 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 86 | # ones. 87 | extensions = [ 88 | 'sphinx.ext.autodoc', 89 | 'sphinx.ext.autosummary', 90 | 'sphinx.ext.doctest', 91 | 'sphinx.ext.inheritance_diagram', 92 | 'sphinx.ext.intersphinx', 93 | 'sphinx.ext.linkcode', 94 | 'sphinx.ext.napoleon', 95 | 'sphinxcontrib.bibtex', 96 | 'sphinxcontrib.katex', 97 | 'sphinx_autodoc_typehints', 98 | 'sphinx_rtd_theme', 99 | 'coverage_check', 100 | 'myst_nb', # This is used for the .ipynb notebooks 101 | ] 102 | 103 | # Add any paths that contain templates here, relative to this directory. 104 | templates_path = ['_templates'] 105 | 106 | # List of patterns, relative to source directory, that match files and 107 | # directories to ignore when looking for source files. 108 | # This pattern also affects html_static_path and html_extra_path. 109 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 110 | 111 | # -- Options for autodoc ----------------------------------------------------- 112 | 113 | autodoc_default_options = { 114 | 'member-order': 'bysource', 115 | 'special-members': True, 116 | 'exclude-members': '__repr__, __str__, __weakref__', 117 | } 118 | 119 | # -- Options for HTML output ------------------------------------------------- 120 | 121 | # The theme to use for HTML and HTML Help pages. See the documentation for 122 | # a list of builtin themes. 123 | # 124 | html_theme = 'sphinx_rtd_theme' 125 | 126 | # Add any paths that contain custom static files (such as style sheets) here, 127 | # relative to this directory. They are copied after the builtin static files, 128 | # so a file named "default.css" will overwrite the builtin "default.css". 129 | html_static_path = [] 130 | 131 | # -- Options for bibtex ------------------------------------------------------ 132 | 133 | bibtex_bibfiles = [] 134 | 135 | # -- Options for myst ------------------------------------------------------- 136 | 137 | jupyter_execute_notebooks = 'force' 138 | execution_allow_errors = False 139 | 140 | # -- Options for katex ------------------------------------------------------ 141 | 142 | # See: https://sphinxcontrib-katex.readthedocs.io/en/0.4.1/macros.html 143 | latex_macros = r""" 144 | \def \d #1{\operatorname{#1}} 145 | """ 146 | 147 | # Translate LaTeX macros to KaTeX and add to options for HTML builder 148 | katex_macros = katex.latex_defs_to_katex_macros(latex_macros) 149 | katex_options = 'macros: {' + katex_macros + '}' 150 | 151 | # Add LaTeX macros for LATEX builder 152 | latex_elements = {'preamble': latex_macros} 153 | 154 | # -- Source code links ------------------------------------------------------- 155 | 156 | 157 | def linkcode_resolve(domain, info): 158 | """Resolve a GitHub URL corresponding to Python object.""" 159 | if domain != 'py': 160 | return None 161 | 162 | try: 163 | mod = sys.modules[info['module']] 164 | except ImportError: 165 | return None 166 | 167 | obj = mod 168 | try: 169 | for attr in info['fullname'].split('.'): 170 | obj = getattr(obj, attr) 171 | except AttributeError: 172 | return None 173 | else: 174 | obj = inspect.unwrap(obj) 175 | 176 | try: 177 | filename = inspect.getsourcefile(obj) 178 | except TypeError: 179 | return None 180 | 181 | try: 182 | source, lineno = inspect.getsourcelines(obj) 183 | except OSError: 184 | return None 185 | 186 | return 'https://github.com/deepmind/dm_pix/tree/master/dm_pix/%s#L%d#L%d' % ( 187 | os.path.relpath(filename, start=os.path.dirname( 188 | pix.__file__)), lineno, lineno + len(source) - 1) 189 | 190 | 191 | # -- Intersphinx configuration ----------------------------------------------- 192 | 193 | intersphinx_mapping = { 194 | 'jax': ('https://jax.readthedocs.io/en/latest/', None), 195 | } 196 | 197 | source_suffix = ['.rst', '.md', '.ipynb'] 198 | -------------------------------------------------------------------------------- /docs/ext/coverage_check.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | """Asserts all public symbols are covered in the docs.""" 15 | 16 | from typing import Any, Mapping, Set 17 | 18 | import dm_pix as pix 19 | from dm_pix._src import test_utils 20 | from sphinx import application 21 | from sphinx import builders 22 | from sphinx import errors 23 | 24 | 25 | class PixCoverageCheck(builders.Builder): 26 | """Builder that checks all public symbols are included.""" 27 | 28 | name = "coverage_check" 29 | 30 | def get_outdated_docs(self) -> str: 31 | return "coverage_check" 32 | 33 | def write(self, *ignored: Any) -> None: 34 | pass 35 | 36 | def finish(self) -> None: 37 | 38 | def dm_pix_public_symbols() -> Set[str]: 39 | symbols = set() 40 | for symbol_name, _ in test_utils.get_public_symbols(pix): 41 | symbols.add("dm_pix." + symbol_name) 42 | return symbols 43 | 44 | documented_objects = frozenset(self.env.domaindata["py"]["objects"]) 45 | undocumented_objects = dm_pix_public_symbols() - documented_objects 46 | if undocumented_objects: 47 | undocumented_objects = tuple(sorted(undocumented_objects)) 48 | raise errors.SphinxError( 49 | "All public symbols must be included in our documentation, did you " 50 | "forget to add an entry to `api.rst`?\n" 51 | f"Undocumented symbols: {undocumented_objects}.") 52 | 53 | 54 | def setup(app: application.Sphinx) -> Mapping[str, Any]: 55 | app.add_builder(PixCoverageCheck) 56 | return dict(version=pix.__version__, parallel_read_safe=True) 57 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | :github_url: https://github.com/deepmind/dm_pix/tree/master/docs 2 | 3 | === 4 | PIX 5 | === 6 | 7 | PIX is an image processing library in JAX, for JAX. 8 | 9 | Overview 10 | ======== 11 | 12 | JAX is a library resulting from the union of Autograd and XLA for 13 | high-performance machine learning research. It provides NumPy, SciPy, 14 | automatic differentiation and first-class GPU/TPU support. 15 | 16 | PIX is a library built on top of JAX with the goal of providing image processing 17 | functions and tools to JAX in a way that they can be optimised and parallelised 18 | through `jax.jit`, `jax.vmap` and `jax.pmap`. 19 | 20 | Installation 21 | ============ 22 | 23 | PIX is written in pure Python, but depends on C++ code via JAX. 24 | 25 | Because JAX installation is different depending on your CUDA version, PIX does 26 | not list JAX as a dependency in `requirements.txt`, although it is technically 27 | listed for reference, but commented. 28 | 29 | First, follow JAX installation instructions to install JAX with the relevant 30 | accelerator support. 31 | 32 | Then, install PIX using ``pip``: 33 | 34 | .. code-block:: bash 35 | 36 | pip install dm-pix 37 | 38 | .. toctree:: 39 | :caption: API Documentation 40 | :maxdepth: 2 41 | 42 | api 43 | 44 | 45 | Contribute 46 | ========== 47 | 48 | - `Issue tracker `_ 49 | - `Source code `_ 50 | 51 | Support 52 | ======= 53 | 54 | If you are having issues, please let us know by filing an issue on our 55 | `issue tracker `_. 56 | 57 | License 58 | ======= 59 | 60 | PIX is licensed under the Apache 2.0 License. 61 | 62 | 63 | Indices and Tables 64 | ================== 65 | 66 | * :ref:`genindex` 67 | -------------------------------------------------------------------------------- /examples/README.md: -------------------------------------------------------------------------------- 1 | # PIX examples 2 | 3 | To run the examples you probably need to install extra requirements. To do so, 4 | use the following `pip` command: 5 | 6 | ```bash 7 | $ pip install -r requirements_examples.txt 8 | ``` 9 | 10 | Then, to run the examples, you can simply run the following commands: 11 | 12 | ```bash 13 | $ cd examples/ 14 | $ python image_augmentation.py 15 | ``` 16 | You can also view an interactive version of the examples on [Colab](./image_augmentation.ipynb). 17 | -------------------------------------------------------------------------------- /examples/assets/adjust_brightness_jax_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/examples/assets/adjust_brightness_jax_logo.jpg -------------------------------------------------------------------------------- /examples/assets/adjust_contrast_jax_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/examples/assets/adjust_contrast_jax_logo.jpg -------------------------------------------------------------------------------- /examples/assets/adjust_gamma_jax_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/examples/assets/adjust_gamma_jax_logo.jpg -------------------------------------------------------------------------------- /examples/assets/adjust_hue_jax_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/examples/assets/adjust_hue_jax_logo.jpg -------------------------------------------------------------------------------- /examples/assets/flip_left_right_jax_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/examples/assets/flip_left_right_jax_logo.jpg -------------------------------------------------------------------------------- /examples/assets/flip_up_down_jax_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/examples/assets/flip_up_down_jax_logo.jpg -------------------------------------------------------------------------------- /examples/assets/gaussian_blur_jax_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/examples/assets/gaussian_blur_jax_logo.jpg -------------------------------------------------------------------------------- /examples/assets/jax_logo.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/google-deepmind/dm_pix/ae478298f3dc6f611fb968fe2e831e8c9925b243/examples/assets/jax_logo.jpg -------------------------------------------------------------------------------- /examples/image_augmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Examples of image augmentations with Pix.""" 16 | 17 | from absl import app 18 | import dm_pix as pix 19 | import jax.numpy as jnp 20 | import numpy as np 21 | import PIL.Image as pil 22 | 23 | _KERNEL_SIGMA = 5 24 | _KERNEL_SIZE = 5 25 | _MAGIC_VALUE = 0.42 26 | 27 | 28 | def main(_) -> None: 29 | # Load an image. 30 | image = _get_image() 31 | 32 | # Flip up-down the image and visual it. 33 | flip_up_down_image = pix.flip_up_down(image=image) 34 | _imshow(flip_up_down_image) 35 | 36 | # Apply a Gaussian filter to the image and visual it. 37 | gaussian_blur_image = pix.gaussian_blur( 38 | image=image, 39 | sigma=_KERNEL_SIGMA, 40 | kernel_size=_KERNEL_SIZE, 41 | ) 42 | _imshow(gaussian_blur_image) 43 | 44 | # Change image brightness and visual it. 45 | adjust_brightness_image = pix.adjust_brightness( 46 | image=image, 47 | delta=_MAGIC_VALUE, 48 | ) 49 | _imshow(adjust_brightness_image) 50 | 51 | # Change image contrast and visual it. 52 | adjust_contrast_image = pix.adjust_contrast( 53 | image=image, 54 | factor=_MAGIC_VALUE, 55 | ) 56 | _imshow(adjust_contrast_image) 57 | 58 | # Change image gamma and visual it. 59 | adjust_gamma_image = pix.adjust_gamma( 60 | image=image, 61 | gamma=_MAGIC_VALUE, 62 | ) 63 | _imshow(adjust_gamma_image) 64 | 65 | # Change image hue and visual it. 66 | adjust_hue_image = pix.adjust_hue( 67 | image=image, 68 | delta=_MAGIC_VALUE, 69 | ) 70 | _imshow(adjust_hue_image) 71 | 72 | 73 | def _get_image(): 74 | return jnp.array(pil.open("./assets/jax_logo.jpg"), dtype=jnp.float32) / 255. 75 | 76 | 77 | def _imshow(image: jnp.ndarray) -> None: 78 | """Showes the input image using PIL/Pillow backend.""" 79 | image = pil.fromarray(np.asarray(image * 255.).astype(np.uint8), "RGB") 80 | image.show() 81 | 82 | 83 | if __name__ == "__main__": 84 | app.run(main) 85 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["flit_core >=3.2,<4"] 3 | build-backend = "flit_core.buildapi" 4 | 5 | [project] 6 | name = "dm_pix" 7 | dynamic = ["version"] 8 | description = 'PIX is an image processing library in JAX, for JAX.' 9 | readme = "README.md" 10 | license = { file = "LICENSE" } 11 | requires-python = ">=3.8" 12 | authors = [ 13 | {name = "Google DeepMind", email = "pix-dev@google.com"}, 14 | ] 15 | classifiers = [ 16 | "Development Status :: 4 - Beta", 17 | "Intended Audience :: Developers", 18 | "Intended Audience :: Education", 19 | "Intended Audience :: Science/Research", 20 | "License :: OSI Approved :: Apache Software License", 21 | "Programming Language :: Python", 22 | "Programming Language :: Python :: 3", 23 | "Topic :: Scientific/Engineering :: Artificial Intelligence", 24 | "Topic :: Scientific/Engineering :: Image Processing", 25 | "Topic :: Scientific/Engineering :: Mathematics", 26 | "Topic :: Software Development :: Libraries", 27 | "Topic :: Software Development :: Libraries :: Python Modules", 28 | ] 29 | dependencies = [ 30 | "chex>=0.0.6", 31 | # jax>=0.2.17 32 | # jaxlib>=0.1.69 33 | ] 34 | 35 | [project.optional-dependencies] 36 | extras = [ 37 | "jax>=0.2.17", 38 | "jaxlib>=0.1.69", 39 | ] 40 | test = [ 41 | "scipy", 42 | "tensorflow", 43 | "pytest-xdist", 44 | ] 45 | docs = [ 46 | "sphinx>=6.0.0", 47 | "sphinx_rtd_theme", 48 | "sphinxcontrib-katex", 49 | "sphinxcontrib-bibtex", 50 | "sphinx-autodoc-typehints", 51 | "IPython", 52 | "ipykernel", 53 | "pandoc", 54 | "myst_nb", 55 | "docutils", 56 | "matplotlib", 57 | ] 58 | examples = [ 59 | "Pillow", 60 | ] 61 | 62 | [tool.setuptools.packages.find] 63 | include=["dm_pix/py.typed"] 64 | exclude = ["*_test.py", "examples"] 65 | -------------------------------------------------------------------------------- /readthedocs.yml: -------------------------------------------------------------------------------- 1 | # Read the Docs configuration file 2 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 3 | 4 | version: 2 5 | 6 | build: 7 | os: ubuntu-22.04 8 | tools: 9 | python: "3.11" 10 | 11 | sphinx: 12 | builder: html 13 | configuration: docs/conf.py 14 | fail_on_warning: false 15 | 16 | python: 17 | install: 18 | - method: pip 19 | path: . 20 | - method: pip 21 | path: . 22 | extra_requirements: 23 | - docs 24 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | set -e 17 | 18 | readonly VENV_DIR=/tmp/dm_pix_test_env 19 | echo "Creating virtual environment under ${VENV_DIR}." 20 | echo "You might want to remove this when you no longer need it." 21 | 22 | # Install deps in a virtual env. 23 | python -m venv "${VENV_DIR}" 24 | source "${VENV_DIR}/bin/activate" 25 | python --version 26 | 27 | # Update pip + setuptools and install dependencies specified in pyproject.toml 28 | python -m pip install --upgrade pip setuptools 29 | python -m pip install . 30 | 31 | # print jax version 32 | python -c 'import jax; print(jax.__version__)' 33 | 34 | # Python test dependencies. 35 | python -m pip install -e ".[test]" 36 | 37 | # CPU count on macos or linux 38 | if [ "$(uname)" == "Darwin" ]; then 39 | N_JOBS=$(sysctl -n hw.logicalcpu) 40 | else 41 | N_JOBS=$(grep -c ^processor /proc/cpuinfo) 42 | fi 43 | 44 | # Run tests using pytest. 45 | python -m pytest -n "${N_JOBS}" dm_pix 46 | --------------------------------------------------------------------------------