├── .clang_format ├── .gitignore ├── .pylintrc ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── augment ├── __init__.py └── effects.py ├── examples └── python │ ├── README.md │ ├── WavAugment_walkthrough.ipynb │ ├── librispeech_selfsupervised.py │ └── process_file.py ├── requirements.txt ├── setup.py └── tests ├── augment_test.py ├── compare_to_sox_test.py └── test.wav /.clang_format: -------------------------------------------------------------------------------- 1 | --- 2 | Language: Cpp 3 | # BasedOnStyle: LLVM 4 | AccessModifierOffset: -2 5 | AlignAfterOpenBracket: Align 6 | AlignConsecutiveAssignments: false 7 | AlignConsecutiveDeclarations: false 8 | AlignEscapedNewlines: Right 9 | AlignOperands: true 10 | AlignTrailingComments: true 11 | AllowAllParametersOfDeclarationOnNextLine: true 12 | AllowShortBlocksOnASingleLine: false 13 | AllowShortCaseLabelsOnASingleLine: false 14 | AllowShortFunctionsOnASingleLine: All 15 | AllowShortIfStatementsOnASingleLine: false 16 | AllowShortLoopsOnASingleLine: false 17 | AlwaysBreakAfterDefinitionReturnType: None 18 | AlwaysBreakAfterReturnType: None 19 | AlwaysBreakBeforeMultilineStrings: false 20 | AlwaysBreakTemplateDeclarations: false 21 | BinPackArguments: true 22 | BinPackParameters: true 23 | BraceWrapping: 24 | AfterClass: false 25 | AfterControlStatement: false 26 | AfterEnum: false 27 | AfterFunction: false 28 | AfterNamespace: false 29 | AfterObjCDeclaration: false 30 | AfterStruct: false 31 | AfterUnion: false 32 | BeforeCatch: false 33 | BeforeElse: false 34 | IndentBraces: false 35 | SplitEmptyFunction: true 36 | SplitEmptyRecord: true 37 | SplitEmptyNamespace: true 38 | BreakBeforeBinaryOperators: None 39 | BreakBeforeBraces: Attach 40 | BreakBeforeInheritanceComma: false 41 | BreakBeforeTernaryOperators: true 42 | BreakConstructorInitializersBeforeComma: false 43 | BreakConstructorInitializers: BeforeColon 44 | BreakAfterJavaFieldAnnotations: false 45 | BreakStringLiterals: true 46 | ColumnLimit: 80 47 | CommentPragmas: '^ IWYU pragma:' 48 | CompactNamespaces: false 49 | ConstructorInitializerAllOnOneLineOrOnePerLine: false 50 | ConstructorInitializerIndentWidth: 4 51 | ContinuationIndentWidth: 4 52 | Cpp11BracedListStyle: true 53 | DerivePointerAlignment: false 54 | DisableFormat: false 55 | ExperimentalAutoDetectBinPacking: false 56 | FixNamespaceComments: true 57 | ForEachMacros: 58 | - foreach 59 | - Q_FOREACH 60 | - BOOST_FOREACH 61 | IncludeCategories: 62 | - Regex: '^"(llvm|llvm-c|clang|clang-c)/' 63 | Priority: 2 64 | - Regex: '^(<|"(gtest|gmock|isl|json)/)' 65 | Priority: 3 66 | - Regex: '.*' 67 | Priority: 1 68 | IncludeIsMainRegex: '(Test)?$' 69 | IndentCaseLabels: false 70 | IndentWidth: 2 71 | IndentWrappedFunctionNames: false 72 | JavaScriptQuotes: Leave 73 | JavaScriptWrapImports: true 74 | KeepEmptyLinesAtTheStartOfBlocks: true 75 | MacroBlockBegin: '' 76 | MacroBlockEnd: '' 77 | MaxEmptyLinesToKeep: 1 78 | NamespaceIndentation: None 79 | ObjCBlockIndentWidth: 2 80 | ObjCSpaceAfterProperty: false 81 | ObjCSpaceBeforeProtocolList: true 82 | PenaltyBreakAssignment: 2 83 | PenaltyBreakBeforeFirstCallParameter: 19 84 | PenaltyBreakComment: 300 85 | PenaltyBreakFirstLessLess: 120 86 | PenaltyBreakString: 1000 87 | PenaltyExcessCharacter: 1000000 88 | PenaltyReturnTypeOnItsOwnLine: 60 89 | PointerAlignment: Right 90 | ReflowComments: true 91 | SortIncludes: true 92 | SortUsingDeclarations: true 93 | SpaceAfterCStyleCast: false 94 | SpaceAfterTemplateKeyword: true 95 | SpaceBeforeAssignmentOperators: true 96 | SpaceBeforeParens: ControlStatements 97 | SpaceInEmptyParentheses: false 98 | SpacesBeforeTrailingComments: 1 99 | SpacesInAngles: false 100 | SpacesInContainerLiterals: true 101 | SpacesInCStyleCastParentheses: false 102 | SpacesInParentheses: false 103 | SpacesInSquareBrackets: false 104 | Standard: Cpp11 105 | TabWidth: 8 106 | UseTab: Never 107 | ... 108 | 109 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | build/ 2 | __pycache__/ 3 | *.so 4 | *.code-workspace 5 | *.egg-info/ 6 | .vscode/ 7 | -------------------------------------------------------------------------------- /.pylintrc: -------------------------------------------------------------------------------- 1 | [MASTER] 2 | 3 | disable= 4 | C0114, # missing-module-docstring 5 | C0115, # missing-class-docstring 6 | C0116, # missing-function-docstring 7 | R0903 # Too few public methods 8 | 9 | # A comma-separated list of package or module names from where C extensions may 10 | # be loaded. Extensions are loading into the active Python interpreter and may 11 | # run arbitrary code. 12 | extension-pkg-whitelist=torch,augment_cpp 13 | 14 | # Specify a score threshold to be exceeded before program exits with error. 15 | fail-under=10 16 | 17 | # Add files or directories to the blacklist. They should be base names, not 18 | # paths. 19 | ignore=CVS 20 | 21 | # Add files or directories matching the regex patterns to the blacklist. The 22 | # regex matches against base names, not paths. 23 | ignore-patterns= 24 | 25 | # Python code to execute, usually for sys.path manipulation such as 26 | # pygtk.require(). 27 | #init-hook= 28 | 29 | # Use multiple processes to speed up Pylint. Specifying 0 will auto-detect the 30 | # number of processors available to use. 31 | jobs=1 32 | 33 | # Control the amount of potential inferred values when inferring a single 34 | # object. This can help the performance when dealing with large functions or 35 | # complex, nested conditions. 36 | limit-inference-results=100 37 | 38 | # List of plugins (as comma separated values of python module names) to load, 39 | # usually to register additional checkers. 40 | load-plugins= 41 | 42 | # Pickle collected data for later comparisons. 43 | persistent=yes 44 | 45 | # When enabled, pylint would attempt to guess common misconfiguration and emit 46 | # user-friendly hints instead of false-positive error messages. 47 | suggestion-mode=yes 48 | 49 | # Allow loading of arbitrary C extensions. Extensions are imported into the 50 | # active Python interpreter and may run arbitrary code. 51 | unsafe-load-any-extension=no 52 | 53 | 54 | [MESSAGES CONTROL] 55 | 56 | # Only show warnings with the listed confidence levels. Leave empty to show 57 | # all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED. 58 | confidence= 59 | 60 | # Disable the message, report, category or checker with the given id(s). You 61 | # can either give multiple identifiers separated by comma (,) or put this 62 | # option multiple times (only on the command line, not in the configuration 63 | # file where it should appear only once). You can also use "--disable=all" to 64 | # disable everything first and then reenable specific checks. For example, if 65 | # you want to run only the similarities checker, you can use "--disable=all 66 | # --enable=similarities". If you want to run only the classes checker, but have 67 | # no Warning level messages displayed, use "--disable=all --enable=classes 68 | # --disable=W". 69 | disable=print-statement, 70 | parameter-unpacking, 71 | unpacking-in-except, 72 | old-raise-syntax, 73 | backtick, 74 | long-suffix, 75 | old-ne-operator, 76 | old-octal-literal, 77 | import-star-module-level, 78 | non-ascii-bytes-literal, 79 | raw-checker-failed, 80 | bad-inline-option, 81 | locally-disabled, 82 | file-ignored, 83 | suppressed-message, 84 | useless-suppression, 85 | deprecated-pragma, 86 | use-symbolic-message-instead, 87 | apply-builtin, 88 | basestring-builtin, 89 | buffer-builtin, 90 | cmp-builtin, 91 | coerce-builtin, 92 | execfile-builtin, 93 | file-builtin, 94 | long-builtin, 95 | raw_input-builtin, 96 | reduce-builtin, 97 | standarderror-builtin, 98 | unicode-builtin, 99 | xrange-builtin, 100 | coerce-method, 101 | delslice-method, 102 | getslice-method, 103 | setslice-method, 104 | no-absolute-import, 105 | old-division, 106 | dict-iter-method, 107 | dict-view-method, 108 | next-method-called, 109 | metaclass-assignment, 110 | indexing-exception, 111 | raising-string, 112 | reload-builtin, 113 | oct-method, 114 | hex-method, 115 | nonzero-method, 116 | cmp-method, 117 | input-builtin, 118 | round-builtin, 119 | intern-builtin, 120 | unichr-builtin, 121 | map-builtin-not-iterating, 122 | zip-builtin-not-iterating, 123 | range-builtin-not-iterating, 124 | filter-builtin-not-iterating, 125 | using-cmp-argument, 126 | eq-without-hash, 127 | div-method, 128 | idiv-method, 129 | rdiv-method, 130 | exception-message-attribute, 131 | invalid-str-codec, 132 | sys-max-int, 133 | bad-python3-import, 134 | deprecated-string-function, 135 | deprecated-str-translate-call, 136 | deprecated-itertools-function, 137 | deprecated-types-field, 138 | next-method-defined, 139 | dict-items-not-iterating, 140 | dict-keys-not-iterating, 141 | dict-values-not-iterating, 142 | deprecated-operator-function, 143 | deprecated-urllib-function, 144 | xreadlines-attribute, 145 | deprecated-sys-function, 146 | exception-escape, 147 | comprehension-escape 148 | 149 | # Enable the message, report, category or checker with the given id(s). You can 150 | # either give multiple identifier separated by comma (,) or put this option 151 | # multiple time (only on the command line, not in the configuration file where 152 | # it should appear only once). See also the "--disable" option for examples. 153 | enable=c-extension-no-member 154 | 155 | 156 | [REPORTS] 157 | 158 | # Python expression which should return a score less than or equal to 10. You 159 | # have access to the variables 'error', 'warning', 'refactor', and 'convention' 160 | # which contain the number of messages in each category, as well as 'statement' 161 | # which is the total number of statements analyzed. This score is used by the 162 | # global evaluation report (RP0004). 163 | evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) 164 | 165 | # Template used to display messages. This is a python new-style format string 166 | # used to format the message information. See doc for all details. 167 | #msg-template= 168 | 169 | # Set the output format. Available formats are text, parseable, colorized, json 170 | # and msvs (visual studio). You can also give a reporter class, e.g. 171 | # mypackage.mymodule.MyReporterClass. 172 | output-format=text 173 | 174 | # Tells whether to display a full report or only the messages. 175 | reports=no 176 | 177 | # Activate the evaluation score. 178 | score=yes 179 | 180 | 181 | [REFACTORING] 182 | 183 | # Maximum number of nested blocks for function / method body 184 | max-nested-blocks=5 185 | 186 | # Complete name of functions that never returns. When checking for 187 | # inconsistent-return-statements if a never returning function is called then 188 | # it will be considered as an explicit return statement and no message will be 189 | # printed. 190 | never-returning-functions=sys.exit 191 | 192 | 193 | [STRING] 194 | 195 | # This flag controls whether inconsistent-quotes generates a warning when the 196 | # character used as a quote delimiter is used inconsistently within a module. 197 | check-quote-consistency=no 198 | 199 | # This flag controls whether the implicit-str-concat should generate a warning 200 | # on implicit string concatenation in sequences defined over several lines. 201 | check-str-concat-over-line-jumps=no 202 | 203 | 204 | [TYPECHECK] 205 | 206 | # List of decorators that produce context managers, such as 207 | # contextlib.contextmanager. Add to this list to register other decorators that 208 | # produce valid context managers. 209 | contextmanager-decorators=contextlib.contextmanager 210 | 211 | # List of members which are set dynamically and missed by pylint inference 212 | # system, and so shouldn't trigger E1101 when accessed. Python regular 213 | # expressions are accepted. 214 | generated-members= 215 | 216 | # Tells whether missing members accessed in mixin class should be ignored. A 217 | # mixin class is detected if its name ends with "mixin" (case insensitive). 218 | ignore-mixin-members=yes 219 | 220 | # Tells whether to warn about missing members when the owner of the attribute 221 | # is inferred to be None. 222 | ignore-none=yes 223 | 224 | # This flag controls whether pylint should warn about no-member and similar 225 | # checks whenever an opaque object is returned when inferring. The inference 226 | # can return multiple potential results while evaluating a Python object, but 227 | # some branches might not be evaluated, which results in partial inference. In 228 | # that case, it might be useful to still emit no-member and other checks for 229 | # the rest of the inferred objects. 230 | ignore-on-opaque-inference=yes 231 | 232 | # List of class names for which member attributes should not be checked (useful 233 | # for classes with dynamically set attributes). This supports the use of 234 | # qualified names. 235 | ignored-classes=optparse.Values,thread._local,_thread._local 236 | 237 | # List of module names for which member attributes should not be checked 238 | # (useful for modules/projects where namespaces are manipulated during runtime 239 | # and thus existing member attributes cannot be deduced by static analysis). It 240 | # supports qualified module names, as well as Unix pattern matching. 241 | ignored-modules= 242 | 243 | # Show a hint with possible names when a member name was not found. The aspect 244 | # of finding the hint is based on edit distance. 245 | missing-member-hint=yes 246 | 247 | # The minimum edit distance a name should have in order to be considered a 248 | # similar match for a missing member name. 249 | missing-member-hint-distance=1 250 | 251 | # The total number of similar names that should be taken in consideration when 252 | # showing a hint for a missing member. 253 | missing-member-max-choices=1 254 | 255 | # List of decorators that change the signature of a decorated function. 256 | signature-mutators= 257 | 258 | 259 | [SPELLING] 260 | 261 | # Limits count of emitted suggestions for spelling mistakes. 262 | max-spelling-suggestions=4 263 | 264 | # Spelling dictionary name. Available dictionaries: none. To make it work, 265 | # install the python-enchant package. 266 | spelling-dict= 267 | 268 | # List of comma separated words that should not be checked. 269 | spelling-ignore-words= 270 | 271 | # A path to a file that contains the private dictionary; one word per line. 272 | spelling-private-dict-file= 273 | 274 | # Tells whether to store unknown words to the private dictionary (see the 275 | # --spelling-private-dict-file option) instead of raising a message. 276 | spelling-store-unknown-words=no 277 | 278 | 279 | [FORMAT] 280 | 281 | # Expected format of line ending, e.g. empty (any line ending), LF or CRLF. 282 | expected-line-ending-format= 283 | 284 | # Regexp for a line that is allowed to be longer than the limit. 285 | ignore-long-lines=^\s*(# )??$ 286 | 287 | # Number of spaces of indent required inside a hanging or continued line. 288 | indent-after-paren=4 289 | 290 | # String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 291 | # tab). 292 | indent-string=' ' 293 | 294 | # Maximum number of characters on a single line. 295 | max-line-length=100 296 | 297 | # Maximum number of lines in a module. 298 | max-module-lines=1000 299 | 300 | # List of optional constructs for which whitespace checking is disabled. `dict- 301 | # separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. 302 | # `trailing-comma` allows a space between comma and closing bracket: (a, ). 303 | # `empty-line` allows space-only lines. 304 | no-space-check=trailing-comma, 305 | dict-separator 306 | 307 | # Allow the body of a class to be on the same line as the declaration if body 308 | # contains single statement. 309 | single-line-class-stmt=no 310 | 311 | # Allow the body of an if to be on the same line as the test if there is no 312 | # else. 313 | single-line-if-stmt=no 314 | 315 | 316 | [SIMILARITIES] 317 | 318 | # Ignore comments when computing similarities. 319 | ignore-comments=yes 320 | 321 | # Ignore docstrings when computing similarities. 322 | ignore-docstrings=yes 323 | 324 | # Ignore imports when computing similarities. 325 | ignore-imports=no 326 | 327 | # Minimum lines number of a similarity. 328 | min-similarity-lines=4 329 | 330 | 331 | [MISCELLANEOUS] 332 | 333 | # List of note tags to take in consideration, separated by a comma. 334 | notes=FIXME, 335 | XXX, 336 | TODO 337 | 338 | # Regular expression of note tags to take in consideration. 339 | #notes-rgx= 340 | 341 | 342 | [LOGGING] 343 | 344 | # The type of string formatting that logging methods do. `old` means using % 345 | # formatting, `new` is for `{}` formatting. 346 | logging-format-style=old 347 | 348 | # Logging modules to check that the string format arguments are in logging 349 | # function parameter format. 350 | logging-modules=logging 351 | 352 | 353 | [VARIABLES] 354 | 355 | # List of additional names supposed to be defined in builtins. Remember that 356 | # you should avoid defining new builtins when possible. 357 | additional-builtins= 358 | 359 | # Tells whether unused global variables should be treated as a violation. 360 | allow-global-unused-variables=yes 361 | 362 | # List of strings which can identify a callback function by name. A callback 363 | # name must start or end with one of those strings. 364 | callbacks=cb_, 365 | _cb 366 | 367 | # A regular expression matching the name of dummy variables (i.e. expected to 368 | # not be used). 369 | dummy-variables-rgx=_+$|(_[a-zA-Z0-9_]*[a-zA-Z0-9]+?$)|dummy|^ignored_|^unused_ 370 | 371 | # Argument names that match this expression will be ignored. Default to name 372 | # with leading underscore. 373 | ignored-argument-names=_.*|^ignored_|^unused_ 374 | 375 | # Tells whether we should check for unused import in __init__ files. 376 | init-import=no 377 | 378 | # List of qualified module names which can have objects that can redefine 379 | # builtins. 380 | redefining-builtins-modules=six.moves,past.builtins,future.builtins,builtins,io 381 | 382 | 383 | [BASIC] 384 | 385 | # Naming style matching correct argument names. 386 | argument-naming-style=snake_case 387 | 388 | # Regular expression matching correct argument names. Overrides argument- 389 | # naming-style. 390 | #argument-rgx= 391 | 392 | # Naming style matching correct attribute names. 393 | attr-naming-style=snake_case 394 | 395 | # Regular expression matching correct attribute names. Overrides attr-naming- 396 | # style. 397 | #attr-rgx= 398 | 399 | # Bad variable names which should always be refused, separated by a comma. 400 | bad-names=foo, 401 | bar, 402 | baz, 403 | toto, 404 | tutu, 405 | tata 406 | 407 | # Bad variable names regexes, separated by a comma. If names match any regex, 408 | # they will always be refused 409 | bad-names-rgxs= 410 | 411 | # Naming style matching correct class attribute names. 412 | class-attribute-naming-style=any 413 | 414 | # Regular expression matching correct class attribute names. Overrides class- 415 | # attribute-naming-style. 416 | #class-attribute-rgx= 417 | 418 | # Naming style matching correct class names. 419 | class-naming-style=PascalCase 420 | 421 | # Regular expression matching correct class names. Overrides class-naming- 422 | # style. 423 | #class-rgx= 424 | 425 | # Naming style matching correct constant names. 426 | const-naming-style=UPPER_CASE 427 | 428 | # Regular expression matching correct constant names. Overrides const-naming- 429 | # style. 430 | #const-rgx= 431 | 432 | # Minimum line length for functions/classes that require docstrings, shorter 433 | # ones are exempt. 434 | docstring-min-length=-1 435 | 436 | # Naming style matching correct function names. 437 | function-naming-style=snake_case 438 | 439 | # Regular expression matching correct function names. Overrides function- 440 | # naming-style. 441 | #function-rgx= 442 | 443 | # Good variable names which should always be accepted, separated by a comma. 444 | good-names=i, 445 | j, 446 | k, 447 | ex, 448 | Run, 449 | _ 450 | 451 | # Good variable names regexes, separated by a comma. If names match any regex, 452 | # they will always be accepted 453 | good-names-rgxs= 454 | 455 | # Include a hint for the correct naming format with invalid-name. 456 | include-naming-hint=no 457 | 458 | # Naming style matching correct inline iteration names. 459 | inlinevar-naming-style=any 460 | 461 | # Regular expression matching correct inline iteration names. Overrides 462 | # inlinevar-naming-style. 463 | #inlinevar-rgx= 464 | 465 | # Naming style matching correct method names. 466 | method-naming-style=snake_case 467 | 468 | # Regular expression matching correct method names. Overrides method-naming- 469 | # style. 470 | #method-rgx= 471 | 472 | # Naming style matching correct module names. 473 | module-naming-style=snake_case 474 | 475 | # Regular expression matching correct module names. Overrides module-naming- 476 | # style. 477 | #module-rgx= 478 | 479 | # Colon-delimited sets of names that determine each other's naming style when 480 | # the name regexes allow several styles. 481 | name-group= 482 | 483 | # Regular expression which should only match function or class names that do 484 | # not require a docstring. 485 | no-docstring-rgx=^_ 486 | 487 | # List of decorators that produce properties, such as abc.abstractproperty. Add 488 | # to this list to register other decorators that produce valid properties. 489 | # These decorators are taken in consideration only for invalid-name. 490 | property-classes=abc.abstractproperty 491 | 492 | # Naming style matching correct variable names. 493 | variable-naming-style=snake_case 494 | 495 | # Regular expression matching correct variable names. Overrides variable- 496 | # naming-style. 497 | #variable-rgx= 498 | 499 | 500 | [DESIGN] 501 | 502 | # Maximum number of arguments for function / method. 503 | max-args=5 504 | 505 | # Maximum number of attributes for a class (see R0902). 506 | max-attributes=7 507 | 508 | # Maximum number of boolean expressions in an if statement (see R0916). 509 | max-bool-expr=5 510 | 511 | # Maximum number of branch for function / method body. 512 | max-branches=12 513 | 514 | # Maximum number of locals for function / method body. 515 | max-locals=15 516 | 517 | # Maximum number of parents for a class (see R0901). 518 | max-parents=7 519 | 520 | # Maximum number of public methods for a class (see R0904). 521 | max-public-methods=20 522 | 523 | # Maximum number of return / yield for function / method body. 524 | max-returns=6 525 | 526 | # Maximum number of statements in function / method body. 527 | max-statements=50 528 | 529 | # Minimum number of public methods for a class (see R0903). 530 | min-public-methods=2 531 | 532 | 533 | [CLASSES] 534 | 535 | # List of method names used to declare (i.e. assign) instance attributes. 536 | defining-attr-methods=__init__, 537 | __new__, 538 | setUp, 539 | __post_init__ 540 | 541 | # List of member names, which should be excluded from the protected access 542 | # warning. 543 | exclude-protected=_asdict, 544 | _fields, 545 | _replace, 546 | _source, 547 | _make 548 | 549 | # List of valid names for the first argument in a class method. 550 | valid-classmethod-first-arg=cls 551 | 552 | # List of valid names for the first argument in a metaclass class method. 553 | valid-metaclass-classmethod-first-arg=cls 554 | 555 | 556 | [IMPORTS] 557 | 558 | # List of modules that can be imported at any level, not just the top level 559 | # one. 560 | allow-any-import-level= 561 | 562 | # Allow wildcard imports from modules that define __all__. 563 | allow-wildcard-with-all=no 564 | 565 | # Analyse import fallback blocks. This can be used to support both Python 2 and 566 | # 3 compatible code, which means that the block might have code that exists 567 | # only in one or another interpreter, leading to false positives when analysed. 568 | analyse-fallback-blocks=no 569 | 570 | # Deprecated modules which should not be used, separated by a comma. 571 | deprecated-modules=optparse,tkinter.tix 572 | 573 | # Create a graph of external dependencies in the given file (report RP0402 must 574 | # not be disabled). 575 | ext-import-graph= 576 | 577 | # Create a graph of every (i.e. internal and external) dependencies in the 578 | # given file (report RP0402 must not be disabled). 579 | import-graph= 580 | 581 | # Create a graph of internal dependencies in the given file (report RP0402 must 582 | # not be disabled). 583 | int-import-graph= 584 | 585 | # Force import order to recognize a module as part of the standard 586 | # compatibility libraries. 587 | known-standard-library= 588 | 589 | # Force import order to recognize a module as part of a third party library. 590 | known-third-party=enchant 591 | 592 | # Couples of modules and preferred modules, separated by a comma. 593 | preferred-modules= 594 | 595 | 596 | [EXCEPTIONS] 597 | 598 | # Exceptions that will emit a warning when being caught. Defaults to 599 | # "BaseException, Exception". 600 | overgeneral-exceptions=BaseException, 601 | Exception 602 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | ## Enforcement 56 | 57 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 58 | reported by contacting the project team at . All 59 | complaints will be reviewed and investigated and will result in a response that 60 | is deemed necessary and appropriate to the circumstances. The project team is 61 | obligated to maintain confidentiality with regard to the reporter of an incident. 62 | Further details of specific enforcement policies may be posted separately. 63 | 64 | Project maintainers who do not follow or enforce the Code of Conduct in good 65 | faith may face temporary or permanent repercussions as determined by other 66 | members of the project's leadership. 67 | 68 | ## Attribution 69 | 70 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 71 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 72 | 73 | [homepage]: https://www.contributor-covenant.org 74 | 75 | For answers to common questions about this code of conduct, see 76 | https://www.contributor-covenant.org/faq -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to WavAugment 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Our Development Process 6 | We use this github repository for the development. 7 | 8 | ## Pull Requests 9 | We actively welcome your pull requests. 10 | 11 | 1. Fork the repo and create your branch from `master`. 12 | 2. If you've added code that should be tested, add tests. 13 | 3. If you've changed APIs, update the documentation. 14 | 4. Ensure the test suite passes. 15 | 5. Make sure your code lints. 16 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 17 | 18 | ## Contributor License Agreement ("CLA") 19 | In order to accept your pull request, we need you to submit a CLA. You only need 20 | to do this once to work on any of Facebook's open source projects. 21 | 22 | Complete your CLA here: 23 | 24 | ## Issues 25 | We use GitHub issues to track public bugs. Please ensure your description is 26 | clear and has sufficient instructions to be able to reproduce the issue. 27 | 28 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 29 | disclosure of security bugs. In those cases, please go through the process 30 | outlined on that page and do not file a public issue. 31 | 32 | ## Coding Style 33 | * We use 4 spaces for indentation in python, 34 | * Please use clang-format and pylint to format your changes: 35 | ** Run `clang-format -i` for formatting C++ code in-place; 36 | ** Make sure that `pylint augment` does not result in new warnings. 37 | 38 | ## License 39 | By contributing to WavAugment, you agree that your contributions will be licensed 40 | under the LICENSE file in the root directory of this source tree. 41 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Facebook, Inc. and its affiliates. 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # WavAugment 2 | 3 | WavAugment performs data augmentation on audio data. 4 | The audio data is represented as [pytorch](https://pytorch.org/) tensors. 5 | 6 | It is particularly useful for speech data. 7 | Among others, it implements the augmentations that we found to be most useful for self-supervised learning 8 | (_Data Augmenting Contrastive Learning of Speech Representations in the Time Domain_, E. Kharitonov, M. Rivière, G. Synnaeve, L. Wolf, P.-E. Mazaré, M. Douze, E. Dupoux. [[arxiv]](https://arxiv.org/abs/2007.00991)): 9 | 10 | * Pitch randomization, 11 | * Reverberation, 12 | * Additive noise, 13 | * Time dropout (temporal masking), 14 | * Band reject, 15 | * Clipping 16 | 17 | Internally, WavAugment uses [libsox](http://sox.sourceforge.net/libsox.html) and allows interleaving of libsox- and pytorch-based effects. 18 | 19 | ### Requirements 20 | * Linux or MacOS 21 | * [pytorch](pytorch.org) >= 1.7 22 | * [torchaudio](pytorch.org/audio) >= 0.7 23 | 24 | ### Installation 25 | To install WavAugment, run the following command: 26 | ```bash 27 | git clone git@github.com:facebookresearch/WavAugment.git && cd WavAugment && python setup.py develop 28 | ``` 29 | 30 | ### Testing 31 | Requires pytest (`pip install pytest`) 32 | 33 | ```bash 34 | python -m pytest -v --doctest-modules 35 | ``` 36 | 37 | ## Usage 38 | 39 | First of all, we provide thouroughly documented [examples](./examples/python), where we demonstrate how a data-augmented dataset interface works. We also provide a Jupyter-based [tutorial](./examples/python/WavAugment_walkthrough.ipynb) [(open in colab)](https://colab.research.google.com/github/facebookresearch/WavAugment/blob/master/examples/python/WavAugment_walkthrough.ipynb) that illlustrates how one can apply various useful effects to a piece of speech (recorded over the mic or pre-recorded). 40 | 41 | ### The `EffectChain` 42 | 43 | The central object is the chain of effects, `EffectChain`, that are applied on a `torch.Tensor` to produce another `torch.Tensor`. 44 | This chain can have multiple effects composed: 45 | ```python 46 | import augment 47 | effect_chain = augment.EffectChain().pitch(100).rate(16_000) 48 | ``` 49 | Parameters of the effect coincide with those of libsox (http://sox.sourceforge.net/libsox.html); however, you can also randomize the parameters by providing a python `Callable` and mix them with standard parameters: 50 | ```python 51 | import numpy as np 52 | random_pitch_shift = lambda: np.random.randint(-100, +100) 53 | # the pitch will be changed by a shift somewhere between (-100, +100) 54 | effect_chain = augment.EffectChain().pitch("-q", random_pitch_shift).rate(16_000) 55 | ``` 56 | Here, the flag`-q` makes `pitch` run faster at some expense of the quality. 57 | If some parameters are provided by a Callable, this Callable will be invoked every time `EffectChain` is applied (eg. to generate random parameters). 58 | 59 | ### Applying the chain 60 | 61 | To apply a chain of effects on a torch.Tensor, we code the following: 62 | ```python 63 | output_tensor = augment.EffectChain().pitch(100).rate(16_000).apply(input_tensor, \ 64 | src_info=src_info, target_info=target_info) 65 | ``` 66 | WavAugment expects `input_tensor` to have a shape of (channels, length). As `input_tensor` does not contain important meta-information, such as sampling rate, we need to provide it manually. 67 | This is done by passing two dictionaries, `src_info` (meta-information about the input format) and `target_info` (our expectated format for the output). 68 | 69 | At minimum, we need to set the sampling rate for the input tensor: `{'rate': 16_000}`. 70 | 71 | ### Example usage 72 | 73 | Below is a small gist of a potential usage: 74 | 75 | ```python 76 | import augment 77 | import numpy as np 78 | 79 | x, sr = torchaudio.load(test_wav) 80 | 81 | # input signal properties 82 | src_info = {'rate': sr} 83 | 84 | # output signal properties 85 | target_info = {'channels': 1, 86 | 'length': 0, # not known beforehand 87 | 'rate': 16_000} 88 | # write down the chain of effects with their string parameters and call .apply() 89 | # effects are specified as a chain of method calls with parameters that can be 90 | # strings, numbers, or callables. The latter case is used for generating randomized 91 | # transformations 92 | random_pitch = lambda: np.random.randint(-400, -200) 93 | y = augment.EffectChain().pitch(random_pitch).rate(16_000).apply(x, \ 94 | src_info=src_info, target_info=target_info) 95 | ``` 96 | 97 | ## Important notes 98 | It often happens that a command-line invocation of sox would change effect chain. To get a better idea of what sox executes internally, you can launch it with a -V flag, eg by running: 99 | ```bash 100 | sox -V tests/test.wav out.wav reverb 0 50 100 101 | ``` 102 | we will see something like: 103 | ``` 104 | sox INFO sox: effects chain: input 16000Hz 1 channels 105 | sox INFO sox: effects chain: reverb 16000Hz 2 channels 106 | sox INFO sox: effects chain: channels 16000Hz 1 channels 107 | sox INFO sox: effects chain: dither 16000Hz 1 channels 108 | sox INFO sox: effects chain: output 16000Hz 1 channels 109 | ``` 110 | This output tells us that the `reverb` effect changes the number of channels, which are squashed into 1 channel by the `channel` effect. Sox also added `dither` effect to hide processing artifacts. 111 | 112 | WavAugment remains explicit and doesn't add effects under the hood. 113 | If you want to emulate a Sox command that decomposes into several effects, we advise to consult `sox -V` and apply the effects manually. 114 | Try it out on some files before running a heavy machine-learning job. 115 | 116 | ## Citation 117 | If you find WavAugment useful in your research, please consider citing: 118 | ``` 119 | @article{wavaugment2020, 120 | title={Data Augmenting Contrastive Learning of Speech Representations in the Time Domain}, 121 | author={Kharitonov, Eugene and Rivi{\`e}re, Morgane and Synnaeve, Gabriel and Wolf, Lior and Mazar{\'e}, Pierre-Emmanuel and Douze, Matthijs and Dupoux, Emmanuel}, 122 | journal={arXiv preprint arXiv:2007.00991}, 123 | year={2020} 124 | } 125 | ``` 126 | 127 | ## Contributing 128 | See the [CONTRIBUTING](CONTRIBUTING.md) file for how to help out. 129 | 130 | ## License 131 | WavAugment is MIT licensed, as found in the LICENSE file. 132 | 133 | -------------------------------------------------------------------------------- /augment/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree 5 | 6 | from .effects import ( 7 | EffectChain, 8 | shutdown_sox, 9 | ) 10 | 11 | __all__ = [ 12 | 'EffectChain', 13 | 'shutdown_sox' 14 | ] 15 | -------------------------------------------------------------------------------- /augment/effects.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | from typing import Dict, List, Optional, Callable, Union, Tuple, Set 7 | 8 | import torch 9 | import numpy as np 10 | import torchaudio 11 | from torchaudio.sox_effects.sox_effects import effect_names as get_effect_names 12 | 13 | def shutdown_sox() -> None: 14 | pass 15 | 16 | 17 | # Arguments that we can pass to effects 18 | SoxArg = Optional[List[Union[str, int, Callable]]] 19 | 20 | 21 | class _PyEffectChain: 22 | 23 | def __init__(self): 24 | self._effects = [] 25 | 26 | def add_effect(self, effect_name, effect_params): 27 | params = [str(e) for e in effect_params] 28 | self._effects.append([effect_name, *params]) 29 | 30 | def apply_flow_effects(self, tensor, src_info, target_info): 31 | return torchaudio.sox_effects.apply_effects_tensor(tensor, int(src_info['rate']), self._effects) 32 | 33 | 34 | class SoxEffect: 35 | def __init__(self, name: str, args: SoxArg): 36 | self.name = name 37 | self.args = args if args else [] 38 | 39 | def instantiate(self): 40 | """ 41 | >>> import random; random.seed(7) 42 | >>> effect = SoxEffect("pitch", [lambda: random.randint(-100, 100)]) 43 | >>> effect.instantiate() 44 | ['pitch', ['-18']] 45 | >>> effect.instantiate() 46 | ['pitch', ['-62']] 47 | """ 48 | instantiated_args = [] 49 | for arg in self.args: 50 | if callable(arg): 51 | arg = arg() 52 | if isinstance(arg, list): 53 | instantiated_args.extend([str(v) for v in arg]) 54 | else: 55 | instantiated_args.append(str(arg)) 56 | else: 57 | instantiated_args.append(str(arg)) 58 | 59 | return [self.name, instantiated_args] 60 | 61 | 62 | class EffectChain: 63 | """ 64 | >>> chain = EffectChain() 65 | >>> _ = chain.pitch("100").rate(16_000).dither() 66 | >>> len(chain) 67 | 3 68 | >>> [e.name for e in chain._chain] 69 | ['pitch', 'rate', 'dither'] 70 | """ 71 | 72 | # libsox sample_t is an int between [-1 << 31, 1 << 31); 73 | # while torchaudio operates with [-1, 1]. Hence, 74 | # each time we pass something to/from libsox, we rescale 75 | _NORMALIZER: int = 1 << 31 76 | KNOWN_EFFECTS: Set[str] = set() 77 | 78 | def __init__(self, in_place: bool = False): 79 | self._chain: List[Union[Callable, SoxEffect]] = [] 80 | self.in_place: bool = in_place 81 | 82 | def _append_effect_to_chain(self, name: str, args: SoxArg = None): 83 | effect = SoxEffect(name, args) 84 | self._chain.append(effect) 85 | return self 86 | 87 | def clear(self): 88 | self._chain = [] 89 | 90 | def __len__(self): 91 | return len(self._chain) 92 | 93 | @staticmethod 94 | def _apply_sox_effects(chain: List[SoxEffect], 95 | input_tensor: torch.Tensor, 96 | src_info: Dict, 97 | target_info: Dict) -> Tuple[torch.Tensor, int]: 98 | instantiated_chain = [x.instantiate() for x in chain] 99 | sox_chain = _PyEffectChain() 100 | for effect_name, effect_args in instantiated_chain: 101 | sox_chain.add_effect(effect_name, effect_args) 102 | 103 | out, sr = sox_chain.apply_flow_effects(input_tensor, 104 | src_info, 105 | target_info) 106 | return out, sr 107 | 108 | def apply(self, 109 | input_tensor: torch.Tensor, 110 | src_info: Dict[str, Union[int, float]], 111 | target_info: Optional[Dict[str, Union[int, float]]] = None) -> torch.Tensor: 112 | """ 113 | input_tensor (torch.Tensor): the input wave to be transformed; 114 | expected shape is (n_channels, length). If it has only 115 | one dimension, it is automatically expanded to have 1 channel. 116 | src_info (Dict): description of the input signal 117 | target_info (Dict): description of the output signal 118 | 119 | Fields that src_info and target_info can contain: 120 | * rate (mandatory for input), 121 | * length, 122 | * precision, 123 | * bits_per_sample. 124 | Those fields are only used by sox-based effects. 125 | 126 | Minimally, src_info must contain rate field (e.g. `{rate: 16_000}`). 127 | Both src_info and target_info can set `length`. If src_info specifies 128 | `length`, only first `length` samples are used; if target_info has `length`, 129 | further output is trimmed. 130 | 131 | It is might happen that sox will return wave shorter than `length` - in this 132 | case the output will be padded with zeroes. 133 | 134 | returns: torch.Tensor with transformed sound. 135 | """ 136 | target_info = dict() if target_info is None else target_info 137 | 138 | if not torch.is_tensor(input_tensor) or input_tensor.is_cuda: 139 | raise RuntimeError('Expected a CPU tensor') 140 | 141 | if not self.in_place: 142 | input_tensor = input_tensor.clone() 143 | if 'rate' not in src_info: 144 | raise RuntimeError("'rate' must be specified for the input") 145 | if len(input_tensor.size()) == 1: 146 | input_tensor = input_tensor.unsqueeze(0) 147 | if 'length' in src_info and src_info['length'] > input_tensor.size(1): 148 | raise RuntimeError("'length' is beyond the tensor length") 149 | 150 | src_info = dict(src_info) # can be mutated in process 151 | sr = src_info['rate'] 152 | 153 | if not self._chain: 154 | out = input_tensor 155 | return out 156 | 157 | sox_effects: List[SoxEffect] = [] 158 | x = input_tensor 159 | 160 | # To minimize the number of sox calls, we stack consequent sox effects together in a single 161 | # sox-side chain. In contrast, we apply python effects immediately. 162 | 163 | for effect in self._chain: 164 | if callable(effect): 165 | if sox_effects: 166 | x, sr = EffectChain._apply_sox_effects( 167 | sox_effects, x, src_info, target_info) 168 | src_info = dict(target_info) 169 | assert src_info['rate'] == sr 170 | 171 | sox_effects = [] 172 | # effect should not mutate src_info or target_info, but 173 | # return new ones if changed 174 | x, src_info, target_info = effect(x, src_info, target_info) 175 | elif isinstance(effect, SoxEffect): 176 | sox_effects.append(effect) 177 | else: 178 | assert False 179 | 180 | if sox_effects: 181 | x, _ = EffectChain._apply_sox_effects( 182 | sox_effects, x, src_info, target_info) 183 | return x 184 | 185 | def time_dropout(self, max_frames: Optional[int] = None, max_seconds: Optional[float] = None): 186 | """ 187 | >>> np.random.seed(1) 188 | >>> chain = EffectChain().time_dropout(max_seconds=0.1) 189 | >>> t = torch.ones([1, 16000]) 190 | >>> x = chain.apply(t, {'rate': 16000}, {'rate': 16000}) 191 | >>> x.min().item(), x.max().item() 192 | (0.0, 1.0) 193 | >>> (x == 0).sum().item() 194 | 1061 195 | >>> (x[:, 235:1296] == 0).all().item() 196 | True 197 | >>> (x[:, :235] == 0).any().item() 198 | False 199 | >>> (x[:, 235 + 1061 + 1:] == 0).any().item() 200 | False 201 | """ 202 | self._chain.append(TimeDropout(max_frames, max_seconds)) 203 | 204 | return self 205 | 206 | def additive_noise(self, noise_generator: Callable, snr: float): 207 | """ 208 | >>> signal = torch.zeros((1, 100)).uniform_() 209 | >>> noise_generator = lambda: torch.zeros((1, 100)) 210 | >>> chain = EffectChain().additive_noise(noise_generator, snr=0) 211 | >>> x = chain.apply(signal, {'rate': 16000}, {'rate': 16000}) 212 | >>> (x == signal.mul(0.5)).all().item() 213 | True 214 | """ 215 | self._chain.append(AdditiveNoise( 216 | noise_generator=noise_generator, snr=snr)) 217 | return self 218 | 219 | def clip(self, clamp_factor: Union[Callable, float]): 220 | """ 221 | >>> signal = torch.tensor([-10, -5, 0, 5, 10]).float() 222 | >>> factor_generator = 0.5 223 | >>> chain = EffectChain().clip(factor_generator) 224 | >>> x = chain.apply(signal, {'rate': 16000}, {}) 225 | >>> x 226 | tensor([[-5., -5., 0., 5., 5.]]) 227 | """ 228 | self._chain.append(ClipValue(clamp_factor)) 229 | return self 230 | 231 | KNOWN_EFFECTS.add('additive_noise') 232 | KNOWN_EFFECTS.add('clip') 233 | KNOWN_EFFECTS.add('time_dropout') 234 | 235 | 236 | class TimeDropout: 237 | def __init__(self, max_frames: Optional[int] = None, max_seconds: Optional[float] = None): 238 | assert max_frames or max_seconds 239 | self.max_frames = max_frames 240 | self.max_seconds = max_seconds 241 | 242 | def __call__(self, x, src_info, dst_info): 243 | if self.max_frames is None: 244 | max_frames = int(src_info['rate'] * self.max_seconds) 245 | else: 246 | max_frames = self.max_frames 247 | 248 | length = np.random.randint(0, max_frames) 249 | 250 | start = np.random.randint(0, x.size(1) - length) 251 | end = start + length 252 | 253 | x[:, start:end, ...].zero_() 254 | return x, src_info, dst_info 255 | 256 | 257 | class AdditiveNoise: 258 | def __init__(self, noise_generator: Callable, snr: float): 259 | self.noise_generator = noise_generator 260 | self.snr = snr 261 | 262 | r = np.exp(snr * np.log(10) / 10) 263 | self.coeff = r / (1.0 + r) 264 | 265 | def __call__(self, x, src_info, dst_info): 266 | noise_instance = self.noise_generator() 267 | assert noise_instance.numel() == x.numel( 268 | ), 'Noise and signal shapes are incompatible' 269 | 270 | noised = self.coeff * x + (1.0 - self.coeff) * noise_instance.view_as(x) 271 | return noised, src_info, dst_info 272 | 273 | 274 | class ClipValue: 275 | def __init__(self, clamp_factor: Union[Callable, float]): 276 | self.clamp_factor = clamp_factor 277 | 278 | def __call__(self, x, src_info, dst_info): 279 | factor = self.clamp_factor() if callable(self.clamp_factor) else self.clamp_factor 280 | x_min, x_max = x.min(), x.max() 281 | 282 | x.clamp_(min=x_min * factor, max=x_max * factor) 283 | return x, src_info, dst_info 284 | 285 | 286 | def create_method(name): 287 | EffectChain.KNOWN_EFFECTS.add(name) 288 | return lambda s, *x: s._append_effect_to_chain(name, list(x)) # pylint: disable=protected-access 289 | 290 | 291 | for _effect_name in get_effect_names(): 292 | setattr(EffectChain, _effect_name, create_method(_effect_name)) 293 | -------------------------------------------------------------------------------- /examples/python/README.md: -------------------------------------------------------------------------------- 1 | # Python examples 2 | 3 | In this directory, we provide a couple of examples, described below. 4 | 5 | ## Walkthrough tutorial 6 | 7 | [WavAugment_walkthrough.ipynb](./examples/python/WavAugment_walkthrough.ipynb) [(open in colab)](https://colab.research.google.com/github/facebookresearch/WavAugment/blob/master/examples/python/WavAugment_walkthrough.ipynb) provides a succint walkthrough tutorial, showing how effects can be applied on a piece of speech (recorded over the mic or pre-recorded). 8 | 9 | ## Processing a single file 10 | 11 | The script [process_file.py](./process_file.py) gives a taste on how different compositions of speech augmentation techniques sound like, by allowing to augment single files. It suppors a few randomized data augmentations: pitch, reverberation, temporal masking, band rejection, and clipping. 12 | 13 | A typical usage is: 14 | ```bash 15 | python process_file.py --input_file=./tests/test.wav \ 16 | --output_file=augmented.wav \ 17 | --chain=pitch,reverb,time_drop \ 18 | --pitch_shift_max=500 \ 19 | --t_ms=100 20 | ``` 21 | where `--chain` specifies a list of augmentations applied sequentially, left-to-right; `t_ms` and `pitch_shift_max` specify parameters of the augmentations. `augmented.wav` would contain the randomly augmented sound. 22 | 23 | 24 | ## Usage in self-supervised learning 25 | 26 | In [librispeech_selfsupervised.py](./librispeech_selfsupervised.py) we use WavAugment in a way that can be used for self-supervised learning. 27 | We define a dataset that iterates over Librispeech data, reads a (randomly shifted) sequence of pre-defined length from each file 28 | and returns two copies of it, independently augmented in different ways. This example does not learns a model, only measures the dataset reading time. 29 | 30 | The code is thoroughly documented. This command will download `dev-clean` in the `./data` directory (if needed), iterate over it, 31 | extracting sequences of 1 second length. The batches of size 32 would be prepared by 8 DataLoader workers. 32 | 33 | ```bash 34 | python librispeech_selfsupervised.py --data=./data --subset=dev-clean --sequence_length_seconds=1 --n_workers=8 --download --batch_size=8 35 | ``` 36 | 37 | Iterating over Librispeech train-clean-100 (100 hours of audio) with 16 workers takes 2 seconds without any data augmentation. 38 | With WavAugment data augmentation it takes around 5 seconds (on a solid server). 39 | -------------------------------------------------------------------------------- /examples/python/librispeech_selfsupervised.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree 5 | 6 | import torch 7 | 8 | import torchaudio 9 | from torchaudio.datasets.librispeech import FOLDER_IN_ARCHIVE, URL 10 | from torchaudio.datasets.librispeech import LIBRISPEECH as Librispeech 11 | from functools import lru_cache 12 | import augment 13 | import numpy as np 14 | import os 15 | import time 16 | import argparse 17 | 18 | """ 19 | In this example, we implement a simple Librispeech-based dataset for self-supervised learning 20 | with data augmentations implemented via WavAugment. 21 | """ 22 | 23 | 24 | # cache all files lengths in-mem to reduce disk IO 25 | @lru_cache(maxsize=None) 26 | def get_file_length(filepath): 27 | """ 28 | Returns the length of the sequence in the file specified by `filepath` 29 | """ 30 | signal_info, encoding_info = torchaudio.info(filepath) 31 | return signal_info.length 32 | 33 | 34 | class LibrispeechSelfSupervised(Librispeech): 35 | """ 36 | Extends the standard Librispeech dataset to a self-supervised use: 37 | * hides speaker and text labels 38 | * loads a sequences of speech of a predefined length, randomly shifted within a file 39 | * return two copies of this sequence, called `past` and `future` 40 | * those two sequences are independently augmented 41 | """ 42 | 43 | def __init__(self, root, sequence_length, augment_past=None, augment_future=None, url=URL, folder_in_archive=FOLDER_IN_ARCHIVE, download=False): 44 | """ 45 | root: where the dataset is stored 46 | sequence_length: expected length of the sequence 47 | augment_past: a Callable that applies data augmentation on `past` sequences 48 | augment_future: a Callable that applies data augmentation on `future` sequences 49 | """ 50 | super().__init__(root, url, folder_in_archive, download) 51 | self.sequence_length = sequence_length 52 | self.augment_past = augment_past 53 | self.augment_future = augment_future 54 | 55 | def __getitem__(self, n): 56 | fileid = self._walker[n] 57 | waveform = self.load_librispeech_item(fileid) 58 | past, future = waveform, waveform 59 | 60 | if self.augment_past: 61 | past = self.augment_past(past) 62 | if self.augment_future: 63 | future = self.augment_future(future) 64 | 65 | return past, future 66 | 67 | def load_librispeech_item(self, fileid): 68 | speaker_id, chapter_id, utterance_id = fileid.split("-") 69 | 70 | file_audio = fileid + self._ext_audio 71 | file_audio = os.path.join( 72 | self._path, speaker_id, chapter_id, file_audio) 73 | 74 | # Get the sequence length 75 | length = get_file_length(file_audio) 76 | 77 | assert length >= self.sequence_length, \ 78 | f'Sequence {file_audio} is too short for the required length {self.sequence_length}' 79 | # Generate a random offset within the file 80 | offset = np.random.randint(0, length - self.sequence_length) 81 | 82 | # Load a randomly shifted piece of audio 83 | waveform, sample_rate = torchaudio.load( 84 | file_audio, offset=offset, num_frames=self.sequence_length) 85 | assert waveform.size(1) == self.sequence_length 86 | return waveform 87 | 88 | 89 | class ChainRunner: 90 | """ 91 | Takes an instance of augment.EffectChain and applies it on pytorch tensors. 92 | """ 93 | 94 | def __init__(self, chain): 95 | self.chain = chain 96 | 97 | def __call__(self, x): 98 | """ 99 | x: torch.Tensor, (channels, length). Must be placed on CPU. 100 | """ 101 | src_info = {'channels': x.size(0), # number of channels 102 | 'length': x.size(1), # length of the sequence 103 | 'precision': 32, # precision (16, 32 bits) 104 | 'rate': 16000.0, # sampling rate 105 | 'bits_per_sample': 32} # size of the sample 106 | 107 | target_info = {'channels': 1, 108 | 'length': x.size(1), 109 | 'precision': 32, 110 | 'rate': 16000.0, 111 | 'bits_per_sample': 32} 112 | 113 | y = self.chain.apply( 114 | x, src_info=src_info, target_info=target_info) 115 | 116 | # sox might misbehave sometimes by giving nan/inf if sequences are too short (or silent) 117 | # and the effect chain includes eg `pitch` 118 | if torch.isnan(y).any() or torch.isinf(y).any(): 119 | return x.clone() 120 | return y 121 | 122 | 123 | # Generate a random shift applied to the speaker's pitch 124 | def random_pitch_shift(): 125 | return np.random.randint(-300, 300) 126 | 127 | # Generate a random size of the room 128 | def random_room_size(): 129 | return np.random.randint(0, 100) 130 | 131 | 132 | def get_args(): 133 | parser = argparse.ArgumentParser() 134 | parser.add_argument('--data', type=str, default='./data', help='Where Librispeech is placed') 135 | parser.add_argument('--download', action='store_true', help='Whether the dataset can be downloaded automatically if not found') 136 | parser.add_argument('--subset', choices=["dev-clean", "dev-other", "test-clean", "test-other", "train-clean-100", 137 | "train-clean-360", "train-other-500"], default='dev-clean', help='Librispeech subset to use') 138 | parser.add_argument('--sequence_length_seconds', type=int, default=1, help='Sample sequence length') 139 | parser.add_argument('--batch_size', type=int, default=32, help="Batch size") 140 | parser.add_argument('--n_workers', type=int, default=8, help="Number of parallel workers to read/preprocess data") 141 | parser.add_argument('--n_epochs', type=int, default=3, help="Number of epochs to run") 142 | parser.add_argument('--dump', action="store_true", help="Dump examples of (non)augmented sequences." 143 | "They would be saved in 'original.wav' and 'augmented.wav'") 144 | 145 | 146 | args = parser.parse_args() 147 | return args 148 | 149 | if __name__ == '__main__': 150 | args = get_args() 151 | 152 | effect_chain_past = augment.EffectChain() 153 | # The pitch effect changes the sampling ratio; we have to compensate for that. 154 | # Here, we specify 'quick' options on both pitch and rate effects, to speed up things 155 | effect_chain_past.pitch("-q", random_pitch_shift).rate("-q", 16_000) 156 | # Next effect we add is `reverb`; it adds makes the signal to have two channels, 157 | # which we combine into 1 by running `channels` w/o parameters 158 | effect_chain_past.reverb(50, 50, random_room_size).channels() 159 | # Futher, we add an effect that randomly drops one 50ms subsequence 160 | effect_chain_past.time_dropout(max_seconds=50 / 1000) 161 | 162 | effect_chain_past_runner = ChainRunner(effect_chain_past) 163 | 164 | # the second, `future` copy would be non-augmented 165 | effect_chain_future = None 166 | effect_chain_future_runner = None 167 | 168 | dataset = LibrispeechSelfSupervised( 169 | root=args.data, 170 | augment_past=effect_chain_past_runner, 171 | augment_future=effect_chain_future_runner, 172 | # In Librispeech, sampling rate is 16000 173 | sequence_length=args.sequence_length_seconds* 16_000, 174 | url=args.subset, 175 | download=args.download) 176 | 177 | if args.dump: 178 | augmented, original = dataset[0] 179 | torchaudio.save('augmented.wav', augmented, 16_000) 180 | torchaudio.save('original.wav', original, 16_000) 181 | print('Saved examples of augmented and non-augmented sequences to augmented.wav and original.wav') 182 | 183 | dataloader = torch.utils.data.DataLoader( 184 | dataset=dataset, 185 | batch_size=args.batch_size, 186 | shuffle=True, 187 | num_workers=args.n_workers 188 | ) 189 | 190 | for epoch in range(args.n_epochs): 191 | start = time.time() 192 | for batch in dataloader: 193 | pass 194 | print(f'Finished epoch {epoch} in {time.time() - start}') 195 | -------------------------------------------------------------------------------- /examples/python/process_file.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | import numpy as np 7 | import torch 8 | import torchaudio 9 | import augment 10 | import argparse 11 | 12 | from dataclasses import dataclass 13 | 14 | class RandomPitchShift: 15 | def __init__(self, shift_max=300): 16 | self.shift_max = shift_max 17 | 18 | def __call__(self): 19 | return np.random.randint(-self.shift_max, self.shift_max) 20 | 21 | class RandomClipFactor: 22 | def __init__(self, factor_min=0.0, factor_max=1.0): 23 | self.factor_min = factor_min 24 | self.factor_max = factor_max 25 | def __call__(self): 26 | return np.random.triangular(self.factor_min, self.factor_max, self.factor_max) 27 | 28 | @dataclass 29 | class RandomReverb: 30 | reverberance_min: int = 50 31 | reverberance_max: int = 50 32 | damping_min: int = 50 33 | damping_max: int = 50 34 | room_scale_min: int = 0 35 | room_scale_max: int = 100 36 | 37 | def __call__(self): 38 | reverberance = np.random.randint(self.reverberance_min, self.reverberance_max + 1) 39 | damping = np.random.randint(self.damping_min, self.damping_max + 1) 40 | room_scale = np.random.randint(self.room_scale_min, self.room_scale_max + 1) 41 | 42 | return [reverberance, damping, room_scale] 43 | 44 | class SpecAugmentBand: 45 | def __init__(self, sampling_rate, scaler): 46 | self.sampling_rate = sampling_rate 47 | self.scaler = scaler 48 | 49 | @staticmethod 50 | def freq2mel(f): 51 | return 2595. * np.log10(1 + f / 700) 52 | 53 | @staticmethod 54 | def mel2freq(m): 55 | return ((10.**(m / 2595.) - 1) * 700) 56 | 57 | def __call__(self): 58 | F = 27.0 * self.scaler 59 | melfmax = freq2mel(self.sample_rate / 2) 60 | meldf = np.random.uniform(0, melfmax * F / 256.) 61 | melf0 = np.random.uniform(0, melfmax - meldf) 62 | low = mel2freq(melf0) 63 | high = mel2freq(melf0 + meldf) 64 | return f'{high}-{low}' 65 | 66 | 67 | def augmentation_factory(description, sampling_rate, args): 68 | chain = augment.EffectChain() 69 | description = description.split(',') 70 | 71 | for effect in description: 72 | if effect == 'bandreject': 73 | chain = chain.sinc('-a', '120', SpecAugmentBand(sampling_rate, args.band_scaler)) 74 | elif effect == 'pitch': 75 | pitch_randomizer = RandomPitchShift(args.pitch_shift_max) 76 | if args.pitch_quick: 77 | chain = chain.pitch('-q', pitch_randomizer).rate('-q', sampling_rate) 78 | else: 79 | chain = chain.pitch(pitch_randomizer).rate(sampling_rate) 80 | elif effect == 'reverb': 81 | randomized_params = RandomReverb(args.reverberance_min, args.reverberance_max, 82 | args.damping_min, args.damping_max, args.room_scale_min, args.room_scale_max) 83 | chain = chain.reverb(randomized_params).channels() 84 | elif effect == 'time_drop': 85 | chain = chain.time_dropout(max_seconds=args.t_ms / 1000.0) 86 | elif effect == 'clip': 87 | chain = chain.clip(RandomClipFactor(args.clip_min, args.clip_max)) 88 | elif effect == 'none': 89 | pass 90 | else: 91 | raise RuntimeError(f'Unknown augmentation type {effect}') 92 | return chain 93 | 94 | 95 | def get_args(): 96 | parser = argparse.ArgumentParser() 97 | parser.add_argument('--input_file', type=str, help='File to procecss') 98 | parser.add_argument('--output_file', type=str, help='Output file') 99 | 100 | parser.add_argument('--chain', type=str, help='Comma-separated list of effects to apply, e.g. "pitch,dropout"', 101 | default='none') 102 | 103 | parser.add_argument('--t_ms', type=int, help='Size of a time dropout sequence', default=50) 104 | 105 | parser.add_argument('--pitch_shift_max', type=int, help='Amplitude of a pitch shift; measured in 1/100th of a tone', default=300) 106 | parser.add_argument('--pitch_quick', action='store_true', help='Speech up the pitch effect at some quality cost') 107 | 108 | parser.add_argument('--room_scale_min', type=int, help='Minimal room size used in randomized reverb (0..100)', default=0) 109 | parser.add_argument('--room_scale_max', type=int, help='Maximal room size used in randomized reverb (0..100)', default=100) 110 | parser.add_argument('--reverberance_min', type=int, help='Minimal reverberance used in randomized reverb (0..100)', default=50) 111 | parser.add_argument('--reverberance_max', type=int, help='Maximal reverberance used in randomized reverb (0..100)', default=50) 112 | parser.add_argument('--damping_min', type=int, help='Minimal damping used in randomized reverb (0..100)', default=50) 113 | parser.add_argument('--damping_max', type=int, help='Maximal damping used in randomized reverb (0..100)', default=50) 114 | parser.add_argument('--clip_min', type=float, help='Minimal clip factor (0.0..1.0)', default=0.5) 115 | parser.add_argument('--clip_max', type=float, help='Maximal clip factor (0.0..1.0)', default=1.0) 116 | 117 | args = parser.parse_args() 118 | args.chain = args.chain.lower() 119 | 120 | if not args.input_file or not args.output_file: 121 | raise RuntimeError('You need to specify "--input_file" and "--output_file"') 122 | 123 | if not (0 <= args.room_scale_min <= args.room_scale_max <= 100): 124 | raise RuntimeError('It should be that 0 <= room_scale_min <= room_scale_max <= 100') 125 | 126 | if not (0 <= args.reverberance_min <= args.reverberance_max <= 100): 127 | raise RuntimeError('It should be that 0 <= reverberance_min <= reverberance_max <= 100') 128 | 129 | if not (0 <= args.damping_min <= args.damping_max <= 100): 130 | raise RuntimeError('It should be that 0 <= damping_min <= damping_max <= 100') 131 | 132 | if not (0.0 <= args.clip_min <= args.clip_max <= 1.0): 133 | raise RuntimeError('It should be that 0 <= clip_min <= clip_max <= 1.0') 134 | 135 | return args 136 | 137 | if __name__ == '__main__': 138 | args = get_args() 139 | 140 | x, sampling_rate = torchaudio.load(args.input_file) 141 | augmentation_chain = augmentation_factory(args.chain, sampling_rate, args) 142 | 143 | y = augmentation_chain.apply(x, 144 | src_info=dict(rate=sampling_rate, length=x.size(1), channels=x.size(0)), 145 | target_info=dict(rate=sampling_rate, length=0) 146 | ) 147 | 148 | torchaudio.save(args.output_file, y, sampling_rate) 149 | 150 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchaudio>=0.7 3 | pytest 4 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree 5 | 6 | import os 7 | import subprocess 8 | 9 | from setuptools import setup, find_packages 10 | 11 | # Creating the version file 12 | cwd = os.path.dirname(os.path.abspath(__file__)) 13 | version = '0.2' 14 | sha = 'Unknown' 15 | 16 | try: 17 | sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() 18 | except Exception: 19 | pass 20 | 21 | setup( 22 | name="augment", 23 | version=version, 24 | # Exclude the build files. 25 | packages=find_packages(exclude=["build"]), 26 | install_requires=['torch', 'torchaudio'] 27 | ) 28 | -------------------------------------------------------------------------------- /tests/augment_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | import warnings 9 | import tempfile 10 | import subprocess 11 | import torchaudio 12 | import pathlib 13 | import numpy as np 14 | 15 | import augment 16 | 17 | test_wav = pathlib.Path(__file__).parent / 'test.wav' 18 | assert test_wav.exists() 19 | 20 | def test_empty_chain(): 21 | x = torch.arange(0, 8000).float() 22 | 23 | src_info = {'channels': 1, 24 | 'length': x.size(0), 25 | 'precision': 32, 26 | 'rate': 16000.0, 27 | 'bits_per_sample': 32} 28 | 29 | target_info = {'channels': 1, 30 | 'length': 0, 31 | 'precision': 32, 32 | 'rate': 16000.0, 33 | 'bits_per_sample': 32} 34 | 35 | y = augment.EffectChain().apply( 36 | x, src_info=src_info, target_info=target_info) 37 | 38 | assert x.view(-1).allclose(y.view(-1)) 39 | 40 | 41 | def test_non_empty_chain(): 42 | x, sr = torchaudio.load(test_wav) 43 | 44 | src_info = {'channels': 1, 45 | 'length': x.size(1), 46 | 'precision': 32, 47 | 'rate': 16000.0, 48 | 'bits_per_sample': 32} 49 | 50 | target_info = {'channels': 1, 51 | 'length': 0, 52 | 'precision': 32, 53 | 'rate': 16000.0, 54 | 'bits_per_sample': 32} 55 | 56 | effects = augment.EffectChain().bandreject(1, 20000) 57 | 58 | y = effects.apply(x, src_info=src_info, 59 | target_info=target_info) 60 | 61 | assert x.size() == y.size(), f'{y.size()}' 62 | assert not x.allclose(y) 63 | 64 | def convert_pitch_cl(test_wav): 65 | with tempfile.NamedTemporaryFile(suffix='.wav') as t_file: 66 | output_name = t_file.name 67 | res = subprocess.run( 68 | ['sox', str(test_wav), output_name, 'pitch', '100']) 69 | assert res.returncode == 0 70 | 71 | y, sr = torchaudio.load(output_name) 72 | return y, sr 73 | 74 | 75 | def convert_pitch_augment(test_wav): 76 | x, sr = torchaudio.load(test_wav) 77 | 78 | assert sr == 16000 79 | 80 | src_info = {'channels': x.size(0), 81 | 'length': x.size(1), 82 | 'precision': 32, 83 | 'rate': 16000.0, 84 | 'bits_per_sample': 32} 85 | 86 | target_info = {'channels': 1, 87 | 'length': 0, 88 | 'precision': 32, 89 | 'rate': 16000.0, 90 | 'bits_per_sample': 32} 91 | 92 | y = augment.EffectChain().pitch(100).rate(16000).apply( 93 | x, src_info=src_info, target_info=target_info) 94 | return y, sr 95 | 96 | def test_agains_cl(): 97 | y1, _ = convert_pitch_cl(test_wav) 98 | y2, _ = convert_pitch_augment(test_wav) 99 | 100 | assert y1.size() == y2.size() 101 | 102 | # NB: higher tolerance due to all the discretization done on save/load 103 | assert torch.allclose(y1, y2, rtol=1e-3, atol=1e-3) 104 | 105 | # just to make sure something is happening 106 | x, sr = torchaudio.load(test_wav) 107 | assert not torch.allclose(x, y2, rtol=1e-3, atol=1e-3) 108 | 109 | def test_stochastic_pitch(): 110 | x, sr = torchaudio.load(test_wav) 111 | 112 | assert sr == 16000 113 | 114 | src_info = {'channels': x.size(0), 115 | 'length': x.size(1), 116 | 'precision': 32, 117 | 'rate': 16000.0, 118 | 'bits_per_sample': 32} 119 | 120 | target_info = {'channels': 1, 121 | 'length': 0, 122 | 'precision': 32, 123 | 'rate': 16000.0, 124 | 'bits_per_sample': 32} 125 | 126 | def random_pitch(): return np.random.randint(100, 500) 127 | y = augment.EffectChain().pitch(random_pitch).rate(16000).apply( 128 | x, src_info=src_info, target_info=target_info) 129 | assert not torch.allclose(x, y, rtol=1e-3, atol=1e-3) 130 | 131 | 132 | def test_additive_noise(): 133 | x, sr = torchaudio.load(test_wav) 134 | 135 | noise = torch.zeros_like(x) 136 | 137 | src_info = {'channels': 1, 138 | 'length': x.size(1), 139 | 'precision': 32, 140 | 'rate': 16000.0, 141 | 'bits_per_sample': 32} 142 | 143 | target_info = {'channels': 1, 144 | 'length': 0, 145 | 'precision': 32, 146 | 'rate': 16000.0, 147 | 'bits_per_sample': 32} 148 | 149 | y = augment.EffectChain() \ 150 | .additive_noise(noise_generator=lambda: x, snr=10.0) \ 151 | .apply(x, src_info=src_info, target_info=target_info) 152 | 153 | assert torch.allclose(x, y) 154 | 155 | def test_number_effects(): 156 | assert len(augment.EffectChain.KNOWN_EFFECTS) == 61 157 | -------------------------------------------------------------------------------- /tests/compare_to_sox_test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. 2 | # 3 | # This source code is licensed under the MIT license found in the 4 | # LICENSE file in the root directory of this source tree. 5 | 6 | 7 | import torch 8 | import warnings 9 | import tempfile 10 | import subprocess 11 | import torchaudio 12 | import pathlib 13 | import numpy as np 14 | 15 | import augment 16 | 17 | test_wav = pathlib.Path(__file__).parent / 'test.wav' 18 | assert test_wav.exists() 19 | 20 | 21 | def run_sox_command(test_wav, cl_parameters): 22 | with tempfile.NamedTemporaryFile(suffix='.wav') as t_file: 23 | output_name = t_file.name 24 | res = subprocess.run( 25 | ['sox', str(test_wav), output_name] + cl_parameters 26 | ) 27 | assert res.returncode == 0 28 | 29 | y, sr = torchaudio.load(output_name) 30 | return y, sr 31 | 32 | 33 | def apply_chain(test_wav, chain): 34 | x, sr = torchaudio.load(test_wav) 35 | 36 | assert sr == 16000 37 | 38 | src_info = {'channels': x.size(0), 39 | 'length': x.size(1), 40 | 'precision': 32, 41 | 'rate': 16000.0, 42 | 'bits_per_sample': 32} 43 | 44 | target_info = {'channels': 1, 45 | 'length': 0, 46 | 'precision': 32, 47 | 'rate': 16000.0, 48 | 'bits_per_sample': 32} 49 | 50 | y = chain.apply( 51 | x, src_info=src_info, target_info=target_info) 52 | return y 53 | 54 | def test_pitch(): 55 | y1, _ = run_sox_command(test_wav, ["pitch", "-100"]) 56 | 57 | chain = augment.EffectChain().pitch(-100).rate(16000) 58 | y2 = apply_chain(test_wav, chain) 59 | 60 | assert y1.size() == y2.size() 61 | 62 | # NB: higher tolerance due to all the discretization done on save/load 63 | assert torch.allclose(y1, y2, rtol=1e-4, atol=1e-4) 64 | 65 | 66 | def test_reverb(): 67 | y1, _ = run_sox_command(test_wav, ["reverb", "50", "50", "100"]) 68 | 69 | chain = augment.EffectChain().reverb(50, 50, 100).channels() 70 | y2 = apply_chain(test_wav, chain) 71 | 72 | assert y1.size() == y2.size() 73 | 74 | # NB: higher tolerance due to all the discretization done on save/load 75 | assert torch.allclose(y1, y2, rtol=1e-4, atol=1e-4) 76 | 77 | 78 | def test_bandreject(): 79 | y1, _ = run_sox_command(test_wav, ["sinc", "-a", "120", "2000-1000"]) 80 | 81 | chain = augment.EffectChain().sinc("-a", "120", "2000-1000") 82 | y2 = apply_chain(test_wav, chain) 83 | 84 | assert y1.size() == y2.size() 85 | 86 | # NB: higher tolerance due to all the discretization done on save/load 87 | assert torch.allclose(y1, y2, rtol=1e-4, atol=1e-4) 88 | -------------------------------------------------------------------------------- /tests/test.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/WavAugment/54afcdb00ccc852c2f030f239f8532c9562b550e/tests/test.wav --------------------------------------------------------------------------------