├── .github └── workflows │ ├── deploy.yaml │ └── test.yaml ├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── fad_pytorch ├── __init__.py ├── _modidx.py ├── fad_embed.py ├── fad_gen.py ├── fad_score.py ├── pann.py ├── pann_pytorch_utils.py └── sqrtm.py ├── nbs ├── 01_fad_gen.ipynb ├── 02_fad_embed.ipynb ├── 03_fad_score.ipynb ├── 04_sqrtm.ipynb ├── _quarto.yml ├── index.ipynb ├── nbdev.yml └── styles.css ├── settings.ini └── setup.py /.github/workflows/deploy.yaml: -------------------------------------------------------------------------------- 1 | name: Deploy to GitHub Pages 2 | 3 | permissions: 4 | contents: write 5 | pages: write 6 | 7 | on: 8 | push: 9 | branches: [ "main", "master" ] 10 | workflow_dispatch: 11 | jobs: 12 | deploy: 13 | runs-on: ubuntu-latest 14 | steps: [run: sudo apt-get update; sudo apt-get install ffmpeg libsndfile-dev, uses: fastai/workflows/quarto-ghp@master] 15 | -------------------------------------------------------------------------------- /.github/workflows/test.yaml: -------------------------------------------------------------------------------- 1 | name: CI 2 | on: [workflow_dispatch, pull_request, push] 3 | 4 | jobs: 5 | test: 6 | runs-on: ubuntu-latest 7 | steps: [run: sudo apt-get update; sudo apt-get install ffmpeg libsndfile-dev, uses: fastai/workflows/nbdev-ci@master, ] 8 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | _docs/ 2 | _proc/ 3 | 4 | *.bak 5 | .gitattributes 6 | .last_checked 7 | .gitconfig 8 | *.bak 9 | *.log 10 | *~ 11 | ~* 12 | _tmp* 13 | tmp* 14 | tags 15 | *.pkg 16 | 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | env/ 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | *.egg-info/ 41 | .installed.cfg 42 | *.egg 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .coverage 58 | .coverage.* 59 | .cache 60 | nosetests.xml 61 | coverage.xml 62 | *.cover 63 | .hypothesis/ 64 | 65 | # Translations 66 | *.mo 67 | *.pot 68 | 69 | # Django stuff: 70 | *.log 71 | local_settings.py 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # pyenv 90 | .python-version 91 | 92 | # celery beat schedule file 93 | celerybeat-schedule 94 | 95 | # SageMath parsed files 96 | *.sage.py 97 | 98 | # dotenv 99 | .env 100 | 101 | # virtualenv 102 | .venv 103 | venv/ 104 | ENV/ 105 | 106 | # Spyder project settings 107 | .spyderproject 108 | .spyproject 109 | 110 | # Rope project settings 111 | .ropeproject 112 | 113 | # mkdocs documentation 114 | /site 115 | 116 | # mypy 117 | .mypy_cache/ 118 | 119 | .vscode 120 | *.swp 121 | 122 | # osx generated files 123 | .DS_Store 124 | .DS_Store? 125 | .Trashes 126 | ehthumbs.db 127 | Thumbs.db 128 | .idea 129 | 130 | # pytest 131 | .pytest_cache 132 | 133 | # tools/trust-doc-nbs 134 | docs_src/.last_checked 135 | 136 | # symlinks to fastai 137 | docs_src/fastai 138 | tools/fastai 139 | 140 | # link checker 141 | checklink/cookies.txt 142 | 143 | # .gitconfig is now autogenerated 144 | .gitconfig 145 | 146 | # Quarto installer 147 | .deb 148 | .pkg 149 | 150 | # Quarto 151 | .quarto 152 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Scott H. Hawley 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include settings.ini 2 | include LICENSE 3 | include CONTRIBUTING.md 4 | include README.md 5 | recursive-exclude * __pycache__ 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | fad_pytorch 2 | ================ 3 | 4 | 5 | 6 | [Original FAD paper (PDF)](https://arxiv.org/pdf/1812.08466.pdf) 7 | 8 | ## Install 9 | 10 | ``` sh 11 | pip install fad_pytorch 12 | ``` 13 | 14 | ## Features: 15 | 16 | - runs in parallel on multiple processors and multiple GPUs (via 17 | `accelerate`) 18 | - supports multiple embedding methods: 19 | - VGGish and PANN, both mono @ 16kHz 20 | - OpenL3 and (LAION-)CLAP, stereo @ 48kHz 21 | - uses publicly-available pretrained checkpoints for music (+other 22 | sources) for those models. (if you want Speech, submit a PR or an 23 | Issue; I don’t do speech.) 24 | - favors ops in PyTorch rather than numpy (or tensorflow) 25 | - `fad_gen` supports local data read or WebDataset (audio data stored in 26 | S3 buckets) 27 | - runs on CPU, CUDA, or MPS 28 | 29 | ## Instructions: 30 | 31 | This is designed to be run as 3 command-line scripts in succession. The 32 | latter 2 (`fad_embed` and `fad_score`) are probably what most people 33 | will want: 34 | 35 | 1. `fad_gen`: produces directories of real & fake audio (given real 36 | data). See `fad_gen` 37 | [documentation](https://drscotthawley.github.io/fad_pytorch/fad_gen.html) 38 | for calling sequence. 39 | 2. `fad_embed [options] `: produces 40 | directories of *embeddings* of real & fake audio 41 | 3. `fad_score [options] `: reads the 42 | embeddings & generates FAD score, for real (“$r$”) and fake (“$f$”): 43 | 44 | $$ FAD = || \mu_r - \mu_f ||^2 + tr\left(\Sigma_r + \Sigma_f - 2 \sqrt{\Sigma_r \Sigma_f}\right)$$ 45 | 46 | ## Documentation 47 | 48 | See the [Documentation 49 | Website](https://drscotthawley.github.io/fad_pytorch/). 50 | 51 | ## Comments / FAQ / Troubleshooting 52 | 53 | - “`RuntimeError: CUDA error: invalid device ordinal`”: This happens 54 | when you have a “bad node” on an AWS cluster. [Haven’t yet figured out 55 | what causes it or how to fix 56 | it](https://discuss.huggingface.co/t/solved-accelerate-accelerator-cuda-error-invalid-device-ordinal/21509/1). 57 | Workaround: Just add the current node to your SLURM `--exclude` list, 58 | exit and retry. Note: it may take as many as 5 to 7 retries before you 59 | get a “good node”. 60 | - “FAD scores obtained from different embedding methods are *wildly* 61 | different!” …Yea. It’s not obvious that scores from different 62 | embedding methods should be comparable. Rather, compare different 63 | groups of audio files using the same embedding method, and/or check 64 | that FAD scores go *down* as similarity improves. 65 | - “FAD score for the same dataset repeated (twice) is not exactly zero!” 66 | …Yea. There seems to be an uncertainty of around +/- 0.008. I’d say, 67 | don’t quote any numbers past the first decimal point. 68 | 69 | ## Contributing 70 | 71 | This repo is still fairly “bare bones” and will benefit from more 72 | documentation and features as time goes on. Note that it is written 73 | using [nbdev](https://nbdev.fast.ai/), so the things to do are: 74 | 75 | 1. Fork this repo 76 | 2. Clone your fork to your (local) machine 77 | 3. Install nbdev: `python3 -m pip install -U nbdev` 78 | 4. Make changes by editing the notebooks in `nbs/`, not the `.py` files 79 | in `fad_pytorch/`. 80 | 5. Run `nbdev_export` to export notebook changes to `.py` files 81 | 6. For good measure, run `nbdev_install_hooks` and `nbdev_clean` - 82 | especially if you’ve *added* any notebooks. 83 | 7. Do a `git status` to see all the `.ipynb` and `.py` files that need 84 | to be added & committed 85 | 8. `git add` those files and then `git commit`, and then `git push` 86 | 9. Take a look in your fork’s GitHub Actions tab, and see if the “test” 87 | and “deploy” CI runs finish properly (green light) or fail (red 88 | light) 89 | 10. Once you get green lights, send in a Pull Request! 90 | 91 | *Feel free to ask me for tips with nbdev, it has quite a learning curve. 92 | You can also ask on [fast.ai forums](https://forums.fast.ai/) and/or 93 | [fast.ai 94 | Discord](https://discord.com/channels/689892369998676007/887694559952400424)* 95 | 96 | ## Citations / Blame / Disclaimer 97 | 98 | This repo is 2 weeks old. I’m not ready for this to be cited in your 99 | papers. I’d hate for there to be some mistake I haven’t found yet. 100 | Perhaps a later version will have citation info. For now, instead, 101 | there’s: 102 | 103 | **Disclaimer:** Results from this repo are still a work in progress. 104 | While every effort has been made to test model outputs, the author takes 105 | no responsbility for mistakes. If you want to double-check via another 106 | source, see “Related Repos” below. 107 | 108 | ## Related Repos 109 | 110 | There are \[several\] others, but this one is mine. These repos didn’t 111 | have all the features I wanted, but I used them for inspiration: 112 | 113 | - https://github.com/gudgud96/frechet-audio-distance 114 | - https://github.com/google-research/google-research/tree/master/frechet_audio_distance: 115 | Goes with [Original FAD paper](https://arxiv.org/pdf/1812.08466.pdf) 116 | - https://github.com/AndreevP/speech_distances 117 | -------------------------------------------------------------------------------- /fad_pytorch/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.6" 2 | -------------------------------------------------------------------------------- /fad_pytorch/_modidx.py: -------------------------------------------------------------------------------- 1 | # Autogenerated by nbdev 2 | 3 | d = { 'settings': { 'branch': 'main', 4 | 'doc_baseurl': '/fad_pytorch', 5 | 'doc_host': 'https://drscotthawley.github.io', 6 | 'git_url': 'https://github.com/drscotthawley/fad_pytorch', 7 | 'lib_path': 'fad_pytorch'}, 8 | 'syms': { 'fad_pytorch.fad_embed': { 'fad_pytorch.fad_embed.download_file': ('fad_embed.html#download_file', 'fad_pytorch/fad_embed.py'), 9 | 'fad_pytorch.fad_embed.download_if_needed': ( 'fad_embed.html#download_if_needed', 10 | 'fad_pytorch/fad_embed.py'), 11 | 'fad_pytorch.fad_embed.embed': ('fad_embed.html#embed', 'fad_pytorch/fad_embed.py'), 12 | 'fad_pytorch.fad_embed.get_ckpt': ('fad_embed.html#get_ckpt', 'fad_pytorch/fad_embed.py'), 13 | 'fad_pytorch.fad_embed.main': ('fad_embed.html#main', 'fad_pytorch/fad_embed.py'), 14 | 'fad_pytorch.fad_embed.setup_embedder': ( 'fad_embed.html#setup_embedder', 15 | 'fad_pytorch/fad_embed.py')}, 16 | 'fad_pytorch.fad_gen': { 'fad_pytorch.fad_gen.gen': ('fad_gen.html#gen', 'fad_pytorch/fad_gen.py'), 17 | 'fad_pytorch.fad_gen.main': ('fad_gen.html#main', 'fad_pytorch/fad_gen.py')}, 18 | 'fad_pytorch.fad_score': { 'fad_pytorch.fad_score.calc_mu_sigma': ('fad_score.html#calc_mu_sigma', 'fad_pytorch/fad_score.py'), 19 | 'fad_pytorch.fad_score.calc_score': ('fad_score.html#calc_score', 'fad_pytorch/fad_score.py'), 20 | 'fad_pytorch.fad_score.main': ('fad_score.html#main', 'fad_pytorch/fad_score.py'), 21 | 'fad_pytorch.fad_score.read_embeddings': ( 'fad_score.html#read_embeddings', 22 | 'fad_pytorch/fad_score.py')}, 23 | 'fad_pytorch.pann': { 'fad_pytorch.pann.AttBlock': ('pann.html#attblock', 'fad_pytorch/pann.py'), 24 | 'fad_pytorch.pann.AttBlock.__init__': ('pann.html#attblock.__init__', 'fad_pytorch/pann.py'), 25 | 'fad_pytorch.pann.AttBlock.forward': ('pann.html#attblock.forward', 'fad_pytorch/pann.py'), 26 | 'fad_pytorch.pann.AttBlock.init_weights': ('pann.html#attblock.init_weights', 'fad_pytorch/pann.py'), 27 | 'fad_pytorch.pann.AttBlock.nonlinear_transform': ( 'pann.html#attblock.nonlinear_transform', 28 | 'fad_pytorch/pann.py'), 29 | 'fad_pytorch.pann.Cnn10': ('pann.html#cnn10', 'fad_pytorch/pann.py'), 30 | 'fad_pytorch.pann.Cnn10.__init__': ('pann.html#cnn10.__init__', 'fad_pytorch/pann.py'), 31 | 'fad_pytorch.pann.Cnn10.forward': ('pann.html#cnn10.forward', 'fad_pytorch/pann.py'), 32 | 'fad_pytorch.pann.Cnn10.init_weight': ('pann.html#cnn10.init_weight', 'fad_pytorch/pann.py'), 33 | 'fad_pytorch.pann.Cnn14': ('pann.html#cnn14', 'fad_pytorch/pann.py'), 34 | 'fad_pytorch.pann.Cnn14.__init__': ('pann.html#cnn14.__init__', 'fad_pytorch/pann.py'), 35 | 'fad_pytorch.pann.Cnn14.forward': ('pann.html#cnn14.forward', 'fad_pytorch/pann.py'), 36 | 'fad_pytorch.pann.Cnn14.init_weight': ('pann.html#cnn14.init_weight', 'fad_pytorch/pann.py'), 37 | 'fad_pytorch.pann.Cnn14_16k': ('pann.html#cnn14_16k', 'fad_pytorch/pann.py'), 38 | 'fad_pytorch.pann.Cnn14_16k.__init__': ('pann.html#cnn14_16k.__init__', 'fad_pytorch/pann.py'), 39 | 'fad_pytorch.pann.Cnn14_16k.forward': ('pann.html#cnn14_16k.forward', 'fad_pytorch/pann.py'), 40 | 'fad_pytorch.pann.Cnn14_16k.init_weight': ('pann.html#cnn14_16k.init_weight', 'fad_pytorch/pann.py'), 41 | 'fad_pytorch.pann.Cnn14_8k': ('pann.html#cnn14_8k', 'fad_pytorch/pann.py'), 42 | 'fad_pytorch.pann.Cnn14_8k.__init__': ('pann.html#cnn14_8k.__init__', 'fad_pytorch/pann.py'), 43 | 'fad_pytorch.pann.Cnn14_8k.forward': ('pann.html#cnn14_8k.forward', 'fad_pytorch/pann.py'), 44 | 'fad_pytorch.pann.Cnn14_8k.init_weight': ('pann.html#cnn14_8k.init_weight', 'fad_pytorch/pann.py'), 45 | 'fad_pytorch.pann.Cnn14_DecisionLevelAtt': ('pann.html#cnn14_decisionlevelatt', 'fad_pytorch/pann.py'), 46 | 'fad_pytorch.pann.Cnn14_DecisionLevelAtt.__init__': ( 'pann.html#cnn14_decisionlevelatt.__init__', 47 | 'fad_pytorch/pann.py'), 48 | 'fad_pytorch.pann.Cnn14_DecisionLevelAtt.forward': ( 'pann.html#cnn14_decisionlevelatt.forward', 49 | 'fad_pytorch/pann.py'), 50 | 'fad_pytorch.pann.Cnn14_DecisionLevelAtt.init_weight': ( 'pann.html#cnn14_decisionlevelatt.init_weight', 51 | 'fad_pytorch/pann.py'), 52 | 'fad_pytorch.pann.Cnn14_DecisionLevelAvg': ('pann.html#cnn14_decisionlevelavg', 'fad_pytorch/pann.py'), 53 | 'fad_pytorch.pann.Cnn14_DecisionLevelAvg.__init__': ( 'pann.html#cnn14_decisionlevelavg.__init__', 54 | 'fad_pytorch/pann.py'), 55 | 'fad_pytorch.pann.Cnn14_DecisionLevelAvg.forward': ( 'pann.html#cnn14_decisionlevelavg.forward', 56 | 'fad_pytorch/pann.py'), 57 | 'fad_pytorch.pann.Cnn14_DecisionLevelAvg.init_weight': ( 'pann.html#cnn14_decisionlevelavg.init_weight', 58 | 'fad_pytorch/pann.py'), 59 | 'fad_pytorch.pann.Cnn14_DecisionLevelMax': ('pann.html#cnn14_decisionlevelmax', 'fad_pytorch/pann.py'), 60 | 'fad_pytorch.pann.Cnn14_DecisionLevelMax.__init__': ( 'pann.html#cnn14_decisionlevelmax.__init__', 61 | 'fad_pytorch/pann.py'), 62 | 'fad_pytorch.pann.Cnn14_DecisionLevelMax.forward': ( 'pann.html#cnn14_decisionlevelmax.forward', 63 | 'fad_pytorch/pann.py'), 64 | 'fad_pytorch.pann.Cnn14_DecisionLevelMax.init_weight': ( 'pann.html#cnn14_decisionlevelmax.init_weight', 65 | 'fad_pytorch/pann.py'), 66 | 'fad_pytorch.pann.Cnn14_emb128': ('pann.html#cnn14_emb128', 'fad_pytorch/pann.py'), 67 | 'fad_pytorch.pann.Cnn14_emb128.__init__': ('pann.html#cnn14_emb128.__init__', 'fad_pytorch/pann.py'), 68 | 'fad_pytorch.pann.Cnn14_emb128.forward': ('pann.html#cnn14_emb128.forward', 'fad_pytorch/pann.py'), 69 | 'fad_pytorch.pann.Cnn14_emb128.init_weight': ( 'pann.html#cnn14_emb128.init_weight', 70 | 'fad_pytorch/pann.py'), 71 | 'fad_pytorch.pann.Cnn14_emb32': ('pann.html#cnn14_emb32', 'fad_pytorch/pann.py'), 72 | 'fad_pytorch.pann.Cnn14_emb32.__init__': ('pann.html#cnn14_emb32.__init__', 'fad_pytorch/pann.py'), 73 | 'fad_pytorch.pann.Cnn14_emb32.forward': ('pann.html#cnn14_emb32.forward', 'fad_pytorch/pann.py'), 74 | 'fad_pytorch.pann.Cnn14_emb32.init_weight': ('pann.html#cnn14_emb32.init_weight', 'fad_pytorch/pann.py'), 75 | 'fad_pytorch.pann.Cnn14_emb512': ('pann.html#cnn14_emb512', 'fad_pytorch/pann.py'), 76 | 'fad_pytorch.pann.Cnn14_emb512.__init__': ('pann.html#cnn14_emb512.__init__', 'fad_pytorch/pann.py'), 77 | 'fad_pytorch.pann.Cnn14_emb512.forward': ('pann.html#cnn14_emb512.forward', 'fad_pytorch/pann.py'), 78 | 'fad_pytorch.pann.Cnn14_emb512.init_weight': ( 'pann.html#cnn14_emb512.init_weight', 79 | 'fad_pytorch/pann.py'), 80 | 'fad_pytorch.pann.Cnn14_mel128': ('pann.html#cnn14_mel128', 'fad_pytorch/pann.py'), 81 | 'fad_pytorch.pann.Cnn14_mel128.__init__': ('pann.html#cnn14_mel128.__init__', 'fad_pytorch/pann.py'), 82 | 'fad_pytorch.pann.Cnn14_mel128.forward': ('pann.html#cnn14_mel128.forward', 'fad_pytorch/pann.py'), 83 | 'fad_pytorch.pann.Cnn14_mel128.init_weight': ( 'pann.html#cnn14_mel128.init_weight', 84 | 'fad_pytorch/pann.py'), 85 | 'fad_pytorch.pann.Cnn14_mel32': ('pann.html#cnn14_mel32', 'fad_pytorch/pann.py'), 86 | 'fad_pytorch.pann.Cnn14_mel32.__init__': ('pann.html#cnn14_mel32.__init__', 'fad_pytorch/pann.py'), 87 | 'fad_pytorch.pann.Cnn14_mel32.forward': ('pann.html#cnn14_mel32.forward', 'fad_pytorch/pann.py'), 88 | 'fad_pytorch.pann.Cnn14_mel32.init_weight': ('pann.html#cnn14_mel32.init_weight', 'fad_pytorch/pann.py'), 89 | 'fad_pytorch.pann.Cnn14_mixup_time_domain': ('pann.html#cnn14_mixup_time_domain', 'fad_pytorch/pann.py'), 90 | 'fad_pytorch.pann.Cnn14_mixup_time_domain.__init__': ( 'pann.html#cnn14_mixup_time_domain.__init__', 91 | 'fad_pytorch/pann.py'), 92 | 'fad_pytorch.pann.Cnn14_mixup_time_domain.forward': ( 'pann.html#cnn14_mixup_time_domain.forward', 93 | 'fad_pytorch/pann.py'), 94 | 'fad_pytorch.pann.Cnn14_mixup_time_domain.init_weight': ( 'pann.html#cnn14_mixup_time_domain.init_weight', 95 | 'fad_pytorch/pann.py'), 96 | 'fad_pytorch.pann.Cnn14_no_dropout': ('pann.html#cnn14_no_dropout', 'fad_pytorch/pann.py'), 97 | 'fad_pytorch.pann.Cnn14_no_dropout.__init__': ( 'pann.html#cnn14_no_dropout.__init__', 98 | 'fad_pytorch/pann.py'), 99 | 'fad_pytorch.pann.Cnn14_no_dropout.forward': ( 'pann.html#cnn14_no_dropout.forward', 100 | 'fad_pytorch/pann.py'), 101 | 'fad_pytorch.pann.Cnn14_no_dropout.init_weight': ( 'pann.html#cnn14_no_dropout.init_weight', 102 | 'fad_pytorch/pann.py'), 103 | 'fad_pytorch.pann.Cnn14_no_specaug': ('pann.html#cnn14_no_specaug', 'fad_pytorch/pann.py'), 104 | 'fad_pytorch.pann.Cnn14_no_specaug.__init__': ( 'pann.html#cnn14_no_specaug.__init__', 105 | 'fad_pytorch/pann.py'), 106 | 'fad_pytorch.pann.Cnn14_no_specaug.forward': ( 'pann.html#cnn14_no_specaug.forward', 107 | 'fad_pytorch/pann.py'), 108 | 'fad_pytorch.pann.Cnn14_no_specaug.init_weight': ( 'pann.html#cnn14_no_specaug.init_weight', 109 | 'fad_pytorch/pann.py'), 110 | 'fad_pytorch.pann.Cnn6': ('pann.html#cnn6', 'fad_pytorch/pann.py'), 111 | 'fad_pytorch.pann.Cnn6.__init__': ('pann.html#cnn6.__init__', 'fad_pytorch/pann.py'), 112 | 'fad_pytorch.pann.Cnn6.forward': ('pann.html#cnn6.forward', 'fad_pytorch/pann.py'), 113 | 'fad_pytorch.pann.Cnn6.init_weight': ('pann.html#cnn6.init_weight', 'fad_pytorch/pann.py'), 114 | 'fad_pytorch.pann.ConvBlock': ('pann.html#convblock', 'fad_pytorch/pann.py'), 115 | 'fad_pytorch.pann.ConvBlock.__init__': ('pann.html#convblock.__init__', 'fad_pytorch/pann.py'), 116 | 'fad_pytorch.pann.ConvBlock.forward': ('pann.html#convblock.forward', 'fad_pytorch/pann.py'), 117 | 'fad_pytorch.pann.ConvBlock.init_weight': ('pann.html#convblock.init_weight', 'fad_pytorch/pann.py'), 118 | 'fad_pytorch.pann.ConvBlock5x5': ('pann.html#convblock5x5', 'fad_pytorch/pann.py'), 119 | 'fad_pytorch.pann.ConvBlock5x5.__init__': ('pann.html#convblock5x5.__init__', 'fad_pytorch/pann.py'), 120 | 'fad_pytorch.pann.ConvBlock5x5.forward': ('pann.html#convblock5x5.forward', 'fad_pytorch/pann.py'), 121 | 'fad_pytorch.pann.ConvBlock5x5.init_weight': ( 'pann.html#convblock5x5.init_weight', 122 | 'fad_pytorch/pann.py'), 123 | 'fad_pytorch.pann.ConvPreWavBlock': ('pann.html#convprewavblock', 'fad_pytorch/pann.py'), 124 | 'fad_pytorch.pann.ConvPreWavBlock.__init__': ( 'pann.html#convprewavblock.__init__', 125 | 'fad_pytorch/pann.py'), 126 | 'fad_pytorch.pann.ConvPreWavBlock.forward': ('pann.html#convprewavblock.forward', 'fad_pytorch/pann.py'), 127 | 'fad_pytorch.pann.ConvPreWavBlock.init_weight': ( 'pann.html#convprewavblock.init_weight', 128 | 'fad_pytorch/pann.py'), 129 | 'fad_pytorch.pann.DaiNet19': ('pann.html#dainet19', 'fad_pytorch/pann.py'), 130 | 'fad_pytorch.pann.DaiNet19.__init__': ('pann.html#dainet19.__init__', 'fad_pytorch/pann.py'), 131 | 'fad_pytorch.pann.DaiNet19.forward': ('pann.html#dainet19.forward', 'fad_pytorch/pann.py'), 132 | 'fad_pytorch.pann.DaiNet19.init_weight': ('pann.html#dainet19.init_weight', 'fad_pytorch/pann.py'), 133 | 'fad_pytorch.pann.DaiNetResBlock': ('pann.html#dainetresblock', 'fad_pytorch/pann.py'), 134 | 'fad_pytorch.pann.DaiNetResBlock.__init__': ('pann.html#dainetresblock.__init__', 'fad_pytorch/pann.py'), 135 | 'fad_pytorch.pann.DaiNetResBlock.forward': ('pann.html#dainetresblock.forward', 'fad_pytorch/pann.py'), 136 | 'fad_pytorch.pann.DaiNetResBlock.init_weight': ( 'pann.html#dainetresblock.init_weight', 137 | 'fad_pytorch/pann.py'), 138 | 'fad_pytorch.pann.InvertedResidual': ('pann.html#invertedresidual', 'fad_pytorch/pann.py'), 139 | 'fad_pytorch.pann.InvertedResidual.__init__': ( 'pann.html#invertedresidual.__init__', 140 | 'fad_pytorch/pann.py'), 141 | 'fad_pytorch.pann.InvertedResidual.forward': ( 'pann.html#invertedresidual.forward', 142 | 'fad_pytorch/pann.py'), 143 | 'fad_pytorch.pann.LeeNet11': ('pann.html#leenet11', 'fad_pytorch/pann.py'), 144 | 'fad_pytorch.pann.LeeNet11.__init__': ('pann.html#leenet11.__init__', 'fad_pytorch/pann.py'), 145 | 'fad_pytorch.pann.LeeNet11.forward': ('pann.html#leenet11.forward', 'fad_pytorch/pann.py'), 146 | 'fad_pytorch.pann.LeeNet11.init_weight': ('pann.html#leenet11.init_weight', 'fad_pytorch/pann.py'), 147 | 'fad_pytorch.pann.LeeNet24': ('pann.html#leenet24', 'fad_pytorch/pann.py'), 148 | 'fad_pytorch.pann.LeeNet24.__init__': ('pann.html#leenet24.__init__', 'fad_pytorch/pann.py'), 149 | 'fad_pytorch.pann.LeeNet24.forward': ('pann.html#leenet24.forward', 'fad_pytorch/pann.py'), 150 | 'fad_pytorch.pann.LeeNet24.init_weight': ('pann.html#leenet24.init_weight', 'fad_pytorch/pann.py'), 151 | 'fad_pytorch.pann.LeeNetConvBlock': ('pann.html#leenetconvblock', 'fad_pytorch/pann.py'), 152 | 'fad_pytorch.pann.LeeNetConvBlock.__init__': ( 'pann.html#leenetconvblock.__init__', 153 | 'fad_pytorch/pann.py'), 154 | 'fad_pytorch.pann.LeeNetConvBlock.forward': ('pann.html#leenetconvblock.forward', 'fad_pytorch/pann.py'), 155 | 'fad_pytorch.pann.LeeNetConvBlock.init_weight': ( 'pann.html#leenetconvblock.init_weight', 156 | 'fad_pytorch/pann.py'), 157 | 'fad_pytorch.pann.LeeNetConvBlock2': ('pann.html#leenetconvblock2', 'fad_pytorch/pann.py'), 158 | 'fad_pytorch.pann.LeeNetConvBlock2.__init__': ( 'pann.html#leenetconvblock2.__init__', 159 | 'fad_pytorch/pann.py'), 160 | 'fad_pytorch.pann.LeeNetConvBlock2.forward': ( 'pann.html#leenetconvblock2.forward', 161 | 'fad_pytorch/pann.py'), 162 | 'fad_pytorch.pann.LeeNetConvBlock2.init_weight': ( 'pann.html#leenetconvblock2.init_weight', 163 | 'fad_pytorch/pann.py'), 164 | 'fad_pytorch.pann.MobileNetV1': ('pann.html#mobilenetv1', 'fad_pytorch/pann.py'), 165 | 'fad_pytorch.pann.MobileNetV1.__init__': ('pann.html#mobilenetv1.__init__', 'fad_pytorch/pann.py'), 166 | 'fad_pytorch.pann.MobileNetV1.forward': ('pann.html#mobilenetv1.forward', 'fad_pytorch/pann.py'), 167 | 'fad_pytorch.pann.MobileNetV1.init_weights': ( 'pann.html#mobilenetv1.init_weights', 168 | 'fad_pytorch/pann.py'), 169 | 'fad_pytorch.pann.MobileNetV2': ('pann.html#mobilenetv2', 'fad_pytorch/pann.py'), 170 | 'fad_pytorch.pann.MobileNetV2.__init__': ('pann.html#mobilenetv2.__init__', 'fad_pytorch/pann.py'), 171 | 'fad_pytorch.pann.MobileNetV2.forward': ('pann.html#mobilenetv2.forward', 'fad_pytorch/pann.py'), 172 | 'fad_pytorch.pann.MobileNetV2.init_weight': ('pann.html#mobilenetv2.init_weight', 'fad_pytorch/pann.py'), 173 | 'fad_pytorch.pann.Res1dNet31': ('pann.html#res1dnet31', 'fad_pytorch/pann.py'), 174 | 'fad_pytorch.pann.Res1dNet31.__init__': ('pann.html#res1dnet31.__init__', 'fad_pytorch/pann.py'), 175 | 'fad_pytorch.pann.Res1dNet31.forward': ('pann.html#res1dnet31.forward', 'fad_pytorch/pann.py'), 176 | 'fad_pytorch.pann.Res1dNet31.init_weight': ('pann.html#res1dnet31.init_weight', 'fad_pytorch/pann.py'), 177 | 'fad_pytorch.pann.Res1dNet51': ('pann.html#res1dnet51', 'fad_pytorch/pann.py'), 178 | 'fad_pytorch.pann.Res1dNet51.__init__': ('pann.html#res1dnet51.__init__', 'fad_pytorch/pann.py'), 179 | 'fad_pytorch.pann.Res1dNet51.forward': ('pann.html#res1dnet51.forward', 'fad_pytorch/pann.py'), 180 | 'fad_pytorch.pann.Res1dNet51.init_weight': ('pann.html#res1dnet51.init_weight', 'fad_pytorch/pann.py'), 181 | 'fad_pytorch.pann.ResNet22': ('pann.html#resnet22', 'fad_pytorch/pann.py'), 182 | 'fad_pytorch.pann.ResNet22.__init__': ('pann.html#resnet22.__init__', 'fad_pytorch/pann.py'), 183 | 'fad_pytorch.pann.ResNet22.forward': ('pann.html#resnet22.forward', 'fad_pytorch/pann.py'), 184 | 'fad_pytorch.pann.ResNet22.init_weights': ('pann.html#resnet22.init_weights', 'fad_pytorch/pann.py'), 185 | 'fad_pytorch.pann.ResNet38': ('pann.html#resnet38', 'fad_pytorch/pann.py'), 186 | 'fad_pytorch.pann.ResNet38.__init__': ('pann.html#resnet38.__init__', 'fad_pytorch/pann.py'), 187 | 'fad_pytorch.pann.ResNet38.forward': ('pann.html#resnet38.forward', 'fad_pytorch/pann.py'), 188 | 'fad_pytorch.pann.ResNet38.init_weights': ('pann.html#resnet38.init_weights', 'fad_pytorch/pann.py'), 189 | 'fad_pytorch.pann.ResNet54': ('pann.html#resnet54', 'fad_pytorch/pann.py'), 190 | 'fad_pytorch.pann.ResNet54.__init__': ('pann.html#resnet54.__init__', 'fad_pytorch/pann.py'), 191 | 'fad_pytorch.pann.ResNet54.forward': ('pann.html#resnet54.forward', 'fad_pytorch/pann.py'), 192 | 'fad_pytorch.pann.ResNet54.init_weights': ('pann.html#resnet54.init_weights', 'fad_pytorch/pann.py'), 193 | 'fad_pytorch.pann.Wavegram_Cnn14': ('pann.html#wavegram_cnn14', 'fad_pytorch/pann.py'), 194 | 'fad_pytorch.pann.Wavegram_Cnn14.__init__': ('pann.html#wavegram_cnn14.__init__', 'fad_pytorch/pann.py'), 195 | 'fad_pytorch.pann.Wavegram_Cnn14.forward': ('pann.html#wavegram_cnn14.forward', 'fad_pytorch/pann.py'), 196 | 'fad_pytorch.pann.Wavegram_Cnn14.init_weight': ( 'pann.html#wavegram_cnn14.init_weight', 197 | 'fad_pytorch/pann.py'), 198 | 'fad_pytorch.pann.Wavegram_Logmel128_Cnn14': ( 'pann.html#wavegram_logmel128_cnn14', 199 | 'fad_pytorch/pann.py'), 200 | 'fad_pytorch.pann.Wavegram_Logmel128_Cnn14.__init__': ( 'pann.html#wavegram_logmel128_cnn14.__init__', 201 | 'fad_pytorch/pann.py'), 202 | 'fad_pytorch.pann.Wavegram_Logmel128_Cnn14.forward': ( 'pann.html#wavegram_logmel128_cnn14.forward', 203 | 'fad_pytorch/pann.py'), 204 | 'fad_pytorch.pann.Wavegram_Logmel128_Cnn14.init_weight': ( 'pann.html#wavegram_logmel128_cnn14.init_weight', 205 | 'fad_pytorch/pann.py'), 206 | 'fad_pytorch.pann.Wavegram_Logmel_Cnn14': ('pann.html#wavegram_logmel_cnn14', 'fad_pytorch/pann.py'), 207 | 'fad_pytorch.pann.Wavegram_Logmel_Cnn14.__init__': ( 'pann.html#wavegram_logmel_cnn14.__init__', 208 | 'fad_pytorch/pann.py'), 209 | 'fad_pytorch.pann.Wavegram_Logmel_Cnn14.forward': ( 'pann.html#wavegram_logmel_cnn14.forward', 210 | 'fad_pytorch/pann.py'), 211 | 'fad_pytorch.pann.Wavegram_Logmel_Cnn14.init_weight': ( 'pann.html#wavegram_logmel_cnn14.init_weight', 212 | 'fad_pytorch/pann.py'), 213 | 'fad_pytorch.pann._ResNet': ('pann.html#_resnet', 'fad_pytorch/pann.py'), 214 | 'fad_pytorch.pann._ResNet.__init__': ('pann.html#_resnet.__init__', 'fad_pytorch/pann.py'), 215 | 'fad_pytorch.pann._ResNet._make_layer': ('pann.html#_resnet._make_layer', 'fad_pytorch/pann.py'), 216 | 'fad_pytorch.pann._ResNet.forward': ('pann.html#_resnet.forward', 'fad_pytorch/pann.py'), 217 | 'fad_pytorch.pann._ResNetWav1d': ('pann.html#_resnetwav1d', 'fad_pytorch/pann.py'), 218 | 'fad_pytorch.pann._ResNetWav1d.__init__': ('pann.html#_resnetwav1d.__init__', 'fad_pytorch/pann.py'), 219 | 'fad_pytorch.pann._ResNetWav1d._make_layer': ( 'pann.html#_resnetwav1d._make_layer', 220 | 'fad_pytorch/pann.py'), 221 | 'fad_pytorch.pann._ResNetWav1d.forward': ('pann.html#_resnetwav1d.forward', 'fad_pytorch/pann.py'), 222 | 'fad_pytorch.pann._ResnetBasicBlock': ('pann.html#_resnetbasicblock', 'fad_pytorch/pann.py'), 223 | 'fad_pytorch.pann._ResnetBasicBlock.__init__': ( 'pann.html#_resnetbasicblock.__init__', 224 | 'fad_pytorch/pann.py'), 225 | 'fad_pytorch.pann._ResnetBasicBlock.forward': ( 'pann.html#_resnetbasicblock.forward', 226 | 'fad_pytorch/pann.py'), 227 | 'fad_pytorch.pann._ResnetBasicBlock.init_weights': ( 'pann.html#_resnetbasicblock.init_weights', 228 | 'fad_pytorch/pann.py'), 229 | 'fad_pytorch.pann._ResnetBasicBlockWav1d': ('pann.html#_resnetbasicblockwav1d', 'fad_pytorch/pann.py'), 230 | 'fad_pytorch.pann._ResnetBasicBlockWav1d.__init__': ( 'pann.html#_resnetbasicblockwav1d.__init__', 231 | 'fad_pytorch/pann.py'), 232 | 'fad_pytorch.pann._ResnetBasicBlockWav1d.forward': ( 'pann.html#_resnetbasicblockwav1d.forward', 233 | 'fad_pytorch/pann.py'), 234 | 'fad_pytorch.pann._ResnetBasicBlockWav1d.init_weights': ( 'pann.html#_resnetbasicblockwav1d.init_weights', 235 | 'fad_pytorch/pann.py'), 236 | 'fad_pytorch.pann._ResnetBottleneck': ('pann.html#_resnetbottleneck', 'fad_pytorch/pann.py'), 237 | 'fad_pytorch.pann._ResnetBottleneck.__init__': ( 'pann.html#_resnetbottleneck.__init__', 238 | 'fad_pytorch/pann.py'), 239 | 'fad_pytorch.pann._ResnetBottleneck.forward': ( 'pann.html#_resnetbottleneck.forward', 240 | 'fad_pytorch/pann.py'), 241 | 'fad_pytorch.pann._ResnetBottleneck.init_weights': ( 'pann.html#_resnetbottleneck.init_weights', 242 | 'fad_pytorch/pann.py'), 243 | 'fad_pytorch.pann._resnet_conv1x1': ('pann.html#_resnet_conv1x1', 'fad_pytorch/pann.py'), 244 | 'fad_pytorch.pann._resnet_conv1x1_wav1d': ('pann.html#_resnet_conv1x1_wav1d', 'fad_pytorch/pann.py'), 245 | 'fad_pytorch.pann._resnet_conv3x1_wav1d': ('pann.html#_resnet_conv3x1_wav1d', 'fad_pytorch/pann.py'), 246 | 'fad_pytorch.pann._resnet_conv3x3': ('pann.html#_resnet_conv3x3', 'fad_pytorch/pann.py'), 247 | 'fad_pytorch.pann.init_bn': ('pann.html#init_bn', 'fad_pytorch/pann.py'), 248 | 'fad_pytorch.pann.init_layer': ('pann.html#init_layer', 'fad_pytorch/pann.py')}, 249 | 'fad_pytorch.pann_pytorch_utils': {}, 250 | 'fad_pytorch.sqrtm': { 'fad_pytorch.sqrtm.MatrixSquareRoot_li': ('sqrtm.html#matrixsquareroot_li', 'fad_pytorch/sqrtm.py'), 251 | 'fad_pytorch.sqrtm.MatrixSquareRoot_li.backward': ( 'sqrtm.html#matrixsquareroot_li.backward', 252 | 'fad_pytorch/sqrtm.py'), 253 | 'fad_pytorch.sqrtm.MatrixSquareRoot_li.forward': ( 'sqrtm.html#matrixsquareroot_li.forward', 254 | 'fad_pytorch/sqrtm.py'), 255 | 'fad_pytorch.sqrtm.compute_error': ('sqrtm.html#compute_error', 'fad_pytorch/sqrtm.py'), 256 | 'fad_pytorch.sqrtm.sqrt_newton_schulz': ('sqrtm.html#sqrt_newton_schulz', 'fad_pytorch/sqrtm.py'), 257 | 'fad_pytorch.sqrtm.sqrt_newton_schulz_autograd': ( 'sqrtm.html#sqrt_newton_schulz_autograd', 258 | 'fad_pytorch/sqrtm.py'), 259 | 'fad_pytorch.sqrtm.sqrtm': ('sqrtm.html#sqrtm', 'fad_pytorch/sqrtm.py')}}} 260 | -------------------------------------------------------------------------------- /fad_pytorch/fad_embed.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/02_fad_embed.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['OPENL3_VERSION', 'GUDGUD_LICENSE', 'download_file', 'download_if_needed', 'get_ckpt', 'setup_embedder', 'embed', 5 | 'main'] 6 | 7 | # %% ../nbs/02_fad_embed.ipynb 5 8 | import os 9 | import numpy as np 10 | import argparse 11 | import laion_clap 12 | from laion_clap.training.data import get_audio_features 13 | from accelerate import Accelerator 14 | import warnings 15 | import torch 16 | 17 | from aeiou.core import get_device, load_audio, get_audio_filenames, makedir 18 | from aeiou.datasets import AudioDataset 19 | from aeiou.hpc import HostPrinter 20 | from torch.utils.data import DataLoader 21 | from pathlib import Path 22 | import requests 23 | from tqdm import tqdm 24 | import site 25 | from einops import rearrange 26 | 27 | try: 28 | from fad_pytorch.pann import Cnn14_16k 29 | except: 30 | from pann import Cnn14_16k 31 | 32 | # there are TWO 'torchopenl3' repos! they operate differently. 33 | OPENL3_VERSION = "turian" # # "hugo" | "turian". set to which version you've installed 34 | import torchopenl3 35 | 36 | # %% ../nbs/02_fad_embed.ipynb 7 37 | def download_file(url, local_filename): 38 | "Includes a progress bar. from https://stackoverflow.com/a/37573701/4259243" 39 | response = requests.get(url, stream=True) 40 | total_size_in_bytes= int(response.headers.get('content-length', 0)) 41 | block_size = 1024 #1 Kilobye 42 | progress_bar = tqdm(total=total_size_in_bytes, unit='iB', unit_scale=True) 43 | with open(local_filename, 'wb') as file: 44 | for data in response.iter_content(block_size): 45 | progress_bar.update(len(data)) 46 | file.write(data) 47 | progress_bar.close() 48 | if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: 49 | print("ERROR, something went wrong") 50 | return local_filename 51 | 52 | def download_if_needed(url, local_filename, accelerator=None): 53 | "wrapper for download file" 54 | if accelerator is None or accelerator.is_local_main_process: # Only do this on one process instead of all 55 | if not os.path.isfile(local_filename): 56 | print(f"File {local_filename} not found, downloading from {url}") 57 | download_file( url, local_filename) 58 | if accelerator is not None: accelerator.wait_for_everyone() 59 | return local_filename 60 | 61 | def get_ckpt(ckpt_file='music_speech_audioset_epoch_15_esc_89.98.pt', 62 | ckpt_base_url='https://huggingface.co/lukewys/laion_clap/blob/main', 63 | ckpt_dl_path=os.path.expanduser("~/checkpoints"), 64 | accelerator=None, 65 | ): 66 | ckpt_path = f"{ckpt_dl_path}/{ckpt_file}" 67 | download_if_needed( f"{ckpt_base_url}/{ckpt_file}" , ckpt_path) 68 | return ckpt_path 69 | 70 | # %% ../nbs/02_fad_embed.ipynb 8 71 | def setup_embedder( 72 | model_choice='clap', # 'clap' | 'vggish' | 'pann' 73 | device='cuda', 74 | ckpt_file='music_speech_audioset_epoch_15_esc_89.98.pt', # NOTE: 'CLAP_CKPT' env var overrides ckpt_file kwarg 75 | ckpt_base_url='https://huggingface.co/lukewys/laion_clap/resolve/main', 76 | # https://huggingface.co/lukewys/laion_clap/resolve/main/music_speech_audioset_epoch_15_esc_89.98.pt 77 | accelerator=None, 78 | ckpt_dl_path=os.path.expanduser("~/checkpoints"), 79 | ): 80 | "load the embedder model" 81 | embedder = None 82 | 83 | sample_rate = 16000 84 | if model_choice == 'clap': 85 | print(f"Starting basic CLAP setup") 86 | clap_fusion, clap_amodel = True, "HTSAT-base" 87 | #doesn't work: warnings.filterwarnings('ignore') # temporarily disable CLAP warnings as they are super annoying. 88 | clap_module = laion_clap.CLAP_Module(enable_fusion=clap_fusion, device=device, amodel=clap_amodel).requires_grad_(False).eval() 89 | clap_ckpt_path = os.getenv('CLAP_CKPT') # NOTE: CLAP_CKPT env var overrides ckpt_file kwarg 90 | if clap_ckpt_path is not None: 91 | #print(f"Loading CLAP from {clap_ckpt_path}") 92 | clap_module.load_ckpt(ckpt=clap_ckpt_path, verbose=False) 93 | else: 94 | print(f"No CLAP checkpoint specified, using {ckpt_file}") 95 | clap_module = laion_clap.CLAP_Module(enable_fusion=False, amodel= 'HTSAT-base') 96 | ckpt_path = get_ckpt(ckpt_file=ckpt_file, ckpt_base_url=ckpt_base_url, ckpt_dl_path=ckpt_dl_path, accelerator=accelerator) 97 | clap_module.load_ckpt(ckpt_path, verbose=False) 98 | #clap_module.load_ckpt(model_id=1, verbose=False) 99 | #warnings.filterwarnings("default") # turn warnings back on. 100 | embedder = clap_module # synonyms 101 | sample_rate = 48000 102 | 103 | # next two model loading codes from gudgud96's repo: https://github.com/gudgud96/frechet-audio-distance, LICENSE below 104 | elif model_choice == "vggish": # https://arxiv.org/abs/1609.09430 105 | embedder = torch.hub.load('harritaylor/torchvggish', 'vggish') 106 | use_pca=False 107 | use_activation=False 108 | if not use_pca: embedder.postprocess = False 109 | if not use_activation: embedder.embeddings = torch.nn.Sequential(*list(embedder.embeddings.children())[:-1]) 110 | sample_rate = 16000 111 | 112 | elif model_choice == "pann": # https://arxiv.org/abs/1912.10211 113 | sample_rate = 16000 114 | model_path = os.path.join(torch.hub.get_dir(), "Cnn14_16k_mAP%3D0.438.pth") 115 | if accelerator is None or accelerator.is_local_main_process: 116 | if not(os.path.exists(model_path)): 117 | torch.hub.download_url_to_file('https://zenodo.org/record/3987831/files/Cnn14_16k_mAP%3D0.438.pth', model_path) 118 | if accelerator is not None: accelerator.wait_for_everyone() 119 | embedder = Cnn14_16k(sample_rate=sample_rate, window_size=512, hop_size=160, mel_bins=64, fmin=50, fmax=8000, classes_num=527) 120 | checkpoint = torch.load(model_path, map_location=device) 121 | embedder.load_state_dict(checkpoint['model']) 122 | 123 | elif model_choice == "openl3" and OPENL3_VERSION == "hugo": # hugo flores garcia's torchopenl3, https://github.com/hugofloresgarcia/torchopenl3 124 | # openl3 repo doesn't install its weights if you do "pip install git+...", so here we download them separately 125 | weights_dir = f"{site.getsitepackages()[0]}/torchopenl3/assets/weights" 126 | makedir(weights_dir) 127 | download_if_needed("https://github.com/hugofloresgarcia/torchopenl3/raw/main/torchopenl3/assets/weights/env-mel128", 128 | f"{weights_dir}/music-mel128", accelerator=accelerator) 129 | embedder = torchopenl3.OpenL3Embedding(input_repr='mel128', embedding_size=512, content_type='music') 130 | sample_rate = 48000 131 | 132 | elif model_choice == "openl3" and OPENL3_VERSION == "turian": # turian et al's torchopenl3, https://github.com/torchopenl3/torchopenl3 133 | sample_rate = 48000 134 | embedder = torchopenl3.models.load_audio_embedding_model(input_repr="mel256", content_type="music", embedding_size=512) 135 | pass # turian et al's does all its setup when it's invoked 136 | else: 137 | raise ValueError("Sorry, other models not supported yet") 138 | 139 | if hasattr(embedder,'eval'): embedder.eval() 140 | return embedder, sample_rate 141 | 142 | 143 | GUDGUD_LICENSE = """For VGGish implementation: 144 | MIT License 145 | 146 | Copyright (c) 2022 Hao Hao Tan 147 | 148 | Permission is hereby granted, free of charge, to any person obtaining a copy 149 | of this software and associated documentation files (the "Software"), to deal 150 | in the Software without restriction, including without limitation the rights 151 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 152 | copies of the Software, and to permit persons to whom the Software is 153 | furnished to do so, subject to the following conditions: 154 | 155 | The above copyright notice and this permission notice shall be included in all 156 | copies or substantial portions of the Software. 157 | 158 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 159 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 160 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 161 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 162 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 163 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 164 | SOFTWARE. 165 | """ 166 | 167 | # %% ../nbs/02_fad_embed.ipynb 10 168 | def embed(args): 169 | model_choice, real_path, fake_path, chunk_size, sr, max_batch_size, debug = args.embed_model, args.real_path, args.fake_path, args.chunk_size, args.sr, args.batch_size, args.debug 170 | 171 | sample_rate = sr 172 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 173 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 174 | ddps = f"[{local_rank}/{world_size}]" # string for distributed computing info, e.g. "[1/8]" 175 | 176 | accelerator = Accelerator() 177 | hprint = HostPrinter(accelerator) # hprint only prints on head node 178 | device = accelerator.device # get_device() 179 | hprint(f"{ddps} args = {args}") 180 | hprint(f'{ddps} Using device: {device}') 181 | 182 | 183 | """ # let accelerate split up the files among processsors 184 | # get the list(s) of audio files 185 | real_filenames = get_audio_filenames(real_path) 186 | #hprint(f"{ddps} real_path, real_filenames = {real_path}, {real_filenames}") 187 | fake_filenames = get_audio_filenames(fake_path) 188 | minlen = len(real_filenames) 189 | if len(real_filenames) != len(fake_filenames): 190 | hprint(f"{ddps} WARNING: len(real_filenames)=={len(real_filenames)} != len(fake_filenames)=={len(fake_filenames)}. Truncating to shorter list") 191 | minlen = min( len(real_filenames) , len(fake_filenames) ) 192 | 193 | # subdivide file lists by process number 194 | num_per_proc = minlen // world_size 195 | start = local_rank * num_per_proc 196 | end = minlen if local_rank == world_size-1 else (local_rank+1) * num_per_proc 197 | #print(f"{ddps} start, end = ",start,end) 198 | real_filenames, fake_filenames = real_filenames[start:end], fake_filenames[start:end] 199 | """ 200 | 201 | model_choices = [model_choice] if model_choice != 'all' else ['clap','vggish','pann','openl3'] 202 | 203 | for model_choice in model_choices: # loop over multiple embedders 204 | hprint(f"\n ** Model_choice = {model_choice}") 205 | # setup embedder and dataloader 206 | embedder, emb_sample_rate = setup_embedder(model_choice, device=device, accelerator=accelerator) 207 | if sr != emb_sample_rate: 208 | hprint(f"\n*******\nWARNING: sr={sr} != {model_choice}'s emb_sample_rate={emb_sample_rate}. Will resample audio to the latter\n*******\n") 209 | sr = emb_sample_rate 210 | hprint(f"{ddps} Embedder '{model_choice}' ready to go!") 211 | 212 | # we read audio in length args.sample_size, cut it into chunks of args,chunk_size to embed, and skip args.hop_size between chunks 213 | # pads with zeros btw 214 | real_dataset = AudioDataset(real_path, sample_rate=emb_sample_rate, sample_size=args.sample_size, return_dict=True, verbose=args.verbose) 215 | fake_dataset = AudioDataset(fake_path, sample_rate=emb_sample_rate, sample_size=args.sample_size, return_dict=True, verbose=args.verbose) 216 | batch_size = min( len(real_dataset) // world_size , max_batch_size ) 217 | hprint(f"\nGiven max_batch_size = {max_batch_size}, len(real_dataset) = {len(real_dataset)}, and world_size = {world_size}, we'll use batch_size = {batch_size}") 218 | real_dl = DataLoader(real_dataset, batch_size=batch_size, shuffle=False) 219 | fake_dl = DataLoader(fake_dataset, batch_size=batch_size, shuffle=False) 220 | 221 | real_dl, fake_dl, embedder = accelerator.prepare( real_dl, fake_dl, embedder ) # prepare handles distributing things among GPUs 222 | 223 | # note that we don't actually care if real & fake files are pulled in the same order; we'll only be comparing the *distributions* of the data. 224 | with torch.no_grad(): 225 | for dl, name in zip([real_dl, fake_dl],['real','fake']): 226 | for i, data_dict in enumerate(dl): # load audio files 227 | audio_sample_batch, filename_batch = data_dict['inputs'], data_dict['filename'] 228 | newdir_already = False 229 | if not newdir_already: 230 | p = Path( filename_batch[0] ) 231 | dir_already = True 232 | newdir = f"{p.parents[0]}_emb_{model_choice}" 233 | hprint(f"creating new directory = {newdir}") 234 | makedir(newdir) 235 | newdir_already = True 236 | # cut audio samples into chunks spaced out by hops, and loop over them 237 | hop_samples = int(args.hop_size * args.sample_size) 238 | hop_starts = np.arange(0, args.sample_size, hop_samples) 239 | if args.max_hops <= 0: 240 | hop_starts = hop_starts[:min(len(hop_starts), args.max_hops)] 241 | if args.sample_size - hop_starts[-1] < args.hop_size: # judgement call: let's not zero-pad on the very end, rather just don't do the last hop 242 | hop_starts = hop_starts[:-1] 243 | for h_ind, hop_loc in enumerate(hop_starts): # proceed through audio file batch via chunks, skipping by hop_samples each time 244 | chunk = audio_sample_batch[:,:,hop_loc:hop_loc+hop_samples] 245 | audio = chunk 246 | 247 | #print(f"{ddps} i = {i}/{len(real_dataset)}, filename = {filename_batch[0]}") 248 | audio = audio.to(device) 249 | 250 | 251 | if model_choice == 'clap': 252 | while len(audio.shape) < 3: 253 | audio = audio.unsqueeze(0) # add batch and/or channel dims 254 | embeddings = accelerator.unwrap_model(embedder).get_audio_embedding_from_data(audio.mean(dim=1).to(device), use_tensor=True).to(audio.dtype) 255 | 256 | elif model_choice == "vggish": 257 | audio = torch.mean(audio, dim=1) # vggish requries we convert to mono 258 | embeddings = [] # ...whoa, vggish can't even handle batches? we have to pass 'em through singly? 259 | for bi, waveform in enumerate(audio): 260 | e = accelerator.unwrap_model(embedder).forward(waveform.cpu().numpy(), emb_sample_rate) 261 | embeddings.append(e) 262 | embeddings = torch.cat(embeddings, dim=0) 263 | 264 | elif model_choice == "pann": 265 | audio = torch.mean(audio, dim=1) # mono only. todo: keepdim=True ? 266 | out = embedder.forward(audio, None) 267 | embeddings = out['embedding'].data 268 | 269 | elif model_choice == "openl3" and OPENL3_VERSION == "hugo": 270 | ##audio = torch.mean(audio, dim=1) # mono only. 271 | embeddings = [] 272 | for bi, waveform in enumerate( audio.cpu().numpy() ): # no batch processing, expects numpy 273 | e = torchopenl3.embed(model=embedder, 274 | audio=waveform, # shape sould be (channels, samples) 275 | sample_rate=emb_sample_rate, # sample rate of input file 276 | hop_size=1, device=device) 277 | if debug: hprint(f"bi = {bi}, waveform.shape = {waveform.shape}, e.shape = {e.shape}") 278 | embeddings.append(torch.tensor(e)) 279 | embeddings = torch.cat(embeddings, dim=0) 280 | 281 | elif model_choice == "openl3" and OPENL3_VERSION == "turian": 282 | # Note: turian's can/will do multiple time-stamped embeddings if the sample_size is long enough. but our chunks/hops precludes this 283 | 284 | #not needed, turns out: audio = renot needed, turns out: arrange(audio, 'b c s -> b s c') # this torchopen3 expects channels-first ordering 285 | embeddings, timestamps = torchopenl3.get_audio_embedding(audio, emb_sample_rate, model=embedder) 286 | embeddings = torch.squeeze(embeddings, 1) # get rid of any spurious dimensions of 1 in middle position 287 | 288 | else: 289 | raise ValueError(f"Unknown model_choice = {model_choice}") 290 | 291 | hprint(f"embeddings.shape = {embeddings.shape}") 292 | # TODO: for now we'll just dump each batch on each proc to its own file; this could be improved 293 | outfilename = f"{newdir}/emb_p{local_rank}_b{i}_h{h_ind}.pt" 294 | hprint(f"{ddps} Saving embeddings to {outfilename}") 295 | torch.save(embeddings.cpu().detach(), outfilename) 296 | 297 | del embedder 298 | torch.cuda.empty_cache() 299 | # end loop over various embedders 300 | return 301 | 302 | # %% ../nbs/02_fad_embed.ipynb 11 303 | def main(): 304 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 305 | parser.add_argument('embed_model', help='choice of embedding model(s): clap | vggish | pann | openl3 | all ', default='clap') 306 | parser.add_argument('real_path', help='Path of files of real audio', default='real/') 307 | parser.add_argument('fake_path', help='Path of files of fake audio', default='fake/') 308 | parser.add_argument('--batch_size', type=int, default=64, help='MAXIMUM Batch size for computing embeddings (may go smaller)') 309 | parser.add_argument('--sample_size', type=int, default=2**18, help='Number of audio samples to read from each audio file') 310 | parser.add_argument('--chunk_size', type=int, default=24000, help='Length of chunks (in audio samples) to embed') 311 | parser.add_argument('--hop_size', type=float, default=0.100, help='(approximate) time difference (in seconds) between each chunk') 312 | parser.add_argument('--max_hops', type=int, default=-1, help="Don't exceed this many hops/chunks/embeddings per audio file. <= 0 disables this.") 313 | parser.add_argument('--sr', type=int, default=48000, help='sample rate (will resample inputs at this rate)') 314 | parser.add_argument('--verbose', action='store_true', help='Show notices of resampling when reading files') 315 | parser.add_argument('--debug', action='store_true', help='Extra messages for debugging this program') 316 | 317 | args = parser.parse_args() 318 | embed(args) 319 | 320 | # %% ../nbs/02_fad_embed.ipynb 12 321 | if __name__ == '__main__' and "get_ipython" not in dir(): 322 | main() 323 | -------------------------------------------------------------------------------- /fad_pytorch/fad_gen.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/01_fad_gen.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['gen', 'main'] 5 | 6 | # %% ../nbs/01_fad_gen.ipynb 6 7 | import os 8 | import argparse 9 | from accelerate import Accelerator 10 | import warnings 11 | import torch 12 | 13 | from aeiou.core import get_device, load_audio, get_audio_filenames, makedir 14 | from aeiou.datasets import get_wds_loader, AudioDataset 15 | from aeiou.hpc import HostPrinter 16 | from pathlib import Path 17 | #from audio_algebra.given_models import StackedDiffAEWrapper 18 | import ast 19 | import torchaudio 20 | from tqdm.auto import tqdm 21 | import math 22 | 23 | # %% ../nbs/01_fad_gen.ipynb 7 24 | def gen(args): 25 | 26 | # HPC / parallel setup 27 | local_rank = int(os.environ.get("LOCAL_RANK", 0)) 28 | world_size = int(os.environ.get("WORLD_SIZE", 1)) 29 | ddps = f"[{local_rank}/{world_size}]" # string for distributed computing info, e.g. "[1/8]" 30 | accelerator = Accelerator() 31 | hprint = HostPrinter(accelerator) # hprint only prints on head node 32 | device = accelerator.device # get_device() 33 | hprint(f"gen: args = {args}") 34 | hprint(f'{ddps} Using device: {device}') 35 | 36 | 37 | model_ckpt, data_sources, profiles, n = args.model_ckpt, args.data_sources, args.profiles, args.n 38 | names = data_sources.split(' ') 39 | #hprint(f"names = {names}") 40 | local_data = False 41 | if 's3://' in data_sources: 42 | hprint("Data sources are on S3") 43 | profiles = ast.literal_eval(profiles) 44 | hprint(f"profiles = {profiles}") 45 | 46 | dl = get_wds_loader( 47 | batch_size=args.batch_size, 48 | s3_url_prefix=None, 49 | sample_size=args.sample_size, 50 | names=names, 51 | sample_rate=args.sample_rate, 52 | num_workers=args.num_workers, 53 | recursive=True, 54 | random_crop=True, 55 | epoch_steps=10000, 56 | profiles=profiles, 57 | ) 58 | else: 59 | hprint("Data sources are local") 60 | dataset = AudioDataset(names, sample_rate=args.sample_rate, sample_size=args.sample_size) 61 | dl = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers) 62 | local_data = True 63 | 64 | print(f"loading {model_ckpt}....") 65 | if model_ckpt.endswith('.ts'): 66 | model = torch.jit.load(model_ckpt) 67 | #else: # default is stacked diffae 68 | # model = StackedDiffAEWrapper(ckpt_info={'ckpt_path':model_ckpt}) 69 | try: 70 | model.setup() # if it needs setup call 71 | except: 72 | pass 73 | 74 | model.eval() 75 | model = model.to(device) 76 | 77 | model, dl = accelerator.prepare( model, dl ) # prepare handles distributing things among GPUs 78 | 79 | reals_path, fakes_path = f"{args.name}_reals", f"{args.name}_fakes" 80 | makedir(reals_path) 81 | makedir(fakes_path) 82 | 83 | progress_bar = tqdm(range(math.ceil(args.n/args.batch_size)), disable=not accelerator.is_local_main_process) 84 | 85 | for i, data in enumerate(dl): 86 | reals = data if local_data else data[0][0] 87 | if args.debug: hprint(f"{ddps} i = {i}, reals.shape = {reals.shape}") 88 | 89 | with torch.no_grad(): 90 | fakes = accelerator.unwrap_model(model).forward(reals.to(device)).cpu() 91 | #hprint(f"fakes.shape = {fakes.shape}") 92 | 93 | for b in range(reals.shape[0]): 94 | waveform = reals[b] 95 | torchaudio.save(f"{reals_path}/{i}_{b}.wav", waveform.cpu(), args.sample_rate) 96 | waveform = fakes[b] 97 | torchaudio.save(f"{fakes_path}/{i}_{b}.wav", waveform.cpu(), args.sample_rate) 98 | 99 | progress_bar.update(1) 100 | if (i+1)*args.batch_size > args.n: 101 | hprint(f"\nGot all the data we needed: {i*args.batch_size}. Stopping") 102 | break 103 | 104 | hprint("Success!") 105 | 106 | # %% ../nbs/01_fad_gen.ipynb 8 107 | def main(): 108 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 109 | parser.add_argument('name', help='Name prefix for output directories: _reals/ & _fakes/') 110 | parser.add_argument('model_ckpt', help='TorchScript (.ts) (Generative) Model checkpoint file') 111 | parser.add_argument('data_sources', help='Space-separated string listing either S3 resources or local directories (but not a mix of both!) for real data') 112 | parser.add_argument('-d','--debug', action="store_true", help='Enable extra debugging messages') 113 | parser.add_argument('-b',"--batch_size", default=2, help='batch size') 114 | parser.add_argument('--n', type=int, default=256, help='Number of real/fake samples to grab/generate, respectively') 115 | parser.add_argument('--num_workers', type=int, default=12, help='Number of pytorch workers to use in DataLoader') 116 | parser.add_argument('-p',"--profiles", default='', help='String representation of dict {resource:profile} of AWS credentials') 117 | parser.add_argument('--sample_rate', type=int, default=48000, help='sample rate (will resample inputs at this rate)') 118 | parser.add_argument('-s','--sample_size', type=int, default=2**18, help='Number of samples per clip') 119 | args = parser.parse_args() 120 | gen( args ) 121 | 122 | # %% ../nbs/01_fad_gen.ipynb 9 123 | if __name__ == '__main__' and "get_ipython" not in dir(): 124 | main() 125 | -------------------------------------------------------------------------------- /fad_pytorch/fad_score.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/03_fad_score.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['read_embeddings', 'calc_mu_sigma', 'calc_score', 'main'] 5 | 6 | # %% ../nbs/03_fad_score.ipynb 5 7 | import torch 8 | import argparse 9 | from .sqrtm import sqrtm 10 | from aeiou.core import fast_scandir 11 | 12 | # %% ../nbs/03_fad_score.ipynb 6 13 | def read_embeddings(emb_path='real_emb_clap/', debug=False): 14 | "reads any .pt files in emb_path and concatenates them into one tensor" 15 | if debug: print("searching in ",emb_path) 16 | _, file_list = fast_scandir(emb_path, ['pt']) 17 | if file_list == []: 18 | _, file_list = fast_scandir('/fsx/shawley/code/fad_pytorch/'+emb_path, ['pt']) # yea, cheap hack just for my testing in nbs/ dir 19 | assert file_list != [] 20 | embeddings = [] 21 | for file_path in file_list: 22 | emb_batch = torch.load(file_path, map_location='cpu') 23 | embeddings.append(emb_batch) 24 | return torch.cat(embeddings, dim=0) 25 | 26 | # %% ../nbs/03_fad_score.ipynb 8 27 | def calc_mu_sigma(emb): 28 | "calculates mean and covariance matrix of batched embeddings" 29 | mu = torch.mean(emb, axis=0) 30 | sigma = torch.cov(emb.T) 31 | return mu, sigma 32 | 33 | # %% ../nbs/03_fad_score.ipynb 10 34 | def calc_score(real_emb_path, # where real embeddings are stored 35 | fake_emb_path, # where fake embeddings are stored 36 | method='maji', # sqrtm calc method: 'maji'|'li' 37 | debug=False 38 | ): 39 | print(f"Calculating FAD score for files in {real_emb_path}/ vs. {fake_emb_path}/") 40 | emb_real = read_embeddings(emb_path=real_emb_path, debug=debug) 41 | emb_fake = read_embeddings(emb_path=fake_emb_path, debug=debug) 42 | if debug: print(emb_real.shape, emb_fake.shape) 43 | 44 | mu_real, sigma_real = calc_mu_sigma(emb_real) 45 | mu_fake, sigma_fake = calc_mu_sigma(emb_fake) 46 | if debug: 47 | print("mu_real.shape, sigma_real.shape =",mu_real.shape, sigma_real.shape) 48 | print("mu_fake.shape, sigma_fake.shape =",mu_fake.shape, sigma_fake.shape) 49 | 50 | mu_diff = mu_real - mu_fake 51 | if debug: 52 | print("mu_diff = ",mu_diff) 53 | score1 = mu_diff.dot(mu_diff) 54 | print("score1: mu_diff.dot(mu_diff) = ",score1) 55 | score2 = torch.trace(sigma_real) 56 | print("score2: torch.trace(sigma_real) = ", score2) 57 | score3 = torch.trace(sigma_fake) 58 | print("score3: torch.trace(sigma_fake) = ",score3) 59 | score_p = sqrtm( torch.matmul( sigma_real, sigma_fake) ) 60 | print("score_p.shape (matmul) = ",score_p.shape) 61 | score4 = -2* torch.trace ( torch.real ( sqrtm( torch.matmul( sigma_real, sigma_fake) , method=method ) ) ) 62 | print("score4 (-2*tr(sqrtm(matmul(sigma_r sigma_f)))) = ",score4) 63 | score = score1 + score2 + score3 + score4 64 | score = mu_diff.dot(mu_diff) + torch.trace(sigma_real) + torch.trace(sigma_fake) -2* torch.trace ( torch.real ( sqrtm( torch.matmul( sigma_real, sigma_fake), method=method ) ) ) 65 | return score 66 | 67 | # %% ../nbs/03_fad_score.ipynb 16 68 | def main(): 69 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 70 | parser.add_argument('real_emb_path', help='Path of files of embeddings of real data', default='real_emb_clap/') 71 | parser.add_argument('fake_emb_path', help='Path of files of embeddings of fake data', default='fake_emb_clap/') 72 | parser.add_argument('-d','--debug', action='store_true', help='Enable debugging') 73 | parser.add_argument('-m','--method', default='maji', help='Method for sqrtm calculation: "maji" or "li" ') 74 | 75 | args = parser.parse_args() 76 | score = calc_score( args.real_emb_path, args.fake_emb_path, method=args.method, debug=args.debug ) 77 | print("FAD score = ",score.cpu().numpy()) 78 | 79 | # %% ../nbs/03_fad_score.ipynb 17 80 | if __name__ == '__main__' and "get_ipython" not in dir(): 81 | main() 82 | -------------------------------------------------------------------------------- /fad_pytorch/pann_pytorch_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | From: https://github.com/qiuqiangkong/audioset_tagging_cnn/blob/master/pytorch/pytorch_utils.py 3 | 4 | Copyright (c) 2018-2020 Qiuqiang Kong 5 | """ 6 | LICENSE_KONG = """ 7 | The MIT License 8 | 9 | Copyright (c) 2018-2020 Qiuqiang Kong 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy 12 | of this software and associated documentation files (the "Software"), to deal 13 | in the Software without restriction, including without limitation the rights 14 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 15 | copies of the Software, and to permit persons to whom the Software is 16 | furnished to do so, subject to the following conditions: 17 | 18 | The above copyright notice and this permission notice shall be included in 19 | all copies or substantial portions of the Software. 20 | 21 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 22 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 23 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 24 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 25 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 26 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 27 | THE SOFTWARE.""" 28 | 29 | 30 | 31 | 32 | import numpy as np 33 | import time 34 | import torch 35 | import torch.nn as nn 36 | 37 | 38 | def move_data_to_device(x, device): 39 | if 'float' in str(x.dtype): 40 | x = torch.Tensor(x) 41 | elif 'int' in str(x.dtype): 42 | x = torch.LongTensor(x) 43 | else: 44 | return x 45 | 46 | return x.to(device) 47 | 48 | 49 | def do_mixup(x, mixup_lambda): 50 | """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes 51 | (1, 3, 5, ...). 52 | 53 | Args: 54 | x: (batch_size * 2, ...) 55 | mixup_lambda: (batch_size * 2,) 56 | 57 | Returns: 58 | out: (batch_size, ...) 59 | """ 60 | out = (x[0 :: 2].transpose(0, -1) * mixup_lambda[0 :: 2] + \ 61 | x[1 :: 2].transpose(0, -1) * mixup_lambda[1 :: 2]).transpose(0, -1) 62 | return out 63 | 64 | 65 | def append_to_dict(dict, key, value): 66 | if key in dict.keys(): 67 | dict[key].append(value) 68 | else: 69 | dict[key] = [value] 70 | 71 | 72 | def forward(model, generator, return_input=False, 73 | return_target=False): 74 | """Forward data to a model. 75 | 76 | Args: 77 | model: object 78 | generator: object 79 | return_input: bool 80 | return_target: bool 81 | 82 | Returns: 83 | audio_name: (audios_num,) 84 | clipwise_output: (audios_num, classes_num) 85 | (ifexist) segmentwise_output: (audios_num, segments_num, classes_num) 86 | (ifexist) framewise_output: (audios_num, frames_num, classes_num) 87 | (optional) return_input: (audios_num, segment_samples) 88 | (optional) return_target: (audios_num, classes_num) 89 | """ 90 | output_dict = {} 91 | device = next(model.parameters()).device 92 | time1 = time.time() 93 | 94 | # Forward data to a model in mini-batches 95 | for n, batch_data_dict in enumerate(generator): 96 | print(n) 97 | batch_waveform = move_data_to_device(batch_data_dict['waveform'], device) 98 | 99 | with torch.no_grad(): 100 | model.eval() 101 | batch_output = model(batch_waveform) 102 | 103 | append_to_dict(output_dict, 'audio_name', batch_data_dict['audio_name']) 104 | 105 | append_to_dict(output_dict, 'clipwise_output', 106 | batch_output['clipwise_output'].data.cpu().numpy()) 107 | 108 | if 'segmentwise_output' in batch_output.keys(): 109 | append_to_dict(output_dict, 'segmentwise_output', 110 | batch_output['segmentwise_output'].data.cpu().numpy()) 111 | 112 | if 'framewise_output' in batch_output.keys(): 113 | append_to_dict(output_dict, 'framewise_output', 114 | batch_output['framewise_output'].data.cpu().numpy()) 115 | 116 | if return_input: 117 | append_to_dict(output_dict, 'waveform', batch_data_dict['waveform']) 118 | 119 | if return_target: 120 | if 'target' in batch_data_dict.keys(): 121 | append_to_dict(output_dict, 'target', batch_data_dict['target']) 122 | 123 | if n % 10 == 0: 124 | print(' --- Inference time: {:.3f} s / 10 iterations ---'.format( 125 | time.time() - time1)) 126 | time1 = time.time() 127 | 128 | for key in output_dict.keys(): 129 | output_dict[key] = np.concatenate(output_dict[key], axis=0) 130 | 131 | return output_dict 132 | 133 | 134 | def interpolate(x, ratio): 135 | """Interpolate data in time domain. This is used to compensate the 136 | resolution reduction in downsampling of a CNN. 137 | 138 | Args: 139 | x: (batch_size, time_steps, classes_num) 140 | ratio: int, ratio to interpolate 141 | 142 | Returns: 143 | upsampled: (batch_size, time_steps * ratio, classes_num) 144 | """ 145 | (batch_size, time_steps, classes_num) = x.shape 146 | upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) 147 | upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) 148 | return upsampled 149 | 150 | 151 | def pad_framewise_output(framewise_output, frames_num): 152 | """Pad framewise_output to the same length as input frames. The pad value 153 | is the same as the value of the last frame. 154 | 155 | Args: 156 | framewise_output: (batch_size, frames_num, classes_num) 157 | frames_num: int, number of frames to pad 158 | 159 | Outputs: 160 | output: (batch_size, frames_num, classes_num) 161 | """ 162 | pad = framewise_output[:, -1 :, :].repeat(1, frames_num - framewise_output.shape[1], 1) 163 | """tensor for padding""" 164 | 165 | output = torch.cat((framewise_output, pad), dim=1) 166 | """(batch_size, frames_num, classes_num)""" 167 | 168 | return output 169 | 170 | 171 | def count_parameters(model): 172 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 173 | 174 | 175 | def count_flops(model, audio_length): 176 | """Count flops. Code modified from others' implementation. 177 | """ 178 | multiply_adds = True 179 | list_conv2d=[] 180 | def conv2d_hook(self, input, output): 181 | batch_size, input_channels, input_height, input_width = input[0].size() 182 | output_channels, output_height, output_width = output[0].size() 183 | 184 | kernel_ops = self.kernel_size[0] * self.kernel_size[1] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 185 | bias_ops = 1 if self.bias is not None else 0 186 | 187 | params = output_channels * (kernel_ops + bias_ops) 188 | flops = batch_size * params * output_height * output_width 189 | 190 | list_conv2d.append(flops) 191 | 192 | list_conv1d=[] 193 | def conv1d_hook(self, input, output): 194 | batch_size, input_channels, input_length = input[0].size() 195 | output_channels, output_length = output[0].size() 196 | 197 | kernel_ops = self.kernel_size[0] * (self.in_channels / self.groups) * (2 if multiply_adds else 1) 198 | bias_ops = 1 if self.bias is not None else 0 199 | 200 | params = output_channels * (kernel_ops + bias_ops) 201 | flops = batch_size * params * output_length 202 | 203 | list_conv1d.append(flops) 204 | 205 | list_linear=[] 206 | def linear_hook(self, input, output): 207 | batch_size = input[0].size(0) if input[0].dim() == 2 else 1 208 | 209 | weight_ops = self.weight.nelement() * (2 if multiply_adds else 1) 210 | bias_ops = self.bias.nelement() 211 | 212 | flops = batch_size * (weight_ops + bias_ops) 213 | list_linear.append(flops) 214 | 215 | list_bn=[] 216 | def bn_hook(self, input, output): 217 | list_bn.append(input[0].nelement() * 2) 218 | 219 | list_relu=[] 220 | def relu_hook(self, input, output): 221 | list_relu.append(input[0].nelement() * 2) 222 | 223 | list_pooling2d=[] 224 | def pooling2d_hook(self, input, output): 225 | batch_size, input_channels, input_height, input_width = input[0].size() 226 | output_channels, output_height, output_width = output[0].size() 227 | 228 | kernel_ops = self.kernel_size * self.kernel_size 229 | bias_ops = 0 230 | params = output_channels * (kernel_ops + bias_ops) 231 | flops = batch_size * params * output_height * output_width 232 | 233 | list_pooling2d.append(flops) 234 | 235 | list_pooling1d=[] 236 | def pooling1d_hook(self, input, output): 237 | batch_size, input_channels, input_length = input[0].size() 238 | output_channels, output_length = output[0].size() 239 | 240 | kernel_ops = self.kernel_size[0] 241 | bias_ops = 0 242 | 243 | params = output_channels * (kernel_ops + bias_ops) 244 | flops = batch_size * params * output_length 245 | 246 | list_pooling2d.append(flops) 247 | 248 | def foo(net): 249 | childrens = list(net.children()) 250 | if not childrens: 251 | if isinstance(net, nn.Conv2d): 252 | net.register_forward_hook(conv2d_hook) 253 | elif isinstance(net, nn.Conv1d): 254 | net.register_forward_hook(conv1d_hook) 255 | elif isinstance(net, nn.Linear): 256 | net.register_forward_hook(linear_hook) 257 | elif isinstance(net, nn.BatchNorm2d) or isinstance(net, nn.BatchNorm1d): 258 | net.register_forward_hook(bn_hook) 259 | elif isinstance(net, nn.ReLU): 260 | net.register_forward_hook(relu_hook) 261 | elif isinstance(net, nn.AvgPool2d) or isinstance(net, nn.MaxPool2d): 262 | net.register_forward_hook(pooling2d_hook) 263 | elif isinstance(net, nn.AvgPool1d) or isinstance(net, nn.MaxPool1d): 264 | net.register_forward_hook(pooling1d_hook) 265 | else: 266 | print('Warning: flop of module {} is not counted!'.format(net)) 267 | return 268 | for c in childrens: 269 | foo(c) 270 | 271 | # Register hook 272 | foo(model) 273 | 274 | device = device = next(model.parameters()).device 275 | input = torch.rand(1, audio_length).to(device) 276 | 277 | out = model(input) 278 | 279 | total_flops = sum(list_conv2d) + sum(list_conv1d) + sum(list_linear) + \ 280 | sum(list_bn) + sum(list_relu) + sum(list_pooling2d) + sum(list_pooling1d) 281 | 282 | return total_flops -------------------------------------------------------------------------------- /fad_pytorch/sqrtm.py: -------------------------------------------------------------------------------- 1 | # AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/04_sqrtm.ipynb. 2 | 3 | # %% auto 0 4 | __all__ = ['use_li', 'sqrtm_li', 'LICENCE_LI', 'LICENCE_SM', 'MatrixSquareRoot_li', 'compute_error', 5 | 'sqrt_newton_schulz_autograd', 'sqrt_newton_schulz', 'sqrtm'] 6 | 7 | # %% ../nbs/04_sqrtm.ipynb 4 8 | import torch 9 | from torch.autograd import Function, Variable 10 | 11 | # %% ../nbs/04_sqrtm.ipynb 6 12 | use_li = True # come back and turn this on if you want to see/use the full code 13 | 14 | if use_li: # lighten the load of imports since we won't use li's in production 15 | import numpy as np 16 | import scipy.linalg 17 | 18 | # %% ../nbs/04_sqrtm.ipynb 7 19 | class MatrixSquareRoot_li(Function): 20 | """ 21 | From https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py, which sadly does not install as a package. LICENSE included below 22 | Square root of a positive definite matrix. 23 | 24 | NOTE: matrix square root is not differentiable for matrices with 25 | zero eigenvalues. 26 | """ 27 | @staticmethod 28 | def forward(ctx, input): 29 | m = input.detach().cpu().numpy().astype(np.float_) # SHH: note how this immediately switches to CPU & numpy :-( 30 | sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).to(input) 31 | ctx.save_for_backward(sqrtm) 32 | return sqrtm 33 | 34 | @staticmethod 35 | def backward(ctx, grad_output): 36 | grad_input = None 37 | if ctx.needs_input_grad[0]: 38 | sqrtm, = ctx.saved_tensors 39 | sqrtm = sqrtm.data.cpu().numpy().astype(np.float_) 40 | gm = grad_output.data.cpu().numpy().astype(np.float_) 41 | 42 | # Given a positive semi-definite matrix X, 43 | # since X = X^{1/2}X^{1/2}, we can compute the gradient of the 44 | # matrix square root dX^{1/2} by solving the Sylvester equation: 45 | # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}). 46 | grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm) 47 | 48 | grad_input = torch.from_numpy(grad_sqrtm).to(grad_output) 49 | return grad_input 50 | 51 | 52 | sqrtm_li = MatrixSquareRoot_li.apply 53 | 54 | 55 | LICENCE_LI = """ 56 | MIT License 57 | 58 | Copyright (c) 2022 Steven Cheng-Xian Li 59 | 60 | Permission is hereby granted, free of charge, to any person obtaining a copy 61 | of this software and associated documentation files (the "Software"), to deal 62 | in the Software without restriction, including without limitation the rights 63 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 64 | copies of the Software, and to permit persons to whom the Software is 65 | furnished to do so, subject to the following conditions: 66 | 67 | The above copyright notice and this permission notice shall be included in all 68 | copies or substantial portions of the Software. 69 | 70 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 71 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 72 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 73 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 74 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 75 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 76 | SOFTWARE. 77 | """ 78 | 79 | # %% ../nbs/04_sqrtm.ipynb 12 80 | # Compute error 81 | def compute_error(A, sA): 82 | normA = torch.sqrt(torch.sum(torch.sum(A * A, dim=1),dim=1)) 83 | error = A - torch.bmm(sA, sA) 84 | error = torch.sqrt((error * error).sum(dim=1).sum(dim=1)) / normA 85 | return torch.mean(error) 86 | 87 | 88 | def sqrt_newton_schulz_autograd(A, 89 | numIters=20, # found experimentally by SHH, comparing w/ Li's method 90 | calc_error=False,): 91 | """Modified from from https://people.cs.umass.edu/~smaji/projects/matrix-sqrt/ 92 | "The drawback of the autograd approach [i.e., this approach] is that a naive implementation stores all the intermediate results. 93 | Thus the memory overhead scales linearly with the number of iterations which is problematic for large matrices." 94 | """ 95 | if len(A.data.shape) < 3: A = A.unsqueeze(0) 96 | batchSize, dim1, dim2 = A.data.shape 97 | assert dim1==dim2 98 | dim, dtype, device = dim1, A.dtype, A.device 99 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() 100 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)); 101 | I = Variable(torch.eye(dim,dim, device=device).view(1, dim, dim). 102 | repeat(batchSize,1,1).type(dtype),requires_grad=False) 103 | Z = Variable(torch.eye(dim,dim, device=device).view(1, dim, dim). 104 | repeat(batchSize,1,1).type(dtype),requires_grad=False) 105 | 106 | for i in range(numIters): 107 | T = 0.5*(3.0*I - Z.bmm(Y)) 108 | Y = Y.bmm(T) 109 | Z = T.bmm(Z) 110 | 111 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 112 | if calc_error: 113 | error = compute_error(A, sA) 114 | return sA, error 115 | return sA 116 | 117 | 118 | def sqrt_newton_schulz(A, # matrix to be sqrt-ified 119 | numIters=20, # numIters=7 found via experimentation 120 | calc_error=False, # setting False disables Maji's error reporting 121 | ): 122 | """ 123 | Sqrt of matrix via Newton-Schulz algorithm 124 | Modified from https://github.com/msubhransu/matrix-sqrt/blob/cc2289a3ed7042b8dbacd53ce8a34da1f814ed2f/matrix_sqrt.py#LL72C1-L87C19 125 | # Forward via Newton-Schulz iterations (non autograd version) 126 | # Seems to be slighlty faster and has much lower memory overhead 127 | 128 | ... Original code didn't preserve device, had no batch dim checking -SHH 129 | """ 130 | while len(A.data.shape) < 3: # needs a batch dimension 131 | A = A.unsqueeze(0) 132 | batchSize, dim1, dim2 = A.data.shape 133 | assert dim1==dim2 134 | dim, dtype, device = dim1, A.dtype, A.device 135 | normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt() 136 | Y = A.div(normA.view(batchSize, 1, 1).expand_as(A)); 137 | I = torch.eye(dim,dim, device=device, dtype=dtype).view(1, dim, dim).repeat(batchSize,1,1) 138 | Z = torch.eye(dim,dim, device=device, dtype=dtype).view(1, dim, dim).repeat(batchSize,1,1) 139 | for i in range(numIters): 140 | T = 0.5*(3.0*I - Z.bmm(Y)) 141 | Y = Y.bmm(T) 142 | Z = T.bmm(Z) 143 | 144 | sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A) 145 | if calc_error: 146 | error = compute_error(A, sA) 147 | return sA, error 148 | else: 149 | return sA 150 | 151 | 152 | """ 153 | # Only used if backprop needed, which it isn't for FAD. Leaving it here anyway. -SHH 154 | def lyap_newton_schulz(z, dldz, numIters, dtype): 155 | # Backward via iterative Lyapunov solver. 156 | batchSize = z.shape[0] 157 | dim = z.shape[1] 158 | normz = z.mul(z).sum(dim=1).sum(dim=1).sqrt() 159 | a = z.div(normz.view(batchSize, 1, 1).expand_as(z)) 160 | I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype) 161 | q = dldz.div(normz.view(batchSize, 1, 1).expand_as(z)) 162 | for i in range(numIters): 163 | q = 0.5*(q.bmm(3.0*I - a.bmm(a)) - a.transpose(1, 2).bmm(a.transpose(1,2).bmm(q) - q.bmm(a)) ) 164 | a = 0.5*a.bmm(3.0*I - a.bmm(a)) 165 | dlda = 0.5*q 166 | return dlda 167 | """ 168 | 169 | 170 | LICENCE_SM = """ 171 | MIT License 172 | 173 | Copyright (c) 2017 Subhransu Maji 174 | 175 | Permission is hereby granted, free of charge, to any person obtaining a copy 176 | of this software and associated documentation files (the "Software"), to deal 177 | in the Software without restriction, including without limitation the rights 178 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 179 | copies of the Software, and to permit persons to whom the Software is 180 | furnished to do so, subject to the following conditions: 181 | 182 | The above copyright notice and this permission notice shall be included in all 183 | copies or substantial portions of the Software. 184 | 185 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 186 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 187 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 188 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 189 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 190 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 191 | SOFTWARE. 192 | """ 193 | 194 | # %% ../nbs/04_sqrtm.ipynb 27 195 | def sqrtm(A, method='maji', numIters=20): 196 | "wrapper function for matrix sqrt algorithm of choice. Also we'll turn off all gradients" 197 | with torch.no_grad(): 198 | if method=='maji': 199 | return sqrt_newton_schulz(A, numIters=numIters, calc_error=False).squeeze() # get rid of any useless batch dimensions 200 | elif method=='li': 201 | return sqrtm_li(A) 202 | else: 203 | raise ValueError(f"Invalid method: {method}") 204 | -------------------------------------------------------------------------------- /nbs/01_fad_gen.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#| hide\n", 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# fad_gen\n", 19 | "\n", 20 | "> Produce directories of real and fake audio" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "This program may not be needed if you already have directories of real & fake audio. " 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": null, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "#| default_exp fad_gen" 37 | ] 38 | }, 39 | { 40 | "cell_type": "code", 41 | "execution_count": null, 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "#| hide\n", 46 | "from nbdev.showdoc import *" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Sample calling sequence(s):\n", 54 | "\n", 55 | "Single GPU, local data:\n", 56 | "```\n", 57 | "fad_gen test autoencoder.ts \"real1/ real2/ real3/\"\n", 58 | "```\n", 59 | "\n", 60 | "\n", 61 | "Multiple GPUs, data on S3:\n", 62 | "```\n", 63 | "accelerate launch fad_pytorch/fad_gen.py 5s_simple model_checkpoint.ts \"s3://s-laion-audio/webdataset_tar/freesound_no_overlap/ s3://s-laion-audio/webdataset_tar/epidemic_sound_effects/\" -p \"{'s3://s-laion-audio':'default'}\"\n", 64 | "```\n", 65 | "\n", 66 | "\n", 67 | "\n", 68 | "General calling sequence:\n", 69 | "```\n", 70 | "$ fad_gen -h\n", 71 | "usage: fad_gen [-h] [-b BATCH_SIZE] [--n N] [--num_workers NUM_WORKERS] [-p PROFILES] [--sample_rate SAMPLE_RATE] [-s SAMPLE_SIZE]\n", 72 | " name model_ckpt data_sources\n", 73 | "\n", 74 | "positional arguments:\n", 75 | " name Name prefix for output directories: _reals/ & _fakes/\n", 76 | " model_ckpt TorchScript (.ts) (Generative) Model checkpoint file\n", 77 | " data_sources Space-separated string listing either S3 resources or local directories (but not a mix of both!) for real data\n", 78 | "\n", 79 | "optional arguments:\n", 80 | " -h, --help show this help message and exit\n", 81 | " -b BATCH_SIZE, --batch_size BATCH_SIZE\n", 82 | " batch size (default: 2)\n", 83 | " --n N Number of real/fake samples to grab/generate, respectively (default: 256)\n", 84 | " --num_workers NUM_WORKERS\n", 85 | " Number of pytorch workers to use in DataLoader (default: 12)\n", 86 | " -p PROFILES, --profiles PROFILES\n", 87 | " String representation of dict {resource:profile} (default: )\n", 88 | " --sample_rate SAMPLE_RATE\n", 89 | " sample rate (will resample inputs at this rate) (default: 48000)\n", 90 | " -s SAMPLE_SIZE, --sample_size SAMPLE_SIZE\n", 91 | " Number of samples per clip (default: 262144)\n", 92 | "```" 93 | ] 94 | }, 95 | { 96 | "cell_type": "code", 97 | "execution_count": null, 98 | "metadata": {}, 99 | "outputs": [], 100 | "source": [ 101 | "#| export\n", 102 | "import os\n", 103 | "import argparse\n", 104 | "from accelerate import Accelerator\n", 105 | "import warnings\n", 106 | "import torch\n", 107 | "\n", 108 | "from aeiou.core import get_device, load_audio, get_audio_filenames, makedir\n", 109 | "from aeiou.datasets import get_wds_loader, AudioDataset\n", 110 | "from aeiou.hpc import HostPrinter\n", 111 | "from pathlib import Path\n", 112 | "#from audio_algebra.given_models import StackedDiffAEWrapper\n", 113 | "import ast\n", 114 | "import torchaudio\n", 115 | "from tqdm.auto import tqdm\n", 116 | "import math" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": null, 122 | "metadata": {}, 123 | "outputs": [], 124 | "source": [ 125 | "#| export\n", 126 | "def gen(args):\n", 127 | " \n", 128 | " # HPC / parallel setup\n", 129 | " local_rank = int(os.environ.get(\"LOCAL_RANK\", 0))\n", 130 | " world_size = int(os.environ.get(\"WORLD_SIZE\", 1))\n", 131 | " ddps = f\"[{local_rank}/{world_size}]\" # string for distributed computing info, e.g. \"[1/8]\" \n", 132 | " accelerator = Accelerator()\n", 133 | " hprint = HostPrinter(accelerator) # hprint only prints on head node\n", 134 | " device = accelerator.device # get_device()\n", 135 | " hprint(f\"gen: args = {args}\")\n", 136 | " hprint(f'{ddps} Using device: {device}')\n", 137 | " \n", 138 | " \n", 139 | " model_ckpt, data_sources, profiles, n = args.model_ckpt, args.data_sources, args.profiles, args.n\n", 140 | " names = data_sources.split(' ')\n", 141 | " #hprint(f\"names = {names}\")\n", 142 | " local_data = False\n", 143 | " if 's3://' in data_sources: \n", 144 | " hprint(\"Data sources are on S3\")\n", 145 | " profiles = ast.literal_eval(profiles)\n", 146 | " hprint(f\"profiles = {profiles}\")\n", 147 | "\n", 148 | " dl = get_wds_loader(\n", 149 | " batch_size=args.batch_size,\n", 150 | " s3_url_prefix=None,\n", 151 | " sample_size=args.sample_size,\n", 152 | " names=names,\n", 153 | " sample_rate=args.sample_rate,\n", 154 | " num_workers=args.num_workers,\n", 155 | " recursive=True,\n", 156 | " random_crop=True,\n", 157 | " epoch_steps=10000,\n", 158 | " profiles=profiles,\n", 159 | " )\n", 160 | " else:\n", 161 | " hprint(\"Data sources are local\")\n", 162 | " dataset = AudioDataset(names, sample_rate=args.sample_rate, sample_size=args.sample_size)\n", 163 | " dl = torch.utils.data.DataLoader(dataset, batch_size=args.batch_size, num_workers=args.num_workers)\n", 164 | " local_data = True\n", 165 | " \n", 166 | " print(f\"loading {model_ckpt}....\")\n", 167 | " if model_ckpt.endswith('.ts'):\n", 168 | " model = torch.jit.load(model_ckpt)\n", 169 | " #else: # default is stacked diffae\n", 170 | " # model = StackedDiffAEWrapper(ckpt_info={'ckpt_path':model_ckpt})\n", 171 | " try:\n", 172 | " model.setup() # if it needs setup call\n", 173 | " except: \n", 174 | " pass \n", 175 | " \n", 176 | " model.eval()\n", 177 | " model = model.to(device)\n", 178 | " \n", 179 | " model, dl = accelerator.prepare( model, dl ) # prepare handles distributing things among GPUs\n", 180 | "\n", 181 | " reals_path, fakes_path = f\"{args.name}_reals\", f\"{args.name}_fakes\"\n", 182 | " makedir(reals_path)\n", 183 | " makedir(fakes_path)\n", 184 | "\n", 185 | " progress_bar = tqdm(range(math.ceil(args.n/args.batch_size)), disable=not accelerator.is_local_main_process)\n", 186 | "\n", 187 | " for i, data in enumerate(dl):\n", 188 | " reals = data if local_data else data[0][0]\n", 189 | " if args.debug: hprint(f\"{ddps} i = {i}, reals.shape = {reals.shape}\")\n", 190 | " \n", 191 | " with torch.no_grad():\n", 192 | " fakes = accelerator.unwrap_model(model).forward(reals.to(device)).cpu()\n", 193 | " #hprint(f\"fakes.shape = {fakes.shape}\")\n", 194 | " \n", 195 | " for b in range(reals.shape[0]):\n", 196 | " waveform = reals[b]\n", 197 | " torchaudio.save(f\"{reals_path}/{i}_{b}.wav\", waveform.cpu(), args.sample_rate)\n", 198 | " waveform = fakes[b]\n", 199 | " torchaudio.save(f\"{fakes_path}/{i}_{b}.wav\", waveform.cpu(), args.sample_rate)\n", 200 | " \n", 201 | " progress_bar.update(1)\n", 202 | " if (i+1)*args.batch_size > args.n:\n", 203 | " hprint(f\"\\nGot all the data we needed: {i*args.batch_size}. Stopping\")\n", 204 | " break\n", 205 | " \n", 206 | " hprint(\"Success!\")" 207 | ] 208 | }, 209 | { 210 | "cell_type": "code", 211 | "execution_count": null, 212 | "metadata": {}, 213 | "outputs": [], 214 | "source": [ 215 | "#| export\n", 216 | "def main(): \n", 217 | " parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n", 218 | " parser.add_argument('name', help='Name prefix for output directories: _reals/ & _fakes/')\n", 219 | " parser.add_argument('model_ckpt', help='TorchScript (.ts) (Generative) Model checkpoint file')\n", 220 | " parser.add_argument('data_sources', help='Space-separated string listing either S3 resources or local directories (but not a mix of both!) for real data')\n", 221 | " parser.add_argument('-d','--debug', action=\"store_true\", help='Enable extra debugging messages')\n", 222 | " parser.add_argument('-b',\"--batch_size\", default=2, help='batch size')\n", 223 | " parser.add_argument('--n', type=int, default=256, help='Number of real/fake samples to grab/generate, respectively')\n", 224 | " parser.add_argument('--num_workers', type=int, default=12, help='Number of pytorch workers to use in DataLoader')\n", 225 | " parser.add_argument('-p',\"--profiles\", default='', help='String representation of dict {resource:profile} of AWS credentials')\n", 226 | " parser.add_argument('--sample_rate', type=int, default=48000, help='sample rate (will resample inputs at this rate)')\n", 227 | " parser.add_argument('-s','--sample_size', type=int, default=2**18, help='Number of samples per clip')\n", 228 | " args = parser.parse_args()\n", 229 | " gen( args )" 230 | ] 231 | }, 232 | { 233 | "cell_type": "code", 234 | "execution_count": null, 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "#| export\n", 239 | "if __name__ == '__main__' and \"get_ipython\" not in dir():\n", 240 | " main()" 241 | ] 242 | }, 243 | { 244 | "cell_type": "code", 245 | "execution_count": null, 246 | "metadata": {}, 247 | "outputs": [], 248 | "source": [ 249 | "#| hide\n", 250 | "import nbdev; nbdev.nbdev_export()" 251 | ] 252 | }, 253 | { 254 | "cell_type": "code", 255 | "execution_count": null, 256 | "metadata": {}, 257 | "outputs": [], 258 | "source": [] 259 | } 260 | ], 261 | "metadata": { 262 | "kernelspec": { 263 | "display_name": "aa", 264 | "language": "python", 265 | "name": "aa" 266 | } 267 | }, 268 | "nbformat": 4, 269 | "nbformat_minor": 4 270 | } 271 | -------------------------------------------------------------------------------- /nbs/02_fad_embed.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#| hide\n", 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# fad_embed\n", 19 | "\n", 20 | "> Command-line script to generate embeddings from audio files" 21 | ] 22 | }, 23 | { 24 | "cell_type": "code", 25 | "execution_count": null, 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "#| default_exp fad_embed" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#| hide\n", 39 | "from nbdev.showdoc import *" 40 | ] 41 | }, 42 | { 43 | "cell_type": "markdown", 44 | "metadata": {}, 45 | "source": [ 46 | "## Sample calling sequences\n", 47 | "\n", 48 | "Single processor, single GPU: \n", 49 | "\n", 50 | "```\n", 51 | "fad_embed clap real/ fake\n", 52 | "```\n", 53 | "\n", 54 | "\n", 55 | "Multiple GPUs, multiple processors (single node): (this example syntax is to run from within main fad_pytorch package directory)\n", 56 | "```\n", 57 | "accelerate launch fad_pytorch/fad_embed.py clap real/ fake/\n", 58 | "```\n", 59 | "\n", 60 | "General invocation: \n", 61 | "\n", 62 | "```bash\n", 63 | "$ fad_embed -h\n", 64 | "usage: fad_embed [-h] [--batch_size BATCH_SIZE] [--sample_size SAMPLE_SIZE] [--chunk_size CHUNK_SIZE] [--hop_size HOP_SIZE] [--max_hops MAX_HOPS] [--sr SR] [--verbose]\n", 65 | " [--debug]\n", 66 | " embed_model real_path fake_path\n", 67 | "\n", 68 | "positional arguments:\n", 69 | " embed_model choice of embedding model(s): clap | vggish | pann | openl3 | all\n", 70 | " real_path Path of files of real audio\n", 71 | " fake_path Path of files of fake audio\n", 72 | "\n", 73 | "options:\n", 74 | " -h, --help show this help message and exit\n", 75 | " --batch_size BATCH_SIZE\n", 76 | " MAXIMUM Batch size for computing embeddings (may go smaller) (default: 64)\n", 77 | " --sample_size SAMPLE_SIZE\n", 78 | " Number of audio samples to read from each audio file (default: 262144)\n", 79 | " --chunk_size CHUNK_SIZE\n", 80 | " Length of chunks (in audio samples) to embed (default: 24000)\n", 81 | " --hop_size HOP_SIZE (approximate) time difference (in seconds) between each chunk (default: 0.1)\n", 82 | " --max_hops MAX_HOPS Don't exceed this many hops/chunks/embeddings per audio file. <= 0 disables this. (default: -1)\n", 83 | " --sr SR sample rate (will resample inputs at this rate) (default: 48000)\n", 84 | " --verbose Show notices of resampling when reading files (default: False)\n", 85 | " --debug Extra messages for debugging this program (default: False)\n", 86 | "\n", 87 | "```" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": null, 93 | "metadata": {}, 94 | "outputs": [ 95 | { 96 | "data": { 97 | "application/vnd.jupyter.widget-view+json": { 98 | "model_id": "cccf171fa6f145d88f0f73c00fe2ff13", 99 | "version_major": 2, 100 | "version_minor": 0 101 | }, 102 | "text/plain": [ 103 | "Downloading (…)solve/main/vocab.txt: 0%| | 0.00/232k [00:00 b s c') # this torchopen3 expects channels-first ordering\n", 551 | " embeddings, timestamps = torchopenl3.get_audio_embedding(audio, emb_sample_rate, model=embedder)\n", 552 | " embeddings = torch.squeeze(embeddings, 1) # get rid of any spurious dimensions of 1 in middle position \n", 553 | "\n", 554 | " else:\n", 555 | " raise ValueError(f\"Unknown model_choice = {model_choice}\")\n", 556 | "\n", 557 | " hprint(f\"embeddings.shape = {embeddings.shape}\")\n", 558 | " # TODO: for now we'll just dump each batch on each proc to its own file; this could be improved\n", 559 | " outfilename = f\"{newdir}/emb_p{local_rank}_b{i}_h{h_ind}.pt\"\n", 560 | " hprint(f\"{ddps} Saving embeddings to {outfilename}\")\n", 561 | " torch.save(embeddings.cpu().detach(), outfilename)\n", 562 | " \n", 563 | " del embedder\n", 564 | " torch.cuda.empty_cache()\n", 565 | " # end loop over various embedders\n", 566 | " return " 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "#| export\n", 576 | "def main(): \n", 577 | " parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n", 578 | " parser.add_argument('embed_model', help='choice of embedding model(s): clap | vggish | pann | openl3 | all ', default='clap')\n", 579 | " parser.add_argument('real_path', help='Path of files of real audio', default='real/')\n", 580 | " parser.add_argument('fake_path', help='Path of files of fake audio', default='fake/')\n", 581 | " parser.add_argument('--batch_size', type=int, default=64, help='MAXIMUM Batch size for computing embeddings (may go smaller)')\n", 582 | " parser.add_argument('--sample_size', type=int, default=2**18, help='Number of audio samples to read from each audio file')\n", 583 | " parser.add_argument('--chunk_size', type=int, default=24000, help='Length of chunks (in audio samples) to embed')\n", 584 | " parser.add_argument('--hop_size', type=float, default=0.100, help='(approximate) time difference (in seconds) between each chunk')\n", 585 | " parser.add_argument('--max_hops', type=int, default=-1, help=\"Don't exceed this many hops/chunks/embeddings per audio file. <= 0 disables this.\")\n", 586 | " parser.add_argument('--sr', type=int, default=48000, help='sample rate (will resample inputs at this rate)')\n", 587 | " parser.add_argument('--verbose', action='store_true', help='Show notices of resampling when reading files')\n", 588 | " parser.add_argument('--debug', action='store_true', help='Extra messages for debugging this program')\n", 589 | "\n", 590 | " args = parser.parse_args()\n", 591 | " embed(args)" 592 | ] 593 | }, 594 | { 595 | "cell_type": "code", 596 | "execution_count": null, 597 | "metadata": {}, 598 | "outputs": [], 599 | "source": [ 600 | "#| export\n", 601 | "if __name__ == '__main__' and \"get_ipython\" not in dir():\n", 602 | " main()" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": null, 608 | "metadata": {}, 609 | "outputs": [], 610 | "source": [ 611 | "#| hide\n", 612 | "import nbdev; nbdev.nbdev_export()" 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": null, 618 | "metadata": {}, 619 | "outputs": [], 620 | "source": [] 621 | } 622 | ], 623 | "metadata": { 624 | "kernelspec": { 625 | "display_name": "aa", 626 | "language": "python", 627 | "name": "aa" 628 | } 629 | }, 630 | "nbformat": 4, 631 | "nbformat_minor": 4 632 | } 633 | -------------------------------------------------------------------------------- /nbs/03_fad_score.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#| hide\n", 10 | "%load_ext autoreload\n", 11 | "%autoreload 2" 12 | ] 13 | }, 14 | { 15 | "cell_type": "markdown", 16 | "metadata": {}, 17 | "source": [ 18 | "# fad_score\n", 19 | "\n", 20 | "> Produce FAD score based on files of embeddings of real and fake data" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "$$ FAD = || \\mu_r - \\mu_f ||^2 + tr\\left(\\Sigma_r + \\Sigma_f - 2 \\sqrt{\\Sigma_r \\Sigma_f}\\right)$$\n", 28 | "\n", 29 | "The embeddings are small enough that this can typically be run on a single processor, on a CPU. However, all the supporting code is GPU-friendly if so desired. " 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#| default_exp fad_score" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "#| hide\n", 48 | "from nbdev.showdoc import *" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "#| export\n", 58 | "import torch \n", 59 | "import argparse\n", 60 | "from fad_pytorch.sqrtm import sqrtm\n", 61 | "from aeiou.core import fast_scandir" 62 | ] 63 | }, 64 | { 65 | "cell_type": "code", 66 | "execution_count": null, 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "#| export\n", 71 | "def read_embeddings(emb_path='real_emb_clap/', debug=False):\n", 72 | " \"reads any .pt files in emb_path and concatenates them into one tensor\"\n", 73 | " if debug: print(\"searching in \",emb_path) \n", 74 | " _, file_list = fast_scandir(emb_path, ['pt'])\n", 75 | " if file_list == []:\n", 76 | " _, file_list = fast_scandir('/fsx/shawley/code/fad_pytorch/'+emb_path, ['pt']) # yea, cheap hack just for my testing in nbs/ dir\n", 77 | " assert file_list != []\n", 78 | " embeddings = []\n", 79 | " for file_path in file_list:\n", 80 | " emb_batch = torch.load(file_path, map_location='cpu') \n", 81 | " embeddings.append(emb_batch)\n", 82 | " return torch.cat(embeddings, dim=0)" 83 | ] 84 | }, 85 | { 86 | "cell_type": "code", 87 | "execution_count": null, 88 | "metadata": {}, 89 | "outputs": [ 90 | { 91 | "data": { 92 | "text/plain": [ 93 | "torch.Size([256, 512])" 94 | ] 95 | }, 96 | "execution_count": null, 97 | "metadata": {}, 98 | "output_type": "execute_result" 99 | } 100 | ], 101 | "source": [ 102 | "#| eval: false\n", 103 | "# lil test of that\n", 104 | "e = read_embeddings()\n", 105 | "e.shape" 106 | ] 107 | }, 108 | { 109 | "cell_type": "code", 110 | "execution_count": null, 111 | "metadata": {}, 112 | "outputs": [], 113 | "source": [ 114 | "#| export \n", 115 | "def calc_mu_sigma(emb): \n", 116 | " \"calculates mean and covariance matrix of batched embeddings\"\n", 117 | " mu = torch.mean(emb, axis=0)\n", 118 | " sigma = torch.cov(emb.T)\n", 119 | " return mu, sigma" 120 | ] 121 | }, 122 | { 123 | "cell_type": "code", 124 | "execution_count": null, 125 | "metadata": {}, 126 | "outputs": [ 127 | { 128 | "data": { 129 | "text/plain": [ 130 | "(torch.Size([512]), torch.Size([512, 512]))" 131 | ] 132 | }, 133 | "execution_count": null, 134 | "metadata": {}, 135 | "output_type": "execute_result" 136 | } 137 | ], 138 | "source": [ 139 | "#| eval: false\n", 140 | "# quick test:\n", 141 | "x = torch.rand(32,512) \n", 142 | "mu, sigma = calc_mu_sigma(x) \n", 143 | "mu.shape, sigma.shape " 144 | ] 145 | }, 146 | { 147 | "cell_type": "code", 148 | "execution_count": null, 149 | "metadata": {}, 150 | "outputs": [], 151 | "source": [ 152 | "#| export\n", 153 | "def calc_score(real_emb_path, # where real embeddings are stored\n", 154 | " fake_emb_path, # where fake embeddings are stored\n", 155 | " method='maji', # sqrtm calc method: 'maji'|'li'\n", 156 | " debug=False\n", 157 | " ): \n", 158 | " print(f\"Calculating FAD score for files in {real_emb_path}/ vs. {fake_emb_path}/\")\n", 159 | " emb_real = read_embeddings(emb_path=real_emb_path, debug=debug)\n", 160 | " emb_fake = read_embeddings(emb_path=fake_emb_path, debug=debug)\n", 161 | " if debug: print(emb_real.shape, emb_fake.shape)\n", 162 | " \n", 163 | " mu_real, sigma_real = calc_mu_sigma(emb_real) \n", 164 | " mu_fake, sigma_fake = calc_mu_sigma(emb_fake) \n", 165 | " if debug:\n", 166 | " print(\"mu_real.shape, sigma_real.shape =\",mu_real.shape, sigma_real.shape)\n", 167 | " print(\"mu_fake.shape, sigma_fake.shape =\",mu_fake.shape, sigma_fake.shape)\n", 168 | " \n", 169 | " mu_diff = mu_real - mu_fake\n", 170 | " if debug:\n", 171 | " print(\"mu_diff = \",mu_diff) \n", 172 | " score1 = mu_diff.dot(mu_diff)\n", 173 | " print(\"score1: mu_diff.dot(mu_diff) = \",score1)\n", 174 | " score2 = torch.trace(sigma_real)\n", 175 | " print(\"score2: torch.trace(sigma_real) = \", score2)\n", 176 | " score3 = torch.trace(sigma_fake)\n", 177 | " print(\"score3: torch.trace(sigma_fake) = \",score3)\n", 178 | " score_p = sqrtm( torch.matmul( sigma_real, sigma_fake) )\n", 179 | " print(\"score_p.shape (matmul) = \",score_p.shape) \n", 180 | " score4 = -2* torch.trace ( torch.real ( sqrtm( torch.matmul( sigma_real, sigma_fake) , method=method ) ) )\n", 181 | " print(\"score4 (-2*tr(sqrtm(matmul(sigma_r sigma_f)))) = \",score4) \n", 182 | " score = score1 + score2 + score3 + score4\n", 183 | " score = mu_diff.dot(mu_diff) + torch.trace(sigma_real) + torch.trace(sigma_fake) -2* torch.trace ( torch.real ( sqrtm( torch.matmul( sigma_real, sigma_fake), method=method ) ) )\n", 184 | " return score" 185 | ] 186 | }, 187 | { 188 | "cell_type": "markdown", 189 | "metadata": {}, 190 | "source": [ 191 | "Test the score function:" 192 | ] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "execution_count": null, 197 | "metadata": {}, 198 | "outputs": [ 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "Calculating FAD score for files in real_emb_clap// vs. fake_emb_clap//\n", 204 | "tensor(0.0951)\n" 205 | ] 206 | } 207 | ], 208 | "source": [ 209 | "#| eval: false\n", 210 | "score = calc_score( 'real_emb_clap/', 'fake_emb_clap/', method='maji')\n", 211 | "print(score)" 212 | ] 213 | }, 214 | { 215 | "cell_type": "markdown", 216 | "metadata": {}, 217 | "source": [ 218 | "Try sending using the exact same data for both distributions: Do we get zero? " 219 | ] 220 | }, 221 | { 222 | "cell_type": "code", 223 | "execution_count": null, 224 | "metadata": {}, 225 | "outputs": [ 226 | { 227 | "name": "stdout", 228 | "output_type": "stream", 229 | "text": [ 230 | "Calculating FAD score for files in real_emb_clap// vs. real_emb_clap//\n", 231 | "searching in real_emb_clap/\n", 232 | "searching in real_emb_clap/\n", 233 | "torch.Size([256, 512]) torch.Size([256, 512])\n", 234 | "mu_real.shape, sigma_real.shape = torch.Size([512]) torch.Size([512, 512])\n", 235 | "mu_fake.shape, sigma_fake.shape = torch.Size([512]) torch.Size([512, 512])\n", 236 | "mu_diff = tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 237 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 238 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 239 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 240 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 241 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 242 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 243 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 244 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 245 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 246 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 247 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 248 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 249 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 250 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 251 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 252 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 253 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 254 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 255 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 256 | " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", 257 | " 0., 0., 0., 0., 0., 0., 0., 0.])\n", 258 | "score1: mu_diff.dot(mu_diff) = tensor(0.)\n", 259 | "score2: torch.trace(sigma_real) = tensor(0.4448)\n", 260 | "score3: torch.trace(sigma_fake) = tensor(0.4448)\n", 261 | "score_p.shape (matmul) = torch.Size([512, 512])\n", 262 | "score4 (-2*tr(sqrtm(matmul(sigma_r sigma_f)))) = tensor(-0.8888)\n", 263 | "tensor(0.0008)\n" 264 | ] 265 | } 266 | ], 267 | "source": [ 268 | "#| eval: false\n", 269 | "score = calc_score( 'real_emb_clap/', 'real_emb_clap/', method='maji', debug=True)\n", 270 | "print(score)" 271 | ] 272 | }, 273 | { 274 | "cell_type": "markdown", 275 | "metadata": {}, 276 | "source": [ 277 | "Ok, so not zero, but small." 278 | ] 279 | }, 280 | { 281 | "cell_type": "code", 282 | "execution_count": null, 283 | "metadata": {}, 284 | "outputs": [], 285 | "source": [ 286 | "#| export\n", 287 | "def main(): \n", 288 | " parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)\n", 289 | " parser.add_argument('real_emb_path', help='Path of files of embeddings of real data', default='real_emb_clap/')\n", 290 | " parser.add_argument('fake_emb_path', help='Path of files of embeddings of fake data', default='fake_emb_clap/')\n", 291 | " parser.add_argument('-d','--debug', action='store_true', help='Enable debugging')\n", 292 | " parser.add_argument('-m','--method', default='maji', help='Method for sqrtm calculation: \"maji\" or \"li\" ')\n", 293 | "\n", 294 | " args = parser.parse_args()\n", 295 | " score = calc_score( args.real_emb_path, args.fake_emb_path, method=args.method, debug=args.debug )\n", 296 | " print(\"FAD score = \",score.cpu().numpy())" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "#| export\n", 306 | "if __name__ == '__main__' and \"get_ipython\" not in dir():\n", 307 | " main()" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": null, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "#| hide\n", 317 | "import nbdev; nbdev.nbdev_export()" 318 | ] 319 | }, 320 | { 321 | "cell_type": "code", 322 | "execution_count": null, 323 | "metadata": {}, 324 | "outputs": [], 325 | "source": [] 326 | } 327 | ], 328 | "metadata": { 329 | "kernelspec": { 330 | "display_name": "aa", 331 | "language": "python", 332 | "name": "aa" 333 | } 334 | }, 335 | "nbformat": 4, 336 | "nbformat_minor": 4 337 | } 338 | -------------------------------------------------------------------------------- /nbs/04_sqrtm.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stdout", 10 | "output_type": "stream", 11 | "text": [ 12 | "The autoreload extension is already loaded. To reload it, use:\n", 13 | " %reload_ext autoreload\n" 14 | ] 15 | } 16 | ], 17 | "source": [ 18 | "#| hide\n", 19 | "%load_ext autoreload\n", 20 | "%autoreload 2" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "# sqrtm\n", 28 | "\n", 29 | "> Methods for computing sqrt of a matrix" 30 | ] 31 | }, 32 | { 33 | "cell_type": "code", 34 | "execution_count": null, 35 | "metadata": {}, 36 | "outputs": [], 37 | "source": [ 38 | "#| default_exp sqrtm" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": null, 44 | "metadata": {}, 45 | "outputs": [], 46 | "source": [ 47 | "#| hide\n", 48 | "from nbdev.showdoc import *" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": null, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "#| export\n", 58 | "import torch\n", 59 | "from torch.autograd import Function, Variable" 60 | ] 61 | }, 62 | { 63 | "cell_type": "markdown", 64 | "metadata": {}, 65 | "source": [ 66 | "## Steve Li's method" 67 | ] 68 | }, 69 | { 70 | "cell_type": "code", 71 | "execution_count": null, 72 | "metadata": {}, 73 | "outputs": [], 74 | "source": [ 75 | "#| export\n", 76 | "\n", 77 | "use_li = True # come back and turn this on if you want to see/use the full code\n", 78 | "\n", 79 | "if use_li: # lighten the load of imports since we won't use li's in production\n", 80 | " import numpy as np\n", 81 | " import scipy.linalg" 82 | ] 83 | }, 84 | { 85 | "cell_type": "code", 86 | "execution_count": null, 87 | "metadata": {}, 88 | "outputs": [], 89 | "source": [ 90 | "#| export\n", 91 | "class MatrixSquareRoot_li(Function):\n", 92 | " \"\"\"\n", 93 | " From https://github.com/steveli/pytorch-sqrtm/blob/master/sqrtm.py, which sadly does not install as a package. LICENSE included below\n", 94 | " Square root of a positive definite matrix.\n", 95 | "\n", 96 | " NOTE: matrix square root is not differentiable for matrices with\n", 97 | " zero eigenvalues.\n", 98 | " \"\"\"\n", 99 | " @staticmethod\n", 100 | " def forward(ctx, input):\n", 101 | " m = input.detach().cpu().numpy().astype(np.float_) # SHH: note how this immediately switches to CPU & numpy :-( \n", 102 | " sqrtm = torch.from_numpy(scipy.linalg.sqrtm(m).real).to(input)\n", 103 | " ctx.save_for_backward(sqrtm)\n", 104 | " return sqrtm\n", 105 | "\n", 106 | " @staticmethod\n", 107 | " def backward(ctx, grad_output):\n", 108 | " grad_input = None\n", 109 | " if ctx.needs_input_grad[0]:\n", 110 | " sqrtm, = ctx.saved_tensors\n", 111 | " sqrtm = sqrtm.data.cpu().numpy().astype(np.float_)\n", 112 | " gm = grad_output.data.cpu().numpy().astype(np.float_)\n", 113 | "\n", 114 | " # Given a positive semi-definite matrix X,\n", 115 | " # since X = X^{1/2}X^{1/2}, we can compute the gradient of the\n", 116 | " # matrix square root dX^{1/2} by solving the Sylvester equation:\n", 117 | " # dX = (d(X^{1/2})X^{1/2} + X^{1/2}(dX^{1/2}).\n", 118 | " grad_sqrtm = scipy.linalg.solve_sylvester(sqrtm, sqrtm, gm)\n", 119 | "\n", 120 | " grad_input = torch.from_numpy(grad_sqrtm).to(grad_output)\n", 121 | " return grad_input\n", 122 | "\n", 123 | "\n", 124 | "sqrtm_li = MatrixSquareRoot_li.apply\n", 125 | "\n", 126 | "\n", 127 | "LICENCE_LI = \"\"\"\n", 128 | "MIT License\n", 129 | "\n", 130 | "Copyright (c) 2022 Steven Cheng-Xian Li\n", 131 | "\n", 132 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 133 | "of this software and associated documentation files (the \"Software\"), to deal\n", 134 | "in the Software without restriction, including without limitation the rights\n", 135 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 136 | "copies of the Software, and to permit persons to whom the Software is\n", 137 | "furnished to do so, subject to the following conditions:\n", 138 | "\n", 139 | "The above copyright notice and this permission notice shall be included in all\n", 140 | "copies or substantial portions of the Software.\n", 141 | "\n", 142 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 143 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 144 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 145 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 146 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 147 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 148 | "SOFTWARE.\n", 149 | "\"\"\"" 150 | ] 151 | }, 152 | { 153 | "cell_type": "markdown", 154 | "metadata": {}, 155 | "source": [ 156 | "Steve Li's test code for the above:\n" 157 | ] 158 | }, 159 | { 160 | "cell_type": "code", 161 | "execution_count": null, 162 | "metadata": {}, 163 | "outputs": [], 164 | "source": [ 165 | "#| eval: false \n", 166 | "from torch.autograd import gradcheck" 167 | ] 168 | }, 169 | { 170 | "cell_type": "code", 171 | "execution_count": null, 172 | "metadata": {}, 173 | "outputs": [ 174 | { 175 | "name": "stdout", 176 | "output_type": "stream", 177 | "text": [ 178 | "sq =\n", 179 | " tensor([[ 2.6360e+01, -5.2896e-01, 4.6020e-01, ..., 4.6385e-01,\n", 180 | " -2.5534e-01, 3.3804e-01],\n", 181 | " [-5.2896e-01, 2.5773e+01, -7.2415e-01, ..., -2.6621e-02,\n", 182 | " 3.0918e-01, -7.8089e-02],\n", 183 | " [ 4.6020e-01, -7.2415e-01, 2.5863e+01, ..., -5.2346e-01,\n", 184 | " 1.4617e-01, -2.5943e-01],\n", 185 | " ...,\n", 186 | " [ 4.6385e-01, -2.6621e-02, -5.2346e-01, ..., 2.6959e+01,\n", 187 | " 3.6158e-01, -4.6653e-01],\n", 188 | " [-2.5534e-01, 3.0918e-01, 1.4617e-01, ..., 3.6158e-01,\n", 189 | " 2.6692e+01, -4.4417e-01],\n", 190 | " [ 3.3804e-01, -7.8089e-02, -2.5943e-01, ..., -4.6653e-01,\n", 191 | " -4.4417e-01, 2.8916e+01]], dtype=torch.float64)\n" 192 | ] 193 | } 194 | ], 195 | "source": [ 196 | "#| eval: false \n", 197 | "if use_li:\n", 198 | " k = torch.randn(1000, 1000).double()\n", 199 | " # Create a positive definite matrix\n", 200 | " pd_mat = (k.t().matmul(k)).requires_grad_()\n", 201 | " with torch.no_grad():\n", 202 | " sq = sqrtm_li(pd_mat)\n", 203 | " print(\"sq =\\n\",sq)\n", 204 | " #print(\"Running gradcheck...\")\n", 205 | " #test = gradcheck(sqrtm_li, (pd_mat,))\n", 206 | " #print(test)" 207 | ] 208 | }, 209 | { 210 | "cell_type": "markdown", 211 | "metadata": {}, 212 | "source": [ 213 | "## Subhransu Maji's method(s)\n", 214 | "\n", 215 | "From https://github.com/msubhransu/matrix-sqrt" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": null, 221 | "metadata": {}, 222 | "outputs": [], 223 | "source": [ 224 | "#| export\n", 225 | "\n", 226 | "# Compute error\n", 227 | "def compute_error(A, sA):\n", 228 | " normA = torch.sqrt(torch.sum(torch.sum(A * A, dim=1),dim=1))\n", 229 | " error = A - torch.bmm(sA, sA)\n", 230 | " error = torch.sqrt((error * error).sum(dim=1).sum(dim=1)) / normA\n", 231 | " return torch.mean(error)\n", 232 | "\n", 233 | "\n", 234 | "def sqrt_newton_schulz_autograd(A, \n", 235 | " numIters=20, # found experimentally by SHH, comparing w/ Li's method\n", 236 | " calc_error=False,):\n", 237 | " \"\"\"Modified from from https://people.cs.umass.edu/~smaji/projects/matrix-sqrt/\n", 238 | " \"The drawback of the autograd approach [i.e., this approach] is that a naive implementation stores all the intermediate results. \n", 239 | " Thus the memory overhead scales linearly with the number of iterations which is problematic for large matrices.\"\n", 240 | " \"\"\"\n", 241 | " if len(A.data.shape) < 3: A = A.unsqueeze(0)\n", 242 | " batchSize, dim1, dim2 = A.data.shape\n", 243 | " assert dim1==dim2\n", 244 | " dim, dtype, device = dim1, A.dtype, A.device\n", 245 | " normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()\n", 246 | " Y = A.div(normA.view(batchSize, 1, 1).expand_as(A));\n", 247 | " I = Variable(torch.eye(dim,dim, device=device).view(1, dim, dim).\n", 248 | " repeat(batchSize,1,1).type(dtype),requires_grad=False)\n", 249 | " Z = Variable(torch.eye(dim,dim, device=device).view(1, dim, dim).\n", 250 | " repeat(batchSize,1,1).type(dtype),requires_grad=False)\n", 251 | "\n", 252 | " for i in range(numIters):\n", 253 | " T = 0.5*(3.0*I - Z.bmm(Y))\n", 254 | " Y = Y.bmm(T)\n", 255 | " Z = T.bmm(Z)\n", 256 | " \n", 257 | " sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)\n", 258 | " if calc_error:\n", 259 | " error = compute_error(A, sA)\n", 260 | " return sA, error\n", 261 | " return sA\n", 262 | "\n", 263 | "\n", 264 | "def sqrt_newton_schulz(A, # matrix to be sqrt-ified\n", 265 | " numIters=20, # numIters=7 found via experimentation\n", 266 | " calc_error=False, # setting False disables Maji's error reporting\n", 267 | " ):\n", 268 | " \"\"\"\n", 269 | " Sqrt of matrix via Newton-Schulz algorithm\n", 270 | " Modified from https://github.com/msubhransu/matrix-sqrt/blob/cc2289a3ed7042b8dbacd53ce8a34da1f814ed2f/matrix_sqrt.py#LL72C1-L87C19\n", 271 | " # Forward via Newton-Schulz iterations (non autograd version)\n", 272 | " # Seems to be slighlty faster and has much lower memory overhead\n", 273 | " \n", 274 | " ... Original code didn't preserve device, had no batch dim checking -SHH\n", 275 | " \"\"\"\n", 276 | " while len(A.data.shape) < 3: # needs a batch dimension\n", 277 | " A = A.unsqueeze(0)\n", 278 | " batchSize, dim1, dim2 = A.data.shape\n", 279 | " assert dim1==dim2\n", 280 | " dim, dtype, device = dim1, A.dtype, A.device\n", 281 | " normA = A.mul(A).sum(dim=1).sum(dim=1).sqrt()\n", 282 | " Y = A.div(normA.view(batchSize, 1, 1).expand_as(A));\n", 283 | " I = torch.eye(dim,dim, device=device, dtype=dtype).view(1, dim, dim).repeat(batchSize,1,1)\n", 284 | " Z = torch.eye(dim,dim, device=device, dtype=dtype).view(1, dim, dim).repeat(batchSize,1,1)\n", 285 | " for i in range(numIters):\n", 286 | " T = 0.5*(3.0*I - Z.bmm(Y))\n", 287 | " Y = Y.bmm(T)\n", 288 | " Z = T.bmm(Z)\n", 289 | " \n", 290 | " sA = Y*torch.sqrt(normA).view(batchSize, 1, 1).expand_as(A)\n", 291 | " if calc_error: \n", 292 | " error = compute_error(A, sA)\n", 293 | " return sA, error\n", 294 | " else:\n", 295 | " return sA\n", 296 | "\n", 297 | " \n", 298 | "\"\"\" \n", 299 | "# Only used if backprop needed, which it isn't for FAD. Leaving it here anyway. -SHH\n", 300 | "def lyap_newton_schulz(z, dldz, numIters, dtype):\n", 301 | " # Backward via iterative Lyapunov solver.\n", 302 | " batchSize = z.shape[0]\n", 303 | " dim = z.shape[1]\n", 304 | " normz = z.mul(z).sum(dim=1).sum(dim=1).sqrt()\n", 305 | " a = z.div(normz.view(batchSize, 1, 1).expand_as(z))\n", 306 | " I = torch.eye(dim,dim).view(1, dim, dim).repeat(batchSize,1,1).type(dtype)\n", 307 | " q = dldz.div(normz.view(batchSize, 1, 1).expand_as(z))\n", 308 | " for i in range(numIters):\n", 309 | " q = 0.5*(q.bmm(3.0*I - a.bmm(a)) - a.transpose(1, 2).bmm(a.transpose(1,2).bmm(q) - q.bmm(a)) )\n", 310 | " a = 0.5*a.bmm(3.0*I - a.bmm(a))\n", 311 | " dlda = 0.5*q\n", 312 | " return dlda\n", 313 | "\"\"\"\n", 314 | "\n", 315 | "\n", 316 | "LICENCE_SM = \"\"\"\n", 317 | "MIT License\n", 318 | "\n", 319 | "Copyright (c) 2017 Subhransu Maji\n", 320 | "\n", 321 | "Permission is hereby granted, free of charge, to any person obtaining a copy\n", 322 | "of this software and associated documentation files (the \"Software\"), to deal\n", 323 | "in the Software without restriction, including without limitation the rights\n", 324 | "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n", 325 | "copies of the Software, and to permit persons to whom the Software is\n", 326 | "furnished to do so, subject to the following conditions:\n", 327 | "\n", 328 | "The above copyright notice and this permission notice shall be included in all\n", 329 | "copies or substantial portions of the Software.\n", 330 | "\n", 331 | "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n", 332 | "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n", 333 | "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n", 334 | "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n", 335 | "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n", 336 | "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n", 337 | "SOFTWARE.\n", 338 | "\"\"\"" 339 | ] 340 | }, 341 | { 342 | "cell_type": "markdown", 343 | "metadata": {}, 344 | "source": [ 345 | "## Error tests" 346 | ] 347 | }, 348 | { 349 | "cell_type": "code", 350 | "execution_count": null, 351 | "metadata": {}, 352 | "outputs": [ 353 | { 354 | "name": "stdout", 355 | "output_type": "stream", 356 | "text": [ 357 | "sa1 =\n", 358 | " tensor([[[ 2.6360e+01, -5.2896e-01, 4.6021e-01, ..., 4.6384e-01,\n", 359 | " -2.5535e-01, 3.3805e-01],\n", 360 | " [-5.2896e-01, 2.5773e+01, -7.2415e-01, ..., -2.6632e-02,\n", 361 | " 3.0917e-01, -7.8092e-02],\n", 362 | " [ 4.6021e-01, -7.2415e-01, 2.5863e+01, ..., -5.2344e-01,\n", 363 | " 1.4618e-01, -2.5943e-01],\n", 364 | " ...,\n", 365 | " [ 4.6384e-01, -2.6632e-02, -5.2344e-01, ..., 2.6959e+01,\n", 366 | " 3.6150e-01, -4.6653e-01],\n", 367 | " [-2.5535e-01, 3.0917e-01, 1.4618e-01, ..., 3.6150e-01,\n", 368 | " 2.6692e+01, -4.4418e-01],\n", 369 | " [ 3.3805e-01, -7.8092e-02, -2.5943e-01, ..., -4.6653e-01,\n", 370 | " -4.4418e-01, 2.8916e+01]]], dtype=torch.float64,\n", 371 | " grad_fn=)\n", 372 | "error = 4.759428080865442e-08\n" 373 | ] 374 | } 375 | ], 376 | "source": [ 377 | "#| eval: false \n", 378 | "sa1, error = sqrt_newton_schulz_autograd( pd_mat.unsqueeze(0), numIters=20, calc_error=True ) \n", 379 | "print(\"sa1 =\\n\",sa1)\n", 380 | "print(\"error =\",error.detach().item())" 381 | ] 382 | }, 383 | { 384 | "cell_type": "code", 385 | "execution_count": null, 386 | "metadata": {}, 387 | "outputs": [ 388 | { 389 | "name": "stdout", 390 | "output_type": "stream", 391 | "text": [ 392 | "sa2 =\n", 393 | " tensor([[[ 2.6360e+01, -5.2896e-01, 4.6021e-01, ..., 4.6384e-01,\n", 394 | " -2.5535e-01, 3.3805e-01],\n", 395 | " [-5.2896e-01, 2.5773e+01, -7.2415e-01, ..., -2.6632e-02,\n", 396 | " 3.0917e-01, -7.8092e-02],\n", 397 | " [ 4.6021e-01, -7.2415e-01, 2.5863e+01, ..., -5.2344e-01,\n", 398 | " 1.4618e-01, -2.5943e-01],\n", 399 | " ...,\n", 400 | " [ 4.6384e-01, -2.6632e-02, -5.2344e-01, ..., 2.6959e+01,\n", 401 | " 3.6150e-01, -4.6653e-01],\n", 402 | " [-2.5535e-01, 3.0917e-01, 1.4618e-01, ..., 3.6150e-01,\n", 403 | " 2.6692e+01, -4.4418e-01],\n", 404 | " [ 3.3805e-01, -7.8092e-02, -2.5943e-01, ..., -4.6653e-01,\n", 405 | " -4.4418e-01, 2.8916e+01]]], dtype=torch.float64,\n", 406 | " grad_fn=)\n", 407 | "error = 4.759428080865442e-08\n" 408 | ] 409 | } 410 | ], 411 | "source": [ 412 | "#| eval: false \n", 413 | "sa2, error = sqrt_newton_schulz( pd_mat.unsqueeze(0), numIters=20, calc_error=True ) \n", 414 | "print(\"sa2 =\\n\",sa2)\n", 415 | "print(\"error =\",error.detach().item())" 416 | ] 417 | }, 418 | { 419 | "cell_type": "code", 420 | "execution_count": null, 421 | "metadata": {}, 422 | "outputs": [ 423 | { 424 | "name": "stdout", 425 | "output_type": "stream", 426 | "text": [ 427 | "diff = \n", 428 | " tensor([[[-3.7835e-05, 3.8437e-07, 8.7208e-06, ..., -1.2816e-05,\n", 429 | " -1.8744e-05, 1.3386e-05],\n", 430 | " [ 3.8437e-07, -1.9896e-06, 2.2977e-06, ..., -1.1004e-05,\n", 431 | " -1.0735e-05, -3.3250e-06],\n", 432 | " [ 8.7208e-06, 2.2977e-06, -7.5329e-06, ..., 2.2518e-05,\n", 433 | " 1.9394e-05, -3.9655e-06],\n", 434 | " ...,\n", 435 | " [-1.2816e-05, -1.1004e-05, 2.2518e-05, ..., -8.1416e-05,\n", 436 | " -7.1274e-05, -1.4799e-06],\n", 437 | " [-1.8744e-05, -1.0735e-05, 1.9394e-05, ..., -7.1274e-05,\n", 438 | " -8.0163e-05, -1.6137e-05],\n", 439 | " [ 1.3386e-05, -3.3250e-06, -3.9655e-06, ..., -1.4799e-06,\n", 440 | " -1.6137e-05, -2.6324e-05]]], dtype=torch.float64,\n", 441 | " grad_fn=)\n" 442 | ] 443 | } 444 | ], 445 | "source": [ 446 | "#| eval: false \n", 447 | "if use_li:\n", 448 | " diff = sa1 - sq\n", 449 | " print(\"diff = \\n\",diff) " 450 | ] 451 | }, 452 | { 453 | "cell_type": "code", 454 | "execution_count": null, 455 | "metadata": {}, 456 | "outputs": [ 457 | { 458 | "name": "stdout", 459 | "output_type": "stream", 460 | "text": [ 461 | "diff = \n", 462 | " tensor([[[-3.7835e-05, 3.8437e-07, 8.7208e-06, ..., -1.2816e-05,\n", 463 | " -1.8744e-05, 1.3386e-05],\n", 464 | " [ 3.8437e-07, -1.9896e-06, 2.2977e-06, ..., -1.1004e-05,\n", 465 | " -1.0735e-05, -3.3250e-06],\n", 466 | " [ 8.7208e-06, 2.2977e-06, -7.5329e-06, ..., 2.2518e-05,\n", 467 | " 1.9394e-05, -3.9655e-06],\n", 468 | " ...,\n", 469 | " [-1.2816e-05, -1.1004e-05, 2.2518e-05, ..., -8.1416e-05,\n", 470 | " -7.1274e-05, -1.4799e-06],\n", 471 | " [-1.8744e-05, -1.0735e-05, 1.9394e-05, ..., -7.1274e-05,\n", 472 | " -8.0163e-05, -1.6137e-05],\n", 473 | " [ 1.3386e-05, -3.3250e-06, -3.9655e-06, ..., -1.4799e-06,\n", 474 | " -1.6137e-05, -2.6324e-05]]], dtype=torch.float64,\n", 475 | " grad_fn=)\n" 476 | ] 477 | } 478 | ], 479 | "source": [ 480 | "#| eval: false \n", 481 | "if use_li:\n", 482 | " diff = sa2 - sq\n", 483 | " print(\"diff = \\n\",diff) " 484 | ] 485 | }, 486 | { 487 | "cell_type": "markdown", 488 | "metadata": {}, 489 | "source": [ 490 | "## Speed & device tests" 491 | ] 492 | }, 493 | { 494 | "cell_type": "code", 495 | "execution_count": null, 496 | "metadata": {}, 497 | "outputs": [], 498 | "source": [ 499 | "#| eval: false \n", 500 | "from aeiou.core import get_device" 501 | ] 502 | }, 503 | { 504 | "cell_type": "code", 505 | "execution_count": null, 506 | "metadata": {}, 507 | "outputs": [ 508 | { 509 | "name": "stdout", 510 | "output_type": "stream", 511 | "text": [ 512 | "device = cuda\n" 513 | ] 514 | } 515 | ], 516 | "source": [ 517 | "#| eval: false \n", 518 | "device = get_device()\n", 519 | "print('device = ',device)\n", 520 | "n,m = 1000, 1000\n", 521 | "with torch.no_grad(): \n", 522 | " k = torch.randn(n, m, device=device)\n", 523 | " pd_mat2 = (k.t().matmul(k)) # Create a positive definite matrix, no grad" 524 | ] 525 | }, 526 | { 527 | "cell_type": "code", 528 | "execution_count": null, 529 | "metadata": {}, 530 | "outputs": [], 531 | "source": [ 532 | "#| eval: false \n", 533 | "# %%timeit\n", 534 | "if use_li:\n", 535 | " sq2 = sqrtm_li(pd_mat2)" 536 | ] 537 | }, 538 | { 539 | "cell_type": "markdown", 540 | "metadata": {}, 541 | "source": [ 542 | "Result of `%%timeit`:\n", 543 | "\n", 544 | "`1.12 s ± 191 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)`" 545 | ] 546 | }, 547 | { 548 | "cell_type": "code", 549 | "execution_count": null, 550 | "metadata": {}, 551 | "outputs": [ 552 | { 553 | "ename": "NameError", 554 | "evalue": "name 'sq2' is not defined", 555 | "output_type": "error", 556 | "traceback": [ 557 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 558 | "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", 559 | "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;66;03m#| eval: false \u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[43msq2\u001b[49m)\n", 560 | "\u001b[0;31mNameError\u001b[0m: name 'sq2' is not defined" 561 | ] 562 | } 563 | ], 564 | "source": [ 565 | "#| eval: false \n", 566 | "print(sq2)" 567 | ] 568 | }, 569 | { 570 | "cell_type": "code", 571 | "execution_count": null, 572 | "metadata": {}, 573 | "outputs": [], 574 | "source": [ 575 | "#| eval: false \n", 576 | "# %%timeit\n", 577 | "sq3 = sqrt_newton_schulz(pd_mat2.unsqueeze(0), numIters=20)[0]" 578 | ] 579 | }, 580 | { 581 | "cell_type": "markdown", 582 | "metadata": {}, 583 | "source": [ 584 | "Result of `%%timeit`:\n", 585 | "\n", 586 | "`8.8 ms ± 23.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)`" 587 | ] 588 | }, 589 | { 590 | "cell_type": "markdown", 591 | "metadata": {}, 592 | "source": [ 593 | "## Wrapper around our method of choice:\n", 594 | "TLDR, we'll use Maji's Newton-Schulz method. Newton-Schulz is an approximate iterative method rather than an exact matrix sqrt, however, with 7 iterations the error is below 1e-5, (presumably significantly) lower than other errors in the problem. " 595 | ] 596 | }, 597 | { 598 | "cell_type": "code", 599 | "execution_count": null, 600 | "metadata": {}, 601 | "outputs": [], 602 | "source": [ 603 | "#|export \n", 604 | "def sqrtm(A, method='maji', numIters=20):\n", 605 | " \"wrapper function for matrix sqrt algorithm of choice. Also we'll turn off all gradients\"\n", 606 | " with torch.no_grad():\n", 607 | " if method=='maji':\n", 608 | " return sqrt_newton_schulz(A, numIters=numIters, calc_error=False).squeeze() # get rid of any useless batch dimensions\n", 609 | " elif method=='li': \n", 610 | " return sqrtm_li(A)\n", 611 | " else:\n", 612 | " raise ValueError(f\"Invalid method: {method}\") " 613 | ] 614 | }, 615 | { 616 | "cell_type": "code", 617 | "execution_count": null, 618 | "metadata": {}, 619 | "outputs": [ 620 | { 621 | "data": { 622 | "text/plain": [ 623 | "tensor([[ 1.9073e-06, 1.5616e-05, 9.0003e-06, ..., -8.9407e-07,\n", 624 | " -3.7104e-05, 6.7353e-06],\n", 625 | " [ 1.9193e-05, -6.1035e-05, -4.7386e-06, ..., 7.8678e-06,\n", 626 | " 2.9802e-05, -5.8711e-06],\n", 627 | " [ 7.5698e-06, -5.4836e-06, -3.8147e-05, ..., 2.0236e-05,\n", 628 | " 1.6876e-06, -2.9024e-05],\n", 629 | " ...,\n", 630 | " [-3.3975e-06, 3.6955e-06, 2.2471e-05, ..., -1.1444e-05,\n", 631 | " 3.6955e-06, 1.2971e-05],\n", 632 | " [-3.6355e-05, 2.5883e-05, 7.5437e-06, ..., 9.4771e-06,\n", 633 | " -6.8665e-05, 6.7949e-06],\n", 634 | " [ 4.2319e-06, -1.3113e-05, -2.6686e-05, ..., 9.9763e-06,\n", 635 | " 1.1921e-05, 5.7220e-06]], device='cuda:0')" 636 | ] 637 | }, 638 | "execution_count": null, 639 | "metadata": {}, 640 | "output_type": "execute_result" 641 | } 642 | ], 643 | "source": [ 644 | "#| eval: false \n", 645 | "sqrtm(pd_mat2, method='maji') - sqrtm(pd_mat2, method='li') " 646 | ] 647 | }, 648 | { 649 | "cell_type": "code", 650 | "execution_count": null, 651 | "metadata": {}, 652 | "outputs": [], 653 | "source": [ 654 | "#| hide\n", 655 | "import nbdev; nbdev.nbdev_export()" 656 | ] 657 | }, 658 | { 659 | "cell_type": "code", 660 | "execution_count": null, 661 | "metadata": {}, 662 | "outputs": [], 663 | "source": [] 664 | } 665 | ], 666 | "metadata": { 667 | "kernelspec": { 668 | "display_name": "aa", 669 | "language": "python", 670 | "name": "aa" 671 | } 672 | }, 673 | "nbformat": 4, 674 | "nbformat_minor": 4 675 | } 676 | -------------------------------------------------------------------------------- /nbs/_quarto.yml: -------------------------------------------------------------------------------- 1 | project: 2 | type: website 3 | 4 | format: 5 | html: 6 | theme: cosmo 7 | css: styles.css 8 | toc: true 9 | 10 | website: 11 | twitter-card: true 12 | open-graph: true 13 | repo-actions: [issue] 14 | navbar: 15 | background: primary 16 | search: true 17 | sidebar: 18 | style: floating 19 | 20 | metadata-files: [nbdev.yml, sidebar.yml] -------------------------------------------------------------------------------- /nbs/index.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "#| hide\n", 10 | "#from fad_pytorch.core import *" 11 | ] 12 | }, 13 | { 14 | "cell_type": "markdown", 15 | "metadata": {}, 16 | "source": [ 17 | "# fad_pytorch\n", 18 | "\n", 19 | "> Frechet Audio Distance evaluation in PyTorch" 20 | ] 21 | }, 22 | { 23 | "cell_type": "markdown", 24 | "metadata": {}, 25 | "source": [ 26 | "[Original FAD paper (PDF)](https://arxiv.org/pdf/1812.08466.pdf)" 27 | ] 28 | }, 29 | { 30 | "cell_type": "markdown", 31 | "metadata": {}, 32 | "source": [ 33 | "## Install" 34 | ] 35 | }, 36 | { 37 | "cell_type": "markdown", 38 | "metadata": {}, 39 | "source": [ 40 | "```sh\n", 41 | "pip install fad_pytorch\n", 42 | "```" 43 | ] 44 | }, 45 | { 46 | "cell_type": "markdown", 47 | "metadata": {}, 48 | "source": [ 49 | "## Features:\n", 50 | "\n", 51 | "- runs in parallel on multiple processors and multiple GPUs (via `accelerate`)\n", 52 | "- supports multiple embedding methods:\n", 53 | " - VGGish and PANN, both mono @ 16kHz\n", 54 | " - OpenL3 and (LAION-)CLAP, stereo @ 48kHz\n", 55 | "- uses publicly-available pretrained checkpoints for music (+other sources) for those models. (if you want Speech, submit a PR or an Issue; I don't do speech.)\n", 56 | "- favors ops in PyTorch rather than numpy (or tensorflow)\n", 57 | "- `fad_gen` supports local data read or WebDataset (audio data stored in S3 buckets)\n", 58 | "- runs on CPU, CUDA, or MPS \n", 59 | "\n", 60 | "## Instructions:\n", 61 | "\n", 62 | "This is designed to be run as 3 command-line scripts in succession. The latter 2 (`fad_embed` and `fad_score`) are probably what most people will want:\n", 63 | "\n", 64 | "1. `fad_gen`: produces directories of real & fake audio (given real data). See `fad_gen` [documentation](https://drscotthawley.github.io/fad_pytorch/fad_gen.html) for calling sequence.\n", 65 | "2. `fad_embed [options] `: produces directories of *embeddings* of real & fake audio\n", 66 | "3. `fad_score [options] `: reads the embeddings & generates FAD score, for real (\"$r$\") and fake (\"$f$\"): \n", 67 | "\n", 68 | "$$ FAD = || \\mu_r - \\mu_f ||^2 + tr\\left(\\Sigma_r + \\Sigma_f - 2 \\sqrt{\\Sigma_r \\Sigma_f}\\right)$$\n", 69 | "\n", 70 | "## Documentation\n", 71 | "See the [Documentation Website](https://drscotthawley.github.io/fad_pytorch/). \n" 72 | ] 73 | }, 74 | { 75 | "cell_type": "markdown", 76 | "metadata": {}, 77 | "source": [ 78 | "## Comments / FAQ / Troubleshooting\n", 79 | "\n", 80 | "- \"`RuntimeError: CUDA error: invalid device ordinal`\": This happens when you have a \"bad node\" on an AWS cluster. [Haven't yet figured out what causes it or how to fix it](https://discuss.huggingface.co/t/solved-accelerate-accelerator-cuda-error-invalid-device-ordinal/21509/1). Workaround: Just add the current node to your SLURM `--exclude` list, exit and retry. Note: it may take as many as 5 to 7 retries before you get a \"good node\". \n", 81 | "- \"FAD scores obtained from different embedding methods are *wildly* different!\" ...Yea. It's not obvious that scores from different embedding methods should be comparable. Rather, compare different groups of audio files using the same embedding method, and/or check that FAD scores go *down* as similarity improves.\n", 82 | "- \"FAD score for the same dataset repeated (twice) is not exactly zero!\" ...Yea. There seems to be an uncertainty of around +/- 0.008. I'd say, don't quote any numbers past the first decimal point. " 83 | ] 84 | }, 85 | { 86 | "cell_type": "markdown", 87 | "metadata": {}, 88 | "source": [ 89 | "## Contributing\n", 90 | "This repo is still fairly \"bare bones\" and will benefit from more documentation and features as time goes on. Note that it is written using [nbdev](https://nbdev.fast.ai/), so the things to do are:\n", 91 | "\n", 92 | "1. Fork this repo\n", 93 | "1. Clone your fork to your (local) machine \n", 94 | "1. Install nbdev: `python3 -m pip install -U nbdev`\n", 95 | "1. Make changes by editing the notebooks in `nbs/`, not the `.py` files in `fad_pytorch/`. \n", 96 | "1. Run `nbdev_export` to export notebook changes to `.py` files \n", 97 | "1. For good measure, run `nbdev_install_hooks` and `nbdev_clean` - especially if you've *added* any notebooks. \n", 98 | "1. Do a `git status` to see all the `.ipynb` and `.py` files that need to be added & committed\n", 99 | "1. `git add` those files and then `git commit`, and then `git push`\n", 100 | "1. Take a look in your fork's GitHub Actions tab, and see if the \"test\" and \"deploy\" CI runs finish properly (green light) or fail (red light) \n", 101 | "1. Once you get green lights, send in a Pull Request! \n", 102 | "\n", 103 | "*Feel free to ask me for tips with nbdev, it has quite a learning curve. You can also ask on [fast.ai forums](https://forums.fast.ai/) and/or [fast.ai Discord](https://discord.com/channels/689892369998676007/887694559952400424)* " 104 | ] 105 | }, 106 | { 107 | "cell_type": "markdown", 108 | "metadata": {}, 109 | "source": [ 110 | "## Citations / Blame / Disclaimer\n", 111 | "\n", 112 | "This repo is 2 weeks old. I'm not ready for this to be cited in your papers. I'd hate for there to be some mistake I haven't found yet. Perhaps a later version will have citation info. For now, instead, there's:\n", 113 | "\n", 114 | "**Disclaimer:** Results from this repo are still a work in progress. While every effort has been made to test model outputs, the author takes no responsbility for mistakes. If you want to double-check via another source, see \"Related Repos\" below. " 115 | ] 116 | }, 117 | { 118 | "cell_type": "markdown", 119 | "metadata": {}, 120 | "source": [ 121 | "## Related Repos\n", 122 | "There are [several] others, but this one is mine. These repos didn't have all the features I wanted, but I used them for inspiration:\n", 123 | "\n", 124 | "- https://github.com/gudgud96/frechet-audio-distance\n", 125 | "- https://github.com/google-research/google-research/tree/master/frechet_audio_distance: Goes with [Original FAD paper](https://arxiv.org/pdf/1812.08466.pdf)\n", 126 | "- https://github.com/AndreevP/speech_distances" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": null, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [] 135 | } 136 | ], 137 | "metadata": { 138 | "kernelspec": { 139 | "display_name": "aa", 140 | "language": "python", 141 | "name": "aa" 142 | } 143 | }, 144 | "nbformat": 4, 145 | "nbformat_minor": 4 146 | } 147 | -------------------------------------------------------------------------------- /nbs/nbdev.yml: -------------------------------------------------------------------------------- 1 | project: 2 | output-dir: _docs 3 | 4 | website: 5 | title: "fad_pytorch" 6 | site-url: "https://drscotthawley.github.io/fad_pytorch" 7 | description: "Frechet Audio Distance evaluation in PyTorch" 8 | repo-branch: main 9 | repo-url: "https://github.com/drscotthawley/fad_pytorch" 10 | -------------------------------------------------------------------------------- /nbs/styles.css: -------------------------------------------------------------------------------- 1 | .cell { 2 | margin-bottom: 1rem; 3 | } 4 | 5 | .cell > .sourceCode { 6 | margin-bottom: 0; 7 | } 8 | 9 | .cell-output > pre { 10 | margin-bottom: 0; 11 | } 12 | 13 | .cell-output > pre, .cell-output > .sourceCode > pre, .cell-output-stdout > pre { 14 | margin-left: 0.8rem; 15 | margin-top: 0; 16 | background: none; 17 | border-left: 2px solid lightsalmon; 18 | border-top-left-radius: 0; 19 | border-top-right-radius: 0; 20 | } 21 | 22 | .cell-output > .sourceCode { 23 | border: none; 24 | } 25 | 26 | .cell-output > .sourceCode { 27 | background: none; 28 | margin-top: 0; 29 | } 30 | 31 | div.description { 32 | padding-left: 2px; 33 | padding-top: 5px; 34 | font-style: italic; 35 | font-size: 135%; 36 | opacity: 70%; 37 | } 38 | -------------------------------------------------------------------------------- /settings.ini: -------------------------------------------------------------------------------- 1 | [DEFAULT] 2 | # All sections below are required unless otherwise specified. 3 | # See https://github.com/fastai/nbdev/blob/master/settings.ini for examples. 4 | 5 | ### Python library ### 6 | repo = fad_pytorch 7 | lib_name = %(repo)s 8 | version = 0.0.6 9 | min_python = 3.7 10 | license = apache2 11 | black_formatting = False 12 | 13 | ### nbdev ### 14 | doc_path = _docs 15 | lib_path = fad_pytorch 16 | nbs_path = nbs 17 | recursive = True 18 | tst_flags = notest 19 | put_version_in_init = True 20 | 21 | ### Docs ### 22 | branch = main 23 | custom_sidebar = False 24 | doc_host = https://%(user)s.github.io 25 | doc_baseurl = /%(repo)s 26 | git_url = https://github.com/%(user)s/%(repo)s 27 | title = %(lib_name)s 28 | 29 | ### PyPI ### 30 | audience = Developers 31 | author = Scott H. Hawley 32 | author_email = scott.hawley@belmont.edu 33 | copyright = 2023 onwards, %(author)s 34 | description = Frechet Audio Distance evaluation in PyTorch 35 | keywords = nbdev jupyter notebook python 36 | language = English 37 | status = 3 38 | user = drscotthawley 39 | 40 | ### Optional ### 41 | requirements = aeiou torch>=1.13.1 torchaudio>=0.13.1 laion-clap accelerate torchlibrosa torchopenl3 42 | # torchopenl3 is turian et al's. pypi won't let use hugo's: git+https://github.com/hugofloresgarcia/torchopenl3.git 43 | 44 | # dev_requirements = 45 | console_scripts = fad_gen=fad_pytorch.fad_gen:main fad_score=fad_pytorch.fad_score:main fad_embed=fad_pytorch.fad_embed:main -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from pkg_resources import parse_version 2 | from configparser import ConfigParser 3 | import setuptools, shlex 4 | assert parse_version(setuptools.__version__)>=parse_version('36.2') 5 | 6 | # note: all settings are in settings.ini; edit there, not here 7 | config = ConfigParser(delimiters=['=']) 8 | config.read('settings.ini', encoding='utf-8') 9 | cfg = config['DEFAULT'] 10 | 11 | cfg_keys = 'version description keywords author author_email'.split() 12 | expected = cfg_keys + "lib_name user branch license status min_python audience language".split() 13 | for o in expected: assert o in cfg, "missing expected setting: {}".format(o) 14 | setup_cfg = {o:cfg[o] for o in cfg_keys} 15 | 16 | licenses = { 17 | 'apache2': ('Apache Software License 2.0','OSI Approved :: Apache Software License'), 18 | 'mit': ('MIT License', 'OSI Approved :: MIT License'), 19 | 'gpl2': ('GNU General Public License v2', 'OSI Approved :: GNU General Public License v2 (GPLv2)'), 20 | 'gpl3': ('GNU General Public License v3', 'OSI Approved :: GNU General Public License v3 (GPLv3)'), 21 | 'bsd3': ('BSD License', 'OSI Approved :: BSD License'), 22 | } 23 | statuses = [ '1 - Planning', '2 - Pre-Alpha', '3 - Alpha', 24 | '4 - Beta', '5 - Production/Stable', '6 - Mature', '7 - Inactive' ] 25 | py_versions = '3.6 3.7 3.8 3.9 3.10'.split() 26 | 27 | requirements = shlex.split(cfg.get('requirements', '')) 28 | if cfg.get('pip_requirements'): requirements += shlex.split(cfg.get('pip_requirements', '')) 29 | min_python = cfg['min_python'] 30 | lic = licenses.get(cfg['license'].lower(), (cfg['license'], None)) 31 | dev_requirements = (cfg.get('dev_requirements') or '').split() 32 | 33 | setuptools.setup( 34 | name = cfg['lib_name'], 35 | license = lic[0], 36 | classifiers = [ 37 | 'Development Status :: ' + statuses[int(cfg['status'])], 38 | 'Intended Audience :: ' + cfg['audience'].title(), 39 | 'Natural Language :: ' + cfg['language'].title(), 40 | ] + ['Programming Language :: Python :: '+o for o in py_versions[py_versions.index(min_python):]] + (['License :: ' + lic[1] ] if lic[1] else []), 41 | url = cfg['git_url'], 42 | packages = setuptools.find_packages(), 43 | include_package_data = True, 44 | install_requires = requirements, 45 | extras_require={ 'dev': dev_requirements }, 46 | dependency_links = cfg.get('dep_links','').split(), 47 | python_requires = '>=' + cfg['min_python'], 48 | long_description = open('README.md', encoding='utf-8').read(), 49 | long_description_content_type = 'text/markdown', 50 | zip_safe = False, 51 | entry_points = { 52 | 'console_scripts': cfg.get('console_scripts','').split(), 53 | 'nbdev': [f'{cfg.get("lib_path")}={cfg.get("lib_path")}._modidx:d'] 54 | }, 55 | **setup_cfg) 56 | 57 | 58 | --------------------------------------------------------------------------------