├── .github ├── ISSUE_TEMPLATE.md └── PULL_REQUEST_TEMPLATE.md ├── .gitignore ├── .pylintrc ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── checkpoints ├── slice-mel-512.pkl └── slice-multi-fb512.pkl ├── config.py ├── configs ├── ddpm-base.cfg ├── ddpm-mel-1seq-512.cfg ├── ddpm-mel-32seq-512-large.cfg ├── ddpm-mel-32seq-512.cfg ├── ddpm-multi-1seq-512.cfg ├── ddpm-multi-32seq-512.cfg ├── mdn-base.cfg ├── mdn-mel-32seq-512-large.cfg ├── mdn-mel-32seq-512.cfg ├── mixture │ ├── mixture-single-2.cfg │ └── mixture-single-ddpm-2.cfg ├── ncsn-mel-1seq-512.cfg └── ncsn-multi-1seq-512.cfg ├── input_pipeline.py ├── models ├── autoregressive.py ├── ncsn.py └── shared.py ├── requirements.txt ├── sample_mdn.py ├── sample_ncsn.py ├── scripts ├── decode_dataset_beam.py ├── generate_compressed_transform.py ├── generate_song_data_beam.py ├── sample_audio.py └── transform_encoded_data.py ├── train_mdn.py ├── train_ncsn.py └── utils ├── data_utils.py ├── ebm_utils.py ├── losses.py ├── metrics.py ├── plot_utils.py ├── song_utils.py └── train_utils.py /.github/ISSUE_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | ## Expected Behavior 2 | 3 | 4 | ## Actual Behavior 5 | 6 | 7 | ## Steps to Reproduce the Problem 8 | 9 | 1. 10 | 1. 11 | 1. 12 | 13 | ## Specifications 14 | 15 | - Version: 16 | - Platform: -------------------------------------------------------------------------------- /.github/PULL_REQUEST_TEMPLATE.md: -------------------------------------------------------------------------------- 1 | Fixes # 2 | 3 | > It's a good idea to open an issue first for discussion. 4 | 5 | - [ ] Tests pass 6 | - [ ] Appropriate changes to README are included in PR -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints 2 | .DS_Store 3 | .vscode 4 | __pycache__ 5 | output 6 | .style.yapf 7 | format.sh 8 | save 9 | *.mid 10 | checkpoints/eval 11 | *.wav 12 | experiments/ -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | # This Pylint rcfile contains a best-effort configuration to uphold the 2 | # best-practices and style described in the Google Python style guide: 3 | # https://google.github.io/styleguide/pyguide.html 4 | # 5 | # Its canonical open-source location is: 6 | # https://google.github.io/styleguide/pylintrc 7 | 8 | [MASTER] 9 | 10 | # Files or directories to be skipped. They should be base names, not paths. 11 | ignore=third_party 12 | 13 | # Files or directories matching the regex patterns are skipped. The regex 14 | # matches against base names, not paths. 15 | ignore-patterns= 16 | 17 | # Pickle collected data for later comparisons. 18 | persistent=no 19 | 20 | # List of plugins (as comma separated values of python modules names) to load, 21 | # usually to register additional checkers. 22 | load-plugins= 23 | 24 | # Use multiple processes to speed up Pylint. 25 | jobs=4 26 | 27 | # Allow loading of arbitrary C extensions. Extensions are imported into the 28 | # active Python interpreter and may run arbitrary code. 29 | unsafe-load-any-extension=no 30 | 31 | 32 | [MESSAGES CONTROL] 33 | 34 | # Only show warnings with the listed confidence levels. Leave empty to show 35 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED 36 | confidence= 37 | 38 | # Enable the message, report, category or checker with the given id(s). You can 39 | # either give multiple identifier separated by comma (,) or put this option 40 | # multiple time (only on the command line, not in the configuration file where 41 | # it should appear only once). See also the "--disable" option for examples. 42 | #enable= 43 | 44 | # Disable the message, report, category or checker with the given id(s). You 45 | # can either give multiple identifiers separated by comma (,) or put this 46 | # option multiple times (only on the command line, not in the configuration 47 | # file where it should appear only once).You can also use "--disable=all" to 48 | # disable everything first and then reenable specific checks. For example, if 49 | # you want to run only the similarities checker, you can use "--disable=all 50 | # --enable=similarities". If you want to run only the classes checker, but have 51 | # no Warning level messages displayed, use"--disable=all --enable=classes 52 | # --disable=W" 53 | disable=abstract-method, 54 | apply-builtin, 55 | arguments-differ, 56 | attribute-defined-outside-init, 57 | backtick, 58 | bad-option-value, 59 | basestring-builtin, 60 | buffer-builtin, 61 | c-extension-no-member, 62 | consider-using-enumerate, 63 | cmp-builtin, 64 | cmp-method, 65 | coerce-builtin, 66 | coerce-method, 67 | delslice-method, 68 | div-method, 69 | duplicate-code, 70 | eq-without-hash, 71 | execfile-builtin, 72 | file-builtin, 73 | filter-builtin-not-iterating, 74 | fixme, 75 | getslice-method, 76 | global-statement, 77 | hex-method, 78 | idiv-method, 79 | implicit-str-concat-in-sequence, 80 | import-error, 81 | import-self, 82 | import-star-module-level, 83 | inconsistent-return-statements, 84 | input-builtin, 85 | intern-builtin, 86 | invalid-str-codec, 87 | locally-disabled, 88 | long-builtin, 89 | long-suffix, 90 | map-builtin-not-iterating, 91 | misplaced-comparison-constant, 92 | missing-function-docstring, 93 | metaclass-assignment, 94 | next-method-called, 95 | next-method-defined, 96 | no-absolute-import, 97 | no-else-break, 98 | no-else-continue, 99 | no-else-raise, 100 | no-else-return, 101 | no-init, # added 102 | no-member, 103 | no-name-in-module, 104 | no-self-use, 105 | nonzero-method, 106 | oct-method, 107 | old-division, 108 | old-ne-operator, 109 | old-octal-literal, 110 | old-raise-syntax, 111 | parameter-unpacking, 112 | print-statement, 113 | raising-string, 114 | range-builtin-not-iterating, 115 | raw_input-builtin, 116 | rdiv-method, 117 | reduce-builtin, 118 | relative-import, 119 | reload-builtin, 120 | round-builtin, 121 | setslice-method, 122 | signature-differs, 123 | standarderror-builtin, 124 | suppressed-message, 125 | sys-max-int, 126 | too-few-public-methods, 127 | too-many-ancestors, 128 | too-many-arguments, 129 | too-many-boolean-expressions, 130 | too-many-branches, 131 | too-many-instance-attributes, 132 | too-many-locals, 133 | too-many-nested-blocks, 134 | too-many-public-methods, 135 | too-many-return-statements, 136 | too-many-statements, 137 | trailing-newlines, 138 | unichr-builtin, 139 | unicode-builtin, 140 | unnecessary-pass, 141 | unpacking-in-except, 142 | useless-else-on-loop, 143 | useless-object-inheritance, 144 | useless-suppression, 145 | using-cmp-argument, 146 | wrong-import-order, 147 | xrange-builtin, 148 | zip-builtin-not-iterating, 149 | 150 | 151 | [REPORTS] 152 | 153 | # Set the output format. Available formats are text, parseable, colorized, msvs 154 | # (visual studio) and html. You can also give a reporter class, eg 155 | # mypackage.mymodule.MyReporterClass. 156 | output-format=text 157 | 158 | # Put messages in a separate file for each module / package specified on the 159 | # command line instead of printing them on stdout. Reports (if any) will be 160 | # written in a file name "pylint_global.[txt|html]". This option is deprecated 161 | # and it will be removed in Pylint 2.0. 162 | files-output=no 163 | 164 | # Tells whether to display a full report or only the messages 165 | reports=no 166 | 167 | # Python expression which should return a note less than 10 (10 is the highest 168 | # note). You have access to the variables errors warning, statement which 169 | # respectively contain the number of errors / warnings messages and the total 170 | # number of statements analyzed. This is used by the global evaluation report 171 | # (RP0004). 172 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 173 | 174 | # Template used to display messages. This is a python new-style format string 175 | # used to format the message information. See doc for all details 176 | #msg-template= 177 | 178 | 179 | [BASIC] 180 | 181 | # Good variable names which should always be accepted, separated by a comma 182 | good-names=main,_ 183 | 184 | # Bad variable names which should always be refused, separated by a comma 185 | bad-names= 186 | 187 | # Colon-delimited sets of names that determine each other's naming style when 188 | # the name regexes allow several styles. 189 | name-group= 190 | 191 | # Include a hint for the correct naming format with invalid-name 192 | include-naming-hint=no 193 | 194 | # List of decorators that produce properties, such as abc.abstractproperty. Add 195 | # to this list to register other decorators that produce valid properties. 196 | property-classes=abc.abstractproperty,cached_property.cached_property,cached_property.threaded_cached_property,cached_property.cached_property_with_ttl,cached_property.threaded_cached_property_with_ttl 197 | 198 | # Regular expression matching correct function names 199 | function-rgx=^(?:(?PsetUp|tearDown|setUpModule|tearDownModule)|(?P_?[A-Z][a-zA-Z0-9]*)|(?P_?[a-z][a-z0-9_]*))$ 200 | 201 | # Regular expression matching correct variable names 202 | variable-rgx=^[a-z][a-z0-9_]*$ 203 | 204 | # Regular expression matching correct constant names 205 | const-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 206 | 207 | # Regular expression matching correct attribute names 208 | attr-rgx=^_{0,2}[a-z][a-z0-9_]*$ 209 | 210 | # Regular expression matching correct argument names 211 | argument-rgx=^[a-z][a-z0-9_]*$ 212 | 213 | # Regular expression matching correct class attribute names 214 | class-attribute-rgx=^(_?[A-Z][A-Z0-9_]*|__[a-z0-9_]+__|_?[a-z][a-z0-9_]*)$ 215 | 216 | # Regular expression matching correct inline iteration names 217 | inlinevar-rgx=^[a-z][a-z0-9_]*$ 218 | 219 | # Regular expression matching correct class names 220 | class-rgx=^_?[A-Z][a-zA-Z0-9]*$ 221 | 222 | # Regular expression matching correct module names 223 | module-rgx=^(_?[a-z][a-z0-9_]*|__init__)$ 224 | 225 | # Regular expression matching correct method names 226 | method-rgx=(?x)^(?:(?P_[a-z0-9_]+__|runTest|setUp|tearDown|setUpTestCase|tearDownTestCase|setupSelf|tearDownClass|setUpClass|(test|assert)_*[A-Z0-9][a-zA-Z0-9_]*|next)|(?P_{0,2}[A-Z][a-zA-Z0-9_]*)|(?P_{0,2}[a-z][a-z0-9_]*))$ 227 | 228 | # Regular expression which should only match function or class names that do 229 | # not require a docstring. 230 | no-docstring-rgx=(__.*__|main|test.*|.*test|.*Test)$ 231 | 232 | # Minimum line length for functions/classes that require docstrings, shorter 233 | # ones are exempt. 234 | docstring-min-length=10 235 | 236 | 237 | [TYPECHECK] 238 | 239 | # List of decorators that produce context managers, such as 240 | # contextlib.contextmanager. Add to this list to register other decorators that 241 | # produce valid context managers. 242 | contextmanager-decorators=contextlib.contextmanager,contextlib2.contextmanager 243 | 244 | # Tells whether missing members accessed in mixin class should be ignored. A 245 | # mixin class is detected if its name ends with "mixin" (case insensitive). 246 | ignore-mixin-members=yes 247 | 248 | # List of module names for which member attributes should not be checked 249 | # (useful for modules/projects where namespaces are manipulated during runtime 250 | # and thus existing member attributes cannot be deduced by static analysis. It 251 | # supports qualified module names, as well as Unix pattern matching. 252 | ignored-modules= 253 | 254 | # List of class names for which member attributes should not be checked (useful 255 | # for classes with dynamically set attributes). This supports the use of 256 | # qualified names. 257 | ignored-classes=optparse.Values,thread._local,_thread._local 258 | 259 | # List of members which are set dynamically and missed by pylint inference 260 | # system, and so shouldn't trigger E1101 when accessed. Python regular 261 | # expressions are accepted. 262 | generated-members= 263 | 264 | 265 | [FORMAT] 266 | 267 | # Maximum number of characters on a single line. 268 | max-line-length=80 269 | 270 | # TODO(https://github.com/PyCQA/pylint/issues/3352): Direct pylint to exempt 271 | # lines made too long by directives to pytype. 272 | 273 | # Regexp for a line that is allowed to be longer than the limit. 274 | ignore-long-lines=(?x)( 275 | ^\s*(\#\ )??$| 276 | ^\s*(from\s+\S+\s+)?import\s+.+$) 277 | 278 | # Allow the body of an if to be on the same line as the test if there is no 279 | # else. 280 | single-line-if-stmt=yes 281 | 282 | # List of optional constructs for which whitespace checking is disabled. `dict- 283 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 284 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 285 | # `empty-line` allows space-only lines. 286 | no-space-check= 287 | 288 | # Maximum number of lines in a module 289 | max-module-lines=99999 290 | 291 | # String used as indentation unit. The internal Google style guide mandates 2 292 | # spaces. Google's externaly-published style guide says 4, consistent with 293 | # PEP 8. Here, we use 2 spaces, for conformity with many open-sourced Google 294 | # projects (like TensorFlow). 295 | indent-string=' ' 296 | 297 | # Number of spaces of indent required inside a hanging or continued line. 298 | indent-after-paren=4 299 | 300 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 301 | expected-line-ending-format= 302 | 303 | 304 | [MISCELLANEOUS] 305 | 306 | # List of note tags to take in consideration, separated by a comma. 307 | notes=TODO 308 | 309 | 310 | [STRING] 311 | 312 | # This flag controls whether inconsistent-quotes generates a warning when the 313 | # character used as a quote delimiter is used inconsistently within a module. 314 | check-quote-consistency=yes 315 | 316 | 317 | [VARIABLES] 318 | 319 | # Tells whether we should check for unused import in __init__ files. 320 | init-import=no 321 | 322 | # A regular expression matching the name of dummy variables (i.e. expectedly 323 | # not used). 324 | dummy-variables-rgx=^\*{0,2}(_$|unused_|dummy_) 325 | 326 | # List of additional names supposed to be defined in builtins. Remember that 327 | # you should avoid to define new builtins when possible. 328 | additional-builtins= 329 | 330 | # List of strings which can identify a callback function by name. A callback 331 | # name must start or end with one of those strings. 332 | callbacks=cb_,_cb 333 | 334 | # List of qualified module names which can have objects that can redefine 335 | # builtins. 336 | redefining-builtins-modules=six,six.moves,past.builtins,future.builtins,functools 337 | 338 | 339 | [LOGGING] 340 | 341 | # Logging modules to check that the string format arguments are in logging 342 | # function parameter format 343 | logging-modules=logging,absl.logging,tensorflow.io.logging 344 | 345 | 346 | [SIMILARITIES] 347 | 348 | # Minimum lines number of a similarity. 349 | min-similarity-lines=4 350 | 351 | # Ignore comments when computing similarities. 352 | ignore-comments=yes 353 | 354 | # Ignore docstrings when computing similarities. 355 | ignore-docstrings=yes 356 | 357 | # Ignore imports when computing similarities. 358 | ignore-imports=no 359 | 360 | 361 | [SPELLING] 362 | 363 | # Spelling dictionary name. Available dictionaries: none. To make it working 364 | # install python-enchant package. 365 | spelling-dict= 366 | 367 | # List of comma separated words that should not be checked. 368 | spelling-ignore-words= 369 | 370 | # A path to a file that contains private dictionary; one word per line. 371 | spelling-private-dict-file= 372 | 373 | # Tells whether to store unknown words to indicated private dictionary in 374 | # --spelling-private-dict-file option instead of raising a message. 375 | spelling-store-unknown-words=no 376 | 377 | 378 | [IMPORTS] 379 | 380 | # Deprecated modules which should not be used, separated by a comma 381 | deprecated-modules=regsub, 382 | TERMIOS, 383 | Bastion, 384 | rexec, 385 | sets 386 | 387 | # Create a graph of every (i.e. internal and external) dependencies in the 388 | # given file (report RP0402 must not be disabled) 389 | import-graph= 390 | 391 | # Create a graph of external dependencies in the given file (report RP0402 must 392 | # not be disabled) 393 | ext-import-graph= 394 | 395 | # Create a graph of internal dependencies in the given file (report RP0402 must 396 | # not be disabled) 397 | int-import-graph= 398 | 399 | # Force import order to recognize a module as part of the standard 400 | # compatibility libraries. 401 | known-standard-library= 402 | 403 | # Force import order to recognize a module as part of a third party library. 404 | known-third-party=enchant, absl 405 | 406 | # Analyse import fallback blocks. This can be used to support both Python 2 and 407 | # 3 compatible code, which means that the block might have code that exists 408 | # only in one or another interpreter, leading to false positives when analysed. 409 | analyse-fallback-blocks=no 410 | 411 | 412 | [CLASSES] 413 | 414 | # List of method names used to declare (i.e. assign) instance attributes. 415 | defining-attr-methods=__init__, 416 | __new__, 417 | setUp 418 | 419 | # List of member names, which should be excluded from the protected access 420 | # warning. 421 | exclude-protected=_asdict, 422 | _fields, 423 | _replace, 424 | _source, 425 | _make 426 | 427 | # List of valid names for the first argument in a class method. 428 | valid-classmethod-first-arg=cls, 429 | class_ 430 | 431 | # List of valid names for the first argument in a metaclass class method. 432 | valid-metaclass-classmethod-first-arg=mcs 433 | 434 | 435 | [EXCEPTIONS] 436 | 437 | # Exceptions that will emit a warning when being caught. Defaults to 438 | # "Exception" 439 | overgeneral-exceptions=StandardError, 440 | Exception, 441 | BaseException 442 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # How to Contribute 2 | 3 | We'd love to accept your patches and contributions to this project. There are 4 | just a few small guidelines you need to follow. 5 | 6 | ## Contributor License Agreement 7 | 8 | Contributions to this project must be accompanied by a Contributor License 9 | Agreement. You (or your employer) retain the copyright to your contribution; 10 | this simply gives us permission to use and redistribute your contributions as 11 | part of the project. Head over to to see 12 | your current agreements on file or to sign a new one. 13 | 14 | You generally only need to submit a CLA once, so if you've already submitted one 15 | (even if it was for a different project), you probably don't need to do it 16 | again. 17 | 18 | ## Code Reviews 19 | 20 | All submissions, including submissions by project members, require review. We 21 | use GitHub pull requests for this purpose. Consult 22 | [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more 23 | information on using pull requests. 24 | 25 | ## Community Guidelines 26 | 27 | This project follows [Google's Open Source Community 28 | Guidelines](https://opensource.google/conduct/). 29 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | 2 | Apache License 3 | Version 2.0, January 2004 4 | http://www.apache.org/licenses/ 5 | 6 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 7 | 8 | 1. Definitions. 9 | 10 | "License" shall mean the terms and conditions for use, reproduction, 11 | and distribution as defined by Sections 1 through 9 of this document. 12 | 13 | "Licensor" shall mean the copyright owner or entity authorized by 14 | the copyright owner that is granting the License. 15 | 16 | "Legal Entity" shall mean the union of the acting entity and all 17 | other entities that control, are controlled by, or are under common 18 | control with that entity. For the purposes of this definition, 19 | "control" means (i) the power, direct or indirect, to cause the 20 | direction or management of such entity, whether by contract or 21 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 22 | outstanding shares, or (iii) beneficial ownership of such entity. 23 | 24 | "You" (or "Your") shall mean an individual or Legal Entity 25 | exercising permissions granted by this License. 26 | 27 | "Source" form shall mean the preferred form for making modifications, 28 | including but not limited to software source code, documentation 29 | source, and configuration files. 30 | 31 | "Object" form shall mean any form resulting from mechanical 32 | transformation or translation of a Source form, including but 33 | not limited to compiled object code, generated documentation, 34 | and conversions to other media types. 35 | 36 | "Work" shall mean the work of authorship, whether in Source or 37 | Object form, made available under the License, as indicated by a 38 | copyright notice that is included in or attached to the work 39 | (an example is provided in the Appendix below). 40 | 41 | "Derivative Works" shall mean any work, whether in Source or Object 42 | form, that is based on (or derived from) the Work and for which the 43 | editorial revisions, annotations, elaborations, or other modifications 44 | represent, as a whole, an original work of authorship. For the purposes 45 | of this License, Derivative Works shall not include works that remain 46 | separable from, or merely link (or bind by name) to the interfaces of, 47 | the Work and Derivative Works thereof. 48 | 49 | "Contribution" shall mean any work of authorship, including 50 | the original version of the Work and any modifications or additions 51 | to that Work or Derivative Works thereof, that is intentionally 52 | submitted to Licensor for inclusion in the Work by the copyright owner 53 | or by an individual or Legal Entity authorized to submit on behalf of 54 | the copyright owner. For the purposes of this definition, "submitted" 55 | means any form of electronic, verbal, or written communication sent 56 | to the Licensor or its representatives, including but not limited to 57 | communication on electronic mailing lists, source code control systems, 58 | and issue tracking systems that are managed by, or on behalf of, the 59 | Licensor for the purpose of discussing and improving the Work, but 60 | excluding communication that is conspicuously marked or otherwise 61 | designated in writing by the copyright owner as "Not a Contribution." 62 | 63 | "Contributor" shall mean Licensor and any individual or Legal Entity 64 | on behalf of whom a Contribution has been received by Licensor and 65 | subsequently incorporated within the Work. 66 | 67 | 2. Grant of Copyright License. Subject to the terms and conditions of 68 | this License, each Contributor hereby grants to You a perpetual, 69 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 70 | copyright license to reproduce, prepare Derivative Works of, 71 | publicly display, publicly perform, sublicense, and distribute the 72 | Work and such Derivative Works in Source or Object form. 73 | 74 | 3. Grant of Patent License. Subject to the terms and conditions of 75 | this License, each Contributor hereby grants to You a perpetual, 76 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 77 | (except as stated in this section) patent license to make, have made, 78 | use, offer to sell, sell, import, and otherwise transfer the Work, 79 | where such license applies only to those patent claims licensable 80 | by such Contributor that are necessarily infringed by their 81 | Contribution(s) alone or by combination of their Contribution(s) 82 | with the Work to which such Contribution(s) was submitted. If You 83 | institute patent litigation against any entity (including a 84 | cross-claim or counterclaim in a lawsuit) alleging that the Work 85 | or a Contribution incorporated within the Work constitutes direct 86 | or contributory patent infringement, then any patent licenses 87 | granted to You under this License for that Work shall terminate 88 | as of the date such litigation is filed. 89 | 90 | 4. Redistribution. You may reproduce and distribute copies of the 91 | Work or Derivative Works thereof in any medium, with or without 92 | modifications, and in Source or Object form, provided that You 93 | meet the following conditions: 94 | 95 | (a) You must give any other recipients of the Work or 96 | Derivative Works a copy of this License; and 97 | 98 | (b) You must cause any modified files to carry prominent notices 99 | stating that You changed the files; and 100 | 101 | (c) You must retain, in the Source form of any Derivative Works 102 | that You distribute, all copyright, patent, trademark, and 103 | attribution notices from the Source form of the Work, 104 | excluding those notices that do not pertain to any part of 105 | the Derivative Works; and 106 | 107 | (d) If the Work includes a "NOTICE" text file as part of its 108 | distribution, then any Derivative Works that You distribute must 109 | include a readable copy of the attribution notices contained 110 | within such NOTICE file, excluding those notices that do not 111 | pertain to any part of the Derivative Works, in at least one 112 | of the following places: within a NOTICE text file distributed 113 | as part of the Derivative Works; within the Source form or 114 | documentation, if provided along with the Derivative Works; or, 115 | within a display generated by the Derivative Works, if and 116 | wherever such third-party notices normally appear. The contents 117 | of the NOTICE file are for informational purposes only and 118 | do not modify the License. You may add Your own attribution 119 | notices within Derivative Works that You distribute, alongside 120 | or as an addendum to the NOTICE text from the Work, provided 121 | that such additional attribution notices cannot be construed 122 | as modifying the License. 123 | 124 | You may add Your own copyright statement to Your modifications and 125 | may provide additional or different license terms and conditions 126 | for use, reproduction, or distribution of Your modifications, or 127 | for any such Derivative Works as a whole, provided Your use, 128 | reproduction, and distribution of the Work otherwise complies with 129 | the conditions stated in this License. 130 | 131 | 5. Submission of Contributions. Unless You explicitly state otherwise, 132 | any Contribution intentionally submitted for inclusion in the Work 133 | by You to the Licensor shall be under the terms and conditions of 134 | this License, without any additional terms or conditions. 135 | Notwithstanding the above, nothing herein shall supersede or modify 136 | the terms of any separate license agreement you may have executed 137 | with Licensor regarding such Contributions. 138 | 139 | 6. Trademarks. This License does not grant permission to use the trade 140 | names, trademarks, service marks, or product names of the Licensor, 141 | except as required for reasonable and customary use in describing the 142 | origin of the Work and reproducing the content of the NOTICE file. 143 | 144 | 7. Disclaimer of Warranty. Unless required by applicable law or 145 | agreed to in writing, Licensor provides the Work (and each 146 | Contributor provides its Contributions) on an "AS IS" BASIS, 147 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 148 | implied, including, without limitation, any warranties or conditions 149 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 150 | PARTICULAR PURPOSE. You are solely responsible for determining the 151 | appropriateness of using or redistributing the Work and assume any 152 | risks associated with Your exercise of permissions under this License. 153 | 154 | 8. Limitation of Liability. In no event and under no legal theory, 155 | whether in tort (including negligence), contract, or otherwise, 156 | unless required by applicable law (such as deliberate and grossly 157 | negligent acts) or agreed to in writing, shall any Contributor be 158 | liable to You for damages, including any direct, indirect, special, 159 | incidental, or consequential damages of any character arising as a 160 | result of this License or out of the use or inability to use the 161 | Work (including but not limited to damages for loss of goodwill, 162 | work stoppage, computer failure or malfunction, or any and all 163 | other commercial damages or losses), even if such Contributor 164 | has been advised of the possibility of such damages. 165 | 166 | 9. Accepting Warranty or Additional Liability. While redistributing 167 | the Work or Derivative Works thereof, You may choose to offer, 168 | and charge a fee for, acceptance of support, warranty, indemnity, 169 | or other liability obligations and/or rights consistent with this 170 | License. However, in accepting such obligations, You may act only 171 | on Your own behalf and on Your sole responsibility, not on behalf 172 | of any other Contributor, and only if You agree to indemnify, 173 | defend, and hold each Contributor harmless for any liability 174 | incurred by, or claims asserted against, such Contributor by reason 175 | of your accepting any such warranty or additional liability. 176 | 177 | END OF TERMS AND CONDITIONS 178 | 179 | APPENDIX: How to apply the Apache License to your work. 180 | 181 | To apply the Apache License to your work, attach the following 182 | boilerplate notice, with the fields enclosed by brackets "[]" 183 | replaced with your own identifying information. (Don't include 184 | the brackets!) The text should be enclosed in the appropriate 185 | comment syntax for the file format. We also recommend that a 186 | file or class name and description of purpose be included on the 187 | same "printed page" as the copyright notice for easier 188 | identification within third-party archives. 189 | 190 | Copyright [yyyy] [name of copyright owner] 191 | 192 | Licensed under the Apache License, Version 2.0 (the "License"); 193 | you may not use this file except in compliance with the License. 194 | You may obtain a copy of the License at 195 | 196 | http://www.apache.org/licenses/LICENSE-2.0 197 | 198 | Unless required by applicable law or agreed to in writing, software 199 | distributed under the License is distributed on an "AS IS" BASIS, 200 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 201 | See the License for the specific language governing permissions and 202 | limitations under the License. 203 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Symbolic Music Generation with Diffusion Models 2 | Supplementary code release for our work [Symbolic Music Generation with Diffusion Models](https://archives.ismir.net/ismir2021/paper/000058.pdf). 3 | 4 | ## Installation 5 | All code is written in Python 3 ([Anaconda](https://www.anaconda.com/) recommended). To install the dependencies: 6 | 7 | ``` 8 | pip install -r requirements.txt 9 | ``` 10 | 11 | A copy of the [Magenta](https://github.com/magenta/magenta) codebase is required for access to MusicVAE and related components. [Installation instructions](https://github.com/magenta/magenta#installation) can be found on the Magenta public repository. You will also need to download [pretrained MusicVAE checkpoints](https://github.com/magenta/magenta/tree/master/magenta/models/music_vae). For our experiments, we use the [2-bar melody model](https://storage.googleapis.com/magentadata/models/music_vae/checkpoints/cat-mel_2bar_big.tar). 12 | 13 | ## Datasets 14 | We use the [Lakh MIDI Dataset](https://colinraffel.com/projects/lmd/) to train our models. Follow [these instructions](https://github.com/magenta/magenta/blob/master/magenta/scripts/README.md) to download and build the Lakh MIDI Dataset. 15 | 16 | 17 | To encode the Lakh dataset with MusicVAE, use `scripts/generate_song_data_beam.py`: 18 | ``` 19 | python scripts/generate_song_data_beam.py \ 20 | --checkpoint=/path/to/musicvae-ckpt \ 21 | --input=/path/to/lakh_tfrecords \ 22 | --output=/path/to/encoded_tfrecords 23 | ``` 24 | 25 | To preprocess and generate fixed-length latent sequences for training diffusion and autoregressive models, refer to `scripts/transform_encoded_data.py`: 26 | ``` 27 | python scripts/transform_encoded_data.py \ 28 | --encoded_data=/path/to/encoded_tfrecords \ 29 | --output_path =/path/to/preprocess_tfrecords \ 30 | --mode=sequences \ 31 | --context_length=32 32 | ``` 33 | ## Training 34 | #### Diffusion 35 | ```python train_ncsn.py --flagfile=configs/ddpm-mel-32seq-512.cfg``` 36 | 37 | #### TransformerMDN 38 | ```python train_mdn.py --flagfile=configs/mdn-mel-32seq-512.cfg``` 39 | 40 | ## Sampling and Generation 41 | #### Diffusion 42 | ``` 43 | python sample_ncsn.py \ 44 | --flagfile=configs/ddpm-mel-32seq-512.cfg \ 45 | --sample_seed=42 \ 46 | --sample_size=1000 \ 47 | --sampling_dir=/path/to/latent-samples 48 | ``` 49 | 50 | #### TransformerMDN 51 | ``` 52 | python sample_ncsn.py \ 53 | --flagfile=configs/mdn-mel-32seq-512.cfg \ 54 | --sample_seed=42 \ 55 | --sample_size=1000 \ 56 | --sampling_dir=/path/to/latent-samples 57 | ``` 58 | 59 | #### Decoding sequences 60 | To convert sequences of embeddings (generated by diffusion or TransformerMDN models) to sequences of MIDI events, refer to `scripts/sample_audio.py`. 61 | 62 | ``` 63 | python scripts/sample_audio.py 64 | --input=/path/to/latent-samples/[ncsn|mdn] \ 65 | --output=/path/to/audio-midi \ 66 | --n_synth=1000 \ 67 | --include_wav=True 68 | ``` 69 | 70 | ## Citing 71 | If you use this code please cite it as: 72 | 73 | ``` 74 | @inproceedings{ 75 | mittal2021symbolicdiffusion, 76 | title={Symbolic Music Generation with Diffusion Models}, 77 | author={Gautam Mittal and Jesse Engel and Curtis Hawthorne and Ian Simon}, 78 | booktitle={Proceedings of the 22nd International Society for Music Information Retrieval Conference}, 79 | year={2021}, 80 | url={https://archives.ismir.net/ismir2021/paper/000058.pdf} 81 | } 82 | ``` 83 | 84 | ## Note 85 | This is not an official Google product. 86 | -------------------------------------------------------------------------------- /checkpoints/slice-mel-512.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magenta/symbolic-music-diffusion/469204d73c39af94302be0c748327ef8c432307a/checkpoints/slice-mel-512.pkl -------------------------------------------------------------------------------- /checkpoints/slice-multi-fb512.pkl: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/magenta/symbolic-music-diffusion/469204d73c39af94302be0c748327ef8c432307a/checkpoints/slice-multi-fb512.pkl -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Model configurations.""" 17 | from magenta.models.music_vae import configs 18 | from magenta.models.music_vae import data 19 | from magenta.models.music_vae import data_hierarchical 20 | 21 | MUSIC_VAE_CONFIG = {} 22 | 23 | melody_2bar_converter = data.OneHotMelodyConverter( 24 | skip_polyphony=False, 25 | max_bars=100, # Truncate long melodies before slicing. 26 | max_tensors_per_notesequence=None, 27 | slice_bars=2, 28 | gap_bars=None, 29 | steps_per_quarter=4, 30 | dedupe_event_lists=False) 31 | 32 | mel_2bar_nopoly_converter = data.OneHotMelodyConverter( 33 | skip_polyphony=True, 34 | max_bars=100, # Truncate long melodies before slicing. 35 | max_tensors_per_notesequence=None, 36 | slice_bars=2, 37 | gap_bars=None, 38 | steps_per_quarter=4, 39 | dedupe_event_lists=False) 40 | 41 | melody_16bar_converter = data.OneHotMelodyConverter( 42 | skip_polyphony=False, 43 | max_bars=100, # Truncate long melodies before slicing. 44 | slice_bars=16, 45 | gap_bars=16, 46 | max_tensors_per_notesequence=None, 47 | steps_per_quarter=4, 48 | dedupe_event_lists=False) 49 | 50 | multitrack_default_1bar_converter = data_hierarchical.MultiInstrumentPerformanceConverter( 51 | num_velocity_bins=8, 52 | hop_size_bars=1, 53 | min_num_instruments=2, 54 | max_num_instruments=8, 55 | max_events_per_instrument=64) 56 | 57 | multitrack_zero_1bar_converter = data_hierarchical.MultiInstrumentPerformanceConverter( 58 | num_velocity_bins=8, 59 | hop_size_bars=1, 60 | min_num_instruments=0, 61 | max_num_instruments=8, 62 | min_total_events=0, 63 | max_events_per_instrument=64, 64 | drop_tracks_and_truncate=True) 65 | 66 | MUSIC_VAE_CONFIG['melody-2-big'] = configs.CONFIG_MAP[ 67 | 'cat-mel_2bar_big']._replace(data_converter=melody_2bar_converter) 68 | 69 | MUSIC_VAE_CONFIG['melody-16-big'] = configs.CONFIG_MAP[ 70 | 'hierdec-mel_16bar']._replace(data_converter=melody_16bar_converter) 71 | 72 | MUSIC_VAE_CONFIG['multi-1-big'] = configs.CONFIG_MAP[ 73 | 'hier-multiperf_vel_1bar_big']._replace( 74 | data_converter=multitrack_default_1bar_converter) 75 | 76 | MUSIC_VAE_CONFIG['multi-0min-1-big'] = configs.CONFIG_MAP[ 77 | 'hier-multiperf_vel_1bar_big']._replace( 78 | data_converter=multitrack_zero_1bar_converter) 79 | 80 | MUSIC_VAE_CONFIG['melody-2-big-nopoly'] = configs.Config( 81 | model=configs.MusicVAE(configs.lstm_models.BidirectionalLstmEncoder(), 82 | configs.lstm_models.CategoricalLstmDecoder()), 83 | hparams=configs.merge_hparams( 84 | configs.lstm_models.get_default_hparams(), 85 | configs.HParams( 86 | batch_size=512, 87 | max_seq_len=32, # 2 bars w/ 16 steps per bar 88 | z_size=512, 89 | enc_rnn_size=[2048], 90 | dec_rnn_size=[2048, 2048, 2048], 91 | )), 92 | note_sequence_augmenter=data.NoteSequenceAugmenter(transpose_range=(-5, 5)), 93 | data_converter=mel_2bar_nopoly_converter) 94 | -------------------------------------------------------------------------------- /configs/ddpm-base.cfg: -------------------------------------------------------------------------------- 1 | --epochs=100 2 | --batch_size=64 3 | --learning_rate=1e-3 4 | --sigma_begin=1e-6 5 | --sigma_end=0.01 6 | --num_sigmas=1000 7 | --problem=vae 8 | --ema=False 9 | --continuous_noise 10 | --normalize 11 | --loss=ddpm 12 | --sampling=ddpm 13 | --schedule_type=linear 14 | --nosnapshot_sampling 15 | --max_steps=500000 16 | -------------------------------------------------------------------------------- /configs/ddpm-mel-1seq-512.cfg: -------------------------------------------------------------------------------- 1 | --architecture=DenseDDPM 2 | --epochs=10 3 | --batch_size=64 4 | --learning_rate=1e-3 5 | --sigma_begin=1e-6 6 | --sigma_end=0.01 7 | --num_sigmas=1000 8 | --problem=vae 9 | --ema 10 | --normalize 11 | --continuous_noise 12 | --loss=ddpm 13 | --sampling=ddpm 14 | --schedule_type=linear 15 | --data_shape=512 16 | --model_dir=save/mel512 17 | --dataset=./output/melody-512 18 | --nosnapshot_sampling 19 | --max_steps=250000 20 | -------------------------------------------------------------------------------- /configs/ddpm-mel-32seq-512-large.cfg: -------------------------------------------------------------------------------- 1 | --flagfile=configs/ddpm-mel-32seq-512.cfg 2 | --num_layers=8 3 | --num_heads=16 4 | --num_mlp_layers=3 5 | --mlp_dims=2048 6 | --model_dir=save/mel512-ddpm-32seq-large 7 | -------------------------------------------------------------------------------- /configs/ddpm-mel-32seq-512.cfg: -------------------------------------------------------------------------------- 1 | --flagfile=configs/ddpm-base.cfg 2 | --architecture=TransformerDDPM 3 | --num_layers=6 4 | --num_heads=8 5 | --num_mlp_layers=2 6 | --mlp_dims=2048 7 | --data_shape=32,512 8 | --dataset=./output/mel-32step-512 9 | --slice_ckpt=./checkpoints/slice-mel-512.pkl 10 | --model_dir=save/mel512-ddpm-32seq 11 | -------------------------------------------------------------------------------- /configs/ddpm-multi-1seq-512.cfg: -------------------------------------------------------------------------------- 1 | --architecture=DenseDDPM 2 | --epochs=10 3 | --batch_size=64 4 | --learning_rate=1e-3 5 | --sigma_begin=1e-6 6 | --sigma_end=0.01 7 | --num_sigmas=1000 8 | --problem=vae 9 | --ema 10 | --normalize 11 | --continuous_noise 12 | --loss=ddpm 13 | --sampling=ddpm 14 | --schedule_type=linear 15 | --data_shape=512 16 | --model_dir=save/lakh512 17 | --dataset=./output/multi-fb512 18 | --nosnapshot_sampling 19 | --max_steps=150000 20 | 21 | -------------------------------------------------------------------------------- /configs/ddpm-multi-32seq-512.cfg: -------------------------------------------------------------------------------- 1 | --architecture=TransformerDDPM4 2 | --epochs=40 3 | --batch_size=64 4 | --learning_rate=1e-3 5 | --sigma_begin=1e-6 6 | --sigma_end=0.01 7 | --num_sigmas=1000 8 | --problem=vae 9 | --ema=False 10 | --continuous_noise 11 | --normalize 12 | --loss=ddpm 13 | --sampling=ddpm 14 | --schedule_type=linear 15 | --data_shape=32,512 16 | --slice_ckpt=./checkpoints/slice-multi-fb512.pkl 17 | --dataset=./output/lakh-32step-512 18 | --model_dir=save/lakh512-32seq 19 | --nosnapshot_sampling 20 | --max_steps=500000 21 | -------------------------------------------------------------------------------- /configs/mdn-base.cfg: -------------------------------------------------------------------------------- 1 | --epochs=1000 2 | --learning_rate=3e-4 3 | --batch_size=128 4 | --max_steps=250000 5 | --mdn_components=100 6 | --num_layers=4 7 | --num_heads=8 8 | --num_mlp_layers=2 9 | --mlp_dims=2048 10 | -------------------------------------------------------------------------------- /configs/mdn-mel-32seq-512-large.cfg: -------------------------------------------------------------------------------- 1 | --flagfile=configs/mdn-mel-32seq-512.cfg 2 | --num_layers=8 3 | --num_heads=16 4 | --num_mlp_layers=3 5 | --mlp_dims=2048 6 | --model_dir=save/mel512-mdn-32seq 7 | -------------------------------------------------------------------------------- /configs/mdn-mel-32seq-512.cfg: -------------------------------------------------------------------------------- 1 | --flagfile=configs/mdn-base.cfg 2 | --architecture=TransformerMDN 3 | --num_layers=6 4 | --num_heads=8 5 | --num_mlp_layers=2 6 | --mlp_dims=2048 7 | --data_shape=32,512 8 | --dataset=./output/mel-32step-512 9 | --slice_ckpt=./checkpoints/slice-mel-512.pkl 10 | --model_dir=save/mel512-mdn-32seq 11 | -------------------------------------------------------------------------------- /configs/mixture/mixture-single-2.cfg: -------------------------------------------------------------------------------- 1 | --architecture=ToyNCSN 2 | --epochs=5 3 | --sigma_begin=0.5 4 | --num_sigmas=10 5 | --ld_epsilon=1.2e-5 6 | --ld_steps=10 7 | --problem=toy 8 | --ema 9 | --normalize 10 | --continuous_noise 11 | --schedule_type=linear 12 | --loss=ssm 13 | --data_shape=2 14 | --dataset=./output/mix2d 15 | --model_dir=save/toy2d 16 | -------------------------------------------------------------------------------- /configs/mixture/mixture-single-ddpm-2.cfg: -------------------------------------------------------------------------------- 1 | --architecture=ToyDDPM 2 | --epochs=10 3 | --batch_size=64 4 | --learning_rate=1e-3 5 | --sigma_begin=1e-6 6 | --sigma_end=0.01 7 | --num_sigmas=1000 8 | --problem=toy 9 | --ema 10 | --continuous_noise=True 11 | --loss=ddpm 12 | --sampling=ddpm 13 | --schedule_type=linear 14 | --data_shape=2 15 | --dataset=./output/mix2d 16 | --model_dir=save/ddpm2d 17 | -------------------------------------------------------------------------------- /configs/ncsn-mel-1seq-512.cfg: -------------------------------------------------------------------------------- 1 | --architecture=DenseNCSN 2 | --epochs=700 3 | --sigma_begin=15 4 | --num_sigmas=500 5 | --ld_epsilon=9.64e-7 6 | --ld_steps=100 7 | --problem=vae 8 | --ema 9 | --normalize 10 | --data_shape=512 11 | --model_dir=save/mel512 12 | --dataset=./output/melody-512 -------------------------------------------------------------------------------- /configs/ncsn-multi-1seq-512.cfg: -------------------------------------------------------------------------------- 1 | --architecture=DenseNCSN 2 | --epochs=700 3 | --sigma_begin=30 4 | --num_sigmas=500 5 | --ld_epsilon=9.91e-7 6 | --ld_steps=100 7 | --problem=vae 8 | --ema 9 | --normalize 10 | --data_shape=512 11 | --model_dir=save/lakh512 12 | --dataset=./output/multi-fb512 13 | -------------------------------------------------------------------------------- /input_pipeline.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Input data pipeline.""" 17 | import os 18 | import time 19 | 20 | import numpy as np 21 | import tensorflow as tf 22 | import tensorflow_datasets as tfds 23 | 24 | from absl import logging 25 | from functools import partial 26 | import utils.data_utils as data_utils 27 | 28 | AUTOTUNE = tf.data.experimental.AUTOTUNE 29 | 30 | 31 | def deconstruct_dict(batch_dict, problem): 32 | key = 'image' if problem == 'mnist' else 'inputs' 33 | return batch_dict[key] 34 | 35 | 36 | def normalize_dataset(batch, data_min, data_max): 37 | """Normalize dataset to range [-1, 1].""" 38 | batch = (batch - data_min) / (data_max - data_min) 39 | batch = 2. * batch - 1. 40 | return batch 41 | 42 | 43 | def slice_transform(batch, problem='vae', slice_idx=None, dim_weights=None): 44 | if dim_weights is not None: 45 | batch = batch * dim_weights 46 | if slice_idx is not None: 47 | batch = tf.gather(batch, slice_idx, axis=-1) 48 | return batch 49 | 50 | 51 | def data_transform(batch, problem='vae', pca=None): 52 | """Data transform. 53 | 54 | Args: 55 | batch: A batch of data samples. 56 | pca: PCA transform object. 57 | 58 | Returns: 59 | Transformed batch array. 60 | """ 61 | if problem == 'mnist': 62 | batch = tf.reshape(batch, (batch.shape[0], -1)) 63 | batch = tf.cast(batch, tf.float32) / 255. 64 | batch = 2. * batch - 1. 65 | 66 | if pca is not None: 67 | if batch.ndim > 2: 68 | init_shape = batch.shape 69 | batch = batch.reshape(batch.shape[0], -1) 70 | batch = pca.transform(batch) 71 | batch = batch.reshape(*init_shape) 72 | else: 73 | batch = pca.transform(batch) 74 | 75 | return batch 76 | 77 | 78 | def inverse_data_transform(batch, 79 | normalize=True, 80 | pca=None, 81 | data_min=0., 82 | data_max=1., 83 | slice_idx=None, 84 | dim_weights=None, 85 | out_channels=512): 86 | """Inverse data transform. 87 | 88 | Args: 89 | batch: Transformed batch array. 90 | pca: PCA transform object. 91 | 92 | Returns: 93 | Original batch array. 94 | """ 95 | if normalize: 96 | batch = (batch + 1.) / 2. 97 | batch = (data_max - data_min) * batch + data_min 98 | 99 | if pca is not None: 100 | batch = pca.inverse_transform(batch) 101 | 102 | if slice_idx is not None: 103 | transformed = np.random.randn(*batch.shape[:-1], out_channels) 104 | transformed[..., slice_idx] = batch 105 | batch = transformed 106 | 107 | if dim_weights is not None: 108 | batch = batch / dim_weights 109 | 110 | return batch 111 | 112 | 113 | def get_dataset(dataset='', 114 | data_shape=(2,), 115 | problem='vae', 116 | batch_size=128, 117 | normalize=True, 118 | pca_ckpt='', 119 | slice_ckpt='', 120 | dim_weights_ckpt='', 121 | include_cardinality=True): 122 | if problem == 'mnist': 123 | train_ds = tfds.load('mnist', split='train', shuffle_files=True) 124 | eval_ds = tfds.load('mnist', split='test', shuffle_files=True) 125 | elif problem in ['vae', 'toy', 'tokens']: 126 | shape = tuple(map(int, data_shape)) 127 | tokens = problem == 'tokens' 128 | train_ds = data_utils.get_tf_record_dataset( 129 | file_pattern=f'{dataset}/train-*.tfrecord', 130 | shape=shape, 131 | batch_size=batch_size, 132 | shuffle=True, 133 | tokens=tokens) 134 | eval_ds = data_utils.get_tf_record_dataset( 135 | file_pattern=f'{dataset}/eval-*.tfrecord', 136 | shape=shape, 137 | batch_size=batch_size, 138 | shuffle=True, 139 | tokens=tokens) 140 | else: 141 | raise ValueError(f'Unknown problem type: {problem}') 142 | 143 | # Dataset loading and transformation (PCA, Slice). 144 | pca = data_utils.load(os.path.expanduser(pca_ckpt)) if pca_ckpt else None 145 | slice_idx = data_utils.load( 146 | os.path.expanduser(slice_ckpt)) if slice_ckpt else None 147 | dim_weights = data_utils.load( 148 | os.path.expanduser(dim_weights_ckpt)) if dim_weights_ckpt else None 149 | 150 | # Batch. 151 | train_ds = train_ds.batch(batch_size, drop_remainder=True) 152 | eval_ds = eval_ds.batch(batch_size, drop_remainder=True) 153 | 154 | train_ds = train_ds.map(partial(deconstruct_dict, problem=problem), 155 | num_parallel_calls=AUTOTUNE) 156 | eval_ds = eval_ds.map(partial(deconstruct_dict, problem=problem), 157 | num_parallel_calls=AUTOTUNE) 158 | 159 | # PCA transform 160 | if problem != 'tokens': 161 | train_ds = train_ds.map(lambda example: tf.py_function( 162 | partial(data_transform, problem=problem, pca=pca), [example], tf.float32 163 | ), 164 | num_parallel_calls=AUTOTUNE) 165 | eval_ds = eval_ds.map(lambda example: tf.py_function( 166 | partial(data_transform, problem=problem, pca=pca), [example], tf.float32 167 | ), 168 | num_parallel_calls=AUTOTUNE) 169 | 170 | # Slice + weight transform 171 | train_ds = train_ds.map(partial(slice_transform, 172 | problem=problem, 173 | slice_idx=slice_idx, 174 | dim_weights=dim_weights), 175 | num_parallel_calls=AUTOTUNE) 176 | eval_ds = eval_ds.map(partial(slice_transform, 177 | problem=problem, 178 | slice_idx=slice_idx, 179 | dim_weights=dim_weights), 180 | num_parallel_calls=AUTOTUNE) 181 | 182 | # Dataset normalization. 183 | train_min, train_max = 0., 1. 184 | eval_min, eval_max = 0., 1. 185 | if normalize: 186 | logging.info('Normalizing dataset to have range [-1, 1].') 187 | config_name = pca_ckpt.split('/')[-1].split('.')[0] 188 | config_name += slice_ckpt.split('/')[-1].split('.')[0] 189 | config_name += dim_weights_ckpt.split('/')[-1].split('.')[0] 190 | train_min, train_max = data_utils.compute_dataset_min_max( 191 | train_ds, 192 | ds_split='train', 193 | cache=True, 194 | cache_dir=os.path.expanduser(dataset), 195 | config=config_name) 196 | eval_min, eval_max = data_utils.compute_dataset_min_max( 197 | eval_ds, 198 | ds_split='eval', 199 | cache=True, 200 | cache_dir=os.path.expanduser(dataset), 201 | config=config_name) 202 | train_ds = train_ds.map(lambda example: normalize_dataset( 203 | example, train_min, train_max), 204 | num_parallel_calls=AUTOTUNE) 205 | eval_ds = eval_ds.map(lambda example: normalize_dataset( 206 | example, eval_min, eval_max), 207 | num_parallel_calls=AUTOTUNE) 208 | 209 | train_ds = train_ds.prefetch(AUTOTUNE) 210 | eval_ds = eval_ds.prefetch(AUTOTUNE) 211 | eval_ds = eval_ds.cache() 212 | 213 | setattr(train_ds, 'min', train_min) 214 | setattr(train_ds, 'max', train_max) 215 | setattr(eval_ds, 'min', eval_min) 216 | setattr(eval_ds, 'max', eval_max) 217 | 218 | if include_cardinality: 219 | t0 = time.time() 220 | config_name = str(batch_size) 221 | data_utils.compute_dataset_cardinality( 222 | train_ds, 223 | ds_split='train', 224 | cache=True, 225 | cache_dir=os.path.expanduser(dataset), 226 | config=config_name) 227 | data_utils.compute_dataset_cardinality( 228 | eval_ds, 229 | ds_split='eval', 230 | cache=True, 231 | cache_dir=os.path.expanduser(dataset), 232 | config=config_name) 233 | logging.info('Computed dataset cardinality in %f seconds', time.time() - t0) 234 | 235 | return train_ds, eval_ds 236 | -------------------------------------------------------------------------------- /models/autoregressive.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Autoregressive models.""" 17 | import jax.numpy as jnp 18 | import jax 19 | 20 | from flax import jax_utils 21 | from flax import nn 22 | 23 | from models.shared import TransformerPositionalEncoding, DenseResBlock, MDN 24 | 25 | 26 | def shift_right(x): 27 | """Shift the input to the right by padding on axis 1.""" 28 | pad_widths = [(0, 0)] * len(x.shape) 29 | pad_widths[1] = (1, 0) # Padding on axis=1 30 | padded = jnp.pad(x, 31 | pad_widths, 32 | mode='constant', 33 | constant_values=x.dtype.type(0)) 34 | return padded[:, :-1] 35 | 36 | 37 | class TransformerMDN(nn.Module): 38 | """Transformer with continuous outputs.""" 39 | 40 | def apply(self, 41 | inputs, 42 | shift=True, 43 | num_layers=6, 44 | num_heads=8, 45 | num_mlp_layers=2, 46 | mlp_dims=2048, 47 | mdn_mixtures=100): 48 | batch_size, seq_len, data_channels = inputs.shape 49 | x = inputs 50 | if shift: 51 | x = shift_right(x) 52 | embed_channels = 128 53 | temb = TransformerPositionalEncoding(jnp.arange(seq_len), embed_channels) 54 | temb = temb[None, :, :] 55 | assert temb.shape[1:] == (seq_len, embed_channels), temb.shape 56 | x = nn.Dense(x, embed_channels) 57 | 58 | x = x + temb 59 | for _ in range(num_layers): 60 | shortcut = x 61 | x = nn.LayerNorm(x) 62 | x = nn.SelfAttention(x, causal_mask=True, num_heads=num_heads) 63 | x = x + shortcut 64 | shortcut2 = x 65 | x = nn.LayerNorm(x) 66 | x = nn.Dense(x, mlp_dims) 67 | x = nn.gelu(x) 68 | x = nn.Dense(x, embed_channels) 69 | x = x + shortcut2 70 | 71 | x = nn.LayerNorm(x) 72 | x = nn.Dense(x, mlp_dims) 73 | 74 | for _ in range(num_mlp_layers): 75 | x = DenseResBlock(x, mlp_dims) 76 | 77 | x = nn.LayerNorm(x) 78 | mdn = MDN.partial(out_channels=data_channels, 79 | num_components=mdn_mixtures, 80 | name='mdn') 81 | pi, mu, log_sigma = mdn(x) 82 | return pi, mu, log_sigma 83 | -------------------------------------------------------------------------------- /models/ncsn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Noise-conditional iterative refinement networks.""" 17 | import jax 18 | import jax.numpy as jnp 19 | from flax import jax_utils 20 | from flax import nn 21 | 22 | from models.shared import TransformerPositionalEncoding, DenseResBlock 23 | 24 | 25 | class NoiseEncoding(nn.Module): 26 | """Sinusoidal noise encoding block.""" 27 | 28 | def apply(self, noise, channels): 29 | # noise.shape = (batch_size, 1) 30 | # channels.shape = () 31 | noise = noise.squeeze(-1) 32 | assert len(noise.shape) == 1 33 | half_dim = channels // 2 34 | emb = jnp.log(10000) / float(half_dim - 1) 35 | emb = jnp.exp(jnp.arange(half_dim) * -emb) 36 | emb = 5000 * noise[:, None] * emb[None, :] 37 | emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1) 38 | if channels % 2 == 1: 39 | emb = jnp.pad(emb, [[0, 0], [0, 1]]) 40 | assert emb.shape == (noise.shape[0], channels) 41 | return emb 42 | 43 | 44 | class DenseFiLM(nn.Module): 45 | """Feature-wise linear modulation (FiLM) generator.""" 46 | 47 | def apply(self, position, embedding_channels, out_channels, sequence=False): 48 | # position.shape = (batch_size, 1) 49 | # embedding_channels.shape, out_channels.shape = (), () 50 | assert len(position.shape) == 2 51 | pos_encoding = NoiseEncoding(position, embedding_channels) 52 | pos_encoding = nn.Dense(pos_encoding, embedding_channels * 4) 53 | pos_encoding = nn.swish(pos_encoding) 54 | pos_encoding = nn.Dense(pos_encoding, embedding_channels * 4) 55 | 56 | if sequence: 57 | pos_encoding = pos_encoding[:, None, :] 58 | 59 | scale = nn.Dense(pos_encoding, out_channels) 60 | shift = nn.Dense(pos_encoding, out_channels) 61 | return scale, shift 62 | 63 | 64 | class ConvFiLM(nn.Module): 65 | """Convolutional FiLM generator.""" 66 | 67 | def apply(self, position, embedding_channels, out_channels): 68 | # noise.shape = (batch_size, 1, 1) 69 | # embedding_channels.shape, out_channels.shape = (), () 70 | assert len(position.shape) == 3 71 | position = position.squeeze(-1) 72 | pos_encoding = NoiseEncoding(position, embedding_channels) 73 | pos_encoding = nn.Dense(pos_encoding, embedding_channels * 4) 74 | pos_encoding = nn.swish(pos_encoding) 75 | pos_encoding = nn.Dense(pos_encoding, embedding_channels * 4) 76 | pos_encoding = pos_encoding[:, None, :] 77 | 78 | scale = nn.Conv(pos_encoding, out_channels, kernel_size=(3,), strides=(1,)) 79 | shift = nn.Conv(pos_encoding, out_channels, kernel_size=(3,), strides=(1,)) 80 | return scale, shift 81 | 82 | 83 | class DenseNCSN(nn.Module): 84 | """Small fully-connected score network.""" 85 | 86 | def apply(self, inputs, sigmas, num_layers=3, mlp_dims=2048): 87 | # inputs.shape = (batch_size, z_dims) 88 | # sigmas.shape = (batch_size, 1) 89 | x = inputs 90 | x = nn.Dense(x, mlp_dims) 91 | for _ in range(num_layers): 92 | scale, shift = DenseFiLM(t, 128, mlp_dims) 93 | x = DenseResBlock(x, mlp_dims, scale=scale, shift=shift) 94 | x = nn.LayerNorm(x) 95 | x = nn.Dense(x, inputs.shape[-1]) 96 | 97 | output = x / sigmas 98 | return output 99 | 100 | 101 | class ConvNCSN(nn.Module): 102 | """Convolutional score network for sequences.""" 103 | 104 | def apply(self, inputs, sigmas): 105 | # inputs.shape = (batch_size, seq_len, z_dims) 106 | # sigmas.shape = (batch_size, 1, 1) 107 | input_channels = inputs.shape[-1] 108 | x = nn.Conv(inputs, 128, kernel_size=(2,), strides=(1,)) 109 | 110 | for channels in (128, 256, 256, 128): 111 | x = ConvResBlock(x, channels) 112 | x = ConvResBlock(x, channels) 113 | 114 | x = nn.LayerNorm(x) 115 | x = nn.relu(x) 116 | x = nn.Conv(x, input_channels, kernel_size=(2,), strides=(1,)) 117 | 118 | output = x / sigmas 119 | return output 120 | 121 | 122 | class DenseDDPM(nn.Module): 123 | """Fully-connected diffusion network.""" 124 | 125 | def apply(self, inputs, t, num_layers=3, mlp_dims=2048): 126 | # inputs.shape = (batch_size, z_dims) 127 | # t.shape = (batch_size, 1) 128 | x = inputs 129 | x = nn.Dense(x, mlp_dims) 130 | for _ in range(num_layers): 131 | scale, shift = DenseFiLM(t, 128, mlp_dims) 132 | x = DenseResBlock(x, mlp_dims, scale=scale, shift=shift) 133 | x = nn.LayerNorm(x) 134 | x = nn.Dense(x, inputs.shape[-1]) 135 | return x 136 | 137 | 138 | class TransformerDDPM(nn.Module): 139 | """Transformer-based diffusion model.""" 140 | 141 | def apply(self, 142 | inputs, 143 | t, 144 | num_layers=6, 145 | num_heads=8, 146 | num_mlp_layers=2, 147 | mlp_dims=2048): 148 | batch_size, seq_len, data_channels = inputs.shape 149 | 150 | x = inputs 151 | embed_channels = 128 152 | temb = TransformerPositionalEncoding(jnp.arange(seq_len), embed_channels) 153 | temb = temb[None, :, :] 154 | assert temb.shape[1:] == (seq_len, embed_channels), temb.shape 155 | x = nn.Dense(x, embed_channels) 156 | 157 | x = x + temb 158 | for _ in range(num_layers): 159 | shortcut = x 160 | x = nn.LayerNorm(x) 161 | x = nn.SelfAttention(x, num_heads=num_heads) 162 | x = x + shortcut 163 | shortcut2 = x 164 | x = nn.LayerNorm(x) 165 | x = nn.Dense(x, mlp_dims) 166 | x = nn.gelu(x) 167 | x = nn.Dense(x, embed_channels) 168 | x = x + shortcut2 169 | 170 | x = nn.LayerNorm(x) 171 | x = nn.Dense(x, mlp_dims) 172 | 173 | for _ in range(num_mlp_layers): 174 | scale, shift = DenseFiLM(t.squeeze(-1), 128, mlp_dims, sequence=True) 175 | x = DenseResBlock(x, mlp_dims, scale=scale, shift=shift) 176 | 177 | x = nn.LayerNorm(x) 178 | x = nn.Dense(x, data_channels) 179 | return x 180 | -------------------------------------------------------------------------------- /models/shared.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Shared neural network components.""" 17 | import jax.numpy as jnp 18 | from flax import nn 19 | 20 | 21 | class MDN(nn.Module): 22 | """Mixture density output layer.""" 23 | 24 | def apply(self, inputs, out_channels=512, num_components=10): 25 | # inputs.shape = (batch_size, seq_len, channels) 26 | x = inputs 27 | mu = nn.Dense(x, out_channels * num_components) 28 | log_sigma = nn.Dense(x, out_channels * num_components) 29 | pi = nn.Dense(x, num_components) 30 | return pi, mu, log_sigma 31 | 32 | 33 | class TransformerPositionalEncoding(nn.Module): 34 | """Transformer positional encoding block.""" 35 | 36 | def apply(self, timesteps, channels): 37 | # timesteps.shape = (seq_len,) 38 | # channels.shape = () 39 | assert len(timesteps.shape) == 1 40 | half_dim = channels // 2 41 | emb = jnp.log(10000) / float(half_dim - 1) 42 | emb = jnp.exp(jnp.arange(half_dim) * -emb) 43 | emb = timesteps[:, None] * emb[None, :] 44 | emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1) 45 | if channels % 2 == 1: 46 | emb = jnp.pad(emb, [[0, 0], [0, 1]]) 47 | assert emb.shape == (timesteps.shape[0], channels) 48 | return emb 49 | 50 | 51 | class FeaturewiseAffine(nn.Module): 52 | """Feature-wise affine layer.""" 53 | 54 | def apply(self, x, scale, shift): 55 | return scale * x + shift 56 | 57 | 58 | class DenseResBlock(nn.Module): 59 | """Fully-connected residual block.""" 60 | 61 | def apply(self, inputs, output_size, scale=1., shift=0.): 62 | output = nn.LayerNorm(inputs) 63 | output = FeaturewiseAffine(output, scale, shift) 64 | output = nn.swish(output) 65 | output = nn.Dense(output, output_size) 66 | output = nn.LayerNorm(output) 67 | output = FeaturewiseAffine(output, scale, shift) 68 | output = nn.swish(output) 69 | output = nn.Dense(output, output_size) 70 | 71 | shortcut = inputs 72 | if inputs.shape[-1] != output_size: 73 | shortcut = nn.Dense(inputs, output_size) 74 | 75 | return output + shortcut 76 | 77 | 78 | class ConvResBlock(nn.Module): 79 | """Convolutional residual block.""" 80 | 81 | def apply(self, inputs, out_channels, scale=1., shift=0.): 82 | output = nn.Conv(inputs, out_channels, kernel_size=(3,), strides=(1,)) 83 | output = nn.swish(output) 84 | shortcut = output 85 | output = nn.Conv(output, out_channels, kernel_size=(3,), strides=(1,)) 86 | output = nn.GroupNorm(output) 87 | output = FeaturewiseAffine(output, scale, shift) 88 | output = nn.swish(output) 89 | assert shortcut.shape[-1] == out_channels 90 | return output + shortcut 91 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.9.0 2 | aiohttp==3.6.2 3 | aiohttp-cors==0.7.0 4 | aioredis==1.3.1 5 | apache-beam==2.22.0 6 | astroid==2.4.2 7 | astunparse==1.6.3 8 | async-timeout==3.0.1 9 | attrs==19.3.0 10 | audioread==2.1.8 11 | avro-python3==1.9.2.1 12 | backcall==0.2.0 13 | backports.functools-lru-cache==1.5 14 | beautifulsoup4==4.9.1 15 | bleach==3.1.5 16 | blessings==1.7 17 | bokeh==2.1.0 18 | brotlipy==0.7.0 19 | bz2file==0.98 20 | cachetools==3.1.1 21 | certifi==2020.12.5 22 | cffi @ file:///tmp/build/80754af9/cffi_1606255099073/work 23 | chardet==3.0.4 24 | click==7.1.2 25 | cloudpickle==1.3.0 26 | cmake==3.18.4.post1 27 | colorama==0.4.3 28 | colorful==0.5.4 29 | crcmod==1.7 30 | cryptography @ file:///tmp/build/80754af9/cryptography_1607635305226/work 31 | cycler==0.10.0 32 | dataclasses==0.6 33 | decorator==4.4.2 34 | defusedxml==0.6.0 35 | dill==0.3.1.1 36 | dm-sonnet==2.0.0 37 | dm-tree==0.1.5 38 | docopt==0.6.2 39 | dopamine-rl==3.0.1 40 | entrypoints==0.3 41 | fastavro==0.23.4 42 | fasteners==0.15 43 | filelock==3.0.12 44 | flake8==3.8.3 45 | Flask==1.1.2 46 | flatbuffers==1.12 47 | flax==0.3.0 48 | future==0.18.2 49 | gast==0.3.3 50 | gevent==20.6.2 51 | gin-config==0.3.0 52 | google==3.0.0 53 | google-api-core==1.21.0 54 | google-api-python-client==1.9.3 55 | google-apitools==0.5.31 56 | google-auth==1.18.0 57 | google-auth-httplib2==0.0.3 58 | google-auth-oauthlib==0.4.1 59 | google-cloud-bigquery==1.24.0 60 | google-cloud-bigtable==1.0.0 61 | google-cloud-core==1.3.0 62 | google-cloud-datastore==1.7.4 63 | google-cloud-dlp==0.13.0 64 | google-cloud-language==1.3.0 65 | google-cloud-pubsub==1.0.2 66 | google-cloud-spanner==1.13.0 67 | google-cloud-videointelligence==1.13.0 68 | google-cloud-vision==0.42.0 69 | google-pasta==0.2.0 70 | google-resumable-media==0.5.1 71 | googleapis-common-protos==1.52.0 72 | gpustat==0.6.0 73 | greenlet==0.4.16 74 | grpc-google-iam-v1==0.12.3 75 | grpcio==1.29.0 76 | grpcio-gcp==0.2.2 77 | gunicorn==20.0.4 78 | gviz-api==1.9.0 79 | gym==0.17.2 80 | h5py==2.10.0 81 | hdfs==2.5.8 82 | hiredis==1.1.0 83 | httplib2==0.17.4 84 | idna==2.10 85 | imageio==2.8.0 86 | importlib-metadata==1.6.1 87 | intervaltree==3.0.2 88 | ipykernel==5.3.0 89 | ipython==7.15.0 90 | ipython-genutils==0.2.0 91 | ipywidgets==7.5.1 92 | isort==4.3.21 93 | itsdangerous==1.1.0 94 | jax==0.2.8 95 | jaxlib==0.1.57+cuda101 96 | jedi==0.17.0 97 | Jinja2==2.11.2 98 | joblib==0.15.1 99 | jsonschema==3.2.0 100 | jupyter==1.0.0 101 | jupyter-client==6.1.3 102 | jupyter-console==6.1.0 103 | jupyter-core==4.6.3 104 | Keras==2.4.3 105 | Keras-Preprocessing==1.1.2 106 | kfac==0.2.0 107 | kiwisolver==1.0.1 108 | lazy-object-proxy==1.4.3 109 | librosa==0.7.2 110 | llvmlite==0.31.0 111 | -e git+https://github.com/magenta/magenta@2d0fd456d7faa272733b57d286f5f26998082cf8#egg=magenta 112 | Markdown==3.2.2 113 | MarkupSafe==1.1.1 114 | matplotlib==2.2.3 115 | mccabe==0.6.1 116 | mesh-tensorflow==0.1.13 117 | mido==1.2.6 118 | mir-eval==0.6 119 | mistune==0.8.4 120 | mock==2.0.0 121 | monotonic==1.5 122 | mpmath==1.1.0 123 | msgpack==1.0.0 124 | MulticoreTSNE==0.1 125 | multidict==4.7.6 126 | nbconvert==5.6.1 127 | nbformat==5.0.7 128 | networkx==2.4 129 | note-seq==0.0.2 130 | notebook==6.0.3 131 | numba==0.48.0 132 | numpy==1.19.4 133 | nvidia-ml-py3==7.352.0 134 | oauth2client==3.0.0 135 | oauthlib==3.1.0 136 | opencensus==0.7.12 137 | opencensus-context==0.1.2 138 | opencv-python==4.2.0.34 139 | opt-einsum==3.2.1 140 | packaging==20.4 141 | pandas==1.0.5 142 | pandocfilters==1.4.2 143 | parso==0.7.0 144 | pbr==5.4.5 145 | pexpect==4.8.0 146 | pickleshare==0.7.5 147 | Pillow==7.1.2 148 | pretty-midi==0.2.9 149 | prometheus-client==0.8.0 150 | promise==2.3 151 | prompt-toolkit==3.0.5 152 | protobuf==3.12.2 153 | psutil==5.8.0 154 | ptyprocess==0.6.0 155 | py-spy==0.3.3 156 | pyarrow==0.17.1 157 | pyasn1==0.4.8 158 | pyasn1-modules==0.2.8 159 | pycodestyle==2.6.0 160 | pycparser @ file:///tmp/build/80754af9/pycparser_1594388511720/work 161 | pydot==1.4.1 162 | pydub==0.24.1 163 | pyflakes==2.2.0 164 | pyFluidSynth==1.2.5 165 | pyglet==1.5.0 166 | Pygments==2.6.1 167 | pygtrie==2.3.3 168 | pylint==2.5.3 169 | pymongo==3.10.1 170 | pyOpenSSL @ file:///tmp/build/80754af9/pyopenssl_1608057966937/work 171 | pyparsing==2.2.0 172 | pypng==0.0.20 173 | pyrsistent==0.16.0 174 | PySocks @ file:///tmp/build/80754af9/pysocks_1594394576006/work 175 | python-dateutil==2.7.3 176 | python-rtmidi==1.1.2 177 | pytz==2018.5 178 | PyWavelets==1.1.1 179 | PyYAML==5.3.1 180 | pyzmq==19.0.1 181 | qtconsole==4.7.4 182 | QtPy==1.9.0 183 | ray==1.1.0 184 | redis==3.5.3 185 | requests==2.24.0 186 | requests-oauthlib==1.3.0 187 | resampy==0.2.2 188 | rsa==4.6 189 | scikit-image==0.17.2 190 | scikit-learn==0.19.2 191 | scipy==1.1.0 192 | seaborn==0.10.1 193 | selenium @ file:///tmp/build/80754af9/selenium_1594420457392/work 194 | Send2Trash==1.5.0 195 | six @ file:///tmp/build/80754af9/six_1605205313296/work 196 | sk-video==1.1.10 197 | sklearn==0.0 198 | sortedcontainers==2.2.2 199 | SoundFile==0.10.3.post1 200 | soupsieve==2.0.1 201 | sox==1.3.7 202 | subprocess32==3.5.2 203 | sympy==1.6 204 | tabulate==0.8.7 205 | tensor2tensor==1.15.7 206 | tensorboard==2.4.1 207 | tensorboard-plugin-profile==2.3.0 208 | tensorboard-plugin-wit==1.6.0.post3 209 | tensorflow==2.3.1 210 | tensorflow-addons==0.10.0 211 | tensorflow-datasets==3.1.0 212 | tensorflow-estimator==2.3.0 213 | tensorflow-gan==2.0.0 214 | tensorflow-hub==0.8.0 215 | tensorflow-metadata==0.22.2 216 | tensorflow-probability==0.7.0 217 | termcolor==1.1.0 218 | terminado==0.8.3 219 | testpath==0.4.4 220 | tf-slim==1.1.0 221 | tfp-nightly==0.12.0.dev20201127 222 | threadpoolctl==2.1.0 223 | tifffile==2020.6.3 224 | toml==0.10.1 225 | torch==1.5.1 226 | tornado==6.0.4 227 | tqdm==4.46.1 228 | traitlets==4.3.3 229 | tsnecuda==0.1.1 230 | typed-ast==1.4.1 231 | typeguard==2.9.1 232 | typing-extensions==3.7.4.2 233 | uritemplate==3.0.1 234 | urllib3==1.25.11 235 | wcwidth==0.2.4 236 | webencodings==0.5.1 237 | Werkzeug==1.0.1 238 | widgetsnbextension==3.5.1 239 | wrapt==1.12.1 240 | yapf==0.27.0 241 | yarl==1.4.2 242 | zipp==3.1.0 243 | zope.event==4.4 244 | zope.interface==5.1.0 245 | -------------------------------------------------------------------------------- /sample_mdn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Sample from trained autoregressive MDN.""" 17 | import os 18 | import time 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | from functools import partial 24 | 25 | import jax 26 | import jax.numpy as jnp 27 | import numpy as np 28 | import tensorflow as tf 29 | import tensorflow_datasets as tfds 30 | 31 | from flax.metrics import tensorboard 32 | from flax.training import checkpoints 33 | 34 | import utils.data_utils as data_utils 35 | import utils.train_utils as train_utils 36 | import utils.losses as losses 37 | import utils.metrics as metrics 38 | import train_transformer 39 | import input_pipeline 40 | 41 | from tensorflow_probability.substrates import jax as tfp 42 | tfd = tfp.distributions 43 | 44 | FLAGS = flags.FLAGS 45 | AUTOTUNE = tf.data.experimental.AUTOTUNE 46 | 47 | flags.DEFINE_integer('sample_seed', 1, 48 | 'Random number generator seed for sampling.') 49 | flags.DEFINE_string('sampling_dir', 'sample', 'Sampling directory.') 50 | flags.DEFINE_integer('sample_size', 1000, 'Number of samples.') 51 | flags.DEFINE_boolean('flush', True, 'Flush generated samples to disk.') 52 | 53 | 54 | def sample(num_samples=2400, steps=32, embedding_dims=42, rng_seed=1, 55 | real=None): 56 | """Generate samples using autoregressive decoding. 57 | 58 | Args: 59 | num_samples: The number of samples to generate. 60 | steps: Number of sampling steps. 61 | embedding_dims: Number of dimensions per embedding. 62 | rng_seed: Initialization seed. 63 | 64 | Returns: 65 | generated: An array of generated samples. 66 | """ 67 | rng = jax.random.PRNGKey(rng_seed) 68 | rng, model_rng = jax.random.split(rng) 69 | 70 | # Create a model with dummy parameters and a dummy optimizer 71 | lm_kwargs = { 72 | 'num_layers': FLAGS.num_layers, 73 | 'num_heads': FLAGS.num_heads, 74 | 'mdn_mixtures': FLAGS.mdn_components, 75 | 'num_mlp_layers': FLAGS.num_mlp_layers, 76 | 'mlp_dims': FLAGS.mlp_dims 77 | } 78 | model = train_transformer.create_model(model_rng, (steps, embedding_dims), 79 | lm_kwargs, 80 | batch_size=1, 81 | verbose=True) 82 | optimizer = train_transformer.create_optimizer(model, 0) 83 | early_stop = train_utils.EarlyStopping() 84 | 85 | # Load learned parameters 86 | optimizer, early_stop = checkpoints.restore_checkpoint( 87 | FLAGS.model_dir, (optimizer, early_stop)) 88 | 89 | # Autoregressive decoding 90 | t0 = time.time() 91 | tokens = jnp.zeros((num_samples, steps, embedding_dims)) 92 | 93 | for i in range(steps): 94 | pi, mu, log_sigma = optimizer.target(tokens, shift=False) 95 | 96 | channels = tokens.shape[-1] 97 | mdn_k = pi.shape[-1] 98 | out_pi = pi.reshape(-1, mdn_k) 99 | out_mu = mu.reshape(-1, channels * mdn_k) 100 | out_log_sigma = log_sigma.reshape(-1, channels * mdn_k) 101 | mix_dist = tfd.Categorical(logits=out_pi) 102 | mus = out_mu.reshape(-1, mdn_k, channels) 103 | log_sigmas = out_log_sigma.reshape(-1, mdn_k, channels) 104 | sigmas = jnp.exp(log_sigmas) 105 | component_dist = tfd.MultivariateNormalDiag(loc=mus, scale_diag=sigmas) 106 | mixture = tfd.MixtureSameFamily(mixture_distribution=mix_dist, 107 | components_distribution=component_dist) 108 | 109 | rng, embed_rng = jax.random.split(rng) 110 | next_tokens = mixture.sample(seed=embed_rng).reshape(*tokens.shape) 111 | next_z = next_tokens[:, i] 112 | 113 | if i < steps - 1: 114 | tokens = jax.ops.index_update(tokens, jax.ops.index[:, i + 1], next_z) 115 | else: 116 | tokens = next_tokens # remove start token 117 | 118 | logging.info('Generated samples in %f seconds', time.time() - t0) 119 | return tokens 120 | 121 | 122 | def main(argv): 123 | del argv # unused 124 | 125 | logging.info(FLAGS.flags_into_string()) 126 | logging.info('Platform: %s', jax.lib.xla_bridge.get_backend().platform) 127 | 128 | # Make sure TensorFlow does not allocate GPU memory. 129 | tf.config.experimental.set_visible_devices([], 'GPU') 130 | 131 | log_dir = FLAGS.sampling_dir 132 | 133 | pca = data_utils.load(os.path.expanduser( 134 | FLAGS.pca_ckpt)) if FLAGS.pca_ckpt else None 135 | slice_idx = data_utils.load(os.path.expanduser( 136 | FLAGS.slice_ckpt)) if FLAGS.slice_ckpt else None 137 | dim_weights = data_utils.load(os.path.expanduser( 138 | FLAGS.dim_weights_ckpt)) if FLAGS.dim_weights_ckpt else None 139 | 140 | train_ds, eval_ds = input_pipeline.get_dataset( 141 | dataset=FLAGS.dataset, 142 | data_shape=FLAGS.data_shape, 143 | problem='vae', 144 | batch_size=FLAGS.batch_size, 145 | normalize=FLAGS.normalize, 146 | pca_ckpt=FLAGS.pca_ckpt, 147 | slice_ckpt=FLAGS.slice_ckpt, 148 | dim_weights_ckpt=FLAGS.dim_weights_ckpt, 149 | include_cardinality=False) 150 | eval_min, eval_max = eval_ds.min, eval_ds.max 151 | eval_ds = eval_ds.unbatch() 152 | if FLAGS.sample_size is not None: 153 | eval_ds = eval_ds.take(FLAGS.sample_size) 154 | real = np.stack([ex for ex in tfds.as_numpy(eval_ds)]) 155 | shape = real[0].shape 156 | 157 | # Generate samples 158 | generated = sample(FLAGS.sample_size, shape[-2], shape[-1], FLAGS.sample_seed, 159 | real) 160 | 161 | # Dump generated to CPU. 162 | generated = np.array(generated) 163 | 164 | # Write samples to disk (used for listening). 165 | if FLAGS.flush: 166 | # Inverse transform data back to listenable/unnormalized latent space. 167 | generated_t = input_pipeline.inverse_data_transform(generated, 168 | FLAGS.normalize, pca, 169 | train_ds.min, 170 | train_ds.max, slice_idx, 171 | dim_weights) 172 | real_t = input_pipeline.inverse_data_transform(real, FLAGS.normalize, pca, 173 | eval_min, eval_max, 174 | slice_idx, dim_weights) 175 | data_utils.save(real_t, os.path.join(log_dir, 'mdn/real.pkl')) 176 | data_utils.save(generated_t, os.path.join(log_dir, 'mdn/generated.pkl')) 177 | 178 | 179 | if __name__ == '__main__': 180 | app.run(main) 181 | -------------------------------------------------------------------------------- /sample_ncsn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Sample from trained score network.""" 17 | import io 18 | import os 19 | import time 20 | import warnings 21 | 22 | from absl import app 23 | from absl import flags 24 | from absl import logging 25 | from functools import partial 26 | 27 | import jax 28 | import jax.numpy as jnp 29 | import numpy as np 30 | import tensorflow as tf 31 | import tensorflow_datasets as tfds 32 | 33 | from flax.metrics import tensorboard 34 | from flax.training import checkpoints 35 | from matplotlib import pyplot as plt 36 | from sklearn.mixture import GaussianMixture 37 | 38 | import utils.data_utils as data_utils 39 | import utils.ebm_utils as ebm_utils 40 | import utils.train_utils as train_utils 41 | import utils.plot_utils as plot_utils 42 | import utils.losses as losses 43 | import utils.metrics as metrics 44 | import models.ncsn as ncsn 45 | import train_ncsn 46 | import input_pipeline 47 | 48 | FLAGS = flags.FLAGS 49 | AUTOTUNE = tf.data.experimental.AUTOTUNE 50 | 51 | flags.DEFINE_integer('sample_seed', 1, 52 | 'Random number generator seed for sampling.') 53 | flags.DEFINE_string('sampling_dir', 'samples', 'Sampling directory.') 54 | flags.DEFINE_integer('sample_size', 1000, 'Number of samples.') 55 | 56 | # Metrics. 57 | flags.DEFINE_boolean('compute_metrics', False, 58 | 'Compute evaluation metrics for generated samples.') 59 | flags.DEFINE_boolean('compute_final_only', False, 60 | 'Do not include metrics for intermediate samples.') 61 | 62 | # Generation. 63 | flags.DEFINE_boolean('flush', True, 'Flush generated samples to disk.') 64 | flags.DEFINE_boolean('animate', False, 'Generate animation of samples.') 65 | flags.DEFINE_boolean('infill', False, 'Infill.') 66 | flags.DEFINE_boolean('interpolate', False, 'Interpolate.') 67 | 68 | 69 | def evaluate(writer, real, collection, baseline, valid_real): 70 | """Evaluation metrics. 71 | 72 | NOTE: It is important for collection and real to be normalized for 73 | accurate evaluation statistics. 74 | 75 | Args: 76 | writer: TensorBoard summary writer. 77 | real: An array of real data samples of shape [N, *data_shape]. 78 | collection: Generated samples at varying timesteps. Each array 79 | is of shape [sampling_steps, N, *data_shape]. 80 | baseline: Generated samples from baseline. 81 | valid_real: Real samples from training distribution. 82 | 83 | Returns: 84 | A dict of evaluation metrics. 85 | """ 86 | assert collection.shape[1:] == real.shape 87 | 88 | logging.info( 89 | f'Generated sample range: [{collection[-1].min()}, {collection[-1].max()}]' 90 | ) 91 | logging.info(f'Test sample range: [{real.min()}, {real.max()}]') 92 | 93 | if collection[-1].min() < -1. or collection[-1].max() > 1. \ 94 | or real.min() < -1 or real.max() > 1.: 95 | warnings.warn( 96 | 'Normalize test samples and generated samples to [-1, 1] range.') 97 | 98 | gen_test_points = collection[np.linspace(0, 99 | len(collection) - 1, 100 | 20).astype(np.uint32)] 101 | 102 | if FLAGS.compute_final_only: 103 | gen_test_points = [gen_test_points[-1]] 104 | 105 | random_points = [np.random.randn(*collection[0].shape)] 106 | real_points = [valid_real] 107 | 108 | if collection.shape[-1] == 2: 109 | im_buf = plot_utils.scatter_2d(collection[0]) 110 | im = tf.image.decode_png(im_buf.getvalue(), channels=4) 111 | writer.image('init', im, step=0) 112 | 113 | init = collection[0] 114 | prd_init = metrics.precision_recall_distribution(real, init) 115 | prd_perfect = metrics.precision_recall_distribution(real, real) 116 | 117 | for model, test_points in [('baseline', [baseline]), 118 | ('ncsn', gen_test_points), 119 | ('random', random_points), ('real', real_points)]: 120 | log_dir = f'{model}/' 121 | if any([point is None for point in test_points]): 122 | continue 123 | 124 | for i, samples in enumerate(test_points): 125 | # Render samples 126 | if samples.shape[-1] == 2: 127 | im_buf = plot_utils.scatter_2d(samples) 128 | im = tf.image.decode_png(im_buf.getvalue(), channels=4) 129 | writer.image(f'{log_dir}fake', im, step=i) 130 | 131 | # K-means histogram evaluation. 132 | prd_dist = metrics.precision_recall_distribution(real, samples) 133 | buf = io.BytesIO() 134 | metrics.prd.plot([prd_dist, prd_init, prd_perfect], 135 | [model, 'noise', 'real']) 136 | plt.savefig(buf, format='png') 137 | plt.close() 138 | buf.seek(0) 139 | im = tf.image.decode_png(buf.getvalue(), channels=4) 140 | writer.image(f'{log_dir}prd', im, step=i) 141 | 142 | recall, precision = metrics.prd_f_beta_score(prd_dist) # F8, F1/8 scores. 143 | f1 = metrics.f1_score(precision, recall) 144 | writer.scalar(f'{log_dir}precision', precision, step=i) 145 | writer.scalar(f'{log_dir}recall', recall, step=i) 146 | writer.scalar(f'{log_dir}f1', f1, step=i) 147 | 148 | # Nearest neighbor evaluation. 149 | improved_p, improved_r = metrics.precision_recall(real, samples) 150 | improved_f1 = metrics.f1_score(improved_p, improved_r) 151 | writer.scalar(f'{log_dir}improved_precision', improved_p, step=i) 152 | writer.scalar(f'{log_dir}improved_recall', improved_r, step=i) 153 | writer.scalar(f'{log_dir}improved_f1', improved_f1, step=i) 154 | 155 | realism_scores = metrics.realism_scores(real, samples) 156 | realism = realism_scores.mean() 157 | writer.scalar(f'{log_dir}ipr_realism', realism, step=i) 158 | 159 | ndb_over_k = metrics.ndb_score(real, samples, k=50) 160 | writer.scalar(f'{log_dir}ndb', ndb_over_k, step=i) 161 | 162 | # Distance evaluation. 163 | frechet_dist = metrics.frechet_distance(real, samples) 164 | writer.scalar(f'{log_dir}frechet_distance', frechet_dist, step=i) 165 | 166 | mmd_rbf = metrics.mmd_rbf(real, samples) 167 | writer.scalar(f'{log_dir}mmd_rbf', mmd_rbf, step=i) 168 | 169 | mmd_polynomial = metrics.mmd_polynomial(real, samples) 170 | writer.scalar(f'{log_dir}mmd_polynomial', mmd_polynomial, step=i) 171 | 172 | writer.flush() 173 | 174 | stats = { 175 | 'precision': precision, 176 | 'recall': recall, 177 | 'f1': f1, 178 | 'improved_precision': improved_p, 179 | 'improved_recall': improved_r, 180 | 'improved_f1': improved_f1, 181 | 'realism': realism, 182 | 'frechet_dist': frechet_dist, 183 | 'mmd_rbf': mmd_rbf, 184 | 'mmd_polynomial': mmd_polynomial 185 | } 186 | return stats 187 | 188 | 189 | def infill_samples(samples, masks, rng_seed=1): 190 | rng = jax.random.PRNGKey(rng_seed) 191 | rng, model_rng = jax.random.split(rng) 192 | 193 | # Create a model with dummy parameters and a dummy optimizer 194 | model_kwargs = { 195 | 'num_layers': FLAGS.num_layers, 196 | 'num_heads': FLAGS.num_heads, 197 | 'num_mlp_layers': FLAGS.num_mlp_layers, 198 | 'mlp_dims': FLAGS.mlp_dims 199 | } 200 | model = train_ncsn.create_model(rng, 201 | samples.shape[1:], 202 | model_kwargs, 203 | batch_size=1, 204 | verbose=True) 205 | optimizer = train_ncsn.create_optimizer(model, 0) 206 | ema = train_utils.EMAHelper(mu=0, params=model.params) 207 | early_stop = train_utils.EarlyStopping() 208 | 209 | # Load learned parameters 210 | optimizer, ema, early_stop = checkpoints.restore_checkpoint( 211 | FLAGS.model_dir, (optimizer, ema, early_stop)) 212 | 213 | # Create noise schedule 214 | sigmas = ebm_utils.create_noise_schedule(FLAGS.sigma_begin, 215 | FLAGS.sigma_end, 216 | FLAGS.num_sigmas, 217 | schedule=FLAGS.schedule_type) 218 | 219 | if FLAGS.sampling == 'ald': 220 | sampling_algorithm = ebm_utils.annealed_langevin_dynamics 221 | elif FLAGS.sampling == 'cas': 222 | sampling_algorithm = ebm_utils.consistent_langevin_dynamics 223 | elif FLAGS.sampling == 'ddpm': 224 | sampling_algorithm = ebm_utils.diffusion_dynamics 225 | else: 226 | raise ValueError(f'Unknown sampling algorithm: {FLAGS.sampling}') 227 | 228 | init_rng, ld_rng = jax.random.split(rng) 229 | init = jax.random.uniform(key=init_rng, shape=samples.shape) 230 | generated, collection, ld_metrics = sampling_algorithm(ld_rng, 231 | optimizer.target, 232 | sigmas, 233 | init, 234 | FLAGS.ld_epsilon, 235 | FLAGS.ld_steps, 236 | FLAGS.denoise, 237 | True, 238 | infill_samples=samples, 239 | infill_masks=masks) 240 | ld_metrics = ebm_utils.collate_sampling_metrics(ld_metrics) 241 | 242 | return generated, collection, ld_metrics 243 | 244 | 245 | def diffusion_stochastic_encoder(samples, rng_seed=1): 246 | """Stochastic encoder for diffusion process (DDPM models). 247 | 248 | Estimates q(x_T | x_0) given real samples (x_0) and a noise schedule. 249 | """ 250 | assert FLAGS.sampling == 'ddpm' 251 | rng = jax.random.PRNGKey(rng_seed) 252 | betas = ebm_utils.create_noise_schedule(FLAGS.sigma_begin, 253 | FLAGS.sigma_end, 254 | FLAGS.num_sigmas, 255 | schedule=FLAGS.schedule_type) 256 | T = len(betas) 257 | alphas = 1. - betas 258 | alphas_prod = jnp.cumprod(alphas) 259 | 260 | rng, noise_rng = jax.random.split(rng) 261 | noise = jax.random.normal(key=rng, shape=samples.shape) 262 | mu = jnp.sqrt(alphas_prod[T]) * samples 263 | sigma = jnp.sqrt(1 - alphas_prod[T]) 264 | z = mu + sigma * noise 265 | return z 266 | 267 | 268 | def diffusion_decoder(z_list, rng_seed=1): 269 | """Generate samples given a list of latent z as an initialization.""" 270 | assert FLAGS.sampling == 'ddpm' 271 | 272 | rng = jax.random.PRNGKey(rng_seed) 273 | rng, ld_rng, model_rng = jax.random.split(rng, num=3) 274 | betas = ebm_utils.create_noise_schedule(FLAGS.sigma_begin, 275 | FLAGS.sigma_end, 276 | FLAGS.num_sigmas, 277 | schedule=FLAGS.schedule_type) 278 | 279 | # Create a model with dummy parameters and a dummy optimizer 280 | model_kwargs = { 281 | 'num_layers': FLAGS.num_layers, 282 | 'num_heads': FLAGS.num_heads, 283 | 'num_mlp_layers': FLAGS.num_mlp_layers, 284 | 'mlp_dims': FLAGS.mlp_dims 285 | } 286 | model = train_ncsn.create_model(model_rng, 287 | z_list[0].shape[1:], 288 | model_kwargs, 289 | batch_size=1, 290 | verbose=True) 291 | optimizer = train_ncsn.create_optimizer(model, 0) 292 | ema = train_utils.EMAHelper(mu=0, params=model.params) 293 | early_stop = train_utils.EarlyStopping() 294 | 295 | # Load learned parameters 296 | optimizer, ema, early_stop = checkpoints.restore_checkpoint( 297 | FLAGS.model_dir, (optimizer, ema, early_stop)) 298 | 299 | gen, collects, sampling_metrics = [], [], [] 300 | for i, z in enumerate(z_list): 301 | generated, collection, ld_metrics = ebm_utils.diffusion_dynamics( 302 | ld_rng, optimizer.target, betas, z, FLAGS.ld_epsilon, FLAGS.ld_steps, 303 | FLAGS.denoise, False) 304 | ld_metrics = ebm_utils.collate_sampling_metrics(ld_metrics) 305 | gen.append(generated) 306 | collects.append(collection) 307 | sampling_metrics.append(ld_metrics) 308 | logging.info('Generated samples %i out of %i', i, len(z_list)) 309 | 310 | return gen, collects, sampling_metrics 311 | 312 | 313 | def generate_samples(sample_shape, num_samples, rng_seed=1): 314 | """Generate samples using pre-trained score network. 315 | 316 | Args: 317 | sample_shape: Shape of each sample. 318 | num_samples: Number of samples to generate. 319 | rng_seed: Random number generator for sampling. 320 | """ 321 | rng = jax.random.PRNGKey(rng_seed) 322 | rng, model_rng = jax.random.split(rng) 323 | 324 | # Create a model with dummy parameters and a dummy optimizer 325 | model_kwargs = { 326 | 'num_layers': FLAGS.num_layers, 327 | 'num_heads': FLAGS.num_heads, 328 | 'num_mlp_layers': FLAGS.num_mlp_layers, 329 | 'mlp_dims': FLAGS.mlp_dims 330 | } 331 | model = train_ncsn.create_model(model_rng, 332 | sample_shape, 333 | model_kwargs, 334 | batch_size=1, 335 | verbose=True) 336 | optimizer = train_ncsn.create_optimizer(model, 0) 337 | ema = train_utils.EMAHelper(mu=0, params=model.params) 338 | early_stop = train_utils.EarlyStopping() 339 | 340 | # Load learned parameters 341 | optimizer, ema, early_stop = checkpoints.restore_checkpoint( 342 | FLAGS.model_dir, (optimizer, ema, early_stop)) 343 | 344 | # Create noise schedule 345 | sigmas = ebm_utils.create_noise_schedule(FLAGS.sigma_begin, 346 | FLAGS.sigma_end, 347 | FLAGS.num_sigmas, 348 | schedule=FLAGS.schedule_type) 349 | 350 | rng, sample_rng = jax.random.split(rng) 351 | 352 | t0 = time.time() 353 | generated, collection, ld_metrics = train_ncsn.sample( 354 | optimizer.target, 355 | sigmas, 356 | sample_rng, 357 | sample_shape, 358 | num_samples=num_samples, 359 | sampling=FLAGS.sampling, 360 | epsilon=FLAGS.ld_epsilon, 361 | steps=FLAGS.ld_steps, 362 | denoise=FLAGS.denoise) 363 | logging.info('Generated samples in %f seconds', time.time() - t0) 364 | 365 | return generated, collection, ld_metrics 366 | 367 | 368 | def main(argv): 369 | del argv # unused 370 | 371 | logging.info(FLAGS.flags_into_string()) 372 | logging.info('Platform: %s', jax.lib.xla_bridge.get_backend().platform) 373 | 374 | # Make sure TensorFlow does not allocate GPU memory. 375 | tf.config.experimental.set_visible_devices([], 'GPU') 376 | 377 | log_dir = FLAGS.sampling_dir 378 | writer = tensorboard.SummaryWriter(log_dir) 379 | 380 | pca = data_utils.load(os.path.expanduser( 381 | FLAGS.pca_ckpt)) if FLAGS.pca_ckpt else None 382 | slice_idx = data_utils.load(os.path.expanduser( 383 | FLAGS.slice_ckpt)) if FLAGS.slice_ckpt else None 384 | dim_weights = data_utils.load(os.path.expanduser( 385 | FLAGS.dim_weights_ckpt)) if FLAGS.dim_weights_ckpt else None 386 | 387 | train_ds, eval_ds = input_pipeline.get_dataset( 388 | dataset=FLAGS.dataset, 389 | data_shape=FLAGS.data_shape, 390 | problem=FLAGS.problem, 391 | batch_size=FLAGS.batch_size, 392 | normalize=FLAGS.normalize, 393 | pca_ckpt=FLAGS.pca_ckpt, 394 | slice_ckpt=FLAGS.slice_ckpt, 395 | dim_weights_ckpt=FLAGS.dim_weights_ckpt, 396 | include_cardinality=False) 397 | eval_min, eval_max = eval_ds.min, eval_ds.max 398 | eval_ds = eval_ds.unbatch() 399 | if FLAGS.sample_size is not None: 400 | eval_ds = eval_ds.take(FLAGS.sample_size) 401 | real = np.stack([ex for ex in tfds.as_numpy(eval_ds)]) 402 | shape = real[0].shape 403 | 404 | # Generation. 405 | if FLAGS.infill: # Infilling. 406 | if FLAGS.problem == 'toy' and real.shape[-1] == 2: 407 | samples = np.copy(real) 408 | samples[:, 1] = 0 409 | masks = np.zeros(samples.shape) 410 | masks[:, 0] = 1 411 | else: 412 | samples = np.copy(real) 413 | 414 | # Infill middle 16 latents 415 | idx = list(range(32)) 416 | fixed_idx = idx[:8] + idx[-8:] 417 | infilled_idx = idx[8:-8] 418 | 419 | samples[:, infilled_idx, :] = 0 # infilled 420 | masks = np.zeros(samples.shape) 421 | masks[:, fixed_idx, :] = 1 # hold fixed 422 | 423 | generated, collection, ld_metrics = infill_samples( 424 | samples, masks, rng_seed=FLAGS.sample_seed) 425 | 426 | elif FLAGS.interpolate: # Interpolation. 427 | starts = real 428 | goals = np.roll(starts, shift=1, axis=0) 429 | starts_z = diffusion_stochastic_encoder(starts, rng_seed=FLAGS.sample_seed) 430 | goals_z = diffusion_stochastic_encoder(goals, rng_seed=FLAGS.sample_seed) 431 | interp_zs = [(1 - alpha) * starts_z + alpha * goals_z 432 | for alpha in np.linspace(0., 1., 9)] 433 | generated, collection, ld_metrics = diffusion_decoder( 434 | interp_zs, rng_seed=FLAGS.sample_seed) 435 | generated, collection = np.stack(generated), np.stack(collection) 436 | 437 | else: # Unconditional generation. 438 | generated, collection, ld_metrics = generate_samples( 439 | shape, len(real), rng_seed=FLAGS.sample_seed) 440 | 441 | # Animation (for 2D samples). 442 | if FLAGS.animate and shape[-1] == 2: 443 | im_buf = plot_utils.animate_scatter_2d(collection[::2], fps=240) 444 | with open(os.path.join(log_dir, 'animated.gif'), 'wb') as f: 445 | f.write(im_buf.getvalue()) 446 | f.close() 447 | 448 | # Dump generated to CPU. 449 | generated = np.array(generated) 450 | collection = np.array(collection) 451 | 452 | # Write samples to disk (used for listening). 453 | if FLAGS.flush: 454 | # Inverse transform data back to listenable/unnormalized latent space. 455 | generated_t = input_pipeline.inverse_data_transform(generated, 456 | FLAGS.normalize, pca, 457 | train_ds.min, 458 | train_ds.max, slice_idx, 459 | dim_weights) 460 | if not FLAGS.interpolate: 461 | collection_t = input_pipeline.inverse_data_transform( 462 | collection, FLAGS.normalize, pca, train_ds.min, train_ds.max, 463 | slice_idx, dim_weights) 464 | data_utils.save(collection_t, os.path.join(log_dir, 465 | 'ncsn/collection.pkl')) 466 | 467 | real_t = input_pipeline.inverse_data_transform(real, FLAGS.normalize, pca, 468 | eval_min, eval_max, 469 | slice_idx, dim_weights) 470 | data_utils.save(real_t, os.path.join(log_dir, 'ncsn/real.pkl')) 471 | data_utils.save(generated_t, os.path.join(log_dir, 'ncsn/generated.pkl')) 472 | 473 | # Run evaluation metrics. 474 | if FLAGS.compute_metrics: 475 | train_ncsn.log_langevin_dynamics(ld_metrics, 0, log_dir) 476 | metrics = evaluate(writer, real, collection, None, real) 477 | train_utils.log_metrics(metrics, 1, 1) 478 | 479 | 480 | if __name__ == '__main__': 481 | app.run(main) 482 | -------------------------------------------------------------------------------- /scripts/decode_dataset_beam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | r"""Dataset generation.""" 17 | 18 | import functools 19 | import pickle 20 | 21 | from absl import app 22 | from absl import flags 23 | from absl import logging 24 | import apache_beam as beam 25 | from apache_beam.metrics import Metrics 26 | from magenta.models.music_vae import TrainedModel 27 | import note_seq 28 | import numpy as np 29 | import tensorflow as tf 30 | 31 | from .. import config 32 | from ../utils/ import song_utils 33 | from ../utils/ import data_utils 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | flags.DEFINE_string( 38 | 'pipeline_options', '--runner=DirectRunner', 39 | 'Command line flags to use in constructing the Beam pipeline options.') 40 | 41 | # Model 42 | flags.DEFINE_string('model', 'melody-2-big', 'Model configuration.') 43 | flags.DEFINE_string('checkpoint', '~/checkpoints/cat-mel_2bar_big.tar', 44 | 'Model checkpoint.') 45 | flags.DEFINE_boolean('melody', True, 'If True, decode melodies.') 46 | 47 | # Dataset 48 | flags.DEFINE_list('data_shape', [32, 512], 'Data shape.') 49 | flags.DEFINE_string('input', './output/mel-32step-512', 50 | 'Path to TFRecord dataset.') 51 | flags.DEFINE_string('output', './decoded', 'Output directory.') 52 | 53 | 54 | class DecodeSong(beam.DoFn): 55 | """Decode MusicVAE embeddings into one-hot NoteSequence tensor.""" 56 | 57 | def setup(self): 58 | logging.info('Loading pre-trained model %s', FLAGS.model) 59 | self.model_config = config.MUSIC_VAE_CONFIG[FLAGS.model] 60 | self.model = TrainedModel(self.model_config, 61 | batch_size=1, 62 | checkpoint_dir_or_path=FLAGS.checkpoint) 63 | 64 | shape = tuple(map(int, FLAGS.data_shape)) 65 | prod = lambda a: functools.reduce(lambda x, y: x * y, a) 66 | flattened_shape = prod(shape) 67 | self.decode_fn = lambda x: data_utils._decode_record(x, flattened_shape, len(shape)) 68 | 69 | def process(self, example): 70 | parsed = self.decode_fn(example) 71 | encoding = parsed['inputs'] 72 | Metrics.counter('DecodeSong', 'decoding_song').inc() 73 | 74 | chunk_length = 2 if FLAGS.melody else 1 75 | chunks = song_utils.embeddings_to_chunks(encoding, self.model) 76 | song_utils.fix_instruments_for_concatenation(chunks) 77 | ns = note_seq.sequences_lib.concatenate_sequences(chunks) 78 | 79 | tensor = np.array( 80 | self.model_config.data_converter.to_tensors(ns).inputs[::chunk_length]) 81 | tensor = tensor.reshape(-1, tensor.shape[-1]) 82 | yield pickle.dumps(tensor) 83 | 84 | 85 | def main(argv): 86 | del argv # unused 87 | 88 | pipeline_options = beam.options.pipeline_options.PipelineOptions( 89 | FLAGS.pipeline_options.split(',')) 90 | 91 | with beam.Pipeline(options=pipeline_options) as p: 92 | p |= 'tfrecord_list' >> beam.Create(tf.io.gfile.glob(FLAGS.input)) 93 | p |= 'read_tfrecord' >> beam.io.tfrecordio.ReadAllFromTFRecord() 94 | p |= 'shuffle_input' >> beam.Reshuffle() 95 | p |= 'decode_song' >> beam.ParDo(DecodeSong()) 96 | p |= 'shuffle_output' >> beam.Reshuffle() 97 | p |= 'write' >> beam.io.WriteToTFRecord(FLAGS.output) 98 | 99 | 100 | if __name__ == '__main__': 101 | app.run(main) 102 | -------------------------------------------------------------------------------- /scripts/generate_compressed_transform.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Fit a compression transform on chunk embedding data.""" 17 | import glob 18 | import os 19 | import sys 20 | 21 | import matplotlib.pyplot as plt 22 | import numpy as np 23 | import tensorflow as tf 24 | import tensorflow_datasets as tfds 25 | from absl import app 26 | from absl import flags 27 | from absl import logging 28 | from sklearn.preprocessing import StandardScaler 29 | from sklearn.decomposition import PCA 30 | from sklearn.pipeline import Pipeline 31 | 32 | sys.path.append("{}/../".format(os.path.dirname(os.path.abspath(__file__)))) 33 | import utils.data_utils as data_utils 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | # Input dataset. 38 | flags.DEFINE_string('dataset', None, 'Path to input data (TFRecord format).') 39 | flags.DEFINE_list('data_shape', [ 40 | 512, 41 | ], 'Shape of data.') 42 | flags.DEFINE_integer('samples', int(2e6), 'Number of data samples to train on.') 43 | 44 | # PCA. 45 | flags.DEFINE_integer('dims', 200, 'Rank of compressed embedding.') 46 | flags.DEFINE_enum('mode', 'pca', ['slice', 'pca'], 'Data generation mode.') 47 | flags.DEFINE_boolean('normalize', True, 48 | 'Add normalization to transform pipeline.') 49 | flags.DEFINE_string('ckpt', './output/pca.pkl', 50 | 'Path to file containing transform checkpoint.') 51 | 52 | # Visualization. 53 | flags.DEFINE_boolean('compute_dims', False, 54 | 'Compute the expected number of dimensions required.') 55 | flags.DEFINE_float('var_threshold', .85, 56 | 'Explained variance threshold for computing dimensions.') 57 | 58 | 59 | class SliceTransform(object): 60 | """Slice transform.""" 61 | 62 | def __init__(self, component_idx, fill_idx, fill): 63 | self.orig_dims = len(component_idx) + len(fill) 64 | self.dims = len(component_idx) 65 | 66 | self.component_idx = component_idx # important components 67 | self.fill_idx = fill_idx # unimportant components 68 | self.fill = fill # values to fill the unimportant components with 69 | 70 | def transform(self, x): 71 | # x.shape = (batch_size, self.orig_dims) 72 | compressed = x[:, self.component_idx] 73 | assert compressed.shape == (x.shape[0], self.dims) 74 | return compressed 75 | 76 | def inverse_transform(self, x): 77 | # x.shape = (batch_size, self.dims) 78 | recon = np.zeros((x.shape[0], self.orig_dims)) 79 | recon[:, self.component_idx] = x 80 | recon[:, self.fill_idx] = self.fill 81 | assert recon.shape == (x.shape[0], self.orig_dims) 82 | return recon 83 | 84 | 85 | def main(argv): 86 | del argv # unused 87 | 88 | tf.config.experimental.set_visible_devices([], 'GPU') 89 | 90 | shape = tuple(map(int, FLAGS.data_shape)) 91 | train_ds = data_utils.get_tf_record_dataset( 92 | file_pattern=f'{FLAGS.dataset}/train-*.tfrecord', 93 | shape=shape, 94 | batch_size=2048, 95 | shuffle=True) 96 | train_ds = train_ds.take(FLAGS.samples) 97 | train_ds = np.stack([ex['inputs'] for ex in tfds.as_numpy(train_ds)]) 98 | 99 | if len(shape) == 2: 100 | assert shape[0] == 3 101 | means = train_ds[:, 1, :] 102 | sigmas = train_ds[:, 2, :] 103 | avg_mean = means.mean(0) 104 | avg_sigma = sigmas.mean(0) 105 | weights = 1 / avg_sigma**2 106 | # idx = np.where(avg_sigma < 0.98)[0] 107 | logging.info('Creating slice transform weights.') 108 | data_utils.save(weights, os.path.expanduser(FLAGS.ckpt)) 109 | return -1 110 | 111 | singular_values = np.linalg.svd(train_ds, 112 | full_matrices=False, 113 | compute_uv=False) 114 | variance_gain = singular_values.cumsum() / singular_values.sum() 115 | 116 | if FLAGS.compute_dims: 117 | dims = np.where(variance_gain >= FLAGS.var_threshold)[0][0] 118 | variance = variance_gain[dims] 119 | plt.text(0, variance + 0.05, '{:.3f}'.format(variance), rotation=0) 120 | plt.text(dims + 0.1, 0.2, dims, rotation=0) 121 | plt.axhline(y=variance, color='r', linestyle='--') 122 | plt.axvline(x=dims, color='r', linestyle='--') 123 | plt.plot(variance_gain) 124 | plt.show() 125 | 126 | logging.info('Explained variance ratio: %f, Rank: %i.', variance, dims) 127 | 128 | else: 129 | logging.info('Creating %s transform with rank %i.', FLAGS.mode, FLAGS.dims) 130 | 131 | operations = [] 132 | if FLAGS.normalize: 133 | logging.info('Adding normalization.') 134 | operations.append(('scaling', StandardScaler())) 135 | if FLAGS.mode == 'pca': 136 | operations.append(('pca', PCA(n_components=FLAGS.dims))) 137 | else: 138 | raise ValueError(f'Unsupported mode: {FLAGS.mode}') 139 | 140 | logging.info('Fitting transform.') 141 | pipeline = Pipeline(operations) 142 | pipeline = pipeline.fit(train_ds) 143 | data_utils.save(pipeline, os.path.expanduser(FLAGS.ckpt)) 144 | 145 | 146 | if __name__ == '__main__': 147 | app.run(main) 148 | -------------------------------------------------------------------------------- /scripts/generate_song_data_beam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | r"""Dataset generation.""" 17 | 18 | import pickle 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | import apache_beam as beam 24 | from apache_beam.metrics import Metrics 25 | from magenta.models.music_vae import TrainedModel 26 | import note_seq 27 | 28 | from .. import config 29 | from ../utils/ import song_utils 30 | 31 | FLAGS = flags.FLAGS 32 | 33 | flags.DEFINE_string( 34 | 'pipeline_options', '--runner=DirectRunner', 35 | 'Command line flags to use in constructing the Beam pipeline options.') 36 | 37 | # Model 38 | flags.DEFINE_string('model', 'melody-2-big', 'Model configuration.') 39 | flags.DEFINE_string('checkpoint', 'fb512_0trackmin/model.ckpt-99967', 40 | 'Model checkpoint.') 41 | 42 | # Data transformation 43 | flags.DEFINE_enum('mode', 'melody', ['melody', 'multitrack'], 44 | 'Data generation mode.') 45 | flags.DEFINE_string('input', None, 'Path to tfrecord files.') 46 | flags.DEFINE_string('output', None, 'Output path.') 47 | 48 | 49 | class EncodeSong(beam.DoFn): 50 | """Encode song into MusicVAE embeddings.""" 51 | 52 | def setup(self): 53 | logging.info('Loading pre-trained model %s', FLAGS.model) 54 | self.model_config = config.MUSIC_VAE_CONFIG[FLAGS.model] 55 | self.model = TrainedModel(self.model_config, 56 | batch_size=1, 57 | checkpoint_dir_or_path=FLAGS.checkpoint) 58 | 59 | def process(self, ns): 60 | logging.info('Processing %s::%s (%f)', ns.id, ns.filename, ns.total_time) 61 | if ns.total_time > 60 * 60: 62 | logging.info('Skipping notesequence with >1 hour duration') 63 | Metrics.counter('EncodeSong', 'skipped_long_song').inc() 64 | return 65 | 66 | Metrics.counter('EncodeSong', 'encoding_song').inc() 67 | 68 | if FLAGS.mode == 'melody': 69 | chunk_length = 2 70 | melodies = song_utils.extract_melodies(ns) 71 | if not melodies: 72 | Metrics.counter('EncodeSong', 'extracted_no_melodies').inc() 73 | return 74 | Metrics.counter('EncodeSong', 'extracted_melody').inc(len(melodies)) 75 | songs = [ 76 | song_utils.Song(melody, self.model_config.data_converter, 77 | chunk_length) for melody in melodies 78 | ] 79 | encoding_matrices = song_utils.encode_songs(self.model, songs) 80 | elif FLAGS.mode == 'multitrack': 81 | chunk_length = 1 82 | song = song_utils.Song(ns, 83 | self.model_config.data_converter, 84 | chunk_length, 85 | multitrack=True) 86 | encoding_matrices = song_utils.encode_songs(self.model, [song]) 87 | else: 88 | raise ValueError(f'Unsupported mode: {FLAGS.mode}') 89 | 90 | for matrix in encoding_matrices: 91 | assert matrix.shape[0] == 3 and matrix.shape[-1] == 512 92 | if matrix.shape[1] == 0: 93 | Metrics.counter('EncodeSong', 'skipped_matrix').inc() 94 | continue 95 | Metrics.counter('EncodeSong', 'encoded_matrix').inc() 96 | yield pickle.dumps(matrix) 97 | 98 | 99 | def main(argv): 100 | del argv # unused 101 | 102 | pipeline_options = beam.options.pipeline_options.PipelineOptions( 103 | FLAGS.pipeline_options.split(',')) 104 | 105 | with beam.Pipeline(options=pipeline_options) as p: 106 | p |= 'tfrecord_list' >> beam.Create(FLAGS.input) 107 | p |= 'read_tfrecord' >> beam.io.tfrecordio.ReadAllFromTFRecord( 108 | coder=beam.coders.ProtoCoder(note_seq.NoteSequence)) 109 | p |= 'shuffle_input' >> beam.Reshuffle() 110 | p |= 'encode_song' >> beam.ParDo(EncodeSong()) 111 | p |= 'shuffle_output' >> beam.Reshuffle() 112 | p |= 'write' >> beam.io.WriteToTFRecord(FLAGS.output) 113 | 114 | 115 | if __name__ == '__main__': 116 | app.run(main) 117 | -------------------------------------------------------------------------------- /scripts/sample_audio.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Generate wav files from samples.""" 17 | import os 18 | import sys 19 | 20 | import jax 21 | import jax.numpy as jnp 22 | import note_seq 23 | import numpy as np 24 | import ray 25 | import tensorflow as tf 26 | 27 | from absl import app 28 | from absl import flags 29 | from absl import logging 30 | from flax import nn 31 | from flax.training import checkpoints 32 | from bokeh.io import export_png 33 | from magenta.models.music_vae import TrainedModel 34 | from pathlib import Path 35 | from scipy.io import wavfile 36 | 37 | sys.path.append("{}/../".format(os.path.dirname(os.path.abspath(__file__)))) 38 | import utils.data_utils as data_utils 39 | import utils.song_utils as song_utils 40 | import utils.train_utils as train_utils 41 | import utils.metrics as metrics 42 | import config 43 | import train_lm 44 | 45 | FLAGS = flags.FLAGS 46 | SYNTH = note_seq.fluidsynth 47 | SAMPLE_RATE = 44100 48 | ray.init() 49 | 50 | flags.DEFINE_integer('eval_seed', 42, 'Random number generator seed.') 51 | flags.DEFINE_string('input', 'sample/ncsn', 'Sampling (input) directory.') 52 | flags.DEFINE_string('output', './audio', 'Output directory.') 53 | flags.DEFINE_integer('n_synth', 1000, 'Number of samples to decode.') 54 | flags.DEFINE_boolean('include_wav', True, 'Include audio waveforms.') 55 | flags.DEFINE_boolean('include_plots', True, 'Include Bokeh plots of MIDI.') 56 | flags.DEFINE_boolean('gen_only', False, 'Only generate the fake audio.') 57 | 58 | flags.DEFINE_boolean('melody', True, 'If True, decode melodies.') 59 | flags.DEFINE_boolean('infill', False, 'Evaluate quality of infilled measures.') 60 | flags.DEFINE_boolean('interpolate', False, 'Evaluate interpolations.') 61 | 62 | 63 | def synthesize_ns(path, ns, synth=SYNTH, sample_rate=SAMPLE_RATE): 64 | """Synthesizes and saves NoteSequence to waveform file.""" 65 | array_of_floats = synth(ns, sample_rate=sample_rate) 66 | normalizer = float(np.iinfo(np.int16).max) 67 | array_of_ints = np.array(np.asarray(array_of_floats) * normalizer, 68 | dtype=np.int16) 69 | wavfile.write(path, sample_rate, array_of_ints) 70 | 71 | 72 | def decode_emb(emb, model, data_converter, chunks_only=False): 73 | """Generates NoteSequence objects from set of embeddings. 74 | 75 | Args: 76 | emb: Embeddings of shape (n_seqs, seq_length, 512). 77 | model: Pre-trained MusicVAE model used for decoding. 78 | data_converter: Corresponding data converter for model. 79 | chunks_only: If True, assumes embeddings are of the shape (n_seqs, 512) 80 | where each generated NoteSequence corresponds to one embedding. 81 | 82 | Returns: 83 | A list of decoded NoteSequence objects. 84 | """ 85 | if chunks_only: 86 | assert len(emb.shape) == 2 87 | samples = song_utils.embeddings_to_chunks(emb, model) 88 | samples = [ 89 | song_utils.Song(sample, data_converter, reconstructed=True) 90 | for sample in samples 91 | ] 92 | else: 93 | samples = [] 94 | count = 0 95 | for emb_sample in emb: 96 | if count % 100 == 0: 97 | logging.info(f'Decoded {count} sequences.') 98 | count += 1 99 | recon = song_utils.embeddings_to_song(emb_sample, model, data_converter) 100 | samples.append(recon) 101 | 102 | return samples 103 | 104 | 105 | @ray.remote 106 | def parallel_synth(song, i, ns_dir, audio_dir, image_dir, include_wav, 107 | include_plots): 108 | """Synthesizes NoteSequences (and plots) in parallel.""" 109 | audio_path = os.path.join(audio_dir, f'{i + 1}.wav') 110 | plot_path = os.path.join(image_dir, f'{i + 1}.png') 111 | ns_path = os.path.join(ns_dir, f'{i+1}.pkl') 112 | logging.info(audio_path) 113 | ns = song.play() 114 | 115 | if include_plots: 116 | fig = note_seq.plot_sequence(ns, show_figure=False) 117 | export_png(fig, filename=plot_path) 118 | 119 | if include_wav: 120 | synthesize_ns(audio_path, ns) 121 | 122 | data_utils.save(ns, ns_path) 123 | return ns 124 | 125 | 126 | def main(argv): 127 | del argv # unused 128 | 129 | # Get VAE model. 130 | if FLAGS.melody: 131 | model_config = config.MUSIC_VAE_CONFIG['melody-2-big'] 132 | ckpt = os.path.expanduser('~/checkpoints/cat-mel_2bar_big.tar') 133 | vae_model = TrainedModel(model_config, 134 | batch_size=1, 135 | checkpoint_dir_or_path=ckpt) 136 | else: 137 | model_config = config.MUSIC_VAE_CONFIG['multi-0min-1-big'] 138 | ckpt = os.path.expanduser( 139 | '~/checkpoints/multitrack/fb512_0trackmin/model.ckpt') 140 | vae_model = TrainedModel(model_config, 141 | batch_size=1, 142 | checkpoint_dir_or_path=ckpt) 143 | logging.info(f'Loaded {ckpt}') 144 | 145 | log_dir = FLAGS.input 146 | real = data_utils.load(os.path.join(log_dir, 'real.pkl')) 147 | generated = data_utils.load(os.path.join(log_dir, 'generated.pkl')) 148 | 149 | collection = data_utils.load(os.path.join(log_dir, 'collection.pkl')) 150 | idx = np.linspace(0, 40, 10).astype(np.int32) 151 | collection = collection[idx] 152 | 153 | 154 | # Get baselines. 155 | start_emb = real[:, 7, :] 156 | end_emb = real[:, 24, :] 157 | idx = list(range(32)) 158 | if FLAGS.infill: 159 | fixed_idx = idx[:8] + idx[-8:] 160 | infilled_idx = idx[8:-8] 161 | 162 | # HACK: Since scaling of eval/train is different, re-add the real bars. 163 | generated[:, fixed_idx, :] = real[:, fixed_idx, :] 164 | 165 | # Prior baseline. 166 | prior = np.random.randn(*generated.shape) 167 | prior[:, fixed_idx, :] = real[:, fixed_idx, :] 168 | else: 169 | prior = np.random.randn(*generated.shape) 170 | 171 | # Interpolation baseline. 172 | interp_baseline = [ 173 | song_utils.spherical_interpolation(start_emb, end_emb, alpha) 174 | for alpha in np.linspace(0., 1., 16+2) 175 | ] 176 | interp_baseline = np.stack(interp_baseline).transpose(1, 0, 2) 177 | start_real = real[:, idx[:7], :] 178 | end_real = real[:, idx[-7:], :] 179 | interp_baseline = np.concatenate((start_real, interp_baseline, end_real), axis=1) 180 | assert interp_baseline.shape == generated.shape 181 | 182 | assert real.shape == generated.shape 183 | is_multi_bar = len(generated.shape) > 2 184 | 185 | logging.info('Decoding sequences.') 186 | eval_seqs = {} 187 | for sample_split, sample_emb in (('real', real), ('gen', generated), 188 | ('prior', prior), ('interp', 189 | interp_baseline)): 190 | if FLAGS.gen_only and sample_split != 'gen': 191 | continue 192 | 193 | sample_split = str(sample_split) 194 | audio_dir = os.path.join(FLAGS.output, sample_split, 'audio') 195 | image_dir = os.path.join(FLAGS.output, sample_split, 'images') 196 | ns_dir = os.path.join(FLAGS.output, sample_split, 'ns') 197 | Path(audio_dir).mkdir(parents=True, exist_ok=True) 198 | Path(image_dir).mkdir(parents=True, exist_ok=True) 199 | 200 | sequences = decode_emb(sample_emb[:FLAGS.n_synth], 201 | vae_model, 202 | model_config.data_converter, 203 | chunks_only=not is_multi_bar) 204 | assert len(sequences) == FLAGS.n_synth 205 | 206 | futures = [ 207 | parallel_synth.remote(song, i, ns_dir, audio_dir, image_dir, 208 | FLAGS.include_wav, FLAGS.include_plots) 209 | for i, song in enumerate(sequences) 210 | ] 211 | ns = ray.get(futures) 212 | eval_seqs[sample_split] = ns 213 | 214 | logging.info(f'Sythesized {sample_split} at {audio_dir}') 215 | 216 | 217 | if __name__ == '__main__': 218 | app.run(main) 219 | -------------------------------------------------------------------------------- /scripts/transform_encoded_data.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Transforms an encoded song dataset into an appropriate format 17 | for a model. 18 | """ 19 | import glob 20 | import os 21 | import pickle 22 | import sys 23 | 24 | from absl import app 25 | from absl import flags 26 | from absl import logging 27 | from functools import reduce 28 | 29 | import numpy as np 30 | import tensorflow as tf 31 | 32 | sys.path.append("{}/../".format(os.path.dirname(os.path.abspath(__file__)))) 33 | import utils.data_utils as data_utils 34 | 35 | FLAGS = flags.FLAGS 36 | 37 | flags.DEFINE_boolean('toy_data', False, 'Create a toy dataset.') 38 | flags.DEFINE_string('encoded_data', '~/data/encoded_lmd', 39 | 'Path to encoded data TFRecord directory.') 40 | flags.DEFINE_string('output_path', './output/transform/', 'Output directory.') 41 | flags.DEFINE_integer('shard_size', 2**17, 'Number of vectors per shard.') 42 | flags.DEFINE_enum('output_format', 'tfrecord', ['tfrecord', 'pkl'], 43 | 'Shard file type.') 44 | 45 | flags.DEFINE_enum('mode', 'flatten', ['flatten', 'sequences', 'decoded'], 46 | 'Transformation mode.') 47 | flags.DEFINE_boolean('remove_zeros', True, 'Remove zero vectors.') 48 | flags.DEFINE_integer('context_length', 4, 49 | 'The length of the context window in a sequence.') 50 | flags.DEFINE_integer('stride', 1, 'The stride used for generating sequences.') 51 | flags.DEFINE_integer('max_songs', None, 52 | 'The maximum number of songs to process.') 53 | flags.DEFINE_integer('max_examples', None, 54 | 'The maximum number of examples to process.') 55 | 56 | 57 | def _bytes_feature(value): 58 | if isinstance(value, type(tf.constant(0))): 59 | value = value.numpy() 60 | return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) 61 | 62 | 63 | def _float_feature(values): 64 | return tf.train.Feature(float_list=tf.train.FloatList(value=values)) 65 | 66 | 67 | def _int_feature(values): 68 | return tf.train.Feature(int64_list=tf.train.Int64List(value=values)) 69 | 70 | 71 | def _serialize(writer, input_tensor, target_tensor=None): 72 | assert writer is not None 73 | prod = lambda a: reduce(lambda x, y: x * y, a) 74 | input_shape = input_tensor.shape 75 | inputs = input_tensor.reshape(prod(input_shape),) 76 | 77 | if FLAGS.mode == 'decoded': 78 | sequence = tf.io.serialize_tensor(input_tensor) 79 | features = _bytes_feature(sequence) 80 | else: 81 | features = _float_feature(inputs) 82 | 83 | features = {'inputs': features, 'input_shape': _int_feature(input_shape)} 84 | 85 | if target_tensor is not None: 86 | target_shape = target_tensor.shape 87 | targets = target_tensor.reshape(prod(target_shape),) 88 | features['targets'] = _float_feature(targets) 89 | features['target_shape'] = _int_feature(target_shape) 90 | 91 | tf_example = tf.train.Example(features=tf.train.Features(feature=features)) 92 | writer.write(tf_example.SerializeToString()) 93 | 94 | 95 | def _serialize_tf_shard(shard, output_path): 96 | with tf.io.TFRecordWriter(os.path.expanduser(output_path)) as writer: 97 | if FLAGS.mode == 'sequences': 98 | for context, target in zip(*shard): 99 | _serialize(writer, context, target_tensor=target) 100 | elif FLAGS.mode == 'flatten' or FLAGS.mode == 'decoded': 101 | for example in shard: 102 | _serialize(writer, example) 103 | logging.info('Saved to %s', output_path) 104 | 105 | 106 | def save_shard(contexts, targets, output_path): 107 | if FLAGS.mode == 'flatten' or FLAGS.mode == 'decoded': 108 | shard = targets[:FLAGS.shard_size] 109 | 110 | shard_type = np.bool if FLAGS.mode == 'decoded' else np.float32 111 | shard = np.stack(shard).astype(shard_type) 112 | 113 | targets = targets[FLAGS.shard_size:] 114 | elif FLAGS.mode == 'sequences': 115 | context_shard = contexts[:FLAGS.shard_size] 116 | target_shard = targets[:FLAGS.shard_size] 117 | context_shard = np.stack(context_shard).astype(np.float32) 118 | target_shard = np.stack(target_shard).astype(np.float32) 119 | shard = (context_shard, target_shard) 120 | 121 | contexts = contexts[FLAGS.shard_size:] 122 | targets = targets[FLAGS.shard_size:] 123 | 124 | output_path += '.' + FLAGS.output_format 125 | 126 | # Serialize shard 127 | if FLAGS.output_format == 'pkl': 128 | data_utils.save(shard, output_path) 129 | elif FLAGS.output_format == 'tfrecord': 130 | _serialize_tf_shard(shard, output_path) 131 | 132 | return contexts, targets 133 | 134 | 135 | def toy_distribution_fn(batch_size=512): 136 | """Samples from a 0.2 * N(-5, 1) + 0.8 * N(5, 1).""" 137 | 138 | c1 = (np.random.randn(batch_size, 2) + 5) 139 | c2 = (np.random.randn(batch_size, 2) + -5) 140 | mask = np.random.uniform(size=batch_size) < 0.8 141 | mask = mask[:, np.newaxis] 142 | mixture = mask * c1 + (1 - mask) * c2 143 | return mixture 144 | 145 | 146 | def toy_sequence_distribution_fn(trajectory_length=10, batch_size=512): 147 | c1 = 0.01 * np.random.randn(batch_size, 2) + 5 148 | c2 = 0.01 * np.random.randn(batch_size, 2) + -5 149 | mask = np.random.uniform(size=batch_size) < 0.8 150 | mask = mask[:, np.newaxis] 151 | center = mask * c1 + (1 - mask) * c2 152 | step_size = 0.1 * np.random.randn(batch_size, 2) 153 | deltas = np.expand_dims(step_size, 1).repeat( 154 | trajectory_length, axis=1) * np.arange(trajectory_length).reshape( 155 | trajectory_length, 1) 156 | center = np.expand_dims(center, 1).repeat(trajectory_length, axis=1) 157 | return center + deltas 158 | 159 | 160 | def main(argv): 161 | del argv # unused 162 | 163 | if FLAGS.mode == 'decoded': 164 | train_glob = f'{FLAGS.encoded_data}/decoded-train.tfrecord-*' 165 | eval_glob = f'{FLAGS.encoded_data}/decoded-eval.tfrecord-*' 166 | else: 167 | train_glob = f'{FLAGS.encoded_data}/training_seqs.tfrecord-*' 168 | eval_glob = f'{FLAGS.encoded_data}/eval_seqs.tfrecord-*' 169 | 170 | train_files = glob.glob(os.path.expanduser(train_glob)) 171 | eval_files = glob.glob(os.path.expanduser(eval_glob)) 172 | 173 | tensor_shape = [tf.float64] 174 | train_dataset = tf.data.TFRecordDataset( 175 | train_files).map(lambda x: tf.py_function( 176 | lambda binary: pickle.loads(binary.numpy()), [x], tensor_shape), 177 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 178 | eval_dataset = tf.data.TFRecordDataset( 179 | eval_files).map(lambda x: tf.py_function( 180 | lambda binary: pickle.loads(binary.numpy()), [x], tensor_shape), 181 | num_parallel_calls=tf.data.experimental.AUTOTUNE) 182 | 183 | ctx_window = FLAGS.context_length 184 | stride = FLAGS.stride 185 | 186 | for ds, split in [(train_dataset, 'train'), (eval_dataset, 'eval')]: 187 | if FLAGS.max_songs is not None: 188 | ds = ds.take(FLAGS.max_songs) 189 | 190 | output_fp = '{}/{}-{:04d}' 191 | contexts, targets = [], [] 192 | count = 0 193 | discard = 0 194 | example_count, should_terminate = 0, False 195 | for song_data in ds.as_numpy_iterator(): 196 | song_embeddings = song_data[0] 197 | 198 | if FLAGS.mode != 'decoded': 199 | assert song_embeddings.ndim == 3 and song_embeddings.shape[0] == 3 200 | 201 | # Use the full VAE embedding 202 | song = song_embeddings[0] 203 | 204 | else: 205 | song = song_data[0] 206 | if song.shape[0] < 896: 207 | discard += 1 208 | continue 209 | 210 | pad_len = 1024 - song.shape[0] 211 | padding = np.zeros((pad_len, song.shape[-1])) 212 | padding[:, 0] = 1. 213 | song = np.concatenate((song, padding)) 214 | assert song.shape[0] == 1024 and song.ndim == 2 215 | 216 | if FLAGS.mode == 'decoded': 217 | example_count += 1 218 | targets.append(song) 219 | 220 | if FLAGS.toy_data: 221 | song = toy_distribution_fn(batch_size=len(song)) 222 | 223 | if FLAGS.mode == 'flatten': 224 | for vec in song: 225 | if FLAGS.remove_zeros and np.linalg.norm(vec) < 1e-6: 226 | continue 227 | if FLAGS.max_examples is not None and example_count >= FLAGS.max_examples: 228 | should_terminate = True 229 | break 230 | example_count += 1 231 | targets.append(vec) 232 | elif FLAGS.mode == 'sequences': 233 | for i in range(0, len(song) - ctx_window, stride): 234 | context = song[i:i + ctx_window] 235 | if FLAGS.remove_zeros and np.where( 236 | np.linalg.norm(context, axis=1) < 1e-6)[0].any(): 237 | continue 238 | if FLAGS.max_examples is not None and example_count >= FLAGS.max_examples: 239 | should_terminate = True 240 | break 241 | example_count += 1 242 | contexts.append(context) 243 | targets.append(song[i + ctx_window]) 244 | 245 | if len(targets) >= FLAGS.shard_size: 246 | contexts, targets = save_shard( 247 | contexts, targets, output_fp.format(FLAGS.output_path, split, 248 | count)) 249 | count += 1 250 | 251 | if should_terminate: 252 | break 253 | 254 | logging.info(f'Discarded {discard} invalid sequences.') 255 | if len(targets) > 0: 256 | save_shard(contexts, targets, 257 | output_fp.format(FLAGS.output_path, split, count)) 258 | 259 | 260 | if __name__ == '__main__': 261 | app.run(main) 262 | -------------------------------------------------------------------------------- /train_mdn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Train Transformer-based continuous language model.""" 17 | import os 18 | import time 19 | 20 | from absl import app 21 | from absl import flags 22 | from absl import logging 23 | from functools import partial 24 | 25 | import jax 26 | import jax.numpy as jnp 27 | import jax.experimental.optimizers 28 | import numpy as np 29 | import tensorflow as tf 30 | import tensorflow_datasets as tfds 31 | 32 | from flax import nn 33 | from flax import optim 34 | from flax.metrics import tensorboard 35 | from flax.training import checkpoints 36 | from flax.training import lr_schedule 37 | 38 | import input_pipeline 39 | import utils.train_utils as train_utils 40 | import utils.data_utils as data_utils 41 | import models.autoregressive as ar 42 | from utils.losses import reduce_fn 43 | 44 | from tensorflow_probability.substrates import jax as tfp 45 | tfd = tfp.distributions 46 | 47 | FLAGS = flags.FLAGS 48 | 49 | flags.DEFINE_integer('seed', 0, 'Random seed for network initialization.') 50 | 51 | # Training 52 | flags.DEFINE_float('learning_rate', 3e-4, 'Learning rate for optimizer.') 53 | flags.DEFINE_integer('batch_size', 128, 'Batch size for training.') 54 | flags.DEFINE_integer('epochs', 1000, 'Number of training epochs.') 55 | flags.DEFINE_integer('max_steps', 100000, 'Maximum number of training steps.') 56 | 57 | # Training stability 58 | flags.DEFINE_boolean('early_stopping', False, 59 | 'Use early stopping to prevent overfitting.') 60 | flags.DEFINE_float('grad_clip', 1., 'Max gradient norm for training.') 61 | flags.DEFINE_float('lr_gamma', 0.98, 'Gamma for learning rate scheduler.') 62 | flags.DEFINE_integer('lr_schedule_interval', 4000, 63 | 'Number of steps between LR changes.') 64 | flags.DEFINE_float('lr_warmup', 0, 'Learning rate warmup (epochs).') 65 | 66 | # Model 67 | flags.DEFINE_string('architecture', 'TransformerMDN', 68 | 'Class name of model architecture.') 69 | flags.DEFINE_integer('mdn_components', 100, 'Number of mixtures.') 70 | flags.DEFINE_integer('num_heads', 8, 'Number of attention heads.') 71 | flags.DEFINE_integer('num_layers', 6, 'Number of encoder layers.') 72 | flags.DEFINE_integer('num_mlp_layers', 2, 'Number of output MLP layers.') 73 | flags.DEFINE_integer('mlp_dims', 2048, 'Number of channels per MLP layer.') 74 | 75 | # Data 76 | flags.DEFINE_list('data_shape', [32, 512], 'Shape of data.') 77 | flags.DEFINE_string( 78 | 'dataset', './output/mel-32step-512', 79 | 'Path to directory containing data as train/eval tfrecord files.') 80 | flags.DEFINE_string('pca_ckpt', '', 'PCA transform.') 81 | flags.DEFINE_string('slice_ckpt', '', 'Slice transform.') 82 | flags.DEFINE_string('dim_weights_ckpt', '', 'Dimension scale transform.') 83 | flags.DEFINE_boolean('normalize', True, 'Normalize dataset to [-1, 1].') 84 | 85 | # Logging, checkpointing, and evaluation 86 | flags.DEFINE_integer('logging_freq', 100, 'Logging frequency.') 87 | flags.DEFINE_integer('snapshot_freq', 5000, 88 | 'Evaluation and checkpoint frequency.') 89 | flags.DEFINE_boolean('snapshot_sampling', True, 90 | 'Sample from score network during evaluation.') 91 | flags.DEFINE_integer('eval_samples', 3000, 'Number of samples to generate.') 92 | flags.DEFINE_integer('checkpoints_to_keep', 50, 93 | 'Number of checkpoints to keep.') 94 | flags.DEFINE_boolean('save_ckpt', True, 95 | 'Save model checkpoints at each evaluation step.') 96 | flags.DEFINE_string('model_dir', './save/mdn', 'Directory to store model data.') 97 | flags.DEFINE_boolean('verbose', True, 'Toggle logging to stdout.') 98 | 99 | 100 | def mdn_loss(pi, mu, log_sigma, x, reduction='mean'): 101 | """Mixture density loss. 102 | 103 | Args: 104 | pi: Unnormalized component mixture distribution. 105 | mu: Mean vectors. 106 | log_sigma: Log standard deviation vectors. 107 | reduction: Type of reduction to apply to loss. 108 | 109 | Returns: 110 | Loss value. If `reduction` is `none`, this has the same shape as `data`; 111 | otherwise, it is scalar. 112 | """ 113 | channels = x.shape[-1] 114 | mdn_k = pi.shape[-1] 115 | out_pi = pi.reshape(-1, mdn_k) 116 | out_mu = mu.reshape(-1, channels * mdn_k) 117 | out_log_sigma = log_sigma.reshape(-1, channels * mdn_k) 118 | 119 | # Create mixture distribution 120 | mix_dist = tfd.Categorical(logits=out_pi) 121 | 122 | # Create component distribution 123 | mus = out_mu.reshape(-1, mdn_k, channels) 124 | log_sigmas = out_log_sigma.reshape(-1, mdn_k, channels) 125 | sigmas = jnp.exp(log_sigmas) 126 | component_dist = tfd.MultivariateNormalDiag(loc=mus, scale_diag=sigmas) 127 | 128 | # Compute loss 129 | mixture = tfd.MixtureSameFamily(mixture_distribution=mix_dist, 130 | components_distribution=component_dist) 131 | x = x.reshape(-1, channels) 132 | loss = -1 * mixture.log_prob(x) 133 | return reduce_fn(loss, reduction) 134 | 135 | 136 | def create_optimizer(model, learning_rate): 137 | optimizer_def = optim.Adam(learning_rate=learning_rate) 138 | optimizer = optimizer_def.create(model) 139 | return optimizer 140 | 141 | 142 | def create_model(rng, input_shape, model_kwargs, batch_size=32, verbose=False): 143 | clazz = getattr(ar, FLAGS.architecture) 144 | module = clazz.partial(**model_kwargs) 145 | output, initial_params = module.init_by_shape( 146 | rng, [((batch_size, *input_shape), jnp.float32)]) 147 | model = nn.Model(module, initial_params) 148 | 149 | if verbose: 150 | train_utils.report_model(model) 151 | return model 152 | 153 | 154 | @jax.jit 155 | def eval_step(batch, model): 156 | """A single evaluation step. 157 | 158 | Args: 159 | batch: A batch of inputs. 160 | model: The model to be used for this evaluation step. 161 | 162 | Returns: 163 | loss: The summed loss on this batch. 164 | examples: Number of examples in this batch. 165 | """ 166 | pi, mu, log_sigma = model(batch) 167 | loss = mdn_loss(pi, mu, log_sigma, batch, 'none') 168 | return loss.sum(), loss.shape[0] 169 | 170 | 171 | def evaluate(dataset, model): 172 | """Evaluates the model on a dataset. 173 | 174 | Args: 175 | dataset: A dataset to be used for the evaluation. Typically valid or test. 176 | model: A model to be evaluated. 177 | 178 | Returns: 179 | A dict with the evaluation results. 180 | """ 181 | count = 0 182 | total_loss = 0. 183 | 184 | for inputs in tfds.as_numpy(dataset): 185 | loss, examples = eval_step(inputs, model) 186 | count += examples 187 | total_loss += loss.item() 188 | 189 | loss = total_loss / count 190 | metrics = {'loss': loss} 191 | 192 | return metrics 193 | 194 | 195 | @jax.jit 196 | def train_step(batch, optimizer, learning_rate): 197 | """Single optimized training step. 198 | 199 | Args: 200 | batch: A batch of inputs. 201 | optimizer: The optimizer to use to update the weights. 202 | learning_rate: Current learning rate. 203 | 204 | Returns: 205 | optimizer: The optimizer in its new state. 206 | train_metrics: A dict with training statistics for the step. 207 | """ 208 | 209 | def loss_fn(model): 210 | pi, mu, log_sigma = model(batch) 211 | loss = mdn_loss(pi, mu, log_sigma, batch, 'mean') 212 | train_metrics = {'loss': loss} 213 | return loss, train_metrics 214 | 215 | grad_fn = jax.value_and_grad(loss_fn, has_aux=True) 216 | (loss, train_metrics), grad = grad_fn(optimizer.target) 217 | grad = jax.experimental.optimizers.clip_grads(grad, FLAGS.grad_clip) 218 | train_metrics['grad'] = jax.experimental.optimizers.l2_norm(grad) 219 | train_metrics['lr'] = learning_rate 220 | optimizer = optimizer.apply_gradient(grad, learning_rate=learning_rate) 221 | return optimizer, train_metrics 222 | 223 | 224 | def train(train_batches, valid_batches, output_dir=None, verbose=True): 225 | """Training loop. 226 | 227 | Args: 228 | train_batches: Training batches from tf.data.Dataset. 229 | valid_batches: Validation batches from tf.data.Dataset. 230 | output_dir: Output directory for checkpoints, logs, and samples. 231 | verbose: Logging verbosity. 232 | 233 | Returns: 234 | An optimizer object with final state. 235 | """ 236 | train_writer = tensorboard.SummaryWriter(os.path.join(output_dir, 'train')) 237 | eval_writer = tensorboard.SummaryWriter(os.path.join(output_dir, 'eval')) 238 | 239 | tfds_batch = valid_batches.take(1) 240 | tfds_batch = list(valid_batches.as_numpy_iterator())[0] 241 | batch_size, *input_shape = tfds_batch.shape 242 | 243 | rng = jax.random.PRNGKey(FLAGS.seed) 244 | rng, model_rng = jax.random.split(rng) 245 | 246 | lm_kwargs = { 247 | 'num_layers': FLAGS.num_layers, 248 | 'num_heads': FLAGS.num_heads, 249 | 'mdn_mixtures': FLAGS.mdn_components, 250 | 'num_mlp_layers': FLAGS.num_mlp_layers, 251 | 'mlp_dims': FLAGS.mlp_dims 252 | } 253 | model = create_model(model_rng, 254 | input_shape, 255 | lm_kwargs, 256 | batch_size, 257 | verbose=verbose) 258 | optimizer = create_optimizer(model, FLAGS.learning_rate) 259 | early_stop = train_utils.EarlyStopping(patience=1) 260 | 261 | # Learning rate schedule 262 | lr_step_schedule = [(i, FLAGS.lr_gamma**i) for i in range(1000)] 263 | lr_scheduler = lr_schedule.create_stepped_learning_rate_schedule( 264 | FLAGS.learning_rate, 265 | FLAGS.lr_schedule_interval, 266 | lr_step_schedule, 267 | warmup_length=FLAGS.lr_warmup) 268 | 269 | sampling_step = -1 270 | for epoch in range(FLAGS.epochs): 271 | start_time = time.time() 272 | for step, batch in enumerate(tfds.as_numpy(train_batches)): 273 | global_step = step + epoch * train_batches.examples 274 | optimizer, train_metrics = train_step(batch, optimizer, 275 | lr_scheduler(global_step)) 276 | 277 | if step % FLAGS.logging_freq == 0: 278 | elapsed = time.time() - start_time 279 | batch_per_sec = (step + 1) / elapsed 280 | ms_per_batch = elapsed * 1000 / (step + 1) 281 | train_metrics['batch/s'] = batch_per_sec 282 | train_metrics['ms/batch'] = ms_per_batch 283 | train_utils.log_metrics(train_metrics, 284 | step, 285 | train_batches.examples, 286 | epoch=epoch, 287 | summary_writer=train_writer, 288 | verbose=verbose) 289 | 290 | if (step % FLAGS.snapshot_freq == 0 and 291 | step > 0) or step == train_batches.examples - 1: 292 | 293 | sampling_step += 1 294 | 295 | eval_metrics = evaluate(valid_batches, optimizer.target) 296 | train_utils.log_metrics(eval_metrics, 297 | global_step, 298 | train_batches.examples * FLAGS.epochs, 299 | summary_writer=eval_writer, 300 | verbose=verbose) 301 | improved, early_stop = early_stop.update(eval_metrics['loss']) 302 | 303 | if (not FLAGS.early_stopping and FLAGS.save_ckpt) or \ 304 | (FLAGS.early_stopping and improved and FLAGS.save_ckpt): 305 | checkpoints.save_checkpoint(output_dir, (optimizer, early_stop), 306 | sampling_step, 307 | keep=FLAGS.checkpoints_to_keep) 308 | 309 | if FLAGS.early_stopping and early_stop.should_stop: 310 | logging.info('EARLY STOP: Ended training after %s epochs.', epoch + 1) 311 | return 312 | 313 | train_writer.flush() 314 | eval_writer.flush() 315 | 316 | # Early termination of training loop. 317 | if FLAGS.max_steps is not None and \ 318 | global_step >= FLAGS.max_steps: 319 | return optimizer 320 | 321 | return optimizer 322 | 323 | 324 | def main(argv): 325 | del argv # unused 326 | 327 | logging.info(FLAGS.flags_into_string()) 328 | logging.info('Platform: %s', jax.lib.xla_bridge.get_backend().platform) 329 | 330 | # Make sure TensorFlow does not allocate GPU memory. 331 | tf.config.experimental.set_visible_devices([], 'GPU') 332 | 333 | train_ds, eval_ds = input_pipeline.get_dataset( 334 | dataset=FLAGS.dataset, 335 | data_shape=FLAGS.data_shape, 336 | problem='vae', 337 | batch_size=FLAGS.batch_size, 338 | normalize=FLAGS.normalize, 339 | pca_ckpt=FLAGS.pca_ckpt, 340 | slice_ckpt=FLAGS.slice_ckpt, 341 | dim_weights_ckpt=FLAGS.dim_weights_ckpt) 342 | 343 | train(train_batches=train_ds, 344 | valid_batches=eval_ds, 345 | output_dir=FLAGS.model_dir, 346 | verbose=FLAGS.verbose) 347 | 348 | 349 | if __name__ == '__main__': 350 | app.run(main) 351 | -------------------------------------------------------------------------------- /utils/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Dataset utilities.""" 17 | import os 18 | import pickle 19 | 20 | import jax 21 | import numpy as np 22 | import tensorflow as tf 23 | 24 | from absl import logging 25 | from functools import reduce 26 | 27 | AUTOTUNE = tf.data.experimental.AUTOTUNE 28 | 29 | 30 | def save(obj, path): 31 | """Save an object to disk as a pickle file.""" 32 | os.makedirs(os.path.dirname(path), exist_ok=True) 33 | with open(path, 'wb') as f: 34 | pickle.dump(obj, f, protocol=4) 35 | logging.info('Saved to %s', path) 36 | 37 | 38 | def load(path): 39 | """Load pickled object into memory.""" 40 | with open(path, 'rb') as f: 41 | return pickle.load(f) 42 | 43 | 44 | def _decode_record(record, flattened_shape, shape_len, tokens=False): 45 | if not tokens: 46 | input_parser = tf.io.FixedLenFeature([flattened_shape], tf.float32) 47 | else: 48 | input_parser = tf.io.FixedLenFeature((), tf.string) 49 | 50 | parsed = tf.io.parse_example( 51 | record, { 52 | 'inputs': input_parser, 53 | 'input_shape': tf.io.FixedLenFeature([shape_len], tf.int64) 54 | }) 55 | 56 | if tokens: 57 | parsed['inputs'] = tf.io.parse_tensor(parsed['inputs'], out_type=np.bool) 58 | 59 | parsed['inputs'] = tf.reshape(parsed['inputs'], parsed['input_shape']) 60 | return parsed 61 | 62 | 63 | def compute_dataset_cardinality(ds, 64 | ds_split='train', 65 | cache=False, 66 | cache_dir=None, 67 | config=''): 68 | """Computes and optionally caches cardinality of tf.data.Dataset.""" 69 | card_cache_path = os.path.join(cache_dir, 70 | f'cache/{ds_split}_{config}_cardinality.pkl') 71 | 72 | if os.path.exists(card_cache_path): 73 | logging.info('Using cached dataset cardinality at %s', cache_dir) 74 | cardinality = load(card_cache_path) 75 | else: 76 | cardinality = -1 77 | if hasattr(ds, 'cardinality'): 78 | cardinality = ds.cardinality().numpy() 79 | if cardinality <= 0: 80 | cardinality = 0 81 | for e in ds.as_numpy_iterator(): 82 | cardinality += 1 83 | 84 | if cache: 85 | assert cache_dir is not None 86 | assert not hasattr(ds, 'examples') 87 | setattr(ds, 'examples', cardinality) 88 | save(cardinality, card_cache_path) 89 | 90 | return cardinality 91 | 92 | 93 | def compute_dataset_statistics(ds, 94 | ds_split='train', 95 | cache=False, 96 | cache_dir=None, 97 | config=''): 98 | """Computes the mean and standard deviation of tf.data.Dataset.""" 99 | mean_cache_path = os.path.join(cache_dir, 100 | f'cache/{ds_split}_{config}_mean.pkl') 101 | stddev_cache_path = os.path.join(cache_dir, 102 | f'cache/{ds_split}_{config}_stddev.pkl') 103 | 104 | if os.path.exists(mean_cache_path) and os.path.exists(stddev_cache_path): 105 | logging.info('Using cached dataset statistics at %s', cache_dir) 106 | ds_mean = load(mean_cache_path) 107 | ds_std = load(stddev_cache_path) 108 | else: 109 | cardinality = compute_dataset_cardinality(ds, cache=False) 110 | ds_sum = ds.reduce(0., lambda x, y: x + y).numpy() 111 | ds_squared = ds.map(lambda x: x**2, num_parallel_calls=AUTOTUNE) 112 | ds_squared_sum = ds_squared.reduce(0., lambda x, y: x + y).numpy() 113 | ds_mean = ds_sum / cardinality 114 | ds_second_moment = ds_squared_sum / cardinality 115 | ds_std = np.sqrt(ds_second_moment - ds_mean**2) 116 | 117 | if cache: 118 | assert cache_dir is not None 119 | assert not hasattr(ds, 'mean') and not hasattr(ds, 'stddev') 120 | setattr(ds, 'mean', ds_mean) 121 | setattr(ds, 'stddev', ds_std) 122 | save(ds_mean, mean_cache_path) 123 | save(ds_std, stddev_cache_path) 124 | 125 | return ds_mean, ds_std 126 | 127 | 128 | def compute_dataset_min_max(ds, 129 | ds_split='train', 130 | cache=False, 131 | cache_dir=None, 132 | config=''): 133 | """Computes the min and max of (batched) tf.data.Dataset.""" 134 | min_cache_path = os.path.join(cache_dir, f'cache/{ds_split}_{config}_min.pkl') 135 | max_cache_path = os.path.join(cache_dir, f'cache/{ds_split}_{config}_max.pkl') 136 | 137 | if os.path.exists(min_cache_path) and os.path.exists(max_cache_path): 138 | logging.info('Using cached dataset min/max at %s', cache_dir) 139 | ds_min = load(min_cache_path) 140 | ds_max = load(max_cache_path) 141 | else: 142 | ds_maxes = ds.map(lambda x: tf.reduce_max(x), num_parallel_calls=AUTOTUNE) 143 | ds_mins = ds.map(lambda x: tf.reduce_min(x), num_parallel_calls=AUTOTUNE) 144 | ds_min = ds_mins.reduce(tf.float32.max, lambda x, y: tf.math.minimum(x, y)) 145 | ds_max = ds_maxes.reduce(tf.float32.min, lambda x, y: tf.math.maximum(x, y)) 146 | ds_min, ds_max = ds_min.numpy(), ds_max.numpy() 147 | 148 | if cache: 149 | assert cache_dir is not None 150 | assert not hasattr(ds, 'min') and not hasattr(ds, 'max') 151 | setattr(ds, 'min', ds_min) 152 | setattr(ds, 'max', ds_max) 153 | save(ds_min, min_cache_path) 154 | save(ds_max, max_cache_path) 155 | 156 | return ds_min, ds_max 157 | 158 | 159 | def get_tf_record_dataset(file_pattern=None, 160 | shape=(512,), 161 | batch_size=512, 162 | shuffle=True, 163 | tokens=False): 164 | """Generates a TFRecord dataset given a file pattern. 165 | 166 | Args: 167 | file_pattern: TFRecord file pattern (e.g train-*.tfrecord) 168 | shape: Example shape. 169 | batch_size: Number of examples per batch during training. This 170 | argument is used for setting the shuffle buffer size. 171 | shuffle: Whether to shuffle the data or not. 172 | tokens: Extract one-hot token dataset. 173 | 174 | Returns: 175 | A tf.data.Dataset iterator. 176 | """ 177 | filenames = tf.data.Dataset.list_files(os.path.expanduser(file_pattern), 178 | shuffle=shuffle) 179 | dataset = filenames.interleave(map_func=tf.data.TFRecordDataset, 180 | cycle_length=40, 181 | num_parallel_calls=AUTOTUNE, 182 | deterministic=False) 183 | if shuffle: 184 | dataset = dataset.shuffle(8 * batch_size) 185 | 186 | prod = lambda a: reduce(lambda x, y: x * y, a) 187 | flattened_shape = prod(shape) 188 | shape_len = len(shape) 189 | decode_fn = lambda x: _decode_record(x, flattened_shape, shape_len, tokens) 190 | dataset = dataset.map(decode_fn, num_parallel_calls=AUTOTUNE) 191 | return dataset 192 | 193 | 194 | def _truncate_embeddings(embeddings, length): 195 | """Truncate embedding matrix.""" 196 | pad_length = length - len(embeddings) 197 | if pad_length <= 0: 198 | embeddings = embeddings[:length] 199 | else: 200 | padding = np.zeros((pad_length, embeddings.shape[-1])) 201 | embeddings = np.concatenate((embeddings, padding)) 202 | 203 | assert len(embeddings) == length 204 | 205 | return embeddings 206 | 207 | 208 | def self_similarity(embeddings, normalized=True, max_len=80): 209 | """Generates self-similarity matrix for sequence of embeddings.""" 210 | embeddings = _truncate_embeddings(embeddings, max_len) 211 | 212 | self_sim = np.dot(embeddings, embeddings.T) 213 | if normalized: 214 | norm_embeddings = embeddings / np.linalg.norm( 215 | embeddings, ord=2, axis=1, keepdims=True) 216 | self_sim = np.dot(norm_embeddings, norm_embeddings.T) 217 | self_sim[np.isnan(self_sim)] = 0 # hack to overcome division by zero NaNs 218 | return self_sim 219 | 220 | 221 | def unroll_upper_triangular(matrix): 222 | """Converts square matrix to vector by unrolling upper triangle.""" 223 | rows, cols = matrix.shape 224 | assert rows == cols, 'Not a square matrix.' 225 | 226 | row_idx, col_idx = np.triu_indices(rows, 1) 227 | unrolled = [] 228 | for i, j in zip(row_idx, col_idx): 229 | unrolled.append(matrix[i][j]) 230 | assert len(unrolled) == rows * (rows - 1) // 2 231 | return unrolled 232 | 233 | 234 | def roll_upper_triangular(vector, size): 235 | """Converts unrolled upper triangle into square matrix.""" 236 | matrix = np.ones((size, size)) 237 | offset = 0 238 | for i in range(size): 239 | stream = vector[offset:] 240 | row = stream[:size - (i + 1)] 241 | matrix[i, i + 1:size] = row 242 | matrix[i + 1:size, i] = row 243 | offset += len(row) 244 | assert offset == len(vector) 245 | return matrix 246 | 247 | 248 | def erase_bars(embeddings, indices): 249 | """Erases vectors from a given a set of embeddings. 250 | 251 | Args: 252 | embeddings: A numpy matrix of embeddings. 253 | indices: A list of indices corresponding to vectors that will be erased. 254 | 255 | Returns: 256 | A modified embedding matrix. 257 | """ 258 | return jax.ops.index_update(embeddings, jax.ops.index[indices], 0) 259 | 260 | 261 | def infill_bars(embeddings, chunk_params, erased_chunk_indices): 262 | """Infills a partially erased embedding matrix with sampled embedding parameters. 263 | 264 | Args: 265 | embeddings: A partially incomplete embedding matrix. 266 | chunk_params: A list of sampled embedding vectors. 267 | erased_bar_indices: A list of indices corresponding to vector positions in the 268 | embedding matrix that will be replaced by sampled embeddings. 269 | 270 | Returns: 271 | A modified embedding matrix. 272 | """ 273 | assert len(chunk_params) == len(erased_chunk_indices) 274 | return jax.ops.index_update(embeddings, jax.ops.index[erased_chunk_indices], 275 | chunk_params) 276 | 277 | 278 | def batches(data, labels=None, batch_size=32): 279 | """Generate batches of data. 280 | 281 | Args: 282 | data: A numpy matrix of data of shape [num_examples, *data_shape]. 283 | labels: An optional matrix of corresponding labels for each entry in data of 284 | shape [num_examples, *label_shape] 285 | batch_size: Batch size. 286 | 287 | Returns: 288 | An iterator that yields batches of data with their labels. 289 | """ 290 | num_batches = data.shape[0] // batch_size 291 | for i in range(num_batches): 292 | j, k = i * batch_size, (i + 1) * batch_size 293 | if labels is not None: 294 | assert len(data) == len(labels) 295 | batch = (data[j:k], labels[j:k]) 296 | else: 297 | batch = data[j:k] 298 | yield batch 299 | 300 | 301 | def shuffle(data, labels=None): 302 | """Shuffle dataset. 303 | 304 | Args: 305 | data: A numpy matrix of data of shape [num_examples, *data_shape]. 306 | labels: An optional matrix of corresponding labels for each entry in data of 307 | shape [num_examples, *label_shape]. 308 | 309 | Returns: 310 | Shuffled data and label matrices. 311 | """ 312 | idx = np.random.permutation(len(data)) 313 | shuffled_data = data[idx] 314 | if labels is not None: 315 | assert len(data) == len(labels) 316 | shuffled_labels = labels[idx] 317 | return shuffled_data, shuffled_labels 318 | else: 319 | return shuffled_data 320 | -------------------------------------------------------------------------------- /utils/ebm_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Utilities for training energy-based models.""" 17 | import jax 18 | import jax.numpy as jnp 19 | import numpy as np 20 | 21 | from flax import struct 22 | from functools import partial 23 | 24 | 25 | @struct.dataclass 26 | class ReplayBuffer(object): 27 | """Replay buffer for sampling.""" 28 | buffer_size: int 29 | dims: int 30 | data: any 31 | 32 | def add(self, samples): 33 | num_samples = len(samples) 34 | new_data = jnp.concatenate((samples, self.data[:-num_samples])) 35 | return self.replace(buffer_size=self.buffer_size, 36 | dims=self.dims, 37 | data=new_data) 38 | 39 | def sample(self, rng, n, p=0.95): 40 | """Generates a set of samples. With probability p, each sample 41 | will come from the replay buffer. With probability (1-p) each 42 | sample will be sampled from a uniform distribution. 43 | """ 44 | buf_mask = jax.random.bernoulli(rng, p=p, shape=(n,))[:, jnp.newaxis] 45 | rand_mask = 1 - buf_mask 46 | idx = jax.random.choice(rng, self.buffer_size, shape=(n,), replace=False) 47 | buf = self.data[idx] 48 | rand = jax.random.uniform(rng, shape=(n, self.dims)) 49 | return jnp.where(rand_mask, rand, buf) 50 | 51 | 52 | def vgrad(f, x): 53 | """Computes gradients for a vector-valued function. 54 | 55 | >>> vgrad(lambda x: 3*x**2, jnp.ones((1,))) 56 | DeviceArray([6.], dtype=float32) 57 | """ 58 | y, vjp_fn = jax.vjp(f, x) 59 | return vjp_fn(jnp.ones(y.shape))[0] 60 | 61 | 62 | def create_noise_schedule(sigma_begin=1, 63 | sigma_end=1e-2, 64 | L=10, 65 | schedule='geometric'): 66 | """Creates a noise schedule. 67 | 68 | Args: 69 | sigma_begin: Starting variance. 70 | sigma_end: Ending variance. 71 | L: Number of values in the noise schedule. 72 | schedule: Type of schedule. 73 | """ 74 | if schedule == 'geometric': 75 | sigmas = jnp.exp(jnp.linspace(jnp.log(sigma_begin), jnp.log(sigma_end), L)) 76 | elif schedule == 'linear': 77 | sigmas = jnp.linspace(sigma_begin, sigma_end, L) 78 | elif schedule == 'fibonacci': 79 | sigmas = [1e-6, 2e-6] 80 | for i in range(L - 2): 81 | sigmas.append(sigmas[-1] + sigmas[-2]) 82 | sigmas = jnp.array(sigmas) 83 | else: 84 | raise ValueError(f'Unsupported schedule: {schedule}') 85 | 86 | return sigmas 87 | 88 | 89 | @partial(jax.jit, static_argnums=( 90 | 4, 91 | 5, 92 | 6, 93 | 7, 94 | )) 95 | def annealed_langevin_dynamics(rng, 96 | model, 97 | sigmas, 98 | init, 99 | epsilon, 100 | T, 101 | denoise, 102 | infill=False, 103 | infill_samples=None, 104 | infill_masks=None): 105 | """Annealed Langevin dynamics sampling from Song et al. 106 | 107 | Args: 108 | rng: Random number generator key. 109 | model: Score network. 110 | sigmas: Noise schedule. 111 | init: Initial state for Langevin dynamics (usually uniform noise). 112 | epsilon: Step size coefficient. 113 | T: Number of steps per noise level. 114 | denoise: Apply an additional denoising step to final samples. 115 | infill: Infill partially complete samples. 116 | infill_samples: Partially complete samples to infill. 117 | infill_masks: Binary mask for infilling partially complete samples. 118 | A zero indicates an element that must be infilled by Langevin dynamics. 119 | 120 | Returns: 121 | state: Final state sampled from Langevin dynamics. 122 | collection: Array of state at each step of sampling with shape 123 | (num_sigmas * T + 1 + int(denoise), :). 124 | ld_metrics: Metrics collected for each noise level with shape (num_sigmas, T). 125 | """ 126 | if not infill: 127 | infill_samples = jnp.zeros(init.shape) 128 | infill_masks = jnp.zeros(init.shape) 129 | 130 | collection_steps = 100 131 | start = init * (1 - infill_masks) + infill_samples * infill_masks 132 | images = np.zeros((collection_steps + 1 + int(denoise), *init.shape)) 133 | images = jax.ops.index_update(images, jax.ops.index[0, :], start) 134 | collection_idx = jnp.linspace(1, 135 | len(sigmas) * T, 136 | collection_steps).astype(jnp.int32) 137 | 138 | def langevin_step(params, i): 139 | state, rng, sigma_i, alpha, collection = params 140 | rng, step_rng, infill_rng = jax.random.split(rng, num=3) 141 | sigma = sigmas[sigma_i] 142 | 143 | y = infill_samples + sigma * jax.random.normal(key=infill_rng, 144 | shape=infill_samples.shape) 145 | 146 | grad = model(state, sigma) 147 | noise = jnp.sqrt(2 * alpha) * jax.random.normal(key=step_rng, 148 | shape=state.shape) 149 | next_state = state + alpha * grad + noise # gradient ascent 150 | 151 | # Apply infilling mask 152 | next_state = next_state * (1 - infill_masks) + y * infill_masks 153 | 154 | # Collect samples 155 | image_idx = sigma_i * T + i + 1 156 | idx_mask = jnp.in1d(collection_idx, image_idx) 157 | idx = jnp.sum(jnp.arange(len(collection_idx)) * idx_mask) + 1 158 | collection = jax.lax.cond(idx_mask.any(), 159 | lambda op: jax.ops.index_update( 160 | collection, jax.ops.index[op, :], next_state), 161 | lambda op: collection, 162 | operand=idx) 163 | 164 | # Collect metrics 165 | grad_norm = jnp.sqrt(jnp.sum(jnp.square(grad), axis=1) + 1e-10).mean() 166 | noise_norm = jnp.sqrt(jnp.sum(jnp.square(noise), axis=1) + 1e-10).mean() 167 | step_norm = jnp.sqrt(jnp.sum(jnp.square(alpha * grad), axis=1) + 168 | 1e-10).mean() 169 | metrics = grad_norm, step_norm, alpha, noise_norm 170 | 171 | next_params = (next_state, rng, sigma_i, alpha, collection) 172 | return next_params, metrics 173 | 174 | def sample_with_sigma(params, sigma_i): 175 | state, rng, collection = params 176 | sigma = sigmas[sigma_i] 177 | alpha = epsilon * (sigma / sigmas[-1])**2 178 | 179 | ld_params = (state, rng, sigma_i, alpha, collection) 180 | next_ld_state, metrics = jax.lax.scan(langevin_step, ld_params, 181 | jnp.arange(T)) 182 | next_state, rng, sigma_i, alpha, collection = next_ld_state 183 | 184 | next_params = (next_state, rng, collection) 185 | return next_params, metrics 186 | 187 | assert len(sigmas) >= 2 188 | init_params = (init, rng, images) 189 | ld_state, ld_metrics = jax.lax.scan(sample_with_sigma, init_params, 190 | jnp.arange(len(sigmas))) 191 | state, rng, collection = ld_state 192 | 193 | # Additional denoising step. 194 | if denoise: 195 | state = state + sigmas[-1]**2 * model(state, sigmas[-1]) 196 | collection = jax.ops.index_update(collection, jax.ops.index[-1, :], state) 197 | 198 | return state, collection, jnp.stack(ld_metrics) 199 | 200 | 201 | @partial(jax.jit, static_argnums=( 202 | 4, 203 | 5, 204 | 6, 205 | 7, 206 | )) 207 | def consistent_langevin_dynamics(rng, 208 | model, 209 | sigmas, 210 | init, 211 | epsilon, 212 | T, 213 | denoise=True, 214 | infill=False, 215 | infill_samples=None, 216 | infill_masks=None): 217 | """Consistent annealed Langevin dynamics sampling from Jolicoeur-Martineau et al. 218 | 219 | Args: 220 | rng: Random number generator key. 221 | model: Score network. 222 | sigmas: Noise schedule. 223 | init: Initial state. 224 | epsilon: Step size coefficient. 225 | T: Number of steps per noise level. 226 | 227 | Returns: 228 | state: Final state sampled from Langevin dynamics. 229 | ld_metrics: Metrics collected for each noise level with shape (num_sigmas, T). 230 | """ 231 | if infill: 232 | raise NotImplementedError 233 | 234 | def langevin_step(params, i): 235 | state, rng = params 236 | rng, step_rng = jax.random.split(rng) 237 | 238 | sigma = sigmas[i] 239 | next_sigma = jnp.where(i < len(sigmas) - 1, sigmas[i + 1], 0.) 240 | 241 | alpha = epsilon * (sigma / sigmas[-1])**2 242 | grad = model(state, sigma) 243 | noise = beta * next_sigma * jax.random.normal(key=step_rng, 244 | shape=state.shape) 245 | next_state = state + alpha * grad + noise 246 | 247 | # Collect metrics 248 | grad_norm = jnp.sqrt(jnp.sum(jnp.square(grad), axis=1) + 1e-10).mean() 249 | noise_norm = jnp.sqrt(jnp.sum(jnp.square(noise), axis=1) + 1e-10).mean() 250 | step_norm = jnp.sqrt(jnp.sum(jnp.square(alpha * grad), axis=1) + 251 | 1e-10).mean() 252 | metrics = grad_norm, step_norm, alpha, noise_norm 253 | 254 | next_params = (next_state, rng) 255 | return next_params, metrics 256 | 257 | assert len(sigmas) >= 2 258 | gamma = sigmas[0] / sigmas[1] 259 | beta = jnp.sqrt(1 - (1 - epsilon / (sigmas[-1]**2))**2) 260 | 261 | init_params = (init, rng) 262 | ld_state, ld_metrics = jax.lax.scan(langevin_step, init_params, 263 | jnp.arange(len(sigmas))) 264 | state, rng = ld_state 265 | 266 | if denoise: 267 | state = state + sigmas[-1]**2 * model(state, sigmas[-1]) 268 | 269 | ld_metrics = jnp.stack(ld_metrics) 270 | ld_metrics = jnp.expand_dims(ld_metrics, axis=2) 271 | return state, ld_metrics 272 | 273 | 274 | @partial(jax.jit, static_argnums=( 275 | 4, 276 | 5, 277 | 6, 278 | 7, 279 | )) 280 | def diffusion_dynamics(rng, 281 | model, 282 | betas, 283 | init, 284 | epsilon, 285 | T, 286 | denoise, 287 | infill=False, 288 | infill_samples=None, 289 | infill_masks=None): 290 | """Diffusion dynamics (reverse process decoder). 291 | 292 | Args: 293 | rng: Random number generator key. 294 | model: Diffusion probabilistic network. 295 | betas: Noise schedule. 296 | init: Initial state for Langevin dynamics (usually Gaussian noise). 297 | epsilon: Null parameter. 298 | T: Null parameter. 299 | denoise: Null parameter used in other methods to find EDS. 300 | infill: Infill partially complete samples. 301 | infill_samples: Partially complete samples to infill. 302 | infill_masks: Binary mask for infilling partially complete samples. 303 | A zero indicates an element that must be infilled by Langevin dynamics. 304 | 305 | Returns: 306 | state: Final state sampled from Langevin dynamics. 307 | collection: Array of state at each step of sampling with shape 308 | (num_sigmas * T + 1 + int(denoise), :). 309 | ld_metrics: Metrics collected for each noise level with shape (num_sigmas, T). 310 | """ 311 | if not infill: 312 | infill_samples = jnp.zeros(init.shape) 313 | infill_masks = jnp.zeros(init.shape) 314 | 315 | alphas = 1 - betas 316 | alphas_prod = jnp.cumprod(alphas) 317 | alphas_prod_prev = jnp.concatenate([jnp.ones((1,)), alphas_prod[:-1]]) 318 | assert alphas.shape == alphas_prod.shape == alphas_prod_prev.shape 319 | 320 | collection_steps = 40 321 | start = init * (1 - infill_masks) + infill_samples * infill_masks 322 | images = np.zeros((collection_steps + 1, *init.shape)) 323 | collection = jax.ops.index_update(images, jax.ops.index[0, :], start) 324 | collection_idx = jnp.linspace(1, len(betas), 325 | collection_steps).astype(jnp.int32) 326 | 327 | def sample_with_beta(params, t): 328 | state, rng, collection = params 329 | rng, key = jax.random.split(rng) 330 | 331 | # Noise schedule constants 332 | beta = betas[t] 333 | alpha = alphas[t] 334 | alpha_prod = alphas_prod[t] 335 | alpha_prod_prev = alphas_prod_prev[t] 336 | 337 | # Constants for posterior q(x_t|x_0) 338 | sqrt_reciprocal_alpha_prod = jnp.sqrt(1 / alpha_prod) 339 | sqrt_alpha_prod_m1 = jnp.sqrt(1 - alpha_prod) * sqrt_reciprocal_alpha_prod 340 | 341 | # Create infilling template 342 | rng, infill_noise_rng = jax.random.split(rng) 343 | infill_noise_cond = t > 0 344 | infill_noise = jax.random.normal(key=infill_noise_rng, 345 | shape=infill_samples.shape) 346 | noisy_y = jnp.sqrt(alpha_prod) * infill_samples + jnp.sqrt( 347 | 1 - alpha_prod) * infill_noise 348 | y = infill_noise_cond * noisy_y + (1 - infill_noise_cond) * infill_samples 349 | 350 | # Constants for posterior q(x_t-1|x_t, x_0) 351 | posterior_mu1 = beta * jnp.sqrt(alpha_prod_prev) / (1 - alpha_prod) 352 | posterior_mu2 = (1 - alpha_prod_prev) * jnp.sqrt(alpha) / (1 - alpha_prod) 353 | 354 | # Clipped variance (must be non-zero) 355 | posterior_var = beta * (1 - alpha_prod_prev) / (1 - alpha_prod) 356 | posterior_var_clipped = jnp.maximum(posterior_var, 1e-20) 357 | posterior_log_var = jnp.log(posterior_var_clipped) 358 | 359 | # Noise 360 | rng, noise_rng = jax.random.split(rng) 361 | noise_cond = t > 0 362 | noise = jax.random.normal(key=noise_rng, shape=state.shape) 363 | noise = noise_cond * noise + (1 - noise_cond) * jnp.zeros(state.shape) 364 | noise = noise * jnp.exp(0.5 * posterior_log_var) 365 | 366 | # Reverse process (reconstruction) 367 | noise_condition_vec = jnp.sqrt(alpha_prod) * jnp.ones((noise.shape[0], 1)) 368 | noise_condition_vec = noise_condition_vec.reshape( 369 | init.shape[0], *([1] * len(init.shape[1:]))) 370 | eps_recon = model(state, noise_condition_vec) 371 | state_recon = sqrt_reciprocal_alpha_prod * state - sqrt_alpha_prod_m1 * eps_recon 372 | state_recon = jnp.clip(state_recon, -1., 1.) 373 | posterior_mu = posterior_mu1 * state_recon + posterior_mu2 * state 374 | next_state = posterior_mu + noise 375 | 376 | # Infill 377 | next_state = next_state * (1 - infill_masks) + y * infill_masks 378 | 379 | # Collect metrics 380 | step = state - next_state 381 | grad_norm = jnp.sqrt(jnp.sum(jnp.square(eps_recon), axis=1) + 1e-10).mean() 382 | noise_norm = jnp.sqrt(jnp.sum(jnp.square(noise), axis=1) + 1e-10).mean() 383 | step_norm = jnp.sqrt(jnp.sum(jnp.square(step), axis=1) + 1e-10).mean() 384 | metrics = (grad_norm, step_norm, alpha_prod, noise_norm) 385 | 386 | # Collect samples 387 | image_idx = len(betas) - t + 1 388 | idx_mask = jnp.in1d(collection_idx, image_idx) 389 | idx = jnp.sum(jnp.arange(len(collection_idx)) * idx_mask) + 1 390 | collection = jax.lax.cond(idx_mask.any(), 391 | lambda op: jax.ops.index_update( 392 | collection, jax.ops.index[op, :], next_state), 393 | lambda op: collection, 394 | operand=idx) 395 | 396 | next_params = (next_state, rng, collection) 397 | return next_params, metrics 398 | 399 | init_params = (init, rng, collection) 400 | beta_steps = jnp.arange(len(betas) - 1, -1, -1) 401 | ld_state, ld_metrics = jax.lax.scan(sample_with_beta, init_params, beta_steps) 402 | state, rng, collection = ld_state 403 | ld_metrics = jnp.stack(ld_metrics) 404 | ld_metrics = jnp.expand_dims(ld_metrics, 2) 405 | return state, collection, ld_metrics 406 | 407 | 408 | def collate_sampling_metrics(ld_metrics): 409 | """Converts Langevin metrics into TensorBoard-readable format. 410 | 411 | Args: 412 | ld_metrics: A tensor with metrics returned by annealed_langevin_dynamics 413 | sampling procedure. 414 | """ 415 | num_metrics, num_sigmas, num_steps = ld_metrics.shape 416 | del num_metrics # unused 417 | sampling_metrics = [[] for i in range(num_sigmas)] 418 | for i in range(num_sigmas): 419 | grad_norm, step_norm, alpha, noise_norm = ld_metrics[:, i, :] 420 | for j in range(num_steps): 421 | metrics = { 422 | 'slope': grad_norm[j], 423 | 'step': step_norm[j], 424 | 'alpha': alpha[j], 425 | 'noise': noise_norm[j] 426 | } 427 | sampling_metrics[i].append(metrics) 428 | return sampling_metrics 429 | -------------------------------------------------------------------------------- /utils/losses.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Loss functions.""" 17 | import jax 18 | import jax.numpy as jnp 19 | from flax import nn 20 | 21 | 22 | def reduce_fn(x, mode): 23 | if mode == "none" or mode is None: 24 | return jnp.asarray(x) 25 | elif mode == "sum": 26 | return jnp.sum(x) 27 | elif mode == "mean": 28 | return jnp.mean(jnp.asarray(x)) 29 | else: 30 | raise ValueError("Unsupported reduction option.") 31 | 32 | 33 | def series_loss(context, true_target, pred_target, reduction="mean"): 34 | """Series loss (self-similarity + MSE loss). 35 | 36 | Compute the loss between a predicted target embedding and the true target embedding 37 | conditioned on a sequence of previous embeddings (context). 38 | 39 | Args: 40 | context: A sequence of continuous embeddings. 41 | true_target: The true next embedding in the sequence. 42 | pred_target: The predicted next embedding in the sequence. 43 | reduction: Type of reduction to apply to loss. 44 | 45 | Returns: 46 | Loss value. If `reduction` is `none`, this has the same shape as `data`; 47 | otherwise, it is scalar. 48 | """ 49 | ss = context @ true_target.T 50 | ss_hat = context @ pred_target.T 51 | loss = mean_squared_error(ss.T, ss_hat.T) + mean_squared_error( 52 | true_target, pred_target) 53 | return reduce_fn(loss, reduction) 54 | 55 | 56 | def _log_gaussian_pdf(y, mu, log_sigma): 57 | """The log probability density function of a Gaussian distribution.""" 58 | norm_const = jnp.log(jnp.sqrt(2.0 * jnp.pi)) 59 | return -0.5 * ((y - mu) / jnp.exp(log_sigma))**2 - log_sigma - norm_const 60 | 61 | 62 | def gaussian_mixture_loss(log_pi, mu, log_sigma, data, reduction="mean"): 63 | """Mixture density network loss. 64 | 65 | Computes the negative log-likelihood of data under a Gaussian mixture. 66 | 67 | Args: 68 | log_pi: The log of the relative weights of each Gaussian. 69 | mu: The mean of the Gaussians. 70 | log_sigma: The log of the standard deviation of each Gaussian. 71 | data: A batch of data to compute the loss for. 72 | reduction: Type of reduction to apply to loss. 73 | 74 | Returns: 75 | Loss value. If `reduction` is `none`, this has the same shape as `data`; 76 | otherwise, it is scalar. 77 | """ 78 | k = log_pi.shape[-1] 79 | data = jnp.repeat(data[:, jnp.newaxis, :,], k, axis=1) 80 | loglik = _log_gaussian_pdf(data, mu, log_sigma) # dimension-wise density 81 | loglik = loglik.sum(axis=2) # works because of diagonal covariance 82 | loss = jax.scipy.special.logsumexp(log_pi + loglik, axis=1) 83 | return -reduce_fn(loss, reduction) 84 | 85 | 86 | def binary_cross_entropy_with_logits(logits, labels): 87 | """Numerically stable binary cross entropy loss.""" 88 | return labels * nn.softplus(-logits) + (1 - labels) * nn.softplus(logits) 89 | 90 | 91 | def sigmoid_cross_entropy(logits, labels, reduction="sum"): 92 | """Computes sigmoid cross entropy given logits and multiple class labels. 93 | 94 | Measures the probability error in discrete classification tasks in which each 95 | class is independent and not mutually exclusive. 96 | 97 | `logits` and `labels` must have the same type and shape. 98 | 99 | Args: 100 | logits: Logit output values. 101 | labels: Ground truth integer labels in {0, 1}. 102 | reduction: Type of reduction to apply to loss. 103 | 104 | Returns: 105 | Loss value. If `reduction` is `none`, this has the same shape as `labels`; 106 | otherwise, it is scalar. 107 | 108 | Raises: 109 | ValueError: If the type of `reduction` is unsupported. 110 | """ 111 | log_p = nn.log_sigmoid(logits) 112 | # log(1 - sigmoid(x)) = log_sigmoid(-x), the latter more numerically stable 113 | log_not_p = nn.log_sigmoid(-logits) 114 | loss = -labels * log_p - (1. - labels) * log_not_p 115 | return reduce_fn(loss, reduction) 116 | 117 | 118 | def mean_squared_error(logits, labels, reduction="mean"): 119 | """Mean squared error.""" 120 | loss = jnp.square(logits - labels).mean(axis=1) 121 | return reduce_fn(loss, reduction) 122 | 123 | 124 | def kl_divergence(mu, var): 125 | """KL divergence from a standard unit Gaussian.""" 126 | return 0.5 * jnp.sum(jnp.square(mu) + var - 1 - jnp.log(var)) 127 | 128 | 129 | def denoising_score_matching_loss(batch, 130 | model, 131 | sigmas, 132 | rng, 133 | continuous_noise=False, 134 | reduction="mean"): 135 | """Denoising score matching objective used to train NCSNs. 136 | 137 | Args: 138 | batch: A batch of data to compute the loss for. 139 | model: A noise-conditioned score network. 140 | sigmas: A noise schedule (list of standard deviations). 141 | rng: Random number generator key to sample sigmas. 142 | continuous_noise: If True, uses continuous noise conditioning. 143 | reduction: Type of reduction to apply to loss. 144 | 145 | Returns: 146 | Loss value. If `reduction` is `none`, this has the same shape as `data`; 147 | otherwise, it is scalar. 148 | """ 149 | rng, label_rng, sample_rng = jax.random.split(rng, num=3) 150 | labels = jax.random.randint(key=label_rng, 151 | shape=(batch.shape[0],), 152 | minval=int(continuous_noise), 153 | maxval=len(sigmas)) 154 | 155 | if continuous_noise: 156 | rng, noise_rng = jax.random.split(rng) 157 | used_sigmas = jax.random.uniform(key=noise_rng, 158 | shape=labels.shape, 159 | minval=sigmas[labels - 1], 160 | maxval=sigmas[labels]) 161 | else: 162 | used_sigmas = sigmas[labels] 163 | 164 | used_sigmas = used_sigmas.reshape(batch.shape[0], 165 | *([1] * len(batch.shape[1:]))) 166 | noise = jax.random.normal(key=sample_rng, shape=batch.shape) * used_sigmas 167 | perturbed_samples = batch + noise 168 | target = -1 / (used_sigmas**2) * noise 169 | scores = model(perturbed_samples, used_sigmas) 170 | 171 | assert target.shape == batch.shape 172 | assert scores.shape == batch.shape 173 | 174 | # Compute loss 175 | target = target.reshape(target.shape[0], -1) 176 | scores = scores.reshape(scores.shape[0], -1) 177 | loss = 0.5 * jnp.sum(jnp.square(scores - target), 178 | axis=-1) * used_sigmas.squeeze()**2 179 | return reduce_fn(loss, reduction) 180 | 181 | 182 | def sliced_score_matching_loss(batch, 183 | model, 184 | sigmas, 185 | rng, 186 | continuous_noise=False, 187 | reduction="mean"): 188 | """Sliced score matching objective used to train NCSNs. 189 | 190 | Args: 191 | batch: A batch of data to compute the loss for. 192 | model: A noise-conditioned score network. 193 | sigmas: A noise schedule (list of standard deviations). 194 | rng: Random number generator key to sample sigmas. 195 | continuous_noise: If True, uses continuous noise conditioning. 196 | reduction: Type of reduction to apply to loss. 197 | 198 | Returns: 199 | Loss value. If `reduction` is `none`, this has the same shape as `data`; 200 | otherwise, it is scalar. 201 | """ 202 | n_particles = 1. # TODO: does not support more than 1 particle. 203 | rng, label_rng, sample_rng, score_rng = jax.random.split(rng, num=4) 204 | labels = jax.random.randint(key=label_rng, 205 | shape=(batch.shape[0],), 206 | minval=int(continuous_noise), 207 | maxval=len(sigmas)) 208 | 209 | if continuous_noise: 210 | rng, noise_rng = jax.random.split(rng) 211 | used_sigmas = jax.random.uniform(key=noise_rng, 212 | shape=labels.shape, 213 | minval=sigmas[labels - 1], 214 | maxval=sigmas[labels]) 215 | else: 216 | used_sigmas = sigmas[labels] 217 | 218 | used_sigmas = used_sigmas.reshape(batch.shape[0], 219 | *([1] * len(batch.shape[1:]))) 220 | noise = jax.random.normal(key=sample_rng, shape=batch.shape) * used_sigmas 221 | perturbed_samples = batch + noise 222 | 223 | dup_samples = perturbed_samples[:] 224 | dup_sigmas = used_sigmas[:] 225 | 226 | vectors = jax.random.rademacher(key=score_rng, shape=dup_samples.shape) 227 | 228 | # Compute gradients. 229 | first_grad = model(dup_samples, dup_sigmas) 230 | score_fn = lambda x: jnp.sum(model(x, dup_sigmas) * vectors) 231 | first_grad_v, second_grad = jax.value_and_grad(score_fn)(dup_samples) 232 | assert second_grad.shape == first_grad.shape 233 | 234 | # Score loss. 235 | first_grad = first_grad.reshape(dup_samples.shape[0], -1) 236 | score_loss = 0.5 * jnp.sum(jnp.square(first_grad), axis=-1) 237 | 238 | # Hessian loss. 239 | hessian_loss = jnp.sum( 240 | (vectors * second_grad).reshape(dup_samples.shape[0], -1), axis=-1) 241 | 242 | # Compute loss. 243 | score_loss = score_loss.reshape(n_particles, -1).mean(axis=0) 244 | hessian_loss = hessian_loss.reshape(n_particles, -1).mean(axis=0) 245 | loss = (score_loss + hessian_loss) * (used_sigmas.squeeze()**2) 246 | 247 | return reduce_fn(loss, reduction) 248 | 249 | 250 | def diffusion_loss(batch, 251 | model, 252 | betas, 253 | rng, 254 | continuous_noise=False, 255 | reduction="mean"): 256 | """Diffusion denoising probabilistic model loss. 257 | 258 | Args: 259 | batch: A batch of data to compute the loss for. 260 | model: A diffusion probabilistic model. 261 | betas: A noise schedule. 262 | rng: Random number generator key to sample sigmas. 263 | continuous_noise: If True, uses continuous noise conditioning. 264 | reduction: Type of reduction to apply to loss. 265 | 266 | Returns: 267 | Loss value. If `reduction` is `none`, this has the same shape as `data`; 268 | otherwise, it is scalar. 269 | """ 270 | T = len(betas) 271 | rng, label_rng, sample_rng = jax.random.split(rng, num=3) 272 | labels = jax.random.randint(key=label_rng, 273 | shape=(batch.shape[0],), 274 | minval=int(continuous_noise), 275 | maxval=T + int(continuous_noise)) 276 | 277 | alphas = 1. - betas 278 | alphas_prod = jnp.cumprod(alphas) 279 | 280 | # if continuous_noise: 281 | alphas_prod = jnp.concatenate([jnp.ones((1,)), alphas_prod]) 282 | rng, noise_rng = jax.random.split(rng) 283 | used_alphas = jax.random.uniform(key=noise_rng, 284 | shape=labels.shape, 285 | minval=alphas_prod[labels - 1], 286 | maxval=alphas_prod[labels]) 287 | # else: 288 | # used_alphas = alphas_prod[labels] 289 | 290 | used_alphas = used_alphas.reshape(batch.shape[0], 291 | *([1] * len(batch.shape[1:]))) 292 | t = labels.reshape(batch.shape[0], *([1] * len(batch.shape[1:]))) 293 | 294 | eps = jax.random.normal(key=sample_rng, shape=batch.shape) 295 | perturbed_sample = jnp.sqrt(used_alphas) * batch + jnp.sqrt(1 - 296 | used_alphas) * eps 297 | 298 | # if continuous_noise: 299 | pred = model(perturbed_sample, 300 | jnp.sqrt(used_alphas)) # condition on noise level. 301 | # else: 302 | # pred = model(perturbed_sample, t) # condition on timestep. 303 | 304 | loss = jnp.square(eps - pred) 305 | loss = jnp.mean(loss, axis=tuple(range(1, len(loss.shape)))) 306 | assert loss.shape == batch.shape[:1] 307 | 308 | return reduce_fn(loss, reduction) 309 | -------------------------------------------------------------------------------- /utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Metrics.""" 17 | import math 18 | import note_seq 19 | import numpy as np 20 | import scipy 21 | from sklearn import metrics 22 | 23 | 24 | def frechet_distance(real, fake): 25 | """Frechet distance. 26 | 27 | Lower score is better. 28 | """ 29 | mu1, sigma1 = np.mean(real, axis=0), np.cov(real, rowvar=False) 30 | mu2, sigma2 = np.mean(fake, axis=0), np.cov(fake, rowvar=False) 31 | diff = mu1 - mu2 32 | covmean, _ = scipy.linalg.sqrtm(sigma1.dot(sigma2), disp=False) 33 | 34 | if not np.isfinite(covmean).all(): 35 | msg = ('fid calculation produces singular product; ' 36 | 'adding %s to diagonal of cov estimates') % eps 37 | print(msg) 38 | offset = np.eye(sigma1.shape[0]) * eps 39 | covmean = scipy.linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 40 | 41 | # Numerical error might give slight imaginary component 42 | if np.iscomplexobj(covmean): 43 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 44 | m = np.max(np.abs(covmean.imag)) 45 | raise ValueError('Imaginary component {}'.format(m)) 46 | covmean = covmean.real 47 | 48 | assert np.isfinite(covmean).all() and not np.iscomplexobj(covmean) 49 | 50 | tr_covmean = np.trace(covmean) 51 | 52 | frechet_dist = diff.dot(diff) 53 | frechet_dist += np.trace(sigma1) + np.trace(sigma2) 54 | frechet_dist -= 2 * tr_covmean 55 | return frechet_dist 56 | 57 | 58 | def mmd_rbf(real, fake, gamma=1.0): 59 | """(RBF) kernel distance. 60 | 61 | Lower score is better. 62 | """ 63 | XX = metrics.pairwise.rbf_kernel(real, real, gamma) 64 | YY = metrics.pairwise.rbf_kernel(fake, fake, gamma) 65 | XY = metrics.pairwise.rbf_kernel(real, fake, gamma) 66 | return XX.mean() + YY.mean() - 2 * XY.mean() 67 | 68 | 69 | def mmd_polynomial(real, fake, degree=2, gamma=1, coef0=0): 70 | """(Polynomial) kernel distance. 71 | 72 | Lower score is better. 73 | """ 74 | XX = metrics.pairwise.polynomial_kernel(real, real, degree, gamma, coef0) 75 | YY = metrics.pairwise.polynomial_kernel(fake, fake, degree, gamma, coef0) 76 | XY = metrics.pairwise.polynomial_kernel(real, fake, degree, gamma, coef0) 77 | return XX.mean() + YY.mean() - 2 * XY.mean() 78 | 79 | 80 | def framewise_statistic(ns, stat_fn, hop_size=1, frame_size=1): 81 | """Computes framewise MIDI statistic.""" 82 | total_time = int(math.ceil(ns.total_time)) 83 | frames = [] 84 | trim = frame_size - hop_size 85 | for i in range(0, total_time - trim, hop_size): 86 | one_sec_chunk = note_seq.sequences_lib.trim_note_sequence( 87 | ns, i, i + frame_size) 88 | value = stat_fn(one_sec_chunk.notes) 89 | frames.append(value) 90 | return np.array(frames) 91 | 92 | 93 | def note_density(ns, hop_size=1, frame_size=1): 94 | stat_fn = lambda notes: len(notes) 95 | return framewise_statistic(ns, 96 | stat_fn, 97 | hop_size=hop_size, 98 | frame_size=frame_size) 99 | 100 | 101 | def pitch_range(ns, hop_size=1, frame_size=1): 102 | 103 | def stat_fn(notes): 104 | pitches = [note.pitch for note in notes] 105 | return max(pitches) - min(pitches) if len(pitches) > 0 else 0 106 | 107 | return framewise_statistic(ns, 108 | stat_fn, 109 | hop_size=hop_size, 110 | frame_size=frame_size) 111 | 112 | 113 | def mean_pitch(ns, hop_size=1, frame_size=1): 114 | 115 | def stat_fn(notes): 116 | pitches = np.array([note.pitch for note in notes]) 117 | return pitches.mean() if len(pitches) > 0 else 0 118 | 119 | return framewise_statistic(ns, 120 | stat_fn, 121 | hop_size=hop_size, 122 | frame_size=frame_size) 123 | 124 | 125 | def var_pitch(ns, hop_size=1, frame_size=1): 126 | 127 | def stat_fn(notes): 128 | pitches = np.array([note.pitch for note in notes]) 129 | return pitches.var() if len(pitches) > 0 else 0 130 | 131 | return framewise_statistic(ns, 132 | stat_fn, 133 | hop_size=hop_size, 134 | frame_size=frame_size) 135 | 136 | 137 | def mean_note_duration(ns, hop_size=1, frame_size=1): 138 | 139 | def stat_fn(notes): 140 | durations = np.array([note.end_time - note.start_time for note in notes]) 141 | return durations.mean() if len(durations) > 0 else 0 142 | 143 | return framewise_statistic(ns, 144 | stat_fn, 145 | hop_size=hop_size, 146 | frame_size=frame_size) 147 | 148 | 149 | def var_note_duration(ns, hop_size=1, frame_size=1): 150 | 151 | def stat_fn(notes): 152 | durations = np.array([note.end_time - note.start_time for note in notes]) 153 | return durations.var() if len(durations) > 0 else 0 154 | 155 | return framewise_statistic(ns, 156 | stat_fn, 157 | hop_size=hop_size, 158 | frame_size=frame_size) 159 | 160 | 161 | def perceptual_midi_histograms(ns, interval=1): 162 | """Generates histograms for each MIDI feature.""" 163 | return dict( 164 | nd=note_density(ns, interval=interval), 165 | pr=pitch_range(ns, interval=interval), 166 | mp=mean_pitch(ns, interval=interval), 167 | vp=var_pitch(ns, interval=interval), 168 | md=mean_note_duration(ns, interval=interval), 169 | vd=var_note_duration(ns, interval=interval), 170 | ) 171 | 172 | 173 | def perceptual_midi_statistics(ns, interval=1, vector=False): 174 | """Feature vector of means and variances of MIDI histograms. 175 | 176 | Args: 177 | ns: NoteSequence object. 178 | interval: Integer time interval (in seconds) for each histogram bin. 179 | vector: If True, returns statistics as a feature vector. 180 | """ 181 | features = {} 182 | histograms = perceptual_midi_histograms(ns, interval=interval) 183 | for key in histograms: 184 | mu = histograms[key].mean() 185 | var = histograms[key].var() 186 | features[key] = (mu, var) 187 | 188 | if vector: 189 | vec = np.array(list(features.values())) 190 | return vec.reshape(-1) 191 | 192 | return features 193 | 194 | 195 | def perceptual_similarity(ns1, ns2, interval=1): 196 | """Perceptual similarity as determined by Overlapping Area Metric. 197 | 198 | Determines pairwise similarity for two NoteSequence objects. 199 | 200 | Args: 201 | ns1: NoteSequence object. 202 | ns2: NoteSequence object. 203 | interval: Integer time interval (in seconds) for each histogram bin. 204 | """ 205 | stats1 = perceptual_midi_statistics(ns1, interval, vector=False) 206 | stats2 = perceptual_midi_statistics(ns2, interval, vector=False) 207 | similarity = {} 208 | for key in stats1: 209 | mu1, var1 = stats1[key] 210 | mu2, var2 = stats2[key] 211 | similarity[key] = overlapping_area(mu1, mu2, var1, var2) 212 | return similarity 213 | 214 | 215 | def overlapping_area(mu1, mu2, var1, var2): 216 | """Compute overlapping area of two Gaussians. 217 | 218 | Args: 219 | mu1: Mean of first Gaussian pdf. 220 | mu2: Mean of second Gaussian pdf. 221 | var1: Variance of first Gaussian pdf. 222 | var2: Variance of second Gaussian pdf. 223 | Returns: 224 | Overlapping area of the two density functions. 225 | """ 226 | idx = mu2 < mu1 227 | mu_a = mu2 * idx + np.logical_not(idx) * mu1 228 | mu_b = mu1 * idx + np.logical_not(idx) * mu2 229 | var_a = var2 * idx + np.logical_not(idx) * var1 230 | var_b = var1 * idx + np.logical_not(idx) * var2 231 | 232 | c_sqrt_factor = (mu_a - mu_b)**2 + 2 * (var_a - var_b) * np.log( 233 | np.sqrt(var_a + 1e-6) / np.sqrt(var_b + 1e-6)) 234 | c_sqrt_factor = np.sqrt(c_sqrt_factor) 235 | c = mu_b * var_a - np.sqrt(var_b) * (mu_a * np.sqrt(var_b) + 236 | np.sqrt(var_a) * c_sqrt_factor) 237 | c = c / (var_a - var_b + 1e-6) 238 | 239 | sqrt_2 = np.sqrt(2) 240 | oa = 1 - 0.5 * scipy.special.erf( 241 | (c - mu_a) / (sqrt_2 * np.sqrt(var_a + 1e-6))) 242 | oa = oa + 0.5 * scipy.special.erf( 243 | (c - mu_b) / (sqrt_2 * np.sqrt(var_b + 1e-6))) 244 | return oa 245 | -------------------------------------------------------------------------------- /utils/plot_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Utilities for TensorBoard plotting.""" 17 | import io 18 | import os 19 | 20 | import matplotlib.pyplot as plt 21 | import numpy as np 22 | from matplotlib.animation import FuncAnimation 23 | 24 | import utils.ebm_utils as ebm_utils 25 | 26 | 27 | def scatter_2d(data, scale=8, show=False): 28 | """Create a plot to compare generated 2D samples. 29 | 30 | Args: 31 | data: An array of 2D points to draw. 32 | scale: A number that specifies the grid bounds. 33 | """ 34 | assert data.shape[-1] == 2 35 | x = data[:, 0] 36 | y = data[:, 1] 37 | 38 | buf = io.BytesIO() 39 | plt.figure() 40 | plt.scatter(x, y, s=0.1) 41 | plt.axis('square') 42 | plt.title('Samples') 43 | plt.xlim([-scale, scale]) 44 | plt.ylim([-scale, scale]) 45 | if show: 46 | plt.show() 47 | plt.savefig(buf, format='png') 48 | plt.close() 49 | buf.seek(0) 50 | return buf 51 | 52 | 53 | def animate_scatter_2d(data, scale=8, show=False, fps=60): 54 | """Create an animation to compare generated 2D samples. 55 | 56 | Args: 57 | data: An array of 2D points to draw of shape [timesteps, N, 2]. 58 | scale: A number that specifies the grid bounds. 59 | """ 60 | assert data.shape[-1] == 2 and data.ndim == 3 61 | 62 | buf = io.BytesIO() 63 | plt.figure() 64 | fig, ax = plt.subplots() 65 | sc = ax.scatter(data[0, :, 0], data[0, :, 1], s=0.1) 66 | title = ax.text(0.5, 67 | 0, 68 | "", 69 | bbox={ 70 | 'facecolor': 'w', 71 | 'alpha': 0.5, 72 | 'pad': 5 73 | }, 74 | transform=ax.transAxes, 75 | ha="center") 76 | plt.axis('square') 77 | plt.title('Samples') 78 | plt.xlim([-scale, scale]) 79 | plt.ylim([-scale, scale]) 80 | 81 | def animate(i): 82 | title.set_text(f'frame: {i}') 83 | sc.set_offsets(data[i]) 84 | 85 | anim = FuncAnimation(fig, animate, frames=len(data), interval=1000 / fps) 86 | 87 | anim.save('tmp.gif', writer='imagemagick') 88 | with open('tmp.gif', 'rb') as f: 89 | buf.write(f.read()) 90 | f.close() 91 | os.remove('tmp.gif') 92 | 93 | if show: 94 | plt.draw() 95 | plt.show() 96 | 97 | plt.close() 98 | buf.seek(0) 99 | return buf 100 | 101 | 102 | def energy_contour_2d(model, scale=5, show=False): 103 | """Create contour plot of the energy landscape. 104 | 105 | Args: 106 | model: An energy-based model (R^2 -> R). 107 | scale: A number that specifies the grid bounds. 108 | """ 109 | x = np.arange(-scale, scale, 0.05) 110 | y = np.arange(-scale, scale, 0.05) 111 | xx, yy = np.meshgrid(x, y, sparse=False) 112 | coords = np.stack((xx, yy), axis=2) 113 | coords_flat = coords.reshape(-1, 2) 114 | z_flat = model(coords_flat) 115 | 116 | # highlight regions where energy is low (high density) 117 | z = -1 * z_flat.reshape(*coords.shape[:-1]) 118 | 119 | buf = io.BytesIO() 120 | plt.figure() 121 | plt.contourf(x, y, z) 122 | if show: 123 | plt.show() 124 | plt.savefig(buf, format='png') 125 | plt.close() 126 | buf.seek(0) 127 | return buf 128 | 129 | 130 | def score_field_2d(model, sigma=None, scale=8, show=False): 131 | """Create plot of the gradient field of the energy landscape. 132 | 133 | Args: 134 | model: An energy-based model (R^2 -> R) or a score network (R^2 -> R^2). 135 | sigma: A noise value for an NCSN. Only required if model is a score network. 136 | scale: A number that specifies the grid bounds. 137 | """ 138 | mesh = [] 139 | x = np.linspace(-scale, scale, 20) 140 | y = np.linspace(-scale, scale, 20) 141 | for i in x: 142 | for j in y: 143 | mesh.append(np.asarray([i, j])) 144 | mesh = np.stack(mesh, axis=0) 145 | 146 | if sigma is not None: 147 | sigma = sigma * np.ones((mesh.shape[0], 1)) 148 | scores = model(mesh, sigma) 149 | else: 150 | scores = ebm_utils.vgrad(model, mesh) 151 | assert scores.shape == mesh.shape 152 | 153 | buf = io.BytesIO() 154 | plt.figure() 155 | plt.quiver(mesh[:, 0], mesh[:, 1], scores[:, 0], scores[:, 1], width=0.005) 156 | plt.title('Estimated scores', fontsize=16) 157 | plt.axis('square') 158 | if show: 159 | plt.show() 160 | plt.savefig(buf, format='png') 161 | plt.close() 162 | buf.seek(0) 163 | return buf 164 | 165 | 166 | def image_tiles(images, shape=(28, 28), show=False): 167 | n = len(images) 168 | for i in range(n): 169 | ax = plt.subplot(1, n, i + 1) 170 | plt.imshow(images[i].reshape(*shape)) 171 | plt.gray() 172 | ax.get_xaxis().set_visible(False) 173 | ax.get_yaxis().set_visible(False) 174 | 175 | if show: 176 | plt.show() 177 | buf = io.BytesIO() 178 | plt.savefig(buf, format='png') 179 | plt.close() 180 | buf.seek(0) 181 | return buf 182 | -------------------------------------------------------------------------------- /utils/song_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Utilities for manipulating multi-measure NoteSequences.""" 17 | import os 18 | import sys 19 | 20 | import note_seq 21 | import numpy as np 22 | 23 | sys.path.append("{}/../".format(os.path.dirname(os.path.abspath(__file__)))) 24 | from config import melody_2bar_converter 25 | 26 | 27 | def spherical_interpolation(p0, p1, alpha): 28 | """Spherical linear interpolation.""" 29 | assert p0.shape == p1.shape 30 | assert p0.ndim == 2 and p1.ndim == 2 31 | unit_p0 = p0 / np.linalg.norm(p0, axis=1, keepdims=1) 32 | unit_p1 = p1 / np.linalg.norm(p1, axis=1, keepdims=1) 33 | omega = np.arccos(np.diag(unit_p0.dot(unit_p1.T))) 34 | so = np.sin(omega) 35 | c1 = (np.sin((1.0 - alpha) * omega) / so)[:, np.newaxis] 36 | c2 = (np.sin(alpha * omega) / so)[:, np.newaxis] 37 | return c1 * p0 + c2 * p1 38 | 39 | 40 | def count_measures(note_sequence): 41 | """Approximate number of measures in the song.""" 42 | splits = note_seq.sequences_lib.split_note_sequence_on_time_changes( 43 | note_sequence) 44 | bars = 0 45 | for split in splits: 46 | time_signature = split.time_signatures[0] 47 | tempo = split.tempos[0] 48 | quarters_per_bar = 4 * time_signature.numerator / time_signature.denominator 49 | seconds_per_bar = 60 * quarters_per_bar / tempo.qpm 50 | num_bars = split.total_time / seconds_per_bar 51 | bars += num_bars 52 | return bars 53 | 54 | 55 | def extract_melodies(note_sequence, keep_longest_split=False): 56 | """Extracts all melodies in a polyphonic note sequence. 57 | 58 | Args: 59 | note_sequence: A polyphonic NoteSequence object. 60 | keep_longest_split: Whether to discard all subsequences with tempo changes 61 | other than the longest one. 62 | 63 | Returns: 64 | List of monophonic NoteSequence objects. 65 | """ 66 | splits = note_seq.sequences_lib.split_note_sequence_on_time_changes( 67 | note_sequence) 68 | 69 | if keep_longest_split: 70 | ns = max(splits, key=lambda x: len(x.notes)) 71 | splits = [ns] 72 | 73 | melodies = [] 74 | for split_ns in splits: 75 | qs = note_seq.sequences_lib.quantize_note_sequence(split_ns, 76 | steps_per_quarter=4) 77 | 78 | instruments = list(set([note.instrument for note in qs.notes])) 79 | 80 | for instrument in instruments: 81 | melody = note_seq.melodies_lib.Melody() 82 | try: 83 | melody.from_quantized_sequence(qs, 84 | ignore_polyphonic_notes=True, 85 | instrument=instrument, 86 | gap_bars=np.inf) 87 | except note_seq.NonIntegerStepsPerBarError: 88 | continue 89 | melody_ns = melody.to_sequence() 90 | melodies.append(melody_ns) 91 | 92 | return melodies 93 | 94 | 95 | def generate_shifted_sequences(song, resolution=1): 96 | """Generates shifted and overlapping versions of a Song. 97 | 98 | Args: 99 | song: A multitrack Song object. 100 | resolution: The number of shifted examples, with computed timing offsets 101 | uniformly spaced. 102 | 103 | Returns: 104 | A list of multitrack Song objects. 105 | """ 106 | offset = 2.0 / resolution 107 | base = song.note_sequence 108 | dc = song.data_converter 109 | results = [] 110 | for step in range(resolution): 111 | shift = note_seq.extract_subsequence(base, offset * step, base.total_time) 112 | results.append(Song(shift, dc, chunk_length=1)) 113 | assert len(results) == resolution 114 | return results 115 | 116 | 117 | def fix_instruments_for_concatenation(note_sequences): 118 | """Adjusts instruments for concatenating multitrack measures.""" 119 | instruments = {} 120 | for i in range(len(note_sequences)): 121 | for note in note_sequences[i].notes: 122 | if not note.is_drum: 123 | if note.program not in instruments: 124 | if len(instruments) >= 8: 125 | instruments[note.program] = len(instruments) + 2 126 | else: 127 | instruments[note.program] = len(instruments) + 1 128 | note.instrument = instruments[note.program] 129 | else: 130 | note.instrument = 9 131 | 132 | 133 | def fix_chunk_lengths_for_concatenation(note_sequences): 134 | """Adjusts the total_time of each tokenized chunk for concatenating 135 | multitrack measures. 136 | """ 137 | max_chunk_time = max([ns.total_time for ns in note_sequences]) 138 | for chunk in note_sequences: 139 | chunk.total_time = max_chunk_time 140 | 141 | 142 | def chunks_to_embeddings(sequences, model, data_converter): 143 | """Convert NoteSequence objects into latent space embeddings. 144 | 145 | Args: 146 | sequences: A list of NoteSequence objects. 147 | model: A TrainedModel object used for inference. 148 | data_converter: A data converter (e.g. OneHotMelodyConverter, 149 | TrioConverter) used to convert NoteSequence objects into 150 | tensor encodings for model inference. 151 | 152 | Returns: 153 | A numpy matrix of shape [len(sequences), latent_dims]. 154 | """ 155 | assert model is not None, 'No model provided.' 156 | 157 | latent_dims = model._z_input.shape[1] 158 | idx = [] 159 | non_rest_chunks = [] 160 | zs = np.zeros((len(sequences), latent_dims)) 161 | mus = np.zeros((len(sequences), latent_dims)) 162 | sigmas = np.zeros((len(sequences), latent_dims)) 163 | for i, chunk in enumerate(sequences): 164 | if len(data_converter.to_tensors(chunk).inputs) > 0: 165 | idx.append(i) 166 | non_rest_chunks.append(chunk) 167 | if non_rest_chunks: 168 | z, mu, sigma = model.encode(non_rest_chunks) 169 | assert z.shape == mu.shape == sigma.shape 170 | for i, mean in enumerate(mu): 171 | zs[idx[i]] = z[i] 172 | mus[idx[i]] = mean 173 | sigmas[idx[i]] = sigma[i] 174 | return zs, mus, sigmas 175 | 176 | 177 | def embeddings_to_chunks(embeddings, model, temperature=1e-3): 178 | """Decode latent embeddings as NoteSequences. 179 | 180 | Args: 181 | embeddings: A numpy array of latent embeddings. 182 | model: A TrainedModel object used for decoding embeddings. 183 | 184 | Returns: 185 | A list of NoteSequence objects. 186 | """ 187 | assert model is not None, 'No model provided.' 188 | assert len(embeddings) > 0 189 | 190 | reconstructed_chunks = model.decode(embeddings, 191 | temperature=temperature, 192 | length=model._config.hparams.max_seq_len) 193 | assert len(reconstructed_chunks) == len(embeddings) 194 | 195 | embedding_norms = np.linalg.norm(embeddings, axis=1) 196 | rest_chunk_idx = np.where( 197 | embedding_norms == 0)[0] # rests correspond to zero-length embeddings 198 | 199 | for idx in rest_chunk_idx: 200 | rest_ns = note_seq.NoteSequence() 201 | rest_ns.total_time = reconstructed_chunks[idx].total_time 202 | reconstructed_chunks[idx] = rest_ns 203 | return reconstructed_chunks 204 | 205 | 206 | def embeddings_to_song(embeddings, 207 | model, 208 | data_converter, 209 | fix_instruments=True, 210 | temperature=1e-3): 211 | """Decode latent embeddings as a concatenated NoteSequence. 212 | 213 | Args: 214 | embeddings: A numpy array of latent embeddings. 215 | model: A TrainedModel object used for decoding. 216 | data_converter: A data converter used by the returned Song 217 | object. 218 | fix_instruments: A boolean determining whether instruments in 219 | multitrack measures should be fixed before concatenation. 220 | 221 | Returns: 222 | A Song object. 223 | """ 224 | chunks = embeddings_to_chunks(embeddings, model, temperature) 225 | if fix_instruments: 226 | fix_instruments_for_concatenation(chunks) 227 | concat_chunks = note_seq.sequences_lib.concatenate_sequences(chunks) 228 | return Song(concat_chunks, data_converter, reconstructed=True) 229 | 230 | 231 | def encode_songs(model, songs, chunk_length=None, programs=None): 232 | """Generate embeddings for a batch of songs. 233 | 234 | Args: 235 | model: A TrainedModel object used for inference. 236 | songs: A list of Song objects. 237 | chunk_length: An integer describing the number of measures 238 | each chunk of each song should contain. 239 | programs: A list of integers specifying which MIDI programs to use. 240 | Default is to keep all available programs. 241 | 242 | Returns: 243 | A list of numpy matrices each with shape [3, len(song_chunks), latent_dims]. 244 | """ 245 | assert model is not None, 'No model provided.' 246 | assert len(songs) > 0, 'No songs provided.' 247 | 248 | chunks, splits = [], [] 249 | data_converter = songs[0].data_converter 250 | i = 0 251 | for song in songs: 252 | chunk_tensors, chunk_sequences = song.chunks(chunk_length=chunk_length, 253 | programs=programs) 254 | del chunk_tensors 255 | chunks.extend(chunk_sequences) 256 | splits.append(i) 257 | i += len(chunk_sequences) 258 | 259 | z, mu, sigma = chunks_to_embeddings(chunks, model, data_converter) 260 | 261 | encoding = [] 262 | for i in range(len(splits)): 263 | j, k = splits[i], None if i + 1 == len(splits) else splits[i + 1] 264 | song_encoding = [z[j:k], mu[j:k], sigma[j:k]] 265 | song_encoding = np.stack(song_encoding) 266 | encoding.append(song_encoding) 267 | 268 | assert len(encoding) == len(splits) == len(songs) 269 | return encoding 270 | 271 | 272 | class Song(object): 273 | """Song object used to provide additional abstractions for NoteSequences. 274 | 275 | Attributes: 276 | note_sequence: A NoteSequence object holding the Song's MIDI data. 277 | data_converter: A data converter used for preprocessing and tokenization 278 | for a corresponding MusicVAE model. 279 | chunk_length: The number of measures in each tokenized chunk of MIDI 280 | (dependent on the model configuration). 281 | multitrack: Whether this Song is multitrack or not. 282 | reconstructed: A boolean describing whether this Song is reconstructed 283 | from the decoder of a MusicVAE model. 284 | """ 285 | 286 | def __init__(self, 287 | note_sequence, 288 | data_converter, 289 | chunk_length=2, 290 | multitrack=False, 291 | reconstructed=False): 292 | self.note_sequence = note_sequence 293 | self.data_converter = data_converter 294 | self.chunk_length = chunk_length 295 | self.reconstructed = reconstructed 296 | self.multitrack = multitrack 297 | 298 | def encode(self, model, chunk_length=None, programs=None): 299 | """Encode song chunks (and full-chunk rests). 300 | 301 | Returns: 302 | z: (chunks, latent_dims), mu: (chunks, latent_dims), sigma: (chunks, latent_dims). 303 | """ 304 | chunk_tensors, chunk_sequences = self.chunks(chunk_length=chunk_length, 305 | programs=programs) 306 | z, means, sigmas = chunks_to_embeddings(chunk_sequences, model, 307 | self.data_converter) 308 | del chunk_tensors # unused 309 | return z 310 | 311 | def chunks(self, chunk_length=None, programs=None, fix_instruments=True): 312 | """Split and featurize song into chunks of tensors and NoteSequences.""" 313 | assert not self.reconstructed, 'Not safe to tokenize reconstructed Songs.' 314 | 315 | data = self.note_sequence 316 | step_size = self.chunk_length 317 | if chunk_length is not None: 318 | step_size = chunk_length 319 | if programs is not None: 320 | data = self.select_programs(programs) 321 | 322 | # Use the data converter to preprocess sequences 323 | tensors = self.data_converter.to_tensors(data).inputs[::step_size] 324 | sequences = self.data_converter.from_tensors(tensors) 325 | 326 | if fix_instruments and self.multitrack: 327 | fix_instruments_for_concatenation(sequences) 328 | 329 | return tensors, sequences 330 | 331 | def count_chunks(self, chunk_length=None): 332 | length = self.chunk_length if chunk_length is None else chunk_length 333 | return count_measures(self.note_sequence) // length 334 | 335 | @property 336 | def programs(self): 337 | """MIDI programs used in this song.""" 338 | return list(set([note.program for note in self.note_sequence.notes])) 339 | 340 | def select_programs(self, programs): 341 | """Keeps selected programs of MIDI (e.g. melody program).""" 342 | assert len(programs) > 0 343 | assert all([program >= 0 for program in programs]) 344 | 345 | ns = note_seq.NoteSequence() 346 | ns.CopyFrom(self.note_sequence) 347 | del ns.notes[:] 348 | 349 | for note in self.note_sequence.notes[:]: 350 | if note.program in programs: 351 | new_note = ns.notes.add() 352 | new_note.CopyFrom(note) 353 | return ns 354 | 355 | def truncate(self, chunks=0, offset=0): 356 | """Returns a truncated version of the song. 357 | 358 | Args: 359 | chunks: The number of chunks in the truncated sequence. 360 | offset: The offset in chunks to begin truncation. 361 | 362 | Returns: 363 | A truncated Song object. 364 | """ 365 | tensors = self.data_converter.to_tensors( 366 | self.note_sequence).inputs[::self.chunk_length] 367 | sequences = self.data_converter.from_tensors(tensors)[offset:offset + 368 | chunks] 369 | fix_instruments_for_concatenation(sequences) 370 | concat_chunks = note_seq.sequences_lib.concatenate_sequences(sequences) 371 | return Song(concat_chunks, 372 | self.data_converter, 373 | chunk_length=self.chunk_length) 374 | 375 | def _count_melody_chunks(self, program): 376 | """Determines the number of 2-measure chunks using the melody data pipeline.""" 377 | ns = self.select_programs([program]) 378 | tensors = melody_2bar_converter.to_tensors(ns).inputs[::2] 379 | sequences = melody_2bar_converter.from_tensors(tensors) 380 | return len(sequences) 381 | 382 | def find_programs(self): 383 | """Search for the most important MIDI programs in the song.""" 384 | 385 | def heuristic(program): 386 | expected = self.count_chunks(chunk_length=2) 387 | extracted = self._count_melody_chunks(program) 388 | if extracted > 0 and abs(extracted - expected) < 0.5 * expected: 389 | return True 390 | return False 391 | 392 | midi_programs = self.programs 393 | top_programs = [p for p in midi_programs if heuristic(p)] 394 | return top_programs 395 | 396 | def stripped_song(self): 397 | """A stripped down version using programs found by a special heuristic.""" 398 | top_programs = self.find_programs() 399 | ns = self.select_programs(top_programs) 400 | return Song(ns, self.data_converter, self.chunk_length) 401 | 402 | def download(self, filename, preprocessed=True, programs=None): 403 | """Download song as MIDI file.""" 404 | assert filename is not None, 'No filename specified.' 405 | 406 | data = self.note_sequence 407 | if programs is not None: 408 | data = self.select_programs(programs) 409 | 410 | if not self.reconstructed and preprocessed: # do not tokenize again if reconstructed 411 | tensors, chunks = self.chunks(programs=programs) 412 | del tensors # unused 413 | data = note_seq.sequences_lib.concatenate_sequences(chunks) 414 | 415 | note_seq.sequence_proto_to_midi_file(data, filename) 416 | 417 | def play(self, preprocessed=True, programs=None): 418 | """Play a song with fluidsynth.""" 419 | data = self.note_sequence 420 | if programs is not None: 421 | data = self.select_programs(programs) 422 | 423 | if not self.reconstructed and preprocessed: # do not tokenize again if reconstructed 424 | tensors, chunks = self.chunks(programs=programs) 425 | del tensors # unused 426 | data = note_seq.sequences_lib.concatenate_sequences(chunks) 427 | 428 | note_seq.play_sequence(data, synth=note_seq.fluidsynth) 429 | return data 430 | -------------------------------------------------------------------------------- /utils/train_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 The Magenta Authors. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | # Lint as: python3 16 | """Training utilities.""" 17 | import jax 18 | import math 19 | import numpy as np 20 | 21 | from absl import logging 22 | from flax import struct 23 | from functools import partial 24 | 25 | 26 | @struct.dataclass 27 | class EarlyStopping: 28 | """Early stopping to avoid overfitting during training. 29 | 30 | Attributes: 31 | min_delta: Minimum delta between updates to be considered an 32 | improvement. 33 | patience: Number of steps of no improvement before stopping. 34 | best_metric: Current best metric value. 35 | patience_count: Number of steps since last improving update. 36 | should_stop: Whether the training loop should stop to avoid 37 | overfitting. 38 | """ 39 | min_delta: float = 0 40 | patience: int = 0 41 | best_metric: float = float('inf') 42 | patience_count: int = 0 43 | should_stop: bool = False 44 | 45 | def update(self, metric): 46 | """Update the state based on metric. 47 | 48 | Returns: 49 | Whether there was an improvement greater than min_delta from 50 | the previous best_metric and the updated EarlyStop object. 51 | """ 52 | 53 | if math.isinf( 54 | self.best_metric) or self.best_metric - metric > self.min_delta: 55 | return True, self.replace(best_metric=metric, patience_count=0) 56 | else: 57 | should_stop = self.patience_count >= self.patience or self.should_stop 58 | return False, self.replace(patience_count=self.patience_count + 1, 59 | should_stop=should_stop) 60 | 61 | 62 | @struct.dataclass 63 | class EMAHelper: 64 | """Exponential moving average of model parameters. 65 | 66 | Attributes: 67 | mu: Momentum parameter. 68 | params: Flax network parameters to update. 69 | """ 70 | mu: float 71 | params: any 72 | 73 | @jax.jit 74 | def update(self, model): 75 | ema_params = jax.tree_multimap( 76 | lambda p_ema, p: p_ema * self.mu + p * (1 - self.mu), self.params, 77 | model.params) 78 | return self.replace(mu=self.mu, params=ema_params) 79 | 80 | 81 | def log_metrics(metrics, 82 | step, 83 | total_steps, 84 | epoch=None, 85 | summary_writer=None, 86 | verbose=True): 87 | """Log metrics. 88 | 89 | Args: 90 | metrics: A dictionary of scalar metrics. 91 | step: The current step. 92 | total_steps: The total number of steps. 93 | epoch: The current epoch. 94 | summary_writer: A TensorBoard summary writer. 95 | verbose: Whether to flush values to stdout. 96 | """ 97 | metrics_str = '' 98 | for metric in metrics: 99 | value = metrics[metric] 100 | if metric == 'lr': 101 | metrics_str += '{} {:5.4f} | '.format(metric, value) 102 | else: 103 | metrics_str += '{} {:5.2f} | '.format(metric, value) 104 | 105 | if summary_writer is not None: 106 | writer_step = step 107 | if epoch is not None: 108 | writer_step = total_steps * epoch + step 109 | summary_writer.scalar(metric, value, writer_step) 110 | 111 | if epoch is not None: 112 | epoch_str = '| epoch {:3d} '.format(epoch) 113 | else: 114 | epoch_str = '' 115 | 116 | if verbose: 117 | logging.info('{}| {:5d}/{:5d} steps | {}'.format(epoch_str, step, 118 | total_steps, metrics_str)) 119 | 120 | 121 | def report_model(model): 122 | """Log number of trainable parameters and their memory footprint.""" 123 | trainable_params = np.sum( 124 | [param.size for param in jax.tree_leaves(model.params)]) 125 | footprint_bytes = np.sum([ 126 | param.size * param.dtype.itemsize 127 | for param in jax.tree_leaves(model.params) 128 | ]) 129 | 130 | logging.info('Number of trainable paramters: {:,}'.format(trainable_params)) 131 | logging.info('Memory footprint: %dMB', footprint_bytes / 2**20) 132 | --------------------------------------------------------------------------------