├── .gitignore ├── .pylintrc ├── LICENSE ├── MANIFEST.in ├── README.md ├── clrs ├── __init__.py ├── _src │ ├── __init__.py │ ├── algorithms │ │ ├── __init__.py │ │ ├── divide_and_conquer.py │ │ ├── divide_and_conquer_test.py │ │ ├── dynamic_programming.py │ │ ├── dynamic_programming_test.py │ │ ├── geometry.py │ │ ├── geometry_test.py │ │ ├── graphs.py │ │ ├── graphs_test.py │ │ ├── greedy.py │ │ ├── greedy_test.py │ │ ├── searching.py │ │ ├── searching_test.py │ │ ├── sorting.py │ │ ├── sorting_test.py │ │ ├── strings.py │ │ └── strings_test.py │ ├── baselines.py │ ├── baselines_test.py │ ├── dataset.py │ ├── dataset_test.py │ ├── decoders.py │ ├── encoders.py │ ├── losses.py │ ├── losses_test.py │ ├── model.py │ ├── nets.py │ ├── probing.py │ ├── probing_test.py │ ├── processors.py │ ├── processors_test.py │ ├── samplers.py │ ├── samplers_test.py │ ├── scratch.py │ ├── specs.py │ └── third_party │ │ ├── __init__.py │ │ └── haiku_transformer.py ├── clrs_test.py ├── examples │ └── run.py ├── models.py └── py.typed ├── requirements └── requirements.txt ├── scripts ├── make_datasets.sh └── run_experiments.sh └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Building and releasing library: 2 | *.egg-info 3 | *.pyc 4 | *.so 5 | build/ 6 | dist/ 7 | venv/ 8 | 9 | # Mac OS 10 | .DS_Store 11 | 12 | # Python tools 13 | .mypy_cache/ 14 | .pytype/ 15 | .ipynb_checkpoints 16 | 17 | # Editors 18 | .idea 19 | .vscode 20 | 21 | wandb/ 22 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Add files or directories to the blacklist. They should be base names, not 11 | # paths. 12 | ignore=third_party 13 | 14 | # Add files or directories matching the regex patterns to the blacklist. The 15 | # regex matches against base names, not paths. 16 | ignore-patterns= 17 | 18 | # Pickle collected data for later comparisons. 19 | persistent=no 20 | 21 | # List of plugins (as comma separated values of python modules names) to load, 22 | # usually to register additional checkers. 23 | load-plugins= 24 | 25 | # Use multiple processes to speed up Pylint. 26 | jobs=4 27 | 28 | # Allow loading of arbitrary C extensions. Extensions are imported into the 29 | # active Python interpreter and may run arbitrary code. 30 | unsafe-load-any-extension=no 31 | 32 | # A comma-separated list of package or module names from where C extensions may 33 | # be loaded. Extensions are loading into the active Python interpreter and may 34 | # run arbitrary code 35 | extension-pkg-whitelist= 36 | 37 | 38 | [MESSAGES CONTROL] 39 | 40 | # Only show warnings with the listed confidence levels. Leave empty to show 41 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 42 | confidence= 43 | 44 | # Enable the message, report, category or checker with the given id(s). You can 45 | # either give multiple identifier separated by comma (,) or put this option 46 | # multiple time (only on the command line, not in the configuration file where 47 | # it should appear only once). See also the "--disable" option for examples. 48 | #enable= 49 | 50 | # Disable the message, report, category or checker with the given id(s). You 51 | # can either give multiple identifiers separated by comma (,) or put this 52 | # option multiple times (only on the command line, not in the configuration 53 | # file where it should appear only once).You can also use "--disable=all" to 54 | # disable everything first and then reenable specific checks. For example, if 55 | # you want to run only the similarities checker, you can use "--disable=all 56 | # --enable=similarities". If you want to run only the classes checker, but have 57 | # no Warning level messages displayed, use"--disable=all --enable=classes 58 | # --disable=W" 59 | disable=apply-builtin, 60 | attribute-defined-outside-init, 61 | backtick, 62 | bad-option-value, 63 | buffer-builtin, 64 | c-extension-no-member, 65 | cmp-builtin, 66 | cmp-method, 67 | coerce-builtin, 68 | coerce-method, 69 | delslice-method, 70 | div-method, 71 | duplicate-code, 72 | eq-without-hash, 73 | execfile-builtin, 74 | file-builtin, 75 | filter-builtin-not-iterating, 76 | fixme, 77 | getslice-method, 78 | global-statement, 79 | hex-method, 80 | idiv-method, 81 | implicit-str-concat-in-sequence, 82 | import-error, 83 | import-self, 84 | import-star-module-level, 85 | input-builtin, 86 | intern-builtin, 87 | invalid-str-codec, 88 | invalid-unary-operand-type, 89 | locally-disabled, 90 | long-builtin, 91 | long-suffix, 92 | map-builtin-not-iterating, 93 | metaclass-assignment, 94 | next-method-called, 95 | next-method-defined, 96 | no-absolute-import, 97 | no-else-break, 98 | no-else-continue, 99 | no-else-raise, 100 | no-else-return, 101 | no-member, 102 | no-self-use, 103 | nonzero-method, 104 | oct-method, 105 | old-division, 106 | old-ne-operator, 107 | old-octal-literal, 108 | old-raise-syntax, 109 | parameter-unpacking, 110 | print-statement, 111 | raising-string, 112 | range-builtin-not-iterating, 113 | raw_input-builtin, 114 | rdiv-method, 115 | reduce-builtin, 116 | relative-import, 117 | reload-builtin, 118 | round-builtin, 119 | setslice-method, 120 | signature-differs, 121 | standarderror-builtin, 122 | suppressed-message, 123 | sys-max-int, 124 | too-few-public-methods, 125 | too-many-ancestors, 126 | too-many-arguments, 127 | too-many-boolean-expressions, 128 | too-many-branches, 129 | too-many-instance-attributes, 130 | too-many-locals, 131 | too-many-public-methods, 132 | too-many-return-statements, 133 | too-many-statements, 134 | trailing-newlines, 135 | unichr-builtin, 136 | unicode-builtin, 137 | unpacking-in-except, 138 | useless-else-on-loop, 139 | useless-suppression, 140 | using-cmp-argument, 141 | xrange-builtin, 142 | wrong-import-order, 143 | zip-builtin-not-iterating, 144 | 145 | 146 | [REPORTS] 147 | 148 | # Set the output format. Available formats are text, parseable, colorized, msvs 149 | # (visual studio) and html. You can also give a reporter class, eg 150 | # mypackage.mymodule.MyReporterClass. 151 | output-format=text 152 | 153 | # Put messages in a separate file for each module / package specified on the 154 | # command line instead of printing them on stdout. Reports (if any) will be 155 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 156 | # and it will be removed in Pylint 2.0. 157 | files-output=no 158 | 159 | # Tells whether to display a full report or only the messages 160 | reports=no 161 | 162 | # Python expression which should return a note less than 10 (10 is the highest 163 | # note). You have access to the variables errors warning, statement which 164 | # respectively contain the number of errors / warnings messages and the total 165 | # number of statements analyzed. This is used by the global evaluation report 166 | # (RP0004). 167 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 168 | 169 | # Template used to display messages. This is a python new-style format string 170 | # used to format the message information. See doc for all details 171 | #msg-template= 172 | 173 | 174 | [BASIC] 175 | 176 | # Good variable names which should always be accepted, separated by a comma 177 | good-names=main,_ 178 | 179 | # Bad variable names which should always be refused, separated by a comma 180 | bad-names= 181 | 182 | # Colon-delimited sets of names that determine each other's naming style when 183 | # the name regexes allow several styles. 184 | name-group= 185 | 186 | # Include a hint for the correct naming format with invalid-name 187 | include-naming-hint=no 188 | 189 | # List of decorators that produce properties, such as abc.abstractproperty. Add 190 | # to this list to register other decorators that produce valid properties. 191 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 192 | 193 | # Regular expression matching correct function names 194 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 195 | 196 | # Regular expression matching correct variable names 197 | variable-rgx=^[a-z][a-z0-9_]*$ 198 | 199 | # Regular expression matching correct constant names 200 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 201 | 202 | # Regular expression matching correct attribute names 203 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 204 | 205 | # Regular expression matching correct argument names 206 | argument-rgx=^[a-z][a-z0-9_]*$ 207 | 208 | # Regular expression matching correct class attribute names 209 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 210 | 211 | # Regular expression matching correct inline iteration names 212 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 213 | 214 | # Regular expression matching correct class names 215 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 216 | 217 | # Regular expression matching correct module names 218 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 219 | 220 | # Regular expression matching correct method names 221 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 222 | 223 | # Regular expression which should only match function or class names that do 224 | # not require a docstring. 225 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 226 | 227 | # Minimum line length for functions/classes that require docstrings, shorter 228 | # ones are exempt. 229 | docstring-min-length=10 230 | 231 | 232 | [TYPECHECK] 233 | 234 | # List of decorators that produce context managers, such as 235 | # contextlib.contextmanager. Add to this list to register other decorators that 236 | # produce valid context managers. 237 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 238 | 239 | # Tells whether missing members accessed in mixin class should be ignored. A 240 | # mixin class is detected if its name ends with "mixin" (case insensitive). 241 | ignore-mixin-members=yes 242 | 243 | # List of module names for which member attributes should not be checked 244 | # (useful for modules/projects where namespaces are manipulated during runtime 245 | # and thus existing member attributes cannot be deduced by static analysis. It 246 | # supports qualified module names, as well as Unix pattern matching. 247 | ignored-modules= 248 | 249 | # List of class names for which member attributes should not be checked (useful 250 | # for classes with dynamically set attributes). This supports the use of 251 | # qualified names. 252 | ignored-classes=optparse.Values,thread._local,_thread._local 253 | 254 | # List of members which are set dynamically and missed by pylint inference 255 | # system, and so shouldn't trigger E1101 when accessed. Python regular 256 | # expressions are accepted. 257 | generated-members= 258 | 259 | # List of decorators that change the signature of a decorated function. 260 | signature-mutators=toolz.functoolz.curry 261 | 262 | [FORMAT] 263 | 264 | # Maximum number of characters on a single line. 265 | max-line-length=80 266 | 267 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 268 | # lines made too long by directives to pytype. 269 | 270 | # Regexp for a line that is allowed to be longer than the limit. 271 | ignore-long-lines=(?x)( 272 | ^\s*(\#\ )??$| 273 | ^\s*(from\s+\S+\s+)?import\s+.+$) 274 | 275 | # Allow the body of an if to be on the same line as the test if there is no 276 | # else. 277 | single-line-if-stmt=yes 278 | 279 | # List of optional constructs for which whitespace checking is disabled. `dict- 280 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 281 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 282 | # `empty-line` allows space-only lines. 283 | no-space-check= 284 | 285 | # Maximum number of lines in a module 286 | max-module-lines=99999 287 | 288 | # String used as indentation unit. The internal Google style guide mandates 2 289 | # spaces. Google's externaly-published style guide says 4, consistent with 290 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 291 | # projects (like TensorFlow). 292 | indent-string=' ' 293 | 294 | # Number of spaces of indent required inside a hanging or continued line. 295 | indent-after-paren=4 296 | 297 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 298 | expected-line-ending-format= 299 | 300 | 301 | [MISCELLANEOUS] 302 | 303 | # List of note tags to take in consideration, separated by a comma. 304 | notes=TODO 305 | 306 | 307 | [VARIABLES] 308 | 309 | # Tells whether we should check for unused import in __init__ files. 310 | init-import=no 311 | 312 | # A regular expression matching the name of dummy variables (i.e. expectedly 313 | # not used). 314 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 315 | 316 | # List of additional names supposed to be defined in builtins. Remember that 317 | # you should avoid to define new builtins when possible. 318 | additional-builtins= 319 | 320 | # List of strings which can identify a callback function by name. A callback 321 | # name must start or end with one of those strings. 322 | callbacks=cb_,_cb 323 | 324 | # List of qualified module names which can have objects that can redefine 325 | # builtins. 326 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 327 | 328 | 329 | [LOGGING] 330 | 331 | # Logging modules to check that the string format arguments are in logging 332 | # function parameter format 333 | logging-modules=logging,absl.logging,tensorflow.google.logging 334 | 335 | 336 | [SIMILARITIES] 337 | 338 | # Minimum lines number of a similarity. 339 | min-similarity-lines=4 340 | 341 | # Ignore comments when computing similarities. 342 | ignore-comments=yes 343 | 344 | # Ignore docstrings when computing similarities. 345 | ignore-docstrings=yes 346 | 347 | # Ignore imports when computing similarities. 348 | ignore-imports=no 349 | 350 | 351 | [SPELLING] 352 | 353 | # Spelling dictionary name. Available dictionaries: none. To make it working 354 | # install python-enchant package. 355 | spelling-dict= 356 | 357 | # List of comma separated words that should not be checked. 358 | spelling-ignore-words= 359 | 360 | # A path to a file that contains private dictionary; one word per line. 361 | spelling-private-dict-file= 362 | 363 | # Tells whether to store unknown words to indicated private dictionary in 364 | # --spelling-private-dict-file option instead of raising a message. 365 | spelling-store-unknown-words=no 366 | 367 | 368 | [IMPORTS] 369 | 370 | # Deprecated modules which should not be used, separated by a comma 371 | deprecated-modules=regsub, 372 | TERMIOS, 373 | Bastion, 374 | rexec, 375 | sets 376 | 377 | # Create a graph of every (i.e. internal and external) dependencies in the 378 | # given file (report RP0402 must not be disabled) 379 | import-graph= 380 | 381 | # Create a graph of external dependencies in the given file (report RP0402 must 382 | # not be disabled) 383 | ext-import-graph= 384 | 385 | # Create a graph of internal dependencies in the given file (report RP0402 must 386 | # not be disabled) 387 | int-import-graph= 388 | 389 | # Force import order to recognize a module as part of the standard 390 | # compatibility libraries. 391 | known-standard-library= 392 | 393 | # Force import order to recognize a module as part of a third party library. 394 | known-third-party=enchant, absl 395 | 396 | # Analyse import fallback blocks. This can be used to support both Python 2 and 397 | # 3 compatible code, which means that the block might have code that exists 398 | # only in one or another interpreter, leading to false positives when analysed. 399 | analyse-fallback-blocks=no 400 | 401 | 402 | [CLASSES] 403 | 404 | # List of method names used to declare (i.e. assign) instance attributes. 405 | defining-attr-methods=__init__, 406 | __new__, 407 | setUp 408 | 409 | # List of member names, which should be excluded from the protected access 410 | # warning. 411 | exclude-protected=_asdict, 412 | _fields, 413 | _replace, 414 | _source, 415 | _make 416 | 417 | # List of valid names for the first argument in a class method. 418 | valid-classmethod-first-arg=cls, 419 | class_ 420 | 421 | # List of valid names for the first argument in a metaclass class method. 422 | valid-metaclass-classmethod-first-arg=mcs 423 | 424 | 425 | [EXCEPTIONS] 426 | 427 | # Exceptions that will emit a warning when being caught. Defaults to 428 | # "Exception" 429 | overgeneral-exceptions=StandardError, 430 | Exception, 431 | BaseException 432 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include README.md 2 | include LICENSE 3 | include requirements/* 4 | include clrs/py.typed 5 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Towards Better Out-of-Distribution Generalization of Neural Algorithmic Reasoning Tasks 2 | This is the code to "Towards Better Out-of-Distribution Generalization of Neural Algorithmic Reasoning Tasks". 3 | 4 | ## Dependencies 5 | 6 | Please refer to CLRS benchmark's [documentation](https://github.com/deepmind/clrs/tree/b3bd8d1e912b5333964e0de49842a2019739fc53) for instructions on how to install this library and dependencies. Note that a separate virtualenv is needed for installing package from this repository if you have already installed CLRS. 7 | 8 | ## Generating Datasets 9 | 10 | To generate the datasets reported in our paper, you should run the following command: 11 | 12 | ```console 13 | CLRS_VENV_PATH=/path/to/venv \ 14 | CLRS_ROOT=/path/to/clrs \ 15 | CLRS_DATASET_PATH=/path/to/clrs/dataset \ 16 | ./scripts/make_datasets.sh 17 | ``` 18 | 19 | where ```CLRS_VENV_PATH``` is the path to the virtual environment containing CRLS, ```CLRS_ROOT``` is the path to the CLRS code root, and ```CLRS_DATASET_PATH``` is the desired directory for creating the CLRS dataset. 20 | 21 | 22 | ## Reproducing the Results 23 | 24 | First, generate the datasets using the above command. Then, run the following script: 25 | 26 | ```console 27 | CLRS_DATASET_PATH=/path/to/saved/datasets \ 28 | CLRS_LOG_PATH=/tmp/clrs_logs \ 29 | CLRS_CHECKPOINT_PATH=/tmp/clrs_checkpoints \ 30 | ./scripts/run_experiments.sh 31 | ``` 32 | 33 | where ```CLRS_DATASET_PATH``` is the path to generated datasets (created in above command), ```CLRS_LOG_PATH``` is the path for saving run logs, and ```CLRS_CHECKPOINT_PATH``` is the path for saving checkpoints. 34 | 35 | The raw logs will be saved in the mentioned logging directory. In addition to that, WandB summaries will also be saved locally. 36 | 37 | ## Colab 38 | 39 | We also provide a simplified notebook in Google Colab, which contains data generation, training, and visualization for BFS algorithm. 40 | 41 | Open In Colab 42 | 43 | 44 | ## Notes 45 | 46 | * Datasets are generated without any hints. This saves 10X space and allows for larger number of samples to be generated. In case you would like the hints to be included, you need to disable ```CLRS_DISABLE_HINTS``` flag inside the dataset generation script. However, there are no mechanisms for efficient dataset generation at the moment, and you need a high amount of RAM for dataset generation. 47 | * Although the ```seed``` variable controls the seed for neural network initialization, the order in which batches are created is not controlled. As a result, exact reproduction of numbers is not guaranteed in the CLRS benchmark. 48 | * The experiments script runs all the experiments in a sequential order. You might want to make them parallel depending on your cluster system. 49 | * The code base also contains the ideas mentioned in appendix of the paper, but are deactivated by default. 50 | * The logs are by default written to their corresponding log directory and WandB locally. In case of debugging, you can activate the ```debug``` flag to see the logs in stdout and disable WandB logging. 51 | * The values of parameters in Google Colab notebook are changed and simplified to facilitate readability and quick training. For complete set of parameters used in our paper please refer to out GitHub code base. 52 | 53 | ## Cite 54 | Please consider citing our paper if you use this code in your research work: 55 | 56 | ``` 57 | @article{ 58 | mahdavi2023towards, 59 | title={Towards Better Out-of-Distribution Generalization of Neural Algorithmic Reasoning Tasks}, 60 | author={Sadegh Mahdavi and Kevin Swersky and Thomas Kipf and Milad Hashemi and Christos Thrampoulidis and Renjie Liao}, 61 | journal={Transactions on Machine Learning Research}, 62 | issn={2835-8856}, 63 | year={2023}, 64 | url={https://openreview.net/forum?id=xkrtvHlp3P}, 65 | note={} 66 | } 67 | ``` 68 | 69 | ## Questions/Bugs 70 | Please submit a Github issue or contact smahdavi@ece.ubc.ca if you have any questions or find any bugs. 71 | -------------------------------------------------------------------------------- /clrs/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """The CLRS Algorithmic Reasoning Benchmark.""" 17 | 18 | from clrs import models 19 | from clrs._src import scratch 20 | from clrs._src import algorithms 21 | from clrs._src import decoders 22 | from clrs._src import processors 23 | from clrs._src.dataset import CLRSDataset 24 | from clrs._src.dataset import create_chunked_dataset 25 | from clrs._src.dataset import create_dataset 26 | from clrs._src.dataset import get_clrs_folder 27 | from clrs._src.dataset import get_dataset_gcp_url 28 | from clrs._src.model import evaluate 29 | from clrs._src.model import evaluate_hints 30 | from clrs._src.model import Model 31 | from clrs._src.probing import DataPoint 32 | from clrs._src.processors import get_processor_factory 33 | from clrs._src.samplers import build_sampler 34 | from clrs._src.samplers import CLRS30 35 | from clrs._src.samplers import Features 36 | from clrs._src.samplers import Feedback 37 | from clrs._src.samplers import Sampler 38 | from clrs._src.samplers import Trajectory 39 | from clrs._src.specs import CLRS_30_ALGS 40 | from clrs._src.specs import Location 41 | from clrs._src.specs import OutputClass 42 | from clrs._src.specs import Spec 43 | from clrs._src.specs import SPECS 44 | from clrs._src.specs import Stage 45 | from clrs._src.specs import Type 46 | 47 | __version__ = "1.0.0" 48 | 49 | __all__ = ( 50 | "build_sampler", 51 | "CLRS30", 52 | "create_chunked_dataset", 53 | "create_dataset", 54 | "get_clrs_folder", 55 | "get_dataset_gcp_url", 56 | "get_processor_factory", 57 | "DataPoint", 58 | "evaluate", 59 | "evaluate_hints", 60 | "Features", 61 | "Feedback", 62 | "Location", 63 | "Model", 64 | "Sampler", 65 | "Spec", 66 | "SPECS", 67 | "Stage", 68 | "Trajectory", 69 | "Type", 70 | "scratch", 71 | ) 72 | -------------------------------------------------------------------------------- /clrs/_src/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """CLRS algorithm implementations.""" 17 | 18 | # pylint:disable=g-bad-import-order 19 | 20 | from clrs._src.algorithms.divide_and_conquer import find_maximum_subarray 21 | from clrs._src.algorithms.divide_and_conquer import find_maximum_subarray_kadane 22 | 23 | from clrs._src.algorithms.dynamic_programming import matrix_chain_order 24 | from clrs._src.algorithms.dynamic_programming import lcs_length 25 | from clrs._src.algorithms.dynamic_programming import optimal_bst 26 | 27 | from clrs._src.algorithms.geometry import segments_intersect 28 | from clrs._src.algorithms.geometry import graham_scan 29 | from clrs._src.algorithms.geometry import jarvis_march 30 | 31 | from clrs._src.algorithms.graphs import dfs 32 | from clrs._src.algorithms.graphs import bfs 33 | from clrs._src.algorithms.graphs import topological_sort 34 | from clrs._src.algorithms.graphs import articulation_points 35 | from clrs._src.algorithms.graphs import bridges 36 | from clrs._src.algorithms.graphs import strongly_connected_components 37 | from clrs._src.algorithms.graphs import mst_kruskal 38 | from clrs._src.algorithms.graphs import mst_prim 39 | from clrs._src.algorithms.graphs import bellman_ford 40 | from clrs._src.algorithms.graphs import dijkstra 41 | from clrs._src.algorithms.graphs import dag_shortest_paths 42 | from clrs._src.algorithms.graphs import floyd_warshall 43 | from clrs._src.algorithms.graphs import bipartite_matching 44 | 45 | from clrs._src.algorithms.greedy import activity_selector 46 | from clrs._src.algorithms.greedy import task_scheduling 47 | 48 | from clrs._src.algorithms.searching import minimum 49 | from clrs._src.algorithms.searching import binary_search 50 | from clrs._src.algorithms.searching import quickselect 51 | 52 | from clrs._src.algorithms.sorting import insertion_sort 53 | from clrs._src.algorithms.sorting import bubble_sort 54 | from clrs._src.algorithms.sorting import heapsort 55 | from clrs._src.algorithms.sorting import quicksort 56 | 57 | from clrs._src.algorithms.strings import naive_string_matcher 58 | from clrs._src.algorithms.strings import kmp_matcher 59 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/divide_and_conquer_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `divide_and_conquer.py`.""" 17 | # pylint: disable=invalid-name 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | 22 | from clrs._src.algorithms import divide_and_conquer 23 | import numpy as np 24 | 25 | 26 | class DivideAndConquerTest(parameterized.TestCase): 27 | 28 | @parameterized.named_parameters( 29 | ("Maximum subarray", divide_and_conquer.find_maximum_subarray), 30 | ("Kadane's variant", divide_and_conquer.find_maximum_subarray_kadane), 31 | ) 32 | def test_find_maximum_subarray_pos(self, algorithm): 33 | A = np.random.randint(0, 100, size=(13,)) 34 | (low, high, sum_), _ = algorithm(A) 35 | self.assertEqual(low, 0) 36 | self.assertEqual(high, len(A) - 1) 37 | self.assertEqual(sum_, np.sum(A)) 38 | 39 | @parameterized.named_parameters( 40 | ("Maximum subarray", divide_and_conquer.find_maximum_subarray), 41 | ("Kadane's variant", divide_and_conquer.find_maximum_subarray_kadane), 42 | ) 43 | def test_find_maximum_subarray(self, algorithm): 44 | A = np.random.randint(-100, 100, size=(13,)) 45 | (low, high, sum_), _ = algorithm(A.copy()) 46 | 47 | # Brute force solution. 48 | best = (0, len(A) - 1) 49 | best_sum = np.sum(A) 50 | for start in range(len(A)): 51 | for stop in range(start, len(A)): 52 | range_sum = np.sum(A[start:stop + 1]) 53 | if range_sum > best_sum: 54 | best = (start, stop) 55 | best_sum = range_sum 56 | 57 | self.assertEqual(low, best[0]) 58 | self.assertEqual(high, best[1]) 59 | self.assertEqual(sum_, best_sum) 60 | 61 | 62 | if __name__ == "__main__": 63 | absltest.main() 64 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/dynamic_programming.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Dynamic programming algorithm generators. 17 | 18 | Currently implements the following: 19 | - Matrix-chain multiplication 20 | - Longest common subsequence 21 | - Optimal binary search tree (Aho et al., 1974) 22 | 23 | See "Introduction to Algorithms" 3ed (CLRS3) for more information. 24 | 25 | """ 26 | # pylint: disable=invalid-name 27 | 28 | 29 | from typing import Tuple 30 | 31 | import chex 32 | from clrs._src import probing 33 | from clrs._src import specs 34 | import numpy as np 35 | 36 | 37 | _Array = np.ndarray 38 | _Out = Tuple[_Array, probing.ProbesDict] 39 | 40 | 41 | def matrix_chain_order(p: _Array) -> _Out: 42 | """Matrix-chain multiplication.""" 43 | 44 | chex.assert_rank(p, 1) 45 | probes = probing.initialize(specs.SPECS['matrix_chain_order']) 46 | 47 | A_pos = np.arange(p.shape[0]) 48 | 49 | probing.push( 50 | probes, 51 | specs.Stage.INPUT, 52 | next_probe={ 53 | 'pos': np.copy(A_pos) * 1.0 / A_pos.shape[0], 54 | 'p': np.copy(p) 55 | }) 56 | 57 | m = np.zeros((p.shape[0], p.shape[0])) 58 | s = np.zeros((p.shape[0], p.shape[0])) 59 | msk = np.zeros((p.shape[0], p.shape[0])) 60 | for i in range(1, p.shape[0]): 61 | m[i, i] = 0 62 | msk[i, i] = 1 63 | while True: 64 | prev_m = np.copy(m) 65 | prev_msk = np.copy(msk) 66 | probing.push( 67 | probes, 68 | specs.Stage.HINT, 69 | next_probe={ 70 | 'pred_h': probing.array(np.copy(A_pos)), 71 | 'm': np.copy(prev_m), 72 | 's_h': np.copy(s), 73 | 'msk': np.copy(msk) 74 | }) 75 | for i in range(1, p.shape[0]): 76 | for j in range(i + 1, p.shape[0]): 77 | flag = prev_msk[i, j] 78 | for k in range(i, j): 79 | if prev_msk[i, k] == 1 and prev_msk[k + 1, j] == 1: 80 | msk[i, j] = 1 81 | q = prev_m[i, k] + prev_m[k + 1, j] + p[i - 1] * p[k] * p[j] 82 | if flag == 0 or q < m[i, j]: 83 | m[i, j] = q 84 | s[i, j] = k 85 | flag = 1 86 | if np.all(prev_m == m): 87 | break 88 | 89 | probing.push(probes, specs.Stage.OUTPUT, next_probe={'s': np.copy(s)}) 90 | probing.finalize(probes) 91 | 92 | return s[1:, 1:], probes 93 | 94 | 95 | def lcs_length(x: _Array, y: _Array) -> _Out: 96 | """Longest common subsequence.""" 97 | chex.assert_rank([x, y], 1) 98 | probes = probing.initialize(specs.SPECS['lcs_length']) 99 | 100 | x_pos = np.arange(x.shape[0]) 101 | y_pos = np.arange(y.shape[0]) 102 | b = np.zeros((x.shape[0], y.shape[0])) 103 | c = np.zeros((x.shape[0], y.shape[0])) 104 | 105 | probing.push( 106 | probes, 107 | specs.Stage.INPUT, 108 | next_probe={ 109 | 'string': probing.strings_id(x_pos, y_pos), 110 | 'pos': probing.strings_pos(x_pos, y_pos), 111 | 'key': probing.array_cat(np.concatenate([np.copy(x), np.copy(y)]), 4) 112 | }) 113 | 114 | for i in range(x.shape[0]): 115 | if x[i] == y[0]: 116 | c[i, 0] = 1 117 | b[i, 0] = 0 118 | elif i > 0 and c[i - 1, 0] == 1: 119 | c[i, 0] = 1 120 | b[i, 0] = 1 121 | else: 122 | c[i, 0] = 0 123 | b[i, 0] = 1 124 | for j in range(y.shape[0]): 125 | if x[0] == y[j]: 126 | c[0, j] = 1 127 | b[0, j] = 0 128 | elif j > 0 and c[0, j - 1] == 1: 129 | c[0, j] = 1 130 | b[0, j] = 2 131 | else: 132 | c[0, j] = 0 133 | b[0, j] = 1 134 | 135 | while True: 136 | prev_c = np.copy(c) 137 | 138 | probing.push( 139 | probes, 140 | specs.Stage.HINT, 141 | next_probe={ 142 | 'pred_h': probing.strings_pred(x_pos, y_pos), 143 | 'b_h': probing.strings_pair_cat(np.copy(b), 3), 144 | 'c': probing.strings_pair(prev_c) 145 | }) 146 | 147 | for i in range(1, x.shape[0]): 148 | for j in range(1, y.shape[0]): 149 | if x[i] == y[j]: 150 | c[i, j] = prev_c[i - 1, j - 1] + 1 151 | b[i, j] = 0 152 | elif prev_c[i - 1, j] >= prev_c[i, j - 1]: 153 | c[i, j] = prev_c[i - 1, j] 154 | b[i, j] = 1 155 | else: 156 | c[i, j] = prev_c[i, j - 1] 157 | b[i, j] = 2 158 | if np.all(prev_c == c): 159 | break 160 | 161 | probing.push( 162 | probes, 163 | specs.Stage.OUTPUT, 164 | next_probe={'b': probing.strings_pair_cat(np.copy(b), 3)}) 165 | probing.finalize(probes) 166 | 167 | return b, probes 168 | 169 | 170 | def optimal_bst(p: _Array, q: _Array) -> _Out: 171 | """Optimal binary search tree (Aho et al., 1974).""" 172 | 173 | chex.assert_rank([p, q], 1) 174 | probes = probing.initialize(specs.SPECS['optimal_bst']) 175 | 176 | A_pos = np.arange(q.shape[0]) 177 | p_cpy = np.zeros(q.shape[0]) 178 | p_cpy[:-1] = np.copy(p) 179 | 180 | probing.push( 181 | probes, 182 | specs.Stage.INPUT, 183 | next_probe={ 184 | 'pos': np.copy(A_pos) * 1.0 / q.shape[0], 185 | 'p': np.copy(p_cpy), 186 | 'q': np.copy(q) 187 | }) 188 | 189 | e = np.zeros((q.shape[0], q.shape[0])) 190 | w = np.zeros((q.shape[0], q.shape[0])) 191 | root = np.zeros((q.shape[0], q.shape[0])) 192 | msks = np.zeros((q.shape[0], q.shape[0])) 193 | 194 | for i in range(q.shape[0]): 195 | e[i, i] = q[i] 196 | w[i, i] = q[i] 197 | msks[i, i] = 1 198 | 199 | probing.push( 200 | probes, 201 | specs.Stage.HINT, 202 | next_probe={ 203 | 'pred_h': probing.array(np.copy(A_pos)), 204 | 'root_h': np.copy(root), 205 | 'e': np.copy(e), 206 | 'w': np.copy(w), 207 | 'msk': np.copy(msks) 208 | }) 209 | 210 | for l in range(1, p.shape[0] + 1): 211 | for i in range(p.shape[0] - l + 1): 212 | j = i + l 213 | e[i, j] = 1e9 214 | w[i, j] = w[i, j - 1] + p[j - 1] + q[j] 215 | for r in range(i, j): 216 | t = e[i, r] + e[r + 1, j] + w[i, j] 217 | if t < e[i, j]: 218 | e[i, j] = t 219 | root[i, j] = r 220 | msks[i, j] = 1 221 | probing.push( 222 | probes, 223 | specs.Stage.HINT, 224 | next_probe={ 225 | 'pred_h': probing.array(np.copy(A_pos)), 226 | 'root_h': np.copy(root), 227 | 'e': np.copy(e), 228 | 'w': np.copy(w), 229 | 'msk': np.copy(msks) 230 | }) 231 | 232 | probing.push(probes, specs.Stage.OUTPUT, next_probe={'root': np.copy(root)}) 233 | probing.finalize(probes) 234 | 235 | return root, probes 236 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/dynamic_programming_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `dynamic_programming.py`.""" 17 | 18 | from absl.testing import absltest 19 | 20 | from clrs._src.algorithms import dynamic_programming 21 | import numpy as np 22 | 23 | 24 | class DynamicProgrammingTest(absltest.TestCase): 25 | 26 | def test_matrix_chain_order_1(self): 27 | 28 | expected = np.array([ 29 | [0, 1, 1, 3, 3, 3], 30 | [0, 0, 2, 3, 3, 3], 31 | [0, 0, 0, 3, 3, 3], 32 | [0, 0, 0, 0, 4, 5], 33 | [0, 0, 0, 0, 0, 5], 34 | [0, 0, 0, 0, 0, 0], 35 | ]) 36 | for shift in [0, 1, 2]: 37 | for scale in [1, 3, 17]: 38 | ps = shift + scale * np.array([30, 35, 15, 5, 10, 20, 25]) 39 | order, _ = dynamic_programming.matrix_chain_order(ps) 40 | np.testing.assert_array_equal(expected, order) 41 | 42 | def test_matrix_chain_order_2(self): 43 | 44 | expected = np.array([ 45 | [0, 1, 2, 2, 4, 2], 46 | [0, 0, 2, 2, 2, 2], 47 | [0, 0, 0, 3, 4, 4], 48 | [0, 0, 0, 0, 4, 4], 49 | [0, 0, 0, 0, 0, 5], 50 | [0, 0, 0, 0, 0, 0], 51 | ]) 52 | 53 | for shift in [0, 1]: 54 | for scale in [1, 3, 17]: 55 | ps = shift + scale * np.array([5, 10, 3, 12, 5, 50, 6]) 56 | order, _ = dynamic_programming.matrix_chain_order(ps) 57 | np.testing.assert_array_equal(expected, order) 58 | 59 | def test_lcs_length(self): 60 | xs = np.array([0, 1, 2, 1, 3, 0, 1]) 61 | ys = np.array([1, 3, 2, 0, 1, 0]) 62 | 63 | expected = np.array([ 64 | [1, 1, 1, 0, 2, 0], 65 | [0, 2, 2, 1, 0, 2], 66 | [1, 1, 0, 2, 1, 1], 67 | [0, 1, 1, 1, 0, 2], 68 | [1, 0, 1, 1, 1, 1], 69 | [1, 1, 1, 0, 1, 0], 70 | [0, 1, 1, 1, 0, 1], 71 | ]) 72 | out, _ = dynamic_programming.lcs_length(xs, ys) 73 | np.testing.assert_array_equal(expected, out) 74 | 75 | def test_optimal_bst(self): 76 | p = np.array([0.15, 0.10, 0.05, 0.10, 0.2]) 77 | q = np.array([0.05, 0.10, 0.05, 0.05, 0.05, 0.10]) 78 | assert p.sum() + q.sum() == 1. 79 | 80 | expected = np.array([ 81 | [0, 0, 0, 1, 1, 1], 82 | [0, 0, 1, 1, 1, 3], 83 | [0, 0, 0, 2, 3, 4], 84 | [0, 0, 0, 0, 3, 4], 85 | [0, 0, 0, 0, 0, 4], 86 | [0, 0, 0, 0, 0, 0], 87 | ]) 88 | 89 | out, _ = dynamic_programming.optimal_bst(p, q) 90 | np.testing.assert_array_equal(expected, out) 91 | 92 | 93 | if __name__ == "__main__": 94 | absltest.main() 95 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/geometry.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Geometry algorithm generators. 17 | 18 | Currently implements the following: 19 | - Segment intersection 20 | - Graham scan convex hull (Graham, 1972) 21 | - Jarvis' march convex hull (Jarvis, 1973) 22 | 23 | See "Introduction to Algorithms" 3ed (CLRS3) for more information. 24 | 25 | """ 26 | # pylint: disable=invalid-name 27 | 28 | import math 29 | from typing import Any, Tuple 30 | 31 | import chex 32 | from clrs._src import probing 33 | from clrs._src import specs 34 | import numpy as np 35 | 36 | 37 | _Array = np.ndarray 38 | _Out = Tuple[Any, probing.ProbesDict] 39 | 40 | 41 | def segments_intersect(xs: _Array, ys: _Array) -> _Out: 42 | """Segment intersection.""" 43 | 44 | assert xs.shape == (4,) 45 | assert ys.shape == (4,) 46 | probes = probing.initialize(specs.SPECS['segments_intersect']) 47 | 48 | A_pos = np.arange(xs.shape[0]) 49 | dirs = np.zeros(xs.shape[0]) 50 | on_seg = np.zeros(xs.shape[0]) 51 | 52 | probing.push( 53 | probes, 54 | specs.Stage.INPUT, 55 | next_probe={ 56 | 'pos': np.copy(A_pos) * 1.0 / A_pos.shape[0], 57 | 'x': np.copy(xs), 58 | 'y': np.copy(ys) 59 | }) 60 | 61 | probing.push( 62 | probes, 63 | specs.Stage.HINT, 64 | next_probe={ 65 | 'i': probing.mask_one(0, xs.shape[0]), 66 | 'j': probing.mask_one(0, xs.shape[0]), 67 | 'k': probing.mask_one(0, xs.shape[0]), 68 | 'dir': np.copy(dirs), 69 | 'on_seg': np.copy(on_seg) 70 | }) 71 | 72 | def cross_product(x1, y1, x2, y2): 73 | return x1 * y2 - x2 * y1 74 | 75 | def direction(xs, ys, i, j, k): 76 | return cross_product(xs[k] - xs[i], ys[k] - ys[i], xs[j] - xs[i], 77 | ys[j] - ys[i]) 78 | 79 | def on_segment(xs, ys, i, j, k): 80 | if min(xs[i], xs[j]) <= xs[k] and xs[k] <= max(xs[i], xs[j]): 81 | if min(ys[i], ys[j]) <= ys[k] and ys[k] <= max(ys[i], ys[j]): 82 | return 1 83 | return 0 84 | 85 | dirs[0] = direction(xs, ys, 2, 3, 0) 86 | on_seg[0] = on_segment(xs, ys, 2, 3, 0) 87 | 88 | probing.push( 89 | probes, 90 | specs.Stage.HINT, 91 | next_probe={ 92 | 'i': probing.mask_one(2, xs.shape[0]), 93 | 'j': probing.mask_one(3, xs.shape[0]), 94 | 'k': probing.mask_one(0, xs.shape[0]), 95 | 'dir': np.copy(dirs), 96 | 'on_seg': np.copy(on_seg) 97 | }) 98 | 99 | dirs[1] = direction(xs, ys, 2, 3, 1) 100 | on_seg[1] = on_segment(xs, ys, 2, 3, 1) 101 | 102 | probing.push( 103 | probes, 104 | specs.Stage.HINT, 105 | next_probe={ 106 | 'i': probing.mask_one(2, xs.shape[0]), 107 | 'j': probing.mask_one(3, xs.shape[0]), 108 | 'k': probing.mask_one(1, xs.shape[0]), 109 | 'dir': np.copy(dirs), 110 | 'on_seg': np.copy(on_seg) 111 | }) 112 | 113 | dirs[2] = direction(xs, ys, 0, 1, 2) 114 | on_seg[2] = on_segment(xs, ys, 0, 1, 2) 115 | 116 | probing.push( 117 | probes, 118 | specs.Stage.HINT, 119 | next_probe={ 120 | 'i': probing.mask_one(0, xs.shape[0]), 121 | 'j': probing.mask_one(1, xs.shape[0]), 122 | 'k': probing.mask_one(2, xs.shape[0]), 123 | 'dir': np.copy(dirs), 124 | 'on_seg': np.copy(on_seg) 125 | }) 126 | 127 | dirs[3] = direction(xs, ys, 0, 1, 3) 128 | on_seg[3] = on_segment(xs, ys, 0, 1, 3) 129 | 130 | probing.push( 131 | probes, 132 | specs.Stage.HINT, 133 | next_probe={ 134 | 'i': probing.mask_one(0, xs.shape[0]), 135 | 'j': probing.mask_one(1, xs.shape[0]), 136 | 'k': probing.mask_one(3, xs.shape[0]), 137 | 'dir': np.copy(dirs), 138 | 'on_seg': np.copy(on_seg) 139 | }) 140 | 141 | ret = 0 142 | 143 | if ((dirs[0] > 0 and dirs[1] < 0) or 144 | (dirs[0] < 0 and dirs[1] > 0)) and ((dirs[2] > 0 and dirs[3] < 0) or 145 | (dirs[2] < 0 and dirs[3] > 0)): 146 | ret = 1 147 | elif dirs[0] == 0 and on_seg[0]: 148 | ret = 1 149 | elif dirs[1] == 0 and on_seg[1]: 150 | ret = 1 151 | elif dirs[2] == 0 and on_seg[2]: 152 | ret = 1 153 | elif dirs[3] == 0 and on_seg[3]: 154 | ret = 1 155 | 156 | probing.push(probes, specs.Stage.OUTPUT, next_probe={'intersect': ret}) 157 | probing.finalize(probes) 158 | 159 | return ret, probes 160 | 161 | 162 | def graham_scan(xs: _Array, ys: _Array) -> _Out: 163 | """Graham scan convex hull (Graham, 1972).""" 164 | 165 | chex.assert_rank([xs, ys], 1) 166 | probes = probing.initialize(specs.SPECS['graham_scan']) 167 | 168 | A_pos = np.arange(xs.shape[0]) 169 | in_hull = np.zeros(xs.shape[0]) 170 | stack_prev = np.arange(xs.shape[0]) 171 | atans = np.zeros(xs.shape[0]) 172 | 173 | probing.push( 174 | probes, 175 | specs.Stage.INPUT, 176 | next_probe={ 177 | 'pos': np.copy(A_pos) * 1.0 / A_pos.shape[0], 178 | 'x': np.copy(xs), 179 | 'y': np.copy(ys) 180 | }) 181 | 182 | probing.push( 183 | probes, 184 | specs.Stage.HINT, 185 | next_probe={ 186 | 'best': probing.mask_one(0, xs.shape[0]), 187 | 'atans': np.copy(atans), 188 | 'in_hull_h': np.copy(in_hull), 189 | 'stack_prev': np.copy(stack_prev), 190 | 'last_stack': probing.mask_one(0, xs.shape[0]), 191 | 'i': probing.mask_one(0, xs.shape[0]), 192 | 'phase': probing.mask_one(0, 5) 193 | }) 194 | 195 | def counter_clockwise(xs, ys, i, j, k): 196 | return ((xs[j] - xs[i]) * (ys[k] - ys[i]) - (ys[j] - ys[i]) * 197 | (xs[k] - xs[i])) <= 0 198 | 199 | best = 0 200 | for i in range(xs.shape[0]): 201 | if ys[i] < ys[best] or (ys[i] == ys[best] and xs[i] < xs[best]): 202 | best = i 203 | 204 | in_hull[best] = 1 205 | last_stack = best 206 | 207 | probing.push( 208 | probes, 209 | specs.Stage.HINT, 210 | next_probe={ 211 | 'best': probing.mask_one(best, xs.shape[0]), 212 | 'atans': np.copy(atans), 213 | 'in_hull_h': np.copy(in_hull), 214 | 'stack_prev': np.copy(stack_prev), 215 | 'last_stack': probing.mask_one(last_stack, xs.shape[0]), 216 | 'i': probing.mask_one(best, xs.shape[0]), 217 | 'phase': probing.mask_one(1, 5) 218 | }) 219 | 220 | for i in range(xs.shape[0]): 221 | if i != best: 222 | atans[i] = math.atan2(ys[i] - ys[best], xs[i] - xs[best]) 223 | atans[best] = -123456789 224 | ind = np.argsort(atans) 225 | atans[best] = 0 226 | 227 | probing.push( 228 | probes, 229 | specs.Stage.HINT, 230 | next_probe={ 231 | 'best': probing.mask_one(best, xs.shape[0]), 232 | 'atans': np.copy(atans), 233 | 'in_hull_h': np.copy(in_hull), 234 | 'stack_prev': np.copy(stack_prev), 235 | 'last_stack': probing.mask_one(last_stack, xs.shape[0]), 236 | 'i': probing.mask_one(best, xs.shape[0]), 237 | 'phase': probing.mask_one(2, 5) 238 | }) 239 | 240 | for i in range(1, xs.shape[0]): 241 | if i >= 3: 242 | while counter_clockwise(xs, ys, stack_prev[last_stack], last_stack, 243 | ind[i]): 244 | prev_last = last_stack 245 | last_stack = stack_prev[last_stack] 246 | stack_prev[prev_last] = prev_last 247 | in_hull[prev_last] = 0 248 | probing.push( 249 | probes, 250 | specs.Stage.HINT, 251 | next_probe={ 252 | 'best': probing.mask_one(best, xs.shape[0]), 253 | 'atans': np.copy(atans), 254 | 'in_hull_h': np.copy(in_hull), 255 | 'stack_prev': np.copy(stack_prev), 256 | 'last_stack': probing.mask_one(last_stack, xs.shape[0]), 257 | 'i': probing.mask_one(A_pos[ind[i]], xs.shape[0]), 258 | 'phase': probing.mask_one(3, 5) 259 | }) 260 | 261 | in_hull[ind[i]] = 1 262 | stack_prev[ind[i]] = last_stack 263 | last_stack = ind[i] 264 | 265 | probing.push( 266 | probes, 267 | specs.Stage.HINT, 268 | next_probe={ 269 | 'best': probing.mask_one(best, xs.shape[0]), 270 | 'atans': np.copy(atans), 271 | 'in_hull_h': np.copy(in_hull), 272 | 'stack_prev': np.copy(stack_prev), 273 | 'last_stack': probing.mask_one(last_stack, xs.shape[0]), 274 | 'i': probing.mask_one(A_pos[ind[i]], xs.shape[0]), 275 | 'phase': probing.mask_one(4, 5) 276 | }) 277 | 278 | probing.push( 279 | probes, 280 | specs.Stage.OUTPUT, 281 | next_probe={'in_hull': np.copy(in_hull)}, 282 | ) 283 | probing.finalize(probes) 284 | 285 | return in_hull, probes 286 | 287 | 288 | def jarvis_march(xs: _Array, ys: _Array) -> _Out: 289 | """Jarvis' march convex hull (Jarvis, 1973).""" 290 | 291 | chex.assert_rank([xs, ys], 1) 292 | probes = probing.initialize(specs.SPECS['jarvis_march']) 293 | 294 | A_pos = np.arange(xs.shape[0]) 295 | in_hull = np.zeros(xs.shape[0]) 296 | 297 | probing.push( 298 | probes, 299 | specs.Stage.INPUT, 300 | next_probe={ 301 | 'pos': np.copy(A_pos) * 1.0 / A_pos.shape[0], 302 | 'x': np.copy(xs), 303 | 'y': np.copy(ys) 304 | }) 305 | 306 | probing.push( 307 | probes, 308 | specs.Stage.HINT, 309 | next_probe={ 310 | 'pred_h': probing.array(np.copy(A_pos)), 311 | 'in_hull_h': np.copy(in_hull), 312 | 'best': probing.mask_one(0, xs.shape[0]), 313 | 'last_point': probing.mask_one(0, xs.shape[0]), 314 | 'endpoint': probing.mask_one(0, xs.shape[0]), 315 | 'i': probing.mask_one(0, xs.shape[0]), 316 | 'phase': probing.mask_one(0, 2) 317 | }) 318 | 319 | def counter_clockwise(xs, ys, i, j, k): 320 | if (k == i) or (k == j): 321 | return False 322 | return ((xs[j] - xs[i]) * (ys[k] - ys[i]) - (ys[j] - ys[i]) * 323 | (xs[k] - xs[i])) <= 0 324 | 325 | best = 0 326 | for i in range(xs.shape[0]): 327 | if ys[i] < ys[best] or (ys[i] == ys[best] and xs[i] < xs[best]): 328 | best = i 329 | 330 | in_hull[best] = 1 331 | last_point = best 332 | endpoint = 0 333 | 334 | probing.push( 335 | probes, 336 | specs.Stage.HINT, 337 | next_probe={ 338 | 'pred_h': probing.array(np.copy(A_pos)), 339 | 'in_hull_h': np.copy(in_hull), 340 | 'best': probing.mask_one(best, xs.shape[0]), 341 | 'last_point': probing.mask_one(last_point, xs.shape[0]), 342 | 'endpoint': probing.mask_one(endpoint, xs.shape[0]), 343 | 'i': probing.mask_one(0, xs.shape[0]), 344 | 'phase': probing.mask_one(1, 2) 345 | }) 346 | 347 | while True: 348 | for i in range(xs.shape[0]): 349 | if endpoint == last_point or counter_clockwise(xs, ys, last_point, 350 | endpoint, i): 351 | endpoint = i 352 | probing.push( 353 | probes, 354 | specs.Stage.HINT, 355 | next_probe={ 356 | 'pred_h': probing.array(np.copy(A_pos)), 357 | 'in_hull_h': np.copy(in_hull), 358 | 'best': probing.mask_one(best, xs.shape[0]), 359 | 'last_point': probing.mask_one(last_point, xs.shape[0]), 360 | 'endpoint': probing.mask_one(endpoint, xs.shape[0]), 361 | 'i': probing.mask_one(i, xs.shape[0]), 362 | 'phase': probing.mask_one(1, 2) 363 | }) 364 | if in_hull[endpoint] > 0: 365 | break 366 | in_hull[endpoint] = 1 367 | last_point = endpoint 368 | endpoint = 0 369 | probing.push( 370 | probes, 371 | specs.Stage.HINT, 372 | next_probe={ 373 | 'pred_h': probing.array(np.copy(A_pos)), 374 | 'in_hull_h': np.copy(in_hull), 375 | 'best': probing.mask_one(best, xs.shape[0]), 376 | 'last_point': probing.mask_one(last_point, xs.shape[0]), 377 | 'endpoint': probing.mask_one(endpoint, xs.shape[0]), 378 | 'i': probing.mask_one(0, xs.shape[0]), 379 | 'phase': probing.mask_one(1, 2) 380 | }) 381 | 382 | probing.push( 383 | probes, specs.Stage.OUTPUT, next_probe={'in_hull': np.copy(in_hull)}) 384 | probing.finalize(probes) 385 | 386 | return in_hull, probes 387 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/geometry_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `geometry.py`.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from clrs._src.algorithms import geometry 22 | import numpy as np 23 | 24 | 25 | class GeometryTest(parameterized.TestCase): 26 | 27 | def test_segments_simple(self): 28 | xs_no = np.array([0, 0, 1, 1]) 29 | ys_no = np.array([0, 1, 0, 1]) 30 | out, _ = geometry.segments_intersect(xs_no, ys_no) 31 | self.assertFalse(out) 32 | 33 | xs_yes = np.array([0, 1, 1, 0]) 34 | ys_yes = np.array([0, 1, 0, 1]) 35 | out, _ = geometry.segments_intersect(xs_yes, ys_yes) 36 | self.assertTrue(out) 37 | 38 | xs_just = np.array([-3, 5, 5, -4]) 39 | ys_just = np.array([-3, 5, 5, -4]) 40 | out, _ = geometry.segments_intersect(xs_just, ys_just) 41 | self.assertTrue(out) 42 | 43 | def test_segments_colinear(self): 44 | xs_no = np.array([-1, 1, 2, 4]) 45 | ys_no = np.array([-1, 1, 2, 4]) 46 | out, _ = geometry.segments_intersect(xs_no, ys_no) 47 | self.assertFalse(out) 48 | 49 | xs_yes = np.array([-3, 5, 1, 2]) 50 | ys_yes = np.array([-3, 5, 1, 2]) 51 | out, _ = geometry.segments_intersect(xs_yes, ys_yes) 52 | self.assertTrue(out) 53 | 54 | xs_just = np.array([-3, 5, 5, 7]) 55 | ys_just = np.array([-3, 5, 5, 7]) 56 | out, _ = geometry.segments_intersect(xs_just, ys_just) 57 | self.assertTrue(out) 58 | 59 | @parameterized.named_parameters( 60 | ("Graham scan convex hull", geometry.graham_scan), 61 | ("Jarvis' march convex hull", geometry.jarvis_march), 62 | ) 63 | def test_convex_hull_simple(self, algorithm): 64 | tt = np.linspace(-np.pi, np.pi, 10)[:-1] 65 | xs = np.cos(tt) 66 | ys = np.sin(tt) 67 | in_hull, _ = algorithm(xs, ys) 68 | self.assertTrue(np.all(in_hull == 1)) 69 | 70 | xs = np.append(xs, [0.1]) 71 | ys = np.append(ys, [0.1]) 72 | in_hull, _ = algorithm(xs, ys) 73 | self.assertTrue(np.all(in_hull[:-1] == 1)) 74 | self.assertTrue(np.all(in_hull[-1:] == 0)) 75 | 76 | @parameterized.named_parameters( 77 | ("Graham scan convex hull", geometry.graham_scan), 78 | ("Jarvis' march convex hull", geometry.jarvis_march), 79 | ) 80 | def test_convex_hull_points(self, algorithm): 81 | xs = np.array([0, 15, 20, 30, 50, 50, 55, 70]) 82 | ys = np.array([30, 25, 0, 60, 40, 10, 20, 30]) 83 | expected = np.array([1, 0, 1, 1, 0, 1, 0, 1]) 84 | out, _ = algorithm(xs, ys) 85 | np.testing.assert_array_equal(expected, out) 86 | 87 | 88 | if __name__ == "__main__": 89 | absltest.main() 90 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/graphs_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `graphs.py`.""" 17 | # pylint: disable=invalid-name 18 | 19 | from absl.testing import absltest 20 | 21 | from clrs._src.algorithms import graphs 22 | import numpy as np 23 | 24 | 25 | # Unweighted graphs. 26 | 27 | DAG = np.array([ 28 | [0, 1, 0, 1, 0], 29 | [0, 0, 0, 0, 1], 30 | [0, 0, 0, 0, 1], 31 | [0, 1, 0, 0, 0], 32 | [0, 0, 0, 0, 0], 33 | ]) 34 | 35 | DIRECTED = np.array([ 36 | [0, 1, 0, 1, 0, 0], 37 | [0, 0, 0, 0, 1, 0], 38 | [0, 0, 0, 0, 1, 1], 39 | [0, 1, 0, 0, 0, 0], 40 | [0, 0, 0, 1, 0, 0], 41 | [0, 0, 0, 0, 0, 1], 42 | ]) 43 | 44 | UNDIRECTED = np.array([ 45 | [0, 1, 0, 0, 1], 46 | [1, 0, 1, 1, 1], 47 | [0, 1, 0, 1, 0], 48 | [0, 1, 1, 0, 1], 49 | [1, 1, 0, 1, 0], 50 | ]) 51 | 52 | ANOTHER_UNDIRECTED = np.array([ 53 | [0, 1, 1, 1, 0], 54 | [1, 0, 1, 0, 0], 55 | [1, 1, 0, 0, 0], 56 | [1, 0, 0, 0, 1], 57 | [0, 0, 0, 1, 0], 58 | ]) 59 | 60 | 61 | # Weighted graphs. 62 | 63 | X = np.iinfo(np.int32).max # not connected 64 | 65 | WEIGHTED_DAG = np.array([ 66 | [X, 9, 3, X, X], 67 | [X, X, 6, X, 2], 68 | [X, X, X, 1, X], 69 | [X, X, X, X, 2], 70 | [X, X, X, X, X], 71 | ]) 72 | 73 | WEIGHTED_DIRECTED = np.array([ 74 | [X, 9, 3, X, X], 75 | [X, X, 6, X, 2], 76 | [X, 2, X, 1, X], 77 | [X, X, 2, X, 2], 78 | [X, X, X, X, X], 79 | ]) 80 | 81 | WEIGHTED_UNDIRECTED = np.array([ 82 | [X, 2, 3, X, X], 83 | [2, X, 1, 3, 2], 84 | [3, 1, X, X, 1], 85 | [X, 3, X, X, 5], 86 | [X, 2, 1, 5, X], 87 | ]) 88 | 89 | 90 | # Bipartite graphs. 91 | 92 | BIPARTITE = np.array([ 93 | [0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0], 94 | [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], 95 | [0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0], 96 | [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0], 97 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 98 | [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], 99 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 100 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 101 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 102 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 103 | [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 104 | ]) 105 | 106 | BIPARTITE_2 = np.array([ 107 | [0, 1, 1, 1, 0, 0, 0, 0], 108 | [0, 0, 0, 0, 1, 1, 0, 0], 109 | [0, 0, 0, 0, 1, 0, 1, 0], 110 | [0, 0, 0, 0, 0, 0, 1, 0], 111 | [0, 0, 0, 0, 0, 0, 0, 1], 112 | [0, 0, 0, 0, 0, 0, 0, 1], 113 | [0, 0, 0, 0, 0, 0, 0, 1], 114 | [0, 0, 0, 0, 0, 0, 0, 0], 115 | ]) 116 | 117 | 118 | class GraphsTest(absltest.TestCase): 119 | 120 | def test_dfs(self): 121 | expected_directed = np.array([0, 0, 2, 4, 1, 2]) 122 | out, _ = graphs.dfs(DIRECTED) 123 | np.testing.assert_array_equal(expected_directed, out) 124 | 125 | expected_undirected = np.array([0, 0, 1, 2, 3]) 126 | out, _ = graphs.dfs(UNDIRECTED) 127 | np.testing.assert_array_equal(expected_undirected, out) 128 | 129 | def test_bfs(self): 130 | expected_directed = np.array([0, 0, 2, 0, 1, 5]) 131 | out, _ = graphs.bfs(DIRECTED, 0) 132 | np.testing.assert_array_equal(expected_directed, out) 133 | 134 | expected_undirected = np.array([0, 0, 1, 1, 0]) 135 | out, _ = graphs.bfs(UNDIRECTED, 0) 136 | np.testing.assert_array_equal(expected_undirected, out) 137 | 138 | def test_topological_sort(self): 139 | expected_dag = np.array([3, 4, 0, 1, 4]) 140 | out, _ = graphs.topological_sort(DAG) 141 | np.testing.assert_array_equal(expected_dag, out) 142 | 143 | def test_articulation_points(self): 144 | expected = np.array([1, 0, 0, 1, 0]) 145 | out, _ = graphs.articulation_points(ANOTHER_UNDIRECTED) 146 | np.testing.assert_array_equal(expected, out) 147 | 148 | def test_bridges(self): 149 | expected = np.array([ 150 | [0, 0, 0, 1, -1], 151 | [0, 0, 0, -1, -1], 152 | [0, 0, 0, -1, -1], 153 | [1, -1, -1, 0, 1], 154 | [-1, -1, -1, 1, 0], 155 | ]) 156 | out, _ = graphs.bridges(ANOTHER_UNDIRECTED) 157 | np.testing.assert_array_equal(expected, out) 158 | 159 | def test_strongly_connected_components(self): 160 | expected_directed = np.array([0, 1, 2, 1, 1, 5]) 161 | out, _ = graphs.strongly_connected_components(DIRECTED) 162 | np.testing.assert_array_equal(expected_directed, out) 163 | 164 | expected_undirected = np.array([0, 0, 0, 0, 0]) 165 | out, _ = graphs.strongly_connected_components(UNDIRECTED) 166 | np.testing.assert_array_equal(expected_undirected, out) 167 | 168 | def test_mst_kruskal(self): 169 | expected = np.array([ 170 | [0, 1, 0, 0, 0], 171 | [1, 0, 1, 1, 0], 172 | [0, 1, 0, 0, 1], 173 | [0, 1, 0, 0, 0], 174 | [0, 0, 1, 0, 0], 175 | ]) 176 | out, _ = graphs.mst_kruskal(WEIGHTED_UNDIRECTED) 177 | np.testing.assert_array_equal(expected, out) 178 | 179 | def test_mst_prim(self): 180 | expected = np.array([0, 0, 1, 1, 2]) 181 | out, _ = graphs.mst_prim(WEIGHTED_UNDIRECTED, 0) 182 | np.testing.assert_array_equal(expected, out) 183 | 184 | def test_bellman_ford(self): 185 | expected = np.array([0, 2, 0, 2, 3]) 186 | out, _ = graphs.bellman_ford(WEIGHTED_DIRECTED, 0) 187 | np.testing.assert_array_equal(expected, out) 188 | 189 | def test_dag_shortest_paths(self): 190 | expected = np.array([0, 0, 0, 2, 3]) 191 | out, _ = graphs.bellman_ford(WEIGHTED_DAG, 0) 192 | np.testing.assert_array_equal(expected, out) 193 | 194 | def test_dijkstra(self): 195 | expected = np.array([0, 2, 0, 2, 3]) 196 | out, _ = graphs.dijkstra(WEIGHTED_DIRECTED, 0) 197 | np.testing.assert_array_equal(expected, out) 198 | 199 | def test_floyd_warshall(self): 200 | expected = np.array([0, 2, 0, 2, 3]) 201 | out, _ = graphs.floyd_warshall(WEIGHTED_DIRECTED) 202 | np.testing.assert_array_equal(expected, out[0]) 203 | 204 | def test_bipartite_matching(self): 205 | expected = np.array([ 206 | [1, 1, 1, 1, 0, 0, -1, -1, -1, -1, -1], 207 | [0, 1, -1, -1, -1, -1, 1, -1, -1, -1, -1], 208 | [0, -1, 1, -1, -1, -1, 0, -1, 1, -1, -1], 209 | [0, -1, -1, 1, -1, -1, -1, 1, 0, 0, -1], 210 | [0, -1, -1, -1, 1, -1, -1, -1, 0, -1, -1], 211 | [0, -1, -1, -1, -1, 1, -1, -1, 0, -1, -1], 212 | [-1, 0, 0, -1, -1, -1, 1, -1, -1, -1, 1], 213 | [-1, -1, -1, 0, -1, -1, -1, 1, -1, -1, 1], 214 | [-1, -1, 0, 0, 0, 0, -1, -1, 1, -1, 1], 215 | [-1, -1, -1, 0, -1, -1, -1, -1, -1, 1, 0], 216 | [-1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 1], 217 | ]) 218 | out, _ = graphs.bipartite_matching(BIPARTITE, 5, 4, 0, 10) 219 | np.testing.assert_array_equal(expected, out) 220 | 221 | expected_2 = np.array([ 222 | [1, 1, 1, 1, -1, -1, -1, -1], 223 | [0, 1, -1, -1, 0, 1, -1, -1], 224 | [0, -1, 1, -1, 1, -1, 0, -1], 225 | [0, -1, -1, 1, -1, -1, 1, -1], 226 | [-1, 0, 0, -1, 1, -1, -1, 1], 227 | [-1, 0, -1, -1, -1, 1, -1, 1], 228 | [-1, -1, 0, 0, -1, -1, 1, 1], 229 | [-1, -1, -1, -1, 0, 0, 0, 1], 230 | ]) 231 | out_2, _ = graphs.bipartite_matching(BIPARTITE_2, 3, 3, 0, 7) 232 | np.testing.assert_array_equal(expected_2, out_2) 233 | 234 | if __name__ == "__main__": 235 | absltest.main() 236 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/greedy.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Greedy algorithm generators. 17 | 18 | Currently implements the following: 19 | - Activity selection (Gavril, 1972) 20 | - Task scheduling (Lawler, 1985) 21 | 22 | See "Introduction to Algorithms" 3ed (CLRS3) for more information. 23 | 24 | """ 25 | # pylint: disable=invalid-name 26 | 27 | from typing import Tuple 28 | 29 | import chex 30 | from clrs._src import probing 31 | from clrs._src import specs 32 | import numpy as np 33 | 34 | 35 | _Array = np.ndarray 36 | _Out = Tuple[_Array, probing.ProbesDict] 37 | 38 | 39 | def activity_selector(s: _Array, f: _Array) -> _Out: 40 | """Activity selection (Gavril, 1972).""" 41 | 42 | chex.assert_rank([s, f], 1) 43 | probes = probing.initialize(specs.SPECS['activity_selector']) 44 | 45 | A_pos = np.arange(s.shape[0]) 46 | 47 | probing.push( 48 | probes, 49 | specs.Stage.INPUT, 50 | next_probe={ 51 | 'pos': np.copy(A_pos) * 1.0 / A_pos.shape[0], 52 | 's': np.copy(s), 53 | 'f': np.copy(f) 54 | }) 55 | 56 | A = np.zeros(s.shape[0]) 57 | 58 | probing.push( 59 | probes, 60 | specs.Stage.HINT, 61 | next_probe={ 62 | 'pred_h': probing.array(np.copy(A_pos)), 63 | 'selected_h': np.copy(A), 64 | 'm': probing.mask_one(0, A_pos.shape[0]), 65 | 'k': probing.mask_one(0, A_pos.shape[0]) 66 | }) 67 | 68 | ind = np.argsort(f) 69 | A[ind[0]] = 1 70 | k = ind[0] 71 | 72 | probing.push( 73 | probes, 74 | specs.Stage.HINT, 75 | next_probe={ 76 | 'pred_h': probing.array(np.copy(A_pos)), 77 | 'selected_h': np.copy(A), 78 | 'm': probing.mask_one(ind[0], A_pos.shape[0]), 79 | 'k': probing.mask_one(k, A_pos.shape[0]) 80 | }) 81 | 82 | for m in range(1, s.shape[0]): 83 | if s[ind[m]] >= f[k]: 84 | A[ind[m]] = 1 85 | k = ind[m] 86 | probing.push( 87 | probes, 88 | specs.Stage.HINT, 89 | next_probe={ 90 | 'pred_h': probing.array(np.copy(A_pos)), 91 | 'selected_h': np.copy(A), 92 | 'm': probing.mask_one(ind[m], A_pos.shape[0]), 93 | 'k': probing.mask_one(k, A_pos.shape[0]) 94 | }) 95 | 96 | probing.push(probes, specs.Stage.OUTPUT, next_probe={'selected': np.copy(A)}) 97 | probing.finalize(probes) 98 | 99 | return A, probes 100 | 101 | 102 | def task_scheduling(d: _Array, w: _Array) -> _Out: 103 | """Task scheduling (Lawler, 1985).""" 104 | 105 | chex.assert_rank([d, w], 1) 106 | probes = probing.initialize(specs.SPECS['task_scheduling']) 107 | 108 | A_pos = np.arange(d.shape[0]) 109 | 110 | probing.push( 111 | probes, 112 | specs.Stage.INPUT, 113 | next_probe={ 114 | 'pos': np.copy(A_pos) * 1.0 / A_pos.shape[0], 115 | 'd': np.copy(d), 116 | 'w': np.copy(w) 117 | }) 118 | 119 | A = np.zeros(d.shape[0]) 120 | 121 | probing.push( 122 | probes, 123 | specs.Stage.HINT, 124 | next_probe={ 125 | 'pred_h': probing.array(np.copy(A_pos)), 126 | 'selected_h': np.copy(A), 127 | 'i': probing.mask_one(0, A_pos.shape[0]), 128 | 't': 0 129 | }) 130 | 131 | ind = np.argsort(-w) 132 | A[ind[0]] = 1 133 | t = 1 134 | 135 | probing.push( 136 | probes, 137 | specs.Stage.HINT, 138 | next_probe={ 139 | 'pred_h': probing.array(np.copy(A_pos)), 140 | 'selected_h': np.copy(A), 141 | 'i': probing.mask_one(ind[0], A_pos.shape[0]), 142 | 't': t 143 | }) 144 | 145 | for i in range(1, d.shape[0]): 146 | if t < d[ind[i]]: 147 | A[ind[i]] = 1 148 | t += 1 149 | probing.push( 150 | probes, 151 | specs.Stage.HINT, 152 | next_probe={ 153 | 'pred_h': probing.array(np.copy(A_pos)), 154 | 'selected_h': np.copy(A), 155 | 'i': probing.mask_one(ind[i], A_pos.shape[0]), 156 | 't': t 157 | }) 158 | 159 | probing.push(probes, specs.Stage.OUTPUT, next_probe={'selected': np.copy(A)}) 160 | probing.finalize(probes) 161 | 162 | return A, probes 163 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/greedy_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `greedy.py`.""" 17 | 18 | from absl.testing import absltest 19 | 20 | from clrs._src.algorithms import greedy 21 | import numpy as np 22 | 23 | 24 | class GreedyTest(absltest.TestCase): 25 | 26 | def test_greedy_activity_selector(self): 27 | s = np.array([1, 3, 0, 5, 3, 5, 6, 8, 8, 2, 12]) 28 | f = np.array([4, 5, 6, 7, 9, 9, 10, 11, 12, 14, 16]) 29 | expected = np.array([1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1]) 30 | out, _ = greedy.activity_selector(s, f) 31 | np.testing.assert_array_equal(expected, out) 32 | 33 | def test_task_scheduling(self): 34 | d = np.array([4, 2, 4, 3, 1, 4, 6]) 35 | w = np.array([70, 60, 50, 40, 30, 20, 10]) 36 | expected = np.array([1, 1, 1, 0, 0, 1, 1]) 37 | out, _ = greedy.task_scheduling(d, w) 38 | np.testing.assert_array_equal(expected, out) 39 | 40 | 41 | if __name__ == "__main__": 42 | absltest.main() 43 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/searching.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Searching algorithm generators. 17 | 18 | Currently implements the following: 19 | - Minimum 20 | - Binary search 21 | - Quickselect (Hoare, 1961) 22 | 23 | See "Introduction to Algorithms" 3ed (CLRS3) for more information. 24 | 25 | """ 26 | # pylint: disable=invalid-name 27 | 28 | 29 | from typing import Tuple, Union 30 | 31 | import chex 32 | from clrs._src import probing 33 | from clrs._src import specs 34 | import numpy as np 35 | 36 | 37 | _Array = np.ndarray 38 | _Numeric = Union[int, float] 39 | _Out = Tuple[int, probing.ProbesDict] 40 | 41 | 42 | def minimum(A: _Array) -> _Out: 43 | """Minimum.""" 44 | 45 | chex.assert_rank(A, 1) 46 | probes = probing.initialize(specs.SPECS['minimum']) 47 | 48 | A_pos = np.arange(A.shape[0]) 49 | 50 | probing.push( 51 | probes, 52 | specs.Stage.INPUT, 53 | next_probe={ 54 | 'pos': np.copy(A_pos) * 1.0 / A.shape[0], 55 | 'key': np.copy(A) 56 | }) 57 | 58 | probing.push( 59 | probes, 60 | specs.Stage.HINT, 61 | next_probe={ 62 | 'pred_h': probing.array(np.copy(A_pos)), 63 | 'min_h': probing.mask_one(0, A.shape[0]), 64 | 'i': probing.mask_one(0, A.shape[0]) 65 | }) 66 | 67 | min_ = 0 68 | for i in range(1, A.shape[0]): 69 | if A[min_] > A[i]: 70 | min_ = i 71 | 72 | probing.push( 73 | probes, 74 | specs.Stage.HINT, 75 | next_probe={ 76 | 'pred_h': probing.array(np.copy(A_pos)), 77 | 'min_h': probing.mask_one(min_, A.shape[0]), 78 | 'i': probing.mask_one(i, A.shape[0]) 79 | }) 80 | 81 | probing.push( 82 | probes, 83 | specs.Stage.OUTPUT, 84 | next_probe={'min': probing.mask_one(min_, A.shape[0])}) 85 | 86 | probing.finalize(probes) 87 | 88 | return min_, probes 89 | 90 | 91 | def binary_search(x: _Numeric, A: _Array) -> _Out: 92 | """Binary search.""" 93 | 94 | chex.assert_rank(A, 1) 95 | probes = probing.initialize(specs.SPECS['binary_search']) 96 | 97 | T_pos = np.arange(A.shape[0]) 98 | 99 | probing.push( 100 | probes, 101 | specs.Stage.INPUT, 102 | next_probe={ 103 | 'pos': np.copy(T_pos) * 1.0 / A.shape[0], 104 | 'key': np.copy(A), 105 | 'target': x 106 | }) 107 | 108 | probing.push( 109 | probes, 110 | specs.Stage.HINT, 111 | next_probe={ 112 | 'pred_h': probing.array(np.copy(T_pos)), 113 | 'low': probing.mask_one(0, A.shape[0]), 114 | 'high': probing.mask_one(A.shape[0] - 1, A.shape[0]), 115 | 'mid': probing.mask_one((A.shape[0] - 1) // 2, A.shape[0]), 116 | }) 117 | 118 | low = 0 119 | high = A.shape[0] - 1 # make sure return is always in array 120 | while low < high: 121 | mid = (low + high) // 2 122 | if x <= A[mid]: 123 | high = mid 124 | else: 125 | low = mid + 1 126 | 127 | probing.push( 128 | probes, 129 | specs.Stage.HINT, 130 | next_probe={ 131 | 'pred_h': probing.array(np.copy(T_pos)), 132 | 'low': probing.mask_one(low, A.shape[0]), 133 | 'high': probing.mask_one(high, A.shape[0]), 134 | 'mid': probing.mask_one((low + high) // 2, A.shape[0]), 135 | }) 136 | 137 | probing.push( 138 | probes, 139 | specs.Stage.OUTPUT, 140 | next_probe={'return': probing.mask_one(high, A.shape[0])}) 141 | 142 | probing.finalize(probes) 143 | 144 | return high, probes 145 | 146 | 147 | def quickselect( 148 | A: _Array, 149 | A_pos=None, 150 | p=None, 151 | r=None, 152 | i=None, 153 | probes=None, 154 | ) -> _Out: 155 | """Quickselect (Hoare, 1961).""" 156 | 157 | chex.assert_rank(A, 1) 158 | 159 | def partition(A, A_pos, p, r, target, probes): 160 | x = A[r] 161 | i = p - 1 162 | for j in range(p, r): 163 | if A[j] <= x: 164 | i += 1 165 | tmp = A[i] 166 | A[i] = A[j] 167 | A[j] = tmp 168 | tmp = A_pos[i] 169 | A_pos[i] = A_pos[j] 170 | A_pos[j] = tmp 171 | 172 | probing.push( 173 | probes, 174 | specs.Stage.HINT, 175 | next_probe={ 176 | 'pred_h': probing.array(np.copy(A_pos)), 177 | 'p': probing.mask_one(A_pos[p], A.shape[0]), 178 | 'r': probing.mask_one(A_pos[r], A.shape[0]), 179 | 'i': probing.mask_one(A_pos[i + 1], A.shape[0]), 180 | 'j': probing.mask_one(A_pos[j], A.shape[0]), 181 | 'i_rank': (i + 1) * 1.0 / A.shape[0], 182 | 'target': target * 1.0 / A.shape[0] 183 | }) 184 | 185 | tmp = A[i + 1] 186 | A[i + 1] = A[r] 187 | A[r] = tmp 188 | tmp = A_pos[i + 1] 189 | A_pos[i + 1] = A_pos[r] 190 | A_pos[r] = tmp 191 | 192 | probing.push( 193 | probes, 194 | specs.Stage.HINT, 195 | next_probe={ 196 | 'pred_h': probing.array(np.copy(A_pos)), 197 | 'p': probing.mask_one(A_pos[p], A.shape[0]), 198 | 'r': probing.mask_one(A_pos[r], A.shape[0]), 199 | 'i': probing.mask_one(A_pos[i + 1], A.shape[0]), 200 | 'j': probing.mask_one(A_pos[r], A.shape[0]), 201 | 'i_rank': (i + 1 - p) * 1.0 / A.shape[0], 202 | 'target': target * 1.0 / A.shape[0] 203 | }) 204 | 205 | return i + 1 206 | 207 | if A_pos is None: 208 | A_pos = np.arange(A.shape[0]) 209 | if p is None: 210 | p = 0 211 | if r is None: 212 | r = len(A) - 1 213 | if i is None: 214 | i = len(A) // 2 215 | if probes is None: 216 | probes = probing.initialize(specs.SPECS['quickselect']) 217 | probing.push( 218 | probes, 219 | specs.Stage.INPUT, 220 | next_probe={ 221 | 'pos': np.copy(A_pos) * 1.0 / A.shape[0], 222 | 'key': np.copy(A) 223 | }) 224 | 225 | q = partition(A, A_pos, p, r, i, probes) 226 | k = q - p 227 | if i == k: 228 | probing.push( 229 | probes, 230 | specs.Stage.OUTPUT, 231 | next_probe={'median': probing.mask_one(A_pos[q], A.shape[0])}) 232 | probing.finalize(probes) 233 | return A[q], probes 234 | elif i < k: 235 | return quickselect(A, A_pos, p, q - 1, i, probes) 236 | else: 237 | return quickselect(A, A_pos, q + 1, r, i - k - 1, probes) 238 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/searching_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `searching.py`.""" 17 | # pylint: disable=invalid-name 18 | 19 | from absl.testing import absltest 20 | 21 | from clrs._src.algorithms import searching 22 | import numpy as np 23 | 24 | 25 | EmptyArray = np.asarray([], dtype=np.int32) 26 | 27 | 28 | class SearchingTest(absltest.TestCase): 29 | 30 | def test_minimum(self): 31 | for _ in range(17): 32 | A = np.random.randint(0, 100, size=(13,)) 33 | idx, _ = searching.minimum(A) 34 | self.assertEqual(A.min(), A[idx]) 35 | 36 | def test_binary_search(self): 37 | A = np.random.randint(0, 100, size=(13,)) 38 | A.sort() 39 | x = np.random.choice(A) 40 | idx, _ = searching.binary_search(x, A) 41 | self.assertEqual(A[idx], x) 42 | 43 | def test_quickselect(self): 44 | A = np.random.randint(0, 100, size=(13,)) 45 | idx, _ = searching.quickselect(A) 46 | self.assertEqual(sorted(A)[len(A) // 2], idx) 47 | 48 | 49 | if __name__ == '__main__': 50 | absltest.main() 51 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/sorting.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Sorting algorithm generators. 17 | 18 | Currently implements the following: 19 | - Insertion sort 20 | - Bubble sort 21 | - Heapsort (Williams, 1964) 22 | - Quicksort (Hoare, 1962) 23 | 24 | See "Introduction to Algorithms" 3ed (CLRS3) for more information. 25 | 26 | """ 27 | # pylint: disable=invalid-name 28 | 29 | 30 | from typing import Tuple 31 | 32 | import chex 33 | from clrs._src import probing 34 | from clrs._src import specs 35 | import numpy as np 36 | 37 | 38 | _Array = np.ndarray 39 | _Out = Tuple[_Array, probing.ProbesDict] 40 | 41 | 42 | def insertion_sort(A: _Array) -> _Out: 43 | """Insertion sort.""" 44 | 45 | chex.assert_rank(A, 1) 46 | probes = probing.initialize(specs.SPECS['insertion_sort']) 47 | 48 | A_pos = np.arange(A.shape[0]) 49 | 50 | probing.push( 51 | probes, 52 | specs.Stage.INPUT, 53 | next_probe={ 54 | 'pos': np.copy(A_pos) * 1.0 / A.shape[0], 55 | 'key': np.copy(A) 56 | }) 57 | 58 | probing.push( 59 | probes, 60 | specs.Stage.HINT, 61 | next_probe={ 62 | 'pred_h': probing.array(np.copy(A_pos)), 63 | 'i': probing.mask_one(0, A.shape[0]), 64 | 'j': probing.mask_one(0, A.shape[0]) 65 | }) 66 | 67 | for j in range(1, A.shape[0]): 68 | key = A[j] 69 | # Insert A[j] into the sorted sequence A[1 .. j - 1] 70 | i = j - 1 71 | while i >= 0 and A[i] > key: 72 | A[i + 1] = A[i] 73 | A_pos[i + 1] = A_pos[i] 74 | i -= 1 75 | A[i + 1] = key 76 | stor_pos = A_pos[i + 1] 77 | A_pos[i + 1] = j 78 | 79 | probing.push( 80 | probes, 81 | specs.Stage.HINT, 82 | next_probe={ 83 | 'pred_h': probing.array(np.copy(A_pos)), 84 | 'i': probing.mask_one(stor_pos, np.copy(A.shape[0])), 85 | 'j': probing.mask_one(j, np.copy(A.shape[0])) 86 | }) 87 | 88 | probing.push( 89 | probes, 90 | specs.Stage.OUTPUT, 91 | next_probe={'pred': probing.array(np.copy(A_pos))}) 92 | 93 | probing.finalize(probes) 94 | 95 | return A, probes 96 | 97 | 98 | def bubble_sort(A: _Array) -> _Out: 99 | """Bubble sort.""" 100 | 101 | chex.assert_rank(A, 1) 102 | probes = probing.initialize(specs.SPECS['bubble_sort']) 103 | 104 | A_pos = np.arange(A.shape[0]) 105 | 106 | probing.push( 107 | probes, 108 | specs.Stage.INPUT, 109 | next_probe={ 110 | 'pos': np.copy(A_pos) * 1.0 / A.shape[0], 111 | 'key': np.copy(A) 112 | }) 113 | 114 | probing.push( 115 | probes, 116 | specs.Stage.HINT, 117 | next_probe={ 118 | 'pred_h': probing.array(np.copy(A_pos)), 119 | 'i': probing.mask_one(0, A.shape[0]), 120 | 'j': probing.mask_one(0, A.shape[0]) 121 | }) 122 | 123 | for i in range(A.shape[0] - 1): 124 | for j in reversed(range(i + 1, A.shape[0])): 125 | if A[j] < A[j - 1]: 126 | tmp = A[j] 127 | A[j] = A[j - 1] 128 | A[j - 1] = tmp 129 | 130 | tmp = A_pos[j] 131 | A_pos[j] = A_pos[j - 1] 132 | A_pos[j - 1] = tmp 133 | 134 | probing.push( 135 | probes, 136 | specs.Stage.HINT, 137 | next_probe={ 138 | 'pred_h': probing.array(np.copy(A_pos)), 139 | 'i': probing.mask_one(A_pos[i], np.copy(A.shape[0])), 140 | 'j': probing.mask_one(A_pos[j], np.copy(A.shape[0])) 141 | }) 142 | 143 | probing.push( 144 | probes, 145 | specs.Stage.OUTPUT, 146 | next_probe={'pred': probing.array(np.copy(A_pos))}, 147 | ) 148 | 149 | probing.finalize(probes) 150 | 151 | return A, probes 152 | 153 | 154 | def heapsort(A: _Array) -> _Out: 155 | """Heapsort (Williams, 1964).""" 156 | 157 | chex.assert_rank(A, 1) 158 | probes = probing.initialize(specs.SPECS['heapsort']) 159 | 160 | A_pos = np.arange(A.shape[0]) 161 | 162 | probing.push( 163 | probes, 164 | specs.Stage.INPUT, 165 | next_probe={ 166 | 'pos': np.copy(A_pos) * 1.0 / A.shape[0], 167 | 'key': np.copy(A) 168 | }) 169 | 170 | probing.push( 171 | probes, 172 | specs.Stage.HINT, 173 | next_probe={ 174 | 'pred_h': probing.array(np.copy(A_pos)), 175 | 'parent': probing.heap(np.copy(A_pos), A.shape[0]), 176 | 'i': probing.mask_one(A.shape[0] - 1, A.shape[0]), 177 | 'j': probing.mask_one(A.shape[0] - 1, A.shape[0]), 178 | 'largest': probing.mask_one(A.shape[0] - 1, A.shape[0]), 179 | 'heap_size': probing.mask_one(A.shape[0] - 1, A.shape[0]), 180 | 'phase': probing.mask_one(0, 3) 181 | }) 182 | 183 | def max_heapify(A, i, heap_size, ind, phase): 184 | l = 2 * i + 1 185 | r = 2 * i + 2 186 | if l < heap_size and A[l] > A[i]: 187 | largest = l 188 | else: 189 | largest = i 190 | if r < heap_size and A[r] > A[largest]: 191 | largest = r 192 | if largest != i: 193 | tmp = A[i] 194 | A[i] = A[largest] 195 | A[largest] = tmp 196 | 197 | tmp = A_pos[i] 198 | A_pos[i] = A_pos[largest] 199 | A_pos[largest] = tmp 200 | 201 | probing.push( 202 | probes, 203 | specs.Stage.HINT, 204 | next_probe={ 205 | 'pred_h': probing.array(np.copy(A_pos)), 206 | 'parent': probing.heap(np.copy(A_pos), heap_size), 207 | 'i': probing.mask_one(A_pos[ind], A.shape[0]), 208 | 'j': probing.mask_one(A_pos[i], A.shape[0]), 209 | 'largest': probing.mask_one(A_pos[largest], A.shape[0]), 210 | 'heap_size': probing.mask_one(A_pos[heap_size - 1], A.shape[0]), 211 | 'phase': probing.mask_one(phase, 3) 212 | }) 213 | 214 | if largest != i: 215 | max_heapify(A, largest, heap_size, ind, phase) 216 | 217 | def build_max_heap(A): 218 | for i in reversed(range(A.shape[0])): 219 | max_heapify(A, i, A.shape[0], i, 0) 220 | 221 | build_max_heap(A) 222 | heap_size = A.shape[0] 223 | for i in reversed(range(1, A.shape[0])): 224 | tmp = A[0] 225 | A[0] = A[i] 226 | A[i] = tmp 227 | 228 | tmp = A_pos[0] 229 | A_pos[0] = A_pos[i] 230 | A_pos[i] = tmp 231 | 232 | heap_size -= 1 233 | 234 | probing.push( 235 | probes, 236 | specs.Stage.HINT, 237 | next_probe={ 238 | 'pred_h': probing.array(np.copy(A_pos)), 239 | 'parent': probing.heap(np.copy(A_pos), heap_size), 240 | 'i': probing.mask_one(A_pos[0], A.shape[0]), 241 | 'j': probing.mask_one(A_pos[i], A.shape[0]), 242 | 'largest': probing.mask_one(0, A.shape[0]), # Consider masking 243 | 'heap_size': probing.mask_one(A_pos[heap_size - 1], A.shape[0]), 244 | 'phase': probing.mask_one(1, 3) 245 | }) 246 | 247 | max_heapify(A, 0, heap_size, i, 2) # reduce heap_size! 248 | 249 | probing.push( 250 | probes, 251 | specs.Stage.OUTPUT, 252 | next_probe={'pred': probing.array(np.copy(A_pos))}, 253 | ) 254 | 255 | probing.finalize(probes) 256 | 257 | return A, probes 258 | 259 | 260 | def quicksort(A: _Array, A_pos=None, p=None, r=None, probes=None) -> _Out: 261 | """Quicksort (Hoare, 1962).""" 262 | 263 | chex.assert_rank(A, 1) 264 | 265 | def partition(A, A_pos, p, r, probes): 266 | x = A[r] 267 | i = p - 1 268 | for j in range(p, r): 269 | if A[j] <= x: 270 | i += 1 271 | tmp = A[i] 272 | A[i] = A[j] 273 | A[j] = tmp 274 | tmp = A_pos[i] 275 | A_pos[i] = A_pos[j] 276 | A_pos[j] = tmp 277 | 278 | probing.push( 279 | probes, 280 | specs.Stage.HINT, 281 | next_probe={ 282 | 'pred_h': probing.array(np.copy(A_pos)), 283 | 'p': probing.mask_one(A_pos[p], A.shape[0]), 284 | 'r': probing.mask_one(A_pos[r], A.shape[0]), 285 | 'i': probing.mask_one(A_pos[i + 1], A.shape[0]), 286 | 'j': probing.mask_one(A_pos[j], A.shape[0]) 287 | }) 288 | 289 | tmp = A[i + 1] 290 | A[i + 1] = A[r] 291 | A[r] = tmp 292 | tmp = A_pos[i + 1] 293 | A_pos[i + 1] = A_pos[r] 294 | A_pos[r] = tmp 295 | 296 | probing.push( 297 | probes, 298 | specs.Stage.HINT, 299 | next_probe={ 300 | 'pred_h': probing.array(np.copy(A_pos)), 301 | 'p': probing.mask_one(A_pos[p], A.shape[0]), 302 | 'r': probing.mask_one(A_pos[r], A.shape[0]), 303 | 'i': probing.mask_one(A_pos[i + 1], A.shape[0]), 304 | 'j': probing.mask_one(A_pos[r], A.shape[0]) 305 | }) 306 | 307 | return i + 1 308 | 309 | if A_pos is None: 310 | A_pos = np.arange(A.shape[0]) 311 | if p is None: 312 | p = 0 313 | if r is None: 314 | r = len(A) - 1 315 | if probes is None: 316 | probes = probing.initialize(specs.SPECS['quicksort']) 317 | probing.push( 318 | probes, 319 | specs.Stage.INPUT, 320 | next_probe={ 321 | 'pos': np.copy(A_pos) * 1.0 / A.shape[0], 322 | 'key': np.copy(A) 323 | }) 324 | 325 | if p < r: 326 | q = partition(A, A_pos, p, r, probes) 327 | quicksort(A, A_pos, p, q - 1, probes) 328 | quicksort(A, A_pos, q + 1, r, probes) 329 | 330 | if p == 0 and r == len(A) - 1: 331 | probing.push( 332 | probes, 333 | specs.Stage.OUTPUT, 334 | next_probe={'pred': probing.array(np.copy(A_pos))}, 335 | ) 336 | probing.finalize(probes) 337 | 338 | return A, probes 339 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/sorting_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `sorting.py`.""" 17 | # pylint: disable=invalid-name 18 | 19 | from absl.testing import absltest 20 | from absl.testing import parameterized 21 | 22 | from clrs._src.algorithms import sorting 23 | import numpy as np 24 | 25 | 26 | class SortingTest(parameterized.TestCase): 27 | 28 | @parameterized.named_parameters( 29 | ("Insertion sort", sorting.insertion_sort), 30 | ("Bubble sort", sorting.bubble_sort), 31 | ("Heapsort", sorting.heapsort), 32 | ("Quicksort", sorting.quicksort), 33 | ) 34 | def test_sorted(self, algorithm): 35 | for _ in range(17): 36 | A = np.random.randint(0, 100, size=(13,)) 37 | output, _ = algorithm(A) 38 | np.testing.assert_array_equal(sorted(A), output) 39 | 40 | 41 | if __name__ == "__main__": 42 | absltest.main() 43 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/strings.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Strings algorithm generators. 17 | 18 | Currently implements the following: 19 | - Naive string matching 20 | - Knuth-Morris-Pratt string matching (Knuth et al., 1977) 21 | 22 | See "Introduction to Algorithms" 3ed (CLRS3) for more information. 23 | 24 | """ 25 | # pylint: disable=invalid-name 26 | 27 | 28 | from typing import Tuple 29 | 30 | import chex 31 | from clrs._src import probing 32 | from clrs._src import specs 33 | import numpy as np 34 | 35 | 36 | _Array = np.ndarray 37 | _Out = Tuple[int, probing.ProbesDict] 38 | 39 | _ALPHABET_SIZE = 4 40 | 41 | 42 | def naive_string_matcher(T: _Array, P: _Array) -> _Out: 43 | """Naive string matching.""" 44 | 45 | chex.assert_rank([T, P], 1) 46 | probes = probing.initialize(specs.SPECS['naive_string_matcher']) 47 | 48 | T_pos = np.arange(T.shape[0]) 49 | P_pos = np.arange(P.shape[0]) 50 | 51 | probing.push( 52 | probes, 53 | specs.Stage.INPUT, 54 | next_probe={ 55 | 'string': 56 | probing.strings_id(T_pos, P_pos), 57 | 'pos': 58 | probing.strings_pos(T_pos, P_pos), 59 | 'key': 60 | probing.array_cat( 61 | np.concatenate([np.copy(T), np.copy(P)]), _ALPHABET_SIZE), 62 | }) 63 | 64 | s = 0 65 | while s <= T.shape[0] - P.shape[0]: 66 | i = s 67 | j = 0 68 | 69 | probing.push( 70 | probes, 71 | specs.Stage.HINT, 72 | next_probe={ 73 | 'pred_h': probing.strings_pred(T_pos, P_pos), 74 | 's': probing.mask_one(s, T.shape[0] + P.shape[0]), 75 | 'i': probing.mask_one(i, T.shape[0] + P.shape[0]), 76 | 'j': probing.mask_one(T.shape[0] + j, T.shape[0] + P.shape[0]) 77 | }) 78 | 79 | while True: 80 | if T[i] != P[j]: 81 | break 82 | elif j == P.shape[0] - 1: 83 | probing.push( 84 | probes, 85 | specs.Stage.OUTPUT, 86 | next_probe={'match': probing.mask_one(s, T.shape[0] + P.shape[0])}) 87 | probing.finalize(probes) 88 | return s, probes 89 | else: 90 | i += 1 91 | j += 1 92 | probing.push( 93 | probes, 94 | specs.Stage.HINT, 95 | next_probe={ 96 | 'pred_h': probing.strings_pred(T_pos, P_pos), 97 | 's': probing.mask_one(s, T.shape[0] + P.shape[0]), 98 | 'i': probing.mask_one(i, T.shape[0] + P.shape[0]), 99 | 'j': probing.mask_one(T.shape[0] + j, T.shape[0] + P.shape[0]) 100 | }) 101 | 102 | s += 1 103 | 104 | # By convention, set probe to head of needle if no match is found 105 | probing.push( 106 | probes, 107 | specs.Stage.OUTPUT, 108 | next_probe={ 109 | 'match': probing.mask_one(T.shape[0], T.shape[0] + P.shape[0]) 110 | }) 111 | return T.shape[0], probes 112 | 113 | 114 | def kmp_matcher(T: _Array, P: _Array) -> _Out: 115 | """Knuth-Morris-Pratt string matching (Knuth et al., 1977).""" 116 | 117 | chex.assert_rank([T, P], 1) 118 | probes = probing.initialize(specs.SPECS['kmp_matcher']) 119 | 120 | T_pos = np.arange(T.shape[0]) 121 | P_pos = np.arange(P.shape[0]) 122 | 123 | probing.push( 124 | probes, 125 | specs.Stage.INPUT, 126 | next_probe={ 127 | 'string': 128 | probing.strings_id(T_pos, P_pos), 129 | 'pos': 130 | probing.strings_pos(T_pos, P_pos), 131 | 'key': 132 | probing.array_cat( 133 | np.concatenate([np.copy(T), np.copy(P)]), _ALPHABET_SIZE), 134 | }) 135 | 136 | pi = np.arange(P.shape[0]) 137 | is_reset = np.zeros(P.shape[0]) 138 | 139 | k = 0 140 | k_reset = 1 141 | is_reset[0] = 1 142 | 143 | # Cover the edge case where |P| = 1, and the first half is not executed. 144 | delta = 1 if P.shape[0] > 1 else 0 145 | 146 | probing.push( 147 | probes, 148 | specs.Stage.HINT, 149 | next_probe={ 150 | 'pred_h': probing.strings_pred(T_pos, P_pos), 151 | 'pi': probing.strings_pi(T_pos, P_pos, pi), 152 | 'is_reset': np.concatenate( 153 | [np.zeros(T.shape[0]), np.copy(is_reset)]), 154 | 'k': probing.mask_one(T.shape[0], T.shape[0] + P.shape[0]), 155 | 'k_reset': k_reset, 156 | 'q': probing.mask_one(T.shape[0] + delta, T.shape[0] + P.shape[0]), 157 | 'q_reset': 1, 158 | 's': probing.mask_one(0, T.shape[0] + P.shape[0]), 159 | 'i': probing.mask_one(0, T.shape[0] + P.shape[0]), 160 | 'phase': 0 161 | }) 162 | 163 | for q in range(1, P.shape[0]): 164 | while k_reset == 0 and P[k + 1] != P[q]: 165 | if is_reset[k] == 1: 166 | k_reset = 1 167 | k = 0 168 | else: 169 | k = pi[k] 170 | probing.push( 171 | probes, 172 | specs.Stage.HINT, 173 | next_probe={ 174 | 'pred_h': probing.strings_pred(T_pos, P_pos), 175 | 'pi': probing.strings_pi(T_pos, P_pos, pi), 176 | 'is_reset': np.concatenate( 177 | [np.zeros(T.shape[0]), np.copy(is_reset)]), 178 | 'k': probing.mask_one(T.shape[0] + k, T.shape[0] + P.shape[0]), 179 | 'k_reset': k_reset, 180 | 'q': probing.mask_one(T.shape[0] + q, T.shape[0] + P.shape[0]), 181 | 'q_reset': 1, 182 | 's': probing.mask_one(0, T.shape[0] + P.shape[0]), 183 | 'i': probing.mask_one(0, T.shape[0] + P.shape[0]), 184 | 'phase': 0 185 | }) 186 | if k_reset == 1: 187 | k_reset = 0 188 | k = -1 189 | if P[k + 1] == P[q]: 190 | k += 1 191 | if k == -1: 192 | k = 0 193 | k_reset = 1 194 | is_reset[q] = 1 195 | pi[q] = k 196 | probing.push( 197 | probes, 198 | specs.Stage.HINT, 199 | next_probe={ 200 | 'pred_h': probing.strings_pred(T_pos, P_pos), 201 | 'pi': probing.strings_pi(T_pos, P_pos, pi), 202 | 'is_reset': np.concatenate( 203 | [np.zeros(T.shape[0]), np.copy(is_reset)]), 204 | 'k': probing.mask_one(T.shape[0] + k, T.shape[0] + P.shape[0]), 205 | 'k_reset': k_reset, 206 | 'q': probing.mask_one(T.shape[0] + q, T.shape[0] + P.shape[0]), 207 | 'q_reset': 1, 208 | 's': probing.mask_one(0, T.shape[0] + P.shape[0]), 209 | 'i': probing.mask_one(0, T.shape[0] + P.shape[0]), 210 | 'phase': 0 211 | }) 212 | q = 0 213 | q_reset = 1 214 | s = 0 215 | for i in range(T.shape[0]): 216 | if i >= P.shape[0]: 217 | s += 1 218 | probing.push( 219 | probes, 220 | specs.Stage.HINT, 221 | next_probe={ 222 | 'pred_h': probing.strings_pred(T_pos, P_pos), 223 | 'pi': probing.strings_pi(T_pos, P_pos, pi), 224 | 'is_reset': np.concatenate( 225 | [np.zeros(T.shape[0]), np.copy(is_reset)]), 226 | 'k': probing.mask_one(T.shape[0] + k, T.shape[0] + P.shape[0]), 227 | 'k_reset': k_reset, 228 | 'q': probing.mask_one(T.shape[0] + q, T.shape[0] + P.shape[0]), 229 | 'q_reset': q_reset, 230 | 's': probing.mask_one(s, T.shape[0] + P.shape[0]), 231 | 'i': probing.mask_one(i, T.shape[0] + P.shape[0]), 232 | 'phase': 1 233 | }) 234 | while q_reset == 0 and P[q + 1] != T[i]: 235 | if is_reset[q] == 1: 236 | q = 0 237 | q_reset = 1 238 | else: 239 | q = pi[q] 240 | probing.push( 241 | probes, 242 | specs.Stage.HINT, 243 | next_probe={ 244 | 'pred_h': probing.strings_pred(T_pos, P_pos), 245 | 'pi': probing.strings_pi(T_pos, P_pos, pi), 246 | 'is_reset': np.concatenate( 247 | [np.zeros(T.shape[0]), np.copy(is_reset)]), 248 | 'k': probing.mask_one(T.shape[0] + k, T.shape[0] + P.shape[0]), 249 | 'k_reset': k_reset, 250 | 'q': probing.mask_one(T.shape[0] + q, T.shape[0] + P.shape[0]), 251 | 'q_reset': q_reset, 252 | 's': probing.mask_one(s, T.shape[0] + P.shape[0]), 253 | 'i': probing.mask_one(i, T.shape[0] + P.shape[0]), 254 | 'phase': 1 255 | }) 256 | if q_reset == 1: 257 | q = -1 258 | q_reset = 0 259 | if P[q + 1] == T[i]: 260 | if q == P.shape[0] - 2: 261 | probing.push( 262 | probes, 263 | specs.Stage.OUTPUT, 264 | next_probe={'match': probing.mask_one(s, T.shape[0] + P.shape[0])}) 265 | probing.finalize(probes) 266 | return s, probes 267 | q += 1 268 | if q == -1: 269 | q_reset = 1 270 | q = 0 271 | 272 | # By convention, set probe to head of needle if no match is found 273 | probing.push( 274 | probes, 275 | specs.Stage.OUTPUT, 276 | next_probe={ 277 | 'match': probing.mask_one(T.shape[0], T.shape[0] + P.shape[0]) 278 | }) 279 | probing.finalize(probes) 280 | 281 | return T.shape[0], probes 282 | -------------------------------------------------------------------------------- /clrs/_src/algorithms/strings_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `strings.py`.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from clrs._src.algorithms import strings 22 | import numpy as np 23 | 24 | 25 | class StringsTest(parameterized.TestCase): 26 | 27 | @parameterized.named_parameters( 28 | ("Naive string matching", strings.naive_string_matcher), 29 | ("KMP string matching", strings.kmp_matcher), 30 | ) 31 | def test_string_matching(self, algorithm): 32 | offset, _ = algorithm(np.array([1, 2, 3]), np.array([1, 2, 3])) 33 | self.assertEqual(offset, 0) 34 | offset, _ = algorithm(np.array([1, 2, 3, 1, 2]), np.array([1, 2, 3])) 35 | self.assertEqual(offset, 0) 36 | offset, _ = algorithm(np.array([1, 2, 3, 1, 2, 3]), np.array([1, 2, 3])) 37 | self.assertEqual(offset, 0) 38 | offset, _ = algorithm(np.array([1, 2, 1, 2, 3]), np.array([1, 2, 3])) 39 | self.assertEqual(offset, 2) 40 | offset, _ = algorithm(np.array([3, 2, 1]), np.array([1, 2, 3])) 41 | self.assertEqual(offset, 3) 42 | offset, _ = algorithm(np.array( 43 | [ 44 | 3, 2, 2, 1, 2, 1, 2, 3, 0, 0, 2, 3, 0, 0, 1, 0 45 | ]), np.array([2, 1, 2, 3])) 46 | self.assertEqual(offset, 4) 47 | 48 | if __name__ == "__main__": 49 | absltest.main() 50 | -------------------------------------------------------------------------------- /clrs/_src/baselines_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `baselines.py`.""" 17 | 18 | import copy 19 | import functools 20 | from typing import Generator 21 | 22 | from absl.testing import absltest 23 | from absl.testing import parameterized 24 | import chex 25 | 26 | from clrs._src import baselines 27 | from clrs._src import dataset 28 | from clrs._src import processors 29 | from clrs._src import samplers 30 | from clrs._src import specs 31 | 32 | import haiku as hk 33 | import jax 34 | import numpy as np 35 | 36 | _Array = np.ndarray 37 | 38 | 39 | def _error(x, y): 40 | return np.sum(np.abs(x-y)) 41 | 42 | 43 | def _make_sampler(algo: str, length: int) -> samplers.Sampler: 44 | sampler, _ = samplers.build_sampler( 45 | algo, 46 | seed=samplers.CLRS30['val']['seed'], 47 | num_samples=samplers.CLRS30['val']['num_samples'], 48 | length=length, 49 | ) 50 | return sampler 51 | 52 | 53 | def _make_iterable_sampler( 54 | algo: str, batch_size: int, 55 | length: int) -> Generator[samplers.Feedback, None, None]: 56 | sampler = _make_sampler(algo, length) 57 | while True: 58 | yield sampler.next(batch_size) 59 | 60 | 61 | class BaselinesTest(parameterized.TestCase): 62 | 63 | def test_full_vs_chunked(self): 64 | """Test that chunking does not affect gradients.""" 65 | 66 | batch_size = 4 67 | length = 8 68 | algo = 'insertion_sort' 69 | spec = specs.SPECS[algo] 70 | rng_key = jax.random.PRNGKey(42) 71 | 72 | full_ds = _make_iterable_sampler(algo, batch_size, length) 73 | chunked_ds = dataset.chunkify( 74 | _make_iterable_sampler(algo, batch_size, length), 75 | length) 76 | double_chunked_ds = dataset.chunkify( 77 | _make_iterable_sampler(algo, batch_size, length), 78 | length * 2) 79 | 80 | full_batches = [next(full_ds) for _ in range(2)] 81 | chunked_batches = [next(chunked_ds) for _ in range(2)] 82 | double_chunk_batch = next(double_chunked_ds) 83 | 84 | with chex.fake_jit(): # jitting makes test longer 85 | 86 | processor_factory = processors.get_processor_factory('mpnn', use_ln=False) 87 | common_args = dict(processor_factory=processor_factory, hidden_dim=8, 88 | learning_rate=0.01, decode_diffs=True, 89 | decode_hints=True, encode_hints=True) 90 | 91 | b_full = baselines.BaselineModel( 92 | spec, dummy_trajectory=full_batches[0], **common_args) 93 | b_full.init(full_batches[0].features, seed=0) 94 | full_params = b_full.params 95 | full_loss_0 = b_full.feedback(rng_key, full_batches[0]) 96 | b_full.params = full_params 97 | full_loss_1 = b_full.feedback(rng_key, full_batches[1]) 98 | new_full_params = b_full.params 99 | 100 | b_chunked = baselines.BaselineModelChunked( 101 | spec, dummy_trajectory=chunked_batches[0], **common_args) 102 | b_chunked.init(chunked_batches[0].features, seed=0) 103 | chunked_params = b_chunked.params 104 | jax.tree_map(np.testing.assert_array_equal, 105 | full_params, chunked_params) 106 | chunked_loss_0 = b_chunked.feedback(rng_key, chunked_batches[0]) 107 | b_chunked.params = chunked_params 108 | chunked_loss_1 = b_chunked.feedback(rng_key, chunked_batches[1]) 109 | new_chunked_params = b_chunked.params 110 | 111 | b_chunked.params = chunked_params 112 | double_chunked_loss = b_chunked.feedback(rng_key, double_chunk_batch) 113 | 114 | # Test that losses match 115 | np.testing.assert_allclose(full_loss_0, chunked_loss_0, rtol=1e-4) 116 | np.testing.assert_allclose(full_loss_1, chunked_loss_1, rtol=1e-4) 117 | np.testing.assert_allclose(full_loss_0 + full_loss_1, 118 | 2 * double_chunked_loss, 119 | rtol=1e-4) 120 | 121 | # Test that gradients are the same (parameters changed equally). 122 | # First check that gradients were not zero, i.e., parameters have changed. 123 | param_change, _ = jax.tree_flatten( 124 | jax.tree_map(_error, full_params, new_full_params)) 125 | self.assertGreater(np.mean(param_change), 0.1) 126 | # Now check that full and chunked gradients are the same. 127 | jax.tree_map(functools.partial(np.testing.assert_allclose, rtol=1e-4), 128 | new_full_params, new_chunked_params) 129 | 130 | def test_multi_vs_single(self): 131 | """Test that multi = single when we only train one of the algorithms.""" 132 | 133 | batch_size = 4 134 | length = 16 135 | algos = ['insertion_sort', 'activity_selector', 'bfs'] 136 | spec = [specs.SPECS[algo] for algo in algos] 137 | rng_key = jax.random.PRNGKey(42) 138 | 139 | full_ds = [_make_iterable_sampler(algo, batch_size, length) 140 | for algo in algos] 141 | full_batches = [next(ds) for ds in full_ds] 142 | full_batches_2 = [next(ds) for ds in full_ds] 143 | 144 | with chex.fake_jit(): # jitting makes test longer 145 | 146 | processor_factory = processors.get_processor_factory('mpnn', use_ln=False) 147 | common_args = dict(processor_factory=processor_factory, hidden_dim=8, 148 | learning_rate=0.01, decode_diffs=True, 149 | decode_hints=True, encode_hints=True) 150 | 151 | b_single = baselines.BaselineModel( 152 | spec[0], dummy_trajectory=full_batches[0], **common_args) 153 | b_multi = baselines.BaselineModel( 154 | spec, dummy_trajectory=full_batches, **common_args) 155 | b_single.init(full_batches[0].features, seed=0) 156 | b_multi.init([f.features for f in full_batches], seed=0) 157 | 158 | single_params = [] 159 | single_losses = [] 160 | multi_params = [] 161 | multi_losses = [] 162 | 163 | single_params.append(copy.deepcopy(b_single.params)) 164 | single_losses.append(b_single.feedback(rng_key, full_batches[0])) 165 | single_params.append(copy.deepcopy(b_single.params)) 166 | single_losses.append(b_single.feedback(rng_key, full_batches_2[0])) 167 | single_params.append(copy.deepcopy(b_single.params)) 168 | 169 | multi_params.append(copy.deepcopy(b_multi.params)) 170 | multi_losses.append(b_multi.feedback(rng_key, full_batches[0], 171 | algorithm_index=0)) 172 | multi_params.append(copy.deepcopy(b_multi.params)) 173 | multi_losses.append(b_multi.feedback(rng_key, full_batches_2[0], 174 | algorithm_index=0)) 175 | multi_params.append(copy.deepcopy(b_multi.params)) 176 | 177 | # Test that losses match 178 | np.testing.assert_array_equal(single_losses, multi_losses) 179 | # Test that loss decreased 180 | assert single_losses[1] < single_losses[0] 181 | 182 | # Test that param changes were the same in single and multi-algorithm 183 | for single, multi in zip(single_params, multi_params): 184 | assert hk.data_structures.is_subset(subset=single, superset=multi) 185 | for module_name, params in single.items(): 186 | jax.tree_map(np.testing.assert_array_equal, params, multi[module_name]) 187 | 188 | # Test that params change for the trained algorithm, but not the others 189 | for module_name, params in multi_params[0].items(): 190 | param_changes = jax.tree_map(lambda a, b: np.sum(np.abs(a-b)), 191 | params, multi_params[1][module_name]) 192 | param_change = sum(param_changes.values()) 193 | if module_name in single_params[0]: # params of trained algorithm 194 | assert param_change > 1e-3 195 | else: # params of non-trained algorithms 196 | assert param_change == 0.0 197 | 198 | @parameterized.parameters(True, False) 199 | def test_multi_algorithm_idx(self, is_chunked): 200 | """Test that algorithm selection works as intended.""" 201 | 202 | batch_size = 4 203 | length = 8 204 | algos = ['insertion_sort', 'activity_selector', 'bfs'] 205 | spec = [specs.SPECS[algo] for algo in algos] 206 | rng_key = jax.random.PRNGKey(42) 207 | 208 | if is_chunked: 209 | ds = [dataset.chunkify(_make_iterable_sampler(algo, batch_size, length), 210 | 2 * length) for algo in algos] 211 | else: 212 | ds = [_make_iterable_sampler(algo, batch_size, length) for algo in algos] 213 | batches = [next(d) for d in ds] 214 | 215 | with chex.fake_jit(): # jitting makes test longer 216 | processor_factory = processors.get_processor_factory('mpnn', use_ln=False) 217 | common_args = dict(processor_factory=processor_factory, hidden_dim=8, 218 | learning_rate=0.01, decode_diffs=True, 219 | decode_hints=True, encode_hints=True) 220 | if is_chunked: 221 | baseline = baselines.BaselineModelChunked( 222 | spec, dummy_trajectory=batches, **common_args) 223 | else: 224 | baseline = baselines.BaselineModel( 225 | spec, dummy_trajectory=batches, **common_args) 226 | baseline.init([f.features for f in batches], seed=0) 227 | 228 | # Find out what parameters change when we train each algorithm 229 | def _change(x, y): 230 | changes = {} 231 | for module_name, params in x.items(): 232 | changes[module_name] = sum( 233 | jax.tree_map( 234 | lambda a, b: np.sum(np.abs(a-b)), params, y[module_name] 235 | ).values()) 236 | return changes 237 | 238 | param_changes = [] 239 | for algo_idx in range(len(algos)): 240 | init_params = copy.deepcopy(baseline.params) 241 | _ = baseline.feedback( 242 | rng_key, batches[algo_idx], algorithm_index=algo_idx) 243 | param_changes.append(_change(init_params, baseline.params)) 244 | 245 | # Test that non-changing parameters correspond to encoders/decoders 246 | # associated with the non-trained algorithms 247 | unchanged = [[k for k in pc if pc[k] == 0] for pc in param_changes] 248 | 249 | def _get_other_algos(algo_idx, modules): 250 | return set([k for k in modules if '_construct_encoders_decoders' in k 251 | and f'algo_{algo_idx}' not in k]) 252 | 253 | for algo_idx in range(len(algos)): 254 | expected_unchanged = _get_other_algos(algo_idx, baseline.params.keys()) 255 | self.assertNotEmpty(expected_unchanged) 256 | self.assertSetEqual(expected_unchanged, set(unchanged[algo_idx])) 257 | 258 | 259 | if __name__ == '__main__': 260 | absltest.main() 261 | -------------------------------------------------------------------------------- /clrs/_src/dataset_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `dataset.py`.""" 17 | 18 | from typing import Generator, List 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | 23 | from clrs._src import dataset 24 | from clrs._src import samplers 25 | from clrs._src import specs 26 | import numpy as np 27 | 28 | _Array = np.ndarray 29 | 30 | 31 | def _stack_to_shortest(x: List[_Array]) -> _Array: 32 | min_len = min(map(len, x)) 33 | return np.array([a[:min_len] for a in x]) 34 | 35 | 36 | def _make_sampler(algo: str) -> samplers.Sampler: 37 | sampler, _ = samplers.build_sampler( 38 | algo, 39 | seed=samplers.CLRS30['val']['seed'], 40 | num_samples=samplers.CLRS30['val']['num_samples'], 41 | length=samplers.CLRS30['val']['length'], 42 | ) 43 | return sampler 44 | 45 | 46 | def _make_iterable_sampler( 47 | algo: str, batch_size: int) -> Generator[samplers.Feedback, None, None]: 48 | sampler = _make_sampler(algo) 49 | while True: 50 | yield sampler.next(batch_size) 51 | 52 | 53 | class DatasetTest(parameterized.TestCase): 54 | 55 | @parameterized.product( 56 | name=specs.CLRS_30_ALGS[:5], 57 | chunk_length=[20, 50]) 58 | def test_chunkify(self, name: str, chunk_length: int): 59 | """Test that samples are concatenated and split in chunks correctly.""" 60 | batch_size = 8 61 | 62 | ds = _make_iterable_sampler(name, batch_size) 63 | chunked_ds = dataset.chunkify( 64 | _make_iterable_sampler(name, batch_size), 65 | chunk_length) 66 | 67 | samples = [next(ds) for _ in range(20)] 68 | cum_lengths = np.cumsum([s.features.lengths for s in samples], axis=0) 69 | n_chunks = np.amax(cum_lengths[-1]).astype(int) // chunk_length + 1 70 | chunks = [next(chunked_ds) for _ in range(n_chunks)] 71 | 72 | # Check correctness of `is_first` and `is_last` markers 73 | start_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate( 74 | [c.features.is_first for c in chunks]).T]).T 75 | end_idx = _stack_to_shortest([np.where(x)[0] for x in np.concatenate( 76 | [c.features.is_last for c in chunks]).T]).T 77 | assert len(start_idx) >= len(cum_lengths) 78 | start_idx = start_idx[:len(cum_lengths)] 79 | assert len(end_idx) >= len(cum_lengths) 80 | end_idx = end_idx[:len(cum_lengths)] 81 | 82 | np.testing.assert_equal(start_idx[0], 0) 83 | np.testing.assert_array_equal(cum_lengths - 1, end_idx) 84 | np.testing.assert_array_equal(cum_lengths[:-1], start_idx[1:]) 85 | 86 | # Check that inputs, outputs and hints have been copied correctly 87 | all_input = np.concatenate([c.features.inputs[0].data for c in chunks]) 88 | all_output = np.concatenate([c.outputs[0].data for c in chunks]) 89 | all_hint = np.concatenate([c.features.hints[0].data for c in chunks]) 90 | for i in range(batch_size): 91 | length0 = int(samples[0].features.lengths[i]) 92 | length1 = int(samples[1].features.lengths[i]) 93 | # Check first sample 94 | np.testing.assert_array_equal( 95 | all_input[:length0, i], 96 | np.tile(samples[0].features.inputs[0].data[i], [length0, 1])) 97 | np.testing.assert_array_equal( 98 | all_output[:length0, i], 99 | np.tile(samples[0].outputs[0].data[i], [length0, 1])) 100 | np.testing.assert_array_equal( 101 | all_hint[:length0, i], 102 | samples[0].features.hints[0].data[:length0, i]) 103 | # Check second sample 104 | np.testing.assert_array_equal( 105 | all_input[length0:length0 + length1, i], 106 | np.tile(samples[1].features.inputs[0].data[i], [length1, 1])) 107 | np.testing.assert_array_equal( 108 | all_output[length0:length0 + length1, i], 109 | np.tile(samples[1].outputs[0].data[i], [length1, 1])) 110 | np.testing.assert_array_equal( 111 | all_hint[length0:length0 + length1, i], 112 | samples[1].features.hints[0].data[:length1, i]) 113 | 114 | 115 | if __name__ == '__main__': 116 | absltest.main() 117 | -------------------------------------------------------------------------------- /clrs/_src/decoders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """decoders utilities.""" 16 | 17 | import functools 18 | from typing import Dict 19 | import chex 20 | from clrs._src import probing 21 | from clrs._src import specs 22 | import haiku as hk 23 | import jax.numpy as jnp 24 | 25 | _Array = chex.Array 26 | _DataPoint = probing.DataPoint 27 | _Location = specs.Location 28 | _Spec = specs.Spec 29 | _Stage = specs.Stage 30 | _Type = specs.Type 31 | 32 | 33 | def construct_decoders(loc: str, t: str, hidden_dim: int, nb_dims: int, 34 | name: str): 35 | """Constructs decoders.""" 36 | linear = functools.partial(hk.Linear, name=f"{name}_dec_linear") 37 | if loc == _Location.NODE: 38 | # Node decoders. 39 | if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]: 40 | decoders = (linear(1),) 41 | elif t == _Type.CATEGORICAL: 42 | decoders = (linear(nb_dims),) 43 | elif t == _Type.POINTER: 44 | decoders = (linear(hidden_dim), linear(hidden_dim), linear(hidden_dim), 45 | linear(1)) 46 | else: 47 | raise ValueError(f"Invalid Type {t}") 48 | 49 | elif loc == _Location.EDGE: 50 | # Edge decoders. 51 | if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]: 52 | decoders = (linear(1), linear(1), linear(1)) 53 | elif t == _Type.CATEGORICAL: 54 | decoders = (linear(nb_dims), linear(nb_dims), linear(nb_dims)) 55 | elif t == _Type.POINTER: 56 | decoders = (linear(hidden_dim), linear(hidden_dim), 57 | linear(hidden_dim), linear(hidden_dim), linear(1)) 58 | else: 59 | raise ValueError(f"Invalid Type {t}") 60 | 61 | elif loc == _Location.GRAPH: 62 | # Graph decoders. 63 | if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]: 64 | decoders = (linear(1), linear(1)) 65 | elif t == _Type.CATEGORICAL: 66 | decoders = (linear(nb_dims), linear(nb_dims)) 67 | elif t == _Type.POINTER: 68 | decoders = (linear(1), linear(1), 69 | linear(1)) 70 | else: 71 | raise ValueError(f"Invalid Type {t}") 72 | 73 | else: 74 | raise ValueError(f"Invalid Location {loc}") 75 | 76 | return decoders 77 | 78 | 79 | def construct_diff_decoders(name: str): 80 | """Constructs diff decoders.""" 81 | linear = functools.partial(hk.Linear, name=f"{name}_diffdec_linear") 82 | decoders = {} 83 | decoders[_Location.NODE] = linear(1) 84 | decoders[_Location.EDGE] = (linear(1), linear(1), linear(1)) 85 | decoders[_Location.GRAPH] = (linear(1), linear(1)) 86 | 87 | return decoders 88 | 89 | 90 | def postprocess(spec: _Spec, preds: Dict[str, _Array]) -> Dict[str, _DataPoint]: 91 | """Postprocesses decoder output.""" 92 | result = {} 93 | for name in preds.keys(): 94 | _, loc, t = spec[name] 95 | data = preds[name] 96 | if t == _Type.SCALAR: 97 | pass 98 | elif t == _Type.MASK: 99 | data = (data > 0.0) * 1.0 100 | elif t in [_Type.MASK_ONE, _Type.CATEGORICAL]: 101 | cat_size = data.shape[-1] 102 | best = jnp.argmax(data, -1) 103 | data = hk.one_hot(best, cat_size) 104 | elif t == _Type.POINTER: 105 | data = jnp.argmax(data, -1) 106 | else: 107 | raise ValueError("Invalid type") 108 | result[name] = probing.DataPoint( 109 | name=name, location=loc, type_=t, data=data) 110 | 111 | return result 112 | 113 | 114 | def decode_fts( 115 | decoders, 116 | spec: _Spec, 117 | h_t: _Array, 118 | adj_mat: _Array, 119 | edge_fts: _Array, 120 | graph_fts: _Array, 121 | inf_bias: bool, 122 | inf_bias_edge: bool, 123 | ): 124 | """Decodes node, edge and graph features.""" 125 | output_preds = {} 126 | hint_preds = {} 127 | 128 | for name in decoders: 129 | decoder = decoders[name] 130 | stage, loc, t = spec[name] 131 | 132 | if loc == _Location.NODE: 133 | preds = _decode_node_fts(decoder, t, h_t, edge_fts, adj_mat, 134 | inf_bias) 135 | elif loc == _Location.EDGE: 136 | preds = _decode_edge_fts(decoder, t, h_t, edge_fts, adj_mat, 137 | inf_bias_edge) 138 | elif loc == _Location.GRAPH: 139 | preds = _decode_graph_fts(decoder, t, h_t, graph_fts) 140 | else: 141 | raise ValueError("Invalid output type") 142 | 143 | if stage == _Stage.OUTPUT: 144 | output_preds[name] = preds 145 | elif stage == _Stage.HINT: 146 | hint_preds[name] = preds 147 | else: 148 | raise ValueError(f"Found unexpected decoder {name}") 149 | 150 | return hint_preds, output_preds 151 | 152 | 153 | def _decode_node_fts(decoders, t: str, h_t: _Array, edge_fts: _Array, 154 | adj_mat: _Array, inf_bias: bool) -> _Array: 155 | """Decodes node features.""" 156 | 157 | if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]: 158 | preds = jnp.squeeze(decoders[0](h_t), -1) 159 | elif t == _Type.CATEGORICAL: 160 | preds = decoders[0](h_t) 161 | elif t == _Type.POINTER: 162 | p_1 = decoders[0](h_t) 163 | p_2 = decoders[1](h_t) 164 | p_3 = decoders[2](edge_fts) 165 | 166 | p_e = jnp.expand_dims(p_2, -2) + p_3 167 | p_m = jnp.maximum(jnp.expand_dims(p_1, -2), 168 | jnp.transpose(p_e, (0, 2, 1, 3))) 169 | 170 | preds = jnp.squeeze(decoders[3](p_m), -1) 171 | 172 | if inf_bias: 173 | per_batch_min = jnp.min(preds, axis=range(1, preds.ndim), keepdims=True) 174 | preds = jnp.where(adj_mat > 0.5, 175 | preds, 176 | jnp.minimum(-1.0, per_batch_min - 1.0)) 177 | else: 178 | raise ValueError("Invalid output type") 179 | 180 | return preds 181 | 182 | 183 | def _decode_edge_fts(decoders, t: str, h_t: _Array, edge_fts: _Array, 184 | adj_mat: _Array, inf_bias_edge: bool) -> _Array: 185 | """Decodes edge features.""" 186 | 187 | pred_1 = decoders[0](h_t) 188 | pred_2 = decoders[1](h_t) 189 | pred_e = decoders[2](edge_fts) 190 | pred = (jnp.expand_dims(pred_1, -2) + jnp.expand_dims(pred_2, -3) + pred_e) 191 | if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]: 192 | preds = jnp.squeeze(pred, -1) 193 | elif t == _Type.CATEGORICAL: 194 | preds = pred 195 | elif t == _Type.POINTER: 196 | pred_2 = decoders[3](h_t) 197 | 198 | p_m = jnp.maximum(jnp.expand_dims(pred, -2), 199 | jnp.expand_dims( 200 | jnp.expand_dims(pred_2, -3), -3)) 201 | 202 | preds = jnp.squeeze(decoders[4](p_m), -1) 203 | else: 204 | raise ValueError("Invalid output type") 205 | if inf_bias_edge and t in [_Type.MASK, _Type.MASK_ONE]: 206 | per_batch_min = jnp.min(preds, axis=range(1, preds.ndim), keepdims=True) 207 | preds = jnp.where(adj_mat > 0.5, 208 | preds, 209 | jnp.minimum(-1.0, per_batch_min - 1.0)) 210 | 211 | return preds 212 | 213 | 214 | def _decode_graph_fts(decoders, t: str, h_t: _Array, 215 | graph_fts: _Array) -> _Array: 216 | """Decodes graph features.""" 217 | 218 | gr_emb = jnp.max(h_t, axis=-2) 219 | pred_n = decoders[0](gr_emb) 220 | pred_g = decoders[1](graph_fts) 221 | pred = pred_n + pred_g 222 | if t in [_Type.SCALAR, _Type.MASK, _Type.MASK_ONE]: 223 | preds = jnp.squeeze(pred, -1) 224 | elif t == _Type.CATEGORICAL: 225 | preds = pred 226 | elif t == _Type.POINTER: 227 | pred_2 = decoders[2](h_t) 228 | ptr_p = jnp.expand_dims(pred, 1) + jnp.transpose(pred_2, (0, 2, 1)) 229 | preds = jnp.squeeze(ptr_p, 1) 230 | 231 | return preds 232 | 233 | 234 | def maybe_decode_diffs( 235 | diff_decoders, 236 | h_t: _Array, 237 | edge_fts: _Array, 238 | graph_fts: _Array, 239 | batch_size: int, 240 | nb_nodes: int, 241 | decode_diffs: bool, 242 | ) -> Dict[str, _Array]: 243 | """Optionally decodes node, edge and graph diffs.""" 244 | 245 | if decode_diffs: 246 | preds = {} 247 | node = _Location.NODE 248 | edge = _Location.EDGE 249 | graph = _Location.GRAPH 250 | preds[node] = _decode_node_diffs(diff_decoders[node], h_t) 251 | preds[edge] = _decode_edge_diffs(diff_decoders[edge], h_t, edge_fts) 252 | preds[graph] = _decode_graph_diffs(diff_decoders[graph], h_t, graph_fts) 253 | 254 | else: 255 | preds = { 256 | _Location.NODE: jnp.ones((batch_size, nb_nodes)), 257 | _Location.EDGE: jnp.ones((batch_size, nb_nodes, nb_nodes)), 258 | _Location.GRAPH: jnp.ones((batch_size)) 259 | } 260 | 261 | return preds 262 | 263 | 264 | def _decode_node_diffs(decoders, h_t: _Array) -> _Array: 265 | """Decodes node diffs.""" 266 | return jnp.squeeze(decoders(h_t), -1) 267 | 268 | 269 | def _decode_edge_diffs(decoders, h_t: _Array, edge_fts: _Array) -> _Array: 270 | """Decodes edge diffs.""" 271 | 272 | e_pred_1 = decoders[0](h_t) 273 | e_pred_2 = decoders[1](h_t) 274 | e_pred_e = decoders[2](edge_fts) 275 | preds = jnp.squeeze( 276 | jnp.expand_dims(e_pred_1, -1) + jnp.expand_dims(e_pred_2, -2) + e_pred_e, 277 | -1, 278 | ) 279 | 280 | return preds 281 | 282 | 283 | def _decode_graph_diffs(decoders, h_t: _Array, graph_fts: _Array) -> _Array: 284 | """Decodes graph diffs.""" 285 | 286 | gr_emb = jnp.max(h_t, axis=-2) 287 | g_pred_n = decoders[0](gr_emb) 288 | g_pred_g = decoders[1](graph_fts) 289 | preds = jnp.squeeze(g_pred_n + g_pred_g, -1) 290 | 291 | return preds 292 | -------------------------------------------------------------------------------- /clrs/_src/encoders.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Encoder utilities.""" 16 | 17 | import functools 18 | import chex 19 | from clrs._src import probing 20 | from clrs._src import specs 21 | import haiku as hk 22 | import jax.numpy as jnp 23 | 24 | _Array = chex.Array 25 | _DataPoint = probing.DataPoint 26 | _Location = specs.Location 27 | _Spec = specs.Spec 28 | _Type = specs.Type 29 | 30 | 31 | def construct_encoders(loc: str, t: str, hidden_dim: int, name: str, algorithm: str): 32 | """Constructs encoders.""" 33 | linear = functools.partial(hk.Linear, name=f'{name}_enc_linear') 34 | encoders = [linear(hidden_dim)] 35 | if loc == _Location.EDGE and t == _Type.POINTER: 36 | # Edge pointers need two-way encoders. 37 | encoders.append(linear(hidden_dim)) 38 | 39 | return encoders 40 | 41 | 42 | def preprocess(dp: _DataPoint, nb_nodes: int) -> _Array: 43 | """Pre-process data point.""" 44 | if dp.type_ == _Type.POINTER: 45 | data = hk.one_hot(dp.data, nb_nodes) 46 | else: 47 | data = dp.data.astype(jnp.float32) 48 | 49 | return data 50 | 51 | 52 | def accum_adj_mat(dp: _DataPoint, data: _Array, adj_mat: _Array) -> _Array: 53 | """Accumulates adjacency matrix.""" 54 | if dp.name == 'pos': # ignore edge position for accumulating adjacency matrix 55 | return (adj_mat > 0.).astype('float32') 56 | 57 | if dp.location == _Location.NODE and dp.type_ == _Type.POINTER: 58 | adj_mat += ((data + jnp.transpose(data, (0, 2, 1))) > 0.0) 59 | elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK: 60 | adj_mat += ((data + jnp.transpose(data, (0, 2, 1))) > 0.0) 61 | 62 | return (adj_mat > 0.).astype('float32') 63 | 64 | 65 | def accum_edge_fts(encoders, dp: _DataPoint, data: _Array, 66 | edge_fts: _Array) -> _Array: 67 | """Encodes and accumulates edge features.""" 68 | encoding = _encode_inputs(encoders, dp, data) 69 | 70 | if dp.location == _Location.NODE and dp.type_ == _Type.POINTER: 71 | edge_fts += encoding 72 | 73 | elif dp.location == _Location.EDGE: 74 | if dp.type_ == _Type.POINTER: 75 | # Aggregate pointer contributions across sender and receiver nodes. 76 | encoding_2 = encoders[1](jnp.expand_dims(data, -1)) 77 | edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2) 78 | else: 79 | edge_fts += encoding 80 | 81 | return edge_fts 82 | 83 | 84 | def accum_node_fts(encoders, dp: _DataPoint, data: _Array, 85 | node_fts: _Array) -> _Array: 86 | """Encodes and accumulates node features.""" 87 | encoding = _encode_inputs(encoders, dp, data) 88 | 89 | if ((dp.location == _Location.NODE and dp.type_ != _Type.POINTER) or 90 | (dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)): 91 | node_fts += encoding 92 | 93 | return node_fts 94 | 95 | 96 | def accum_graph_fts(encoders, dp: _DataPoint, data: _Array, 97 | graph_fts: _Array) -> _Array: 98 | """Encodes and accumulates graph features.""" 99 | encoding = _encode_inputs(encoders, dp, data) 100 | 101 | if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER: 102 | graph_fts += encoding 103 | 104 | return graph_fts 105 | 106 | 107 | def _encode_inputs(encoders, dp: _DataPoint, data: _Array) -> _Array: 108 | if dp.type_ == _Type.CATEGORICAL: 109 | encoding = encoders[0](data) 110 | else: 111 | encoding = encoders[0](jnp.expand_dims(data, -1)) 112 | return encoding 113 | -------------------------------------------------------------------------------- /clrs/_src/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | """Utilities for calculating losses.""" 16 | 17 | from typing import Dict, List, Tuple 18 | import chex 19 | from clrs._src import probing 20 | from clrs._src import specs 21 | 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | 26 | _Array = chex.Array 27 | _DataPoint = probing.DataPoint 28 | _Location = specs.Location 29 | _OutputClass = specs.OutputClass 30 | _PredTrajectory = Dict[str, _Array] 31 | _PredTrajectories = List[_PredTrajectory] 32 | _Type = specs.Type 33 | 34 | EPS = 1e-12 35 | 36 | 37 | def _expand_to(x: _Array, y: _Array) -> _Array: 38 | while len(y.shape) > len(x.shape): 39 | x = jnp.expand_dims(x, -1) 40 | return x 41 | 42 | 43 | def _expand_and_broadcast_to(x: _Array, y: _Array) -> _Array: 44 | return jnp.broadcast_to(_expand_to(x, y), y.shape) 45 | 46 | 47 | def output_loss_chunked(truth: _DataPoint, pred: _Array, 48 | is_last: _Array, nb_nodes: int) -> float: 49 | """Output loss for time-chunked training.""" 50 | 51 | mask = None 52 | 53 | if truth.type_ == _Type.SCALAR: 54 | loss = (pred - truth.data)**2 55 | 56 | elif truth.type_ == _Type.MASK: 57 | loss = ( 58 | jnp.maximum(pred, 0) - pred * truth.data + 59 | jnp.log1p(jnp.exp(-jnp.abs(pred)))) 60 | mask = (truth.data != _OutputClass.MASKED) 61 | 62 | elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]: 63 | mask = jnp.any(truth.data == _OutputClass.POSITIVE, axis=-1) 64 | masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype( 65 | jnp.float32) 66 | loss = -jnp.sum(masked_truth * jax.nn.log_softmax(pred), axis=-1) 67 | 68 | elif truth.type_ == _Type.POINTER: 69 | loss = -jnp.sum( 70 | hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), axis=-1) 71 | 72 | if mask is not None: 73 | mask = mask * _expand_and_broadcast_to(is_last, loss) 74 | else: 75 | mask = _expand_and_broadcast_to(is_last, loss) 76 | total_mask = jnp.maximum(jnp.sum(mask), EPS) 77 | return jnp.sum(jnp.where(mask, loss, 0.0)) / total_mask 78 | 79 | 80 | def output_loss(truth: _DataPoint, pred: _Array, nb_nodes: int) -> float: 81 | """Output loss for full-sample training.""" 82 | 83 | if truth.type_ == _Type.SCALAR: 84 | total_loss = jnp.mean((pred - truth.data)**2) 85 | 86 | elif truth.type_ == _Type.MASK: 87 | loss = ( 88 | jnp.maximum(pred, 0) - pred * truth.data + 89 | jnp.log1p(jnp.exp(-jnp.abs(pred)))) 90 | mask = (truth.data != _OutputClass.MASKED).astype(jnp.float32) 91 | total_loss = jnp.sum(loss * mask) / jnp.sum(mask) 92 | 93 | elif truth.type_ in [_Type.MASK_ONE, _Type.CATEGORICAL]: 94 | masked_truth = truth.data * (truth.data != _OutputClass.MASKED).astype( 95 | jnp.float32) 96 | total_loss = (-jnp.sum(masked_truth * jax.nn.log_softmax(pred)) / 97 | jnp.sum(truth.data == _OutputClass.POSITIVE)) 98 | 99 | elif truth.type_ == _Type.POINTER: 100 | total_loss = ( 101 | jnp.mean(-jnp.sum( 102 | hk.one_hot(truth.data, nb_nodes) * jax.nn.log_softmax(pred), 103 | axis=-1))) 104 | 105 | return total_loss 106 | 107 | 108 | def diff_loss_chunked(diff_logits, gt_diffs, is_first): 109 | """Diff loss for time-chunked training.""" 110 | total_loss = 0. 111 | for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: 112 | valid = (1 - _expand_and_broadcast_to(is_first, diff_logits[loc])).astype( 113 | jnp.float32) 114 | total_valid = jnp.maximum(jnp.sum(valid), EPS) 115 | loss = ( 116 | jnp.maximum(diff_logits[loc], 0) - 117 | diff_logits[loc] * gt_diffs[loc] + 118 | jnp.log1p(jnp.exp(-jnp.abs(diff_logits[loc])))) 119 | total_loss += jnp.sum(jnp.where(valid, loss, 0.0)) / total_valid 120 | return total_loss 121 | 122 | 123 | def diff_loss(diff_logits, gt_diffs, lengths, verbose=False): 124 | """Diff loss for full-sample training.""" 125 | total_loss = 0. 126 | verbose_loss = dict() 127 | length = len(gt_diffs) 128 | 129 | for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: 130 | for i in range(length): 131 | loss = _diff_loss(loc, i, diff_logits, gt_diffs, lengths) / length 132 | if verbose: 133 | verbose_loss[loc + '_diff_%d' % i] = loss 134 | else: 135 | total_loss += loss 136 | 137 | return verbose_loss if verbose else total_loss 138 | 139 | 140 | def _diff_loss(loc, i, diff_logits, gt_diffs, lengths) -> float: 141 | """Full-sample diff loss helper.""" 142 | is_not_done = _is_not_done_broadcast(lengths, i, diff_logits[i][loc]) 143 | loss = ( 144 | jnp.maximum(diff_logits[i][loc], 0) - 145 | diff_logits[i][loc] * gt_diffs[i][loc] + 146 | jnp.log1p(jnp.exp(-jnp.abs(diff_logits[i][loc]))) * is_not_done) 147 | 148 | return jnp.mean(loss) 149 | 150 | 151 | def hint_loss_chunked( 152 | truth: _DataPoint, 153 | pred: _Array, 154 | gt_diffs: _PredTrajectory, 155 | is_first: _Array, 156 | nb_nodes: int, 157 | decode_diffs: bool, 158 | ): 159 | """Hint loss for time-chunked training.""" 160 | loss, mask = _hint_loss( 161 | truth_data=truth.data, 162 | truth_type=truth.type_, 163 | pred=pred, 164 | nb_nodes=nb_nodes, 165 | ) 166 | 167 | mask *= (1 - _expand_to(is_first, loss)).astype(jnp.float32) 168 | if decode_diffs: 169 | mask *= gt_diffs[truth.location] 170 | loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS) 171 | return loss 172 | 173 | 174 | def hint_loss( 175 | truth: _DataPoint, 176 | preds: List[_Array], 177 | gt_diffs: _PredTrajectories, 178 | lengths: _Array, 179 | nb_nodes: int, 180 | decode_diffs: bool, 181 | verbose: bool = False, 182 | ): 183 | """Hint loss for full-sample training.""" 184 | total_loss = 0. 185 | verbose_loss = {} 186 | length = truth.data.shape[0] - 1 187 | 188 | loss, mask = _hint_loss( 189 | truth_data=truth.data[1:], 190 | truth_type=truth.type_, 191 | pred=jnp.stack(preds), 192 | nb_nodes=nb_nodes, 193 | ) 194 | mask *= _is_not_done_broadcast(lengths, jnp.arange(length)[:, None], loss) 195 | if decode_diffs: 196 | mask *= jnp.stack([g[truth.location] for g in gt_diffs]) 197 | loss = jnp.sum(loss * mask) / jnp.maximum(jnp.sum(mask), EPS) 198 | if verbose: 199 | verbose_loss['loss_' + truth.name] = loss 200 | else: 201 | total_loss += loss 202 | 203 | return verbose_loss if verbose else total_loss 204 | 205 | 206 | def _hint_loss( 207 | truth_data: _Array, 208 | truth_type: str, 209 | pred: _Array, 210 | nb_nodes: int, 211 | ) -> Tuple[_Array, _Array]: 212 | """Hint loss helper.""" 213 | mask = None 214 | if truth_type == _Type.SCALAR: 215 | loss = (pred - truth_data)**2 216 | 217 | elif truth_type == _Type.MASK: 218 | loss = (jnp.maximum(pred, 0) - pred * truth_data + 219 | jnp.log1p(jnp.exp(-jnp.abs(pred)))) 220 | mask = (truth_data != _OutputClass.MASKED).astype(jnp.float32) 221 | 222 | elif truth_type == _Type.MASK_ONE: 223 | loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1, 224 | keepdims=True) 225 | 226 | elif truth_type == _Type.CATEGORICAL: 227 | loss = -jnp.sum(truth_data * jax.nn.log_softmax(pred), axis=-1) 228 | mask = jnp.any(truth_data == _OutputClass.POSITIVE, axis=-1).astype( 229 | jnp.float32) 230 | 231 | elif truth_type == _Type.POINTER: 232 | loss = -jnp.sum( 233 | hk.one_hot(truth_data, nb_nodes) * jax.nn.log_softmax(pred), 234 | axis=-1) 235 | 236 | if mask is None: 237 | mask = jnp.ones_like(loss) 238 | return loss, mask 239 | 240 | 241 | def _is_not_done_broadcast(lengths, i, tensor): 242 | is_not_done = (lengths > i + 1) * 1.0 243 | while len(is_not_done.shape) < len(tensor.shape): 244 | is_not_done = jnp.expand_dims(is_not_done, -1) 245 | return is_not_done 246 | -------------------------------------------------------------------------------- /clrs/_src/losses_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `losses.py`.""" 17 | 18 | from typing import Generator 19 | 20 | from absl.testing import absltest 21 | from absl.testing import parameterized 22 | 23 | from clrs._src import dataset 24 | from clrs._src import losses 25 | from clrs._src import probing 26 | from clrs._src import samplers 27 | from clrs._src import specs 28 | import jax 29 | import jax.numpy as jnp 30 | import numpy as np 31 | 32 | _Array = np.ndarray 33 | _Location = specs.Location 34 | 35 | 36 | def _make_sampler(algo: str, nb_nodes: int) -> samplers.Sampler: 37 | sampler, _ = samplers.build_sampler( 38 | algo, 39 | seed=samplers.CLRS30['val']['seed'], 40 | num_samples=samplers.CLRS30['val']['num_samples'], 41 | length=nb_nodes, 42 | ) 43 | return sampler 44 | 45 | 46 | def _make_iterable_sampler( 47 | algo: str, batch_size: int, 48 | nb_nodes: int) -> Generator[samplers.Feedback, None, None]: 49 | sampler = _make_sampler(algo, nb_nodes) 50 | while True: 51 | yield sampler.next(batch_size) 52 | 53 | 54 | def _as_pred_data(x, nb_nodes, seed, batch_axis): 55 | """Fake a prediction from a data point.""" 56 | # Permute along batch axis to make the prediction different. 57 | key = jax.random.PRNGKey(seed) 58 | data = jax.random.permutation(key, x.data, axis=batch_axis) 59 | # Extend to one-hot for pointer types. 60 | if x.type_ == specs.Type.POINTER: 61 | return jax.nn.one_hot(data, nb_nodes) 62 | return data 63 | 64 | 65 | def _mask_datapoint(x, seed, t_axis=None): 66 | """Add some masking to data.""" 67 | key = jax.random.PRNGKey(seed) 68 | data = x.data 69 | if x.type_ == specs.Type.MASK: 70 | # mask some data at random 71 | mask_shape = list(data.shape) 72 | if t_axis is not None: 73 | mask_shape[t_axis] = 1 74 | mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2 75 | data = jnp.where(mask, specs.OutputClass.MASKED, data) 76 | elif x.type_ in [specs.Type.CATEGORICAL, specs.Type.MASK_ONE]: 77 | # mask some data at random (all categories together) 78 | mask_shape = list(data.shape)[:-1] 79 | if t_axis is not None: 80 | mask_shape[t_axis] = 1 81 | mask = jax.random.uniform(key, tuple(mask_shape)) < 0.2 82 | data = jnp.where(mask[..., None], specs.OutputClass.MASKED, data) 83 | return probing.DataPoint(name=x.name, location=x.location, type_=x.type_, 84 | data=data) 85 | 86 | 87 | def _rand_diff(seed, shape): 88 | return 2.0 * jax.random.uniform(jax.random.PRNGKey(seed), shape) - 1.0 89 | 90 | 91 | def _rand_mask(seed, shape, p=0.5): 92 | return (jax.random.uniform(jax.random.PRNGKey(seed), shape) > p).astype(float) 93 | 94 | 95 | def invert(d): 96 | """Dict of lists -> list of dicts.""" 97 | if d: 98 | return [dict(zip(d, i)) for i in zip(*d.values())] 99 | 100 | 101 | def _create_data(algo, nb_nodes): 102 | batch_size = 8 103 | 104 | ds = _make_iterable_sampler(algo, batch_size, nb_nodes) 105 | full_sample = next(ds) 106 | 107 | chunk_length = full_sample.features.lengths[0].astype(int) 108 | chunked_ds = dataset.chunkify( 109 | _make_iterable_sampler(algo, batch_size, nb_nodes), 110 | chunk_length) 111 | chunk_sample = next(chunked_ds) 112 | 113 | gt_diffs = { 114 | _Location.NODE: _rand_mask(0, (chunk_length, batch_size, 115 | nb_nodes)), 116 | _Location.EDGE: _rand_mask(1, (chunk_length, batch_size, 117 | nb_nodes, nb_nodes)), 118 | _Location.GRAPH: _rand_mask(2, (chunk_length, batch_size)), 119 | } 120 | diff_logits = { 121 | _Location.NODE: _rand_diff(3, (chunk_length, batch_size, 122 | nb_nodes)), 123 | _Location.EDGE: _rand_diff(4, (chunk_length, batch_size, 124 | nb_nodes, nb_nodes)), 125 | _Location.GRAPH: _rand_diff(5, (chunk_length, batch_size)), 126 | } 127 | 128 | return full_sample, chunk_sample, gt_diffs, diff_logits 129 | 130 | 131 | class FullVsChunkLossesTest(parameterized.TestCase): 132 | """Test that the full and chunked versions of the losses match.""" 133 | 134 | # Test two algorithms with fixed-length, covering all data types 135 | @parameterized.parameters('dfs', 'floyd_warshall') 136 | def test_output_loss(self, algo): 137 | nb_nodes = 16 138 | full_sample, chunk_sample, _, _ = _create_data(algo, nb_nodes) 139 | 140 | # Calculate output loss. 141 | for truth_full, truth_chunked in zip(full_sample.outputs, 142 | chunk_sample.outputs): 143 | chunk_output_loss = losses.output_loss_chunked( 144 | truth=_mask_datapoint(truth_chunked, seed=0), 145 | pred=_as_pred_data(truth_chunked, nb_nodes, 0, 1), 146 | is_last=chunk_sample.features.is_last, 147 | nb_nodes=nb_nodes, 148 | ) 149 | full_output_loss = losses.output_loss( 150 | truth=_mask_datapoint(truth_full, seed=0), 151 | pred=_as_pred_data(truth_full, nb_nodes, 0, 0), 152 | nb_nodes=nb_nodes, 153 | ) 154 | np.testing.assert_allclose(chunk_output_loss, full_output_loss, rtol=1e-4) 155 | 156 | @parameterized.parameters('dfs', 'floyd_warshall') 157 | def test_diff_loss(self, algo): 158 | nb_nodes = 16 159 | full_sample, chunk_sample, gt_diffs, diff_logits = _create_data( 160 | algo, nb_nodes) 161 | chunk_diff_loss = losses.diff_loss_chunked( 162 | diff_logits=diff_logits, 163 | gt_diffs=gt_diffs, 164 | is_first=chunk_sample.features.is_first, 165 | ) 166 | full_diff_loss = losses.diff_loss( 167 | diff_logits=invert(diff_logits)[1:], 168 | gt_diffs=invert(gt_diffs)[1:], 169 | lengths=full_sample.features.lengths, 170 | ) 171 | np.testing.assert_allclose(chunk_diff_loss, full_diff_loss, rtol=1e-4) 172 | 173 | @parameterized.parameters('dfs', 'floyd_warshall') 174 | def test_hint_loss(self, algo): 175 | nb_nodes = 16 176 | full_sample, chunk_sample, gt_diffs, unused_diff_logits = _create_data( 177 | algo, nb_nodes) 178 | for decode_diffs in [False, True]: 179 | for truth_full, truth_chunked in zip(full_sample.features.hints, 180 | chunk_sample.features.hints): 181 | np.testing.assert_array_equal(truth_full.data, truth_chunked.data) 182 | pred = _as_pred_data(truth_chunked, nb_nodes, 0, 1) 183 | chunk_hint_loss = losses.hint_loss_chunked( 184 | truth=_mask_datapoint(truth_chunked, seed=1, t_axis=0), 185 | pred=pred, 186 | gt_diffs=gt_diffs, 187 | is_first=chunk_sample.features.is_first, 188 | nb_nodes=nb_nodes, 189 | decode_diffs=decode_diffs, 190 | ) 191 | 192 | full_preds = pred[1:] 193 | full_gt_diffs = invert(gt_diffs)[1:] 194 | full_hint_loss = losses.hint_loss( 195 | truth=_mask_datapoint(truth_full, 1, t_axis=0), 196 | preds=full_preds, 197 | gt_diffs=full_gt_diffs, 198 | lengths=full_sample.features.lengths, 199 | nb_nodes=nb_nodes, 200 | decode_diffs=decode_diffs, 201 | ) 202 | np.testing.assert_allclose(chunk_hint_loss, full_hint_loss, rtol=1e-4) 203 | 204 | 205 | if __name__ == '__main__': 206 | absltest.main() 207 | -------------------------------------------------------------------------------- /clrs/_src/model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Model base classes and utilities.""" 17 | 18 | import abc 19 | from typing import Dict, List, Optional, Tuple, Union 20 | import chex 21 | 22 | from clrs._src import probing 23 | from clrs._src import samplers 24 | from clrs._src import specs 25 | import numpy as np 26 | 27 | 28 | _Array = chex.Array 29 | Result = Dict[str, probing.DataPoint] 30 | 31 | 32 | class Model(abc.ABC): 33 | """Abstract base class for CLRS3-B models.""" 34 | 35 | def __init__(self, spec: Union[specs.Spec, List[specs.Spec]]): 36 | """Set up the problem, prepare to predict on first task.""" 37 | if not isinstance(spec, list): 38 | spec = [spec] 39 | self._spec = spec 40 | 41 | @abc.abstractmethod 42 | def predict(self, features: samplers.Features) -> Result: 43 | """Make predictions about the current task.""" 44 | pass 45 | 46 | @abc.abstractmethod 47 | def feedback(self, feedback: Optional[samplers.Feedback]): 48 | """Advance to the next task, incorporating any available feedback.""" 49 | pass 50 | 51 | 52 | def evaluate_hints( 53 | hints: Tuple[probing.DataPoint], 54 | lengths: _Array, 55 | hint_preds: List[Result], 56 | ) -> Dict[str, _Array]: 57 | """Evaluate hint predictions.""" 58 | evals = {} 59 | for truth in hints: 60 | assert truth.name in hint_preds[0] 61 | eval_along_time = [_evaluate(truth, p[truth.name], hints, 62 | idx=i+1, lengths=lengths)[0] 63 | for (i, p) in enumerate(hint_preds)] 64 | evals[truth.name] = np.sum( 65 | [x * np.sum(i+1 < lengths) 66 | for i, x in enumerate(eval_along_time)]) / np.sum(lengths - 1) 67 | evals[truth.name + '_along_time'] = np.array(eval_along_time) 68 | 69 | # Unlike outputs, the hints sometimes include scalars, which don't have 70 | # a meaningful eval score. So we don't compute a global 'hint score' as we 71 | # do for outputs. 72 | return evals 73 | 74 | 75 | def evaluate( 76 | outputs: Tuple[probing.DataPoint], 77 | predictions: Result, 78 | ) -> Dict[str, float]: 79 | """Evaluate output predictions.""" 80 | evals = {} 81 | node_level_metrics = {} 82 | graph_level_metrics = {} 83 | for truth in outputs: 84 | assert truth.name in predictions 85 | pred = predictions[truth.name] 86 | node_level, graph_level = _evaluate(truth, pred, outputs) 87 | node_level_metrics[truth.name] = node_level 88 | graph_level_metrics[truth.name + '_glevel'] = graph_level 89 | # Return a single scalar score that is the mean of all output scores. 90 | evals['score'] = sum([v.item() for v in node_level_metrics.values()]) / len(node_level_metrics) 91 | evals['graph_score'] = sum([v.item() for v in graph_level_metrics.values()]) / len(graph_level_metrics) 92 | evals.update(node_level_metrics) 93 | evals.update(graph_level_metrics) 94 | return evals 95 | 96 | 97 | def _evaluate(truth, pred, full_truth, idx=None, lengths=None): 98 | """Evaluate single prediction of hint or output.""" 99 | assert pred.name == truth.name 100 | assert pred.location == truth.location 101 | assert pred.type_ == truth.type_ 102 | mask_name = f'{truth.name}_mask' 103 | if mask_name in full_truth: 104 | assert False 105 | mask = full_truth[mask_name].data 106 | return np.mean((pred.data[mask].flatten() - truth.data[mask].flatten())**2) 107 | else: 108 | if truth.type_ not in _EVAL_FN: 109 | raise ValueError('Invalid type') 110 | truth_data = truth.data 111 | pred_data = pred.data 112 | if idx is not None: 113 | if np.all(idx >= lengths): 114 | return 0. 115 | truth_data = truth_data[idx][idx < lengths] 116 | pred_data = pred_data[idx < lengths] 117 | return ( 118 | _EVAL_FN[truth.type_](pred_data, truth_data), graph_level_eval(truth_data.shape, truth.type_)(pred_data, truth_data) 119 | ) 120 | 121 | 122 | def _eval_one(pred, truth): 123 | mask = np.all(truth != specs.OutputClass.MASKED, axis=-1) 124 | return np.sum( 125 | (np.argmax(pred, -1) == np.argmax(truth, -1)) * mask) / np.sum(mask) 126 | 127 | 128 | def _mask_fn(pred, truth): 129 | """Evaluate outputs of type MASK, and account for any class imbalance.""" 130 | mask = (truth != specs.OutputClass.MASKED).astype(np.float32) 131 | 132 | # Use F1 score for the masked outputs to address any imbalance 133 | tp = np.sum((((pred > 0.5) * (truth > 0.5)) * 1.0) * mask) 134 | fp = np.sum((((pred > 0.5) * (truth < 0.5)) * 1.0) * mask) 135 | fn = np.sum((((pred < 0.5) * (truth > 0.5)) * 1.0) * mask) 136 | 137 | # Protect against division by zero 138 | if tp + fp > 0: 139 | precision = tp / (tp + fp) 140 | else: 141 | precision = np.float32(1.0) 142 | if tp + fn > 0: 143 | recall = tp / (tp + fn) 144 | else: 145 | recall = np.float32(1.0) 146 | 147 | if precision + recall > 0.0: 148 | f_1 = 2.0 * precision * recall / (precision + recall) 149 | else: 150 | f_1 = np.float32(0.0) 151 | 152 | return f_1 153 | 154 | _EVAL_FN = { 155 | specs.Type.SCALAR: 156 | lambda pred, truth: np.mean((pred - truth)**2), 157 | specs.Type.MASK: _mask_fn, 158 | specs.Type.MASK_ONE: 159 | _eval_one, 160 | specs.Type.CATEGORICAL: 161 | _eval_one, 162 | specs.Type.POINTER: 163 | lambda pred, truth: np.mean((pred == truth) * 1.0) 164 | } 165 | 166 | def graph_level_eval(truth_shape, out_type: specs.Type): 167 | def _eval_one_graph(pred, truth): 168 | mask = np.all(truth != specs.OutputClass.MASKED, axis=-1) 169 | final_pred = np.argmax(pred, -1) 170 | final_truth = np.argmax(truth, -1) 171 | correct = np.logical_or(np.invert(mask), final_truth == final_pred) 172 | return np.mean(np.all(correct, axis=tuple(range(1, len(correct.shape)))) * 1.0) 173 | 174 | def _mask_fn_graph(pred, truth): # All masks within the same graph must be correct. 175 | mask = (truth != specs.OutputClass.MASKED) 176 | correct = np.logical_or(np.invert(mask), truth == pred) 177 | return np.mean(np.all(correct, axis=tuple(range(1, len(correct.shape)))) * 1.0) 178 | 179 | if len(truth_shape) < 2: 180 | return _EVAL_FN[out_type] 181 | 182 | eval_fn = { 183 | specs.Type.SCALAR: _EVAL_FN[specs.Type.SCALAR], 184 | specs.Type.MASK: _mask_fn_graph, 185 | specs.Type.POINTER: 186 | lambda pred, truth: np.mean(np.all(pred == truth, axis=tuple(range(1, len(truth_shape)))) * 1.0), 187 | } 188 | if len(truth_shape) == 2: 189 | eval_fn.update({ 190 | specs.Type.MASK_ONE: _EVAL_FN[specs.Type.CATEGORICAL], 191 | specs.Type.CATEGORICAL: _EVAL_FN[specs.Type.CATEGORICAL], 192 | }) 193 | else: 194 | eval_fn.update({ 195 | specs.Type.MASK_ONE: _eval_one_graph, 196 | specs.Type.CATEGORICAL: _eval_one_graph, 197 | }) 198 | return eval_fn[out_type] 199 | 200 | -------------------------------------------------------------------------------- /clrs/_src/probing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Probing utilities. 17 | 18 | The dataflow for an algorithm is represented by `(stage, loc, type, data)` 19 | "probes" that are valid under that algorithm's spec (see `specs.py`). 20 | 21 | When constructing probes, it is convenient to represent these fields in a nested 22 | format (`ProbesDict`) to facilate efficient contest-based look-up. 23 | 24 | """ 25 | 26 | from typing import Dict, List, Tuple, Union 27 | 28 | import attr 29 | from clrs._src import specs 30 | import jax 31 | import numpy as np 32 | import tensorflow as tf 33 | 34 | 35 | _Location = specs.Location 36 | _Stage = specs.Stage 37 | _Type = specs.Type 38 | _OutputClass = specs.OutputClass 39 | 40 | _Array = np.ndarray 41 | _Data = Union[_Array, List[_Array]] 42 | _DataOrType = Union[_Data, str] 43 | 44 | ProbesDict = Dict[ 45 | str, Dict[str, Dict[str, Dict[str, _DataOrType]]]] 46 | 47 | 48 | def _convert_to_str(element): 49 | if isinstance(element, tf.Tensor): 50 | return element.numpy().decode('utf-8') 51 | elif isinstance(element, (np.ndarray, bytes)): 52 | return element.decode('utf-8') 53 | else: 54 | return element 55 | 56 | 57 | # First anotation makes this object jax.jit/pmap friendly, second one makes this 58 | # tf.data.Datasets friendly. 59 | @jax.tree_util.register_pytree_node_class 60 | @attr.define 61 | class DataPoint: 62 | """Describes a data point.""" 63 | 64 | _name: str 65 | _location: str 66 | _type_: str 67 | data: _Array 68 | 69 | @property 70 | def name(self): 71 | return _convert_to_str(self._name) 72 | 73 | @property 74 | def location(self): 75 | return _convert_to_str(self._location) 76 | 77 | @property 78 | def type_(self): 79 | return _convert_to_str(self._type_) 80 | 81 | def __repr__(self): 82 | s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t' 83 | return s + f'type={self.type_},\tdata=Array{self.data.shape})' 84 | 85 | def tree_flatten(self): 86 | data = (self.data,) 87 | meta = (self.name, self.location, self.type_) 88 | return data, meta 89 | 90 | @classmethod 91 | def tree_unflatten(cls, meta, data): 92 | name, location, type_ = meta 93 | subdata, = data 94 | return DataPoint(name, location, type_, subdata) 95 | 96 | 97 | class ProbeError(Exception): 98 | pass 99 | 100 | 101 | def initialize(spec: specs.Spec) -> ProbesDict: 102 | """Initializes an empty `ProbesDict` corresponding with the provided spec.""" 103 | probes = dict() 104 | for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]: 105 | probes[stage] = {} 106 | for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: 107 | probes[stage][loc] = {} 108 | 109 | for name in spec: 110 | stage, loc, t = spec[name] 111 | probes[stage][loc][name] = {} 112 | probes[stage][loc][name]['data'] = [] 113 | probes[stage][loc][name]['type_'] = t 114 | # Pytype thinks initialize() returns a ProbesDict with a str for all final 115 | # values instead of _DataOrType. 116 | return probes # pytype: disable=bad-return-type 117 | 118 | 119 | def push(probes: ProbesDict, stage: str, next_probe): 120 | """Pushes a probe into an existing `ProbesDict`.""" 121 | for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: 122 | for name in probes[stage][loc]: 123 | if name not in next_probe: 124 | raise ProbeError(f'Missing probe for {name}.') 125 | if isinstance(probes[stage][loc][name]['data'], _Array): 126 | raise ProbeError('Attemping to push to finalized `ProbesDict`.') 127 | # Pytype thinks initialize() returns a ProbesDict with a str for all final 128 | # values instead of _DataOrType. 129 | probes[stage][loc][name]['data'].append(next_probe[name]) # pytype: disable=attribute-error 130 | 131 | 132 | def finalize(probes: ProbesDict): 133 | """Finalizes a `ProbesDict` by stacking/squeezing `data` field.""" 134 | for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]: 135 | for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: 136 | for name in probes[stage][loc]: 137 | if isinstance(probes[stage][loc][name]['data'], _Array): 138 | raise ProbeError('Attemping to re-finalize a finalized `ProbesDict`.') 139 | if stage == _Stage.HINT: 140 | # Hints are provided for each timestep. Stack them here. 141 | probes[stage][loc][name]['data'] = np.stack( 142 | probes[stage][loc][name]['data']) 143 | else: 144 | # Only one instance of input/output exist. Remove leading axis. 145 | probes[stage][loc][name]['data'] = np.squeeze( 146 | np.array(probes[stage][loc][name]['data'])) 147 | 148 | 149 | def split_stages( 150 | probes: ProbesDict, 151 | spec: specs.Spec, 152 | ) -> Tuple[List[DataPoint], List[DataPoint], List[DataPoint]]: 153 | """Splits contents of `ProbesDict` into `DataPoint`s by stage.""" 154 | 155 | inputs = [] 156 | outputs = [] 157 | hints = [] 158 | 159 | for name in spec: 160 | stage, loc, t = spec[name] 161 | 162 | if stage not in probes: 163 | raise ProbeError(f'Missing stage {stage}.') 164 | if loc not in probes[stage]: 165 | raise ProbeError(f'Missing location {loc}.') 166 | if name not in probes[stage][loc]: 167 | raise ProbeError(f'Missing probe {name}.') 168 | if 'type_' not in probes[stage][loc][name]: 169 | raise ProbeError(f'Probe {name} missing attribute `type_`.') 170 | if 'data' not in probes[stage][loc][name]: 171 | raise ProbeError(f'Probe {name} missing attribute `data`.') 172 | if t != probes[stage][loc][name]['type_']: 173 | raise ProbeError(f'Probe {name} of incorrect type {t}.') 174 | 175 | data = probes[stage][loc][name]['data'] 176 | if not isinstance(probes[stage][loc][name]['data'], _Array): 177 | raise ProbeError((f'Invalid `data` for probe "{name}". ' + 178 | 'Did you forget to call `probing.finalize`?')) 179 | 180 | if t in [_Type.MASK, _Type.MASK_ONE, _Type.CATEGORICAL]: 181 | # pytype: disable=attribute-error 182 | if not ((data == 0) | (data == 1) | (data == -1)).all(): 183 | raise ProbeError(f'0|1|-1 `data` for probe "{name}"') 184 | # pytype: enable=attribute-error 185 | if t in [_Type.MASK_ONE, _Type.CATEGORICAL 186 | ] and not np.all(np.sum(np.abs(data), -1) == 1): 187 | raise ProbeError(f'Expected one-hot `data` for probe "{name}"') 188 | 189 | dim_to_expand = 1 if stage == _Stage.HINT else 0 190 | data_point = DataPoint(name=name, location=loc, type_=t, 191 | data=np.expand_dims(data, dim_to_expand)) 192 | 193 | if stage == _Stage.INPUT: 194 | inputs.append(data_point) 195 | elif stage == _Stage.OUTPUT: 196 | outputs.append(data_point) 197 | else: 198 | hints.append(data_point) 199 | 200 | return inputs, outputs, hints 201 | 202 | 203 | # pylint: disable=invalid-name 204 | 205 | 206 | def array(A_pos: np.ndarray) -> np.ndarray: 207 | """Constructs an `array` probe.""" 208 | probe = np.arange(A_pos.shape[0]) 209 | for i in range(1, A_pos.shape[0]): 210 | probe[A_pos[i]] = A_pos[i - 1] 211 | return probe 212 | 213 | 214 | def array_cat(A: np.ndarray, n: int) -> np.ndarray: 215 | """Constructs an `array_cat` probe.""" 216 | assert n > 0 217 | probe = np.zeros((A.shape[0], n)) 218 | for i in range(A.shape[0]): 219 | probe[i, A[i]] = 1 220 | return probe 221 | 222 | 223 | def heap(A_pos: np.ndarray, heap_size: int) -> np.ndarray: 224 | """Constructs a `heap` probe.""" 225 | assert heap_size > 0 226 | probe = np.arange(A_pos.shape[0]) 227 | for i in range(1, heap_size): 228 | probe[A_pos[i]] = A_pos[(i - 1) // 2] 229 | return probe 230 | 231 | 232 | def graph(A: np.ndarray) -> np.ndarray: 233 | """Constructs a `graph` probe.""" 234 | probe = (A != 0) * 1.0 235 | probe = ((A + np.eye(A.shape[0])) != 0) * 1.0 236 | return probe 237 | 238 | 239 | def mask_one(i: int, n: int) -> np.ndarray: 240 | """Constructs a `mask_one` probe.""" 241 | assert n > i 242 | probe = np.zeros(n) 243 | probe[i] = 1 244 | return probe 245 | 246 | 247 | def strings_id(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: 248 | """Constructs a `strings_id` probe.""" 249 | probe_T = np.zeros(T_pos.shape[0]) 250 | probe_P = np.ones(P_pos.shape[0]) 251 | return np.concatenate([probe_T, probe_P]) 252 | 253 | 254 | def strings_pair(pair_probe: np.ndarray) -> np.ndarray: 255 | """Constructs a `strings_pair` probe.""" 256 | n = pair_probe.shape[0] 257 | m = pair_probe.shape[1] 258 | probe_ret = np.zeros((n + m, n + m)) 259 | for i in range(0, n): 260 | for j in range(0, m): 261 | probe_ret[i, j + n] = pair_probe[i, j] 262 | return probe_ret 263 | 264 | 265 | def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray: 266 | """Constructs a `strings_pair_cat` probe.""" 267 | assert nb_classes > 0 268 | n = pair_probe.shape[0] 269 | m = pair_probe.shape[1] 270 | 271 | # Add an extra class for 'this cell left blank.' 272 | probe_ret = np.zeros((n + m, n + m, nb_classes + 1)) 273 | for i in range(0, n): 274 | for j in range(0, m): 275 | probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE 276 | 277 | # Fill the blank cells. 278 | for i_1 in range(0, n): 279 | for i_2 in range(0, n): 280 | probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED 281 | for j_1 in range(0, m): 282 | for x in range(0, n + m): 283 | probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED 284 | return probe_ret 285 | 286 | 287 | def strings_pi(T_pos: np.ndarray, P_pos: np.ndarray, 288 | pi: np.ndarray) -> np.ndarray: 289 | """Constructs a `strings_pi` probe.""" 290 | probe = np.arange(T_pos.shape[0] + P_pos.shape[0]) 291 | for j in range(P_pos.shape[0]): 292 | probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + pi[P_pos[j]] 293 | return probe 294 | 295 | 296 | def strings_pos(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: 297 | """Constructs a `strings_pos` probe.""" 298 | probe_T = np.copy(T_pos) * 1.0 / T_pos.shape[0] 299 | probe_P = np.copy(P_pos) * 1.0 / P_pos.shape[0] 300 | return np.concatenate([probe_T, probe_P]) 301 | 302 | 303 | def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: 304 | """Constructs a `strings_pred` probe.""" 305 | probe = np.arange(T_pos.shape[0] + P_pos.shape[0]) 306 | for i in range(1, T_pos.shape[0]): 307 | probe[T_pos[i]] = T_pos[i - 1] 308 | for j in range(1, P_pos.shape[0]): 309 | probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + P_pos[j - 1] 310 | return probe 311 | -------------------------------------------------------------------------------- /clrs/_src/probing_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `probing.py`.""" 17 | 18 | from absl.testing import absltest 19 | 20 | from clrs._src import probing 21 | import numpy as np 22 | 23 | 24 | # pylint: disable=invalid-name 25 | 26 | 27 | class ProbingTest(absltest.TestCase): 28 | 29 | def test_array(self): 30 | A_pos = np.array([1, 2, 0, 4, 3]) 31 | expected = np.array([2, 1, 1, 4, 0]) 32 | out = probing.array(A_pos) 33 | np.testing.assert_array_equal(expected, out) 34 | 35 | def test_array_cat(self): 36 | A = np.array([2, 1, 0, 1, 1]) 37 | expected = np.array([ 38 | [0, 0, 1], 39 | [0, 1, 0], 40 | [1, 0, 0], 41 | [0, 1, 0], 42 | [0, 1, 0] 43 | ]) 44 | out = probing.array_cat(A, 3) 45 | np.testing.assert_array_equal(expected, out) 46 | 47 | def test_heap(self): 48 | A_pos = np.array([1, 3, 5, 0, 7, 4, 2, 6]) 49 | expected = np.array([3, 1, 2, 1, 5, 1, 6, 3]) 50 | out = probing.heap(A_pos, heap_size=6) 51 | np.testing.assert_array_equal(expected, out) 52 | 53 | def test_graph(self): 54 | G = np.array([ 55 | [0.0, 7.0, -1.0, -3.9, 7.452], 56 | [0.0, 0.0, 133.0, 0.0, 9.3], 57 | [0.5, 0.1, 0.22, 0.55, 0.666], 58 | [7.0, 6.1, 0.2, 0.0, 0.0], 59 | [0.0, 3.0, 0.0, 1.0, 0.5] 60 | ]) 61 | expected = np.array([ 62 | [1.0, 1.0, 1.0, 1.0, 1.0], 63 | [0.0, 1.0, 1.0, 0.0, 1.0], 64 | [1.0, 1.0, 1.0, 1.0, 1.0], 65 | [1.0, 1.0, 1.0, 1.0, 0.0], 66 | [0.0, 1.0, 0.0, 1.0, 1.0] 67 | ]) 68 | out = probing.graph(G) 69 | np.testing.assert_array_equal(expected, out) 70 | 71 | def test_mask_one(self): 72 | expected = np.array([0, 0, 0, 1, 0]) 73 | out = probing.mask_one(3, 5) 74 | np.testing.assert_array_equal(expected, out) 75 | 76 | def test_strings_id(self): 77 | T_pos = np.array([0, 1, 2, 3, 4]) 78 | P_pos = np.array([0, 1, 2]) 79 | expected = np.array([0, 0, 0, 0, 0, 1, 1, 1]) 80 | out = probing.strings_id(T_pos, P_pos) 81 | np.testing.assert_array_equal(expected, out) 82 | 83 | def test_strings_pair(self): 84 | pair_probe = np.array([ 85 | [0.5, 3.1, 9.1, 7.3], 86 | [1.0, 0.0, 8.0, 9.3], 87 | [0.1, 5.0, 0.0, 1.2] 88 | ]) 89 | expected = np.array([ 90 | [0.0, 0.0, 0.0, 0.5, 3.1, 9.1, 7.3], 91 | [0.0, 0.0, 0.0, 1.0, 0.0, 8.0, 9.3], 92 | [0.0, 0.0, 0.0, 0.1, 5.0, 0.0, 1.2], 93 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 94 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 95 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], 96 | [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0] 97 | ]) 98 | out = probing.strings_pair(pair_probe) 99 | np.testing.assert_equal(expected, out) 100 | 101 | def test_strings_pair_cat(self): 102 | pair_probe = np.array([ 103 | [0, 2, 1], 104 | [2, 2, 0] 105 | ]) 106 | expected = np.array([ 107 | [ 108 | [0, 0, 0, -1], 109 | [0, 0, 0, -1], 110 | [1, 0, 0, 0], 111 | [0, 0, 1, 0], 112 | [0, 1, 0, 0], 113 | ], 114 | [ 115 | [0, 0, 0, -1], 116 | [0, 0, 0, -1], 117 | [0, 0, 1, 0], 118 | [0, 0, 1, 0], 119 | [1, 0, 0, 0], 120 | ], 121 | [ 122 | [0, 0, 0, -1], 123 | [0, 0, 0, -1], 124 | [0, 0, 0, -1], 125 | [0, 0, 0, -1], 126 | [0, 0, 0, -1], 127 | ], 128 | [ 129 | [0, 0, 0, -1], 130 | [0, 0, 0, -1], 131 | [0, 0, 0, -1], 132 | [0, 0, 0, -1], 133 | [0, 0, 0, -1], 134 | ], 135 | [ 136 | [0, 0, 0, -1], 137 | [0, 0, 0, -1], 138 | [0, 0, 0, -1], 139 | [0, 0, 0, -1], 140 | [0, 0, 0, -1], 141 | ], 142 | ]) 143 | out = probing.strings_pair_cat(pair_probe, 3) 144 | np.testing.assert_equal(expected, out) 145 | 146 | def test_strings_pi(self): 147 | T_pos = np.array([0, 1, 2, 3, 4, 5]) 148 | P_pos = np.array([0, 1, 2, 3]) 149 | pi = np.array([3, 1, 0, 2]) 150 | expected = np.array( 151 | [0, 1, 2, 3, 4, 5, 9, 7, 6, 8] 152 | ) 153 | out = probing.strings_pi(T_pos, P_pos, pi) 154 | np.testing.assert_array_equal(expected, out) 155 | 156 | def test_strings_pos(self): 157 | T_pos = np.array([0, 1, 2, 3, 4]) 158 | P_pos = np.array([0, 1, 2, 3]) 159 | expected = np.array( 160 | [0.0, 0.2, 0.4, 0.6, 0.8, 161 | 0.0, 0.25, 0.5, 0.75] 162 | ) 163 | out = probing.strings_pos(T_pos, P_pos) 164 | np.testing.assert_array_equal(expected, out) 165 | 166 | def test_strings_pred(self): 167 | T_pos = np.array([0, 1, 2, 3, 4]) 168 | P_pos = np.array([0, 1, 2]) 169 | expected = np.array([0, 0, 1, 2, 3, 5, 5, 6]) 170 | out = probing.strings_pred(T_pos, P_pos) 171 | np.testing.assert_array_equal(expected, out) 172 | 173 | if __name__ == "__main__": 174 | absltest.main() 175 | -------------------------------------------------------------------------------- /clrs/_src/processors_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Tests for processors.py.""" 17 | 18 | from absl.testing import absltest 19 | import chex 20 | from clrs._src import processors 21 | import haiku as hk 22 | import jax.numpy as jnp 23 | 24 | 25 | class MemnetTest(absltest.TestCase): 26 | 27 | def test_simple_run_and_check_shapes(self): 28 | 29 | batch_size = 64 30 | vocab_size = 177 31 | embedding_size = 64 32 | sentence_size = 11 33 | memory_size = 320 34 | linear_output_size = 128 35 | num_hops = 2 36 | use_ln = True 37 | 38 | def forward_fn(queries, stories): 39 | model = processors.MemNetFull( 40 | vocab_size=vocab_size, 41 | embedding_size=embedding_size, 42 | sentence_size=sentence_size, 43 | memory_size=memory_size, 44 | linear_output_size=linear_output_size, 45 | num_hops=num_hops, 46 | use_ln=use_ln) 47 | return model._apply(queries, stories) 48 | 49 | forward = hk.transform(forward_fn) 50 | 51 | queries = jnp.ones([batch_size, sentence_size], dtype=jnp.int32) 52 | stories = jnp.ones([batch_size, memory_size, sentence_size], 53 | dtype=jnp.int32) 54 | 55 | key = hk.PRNGSequence(42) 56 | params = forward.init(next(key), queries, stories) 57 | 58 | model_output = forward.apply(params, None, queries, stories) 59 | chex.assert_shape(model_output, [batch_size, vocab_size]) 60 | chex.assert_type(model_output, jnp.float32) 61 | 62 | 63 | if __name__ == '__main__': 64 | absltest.main() 65 | -------------------------------------------------------------------------------- /clrs/_src/samplers_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Unit tests for `samplers.py`.""" 17 | 18 | from absl.testing import absltest 19 | from absl.testing import parameterized 20 | 21 | from clrs._src import probing 22 | from clrs._src import samplers 23 | from clrs._src import specs 24 | import jax 25 | import numpy as np 26 | 27 | 28 | class SamplersTest(parameterized.TestCase): 29 | 30 | @parameterized.parameters(*specs.CLRS_30_ALGS) 31 | def test_sampler_determinism(self, name): 32 | num_samples = 3 33 | num_nodes = 10 34 | sampler, _ = samplers.build_sampler(name, num_samples, num_nodes) 35 | 36 | np.random.seed(47) # Set seed 37 | feedback = sampler.next() 38 | expected = feedback.outputs[0].data.copy() 39 | 40 | np.random.seed(48) # Set a different seed 41 | feedback = sampler.next() 42 | actual = feedback.outputs[0].data.copy() 43 | 44 | # Validate that datasets are the same. 45 | np.testing.assert_array_equal(expected, actual) 46 | 47 | @parameterized.parameters(*specs.CLRS_30_ALGS) 48 | def test_sampler_batch_determinism(self, name): 49 | num_samples = 10 50 | batch_size = 5 51 | num_nodes = 10 52 | seed = 0 53 | sampler_1, _ = samplers.build_sampler( 54 | name, num_samples, num_nodes, seed=seed) 55 | sampler_2, _ = samplers.build_sampler( 56 | name, num_samples, num_nodes, seed=seed) 57 | 58 | feedback_1 = sampler_1.next(batch_size) 59 | feedback_2 = sampler_2.next(batch_size) 60 | 61 | # Validate that datasets are the same. 62 | jax.tree_map(np.testing.assert_array_equal, feedback_1, feedback_2) 63 | 64 | def test_end_to_end(self): 65 | num_samples = 7 66 | num_nodes = 3 67 | sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes) 68 | feedback = sampler.next() 69 | 70 | inputs = feedback.features.inputs 71 | self.assertLen(inputs, 4) 72 | self.assertEqual(inputs[0].name, "pos") 73 | self.assertEqual(inputs[0].data.shape, (num_samples, num_nodes)) 74 | 75 | outputs = feedback.outputs 76 | self.assertLen(outputs, 1) 77 | self.assertEqual(outputs[0].name, "pi") 78 | self.assertEqual(outputs[0].data.shape, (num_samples, num_nodes)) 79 | 80 | def test_batch_size(self): 81 | num_samples = 7 82 | num_nodes = 3 83 | sampler, _ = samplers.build_sampler("bfs", num_samples, num_nodes) 84 | 85 | # Full-batch. 86 | feedback = sampler.next() 87 | for dp in feedback.features.inputs: # [B, ...] 88 | self.assertEqual(dp.data.shape[0], num_samples) 89 | 90 | for dp in feedback.outputs: # [B, ...] 91 | self.assertEqual(dp.data.shape[0], num_samples) 92 | 93 | for dp in feedback.features.hints: # [T, B, ...] 94 | self.assertEqual(dp.data.shape[1], num_samples) 95 | 96 | self.assertLen(feedback.features.lengths, num_samples) 97 | 98 | # Specified batch. 99 | batch_size = 5 100 | feedback = sampler.next(batch_size) 101 | 102 | for dp in feedback.features.inputs: # [B, ...] 103 | self.assertEqual(dp.data.shape[0], batch_size) 104 | 105 | for dp in feedback.outputs: # [B, ...] 106 | self.assertEqual(dp.data.shape[0], batch_size) 107 | 108 | for dp in feedback.features.hints: # [T, B, ...] 109 | self.assertEqual(dp.data.shape[1], batch_size) 110 | 111 | self.assertLen(feedback.features.lengths, batch_size) 112 | 113 | def test_batch_io(self): 114 | sample = [ 115 | probing.DataPoint( 116 | name="x", 117 | location=specs.Location.NODE, 118 | type_=specs.Type.SCALAR, 119 | data=np.zeros([1, 3]), 120 | ), 121 | probing.DataPoint( 122 | name="y", 123 | location=specs.Location.EDGE, 124 | type_=specs.Type.MASK, 125 | data=np.zeros([1, 3, 3]), 126 | ), 127 | ] 128 | 129 | trajectory = [sample.copy(), sample.copy(), sample.copy(), sample.copy()] 130 | batched = samplers._batch_io(trajectory) 131 | 132 | np.testing.assert_array_equal(batched[0].data, np.zeros([4, 3])) 133 | np.testing.assert_array_equal(batched[1].data, np.zeros([4, 3, 3])) 134 | 135 | def test_batch_hint(self): 136 | sample0 = [ 137 | probing.DataPoint( 138 | name="x", 139 | location=specs.Location.NODE, 140 | type_=specs.Type.MASK, 141 | data=np.zeros([2, 1, 3]), 142 | ), 143 | probing.DataPoint( 144 | name="y", 145 | location=specs.Location.NODE, 146 | type_=specs.Type.POINTER, 147 | data=np.zeros([2, 1, 3]), 148 | ), 149 | ] 150 | 151 | sample1 = [ 152 | probing.DataPoint( 153 | name="x", 154 | location=specs.Location.NODE, 155 | type_=specs.Type.MASK, 156 | data=np.zeros([1, 1, 3]), 157 | ), 158 | probing.DataPoint( 159 | name="y", 160 | location=specs.Location.NODE, 161 | type_=specs.Type.POINTER, 162 | data=np.zeros([1, 1, 3]), 163 | ), 164 | ] 165 | 166 | trajectory = [sample0, sample1] 167 | batched, lengths = samplers._batch_hints(trajectory) 168 | 169 | np.testing.assert_array_equal(batched[0].data, np.zeros([2, 2, 3])) 170 | np.testing.assert_array_equal(batched[1].data, np.zeros([2, 2, 3])) 171 | np.testing.assert_array_equal(lengths, np.array([2, 1])) 172 | 173 | def test_padding(self): 174 | lens = np.random.choice(10, (10,), replace=True) + 1 175 | trajectory = [] 176 | for len_ in lens: 177 | trajectory.append([ 178 | probing.DataPoint( 179 | name="x", 180 | location=specs.Location.NODE, 181 | type_=specs.Type.MASK, 182 | data=np.ones([len_, 1, 3]), 183 | ) 184 | ]) 185 | 186 | batched, lengths = samplers._batch_hints(trajectory) 187 | np.testing.assert_array_equal(lengths, lens) 188 | 189 | for i in range(len(lens)): 190 | ones = batched[0].data[:lens[i], i, :] 191 | zeros = batched[0].data[lens[i]:, i, :] 192 | np.testing.assert_array_equal(ones, np.ones_like(ones)) 193 | np.testing.assert_array_equal(zeros, np.zeros_like(zeros)) 194 | 195 | 196 | if __name__ == "__main__": 197 | absltest.main() 198 | -------------------------------------------------------------------------------- /clrs/_src/scratch.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import logging 3 | 4 | import jax 5 | try: # Only needed for calling functions in this file 6 | import networkx as nx 7 | import matplotlib.pyplot as plt 8 | except: 9 | pass 10 | import numpy as np 11 | import jax.numpy as jnp 12 | 13 | import tensorflow_datasets as tfds 14 | import pickle 15 | import clrs 16 | 17 | def compare_dataset_io(file_name1, file_name2, algorithm, split): 18 | dataset1 = tfds.load(f'clrs_dataset/{algorithm}_{split}', 19 | data_dir=file_name1, split=split).as_numpy_iterator() 20 | dataset2 = tfds.load(f'clrs_dataset/{algorithm}_{split}', 21 | data_dir=file_name2, split=split).as_numpy_iterator() 22 | ok = 0 23 | for d1, d2 in zip(dataset1, dataset2): 24 | for k in d1.keys(): 25 | assert (d1[k]-d2[k]).sum() < 0.0000001, f"Mis Match!, {k}, {d1[k]}, {d2[k]}, {ok}, {algorithm}" 26 | ok += 1 27 | print(f"OK: {ok}") 28 | 29 | 30 | def load_and_save_logits(rng_key, eval_model: clrs.models.BaselineModel, param_file: str, sampler, sample_count, spec, 31 | batch_size, file_name, save_file=True): 32 | eval_model.restore_model(param_file, replace_params=True) 33 | predict_fn = eval_model.predict 34 | 35 | processed_samples = 0 36 | inputs = [] 37 | preds = [] 38 | hint_preds = [] 39 | outputs = [] 40 | output_logits = [] 41 | hints = [] 42 | lengths = [] 43 | while processed_samples < sample_count: 44 | feedback = next(sampler) 45 | inputs.append(feedback.features.inputs) 46 | outputs.append(feedback.outputs) 47 | rng_key, new_rng_key = jax.random.split(rng_key) 48 | cur_preds, (cur_hint_preds, _, _, hiddens), net_aux = predict_fn(rng_key, feedback.features) 49 | output_logits.append(net_aux['output_logits']) 50 | preds.append(cur_preds) 51 | lengths.append(feedback.features.lengths) 52 | hints.append(feedback.features.hints) 53 | hint_preds.append(cur_hint_preds) 54 | rng_key = new_rng_key 55 | processed_samples += batch_size 56 | outputs = _concat(outputs, axis=0) 57 | output_logits = _concat(output_logits, axis=0) 58 | preds = _concat(preds, axis=0) 59 | lengths = _concat(lengths, axis=0) 60 | inputs = _concat(inputs, axis=0) 61 | # for hints, axis=1 because hints have time dimension first 62 | hints = _concat(hints, axis=1) 63 | # for hint_preds, axis=0 because the time dim is unrolled as a list 64 | hint_preds = _concat(hint_preds, axis=0) 65 | if save_file: 66 | jnp.save(f'/data/smahdavi/tmp/{file_name}.npy', dict( 67 | inputs=inputs, preds=preds, outputs=outputs, hints=hints, lengths=lengths, hint_preds=hint_preds, 68 | output_logits=output_logits 69 | )) 70 | return preds, outputs, hints, lengths, hint_preds, spec 71 | 72 | def _concat(dps, axis): 73 | return jax.tree_util.tree_map(lambda *x: jnp.concatenate(x, axis), *dps) 74 | 75 | 76 | def load_and_save_model_soup(param_files: list, output_file:str, coefs: list): 77 | param_states = [] 78 | opt_state = [] 79 | for file_name in param_files: 80 | with open(file_name, 'rb') as f: 81 | restored_state = pickle.load(f) 82 | param_states.append(restored_state['params']) 83 | opt_state = restored_state['opt_state'] 84 | param_state = jax.tree_util.tree_map(lambda *x: sum([x_ * coef_ for (x_, coef_) in zip(x, coefs)]), *param_states) 85 | to_save = {'params': param_state, 'opt_state': opt_state} 86 | path = output_file 87 | with open(path, 'wb') as f: 88 | pickle.dump(to_save, f) 89 | 90 | 91 | 92 | def visualize_graph(A: jnp.ndarray, bridges: jnp.ndarray, pred_bridges: jnp.ndarray): 93 | A = A - jnp.diag(A.diagonal()) 94 | G = nx.from_numpy_array(A).to_undirected() 95 | plt.figure(figsize=(20, 20)) 96 | pos = nx.spring_layout(G, seed=1, scale=2, k=0.25) 97 | 98 | if bridges is None: 99 | bridges = A * 2 - 1 100 | logging.warning("Replacing bridges with all edges") 101 | 102 | pred_bridges = jnp.where(bridges != -1, pred_bridges, bridges) 103 | bridge_list = [(node1.item(), node2.item()) for (node1, node2) in zip(*jnp.where(bridges == 1))] 104 | pred_bridges_list = [(node1.item(), node2.item()) for (node1, node2) in zip(*jnp.where(pred_bridges == 1))] 105 | print(bridge_list) 106 | print(G.edges()) 107 | nx.draw_networkx_nodes(G, pos) 108 | nx.draw_networkx_labels(G, pos) 109 | nx.draw_networkx_edges(G, pos, edgelist=G.edges(), edge_color='black', alpha=0.5) 110 | nx.draw_networkx_edges(G, pos, edgelist=bridge_list, edge_color='green', style='dashed') 111 | nx.draw_networkx_edges(G, pos, edgelist=pred_bridges_list, edge_color='red', style='dashed') 112 | plt.savefig('/tmp/1.png') 113 | plt.close() 114 | 115 | 116 | def _generate_bridge_graphs(seed, n_nodes, p=0.3): # A Bit larger P 117 | np_rng = np.random.RandomState(seed) 118 | 119 | def _random_connected_graph(rng): 120 | while True: 121 | mat = rng.binomial(1, p, size=(n_nodes//2, n_nodes//2)) 122 | mat *= np.transpose(mat) 123 | G = nx.from_numpy_array(mat).to_undirected() 124 | if nx.is_connected(G): 125 | break 126 | return mat 127 | 128 | mat1 = _random_connected_graph(np_rng) 129 | mat2 = _random_connected_graph(np_rng) 130 | block_with_bridge = np.zeros((n_nodes//2, n_nodes//2)) 131 | block_with_bridge[np_rng.random_integers(0, n_nodes//2-1), np_rng.random_integers(0, n_nodes//2-1)] = 1 132 | block_no_bridge = np.copy(block_with_bridge) 133 | while True: 134 | block_no_bridge[np_rng.random_integers(0, n_nodes//2-1), np_rng.random_integers(0, n_nodes//2-1)] = 1 135 | if ((block_no_bridge - block_with_bridge) ** 2).sum() > 0.5: 136 | break 137 | 138 | mat_bridge = np.block([ 139 | [mat1, block_with_bridge], 140 | [block_with_bridge.T, mat1], 141 | ]) 142 | mat_no_bridge = np.block([ 143 | [mat2, block_no_bridge], 144 | [block_no_bridge.T, mat2], 145 | ]) 146 | output = np.block([ 147 | [np.zeros((n_nodes//2, n_nodes//2)), block_with_bridge], 148 | [block_with_bridge.T, np.zeros((n_nodes//2, n_nodes//2))], 149 | ]) * 2 - 1 # 0 -> -1 150 | 151 | return mat_bridge, mat_no_bridge, output 152 | 153 | def _add_self_loops(mat): 154 | adj = np.copy(mat) 155 | np.fill_diagonal(adj, 1) 156 | return adj 157 | 158 | def _get_feedback_sampler(example_feedbacks, n_iters): 159 | batch_size, n_nodes = example_feedbacks.features.inputs[0].data.shape[:2] 160 | for i in range(n_iters): 161 | A_batch = np.zeros_like(example_feedbacks.features.inputs[1].data) 162 | adj_batch = np.zeros_like(A_batch) 163 | out_batch = np.zeros((batch_size // 2, n_nodes, n_nodes)) 164 | for j in range(batch_size//2): 165 | mat_bridge, mat_no_bridge, output = _generate_bridge_graphs(i, n_nodes) 166 | adj_bridge, adj_no_bridge = _add_self_loops(mat_bridge), _add_self_loops(mat_no_bridge) 167 | A_batch[2*j] = mat_bridge 168 | A_batch[2*j+1] = mat_no_bridge 169 | adj_batch[2*j] = adj_bridge 170 | adj_batch[2*j+1] = adj_no_bridge 171 | out_batch[j] = output 172 | inputs = [] 173 | for inp in example_feedbacks.features.inputs: 174 | if inp.name == 'A': 175 | data = A_batch 176 | elif inp.name == 'adj': 177 | data = adj_batch 178 | else: 179 | data = inp.data 180 | inputs.append(clrs.DataPoint(inp.name, inp.location, inp.type_, data)) 181 | feedback = clrs.Feedback( 182 | clrs.Features( 183 | tuple(inputs), tuple(), example_feedbacks.features.lengths 184 | ), np.stack(out_batch) 185 | ) 186 | yield feedback 187 | return 188 | 189 | def _eval_bridge_outputs(outputs, preds): 190 | b, n, _ = outputs.shape 191 | b2 = preds['is_bridge'].data.shape[0] 192 | assert 2 * b == b2 193 | correct = 0 194 | for i in range(b): 195 | x_bridge = jnp.where(outputs[i] == 1, preds['is_bridge'].data[2 * i], 0).sum() 196 | x_no_bridge = jnp.where(outputs[i] == 1, preds['is_bridge'].data[2 * i + 1], 0).sum() 197 | print( 198 | i, 199 | x_bridge, 200 | x_no_bridge, 201 | ) 202 | if (x_bridge > 0.5) and (x_no_bridge < 0.5): 203 | correct += 1 204 | print(correct / b) 205 | 206 | def eval_bridges(rng_key, eval_model: clrs.models.BaselineModel, param_file: str, sampler, sample_count, spec, 207 | batch_size, file_name, save_file=True): 208 | eval_model.restore_model(param_file, replace_params=True) 209 | predict_fn = eval_model.predict 210 | 211 | processed_samples = 0 212 | inputs = [] 213 | preds = [] 214 | hint_preds = [] 215 | outputs = [] 216 | output_logits = [] 217 | hints = [] 218 | lengths = [] 219 | example_feedback = next(sampler) 220 | for feedback in _get_feedback_sampler(example_feedback, n_iters=125): 221 | inputs.append(feedback.features.inputs) 222 | outputs.append(feedback.outputs) 223 | rng_key, new_rng_key = jax.random.split(rng_key) 224 | cur_preds, (cur_hint_preds, _, _, hiddens), net_aux = predict_fn(rng_key, feedback.features) 225 | output_logits.append(net_aux['output_logits']) 226 | preds.append(cur_preds) 227 | lengths.append(feedback.features.lengths) 228 | hints.append(feedback.features.hints) 229 | hint_preds.append(cur_hint_preds) 230 | rng_key = new_rng_key 231 | outputs = _concat(outputs, axis=0) 232 | output_logits = _concat(output_logits, axis=0) 233 | preds = _concat(preds, axis=0) 234 | lengths = _concat(lengths, axis=0) 235 | inputs = _concat(inputs, axis=0) 236 | _eval_bridge_outputs(outputs, preds) 237 | 238 | # print("Hi") 239 | # vis_idx = 71*2 240 | # visualize_graph(inputs[1].data[vis_idx], outputs[vis_idx // 2], preds['is_bridge'].data[vis_idx]) 241 | # visualize_graph(inputs[1]) 242 | return 243 | -------------------------------------------------------------------------------- /clrs/_src/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSL-Lab/clrs/8de0e6a42be569a8ed85decaf60620d3233ca67e/clrs/_src/third_party/__init__.py -------------------------------------------------------------------------------- /clrs/_src/third_party/haiku_transformer.py: -------------------------------------------------------------------------------- 1 | # Changed causal attention to fully connected 2 | 3 | 4 | # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved. 5 | # 6 | # Licensed under the Apache License, Version 2.0 (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at 9 | # 10 | # http://www.apache.org/licenses/LICENSE-2.0 11 | # 12 | # Unless required by applicable law or agreed to in writing, software 13 | # distributed under the License is distributed on an "AS IS" BASIS, 14 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 | # See the License for the specific language governing permissions and 16 | # limitations under the License. 17 | # ============================================================================== 18 | """Transformer model components.""" 19 | 20 | from typing import Optional 21 | 22 | import haiku as hk 23 | import jax 24 | import jax.numpy as jnp 25 | import numpy as np 26 | 27 | 28 | class DenseBlock(hk.Module): 29 | """A 2-layer MLP which widens then narrows the input.""" 30 | 31 | def __init__(self, 32 | init_scale: float, 33 | widening_factor: float = 1, 34 | name: Optional[str] = None): 35 | super().__init__(name=name) 36 | self._init_scale = init_scale 37 | self._widening_factor = widening_factor 38 | 39 | def __call__(self, x: jnp.ndarray) -> jnp.ndarray: 40 | hiddens = x.shape[-1] 41 | initializer = hk.initializers.VarianceScaling(self._init_scale) 42 | x = hk.Linear(int(self._widening_factor * hiddens), w_init=initializer)(x) 43 | x = jax.nn.gelu(x) 44 | return hk.Linear(hiddens, w_init=initializer)(x) 45 | 46 | 47 | class Transformer(hk.Module): 48 | """A transformer stack.""" 49 | 50 | def __init__(self, 51 | num_heads: int, 52 | num_layers: int, 53 | dropout_rate: float, 54 | name: Optional[str] = None, 55 | add_ln: bool = False): 56 | super().__init__(name=name) 57 | self._num_layers = num_layers 58 | self._num_heads = num_heads 59 | self._dropout_rate = dropout_rate 60 | self._add_ln = add_ln 61 | 62 | def __call__(self, 63 | h: jnp.ndarray, 64 | mask: Optional[jnp.ndarray], 65 | is_training: bool) -> jnp.ndarray: 66 | """Connects the transformer. 67 | Args: 68 | h: Inputs, [B, T, D]. 69 | mask: Padding mask, [B, T]. 70 | is_training: Whether we're training or not. 71 | Returns: 72 | Array of shape [B, T, D]. 73 | """ 74 | 75 | init_scale = 2. / self._num_layers 76 | dropout_rate = self._dropout_rate if is_training else 0. 77 | if mask is not None: 78 | mask = mask[:, None, :] # None for the head part 79 | 80 | # Note: names chosen to approximately match those used in the GPT-2 code; 81 | # see https://github.com/openai/gpt-2/blob/master/src/model.py. 82 | for i in range(self._num_layers): 83 | h_norm = layer_norm(h, name=f'h{i}_ln_1') 84 | h_attn = hk.MultiHeadAttention( 85 | num_heads=self._num_heads, 86 | key_size=h.shape[-1], 87 | model_size=h.shape[-1], 88 | w_init_scale=init_scale, 89 | name=f'h{i}_attn')(query=h_norm, key=h_norm, value=h_norm, mask=mask) 90 | h_attn = hk.dropout(hk.next_rng_key(), dropout_rate, h_attn) 91 | h = h + h_attn 92 | h_norm = layer_norm(h, name=f'h{i}_ln_2') 93 | h_dense = DenseBlock(init_scale, name=f'h{i}_mlp')(h_norm) 94 | h_dense = hk.dropout(hk.next_rng_key(), dropout_rate, h_dense) 95 | h = h + h_dense 96 | if self._add_ln: 97 | h = layer_norm(h, name='ln_f') 98 | 99 | return h 100 | 101 | 102 | def layer_norm(x: jnp.ndarray, name: Optional[str] = None) -> jnp.ndarray: 103 | """Apply a unique LayerNorm to x with default settings.""" 104 | return hk.LayerNorm(axis=-1, 105 | create_scale=True, 106 | create_offset=True, 107 | name=name)(x) 108 | -------------------------------------------------------------------------------- /clrs/clrs_test.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Import test for CLRS.""" 17 | 18 | from absl.testing import absltest 19 | import clrs 20 | 21 | 22 | class ClrsTest(absltest.TestCase): 23 | """Test CLRS can be imported correctly.""" 24 | 25 | def test_import(self): 26 | self.assertTrue(hasattr(clrs, 'Model')) 27 | 28 | 29 | if __name__ == '__main__': 30 | absltest.main() 31 | -------------------------------------------------------------------------------- /clrs/models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """The CLRS Algorithmic Reasoning Benchmark.""" 17 | 18 | from clrs._src.baselines import BaselineModel 19 | from clrs._src.baselines import BaselineModelChunked 20 | from clrs._src.nets import Net 21 | from clrs._src.nets import NetChunked 22 | from clrs._src.processors import GAT 23 | from clrs._src.processors import MPNN 24 | 25 | __all__ = ( 26 | "BaselineModel", 27 | "BaselineModelChunked", 28 | "GAT", 29 | "MPNN", 30 | "Net", 31 | "NetChunked", 32 | ) 33 | -------------------------------------------------------------------------------- /clrs/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DSL-Lab/clrs/8de0e6a42be569a8ed85decaf60620d3233ca67e/clrs/py.typed -------------------------------------------------------------------------------- /requirements/requirements.txt: -------------------------------------------------------------------------------- 1 | six>=1.15.0 2 | absl-py>=0.13.0 3 | attrs>=21.4.0 4 | chex>=0.0.8 5 | dm-haiku>=0.0.4 6 | jax>=0.2.18 7 | jaxlib>=0.1.69 8 | numpy>=1.21.1 9 | opt-einsum>=3.3.0 10 | optax>=0.0.9 11 | tensorflow>=2.9.0 12 | tfds-nightly==4.5.2.dev202204190046 13 | toolz>=0.11.1 14 | ml-collections>=0.1.1 15 | wandb>=0.12.21 16 | -------------------------------------------------------------------------------- /scripts/make_datasets.sh: -------------------------------------------------------------------------------- 1 | venv_path=${CLRS_VENV_PATH:-/path/to/clrs_venv/bin/activate} 2 | clrs_root=${CLRS_ROOT:-/path/to/clrs_code/clrs} 3 | clrs_dataset_path=${CLRS_DATASET_PATH:-/path/to/save/clrs_datasets} 4 | 5 | source ${venv_path} 6 | 7 | # Make sure correct version of CLRS is being used 8 | python3 -c 'import clrs; clrs.Sampler._random_er_or_k_reg_graph' || { echo 'Error: Either CLRS is not installed, or another version of CLRS is being used' ; exit 1; } 9 | 10 | graph_algs=(2 3 7 8 9 11 20 21 27 29) # Graph Algorithm indices from specs.py, excluding articulation points and bridges 11 | all_algs=(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29) # All algorithms 12 | 13 | ### No Hint L 14 | 15 | for alg_idx in "${all_algs[@]}" 16 | do 17 | for offset in 0 30 60 # Train validation test 18 | do 19 | idx=$((alg_idx + offset)) 20 | train_samples="100000" 21 | val_samples="32" 22 | test_samples="32" 23 | trainval_length="16" 24 | test_length="64" 25 | disable_hints="true" 26 | sampling_strategy="standard" 27 | make_dataset="True" 28 | 29 | echo "Generating L-CLRS for algorithm ${alg_idx} and split ${offset}" 30 | CLRS_MAKE_DATASET=${make_dataset} \ 31 | CLRS_DISABLE_HINTS=${disable_hints} \ 32 | CLRS_SAMPLING_STRATEGY=${sampling_strategy} \ 33 | CLRS_TRAIN_SAMPLES=${train_samples} \ 34 | CLRS_TRAIN_LENGTH=${trainval_length} \ 35 | CLRS_VAL_SAMPLES=${val_samples} \ 36 | CLRS_VAL_LENGTH=${trainval_length} \ 37 | CLRS_TEST_SAMPLES=${test_samples} \ 38 | CLRS_TEST_LENGTH=${test_length} \ 39 | tfds build "${clrs_root}"/clrs/_src/dataset.py --data_dir "${clrs_dataset_path}"/CLRS30_${sampling_strategy}_L/ --config_idx=${idx} 40 | 41 | 42 | if [[ ! " ${graph_algs[*]} " =~ " ${alg_idx} " ]]; then 43 | continue 44 | fi 45 | 46 | train_samples="100000" 47 | val_samples="1000" 48 | test_samples="1000" 49 | trainval_length="16" 50 | test_length="32" 51 | 52 | sampling_strategy="reg_same_deg" 53 | 54 | echo "Generating L-CLRS-Len for algorithm ${alg_idx} and split ${offset}" 55 | CLRS_MAKE_DATASET=${make_dataset} \ 56 | CLRS_DISABLE_HINTS=${disable_hints} \ 57 | CLRS_SAMPLING_STRATEGY=${sampling_strategy} \ 58 | CLRS_TRAIN_SAMPLES=${train_samples} \ 59 | CLRS_TRAIN_LENGTH=${trainval_length} \ 60 | CLRS_VAL_SAMPLES=${val_samples} \ 61 | CLRS_VAL_LENGTH=${trainval_length} \ 62 | CLRS_TEST_SAMPLES=${test_samples} \ 63 | CLRS_TEST_LENGTH=${test_length} \ 64 | tfds build "${clrs_root}"/clrs/_src/dataset.py --data_dir "${clrs_dataset_path}"/CLRS30_32_${sampling_strategy}_L/ --config_idx=${idx} 65 | 66 | sampling_strategy="reg" 67 | 68 | echo "Generating L-CLRS-Len-Deg for algorithm ${alg_idx} and split ${offset}" 69 | CLRS_MAKE_DATASET=${make_dataset} \ 70 | CLRS_DISABLE_HINTS=${disable_hints} \ 71 | CLRS_SAMPLING_STRATEGY=${sampling_strategy} \ 72 | CLRS_TRAIN_SAMPLES=${train_samples} \ 73 | CLRS_TRAIN_LENGTH=${trainval_length} \ 74 | CLRS_VAL_SAMPLES=${val_samples} \ 75 | CLRS_VAL_LENGTH=${trainval_length} \ 76 | CLRS_TEST_SAMPLES=${test_samples} \ 77 | CLRS_TEST_LENGTH=${test_length} \ 78 | tfds build "${clrs_root}"/clrs/_src/dataset.py --data_dir "${clrs_dataset_path}"/CLRS30_32_${sampling_strategy}_L/ --config_idx=${idx} 79 | 80 | 81 | sampling_strategy="reg_same_nodes" 82 | trainval_length="32" 83 | 84 | echo "Generating L-CLRS-Deg for algorithm ${alg_idx} and split ${offset}" 85 | CLRS_MAKE_DATASET=${make_dataset} \ 86 | CLRS_DISABLE_HINTS=${disable_hints} \ 87 | CLRS_SAMPLING_STRATEGY=${sampling_strategy} \ 88 | CLRS_TRAIN_SAMPLES=${train_samples} \ 89 | CLRS_TRAIN_LENGTH=${trainval_length} \ 90 | CLRS_VAL_SAMPLES=${val_samples} \ 91 | CLRS_VAL_LENGTH=${trainval_length} \ 92 | CLRS_TEST_SAMPLES=${test_samples} \ 93 | CLRS_TEST_LENGTH=${test_length} \ 94 | tfds build "${clrs_root}"/clrs/_src/dataset.py --data_dir "${clrs_dataset_path}"/CLRS30_32_${sampling_strategy}_L/ --config_idx=${idx} 95 | done 96 | done 97 | -------------------------------------------------------------------------------- /scripts/run_experiments.sh: -------------------------------------------------------------------------------- 1 | TMPDIR=${TMPDIR:-/tmp} 2 | data_root=${CLRS_DATASET_PATH:-/path/to/saved/clrs_datasets} 3 | log_root=${CLRS_LOG_PATH:-/tmp/clrs_logs} 4 | checkpoint_root=${CLRS_CHECKPOINT_PATH:-/tmp/clrs_checkpoints} 5 | steps=${steps:-20000} 6 | 7 | all_algorithms=("articulation_points" "activity_selector" "bellman_ford" "bfs" "binary_search" "bridges" "bubble_sort" "dag_shortest_paths" "dfs" "dijkstra" "find_maximum_subarray_kadane" "floyd_warshall" "graham_scan" "heapsort" "insertion_sort" "jarvis_march" "kmp_matcher" "lcs_length" "matrix_chain_order" "minimum" "mst_kruskal" "mst_prim" "naive_string_matcher" "optimal_bst" "quickselect" "quicksort" "segments_intersect" "strongly_connected_components" "task_scheduling" "topological_sort") 8 | distinct_algorithms=("articulation_points" "activity_selector" "bellman_ford" "bfs" "binary_search" "bridges" "dag_shortest_paths" "dfs" "find_maximum_subarray_kadane" "floyd_warshall" "graham_scan" "lcs_length" "matrix_chain_order" "minimum" "mst_kruskal" "mst_prim" "naive_string_matcher" "optimal_bst" "quickselect" "quicksort" "segments_intersect" "strongly_connected_components" "task_scheduling" "topological_sort") 9 | 10 | # Processor Comparisons, Table 5 Row 1 11 | # mpnn processor is MPNN-FC in the paper 12 | # pgn_mpnn processor is MPNN-G in the paper 13 | # edge_att processor is 2WL in the paper 14 | for algorithm in "${distinct_algorithms[@]}" 15 | do 16 | batch_size_mp=32 17 | batch_size_2wl=16 18 | dataset="${data_root}/CLRS30_standard_L" 19 | train_items_mp=$(( steps*batch_size_mp )) 20 | train_items_2wl=$(( steps*batch_size_2wl )) 21 | exp_name="main_table2" 22 | hidden_size=128 23 | hidden_size_hybrid=108 # Stay in the same parameter budget 24 | hint_mode="none" 25 | for seed in {42..44} 26 | do 27 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/pgn_mpnn --seed=${seed} --batch_size=${batch_size_mp} --processor_type pgn_mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items_mp} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 28 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/2wl_2 --seed=${seed} --batch_size=${batch_size_2wl} --exp_flags.infrequent_test_eval=True --processor_type edge_att --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 29 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/hybrid_avg_2 --seed=${seed} --exp_flags.hybrid_type=avg --exp_flags.infrequent_test_eval=True --batch_size=${batch_size_2wl} --processor_type hybrid --hint_mode=${hint_mode} --hidden_size ${hidden_size_hybrid} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 30 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/hybrid_sigmoid --seed=${seed} --exp_flags.hybrid_type=sigmoid --exp_flags.infrequent_test_eval=True --batch_size=${batch_size_2wl} --processor_type hybrid --hint_mode=${hint_mode} --hidden_size ${hidden_size_hybrid} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 31 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/hybrid_avg_pp --seed=${seed} --exp_flags.hybrid_processors=p_p --exp_flags.hybrid_type=avg --exp_flags.infrequent_test_eval=True --batch_size=${batch_size_2wl} --processor_type hybrid --hint_mode=${hint_mode} --hidden_size ${hidden_size_hybrid} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 32 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/hybrid_avg_ee --seed=${seed} --exp_flags.hybrid_processors=e_e --exp_flags.hybrid_type=avg --exp_flags.infrequent_test_eval=True --batch_size=${batch_size_2wl} --processor_type hybrid --hint_mode=${hint_mode} --hidden_size ${hidden_size_hybrid} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 33 | done 34 | done 35 | 36 | # Processors + Random Scalar Position, Table 5 row 2 37 | for algorithm in "${distinct_algorithms[@]}" 38 | do 39 | batch_size_mp=32 40 | batch_size_2wl=16 41 | dataset="${data_root}/CLRS30_standard_L" 42 | train_items_mp=$(( steps*batch_size_mp )) 43 | train_items_2wl=$(( steps*batch_size_2wl )) 44 | exp_name="final_results" 45 | hidden_size=128 46 | hidden_size_hybrid=108 # Stay in the same parameter budget 47 | hint_mode="none" 48 | for seed in {42..44} 49 | do 50 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/pgn_mpnn --exp_flags.random_pos=True --seed=${seed} --batch_size=${batch_size_mp} --processor_type pgn_mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items_mp} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 51 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/2wl_2 --exp_flags.random_pos=True --seed=${seed} --batch_size=${batch_size_2wl} --exp_flags.infrequent_test_eval=True --processor_type edge_att --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 52 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/hybrid_avg_2 --exp_flags.random_pos=True --seed=${seed} --exp_flags.hybrid_type=avg --exp_flags.infrequent_test_eval=True --batch_size=${batch_size_2wl} --processor_type hybrid --hint_mode=${hint_mode} --hidden_size ${hidden_size_hybrid} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 53 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/hybrid_sigmoid --exp_flags.random_pos=True --seed=${seed} --exp_flags.hybrid_type=sigmoid --exp_flags.infrequent_test_eval=True --batch_size=${batch_size_2wl} --processor_type hybrid --hint_mode=${hint_mode} --hidden_size ${hidden_size_hybrid} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 54 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/hybrid_avg_pp --exp_flags.random_pos=True --seed=${seed} --exp_flags.hybrid_processors=p_p --exp_flags.hybrid_type=avg --exp_flags.infrequent_test_eval=True --batch_size=${batch_size} --processor_type hybrid --hint_mode=${hint_mode} --hidden_size ${hidden_size_hybrid} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 55 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/hybrid_avg_ee --exp_flags.random_pos=True --seed=${seed} --exp_flags.hybrid_processors=e_e --exp_flags.hybrid_type=avg --exp_flags.infrequent_test_eval=True --batch_size=${batch_size} --processor_type hybrid --hint_mode=${hint_mode} --hidden_size ${hidden_size_hybrid} --algorithm ${algorithm} --train_items ${train_items_2wl} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 56 | done 57 | done 58 | 59 | 60 | # Position Encoding Ablations, Tables 4 and 6 61 | for algorithm in "naive_string_matcher" "dfs" "find_maximum_subarray_kadane" "dag_shortest_paths" "matrix_chain_order" "topological_sort" "bfs" 62 | do 63 | batch_size=32 64 | dataset="${data_root}/CLRS30_standard_L" 65 | train_items=$(( steps*batch_size )) 66 | exp_name="pos_encodings_ablation2" 67 | hidden_size=128 68 | hint_mode="none" 69 | for seed in {42..44} 70 | do 71 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/pgn_mpnn/standard --seed=${seed} --batch_size=${batch_size} --processor_type pgn_mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 72 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/pgn_mpnn/random_pos --seed=${seed} --exp_flags.random_pos=True --batch_size=${batch_size} --processor_type pgn_mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 73 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/pgn_mpnn/trans_pos_enc --seed=${seed} --exp_flags.trans_pos_enc=True --batch_size=${batch_size} --processor_type pgn_mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 74 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/mpnn/standard --seed=${seed} --batch_size=${batch_size} --processor_type mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 75 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/mpnn/edgewise_pos --seed=${seed} --exp_flags.edgewise_pos=True --batch_size=${batch_size} --processor_type mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 76 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/mpnn/random_pos --seed=${seed} --exp_flags.random_pos=True --batch_size=${batch_size} --processor_type mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --algorithm ${algorithm} --train_items ${train_items} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 77 | done 78 | done 79 | 80 | # Generalization Modes Ablations, Table 2 81 | # CLRS30_32_reg_same_nodes_L is when number of nodes stay the same and corresponds to L-CLRS-Deg in the paper 82 | # CLRS30_32_reg_same_deg_L is when degree stays the same, and corresponds to L-CLRS-Len in the paper 83 | # CLRS30_32_reg_L is when both degree and number of nodes change, and corresponds to L-CLRS-Len-Deg in the paper 84 | for algorithm in "bellman_ford" "bfs" "dag_shortest_paths" "dfs" "floyd_warshall" "mst_kruskal" "mst_prim" "strongly_connected_components" "topological_sort" 85 | do 86 | batch_size=8 87 | train_items=$(( steps*batch_size )) 88 | exp_name="ood_mode_table2" 89 | hint_mode="none" 90 | hidden_size=256 91 | for seed in {42..44} 92 | do 93 | for dataset_name in "CLRS30_32_reg_same_nodes_L" "CLRS30_32_reg_L" "CLRS30_32_reg_same_deg_L" 94 | do 95 | dataset="${data_root}/${dataset_name}" 96 | python3 -m clrs.examples.run --log_prefix ${exp_name}/${algorithm}/pgn_mpnn_${dataset_name} --batch_size=${batch_size} --processor_type pgn_mpnn --hint_mode=${hint_mode} --hidden_size ${hidden_size} --seed=${seed} --algorithm ${algorithm} --train_items ${train_items} --checkpoint_path "${checkpoint_root}" --dataset_path "${dataset}" --log_path "${log_root}" 97 | done 98 | done 99 | done 100 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Install script for setuptools.""" 17 | 18 | import os 19 | from setuptools import find_namespace_packages 20 | from setuptools import setup 21 | 22 | _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | 24 | 25 | def _get_version(): 26 | with open('clrs/__init__.py') as fp: 27 | for line in fp: 28 | if line.startswith('__version__') and '=' in line: 29 | version = line[line.find('=') + 1:].strip(' \'"\n') 30 | if version: 31 | return version 32 | raise ValueError('`__version__` not defined in `clrs/__init__.py`') 33 | 34 | 35 | def _parse_requirements(path): 36 | 37 | with open(os.path.join(_CURRENT_DIR, path)) as f: 38 | packages = [] 39 | for line in f: 40 | line = line.strip() 41 | # let's also ignore empty lines and comments 42 | if not line or line.startswith('#'): 43 | continue 44 | if 'https://' in line: 45 | tail = line.rsplit('/', 1)[1] 46 | tail = tail.split('#')[0] 47 | line = tail.replace('@', '==').replace('.git', '') 48 | packages.append(line) 49 | return packages 50 | 51 | 52 | setup( 53 | name='dm-clrs', 54 | version=_get_version(), 55 | url='https://github.com/deepmind/clrs', 56 | license='Apache 2.0', 57 | author='DeepMind', 58 | description=('The CLRS Algorithmic Reasoning Benchmark.'), 59 | long_description=open(os.path.join(_CURRENT_DIR, 'README.md')).read(), 60 | long_description_content_type='text/markdown', 61 | author_email='clrs-dev@google.com', 62 | keywords='python machine learning', 63 | packages=find_namespace_packages(exclude=['*_test.py']), 64 | install_requires=_parse_requirements( 65 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')), 66 | tests_require=_parse_requirements( 67 | os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')), 68 | zip_safe=False, # Required for full installation. 69 | python_requires='>=3.6', 70 | classifiers=[ 71 | 'Development Status :: 4 - Beta', 72 | 'Environment :: Console', 73 | 'Intended Audience :: Science/Research', 74 | 'License :: OSI Approved :: Apache Software License', 75 | 'Operating System :: POSIX :: Linux', 76 | 'Operating System :: Microsoft :: Windows', 77 | 'Operating System :: MacOS :: MacOS X', 78 | 'Programming Language :: Python :: 3', 79 | 'Programming Language :: Python :: 3.6', 80 | 'Programming Language :: Python :: 3.7', 81 | 'Topic :: Scientific/Engineering :: Artificial Intelligence', 82 | ], 83 | ) 84 | --------------------------------------------------------------------------------