├── .gitignore ├── LICENSE ├── LICENSE-AUDIOREACTIVE ├── LICENSE-AUTOENCODER ├── LICENSE-CONTRASTIVE-LEARNER ├── LICENSE-FID ├── LICENSE-LPIPS ├── LICENSE-LUCIDRAINS ├── LICENSE-NVIDIA ├── LICENSE-ROSINALITY └── LICENSE-VGG ├── README.md ├── accelerate ├── accelerate_inception.py ├── accelerate_logcosh.py └── accelerate_segnet.py ├── audioreactive ├── __init__.py ├── bend.py ├── examples │ ├── Wavefunk - Dwelling in the Kelp.mp3 │ ├── Wavefunk - Tau Ceti Alpha.mp3 │ ├── Wavefunk - Temper.mp3 │ ├── __init__.py │ ├── default.py │ ├── kelp.py │ ├── tauceti.py │ └── temper.py ├── latent.py ├── signal.py └── util.py ├── augment.py ├── contrastive_learner.py ├── convert_weight.py ├── dataset.py ├── distributed.py ├── generate.py ├── generate_audiovisual.py ├── generate_video.py ├── gpu_profile.py ├── gpumon.py ├── lightning.py ├── lookahead_minimax.py ├── lucidrains.py ├── models ├── autoencoder.py ├── stylegan1.py └── stylegan2.py ├── op ├── __init__.py ├── fused_act.py ├── fused_bias_act.cpp ├── fused_bias_act_kernel.cu ├── upfirdn2d.cpp ├── upfirdn2d.py └── upfirdn2d_kernel.cu ├── prepare_data.py ├── prepare_vae_codes.py ├── projector.py ├── render.py ├── requirements.txt ├── select_latents.py ├── train.py ├── train_profile.py ├── validation ├── __init__.py ├── calc_fid.py ├── calc_inception.py ├── calc_ppl.py ├── inception.py ├── lpips │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ ├── util.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── metrics.py └── spectral_norm.py └── workspace ├── naamloos_average_pitch.npy ├── naamloos_bass_sum.npy ├── naamloos_drop_latents.npy ├── naamloos_drop_latents_1.npy ├── naamloos_high_average_pitch.npy ├── naamloos_high_pitches_mean.npy ├── naamloos_intro_latents.npy ├── naamloos_metadata.json ├── naamloos_onsets.npy ├── naamloos_params.json ├── naamloos_pitches_mean.npy └── naamloos_rms.npy /.gitignore: -------------------------------------------------------------------------------- 1 | pretrained_models/ 2 | wandb 3 | wandb/ 4 | *.lmdb/ 5 | *.pkl 6 | checkpoints/ 7 | maua-stylegan/ 8 | .vscode 9 | output/ 10 | workspace/* 11 | !workspace 12 | output/* 13 | !output 14 | 15 | # Byte-compiled / optimized / DLL files 16 | __pycache__/ 17 | *.py[cod] 18 | *$py.class 19 | 20 | # C extensions 21 | *.so 22 | 23 | # Distribution / packaging 24 | .Python 25 | build/ 26 | develop-eggs/ 27 | dist/ 28 | downloads/ 29 | eggs/ 30 | .eggs/ 31 | lib/ 32 | lib64/ 33 | parts/ 34 | sdist/ 35 | var/ 36 | wheels/ 37 | pip-wheel-metadata/ 38 | share/python-wheels/ 39 | *.egg-info/ 40 | .installed.cfg 41 | *.egg 42 | MANIFEST 43 | 44 | # PyInstaller 45 | # Usually these files are written by a python script from a template 46 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 47 | *.manifest 48 | *.spec 49 | 50 | # Installer logs 51 | pip-log.txt 52 | pip-delete-this-directory.txt 53 | 54 | # Unit test / coverage reports 55 | htmlcov/ 56 | .tox/ 57 | .nox/ 58 | .coverage 59 | .coverage.* 60 | .cache 61 | nosetests.xml 62 | coverage.xml 63 | *.cover 64 | *.py,cover 65 | .hypothesis/ 66 | .pytest_cache/ 67 | 68 | # Translations 69 | *.mo 70 | *.pot 71 | 72 | # Django stuff: 73 | *.log 74 | local_settings.py 75 | db.sqlite3 76 | db.sqlite3-journal 77 | 78 | # Flask stuff: 79 | instance/ 80 | .webassets-cache 81 | 82 | # Scrapy stuff: 83 | .scrapy 84 | 85 | # Sphinx documentation 86 | docs/_build/ 87 | 88 | # PyBuilder 89 | target/ 90 | 91 | # Jupyter Notebook 92 | .ipynb_checkpoints 93 | 94 | # IPython 95 | profile_default/ 96 | ipython_config.py 97 | 98 | # pyenv 99 | .python-version 100 | 101 | # pipenv 102 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 103 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 104 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 105 | # install all needed dependencies. 106 | #Pipfile.lock 107 | 108 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 109 | __pypackages__/ 110 | 111 | # Celery stuff 112 | celerybeat-schedule 113 | celerybeat.pid 114 | 115 | # SageMath parsed files 116 | *.sage.py 117 | 118 | # Environments 119 | .env 120 | .venv 121 | env/ 122 | venv/ 123 | ENV/ 124 | env.bak/ 125 | venv.bak/ 126 | 127 | # Spyder project settings 128 | .spyderproject 129 | .spyproject 130 | 131 | # Rope project settings 132 | .ropeproject 133 | 134 | # mkdocs documentation 135 | /site 136 | 137 | # mypy 138 | .mypy_cache/ 139 | .dmypy.json 140 | dmypy.json 141 | 142 | # Pyre type checker 143 | .pyre/ 144 | 145 | -------------------------------------------------------------------------------- /LICENSE/LICENSE-CONTRASTIVE-LEARNER: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Phil Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE/LICENSE-LPIPS: -------------------------------------------------------------------------------- 1 | Copyright (c) 2018, Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 15 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 16 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 17 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 18 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 19 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 20 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 21 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 22 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 23 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 24 | 25 | -------------------------------------------------------------------------------- /LICENSE/LICENSE-NVIDIA: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | 3 | 4 | Nvidia Source Code License-NC 5 | 6 | ======================================================================= 7 | 8 | 1. Definitions 9 | 10 | "Licensor" means any person or entity that distributes its Work. 11 | 12 | "Software" means the original work of authorship made available under 13 | this License. 14 | 15 | "Work" means the Software and any additions to or derivative works of 16 | the Software that are made available under this License. 17 | 18 | "Nvidia Processors" means any central processing unit (CPU), graphics 19 | processing unit (GPU), field-programmable gate array (FPGA), 20 | application-specific integrated circuit (ASIC) or any combination 21 | thereof designed, made, sold, or provided by Nvidia or its affiliates. 22 | 23 | The terms "reproduce," "reproduction," "derivative works," and 24 | "distribution" have the meaning as provided under U.S. copyright law; 25 | provided, however, that for the purposes of this License, derivative 26 | works shall not include works that remain separable from, or merely 27 | link (or bind by name) to the interfaces of, the Work. 28 | 29 | Works, including the Software, are "made available" under this License 30 | by including in or with the Work either (a) a copyright notice 31 | referencing the applicability of this License to the Work, or (b) a 32 | copy of this License. 33 | 34 | 2. License Grants 35 | 36 | 2.1 Copyright Grant. Subject to the terms and conditions of this 37 | License, each Licensor grants to you a perpetual, worldwide, 38 | non-exclusive, royalty-free, copyright license to reproduce, 39 | prepare derivative works of, publicly display, publicly perform, 40 | sublicense and distribute its Work and any resulting derivative 41 | works in any form. 42 | 43 | 3. Limitations 44 | 45 | 3.1 Redistribution. You may reproduce or distribute the Work only 46 | if (a) you do so under this License, (b) you include a complete 47 | copy of this License with your distribution, and (c) you retain 48 | without modification any copyright, patent, trademark, or 49 | attribution notices that are present in the Work. 50 | 51 | 3.2 Derivative Works. You may specify that additional or different 52 | terms apply to the use, reproduction, and distribution of your 53 | derivative works of the Work ("Your Terms") only if (a) Your Terms 54 | provide that the use limitation in Section 3.3 applies to your 55 | derivative works, and (b) you identify the specific derivative 56 | works that are subject to Your Terms. Notwithstanding Your Terms, 57 | this License (including the redistribution requirements in Section 58 | 3.1) will continue to apply to the Work itself. 59 | 60 | 3.3 Use Limitation. The Work and any derivative works thereof only 61 | may be used or intended for use non-commercially. The Work or 62 | derivative works thereof may be used or intended for use by Nvidia 63 | or its affiliates commercially or non-commercially. As used herein, 64 | "non-commercially" means for research or evaluation purposes only. 65 | 66 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 67 | against any Licensor (including any claim, cross-claim or 68 | counterclaim in a lawsuit) to enforce any patents that you allege 69 | are infringed by any Work, then your rights under this License from 70 | such Licensor (including the grants in Sections 2.1 and 2.2) will 71 | terminate immediately. 72 | 73 | 3.5 Trademarks. This License does not grant any rights to use any 74 | Licensor's or its affiliates' names, logos, or trademarks, except 75 | as necessary to reproduce the notices described in this License. 76 | 77 | 3.6 Termination. If you violate any term of this License, then your 78 | rights under this License (including the grants in Sections 2.1 and 79 | 2.2) will terminate immediately. 80 | 81 | 4. Disclaimer of Warranty. 82 | 83 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 84 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 85 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 86 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 87 | THIS LICENSE. 88 | 89 | 5. Limitation of Liability. 90 | 91 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 92 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 93 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 94 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 95 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 96 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 97 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 98 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 99 | THE POSSIBILITY OF SUCH DAMAGES. 100 | 101 | ======================================================================= 102 | -------------------------------------------------------------------------------- /LICENSE/LICENSE-ROSINALITY: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kim Seonghyeon 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /LICENSE/LICENSE-VGG: -------------------------------------------------------------------------------- 1 | Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu. 2 | BSD License. All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution. 13 | 14 | THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL 15 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE. 16 | IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL 17 | DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, 18 | WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING 19 | OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. 20 | 21 | 22 | --------------------------- LICENSE FOR pytorch-CycleGAN-and-pix2pix ---------------- 23 | Copyright (c) 2017, Jun-Yan Zhu and Taesung Park 24 | All rights reserved. 25 | 26 | Redistribution and use in source and binary forms, with or without 27 | modification, are permitted provided that the following conditions are met: 28 | 29 | * Redistributions of source code must retain the above copyright notice, this 30 | list of conditions and the following disclaimer. 31 | 32 | * Redistributions in binary form must reproduce the above copyright notice, 33 | this list of conditions and the following disclaimer in the documentation 34 | and/or other materials provided with the distribution. 35 | 36 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 37 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 38 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 39 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 40 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 41 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 42 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 43 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 44 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 45 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # maua-stylegan2 2 | 3 | This is the repo for my experiments with StyleGAN2. There are many like it, but this one is mine. 4 | 5 | It contains the code for [Audio-reactive Latent Interpolations with StyleGAN](https://wavefunk.xyz/assets/audio-reactive-stylegan/paper.pdf) for the NeurIPS 2020 [Workshop on Machine Learning for Creativity and Design](https://neurips2020creativity.github.io/). 6 | 7 | The original base is [Kim Seonghyeon's excellent implementation](https://github.com/rosinality/stylegan2-pytorch), but I've gathered code from multiple different repositories or other places online and hacked/grafted it all together. License information for the code should all be in the LICENSE folder, but if you find anything missing or incorrect please let me know and I'll fix it immediately. Tread carefully when trying to distribute any code from this repo, it's meant for research and demonstration. 8 | 9 | The files/folders of interest and their purpose are: 10 | 11 | | File/Folder | Description 12 | | :--- | :---------- 13 | | generate_audiovisual.py | used to generate audio-reactive interpolations 14 | | audioreactive/ | contains the main functions needed for audioreactiveness + examples demonstrating how they can be used 15 | | render.py | renders interpolations using ffmpeg 16 | | select_latents.py | GUI for selecting latents, left click to add to top set, right click to add to bottom 17 | | models/ | StyleGAN networks 18 | | workspace/ | place to store intermediate results, latents, or inputs, etc. 19 | | output/ | default generated output folder 20 | | train.py | code for training models 21 | 22 | The rest of the code is experimental, probably broken, and unsupported. 23 | 24 | ## Installation 25 | 26 | ```bash 27 | git clone https://github.com/JCBrouwer/maua-stylegan2 28 | cd maua-stylegan2 29 | pip install -r requirements.txt 30 | ``` 31 | 32 | Alternatively, check out this [Colab Notebook](https://colab.research.google.com/drive/1Ig1EXfmBC01qik11Q32P0ZffFtNipiBR) 33 | 34 | ## Generating audio-reactive interpolations 35 | 36 | The simplest way to get started is to try either (in shell): 37 | ```bash 38 | python generate_audiovisual.py --ckpt "/path/to/model.pt" --audio_file "/path/to/audio.wav" 39 | ``` 40 | or (in e.g. a jupyter notebook): 41 | ```python 42 | from generate_audiovisual import generate 43 | generate("/path/to/model.pt", "/path/to/audio.wav") 44 | ``` 45 | 46 | This will use the default audio-reactive settings (which aren't great). 47 | 48 | To customize the generated interpolation, more functions can be defined to generate latents, noise, network bends, model rewrites, and truncation. 49 | 50 | ```python 51 | import audioreactive as ar 52 | from generate_audiovisual import generate 53 | 54 | def initialize(args): 55 | args.onsets = ar.onsets(args.audio, args.sr, ...) 56 | args.chroma = ar.chroma(args.audio, args.sr, ...) 57 | return args 58 | 59 | def get_latents(selection, args): 60 | latents = ar.chroma_weight_latents(args.chroma, selection) 61 | return latents 62 | 63 | def get_noise(height, width, scale, num_scales, args): 64 | noise = ar.perlin_noise(...) 65 | noise *= 1 + args.onsets 66 | return noise 67 | 68 | generate(ckpt="/path/to/model.pt", audio_file="/path/to/audio.wav", initialize=initialize, get_latents=get_latents, get_noise=get_noise) 69 | ``` 70 | 71 | When running from command line, the `generate()` call at the end can be left out and the interpolation can be generated with: 72 | 73 | ```bash 74 | python generate_audiovisual.py --ckpt "/path/to/model.pt" --audio_file "/path/to/audio.wav" --audioreactive_file "/path/to/the/code_above.py" 75 | ``` 76 | 77 | This lets you change arguments on the command line rather than having to add them to the `generate()` call in you python file (use whatever you prefer). 78 | 79 | Within these functions, you can execute any python code to make the inputs to the network react to the music. There are a number of useful functions provided in `audioreactive/` (imported above as `ar`). 80 | 81 | Examples showing how to use the library and demonstrating some of the techniques discussed in the paper can be found in `audioreactive/examples/`. A playlist with example results can be found [here](https://www.youtube.com/watch?v=2LxHRGppdpA&list=PLkain1QGMwiWndQwr3U4shvNpoFC21E3a). 82 | 83 | One important thing to note is that the outputs of the functions must adhere strictly to the expected formats. 84 | 85 | Each of the functions is called with all of the arguments from the command line (or `generate()`) in the `args` variable. On top of the arguments, `args` also contains: 86 | - audio: raw audio signal 87 | - sr: sampling rate of audio 88 | - n_frames: total number of interpolation frames 89 | - duration: length of audio in seconds 90 | 91 | ```python 92 | def initialize(args): 93 | # intialize values used in multiple of the following functions here 94 | # e.g. onsets, chroma, RMS, segmentations, bpms, etc. 95 | # this is useful to prevent duplicate computations (get_noise is called for each noise size) 96 | # remember to store them back in args 97 | ... 98 | return args 99 | 100 | def get_latents(selection, args): 101 | # selection holds some latent vectors (generated randomly or from a file) 102 | # generate an audioreactive latent tensor of shape [n_frames, layers, latent_dim] 103 | ... 104 | return latents 105 | 106 | def get_noise(height, width, scale, num_scales, args): 107 | # height and width are the spatial dimensions of the current noise layer 108 | # scale is the index and num_scales the total number of noise layers 109 | # generate an audioreactive noise tensor of shape [n_frames, 1, height, width] 110 | ... 111 | return noise 112 | 113 | def get_bends(args): 114 | # generate a list of dictionaries specifying network bends 115 | # these must follow one of two forms: 116 | # 117 | # either: { 118 | # "layer": layer index to apply bend to, 119 | # "transform": torch.nn.Module that applies the transformation, 120 | # } 121 | # or: { 122 | # "layer": layer index to apply bend to, 123 | # "modulation": time dependent modulation of the transformation, shape=(n_frames, ...), 124 | # "transform": function that takes a batch of modulation and returns a torch.nn.Module 125 | # that applies the transformation (given the modulation batch), 126 | # } 127 | # (The second one is technical debt in a nutshell. It's a workaround to get kornia transforms 128 | # to play nicely. You're probably better off using the first option with a th.nn.Module that 129 | # has its modulation as an attribute and keeps count of which frame it's rendering internally). 130 | ... 131 | return bends 132 | 133 | def get_rewrites(args): 134 | # generate a dictionary specifying model rewrites 135 | # each key value pair should follow: 136 | # param_name -> [transform, modulation] 137 | # where: param_name is the fully-qualified parameter name (see generator.named_children()) 138 | # transform & modulation follow the form of the second network bending dict option above 139 | ... 140 | return rewrites 141 | 142 | def get_truncation(args): 143 | # generate a sequence of truncation values of shape (n_frames,) 144 | ... 145 | return truncation 146 | ``` 147 | 148 | The arguments to `generate_audiovisual.py` are as follows. The first two are required, and the remaining are optional. 149 | ```bash 150 | generate_audiovisual.py 151 | --ckpt CKPT # path to model checkpoint 152 | --audio_file AUDIO_FILE # path to audio file to react to 153 | --audioreactive_file AUDIOREACTIVE_FILE # file with audio-reactive functions defined (as above) 154 | --output_dir OUTPUT_DIR # path to output dir 155 | --offset OFFSET # starting time in audio in seconds (defaults to 0) 156 | --duration DURATION # duration of interpolation to generate in seconds (leave empty for length of audiofile) 157 | --latent_file LATENT_FILE # path to latents saved as numpy array 158 | --shuffle_latents # whether to shuffle the supplied latents or not 159 | --out_size OUT_SIZE # ouput video size: [512, 1024, or 1920] 160 | --fps FPS # output video framerate 161 | --batch BATCH # batch size to render with 162 | --truncation TRUNCATION # truncation to render with (leave empty if get_truncations() is in --audioreactive_file) 163 | --randomize_noise # whether to randomize noise 164 | --dataparallel # whether to use data parallel rendering 165 | --stylegan1 # if the model checkpoint is StyleGAN1 166 | --G_res G_RES # training resolution of the generator 167 | --base_res_factor BASE_RES_FACTOR # factor to increase generator noise maps by (useful when e.g. doubling 512px net to 1024px) 168 | --noconst # whether the generator was trained without a constant input layer 169 | --latent_dim LATENT_DIM # latent vector size of the generator 170 | --n_mlp N_MLP # number of mapping network layers 171 | --channel_multiplier CHANNEL_MULTIPLIER # generator's channel scaling multiplier 172 | ``` 173 | 174 | Alternatively, `generate()` can be called directly from python. It takes the same arguments as generate_audiovisual.py except instead of supplying an audioreactive_file, the functions should be supplied directly (i.e. initialize, get_latents, get_noise, get_bends, get_rewrites, and get_truncation as arguments). 175 | 176 | Model checkpoints can be converted from tensorflow .pkl's with [Kim Seonghyeon's script](https://github.com/rosinality/stylegan2-pytorch/blob/master/convert_weight.py) (the one in this repo is broken). Both StyleGAN2 and StyleGAN2-ADA tensorflow checkpoints should work once converted. A good place to find models is [this repo](https://github.com/justinpinkney/awesome-pretrained-stylegan2). 177 | 178 | There is minimal support for rendering with StyleGAN1 checkpoints as well, although only with latent and noise (no network bending or model rewriting). 179 | 180 | ## Citation 181 | 182 | If you use the techniques introduced in the paper or the code in this repository for your research, please cite the paper: 183 | ``` 184 | @InProceedings{Brouwer_2020_NeurIPS_Workshops}, 185 | author = {Brouwer, Hans}, 186 | title = {Audio-reactive Latent Interpolations with StyleGAN}, 187 | booktitle = {Proceedings of the 4th Workshop on Machine Learning for Creativity and Design at NeurIPS 2020}, 188 | month = {December}, 189 | year = {2020}, 190 | url={https://jcbrouwer.github.io/assets/audio-reactive-stylegan/paper.pdf} 191 | } 192 | ``` 193 | -------------------------------------------------------------------------------- /accelerate/accelerate_inception.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import wandb 4 | import argparse 5 | import torch as th 6 | from tqdm import tqdm 7 | from torch.utils import data 8 | import torch.nn.functional as F 9 | from inception_vae import InceptionVAE 10 | from dataset import MultiResolutionDataset 11 | from torchvision import transforms, utils, models 12 | 13 | 14 | def info(x): 15 | print(x.shape, x.detach().cpu().min(), x.detach().cpu().mean(), x.detach().cpu().max()) 16 | 17 | 18 | def sample_data(loader): 19 | while True: 20 | for batch in loader: 21 | yield batch 22 | 23 | 24 | class VGG19(th.nn.Module): 25 | """ 26 | Adapted from https://github.com/NVIDIA/pix2pixHD 27 | See LICENSE-VGG 28 | """ 29 | 30 | def __init__(self, requires_grad=False): 31 | super(VGG19, self).__init__() 32 | vgg_pretrained_features = models.vgg19(pretrained=True).features 33 | self.slice1 = th.nn.Sequential() 34 | self.slice2 = th.nn.Sequential() 35 | self.slice3 = th.nn.Sequential() 36 | self.slice4 = th.nn.Sequential() 37 | self.slice5 = th.nn.Sequential() 38 | for x in range(2): 39 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 40 | for x in range(2, 7): 41 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 42 | for x in range(7, 12): 43 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 44 | for x in range(12, 21): 45 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 46 | for x in range(21, 30): 47 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 48 | if not requires_grad: 49 | for param in self.parameters(): 50 | param.requires_grad = False 51 | 52 | def forward(self, X): 53 | h_relu1 = self.slice1(X) 54 | h_relu2 = self.slice2(h_relu1) 55 | h_relu3 = self.slice3(h_relu2) 56 | h_relu4 = self.slice4(h_relu3) 57 | h_relu5 = self.slice5(h_relu4) 58 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 59 | return out 60 | 61 | 62 | class VGGLoss(th.nn.Module): 63 | """ 64 | Adapted from https://github.com/NVIDIA/pix2pixHD 65 | See LICENSE-VGG 66 | """ 67 | 68 | def __init__(self): 69 | super(VGGLoss, self).__init__() 70 | self.vgg = VGG19().cuda() 71 | self.criterion = th.nn.L1Loss() 72 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 73 | 74 | def forward(self, x, y): 75 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 76 | loss = 0 77 | for i in range(len(x_vgg)): 78 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 79 | return loss 80 | 81 | 82 | def train(latent_dim, num_repeats, learning_rate, lambda_vgg, lambda_mse): 83 | print( 84 | f"latent_dim={latent_dim:.4f}", 85 | f"num_repeats={num_repeats:.4f}", 86 | f"learning_rate={learning_rate:.4f}", 87 | f"lambda_vgg={lambda_vgg:.4f}", 88 | f"lambda_mse={lambda_mse:.4f}", 89 | ) 90 | 91 | transform = transforms.Compose( 92 | [ 93 | transforms.Resize(128), 94 | transforms.RandomHorizontalFlip(p=0.5), 95 | transforms.ToTensor(), 96 | # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 97 | ] 98 | ) 99 | batch_size = 72 100 | data_path = "/home/hans/trainsets/cyphis" 101 | name = os.path.splitext(os.path.basename(data_path))[0] 102 | dataset = MultiResolutionDataset(data_path, transform, 256) 103 | dataloader = data.DataLoader( 104 | dataset, batch_size=batch_size, sampler=data.RandomSampler(dataset), num_workers=12, drop_last=True, 105 | ) 106 | loader = sample_data(dataloader) 107 | sample_imgs = next(loader)[:24] 108 | wandb.log({"Real Images": [wandb.Image(utils.make_grid(sample_imgs, nrow=6, normalize=True, range=(0, 1)))]}) 109 | 110 | vae, vae_optim = None, None 111 | vae = InceptionVAE(latent_dim=latent_dim, repeat_per_block=num_repeats).to(device) 112 | vae_optim = th.optim.Adam(vae.parameters(), lr=learning_rate) 113 | 114 | vgg = VGGLoss() 115 | 116 | # sample_z = th.randn(size=(24, 512)) 117 | 118 | scores = [] 119 | num_iters = 100_000 120 | pbar = tqdm(range(num_iters), smoothing=0.1) 121 | for i in pbar: 122 | vae.train() 123 | 124 | real = next(loader).to(device) 125 | 126 | fake, mu, log_var = vae(real) 127 | 128 | bce = F.binary_cross_entropy(fake, real, size_average=False) 129 | kld = -0.5 * th.sum(1 + log_var - mu.pow(2) - log_var.exp()) 130 | vgg_loss = vgg(fake, real) 131 | mse_loss = th.sqrt((fake - real).pow(2).mean()) 132 | 133 | loss = bce + kld + lambda_vgg * vgg_loss + lambda_mse * mse_loss 134 | 135 | loss_dict = { 136 | "Total": loss, 137 | "BCE": bce, 138 | "Kullback Leibler Divergence": kld, 139 | "MSE": mse_loss, 140 | "VGG": vgg_loss, 141 | } 142 | 143 | vae.zero_grad() 144 | loss.backward() 145 | vae_optim.step() 146 | 147 | wandb.log(loss_dict) 148 | 149 | with th.no_grad(): 150 | if i % int(num_iters / 100) == 0 or i + 1 == num_iters: 151 | vae.eval() 152 | 153 | sample, _, _ = vae(sample_imgs.to(device)) 154 | grid = utils.make_grid(sample, nrow=6, normalize=True, range=(0, 1)) 155 | del sample 156 | wandb.log({"Reconstructed Images VAE": [wandb.Image(grid, caption=f"Step {i}")]}) 157 | 158 | sample = vae.sampling() 159 | grid = utils.make_grid(sample, nrow=6, normalize=True, range=(0, 1)) 160 | del sample 161 | wandb.log({"Generated Images VAE": [wandb.Image(grid, caption=f"Step {i}")]}) 162 | 163 | gc.collect() 164 | th.cuda.empty_cache() 165 | 166 | th.save( 167 | {"vae": vae.state_dict(), "vae_optim": vae_optim.state_dict()}, 168 | f"/home/hans/modelzoo/maua-sg2/vae-{name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}.pt", 169 | ) 170 | 171 | if th.isnan(loss).any() or th.isinf(loss).any(): 172 | print("NaN losses, exiting...") 173 | print( 174 | { 175 | "Total": loss, 176 | "\nBCE": bce, 177 | "\nKullback Leibler Divergence": kld, 178 | "\nMSE": mse_loss, 179 | "\nVGG": vgg_loss, 180 | } 181 | ) 182 | wandb.log({"Total": 27000}) 183 | return 184 | 185 | 186 | if __name__ == "__main__": 187 | parser = argparse.ArgumentParser() 188 | parser.add_argument("--latent_dim", type=float, default=512) 189 | parser.add_argument("--num_repeats", type=float, default=1) 190 | parser.add_argument("--learning_rate", type=float, default=0.005) 191 | parser.add_argument("--lambda_vgg", type=float, default=1.0) 192 | parser.add_argument("--lambda_mse", type=float, default=1.0) 193 | args = parser.parse_args() 194 | 195 | device = "cuda" 196 | th.backends.cudnn.benchmark = True 197 | 198 | wandb.init(project=f"maua-stylegan") 199 | 200 | train( 201 | args.latent_dim, args.num_repeats, args.learning_rate, args.lambda_vgg, args.lambda_mse, 202 | ) 203 | 204 | -------------------------------------------------------------------------------- /accelerate/accelerate_logcosh.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import wandb 4 | import argparse 5 | import validation 6 | import torch as th 7 | from tqdm import tqdm 8 | from torch.utils import data 9 | from autoencoder import LogCoshVAE 10 | from dataset import MultiResolutionDataset 11 | from torchvision import transforms, utils, models 12 | 13 | 14 | def data_sampler(dataset, shuffle, distributed): 15 | if distributed: 16 | return data.distributed.DistributedSampler(dataset, shuffle=shuffle) 17 | if shuffle: 18 | return data.RandomSampler(dataset) 19 | else: 20 | return data.SequentialSampler(dataset) 21 | 22 | 23 | def sample_data(loader): 24 | while True: 25 | for batch in loader: 26 | yield batch 27 | 28 | 29 | class VGG19(th.nn.Module): 30 | """ 31 | Adapted from https://github.com/NVIDIA/pix2pixHD 32 | See LICENSE-VGG 33 | """ 34 | 35 | def __init__(self, requires_grad=False): 36 | super(VGG19, self).__init__() 37 | vgg_pretrained_features = models.vgg19(pretrained=True).features 38 | self.slice1 = th.nn.Sequential() 39 | self.slice2 = th.nn.Sequential() 40 | self.slice3 = th.nn.Sequential() 41 | self.slice4 = th.nn.Sequential() 42 | self.slice5 = th.nn.Sequential() 43 | for x in range(2): 44 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 45 | for x in range(2, 7): 46 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 47 | for x in range(7, 12): 48 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 49 | for x in range(12, 21): 50 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 51 | for x in range(21, 30): 52 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 53 | if not requires_grad: 54 | for param in self.parameters(): 55 | param.requires_grad = False 56 | 57 | def forward(self, X): 58 | h_relu1 = self.slice1(X) 59 | h_relu2 = self.slice2(h_relu1) 60 | h_relu3 = self.slice3(h_relu2) 61 | h_relu4 = self.slice4(h_relu3) 62 | h_relu5 = self.slice5(h_relu4) 63 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 64 | return out 65 | 66 | 67 | class VGGLoss(th.nn.Module): 68 | """ 69 | Adapted from https://github.com/NVIDIA/pix2pixHD 70 | See LICENSE-VGG 71 | """ 72 | 73 | def __init__(self): 74 | super(VGGLoss, self).__init__() 75 | self.vgg = VGG19().cuda() 76 | self.criterion = th.nn.L1Loss() 77 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 78 | 79 | def forward(self, x, y): 80 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 81 | loss = 0 82 | for i in range(len(x_vgg)): 83 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 84 | return loss 85 | 86 | 87 | device = "cuda" 88 | th.backends.cudnn.benchmark = True 89 | 90 | wandb.init(project=f"maua-stylegan") 91 | 92 | 93 | def train(latent_dim, learning_rate, number_filters, vae_alpha, vae_beta, kl_divergence_weight): 94 | print( 95 | f"latent_dim={latent_dim}", 96 | f"learning_rate={learning_rate}", 97 | f"number_filters={number_filters}", 98 | f"vae_alpha={vae_alpha}", 99 | f"vae_beta={vae_beta}", 100 | f"kl_divergence_weight={kl_divergence_weight}", 101 | ) 102 | 103 | batch_size = 64 104 | i = None 105 | while batch_size >= 1: 106 | try: 107 | transform = transforms.Compose( 108 | [ 109 | transforms.Resize(128), 110 | transforms.RandomHorizontalFlip(p=0.5), 111 | transforms.ToTensor(), 112 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 113 | ] 114 | ) 115 | data_path = "/home/hans/trainsets/cyphis" 116 | name = os.path.splitext(os.path.basename(data_path))[0] 117 | dataset = MultiResolutionDataset(data_path, transform, 256) 118 | dataloader = data.DataLoader( 119 | dataset, 120 | batch_size=int(batch_size), 121 | sampler=data_sampler(dataset, shuffle=True, distributed=False), 122 | num_workers=12, 123 | drop_last=True, 124 | ) 125 | loader = sample_data(dataloader) 126 | sample_imgs = next(loader)[:24] 127 | wandb.log( 128 | {"Real Images": [wandb.Image(utils.make_grid(sample_imgs, nrow=6, normalize=True, range=(-1, 1)))]} 129 | ) 130 | 131 | hidden_dims = [min(int(number_filters) * 2 ** i, latent_dim) for i in range(5)] + [latent_dim] 132 | vae, vae_optim = None, None 133 | vae = LogCoshVAE( 134 | 3, latent_dim, hidden_dims=hidden_dims, alpha=vae_alpha, beta=vae_beta, kld_weight=kl_divergence_weight, 135 | ).to(device) 136 | vae.train() 137 | vae_optim = th.optim.Adam(vae.parameters(), lr=learning_rate) 138 | 139 | mse_loss = th.nn.MSELoss() 140 | vgg = VGGLoss() 141 | 142 | sample_z = th.randn(size=(24, latent_dim)) 143 | 144 | scores = [] 145 | num_iters = 100_000 146 | pbar = range(num_iters) 147 | pbar = tqdm(pbar, smoothing=0.1) 148 | for i in pbar: 149 | vae.train() 150 | 151 | real = next(loader).to(device) 152 | fake, mu, log_var = vae(real) 153 | 154 | loss_dict = vae.loss(real, fake, mu, log_var) 155 | vgg_loss = vgg(fake, real) 156 | loss = loss_dict["Total"] + vgg_loss 157 | 158 | vae.zero_grad() 159 | loss.backward() 160 | vae_optim.step() 161 | 162 | wandb.log( 163 | { 164 | "Total": loss, 165 | "VGG": vgg_loss, 166 | "Reconstruction": loss_dict["Reconstruction"], 167 | "Kullback Leibler Divergence": loss_dict["Kullback Leibler Divergence"], 168 | } 169 | ) 170 | 171 | if i % int(num_iters / 1000) == 0 or i + 1 == num_iters: 172 | with th.no_grad(): 173 | vae.eval() 174 | 175 | sample, _, _ = vae(sample_imgs.to(device)) 176 | grid = utils.make_grid(sample, nrow=6, normalize=True, range=(-1, 1),) 177 | del sample 178 | wandb.log({"Reconstructed Images VAE": [wandb.Image(grid, caption=f"Step {i}")]}) 179 | 180 | sample = vae.decode(sample_z.to(device)) 181 | grid = utils.make_grid(sample, nrow=6, normalize=True, range=(-1, 1),) 182 | del sample 183 | wandb.log({"Generated Images VAE": [wandb.Image(grid, caption=f"Step {i}")]}) 184 | 185 | if i % int(num_iters / 40) == 0 or i + 1 == num_iters: 186 | with th.no_grad(): 187 | fid_dict = validation.vae_fid(vae, int(batch_size), (latent_dim,), 5000, name) 188 | wandb.log(fid_dict) 189 | mse = mse_loss(fake, real) * 5000 190 | score = fid_dict["FID"] + mse + 1000 * vgg_loss 191 | wandb.log({"Score": score}) 192 | pbar.set_description(f"FID: {fid_dict['FID']:.2f} MSE: {mse:.2f} VGG: {1000 * vgg_loss:.2f}") 193 | 194 | if i >= num_iters / 2: 195 | scores.append(score) 196 | 197 | if th.isnan(loss).any() or th.isinf(loss).any(): 198 | print("NaN losses, exiting...") 199 | print( 200 | { 201 | "Total": loss.detach().cpu().item(), 202 | "\nVGG": vgg_loss.detach().cpu().item(), 203 | "\nReconstruction": loss_dict["Reconstruction"].detach().cpu().item(), 204 | "\nKullback Leibler Divergence": loss_dict["Kullback Leibler Divergence"] 205 | .detach() 206 | .cpu() 207 | .item(), 208 | } 209 | ) 210 | wandb.log({"Score": 27000}) 211 | return 212 | 213 | return 214 | 215 | except RuntimeError as e: 216 | if "CUDA out of memory" in str(e): 217 | batch_size = batch_size / 2 218 | 219 | if batch_size < 1: 220 | print("This configuration does not fit into memory, exiting...") 221 | wandb.log({"Score": 27000}) 222 | return 223 | 224 | print(f"Out of memory, halving batch size... {batch_size}") 225 | if vae is not None: 226 | del vae 227 | if vae_optim is not None: 228 | del vae_optim 229 | gc.collect() 230 | th.cuda.empty_cache() 231 | 232 | else: 233 | print(e) 234 | return 235 | 236 | 237 | parser = argparse.ArgumentParser() 238 | parser.add_argument("--latent_dim", type=int, default=1024) 239 | parser.add_argument("--learning_rate", type=float, default=0.005) 240 | parser.add_argument("--number_filters", type=int, default=64) 241 | parser.add_argument("--vae_alpha", type=float, default=10.0) 242 | parser.add_argument("--vae_beta", type=float, default=1.0) 243 | parser.add_argument("--kl_divergence_weight", type=float, default=1.0) 244 | args = parser.parse_args() 245 | 246 | train( 247 | args.latent_dim, args.learning_rate, args.number_filters, args.vae_alpha, args.vae_beta, args.kl_divergence_weight, 248 | ) 249 | 250 | -------------------------------------------------------------------------------- /accelerate/accelerate_segnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | import gc 3 | import wandb 4 | import argparse 5 | import torch as th 6 | from tqdm import tqdm 7 | from torch.utils import data 8 | from autoencoder import ConvSegNet 9 | from dataset import MultiResolutionDataset 10 | from torchvision import transforms, utils, models 11 | 12 | 13 | def info(x): 14 | print(x.shape, x.detach().cpu().min(), x.detach().cpu().mean(), x.detach().cpu().max()) 15 | 16 | 17 | def sample_data(loader): 18 | while True: 19 | for batch in loader: 20 | yield batch 21 | 22 | 23 | class VGG19(th.nn.Module): 24 | """ 25 | Adapted from https://github.com/NVIDIA/pix2pixHD 26 | See LICENSE-VGG 27 | """ 28 | 29 | def __init__(self, requires_grad=False): 30 | super(VGG19, self).__init__() 31 | vgg_pretrained_features = models.vgg19(pretrained=True).features 32 | self.slice1 = th.nn.Sequential() 33 | self.slice2 = th.nn.Sequential() 34 | self.slice3 = th.nn.Sequential() 35 | self.slice4 = th.nn.Sequential() 36 | self.slice5 = th.nn.Sequential() 37 | for x in range(2): 38 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 39 | for x in range(2, 7): 40 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 41 | for x in range(7, 12): 42 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 43 | for x in range(12, 21): 44 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 45 | for x in range(21, 30): 46 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 47 | if not requires_grad: 48 | for param in self.parameters(): 49 | param.requires_grad = False 50 | 51 | def forward(self, X): 52 | h_relu1 = self.slice1(X) 53 | h_relu2 = self.slice2(h_relu1) 54 | h_relu3 = self.slice3(h_relu2) 55 | h_relu4 = self.slice4(h_relu3) 56 | h_relu5 = self.slice5(h_relu4) 57 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 58 | return out 59 | 60 | 61 | class VGGLoss(th.nn.Module): 62 | """ 63 | Adapted from https://github.com/NVIDIA/pix2pixHD 64 | See LICENSE-VGG 65 | """ 66 | 67 | def __init__(self): 68 | super(VGGLoss, self).__init__() 69 | self.vgg = VGG19().cuda() 70 | self.criterion = th.nn.L1Loss() 71 | self.weights = [1.0 / 32, 1.0 / 16, 1.0 / 8, 1.0 / 4, 1.0] 72 | 73 | def forward(self, x, y): 74 | x_vgg, y_vgg = self.vgg(x), self.vgg(y) 75 | loss = 0 76 | for i in range(len(x_vgg)): 77 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 78 | return loss 79 | 80 | 81 | def align(x, y, alpha=2): 82 | return (x - y).norm(p=2, dim=1).pow(alpha).mean() 83 | 84 | 85 | def uniform(x, t=2): 86 | return (th.pdist(x.view(x.size(0), -1), p=2).pow(2).mul(-t).exp().mean() + 1e-27).log() 87 | 88 | 89 | def train(learning_rate, lambda_mse): 90 | print( 91 | f"learning_rate={learning_rate:.4f}", f"lambda_mse={lambda_mse:.4f}", 92 | ) 93 | 94 | transform = transforms.Compose( 95 | [ 96 | transforms.Resize(128), 97 | transforms.RandomHorizontalFlip(p=0.5), 98 | transforms.ToTensor(), 99 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 100 | ] 101 | ) 102 | batch_size = 72 103 | data_path = "/home/hans/trainsets/cyphis" 104 | name = os.path.splitext(os.path.basename(data_path))[0] 105 | dataset = MultiResolutionDataset(data_path, transform, 256) 106 | dataloader = data.DataLoader( 107 | dataset, batch_size=batch_size, sampler=data.RandomSampler(dataset), num_workers=12, drop_last=True, 108 | ) 109 | loader = sample_data(dataloader) 110 | sample_imgs = next(loader)[:24] 111 | wandb.log({"Real Images": [wandb.Image(utils.make_grid(sample_imgs, nrow=6, normalize=True, range=(-1, 1)))]}) 112 | 113 | vae, vae_optim = None, None 114 | vae = ConvSegNet().to(device) 115 | vae_optim = th.optim.Adam(vae.parameters(), lr=learning_rate) 116 | 117 | vgg = VGGLoss() 118 | 119 | sample_z = th.randn(size=(24, 512, 16, 16)) 120 | sample_z /= sample_z.abs().max() 121 | 122 | scores = [] 123 | num_iters = 100_000 124 | pbar = tqdm(range(num_iters), smoothing=0.1) 125 | for i in pbar: 126 | vae.train() 127 | 128 | real = next(loader).to(device) 129 | 130 | z = vae.encode(real) 131 | fake = vae.decode(z) 132 | 133 | vgg_loss = vgg(fake, real) 134 | 135 | mse_loss = th.sqrt((fake - real).pow(2).mean()) 136 | 137 | # diff = fake - real 138 | # recons_loss = recons_alpha * diff + th.log(1.0 + th.exp(-2 * recons_alpha * diff)) - th.log(th.tensor(2.0)) 139 | # recons_loss = (1.0 / recons_alpha) * recons_loss.mean() 140 | # recons_loss = recons_loss if not th.isinf(recons_loss).any() else 0 141 | 142 | # x, y = z.chunk(2) 143 | # align_loss = align(x, y, alpha=align_alpha) 144 | # unif_loss = -(uniform(x, t=unif_t) + uniform(y, t=unif_t)) / 2.0 145 | 146 | loss = ( 147 | vgg_loss 148 | + lambda_mse * mse_loss 149 | # + lambda_recons * recons_loss 150 | # + lambda_align * align_loss 151 | # + lambda_unif * unif_loss 152 | ) 153 | # print(vgg_loss.detach().cpu().item()) 154 | # print(lambda_mse * mse_loss.detach().cpu().item()) 155 | # # print(lambda_recons * recons_loss.detach().cpu().item()) 156 | # print(lambda_align * align_loss.detach().cpu().item()) 157 | # print(lambda_unif * unif_loss.detach().cpu().item()) 158 | 159 | loss_dict = { 160 | "Total": loss, 161 | "MSE": mse_loss, 162 | "VGG": vgg_loss, 163 | # "Reconstruction": recons_loss, 164 | # "Alignment": align_loss, 165 | # "Uniformity": unif_loss, 166 | } 167 | 168 | vae.zero_grad() 169 | loss.backward() 170 | vae_optim.step() 171 | 172 | wandb.log(loss_dict) 173 | # pbar.set_description(" ".join()) 174 | 175 | with th.no_grad(): 176 | if i % int(num_iters / 100) == 0 or i + 1 == num_iters: 177 | vae.eval() 178 | 179 | sample = vae(sample_imgs.to(device)) 180 | grid = utils.make_grid(sample, nrow=6, normalize=True, range=(-1, 1)) 181 | del sample 182 | wandb.log({"Reconstructed Images VAE": [wandb.Image(grid, caption=f"Step {i}")]}) 183 | 184 | sample = vae.decode(sample_z.to(device)) 185 | grid = utils.make_grid(sample, nrow=6, normalize=True, range=(-1, 1)) 186 | del sample 187 | wandb.log({"Generated Images VAE": [wandb.Image(grid, caption=f"Step {i}")]}) 188 | 189 | gc.collect() 190 | th.cuda.empty_cache() 191 | 192 | th.save( 193 | {"vae": vae.state_dict(), "vae_optim": vae_optim.state_dict()}, 194 | f"/home/hans/modelzoo/maua-sg2/vae-{name}-{wandb.run.dir.split('/')[-1].split('-')[-1]}.pt", 195 | ) 196 | 197 | if th.isnan(loss).any(): 198 | print("NaN losses, exiting...") 199 | wandb.log({"Total": 27000}) 200 | return 201 | 202 | 203 | if __name__ == "__main__": 204 | parser = argparse.ArgumentParser() 205 | parser.add_argument("--learning_rate", type=float, default=0.005) 206 | parser.add_argument("--lambda_mse", type=float, default=1.0) 207 | # parser.add_argument("--lambda_recons", type=float, default=0.0) 208 | # parser.add_argument("--recons_alpha", type=float, default=5.0) 209 | # parser.add_argument("--lambda_align", type=float, default=1.0) 210 | # parser.add_argument("--align_alpha", type=float, default=2.0) 211 | # parser.add_argument("--lambda_unif", type=float, default=1.0) 212 | # parser.add_argument("--unif_t", type=float, default=0.001) 213 | args = parser.parse_args() 214 | 215 | device = "cuda" 216 | th.backends.cudnn.benchmark = True 217 | 218 | wandb.init(project=f"maua-stylegan") 219 | 220 | train( 221 | args.learning_rate, 222 | args.lambda_mse, 223 | # args.lambda_recons, 224 | # args.recons_alpha, 225 | # args.lambda_align, 226 | # args.align_alpha, 227 | # args.lambda_unif, 228 | # args.unif_t, 229 | ) 230 | 231 | -------------------------------------------------------------------------------- /audioreactive/__init__.py: -------------------------------------------------------------------------------- 1 | from .bend import * 2 | from .examples import * 3 | from .latent import * 4 | from .signal import * 5 | from .util import * 6 | -------------------------------------------------------------------------------- /audioreactive/bend.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import kornia.augmentation as kA 4 | import kornia.geometry.transform as kT 5 | import torch as th 6 | 7 | # ==================================================================================== 8 | # ================================= network bending ================================== 9 | # ==================================================================================== 10 | 11 | 12 | class NetworkBend(th.nn.Module): 13 | """Base network bending class 14 | 15 | Args: 16 | sequential_fn (function): Function that takes a batch of modulation and creates th.nn.Sequential 17 | modulation (th.tensor): Modulation batch 18 | """ 19 | 20 | def __init__(self, sequential_fn, modulation): 21 | super(NetworkBend, self).__init__() 22 | self.sequential = sequential_fn(modulation) 23 | 24 | def forward(self, x): 25 | return self.sequential(x) 26 | 27 | 28 | class AddNoise(th.nn.Module): 29 | """Adds static noise to output 30 | 31 | Args: 32 | noise (th.tensor): Noise to be added 33 | """ 34 | 35 | def __init__(self, noise): 36 | super(AddNoise, self).__init__() 37 | self.noise = noise 38 | 39 | def forward(self, x): 40 | return x + self.noise.to(x.device) 41 | 42 | 43 | class Print(th.nn.Module): 44 | """Prints intermediate feature statistics (useful for debugging complicated network bends).""" 45 | 46 | def forward(self, x): 47 | print(x.shape, [x.min().item(), x.mean().item(), x.max().item()], th.std(x).item()) 48 | return x 49 | 50 | 51 | class Translate(NetworkBend): 52 | """Creates horizontal translating effect where repeated linear interpolations from 0 to 1 (saw tooth wave) creates seamless scrolling effect. 53 | 54 | Args: 55 | modulation (th.tensor): [0.0-1.0]. Batch of modulation 56 | h (int): Height of intermediate features that the network bend is applied to 57 | w (int): Width of intermediate features that the network bend is applied to 58 | noise (int): Noise to be added (must be 5 * width wide) 59 | """ 60 | 61 | def __init__(self, modulation, h, w, noise): 62 | sequential_fn = lambda b: th.nn.Sequential( 63 | th.nn.ReflectionPad2d((int(w / 2), int(w / 2), 0, 0)), 64 | th.nn.ReflectionPad2d((w, w, 0, 0)), 65 | th.nn.ReflectionPad2d((w, 0, 0, 0)), 66 | AddNoise(noise), 67 | kT.Translate(b), 68 | kA.CenterCrop((h, w)), 69 | ) 70 | super(Translate, self).__init__(sequential_fn, modulation) 71 | 72 | 73 | class Zoom(NetworkBend): 74 | """Creates zooming effect. 75 | 76 | Args: 77 | modulation (th.tensor): [0.0-1.0]. Batch of modulation 78 | h (int): height of intermediate features that the network bend is applied to 79 | w (int): width of intermediate features that the network bend is applied to 80 | """ 81 | 82 | def __init__(self, modulation, h, w): 83 | padding = int(max(h, w)) - 1 84 | sequential_fn = lambda b: th.nn.Sequential(th.nn.ReflectionPad2d(padding), kT.Scale(b), kA.CenterCrop((h, w))) 85 | super(Zoom, self).__init__(sequential_fn, modulation) 86 | 87 | 88 | class Rotate(NetworkBend): 89 | """Creates rotation effect. 90 | 91 | Args: 92 | modulation (th.tensor): [0.0-1.0]. Batch of modulation 93 | h (int): height of intermediate features that the network bend is applied to 94 | w (int): width of intermediate features that the network bend is applied to 95 | """ 96 | 97 | def __init__(self, modulation, h, w): 98 | # worst case rotation brings sqrt(2) * max_side_length out-of-frame pixels into frame 99 | # padding should cover that exactly 100 | padding = int(max(h, w) * (1 - math.sqrt(2) / 2)) 101 | sequential_fn = lambda b: th.nn.Sequential(th.nn.ReflectionPad2d(padding), kT.Rotate(b), kA.CenterCrop((h, w))) 102 | super(Rotate, self).__init__(sequential_fn, modulation) 103 | -------------------------------------------------------------------------------- /audioreactive/examples/Wavefunk - Dwelling in the Kelp.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/audioreactive/examples/Wavefunk - Dwelling in the Kelp.mp3 -------------------------------------------------------------------------------- /audioreactive/examples/Wavefunk - Tau Ceti Alpha.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/audioreactive/examples/Wavefunk - Tau Ceti Alpha.mp3 -------------------------------------------------------------------------------- /audioreactive/examples/Wavefunk - Temper.mp3: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/audioreactive/examples/Wavefunk - Temper.mp3 -------------------------------------------------------------------------------- /audioreactive/examples/__init__.py: -------------------------------------------------------------------------------- 1 | from . import * 2 | -------------------------------------------------------------------------------- /audioreactive/examples/default.py: -------------------------------------------------------------------------------- 1 | import torch as th 2 | 3 | import audioreactive as ar 4 | 5 | 6 | def initialize(args): 7 | args.lo_onsets = ar.onsets(args.audio, args.sr, args.n_frames, fmax=150, smooth=5, clip=97, power=2) 8 | args.hi_onsets = ar.onsets(args.audio, args.sr, args.n_frames, fmin=500, smooth=5, clip=99, power=2) 9 | return args 10 | 11 | 12 | def get_latents(selection, args): 13 | chroma = ar.chroma(args.audio, args.sr, args.n_frames) 14 | chroma_latents = ar.chroma_weight_latents(chroma, selection) 15 | latents = ar.gaussian_filter(chroma_latents, 4) 16 | 17 | lo_onsets = args.lo_onsets[:, None, None] 18 | hi_onsets = args.hi_onsets[:, None, None] 19 | 20 | latents = hi_onsets * selection[[-4]] + (1 - hi_onsets) * latents 21 | latents = lo_onsets * selection[[-7]] + (1 - lo_onsets) * latents 22 | 23 | latents = ar.gaussian_filter(latents, 2, causal=0.2) 24 | 25 | return latents 26 | 27 | 28 | def get_noise(height, width, scale, num_scales, args): 29 | if width > 256: 30 | return None 31 | 32 | lo_onsets = args.lo_onsets[:, None, None, None].cuda() 33 | hi_onsets = args.hi_onsets[:, None, None, None].cuda() 34 | 35 | noise_noisy = ar.gaussian_filter(th.randn((args.n_frames, 1, height, width), device="cuda"), 5) 36 | noise = ar.gaussian_filter(th.randn((args.n_frames, 1, height, width), device="cuda"), 128) 37 | 38 | if width < 128: 39 | noise = lo_onsets * noise_noisy + (1 - lo_onsets) * noise 40 | if width > 32: 41 | noise = hi_onsets * noise_noisy + (1 - hi_onsets) * noise 42 | 43 | noise /= noise.std() * 2.5 44 | 45 | return noise.cpu() 46 | -------------------------------------------------------------------------------- /audioreactive/examples/kelp.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file shows an example of a loop based interpolation 3 | Here sections are identified with laplacian segmentation and looping latents are generated for each section 4 | The noise is looping perlin noise 5 | Long term section analysis is done with the RMS to interpolate between latent sequences for the intro/outro and drop 6 | """ 7 | 8 | 9 | import librosa as rosa 10 | import torch as th 11 | 12 | import audioreactive as ar 13 | 14 | OVERRIDE = dict(audio_file="audioreactive/examples/Wavefunk - Dwelling in the Kelp.mp3", out_size=1920) 15 | BPM = 130 16 | 17 | 18 | def initialize(args): 19 | # RMS can be used to distinguish between the drop sections and intro/outros 20 | rms = ar.rms(args.audio, args.sr, args.n_frames, smooth=10, clip=60, power=1) 21 | rms = ar.expand(rms, threshold=0.8, ratio=10) 22 | rms = ar.gaussian_filter(rms, 4) 23 | rms = ar.normalize(rms) 24 | args.rms = rms 25 | 26 | # cheating a little here, this my song so I have the multitracks 27 | # this is much easier than fiddling with onsets until you have envelopes that dance nicely to the drums 28 | audio, sr = rosa.load("workspace/kelpkick.wav", offset=args.offset, duration=args.duration) 29 | args.kick_onsets = ar.onsets(audio, sr, args.n_frames, margin=1, smooth=4) 30 | audio, sr = rosa.load("workspace/kelpsnare.wav", offset=args.offset, duration=args.duration) 31 | args.snare_onsets = ar.onsets(audio, sr, args.n_frames, margin=1, smooth=4) 32 | 33 | ar.plot_signals([args.rms, args.kick_onsets, args.snare_onsets]) 34 | 35 | return args 36 | 37 | 38 | def get_latents(selection, args): 39 | # expand envelopes to latent shape 40 | rms = args.rms[:, None, None] 41 | low_onsets = args.kick_onsets[:, None, None] 42 | high_onsets = args.snare_onsets[:, None, None] 43 | 44 | # get timestamps and labels with laplacian segmentation 45 | # k is the number of labels the algorithm may use 46 | # try multiple values with plot=True to see which value correlates best with the sections of the song 47 | timestamps, labels = ar.laplacian_segmentation(args.audio, args.sr, k=7) 48 | 49 | # a second set of latents for the drop section, the 'selection' variable is the other set for the intro 50 | drop_selection = ar.load_latents("workspace/cyphept_kelp_drop_latents.npy") 51 | color_layer = 9 52 | 53 | latents = [] 54 | for (start, stop), l in zip(zip(timestamps, timestamps[1:]), labels): 55 | start_frame = int(round(start / args.duration * args.n_frames)) 56 | stop_frame = int(round(stop / args.duration * args.n_frames)) 57 | section_frames = stop_frame - start_frame 58 | section_bars = (stop - start) * (BPM / 60) / 4 59 | 60 | # get portion of latent selection (wrapping around to start) 61 | latent_selection_slice = ar.wrapping_slice(selection, l, 4) 62 | # spline interpolation loops through selection slice 63 | latent_section = ar.spline_loops(latent_selection_slice, n_frames=section_frames, n_loops=section_bars / 4) 64 | # set the color with laplacian segmentation label, (1 latent repeated for entire section in upper layers) 65 | latent_section[:, color_layer:] = th.cat([selection[[l], color_layer:]] * section_frames) 66 | 67 | # same as above but for the drop latents (with faster loops) 68 | drop_selection_slice = ar.wrapping_slice(drop_selection, l, 4) 69 | drop_section = ar.spline_loops(drop_selection_slice, n_frames=section_frames, n_loops=section_bars / 2) 70 | drop_section[:, color_layer:] = th.cat([drop_selection[[l], color_layer:]] * section_frames) 71 | 72 | # merged based on RMS (drop section or not) 73 | latents.append((1 - rms[start_frame:stop_frame]) * latent_section + rms[start_frame:stop_frame] * drop_section) 74 | 75 | # concatenate latents to correct length & smooth over the junctions 76 | len_latents = sum([len(l) for l in latents]) 77 | if len_latents != args.n_frames: 78 | latents.append(th.cat([latents[-1][[-1]]] * (args.n_frames - len_latents))) 79 | latents = th.cat(latents).float() 80 | latents = ar.gaussian_filter(latents, 3) 81 | 82 | # use onsets to modulate towards latents 83 | latents = 0.666 * low_onsets * selection[[2]] + (1 - 0.666 * low_onsets) * latents 84 | latents = 0.666 * high_onsets * selection[[1]] + (1 - 0.666 * high_onsets) * latents 85 | 86 | latents = ar.gaussian_filter(latents, 1, causal=0.2) 87 | return latents 88 | 89 | 90 | def get_noise(height, width, scale, num_scales, args): 91 | if width > 512: # larger sizes don't fit in VRAM, just use default or randomize 92 | return 93 | 94 | num_bars = int(round(args.duration * (BPM / 60) / 4)) 95 | frames_per_loop = int(args.n_frames / num_bars * 2) # loop every 2 bars 96 | 97 | def perlin_pls(resolution): 98 | perlin = ar.perlin_noise(shape=(frames_per_loop, height, width), res=resolution)[:, None, ...].cpu() 99 | perlin = th.cat([perlin] * int(num_bars / 2)) # concatenate multiple copies for looping 100 | if args.n_frames - len(perlin) > 0: 101 | perlin = th.cat([perlin, th.cat([perlin[[-1]]] * (args.n_frames - len(perlin)))]) # fix up rounding errors 102 | return perlin 103 | 104 | smooth = perlin_pls(resolution=(1, 1, 1)) # (time res, x res, y res) 105 | noise = perlin_pls(resolution=(8, 4, 4)) # higher resolution => higher frequency noise => more movement in video 106 | 107 | rms = args.rms[:, None, None, None] 108 | noise = rms * noise + (1 - rms) * smooth # blend between noises based on drop (high rms) or not 109 | 110 | return noise 111 | 112 | 113 | def get_bends(args): 114 | # repeat the intermediate features outwards on both sides (2:1 aspect ratio) 115 | # + add some noise to give the whole thing a little variation (disguises the repetition) 116 | transform = th.nn.Sequential( 117 | th.nn.ReplicationPad2d((2, 2, 0, 0)), ar.AddNoise(0.025 * th.randn(size=(1, 1, 4, 8), device="cuda")), 118 | ) 119 | bends = [{"layer": 0, "transform": transform}] 120 | 121 | return bends 122 | -------------------------------------------------------------------------------- /audioreactive/examples/tauceti.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file shows an example of network bending 3 | The latents and noise are similar to temper.py (although without spatial noise controls) 4 | The latents cycle through different colors for different sections of the drop 5 | During the drop, a translation is applied which makes the video seem to scroll endlessly 6 | """ 7 | 8 | from functools import partial 9 | 10 | import numpy as np 11 | import torch as th 12 | 13 | import audioreactive as ar 14 | 15 | OVERRIDE = dict( 16 | audio_file="audioreactive/examples/Wavefunk - Tau Ceti Alpha.mp3", 17 | out_size=1920, # get bends assumes 1920x1080 output size 18 | dataparallel=False, # makes use of a kornia transform during network bending => not compatible with dataparallel 19 | fps=30, # 5591 magic number below is based on number of frames in output video with fps of 30 20 | ) 21 | 22 | 23 | def initialize(args): 24 | args.low_onsets = ar.onsets(args.audio, args.sr, args.n_frames, fmax=150, smooth=5, clip=97, power=2) 25 | args.high_onsets = ar.onsets(args.audio, args.sr, args.n_frames, fmin=500, smooth=5, clip=99, power=2) 26 | return args 27 | 28 | 29 | def get_latents(selection, args): 30 | chroma = ar.chroma(args.audio, args.sr, args.n_frames) 31 | chroma_latents = ar.chroma_weight_latents(chroma, selection[:12]) # shape [n_frames, 18, 512] 32 | latents = ar.gaussian_filter(chroma_latents, 5) 33 | 34 | lo_onsets = args.low_onsets[:, None, None] # expand to same shape as latents [n_frames, 1, 1] 35 | hi_onsets = args.high_onsets[:, None, None] 36 | 37 | latents = hi_onsets * selection[[-4]] + (1 - hi_onsets) * latents 38 | latents = lo_onsets * selection[[-7]] + (1 - lo_onsets) * latents 39 | 40 | latents = ar.gaussian_filter(latents, 5, causal=0) 41 | 42 | # cheating a little, you could probably do this with laplacian segmentation, but is it worth the effort? 43 | drop_start = int(5591 * (45 / args.duration)) 44 | drop_end = int(5591 * (135 / args.duration)) 45 | 46 | # selection of latents with different colors (chosen with select_latents.py) 47 | color_latent_selection = th.from_numpy(np.load("workspace/cyphept-multicolor-latents.npy")) 48 | 49 | # build sequence of latents for just the upper layers 50 | color_layer = 9 51 | color_latents = [latents[:drop_start, color_layer:]] 52 | 53 | # for 4 different sections in the drop, use a different color latent 54 | drop_length = drop_end - drop_start 55 | section_length = int(drop_length / 4) 56 | for i, section_start in enumerate(range(0, drop_length, section_length)): 57 | if i > 3: 58 | break 59 | color_latents.append(th.cat([color_latent_selection[[i], color_layer:]] * section_length)) 60 | 61 | # ensure color sequence is correct length and concatenate 62 | if drop_length - 4 * section_length != 0: 63 | color_latents.append(th.cat([color_latent_selection[[i], color_layer:]] * (drop_length - 4 * section_length))) 64 | color_latents.append(latents[drop_end:, color_layer:]) 65 | color_latents = th.cat(color_latents, axis=0) 66 | 67 | color_latents = ar.gaussian_filter(color_latents, 5) 68 | 69 | # set upper layers of latent sequence to the colored sequence 70 | latents[:, color_layer:] = color_latents 71 | 72 | return latents 73 | 74 | 75 | def get_noise(height, width, scale, num_scales, args): 76 | if width > 256: 77 | return None 78 | 79 | lo_onsets = 1.25 * args.low_onsets[:, None, None, None].cuda() 80 | hi_onsets = 1.25 * args.high_onsets[:, None, None, None].cuda() 81 | 82 | noise_noisy = ar.gaussian_filter(th.randn((args.n_frames, 1, height, width), device="cuda"), 5) 83 | 84 | noise = ar.gaussian_filter(th.randn((args.n_frames, 1, height, width), device="cuda"), 128) 85 | if width > 8: 86 | noise = lo_onsets * noise_noisy + (1 - lo_onsets) * noise 87 | noise = hi_onsets * noise_noisy + (1 - hi_onsets) * noise 88 | 89 | noise /= noise.std() * 2.5 90 | 91 | return noise.cpu() 92 | 93 | 94 | def get_bends(args): 95 | # repeat the intermediate features outwards on both sides (2:1 aspect ratio) 96 | # + add some noise to give the whole thing a little variation (disguises the repetition) 97 | transform = th.nn.Sequential( 98 | th.nn.ReplicationPad2d((2, 2, 0, 0)), ar.AddNoise(0.025 * th.randn(size=(1, 1, 4, 8), device="cuda")), 99 | ) 100 | bends = [{"layer": 0, "transform": transform}] 101 | 102 | # during the drop, create scrolling effect 103 | drop_start = int(5591 * (45 / args.duration)) 104 | drop_end = int(5591 * (135 / args.duration)) 105 | 106 | # calculate length of loops, number of loops, and remainder at end of drop 107 | scroll_loop_length = int(6 * args.fps) 108 | scroll_loop_num = int((drop_end - drop_start) / scroll_loop_length) 109 | scroll_trunc = (drop_end - drop_start) - scroll_loop_num * scroll_loop_length 110 | 111 | # apply network bending to 4th layer in StyleGAN 112 | # lower layer network bends have more fluid outcomes 113 | tl = 4 114 | h = 2 ** tl 115 | w = 2 * h 116 | 117 | # create values between 0 and 1 corresponding to fraction of scroll from left to right completed 118 | # all 0s during intro 119 | intro_tl8 = np.zeros(drop_start) 120 | # repeating linear interpolation from 0 to 1 during drop 121 | loops_tl8 = np.concatenate([np.linspace(0, w, scroll_loop_length)] * scroll_loop_num) 122 | # truncated interp 123 | last_loop_tl8 = np.linspace(0, w, scroll_loop_length)[:scroll_trunc] 124 | # static at final truncated value during outro 125 | outro_tl8 = np.ones(args.n_frames - drop_end) * np.linspace(0, w, scroll_loop_length)[scroll_trunc + 1] 126 | 127 | # create 2D array of translations in x and y directions 128 | x_tl8 = np.concatenate([intro_tl8, loops_tl8, last_loop_tl8, outro_tl8]) 129 | y_tl8 = np.zeros(args.n_frames) 130 | translation = (th.tensor([x_tl8, y_tl8]).float().T)[: args.n_frames] 131 | 132 | # smooth the transition from intro to drop to prevent jerk 133 | translation.T[0, drop_start - args.fps : drop_start + args.fps] = ar.gaussian_filter( 134 | translation.T[0, drop_start - 5 * args.fps : drop_start + 5 * args.fps], 5 135 | )[4 * args.fps : -4 * args.fps] 136 | 137 | class Translate(NetworkBend): 138 | """From audioreactive/examples/bend.py""" 139 | 140 | def __init__(self, modulation, h, w, noise): 141 | sequential_fn = lambda b: th.nn.Sequential( 142 | th.nn.ReflectionPad2d((int(w / 2), int(w / 2), 0, 0)), # < Reflect out to 5x width (so that after 143 | th.nn.ReflectionPad2d((w, w, 0, 0)), # < translating w pixels, center crop gives 144 | th.nn.ReflectionPad2d((w, 0, 0, 0)), # < same features as translating 0 pixels) 145 | AddNoise(noise), # add some noise to disguise reflections 146 | kT.Translate(b), 147 | kA.CenterCrop((h, w)), 148 | ) 149 | super(Translate, self).__init__(sequential_fn, modulation) 150 | 151 | # create static noise for translate bend 152 | noise = 0.2 * th.randn((1, 1, h, 5 * w), device="cuda") 153 | # create function which returns an initialized Translate object when fed a batch of modulation 154 | # this is so that creation of the object is delayed until the specific batch is sent into the generator 155 | # (there's probably an easier way to do this without the kornia transforms, e.g. using Broad et al.'s transform implementations) 156 | transform = lambda batch: partial(Translate, h=h, w=w, noise=noise)(batch) 157 | bends += [{"layer": tl, "transform": transform, "modulation": translation}] # add network bend to list dict 158 | 159 | return bends 160 | -------------------------------------------------------------------------------- /audioreactive/examples/temper.py: -------------------------------------------------------------------------------- 1 | """ 2 | This file shows an example of spatial control of the noise using a simple circular mask 3 | The latents are a chromagram weighted sequence, modulated by drum onsets 4 | """ 5 | 6 | import scipy.ndimage.filters as ndi 7 | import torch as th 8 | 9 | import audioreactive as ar 10 | 11 | OVERRIDE = dict(audio_file="audioreactive/examples/Wavefunk - Temper.mp3", out_size=1024) 12 | 13 | 14 | def initialize(args): 15 | # these onsets can definitely use some tweaking, the drum reactivity isn't great for this one 16 | # the main bass makes it hard to identify both the kick and the snare because it is so loud and covers the whole spectrum 17 | args.lo_onsets = ar.onsets(args.audio, args.sr, args.n_frames, fmax=150, smooth=5, clip=97, power=2) 18 | args.hi_onsets = ar.onsets(args.audio, args.sr, args.n_frames, fmin=500, smooth=5, clip=99, power=2) 19 | return args 20 | 21 | 22 | def get_latents(selection, args): 23 | # create chromagram weighted sequence 24 | chroma = ar.chroma(args.audio, args.sr, args.n_frames) 25 | chroma_latents = ar.chroma_weight_latents(chroma, selection) 26 | latents = ar.gaussian_filter(chroma_latents, 4) 27 | 28 | # expand onsets to latent shape 29 | lo_onsets = args.lo_onsets[:, None, None] 30 | hi_onsets = args.hi_onsets[:, None, None] 31 | 32 | # modulate latents to specific latent vectors 33 | latents = hi_onsets * selection[[-4]] + (1 - hi_onsets) * latents 34 | latents = lo_onsets * selection[[-7]] + (1 - lo_onsets) * latents 35 | 36 | latents = ar.gaussian_filter(latents, 2, causal=0.2) 37 | 38 | return latents 39 | 40 | 41 | def circular_mask(h, w, center=None, radius=None, soft=0): 42 | if center is None: # use the middle of the image 43 | center = (int(w / 2), int(h / 2)) 44 | if radius is None: # use the smallest distance between the center and image walls 45 | radius = min(center[0], center[1], w - center[0], h - center[1]) 46 | 47 | Y, X = np.ogrid[:h, :w] 48 | dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2) 49 | mask = dist_from_center <= radius 50 | 51 | if soft > 0: 52 | mask = ndi.gaussian_filter(mask, sigma=int(round(soft))) # blur mask for smoother transition 53 | 54 | return th.from_numpy(mask) 55 | 56 | 57 | def get_noise(height, width, scale, num_scales, args): 58 | if width > 256: # larger sizes don't fit in VRAM, just use default or randomize 59 | return None 60 | 61 | # expand onsets to noise shape 62 | # send to GPU as gaussian_filter on large noise tensors with high standard deviation is slow 63 | lo_onsets = args.lo_onsets[:, None, None, None].cuda() 64 | hi_onsets = args.hi_onsets[:, None, None, None].cuda() 65 | 66 | # 1s inside circle of radius, 0s outside 67 | mask = circular_mask(height, width, radius=int(width / 2), soft=2)[None, None, ...].float().cuda() 68 | 69 | # create noise which changes quickly (small standard deviation smoothing) 70 | noise_noisy = ar.gaussian_filter(th.randn((args.n_frames, 1, height, width), device="cuda"), 5) 71 | 72 | # create noise which changes slowly (large standard deviation smoothing) 73 | noise = ar.gaussian_filter(th.randn((args.n_frames, 1, height, width), device="cuda"), 128) 74 | 75 | # for lower layers, noise inside circle are affected by low onsets 76 | if width < 128: 77 | noise = 2 * mask * lo_onsets * noise_noisy + (1 - mask) * (1 - lo_onsets) * noise 78 | # for upper layers, noise outside circle are affected by high onsets 79 | if width > 32: 80 | noise = 0.75 * (1 - mask) * hi_onsets * noise_noisy + mask * (1 - 0.75 * hi_onsets) * noise 81 | 82 | # ensure amplitude of noise is close to standard normal distribution (dividing by std. dev. gets it exactly there) 83 | noise /= noise.std() * 2 84 | 85 | return noise.cpu() 86 | -------------------------------------------------------------------------------- /audioreactive/latent.py: -------------------------------------------------------------------------------- 1 | import gc 2 | 3 | import numpy as np 4 | import torch as th 5 | from scipy import interpolate 6 | 7 | from models.stylegan2 import Generator 8 | from .signal import gaussian_filter 9 | 10 | # ==================================================================================== 11 | # ================================= latent/noise ops ================================= 12 | # ==================================================================================== 13 | 14 | 15 | def chroma_weight_latents(chroma, latents): 16 | """Creates chromagram weighted latent sequence 17 | 18 | Args: 19 | chroma (th.tensor): Chromagram 20 | latents (th.tensor): Latents (must have same number as number of notes in chromagram) 21 | 22 | Returns: 23 | th.tensor: Chromagram weighted latent sequence 24 | """ 25 | base_latents = (chroma[..., None, None] * latents[None, ...]).sum(1) 26 | return base_latents 27 | 28 | 29 | def slerp(val, low, high): 30 | """Interpolation along geodesic of n-dimensional unit sphere 31 | from https://github.com/soumith/dcgan.torch/issues/14#issuecomment-200025792 32 | 33 | Args: 34 | val (float): Value between 0 and 1 representing fraction of interpolation completed 35 | low (float): Starting value 36 | high (float): Ending value 37 | 38 | Returns: 39 | float: Interpolated value 40 | """ 41 | omega = np.arccos(np.clip(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high)), -1, 1)) 42 | so = np.sin(omega) 43 | if so == 0: 44 | return (1.0 - val) * low + val * high # L'Hopital's rule/LERP 45 | return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high 46 | 47 | 48 | def slerp_loops(latent_selection, n_frames, n_loops, smoothing=1, loop=True): 49 | """Get looping latents using geodesic interpolation. Total length of n_frames with n_loops repeats. 50 | 51 | Args: 52 | latent_selection (th.tensor): Set of latents to loop between (in order) 53 | n_frames (int): Total length of output looping sequence 54 | n_loops (int): Number of times to loop 55 | smoothing (int, optional): Standard deviation of gaussian smoothing kernel. Defaults to 1. 56 | loop (bool, optional): Whether to return to first latent. Defaults to True. 57 | 58 | Returns: 59 | th.tensor: Sequence of smoothly looping latents 60 | """ 61 | if loop: 62 | latent_selection = np.concatenate([latent_selection, latent_selection[[0]]]) 63 | 64 | base_latents = [] 65 | for n in range(len(latent_selection)): 66 | for val in np.linspace(0.0, 1.0, int(n_frames // max(1, n_loops) // len(latent_selection))): 67 | base_latents.append( 68 | th.from_numpy( 69 | slerp( 70 | val, 71 | latent_selection[n % len(latent_selection)][0], 72 | latent_selection[(n + 1) % len(latent_selection)][0], 73 | ) 74 | ) 75 | ) 76 | base_latents = th.stack(base_latents) 77 | base_latents = gaussian_filter(base_latents, smoothing) 78 | base_latents = th.cat([base_latents] * int(n_frames / len(base_latents)), axis=0) 79 | base_latents = th.cat([base_latents[:, None, :]] * 18, axis=1) 80 | if n_frames - len(base_latents) != 0: 81 | base_latents = th.cat([base_latents, base_latents[0 : n_frames - len(base_latents)]]) 82 | return base_latents 83 | 84 | 85 | def spline_loops(latent_selection, n_frames, n_loops, loop=True): 86 | """Get looping latents using spline interpolation. Total length of n_frames with n_loops repeats. 87 | 88 | Args: 89 | latent_selection (th.tensor): Set of latents to loop between (in order) 90 | n_frames (int): Total length of output looping sequence 91 | n_loops (int): Number of times to loop 92 | loop (bool, optional): Whether to return to first latent. Defaults to True. 93 | 94 | Returns: 95 | th.tensor: Sequence of smoothly looping latents 96 | """ 97 | if loop: 98 | latent_selection = np.concatenate([latent_selection, latent_selection[[0]]]) 99 | 100 | x = np.linspace(0, 1, int(n_frames // max(1, n_loops))) 101 | base_latents = np.zeros((len(x), *latent_selection.shape[1:])) 102 | for lay in range(latent_selection.shape[1]): 103 | for lat in range(latent_selection.shape[2]): 104 | tck = interpolate.splrep(np.linspace(0, 1, latent_selection.shape[0]), latent_selection[:, lay, lat]) 105 | base_latents[:, lay, lat] = interpolate.splev(x, tck) 106 | 107 | base_latents = th.cat([th.from_numpy(base_latents)] * int(n_frames / len(base_latents)), axis=0) 108 | if n_frames - len(base_latents) > 0: 109 | base_latents = th.cat([base_latents, base_latents[0 : n_frames - len(base_latents)]]) 110 | return base_latents[:n_frames] 111 | 112 | 113 | def wrapping_slice(tensor, start, length, return_indices=False): 114 | """Gets slice of tensor of a given length that wraps around to beginning 115 | 116 | Args: 117 | tensor (th.tensor): Tensor to slice 118 | start (int): Starting index 119 | length (int): Size of slice 120 | return_indices (bool, optional): Whether to return indices rather than values. Defaults to False. 121 | 122 | Returns: 123 | th.tensor: Values or indices of slice 124 | """ 125 | if start + length <= tensor.shape[0]: 126 | indices = th.arange(start, start + length) 127 | else: 128 | indices = th.cat((th.arange(start, tensor.shape[0]), th.arange(0, (start + length) % tensor.shape[0]))) 129 | if tensor.shape[0] == 1: 130 | indices = th.zeros(1, dtype=th.int64) 131 | if return_indices: 132 | return indices 133 | return tensor[indices] 134 | 135 | 136 | def generate_latents(n_latents, ckpt, G_res, noconst=False, latent_dim=512, n_mlp=8, channel_multiplier=2): 137 | """Generates random, mapped latents 138 | 139 | Args: 140 | n_latents (int): Number of mapped latents to generate 141 | ckpt (str): Generator checkpoint to use 142 | G_res (int): Generator's training resolution 143 | noconst (bool, optional): Whether the generator was trained without constant starting layer. Defaults to False. 144 | latent_dim (int, optional): Size of generator's latent vectors. Defaults to 512. 145 | n_mlp (int, optional): Number of layers in the generator's mapping network. Defaults to 8. 146 | channel_multiplier (int, optional): Scaling multiplier for generator's channel depth. Defaults to 2. 147 | 148 | Returns: 149 | th.tensor: Set of mapped latents 150 | """ 151 | generator = Generator( 152 | G_res, latent_dim, n_mlp, channel_multiplier=channel_multiplier, constant_input=not noconst, checkpoint=ckpt, 153 | ).cuda() 154 | zs = th.randn((n_latents, latent_dim), device="cuda") 155 | latent_selection = generator(zs, map_latents=True).cpu() 156 | del generator, zs 157 | gc.collect() 158 | th.cuda.empty_cache() 159 | return latent_selection 160 | 161 | 162 | def save_latents(latents, filename): 163 | """Saves latent vectors to file 164 | 165 | Args: 166 | latents (th.tensor): Latent vector(s) to save 167 | filename (str): Filename to save to 168 | """ 169 | np.save(filename, latents) 170 | 171 | 172 | def load_latents(filename): 173 | """Load latents from numpy file 174 | 175 | Args: 176 | filename (str): Filename to load from 177 | 178 | Returns: 179 | th.tensor: Latent vectors 180 | """ 181 | return th.from_numpy(np.load(filename)) 182 | 183 | 184 | def _perlinterpolant(t): 185 | return t * t * t * (t * (t * 6 - 15) + 10) 186 | 187 | 188 | def perlin_noise(shape, res, tileable=(True, False, False), interpolant=_perlinterpolant): 189 | """Generate a 3D tensor of perlin noise. 190 | 191 | Args: 192 | shape: The shape of the generated tensor (tuple of three ints). This must be a multiple of res. 193 | res: The number of periods of noise to generate along each axis (tuple of three ints). Note shape must be a multiple of res. 194 | tileable: If the noise should be tileable along each axis (tuple of three bools). Defaults to (False, False, False). 195 | interpolant: The interpolation function, defaults to t*t*t*(t*(t*6 - 15) + 10). 196 | 197 | Returns: 198 | A tensor of shape shape with the generated noise. 199 | 200 | Raises: 201 | ValueError: If shape is not a multiple of res. 202 | """ 203 | delta = (res[0] / shape[0], res[1] / shape[1], res[2] / shape[2]) 204 | d = (shape[0] // res[0], shape[1] // res[1], shape[2] // res[2]) 205 | grid = np.mgrid[0 : res[0] : delta[0], 0 : res[1] : delta[1], 0 : res[2] : delta[2]] 206 | grid = grid.transpose(1, 2, 3, 0) % 1 207 | grid = th.from_numpy(grid).cuda() 208 | # Gradients 209 | theta = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1, res[2] + 1) 210 | phi = 2 * np.pi * np.random.rand(res[0] + 1, res[1] + 1, res[2] + 1) 211 | gradients = np.stack((np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)), axis=3) 212 | if tileable[0]: 213 | gradients[-1, :, :] = gradients[0, :, :] 214 | if tileable[1]: 215 | gradients[:, -1, :] = gradients[:, 0, :] 216 | if tileable[2]: 217 | gradients[:, :, -1] = gradients[:, :, 0] 218 | gradients = gradients.repeat(d[0], 0).repeat(d[1], 1).repeat(d[2], 2) 219 | gradients = th.from_numpy(gradients).cuda() 220 | g000 = gradients[: -d[0], : -d[1], : -d[2]] 221 | g100 = gradients[d[0] :, : -d[1], : -d[2]] 222 | g010 = gradients[: -d[0], d[1] :, : -d[2]] 223 | g110 = gradients[d[0] :, d[1] :, : -d[2]] 224 | g001 = gradients[: -d[0], : -d[1], d[2] :] 225 | g101 = gradients[d[0] :, : -d[1], d[2] :] 226 | g011 = gradients[: -d[0], d[1] :, d[2] :] 227 | g111 = gradients[d[0] :, d[1] :, d[2] :] 228 | # Ramps 229 | n000 = th.sum(th.stack((grid[:, :, :, 0], grid[:, :, :, 1], grid[:, :, :, 2]), axis=3) * g000, 3) 230 | n100 = th.sum(th.stack((grid[:, :, :, 0] - 1, grid[:, :, :, 1], grid[:, :, :, 2]), axis=3) * g100, 3) 231 | n010 = th.sum(th.stack((grid[:, :, :, 0], grid[:, :, :, 1] - 1, grid[:, :, :, 2]), axis=3) * g010, 3) 232 | n110 = th.sum(th.stack((grid[:, :, :, 0] - 1, grid[:, :, :, 1] - 1, grid[:, :, :, 2]), axis=3) * g110, 3) 233 | n001 = th.sum(th.stack((grid[:, :, :, 0], grid[:, :, :, 1], grid[:, :, :, 2] - 1), axis=3) * g001, 3) 234 | n101 = th.sum(th.stack((grid[:, :, :, 0] - 1, grid[:, :, :, 1], grid[:, :, :, 2] - 1), axis=3) * g101, 3) 235 | n011 = th.sum(th.stack((grid[:, :, :, 0], grid[:, :, :, 1] - 1, grid[:, :, :, 2] - 1), axis=3) * g011, 3) 236 | n111 = th.sum(th.stack((grid[:, :, :, 0] - 1, grid[:, :, :, 1] - 1, grid[:, :, :, 2] - 1), axis=3) * g111, 3) 237 | # Interpolation 238 | t = interpolant(grid) 239 | n00 = n000 * (1 - t[:, :, :, 0]) + t[:, :, :, 0] * n100 240 | n10 = n010 * (1 - t[:, :, :, 0]) + t[:, :, :, 0] * n110 241 | n01 = n001 * (1 - t[:, :, :, 0]) + t[:, :, :, 0] * n101 242 | n11 = n011 * (1 - t[:, :, :, 0]) + t[:, :, :, 0] * n111 243 | n0 = (1 - t[:, :, :, 1]) * n00 + t[:, :, :, 1] * n10 244 | n1 = (1 - t[:, :, :, 1]) * n01 + t[:, :, :, 1] * n11 245 | perlin = (1 - t[:, :, :, 2]) * n0 + t[:, :, :, 2] * n1 246 | return perlin * 2 - 1 # stretch from -1 to 1 247 | -------------------------------------------------------------------------------- /audioreactive/util.py: -------------------------------------------------------------------------------- 1 | import librosa as rosa 2 | import librosa.display 3 | import matplotlib.pyplot as plt 4 | import numpy as np 5 | 6 | # ==================================================================================== 7 | # ==================================== utilities ===================================== 8 | # ==================================================================================== 9 | 10 | 11 | def info(arr): 12 | """Shows statistics and shape information of (lists of) np.arrays/th.tensors 13 | 14 | Args: 15 | arr (np.array/th.tensor/list): List of or single np.array or th.tensor 16 | """ 17 | if isinstance(arr, list): 18 | print([(list(a.shape), f"{a.min():.2f}", f"{a.mean():.2f}", f"{a.max():.2f}") for a in arr]) 19 | else: 20 | print(list(arr.shape), f"{arr.min():.2f}", f"{arr.mean():.2f}", f"{arr.max():.2f}") 21 | 22 | 23 | def plot_signals(signals): 24 | """Shows plot of (multiple) 1D signals 25 | 26 | Args: 27 | signals (np.array/th.tensor): List of signals (1 non-unit dimension) 28 | """ 29 | plt.figure(figsize=(16, 4 * len(signals))) 30 | for sbplt, y in enumerate(signals): 31 | try: 32 | signal = signal.cpu().numpy() 33 | except: 34 | pass 35 | plt.subplot(len(signals), 1, sbplt + 1) 36 | plt.plot(y.squeeze()) 37 | plt.tight_layout() 38 | plt.show() 39 | 40 | 41 | def plot_spectra(spectra, chroma=False): 42 | """Shows plot of (multiple) spectrograms 43 | 44 | Args: 45 | spectra (np.array/th.tensor): List of spectrograms 46 | chroma (bool, optional): Whether to plot with chromagram y-axis label. Defaults to False. 47 | """ 48 | fig, axes = plt.subplots(len(spectra), 1, figsize=(16, 4 * len(spectra))) 49 | for ax, spectrum in zip(axes if len(spectra) > 1 else [axes], spectra): 50 | try: 51 | spectrum = spectrum.cpu().numpy() 52 | except: 53 | pass 54 | if spectrum.shape[1] == 12: 55 | spectrum = spectrum.T 56 | rosa.display.specshow(spectrum, y_axis="chroma" if chroma else None, x_axis="time", ax=ax) 57 | plt.tight_layout() 58 | plt.show() 59 | 60 | 61 | def plot_audio(audio, sr): 62 | """Shows spectrogram of audio signal 63 | 64 | Args: 65 | audio (np.array): Audio signal to be plotted 66 | sr (int): Sampling rate of the audio 67 | """ 68 | plt.figure(figsize=(16, 9)) 69 | rosa.display.specshow( 70 | rosa.power_to_db(rosa.feature.melspectrogram(y=audio, sr=sr), ref=np.max), y_axis="mel", x_axis="time" 71 | ) 72 | plt.colorbar(format="%+2.f dB") 73 | plt.tight_layout() 74 | plt.show() 75 | 76 | 77 | def plot_chroma_comparison(audio, sr): 78 | """Shows plot comparing different chromagram strategies. 79 | 80 | Args: 81 | audio (np.array): Audio signal to be plotted 82 | sr (int): Sampling rate of the audio 83 | """ 84 | fig, ax = plt.subplots(nrows=2, ncols=3, figsize=(16, 9)) 85 | for col, types in enumerate([["cens", "cqt"], ["deep", "clp"], ["stft"]]): 86 | for row, type in enumerate(types): 87 | ch = raw_chroma(audio, sr, type=type) 88 | if ch.shape[1] == 12: 89 | ch = ch.T 90 | librosa.display.specshow(ch, y_axis="chroma", x_axis="time", ax=ax[row, col]) 91 | ax[row, col].set(title=type) 92 | ax[row, col].label_outer() 93 | plt.tight_layout() 94 | plt.show() 95 | -------------------------------------------------------------------------------- /contrastive_learner.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import random 3 | from functools import wraps 4 | 5 | import torch 6 | from torch import nn 7 | import torch.nn.functional as F 8 | 9 | 10 | def identity(x): 11 | return x 12 | 13 | 14 | def default(val, def_val): 15 | return def_val if val is None else val 16 | 17 | 18 | def flatten(t): 19 | return t.reshape(t.shape[0], -1) 20 | 21 | 22 | def safe_concat(arr, el, dim=0): 23 | if arr is None: 24 | return el 25 | return torch.cat((arr, el), dim=dim) 26 | 27 | 28 | def singleton(cache_key): 29 | def inner_fn(fn): 30 | @wraps(fn) 31 | def wrapper(self, *args, **kwargs): 32 | instance = getattr(self, cache_key) 33 | if instance is not None: 34 | return instance 35 | 36 | instance = fn(self, *args, **kwargs) 37 | setattr(self, cache_key, instance) 38 | return instance 39 | 40 | return wrapper 41 | 42 | return inner_fn 43 | 44 | 45 | # losses 46 | 47 | 48 | def contrastive_loss(queries, keys, temperature=0.1): 49 | b, device = queries.shape[0], queries.device 50 | logits = queries @ keys.t() 51 | logits = logits - logits.max(dim=-1, keepdim=True).values 52 | logits /= temperature 53 | return F.cross_entropy(logits, torch.arange(b, device=device)) 54 | 55 | 56 | def nt_xent_loss(queries, keys, temperature=0.1): 57 | b, device = queries.shape[0], queries.device 58 | 59 | n = b * 2 60 | projs = torch.cat((queries, keys)) 61 | logits = projs @ projs.t() 62 | 63 | mask = torch.eye(n, device=device).bool() 64 | logits = logits[~mask].reshape(n, n - 1) 65 | logits /= temperature 66 | 67 | labels = torch.cat(((torch.arange(b, device=device) + b - 1), torch.arange(b, device=device)), dim=0) 68 | loss = F.cross_entropy(logits, labels, reduction="sum") 69 | loss /= 2 * (b - 1) 70 | return loss 71 | 72 | 73 | # augmentation utils 74 | 75 | 76 | class RandomApply(nn.Module): 77 | def __init__(self, fn, p): 78 | super().__init__() 79 | self.fn = fn 80 | self.p = p 81 | 82 | def forward(self, x): 83 | x_out = [] 84 | for ex in x: 85 | if random.random() > self.p: 86 | x_out.append(ex[None, :]) 87 | else: 88 | x_out.append(self.fn(ex)) 89 | return torch.cat(x_out) 90 | 91 | 92 | # exponential moving average 93 | 94 | 95 | class EMA: 96 | def __init__(self, beta): 97 | super().__init__() 98 | self.beta = beta 99 | 100 | def update_average(self, old, new): 101 | if old is None: 102 | return new 103 | return old * self.beta + (1 - self.beta) * new 104 | 105 | 106 | def update_moving_average(ema_updater, ma_model, current_model): 107 | for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): 108 | old_weight, up_weight = ma_params.data, current_params.data 109 | ma_params.data = ema_updater.update_average(old_weight, up_weight) 110 | 111 | 112 | # hidden layer extractor class 113 | 114 | 115 | class OutputHiddenLayer(nn.Module): 116 | def __init__(self, net, layer=-2): 117 | super().__init__() 118 | self.net = net 119 | self.layer = layer 120 | 121 | self.hidden = None 122 | self._register_hook() 123 | 124 | def _find_layer(self): 125 | if type(self.layer) == str: 126 | modules = dict([*self.net.named_modules()]) 127 | return modules.get(self.layer, None) 128 | elif type(self.layer) == int: 129 | children = [*self.net.children()] 130 | return children[self.layer] 131 | elif type(self.layer) == tuple: 132 | children = [*self.net.children()] 133 | grand_children = [*children[self.layer[0]].children()] 134 | return grand_children[self.layer[1]] 135 | return None 136 | 137 | def _register_hook(self): 138 | def hook(_, __, output): 139 | self.hidden = output 140 | 141 | layer = self._find_layer() 142 | assert layer is not None, f"hidden layer ({self.layer}) not found" 143 | handle = layer.register_forward_hook(hook) 144 | 145 | def forward(self, x): 146 | if self.layer == -1: 147 | return self.net(x) 148 | 149 | _ = self.net(x) 150 | hidden = self.hidden 151 | self.hidden = None 152 | assert hidden is not None, f"hidden layer {self.layer} never emitted an output" 153 | return hidden 154 | 155 | 156 | class ContrastiveLearner(nn.Module): 157 | def __init__( 158 | self, 159 | net, 160 | image_size, 161 | hidden_layer=-2, 162 | project_hidden=True, 163 | project_dim=128, 164 | use_nt_xent_loss=False, 165 | use_bilinear=False, 166 | use_momentum=False, 167 | momentum_value=0.999, 168 | key_encoder=None, 169 | temperature=0.1, 170 | ): 171 | super().__init__() 172 | self.net = OutputHiddenLayer(net, layer=hidden_layer) 173 | 174 | self.temperature = temperature 175 | self.use_nt_xent_loss = use_nt_xent_loss 176 | 177 | self.project_hidden = project_hidden 178 | self.projection = None 179 | self.project_dim = project_dim 180 | 181 | self.use_bilinear = use_bilinear 182 | self.bilinear_w = None 183 | 184 | self.use_momentum = use_momentum 185 | self.ema_updater = EMA(momentum_value) 186 | self.key_encoder = key_encoder 187 | 188 | # for accumulating queries and keys across calls 189 | self.queries = None 190 | self.keys = None 191 | 192 | # send a mock image tensor to instantiate parameters 193 | init = torch.randn(1, 3, image_size, image_size, device="cuda") 194 | self.forward(init) 195 | 196 | @singleton("key_encoder") 197 | def _get_key_encoder(self): 198 | key_encoder = copy.deepcopy(self.net) 199 | key_encoder._register_hook() 200 | return key_encoder 201 | 202 | @singleton("bilinear_w") 203 | def _get_bilinear(self, hidden): 204 | _, dim = hidden.shape 205 | return nn.Parameter(torch.eye(dim, device=device, dtype=dtype)).to(hidden) 206 | 207 | @singleton("projection") 208 | def _get_projection_fn(self, hidden): 209 | _, dim = hidden.shape 210 | return nn.Sequential( 211 | nn.Linear(dim, dim, bias=False), nn.LeakyReLU(inplace=True), nn.Linear(dim, self.project_dim, bias=False) 212 | ).to(hidden) 213 | 214 | def reset_moving_average(self): 215 | assert self.use_momentum, "must be using momentum method for key encoder" 216 | del self.key_encoder 217 | self.key_encoder = None 218 | 219 | def update_moving_average(self): 220 | assert self.key_encoder is not None, "key encoder has not been created yet" 221 | self.key_encoder = update_moving_average(self.ema_updater, self.key_encoder, self.net) 222 | 223 | def calculate_loss(self): 224 | assert self.queries is not None and self.keys is not None, "no queries or keys accumulated" 225 | loss_fn = nt_xent_loss if self.use_nt_xent_loss else contrastive_loss 226 | loss = loss_fn(self.queries, self.keys, temperature=self.temperature) 227 | self.queries = self.keys = None 228 | return loss 229 | 230 | def forward(self, x, aug_x, accumulate=False): 231 | b, c, h, w, device = *x.shape, x.device 232 | 233 | queries = self.net(aug_x) 234 | 235 | key_encoder = self.net if not self.use_momentum else self._get_key_encoder() 236 | keys = key_encoder(aug_x) 237 | 238 | if self.use_momentum: 239 | keys = keys.detach() 240 | 241 | queries, keys = map(flatten, (queries, keys)) 242 | 243 | if self.use_bilinear: 244 | W = self._get_bilinear(keys) 245 | keys = (W @ keys.t()).t() 246 | 247 | project_fn = self._get_projection_fn(queries) if self.project_hidden else identity 248 | queries, keys = map(project_fn, (queries, keys)) 249 | 250 | self.queries = safe_concat(self.queries, queries) 251 | self.keys = safe_concat(self.keys, keys) 252 | 253 | return self.calculate_loss() if not accumulate else None 254 | -------------------------------------------------------------------------------- /convert_weight.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | import pickle 5 | import sys 6 | 7 | import numpy as np 8 | import torch 9 | from torchvision import utils 10 | 11 | from model import Discriminator, Generator 12 | 13 | 14 | def convert_modconv(vars, source_name, target_name, flip=False): 15 | weight = vars[source_name + "/weight"].value().eval() 16 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 17 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 18 | noise = vars[source_name + "/noise_strength"].value().eval() 19 | bias = vars[source_name + "/bias"].value().eval() 20 | 21 | dic = { 22 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 23 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 24 | "conv.modulation.bias": mod_bias + 1, 25 | "noise.weight": np.array([noise]), 26 | "activate.bias": bias, 27 | } 28 | 29 | dic_torch = {} 30 | 31 | for k, v in dic.items(): 32 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 33 | 34 | if flip: 35 | dic_torch[target_name + ".conv.weight"] = torch.flip(dic_torch[target_name + ".conv.weight"], [3, 4]) 36 | 37 | return dic_torch 38 | 39 | 40 | def convert_conv(vars, source_name, target_name, bias=True, start=0): 41 | weight = vars[source_name + "/weight"].value().eval() 42 | 43 | dic = {"weight": weight.transpose((3, 2, 0, 1))} 44 | 45 | if bias: 46 | dic["bias"] = vars[source_name + "/bias"].value().eval() 47 | 48 | dic_torch = {} 49 | 50 | dic_torch[target_name + f".{start}.weight"] = torch.from_numpy(dic["weight"]) 51 | 52 | if bias: 53 | dic_torch[target_name + f".{start + 1}.bias"] = torch.from_numpy(dic["bias"]) 54 | 55 | return dic_torch 56 | 57 | 58 | def convert_torgb(vars, source_name, target_name): 59 | weight = vars[source_name + "/weight"].value().eval() 60 | mod_weight = vars[source_name + "/mod_weight"].value().eval() 61 | mod_bias = vars[source_name + "/mod_bias"].value().eval() 62 | bias = vars[source_name + "/bias"].value().eval() 63 | 64 | dic = { 65 | "conv.weight": np.expand_dims(weight.transpose((3, 2, 0, 1)), 0), 66 | "conv.modulation.weight": mod_weight.transpose((1, 0)), 67 | "conv.modulation.bias": mod_bias + 1, 68 | "bias": bias.reshape((1, 3, 1, 1)), 69 | } 70 | 71 | dic_torch = {} 72 | 73 | for k, v in dic.items(): 74 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 75 | 76 | return dic_torch 77 | 78 | 79 | def convert_dense(vars, source_name, target_name): 80 | weight = vars[source_name + "/weight"].value().eval() 81 | bias = vars[source_name + "/bias"].value().eval() 82 | 83 | dic = {"weight": weight.transpose((1, 0)), "bias": bias} 84 | 85 | dic_torch = {} 86 | 87 | for k, v in dic.items(): 88 | dic_torch[target_name + "." + k] = torch.from_numpy(v) 89 | 90 | return dic_torch 91 | 92 | 93 | def update(state_dict, new): 94 | for k, v in new.items(): 95 | if k not in state_dict: 96 | raise KeyError(k + " is not found") 97 | 98 | if v.shape != state_dict[k].shape: 99 | raise ValueError(f"Shape mismatch: {v.shape} vs {state_dict[k].shape}") 100 | 101 | state_dict[k] = v 102 | 103 | 104 | def discriminator_fill_statedict(statedict, vars, size): 105 | log_size = int(math.log(size, 2)) 106 | 107 | update(statedict, convert_conv(vars, f"{size}x{size}/FromRGB", "convs.0")) 108 | 109 | conv_i = 1 110 | 111 | for i in range(log_size - 2, 0, -1): 112 | reso = 4 * 2 ** i 113 | update( 114 | statedict, convert_conv(vars, f"{reso}x{reso}/Conv0", f"convs.{conv_i}.conv1"), 115 | ) 116 | update( 117 | statedict, convert_conv(vars, f"{reso}x{reso}/Conv1_down", f"convs.{conv_i}.conv2", start=1), 118 | ) 119 | update( 120 | statedict, convert_conv(vars, f"{reso}x{reso}/Skip", f"convs.{conv_i}.skip", start=1, bias=False), 121 | ) 122 | conv_i += 1 123 | 124 | update(statedict, convert_conv(vars, f"4x4/Conv", "final_conv")) 125 | update(statedict, convert_dense(vars, f"4x4/Dense0", "final_linear.0")) 126 | update(statedict, convert_dense(vars, f"Output", "final_linear.1")) 127 | 128 | return statedict 129 | 130 | 131 | def fill_statedict(state_dict, vars, size): 132 | log_size = int(math.log(size, 2)) 133 | 134 | for i in range(8): 135 | update(state_dict, convert_dense(vars, f"G_mapping/Dense{i}", f"style.{i + 1}")) 136 | 137 | update( 138 | state_dict, {"input.input": torch.from_numpy(vars["G_synthesis/4x4/Const/const"].value().eval())}, 139 | ) 140 | 141 | update(state_dict, convert_torgb(vars, "G_synthesis/4x4/ToRGB", "to_rgb1")) 142 | 143 | for i in range(log_size - 2): 144 | reso = 4 * 2 ** (i + 1) 145 | update( 146 | state_dict, convert_torgb(vars, f"G_synthesis/{reso}x{reso}/ToRGB", f"to_rgbs.{i}"), 147 | ) 148 | 149 | update(state_dict, convert_modconv(vars, "G_synthesis/4x4/Conv", "conv1")) 150 | 151 | conv_i = 0 152 | 153 | for i in range(log_size - 2): 154 | reso = 4 * 2 ** (i + 1) 155 | update( 156 | state_dict, convert_modconv(vars, f"G_synthesis/{reso}x{reso}/Conv0_up", f"convs.{conv_i}", flip=True,), 157 | ) 158 | update( 159 | state_dict, convert_modconv(vars, f"G_synthesis/{reso}x{reso}/Conv1", f"convs.{conv_i + 1}"), 160 | ) 161 | conv_i += 2 162 | 163 | for i in range(0, (log_size - 2) * 2 + 1): 164 | update( 165 | state_dict, {f"noises.noise_{i}": torch.from_numpy(vars[f"G_synthesis/noise{i}"].value().eval())}, 166 | ) 167 | 168 | return state_dict 169 | 170 | 171 | if __name__ == "__main__": 172 | device = "cuda" 173 | 174 | parser = argparse.ArgumentParser() 175 | parser.add_argument("--repo", type=str, required=True) 176 | parser.add_argument("--gen", action="store_true") 177 | parser.add_argument("--disc", action="store_true") 178 | parser.add_argument("--channel_multiplier", type=int, default=2) 179 | parser.add_argument("path", metavar="PATH") 180 | 181 | args = parser.parse_args() 182 | 183 | sys.path.append(args.repo) 184 | 185 | import dnnlib 186 | from dnnlib import tflib 187 | 188 | tflib.init_tf() 189 | 190 | with open(args.path, "rb") as f: 191 | generator, discriminator, g_ema = pickle.load(f) 192 | 193 | size = g_ema.output_shape[2] 194 | 195 | g = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) 196 | state_dict = g.state_dict() 197 | state_dict = fill_statedict(state_dict, g_ema.vars, size) 198 | 199 | g.load_state_dict(state_dict) 200 | 201 | latent_avg = torch.from_numpy(g_ema.vars["dlatent_avg"].value().eval()) 202 | 203 | ckpt = {"g_ema": state_dict, "latent_avg": latent_avg} 204 | 205 | if args.gen: 206 | g_train = Generator(size, 512, 8, channel_multiplier=args.channel_multiplier) 207 | g_train_state = g_train.state_dict() 208 | g_train_state = fill_statedict(g_train_state, generator.vars, size) 209 | ckpt["g"] = g_train_state 210 | 211 | if args.disc: 212 | disc = Discriminator(size, channel_multiplier=args.channel_multiplier) 213 | d_state = disc.state_dict() 214 | d_state = discriminator_fill_statedict(d_state, discriminator.vars, size) 215 | ckpt["d"] = d_state 216 | 217 | name = os.path.splitext(os.path.basename(args.path))[0] 218 | torch.save(ckpt, name + ".pt") 219 | 220 | batch_size = {256: 16, 512: 9, 1024: 4} 221 | n_sample = batch_size.get(size, 25) 222 | 223 | g = g.to(device) 224 | 225 | z = np.random.RandomState(0).randn(n_sample, 512).astype("float32") 226 | 227 | with torch.no_grad(): 228 | img_pt, _ = g([torch.from_numpy(z).to(device)], truncation=0.5, truncation_latent=latent_avg.to(device),) 229 | 230 | Gs_kwargs = dnnlib.EasyDict() 231 | Gs_kwargs.randomize_noise = False 232 | img_tf = g_ema.run(z, None, **Gs_kwargs) 233 | img_tf = torch.from_numpy(img_tf).to(device) 234 | 235 | img_diff = ((img_pt + 1) / 2).clamp(0.0, 1.0) - ((img_tf.to(device) + 1) / 2).clamp(0.0, 1.0) 236 | 237 | img_concat = torch.cat((img_tf, img_pt, img_diff), dim=0) 238 | utils.save_image(img_concat, name + ".png", nrow=n_sample, normalize=True, range=(-1, 1)) 239 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | from io import BytesIO 2 | 3 | import lmdb 4 | import numpy as np 5 | from PIL import Image 6 | from PIL import Image 7 | from torch.utils.data import Dataset 8 | 9 | 10 | class MultiResolutionDataset(Dataset): 11 | def __init__(self, path, transform, resolution=256): 12 | self.env = lmdb.open(path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False,) 13 | 14 | if not self.env: 15 | raise IOError("Cannot open lmdb dataset", path) 16 | 17 | with self.env.begin(write=False) as txn: 18 | self.length = int(txn.get("length".encode("utf-8")).decode("utf-8")) 19 | 20 | self.resolution = resolution 21 | self.transform = transform 22 | 23 | def __len__(self): 24 | return self.length 25 | 26 | def __getitem__(self, index): 27 | while True: 28 | try: 29 | with self.env.begin(write=False) as txn: 30 | key = f"{self.resolution}-{str(index).zfill(5)}".encode("utf-8") 31 | img_bytes = txn.get(key) 32 | 33 | buffer = BytesIO(img_bytes) 34 | img = Image.open(buffer) 35 | break 36 | except: 37 | print(f"ERROR loading image {index}") 38 | index = int(np.random.rand() * self.length) 39 | print(f"Trying again with {index}...") 40 | img = self.transform(img) 41 | 42 | return img 43 | -------------------------------------------------------------------------------- /distributed.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | 3 | import torch 4 | from torch import distributed as dist 5 | 6 | 7 | def get_rank(): 8 | if not dist.is_available(): 9 | return 0 10 | 11 | if not dist.is_initialized(): 12 | return 0 13 | 14 | return dist.get_rank() 15 | 16 | 17 | def synchronize(): 18 | if not dist.is_available(): 19 | return 20 | 21 | if not dist.is_initialized(): 22 | return 23 | 24 | world_size = dist.get_world_size() 25 | 26 | if world_size == 1: 27 | return 28 | 29 | dist.barrier() 30 | 31 | 32 | def get_world_size(): 33 | if not dist.is_available(): 34 | return 1 35 | 36 | if not dist.is_initialized(): 37 | return 1 38 | 39 | return dist.get_world_size() 40 | 41 | 42 | def reduce_sum(tensor): 43 | if not dist.is_available(): 44 | return tensor 45 | 46 | if not dist.is_initialized(): 47 | return tensor 48 | 49 | tensor = tensor.clone() 50 | dist.all_reduce(tensor, op=dist.ReduceOp.SUM) 51 | 52 | return tensor 53 | 54 | 55 | def gather_grad(params): 56 | world_size = get_world_size() 57 | 58 | if world_size == 1: 59 | return 60 | 61 | for param in params: 62 | if param.grad is not None: 63 | dist.all_reduce(param.grad.data, op=dist.ReduceOp.SUM) 64 | param.grad.data.div_(world_size) 65 | 66 | 67 | def all_gather(data): 68 | world_size = get_world_size() 69 | 70 | if world_size == 1: 71 | return [data] 72 | 73 | buffer = pickle.dumps(data) 74 | storage = torch.ByteStorage.from_buffer(buffer) 75 | tensor = torch.ByteTensor(storage).to("cuda") 76 | 77 | local_size = torch.IntTensor([tensor.numel()]).to("cuda") 78 | size_list = [torch.IntTensor([0]).to("cuda") for _ in range(world_size)] 79 | dist.all_gather(size_list, local_size) 80 | size_list = [int(size.item()) for size in size_list] 81 | max_size = max(size_list) 82 | 83 | tensor_list = [] 84 | for _ in size_list: 85 | tensor_list.append(torch.ByteTensor(size=(max_size,)).to("cuda")) 86 | 87 | if local_size != max_size: 88 | padding = torch.ByteTensor(size=(max_size - local_size,)).to("cuda") 89 | tensor = torch.cat((tensor, padding), 0) 90 | 91 | dist.all_gather(tensor_list, tensor) 92 | 93 | data_list = [] 94 | 95 | for size, tensor in zip(size_list, tensor_list): 96 | buffer = tensor.cpu().numpy().tobytes()[:size] 97 | data_list.append(pickle.loads(buffer)) 98 | 99 | return data_list 100 | 101 | 102 | def reduce_loss_dict(loss_dict): 103 | world_size = get_world_size() 104 | 105 | if world_size < 2: 106 | return loss_dict 107 | 108 | with torch.no_grad(): 109 | keys = [] 110 | losses = [] 111 | 112 | for k in sorted(loss_dict.keys()): 113 | keys.append(k) 114 | losses.append(loss_dict[k]) 115 | 116 | losses = torch.stack(losses, 0) 117 | dist.reduce(losses, dst=0) 118 | 119 | if dist.get_rank() == 0: 120 | losses /= world_size 121 | 122 | reduced_losses = {k: v for k, v in zip(keys, losses)} 123 | 124 | return reduced_losses 125 | -------------------------------------------------------------------------------- /generate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import torch 3 | from torchvision import utils 4 | from models.stylegan2 import Generator 5 | from tqdm import tqdm 6 | 7 | 8 | def generate(args, g_ema, device, mean_latent): 9 | 10 | with torch.no_grad(): 11 | g_ema.eval() 12 | for i in tqdm(range(args.pics)): 13 | sample_z = torch.randn(args.sample, args.latent, device=device) 14 | 15 | sample, _ = g_ema([sample_z], truncation=args.truncation, truncation_latent=mean_latent) 16 | 17 | utils.save_image( 18 | sample, f"sample/{str(i).zfill(6)}.png", nrow=1, normalize=True, range=(-1, 1), 19 | ) 20 | 21 | 22 | if __name__ == "__main__": 23 | device = "cuda" 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | parser.add_argument("--size", type=int, default=1024) 28 | parser.add_argument("--sample", type=int, default=1) 29 | parser.add_argument("--pics", type=int, default=20) 30 | parser.add_argument("--truncation", type=float, default=1) 31 | parser.add_argument("--truncation_mean", type=int, default=4096) 32 | parser.add_argument("--ckpt", type=str, default="stylegan2-ffhq-config-f.pt") 33 | parser.add_argument("--channel_multiplier", type=int, default=2) 34 | 35 | args = parser.parse_args() 36 | 37 | args.latent = 512 38 | args.n_mlp = 8 39 | 40 | g_ema = Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device) 41 | checkpoint = torch.load(args.ckpt) 42 | 43 | g_ema.load_state_dict(checkpoint["g_ema"]) 44 | 45 | if args.truncation < 1: 46 | with torch.no_grad(): 47 | mean_latent = g_ema.mean_latent(args.truncation_mean) 48 | else: 49 | mean_latent = None 50 | 51 | generate(args, g_ema, device, mean_latent) 52 | -------------------------------------------------------------------------------- /gpu_profile.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import linecache 3 | import os 4 | 5 | os.environ["CUDA_LAUNCH_BLOCKING"] = "1" 6 | 7 | from py3nvml import py3nvml 8 | import torch 9 | import socket 10 | 11 | # different settings 12 | print_tensor_sizes = True 13 | use_incremental = False 14 | 15 | if "GPU_DEBUG" in os.environ: 16 | gpu_profile_fn = f"host_{socket.gethostname()}_gpu{os.environ['GPU_DEBUG']}_mem_prof-{datetime.datetime.now():%d-%b-%y-%H-%M-%S}.prof.txt" 17 | print("profiling gpu usage to ", gpu_profile_fn) 18 | 19 | ## Global variables 20 | last_tensor_sizes = set() 21 | last_meminfo_used = 0 22 | lineno = None 23 | func_name = None 24 | filename = None 25 | module_name = None 26 | 27 | 28 | def gpu_profile(frame, event, arg): 29 | # it is _about to_ execute (!) 30 | global last_tensor_sizes 31 | global last_meminfo_used 32 | global lineno, func_name, filename, module_name 33 | 34 | if event == "line": 35 | try: 36 | # about _previous_ line (!) 37 | if lineno is not None: 38 | py3nvml.nvmlInit() 39 | handle = py3nvml.nvmlDeviceGetHandleByIndex(int(os.environ["GPU_DEBUG"])) 40 | meminfo = py3nvml.nvmlDeviceGetMemoryInfo(handle) 41 | line = linecache.getline(filename, lineno) 42 | where_str = module_name + " " + func_name + ":" + str(lineno) 43 | 44 | new_meminfo_used = meminfo.used 45 | mem_display = new_meminfo_used - last_meminfo_used if use_incremental else new_meminfo_used 46 | if abs(new_meminfo_used - last_meminfo_used) / 1024 ** 2 > 256: 47 | with open(gpu_profile_fn, "a+") as f: 48 | f.write(f"{where_str:<50}" f":{(mem_display)/1024**2:<7.1f}Mb " f"{line.rstrip()}\n") 49 | 50 | last_meminfo_used = new_meminfo_used 51 | if print_tensor_sizes is True: 52 | for tensor in get_tensors(): 53 | if not hasattr(tensor, "dbg_alloc_where"): 54 | tensor.dbg_alloc_where = where_str 55 | new_tensor_sizes = {(type(x), tuple(x.size()), x.dbg_alloc_where) for x in get_tensors()} 56 | for t, s, loc in new_tensor_sizes - last_tensor_sizes: 57 | f.write(f"+ {loc:<50} {str(s):<20} {str(t):<10}\n") 58 | for t, s, loc in last_tensor_sizes - new_tensor_sizes: 59 | f.write(f"- {loc:<50} {str(s):<20} {str(t):<10}\n") 60 | last_tensor_sizes = new_tensor_sizes 61 | py3nvml.nvmlShutdown() 62 | 63 | # save details about line _to be_ executed 64 | lineno = None 65 | 66 | func_name = frame.f_code.co_name 67 | filename = frame.f_globals["__file__"] 68 | if filename.endswith(".pyc") or filename.endswith(".pyo"): 69 | filename = filename[:-1] 70 | module_name = frame.f_globals["__name__"] 71 | lineno = frame.f_lineno 72 | 73 | # only profile codes within the parent folder, otherwise there are too many function calls into other pytorch scripts 74 | # need to modify the key words below to suit your case. 75 | if "maua-stylegan2" not in os.path.dirname(os.path.abspath(filename)): 76 | lineno = None # skip current line evaluation 77 | 78 | if ( 79 | "car_datasets" in filename 80 | or "_exec_config" in func_name 81 | or "gpu_profile" in module_name 82 | or "tee_stdout" in module_name 83 | or "PIL" in module_name 84 | ): 85 | lineno = None # skip othe unnecessary lines 86 | 87 | return gpu_profile 88 | 89 | except (KeyError, AttributeError): 90 | pass 91 | 92 | return gpu_profile 93 | 94 | 95 | def get_tensors(gpu_only=True): 96 | import gc 97 | 98 | for obj in gc.get_objects(): 99 | try: 100 | if torch.is_tensor(obj): 101 | tensor = obj 102 | elif hasattr(obj, "data") and torch.is_tensor(obj.data): 103 | tensor = obj.data 104 | else: 105 | continue 106 | 107 | if tensor.is_cuda: 108 | yield tensor 109 | except Exception as e: 110 | pass 111 | -------------------------------------------------------------------------------- /gpumon.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import signal 4 | import subprocess 5 | import time 6 | from queue import Empty, Queue 7 | from threading import Thread 8 | 9 | import numpy as np 10 | import wandb 11 | 12 | parser = argparse.ArgumentParser() 13 | 14 | parser.add_argument("--wbname", type=str, required=True) 15 | parser.add_argument("--wbproj", type=str, required=True) 16 | parser.add_argument("--wbgroup", type=str, default=None) 17 | 18 | args = parser.parse_args() 19 | 20 | if args.wbgroup is None: 21 | wandb.init(project=args.wbproj, name=args.wbname, settings=wandb.Settings(_disable_stats=True)) 22 | else: 23 | wandb.init(project=args.wbproj, group=args.wbgroup, name=args.wbname, settings=wandb.Settings(_disable_stats=True)) 24 | 25 | 26 | def enqueue_output(out, queue): 27 | for line in iter(out.readline, b""): 28 | queue.put(line) 29 | out.close() 30 | 31 | 32 | os.setpgrp() 33 | 34 | clock_proc = subprocess.Popen("nvidia-smi dmon -s c", shell=True, stdout=subprocess.PIPE, bufsize=1) 35 | clock_proc.daemon = True 36 | 37 | time.sleep(0.5) 38 | 39 | throttle_reasons = [ 40 | "clocks_throttle_reasons.gpu_idle", 41 | "clocks_throttle_reasons.applications_clocks_setting", 42 | "clocks_throttle_reasons.sw_power_cap", 43 | "clocks_throttle_reasons.sw_thermal_slowdown", 44 | "clocks_throttle_reasons.hw_slowdown", 45 | "clocks_throttle_reasons.hw_thermal_slowdown", 46 | "clocks_throttle_reasons.hw_power_brake_slowdown", 47 | "clocks_throttle_reasons.sync_boost", 48 | ] 49 | throttle_proc = subprocess.Popen( 50 | f"nvidia-smi --query-gpu=index,{','.join(throttle_reasons)} --format=csv,noheader --loop=1", 51 | shell=True, 52 | stdout=subprocess.PIPE, 53 | bufsize=1, 54 | ) 55 | throttle_proc.daemon = True 56 | 57 | # create queue that gets the output lines from both processes 58 | q = Queue() 59 | clock_thread = Thread(target=enqueue_output, args=(clock_proc.stdout, q)) 60 | clock_thread.daemon = True 61 | thottle_thread = Thread(target=enqueue_output, args=(throttle_proc.stdout, q)) 62 | thottle_thread.daemon = True 63 | 64 | clock_thread.start() 65 | thottle_thread.start() 66 | 67 | throttles = [[], []] 68 | clocks = [[], []] 69 | while clock_proc.poll() is None or not q.empty(): 70 | try: 71 | line = q.get_nowait() 72 | except Empty: 73 | pass 74 | else: 75 | line = line.decode("utf-8").strip() 76 | if "#" in line: 77 | continue 78 | if "," in line: 79 | raw = line.split(",") 80 | gpu = int(raw[0]) 81 | bits = [0 if "Not" in a else 1 for a in raw[1:]] 82 | throttles[gpu].append(bits) 83 | # print(gpu, bits) 84 | else: 85 | raw = line.split(" ") 86 | gpu = int(raw[0]) 87 | clock = int(raw[-1]) 88 | clocks[gpu].append(clock) 89 | # print(gpu, clock) 90 | 91 | if len(clocks[0]) > 30: 92 | try: 93 | throttles = np.array(throttles) 94 | clocks = np.array(clocks) 95 | log_dict = {} 96 | for gpu in [0, 1]: 97 | log_dict[f"gpu.{gpu}.clock.speed"] = np.mean(clocks[gpu]) 98 | 99 | for r, reason in enumerate(throttle_reasons): 100 | log_dict[f"gpu.{gpu}.{reason}"] = np.mean(throttles[gpu, :, r]) 101 | 102 | print("\n".join([k.ljust(80) + str(v) for k, v in log_dict.items()])) 103 | wandb.log(log_dict) 104 | except: 105 | pass 106 | 107 | throttles = [[], []] 108 | clocks = [[], []] 109 | 110 | os.kill(throttle_proc.pid, signal.SIGINT) 111 | os.kill(clock_proc.pid, signal.SIGINT) 112 | -------------------------------------------------------------------------------- /lookahead_minimax.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | 3 | import torch 4 | from torch.optim.optimizer import Optimizer 5 | 6 | 7 | class LookaheadMinimax(Optimizer): 8 | r""" 9 | A PyTorch implementation of the lookahead wrapper for GANs. 10 | 11 | This optimizer performs the lookahead step on both the discriminator and generator optimizers after the generator's 12 | optimizer takes a step. This ensures that joint minimax lookahead is used rather than alternating minimax lookahead 13 | (which would result from simply applying the original Lookahead Optimizer to both networks separately). 14 | 15 | Lookahead Minimax Optimizer: https://arxiv.org/abs/2006.14567 16 | Lookahead Optimizer: https://arxiv.org/abs/1907.08610 17 | """ 18 | 19 | def __init__(self, G_optimizer, D_optimizer, la_steps=5, la_alpha=0.5, pullback_momentum="none", accumulate=1): 20 | """ 21 | G_optimizer: generator optimizer 22 | D_optimizer: discriminator optimizer 23 | la_steps (int): number of lookahead steps 24 | la_alpha (float): linear interpolation factor. 1.0 recovers the inner optimizer. 25 | pullback_momentum (str): change to inner optimizer momentum on interpolation update 26 | acumulate (int): number of gradient accumulation steps 27 | """ 28 | self.G_optimizer = G_optimizer 29 | self.D_optimizer = D_optimizer 30 | 31 | self._la_step = 0 # counter for inner optimizer 32 | self.la_alpha = la_alpha 33 | self._total_la_steps = la_steps * accumulate 34 | self._la_steps = la_steps 35 | 36 | pullback_momentum = pullback_momentum.lower() 37 | assert pullback_momentum in ["reset", "pullback", "none"] 38 | self.pullback_momentum = pullback_momentum 39 | 40 | self.state = defaultdict(dict) 41 | 42 | # Cache the current optimizer parameters 43 | for group in G_optimizer.param_groups: 44 | for p in group["params"]: 45 | param_state = self.state[p] 46 | param_state["cached_G_params"] = torch.zeros_like(p.data) 47 | param_state["cached_G_params"].copy_(p.data) 48 | if self.pullback_momentum == "pullback": 49 | param_state["cached_G_mom"] = torch.zeros_like(p.data) 50 | 51 | for group in D_optimizer.param_groups: 52 | for p in group["params"]: 53 | param_state = self.state[p] 54 | param_state["cached_D_params"] = torch.zeros_like(p.data) 55 | param_state["cached_D_params"].copy_(p.data) 56 | if self.pullback_momentum == "pullback": 57 | param_state["cached_D_mom"] = torch.zeros_like(p.data) 58 | 59 | def __getstate__(self): 60 | return { 61 | "state": self.state, 62 | "G_optimizer": self.G_optimizer, 63 | "D_optimizer": self.D_optimizer, 64 | "la_alpha": self.la_alpha, 65 | "_la_step": self._la_step, 66 | "_total_la_steps": self._la_steps, 67 | "pullback_momentum": self.pullback_momentum, 68 | } 69 | 70 | def zero_grad(self): 71 | self.G_optimizer.zero_grad() 72 | 73 | def get_la_step(self): 74 | return self._la_step 75 | 76 | def state_dict(self): 77 | return self.G_optimizer.state_dict() 78 | 79 | def load_state_dict(self, G_state_dict, D_state_dict): 80 | self.G_optimizer.load_state_dict(G_state_dict) 81 | self.D_optimizer.load_state_dict(D_state_dict) 82 | 83 | # Cache the current optimizer parameters 84 | for group in self.G_optimizer.param_groups: 85 | for p in group["params"]: 86 | param_state = self.state[p] 87 | param_state["cached_G_params"] = torch.zeros_like(p.data) 88 | param_state["cached_G_params"].copy_(p.data) 89 | if self.pullback_momentum == "pullback": 90 | param_state["cached_G_mom"] = self.G_optimizer.state[p]["momentum_buffer"] 91 | 92 | for group in self.D_optimizer.param_groups: 93 | for p in group["params"]: 94 | param_state = self.state[p] 95 | param_state["cached_D_params"] = torch.zeros_like(p.data) 96 | param_state["cached_D_params"].copy_(p.data) 97 | if self.pullback_momentum == "pullback": 98 | param_state["cached_D_mom"] = self.D_optimizer.state[p]["momentum_buffer"] 99 | 100 | def _backup_and_load_cache(self): 101 | """ 102 | Useful for performing evaluation on the slow weights (which typically generalize better) 103 | """ 104 | for group in self.G_optimizer.param_groups: 105 | for p in group["params"]: 106 | param_state = self.state[p] 107 | param_state["backup_G_params"] = torch.zeros_like(p.data) 108 | param_state["backup_G_params"].copy_(p.data) 109 | p.data.copy_(param_state["cached_G_params"]) 110 | 111 | for group in self.D_optimizer.param_groups: 112 | for p in group["params"]: 113 | param_state = self.state[p] 114 | param_state["backup_D_params"] = torch.zeros_like(p.data) 115 | param_state["backup_D_params"].copy_(p.data) 116 | p.data.copy_(param_state["cached_D_params"]) 117 | 118 | def _clear_and_load_backup(self): 119 | for group in self.G_optimizer.param_groups: 120 | for p in group["params"]: 121 | param_state = self.state[p] 122 | p.data.copy_(param_state["backup_G_params"]) 123 | del param_state["backup_G_params"] 124 | 125 | for group in self.D_optimizer.param_groups: 126 | for p in group["params"]: 127 | param_state = self.state[p] 128 | p.data.copy_(param_state["backup_D_params"]) 129 | del param_state["backup_D_params"] 130 | 131 | @property 132 | def param_groups(self): 133 | return self.G_optimizer.param_groups 134 | 135 | def step(self, closure=None): 136 | """ 137 | Performs a single Lookahead optimization step on BOTH optimizers after the generator's optimizer step. 138 | 139 | This allows the discriminator's optimizer to take more steps when using a higher step ratio and still have the 140 | lookahead step being performed once after k generator steps. This also ensures the optimizers are updated with 141 | the lookahead step simultaneously, rather than in alternating fashion. 142 | 143 | Arguments: 144 | closure (callable, optional): A closure that reevaluates the model 145 | and returns the loss. 146 | """ 147 | loss = self.G_optimizer.step(closure) 148 | self._la_step += 1 149 | 150 | if self._la_step >= self._total_la_steps: 151 | with torch.cuda.amp.autocast(enabled=False): 152 | self._la_step = 0 153 | 154 | # Lookahead and cache the current generator optimizer parameters 155 | for group in self.G_optimizer.param_groups: 156 | for p in group["params"]: 157 | param_state = self.state[p] 158 | p.data.mul_(self.la_alpha).add_(1.0 - self.la_alpha, param_state["cached_G_params"]) 159 | param_state["cached_G_params"].copy_(p.data) 160 | 161 | if self.pullback_momentum == "pullback": 162 | internal_momentum = self.G_optimizer.state[p]["momentum_buffer"] 163 | self.G_optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_( 164 | 1.0 - self.la_alpha, param_state["cached_G_mom"] 165 | ) 166 | param_state["cached_G_mom"] = self.G_optimizer.state[p]["momentum_buffer"] 167 | elif self.pullback_momentum == "reset": 168 | self.G_optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data) 169 | 170 | # Lookahead and cache the current discriminator optimizer parameters 171 | for group in self.D_optimizer.param_groups: 172 | for p in group["params"]: 173 | param_state = self.state[p] 174 | p.data.mul_(self.la_alpha).add_(1.0 - self.la_alpha, param_state["cached_D_params"]) 175 | param_state["cached_D_params"].copy_(p.data) 176 | 177 | if self.pullback_momentum == "pullback": 178 | internal_momentum = self.D_optimizer.state[p]["momentum_buffer"] 179 | self.D_optimizer.state[p]["momentum_buffer"] = internal_momentum.mul_(self.la_alpha).add_( 180 | 1.0 - self.la_alpha, param_state["cached_D_mom"] 181 | ) 182 | param_state["cached_D_mom"] = self.optimizer.state[p]["momentum_buffer"] 183 | elif self.pullback_momentum == "reset": 184 | self.D_optimizer.state[p]["momentum_buffer"] = torch.zeros_like(p.data) 185 | 186 | return loss 187 | -------------------------------------------------------------------------------- /op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.autograd import Function 7 | from torch.utils.cpp_extension import load 8 | 9 | 10 | module_path = os.path.dirname(__file__) 11 | fused = load( 12 | "fused", 13 | sources=[ 14 | os.path.join(module_path, "fused_bias_act.cpp"), 15 | os.path.join(module_path, "fused_bias_act_kernel.cu"), 16 | ], 17 | ) 18 | 19 | 20 | class FusedLeakyReLUFunctionBackward(Function): 21 | @staticmethod 22 | def forward(ctx, grad_output, out, negative_slope, scale): 23 | ctx.save_for_backward(out) 24 | ctx.negative_slope = negative_slope 25 | ctx.scale = scale 26 | 27 | empty = grad_output.new_empty(0) 28 | 29 | grad_input = fused.fused_bias_act( 30 | grad_output, empty, out, 3, 1, negative_slope, scale 31 | ) 32 | 33 | dim = [0] 34 | 35 | if grad_input.ndim > 2: 36 | dim += list(range(2, grad_input.ndim)) 37 | 38 | grad_bias = grad_input.sum(dim).detach() 39 | 40 | return grad_input, grad_bias 41 | 42 | @staticmethod 43 | def backward(ctx, gradgrad_input, gradgrad_bias): 44 | out, = ctx.saved_tensors 45 | gradgrad_out = fused.fused_bias_act( 46 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 47 | ) 48 | 49 | return gradgrad_out, None, None, None 50 | 51 | 52 | class FusedLeakyReLUFunction(Function): 53 | @staticmethod 54 | def forward(ctx, input, bias, negative_slope, scale): 55 | empty = input.new_empty(0) 56 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 57 | ctx.save_for_backward(out) 58 | ctx.negative_slope = negative_slope 59 | ctx.scale = scale 60 | 61 | return out 62 | 63 | @staticmethod 64 | def backward(ctx, grad_output): 65 | out, = ctx.saved_tensors 66 | 67 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 68 | grad_output, out, ctx.negative_slope, ctx.scale 69 | ) 70 | 71 | return grad_input, grad_bias, None, None 72 | 73 | 74 | class FusedLeakyReLU(nn.Module): 75 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 76 | super().__init__() 77 | 78 | self.bias = nn.Parameter(torch.zeros(channel)) 79 | self.negative_slope = negative_slope 80 | self.scale = scale 81 | 82 | def forward(self, input): 83 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 84 | 85 | 86 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 87 | if input.device.type == "cpu": 88 | rest_dim = [1] * (input.ndim - bias.ndim - 1) 89 | return ( 90 | F.leaky_relu( 91 | input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 92 | ) 93 | * scale 94 | ) 95 | 96 | else: 97 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 98 | -------------------------------------------------------------------------------- /op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | upfirdn2d_op = load( 11 | "upfirdn2d", 12 | sources=[ 13 | os.path.join(module_path, "upfirdn2d.cpp"), 14 | os.path.join(module_path, "upfirdn2d_kernel.cu"), 15 | ], 16 | ) 17 | 18 | 19 | class UpFirDn2dBackward(Function): 20 | @staticmethod 21 | def forward( 22 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 23 | ): 24 | 25 | up_x, up_y = up 26 | down_x, down_y = down 27 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 28 | 29 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 30 | 31 | grad_input = upfirdn2d_op.upfirdn2d( 32 | grad_output, 33 | grad_kernel, 34 | down_x, 35 | down_y, 36 | up_x, 37 | up_y, 38 | g_pad_x0, 39 | g_pad_x1, 40 | g_pad_y0, 41 | g_pad_y1, 42 | ) 43 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 44 | 45 | ctx.save_for_backward(kernel) 46 | 47 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 48 | 49 | ctx.up_x = up_x 50 | ctx.up_y = up_y 51 | ctx.down_x = down_x 52 | ctx.down_y = down_y 53 | ctx.pad_x0 = pad_x0 54 | ctx.pad_x1 = pad_x1 55 | ctx.pad_y0 = pad_y0 56 | ctx.pad_y1 = pad_y1 57 | ctx.in_size = in_size 58 | ctx.out_size = out_size 59 | 60 | return grad_input 61 | 62 | @staticmethod 63 | def backward(ctx, gradgrad_input): 64 | kernel, = ctx.saved_tensors 65 | 66 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 67 | 68 | gradgrad_out = upfirdn2d_op.upfirdn2d( 69 | gradgrad_input, 70 | kernel, 71 | ctx.up_x, 72 | ctx.up_y, 73 | ctx.down_x, 74 | ctx.down_y, 75 | ctx.pad_x0, 76 | ctx.pad_x1, 77 | ctx.pad_y0, 78 | ctx.pad_y1, 79 | ) 80 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 81 | gradgrad_out = gradgrad_out.view( 82 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 83 | ) 84 | 85 | return gradgrad_out, None, None, None, None, None, None, None, None 86 | 87 | 88 | class UpFirDn2d(Function): 89 | @staticmethod 90 | def forward(ctx, input, kernel, up, down, pad): 91 | up_x, up_y = up 92 | down_x, down_y = down 93 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 94 | 95 | kernel_h, kernel_w = kernel.shape 96 | batch, channel, in_h, in_w = input.shape 97 | ctx.in_size = input.shape 98 | 99 | input = input.reshape(-1, in_h, in_w, 1) 100 | 101 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 102 | 103 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 104 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 105 | ctx.out_size = (out_h, out_w) 106 | 107 | ctx.up = (up_x, up_y) 108 | ctx.down = (down_x, down_y) 109 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 110 | 111 | g_pad_x0 = kernel_w - pad_x0 - 1 112 | g_pad_y0 = kernel_h - pad_y0 - 1 113 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 114 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 115 | 116 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 117 | 118 | out = upfirdn2d_op.upfirdn2d( 119 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 120 | ) 121 | # out = out.view(major, out_h, out_w, minor) 122 | out = out.view(-1, channel, out_h, out_w) 123 | 124 | return out 125 | 126 | @staticmethod 127 | def backward(ctx, grad_output): 128 | kernel, grad_kernel = ctx.saved_tensors 129 | 130 | grad_input = UpFirDn2dBackward.apply( 131 | grad_output, 132 | kernel, 133 | grad_kernel, 134 | ctx.up, 135 | ctx.down, 136 | ctx.pad, 137 | ctx.g_pad, 138 | ctx.in_size, 139 | ctx.out_size, 140 | ) 141 | 142 | return grad_input, None, None, None, None 143 | 144 | 145 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 146 | if input.device.type == "cpu": 147 | out = upfirdn2d_native( 148 | input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1] 149 | ) 150 | 151 | else: 152 | out = UpFirDn2d.apply( 153 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 154 | ) 155 | 156 | return out 157 | 158 | 159 | def upfirdn2d_native( 160 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 161 | ): 162 | _, channel, in_h, in_w = input.shape 163 | input = input.reshape(-1, in_h, in_w, 1) 164 | 165 | _, in_h, in_w, minor = input.shape 166 | kernel_h, kernel_w = kernel.shape 167 | 168 | out = input.view(-1, in_h, 1, in_w, 1, minor) 169 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 170 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 171 | 172 | out = F.pad( 173 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 174 | ) 175 | out = out[ 176 | :, 177 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 178 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 179 | :, 180 | ] 181 | 182 | out = out.permute(0, 3, 1, 2) 183 | out = out.reshape( 184 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 185 | ) 186 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 187 | out = F.conv2d(out, w) 188 | out = out.reshape( 189 | -1, 190 | minor, 191 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 192 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 193 | ) 194 | out = out.permute(0, 2, 3, 1) 195 | out = out[:, ::down_y, ::down_x, :] 196 | 197 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 198 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 199 | 200 | return out.view(-1, channel, out_h, out_w) 201 | -------------------------------------------------------------------------------- /prepare_data.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from io import BytesIO 3 | import multiprocessing 4 | from functools import partial 5 | 6 | from PIL import Image 7 | 8 | import lmdb 9 | from tqdm import tqdm 10 | from torchvision import datasets 11 | from torchvision.transforms import functional as trans_fn 12 | 13 | # ImageFile.LOAD_TRUNCATED_IMAGES = True 14 | 15 | 16 | def resize_and_convert(img, size, resample, quality=100): 17 | img = trans_fn.resize(img, size, resample) 18 | img = trans_fn.center_crop(img, size) 19 | buffer = BytesIO() 20 | img.save(buffer, format="jpeg", quality=quality) 21 | val = buffer.getvalue() 22 | 23 | return val 24 | 25 | 26 | def resize_multiple(img, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS, quality=100): 27 | imgs = [] 28 | 29 | for size in sizes: 30 | imgs.append(resize_and_convert(img, size, resample, quality)) 31 | 32 | return imgs 33 | 34 | 35 | def resize_worker(img_file, sizes, resample): 36 | i, file = img_file 37 | try: 38 | img = Image.open(file) 39 | img = img.convert("RGB") 40 | except: 41 | print(file, "truncated") 42 | out = resize_multiple(img, sizes=sizes, resample=resample) 43 | 44 | return i, out 45 | 46 | 47 | def prepare(env, dataset, n_worker, sizes=(128, 256, 512, 1024), resample=Image.LANCZOS): 48 | resize_fn = partial(resize_worker, sizes=sizes, resample=resample) 49 | 50 | files = sorted(dataset.imgs, key=lambda x: x[0]) 51 | files = [(i, file) for i, (file, label) in enumerate(files)] 52 | total = 0 53 | 54 | with multiprocessing.Pool(n_worker) as pool: 55 | for i, imgs in tqdm(pool.imap_unordered(resize_fn, files)): 56 | for size, img in zip(sizes, imgs): 57 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8") 58 | 59 | with env.begin(write=True) as txn: 60 | txn.put(key, img) 61 | 62 | total += 1 63 | 64 | with env.begin(write=True) as txn: 65 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 66 | 67 | 68 | if __name__ == "__main__": 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument("--out", type=str) 71 | parser.add_argument("--size", type=str, default="128,256,512,1024") 72 | parser.add_argument("--n_worker", type=int, default=8) 73 | parser.add_argument("--resample", type=str, default="bilinear") 74 | parser.add_argument("path", type=str) 75 | 76 | args = parser.parse_args() 77 | 78 | resample_map = {"lanczos": Image.LANCZOS, "bilinear": Image.BILINEAR} 79 | resample = resample_map[args.resample] 80 | 81 | sizes = [int(s.strip()) for s in args.size.split(",")] 82 | 83 | print(f"Make dataset of image sizes:", ", ".join(str(s) for s in sizes)) 84 | 85 | imgset = datasets.ImageFolder(args.path) 86 | 87 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 88 | prepare(env, imgset, args.n_worker, sizes=sizes, resample=resample) 89 | -------------------------------------------------------------------------------- /prepare_vae_codes.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import numpy as np 3 | import multiprocessing 4 | from functools import partial 5 | 6 | import lmdb 7 | from tqdm import tqdm 8 | 9 | import torch as th 10 | from autoencoder import ConvSegNet 11 | from torchvision import datasets 12 | import torchvision.transforms as transforms 13 | 14 | 15 | def lmdmb_write_worker(i_code, env, size): 16 | i, code = i_code.cpu().numpy() 17 | key = f"{size}-{str(i).zfill(5)}".encode("utf-8") 18 | with env.begin(write=True) as txn: 19 | txn.put(key, code) 20 | 21 | 22 | def prepare(env, vae, loader, total, batch_size, n_worker, size=1024): 23 | write_fn = partial(lmdmb_write_worker, env=env, size=size) 24 | 25 | b = 0 26 | with multiprocessing.Pool(n_worker) as pool: 27 | for batch in tqdm(loader): 28 | code_nums = np.arange(b * batch_size, (b + 1) * batch_size) 29 | 30 | with th.no_grad(): 31 | codes = vae.module.encode(batch[0].cuda()) 32 | 33 | pool.imap_unordered(write_fn, zip(code_nums, codes)) 34 | 35 | b += 1 36 | 37 | with env.begin(write=True) as txn: 38 | txn.put("length".encode("utf-8"), str(total).encode("utf-8")) 39 | 40 | 41 | if __name__ == "__main__": 42 | parser = argparse.ArgumentParser() 43 | parser.add_argument("--out", type=str) 44 | parser.add_argument("--size", type=int, default=1024) 45 | parser.add_argument("--n_worker", type=int, default=24) 46 | parser.add_argument("--batch_size", type=int, default=4) 47 | parser.add_argument("--resample", type=str, default="bilinear") 48 | parser.add_argument("data_path", type=str) 49 | parser.add_argument("vae_checkpoint", type=str) 50 | 51 | args = parser.parse_args() 52 | 53 | print(f"Make dataset of image size:", args.size) 54 | 55 | transform = transforms.Compose( 56 | [ 57 | transforms.Resize(args.size), 58 | transforms.ToTensor(), 59 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True), 60 | ] 61 | ) 62 | imgset = datasets.ImageFolder(args.data_path, transform=transform) 63 | loader = th.utils.data.DataLoader(imgset, batch_size=args.batch_size, num_workers=int(args.n_worker / 2)) 64 | print(args.batch_size) 65 | print(loader) 66 | 67 | vae = ConvSegNet() 68 | vae.load_state_dict(th.load(args.vae_checkpoint)["vae"]) 69 | vae = th.nn.DataParallel(vae).eval().cuda() 70 | 71 | with lmdb.open(args.out, map_size=1024 ** 4, readahead=False) as env: 72 | prepare( 73 | env, 74 | vae, 75 | loader, 76 | total=len(imgset), 77 | batch_size=args.batch_size, 78 | n_worker=int(args.n_worker / 2), 79 | size=args.size, 80 | ) 81 | -------------------------------------------------------------------------------- /projector.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import math 3 | import os 4 | 5 | import torch 6 | from torch import optim 7 | from torch.nn import functional as F 8 | from torchvision import transforms 9 | from PIL import Image 10 | from tqdm import tqdm 11 | 12 | import lpips 13 | from model import Generator 14 | 15 | 16 | def noise_regularize(noises): 17 | loss = 0 18 | 19 | for noise in noises: 20 | size = noise.shape[2] 21 | 22 | while True: 23 | loss = ( 24 | loss 25 | + (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2) 26 | + (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2) 27 | ) 28 | 29 | if size <= 8: 30 | break 31 | 32 | noise = noise.reshape([1, 1, size // 2, 2, size // 2, 2]) 33 | noise = noise.mean([3, 5]) 34 | size //= 2 35 | 36 | return loss 37 | 38 | 39 | def noise_normalize_(noises): 40 | for noise in noises: 41 | mean = noise.mean() 42 | std = noise.std() 43 | 44 | noise.data.add_(-mean).div_(std) 45 | 46 | 47 | def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05): 48 | lr_ramp = min(1, (1 - t) / rampdown) 49 | lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi) 50 | lr_ramp = lr_ramp * min(1, t / rampup) 51 | 52 | return initial_lr * lr_ramp 53 | 54 | 55 | def latent_noise(latent, strength): 56 | noise = torch.randn_like(latent) * strength 57 | 58 | return latent + noise 59 | 60 | 61 | def make_image(tensor): 62 | return ( 63 | tensor.detach() 64 | .clamp_(min=-1, max=1) 65 | .add(1) 66 | .div_(2) 67 | .mul(255) 68 | .type(torch.uint8) 69 | .permute(0, 2, 3, 1) 70 | .to("cpu") 71 | .numpy() 72 | ) 73 | 74 | 75 | if __name__ == "__main__": 76 | device = "cuda" 77 | 78 | parser = argparse.ArgumentParser() 79 | parser.add_argument("--ckpt", type=str, required=True) 80 | parser.add_argument("--size", type=int, default=256) 81 | parser.add_argument("--lr_rampup", type=float, default=0.05) 82 | parser.add_argument("--lr_rampdown", type=float, default=0.25) 83 | parser.add_argument("--lr", type=float, default=0.1) 84 | parser.add_argument("--noise", type=float, default=0.05) 85 | parser.add_argument("--noise_ramp", type=float, default=0.75) 86 | parser.add_argument("--step", type=int, default=1000) 87 | parser.add_argument("--noise_regularize", type=float, default=1e5) 88 | parser.add_argument("--mse", type=float, default=0) 89 | parser.add_argument("--w_plus", action="store_true") 90 | parser.add_argument("files", metavar="FILES", nargs="+") 91 | 92 | args = parser.parse_args() 93 | 94 | n_mean_latent = 10000 95 | 96 | resize = min(args.size, 256) 97 | 98 | transform = transforms.Compose( 99 | [ 100 | transforms.Resize(resize), 101 | transforms.CenterCrop(resize), 102 | transforms.ToTensor(), 103 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 104 | ] 105 | ) 106 | 107 | imgs = [] 108 | 109 | for imgfile in args.files: 110 | img = transform(Image.open(imgfile).convert("RGB")) 111 | imgs.append(img) 112 | 113 | imgs = torch.stack(imgs, 0).to(device) 114 | 115 | g_ema = Generator(args.size, 512, 8) 116 | g_ema.load_state_dict(torch.load(args.ckpt)["g_ema"], strict=False) 117 | g_ema.eval() 118 | g_ema = g_ema.to(device) 119 | 120 | with torch.no_grad(): 121 | noise_sample = torch.randn(n_mean_latent, 512, device=device) 122 | latent_out = g_ema.style(noise_sample) 123 | 124 | latent_mean = latent_out.mean(0) 125 | latent_std = ((latent_out - latent_mean).pow(2).sum() / n_mean_latent) ** 0.5 126 | 127 | percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda")) 128 | 129 | noises = g_ema.make_noise() 130 | 131 | latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(2, 1) 132 | 133 | if args.w_plus: 134 | latent_in = latent_in.unsqueeze(1).repeat(1, g_ema.n_latent, 1) 135 | 136 | latent_in.requires_grad = True 137 | 138 | for noise in noises: 139 | noise.requires_grad = True 140 | 141 | optimizer = optim.Adam([latent_in] + noises, lr=args.lr) 142 | 143 | pbar = tqdm(range(args.step)) 144 | latent_path = [] 145 | 146 | for i in pbar: 147 | t = i / args.step 148 | lr = get_lr(t, args.lr) 149 | optimizer.param_groups[0]["lr"] = lr 150 | noise_strength = latent_std * args.noise * max(0, 1 - t / args.noise_ramp) ** 2 151 | latent_n = latent_noise(latent_in, noise_strength.item()) 152 | 153 | img_gen, _ = g_ema([latent_n], input_is_latent=True, noise=noises) 154 | 155 | batch, channel, height, width = img_gen.shape 156 | 157 | if height > 256: 158 | factor = height // 256 159 | 160 | img_gen = img_gen.reshape(batch, channel, height // factor, factor, width // factor, factor) 161 | img_gen = img_gen.mean([3, 5]) 162 | 163 | p_loss = percept(img_gen, imgs).sum() 164 | n_loss = noise_regularize(noises) 165 | mse_loss = F.mse_loss(img_gen, imgs) 166 | 167 | loss = p_loss + args.noise_regularize * n_loss + args.mse * mse_loss 168 | 169 | optimizer.zero_grad() 170 | loss.backward() 171 | optimizer.step() 172 | 173 | noise_normalize_(noises) 174 | 175 | if (i + 1) % 100 == 0: 176 | latent_path.append(latent_in.detach().clone()) 177 | 178 | pbar.set_description( 179 | ( 180 | f"perceptual: {p_loss.item():.4f}; noise regularize: {n_loss.item():.4f};" 181 | f" mse: {mse_loss.item():.4f}; lr: {lr:.4f}" 182 | ) 183 | ) 184 | 185 | result_file = {"noises": noises} 186 | 187 | img_gen, _ = g_ema([latent_path[-1]], input_is_latent=True, noise=noises) 188 | 189 | filename = os.path.splitext(os.path.basename(args.files[0]))[0] + ".pt" 190 | 191 | img_ar = make_image(img_gen) 192 | 193 | for i, input_name in enumerate(args.files): 194 | result_file[input_name] = {"img": img_gen[i], "latent": latent_in[i]} 195 | img_name = os.path.splitext(os.path.basename(input_name))[0] + "-project.png" 196 | pil_img = Image.fromarray(img_ar[i]) 197 | pil_img.save(img_name) 198 | 199 | torch.save(result_file, filename) 200 | -------------------------------------------------------------------------------- /render.py: -------------------------------------------------------------------------------- 1 | import queue 2 | from threading import Thread 3 | 4 | import ffmpeg 5 | import numpy as np 6 | import PIL.Image 7 | import torch as th 8 | from tqdm import tqdm 9 | 10 | th.set_grad_enabled(False) 11 | th.backends.cudnn.benchmark = True 12 | 13 | 14 | def render( 15 | generator, 16 | latents, 17 | noise, 18 | offset, 19 | duration, 20 | batch_size, 21 | out_size, 22 | output_file, 23 | audio_file=None, 24 | truncation=1.0, 25 | bends=[], 26 | rewrites={}, 27 | randomize_noise=False, 28 | ffmpeg_preset="slow", 29 | ): 30 | split_queue = queue.Queue() 31 | render_queue = queue.Queue() 32 | 33 | # postprocesses batched torch tensors to individual RGB numpy arrays 34 | def split_batches(jobs_in, jobs_out): 35 | while True: 36 | try: 37 | imgs = jobs_in.get(timeout=5) 38 | except queue.Empty: 39 | return 40 | imgs = (imgs.clamp_(-1, 1) + 1) * 127.5 41 | imgs = imgs.permute(0, 2, 3, 1) 42 | for img in imgs: 43 | jobs_out.put(img.cpu().numpy().astype(np.uint8)) 44 | jobs_in.task_done() 45 | 46 | # start background ffmpeg process that listens on stdin for frame data 47 | if out_size == 512: 48 | output_size = "512x512" 49 | elif out_size == 1024: 50 | output_size = "1024x1024" 51 | elif out_size == 1920: 52 | output_size = "1920x1080" 53 | elif out_size == 1080: 54 | output_size = "1080x1920" 55 | else: 56 | raise Exception("The only output sizes currently supported are: 512, 1024, 1080, or 1920") 57 | 58 | if audio_file is not None: 59 | audio = ffmpeg.input(audio_file, ss=offset, t=duration, guess_layout_max=0) 60 | video = ( 61 | ffmpeg.input("pipe:", format="rawvideo", pix_fmt="rgb24", framerate=len(latents) / duration, s=output_size) 62 | .output( 63 | audio, 64 | output_file, 65 | framerate=len(latents) / duration, 66 | vcodec="libx264", 67 | pix_fmt="yuv420p", 68 | preset=ffmpeg_preset, 69 | audio_bitrate="320K", 70 | ac=2, 71 | v="warning", 72 | ) 73 | .global_args("-hide_banner") 74 | .overwrite_output() 75 | .run_async(pipe_stdin=True) 76 | ) 77 | else: 78 | video = ( 79 | ffmpeg.input("pipe:", format="rawvideo", pix_fmt="rgb24", framerate=len(latents) / duration, s=output_size) 80 | .output( 81 | output_file, 82 | framerate=len(latents) / duration, 83 | vcodec="libx264", 84 | pix_fmt="yuv420p", 85 | preset=ffmpeg_preset, 86 | v="warning", 87 | ) 88 | .global_args("-hide_banner") 89 | .overwrite_output() 90 | .run_async(pipe_stdin=True) 91 | ) 92 | 93 | # writes numpy frames to ffmpeg stdin as raw rgb24 bytes 94 | def make_video(jobs_in): 95 | w, h = [int(dim) for dim in output_size.split("x")] 96 | for _ in tqdm(range(len(latents)), position=0, leave=True, ncols=80): 97 | img = jobs_in.get(timeout=5) 98 | if img.shape[1] == 2048: 99 | img = img[:, 112:-112, :] 100 | im = PIL.Image.fromarray(img) 101 | img = np.array(im.resize((1920, 1080), PIL.Image.BILINEAR)) 102 | elif img.shape[0] == 2048: 103 | img = img[112:-112, :, :] 104 | im = PIL.Image.fromarray(img) 105 | img = np.array(im.resize((1080, 1920), PIL.Image.BILINEAR)) 106 | assert ( 107 | img.shape[1] == w and img.shape[0] == h 108 | ), f"""generator's output image size does not match specified output size: \n 109 | got: {img.shape[1]}x{img.shape[0]}\t\tshould be {output_size}""" 110 | video.stdin.write(img.tobytes()) 111 | jobs_in.task_done() 112 | video.stdin.close() 113 | video.wait() 114 | 115 | splitter = Thread(target=split_batches, args=(split_queue, render_queue)) 116 | splitter.daemon = True 117 | renderer = Thread(target=make_video, args=(render_queue,)) 118 | renderer.daemon = True 119 | 120 | # make all data that needs to be loaded to the GPU float, contiguous, and pinned 121 | # the entire process is severly memory-transfer bound, but at least this might help a little 122 | latents = latents.float().contiguous().pin_memory() 123 | 124 | for ni, noise_scale in enumerate(noise): 125 | noise[ni] = noise_scale.float().contiguous().pin_memory() if noise_scale is not None else None 126 | 127 | param_dict = dict(generator.named_parameters()) 128 | original_weights = {} 129 | for param, (rewrite, modulation) in rewrites.items(): 130 | rewrites[param] = [rewrite, modulation.float().contiguous().pin_memory()] 131 | original_weights[param] = param_dict[param].copy().cpu().float().contiguous().pin_memory() 132 | 133 | for bend in bends: 134 | if "modulation" in bend: 135 | bend["modulation"] = bend["modulation"].float().contiguous().pin_memory() 136 | 137 | if not isinstance(truncation, float): 138 | truncation = truncation.float().contiguous().pin_memory() 139 | 140 | for n in range(0, len(latents), batch_size): 141 | # load batches of data onto the GPU 142 | latent_batch = latents[n : n + batch_size].cuda(non_blocking=True) 143 | 144 | noise_batch = [] 145 | for noise_scale in noise: 146 | if noise_scale is not None: 147 | noise_batch.append(noise_scale[n : n + batch_size].cuda(non_blocking=True)) 148 | else: 149 | noise_batch.append(None) 150 | 151 | bend_batch = [] 152 | if bends is not None: 153 | for bend in bends: 154 | if "modulation" in bend: 155 | transform = bend["transform"](bend["modulation"][n : n + batch_size].cuda(non_blocking=True)) 156 | bend_batch.append({"layer": bend["layer"], "transform": transform}) 157 | else: 158 | bend_batch.append({"layer": bend["layer"], "transform": bend["transform"]}) 159 | 160 | for param, (rewrite, modulation) in rewrites.items(): 161 | transform = rewrite(modulation[n : n + batch_size]) 162 | rewritten_weight = transform(original_weights[param]).cuda(non_blocking=True) 163 | param_attrs = param.split(".") 164 | mod = generator 165 | for attr in param_attrs[:-1]: 166 | mod = getattr(mod, attr) 167 | setattr(mod, param_attrs[-1], th.nn.Parameter(rewritten_weight)) 168 | 169 | if not isinstance(truncation, float): 170 | truncation_batch = truncation[n : n + batch_size].cuda(non_blocking=True) 171 | else: 172 | truncation_batch = truncation 173 | 174 | # forward through the generator 175 | outputs, _ = generator( 176 | styles=latent_batch, 177 | noise=noise_batch, 178 | truncation=truncation_batch, 179 | transform_dict_list=bend_batch, 180 | randomize_noise=randomize_noise, 181 | input_is_latent=True, 182 | ) 183 | 184 | # send output to be split into frames and rendered one by one 185 | split_queue.put(outputs) 186 | 187 | if n == 0: 188 | splitter.start() 189 | renderer.start() 190 | 191 | splitter.join() 192 | renderer.join() 193 | 194 | 195 | def write_video(arr, output_file, fps): 196 | print(f"writing {arr.shape[0]} frames...") 197 | 198 | output_size = "x".join(reversed([str(s) for s in arr.shape[1:-1]])) 199 | 200 | ffmpeg_proc = ( 201 | ffmpeg.input("pipe:", format="rawvideo", pix_fmt="rgb24", framerate=fps, s=output_size) 202 | .output(output_file, framerate=fps, vcodec="libx264", preset="slow", v="warning") 203 | .global_args("-benchmark", "-stats", "-hide_banner") 204 | .overwrite_output() 205 | .run_async(pipe_stdin=True) 206 | ) 207 | 208 | for frame in arr: 209 | ffmpeg_proc.stdin.write(frame.astype(np.uint8).tobytes()) 210 | 211 | ffmpeg_proc.stdin.close() 212 | ffmpeg_proc.wait() 213 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | numpy 4 | librosa 5 | cython 6 | madmom 7 | tqdm 8 | kornia 9 | matplotlib 10 | ffmpeg-python 11 | joblib -------------------------------------------------------------------------------- /validation/__init__.py: -------------------------------------------------------------------------------- 1 | from .metrics import vae_fid, fid, get_dataset_inception_features, ppl, prdc 2 | from .spectral_norm import track_spectral_norm 3 | -------------------------------------------------------------------------------- /validation/calc_fid.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | 4 | import torch 5 | from torch import nn 6 | import numpy as np 7 | from scipy import linalg 8 | from tqdm import tqdm 9 | 10 | from model import Generator 11 | from inception import InceptionV3 12 | 13 | 14 | @torch.no_grad() 15 | def extract_feature_from_samples(generator, inception, truncation, truncation_latent, batch_size, n_sample, device): 16 | n_batch = n_sample // batch_size 17 | resid = n_sample - (n_batch * batch_size) 18 | batch_sizes = [batch_size] * n_batch + [resid] 19 | features = [] 20 | 21 | for batch in tqdm(batch_sizes): 22 | latent = torch.randn(batch, 512, device=device) 23 | img, _ = g([latent], truncation=truncation, truncation_latent=truncation_latent) 24 | feat = inception(img)[0].view(img.shape[0], -1) 25 | features.append(feat.to("cpu")) 26 | 27 | features = torch.cat(features, 0) 28 | 29 | return features 30 | 31 | 32 | def calc_fid(sample_mean, sample_cov, real_mean, real_cov, eps=1e-6): 33 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 34 | 35 | if not np.isfinite(cov_sqrt).all(): 36 | print("product of cov matrices is singular") 37 | offset = np.eye(sample_cov.shape[0]) * eps 38 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 39 | 40 | if np.iscomplexobj(cov_sqrt): 41 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 42 | m = np.max(np.abs(cov_sqrt.imag)) 43 | 44 | raise ValueError(f"Imaginary component {m}") 45 | 46 | cov_sqrt = cov_sqrt.real 47 | 48 | mean_diff = sample_mean - real_mean 49 | mean_norm = mean_diff @ mean_diff 50 | 51 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 52 | 53 | fid = mean_norm + trace 54 | 55 | return fid 56 | 57 | 58 | if __name__ == "__main__": 59 | device = "cuda" 60 | 61 | parser = argparse.ArgumentParser() 62 | 63 | parser.add_argument("--truncation", type=float, default=1) 64 | parser.add_argument("--truncation_mean", type=int, default=4096 * 8) 65 | parser.add_argument("--batch", type=int, default=64) 66 | parser.add_argument("--n_sample", type=int, default=50000) 67 | parser.add_argument("--size", type=int, default=256) 68 | parser.add_argument("--inception", type=str, default=None, required=True) 69 | parser.add_argument("ckpt", metavar="CHECKPOINT") 70 | 71 | args = parser.parse_args() 72 | 73 | ckpt = torch.load(args.ckpt) 74 | 75 | g = Generator(args.size, 512, 8).to(device) 76 | g.load_state_dict(ckpt["g_ema"]) 77 | g = nn.DataParallel(g) 78 | g.eval() 79 | 80 | if args.truncation < 1: 81 | with torch.no_grad(): 82 | mean_latent = g.mean_latent(args.truncation_mean) 83 | 84 | else: 85 | mean_latent = None 86 | 87 | inception = InceptionV3([3], normalize_input=False, init_weights=False) 88 | inception = nn.DataParallel(inception).eval().cuda() 89 | 90 | features = extract_feature_from_samples( 91 | g, inception, args.truncation, mean_latent, args.batch, args.n_sample, device 92 | ).numpy() 93 | print(f"extracted {features.shape[0]} features") 94 | 95 | sample_mean = np.mean(features, 0) 96 | sample_cov = np.cov(features, rowvar=False) 97 | 98 | with open(args.inception, "rb") as f: 99 | embeds = pickle.load(f) 100 | real_mean = embeds["mean"] 101 | real_cov = embeds["cov"] 102 | 103 | fid = calc_fid(sample_mean, sample_cov, real_mean, real_cov) 104 | 105 | print("fid:", fid) 106 | -------------------------------------------------------------------------------- /validation/calc_inception.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import pickle 3 | import os 4 | 5 | import torch 6 | from torch import nn 7 | from torch.nn import functional as F 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | from torchvision.models import Inception3 11 | import numpy as np 12 | from tqdm import tqdm 13 | 14 | from inception import InceptionV3 15 | from dataset import MultiResolutionDataset 16 | 17 | 18 | class Inception3Feature(Inception3): 19 | def forward(self, x): 20 | if x.shape[2] != 299 or x.shape[3] != 299: 21 | x = F.interpolate(x, size=(299, 299), mode="bilinear", align_corners=True) 22 | 23 | x = self.Conv2d_1a_3x3(x) # 299 x 299 x 3 24 | x = self.Conv2d_2a_3x3(x) # 149 x 149 x 32 25 | x = self.Conv2d_2b_3x3(x) # 147 x 147 x 32 26 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 147 x 147 x 64 27 | 28 | x = self.Conv2d_3b_1x1(x) # 73 x 73 x 64 29 | x = self.Conv2d_4a_3x3(x) # 73 x 73 x 80 30 | x = F.max_pool2d(x, kernel_size=3, stride=2) # 71 x 71 x 192 31 | 32 | x = self.Mixed_5b(x) # 35 x 35 x 192 33 | x = self.Mixed_5c(x) # 35 x 35 x 256 34 | x = self.Mixed_5d(x) # 35 x 35 x 288 35 | 36 | x = self.Mixed_6a(x) # 35 x 35 x 288 37 | x = self.Mixed_6b(x) # 17 x 17 x 768 38 | x = self.Mixed_6c(x) # 17 x 17 x 768 39 | x = self.Mixed_6d(x) # 17 x 17 x 768 40 | x = self.Mixed_6e(x) # 17 x 17 x 768 41 | 42 | x = self.Mixed_7a(x) # 17 x 17 x 768 43 | x = self.Mixed_7b(x) # 8 x 8 x 1280 44 | x = self.Mixed_7c(x) # 8 x 8 x 2048 45 | 46 | x = F.avg_pool2d(x, kernel_size=8) # 8 x 8 x 2048 47 | 48 | return x.view(x.shape[0], x.shape[1]) # 1 x 1 x 2048 49 | 50 | 51 | def load_patched_inception_v3(): 52 | # inception = inception_v3(pretrained=True) 53 | # inception_feat = Inception3Feature() 54 | # inception_feat.load_state_dict(inception.state_dict()) 55 | inception_feat = InceptionV3([3], normalize_input=False, init_weights=False) 56 | 57 | return inception_feat 58 | 59 | 60 | @torch.no_grad() 61 | def extract_features(loader, inception, device): 62 | pbar = tqdm(loader) 63 | 64 | feature_list = [] 65 | 66 | for img in pbar: 67 | img = img.to(device) 68 | feature = inception(img)[0].view(img.shape[0], -1) 69 | feature_list.append(feature.to("cpu")) 70 | 71 | features = torch.cat(feature_list, 0) 72 | 73 | return features 74 | 75 | 76 | if __name__ == "__main__": 77 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 78 | 79 | parser = argparse.ArgumentParser(description="Calculate Inception v3 features for datasets") 80 | parser.add_argument("--size", type=int, default=256) 81 | parser.add_argument("--batch", default=64, type=int, help="batch size") 82 | parser.add_argument("--n_sample", type=int, default=50000) 83 | parser.add_argument("--vflip", action="store_true") 84 | parser.add_argument("--hflip", action="store_true") 85 | parser.add_argument("path", metavar="PATH", help="path to datset lmdb file") 86 | 87 | args = parser.parse_args() 88 | 89 | inception = load_patched_inception_v3() 90 | inception = nn.DataParallel(inception).eval().to(device) 91 | 92 | transform = transforms.Compose( 93 | [ 94 | transforms.RandomVerticalFlip(p=0.5 if args.vflip else 0), 95 | transforms.RandomHorizontalFlip(p=0.5 if args.hflip else 0), 96 | transforms.ToTensor(), 97 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 98 | ] 99 | ) 100 | 101 | dset = MultiResolutionDataset(args.path, transform=transform, resolution=args.size) 102 | loader = DataLoader(dset, batch_size=args.batch, num_workers=4) 103 | 104 | features = extract_features(loader, inception, device).numpy() 105 | 106 | features = features[: args.n_sample] 107 | 108 | print(f"extracted {features.shape[0]} features") 109 | 110 | mean = np.mean(features, 0) 111 | cov = np.cov(features, rowvar=False) 112 | 113 | name = os.path.splitext(os.path.basename(args.path))[0] 114 | 115 | with open(f"inception_{name}.pkl", "wb") as f: 116 | pickle.dump({"mean": mean, "cov": cov, "size": args.size, "path": args.path}, f) 117 | -------------------------------------------------------------------------------- /validation/calc_ppl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from torch.nn import functional as F 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | import lpips 9 | from model import Generator 10 | 11 | 12 | def normalize(x): 13 | return x / torch.sqrt(x.pow(2).sum(-1, keepdim=True)) 14 | 15 | 16 | def slerp(a, b, t): 17 | a = normalize(a) 18 | b = normalize(b) 19 | d = (a * b).sum(-1, keepdim=True) 20 | p = t * torch.acos(d) 21 | c = normalize(b - d * a) 22 | d = a * torch.cos(p) + c * torch.sin(p) 23 | 24 | return normalize(d) 25 | 26 | 27 | def lerp(a, b, t): 28 | return a + (b - a) * t 29 | 30 | 31 | if __name__ == "__main__": 32 | device = "cuda" 33 | 34 | parser = argparse.ArgumentParser() 35 | 36 | parser.add_argument("--space", choices=["z", "w"]) 37 | parser.add_argument("--batch", type=int, default=64) 38 | parser.add_argument("--n_sample", type=int, default=5000) 39 | parser.add_argument("--size", type=int, default=256) 40 | parser.add_argument("--eps", type=float, default=1e-4) 41 | parser.add_argument("--crop", action="store_true") 42 | parser.add_argument("ckpt", metavar="CHECKPOINT") 43 | 44 | args = parser.parse_args() 45 | 46 | latent_dim = 512 47 | 48 | ckpt = torch.load(args.ckpt) 49 | 50 | g = Generator(args.size, latent_dim, 8).to(device) 51 | g.load_state_dict(ckpt["g_ema"]) 52 | g.eval() 53 | 54 | percept = lpips.PerceptualLoss(model="net-lin", net="vgg", use_gpu=device.startswith("cuda")) 55 | 56 | distances = [] 57 | 58 | n_batch = args.n_sample // args.batch 59 | resid = args.n_sample - (n_batch * args.batch) 60 | batch_sizes = [args.batch] * n_batch + [resid] 61 | 62 | with torch.no_grad(): 63 | for batch in tqdm(batch_sizes): 64 | noise = g.make_noise() 65 | 66 | inputs = torch.randn([batch * 2, latent_dim], device=device) 67 | lerp_t = torch.rand(batch, device=device) 68 | 69 | if args.space == "w": 70 | latent = g.get_latent(inputs) 71 | latent_t0, latent_t1 = latent[::2], latent[1::2] 72 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) 73 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + args.eps) 74 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) 75 | 76 | image, _ = g([latent_e], input_is_latent=True, noise=noise) 77 | 78 | if args.crop: 79 | c = image.shape[2] // 8 80 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] 81 | 82 | factor = image.shape[2] // 256 83 | 84 | if factor > 1: 85 | image = F.interpolate(image, size=(256, 256), mode="bilinear", align_corners=False) 86 | 87 | dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (args.eps ** 2) 88 | distances.append(dist.to("cpu").numpy()) 89 | 90 | distances = np.concatenate(distances, 0) 91 | 92 | lo = np.percentile(distances, 1, interpolation="lower") 93 | hi = np.percentile(distances, 99, interpolation="higher") 94 | filtered_dist = np.extract(np.logical_and(lo <= distances, distances <= hi), distances) 95 | 96 | print("ppl:", filtered_dist.mean()) 97 | -------------------------------------------------------------------------------- /validation/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from skimage.measure import compare_ssim 7 | import torch 8 | from torch.autograd import Variable 9 | 10 | from . import dist_model 11 | 12 | 13 | class PerceptualLoss(torch.nn.Module): 14 | def __init__( 15 | self, model="net-lin", net="alex", colorspace="rgb", spatial=False, use_gpu=True, gpu_ids=[0] 16 | ): # VGG using our perceptually-learned weights (LPIPS metric) 17 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 18 | super(PerceptualLoss, self).__init__() 19 | # print('Setting up Perceptual loss...') 20 | self.use_gpu = use_gpu 21 | self.spatial = spatial 22 | self.gpu_ids = gpu_ids 23 | self.model = dist_model.DistModel() 24 | self.model.initialize( 25 | model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids 26 | ) 27 | # print('...[%s] initialized'%self.model.name()) 28 | # print('...Done') 29 | 30 | def forward(self, pred, target, normalize=False): 31 | """ 32 | Pred and target are Variables. 33 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 34 | If normalize is False, assumes the images are already between [-1,+1] 35 | 36 | Inputs pred and target are Nx3xHxW 37 | Output pytorch Variable N long 38 | """ 39 | 40 | if normalize: 41 | target = 2 * target - 1 42 | pred = 2 * pred - 1 43 | 44 | return self.model.forward(target, pred) 45 | -------------------------------------------------------------------------------- /validation/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | 4 | 5 | class BaseModel: 6 | def __init__(self): 7 | pass 8 | 9 | def name(self): 10 | return "BaseModel" 11 | 12 | def initialize(self, use_gpu=True, gpu_ids=[0]): 13 | self.use_gpu = use_gpu 14 | self.gpu_ids = gpu_ids 15 | 16 | def forward(self): 17 | pass 18 | 19 | def get_image_paths(self): 20 | pass 21 | 22 | def optimize_parameters(self): 23 | pass 24 | 25 | def get_current_visuals(self): 26 | return self.input 27 | 28 | def get_current_errors(self): 29 | return {} 30 | 31 | def save(self, label): 32 | pass 33 | 34 | # helper saving function that can be used by subclasses 35 | def save_network(self, network, path, network_label, epoch_label): 36 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 37 | save_path = os.path.join(path, save_filename) 38 | torch.save(network.state_dict(), save_path) 39 | 40 | # helper loading function that can be used by subclasses 41 | def load_network(self, network, network_label, epoch_label): 42 | save_filename = "%s_net_%s.pth" % (epoch_label, network_label) 43 | save_path = os.path.join(self.save_dir, save_filename) 44 | print("Loading network from %s" % save_path) 45 | network.load_state_dict(torch.load(save_path)) 46 | 47 | def update_learning_rate(): 48 | pass 49 | 50 | def get_image_paths(self): 51 | return self.image_paths 52 | 53 | def save_done(self, flag=False): 54 | np.save(os.path.join(self.save_dir, "done_flag"), flag) 55 | np.savetxt(os.path.join(self.save_dir, "done_flag"), [flag,], fmt="%i") 56 | 57 | -------------------------------------------------------------------------------- /validation/lpips/networks_basic.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | from . import pretrained_networks as pn 5 | 6 | from . import util 7 | 8 | 9 | def spatial_average(in_tens, keepdim=True): 10 | return in_tens.mean([2, 3], keepdim=keepdim) 11 | 12 | 13 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 14 | in_H = in_tens.shape[2] 15 | scale_factor = 1.0 * out_H / in_H 16 | 17 | return nn.Upsample(scale_factor=scale_factor, mode="bilinear", align_corners=False)(in_tens) 18 | 19 | 20 | # Learned perceptual metric 21 | class PNetLin(nn.Module): 22 | def __init__( 23 | self, 24 | pnet_type="vgg", 25 | pnet_rand=False, 26 | pnet_tune=False, 27 | use_dropout=True, 28 | spatial=False, 29 | version="0.1", 30 | lpips=True, 31 | ): 32 | super(PNetLin, self).__init__() 33 | 34 | self.pnet_type = pnet_type 35 | self.pnet_tune = pnet_tune 36 | self.pnet_rand = pnet_rand 37 | self.spatial = spatial 38 | self.lpips = lpips 39 | self.version = version 40 | self.scaling_layer = ScalingLayer() 41 | 42 | if self.pnet_type in ["vgg", "vgg16"]: 43 | net_type = pn.vgg16 44 | self.chns = [64, 128, 256, 512, 512] 45 | elif self.pnet_type == "alex": 46 | net_type = pn.alexnet 47 | self.chns = [64, 192, 384, 256, 256] 48 | elif self.pnet_type == "squeeze": 49 | net_type = pn.squeezenet 50 | self.chns = [64, 128, 256, 384, 384, 512, 512] 51 | self.L = len(self.chns) 52 | 53 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 54 | 55 | if lpips: 56 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 57 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 58 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 59 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 60 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 61 | self.lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] 62 | if self.pnet_type == "squeeze": # 7 layers for squeezenet 63 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 64 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 65 | self.lins += [self.lin5, self.lin6] 66 | 67 | def forward(self, in0, in1, retPerLayer=False): 68 | # v0.0 - original release had a bug, where input was not scaled 69 | in0_input, in1_input = ( 70 | (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version == "0.1" else (in0, in1) 71 | ) 72 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 73 | feats0, feats1, diffs = {}, {}, {} 74 | 75 | for kk in range(self.L): 76 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 77 | diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 78 | 79 | if self.lpips: 80 | if self.spatial: 81 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] 82 | else: 83 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 84 | else: 85 | if self.spatial: 86 | res = [upsample(diffs[kk].sum(dim=1, keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] 87 | else: 88 | res = [spatial_average(diffs[kk].sum(dim=1, keepdim=True), keepdim=True) for kk in range(self.L)] 89 | 90 | val = res[0] 91 | for l in range(1, self.L): 92 | val += res[l] 93 | 94 | if retPerLayer: 95 | return (val, res) 96 | else: 97 | return val 98 | 99 | 100 | class ScalingLayer(nn.Module): 101 | def __init__(self): 102 | super(ScalingLayer, self).__init__() 103 | self.register_buffer("shift", torch.Tensor([-0.030, -0.088, -0.188])[None, :, None, None]) 104 | self.register_buffer("scale", torch.Tensor([0.458, 0.448, 0.450])[None, :, None, None]) 105 | 106 | def forward(self, inp): 107 | return (inp - self.shift) / self.scale 108 | 109 | 110 | class NetLinLayer(nn.Module): 111 | """ A single linear layer which does a 1x1 conv """ 112 | 113 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 114 | super(NetLinLayer, self).__init__() 115 | 116 | layers = [nn.Dropout(),] if (use_dropout) else [] 117 | layers += [ 118 | nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), 119 | ] 120 | self.model = nn.Sequential(*layers) 121 | 122 | 123 | class Dist2LogitLayer(nn.Module): 124 | """ takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) """ 125 | 126 | def __init__(self, chn_mid=32, use_sigmoid=True): 127 | super(Dist2LogitLayer, self).__init__() 128 | 129 | layers = [ 130 | nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True), 131 | ] 132 | layers += [ 133 | nn.LeakyReLU(0.2, True), 134 | ] 135 | layers += [ 136 | nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True), 137 | ] 138 | layers += [ 139 | nn.LeakyReLU(0.2, True), 140 | ] 141 | layers += [ 142 | nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True), 143 | ] 144 | if use_sigmoid: 145 | layers += [ 146 | nn.Sigmoid(), 147 | ] 148 | self.model = nn.Sequential(*layers) 149 | 150 | def forward(self, d0, d1, eps=0.1): 151 | return self.model.forward(torch.cat((d0, d1, d0 - d1, d0 / (d1 + eps), d1 / (d0 + eps)), dim=1)) 152 | 153 | 154 | class BCERankingLoss(nn.Module): 155 | def __init__(self, chn_mid=32): 156 | super(BCERankingLoss, self).__init__() 157 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 158 | # self.parameters = list(self.net.parameters()) 159 | self.loss = torch.nn.BCELoss() 160 | 161 | def forward(self, d0, d1, judge): 162 | per = (judge + 1.0) / 2.0 163 | self.logit = self.net.forward(d0, d1) 164 | return self.loss(self.logit, per) 165 | 166 | 167 | # L2, DSSIM metrics 168 | class FakeNet(nn.Module): 169 | def __init__(self, use_gpu=True, colorspace="Lab"): 170 | super(FakeNet, self).__init__() 171 | self.use_gpu = use_gpu 172 | self.colorspace = colorspace 173 | 174 | 175 | class L2(FakeNet): 176 | def forward(self, in0, in1, retPerLayer=None): 177 | assert in0.size()[0] == 1 # currently only supports batchSize 1 178 | 179 | if self.colorspace == "RGB": 180 | (N, C, X, Y) = in0.size() 181 | value = torch.mean( 182 | torch.mean(torch.mean((in0 - in1) ** 2, dim=1).view(N, 1, X, Y), dim=2).view(N, 1, 1, Y), dim=3 183 | ).view(N) 184 | return value 185 | elif self.colorspace == "Lab": 186 | value = util.l2( 187 | util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 188 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), 189 | range=100.0, 190 | ).astype("float") 191 | ret_var = Variable(torch.Tensor((value,))) 192 | if self.use_gpu: 193 | ret_var = ret_var.cuda() 194 | return ret_var 195 | 196 | 197 | class DSSIM(FakeNet): 198 | def forward(self, in0, in1, retPerLayer=None): 199 | assert in0.size()[0] == 1 # currently only supports batchSize 1 200 | 201 | if self.colorspace == "RGB": 202 | value = util.dssim(1.0 * util.tensor2im(in0.data), 1.0 * util.tensor2im(in1.data), range=255.0).astype( 203 | "float" 204 | ) 205 | elif self.colorspace == "Lab": 206 | value = util.dssim( 207 | util.tensor2np(util.tensor2tensorlab(in0.data, to_norm=False)), 208 | util.tensor2np(util.tensor2tensorlab(in1.data, to_norm=False)), 209 | range=100.0, 210 | ).astype("float") 211 | ret_var = Variable(torch.Tensor((value,))) 212 | if self.use_gpu: 213 | ret_var = ret_var.cuda() 214 | return ret_var 215 | 216 | 217 | def print_network(net): 218 | num_params = 0 219 | for param in net.parameters(): 220 | num_params += param.numel() 221 | print("Network", net) 222 | print("Total number of parameters: %d" % num_params) 223 | 224 | -------------------------------------------------------------------------------- /validation/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2, 5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5", "relu6", "relu7"]) 52 | out = vgg_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5, h_relu6, h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ["relu1", "relu2", "relu3", "relu4", "relu5"]) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | 98 | class vgg16(torch.nn.Module): 99 | def __init__(self, requires_grad=False, pretrained=True): 100 | super(vgg16, self).__init__() 101 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 102 | self.slice1 = torch.nn.Sequential() 103 | self.slice2 = torch.nn.Sequential() 104 | self.slice3 = torch.nn.Sequential() 105 | self.slice4 = torch.nn.Sequential() 106 | self.slice5 = torch.nn.Sequential() 107 | self.N_slices = 5 108 | for x in range(4): 109 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 110 | for x in range(4, 9): 111 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 112 | for x in range(9, 16): 113 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 114 | for x in range(16, 23): 115 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 116 | for x in range(23, 30): 117 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 118 | if not requires_grad: 119 | for param in self.parameters(): 120 | param.requires_grad = False 121 | 122 | def forward(self, X): 123 | h = self.slice1(X) 124 | h_relu1_2 = h 125 | h = self.slice2(h) 126 | h_relu2_2 = h 127 | h = self.slice3(h) 128 | h_relu3_3 = h 129 | h = self.slice4(h) 130 | h_relu4_3 = h 131 | h = self.slice5(h) 132 | h_relu5_3 = h 133 | vgg_outputs = namedtuple("VggOutputs", ["relu1_2", "relu2_2", "relu3_3", "relu4_3", "relu5_3"]) 134 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 135 | 136 | return out 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if num == 18: 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif num == 34: 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif num == 50: 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif num == 101: 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif num == 152: 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ["relu1", "conv2", "conv3", "conv4", "conv5"]) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /validation/lpips/util.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from __future__ import division 3 | from __future__ import print_function 4 | 5 | import numpy as np 6 | from skimage.measure import compare_ssim 7 | import torch 8 | 9 | 10 | def normalize_tensor(in_feat, eps=1e-10): 11 | norm_factor = torch.sqrt(torch.sum(in_feat ** 2, dim=1, keepdim=True)) 12 | return in_feat / (norm_factor + eps) 13 | 14 | 15 | def l2(p0, p1, range=255.0): 16 | return 0.5 * np.mean((p0 / range - p1 / range) ** 2) 17 | 18 | 19 | def psnr(p0, p1, peak=255.0): 20 | return 10 * np.log10(peak ** 2 / np.mean((1.0 * p0 - 1.0 * p1) ** 2)) 21 | 22 | 23 | def dssim(p0, p1, range=255.0): 24 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.0 25 | 26 | 27 | def rgb2lab(in_img, mean_cent=False): 28 | from skimage import color 29 | 30 | img_lab = color.rgb2lab(in_img) 31 | if mean_cent: 32 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 33 | return img_lab 34 | 35 | 36 | def tensor2np(tensor_obj): 37 | # change dimension of a tensor object into a numpy array 38 | return tensor_obj[0].cpu().float().numpy().transpose((1, 2, 0)) 39 | 40 | 41 | def np2tensor(np_obj): 42 | # change dimenion of np array into tensor array 43 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 44 | 45 | 46 | def tensor2tensorlab(image_tensor, to_norm=True, mc_only=False): 47 | # image tensor to lab tensor 48 | from skimage import color 49 | 50 | img = tensor2im(image_tensor) 51 | img_lab = color.rgb2lab(img) 52 | if mc_only: 53 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 54 | if to_norm and not mc_only: 55 | img_lab[:, :, 0] = img_lab[:, :, 0] - 50 56 | img_lab = img_lab / 100.0 57 | 58 | return np2tensor(img_lab) 59 | 60 | 61 | def tensorlab2tensor(lab_tensor, return_inbnd=False): 62 | from skimage import color 63 | import warnings 64 | 65 | warnings.filterwarnings("ignore") 66 | 67 | lab = tensor2np(lab_tensor) * 100.0 68 | lab[:, :, 0] = lab[:, :, 0] + 50 69 | 70 | rgb_back = 255.0 * np.clip(color.lab2rgb(lab.astype("float")), 0, 1) 71 | if return_inbnd: 72 | # convert back to lab, see if we match 73 | lab_back = color.rgb2lab(rgb_back.astype("uint8")) 74 | mask = 1.0 * np.isclose(lab_back, lab, atol=2.0) 75 | mask = np2tensor(np.prod(mask, axis=2)[:, :, np.newaxis]) 76 | return (im2tensor(rgb_back), mask) 77 | else: 78 | return im2tensor(rgb_back) 79 | 80 | 81 | def rgb2lab(input): 82 | from skimage import color 83 | 84 | return color.rgb2lab(input / 255.0) 85 | 86 | 87 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 88 | image_numpy = image_tensor[0].cpu().float().numpy() 89 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 90 | return image_numpy.astype(imtype) 91 | 92 | 93 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 94 | return torch.Tensor((image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 95 | 96 | 97 | def tensor2vec(vector_tensor): 98 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 99 | 100 | 101 | def voc_ap(rec, prec, use_07_metric=False): 102 | """ ap = voc_ap(rec, prec, [use_07_metric]) 103 | Compute VOC AP given precision and recall. 104 | If use_07_metric is true, uses the 105 | VOC 07 11 point method (default:False). 106 | """ 107 | if use_07_metric: 108 | # 11 point metric 109 | ap = 0.0 110 | for t in np.arange(0.0, 1.1, 0.1): 111 | if np.sum(rec >= t) == 0: 112 | p = 0 113 | else: 114 | p = np.max(prec[rec >= t]) 115 | ap = ap + p / 11.0 116 | else: 117 | # correct AP calculation 118 | # first append sentinel values at the end 119 | mrec = np.concatenate(([0.0], rec, [1.0])) 120 | mpre = np.concatenate(([0.0], prec, [0.0])) 121 | 122 | # compute the precision envelope 123 | for i in range(mpre.size - 1, 0, -1): 124 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 125 | 126 | # to calculate area under PR curve, look for points 127 | # where X axis (recall) changes value 128 | i = np.where(mrec[1:] != mrec[:-1])[0] 129 | 130 | # and sum (\Delta recall) * prec 131 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 132 | return ap 133 | 134 | 135 | def tensor2im(image_tensor, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 136 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 137 | image_numpy = image_tensor[0].cpu().float().numpy() 138 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 139 | return image_numpy.astype(imtype) 140 | 141 | 142 | def im2tensor(image, imtype=np.uint8, cent=1.0, factor=255.0 / 2.0): 143 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 144 | return torch.Tensor((image / factor - cent)[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 145 | 146 | -------------------------------------------------------------------------------- /validation/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/validation/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /validation/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/validation/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /validation/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/validation/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /validation/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/validation/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /validation/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/validation/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /validation/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/validation/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /validation/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import random 4 | 5 | from sklearn.metrics import pairwise_distances 6 | from tqdm import tqdm 7 | import torch 8 | from torch.nn import functional as F 9 | import numpy as np 10 | from scipy import linalg 11 | 12 | from .inception import InceptionV3 13 | from . import lpips 14 | 15 | 16 | @torch.no_grad() 17 | def vae_fid(vae, batch_size, latent_dim, n_sample, inception_name, calculate_prdc=True): 18 | vae.eval() 19 | 20 | inception = InceptionV3([3], normalize_input=False, init_weights=False) 21 | inception = inception.eval().to(next(vae.parameters()).device) 22 | 23 | n_batch = n_sample // batch_size 24 | resid = n_sample - (n_batch * batch_size) 25 | if resid == 0: 26 | batch_sizes = [batch_size] * n_batch 27 | else: 28 | batch_sizes = [batch_size] * n_batch + [resid] 29 | features = [] 30 | 31 | for batch in batch_sizes: 32 | latent = torch.randn(batch, *latent_dim).cuda() 33 | img = vae.decode(latent) 34 | feat = inception(img)[0].view(img.shape[0], -1) 35 | features.append(feat.to("cpu")) 36 | features = torch.cat(features, 0).numpy() 37 | 38 | del inception 39 | 40 | sample_mean = np.mean(features, 0) 41 | sample_cov = np.cov(features, rowvar=False) 42 | 43 | with open(f"inception_{inception_name}_stats.pkl", "rb") as f: 44 | embeds = pickle.load(f) 45 | real_mean = embeds["mean"] 46 | real_cov = embeds["cov"] 47 | 48 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 49 | 50 | if not np.isfinite(cov_sqrt).all(): 51 | print("product of cov matrices is singular") 52 | offset = np.eye(sample_cov.shape[0]) * 1e-6 53 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 54 | 55 | if np.iscomplexobj(cov_sqrt): 56 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 57 | m = np.max(np.abs(cov_sqrt.imag)) 58 | 59 | raise ValueError(f"Imaginary component {m}") 60 | 61 | cov_sqrt = cov_sqrt.real 62 | 63 | mean_diff = sample_mean - real_mean 64 | mean_norm = mean_diff @ mean_diff 65 | 66 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 67 | 68 | fid = mean_norm + trace 69 | 70 | ret_dict = {"FID": fid} 71 | 72 | if calculate_prdc: 73 | with open(f"inception_{inception_name}_features.pkl", "rb") as f: 74 | embeds = pickle.load(f) 75 | real_feats = embeds["features"] 76 | _, _, density, coverage = prdc(real_feats[:80000], features[:80000]) 77 | ret_dict["Density"] = density 78 | ret_dict["Coverage"] = coverage 79 | 80 | return ret_dict 81 | 82 | 83 | @torch.no_grad() 84 | def fid(generator, batch_size, n_sample, truncation, inception_name, calculate_prdc=True): 85 | generator.eval() 86 | mean_latent = generator.mean_latent(2 ** 14) 87 | 88 | inception = InceptionV3([3], normalize_input=False, init_weights=False) 89 | inception = inception.eval().to(next(generator.parameters()).device) 90 | 91 | n_batch = n_sample // batch_size 92 | resid = n_sample - (n_batch * batch_size) 93 | if resid == 0: 94 | batch_sizes = [batch_size] * n_batch 95 | else: 96 | batch_sizes = [batch_size] * n_batch + [resid] 97 | features = [] 98 | 99 | for batch in batch_sizes: 100 | if truncation is None: 101 | trunc = random.uniform(0.9, 1.5) 102 | else: 103 | trunc = truncation 104 | latent = torch.randn(batch, 512).cuda() 105 | img, _ = generator([latent], truncation=trunc, truncation_latent=mean_latent) 106 | feat = inception(img)[0].view(img.shape[0], -1) 107 | features.append(feat.to("cpu")) 108 | features = torch.cat(features, 0).numpy() 109 | 110 | del inception 111 | 112 | sample_mean = np.mean(features, 0) 113 | sample_cov = np.cov(features, rowvar=False) 114 | 115 | with open(f"inception_{inception_name}_stats.pkl", "rb") as f: 116 | embeds = pickle.load(f) 117 | real_mean = embeds["mean"] 118 | real_cov = embeds["cov"] 119 | 120 | cov_sqrt, _ = linalg.sqrtm(sample_cov @ real_cov, disp=False) 121 | 122 | if not np.isfinite(cov_sqrt).all(): 123 | print("product of cov matrices is singular") 124 | offset = np.eye(sample_cov.shape[0]) * 1e-6 125 | cov_sqrt = linalg.sqrtm((sample_cov + offset) @ (real_cov + offset)) 126 | 127 | if np.iscomplexobj(cov_sqrt): 128 | if not np.allclose(np.diagonal(cov_sqrt).imag, 0, atol=1e-3): 129 | m = np.max(np.abs(cov_sqrt.imag)) 130 | 131 | raise ValueError(f"Imaginary component {m}") 132 | 133 | cov_sqrt = cov_sqrt.real 134 | 135 | mean_diff = sample_mean - real_mean 136 | mean_norm = mean_diff @ mean_diff 137 | 138 | trace = np.trace(sample_cov) + np.trace(real_cov) - 2 * np.trace(cov_sqrt) 139 | 140 | fid = mean_norm + trace 141 | 142 | ret_dict = {"FID": fid} 143 | 144 | if calculate_prdc: 145 | with open(f"inception_{inception_name}_features.pkl", "rb") as f: 146 | embeds = pickle.load(f) 147 | real_feats = embeds["features"] 148 | _, _, density, coverage = prdc(real_feats[:80000], features[:80000]) 149 | ret_dict["Density"] = density 150 | ret_dict["Coverage"] = coverage 151 | 152 | return ret_dict 153 | 154 | 155 | def get_dataset_inception_features(loader, inception_name, size): 156 | if not os.path.exists(f"inception_{inception_name}_stats.pkl"): 157 | print("calculating inception features for FID....") 158 | inception = InceptionV3([3], normalize_input=False, init_weights=False) 159 | inception = torch.nn.DataParallel(inception).eval().cuda() 160 | 161 | feature_list = [] 162 | for img in tqdm(loader): 163 | img = img.cuda() 164 | feature = inception(img)[0].view(img.shape[0], -1) 165 | feature_list.append(feature.to("cpu")) 166 | features = torch.cat(feature_list, 0).numpy() 167 | 168 | mean = np.mean(features, 0) 169 | cov = np.cov(features, rowvar=False) 170 | 171 | with open(f"inception_{inception_name}_stats.pkl", "wb") as f: 172 | pickle.dump({"mean": mean, "cov": cov, "size": size, "feat": features}, f) 173 | with open(f"inception_{inception_name}_features.pkl", "wb") as f: 174 | pickle.dump({"features": features}, f) 175 | else: 176 | print(f"Found inception features: inception_{inception_name}_stats.pkl") 177 | 178 | 179 | def compute_pairwise_distance(data_x, data_y=None, metric="l2"): 180 | if data_y is None: 181 | data_y = data_x 182 | dists = pairwise_distances( 183 | data_x.reshape((len(data_x), -1)), data_y.reshape((len(data_y), -1)), metric=metric, n_jobs=24 184 | ) 185 | return dists 186 | 187 | 188 | def get_kth_value(unsorted, k, axis=-1): 189 | indices = np.argpartition(unsorted, k, axis=axis)[..., :k] 190 | k_smallests = np.take_along_axis(unsorted, indices, axis=axis) 191 | kth_values = k_smallests.max(axis=axis) 192 | return kth_values 193 | 194 | 195 | def compute_nearest_neighbour_distances(input_features, nearest_k, metric): 196 | distances = compute_pairwise_distance(input_features, metric=metric) 197 | radii = get_kth_value(distances, k=nearest_k + 1, axis=-1) 198 | return radii 199 | 200 | 201 | def prdc(real_features, fake_features, nearest_k=10, metric="l2"): 202 | real_nearest_neighbour_distances = compute_nearest_neighbour_distances(real_features, nearest_k, metric=metric) 203 | fake_nearest_neighbour_distances = compute_nearest_neighbour_distances(fake_features, nearest_k, metric=metric) 204 | distance_real_fake = compute_pairwise_distance(real_features, fake_features, metric=metric) 205 | 206 | precision = (distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1)).any(axis=0).mean() 207 | recall = (distance_real_fake < np.expand_dims(fake_nearest_neighbour_distances, axis=0)).any(axis=1).mean() 208 | 209 | density = (1.0 / float(nearest_k)) * ( 210 | distance_real_fake < np.expand_dims(real_nearest_neighbour_distances, axis=1) 211 | ).sum(axis=0).mean() 212 | coverage = (distance_real_fake.min(axis=1) < real_nearest_neighbour_distances).mean() 213 | 214 | return precision, recall, density, coverage 215 | 216 | 217 | def lerp(a, b, t): 218 | return a + (b - a) * t 219 | 220 | 221 | @torch.no_grad() 222 | def ppl(generator, batch_size, n_sample, space, crop, latent_dim, eps=1e-4): 223 | generator.eval() 224 | 225 | percept = lpips.PerceptualLoss( 226 | model="net-lin", net="vgg", use_gpu=True, gpu_ids=[next(generator.parameters()).device.index] 227 | ) 228 | 229 | distances = [] 230 | 231 | n_batch = n_sample // batch_size 232 | resid = n_sample - (n_batch * batch_size) 233 | if resid == 0: 234 | batch_sizes = [batch_size] * n_batch 235 | else: 236 | batch_sizes = [batch_size] * n_batch + [resid] 237 | 238 | for batch_size in batch_sizes: 239 | noise = generator.make_noise() 240 | 241 | inputs = torch.randn([batch_size * 2, latent_dim]).cuda() 242 | lerp_t = torch.rand(batch_size).cuda() 243 | 244 | if space == "w": 245 | latent = generator.get_latent(inputs) 246 | latent_t0, latent_t1 = latent[::2], latent[1::2] 247 | latent_e0 = lerp(latent_t0, latent_t1, lerp_t[:, None]) 248 | latent_e1 = lerp(latent_t0, latent_t1, lerp_t[:, None] + eps) 249 | latent_e = torch.stack([latent_e0, latent_e1], 1).view(*latent.shape) 250 | 251 | image, _ = generator(latent_e, input_is_latent=True, noise=noise) 252 | 253 | if crop: 254 | c = image.shape[2] // 8 255 | image = image[:, :, c * 3 : c * 7, c * 2 : c * 6] 256 | 257 | factor = image.shape[2] // 256 258 | 259 | if factor > 1: 260 | image = F.interpolate(image, size=(256, 256), mode="bilinear", align_corners=False) 261 | 262 | dist = percept(image[::2], image[1::2]).view(image.shape[0] // 2) / (eps ** 2) 263 | distances.append(dist.to("cpu").numpy()) 264 | 265 | distances = np.concatenate(distances, 0) 266 | 267 | lo = np.percentile(distances, 1, interpolation="lower") 268 | hi = np.percentile(distances, 99, interpolation="higher") 269 | filtered_dist = np.extract(np.logical_and(lo <= distances, distances <= hi), distances) 270 | path_length = filtered_dist.mean() 271 | 272 | del percept, inputs, lerp_t, image, dist 273 | 274 | return path_length 275 | -------------------------------------------------------------------------------- /validation/spectral_norm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class SpectralNorm(object): 5 | def __init__(self, name="weight", n_power_iterations=1, dim=0, eps=1e-12): 6 | self.name = name 7 | self.dim = dim 8 | if n_power_iterations <= 0: 9 | raise ValueError( 10 | "Expected n_power_iterations to be positive, but " 11 | "got n_power_iterations={}".format(n_power_iterations) 12 | ) 13 | self.n_power_iterations = n_power_iterations 14 | self.eps = eps 15 | 16 | def reshape_weight_to_matrix(self, weight): 17 | weight_mat = weight 18 | if self.dim != 0: 19 | # permute dim to front 20 | weight_mat = weight_mat.permute(self.dim, *[d for d in range(weight_mat.dim()) if d != self.dim]) 21 | height = weight_mat.size(0) 22 | return weight_mat.reshape(height, -1) 23 | 24 | def compute_sigma(self, module): 25 | with torch.no_grad(): 26 | weight = getattr(module, self.name) 27 | weight_mat = self.reshape_weight_to_matrix(weight) 28 | 29 | u = getattr(module, self.name + "_u") 30 | v = getattr(module, self.name + "_v") 31 | for _ in range(self.n_power_iterations): 32 | v = torch.nn.functional.normalize(torch.mv(weight_mat.t(), u), dim=0, eps=self.eps) 33 | u = torch.nn.functional.normalize(torch.mv(weight_mat, v), dim=0, eps=self.eps) 34 | setattr(module, self.name + "_u", u) 35 | setattr(module, self.name + "_v", v) 36 | 37 | sigma = torch.dot(u, torch.mv(weight_mat, v)) 38 | setattr(module, "spectral_norm", sigma) 39 | 40 | def remove(self, module): 41 | delattr(module, self.name) 42 | delattr(module, self.name + "_u") 43 | delattr(module, self.name + "_v") 44 | delattr(module, "spectral_norm") 45 | 46 | def __call__(self, module, inputs): 47 | self.compute_sigma(module) 48 | 49 | def _solve_v_and_rescale(self, weight_mat, u, target_sigma): 50 | v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), weight_mat.t(), u.unsqueeze(1)).squeeze(1) 51 | return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) 52 | 53 | @staticmethod 54 | def apply(module, name, n_power_iterations, dim, eps, normalize=True): 55 | for k, hook in module._forward_pre_hooks.items(): 56 | if isinstance(hook, SpectralNorm) and hook.name == name: 57 | raise RuntimeError("Cannot register two spectral_norm hooks on " "the same parameter {}".format(name)) 58 | 59 | fn = SpectralNorm(name, n_power_iterations, dim, eps) 60 | weight = module._parameters[name] 61 | with torch.no_grad(): 62 | weight_mat = fn.reshape_weight_to_matrix(weight) 63 | 64 | h, w = weight_mat.size() 65 | u = torch.nn.functional.normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) 66 | v = torch.nn.functional.normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) 67 | 68 | module.register_buffer(fn.name + "_u", u) 69 | module.register_buffer(fn.name + "_v", v) 70 | module.register_buffer("spectral_norm", torch.tensor(-1, device=next(module.parameters()).device)) 71 | 72 | module.register_forward_pre_hook(fn) 73 | return fn 74 | 75 | 76 | def track_spectral_norm(module, name="weight", n_power_iterations=1, eps=1e-12, dim=None): 77 | r"""Tracks the spectral norm of a module's weight parameter 78 | Args: 79 | module (nn.Module): containing module 80 | name (str, optional): name of weight parameter 81 | n_power_iterations (int, optional): number of power iterations to 82 | calculate spectral norm 83 | eps (float, optional): epsilon for numerical stability in 84 | calculating norms 85 | dim (int, optional): dimension corresponding to number of outputs, 86 | the default is ``0``, except for modules that are instances of 87 | ConvTranspose{1,2,3}d, when it is ``1`` 88 | Returns: 89 | The original module with the spectral norm hook 90 | Example:: 91 | >>> m = spectral_norm(nn.Linear(20, 40)) 92 | >>> m 93 | Linear(in_features=20, out_features=40, bias=True) 94 | >>> m.weight_u.size() 95 | torch.Size([40]) 96 | """ 97 | if dim is None: 98 | if isinstance(module, (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d)): 99 | dim = 1 100 | else: 101 | dim = 0 102 | SpectralNorm.apply(module, name, n_power_iterations, dim, eps) 103 | return module 104 | 105 | 106 | def remove_spectral_norm(module, name="weight"): 107 | r"""Removes the spectral normalization reparameterization from a module. 108 | Args: 109 | module (Module): containing module 110 | name (str, optional): name of weight parameter 111 | Example: 112 | >>> m = spectral_norm(nn.Linear(40, 10)) 113 | >>> remove_spectral_norm(m) 114 | """ 115 | for k, hook in module._forward_pre_hooks.items(): 116 | if isinstance(hook, SpectralNorm) and hook.name == name: 117 | hook.remove(module) 118 | del module._forward_pre_hooks[k] 119 | break 120 | else: 121 | raise ValueError("spectral_norm of '{}' not found in {}".format(name, module)) 122 | 123 | return module 124 | -------------------------------------------------------------------------------- /workspace/naamloos_average_pitch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_average_pitch.npy -------------------------------------------------------------------------------- /workspace/naamloos_bass_sum.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_bass_sum.npy -------------------------------------------------------------------------------- /workspace/naamloos_drop_latents.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_drop_latents.npy -------------------------------------------------------------------------------- /workspace/naamloos_drop_latents_1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_drop_latents_1.npy -------------------------------------------------------------------------------- /workspace/naamloos_high_average_pitch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_high_average_pitch.npy -------------------------------------------------------------------------------- /workspace/naamloos_high_pitches_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_high_pitches_mean.npy -------------------------------------------------------------------------------- /workspace/naamloos_intro_latents.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_intro_latents.npy -------------------------------------------------------------------------------- /workspace/naamloos_metadata.json: -------------------------------------------------------------------------------- 1 | {"total_frames": 4986} 2 | -------------------------------------------------------------------------------- /workspace/naamloos_onsets.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_onsets.npy -------------------------------------------------------------------------------- /workspace/naamloos_params.json: -------------------------------------------------------------------------------- 1 | {"intro_num_beats": 64, "intro_loop_smoothing": 30, "intro_loop_factor": 0.4, "intro_loop_len": 12, "drop_num_beats": 32, "drop_loop_smoothing": 15, "drop_loop_factor": 1, "drop_loop_len": 6, "onset_smooth": 2, "onset_clip": 95, "freq_mod": 10, "freq_mod_offset": 0, "freq_smooth": 5, "freq_latent_smooth": 4, "freq_latent_layer": 1, "freq_latent_weight": 2, "high_freq_mod": 10, "high_freq_mod_offset": 0, "high_freq_smooth": 4, "high_freq_latent_smooth": 5, "high_freq_latent_layer": 2, "high_freq_latent_weight": 1.5, "rms_smooth": 5, "bass_smooth": 5, "bass_clip": 65, "drop_clip": 75, "drop_smooth": 5, "drop_weight": 1, "high_noise_clip": 100, "high_noise_weight": 1.5, "low_noise_weight": 1} -------------------------------------------------------------------------------- /workspace/naamloos_pitches_mean.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_pitches_mean.npy -------------------------------------------------------------------------------- /workspace/naamloos_rms.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JCBrouwer/maua-stylegan2/7f9282141053be85ecb1ecc4a19f11bda90298b7/workspace/naamloos_rms.npy --------------------------------------------------------------------------------