├── .github └── workflows │ ├── deploy.yaml │ └── test.yaml ├── .gitignore ├── 00_core.ipynb ├── 01_datasets.ipynb ├── 02_viz.ipynb ├── 03_chunkadelic.ipynb ├── 03b_chunk_one_file.ipynb ├── 04_spectrofu.ipynb ├── 05_hpc.ipynb ├── LICENSE ├── MANIFEST.in ├── README.md ├── _quarto.yml ├── aeiou ├── __init__.py ├── _modidx.py ├── chunk_one_file.py ├── chunkadelic.py ├── core.py ├── datasets.py ├── hpc.py ├── spectrofu.py └── viz.py ├── examples ├── accel_config.yaml ├── example.wav └── stereo_pewpew.mp3 ├── index.ipynb ├── nbdev.yml ├── settings.ini ├── setup.py ├── sidebar.yml └── styles.css /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | on: 3 | push: 4 | branches: [ "main", "master" ] 5 | workflow_dispatch: 6 | jobs: 7 | deploy: 8 | runs-on: ubuntu-latest 9 | steps: [run: sudo apt-get update; sudo apt-get install ffmpeg libsndfile-dev, uses: fastai/workflows/quarto-ghp@master] 10 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [workflow_dispatch, pull_request, push] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: [run: sudo apt-get update; sudo apt-get install ffmpeg libsndfile-dev, uses: fastai/workflows/nbdev-ci@master, ] 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | _docs/ 2 | _proc/ 3 | 4 | *.bak 5 | .gitattributes 6 | .last_checked 7 | .gitconfig 8 | *.bak 9 | *.log 10 | *~ 11 | ~* 12 | _tmp* 13 | tmp* 14 | tags 15 | *.pkg 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | env/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # dotenv 99 | .env 100 | 101 | # virtualenv 102 | .venv 103 | venv/ 104 | ENV/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | .vscode 120 | *.swp 121 | 122 | # osx generated files 123 | .DS_Store 124 | .DS_Store? 125 | .Trashes 126 | ehthumbs.db 127 | Thumbs.db 128 | .idea 129 | 130 | # pytest 131 | .pytest_cache 132 | 133 | # tools/trust-doc-nbs 134 | docs_src/.last_checked 135 | 136 | # symlinks to fastai 137 | docs_src/fastai 138 | tools/fastai 139 | 140 | # link checker 141 | checklink/cookies.txt 142 | 143 | # .gitconfig is now autogenerated 144 | .gitconfig 145 | 146 | # Quarto installer 147 | .deb 148 | .pkg 149 | 150 | # Quarto 151 | .quarto 152 | 153 | /.quarto/ 154 | -------------------------------------------------------------------------------- /03_chunkadelic.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d404833b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "9d2cefc8", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "#| default_exp chunkadelic" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "5d79fbe5", 27 | "metadata": {}, 28 | "source": [ 29 | "# chunkadelic" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "262d835a", 35 | "metadata": {}, 36 | "source": [ 37 | "\n", 38 | "> Console script and callable function for preprocessing dataset of disparate-sized audio files into uniform chunks\n", 39 | "\n", 40 | "Note: Duplicates the directory structure(s) referenced by input paths. \n", 41 | "\n", 42 | "\n", 43 | "```bash\n", 44 | "$ chunkadelic -h \n", 45 | "usage: chunkadelic [-h] [--chunk_size CHUNK_SIZE] [--sr SR] [--norm [{False,global,channel}]] [--spacing SPACING] [--strip] [--thresh THRESH]\n", 46 | " [--bits BITS] [--workers WORKERS] [--nomix] [--nopad] [--verbose] [--debug]\n", 47 | " output_path input_paths [input_paths ...]\n", 48 | "\n", 49 | "positional arguments:\n", 50 | " output_path Path of output for chunkified data\n", 51 | " input_paths Path(s) of a file or a folder of files. (recursive)\n", 52 | "\n", 53 | "options:\n", 54 | " -h, --help show this help message and exit\n", 55 | " --chunk_size CHUNK_SIZE\n", 56 | " Length of chunks (default: 131072)\n", 57 | " --sr SR Output sample rate (default: 48000)\n", 58 | " --norm [{False,global,channel}]\n", 59 | " Normalize audio, based on the max of the absolute value [global/channel/False] (default: False)\n", 60 | " --spacing SPACING Spacing factor, advance this fraction of a chunk per copy (default: 0.5)\n", 61 | " --strip Strips silence: chunks with max dB below are not outputted (default: False)\n", 62 | " --thresh THRESH threshold in dB for determining what constitutes silence (default: -70)\n", 63 | " --bits BITS Bit depth: \"None\" uses torchaudio default | \"match\"=match input audio files | or specify an int (default: None)\n", 64 | " --workers WORKERS Maximum number of workers to use (default: all) (default: 20)\n", 65 | " --nomix (BDCT Dataset specific) exclude output of \"*/Audio Files/*Mix*\" (default: False)\n", 66 | " --nopad Disable zero padding for audio shorter than chunk_size (default: False)\n", 67 | " --verbose Extra output logging (default: False)\n", 68 | " --debug Extra EXTRA output logging (default: False)\n", 69 | "```" 70 | ] 71 | }, 72 | { 73 | "cell_type": "code", 74 | "execution_count": null, 75 | "id": "0be2e849", 76 | "metadata": {}, 77 | "outputs": [], 78 | "source": [ 79 | "#|hide\n", 80 | "from nbdev.showdoc import *" 81 | ] 82 | }, 83 | { 84 | "cell_type": "code", 85 | "execution_count": null, 86 | "id": "5df76a87", 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "#|export\n", 91 | "import argparse \n", 92 | "import os \n", 93 | "from functools import partial\n", 94 | "from tqdm.contrib.concurrent import process_map \n", 95 | "import torch\n", 96 | "import torchaudio\n", 97 | "import math\n", 98 | "from aeiou.core import is_silence, load_audio, makedir, get_audio_filenames, normalize_audio, get_dbmax\n", 99 | "import multiprocessing as mp\n", 100 | "from multiprocessing import Pool, cpu_count, Barrier" 101 | ] 102 | }, 103 | { 104 | "cell_type": "code", 105 | "execution_count": null, 106 | "id": "519f8740", 107 | "metadata": {}, 108 | "outputs": [], 109 | "source": [ 110 | "#|export\n", 111 | "def blow_chunks(\n", 112 | " audio:torch.tensor, # long audio file to be chunked\n", 113 | " new_filename:str, # stem of new filename(s) to be output as chunks\n", 114 | " chunk_size:int, # how big each audio chunk is, in samples\n", 115 | " sr=48000, # audio sample rate in Hz\n", 116 | " norm='False', # normalize input audio, based on the max of the absolute value ['global','channel', or anything else for None, e.g. False]\n", 117 | " spacing=0.5, # fraction of each chunk to advance between hops\n", 118 | " strip=False, # strip silence: chunks with max power in dB below this value will not be saved to files\n", 119 | " thresh=-70, # threshold in dB for determining what counts as silence\n", 120 | " bits_per_sample=None, # kwarg for torchaudio.save, None means use defaults\n", 121 | " nopad=False, # disable zero-padding, allowing samples to be shorter than chunk_size (including \"leftovers\" on the \"ends\")\n", 122 | " debug=False, # print debugging information \n", 123 | " ):\n", 124 | " \"chunks up the audio and saves them with --{i} on the end of each chunk filename\"\n", 125 | " if (debug): print(f\" blow_chunks: audio.shape = {audio.shape}\",flush=True)\n", 126 | " \n", 127 | " #chunk = torch.zeros(audio.shape[0], chunk_size) \n", 128 | " _, ext = os.path.splitext(new_filename)\n", 129 | " \n", 130 | " if norm in ['global','channel']: audio = normalize_audio(audio, norm) \n", 131 | "\n", 132 | " spacing = 0.5 if spacing == 0 else spacing # handle degenerate case as a request for the defaults\n", 133 | " \n", 134 | " start, i = 0, 0\n", 135 | " while start < audio.shape[-1]:\n", 136 | " out_filename = new_filename.replace(ext, f'--{i}'+ext) \n", 137 | " end = min(start + chunk_size, audio.shape[-1])\n", 138 | " if (end-start < chunk_size) and not nopad: # audio shorter than chunk_size: pad with zeros\n", 139 | " chunk = torch.zeros(audio.shape[0], chunk_size) \n", 140 | " chunk[:,0:end-start] = audio[:,start:end]\n", 141 | " else:\n", 142 | " chunk = audio[:,start:end]\n", 143 | " if (not strip) or (not is_silence(chunk, thresh=thresh)):\n", 144 | " if debug: print(f\" Saving output chunk {out_filename}, bits_per_sample={bits_per_sample}, chunk.shape={chunk.shape}\", flush=True)\n", 145 | " torchaudio.save(out_filename, chunk, sr, bits_per_sample=bits_per_sample)\n", 146 | " else:\n", 147 | " print(f\"Skipping chunk {out_filename} because it's 'silent' (below threhold of {thresh} dB).\",flush=True)\n", 148 | " start, i = start + int(spacing * chunk_size), i + 1\n", 149 | " return " 150 | ] 151 | }, 152 | { 153 | "cell_type": "code", 154 | "execution_count": null, 155 | "id": "701f4ae2-aa3b-4763-8bc8-4883fe22d586", 156 | "metadata": {}, 157 | "outputs": [], 158 | "source": [ 159 | "#|export\n", 160 | "def set_bit_rate(bits, filename, debug=False):\n", 161 | " if (bits is None) or isinstance(bits, int): bits_per_sample = bits\n", 162 | " elif bits.lower()=='none': \n", 163 | " bits_per_sample = None # use torchaudio default \n", 164 | " elif bits.lower()=='match':\n", 165 | " try:\n", 166 | " bits_per_sample = torchaudio.info(filename).bits_per_sample\n", 167 | " except Exception as e:\n", 168 | " print(\" Error with bits=match: Can't get audio medatadata. Choosing default=None\")\n", 169 | " bits_per_sample=None\n", 170 | " else:\n", 171 | " bits_per_sample = int(bits)\n", 172 | " if debug: print(\" set_bit_rate: bits_per_sample =\",bits_per_sample,flush=True)\n", 173 | " return bits_per_sample" 174 | ] 175 | }, 176 | { 177 | "cell_type": "code", 178 | "execution_count": null, 179 | "id": "2fc2b499", 180 | "metadata": {}, 181 | "outputs": [], 182 | "source": [ 183 | "#|export \n", 184 | "def chunk_one_file(\n", 185 | " filenames:list, # list of filenames from which we'll pick one\n", 186 | " args, # output of argparse\n", 187 | " file_ind # index from filenames list to read from\n", 188 | " ):\n", 189 | " \"this chunks up one file by setting things up and then calling blow_chunks\"\n", 190 | " filename = filenames[file_ind] # this is actually input_path+/+filename\n", 191 | " output_path, input_paths = args.output_path, args.input_paths\n", 192 | " new_filename = None\n", 193 | " if args.debug: print(f\" --- process_one_file: filenames[{file_ind}] = {filename}\\n\", flush=True)\n", 194 | " \n", 195 | " for ipath in input_paths: # set up the output filename & any folders it needs\n", 196 | " if args.nomix and ('Mix' in ipath) and ('Audio Files' in ipath): return # this is specific to the BDCT dataset, otherwise ignore\n", 197 | " if ipath in filename:\n", 198 | " last_ipath = ipath.split('/')[-1] # get the last part of ipath\n", 199 | " clean_filename = filename.replace(ipath,'') # remove all of ipath from the front of filename\n", 200 | " new_filename = f\"{output_path}/{last_ipath}/{clean_filename}\".replace('//','/') \n", 201 | " makedir(os.path.dirname(new_filename)) # we might need to make a directory for the output file\n", 202 | " break\n", 203 | "\n", 204 | " if new_filename is None:\n", 205 | " print(f\"ERROR: Something went wrong with name of input file {filename}. Skipping.\",flush=True) \n", 206 | " return \n", 207 | " \n", 208 | " try: # try to load the audio file and chunk it up\n", 209 | " if args.debug: print(f\" About to load filenames[{file_ind}] = {filename}\\n\", flush=True)\n", 210 | " audio = load_audio(filename, sr=args.sr, verbose=args.debug)\n", 211 | " if args.debug: print(f\" We loaded the audio, audio.shape = {audio.shape}. Setting bit rate.\",flush=True) \n", 212 | " bits_per_sample = set_bit_rate(args.bits, filename, debug=args.debug)\n", 213 | " if args.debug: print(f\" Bit rate set. Calling blow_chunks...\", flush=True)\n", 214 | " blow_chunks(audio, new_filename, args.chunk_size, sr=args.sr, spacing=args.spacing, \n", 215 | " strip=args.strip, thresh=args.thresh, bits_per_sample=bits_per_sample, nopad=args.nopad, debug=args.debug)\n", 216 | " except Exception as e: \n", 217 | " print(f\"Error '{e}' while loading {filename} or writing chunks. Skipping.\", flush=True)\n", 218 | "\n", 219 | " if args.debug: print(f\" --- File {file_ind}: {filename} completed.\\n\", flush=True)\n", 220 | " return" 221 | ] 222 | }, 223 | { 224 | "cell_type": "markdown", 225 | "id": "f3c24fb9", 226 | "metadata": {}, 227 | "source": [ 228 | "Testing sequential execution of for one file at a time:" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "208bdaac", 235 | "metadata": {}, 236 | "outputs": [ 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "filenames = ['examples/stereo_pewpew.mp3', 'examples/example.wav']\n", 242 | "file 1/2: examples/stereo_pewpew.mp3:\n", 243 | " --- process_one_file: filenames[0] = examples/stereo_pewpew.mp3\n", 244 | "\n", 245 | " About to load filenames[0] = examples/stereo_pewpew.mp3\n", 246 | "\n", 247 | "Resampling examples/stereo_pewpew.mp3 from 44100.0 Hz to 48000 Hz\n", 248 | " We loaded the audio, audio.shape = torch.Size([2, 234505]). Setting bit rate.\n", 249 | " Error with bits=match: Can't get audio medatadata. Choosing default=None\n", 250 | " set_bit_rate: bits_per_sample = None\n", 251 | " Bit rate set. Calling blow_chunks...\n", 252 | " blow_chunks: audio.shape = torch.Size([2, 234505])\n", 253 | " Saving output chunk test_chunks/stereo_pewpew--0.mp3, bits_per_sample=None, chunk.shape=torch.Size([2, 131072])\n", 254 | " Saving output chunk test_chunks/stereo_pewpew--1.mp3, bits_per_sample=None, chunk.shape=torch.Size([2, 131072])\n", 255 | " Saving output chunk test_chunks/stereo_pewpew--2.mp3, bits_per_sample=None, chunk.shape=torch.Size([2, 103433])\n", 256 | " Saving output chunk test_chunks/stereo_pewpew--3.mp3, bits_per_sample=None, chunk.shape=torch.Size([2, 37897])\n", 257 | " --- File 0: examples/stereo_pewpew.mp3 completed.\n", 258 | "\n", 259 | "file 2/2: examples/example.wav:\n", 260 | " --- process_one_file: filenames[1] = examples/example.wav\n", 261 | "\n", 262 | " About to load filenames[1] = examples/example.wav\n", 263 | "\n", 264 | "Resampling examples/example.wav from 44100 Hz to 48000 Hz\n", 265 | " We loaded the audio, audio.shape = torch.Size([1, 55728]). Setting bit rate.\n", 266 | " set_bit_rate: bits_per_sample = 16\n", 267 | " Bit rate set. Calling blow_chunks...\n", 268 | " blow_chunks: audio.shape = torch.Size([1, 55728])\n", 269 | " Saving output chunk test_chunks/example--0.wav, bits_per_sample=16, chunk.shape=torch.Size([1, 55728])\n", 270 | " --- File 1: examples/example.wav completed.\n", 271 | "\n" 272 | ] 273 | } 274 | ], 275 | "source": [ 276 | "#| eval: false\n", 277 | "class AttrDict(dict): # cf. https://stackoverflow.com/a/14620633/4259243\n", 278 | " \"setup an object to hold args\"\n", 279 | " def __init__(self, *args, **kwargs):\n", 280 | " super(AttrDict, self).__init__(*args, **kwargs)\n", 281 | " self.__dict__ = self\n", 282 | " \n", 283 | "args = AttrDict() # setup something akin to what argparse gives\n", 284 | "args.update( {'output_path':'test_chunks', 'input_paths':['examples/'], 'sr':48000, 'chunk_size':131072, 'spacing':0.5,\n", 285 | " 'norm':'global', 'strip':False, 'thresh':-70, 'nomix':False, 'verbose':True, 'nopad':True,\n", 286 | " 'workers':min(32, os.cpu_count() + 4), 'debug':True, 'bits':'match' })\n", 287 | "\n", 288 | "filenames = get_audio_filenames(args.input_paths)\n", 289 | "print(\"filenames =\",filenames)\n", 290 | "for i in range(len(filenames)):\n", 291 | " print(f\"file {i+1}/{len(filenames)}: {filenames[i]}:\")\n", 292 | " chunk_one_file(filenames, args, i)" 293 | ] 294 | }, 295 | { 296 | "cell_type": "markdown", 297 | "id": "a0b0e456", 298 | "metadata": {}, 299 | "source": [ 300 | "The main executable `chunkadelic` does the same as the previous sequential execution, albeit in parallel. \n", 301 | "\n", 302 | "> Note: Restrictions in Python's `ProcessPoolExecutor` prevent directly invoking parallel execution of `chunk_one_file` while in interactive mode or inside a Jupyter notebook: You must use the CLI (or subprocess it). " 303 | ] 304 | }, 305 | { 306 | "cell_type": "code", 307 | "execution_count": null, 308 | "id": "56e01f54", 309 | "metadata": {}, 310 | "outputs": [], 311 | "source": [ 312 | "#|export\n", 313 | "def main():\n", 314 | " parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n", 315 | " parser.add_argument('--chunk_size', type=int, default=2**17, help='Length of chunks')\n", 316 | " parser.add_argument('--sr', type=int, default=48000, help='Output sample rate')\n", 317 | " parser.add_argument('--norm', default='False', const='False', nargs='?', choices=['False', 'global', 'channel'],\n", 318 | " help='Normalize audio, based on the max of the absolute value [global/channel/False]')\n", 319 | " parser.add_argument('--spacing', type=float, default=0.5, help='Spacing factor, advance this fraction of a chunk per copy')\n", 320 | " parser.add_argument('--strip', action='store_true', help='Strips silence: chunks with max dB below are not outputted')\n", 321 | " parser.add_argument('--thresh', type=int, default=-70, help='threshold in dB for determining what constitutes silence')\n", 322 | " parser.add_argument('--bits', type=str, default='None', help='Bit depth: \"None\" uses torchaudio default | \"match\"=match input audio files | or specify an int')\n", 323 | " parser.add_argument('--workers', type=int, default=min(32, os.cpu_count() + 4), help='Maximum number of workers to use (default: all)')\n", 324 | " parser.add_argument('--nomix', action='store_true', help='(BDCT Dataset specific) exclude output of \"*/Audio Files/*Mix*\"')\n", 325 | " parser.add_argument('--nopad', action='store_true', help='Disable zero padding for audio shorter than chunk_size')\n", 326 | " parser.add_argument('output_path', help='Path of output for chunkified data')\n", 327 | " parser.add_argument('input_paths', nargs='+', help='Path(s) of a file or a folder of files. (recursive)')\n", 328 | " parser.add_argument('--verbose', action='store_true', help='Extra output logging')\n", 329 | " parser.add_argument('--debug', action='store_true', help='Extra EXTRA output logging')\n", 330 | " args = parser.parse_args()\n", 331 | " \n", 332 | " if args.verbose: \n", 333 | " print(\"chunkadelic: args = \",args)\n", 334 | " print(\"Getting list of input filenames\")\n", 335 | " filenames = get_audio_filenames(args.input_paths)\n", 336 | " if args.verbose:\n", 337 | " print(f\" Got {len(filenames)} input filenames\") \n", 338 | " if not (args.norm in ['global','channel']): \n", 339 | " print(f\"Warning: since norm = {args.norm}, no normalizations will be performed.\")\n", 340 | " print(\"Processing files (in parallel)...\")\n", 341 | " \n", 342 | " wrapper = partial(chunk_one_file, filenames, args)\n", 343 | " r = process_map(wrapper, range(len(filenames)), chunksize=1, max_workers=args.workers) # different chunksize used by tqdm. max_workers is to avoid annoying other ppl\n", 344 | " \n", 345 | " if args.verbose: print(\"Finished\") " 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "id": "1bf27260", 352 | "metadata": {}, 353 | "outputs": [], 354 | "source": [ 355 | "#| hide\n", 356 | "from nbdev import nbdev_export\n", 357 | "nbdev_export()" 358 | ] 359 | }, 360 | { 361 | "cell_type": "markdown", 362 | "id": "0ecd62d3", 363 | "metadata": {}, 364 | "source": [ 365 | "\n", 366 | "---\n", 367 | "Testing of CLI run: (don't run this on GitHub CI or it will hang)" 368 | ] 369 | }, 370 | { 371 | "cell_type": "code", 372 | "execution_count": null, 373 | "id": "af8472be", 374 | "metadata": {}, 375 | "outputs": [ 376 | { 377 | "name": "stdout", 378 | "output_type": "stream", 379 | "text": [ 380 | "Traceback (most recent call last):\n", 381 | " File \u001b[35m\"/Users/shawley/envs/aeiou2/bin/chunkadelic\"\u001b[0m, line \u001b[35m5\u001b[0m, in \u001b[35m\u001b[0m\n", 382 | " from aeiou.chunkadelic import main\n", 383 | " File \u001b[35m\"/Users/shawley/github/aeiou/aeiou/chunkadelic.py\"\u001b[0m, line \u001b[35m14\u001b[0m, in \u001b[35m\u001b[0m\n", 384 | " from .core import is_silence, load_audio, makedir, get_audio_filenames, normalize_audio, get_dbmax\n", 385 | " File \u001b[35m\"/Users/shawley/github/aeiou/aeiou/core.py\"\u001b[0m, line \u001b[35m14\u001b[0m, in \u001b[35m\u001b[0m\n", 386 | " from librosa import load as lr_load\n", 387 | "\u001b[1;35mModuleNotFoundError\u001b[0m: \u001b[35mNo module named 'librosa'\u001b[0m\n" 388 | ] 389 | } 390 | ], 391 | "source": [ 392 | "#| eval: false\n", 393 | "! chunkadelic -h" 394 | ] 395 | }, 396 | { 397 | "cell_type": "code", 398 | "execution_count": null, 399 | "id": "b7ff61bf", 400 | "metadata": {}, 401 | "outputs": [], 402 | "source": [ 403 | "#| eval: false\n", 404 | "import subprocess" 405 | ] 406 | }, 407 | { 408 | "cell_type": "code", 409 | "execution_count": null, 410 | "id": "9009283f", 411 | "metadata": {}, 412 | "outputs": [ 413 | { 414 | "name": "stdout", 415 | "output_type": "stream", 416 | "text": [ 417 | "chunkadelic: running tests of normalizations:\n", 418 | "-----\n", 419 | " norm = False\n" 420 | ] 421 | }, 422 | { 423 | "name": "stderr", 424 | "output_type": "stream", 425 | "text": [ 426 | "Traceback (most recent call last):\n", 427 | " File \u001b[35m\"/Users/shawley/envs/aeiou2/bin/chunkadelic\"\u001b[0m, line \u001b[35m5\u001b[0m, in \u001b[35m\u001b[0m\n", 428 | " from aeiou.chunkadelic import main\n", 429 | " File \u001b[35m\"/Users/shawley/github/aeiou/aeiou/chunkadelic.py\"\u001b[0m, line \u001b[35m14\u001b[0m, in \u001b[35m\u001b[0m\n", 430 | " from .core import is_silence, load_audio, makedir, get_audio_filenames, normalize_audio, get_dbmax\n", 431 | " File \u001b[35m\"/Users/shawley/github/aeiou/aeiou/core.py\"\u001b[0m, line \u001b[35m14\u001b[0m, in \u001b[35m\u001b[0m\n", 432 | " from librosa import load as lr_load\n", 433 | "\u001b[1;35mModuleNotFoundError\u001b[0m: \u001b[35mNo module named 'librosa'\u001b[0m\n" 434 | ] 435 | }, 436 | { 437 | "name": "stdout", 438 | "output_type": "stream", 439 | "text": [ 440 | "\n", 441 | "-----\n", 442 | " norm = global\n" 443 | ] 444 | }, 445 | { 446 | "name": "stderr", 447 | "output_type": "stream", 448 | "text": [ 449 | "Traceback (most recent call last):\n", 450 | " File \u001b[35m\"/Users/shawley/envs/aeiou2/bin/chunkadelic\"\u001b[0m, line \u001b[35m5\u001b[0m, in \u001b[35m\u001b[0m\n", 451 | " from aeiou.chunkadelic import main\n", 452 | " File \u001b[35m\"/Users/shawley/github/aeiou/aeiou/chunkadelic.py\"\u001b[0m, line \u001b[35m14\u001b[0m, in \u001b[35m\u001b[0m\n", 453 | " from .core import is_silence, load_audio, makedir, get_audio_filenames, normalize_audio, get_dbmax\n", 454 | " File \u001b[35m\"/Users/shawley/github/aeiou/aeiou/core.py\"\u001b[0m, line \u001b[35m14\u001b[0m, in \u001b[35m\u001b[0m\n", 455 | " from librosa import load as lr_load\n", 456 | "\u001b[1;35mModuleNotFoundError\u001b[0m: \u001b[35mNo module named 'librosa'\u001b[0m\n" 457 | ] 458 | }, 459 | { 460 | "name": "stdout", 461 | "output_type": "stream", 462 | "text": [ 463 | "\n", 464 | "-----\n", 465 | " norm = channel\n", 466 | "\n" 467 | ] 468 | }, 469 | { 470 | "name": "stderr", 471 | "output_type": "stream", 472 | "text": [ 473 | "Traceback (most recent call last):\n", 474 | " File \u001b[35m\"/Users/shawley/envs/aeiou2/bin/chunkadelic\"\u001b[0m, line \u001b[35m5\u001b[0m, in \u001b[35m\u001b[0m\n", 475 | " from aeiou.chunkadelic import main\n", 476 | " File \u001b[35m\"/Users/shawley/github/aeiou/aeiou/chunkadelic.py\"\u001b[0m, line \u001b[35m14\u001b[0m, in \u001b[35m\u001b[0m\n", 477 | " from .core import is_silence, load_audio, makedir, get_audio_filenames, normalize_audio, get_dbmax\n", 478 | " File \u001b[35m\"/Users/shawley/github/aeiou/aeiou/core.py\"\u001b[0m, line \u001b[35m14\u001b[0m, in \u001b[35m\u001b[0m\n", 479 | " from librosa import load as lr_load\n", 480 | "\u001b[1;35mModuleNotFoundError\u001b[0m: \u001b[35mNo module named 'librosa'\u001b[0m\n" 481 | ] 482 | } 483 | ], 484 | "source": [ 485 | "#| eval: false\n", 486 | "print(\"chunkadelic: running tests of normalizations:\")\n", 487 | "for norm in ['False', 'global','channel']:\n", 488 | " print(\"-----\\n norm =\",norm)\n", 489 | " result = subprocess.run(['chunkadelic', '--norm', norm, 'test_chunks','examples/'], stdout=subprocess.PIPE)\n", 490 | " out = result.stdout.decode(\"utf-8\") \n", 491 | " print(out)\n", 492 | " assert 'error' not in out.lower(), f'Error occured while running with norm={norm}' # for CI testing" 493 | ] 494 | }, 495 | { 496 | "cell_type": "code", 497 | "execution_count": null, 498 | "id": "7f793cd7", 499 | "metadata": {}, 500 | "outputs": [], 501 | "source": [] 502 | } 503 | ], 504 | "metadata": { 505 | "kernelspec": { 506 | "display_name": "aa", 507 | "language": "python", 508 | "name": "aa" 509 | } 510 | }, 511 | "nbformat": 4, 512 | "nbformat_minor": 5 513 | } 514 | -------------------------------------------------------------------------------- /03b_chunk_one_file.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "d404833b", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "9d2cefc8", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "#| default_exp chunk_one_file" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "5d79fbe5", 27 | "metadata": {}, 28 | "source": [ 29 | "# chunk_one_file" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "262d835a", 35 | "metadata": {}, 36 | "source": [ 37 | "\n", 38 | "> turns one file into chunks. intended to be called only from `chunkadelic`. See `chunkadelic` for further info.\n", 39 | "\n", 40 | "Note: Duplicates the directory structure(s) referenced by input paths. " 41 | ] 42 | }, 43 | { 44 | "cell_type": "code", 45 | "execution_count": null, 46 | "id": "0be2e849", 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "#|hide\n", 51 | "from nbdev.showdoc import *" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "5df76a87", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "#|export\n", 62 | "import os \n", 63 | "import torch\n", 64 | "import torchaudio\n", 65 | "from aeiou.core import is_silence, load_audio, makedir, get_audio_filenames, normalize_audio, get_dbmax" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "id": "519f8740", 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "#|export\n", 76 | "def blow_chunks(\n", 77 | " audio:torch.tensor, # long audio file to be chunked\n", 78 | " new_filename:str, # stem of new filename(s) to be output as chunks\n", 79 | " chunk_size:int, # how big each audio chunk is, in samples\n", 80 | " sr=48000, # audio sample rate in Hz\n", 81 | " norm='False', # normalize input audio, based on the max of the absolute value ['global','channel', or anything else for None, e.g. False]\n", 82 | " spacing=0.5, # fraction of each chunk to advance between hops\n", 83 | " strip=False, # strip silence: chunks with max power in dB below this value will not be saved to files\n", 84 | " thresh=-70, # threshold in dB for determining what counts as silence\n", 85 | " debug=False, # print debugging information \n", 86 | " ):\n", 87 | " \"chunks up the audio and saves them with --{i} on the end of each chunk filename\"\n", 88 | " if (debug): print(f\" blow_chunks: audio.shape = {audio.shape}\",flush=True)\n", 89 | " \n", 90 | " chunk = torch.zeros(audio.shape[0], chunk_size)\n", 91 | " _, ext = os.path.splitext(new_filename)\n", 92 | " \n", 93 | " if norm in ['global','channel']: audio = normalize_audio(audio, norm) \n", 94 | "\n", 95 | " spacing = 0.5 if spacing == 0 else spacing # handle degenerate case as a request for the defaults\n", 96 | " \n", 97 | " start, i = 0, 0\n", 98 | " while start < audio.shape[-1]:\n", 99 | " out_filename = new_filename.replace(ext, f'--{i}'+ext) \n", 100 | " end = min(start + chunk_size, audio.shape[-1])\n", 101 | " if end-start < chunk_size: # needs zero padding on end\n", 102 | " chunk = torch.zeros(audio.shape[0], chunk_size)\n", 103 | " chunk[:,0:end-start] = audio[:,start:end]\n", 104 | " if (not strip) or (not is_silence(chunk, thresh=thresh)):\n", 105 | " torchaudio.save(out_filename, chunk, sr)\n", 106 | " else:\n", 107 | " print(f\"Skipping chunk {out_filename} because it's 'silent' (below threhold of {thresh} dB).\",flush=True)\n", 108 | " start, i = start + int(spacing * chunk_size), i + 1\n", 109 | " return " 110 | ] 111 | }, 112 | { 113 | "cell_type": "code", 114 | "execution_count": null, 115 | "id": "2fc2b499", 116 | "metadata": {}, 117 | "outputs": [], 118 | "source": [ 119 | "#|export \n", 120 | "def chunk_one_file(\n", 121 | " filenames:list, # list of filenames from which we'll pick one\n", 122 | " args, # output of argparse\n", 123 | " file_ind # index from filenames list to read from\n", 124 | " ):\n", 125 | " \"this chunks up one file by setting things up and then calling blow_chunks\"\n", 126 | " filename = filenames[file_ind] # this is actually input_path+/+filename\n", 127 | " output_path, input_paths = args.output_path, args.input_paths\n", 128 | " new_filename = None\n", 129 | " if args.debug: print(f\" --- process_one_file: filenames[{file_ind}] = {filename}\\n\", flush=True)\n", 130 | " \n", 131 | " for ipath in input_paths: # set up the output filename & any folders it needs\n", 132 | " if args.nomix and ('Mix' in ipath) and ('Audio Files' in ipath): return # this is specific to the BDCT dataset, otherwise ignore\n", 133 | " if ipath in filename:\n", 134 | " last_ipath = ipath.split('/')[-1] # get the last part of ipath\n", 135 | " clean_filename = filename.replace(ipath,'') # remove all of ipath from the front of filename\n", 136 | " new_filename = f\"{output_path}/{last_ipath}/{clean_filename}\".replace('//','/') \n", 137 | " makedir(os.path.dirname(new_filename)) # we might need to make a directory for the output file\n", 138 | " break\n", 139 | "\n", 140 | " if new_filename is None:\n", 141 | " print(f\"ERROR: Something went wrong with name of input file {filename}. Skipping.\",flush=True) \n", 142 | " return \n", 143 | " \n", 144 | " try:\n", 145 | " if args.debug: print(f\" About to load filenames[{file_ind}] = {filename}\\n\", flush=True)\n", 146 | " audio = load_audio(filename, sr=args.sr, verbose=args.debug)\n", 147 | " if args.debug: print(f\" We loaded the audio, audio.shape = {audio.shape}\\n Calling blow_chunks...\", flush=True)\n", 148 | " blow_chunks(audio, new_filename, args.chunk_size, sr=args.sr, spacing=args.spacing, strip=args.strip, thresh=args.thresh, debug=args.debug)\n", 149 | " except Exception as e: \n", 150 | " print(f\"Error '{e}' while loading {filename} or writing chunks. Skipping.\", flush=True)\n", 151 | "\n", 152 | " if args.debug: print(f\" --- File {file_ind}: {filename} completed.\\n\", flush=True)\n", 153 | " return" 154 | ] 155 | }, 156 | { 157 | "cell_type": "markdown", 158 | "id": "f3c24fb9", 159 | "metadata": {}, 160 | "source": [ 161 | "Testing equential execution of for one file at a time, sequentially:" 162 | ] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "233f88e6", 168 | "metadata": {}, 169 | "outputs": [ 170 | { 171 | "name": "stdout", 172 | "output_type": "stream", 173 | "text": [ 174 | "filenames = ['examples/stereo_pewpew.mp3', 'examples/example.wav']\n", 175 | "file 1/2: examples/stereo_pewpew.mp3:\n", 176 | " --- process_one_file: filenames[0] = examples/stereo_pewpew.mp3\n", 177 | "\n", 178 | " About to load filenames[0] = examples/stereo_pewpew.mp3\n", 179 | "\n", 180 | "Resampling examples/stereo_pewpew.mp3 from 44100.0 Hz to 48000 Hz\n", 181 | " We loaded the audio, audio.shape = torch.Size([2, 234505])\n", 182 | " Calling blow_chunks...\n", 183 | " blow_chunks: audio.shape = torch.Size([2, 234505])\n", 184 | " --- File 0: examples/stereo_pewpew.mp3 completed.\n", 185 | "\n", 186 | "file 2/2: examples/example.wav:\n", 187 | " --- process_one_file: filenames[1] = examples/example.wav\n", 188 | "\n", 189 | " About to load filenames[1] = examples/example.wav\n", 190 | "\n", 191 | "Resampling examples/example.wav from 44100 Hz to 48000 Hz\n", 192 | " We loaded the audio, audio.shape = torch.Size([1, 55728])\n", 193 | " Calling blow_chunks...\n", 194 | " blow_chunks: audio.shape = torch.Size([1, 55728])\n", 195 | " --- File 1: examples/example.wav completed.\n", 196 | "\n" 197 | ] 198 | } 199 | ], 200 | "source": [ 201 | "class AttrDict(dict): # cf. https://stackoverflow.com/a/14620633/4259243\n", 202 | " \"setup an object to hold args\"\n", 203 | " def __init__(self, *args, **kwargs):\n", 204 | " super(AttrDict, self).__init__(*args, **kwargs)\n", 205 | " self.__dict__ = self\n", 206 | " \n", 207 | "args = AttrDict() # setup something akin to what argparse gives\n", 208 | "args.update( {'output_path':'test_chunks', 'input_paths':['examples/'], 'sr':48000, 'chunk_size':131072, 'spacing':0.5,\n", 209 | " 'norm':'global', 'strip':False, 'thresh':-70, 'nomix':False, 'verbose':True,\n", 210 | " 'workers':min(32, os.cpu_count() + 4), 'debug':True })\n", 211 | "\n", 212 | "filenames = get_audio_filenames(args.input_paths)\n", 213 | "print(\"filenames =\",filenames)\n", 214 | "for i in range(len(filenames)):\n", 215 | " print(f\"file {i+1}/{len(filenames)}: {filenames[i]}:\")\n", 216 | " chunk_one_file(filenames, args, i)" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "id": "1bf27260", 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "#| hide\n", 227 | "from nbdev import nbdev_export\n", 228 | "nbdev_export()" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "193feea4", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [] 238 | } 239 | ], 240 | "metadata": { 241 | "kernelspec": { 242 | "display_name": "Python 3 (ipykernel)", 243 | "language": "python", 244 | "name": "python3" 245 | } 246 | }, 247 | "nbformat": 4, 248 | "nbformat_minor": 5 249 | } 250 | -------------------------------------------------------------------------------- /04_spectrofu.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "3c35e2cf", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "code", 16 | "execution_count": null, 17 | "id": "9d2cefc8", 18 | "metadata": {}, 19 | "outputs": [], 20 | "source": [ 21 | "#| default_exp spectrofu" 22 | ] 23 | }, 24 | { 25 | "cell_type": "markdown", 26 | "id": "5d79fbe5", 27 | "metadata": {}, 28 | "source": [ 29 | "# spectrofu" 30 | ] 31 | }, 32 | { 33 | "cell_type": "markdown", 34 | "id": "262d835a", 35 | "metadata": {}, 36 | "source": [ 37 | "> Command-line script that preprocesses a dataset of audio and turns it into spectrograms. \n", 38 | "\n", 39 | "Assumes pre-chunking e.g. via `chunkadelic` --- This is pretty much a simplified duplicate of `chunkadelic`.\n", 40 | "\n", 41 | "Note: Duplicates the directory structure(s) referenced by input paths. " 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "id": "a023138f", 48 | "metadata": {}, 49 | "outputs": [], 50 | "source": [ 51 | "#|hide\n", 52 | "from nbdev.showdoc import *" 53 | ] 54 | }, 55 | { 56 | "cell_type": "code", 57 | "execution_count": null, 58 | "id": "72c40976", 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "#|all_slow" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": null, 68 | "id": "5df76a87", 69 | "metadata": {}, 70 | "outputs": [], 71 | "source": [ 72 | "#|export\n", 73 | "import argparse \n", 74 | "from glob import glob \n", 75 | "from pathlib import Path\n", 76 | "import os \n", 77 | "import math\n", 78 | "from multiprocessing import Pool, cpu_count, Barrier\n", 79 | "from functools import partial\n", 80 | "from tqdm.contrib.concurrent import process_map \n", 81 | "import torch\n", 82 | "import torchaudio\n", 83 | "from aeiou.core import is_silence, load_audio, makedir, get_audio_filenames\n", 84 | "from aeiou.viz import audio_spectrogram_image" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": null, 90 | "id": "b662abdf", 91 | "metadata": {}, 92 | "outputs": [], 93 | "source": [ 94 | "#|export\n", 95 | "def save_stft(\n", 96 | " audio:torch.tensor, # long audio file to be chunked\n", 97 | " new_filename:str # stem of new filename(s) to be output as spectrogram images\n", 98 | " ):\n", 99 | " \"coverts audio to stft image and saves it\"\n", 100 | " im = audio_spectrogram_image(audio, justimage=True) # should already be a PIL image\n", 101 | " print(f\"saving new file = {new_filename}\")\n", 102 | " im.save(new_filename)\n", 103 | " return" 104 | ] 105 | }, 106 | { 107 | "cell_type": "code", 108 | "execution_count": null, 109 | "id": "56e01f54", 110 | "metadata": {}, 111 | "outputs": [], 112 | "source": [ 113 | "#|export\n", 114 | "def process_one_file(\n", 115 | " filenames:list, # list of filenames from which we'll pick one\n", 116 | " args, # output of argparse\n", 117 | " file_ind # index from filenames list to read from\n", 118 | " ):\n", 119 | " \"this turns one audio file into a spectrogram. left channel only for now\"\n", 120 | " filename = filenames[file_ind] # this is actually input_path+/+filename\n", 121 | " output_path, input_paths = args.output_path, args.input_paths\n", 122 | " new_filename = None\n", 123 | " \n", 124 | " for ipath in input_paths: # set up the output filename & any folders it needs\n", 125 | " if ipath in filename: # this just avoids repeats/ weirdness.\n", 126 | " last_ipath = ipath.split('/')[-1] # get the last part of ipath\n", 127 | " clean_filename = filename.replace(ipath,'') # remove all of ipath from the front of filename\n", 128 | " new_filename = f\"{output_path}/{last_ipath}/{clean_filename}\".replace('//','/')\n", 129 | " new_filename = str(Path(new_filename).with_suffix(\".png\")) # give it file extension for image\n", 130 | " makedir(os.path.dirname(new_filename)) # we might need to make a directory for the output file\n", 131 | " break\n", 132 | " \n", 133 | " if new_filename is None:\n", 134 | " print(f\"ERROR: Something went wrong with name of input file {filename}. Skipping.\",flush=True) \n", 135 | " return \n", 136 | "\n", 137 | " try:\n", 138 | " audio = load_audio(filename, sr=args.sr)\n", 139 | " save_stft(audio, new_filename)\n", 140 | " except Exception as e: \n", 141 | " print(f\"Some kind of error happened with {filename}, either loading or writing images. Skipping.\", flush=True)\n", 142 | "\n", 143 | " return\n", 144 | "\n", 145 | "\n", 146 | "def main():\n", 147 | " parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n", 148 | " parser.add_argument('--sr', type=int, default=48000, help='Output sample rate')\n", 149 | " parser.add_argument('--workers', type=int, default=min(32, os.cpu_count() + 4), help='Maximum number of workers to use (default: all)')\n", 150 | " parser.add_argument('output_path', help='Path of output for spectrogram-ified data')\n", 151 | " parser.add_argument('input_paths', nargs='+', help='Path(s) of a file or a folder of files. (recursive)')\n", 152 | " args = parser.parse_args()\n", 153 | "\n", 154 | " print(f\" output_path = {args.output_path}\")\n", 155 | "\n", 156 | " print(\"Getting list of input filenames\")\n", 157 | " filenames = get_audio_filenames(args.input_paths) \n", 158 | " n = len(filenames) \n", 159 | " print(f\" Got {n} input filenames\") \n", 160 | "\n", 161 | " print(\"Processing files (in parallel)\")\n", 162 | " wrapper = partial(process_one_file, filenames, args)\n", 163 | " r = process_map(wrapper, range(0, n), chunksize=1, max_workers=args.workers) # different chunksize used by tqdm. max_workers is to avoid annoying other ppl\n", 164 | "\n", 165 | " print(\"Finished\")" 166 | ] 167 | }, 168 | { 169 | "cell_type": "code", 170 | "execution_count": null, 171 | "id": "a159380a", 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "usage: spectrofu [-h] [--sr SR] [--workers WORKERS]\n", 179 | " output_path input_paths [input_paths ...]\n", 180 | "\n", 181 | "positional arguments:\n", 182 | " output_path Path of output for spectrogram-ified data\n", 183 | " input_paths Path(s) of a file or a folder of files. (recursive)\n", 184 | "\n", 185 | "options:\n", 186 | " -h, --help show this help message and exit\n", 187 | " --sr SR Output sample rate (default: 48000)\n", 188 | " --workers WORKERS Maximum number of workers to use (default: all) (default:\n", 189 | " 14)\n" 190 | ] 191 | } 192 | ], 193 | "source": [ 194 | "! spectrofu -h " 195 | ] 196 | }, 197 | { 198 | "cell_type": "code", 199 | "execution_count": null, 200 | "id": "1bf27260", 201 | "metadata": {}, 202 | "outputs": [], 203 | "source": [ 204 | "#| hide\n", 205 | "from nbdev import nbdev_export\n", 206 | "nbdev_export()" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "id": "afc3ed0d-d9c8-4916-b1a4-10879c95cd73", 213 | "metadata": {}, 214 | "outputs": [], 215 | "source": [] 216 | } 217 | ], 218 | "metadata": { 219 | "kernelspec": { 220 | "display_name": "Python 3 (ipykernel)", 221 | "language": "python", 222 | "name": "python3" 223 | } 224 | }, 225 | "nbformat": 4, 226 | "nbformat_minor": 5 227 | } 228 | -------------------------------------------------------------------------------- /05_hpc.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "%load_ext autoreload\n", 10 | "%autoreload 2" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "execution_count": null, 16 | "metadata": {}, 17 | "outputs": [], 18 | "source": [ 19 | "#| default_exp hpc" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "# hpc\n", 27 | "\n", 28 | "> routines for running on clusters" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "This part isn't strictly for audio i/o, but is nevertheless a normal part of Harmonai's operations. The point of this package is to reduce code-copying between Harmonai projects. \n", 36 | "\n", 37 | "**Heads up**: Huggingface `accelerate` support will likely be *deprecated* soon. We found `accelerate` necessary because of problems running PyTorch Lightning on multiple nodes, but those problems have now been resolved. Thus we will likely be using Lighting, so you will see that dependency being added and perhaps accelerate being removed. " 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": null, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "#|hide\n", 47 | "from nbdev.showdoc import *" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "metadata": {}, 54 | "outputs": [ 55 | { 56 | "name": "stderr", 57 | "output_type": "stream", 58 | "text": [ 59 | "/Users/shawley/envs/aeiou/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", 60 | " from .autonotebook import tqdm as notebook_tqdm\n" 61 | ] 62 | } 63 | ], 64 | "source": [ 65 | "#|export \n", 66 | "import yaml\n", 67 | "import accelerate\n", 68 | "from pathlib import Path\n", 69 | "import tqdm\n", 70 | "import torch\n", 71 | "import torchaudio\n", 72 | "from torchaudio import transforms as T\n", 73 | "import os\n" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": null, 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "#|export \n", 83 | "def get_accel_config(filename='~/.cache/huggingface/accelerate/default_config.yaml'):\n", 84 | " \"get huggingface accelerate config info\" \n", 85 | " try: # first try to use the default file\n", 86 | " filename = filename.replace('~', str(Path.home()))\n", 87 | " with open(filename, 'r') as file:\n", 88 | " ac = yaml.safe_load(file)\n", 89 | " except OSError:\n", 90 | " ac = {}\n", 91 | " \n", 92 | " # then update using any environment variables\n", 93 | " if os.getenv('MAIN_PROCESS_IP') is not None: ac['main_process_ip'] = os.getenv('MAIN_PROCESS_IP')\n", 94 | " if os.getenv('MACHINE_RANK') is not None: ac['machine_rank'] = os.getenv('MACHINE_RANK')\n", 95 | " if os.getenv('NUM_MACHINES') is not None: ac['num_machines'] = os.getenv('NUM_MACHINES')\n", 96 | " if os.getenv('NUM_PROCESSES') is not None: ac['num_processes'] = os.getenv('NUM_PROCESSES')\n", 97 | "\n", 98 | " return ac" 99 | ] 100 | }, 101 | { 102 | "cell_type": "markdown", 103 | "metadata": {}, 104 | "source": [ 105 | "Let's test that:" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [ 113 | { 114 | "data": { 115 | "text/plain": [ 116 | "{'compute_environment': 'LOCAL_MACHINE',\n", 117 | " 'deepspeed_config': {},\n", 118 | " 'distributed_type': 'MULTI_GPU',\n", 119 | " 'fsdp_config': {},\n", 120 | " 'machine_rank': 0,\n", 121 | " 'main_process_ip': '',\n", 122 | " 'main_process_port': 12332,\n", 123 | " 'main_training_function': 'main',\n", 124 | " 'mixed_precision': 'no',\n", 125 | " 'num_machines': 2,\n", 126 | " 'num_processes': 8,\n", 127 | " 'use_cpu': False}" 128 | ] 129 | }, 130 | "execution_count": null, 131 | "metadata": {}, 132 | "output_type": "execute_result" 133 | } 134 | ], 135 | "source": [ 136 | "ac = get_accel_config('examples/accel_config.yaml')\n", 137 | "ac" 138 | ] 139 | }, 140 | { 141 | "cell_type": "markdown", 142 | "metadata": {}, 143 | "source": [ 144 | "Next is a little utility to replace `print`, where it'll only print on the cluster headnode. Note that you can only send one string to `hprint`, so use f-strings. Also we use ANSI codes to color the text (currently cyan) to help it stand out from all the other text that's probably scrolling by!" 145 | ] 146 | }, 147 | { 148 | "cell_type": "code", 149 | "execution_count": null, 150 | "metadata": {}, 151 | "outputs": [], 152 | "source": [ 153 | "#|export \n", 154 | "class HostPrinter():\n", 155 | " \"lil accelerate utility for only printing on host node\"\n", 156 | " def __init__(\n", 157 | " self, \n", 158 | " accelerator, # huggingface accelerator object\n", 159 | " tag='\\033[96m', # starting color\n", 160 | " untag='\\033[0m' # reset to default color\n", 161 | " ): \n", 162 | " self.accelerator, self.tag, self.untag = accelerator, tag, untag\n", 163 | " def __call__(self, s:str):\n", 164 | " if self.accelerator.is_main_process:\n", 165 | " print(self.tag + s + self.untag, flush=True)" 166 | ] 167 | }, 168 | { 169 | "cell_type": "markdown", 170 | "metadata": {}, 171 | "source": [ 172 | "Here's a test:" 173 | ] 174 | }, 175 | { 176 | "cell_type": "code", 177 | "execution_count": null, 178 | "metadata": {}, 179 | "outputs": [ 180 | { 181 | "name": "stdout", 182 | "output_type": "stream", 183 | "text": [ 184 | "\u001b[96mUsing device: cpu\u001b[0m\n" 185 | ] 186 | } 187 | ], 188 | "source": [ 189 | "accelerator = accelerate.Accelerator()\n", 190 | "device = accelerator.device\n", 191 | "hprint = HostPrinter(accelerator) # hprint only prints on head node\n", 192 | "hprint(f'Using device: {device}')" 193 | ] 194 | }, 195 | { 196 | "cell_type": "markdown", 197 | "metadata": {}, 198 | "source": [ 199 | "## PyTorch+Accelerate Model routines\n", 200 | "For when the model is wrapped in a `accelerate` accelerator" 201 | ] 202 | }, 203 | { 204 | "cell_type": "code", 205 | "execution_count": null, 206 | "metadata": {}, 207 | "outputs": [], 208 | "source": [ 209 | "#|export \n", 210 | "def save(\n", 211 | " accelerator, # Huggingface accelerator object\n", 212 | " args, # prefigure args dict, (we only use args.name)\n", 213 | " model, # the model, pre-unwrapped\n", 214 | " opt=None, # optimizer state\n", 215 | " epoch=None, # training epoch number\n", 216 | " step=None # training setp number\n", 217 | " ):\n", 218 | " \"for checkpointing & model saves\"\n", 219 | " #accelerator.wait_for_everyone() # hangs\n", 220 | " filename = f'{args.name}_{step:08}.pth' if (step is not None) else f'{args.name}.pth'\n", 221 | " if accelerator.is_main_process:\n", 222 | " print(f'\\nSaving checkpoint to {filename}...')\n", 223 | " obj = {'model': accelerator.unwrap_model(model).state_dict() }\n", 224 | " if opt is not None: obj['opt'] = opt.state_dict()\n", 225 | " if epoch is not None: obj['epoch'] = epoch\n", 226 | " if step is not None: obj['step'] = step\n", 227 | " accelerator.save(obj, filename)" 228 | ] 229 | }, 230 | { 231 | "cell_type": "code", 232 | "execution_count": null, 233 | "metadata": {}, 234 | "outputs": [], 235 | "source": [ 236 | "#|export \n", 237 | "def load(\n", 238 | " accelerator, # Huggingface accelerator object\n", 239 | " model, # an uninitialized model (pre-unwrapped) whose weights will be overwritten\n", 240 | " filename:str, # name of the checkpoint file\n", 241 | " opt=None, # optimizer state UNUSED FOR NOW\n", 242 | " ):\n", 243 | " \"load a saved model checkpoint\"\n", 244 | " #accelerator.wait_for_everyone() # hangs\n", 245 | " if accelerator.is_main_process:\n", 246 | " print(f'\\nLoading checkpoint from {filename}...')\n", 247 | " accelerator.unwrap_model(model).load_state_dict(torch.load(filename)['model'])\n", 248 | " return model # this return isn't actually needed since model is already updated at this point" 249 | ] 250 | }, 251 | { 252 | "cell_type": "markdown", 253 | "metadata": {}, 254 | "source": [ 255 | "## Utils for Accelerate or Lightning\n", 256 | "Be sure to use \"unwrap\" any accelerate model when calling these" 257 | ] 258 | }, 259 | { 260 | "cell_type": "code", 261 | "execution_count": null, 262 | "metadata": {}, 263 | "outputs": [], 264 | "source": [ 265 | "#|export \n", 266 | "def n_params(\n", 267 | " module # raw PyTorch model/module, e.g. returned by accelerator.unwrap_model()\n", 268 | " ):\n", 269 | " \"\"\"Returns the number of trainable parameters in a module.\n", 270 | " Be sure to use accelerator.unwrap_model when calling this. \"\"\"\n", 271 | " return sum(p.numel() for p in module.parameters())" 272 | ] 273 | }, 274 | { 275 | "cell_type": "code", 276 | "execution_count": null, 277 | "metadata": {}, 278 | "outputs": [], 279 | "source": [ 280 | "#|export \n", 281 | "def freeze(\n", 282 | " model # raw PyTorch model, e.g. returned by accelerator.unwrap_model()\n", 283 | " ):\n", 284 | " \"\"\"freezes model weights; turns off gradient info\n", 285 | " If using accelerate, call thisaccelerator.unwrap_model when calling this. \"\"\"\n", 286 | " for param in model.parameters(): \n", 287 | " param.requires_grad = False" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "metadata": {}, 294 | "outputs": [], 295 | "source": [ 296 | "#| hide\n", 297 | "from nbdev import nbdev_export\n", 298 | "nbdev_export()" 299 | ] 300 | }, 301 | { 302 | "cell_type": "code", 303 | "execution_count": null, 304 | "metadata": {}, 305 | "outputs": [], 306 | "source": [] 307 | } 308 | ], 309 | "metadata": { 310 | "kernelspec": { 311 | "display_name": "aa", 312 | "language": "python", 313 | "name": "aa" 314 | } 315 | }, 316 | "nbformat": 4, 317 | "nbformat_minor": 4 318 | } 319 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | aeiou 2 | ================ 3 | 4 | 5 | 6 | Pronounced “[ayoo](https://youtu.be/Hv6RbEOlqRo?t=24)” 7 | 8 | ## Install 9 | 10 | It is recommended you install the latest version from GitHub via 11 | 12 | ``` sh 13 | pip install git+https://github.com/drscotthawley/aeiou.git 14 | ``` 15 | 16 | However binaries will be occasionally updated on PyPI, installed via 17 | 18 | ``` sh 19 | pip install aeiou 20 | ``` 21 | 22 | ## How to use 23 | 24 | This is a series of utility routines developed in support of multiple 25 | projects within the [Harmonai](https://www.harmonai.org/) organization. 26 | See individual documentation pages for more specific instructions on how 27 | these can be used. Note that this is *research code*, so it’s a) in flux 28 | and b) in need of improvements to documenation. 29 | 30 | ## Documentation 31 | 32 | Documentation for this library is hosted on the [aeiou GitHub Pages 33 | site](https://drscotthawley.github.io/aeiou/). 34 | 35 | ## Contributing 36 | 37 | Contributions are welcome – especially for improvements to 38 | documentation! To contribute: 39 | 40 | 1. Fork this repo and then clone your fork to your local machine. 41 | 42 | 2. Create a new (local) branch: `git -b mybranch` (or whatever you want 43 | to call it). 44 | 45 | 3. This library is written entirely in [nbdev](https://nbdev.fast.ai/) 46 | version 2, using Jupyter notebooks. 47 | 48 | 4. [Install nbdev](https://nbdev.fast.ai/getting_started.html#install) 49 | and then you can edit the Jupyter notebooks. 50 | 51 | ** **NOTE:** Edit the notebook (`.ipynb`) files, *not* the `.py` files, as the latter get overwritten by `nbdev`. 52 | 53 | 6. After editing notebooks, run `nbdev_prepare` 54 | 55 | 7. If that succeeds, you can do 56 | `git add *.ipynb aeiou/*.py; git commit` and then `git push` to get 57 | your changes to back to your fork on GitHub. 58 | 59 | 8. Then send a Pull Request from your fork to the `dev` branch of this original `aeiou` 60 | repository. 61 | 62 | ## Attribution 63 | 64 | Please include attribution of this code if you reproduce sections of it 65 | in your own code: 66 | 67 | aeiou: audio engineering i/o utilities: Copyright (c) Scott H. Hawley, 2022-2023. https://github.com/drscotthawley/aeiou 68 | 69 | In research papers, please cite this software if you find it useful: 70 | 71 | ``` bibtex 72 | @misc{aeiou, 73 | author = {Scott H. Hawley}, 74 | title = {aeiou: audio engineering i/o utilities}, 75 | year = {2022}, 76 | url = {https://github.com/drscotthawley/aeiou}, 77 | } 78 | ``` 79 | 80 | Copyright (c) Scott H. Hawley, 2022-2023. 81 | 82 | ## License 83 | 84 | [License](https://github.com/drscotthawley/aeiou/blob/main/LICENSE) is 85 | APACHE 2.0. 86 | -------------------------------------------------------------------------------- /_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | format: 5 | html: 6 | theme: cosmo 7 | css: styles.css 8 | toc: true 9 | 10 | website: 11 | twitter-card: true 12 | open-graph: true 13 | repo-actions: [issue] 14 | navbar: 15 | background: primary 16 | search: true 17 | sidebar: 18 | style: floating 19 | 20 | metadata-files: [nbdev.yml, sidebar.yml] -------------------------------------------------------------------------------- /aeiou/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.21" 2 | -------------------------------------------------------------------------------- /aeiou/_modidx.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by nbdev 2 | 3 | d = { 'settings': { 'branch': 'main', 4 | 'doc_baseurl': '/aeiou/', 5 | 'doc_host': 'https://drscotthawley.github.io', 6 | 'git_url': 'https://github.com/drscotthawley/aeiou/', 7 | 'lib_path': 'aeiou'}, 8 | 'syms': { 'aeiou.chunk_one_file': { 'aeiou.chunk_one_file.blow_chunks': ('chunk_one_file.html#blow_chunks', 'aeiou/chunk_one_file.py'), 9 | 'aeiou.chunk_one_file.chunk_one_file': ( 'chunk_one_file.html#chunk_one_file', 10 | 'aeiou/chunk_one_file.py')}, 11 | 'aeiou.chunkadelic': { 'aeiou.chunkadelic.blow_chunks': ('chunkadelic.html#blow_chunks', 'aeiou/chunkadelic.py'), 12 | 'aeiou.chunkadelic.chunk_one_file': ('chunkadelic.html#chunk_one_file', 'aeiou/chunkadelic.py'), 13 | 'aeiou.chunkadelic.main': ('chunkadelic.html#main', 'aeiou/chunkadelic.py'), 14 | 'aeiou.chunkadelic.set_bit_rate': ('chunkadelic.html#set_bit_rate', 'aeiou/chunkadelic.py')}, 15 | 'aeiou.core': { 'aeiou.core.audio_float_to_int': ('core.html#audio_float_to_int', 'aeiou/core.py'), 16 | 'aeiou.core.batch_it_crazy': ('core.html#batch_it_crazy', 'aeiou/core.py'), 17 | 'aeiou.core.fast_scandir': ('core.html#fast_scandir', 'aeiou/core.py'), 18 | 'aeiou.core.get_audio_filenames': ('core.html#get_audio_filenames', 'aeiou/core.py'), 19 | 'aeiou.core.get_dbmax': ('core.html#get_dbmax', 'aeiou/core.py'), 20 | 'aeiou.core.get_device': ('core.html#get_device', 'aeiou/core.py'), 21 | 'aeiou.core.get_latest_ckpt': ('core.html#get_latest_ckpt', 'aeiou/core.py'), 22 | 'aeiou.core.get_run_info': ('core.html#get_run_info', 'aeiou/core.py'), 23 | 'aeiou.core.is_silence': ('core.html#is_silence', 'aeiou/core.py'), 24 | 'aeiou.core.is_tool': ('core.html#is_tool', 'aeiou/core.py'), 25 | 'aeiou.core.load_audio': ('core.html#load_audio', 'aeiou/core.py'), 26 | 'aeiou.core.makedir': ('core.html#makedir', 'aeiou/core.py'), 27 | 'aeiou.core.normalize_audio': ('core.html#normalize_audio', 'aeiou/core.py'), 28 | 'aeiou.core.rnd_string': ('core.html#rnd_string', 'aeiou/core.py'), 29 | 'aeiou.core.untuple': ('core.html#untuple', 'aeiou/core.py')}, 30 | 'aeiou.datasets': { 'aeiou.datasets.AudioDataset': ('datasets.html#audiodataset', 'aeiou/datasets.py'), 31 | 'aeiou.datasets.AudioDataset.__getitem__': ('datasets.html#audiodataset.__getitem__', 'aeiou/datasets.py'), 32 | 'aeiou.datasets.AudioDataset.__init__': ('datasets.html#audiodataset.__init__', 'aeiou/datasets.py'), 33 | 'aeiou.datasets.AudioDataset.__len__': ('datasets.html#audiodataset.__len__', 'aeiou/datasets.py'), 34 | 'aeiou.datasets.AudioDataset.get_data_range': ( 'datasets.html#audiodataset.get_data_range', 35 | 'aeiou/datasets.py'), 36 | 'aeiou.datasets.AudioDataset.get_next_chunk': ( 'datasets.html#audiodataset.get_next_chunk', 37 | 'aeiou/datasets.py'), 38 | 'aeiou.datasets.AudioDataset.load_file_ind': ( 'datasets.html#audiodataset.load_file_ind', 39 | 'aeiou/datasets.py'), 40 | 'aeiou.datasets.AudioDataset.preload_files': ( 'datasets.html#audiodataset.preload_files', 41 | 'aeiou/datasets.py'), 42 | 'aeiou.datasets.AudioWebDataLoader': ('datasets.html#audiowebdataloader', 'aeiou/datasets.py'), 43 | 'aeiou.datasets.FillTheNoise': ('datasets.html#fillthenoise', 'aeiou/datasets.py'), 44 | 'aeiou.datasets.FillTheNoise.__call__': ('datasets.html#fillthenoise.__call__', 'aeiou/datasets.py'), 45 | 'aeiou.datasets.FillTheNoise.__init__': ('datasets.html#fillthenoise.__init__', 'aeiou/datasets.py'), 46 | 'aeiou.datasets.IterableAudioDataset': ('datasets.html#iterableaudiodataset', 'aeiou/datasets.py'), 47 | 'aeiou.datasets.IterableAudioDataset.__init__': ( 'datasets.html#iterableaudiodataset.__init__', 48 | 'aeiou/datasets.py'), 49 | 'aeiou.datasets.IterableAudioDataset.__iter__': ( 'datasets.html#iterableaudiodataset.__iter__', 50 | 'aeiou/datasets.py'), 51 | 'aeiou.datasets.Mono': ('datasets.html#mono', 'aeiou/datasets.py'), 52 | 'aeiou.datasets.Mono.__call__': ('datasets.html#mono.__call__', 'aeiou/datasets.py'), 53 | 'aeiou.datasets.NormInputs': ('datasets.html#norminputs', 'aeiou/datasets.py'), 54 | 'aeiou.datasets.NormInputs.__call__': ('datasets.html#norminputs.__call__', 'aeiou/datasets.py'), 55 | 'aeiou.datasets.NormInputs.__init__': ('datasets.html#norminputs.__init__', 'aeiou/datasets.py'), 56 | 'aeiou.datasets.PadCrop': ('datasets.html#padcrop', 'aeiou/datasets.py'), 57 | 'aeiou.datasets.PadCrop.__call__': ('datasets.html#padcrop.__call__', 'aeiou/datasets.py'), 58 | 'aeiou.datasets.PadCrop.__init__': ('datasets.html#padcrop.__init__', 'aeiou/datasets.py'), 59 | 'aeiou.datasets.PadCrop.draw_chunk': ('datasets.html#padcrop.draw_chunk', 'aeiou/datasets.py'), 60 | 'aeiou.datasets.PadCrop_Normalized_T': ('datasets.html#padcrop_normalized_t', 'aeiou/datasets.py'), 61 | 'aeiou.datasets.PadCrop_Normalized_T.__call__': ( 'datasets.html#padcrop_normalized_t.__call__', 62 | 'aeiou/datasets.py'), 63 | 'aeiou.datasets.PadCrop_Normalized_T.__init__': ( 'datasets.html#padcrop_normalized_t.__init__', 64 | 'aeiou/datasets.py'), 65 | 'aeiou.datasets.PadCrop_Normalized_T_old': ('datasets.html#padcrop_normalized_t_old', 'aeiou/datasets.py'), 66 | 'aeiou.datasets.PadCrop_Normalized_T_old.__call__': ( 'datasets.html#padcrop_normalized_t_old.__call__', 67 | 'aeiou/datasets.py'), 68 | 'aeiou.datasets.PadCrop_Normalized_T_old.__init__': ( 'datasets.html#padcrop_normalized_t_old.__init__', 69 | 'aeiou/datasets.py'), 70 | 'aeiou.datasets.PhaseFlipper': ('datasets.html#phaseflipper', 'aeiou/datasets.py'), 71 | 'aeiou.datasets.PhaseFlipper.__call__': ('datasets.html#phaseflipper.__call__', 'aeiou/datasets.py'), 72 | 'aeiou.datasets.PhaseFlipper.__init__': ('datasets.html#phaseflipper.__init__', 'aeiou/datasets.py'), 73 | 'aeiou.datasets.RandMask1D': ('datasets.html#randmask1d', 'aeiou/datasets.py'), 74 | 'aeiou.datasets.RandMask1D.__init__': ('datasets.html#randmask1d.__init__', 'aeiou/datasets.py'), 75 | 'aeiou.datasets.RandMask1D.forward': ('datasets.html#randmask1d.forward', 'aeiou/datasets.py'), 76 | 'aeiou.datasets.RandMask1D.make_single_mask': ( 'datasets.html#randmask1d.make_single_mask', 77 | 'aeiou/datasets.py'), 78 | 'aeiou.datasets.RandMask1D.mask_once_1channel': ( 'datasets.html#randmask1d.mask_once_1channel', 79 | 'aeiou/datasets.py'), 80 | 'aeiou.datasets.RandPool': ('datasets.html#randpool', 'aeiou/datasets.py'), 81 | 'aeiou.datasets.RandPool.__call__': ('datasets.html#randpool.__call__', 'aeiou/datasets.py'), 82 | 'aeiou.datasets.RandPool.__init__': ('datasets.html#randpool.__init__', 'aeiou/datasets.py'), 83 | 'aeiou.datasets.RandomGain': ('datasets.html#randomgain', 'aeiou/datasets.py'), 84 | 'aeiou.datasets.RandomGain.__call__': ('datasets.html#randomgain.__call__', 'aeiou/datasets.py'), 85 | 'aeiou.datasets.RandomGain.__init__': ('datasets.html#randomgain.__init__', 'aeiou/datasets.py'), 86 | 'aeiou.datasets.Stereo': ('datasets.html#stereo', 'aeiou/datasets.py'), 87 | 'aeiou.datasets.Stereo.__call__': ('datasets.html#stereo.__call__', 'aeiou/datasets.py'), 88 | 'aeiou.datasets.fix_double_slashes': ('datasets.html#fix_double_slashes', 'aeiou/datasets.py'), 89 | 'aeiou.datasets.get_all_s3_urls': ('datasets.html#get_all_s3_urls', 'aeiou/datasets.py'), 90 | 'aeiou.datasets.get_all_s3_urls_zach': ('datasets.html#get_all_s3_urls_zach', 'aeiou/datasets.py'), 91 | 'aeiou.datasets.get_contiguous_range': ('datasets.html#get_contiguous_range', 'aeiou/datasets.py'), 92 | 'aeiou.datasets.get_s3_contents': ('datasets.html#get_s3_contents', 'aeiou/datasets.py'), 93 | 'aeiou.datasets.get_wds_loader': ('datasets.html#get_wds_loader', 'aeiou/datasets.py'), 94 | 'aeiou.datasets.is_valid_sample': ('datasets.html#is_valid_sample', 'aeiou/datasets.py'), 95 | 'aeiou.datasets.log_and_continue': ('datasets.html#log_and_continue', 'aeiou/datasets.py'), 96 | 'aeiou.datasets.name_cache_file': ('datasets.html#name_cache_file', 'aeiou/datasets.py'), 97 | 'aeiou.datasets.pipeline_return': ('datasets.html#pipeline_return', 'aeiou/datasets.py'), 98 | 'aeiou.datasets.smoothstep': ('datasets.html#smoothstep', 'aeiou/datasets.py'), 99 | 'aeiou.datasets.smoothstep_box': ('datasets.html#smoothstep_box', 'aeiou/datasets.py'), 100 | 'aeiou.datasets.wds_preprocess': ('datasets.html#wds_preprocess', 'aeiou/datasets.py')}, 101 | 'aeiou.hpc': { 'aeiou.hpc.HostPrinter': ('hpc.html#hostprinter', 'aeiou/hpc.py'), 102 | 'aeiou.hpc.HostPrinter.__call__': ('hpc.html#hostprinter.__call__', 'aeiou/hpc.py'), 103 | 'aeiou.hpc.HostPrinter.__init__': ('hpc.html#hostprinter.__init__', 'aeiou/hpc.py'), 104 | 'aeiou.hpc.freeze': ('hpc.html#freeze', 'aeiou/hpc.py'), 105 | 'aeiou.hpc.get_accel_config': ('hpc.html#get_accel_config', 'aeiou/hpc.py'), 106 | 'aeiou.hpc.load': ('hpc.html#load', 'aeiou/hpc.py'), 107 | 'aeiou.hpc.n_params': ('hpc.html#n_params', 'aeiou/hpc.py'), 108 | 'aeiou.hpc.save': ('hpc.html#save', 'aeiou/hpc.py')}, 109 | 'aeiou.spectrofu': { 'aeiou.spectrofu.main': ('spectrofu.html#main', 'aeiou/spectrofu.py'), 110 | 'aeiou.spectrofu.process_one_file': ('spectrofu.html#process_one_file', 'aeiou/spectrofu.py'), 111 | 'aeiou.spectrofu.save_stft': ('spectrofu.html#save_stft', 'aeiou/spectrofu.py')}, 112 | 'aeiou.viz': { 'aeiou.viz.audio_spectrogram_image': ('viz.html#audio_spectrogram_image', 'aeiou/viz.py'), 113 | 'aeiou.viz.embeddings_table': ('viz.html#embeddings_table', 'aeiou/viz.py'), 114 | 'aeiou.viz.generate_melspec': ('viz.html#generate_melspec', 'aeiou/viz.py'), 115 | 'aeiou.viz.mel_spectrogram': ('viz.html#mel_spectrogram', 'aeiou/viz.py'), 116 | 'aeiou.viz.on_colab': ('viz.html#on_colab', 'aeiou/viz.py'), 117 | 'aeiou.viz.pca_point_cloud': ('viz.html#pca_point_cloud', 'aeiou/viz.py'), 118 | 'aeiou.viz.playable_spectrogram': ('viz.html#playable_spectrogram', 'aeiou/viz.py'), 119 | 'aeiou.viz.plot_jukebox_embeddings': ('viz.html#plot_jukebox_embeddings', 'aeiou/viz.py'), 120 | 'aeiou.viz.point_cloud': ('viz.html#point_cloud', 'aeiou/viz.py'), 121 | 'aeiou.viz.print_stats': ('viz.html#print_stats', 'aeiou/viz.py'), 122 | 'aeiou.viz.proj_pca': ('viz.html#proj_pca', 'aeiou/viz.py'), 123 | 'aeiou.viz.project_down': ('viz.html#project_down', 'aeiou/viz.py'), 124 | 'aeiou.viz.setup_plotly': ('viz.html#setup_plotly', 'aeiou/viz.py'), 125 | 'aeiou.viz.show_pca_point_cloud': ('viz.html#show_pca_point_cloud', 'aeiou/viz.py'), 126 | 'aeiou.viz.show_point_cloud': ('viz.html#show_point_cloud', 'aeiou/viz.py'), 127 | 'aeiou.viz.spectrogram_image': ('viz.html#spectrogram_image', 'aeiou/viz.py'), 128 | 'aeiou.viz.tokens_spectrogram_image': ('viz.html#tokens_spectrogram_image', 'aeiou/viz.py')}}} 129 | -------------------------------------------------------------------------------- /aeiou/chunk_one_file.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../03b_chunk_one_file.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['blow_chunks', 'chunk_one_file'] 5 | 6 | # %% ../03b_chunk_one_file.ipynb 5 7 | import os 8 | import torch 9 | import torchaudio 10 | from .core import is_silence, load_audio, makedir, get_audio_filenames, normalize_audio, get_dbmax 11 | 12 | # %% ../03b_chunk_one_file.ipynb 6 13 | def blow_chunks( 14 | audio:torch.tensor, # long audio file to be chunked 15 | new_filename:str, # stem of new filename(s) to be output as chunks 16 | chunk_size:int, # how big each audio chunk is, in samples 17 | sr=48000, # audio sample rate in Hz 18 | norm='False', # normalize input audio, based on the max of the absolute value ['global','channel', or anything else for None, e.g. False] 19 | spacing=0.5, # fraction of each chunk to advance between hops 20 | strip=False, # strip silence: chunks with max power in dB below this value will not be saved to files 21 | thresh=-70, # threshold in dB for determining what counts as silence 22 | debug=False, # print debugging information 23 | ): 24 | "chunks up the audio and saves them with --{i} on the end of each chunk filename" 25 | if (debug): print(f" blow_chunks: audio.shape = {audio.shape}",flush=True) 26 | 27 | chunk = torch.zeros(audio.shape[0], chunk_size) 28 | _, ext = os.path.splitext(new_filename) 29 | 30 | if norm in ['global','channel']: audio = normalize_audio(audio, norm) 31 | 32 | spacing = 0.5 if spacing == 0 else spacing # handle degenerate case as a request for the defaults 33 | 34 | start, i = 0, 0 35 | while start < audio.shape[-1]: 36 | out_filename = new_filename.replace(ext, f'--{i}'+ext) 37 | end = min(start + chunk_size, audio.shape[-1]) 38 | if end-start < chunk_size: # needs zero padding on end 39 | chunk = torch.zeros(audio.shape[0], chunk_size) 40 | chunk[:,0:end-start] = audio[:,start:end] 41 | if (not strip) or (not is_silence(chunk, thresh=thresh)): 42 | torchaudio.save(out_filename, chunk, sr) 43 | else: 44 | print(f"Skipping chunk {out_filename} because it's 'silent' (below threhold of {thresh} dB).",flush=True) 45 | start, i = start + int(spacing * chunk_size), i + 1 46 | return 47 | 48 | # %% ../03b_chunk_one_file.ipynb 7 49 | def chunk_one_file( 50 | filenames:list, # list of filenames from which we'll pick one 51 | args, # output of argparse 52 | file_ind # index from filenames list to read from 53 | ): 54 | "this chunks up one file by setting things up and then calling blow_chunks" 55 | filename = filenames[file_ind] # this is actually input_path+/+filename 56 | output_path, input_paths = args.output_path, args.input_paths 57 | new_filename = None 58 | if args.debug: print(f" --- process_one_file: filenames[{file_ind}] = {filename}\n", flush=True) 59 | 60 | for ipath in input_paths: # set up the output filename & any folders it needs 61 | if args.nomix and ('Mix' in ipath) and ('Audio Files' in ipath): return # this is specific to the BDCT dataset, otherwise ignore 62 | if ipath in filename: 63 | last_ipath = ipath.split('/')[-1] # get the last part of ipath 64 | clean_filename = filename.replace(ipath,'') # remove all of ipath from the front of filename 65 | new_filename = f"{output_path}/{last_ipath}/{clean_filename}".replace('//','/') 66 | makedir(os.path.dirname(new_filename)) # we might need to make a directory for the output file 67 | break 68 | 69 | if new_filename is None: 70 | print(f"ERROR: Something went wrong with name of input file {filename}. Skipping.",flush=True) 71 | return 72 | 73 | try: 74 | if args.debug: print(f" About to load filenames[{file_ind}] = {filename}\n", flush=True) 75 | audio = load_audio(filename, sr=args.sr, verbose=args.debug) 76 | if args.debug: print(f" We loaded the audio, audio.shape = {audio.shape}\n Calling blow_chunks...", flush=True) 77 | blow_chunks(audio, new_filename, args.chunk_size, sr=args.sr, spacing=args.spacing, strip=args.strip, thresh=args.thresh, debug=args.debug) 78 | except Exception as e: 79 | print(f"Error '{e}' while loading {filename} or writing chunks. Skipping.", flush=True) 80 | 81 | if args.debug: print(f" --- File {file_ind}: {filename} completed.\n", flush=True) 82 | return 83 | -------------------------------------------------------------------------------- /aeiou/chunkadelic.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../03_chunkadelic.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['blow_chunks', 'set_bit_rate', 'chunk_one_file', 'main'] 5 | 6 | # %% ../03_chunkadelic.ipynb 5 7 | import argparse 8 | import os 9 | from functools import partial 10 | from tqdm.contrib.concurrent import process_map 11 | import torch 12 | import torchaudio 13 | import math 14 | from .core import is_silence, load_audio, makedir, get_audio_filenames, normalize_audio, get_dbmax 15 | import multiprocessing as mp 16 | from multiprocessing import Pool, cpu_count, Barrier 17 | 18 | # %% ../03_chunkadelic.ipynb 6 19 | def blow_chunks( 20 | audio:torch.tensor, # long audio file to be chunked 21 | new_filename:str, # stem of new filename(s) to be output as chunks 22 | chunk_size:int, # how big each audio chunk is, in samples 23 | sr=48000, # audio sample rate in Hz 24 | norm='False', # normalize input audio, based on the max of the absolute value ['global','channel', or anything else for None, e.g. False] 25 | spacing=0.5, # fraction of each chunk to advance between hops 26 | strip=False, # strip silence: chunks with max power in dB below this value will not be saved to files 27 | thresh=-70, # threshold in dB for determining what counts as silence 28 | bits_per_sample=None, # kwarg for torchaudio.save, None means use defaults 29 | nopad=False, # disable zero-padding, allowing samples to be shorter than chunk_size (including "leftovers" on the "ends") 30 | debug=False, # print debugging information 31 | ): 32 | "chunks up the audio and saves them with --{i} on the end of each chunk filename" 33 | if (debug): print(f" blow_chunks: audio.shape = {audio.shape}",flush=True) 34 | 35 | #chunk = torch.zeros(audio.shape[0], chunk_size) 36 | _, ext = os.path.splitext(new_filename) 37 | 38 | if norm in ['global','channel']: audio = normalize_audio(audio, norm) 39 | 40 | spacing = 0.5 if spacing == 0 else spacing # handle degenerate case as a request for the defaults 41 | 42 | start, i = 0, 0 43 | while start < audio.shape[-1]: 44 | out_filename = new_filename.replace(ext, f'--{i}'+ext) 45 | end = min(start + chunk_size, audio.shape[-1]) 46 | if (end-start < chunk_size) and not nopad: # audio shorter than chunk_size: pad with zeros 47 | chunk = torch.zeros(audio.shape[0], chunk_size) 48 | chunk[:,0:end-start] = audio[:,start:end] 49 | else: 50 | chunk = audio[:,start:end] 51 | if (not strip) or (not is_silence(chunk, thresh=thresh)): 52 | if debug: print(f" Saving output chunk {out_filename}, bits_per_sample={bits_per_sample}, chunk.shape={chunk.shape}", flush=True) 53 | torchaudio.save(out_filename, chunk, sr, bits_per_sample=bits_per_sample) 54 | else: 55 | print(f"Skipping chunk {out_filename} because it's 'silent' (below threhold of {thresh} dB).",flush=True) 56 | start, i = start + int(spacing * chunk_size), i + 1 57 | return 58 | 59 | # %% ../03_chunkadelic.ipynb 7 60 | def set_bit_rate(bits, filename, debug=False): 61 | if (bits is None) or isinstance(bits, int): bits_per_sample = bits 62 | elif bits.lower()=='none': 63 | bits_per_sample = None # use torchaudio default 64 | elif bits.lower()=='match': 65 | try: 66 | bits_per_sample = torchaudio.info(filename).bits_per_sample 67 | except Exception as e: 68 | print(" Error with bits=match: Can't get audio medatadata. Choosing default=None") 69 | bits_per_sample=None 70 | else: 71 | bits_per_sample = int(bits) 72 | if debug: print(" set_bit_rate: bits_per_sample =",bits_per_sample,flush=True) 73 | return bits_per_sample 74 | 75 | # %% ../03_chunkadelic.ipynb 8 76 | def chunk_one_file( 77 | filenames:list, # list of filenames from which we'll pick one 78 | args, # output of argparse 79 | file_ind # index from filenames list to read from 80 | ): 81 | "this chunks up one file by setting things up and then calling blow_chunks" 82 | filename = filenames[file_ind] # this is actually input_path+/+filename 83 | output_path, input_paths = args.output_path, args.input_paths 84 | new_filename = None 85 | if args.debug: print(f" --- process_one_file: filenames[{file_ind}] = {filename}\n", flush=True) 86 | 87 | for ipath in input_paths: # set up the output filename & any folders it needs 88 | if args.nomix and ('Mix' in ipath) and ('Audio Files' in ipath): return # this is specific to the BDCT dataset, otherwise ignore 89 | if ipath in filename: 90 | last_ipath = ipath.split('/')[-1] # get the last part of ipath 91 | clean_filename = filename.replace(ipath,'') # remove all of ipath from the front of filename 92 | new_filename = f"{output_path}/{last_ipath}/{clean_filename}".replace('//','/') 93 | makedir(os.path.dirname(new_filename)) # we might need to make a directory for the output file 94 | break 95 | 96 | if new_filename is None: 97 | print(f"ERROR: Something went wrong with name of input file {filename}. Skipping.",flush=True) 98 | return 99 | 100 | try: # try to load the audio file and chunk it up 101 | if args.debug: print(f" About to load filenames[{file_ind}] = {filename}\n", flush=True) 102 | audio = load_audio(filename, sr=args.sr, verbose=args.debug) 103 | if args.debug: print(f" We loaded the audio, audio.shape = {audio.shape}. Setting bit rate.",flush=True) 104 | bits_per_sample = set_bit_rate(args.bits, filename, debug=args.debug) 105 | if args.debug: print(f" Bit rate set. Calling blow_chunks...", flush=True) 106 | blow_chunks(audio, new_filename, args.chunk_size, sr=args.sr, spacing=args.spacing, 107 | strip=args.strip, thresh=args.thresh, bits_per_sample=bits_per_sample, nopad=args.nopad, debug=args.debug) 108 | except Exception as e: 109 | print(f"Error '{e}' while loading {filename} or writing chunks. Skipping.", flush=True) 110 | 111 | if args.debug: print(f" --- File {file_ind}: {filename} completed.\n", flush=True) 112 | return 113 | 114 | # %% ../03_chunkadelic.ipynb 12 115 | def main(): 116 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 117 | parser.add_argument('--chunk_size', type=int, default=2**17, help='Length of chunks') 118 | parser.add_argument('--sr', type=int, default=48000, help='Output sample rate') 119 | parser.add_argument('--norm', default='False', const='False', nargs='?', choices=['False', 'global', 'channel'], 120 | help='Normalize audio, based on the max of the absolute value [global/channel/False]') 121 | parser.add_argument('--spacing', type=float, default=0.5, help='Spacing factor, advance this fraction of a chunk per copy') 122 | parser.add_argument('--strip', action='store_true', help='Strips silence: chunks with max dB below are not outputted') 123 | parser.add_argument('--thresh', type=int, default=-70, help='threshold in dB for determining what constitutes silence') 124 | parser.add_argument('--bits', type=str, default='None', help='Bit depth: "None" uses torchaudio default | "match"=match input audio files | or specify an int') 125 | parser.add_argument('--workers', type=int, default=min(32, os.cpu_count() + 4), help='Maximum number of workers to use (default: all)') 126 | parser.add_argument('--nomix', action='store_true', help='(BDCT Dataset specific) exclude output of "*/Audio Files/*Mix*"') 127 | parser.add_argument('--nopad', action='store_true', help='Disable zero padding for audio shorter than chunk_size') 128 | parser.add_argument('output_path', help='Path of output for chunkified data') 129 | parser.add_argument('input_paths', nargs='+', help='Path(s) of a file or a folder of files. (recursive)') 130 | parser.add_argument('--verbose', action='store_true', help='Extra output logging') 131 | parser.add_argument('--debug', action='store_true', help='Extra EXTRA output logging') 132 | args = parser.parse_args() 133 | 134 | if args.verbose: 135 | print("chunkadelic: args = ",args) 136 | print("Getting list of input filenames") 137 | filenames = get_audio_filenames(args.input_paths) 138 | if args.verbose: 139 | print(f" Got {len(filenames)} input filenames") 140 | if not (args.norm in ['global','channel']): 141 | print(f"Warning: since norm = {args.norm}, no normalizations will be performed.") 142 | print("Processing files (in parallel)...") 143 | 144 | wrapper = partial(chunk_one_file, filenames, args) 145 | r = process_map(wrapper, range(len(filenames)), chunksize=1, max_workers=args.workers) # different chunksize used by tqdm. max_workers is to avoid annoying other ppl 146 | 147 | if args.verbose: print("Finished") 148 | -------------------------------------------------------------------------------- /aeiou/core.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../00_core.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['pdlbd_exts', 'get_device', 'is_tool', 'normalize_audio', 'load_audio', 'get_dbmax', 'audio_float_to_int', 5 | 'is_silence', 'batch_it_crazy', 'makedir', 'fast_scandir', 'get_audio_filenames', 'untuple', 6 | 'get_latest_ckpt', 'rnd_string', 'get_run_info'] 7 | 8 | # %% ../00_core.ipynb 4 9 | import torch 10 | import torchaudio 11 | from torchaudio import transforms as T 12 | from torch.nn import functional as F 13 | import numpy as np 14 | from librosa import load as lr_load 15 | from pedalboard.io import AudioFile, get_supported_read_formats 16 | import os 17 | import math 18 | from einops import rearrange 19 | import random 20 | import string 21 | import glob 22 | from pathlib import Path 23 | import warnings 24 | 25 | # %% ../00_core.ipynb 5 26 | def get_device(gpu_str=''): 27 | "utility to suggest which pytorch device to use" 28 | #return torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu') 29 | device_str = 'cpu' 30 | if torch.cuda.is_available(): 31 | device_str = 'cuda' if gpu_str=='' else f'cuda:{gpu_str}' 32 | elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available(): # must check for mps attr if using older pytorch 33 | device_str = 'mps' 34 | os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' 35 | return torch.device(device_str) 36 | 37 | # %% ../00_core.ipynb 7 38 | def is_tool(name): 39 | """Check whether `name` is on PATH and marked as executable.""" 40 | 41 | # from whichcraft import which 42 | from shutil import which 43 | 44 | return which(name) is not None 45 | 46 | # %% ../00_core.ipynb 9 47 | def normalize_audio( 48 | audio_in, # input array/tensor (numpy or Pytorch) 49 | norm='global', # global (use max-abs of whole clip) | channel (per-channel norm'd individually) | ''/None 50 | ): 51 | "normalize audio, based on the max of the absolute value" 52 | audio_out = audio_in.clone() if torch.is_tensor(audio_in) else audio_in.copy() # rudimentary PyTorch/NumPy support 53 | if ('global' == norm) or len(audio_in.shape)< 2: 54 | absmax = abs(audio_in).max() 55 | audio_out = 0.99*audio_in/absmax if absmax != 0 else audio_in # 0.99 = just below clipping 56 | elif 'channel' == norm: 57 | for c in range(audio_in.shape[0]): # this loop is slow but sure. TODO: do it fast but still avoid div by zero 58 | absmax = abs(audio_in[c]).max() 59 | audio_out[c] = 0.99*audio_in[c]/absmax if absmax != 0 else audio_in[c] # 0.99 = just below clipping 60 | #anything else, pass unchanged 61 | return audio_out 62 | 63 | # %% ../00_core.ipynb 19 64 | pdlbd_exts = None # stores supported pedalboard file extensions. Global so it updates once per run 65 | 66 | def load_audio( 67 | filename:str, # name of file to load 68 | sr=48000, # sample rate in Hz 69 | verbose=True, # whether or not to print notices of resampling 70 | norm='', # passedto normalize_audio(), see above 71 | )->torch.tensor: 72 | "loads an audio file as a torch tensor" 73 | global pdlbd_exts 74 | 75 | if '.mp3' in filename.lower(): # don't rely on torchaudio for mp3s 76 | pdlbd_exts = get_supported_read_formats() if pdlbd_exts==None else pdlbd_exts 77 | if '.mp3' in pdlbd_exts: # first try pedalboard's mp3 support 78 | with AudioFile(filename) as f: 79 | audio, in_sr = f.read(f.frames), f.samplerate 80 | else: 81 | if verbose: print("Warning: pedalboard mp3 support failed, falling back to librosa") 82 | audio, in_sr = lr_load(filename, mono=False, sr=sr) 83 | audio = torch.tensor(audio) 84 | else: 85 | audio, in_sr = torchaudio.load(filename) 86 | if in_sr != sr: 87 | if verbose: print(f"Resampling {filename} from {in_sr} Hz to {sr} Hz",flush=True) 88 | resample_tf = T.Resample(in_sr, sr) 89 | audio = resample_tf(audio) 90 | 91 | if norm in ['global','channel']: audio = normalize_audio(audio, norm=norm) 92 | return audio 93 | 94 | # %% ../00_core.ipynb 29 95 | def get_dbmax( 96 | audio, # torch tensor of (multichannel) audio 97 | ): 98 | "finds the loudest value in the entire clip and puts that into dB (full scale)" 99 | return 20*torch.log10(torch.flatten(audio.abs()).max()).cpu().numpy() 100 | 101 | # %% ../00_core.ipynb 32 102 | def audio_float_to_int(waveform): 103 | "converts torch float to numpy int16 (for playback in notebooks)" 104 | return np.clip( waveform.cpu().numpy()*32768 , -32768, 32768).astype('int16') 105 | 106 | # %% ../00_core.ipynb 34 107 | def is_silence( 108 | audio, # torch tensor of (multichannel) audio 109 | thresh=-60, # threshold in dB below which we declare to be silence 110 | ): 111 | "checks if entire clip is 'silence' below some dB threshold" 112 | dBmax = get_dbmax(audio) 113 | return dBmax < thresh 114 | 115 | # %% ../00_core.ipynb 38 116 | def batch_it_crazy( 117 | x, # a time series as a PyTorch tensor, e.g. stereo or mono audio 118 | win_len, # length of each "window", i.e. length of each element in new batch 119 | ): 120 | "(pun intended) Chop up long sequence into a batch of win_len windows" 121 | if len(x.shape) < 2: x = x.unsqueeze(0) # guard against 1-d arrays 122 | x_len = x.shape[-1] 123 | n_windows = (x_len // win_len) + 1 124 | pad_amt = win_len * n_windows - x_len # pad end w. zeros to make lengths even when split 125 | xpad = F.pad(x, (0, pad_amt)) 126 | return rearrange(xpad, 'd (b n) -> b d n', n=win_len) 127 | 128 | # %% ../00_core.ipynb 45 129 | def makedir( 130 | path:str, # directory or nested set of directories 131 | ): 132 | "creates directories where they don't exist" 133 | if os.path.isdir(path): return # don't make it if it already exists 134 | #print(f" Making directory {path}") 135 | try: 136 | os.makedirs(path) # recursively make all dirs named in path 137 | except: # don't really care about errors 138 | pass 139 | 140 | # %% ../00_core.ipynb 47 141 | def fast_scandir( 142 | dir:str, # top-level directory at which to begin scanning 143 | ext:list # list of allowed file extensions 144 | ): 145 | "very fast `glob` alternative. from https://stackoverflow.com/a/59803793/4259243" 146 | subfolders, files = [], [] 147 | ext = ['.'+x if x[0]!='.' else x for x in ext] # add starting period to extensions if needed 148 | try: # hope to avoid 'permission denied' by this try 149 | for f in os.scandir(dir): 150 | try: # 'hope to avoid too many levels of symbolic links' error 151 | if f.is_dir(): 152 | subfolders.append(f.path) 153 | elif f.is_file(): 154 | if os.path.splitext(f.name)[1].lower() in ext: 155 | files.append(f.path) 156 | except: 157 | pass 158 | except: 159 | pass 160 | 161 | for dir in list(subfolders): 162 | sf, f = fast_scandir(dir, ext) 163 | subfolders.extend(sf) 164 | files.extend(f) 165 | return subfolders, files 166 | 167 | # %% ../00_core.ipynb 51 168 | def get_audio_filenames( 169 | paths:list # directories in which to search 170 | ): 171 | "recursively get a list of audio filenames" 172 | filenames = [] 173 | if type(paths) is str: paths = [paths] 174 | for path in paths: # get a list of relevant filenames 175 | subfolders, files = fast_scandir(path, ['.wav','.flac','.ogg','.aiff','.aif','.mp3']) 176 | filenames.extend(files) 177 | return filenames 178 | 179 | # %% ../00_core.ipynb 54 180 | def untuple(x, verbose=False): 181 | """Recursive. For when you're sick of tuples and lists: 182 | keeps peeling off elements until we get a non-tuple or non-list, 183 | i.e., returns the 'first' data element we can 'actually use'""" 184 | if isinstance(x, tuple) or isinstance(x, list): 185 | if verbose: print("yea: x = ",x) 186 | return untuple(x[0], verbose=verbose) 187 | else: 188 | if verbose: print("no: x = ",x) 189 | return x 190 | 191 | # %% ../00_core.ipynb 57 192 | def get_latest_ckpt( 193 | dir_tree, # name of the run without unique identifer 194 | run_name_prefix='', # unique identifier for particular run 195 | sim_ckpts=[''], # string or list of strings. other paths to check under if nothing's in dir_tree 196 | verbose=True, # whether to print message(s) 197 | ): 198 | "This will grab the most recent checkpoint filename in dir tree given by name" 199 | list_of_files = list(Path(dir_tree).glob(f'**/{run_name_prefix}*/checkpoints/*.ckpt')) 200 | if [] != list_of_files: return max(list_of_files, key=os.path.getctime) 201 | print(f" Nothing relevant found in {dir_tree}. Checking also in {sim_ckpts}.") 202 | if isinstance(sim_ckpts, str): sim_ckpts = [sim_ckpts] 203 | for pattern in sim_ckpts: 204 | if verbose: print(" pattern = ",pattern) 205 | directories = [dir_path for dir_path in glob.glob(pattern) if not os.path.isfile(dir_path)] 206 | if verbose: print(" Also checking in ",directories) 207 | for directory in directories: 208 | if verbose: print(" directory = ",directory) 209 | list_of_files += list(Path(directory).glob(f'**/{run_name_prefix}*/checkpoints/*.ckpt')) 210 | if [] != list_of_files: return max(list_of_files, key=os.path.getctime) 211 | warnings.warn(" No matching checkpoint files found anywhere. Starting run from scratch.") 212 | return "" 213 | 214 | # %% ../00_core.ipynb 60 215 | def rnd_string(n=8): 216 | "random letters and numbers of given length. case sensitive" 217 | raise DeprecationWarning("Better to generate random string in SLURM script") 218 | return ''.join(random.choice(string.ascii_letters+string.digits) for i in range(n)) 219 | 220 | def get_run_info(run_name, verbose=True): 221 | """ 222 | parses run_name into (ideally) prefix & id using underscore as separator, and/or fills in missing info if needed 223 | NOTE: do not trust generated strings to be same on other processes 224 | """ 225 | run_info = run_name.split('_') 226 | prefix = run_info[0] 227 | if len(run_info)>1: 228 | run_id = run_info[-1] 229 | else: 230 | run_id = rnd_string() 231 | if verbose: print(f"WARNING: generating random run_id as {run_id}. Might be different on different process(es).") 232 | raise DeprecationWarning("Instead, generate random string in SLURM script") 233 | new_run_name = f"{prefix}_{run_id}" if prefix !='' else f"{run_id}" 234 | return {'prefix': prefix, 'id':run_id, 'run_name':new_run_name} 235 | -------------------------------------------------------------------------------- /aeiou/datasets.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../01_datasets.ipynb. 2 | 3 | # %% ../01_datasets.ipynb 5 4 | from __future__ import annotations # for type hints, in LAION code samples 5 | import numpy as np 6 | import torch 7 | import torch.nn as nn 8 | import torchaudio 9 | from torchaudio import transforms as T 10 | from torchvision import transforms as VT 11 | import random 12 | import os 13 | import json 14 | import tqdm 15 | from multiprocessing import Pool, cpu_count 16 | from urllib.parse import urlparse 17 | from functools import partial 18 | from .core import load_audio, get_audio_filenames, is_silence, untuple 19 | from fastcore.utils import * 20 | import webdataset as wds 21 | import subprocess 22 | import re 23 | import pedalboard 24 | from typing import Tuple 25 | 26 | # %% auto 0 27 | __all__ = ['pipeline_return', 'RandomGain', 'PadCrop', 'PadCrop_Normalized_T_old', 'PadCrop_Normalized_T', 'PhaseFlipper', 28 | 'FillTheNoise', 'RandPool', 'NormInputs', 'Mono', 'Stereo', 'smoothstep', 'smoothstep_box', 'RandMask1D', 29 | 'AudioDataset', 'fix_double_slashes', 'get_s3_contents', 'get_contiguous_range', 'get_all_s3_urls', 30 | 'get_all_s3_urls_zach', 'IterableAudioDataset', 'name_cache_file', 'wds_preprocess', 'log_and_continue', 31 | 'is_valid_sample', 'AudioWebDataLoader', 'get_wds_loader'] 32 | 33 | # %% ../01_datasets.ipynb 8 34 | def pipeline_return( 35 | val, # value to be returned (by calling function) 36 | x, # original data-container that was passed in (tensor or dict) 37 | key='inputs', # if x is dict, this key gets overwritten/added 38 | ): 39 | "little helper routine that appears at end of most augmentations, to compress code" 40 | if not isinstance(x, dict): 41 | return val 42 | else: 43 | x[key] = val 44 | return x 45 | 46 | # %% ../01_datasets.ipynb 9 47 | class RandomGain(nn.Module): 48 | "apply a random gain to audio" 49 | def __init__(self, 50 | min_gain, # minimum gain to apply 51 | max_gain, # maximum gain to apply 52 | ): 53 | super().__init__() 54 | self.min_gain = min_gain 55 | self.max_gain = max_gain 56 | 57 | def __call__(self, x): 58 | signal = x if not isinstance(x, dict) else x['inputs'] 59 | gain = random.uniform(self.min_gain, self.max_gain) 60 | signal = signal * gain 61 | return pipeline_return(signal, x) 62 | 63 | # %% ../01_datasets.ipynb 14 64 | class PadCrop(nn.Module): 65 | "Grabs a randomly-located section from an audio file, padding with zeros in case of any misalignment" 66 | def __init__(self, 67 | n_samples, # length of chunk to extract from longer signal 68 | randomize=True, # draw cropped chunk from a random position in audio file 69 | redraw_silence=True, # a chunk containing silence will be replaced with a new one 70 | silence_thresh=-60, # threshold in dB below which we declare to be silence 71 | max_redraws=2 # when redrawing silences, don't do it more than this many 72 | ): 73 | super().__init__() 74 | store_attr() # sets self.___ vars automatically 75 | 76 | def draw_chunk(self, signal): 77 | "here's the part that actually draws a cropped/padded chunk of audio from signal" 78 | if len(signal.shape) < 2: signal = torch.unsqueeze(signal,0) 79 | n, s = signal.shape 80 | start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() 81 | end = start + self.n_samples 82 | chunk = signal.new_zeros([n, self.n_samples]) 83 | chunk[:, :min(s, self.n_samples)] = signal[:, start:end] 84 | crop_range = torch.tensor([start,end],dtype=int).to(signal.device) # making this a tensor helps preserve order in DataLoader 85 | return chunk, crop_range 86 | 87 | def __call__(self, x): 88 | "when part of the pipline, this will grab a padded/cropped chunk from signal" 89 | signal = x if not isinstance(x, dict) else x['inputs'] 90 | chunk, crop_range = self.draw_chunk(signal) 91 | num_redraws = 0 92 | while self.redraw_silence and is_silence(chunk, thresh=self.silence_thresh) and (num_redraws < self.max_redraws): 93 | chunk, crop_range = self.draw_chunk(signal) 94 | num_redraws = num_redraws+1 95 | if not isinstance(x, dict): # multiple values, not handled by pipeline_return 96 | return chunk 97 | else: 98 | ##SHH: don't save original as x['uncropped'] unless all input files have the same length, otherwise torch.utils.data.DataLoader will complain about collating different lengths 99 | ##x['uncropped'] = x['inputs'] # save a copy (of the pointer) in case we want to quickly re-crop the same audio file 100 | x['inputs'], x['crop_range'] = chunk, crop_range # crop_range reports where chunk was taken from 101 | return x 102 | 103 | # %% ../01_datasets.ipynb 16 104 | class PadCrop_Normalized_T_old(nn.Module): 105 | """Variation on PadCrop. source: Zach Evan's audio-diffusion repo""" 106 | def __init__(self, n_samples: int, randomize:bool = True): 107 | 108 | super().__init__() 109 | 110 | self.n_samples = n_samples 111 | self.randomize = randomize 112 | 113 | def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float]: 114 | 115 | n_channels, n_samples = source.shape 116 | 117 | upper_bound = max(0, n_samples - self.n_samples) 118 | 119 | offset = 0 120 | if(self.randomize and n_samples > self.n_samples): 121 | offset = random.randint(0, upper_bound + 1) 122 | 123 | t_start = offset / (upper_bound + self.n_samples) 124 | t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) 125 | 126 | chunk = source.new_zeros([n_channels, self.n_samples]) 127 | chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] 128 | 129 | return ( 130 | chunk, 131 | t_start, 132 | t_end 133 | ) 134 | 135 | 136 | 137 | 138 | class PadCrop_Normalized_T(nn.Module): 139 | 140 | def __init__(self, n_samples: int, sample_rate: int, randomize: bool = True): 141 | "Newer version as per Zach's edits" 142 | 143 | super().__init__() 144 | 145 | self.n_samples = n_samples 146 | self.sample_rate = sample_rate 147 | self.randomize = randomize 148 | 149 | def __call__(self, source: torch.Tensor) -> Tuple[torch.Tensor, float, float, int, int]: 150 | 151 | n_channels, n_samples = source.shape 152 | 153 | upper_bound = max(0, n_samples - self.n_samples) 154 | 155 | offset = 0 156 | if(self.randomize and n_samples > self.n_samples): 157 | offset = random.randint(0, upper_bound + 1) 158 | 159 | t_start = offset / (upper_bound + self.n_samples) 160 | t_end = (offset + self.n_samples) / (upper_bound + self.n_samples) 161 | 162 | chunk = source.new_zeros([n_channels, self.n_samples]) 163 | chunk[:, :min(n_samples, self.n_samples)] = source[:, offset:offset + self.n_samples] 164 | 165 | seconds_start = math.floor(offset / self.sample_rate) 166 | seconds_total = math.ceil(n_samples / self.sample_rate) 167 | 168 | return ( 169 | chunk, 170 | t_start, 171 | t_end, 172 | seconds_start, 173 | seconds_total 174 | ) 175 | 176 | # %% ../01_datasets.ipynb 21 177 | class PhaseFlipper(nn.Module): 178 | "she was PHAAAAAAA-AAAASE FLIPPER, a random invert yeah" 179 | def __init__(self, 180 | p=0.5 # probability that phase flip will be applied 181 | ): 182 | super().__init__() 183 | self.p = p 184 | def __call__(self, x): 185 | signal = x if not isinstance(x, dict) else x['inputs'] 186 | out = -signal if (random.random() < self.p) else signal 187 | return pipeline_return(out, x) 188 | 189 | # %% ../01_datasets.ipynb 22 190 | class FillTheNoise(nn.Module): 191 | "randomly adds a bit of noise, or not, just to spice things up. (Name is an homage to DJ/artist/collaborator Kill the Noise)" 192 | def __init__(self, 193 | p=0.33 # probability that noise will be added 194 | ): 195 | super().__init__() 196 | self.p = p 197 | def __call__(self, x): 198 | signal = x if not isinstance(x, dict) else x['inputs'] 199 | out = signal + 0.25*random.random()*(2*torch.rand_like(signal)-1) if (random.random() < self.p) else signal 200 | return pipeline_return(out, x) 201 | 202 | # %% ../01_datasets.ipynb 23 203 | class RandPool(nn.Module): 204 | "maybe (or maybe not) do an avgpool operation, with a random-sized kernel " 205 | def __init__(self, p=0.2): 206 | self.p, self.maxkern = p, 100 207 | def __call__(self, x): 208 | if (random.random() < self.p): 209 | signal = x if not isinstance(x, dict) else x['inputs'] 210 | ksize = int(random.random()*self.maxkern) 211 | avger = nn.AvgPool1d(kernel_size=ksize, stride=1, padding=1) 212 | return pipeline_return( avger(signal), x ) 213 | else: 214 | return x # do nothing 215 | 216 | # %% ../01_datasets.ipynb 24 217 | class NormInputs(nn.Module): 218 | "Normalize inputs to [-1,1]. Useful for quiet inputs" 219 | def __init__(self, 220 | do_norm=True # controllable parameter for turning normalization on/off 221 | ): 222 | super().__init__() 223 | self.do_norm = do_norm 224 | self.eps = 1e-2 225 | def __call__(self, x): 226 | signal = x if not isinstance(x, dict) else x['inputs'] 227 | out = signal if (not self.do_norm) else signal/(torch.amax(signal,-1)[0] + self.eps) 228 | return pipeline_return(out, x) 229 | 230 | # %% ../01_datasets.ipynb 25 231 | class Mono(nn.Module): 232 | "convert audio to mono" 233 | def __call__(self, x): 234 | signal = x if not isinstance(x, dict) else x['inputs'] 235 | out = torch.mean(signal, dim=0) if len(signal.shape) > 1 else signal # average across channels 236 | return pipeline_return(out, x) 237 | 238 | # %% ../01_datasets.ipynb 26 239 | class Stereo(nn.Module): 240 | "convert audio to stereo" 241 | def __call__(self, x): 242 | signal = x if not isinstance(x, dict) else x['inputs'] 243 | # Check if it's mono 244 | if len(signal.shape) == 1: # s -> 2, s 245 | signal = signal.unsqueeze(0).repeat(2, 1) 246 | elif len(signal.shape) == 2: 247 | if signal.shape[0] == 1: #1, s -> 2, s 248 | signal = signal.repeat(2, 1) # copy mono to stereo 249 | elif signal.shape[0] > 2: #?, s -> 2,s 250 | signal = signal[:2, :] # grab only first two channels 251 | return pipeline_return(signal, x) 252 | 253 | # %% ../01_datasets.ipynb 28 254 | def smoothstep(x, # a tensor of coordinates across a domain, e.g. [0,1] 255 | edge0=0.4, # "zero"/"left" side of smoothstep 256 | edge1=0.6, # "one"/"right" side of smoothstep 257 | ): 258 | "an s-shaped curve, 0's on left side and 1's at right side, with gradient zero at all 1's and 0's. cf. https://en.wikipedia.org/wiki/Smoothstep" 259 | x = torch.where(x < edge0, 0, x) 260 | x = torch.where(x > edge1, 1, x) 261 | x = torch.where( torch.logical_and(x >= edge0, x <= edge1) , (x - edge0) / (edge1 - edge0), x ) 262 | return x * x * (3 - 2 * x) 263 | 264 | # %% ../01_datasets.ipynb 29 265 | def smoothstep_box( 266 | coords, # tensor of coordinate values 267 | edges = (0.2,0.3,0.5,0.6) # (left 1's boundary, left 0's boundary, right 0's boundary, right 1's boundary) 268 | ): 269 | "makes a flat region of zeros that transitions smoothly to 1's via smoothsteps at the sides" 270 | assert edges[0] < edges[1] and edges[1] < edges[2] and edges[2] < edges[3], f"Edges should be in increasing order but you have edges = {edges}" 271 | right = smoothstep(coords, edge0=edges[2], edge1=edges[3]) 272 | left = 1 - smoothstep(coords, edge0=edges[0], edge1=edges[1]) 273 | return left + right 274 | 275 | # %% ../01_datasets.ipynb 34 276 | class RandMask1D(nn.Module): 277 | "Performs masking or 'cutout' along 1d data. Can support 'smooth sides' to the cutouts. Note that you probably want masking to be the *last* step in the augmentation pipeline" 278 | def __init__(self, 279 | mask_frac=0.25, # fraction of total input that is to be masked (helps compute no. of masked regions) 280 | mask_width=0.1, # either a fraction of the total length (float < 1) or an exact integer value for length of each masked region 281 | mask_type='simple', # 'simple'=hard sides to cuts, 'softstep'=smooth sides, 'nyquist'=nyquist-freq wave 0.5*(1,-1,1,-1,..) 282 | edge_width=0.2, # for mask_type=smoothstep, fraction or integer value of transition regions to come in from the sides of zeros region 283 | per_channel=False, # different masks on different channels; model can cheat if your inputs are mono 284 | verbose = False, # show logging info 285 | ): 286 | super().__init__() 287 | if mask_width < 1: self.mask_width_frac = mask_width # if float is given, set fraction of chunk length for each mask 288 | self.mask_frac, self.mask_width, self.mask_type, self.edge_width, self.verbose = mask_frac, mask_width, mask_type, edge_width, verbose 289 | self.per_channel = per_channel 290 | self.mask = None # mask is only setup (once and for all) when forward() is called 291 | 292 | def make_single_mask(self, x, mask_val=0): 293 | "allocate a 1D group of min_vals (zeros) amidst a bunch of 1's. Put the zeros/min_vals values in the middle" 294 | start = max(0, (x.shape[-1] - self.mask_width)//2 ) 295 | end = min(start + self.mask_width, x.shape[-1]) # don't go over the edge 296 | with torch.no_grad(): 297 | self.mask = torch.ones(x.shape[-1]).to(x.device) 298 | if self.mask_type == 'simple': 299 | self.mask[start:end] = mask_val 300 | elif self.mask_type == 'smoothstep': 301 | coords = torch.linspace(0,1, steps=x.shape[-1]).to(x.device) 302 | ew = self.edge_width if isinstance(self.edge_width,int) else int((end-start)*self.edge_width) # edge width in samples 303 | self.mask = smoothstep_box(coords, edges=[coords[i] for i in [start, start+ew, end-ew, end]]) 304 | elif self.mask_type == 'nyquist': 305 | self.mask[start:end:2], self.mask[start+1:end:2] = 0.5, -0.5 # nyquist noise, amplitude 0.5 seems good 306 | else: 307 | assert False, f"Error: Unsupported mask type: '{self.mask_type}'" 308 | 309 | def mask_once_1channel(self, 310 | xc, # one channel of x 311 | move=None, # amount by which to shift the mask around, in samples 312 | start_loc = None, # can specify where to start from (typically leave this as None) 313 | ): 314 | "excises one mask region for one channel (hence '_1c') in one batch" 315 | # shift the mask forward or backward 316 | shift_by = int((2*np.random.rand()-1)*xc.shape[-1]) if start_loc is None else start_loc 317 | with torch.no_grad(): 318 | mask_view = torch.roll(self.mask, shift_by, -1).to(xc.device) # move the mask around (as a view of original mask tensor) 319 | if self.mask_type != 'nyquist': 320 | return xc * mask_view # this does the excising 321 | else: 322 | return torch.where(mask_view == 1, xc, mask_view) 323 | 324 | 325 | def forward(self, x): 326 | signal = x if not isinstance(x, dict) else x['inputs'] 327 | if self.mask is None: # setup the mask if it hasn't been setup already 328 | if isinstance(self.mask_width, float): # convert it from a fraction to an integer number of samples 329 | self.mask_width = int(signal.shape[-1] * self.mask_width_frac) 330 | self.make_single_mask(signal) 331 | self.n_masks = int(self.mask_frac * signal.shape[-1]/self.mask_width) # number of mask regions to add per channel. we will not worry about whether masks end up overlapping or not 332 | if self.verbose: print("\n MMMM- RandMask1D: Mask engaged! self.mask_width, self.n_masks = ",self.mask_width, self.n_masks,"\n") 333 | 334 | out = signal.clone().to(signal.device) # make a copy so that we don't overwrite x 335 | while len(out.shape) < 3: # add batch dim and channel dim for loop below if needed 336 | out = out.unsqueeze(0) 337 | assert len(out.shape) >= 3, f"Expected x to have 3 or more dimensions but x.shape = {x.shape}" # x.shape should be [b,c,n_samples] 338 | for bi in range(out.shape[0]): # TODO: gotta be a way to do this all at once instead of 3 loops! 339 | if self.per_channel: 340 | for c in range(out.shape[1]): 341 | for i in range(self.n_masks): 342 | out[bi,c,:] = self.mask_once_1channel(out[bi,c,:]) 343 | else: # mask all channels at once. keeps model from cheating when mono has been doubled to L&R 344 | for i in range(self.n_masks): 345 | out[bi,:,:] = self.mask_once_1channel(out[bi,:,:]) 346 | out = torch.reshape(out, signal.shape) 347 | if not isinstance(x, dict): # too complex for pipeline_return 348 | return out 349 | else: 350 | x['unmasked'] = x['inputs'] # save a copy (of the pointer) in case we want it later 351 | x['inputs'] = out 352 | return x 353 | 354 | # %% ../01_datasets.ipynb 44 355 | class AudioDataset(torch.utils.data.Dataset): 356 | """ 357 | Reads from a tree of directories and serves up cropped bits from any and all audio files 358 | found therein. For efficiency, best if you "chunk" these files via chunkadelic 359 | modified from https://github.com/drscotthawley/audio-diffusion/blob/main/dataset/dataset.py 360 | """ 361 | def __init__(self, 362 | paths, # list of strings of directory (/tree) names to draw audio files from 363 | sample_rate=48000, # audio sample rate in Hz 364 | sample_size=65536, # how many audio samples in each "chunk" 365 | random_crop=True, # take chunks from random positions within files 366 | load_frac=1.0, # fraction of total dataset to load 367 | cache_training_data=False, # True = pre-load whole dataset into memory (not fully supported) 368 | num_gpus=8, # used only when `cache_training_data=True`, to avoid duplicates, 369 | redraw_silence=True, # a chunk containing silence will be replaced with a new one 370 | silence_thresh=-60, # threshold in dB below which we declare to be silence 371 | max_redraws=2, # when redrawing silences, don't do it more than this many 372 | augs='Stereo(), PhaseFlipper()', # list of augmentation transforms **after PadCrop**, as a string 373 | verbose=False, # whether to print notices of reasampling or not 374 | return_dict=False # False=return raw audio only, True=return dict of all kinds of info 375 | ): 376 | super().__init__() 377 | 378 | print("augs =",augs) 379 | # base_augs are always applied 380 | base_augs = 'PadCrop(sample_size, randomize=random_crop, redraw_silence=redraw_silence, silence_thresh=silence_thresh, max_redraws=max_redraws)' 381 | self.augs = eval(f'torch.nn.Sequential( {base_augs}, {augs} )') if augs is not None else None 382 | self.silence_thresh = silence_thresh 383 | self.redraw_silence = redraw_silence 384 | self.max_redraws = max_redraws 385 | self.sr = sample_rate 386 | self.cache_training_data = cache_training_data 387 | self.verbose = verbose 388 | self.return_dict = return_dict 389 | 390 | self.filenames = get_audio_filenames(paths) 391 | print(f"AudioDataset:{len(self.filenames)} files found.") 392 | self.n_files = int(len(self.filenames)*load_frac) 393 | self.filenames = self.filenames[0:self.n_files] 394 | if cache_training_data: self.preload_files() 395 | 396 | self.convert_tensor = VT.ToTensor() 397 | 398 | def load_file_ind(self, file_list,i): # used when caching training data 399 | return load_audio(file_list[i], sr=self.sr, verbose=self.verbose).cpu() 400 | 401 | def get_data_range(self): # for parallel runs, only grab part of the data -- OBVIATED BY CHUNKING. 402 | start, stop = 0, len(self.filenames) 403 | try: 404 | local_rank = int(os.environ["LOCAL_RANK"]) 405 | world_size = int(os.environ["WORLD_SIZE"]) 406 | interval = stop//world_size 407 | start, stop = local_rank*interval, (local_rank+1)*interval 408 | return start, stop 409 | except KeyError as e: # we're on GPU 0 and the others haven't been initialized yet 410 | start, stop = 0, len(self.filenames)//self.num_gpus 411 | return start, stop 412 | 413 | def preload_files(self): 414 | print(f"Caching {self.n_files} input audio files:") 415 | wrapper = partial(self.load_file_ind, self.filenames) 416 | start, stop = self.get_data_range() 417 | with Pool(processes=cpu_count()) as p: # //8 to avoid FS bottleneck and/or too many processes (b/c * num_gpus) 418 | self.audio_files = list(tqdm.tqdm(p.imap(wrapper, range(start,stop)), total=stop-start)) 419 | 420 | def __len__(self): 421 | return len(self.filenames) 422 | 423 | 424 | def get_next_chunk(self, 425 | idx, # the index of the file within the list of files 426 | ): 427 | "The heart of this whole dataset routine: Loads file, crops & runs other augmentations" 428 | audio_filename = self.filenames[idx] 429 | try: 430 | if self.cache_training_data: 431 | audio = self.audio_files[idx] # .copy() 432 | else: 433 | audio = load_audio(audio_filename, sr=self.sr, verbose=self.verbose) 434 | x = {'filename':audio_filename, 'inputs':audio} if self.return_dict else audio # x is either audio or dict 435 | x = self.augs(x) # RUN AUGMENTATION PIPELINE 436 | if isinstance(x, dict): 437 | x['inputs'] = x['inputs'].clamp(-1, 1) 438 | else: 439 | x = x.clamp(-1, 1) 440 | return x 441 | 442 | except Exception as e: 443 | print(f'AudioDataset.get_next_chunk: Error loading file {audio_filename}: {e}') 444 | return None 445 | 446 | 447 | def __getitem__(self, 448 | idx # the index of the file within the list of files 449 | ): 450 | "returns either audio tensor or a dict with lots of info" 451 | x = self.get_next_chunk(idx) # x is either audio or a dict, depending on self.return_dict 452 | audio = x if not isinstance(x, dict) else x['inputs'] 453 | 454 | # even with PadCrop set to reject silences, it could be that the whole file is silence; 455 | num_redraws = 0 456 | while (audio is None) or (self.redraw_silence and is_silence(audio, thresh=self.silence_thresh) \ 457 | and (num_redraws < self.max_redraws)): 458 | next_idx = random.randint(0,len(self.filenames)-1) # pick some other file at random 459 | x, num_redraws = self.get_next_chunk(next_idx), num_redraws+1 460 | audio = x if not isinstance(x, dict) else x['inputs'] 461 | 462 | #if self.verbose: print("__getitem__: x =",x) # turning this off. verbose should only show resampling notices 463 | return self[random.randrange(len(self))] if (x is None) else x 464 | 465 | # %% ../01_datasets.ipynb 57 466 | def fix_double_slashes(s, debug=False): 467 | "aws is pretty unforgiving compared to 'normal' filesystems. so here's some 'cleanup'" 468 | cdsh_split = s.split('://') # peel of double-slashes associated with URL 469 | assert (len(cdsh_split) <= 2) and (len(cdsh_split) > 0), f'what kind of string are you using? s={s}' 470 | post = cdsh_split[-1] 471 | while '//' in post: 472 | post = post.replace('//','/') 473 | if len(cdsh_split) > 1: 474 | return cdsh_split[0] + '://' + post 475 | else: 476 | return post 477 | 478 | # %% ../01_datasets.ipynb 62 479 | def get_s3_contents( 480 | dataset_path, # "name" of the dataset on s3 481 | s3_url_prefix='s3://s-laion-audio/webdataset_tar/', # s3 bucket to check 482 | filter='', # only grab certain filename / extensions 483 | recursive=True, # check all subdirectories. RECOMMEND LEAVING THIS TRUE 484 | debug=False, # print debugging info (don't rely on this info staying consistent) 485 | profile='default', # name of the AWS profile credentials 486 | ): 487 | "Gets a list of names of files or subdirectories on an s3 path" 488 | if (dataset_path != '') and (not dataset_path.endswith('/')): 489 | dataset_path = dataset_path + '/' 490 | dataset_path = fix_double_slashes(dataset_path) 491 | #if debug: print(f"cmd string: aws s3 ls {s3_url_prefix}{dataset_path} --profile {profile}") 492 | if not recursive: 493 | run_ls = subprocess.run(['aws','s3','ls',f'{s3_url_prefix}{dataset_path}','--profile',profile], capture_output=True) 494 | else: 495 | run_ls = subprocess.run(['aws','s3','ls',f'{s3_url_prefix}{dataset_path}','--recursive', '--profile',profile], capture_output=True) 496 | run_ls = subprocess.run(["awk",'{$1=$2=$3=""; print $0}'], input=run_ls.stdout, capture_output=True) 497 | run_ls = subprocess.run(["sed",'s/^[ \t]*//'], input=run_ls.stdout, capture_output=True) 498 | contents = run_ls.stdout.decode('utf-8') 499 | if debug: print("1 contents[:10] = \n",contents[:10]) # WARNING: this is a big long list 500 | contents = contents.split('\n') 501 | contents = [x.strip() for x in contents if x] # list of non-empty strings, without leading whitespace 502 | contents = [x.replace('PRE ','') if (x[-1]=='/') else x for x in contents] # directories 503 | #if recursive: # recursive flag weirdly adds redundant extra directory name taken from s3 url, so we should strip 504 | # in recursive cases we'll get the full directory path off the host name 505 | #main_dir = s3_url_prefix.split('/')[-2] # everything after the netloc 506 | #if debug: print("main_dir =",main_dir) 507 | #contents = [x.replace(f'{main_dir}/','').replace(dataset_path,'').replace('//','/') for x in contents] 508 | #contents = [x.replace(dataset_path,'').replace('//','/') for x in contents] 509 | 510 | #if debug: print("2 recursive contents[:10] = ",contents[:10]) 511 | return [x for x in contents if filter in x] # return filtered list 512 | 513 | # %% ../01_datasets.ipynb 71 514 | def get_contiguous_range( 515 | tar_names, # list of tar file names, although the .tar part is actually optional 516 | ): 517 | "given a string of tar file names, return a string of their numerical range if the numbers are contiguous. Otherwise return empty string" 518 | if len(tar_names) == 0: return '' 519 | elif len(tar_names) == 1: return tar_names[-1] 520 | just_nums = [ Path(x).stem for x in tar_names] # get just the filenames, no extension or directory 521 | just_nums.sort(key=int) # sorts numerically but meaningfully preserves leading zeros in strings 522 | nums_arr = np.asarray(just_nums, dtype=int) 523 | is_contiguous = np.abs( (nums_arr - np.roll(nums_arr,1)) [1:] ).max() == 1 524 | if is_contiguous: # {000000..000999} 525 | return '{' + f'{just_nums[0]}..{just_nums[-1]}' +'}' 526 | else: 527 | print("get_contiguous_range: File numbers not continuous") # have to do more work 528 | return '' # empty string will signify no dice; signal for more work to be done 529 | 530 | # %% ../01_datasets.ipynb 87 531 | def get_all_s3_urls( 532 | names=[], # list of all valid [LAION AudioDataset] dataset names, can include URLs in which case s3_url_prefix is ignored 533 | subsets=[''], # list of subsets you want from those datasets, e.g. ['train','valid'] 534 | s3_url_prefix=None, # prefix for those dataset names if no s3:// supplied in names, e.g. 's3://s-laion-audio/webdataset_tar/' 535 | recursive=True, # recursively list all tar files in all subdirs 536 | filter_str='tar', # only grab files with this substring 537 | debug=False, # print debugging info -- note: info displayed likely to change at dev's whims 538 | profiles={}, # list of S3 profiles to use, e.g. {'s3://s-laion-audio':'default'} 539 | **kwargs 540 | ): 541 | "get urls of shards (tar files) for multiple datasets in one s3 bucket" 542 | if s3_url_prefix is None: 543 | s3_url_prefix = '' 544 | urls = [] 545 | names = [''] if names == [] else names # make sure it's a list, for loop below 546 | subsets = [''] if subsets == [] else subsets # make sure it's a list, for loop below 547 | for name in names: 548 | purl = urlparse(name) # check if name already has a URL in it; if so, ignore s3_url_prefix 549 | if purl.scheme == '': 550 | s3_prefix = s3_url_prefix 551 | else: 552 | s3_prefix = f"{purl.scheme}://{purl.netloc}" 553 | name = name.replace(s3_prefix,'') 554 | if debug: 555 | print(f"s3_prefix = {s3_prefix}, name = {name}") 556 | if debug: print(f"get_all_s3_urls: {s3_prefix}{name}:") 557 | for subset in subsets: 558 | contents_str = fix_double_slashes(f'{name}/{subset}/') 559 | # match profile with name or use default 560 | profile = profiles.get(s3_prefix, 'default') 561 | if debug: 562 | print(f" name = {name}, profile = {profile}") 563 | print(" contents_str =",contents_str, ", s3_prefix =",s3_prefix) 564 | tar_list = get_s3_contents(contents_str, s3_url_prefix=s3_prefix, recursive=recursive, filter=filter_str, debug=debug, profile=profile) 565 | for tar in tar_list: 566 | tar = tar.replace(" ","\ ").replace("(","\(").replace(")","\)") # escape spaces and parentheses for shell 567 | s3_path = fix_double_slashes(f"{s3_prefix}/{tar} -") 568 | request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" 569 | if profile != '': request_str += f" --profile {profile}" 570 | if debug: print("request_str = ",request_str) 571 | urls.append(fix_double_slashes(request_str)) 572 | #urls = [x.replace('tar//','tar/') for x in urls] # one last double-check 573 | return urls 574 | 575 | 576 | 577 | import posixpath 578 | 579 | def get_all_s3_urls_zach( 580 | names=[], # list of all valid [LAION AudioDataset] dataset names 581 | subsets=[''], # list of subsets you want from those datasets, e.g. ['train','valid'] 582 | s3_url_prefix=None, # prefix for those dataset names 583 | recursive=True, # recursively list all tar files in all subdirs 584 | filter_str='tar', # only grab files with this substring 585 | debug=False, # print debugging info -- note: info displayed likely to change at dev's whims 586 | profiles={}, # dictionary of profiles for each item in names, e.g. {'dataset1': 'profile1', 'dataset2': 'profile2'} 587 | ): 588 | "get urls of shards (tar files) for multiple datasets in one s3 bucket" 589 | urls = [] 590 | for name in names: 591 | # If s3_url_prefix is not specified, assume the full S3 path is included in each element of the names list 592 | if s3_url_prefix is None or ''==s3_url_prefix: 593 | contents_str = name 594 | else: 595 | # Construct the S3 path using the s3_url_prefix and the current name value 596 | contents_str = posixpath.join(s3_url_prefix, name) 597 | if debug: 598 | print(f"get_all_s3_urls: {contents_str}") 599 | for subset in subsets: 600 | subset_str = posixpath.join(contents_str, subset) 601 | if debug: 602 | print(f" subset_str = {subset_str}") 603 | # Get the list of tar files in the current subset directory 604 | profile = profiles.get(name, 'default') 605 | if debug: print(f" name = {name}, profile = {profile}") 606 | tar_list = get_s3_contents(subset_str, s3_url_prefix=None, recursive=recursive, filter=filter_str, debug=debug, profile=profile) 607 | for tar in tar_list: 608 | # Escape spaces and parentheses in the tar filename for use in the shell command 609 | tar = tar.replace(" ","\ ").replace("(","\(").replace(")","\)") 610 | # Construct the S3 path to the current tar file 611 | s3_path = posixpath.join(name, subset, tar) + " -" 612 | # Construct the AWS CLI command to download the current tar file 613 | if s3_url_prefix is None: 614 | request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {s3_path}" 615 | else: 616 | request_str = f"pipe:aws s3 --cli-connect-timeout 0 cp {posixpath.join(s3_url_prefix, s3_path)}" 617 | if profiles.get(name): 618 | request_str += f" --profile {profiles.get(name)}" 619 | if debug: 620 | print("request_str = ", request_str) 621 | # Add the constructed URL to the list of URLs 622 | urls.append(request_str) 623 | return urls 624 | 625 | # %% ../01_datasets.ipynb 90 626 | class IterableAudioDataset(torch.utils.data.IterableDataset): 627 | "Iterable version of AudioDataset, used with Chain (below)" 628 | def __init__(self, 629 | paths, # list of strings of directory (/tree) names to draw audio files from 630 | sample_rate=48000, # audio sample rate in Hz 631 | sample_size=65536, # how many audio samples in each "chunk" 632 | random_crop=True, # take chunks from random positions within files 633 | load_frac=1.0, # fraction of total dataset to load 634 | cache_training_data=False, # True = pre-load whole dataset into memory (not fully supported) 635 | num_gpus=8, # used only when `cache_training_data=True`, to avoid duplicates, 636 | redraw_silence=True, # a chunk containing silence will be replaced with a new one 637 | silence_thresh=-60, # threshold in dB below which we declare to be silence 638 | max_redraws=2, # when redrawing silences, don't do it more than this many 639 | augs='Stereo(), PhaseFlipper()', # list of augmentation transforms **after PadCrop**, as a string 640 | verbose=False, # whether to print notices of reasampling or not 641 | ): 642 | super().__init__() 643 | self.this = AudioDataset(paths, sample_rate=sample_rate, sample_size=sample_size, random_crop=random_crop, 644 | load_frac=load_frac, cache_training_data=cache_training_data, num_gpus=num_gpus, 645 | redraw_silence=redraw_silence, silence_thresh=silence_thresh, max_redraws=max_redraws, 646 | augs=augs, verbose=verbose) 647 | self.len = len(self.this) 648 | 649 | def __iter__(self): 650 | yield self.this.__getitem__(random.randint(0, self.len)) 651 | 652 | # %% ../01_datasets.ipynb 94 653 | def name_cache_file(url): 654 | "provides the filename to which to cache a url" 655 | return re.findall(r's3:.* -',url)[0][:-2].replace('/','_').replace(' ','\ ').replace(':','_') 656 | 657 | """ 658 | # old version by drscotthawley, replaced by Zach's edits, scroll down 659 | pp_calls = 0 660 | def wds_preprocess_old(sample, sample_size=65536, sample_rate=48000, random_crop=True, verbose=False): 661 | "sampling and processing callback/handler for AudioWebDataLoader, below" 662 | global pp_calls 663 | pp_calls+= 1 664 | if verbose: print("pp_calls =",pp_calls) 665 | audio_keys = ("flac", "wav", "mp3", "aiff") 666 | found_key, rewrite_key = '', 'audio' # SHH added 'audio' key for to match zach's webdataloader 667 | if verbose: print(f"----> Starting wds_preprocess: sample.items() = {sample.items()}") 668 | for k,v in sample.items(): # print the all entries in dict 669 | for akey in audio_keys: 670 | if k.endswith(akey): 671 | found_key, rewrite_key = k, akey # to rename long/weird key with its simpler counterpart 672 | break 673 | if '' != found_key: break 674 | if '' == found_key: # got no audio! 675 | print(" wds_preprocess: Error: No audio in this sample:") 676 | for k,v in sample.items(): # print the all entries in dict 677 | print(f" {k:20s} {repr(v)[:50]}") 678 | print(" wds_preprocess: Skipping it.") 679 | return None # try returning None to tell WebDataset to skip this one ? 680 | 681 | audio, in_sr = sample[found_key] 682 | if in_sr != sample_rate: 683 | if verbose: print(f"wds_preprocess: Resampling {filename} from {in_sr} Hz to {sample_rate} Hz",flush=True) 684 | resample_tf = T.Resample(in_sr, sample_rate) 685 | audio = resample_tf(audio) 686 | 687 | # apply cropping and normalization 688 | #myop = torch.nn.Sequential(PadCrop(sample_size, randomize=random_crop), Stereo(), PhaseFlipper()) 689 | #audio = myop(audio) 690 | 691 | # Pad/crop and get the relative timestamp 692 | #pad_crop = PadCrop(sample_size, randomize=random_crop) 693 | pad_crop = PadCrop_Normalized_T(sample_size, randomize=random_crop) 694 | audio, t_start, t_end = pad_crop(audio) 695 | 696 | # Make the audio stereo and augment by randomly inverting phase 697 | augs = torch.nn.Sequential(Stereo(), PhaseFlipper()) 698 | audio = augs(audio) 699 | 700 | sample["timestamps"] = (t_start, t_end) 701 | sample["audio"] = audio # regardless of what's above, let's also make a key pointing to the audio 702 | 703 | if found_key != rewrite_key: # rename long/weird key with its simpler counterpart 704 | del sample[found_key] 705 | sample[rewrite_key] = audio 706 | 707 | 708 | if verbose: print(f" ----> Leaving wds_preprocess: sample.items() = {sample.items()}") 709 | 710 | return sample 711 | """ 712 | 713 | def wds_preprocess( 714 | sample, 715 | sample_size=65536, 716 | sample_rate=48000, 717 | verbose=False, 718 | random_crop=True, 719 | normalize_lufs=None, 720 | metadata_prompt_funcs=None, 721 | force_channels = "stereo", 722 | augment_phase = True, 723 | ): 724 | """utility routine for QuickWebDataLoader, below. 725 | New version by Zach Evans, from https://github.com/zqevans/audio-diffusion/dataset.py. Old version in source, commented out 726 | """ 727 | audio_keys = ("flac", "wav", "mp3", "m4a", "ogg") 728 | 729 | found_key, rewrite_key = '', '' 730 | for k,v in sample.items(): # print the all entries in dict 731 | for akey in audio_keys: 732 | if k.endswith(akey): 733 | found_key, rewrite_key = k, akey # to rename long/weird key with its simpler counterpart 734 | break 735 | if '' != found_key: break 736 | if '' == found_key: # got no audio! 737 | # print(" Error: No audio in this sample:") 738 | # for k,v in sample.items(): # print the all entries in dict 739 | # print(f" {k:20s} {repr(v)[:50]}") 740 | # print(" Skipping it.") 741 | return None # try returning None to tell WebDataset to skip this one ? 742 | 743 | audio, in_sr = sample[found_key] 744 | if in_sr != sample_rate: 745 | if in_sr < 8000: 746 | print(f"Very low SR ({in_sr}) for file {sample['url']}") 747 | if verbose: print(f"Resampling from {in_sr} Hz to {sample_rate} Hz",flush=True) 748 | resample_tf = T.Resample(in_sr, sample_rate) 749 | audio = resample_tf(audio) 750 | 751 | if normalize_lufs is not None: 752 | # Loudness normalization to -12 LKFS, adapted from pyloudnorm 753 | meter = pyln.Meter(sample_rate) 754 | loudness = meter.integrated_loudness(audio.transpose(-2, -1).numpy()) 755 | delta_loudness = (normalize_lufs - float(loudness)) 756 | gain = 10.0 ** (delta_loudness/20.0) 757 | audio = gain * audio 758 | 759 | if sample_size is not None: 760 | # Pad/crop and get the relative timestamp 761 | pad_crop = PadCrop_Normalized_T(sample_size, randomize=random_crop, sample_rate=sample_rate) 762 | audio, t_start, t_end, seconds_start, seconds_total = pad_crop(audio) 763 | sample["json"]["seconds_start"] = seconds_start 764 | sample["json"]["seconds_total"] = seconds_total 765 | else: 766 | t_start, t_end = 0, 1 767 | 768 | #Check if audio is length zero, initialize to a single zero if so 769 | if audio.shape[-1] == 0: 770 | audio = torch.zeros(1, 1) 771 | 772 | # Make the audio stereo and augment by randomly inverting phase 773 | augs = torch.nn.Sequential( 774 | Stereo() if force_channels == "stereo" else torch.nn.Identity(), 775 | Mono() if force_channels == "mono" else torch.nn.Identity(), 776 | PhaseFlipper() if augment_phase else torch.nn.Identity() 777 | ) 778 | audio = augs(audio) 779 | 780 | sample["timestamps"] = (t_start, t_end) 781 | 782 | if "text" in sample["json"]: 783 | sample["json"]["prompt"] = sample["json"]["text"] 784 | 785 | if metadata_prompt_funcs is not None: 786 | for key, prompt_func in metadata_prompt_funcs.items(): 787 | if key in sample["__url__"]: 788 | prompt = prompt_func(sample["json"]) 789 | sample["json"]["prompt"] = prompt 790 | 791 | if found_key != rewrite_key: # rename long/weird key with its simpler counterpart 792 | del sample[found_key] 793 | sample["audio"] = audio 794 | return sample 795 | 796 | # %% ../01_datasets.ipynb 97 797 | def log_and_continue(exn): 798 | """Call in an exception handler to ignore any exception, isssue a warning, and continue. 799 | source: audio-diffusion repo""" 800 | print(f"Handling webdataset error ({repr(exn)}). Ignoring.") 801 | rank, world_size, worker, num_workers = wds.utils.pytorch_worker_info() 802 | print(f"Rank: {rank}, worker: {worker}") 803 | return True 804 | 805 | def is_valid_sample(sample): 806 | """source: audio-diffusion repo""" 807 | silence = is_silence(sample["audio"]) 808 | result = ("json" in sample) and ("audio" in sample) and not silence 809 | if result==False: 810 | print(f'is_valid_sample: result=False: ("json" in sample)={("json" in sample)}, ("audio" in sample) = {("audio" in sample)}, silence = {silence} ') 811 | return result 812 | 813 | # %% ../01_datasets.ipynb 99 814 | def AudioWebDataLoader( 815 | names=['FSD50K'], # names of datasets. will search all available s3 urls 816 | subsets=[''], # list of subsets you want from those datasets, e.g. ['train','valid'] 817 | s3_url_prefix='s3://s-laion-audio/webdataset_tar/', # prefix for those dataset names 818 | profile='', # AWS S3 profile string to pass in (default: none) 819 | audio_file_ext="wav;flac;mp3;ogg;aiff;aif", # extension(s) of audio files; passed to wds.to_tuple 820 | filter_str='tar', # only grab files with this substring 821 | recursive=True, # recursively list all tar files in all subdirs 822 | sample_size=65536, # how long each sample to grab via PadCrop 823 | sample_rate=48000, # standard sr in Hz 824 | random_crop=True, # take chunks from random positions within files 825 | num_workers=os.cpu_count()//2,# number of PyTorch DataLoaders 826 | prefetch_factor=10, # number of batches to pre-fetch 827 | batch_size=4, # typical batch size 828 | shuffle_vals=[1000, 10000], # values passed into shuffle as per WDS tutorials 829 | epoch_len=1000, # how many passes/loads make for an epoch? wds part of this is not well documented IMHO 830 | debug=False, # print info on internal workings 831 | verbose=False, # unlike debug. this only prints in the callback 832 | callback=wds_preprocess, # function to call for additional user-based processing 833 | shuffle_urls=True, # shuffle url list before it's passed to WebDataset 834 | shuffle_seed=None, # seed for shuffling of urls 835 | zachs=True, # use zach's data pipeline or hawley's 836 | **kwargs, # what else to pass to callback 837 | ): 838 | "Sets up a WebDataLoader pipeline with some typical defaults for audio files" 839 | if verbose: 840 | print("AudioWebDataLoader: Note: 'Broken pipe' messages you might get aren't a big deal, but may indicate files that are too big.") 841 | print("AudioWebDataLoader: ", ', '.join(['{}={!r}'.format(k, v) for k, v in locals().items()])) 842 | if names is not list: names = list(names) 843 | urls = get_all_s3_urls(names=names, subsets=subsets, s3_url_prefix=s3_url_prefix, recursive=recursive, 844 | profile=profile, filter_str=filter_str, debug=debug) 845 | if debug: print("AudioWebDataLoader: urls =\n",urls) 846 | os.environ["WDS_VERBOSE_CACHE"] = "1" # tell webdataset to cache stuff 847 | if len(urls) > 0: 848 | if shuffle_urls: 849 | if shuffle_seed is not None: 850 | random.seed(shuffle_seed) 851 | random.shuffle(urls) 852 | if debug: print("AudioWebDataLoader: shuffled urls =\n",urls) 853 | if zachs: 854 | dataset = wds.DataPipeline( 855 | wds.ResampledShards(urls), # Yields a single .tar URL 856 | wds.tarfile_to_samples(handler=log_and_continue), # Opens up a stream to the TAR file, yields files grouped by keys 857 | wds.shuffle(shuffle_vals[0], handler=log_and_continue), # SHH added 858 | wds.decode(wds.torch_audio, handler=log_and_continue), 859 | wds.map(partial(callback, sample_size=sample_size, sample_rate=sample_rate, verbose=verbose, random_crop=random_crop, **kwargs), handler=log_and_continue), 860 | wds.shuffle(shuffle_vals[1], handler=log_and_continue), # SHH added 861 | wds.select(is_valid_sample), 862 | wds.to_tuple("audio", "json", "timestamps", handler=log_and_continue), 863 | wds.batched(batch_size, partial=False) 864 | ).with_epoch(epoch_len//num_workers if num_workers > 0 else epoch_len) 865 | else: 866 | dataset = wds.DataPipeline( 867 | wds.ResampledShards(urls), # cache_dir='./_mycache'), <-- not allowed 868 | wds.tarfile_to_samples(), 869 | wds.shuffle(shuffle_vals[0]), 870 | wds.decode(wds.torch_audio), 871 | wds.map(partial(callback, sample_size=sample_size, sample_rate=sample_rate, verbose=verbose, random_crop=random_crop, **kwargs)), 872 | wds.shuffle(shuffle_vals[1]), 873 | wds.to_tuple(audio_file_ext), # here's where it searches for the file extension 874 | wds.batched(batch_size), 875 | ).with_epoch(epoch_len) 876 | 877 | return wds.WebLoader(dataset, num_workers=num_workers, prefetch_factor=prefetch_factor, **kwargs) 878 | else: 879 | print("*****ERROR: AudioWebDataLoader: No URLs found. Returning 'None'") 880 | return None 881 | 882 | # %% ../01_datasets.ipynb 105 883 | def get_wds_loader(batch_size, sample_size, names, s3_url_prefix=None, sample_rate=48000, num_workers=8, 884 | recursive=True, profiles={}, epoch_steps=1000, random_crop=True, normalize_lufs=None, 885 | metadata_prompt_funcs=None, force_channels="stereo", augment_phase=True): 886 | "Simpler loader from https://github.com/zqevans/audio-diffusion/dataset.py" 887 | 888 | preprocess_fn = partial(wds_preprocess, sample_size=sample_size, sample_rate=sample_rate, random_crop=random_crop, normalize_lufs=normalize_lufs, metadata_prompt_funcs=metadata_prompt_funcs, force_channels=force_channels, augment_phase=augment_phase) 889 | 890 | urls = get_all_s3_urls( 891 | names=names, 892 | s3_url_prefix=s3_url_prefix, 893 | recursive=recursive, 894 | profiles=profiles 895 | ) 896 | 897 | dataset = wds.DataPipeline( 898 | wds.ResampledShards(urls), # Yields a single .tar URL 899 | wds.tarfile_to_samples(handler=log_and_continue), # Opens up a stream to the TAR file, yields files grouped by keys 900 | wds.decode(wds.torch_audio, handler=log_and_continue), 901 | wds.map(preprocess_fn, handler=log_and_continue), 902 | #wds.shuffle(bufsize=100, initial=10, handler=log_and_continue), # Pulls from iterator until initial value 903 | wds.select(is_valid_sample), 904 | wds.to_tuple("audio", "json", "timestamps", handler=log_and_continue), 905 | wds.batched(batch_size, partial=False) 906 | ).with_epoch(epoch_steps//num_workers if num_workers > 0 else epoch_steps) 907 | 908 | return wds.WebLoader(dataset, num_workers=num_workers) 909 | -------------------------------------------------------------------------------- /aeiou/hpc.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../05_hpc.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['get_accel_config', 'HostPrinter', 'save', 'load', 'n_params', 'freeze'] 5 | 6 | # %% ../05_hpc.ipynb 5 7 | import yaml 8 | import accelerate 9 | from pathlib import Path 10 | import tqdm 11 | import torch 12 | import torchaudio 13 | from torchaudio import transforms as T 14 | import os 15 | 16 | 17 | # %% ../05_hpc.ipynb 6 18 | def get_accel_config(filename='~/.cache/huggingface/accelerate/default_config.yaml'): 19 | "get huggingface accelerate config info" 20 | try: # first try to use the default file 21 | filename = filename.replace('~', str(Path.home())) 22 | with open(filename, 'r') as file: 23 | ac = yaml.safe_load(file) 24 | except OSError: 25 | ac = {} 26 | 27 | # then update using any environment variables 28 | if os.getenv('MAIN_PROCESS_IP') is not None: ac['main_process_ip'] = os.getenv('MAIN_PROCESS_IP') 29 | if os.getenv('MACHINE_RANK') is not None: ac['machine_rank'] = os.getenv('MACHINE_RANK') 30 | if os.getenv('NUM_MACHINES') is not None: ac['num_machines'] = os.getenv('NUM_MACHINES') 31 | if os.getenv('NUM_PROCESSES') is not None: ac['num_processes'] = os.getenv('NUM_PROCESSES') 32 | 33 | return ac 34 | 35 | # %% ../05_hpc.ipynb 10 36 | class HostPrinter(): 37 | "lil accelerate utility for only printing on host node" 38 | def __init__( 39 | self, 40 | accelerator, # huggingface accelerator object 41 | tag='\033[96m', # starting color 42 | untag='\033[0m' # reset to default color 43 | ): 44 | self.accelerator, self.tag, self.untag = accelerator, tag, untag 45 | def __call__(self, s:str): 46 | if self.accelerator.is_main_process: 47 | print(self.tag + s + self.untag, flush=True) 48 | 49 | # %% ../05_hpc.ipynb 14 50 | def save( 51 | accelerator, # Huggingface accelerator object 52 | args, # prefigure args dict, (we only use args.name) 53 | model, # the model, pre-unwrapped 54 | opt=None, # optimizer state 55 | epoch=None, # training epoch number 56 | step=None # training setp number 57 | ): 58 | "for checkpointing & model saves" 59 | #accelerator.wait_for_everyone() # hangs 60 | filename = f'{args.name}_{step:08}.pth' if (step is not None) else f'{args.name}.pth' 61 | if accelerator.is_main_process: 62 | print(f'\nSaving checkpoint to {filename}...') 63 | obj = {'model': accelerator.unwrap_model(model).state_dict() } 64 | if opt is not None: obj['opt'] = opt.state_dict() 65 | if epoch is not None: obj['epoch'] = epoch 66 | if step is not None: obj['step'] = step 67 | accelerator.save(obj, filename) 68 | 69 | # %% ../05_hpc.ipynb 15 70 | def load( 71 | accelerator, # Huggingface accelerator object 72 | model, # an uninitialized model (pre-unwrapped) whose weights will be overwritten 73 | filename:str, # name of the checkpoint file 74 | opt=None, # optimizer state UNUSED FOR NOW 75 | ): 76 | "load a saved model checkpoint" 77 | #accelerator.wait_for_everyone() # hangs 78 | if accelerator.is_main_process: 79 | print(f'\nLoading checkpoint from {filename}...') 80 | accelerator.unwrap_model(model).load_state_dict(torch.load(filename)['model']) 81 | return model # this return isn't actually needed since model is already updated at this point 82 | 83 | # %% ../05_hpc.ipynb 17 84 | def n_params( 85 | module # raw PyTorch model/module, e.g. returned by accelerator.unwrap_model() 86 | ): 87 | """Returns the number of trainable parameters in a module. 88 | Be sure to use accelerator.unwrap_model when calling this. """ 89 | return sum(p.numel() for p in module.parameters()) 90 | 91 | # %% ../05_hpc.ipynb 18 92 | def freeze( 93 | model # raw PyTorch model, e.g. returned by accelerator.unwrap_model() 94 | ): 95 | """freezes model weights; turns off gradient info 96 | If using accelerate, call thisaccelerator.unwrap_model when calling this. """ 97 | for param in model.parameters(): 98 | param.requires_grad = False 99 | -------------------------------------------------------------------------------- /aeiou/spectrofu.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../04_spectrofu.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['save_stft', 'process_one_file', 'main'] 5 | 6 | # %% ../04_spectrofu.ipynb 6 7 | import argparse 8 | from glob import glob 9 | from pathlib import Path 10 | import os 11 | import math 12 | from multiprocessing import Pool, cpu_count, Barrier 13 | from functools import partial 14 | from tqdm.contrib.concurrent import process_map 15 | import torch 16 | import torchaudio 17 | from .core import is_silence, load_audio, makedir, get_audio_filenames 18 | from .viz import audio_spectrogram_image 19 | 20 | # %% ../04_spectrofu.ipynb 7 21 | def save_stft( 22 | audio:torch.tensor, # long audio file to be chunked 23 | new_filename:str # stem of new filename(s) to be output as spectrogram images 24 | ): 25 | "coverts audio to stft image and saves it" 26 | im = audio_spectrogram_image(audio, justimage=True) # should already be a PIL image 27 | print(f"saving new file = {new_filename}") 28 | im.save(new_filename) 29 | return 30 | 31 | # %% ../04_spectrofu.ipynb 8 32 | def process_one_file( 33 | filenames:list, # list of filenames from which we'll pick one 34 | args, # output of argparse 35 | file_ind # index from filenames list to read from 36 | ): 37 | "this turns one audio file into a spectrogram. left channel only for now" 38 | filename = filenames[file_ind] # this is actually input_path+/+filename 39 | output_path, input_paths = args.output_path, args.input_paths 40 | new_filename = None 41 | 42 | for ipath in input_paths: # set up the output filename & any folders it needs 43 | if ipath in filename: # this just avoids repeats/ weirdness. 44 | last_ipath = ipath.split('/')[-1] # get the last part of ipath 45 | clean_filename = filename.replace(ipath,'') # remove all of ipath from the front of filename 46 | new_filename = f"{output_path}/{last_ipath}/{clean_filename}".replace('//','/') 47 | new_filename = str(Path(new_filename).with_suffix(".png")) # give it file extension for image 48 | makedir(os.path.dirname(new_filename)) # we might need to make a directory for the output file 49 | break 50 | 51 | if new_filename is None: 52 | print(f"ERROR: Something went wrong with name of input file {filename}. Skipping.",flush=True) 53 | return 54 | 55 | try: 56 | audio = load_audio(filename, sr=args.sr) 57 | save_stft(audio, new_filename) 58 | except Exception as e: 59 | print(f"Some kind of error happened with {filename}, either loading or writing images. Skipping.", flush=True) 60 | 61 | return 62 | 63 | 64 | def main(): 65 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 66 | parser.add_argument('--sr', type=int, default=48000, help='Output sample rate') 67 | parser.add_argument('--workers', type=int, default=min(32, os.cpu_count() + 4), help='Maximum number of workers to use (default: all)') 68 | parser.add_argument('output_path', help='Path of output for spectrogram-ified data') 69 | parser.add_argument('input_paths', nargs='+', help='Path(s) of a file or a folder of files. (recursive)') 70 | args = parser.parse_args() 71 | 72 | print(f" output_path = {args.output_path}") 73 | 74 | print("Getting list of input filenames") 75 | filenames = get_audio_filenames(args.input_paths) 76 | n = len(filenames) 77 | print(f" Got {n} input filenames") 78 | 79 | print("Processing files (in parallel)") 80 | wrapper = partial(process_one_file, filenames, args) 81 | r = process_map(wrapper, range(0, n), chunksize=1, max_workers=args.workers) # different chunksize used by tqdm. max_workers is to avoid annoying other ppl 82 | 83 | print("Finished") 84 | -------------------------------------------------------------------------------- /aeiou/viz.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../02_viz.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['plotly_already_setup', 'embeddings_table', 'project_down', 'proj_pca', 'point_cloud', 'pca_point_cloud', 'on_colab', 5 | 'setup_plotly', 'show_point_cloud', 'show_pca_point_cloud', 'print_stats', 'mel_spectrogram', 6 | 'spectrogram_image', 'audio_spectrogram_image', 'generate_melspec', 'playable_spectrogram', 7 | 'tokens_spectrogram_image', 'plot_jukebox_embeddings'] 8 | 9 | # %% ../02_viz.ipynb 5 10 | import math 11 | import os 12 | from pathlib import Path 13 | from matplotlib.backends.backend_agg import FigureCanvasAgg 14 | import matplotlib.cm as cm 15 | import matplotlib.pyplot as plt 16 | from matplotlib.colors import Normalize 17 | from matplotlib.figure import Figure 18 | import numpy as np 19 | from PIL import Image 20 | 21 | import torch 22 | from torch import optim, nn 23 | from torch.nn import functional as F 24 | import torchaudio 25 | import torchaudio.transforms as T 26 | from librosa import power_to_db 27 | from einops import rearrange 28 | 29 | import wandb 30 | import numpy as np 31 | from pandas import DataFrame 32 | import umap 33 | 34 | from IPython.display import display, HTML # just for displaying inside notebooks 35 | 36 | from .core import load_audio 37 | 38 | # for playable spectrograms: 39 | import plotly.graph_objects as go 40 | import holoviews as hv 41 | import panel as pn 42 | from bokeh.resources import INLINE 43 | from scipy.signal import spectrogram 44 | #from bokeh.io import output_notebook 45 | 46 | # %% ../02_viz.ipynb 6 47 | def embeddings_table(tokens): 48 | "make a table of embeddings for use with wandb" 49 | features, labels = [], [] 50 | embeddings = rearrange(tokens, 'b d n -> b n d') # each demo sample is n vectors in d-dim space 51 | for i in range(embeddings.size()[0]): # nested for's are slow but sure ;-) 52 | for j in range(embeddings.size()[1]): 53 | features.append(embeddings[i,j].detach().cpu().numpy()) 54 | labels.append([f'demo{i}']) # labels does the grouping / color for each point 55 | features = np.array(features) 56 | #print("\nfeatures.shape = ",features.shape) 57 | labels = np.concatenate(labels, axis=0) 58 | cols = [f"dim_{i}" for i in range(features.shape[1])] 59 | df = DataFrame(features, columns=cols) 60 | df['LABEL'] = labels 61 | return wandb.Table(columns=df.columns.to_list(), data=df.values) 62 | 63 | # %% ../02_viz.ipynb 7 64 | def project_down(tokens, # batched high-dimensional data with dims (b,d,n) 65 | proj_dims=3, # dimensions to project to 66 | method='pca', # projection method: 'pca'|'umap' 67 | n_neighbors=10, # umap parameter for number of neighbors 68 | min_dist=0.3, # umap param for minimum distance 69 | debug=False, # print more info while running 70 | **kwargs, # other params to pass to umap, cf. https://umap-learn.readthedocs.io/en/latest/parameters.html 71 | ): 72 | "this projects to lower dimenions, grabbing the first _`proj_dims`_ dimensions" 73 | method = method.lower() 74 | A = rearrange(tokens, 'b d n -> (b n) d') # put all the vectors into the same d-dim space 75 | if A.shape[-1] > proj_dims: 76 | if method=='umap': 77 | proj_data = umap.UMAP(n_components=proj_dims, n_neighbors=n_neighbors, min_dist=min_dist, 78 | metric='correlation', **kwargs).fit_transform(A.cpu().numpy()) 79 | proj_data = torch.from_numpy(proj_data).to(tokens.device) 80 | else: # pca 81 | (U, S, V) = torch.pca_lowrank(A) 82 | proj_data = torch.matmul(A, V[:, :proj_dims]) # this is the actual PCA projection step 83 | else: 84 | proj_data = A 85 | if debug: print("proj_data.shape =",proj_data.shape) 86 | return torch.reshape(proj_data, (tokens.size()[0], -1, proj_dims)) # put it in shape [batch, n, proj_dims] 87 | 88 | 89 | def proj_pca(tokens, proj_dims=3): 90 | return project_down(do_proj, method='pca', proj_dims=proj_dims) 91 | 92 | # %% ../02_viz.ipynb 9 93 | def point_cloud( 94 | tokens, # embeddings / latent vectors. shape = (b, d, n) 95 | method='pca', # projection method for 3d mapping: 'pca' | 'umap' 96 | color_scheme='batch', # 'batch': group by sample; integer n: n groups, sequentially, otherwise color sequentially by time step 97 | output_type='wandbobj', # plotly | points | wandbobj. NOTE: WandB can do 'plotly' directly! 98 | mode='markers', # plotly scatter mode. 'lines+markers' or 'markers' 99 | size=3, # size of the dots 100 | line=dict(color='rgba(10,10,10,0.01)'), # if mode='lines+markers', plotly line specifier. cf. https://plotly.github.io/plotly.py-docs/generated/plotly.graph_objects.scatter3d.html#plotly.graph_objects.scatter3d.Line 101 | ds_preproj=1, # EXPERIMENTAL: downsampling factor before projecting (1=no downsampling). Could screw up colors 102 | ds_preplot=1, # EXPERIMENTAL: downsampling factor before plotting (1=no downsampling). Could screw up colors 103 | debug=False, # print more info 104 | colormap=None, # valid color map to use, None=defaults 105 | darkmode=False, # dark background, white fonts 106 | layout_dict=None, # extra plotly layout options such as camera orientation 107 | **kwargs, # anything else to pass along 108 | ): 109 | "returns a 3D point cloud of the tokens" 110 | if ds_preproj != 1: 111 | tokens = tokens[torch.randperm(tokens.size()[0])] # EXPERIMENTAL: to 'correct' for possible weird effects of downsampling 112 | tokens = tokens[::ds_preproj] 113 | if debug: print("tokens.shape =",tokens.shape) 114 | 115 | data = project_down(tokens, method=method, debug=debug, **kwargs).cpu().numpy() 116 | if debug: print("data.shape =",data.shape) 117 | if data.shape[-1] < 3: # for data less than 3D, embed it in 3D 118 | data = np.pad(data, ((0,0),(0,0),(0, 3-data.shape[-1])), mode='constant', constant_values=0) 119 | 120 | bytime = False 121 | points = [] 122 | if color_scheme=='batch': # all dots in same batch index same color, each batch-index unique (almost) 123 | ncolors = data.shape[0] 124 | cmap, norm = cm.tab20, Normalize(vmin=0, vmax=ncolors) 125 | elif isinstance(color_scheme, int) or color_scheme.isnumeric(): # n groups, by batch-indices, sequentially 126 | ncolors = int(color_scheme) 127 | cmap, norm = cm.tab20, Normalize(vmin=0, vmax=ncolors) 128 | else: # time steps match up 129 | bytime, ncolors = True, data.shape[1] 130 | cmap, norm = cm.viridis, Normalize(vmin=0, vmax=ncolors) 131 | 132 | cmap = cmap if colormap is None else colormap # overwrite default cmap with user choice if given 133 | 134 | points = [] 135 | for bi in range(data.shape[0]): # batch index 136 | if color_scheme=='batch': 137 | [r, g, b, _] = [int(255*x) for x in cmap(norm(bi+1))] 138 | elif isinstance(color_scheme, int) or color_scheme.isnumeric(): 139 | grouplen = data.shape[0]//(ncolors) 140 | #if debug: print(f"bi, grouplen, bi//grouplen = ",bi, grouplen, bi//grouplen) 141 | [r, g, b, _] = [int(255*x) for x in cmap(norm(bi//grouplen))] 142 | #if debug: print("r,g,b = ",r,g,b) 143 | for n in range(data.shape[1]): # across time 144 | if bytime: [r, g, b, _] = [int(255*x) for x in cmap(norm(n))] 145 | points.append([data[bi,n,0], data[bi,n,1], data[bi,n,2], r, g, b]) # include dot colors with point coordinates 146 | 147 | point_cloud = np.array(points) 148 | 149 | if output_type == 'points': 150 | return point_cloud 151 | elif output_type =='plotly': 152 | fig = go.Figure(data=[go.Scatter3d( 153 | x=point_cloud[::ds_preplot,0], y=point_cloud[::ds_preplot,1], z=point_cloud[::ds_preplot,2], 154 | marker=dict(size=size, color=point_cloud[:,3:6]), 155 | mode=mode, 156 | # show batch index and time index in tooltips: 157 | text=[ f'bi: {i*ds_preplot}, ti: {j}' for i in range(data.shape[0]//ds_preplot) for j in range(data.shape[1]) ], 158 | line=line, 159 | )]) 160 | fig.update_layout(margin=dict(l=0, r=0, b=0, t=0)) # tight layout 161 | if darkmode: 162 | fig.layout.template = 'plotly_dark' 163 | if isinstance(darkmode, str): # 'rgb(12,15,24)'gradio margins in dark mode 164 | fig.update_layout( paper_bgcolor=darkmode) 165 | if layout_dict: 166 | fig.update_layout( **layout_dict ) 167 | 168 | if debug: print("point_cloud: fig made. returning") 169 | return fig 170 | else: 171 | return wandb.Object3D(point_cloud) 172 | 173 | 174 | def pca_point_cloud( 175 | tokens, # embeddings / latent vectors. shape = (b, d, n) 176 | color_scheme='batch', # 'batch': group by sample, otherwise color sequentially 177 | output_type='wandbobj', # plotly | points | wandbobj. NOTE: WandB can do 'plotly' directly! 178 | mode='markers', # plotly scatter mode. 'lines+markers' or 'markers' 179 | size=3, # size of the dots 180 | line=dict(color='rgba(10,10,10,0.01)'), # if mode='lines+markers', plotly line specifier. cf. https://plotly.github.io/plotly.py-docs/generated/plotly.graph_objects.scatter3d.html#plotly.graph_objects.scatter3d.Line 181 | **kwargs, 182 | ): 183 | return point_cloud(tokens, method='pca', color_scheme=color_scheme, output_type=output_type, 184 | mode=mode, size=size, line=line, **kwargs) 185 | 186 | # %% ../02_viz.ipynb 11 187 | # have to do a little extra stuff to make this come out in the docs. This part taken from drscotthawley's `mrspuff` library 188 | def on_colab(): # cf https://stackoverflow.com/questions/53581278/test-if-notebook-is-running-on-google-colab 189 | """Returns true if code is being executed on Colab, false otherwise""" 190 | try: 191 | return 'google.colab' in str(get_ipython()) 192 | except NameError: # no get_ipython, so definitely not on Colab 193 | return False 194 | 195 | plotly_already_setup = False 196 | def setup_plotly(nbdev=True): 197 | """Plotly is already 'setup' on colab, but on regular Jupyter notebooks we need to do a couple things""" 198 | global plotly_already_setup 199 | if plotly_already_setup: return 200 | if nbdev and not on_colab(): # Nick Burrus' code for normal-Juptyer use with plotly & nbdev 201 | import plotly.io as pio 202 | pio.renderers.default = 'notebook_connected' 203 | js = '' 204 | display(HTML(js)) 205 | plotly_already_setup = True 206 | 207 | def show_point_cloud(tokens, # same arts as point_cloud 208 | method='pca', 209 | color_scheme='batch', 210 | mode='markers', 211 | line=dict(color='rgba(10,10,10,0.01)'), 212 | ds_preproj=1, 213 | ds_preplot=1, 214 | debug=False, 215 | **kwargs): 216 | "display a 3d scatter plot of tokens in notebook" 217 | setup_plotly() 218 | fig = point_cloud(tokens, ds_preproj=ds_preproj, ds_preplot=ds_preplot, debug=debug, method=method, 219 | color_scheme=color_scheme, output_type='plotly', mode=mode, line=line, **kwargs) 220 | fig.show() 221 | 222 | def show_pca_point_cloud(tokens, 223 | color_scheme='batch', 224 | mode='markers', 225 | colormap=None, 226 | line=dict(color='rgba(10,10,10,0.01)'), 227 | **kwargs, 228 | ): 229 | "display a 3d scatter plot of tokens in notebook" 230 | show_point_cloud(tokens, color_scheme=color_scheme, mode=mode, colormap=colormap, line=line, **kwargs) 231 | 232 | # %% ../02_viz.ipynb 21 233 | def print_stats(waveform, sample_rate=None, src=None, print=print): 234 | "print stats about waveform. Taken verbatim from pytorch docs." 235 | if src: 236 | print(f"-" * 10) 237 | print(f"Source: {src}") 238 | print(f"-" * 10) 239 | if sample_rate: 240 | print(f"Sample Rate: {sample_rate}") 241 | print(f"Shape: {tuple(waveform.shape)}") 242 | print(f"Dtype: {waveform.dtype}") 243 | print(f" - Max: {waveform.max().item():6.3f}") 244 | print(f" - Min: {waveform.min().item():6.3f}") 245 | print(f" - Mean: {waveform.mean().item():6.3f}") 246 | print(f" - Std Dev: {waveform.std().item():6.3f}") 247 | print('') 248 | print(f"{waveform}") 249 | print('') 250 | 251 | # %% ../02_viz.ipynb 25 252 | def mel_spectrogram(waveform, power=2.0, sample_rate=48000, db=False, n_fft=1024, n_mels=128, debug=False): 253 | "calculates data array for mel spectrogram (in however many channels)" 254 | win_length = None 255 | hop_length = n_fft//2 # 512 256 | 257 | mel_spectrogram_op = T.MelSpectrogram( 258 | sample_rate=sample_rate, n_fft=n_fft, win_length=win_length, 259 | hop_length=hop_length, center=True, pad_mode="reflect", power=power, 260 | norm='slaney', onesided=True, n_mels=n_mels, mel_scale="htk") 261 | 262 | melspec = mel_spectrogram_op(waveform.float()) 263 | if db: 264 | amp_to_db_op = T.AmplitudeToDB() 265 | melspec = amp_to_db_op(melspec) 266 | if debug: 267 | print_stats(melspec, print=print) 268 | print(f"torch.max(melspec) = {torch.max(melspec)}") 269 | print(f"melspec.shape = {melspec.shape}") 270 | return melspec 271 | 272 | # %% ../02_viz.ipynb 26 273 | def spectrogram_image( 274 | spec, 275 | title=None, 276 | ylabel='freq_bin', 277 | aspect='auto', 278 | xmax=None, 279 | db_range=[35,120], 280 | justimage=False, 281 | figsize=(5, 4), # size of plot (if justimage==False) 282 | ): 283 | "Modified from PyTorch tutorial https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html" 284 | fig = Figure(figsize=figsize, dpi=100) if not justimage else Figure(figsize=(4.145, 4.145), dpi=100, tight_layout=True) 285 | canvas = FigureCanvasAgg(fig) 286 | axs = fig.add_subplot() 287 | spec = spec.squeeze() 288 | im = axs.imshow(power_to_db(spec), origin='lower', aspect=aspect, vmin=db_range[0], vmax=db_range[1]) 289 | if xmax: 290 | axs.set_xlim((0, xmax)) 291 | if justimage: 292 | axs.axis('off') 293 | plt.tight_layout() 294 | else: 295 | axs.set_ylabel(ylabel) 296 | axs.set_xlabel('frame') 297 | axs.set_title(title or 'Spectrogram (dB)') 298 | fig.colorbar(im, ax=axs) 299 | canvas.draw() 300 | rgba = np.asarray(canvas.buffer_rgba()) 301 | im = Image.fromarray(rgba) 302 | if justimage: # remove tiny white border 303 | b = 15 # border size 304 | im = im.crop((b,b, im.size[0]-b, im.size[1]-b)) 305 | #print(f"im.size = {im.size}") 306 | return im 307 | 308 | # %% ../02_viz.ipynb 27 309 | def audio_spectrogram_image(waveform, power=2.0, sample_rate=48000, print=print, db=False, db_range=[35,120], justimage=False, log=False, figsize=(5, 4)): 310 | "Wrapper for calling above two routines at once, does Mel scale; Modified from PyTorch tutorial https://pytorch.org/tutorials/beginner/audio_feature_extractions_tutorial.html" 311 | melspec = mel_spectrogram(waveform, power=power, db=db, sample_rate=sample_rate, debug=log) 312 | melspec = melspec[0] # TODO: only left channel for now 313 | return spectrogram_image(melspec, title="MelSpectrogram", ylabel='mel bins (log freq)', db_range=db_range, justimage=justimage, figsize=figsize) 314 | 315 | # %% ../02_viz.ipynb 31 316 | # Original code by Scott Condron (@scottire) of Weights and Biases, edited by @drscotthawley 317 | # cf. @scottire's original code here: https://gist.github.com/scottire/a8e5b74efca37945c0f1b0670761d568 318 | # and Morgan McGuire's edit here; https://github.com/morganmcg1/wandb_spectrogram 319 | 320 | 321 | # helper routine; a bit redundant given what else is in this repo 322 | def generate_melspec(audio_data, sample_rate=48000, power=2.0, n_fft = 1024, win_length = None, hop_length = None, n_mels = 128): 323 | "helper routine for playable_spectrogram" 324 | if hop_length is None: 325 | hop_length = n_fft//2 326 | 327 | # convert to torch 328 | audio_data = torch.tensor(audio_data, dtype=torch.float32) 329 | 330 | mel_spectrogram_op = T.MelSpectrogram( 331 | sample_rate=sample_rate, 332 | n_fft=n_fft, 333 | win_length=win_length, 334 | hop_length=hop_length, 335 | center=True, 336 | pad_mode="reflect", 337 | power=power, 338 | norm="slaney", 339 | onesided=True, 340 | n_mels=n_mels, 341 | mel_scale="htk", 342 | ) 343 | 344 | melspec = mel_spectrogram_op(audio_data).numpy() 345 | mel_db = np.flipud(power_to_db(melspec)) 346 | return mel_db 347 | 348 | 349 | # the main routine 350 | def playable_spectrogram( 351 | waveform, # audio, PyTorch tensor 352 | sample_rate=48000, # sample rate in Hz 353 | specs:str='all', # see docstring below 354 | layout:str='row', # 'row' or 'grid' 355 | height=170, # height of spectrogram image 356 | width=400, # width of spectrogram image 357 | cmap='viridis', # colormap string for Holoviews, see https://holoviews.org/user_guide/Colormaps.html 358 | output_type='wandb', # 'wandb', 'html_file', 'live': use live for notebooks 359 | debug=True # flag for internal print statements 360 | ): 361 | ''' 362 | Takes a tensor input and returns a [wandb.]HTML object with spectrograms of the audio 363 | specs : 364 | "all_specs", spectrograms only 365 | "all", all plots 366 | "melspec", melspectrogram only 367 | "spec", spectrogram only 368 | "wave_mel", waveform and melspectrogram only 369 | "waveform", waveform only, equivalent to wandb.Audio object 370 | 371 | Limitations: spectrograms show channel 0 only (i.e., mono) 372 | ''' 373 | hv.extension("bokeh", logo=False) 374 | 375 | audio_data = waveform.cpu().numpy() 376 | 377 | duration = audio_data.shape[-1]/sample_rate 378 | if len(audio_data.shape) > 1: 379 | mono_audio = audio_data[0,:] # MONO ONLY get one channel, for spectrograms 380 | 381 | # for the audio widget, it works best if you save-a-file-read-a-file 382 | # need to convert to int for Panel Audio element 383 | #audio_ints = np.clip( audio_data*32768 , -32768, 32768).astype('int16') 384 | # Audio widget 385 | #audio = pn.pane.Audio(audio_ints, sample_rate=sample_rate, name='Audio', throttle=10) 386 | tmp_audio_file = f'audio_out.wav' # holoview expects file to persist _{int(np.random.rand()*10000)}.wav' # rand number is just to allow parallel operation 387 | torchaudio.save(tmp_audio_file, waveform.cpu() ,sample_rate) 388 | audio = pn.pane.Audio(tmp_audio_file, name='Audio', throttle=10) 389 | #os.remove(tmp_audio_file) # but we don't want a ton of files to accumulate on the disk 390 | 391 | # Add HTML components 392 | line = hv.VLine(0).opts(color='red') 393 | line2 = hv.VLine(0).opts(color='green') 394 | line3 = hv.VLine(0).opts(color='white') 395 | 396 | slider = pn.widgets.FloatSlider(end=duration, visible=False, step=0.001) 397 | slider.jslink(audio, value='time', bidirectional=True) 398 | slider.jslink(line, value='glyph.location') 399 | slider.jslink(line2, value='glyph.location') 400 | slider.jslink(line3, value='glyph.location') 401 | 402 | # Spectogram plot 403 | if specs in ['spec','all_specs,','all']: 404 | f, t, sxx = spectrogram(mono_audio, sample_rate) 405 | spec_gram_hv = hv.Image((t, f, np.log10(sxx)), ["Time (s)", "Frequency (Hz)"]).opts( 406 | width=width, height=height, labelled=[], axiswise=True, color_levels=512, cmap=cmap) 407 | spec_gram_hv = spec_gram_hv.options(xlabel="Time (s)", ylabel="Frequency (Hz)") # for some reason it was ignoring my axis labels the in the Image definition 408 | 409 | spec_gram_hv *= line 410 | else: 411 | spec_gram_hv = None 412 | 413 | # Melspectogram plot 414 | if specs in ['melspec','all_specs','all','wave_mel']: 415 | mel_db = generate_melspec(mono_audio, sample_rate=sample_rate, power=2.0, n_fft = 1024, n_mels = 100) 416 | melspec_gram_hv = hv.Image(mel_db, ["Time (s)", "Mel Freq"], bounds=(0, 0, duration, mel_db.max()), ).opts( 417 | width=width, height=height, labelled=[], axiswise=True, color_levels=512, cmap=cmap) 418 | melspec_gram_hv = melspec_gram_hv.options(xlabel="Time (s)", ylabel="Mel Freq") 419 | melspec_gram_hv *= line3 420 | else: 421 | melspec_gram_hv = None 422 | 423 | # Waveform plot (multichannel as colors atop one another) 424 | if specs in ['waveform','all','wave_mel']: 425 | time = np.linspace(0, len(mono_audio)/sample_rate, num=len(mono_audio)) 426 | overlay_curves = [] 427 | for i in range(audio_data.shape[0]): 428 | overlay_curves.append(hv.Curve((time, audio_data[i]), "Time (s)", "amplitude").opts( 429 | width=width, height=height, axiswise=True)) 430 | line_plot_hv = hv.Overlay(overlay_curves).opts(width=width, height=height) 431 | #line_plot_hv = hv.Curve((time, mono_audio), "Time (s)", "amplitude").opts( 432 | # width=width, height=height, axiswise=True) 433 | line_plot_hv *= line2 434 | else: 435 | line_plot_hv = None 436 | 437 | 438 | # Create HTML layout 439 | html_file_name = "audio_spec.html" 440 | 441 | if layout == 'grid': 442 | combined = pn.GridBox(audio, spec_gram_hv, line_plot_hv, melspec_gram_hv, slider, ncols=2, nrows=2) 443 | else: # 'row' 444 | combined = pn.Row(audio, line_plot_hv, melspec_gram_hv, spec_gram_hv, slider) 445 | 446 | 447 | if output_type == 'live': 448 | return combined 449 | 450 | combined = combined.save(html_file_name) 451 | return wandb.Html(html_file_name) if output_type=='wandb' else html_file_name 452 | 453 | # %% ../02_viz.ipynb 35 454 | from matplotlib.ticker import AutoLocator 455 | def tokens_spectrogram_image( 456 | tokens, # the embeddings themselves (in some diffusion codes these are called 'tokens') 457 | aspect='auto', # aspect ratio of plot 458 | title='Embeddings', # title to put on top 459 | ylabel='index', # label for y axis of plot 460 | cmap='coolwarm', # colormap to use. (default used to be 'viridis') 461 | symmetric=True, # make color scale symmetric about zero, i.e. +/- same extremes 462 | figsize=(8, 4), # matplotlib size of the figure 463 | dpi=100, # dpi of figure 464 | mark_batches=False, # separate batches with dividing lines 465 | debug=False, # print debugging info 466 | ): 467 | "for visualizing embeddings in a spectrogram-like way" 468 | batch_size, dim, samples = tokens.shape 469 | embeddings = rearrange(tokens, 'b d n -> (b n) d') # expand batches in time 470 | vmin, vmax = None, None 471 | if symmetric: 472 | vmax = torch.abs(embeddings).max() 473 | vmin = -vmax 474 | 475 | fig = Figure(figsize=figsize, dpi=dpi) 476 | canvas = FigureCanvasAgg(fig) 477 | ax = fig.add_subplot() 478 | if symmetric: 479 | subtitle = f'min={embeddings.min():0.4g}, max={embeddings.max():0.4g}' 480 | ax.set_title(title+'\n') 481 | ax.text(x=0.435, y=0.9, s=subtitle, fontsize=11, ha="center", transform=fig.transFigure) 482 | else: 483 | ax.set_title(title) 484 | ax.set_ylabel(ylabel) 485 | ax.set_xlabel('time frame (samples, in batches)') 486 | if mark_batches: 487 | intervals = np.arange(batch_size)*samples 488 | if debug: print("intervals = ",intervals) 489 | ax.vlines(intervals, -10, dim+10, color='black', linestyle='dashed', linewidth=1) 490 | 491 | im = ax.imshow(embeddings.cpu().numpy().T, origin='lower', aspect=aspect, interpolation='none', cmap=cmap, vmin=vmin,vmax=vmax) #.T because numpy is x/y 'backwards' 492 | fig.colorbar(im, ax=ax) 493 | fig.tight_layout() 494 | canvas.draw() 495 | rgba = np.asarray(canvas.buffer_rgba()) 496 | return Image.fromarray(rgba) 497 | 498 | # %% ../02_viz.ipynb 39 499 | def plot_jukebox_embeddings(zs, aspect='auto'): 500 | "makes a plot of jukebox embeddings" 501 | fig, ax = plt.subplots(nrows=len(zs)) 502 | for i, z in enumerate(zs): 503 | #z = torch.squeeze(z) 504 | z = z.cpu().numpy() 505 | x = np.arange(z.shape[-1]) 506 | im = ax[i].imshow(z, origin='lower', aspect=aspect, interpolation='none') 507 | 508 | #plt.legend() 509 | plt.ylabel("emb (top=fine, bottom=coarse)") 510 | return {"chart": plt} 511 | -------------------------------------------------------------------------------- /examples/accel_config.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | deepspeed_config: {} 3 | distributed_type: MULTI_GPU 4 | fsdp_config: {} 5 | machine_rank: 0 6 | main_process_ip: '' 7 | main_process_port: 12332 8 | main_training_function: main 9 | mixed_precision: 'no' 10 | num_machines: 2 11 | num_processes: 8 12 | use_cpu: false 13 | 14 | -------------------------------------------------------------------------------- /examples/example.wav: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drscotthawley/aeiou/cd580bbea8250662e411feef30fe62f1cdc049b4/examples/example.wav -------------------------------------------------------------------------------- /examples/stereo_pewpew.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/drscotthawley/aeiou/cd580bbea8250662e411feef30fe62f1cdc049b4/examples/stereo_pewpew.mp3 -------------------------------------------------------------------------------- /index.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#|hide\n", 10 | "from aeiou.core import *" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "# aeiou\n", 18 | "\n", 19 | "> audio engineering i/o utils." 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "Pronounced \"[ayoo](https://youtu.be/Hv6RbEOlqRo?t=24)\"" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Install" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "It is recommended you install the latest version from GitHub via \n", 41 | "\n", 42 | "```sh \n", 43 | "pip install git+https://github.com/drscotthawley/aeiou.git\n", 44 | "```\n", 45 | "However binaries will be occasionally updated on PyPI, installed via\n", 46 | "\n", 47 | "```sh\n", 48 | "pip install aeiou\n", 49 | "```" 50 | ] 51 | }, 52 | { 53 | "cell_type": "markdown", 54 | "metadata": {}, 55 | "source": [ 56 | "## How to use" 57 | ] 58 | }, 59 | { 60 | "cell_type": "markdown", 61 | "metadata": {}, 62 | "source": [ 63 | "This is a series of utility routines developed in support of multiple projects within the [Harmonai](https://www.harmonai.org/) organization. See individual documentation pages for more specific instructions on how these can be used. Note that this is *research code*, so it's a) in flux and b) in need of improvements to documenation. " 64 | ] 65 | }, 66 | { 67 | "cell_type": "markdown", 68 | "metadata": {}, 69 | "source": [ 70 | "## Documentation" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "Documentation for this library is hosted on the [aeiou GitHub Pages site](https://drscotthawley.github.io/aeiou/)." 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Contributing" 85 | ] 86 | }, 87 | { 88 | "cell_type": "markdown", 89 | "metadata": {}, 90 | "source": [ 91 | "Contributions are welcome -- especially for improvements to documentation! To contribute:\n", 92 | "\n", 93 | "1. Fork this repo and then clone your fork to your local machine. \n", 94 | "\n", 95 | "1. Create a new (local) branch: `git -b mybranch` (or whatever you want to call it). \n", 96 | "1. This library is written entirely in [nbdev](https://nbdev.fast.ai/) version 2, using Jupyter notebooks. \n", 97 | "\n", 98 | "1. [Install nbdev](https://nbdev.fast.ai/getting_started.html#install) and then you can edit the Jupyter notebooks. \n", 99 | "\n", 100 | "1. After editing notebooks, run `nbdev_prepare` \n", 101 | "\n", 102 | "1. If that succeeds, you can do `git add *.ipynb aeiou/*.py; git commit` and then `git push` to get your changes to back to your fork on GitHub. \n", 103 | "\n", 104 | "1. Then send a Pull Request from your fork to the main `aeiou` repository. " 105 | ] 106 | }, 107 | { 108 | "cell_type": "markdown", 109 | "metadata": {}, 110 | "source": [ 111 | "## Attribution\n", 112 | "Please include attribution of this code if you reproduce sections of it in your own code:\n", 113 | "```\n", 114 | "aeiou: audio engineering i/o utilities: Copyright (c) Scott H. Hawley, 2022-2023. https://github.com/drscotthawley/aeiou\n", 115 | "```\n", 116 | " \n", 117 | " \n", 118 | "In research papers, please cite this software if you find it useful:\n", 119 | "```bibtex\n", 120 | "@misc{aeiou,\n", 121 | " author = {Scott H. Hawley},\n", 122 | " title = {aeiou: audio engineering i/o utilities},\n", 123 | " year = {2022},\n", 124 | " url = {https://github.com/drscotthawley/aeiou},\n", 125 | "}\n", 126 | "```\n", 127 | "Copyright (c) Scott H. Hawley, 2022-2023. " 128 | ] 129 | }, 130 | { 131 | "cell_type": "markdown", 132 | "metadata": {}, 133 | "source": [ 134 | "## License\n", 135 | "[License](https://github.com/drscotthawley/aeiou/blob/main/LICENSE) is APACHE 2.0. " 136 | ] 137 | }, 138 | { 139 | "cell_type": "code", 140 | "execution_count": null, 141 | "metadata": {}, 142 | "outputs": [], 143 | "source": [ 144 | "#| hide\n", 145 | "from nbdev import nbdev_export\n", 146 | "nbdev_export()" 147 | ] 148 | } 149 | ], 150 | "metadata": { 151 | "kernelspec": { 152 | "display_name": "python3", 153 | "language": "python", 154 | "name": "python3" 155 | } 156 | }, 157 | "nbformat": 4, 158 | "nbformat_minor": 4 159 | } 160 | -------------------------------------------------------------------------------- /nbdev.yml: -------------------------------------------------------------------------------- 1 | project: 2 | output-dir: _docs 3 | 4 | website: 5 | title: "aeiou" 6 | site-url: "https://drscotthawley.github.io/aeiou/" 7 | description: "audio engineering i/o utils" 8 | repo-branch: main 9 | repo-url: "https://github.com/drscotthawley/aeiou/" 10 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | # All sections below are required unless otherwise specified 3 | # see https://github.com/fastai/nbdev/blob/master/settings.ini for examples 4 | 5 | ### Python Library ### 6 | lib_name = aeiou 7 | min_python = 3.7 8 | version = 0.0.21 9 | 10 | ### OPTIONAL ### 11 | 12 | requirements = fastcore pandas numpy plotly bokeh holoviews scipy torch torchvision torchaudio wandb matplotlib pillow tqdm einops ipython accelerate webdataset pedalboard umap-learn 13 | dev_requirements = nbformat>=4.2.0 14 | console_scripts = chunkadelic=aeiou.chunkadelic:main spectrofu=aeiou.spectrofu:main 15 | 16 | 17 | ### nbdev ### 18 | nbs_path = . 19 | doc_path = _docs 20 | recursive = False 21 | tst_flags = notest 22 | 23 | ### Documentation ### 24 | host = github 25 | repo = aeiou 26 | branch = main 27 | custom_sidebar = False 28 | 29 | ### PyPI ### 30 | audience = Developers 31 | author = Scott Hawley 32 | author_email = scott.hawley@belmont.edu 33 | copyright = (c) 2022-2025 Scott H. Hawley 34 | description = audio engineering i/o utils 35 | keywords = nbdev 36 | language = English 37 | license = apache2 38 | status = 2 39 | user = drscotthawley 40 | 41 | ### Inferred From Other Values ### 42 | doc_host = https://%(user)s.github.io 43 | doc_baseurl = /%(lib_name)s/ 44 | git_url = https://github.com/%(user)s/%(repo)s/ 45 | lib_path = %(lib_name)s 46 | title = %(lib_name)s 47 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.6 3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = cfg.get('requirements','').split() 28 | if cfg.get('pip_requirements'): requirements += cfg.get('pip_requirements','').split() 29 | min_python = cfg['min_python'] 30 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 31 | dev_requirements = (cfg.get('dev_requirements') or '').split() 32 | 33 | setuptools.setup( 34 | name = cfg['lib_name'], 35 | license = lic[0], 36 | classifiers = [ 37 | 'Development Status :: ' + statuses[int(cfg['status'])], 38 | 'Intended Audience :: ' + cfg['audience'].title(), 39 | 'Natural Language :: ' + cfg['language'].title(), 40 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 41 | url = cfg['git_url'], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | install_requires = requirements, 45 | extras_require={ 'dev': dev_requirements }, 46 | dependency_links = cfg.get('dep_links','').split(), 47 | python_requires = '>=' + cfg['min_python'], 48 | long_description = open('README.md', encoding="utf8").read(), 49 | long_description_content_type = 'text/markdown', 50 | zip_safe = False, 51 | entry_points = { 52 | 'console_scripts': cfg.get('console_scripts','').split(), 53 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 54 | }, 55 | **setup_cfg) 56 | 57 | 58 | -------------------------------------------------------------------------------- /sidebar.yml: -------------------------------------------------------------------------------- 1 | website: 2 | sidebar: 3 | contents: 4 | - index.ipynb 5 | - 00_core.ipynb 6 | - 01_datasets.ipynb 7 | - 02_viz.ipynb 8 | - 03_chunkadelic.ipynb 9 | - 04_spectrofu.ipynb 10 | - 05_hpc.ipynb 11 | -------------------------------------------------------------------------------- /styles.css: -------------------------------------------------------------------------------- 1 | .cell { 2 | margin-bottom: 1rem; 3 | } 4 | 5 | .cell > .sourceCode { 6 | margin-bottom: 0; 7 | } 8 | 9 | .cell-output > pre { 10 | margin-bottom: 0; 11 | } 12 | 13 | .cell-output > pre, .cell-output > .sourceCode > pre, .cell-output-stdout > pre { 14 | margin-left: 0.8rem; 15 | margin-top: 0; 16 | background: none; 17 | border-left: 2px solid lightsalmon; 18 | border-top-left-radius: 0; 19 | border-top-right-radius: 0; 20 | } 21 | 22 | .cell-output > .sourceCode { 23 | border: none; 24 | } 25 | 26 | .cell-output > .sourceCode { 27 | background: none; 28 | margin-top: 0; 29 | } 30 | 31 | div.description { 32 | padding-left: 2px; 33 | padding-top: 5px; 34 | font-style: italic; 35 | font-size: 135%; 36 | opacity: 70%; 37 | } 38 | --------------------------------------------------------------------------------