├── .gitattributes ├── .gitignore ├── LICENSE.txt ├── README.md ├── StyleGAN2a.ipynb ├── StyleGAN2a_colab.ipynb ├── _in ├── blonde458.npy ├── dlats │ ├── ffhq-1024-f-1.npy │ ├── ffhq-1024-f-1.txt │ ├── ffhq-1024-f-2.npy │ └── ffhq-1024-f-2.txt ├── mask.jpg ├── photo │ └── blonde458.jpg └── vectors_ffhq │ ├── age.npy │ ├── bald.npy │ ├── beard.npy │ ├── blur.npy │ ├── exposure.npy │ ├── eye_distance.npy │ ├── eye_eyebrow_distance.npy │ ├── eye_makeup.npy │ ├── eye_occluded.npy │ ├── eye_open.npy │ ├── eye_ratio.npy │ ├── eye_size.npy │ ├── eyes_open.npy │ ├── feminine.npy │ ├── forehead_height.npy │ ├── forehead_occluded.npy │ ├── gender.npy │ ├── glasses.npy │ ├── hair_black.npy │ ├── hair_blond.npy │ ├── hair_brown.npy │ ├── hair_gray.npy │ ├── hair_other.npy │ ├── hair_red.npy │ ├── headwear.npy │ ├── jaw_height.npy │ ├── lip_height.npy │ ├── lip_makeup.npy │ ├── lip_ratio.npy │ ├── moustache.npy │ ├── mouth_occluded.npy │ ├── mouth_open.npy │ ├── mouth_ratio.npy │ ├── mouth_size.npy │ ├── noise.npy │ ├── nose_mouth_distance.npy │ ├── nose_ratio.npy │ ├── nose_size.npy │ ├── nose_tip.npy │ ├── pitch.npy │ ├── pupils.npy │ ├── roll.npy │ ├── sideburns.npy │ ├── smile.npy │ ├── squint.npy │ ├── sunglasses.npy │ └── yaw.npy ├── _out ├── ffhq-1024-2048x1024-4x1-00.jpg ├── ffhq-1024-2048x1024-4x1-07.jpg ├── ffhq-1024-2048x1024-4x1-16.jpg └── ffhq-1024-2048x1024-4x1-digress-15.jpg ├── data └── multicrop.bat ├── gen.bat ├── model_convert.bat ├── models └── .gitignore ├── play_dlatents.bat ├── play_vectors.bat ├── project.bat ├── requirements.txt ├── src ├── _genSGAN2.py ├── _play_dlatents.py ├── _play_vectors.py ├── dataset_tool.py ├── dnnlib │ ├── __init__.py │ └── util.py ├── legacy.py ├── model_convert.py ├── model_pt2pkl.py ├── notes.txt ├── projector.py ├── torch_utils │ ├── __init__.py │ ├── custom_ops.py │ ├── misc.py │ ├── ops │ │ ├── __init__.py │ │ ├── bias_act.cpp │ │ ├── bias_act.cu │ │ ├── bias_act.h │ │ ├── bias_act.py │ │ ├── conv2d_gradfix.py │ │ ├── conv2d_resample.py │ │ ├── fma.py │ │ ├── grid_sample_gradfix.py │ │ ├── upfirdn2d.cpp │ │ ├── upfirdn2d.cu │ │ ├── upfirdn2d.h │ │ └── upfirdn2d.py │ ├── persistence.py │ └── training_stats.py ├── train.py ├── training │ ├── __init__.py │ ├── augment.py │ ├── dataset.py │ ├── loss.py │ ├── networks.py │ ├── stylegan2_multi.py │ └── training_loop.py └── util │ ├── multicrop.py │ ├── progress_bar.py │ └── utilgan.py ├── train.bat ├── train └── .gitkeep └── train_resume.bat /.gitattributes: -------------------------------------------------------------------------------- 1 | # *.pkl filter=lfs diff=lfs merge=lfs -text 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | _cudacache/ 4 | 5 | # Jupyter Notebook 6 | .ipynb_checkpoints/ 7 | 8 | *.pkl 9 | *.pt 10 | *.avi 11 | *.mp4 12 | *.zip 13 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | ----------------------- LICENSE FOR stylegan2-ada --------------------- 2 | 3 | Copyright (c) 2021, NVIDIA Corporation. All rights reserved. 4 | 5 | 6 | NVIDIA Source Code License for StyleGAN2 with Adaptive Discriminator Augmentation (ADA) 7 | 8 | 9 | ======================================================================= 10 | 11 | 1. Definitions 12 | 13 | "Licensor" means any person or entity that distributes its Work. 14 | 15 | "Software" means the original work of authorship made available under 16 | this License. 17 | 18 | "Work" means the Software and any additions to or derivative works of 19 | the Software that are made available under this License. 20 | 21 | The terms "reproduce," "reproduction," "derivative works," and 22 | "distribution" have the meaning as provided under U.S. copyright law; 23 | provided, however, that for the purposes of this License, derivative 24 | works shall not include works that remain separable from, or merely 25 | link (or bind by name) to the interfaces of, the Work. 26 | 27 | Works, including the Software, are "made available" under this License 28 | by including in or with the Work either (a) a copyright notice 29 | referencing the applicability of this License to the Work, or (b) a 30 | copy of this License. 31 | 32 | 2. License Grants 33 | 34 | 2.1 Copyright Grant. Subject to the terms and conditions of this 35 | License, each Licensor grants to you a perpetual, worldwide, 36 | non-exclusive, royalty-free, copyright license to reproduce, 37 | prepare derivative works of, publicly display, publicly perform, 38 | sublicense and distribute its Work and any resulting derivative 39 | works in any form. 40 | 41 | 3. Limitations 42 | 43 | 3.1 Redistribution. You may reproduce or distribute the Work only 44 | if (a) you do so under this License, (b) you include a complete 45 | copy of this License with your distribution, and (c) you retain 46 | without modification any copyright, patent, trademark, or 47 | attribution notices that are present in the Work. 48 | 49 | 3.2 Derivative Works. You may specify that additional or different 50 | terms apply to the use, reproduction, and distribution of your 51 | derivative works of the Work ("Your Terms") only if (a) Your Terms 52 | provide that the use limitation in Section 3.3 applies to your 53 | derivative works, and (b) you identify the specific derivative 54 | works that are subject to Your Terms. Notwithstanding Your Terms, 55 | this License (including the redistribution requirements in Section 56 | 3.1) will continue to apply to the Work itself. 57 | 58 | 3.3 Use Limitation. The Work and any derivative works thereof only 59 | may be used or intended for use non-commercially. Notwithstanding 60 | the foregoing, NVIDIA and its affiliates may use the Work and any 61 | derivative works commercially. As used herein, "non-commercially" 62 | means for research or evaluation purposes only. 63 | 64 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 65 | against any Licensor (including any claim, cross-claim or 66 | counterclaim in a lawsuit) to enforce any patents that you allege 67 | are infringed by any Work, then your rights under this License from 68 | such Licensor (including the grant in Section 2.1) will terminate 69 | immediately. 70 | 71 | 3.5 Trademarks. This License does not grant any rights to use any 72 | Licensor’s or its affiliates’ names, logos, or trademarks, except 73 | as necessary to reproduce the notices described in this License. 74 | 75 | 3.6 Termination. If you violate any term of this License, then your 76 | rights under this License (including the grant in Section 2.1) will 77 | terminate immediately. 78 | 79 | 4. Disclaimer of Warranty. 80 | 81 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 82 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 83 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 84 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 85 | THIS LICENSE. 86 | 87 | 5. Limitation of Liability. 88 | 89 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 90 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 91 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 92 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 93 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 94 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 95 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 96 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 97 | THE POSSIBILITY OF SUCH DAMAGES. 98 | 99 | ======================================================================= 100 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/README.md -------------------------------------------------------------------------------- /_in/blonde458.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/blonde458.npy -------------------------------------------------------------------------------- /_in/dlats/ffhq-1024-f-1.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/dlats/ffhq-1024-f-1.npy -------------------------------------------------------------------------------- /_in/dlats/ffhq-1024-f-1.txt: -------------------------------------------------------------------------------- 1 | 5,8 -------------------------------------------------------------------------------- /_in/dlats/ffhq-1024-f-2.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/dlats/ffhq-1024-f-2.npy -------------------------------------------------------------------------------- /_in/dlats/ffhq-1024-f-2.txt: -------------------------------------------------------------------------------- 1 | 11,14 -------------------------------------------------------------------------------- /_in/mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/mask.jpg -------------------------------------------------------------------------------- /_in/photo/blonde458.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/photo/blonde458.jpg -------------------------------------------------------------------------------- /_in/vectors_ffhq/age.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/age.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/bald.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/bald.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/beard.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/beard.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/blur.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/blur.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/exposure.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/exposure.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/eye_distance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/eye_distance.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/eye_eyebrow_distance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/eye_eyebrow_distance.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/eye_makeup.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/eye_makeup.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/eye_occluded.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/eye_occluded.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/eye_open.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/eye_open.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/eye_ratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/eye_ratio.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/eye_size.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/eye_size.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/eyes_open.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/eyes_open.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/feminine.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/feminine.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/forehead_height.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/forehead_height.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/forehead_occluded.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/forehead_occluded.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/gender.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/gender.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/glasses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/glasses.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/hair_black.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/hair_black.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/hair_blond.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/hair_blond.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/hair_brown.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/hair_brown.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/hair_gray.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/hair_gray.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/hair_other.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/hair_other.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/hair_red.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/hair_red.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/headwear.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/headwear.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/jaw_height.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/jaw_height.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/lip_height.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/lip_height.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/lip_makeup.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/lip_makeup.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/lip_ratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/lip_ratio.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/moustache.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/moustache.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/mouth_occluded.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/mouth_occluded.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/mouth_open.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/mouth_open.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/mouth_ratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/mouth_ratio.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/mouth_size.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/mouth_size.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/noise.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/noise.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/nose_mouth_distance.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/nose_mouth_distance.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/nose_ratio.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/nose_ratio.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/nose_size.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/nose_size.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/nose_tip.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/nose_tip.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/pitch.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/pitch.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/pupils.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/pupils.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/roll.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/roll.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/sideburns.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/sideburns.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/smile.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/smile.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/squint.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/squint.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/sunglasses.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/sunglasses.npy -------------------------------------------------------------------------------- /_in/vectors_ffhq/yaw.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_in/vectors_ffhq/yaw.npy -------------------------------------------------------------------------------- /_out/ffhq-1024-2048x1024-4x1-00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_out/ffhq-1024-2048x1024-4x1-00.jpg -------------------------------------------------------------------------------- /_out/ffhq-1024-2048x1024-4x1-07.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_out/ffhq-1024-2048x1024-4x1-07.jpg -------------------------------------------------------------------------------- /_out/ffhq-1024-2048x1024-4x1-16.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_out/ffhq-1024-2048x1024-4x1-16.jpg -------------------------------------------------------------------------------- /_out/ffhq-1024-2048x1024-4x1-digress-15.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/_out/ffhq-1024-2048x1024-4x1-digress-15.jpg -------------------------------------------------------------------------------- /data/multicrop.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | 3 | : multicrop images to squares of specific size 4 | : %1 = video file or directory with images 5 | : %2 = size 6 | : %3 = shift step 7 | 8 | if exist %1\* goto dir 9 | 10 | :video 11 | echo .. cropping video 12 | if not exist "%~dp1\%~n1-tmp" md "%~dp1\%~n1-tmp" 13 | ffmpeg -y -v error -i %1 -q:v 2 "%~dp1\%~n1-tmp\%~n1-c-%%06d.png" 14 | python ../src/util/multicrop.py --in_dir %~dp1/%~n1-tmp --out_dir %~dp1/%~n1-sub --size %2 --step %3 15 | rmdir /s /q %~dp1\%~n1-tmp 16 | goto end 17 | 18 | :dir 19 | echo .. cropping images 20 | python ../src/util/multicrop.py --in_dir %1 --out_dir %~dp1/%~n1-sub --size %2 --step %3 21 | goto end 22 | 23 | :end -------------------------------------------------------------------------------- /gen.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 > NUL 3 | rem set TORCH_HOME=C:\X\torch 4 | rem set TORCH_EXTENSIONS_DIR=src\torch_utils\ops\.cache 5 | 6 | if "%1"=="" goto help 7 | 8 | if "%2"=="" goto test 9 | if "%2"=="1" goto test 10 | 11 | set model=%1 12 | set name=%~n1 13 | set size=%2 14 | set frames=%3 15 | set args=%4 %5 %6 %7 %8 %9 16 | for %%q in (1 2 3 4 5 6 7 8 9 10) do shift 17 | set args=%args% %0 %1 %2 %3 %4 %5 %6 %7 %8 %9 18 | for %%q in (1 2 3 4 5 6 7 8 9 10) do shift 19 | set args=%args% %0 %1 %2 %3 %4 %5 %6 %7 %8 %9 20 | for %%q in (1 2 3 4 5 6 7 8 9 10) do shift 21 | set args=%args% %0 %1 %2 %3 %4 %5 %6 %7 %8 %9 22 | 23 | python src/_genSGAN2.py --model models/%model% --out_dir _out/%name% --size %size% --frames %frames% %args% 24 | goto ff 25 | 26 | :test 27 | python src/_genSGAN2.py --model models/%name%.pkl --out_dir _out/%name% --frames 200-20 ^ 28 | %3 %4 %5 %6 %7 %8 %9 29 | 30 | :ff 31 | ffmpeg -y -v warning -i _out\%name%\%%06d.jpg -c:v mjpeg -q:v 2 _out/%name%-%2.avi 32 | rem rmdir /s /q _out\%name% 33 | 34 | goto end 35 | 36 | 37 | :help 38 | echo Usage: gen model x-y framecount-transit 39 | echo e.g.: gen ffhq-1024 1280-720 100-25 40 | 41 | :end 42 | -------------------------------------------------------------------------------- /model_convert.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 > NUL 3 | 4 | python src/model_convert.py --source %1 %2 %3 %4 %5 %6 %7 %8 %9 5 | 6 | -------------------------------------------------------------------------------- /models/.gitignore: -------------------------------------------------------------------------------- 1 | # Ignore everything in this directory 2 | * 3 | # Except this file 4 | !.gitignore -------------------------------------------------------------------------------- /play_dlatents.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 > NUL 3 | 4 | if "%1"=="" goto help 5 | 6 | python src/_play_dlatents.py --model models/%1 --dlatents _in/%2 --out_dir _out/%~n1-%~n2 --fstep %3 --size %4 ^ 7 | %5 %6 %7 %8 %9 8 | 9 | ffmpeg -y -v warning -i _out\%~n1-%~n2\%%06d.jpg -c:v mjpeg -q:v 2 _out/%~n1-%~n2.avi 10 | rem rmdir /s /q _out\%~n1-%~n2 11 | 12 | goto end 13 | 14 | :help 15 | echo Usage: play_dlatents model latentsdir fstep size 16 | echo e.g.: play_dlatents ffhq-1024-f npy 25 1920-1080 17 | 18 | :end -------------------------------------------------------------------------------- /play_vectors.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 > NUL 3 | 4 | if "%1"=="" goto help 5 | 6 | python src/_play_vectors.py --model models/%1 --base_lat _in/%2 --vector_dir _in/%3 --out_dir _out/%~n1-%~n3 ^ 7 | %4 %5 %6 %7 %8 %9 8 | 9 | ffmpeg -y -v warning -i _out\%~n1-%~n3\%%06d.jpg -c:v mjpeg -q:v 2 %~n1-%~n3.avi 10 | rmdir /s /q _out\%~n1-%~n3 11 | 12 | goto end 13 | 14 | :help 15 | echo Usage: play_vectors model latentsdir vector 16 | echo e.g.: play_vectors ffhq-1024-f npy age.npy 17 | 18 | :end -------------------------------------------------------------------------------- /project.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 > NUL 3 | 4 | python src/projector.py --model=models/%1 --in_dir=_in/%2 --out_dir=_out/proj/%2 ^ 5 | --save_video ^ 6 | %3 %4 %5 %6 %7 %8 %9 7 | 8 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | click 2 | psutil 3 | ninja 4 | imageio 5 | imageio-ffmpeg -------------------------------------------------------------------------------- /src/_genSGAN2.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 3 | import os.path as osp 4 | import argparse 5 | import numpy as np 6 | from imageio import imsave 7 | 8 | import torch 9 | 10 | import dnnlib 11 | import legacy 12 | 13 | from util.utilgan import latent_anima, basename, img_read 14 | try: # progress bar for notebooks 15 | get_ipython().__class__.__name__ 16 | from util.progress_bar import ProgressIPy as ProgressBar 17 | except: # normal console 18 | from util.progress_bar import ProgressBar 19 | 20 | desc = "Customized StyleGAN2-ada on PyTorch" 21 | parser = argparse.ArgumentParser(description=desc) 22 | parser.add_argument('-o', '--out_dir', default='_out', help='output directory') 23 | parser.add_argument('-m', '--model', default='models/ffhq-1024.pkl', help='path to pkl checkpoint file') 24 | parser.add_argument('-l', '--labels', default=None, type=int, help='labels/categories for conditioning') 25 | # custom 26 | parser.add_argument('-s', '--size', default=None, help='output resolution, set in X-Y format') 27 | parser.add_argument('-sc', '--scale_type', default='pad', help="main types: pad, padside, symm, symmside") 28 | parser.add_argument('-lm', '--latmask', default=None, help='external mask file (or directory) for multi latent blending') 29 | parser.add_argument('-n', '--nXY', default='1-1', help='multi latent frame split count by X (width) and Y (height)') 30 | parser.add_argument( '--splitfine',default=0, type=float, help='multi latent frame split edge sharpness (0 = smooth, higher => finer)') 31 | parser.add_argument('-tr','--trunc', default=0.8, type=float, help='truncation psi 0..1 (lower = stable, higher = various)') 32 | parser.add_argument('-d', '--digress', default=0, type=float, help='distortion technique by Aydao (strength of the effect)') 33 | parser.add_argument('--save_lat', action='store_true', help='save latent vectors to file') 34 | parser.add_argument( '--ext', default='jpg', help='save as jps or png') 35 | parser.add_argument( '--seed', default=None, type=int) 36 | parser.add_argument('-v', '--verbose', action='store_true') 37 | # animation 38 | parser.add_argument('--frames', default='200-25', help='total frames to generate, length of interpolation step') 39 | parser.add_argument("--cubic", action='store_true', help="use cubic splines for smoothing") 40 | parser.add_argument("--gauss", action='store_true', help="use Gaussian smoothing") 41 | a = parser.parse_args() 42 | 43 | if a.size is not None: 44 | a.size = [int(s) for s in a.size.split('-')][::-1] 45 | if len(a.size) == 1: a.size = a.size * 2 46 | [a.frames, a.fstep] = [int(s) for s in a.frames.split('-')] 47 | 48 | def generate(): 49 | os.makedirs(a.out_dir, exist_ok=True) 50 | if a.seed==0: a.seed = None 51 | np.random.seed(seed=a.seed) 52 | device = torch.device('cuda') 53 | 54 | # setup generator 55 | Gs_kwargs = dnnlib.EasyDict() 56 | Gs_kwargs.verbose = a.verbose 57 | Gs_kwargs.size = a.size 58 | Gs_kwargs.scale_type = a.scale_type 59 | 60 | # mask/blend latents with external latmask or by splitting the frame 61 | if a.latmask is None: 62 | nHW = [int(s) for s in a.nXY.split('-')][::-1] 63 | assert len(nHW)==2, ' Wrong count nXY: %d (must be 2)' % len(nHW) 64 | n_mult = nHW[0] * nHW[1] 65 | Gs_kwargs.countHW = nHW 66 | Gs_kwargs.splitfine = a.splitfine 67 | if a.verbose is True and n_mult > 1: print(' Latent blending w/split frame %d x %d' % (nHW[1], nHW[0])) 68 | lmask = [None] 69 | else: 70 | n_mult = 2 71 | if osp.isfile(a.latmask): # single file 72 | lmask = np.asarray([[img_read(a.latmask)[:,:,0] / 255.]]) # [1,1,h,w] 73 | elif osp.isdir(a.latmask): # directory with frame sequence 74 | lmask = np.expand_dims(np.asarray([img_read(f)[:,:,0] / 255. for f in img_list(a.latmask)]), 1) # [n,1,h,w] 75 | else: 76 | print(' !! Blending mask not found:', a.latmask); exit(1) 77 | if a.verbose is True: print(' Latent blending with mask', a.latmask, lmask.shape) 78 | lmask = np.concatenate((lmask, 1 - lmask), 1) # [frm,2,h,w] 79 | lmask = torch.from_numpy(lmask).to(device) 80 | 81 | # load base or custom network 82 | pkl_name = osp.splitext(a.model)[0] 83 | if '.pkl' in a.model.lower(): 84 | custom = False 85 | print(' .. Gs from pkl ..', basename(a.model)) 86 | else: 87 | custom = True 88 | print(' .. Gs custom ..', basename(a.model)) 89 | with dnnlib.util.open_url(pkl_name + '.pkl') as f: 90 | Gs = legacy.load_network_pkl(f, custom=custom, **Gs_kwargs)['G_ema'].to(device) # type: ignore 91 | 92 | if a.verbose is True: print(' out shape', Gs.output_shape[1:]) 93 | 94 | if a.verbose is True: print(' making timeline..') 95 | lats = [] # list of [frm,1,512] 96 | for i in range(n_mult): 97 | lat_tmp = latent_anima((1, Gs.z_dim), a.frames, a.fstep, cubic=a.cubic, gauss=a.gauss, seed=a.seed, verbose=False) # [frm,1,512] 98 | lats.append(lat_tmp) # list of [frm,1,512] 99 | latents = np.concatenate(lats, 1) # [frm,X,512] 100 | print(' latents', latents.shape) 101 | latents = torch.from_numpy(latents).to(device) 102 | frame_count = latents.shape[0] 103 | 104 | # distort image by tweaking initial const layer 105 | if a.digress > 0: 106 | try: init_res = Gs.init_res 107 | except: init_res = (4,4) # default initial layer size 108 | dconst = [] 109 | for i in range(n_mult): 110 | dc_tmp = a.digress * latent_anima([1, Gs.z_dim, *init_res], a.frames, a.fstep, cubic=True, seed=a.seed, verbose=False) 111 | dconst.append(dc_tmp) 112 | dconst = np.concatenate(dconst, 1) 113 | else: 114 | dconst = np.zeros([frame_count, 1, 1, 1, 1]) 115 | dconst = torch.from_numpy(dconst).to(device) 116 | 117 | # labels / conditions 118 | label_size = Gs.c_dim 119 | if label_size > 0: 120 | labels = torch.zeros((frame_count, n_mult, label_size), device=device) # [frm,X,lbl] 121 | if a.labels is None: 122 | label_ids = [] 123 | for i in range(n_mult): 124 | label_ids.append(random.randint(0, label_size-1)) 125 | else: 126 | label_ids = [int(x) for x in a.labels.split('-')] 127 | label_ids = label_ids[:n_mult] # ensure we have enough labels 128 | for i, l in enumerate(label_ids): 129 | labels[:,i,l] = 1 130 | else: 131 | labels = [None] 132 | 133 | # warm up 134 | if custom: 135 | _ = Gs(latents[0], labels[0], lmask[0], dconst[0], noise_mode='const') 136 | else: 137 | _ = Gs(latents[0], labels[0], noise_mode='const') 138 | 139 | # generate images from latent timeline 140 | pbar = ProgressBar(frame_count) 141 | for i in range(frame_count): 142 | 143 | latent = latents[i] # [X,512] 144 | label = labels[i % len(labels)] 145 | latmask = lmask[i % len(lmask)] # [X,h,w] or None 146 | dc = dconst[i % len(dconst)] # [X,512,4,4] 147 | 148 | # generate multi-latent result 149 | if custom: 150 | output = Gs(latent, label, latmask, dc, truncation_psi=a.trunc, noise_mode='const') 151 | else: 152 | output = Gs(latent, label, truncation_psi=a.trunc, noise_mode='const') 153 | output = (output.permute(0,2,3,1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() 154 | 155 | # save image 156 | ext = 'png' if output.shape[3]==4 else a.ext if a.ext is not None else 'jpg' 157 | filename = osp.join(a.out_dir, "%06d.%s" % (i,ext)) 158 | imsave(filename, output[0]) 159 | pbar.upd() 160 | 161 | # convert latents to dlatents, save them 162 | if a.save_lat is True: 163 | latents = latents.squeeze(1) # [frm,512] 164 | dlatents = Gs.mapping(latents, label) # [frm,18,512] 165 | if a.size is None: a.size = ['']*2 166 | filename = '{}-{}-{}.npy'.format(basename(a.model), a.size[1], a.size[0]) 167 | filename = osp.join(osp.dirname(a.out_dir), filename) 168 | dlatents = dlatents.cpu().numpy() 169 | np.save(filename, dlatents) 170 | print('saved dlatents', dlatents.shape, 'to', filename) 171 | 172 | 173 | if __name__ == '__main__': 174 | generate() 175 | -------------------------------------------------------------------------------- /src/_play_dlatents.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 3 | import os.path as osp 4 | import argparse 5 | import numpy as np 6 | from imageio import imsave 7 | 8 | import torch 9 | 10 | import dnnlib 11 | import legacy 12 | 13 | from util.utilgan import latent_anima, load_latents, file_list, basename 14 | try: # progress bar for notebooks 15 | get_ipython().__class__.__name__ 16 | from util.progress_bar import ProgressIPy as ProgressBar 17 | except: # normal console 18 | from util.progress_bar import ProgressBar 19 | 20 | desc = "Customized StyleGAN2 on Tensorflow" 21 | parser = argparse.ArgumentParser(description=desc) 22 | parser.add_argument('--dlatents', default=None, help='Saved dlatent vectors in single *.npy file or directory with such files') 23 | parser.add_argument('--style_dlat', default=None, help='Saved latent vector for hi res (style) features') 24 | parser.add_argument('--out_dir', default='_out', help='Output directory') 25 | parser.add_argument('--model', default='models/ffhq-1024-f.pkl', help='path to checkpoint file') 26 | parser.add_argument('--size', default=None, help='Output resolution') 27 | parser.add_argument('--scale_type', default='pad', help="main types: pad, padside, symm, symmside") 28 | parser.add_argument('--trunc', type=float, default=1, help='Truncation psi 0..1 (lower = stable, higher = various)') 29 | parser.add_argument('--digress', type=float, default=0, help='distortion technique by Aydao (strength of the effect)') 30 | parser.add_argument('--ext', default='jpg', help='save as jps or png') 31 | parser.add_argument('--verbose', action='store_true') 32 | parser.add_argument('--ops', default='cuda', help='custom op implementation (cuda or ref)') 33 | # animation 34 | parser.add_argument("--fstep", type=int, default=25, help="Number of frames for smooth interpolation") 35 | parser.add_argument("--cubic", action='store_true', help="Use cubic splines for smoothing") 36 | a = parser.parse_args() 37 | 38 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1] 39 | 40 | def main(): 41 | os.makedirs(a.out_dir, exist_ok=True) 42 | device = torch.device('cuda') 43 | 44 | # setup generator 45 | Gs_kwargs = dnnlib.EasyDict() 46 | Gs_kwargs.verbose = a.verbose 47 | Gs_kwargs.size = a.size 48 | Gs_kwargs.scale_type = a.scale_type 49 | 50 | # load base or custom network 51 | pkl_name = osp.splitext(a.model)[0] 52 | if '.pkl' in a.model.lower(): 53 | custom = False 54 | print(' .. Gs from pkl ..', basename(a.model)) 55 | else: 56 | custom = True 57 | print(' .. Gs custom ..', basename(a.model)) 58 | with dnnlib.util.open_url(pkl_name + '.pkl') as f: 59 | Gs = legacy.load_network_pkl(f, custom=custom, **Gs_kwargs)['G_ema'].to(device) # type: ignore 60 | 61 | dlat_shape = (1, Gs.num_ws, Gs.w_dim) # [1,18,512] 62 | 63 | # read saved latents 64 | if a.dlatents is not None and osp.isfile(a.dlatents): 65 | key_dlatents = load_latents(a.dlatents) 66 | if len(key_dlatents.shape) == 2: key_dlatents = np.expand_dims(key_dlatents, 0) 67 | elif a.dlatents is not None and osp.isdir(a.dlatents): 68 | # if a.dlatents.endswith('/') or a.dlatents.endswith('\\'): a.dlatents = a.dlatents[:-1] 69 | key_dlatents = [] 70 | npy_list = file_list(a.dlatents, 'npy') 71 | for npy in npy_list: 72 | key_dlatent = load_latents(npy) 73 | if len(key_dlatent.shape) == 2: key_dlatent = np.expand_dims(key_dlatent, 0) 74 | key_dlatents.append(key_dlatent) 75 | key_dlatents = np.concatenate(key_dlatents) # [frm,18,512] 76 | else: 77 | print(' No input dlatents found'); exit() 78 | key_dlatents = key_dlatents[:, np.newaxis] # [frm,1,18,512] 79 | print(' key dlatents', key_dlatents.shape) 80 | 81 | # replace higher layers with single (style) latent 82 | if a.style_dlat is not None: 83 | print(' styling with dlatent', a.style_dlat) 84 | style_dlatent = load_latents(a.style_dlat) 85 | while len(style_dlatent.shape) < 4: style_dlatent = np.expand_dims(style_dlatent, 0) 86 | # try replacing 5 by other value, less than Gs.num_ws 87 | key_dlatents[:, :, range(5, Gs.num_ws), :] = style_dlatent[:, :, range(5, Gs.num_ws), :] 88 | 89 | frames = key_dlatents.shape[0] * a.fstep 90 | 91 | dlatents = latent_anima(dlat_shape, frames, a.fstep, key_latents=key_dlatents, cubic=a.cubic, verbose=True) # [frm,1,512] 92 | print(' dlatents', dlatents.shape) 93 | frame_count = dlatents.shape[0] 94 | dlatents = torch.from_numpy(dlatents).to(device) 95 | 96 | # distort image by tweaking initial const layer 97 | if a.digress > 0: 98 | try: init_res = Gs.init_res 99 | except: init_res = (4,4) # default initial layer size 100 | dconst = a.digress * latent_anima([1, Gs.z_dim, *init_res], frame_count, a.fstep, cubic=True, verbose=False) 101 | else: 102 | dconst = np.zeros([frame_count, 1, 1, 1, 1]) 103 | dconst = torch.from_numpy(dconst).to(device) 104 | 105 | # generate images from latent timeline 106 | pbar = ProgressBar(frame_count) 107 | for i in range(frame_count): 108 | 109 | # generate multi-latent result 110 | if custom: 111 | output = Gs.synthesis(dlatents[i], None, dconst[i], noise_mode='const') 112 | else: 113 | output = Gs.synthesis(dlatents[i], noise_mode='const') 114 | output = (output.permute(0,2,3,1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy() 115 | 116 | ext = 'png' if output.shape[3]==4 else a.ext if a.ext is not None else 'jpg' 117 | filename = osp.join(a.out_dir, "%06d.%s" % (i,ext)) 118 | imsave(filename, output[0]) 119 | pbar.upd() 120 | 121 | 122 | if __name__ == '__main__': 123 | main() 124 | 125 | -------------------------------------------------------------------------------- /src/_play_vectors.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ['TF_CPP_MIN_LOG_LEVEL']='2' 3 | import os.path as osp 4 | import argparse 5 | import numpy as np 6 | from imageio import imsave 7 | import pickle 8 | import cv2 9 | 10 | import torch 11 | 12 | import dnnlib 13 | import legacy 14 | 15 | from util.utilgan import load_latents, file_list, basename 16 | try: # progress bar for notebooks 17 | get_ipython().__class__.__name__ 18 | from util.progress_bar import ProgressIPy as ProgressBar 19 | except: # normal console 20 | from util.progress_bar import ProgressBar 21 | 22 | desc = "Customized StyleGAN2 on Tensorflow" 23 | parser = argparse.ArgumentParser(description=desc) 24 | parser.add_argument('--vector_dir', default=None, help='Saved latent directions in *.npy format') 25 | parser.add_argument('--base_lat', default=None, help='Saved latent vector as *.npy file') 26 | parser.add_argument('--out_dir', default='_out/ttt', help='Output directory') 27 | parser.add_argument('--model', default='models/ffhq-1024.pkl', help='path to checkpoint file') 28 | parser.add_argument('--size', default=None, help='output resolution, set in X-Y format') 29 | parser.add_argument('--scale_type', default='pad', help="main types: pad, padside, symm, symmside") 30 | parser.add_argument('--trunc', type=float, default=0.8, help='truncation psi 0..1 (lower = stable, higher = various)') 31 | parser.add_argument('--ext', default='jpg', help='save as jps or png') 32 | parser.add_argument('--verbose', action='store_true') 33 | parser.add_argument('--ops', default='cuda', help='custom op implementation (cuda or ref)') 34 | # animation 35 | parser.add_argument("--fstep", type=int, default=25, help="Number of frames for interpolation step") 36 | a = parser.parse_args() 37 | 38 | if a.size is not None: a.size = [int(s) for s in a.size.split('-')][::-1] 39 | 40 | def generate_image(latent): 41 | args = [None, 0] if custom else [] # custom model? 42 | if use_d: 43 | img = Gs.synthesis(latent, *args, noise_mode='const') 44 | else: 45 | img = Gs(latent, None, *args, truncation_psi=a.trunc, noise_mode='const') 46 | img = (img.permute(0,2,3,1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu().numpy()[0] 47 | return img 48 | 49 | def render_latent_dir(latent, direction, coeff): 50 | new_latent = latent + coeff*direction 51 | img = generate_image(new_latent) 52 | return img 53 | 54 | def render_latent_mix(latent1, latent2, coeff): 55 | new_latent = latent1 * (1-coeff) + latent2 * coeff 56 | img = generate_image(new_latent) 57 | return img 58 | 59 | def pingpong(x, delta): 60 | x = (x + delta) % 2 61 | if x > 1: 62 | x = 1 - (x%1) 63 | delta = -delta 64 | return x, delta 65 | 66 | def get_coeffs_dir(lrange, count): 67 | dx = 1 / count 68 | x = -lrange[0] / (lrange[1] - lrange[0]) 69 | xs = [0] 70 | for _ in range(count*2): 71 | x, dx = pingpong(x, dx) 72 | xs.append( x * (lrange[1] - lrange[0]) + lrange[0] ) 73 | return xs 74 | 75 | def make_loop(base_latent, direction, lrange, fcount, start_frame=0, ext='jpg'): 76 | coeffs = get_coeffs_dir(lrange, fcount//2) 77 | # pbar = ProgressBar(fcount) 78 | for i in range(fcount): 79 | img = render_latent_dir(base_latent, direction, coeffs[i]) 80 | fname1 = os.path.join(a.out_dir, "%06d.%s" % (i+start_frame, ext)) 81 | if i%2==0 and a.verbose is True: 82 | cv2.imshow('latent', img[:,:,::-1]) 83 | cv2.waitKey(10) 84 | imsave(fname1, img) 85 | # pbar.upd() 86 | 87 | def make_transit(lat1, lat2, fcount, start_frame=0, ext='jpg'): 88 | # pbar = ProgressBar(fcount) 89 | for i in range(fcount): 90 | img = render_latent_mix(lat1, lat2, i/fcount) 91 | fname = os.path.join(a.out_dir, "%06d.%s" % (i+start_frame, ext)) 92 | if i%2==0 and a.verbose is True: 93 | cv2.imshow('latent', img[:,:,::-1]) 94 | cv2.waitKey(10) 95 | imsave(fname, img) 96 | # pbar.upd() 97 | 98 | def main(): 99 | if a.vector_dir is not None: 100 | if a.vector_dir.endswith('/') or a.vector_dir.endswith('\\'): a.vector_dir = a.vector_dir[:-1] 101 | os.makedirs(a.out_dir, exist_ok=True) 102 | device = torch.device('cuda') 103 | 104 | global Gs, use_d, custom 105 | 106 | # setup generator 107 | Gs_kwargs = dnnlib.EasyDict() 108 | Gs_kwargs.verbose = a.verbose 109 | Gs_kwargs.size = a.size 110 | Gs_kwargs.scale_type = a.scale_type 111 | 112 | # load base or custom network 113 | pkl_name = osp.splitext(a.model)[0] 114 | if '.pkl' in a.model.lower(): 115 | custom = False 116 | print(' .. Gs from pkl ..', basename(a.model)) 117 | else: 118 | custom = True 119 | print(' .. Gs custom ..', basename(a.model)) 120 | with dnnlib.util.open_url(pkl_name + '.pkl') as f: 121 | Gs = legacy.load_network_pkl(f, custom=custom, **Gs_kwargs)['G_ema'].to(device) # type: ignore 122 | 123 | # load directions 124 | if a.vector_dir is not None: 125 | directions = [] 126 | vector_list = file_list(a.vector_dir, 'npy') 127 | for v in vector_list: 128 | direction = load_latents(v) 129 | if len(direction.shape) == 2: direction = np.expand_dims(direction, 0) 130 | directions.append(direction) 131 | directions = np.concatenate(directions)[:, np.newaxis] # [frm,1,18,512] 132 | else: 133 | print(' No vectors found'); exit() 134 | 135 | if len(direction[0].shape) > 1 and direction[0].shape[0] > 1: 136 | use_d = True 137 | print(' directions', directions.shape, 'using d' if use_d else 'using w') 138 | directions = torch.from_numpy(directions).to(device) 139 | 140 | # latent direction range 141 | lrange = [-0.5, 0.5] 142 | 143 | # load saved latents 144 | if a.base_lat is not None: 145 | base_latent = load_latents(a.base_lat) 146 | base_latent = torch.from_numpy(base_latent).to(device) 147 | else: 148 | print(' No NPY input given, making random') 149 | base_latent = np.random.randn(1, Gs.z_dim) 150 | if use_d: 151 | base_latent = Gs.mapping(base_latent, None) # [frm,18,512] 152 | 153 | pbar = ProgressBar(len(directions)) 154 | for i, direction in enumerate(directions): 155 | make_loop(base_latent, direction, lrange, a.fstep*2, a.fstep*2 * i, a.ext) 156 | pbar.upd() 157 | 158 | # make_transit(base_lats[i], base_lats[(i+1)%len(base_lats)], n, 2*n*i + n, a.ext) 159 | 160 | 161 | if __name__ == '__main__': 162 | main() 163 | 164 | -------------------------------------------------------------------------------- /src/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /src/model_convert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | import pickle 4 | import numpy as np 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | import dnnlib 10 | import legacy 11 | 12 | from util.utilgan import basename, calc_init_res 13 | try: # progress bar for notebooks 14 | get_ipython().__class__.__name__ 15 | from util.progress_bar import ProgressIPy as ProgressBar 16 | except: # normal console 17 | from util.progress_bar import ProgressBar 18 | 19 | parser = argparse.ArgumentParser() 20 | parser.add_argument('--source', required=True, help='Source model path') 21 | parser.add_argument('--out_dir', default='./', help='Output directory for reduced/reconstructed model') 22 | parser.add_argument('-r', '--reconstruct', action='store_true', help='Reconstruct model (add internal arguments)') 23 | parser.add_argument('-s', '--res', default=None, help='Target resolution in format X-Y') 24 | parser.add_argument('-a', '--alpha', action='store_true', help='Add alpha channel for RGBA processing') 25 | parser.add_argument('-l', '--labels', default=None, type=int, help='Labels for conditional model') 26 | parser.add_argument('-f', '--full', action='store_true', help='Save full model') 27 | parser.add_argument('-v', '--verbose', action='store_true') 28 | a = parser.parse_args() 29 | 30 | if a.res is not None: 31 | a.res = [int(s) for s in a.res.split('-')][::-1] 32 | if len(a.res) == 1: a.res = a.res + a.res 33 | 34 | def load_pkl(filepath): 35 | with dnnlib.util.open_url(filepath) as f: 36 | nets = legacy.load_network_pkl(f, custom=False) # ['G', 'D', 'G_ema', 'training_set_kwargs', 'augment_pipe'] 37 | return nets 38 | 39 | def save_pkl(nets, filepath): 40 | with open(filepath, 'wb') as file: 41 | pickle.dump(nets, file) # , protocol=pickle.HIGHEST_PROTOCOL 42 | 43 | def create_model(net_in, data_shape, labels=None, full=False, custom=False, init=False): 44 | init_res, resolution, res_log2 = calc_init_res(data_shape[1:]) 45 | net_in['G_ema'].img_resolution = resolution 46 | net_in['G_ema'].img_channels = data_shape[0] 47 | net_in['G_ema'].init_res = init_res 48 | net_out = legacy.create_networks(net_in, full=full, custom=custom, init=init, labels=labels) 49 | return net_out 50 | 51 | def add_channel(x, subnet): # [BCHW] 52 | if subnet == 'D': # pad second dim [1] 53 | padding = [0] * (len(x.shape)-2)*2 54 | padding += [0,1,0,0] 55 | else: # pad last dim [-1] 56 | padding = [0] * (len(x.shape)-1)*2 57 | padding += [0,1] 58 | y = F.pad(x, padding, 'constant', 1) 59 | return y 60 | 61 | def pad_up_to(x, size, type='side'): 62 | sh = x.shape 63 | if list(x.shape) == list(size): return x 64 | padding = [] 65 | for i, s in enumerate(size): 66 | p0 = (s-sh[i]) // 2 67 | p1 = s-sh[i] - p0 68 | padding = padding + [p0,p1] 69 | y = F.pad(x, padding[::-1], 'constant', 0) 70 | return y 71 | 72 | def copy_vars(src_net, tgt_net, add_alpha=False, xtile=False) -> None: 73 | for subnet in ['G_ema', 'G', 'D']: 74 | if subnet in src_net.keys() and subnet in tgt_net.keys(): 75 | src_dict = src_net[subnet].state_dict() 76 | tgt_dict = tgt_net[subnet].state_dict() 77 | vars = [name for name in src_dict.keys() if name in tgt_dict.keys()] 78 | pbar = ProgressBar(len(vars)) 79 | for name in vars: 80 | source_shape = src_dict[name].shape 81 | target_shape = tgt_dict[name].shape 82 | if source_shape == target_shape: 83 | tgt_dict[name].copy_(src_dict[name]).requires_grad_(False) 84 | else: 85 | if add_alpha: 86 | update = add_channel(src_dict[name], subnet) 87 | assert target_shape == update.shape, 'Diff shapes yet: src %s tgt %s' % (str(update.shape), str(target_shape)) 88 | tgt_dict[name].copy_(update).requires_grad_(False) 89 | elif xtile: 90 | assert len(source_shape) == len(target_shape), "Diff shape ranks: src %s tgt %s" % (str(source_shape), str(target_shape)) 91 | try: 92 | update = src_dict[name][:target_shape[0], :target_shape[1], ...] # !!! corrects only first two dims 93 | except: 94 | update = src_dict[name][:target_shape[0]] 95 | if np.greater(target_shape, source_shape).any(): 96 | tile_count = [target_shape[i] // source_shape[i] for i in range(len(source_shape))] 97 | update = src_dict[name].repeat(*tile_count) # [512,512] => [1024,512] 98 | if a.verbose is True: print(name, tile_count, source_shape, '=>', target_shape, '\n\n') # G_mapping/Dense0, D/Output 99 | tgt_dict[name].copy_(update).requires_grad_(False) 100 | else: # crop/pad 101 | update = pad_up_to(src_dict[name], target_shape) 102 | if a.verbose is True: print(name, source_shape, '=>', update.shape, '\n\n') 103 | tgt_dict[name].copy_(update).requires_grad_(False) 104 | pbar.upd(name) 105 | 106 | def main(): 107 | 108 | net_in = load_pkl(a.source) 109 | Gs_in = net_in['G_ema'] 110 | if hasattr(Gs_in, 'output_shape'): 111 | out_shape = Gs_in.output_shape 112 | print(' Loading model', a.source, out_shape) 113 | _, res_in, _ = calc_init_res(out_shape[1:]) 114 | else: # original model 115 | res_in = Gs_in.img_resolution 116 | out_shape = [None, Gs_in.img_channels, res_in, res_in] 117 | # netdict = net_in['G_ema'].state_dict() 118 | # for k in netdict.keys(): 119 | # print(k, netdict[k].shape) 120 | 121 | if a.res is not None or a.alpha is True: 122 | if a.res is None: a.res = out_shape[2:] 123 | colors = 4 if a.alpha is True else out_shape[1] 124 | _, res_out, _ = calc_init_res([colors, *a.res]) 125 | 126 | if res_in != res_out or a.alpha is True: # add or remove layers 127 | assert 'G' in net_in.keys() and 'D' in net_in.keys(), " !! G/D subnets not found in source model !!" 128 | data_shape = [colors, res_out, res_out] 129 | print(' Reconstructing full model with shape', data_shape) 130 | net_out = create_model(net_in, data_shape, full=True) 131 | copy_vars(net_in, net_out, add_alpha=True) 132 | a.full = True 133 | 134 | if a.res[0] != res_out or a.res[1] != res_out: # crop or pad layers 135 | data_shape = [colors, *a.res] 136 | net_out = create_model(net_in, data_shape, full=True) 137 | copy_vars(net_in, net_out) 138 | 139 | if a.labels is not None: 140 | assert 'G' in net_in.keys() and 'D' in net_in.keys(), " !! G/D subnets not found in source model !!" 141 | print(' Reconstructing full model with labels', a.labels) 142 | data_shape = out_shape[1:] 143 | net_out = create_model(net_in, data_shape, labels=a.labels, full=True) 144 | copy_vars(net_in, net_out, xtile=True) 145 | a.full = True 146 | 147 | if a.labels is None and a.res is None and a.alpha is not True: 148 | if a.reconstruct is True: 149 | print(' Reconstructing model with same size /', 'full' if a.full else 'Gs') 150 | data_shape = out_shape[1:] 151 | net_out = create_model(net_in, data_shape, full=a.full, init=True) 152 | else: 153 | net_out = dict(G_ema = Gs_in) 154 | 155 | out_name = basename(a.source) 156 | if a.res is not None: out_name += '-%dx%d' % (a.res[1], a.res[0]) 157 | if a.alpha is True: out_name += 'a' 158 | if a.labels is not None: out_name += '-c%d' % a.labels 159 | if not a.full: out_name += '-Gs' 160 | 161 | save_pkl(net_out, os.path.join(a.out_dir, '%s.pkl' % out_name)) 162 | print(' Done') 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /src/model_pt2pkl.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import math 4 | import argparse 5 | import numpy as np 6 | import pickle 7 | 8 | import torch 9 | 10 | import dnnlib, legacy 11 | try: # progress bar for notebooks 12 | get_ipython().__class__.__name__ 13 | from util.progress_bar import ProgressIPy as ProgressBar 14 | except: # normal console 15 | from util.progress_bar import ProgressBar 16 | 17 | def get_args(): 18 | parser = argparse.ArgumentParser(description="Rosinality (pytorch) to Nvidia (pkl) checkpoint converter") 19 | parser.add_argument("--model_pkl", metavar="PATH", help="path to the source sg2ada pytorch (pkl) weights") 20 | parser.add_argument("--model_pt", metavar="PATH", help="path to the updated pytorch (pt) weights") 21 | args = parser.parse_args() 22 | return args 23 | 24 | def load_pkl(filepath): 25 | with dnnlib.util.open_url(filepath) as f: 26 | nets = legacy.load_network_pkl(f, custom=False) 27 | return nets 28 | 29 | def save_pkl(nets, filepath): 30 | with open(filepath, 'wb') as file: 31 | pickle.dump(nets, file) 32 | 33 | def update(tgt_dict, name, value): 34 | tgt_dict[name].copy_(value).requires_grad_(False) 35 | 36 | def convert_modconv(tgt_dict, src_dict, target_name, source_name): 37 | conv_weight = src_dict[source_name + ".conv.weight"].squeeze(0) 38 | update(tgt_dict, target_name + ".weight", conv_weight) 39 | update(tgt_dict, target_name + ".affine.weight", src_dict[source_name + ".conv.modulation.weight"]) 40 | update(tgt_dict, target_name + ".affine.bias", src_dict[source_name + ".conv.modulation.bias"]) 41 | update(tgt_dict, target_name + ".noise_strength", src_dict[source_name + ".noise.weight"].squeeze()) 42 | update(tgt_dict, target_name + ".bias", src_dict[source_name + ".activate.bias"].squeeze()) 43 | 44 | def convert_torgb(tgt_dict, src_dict, target_name, source_name): 45 | update(tgt_dict, target_name + ".weight", src_dict[source_name + ".conv.weight"].squeeze(0).squeeze(0)) 46 | update(tgt_dict, target_name + ".affine.weight", src_dict[source_name + ".conv.modulation.weight"]) 47 | update(tgt_dict, target_name + ".affine.bias", src_dict[source_name + ".conv.modulation.bias"]) 48 | update(tgt_dict, target_name + ".bias", src_dict[source_name + ".bias"].squeeze()) 49 | 50 | def convert_dense(tgt_dict, src_dict, target_name, source_name): 51 | update(tgt_dict, target_name + ".weight", src_dict[source_name + ".weight"]) 52 | update(tgt_dict, target_name + ".bias", src_dict[source_name + ".bias"]) 53 | 54 | def update_G(src_dict, tgt_dict, size, n_mlp): 55 | log_size = int(math.log(size, 2)) 56 | 57 | pbar = ProgressBar(n_mlp + log_size-2 + log_size-2 + (log_size-2)*2+1 + 2) 58 | for i in range(n_mlp): 59 | convert_dense(tgt_dict, src_dict, f"mapping.fc{i}", f"style.{i+1}") 60 | pbar.upd() 61 | update(tgt_dict, "synthesis.b4.const", src_dict["input.input"].squeeze(0)) 62 | convert_torgb(tgt_dict, src_dict, "synthesis.b4.torgb", "to_rgb1") 63 | pbar.upd() 64 | 65 | for i in range(log_size-2): 66 | reso = 4 * 2 ** (i+1) 67 | convert_torgb(tgt_dict, src_dict, f"synthesis.b{reso}.torgb", f"to_rgbs.{i}") 68 | pbar.upd() 69 | convert_modconv(tgt_dict, src_dict, "synthesis.b4.conv1", "conv1") 70 | pbar.upd() 71 | 72 | conv_i = 0 73 | for i in range(log_size-2): 74 | reso = 4 * 2 ** (i+1) 75 | convert_modconv(tgt_dict, src_dict, f"synthesis.b{reso}.conv0", f"convs.{conv_i}") 76 | convert_modconv(tgt_dict, src_dict, f"synthesis.b{reso}.conv1", f"convs.{conv_i + 1}") 77 | conv_i += 2 78 | pbar.upd() 79 | 80 | for i in range(0, (log_size-2) * 2 + 1): 81 | reso = 4 * 2 ** (math.ceil(i/2)) 82 | update(tgt_dict, f"synthesis.b{reso}.conv{(i+1)%2}.noise_const", src_dict[f"noises.noise_{i}"].squeeze()) 83 | pbar.upd() 84 | 85 | src_kernels = [k for k in src_dict.keys() if 'kernel' in k] 86 | src_kernel = src_dict[src_kernels[0]] 87 | tgt_kernels = [k for k in tgt_dict.keys() if 'resample_filter' in k] # [0] 88 | for tgt_k in tgt_kernels: 89 | update(tgt_dict, tgt_k, src_kernel/4) 90 | 91 | def load_net_from_pkl(path): 92 | tgt_net = load_pkl(path) 93 | Gs = tgt_net['G_ema'] 94 | tgt_dict = tgt_net['G_ema'].state_dict() 95 | n_mlp = len([l for l in tgt_dict.keys() if l.startswith('mapping.fc')]) // 2 96 | size = tgt_net['G_ema'].img_resolution 97 | return tgt_net, size, n_mlp 98 | 99 | if __name__ == "__main__": 100 | args = get_args() 101 | 102 | tgt_net, size, n_mlp = load_net_from_pkl(args.model_pkl) 103 | tgt_dict = tgt_net['G_ema'].state_dict() 104 | src_dict = torch.load(args.model_pt) 105 | update_G(src_dict['g_ema'], tgt_dict, size, n_mlp) 106 | 107 | out_name = args.model_pt.replace('.pt', '.pkl') 108 | save_pkl(dict(G_ema=tgt_net['G_ema']), out_name) 109 | 110 | -------------------------------------------------------------------------------- /src/notes.txt: -------------------------------------------------------------------------------- 1 | https://github.com/NVlabs/stylegan2-ada-pytorch 2 | 3 | set DNNLIB_CACHE_DIR=models 4 | set TORCH_EXTENSIONS_DIR=src/.cache/torch -------------------------------------------------------------------------------- /src/projector.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Project given image to the latent space of pretrained network pickle.""" 10 | 11 | import os 12 | import copy 13 | import argparse 14 | # from time import perf_counter 15 | 16 | import imageio 17 | import numpy as np 18 | import PIL.Image 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | import dnnlib 24 | import legacy 25 | 26 | from util.utilgan import img_list, img_read, basename 27 | try: # progress bar for notebooks 28 | get_ipython().__class__.__name__ 29 | from util.progress_bar import ProgressIPy as ProgressBar 30 | except: # normal console 31 | from util.progress_bar import ProgressBar 32 | 33 | def project( 34 | G, 35 | target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution 36 | *, 37 | num_steps = 1000, 38 | w_avg_samples = 10000, 39 | initial_learning_rate = 0.1, 40 | initial_noise_factor = 0.05, 41 | lr_rampdown_length = 0.25, 42 | lr_rampup_length = 0.05, 43 | noise_ramp_length = 0.75, 44 | regularize_noise_weight = 1e5, 45 | verbose = False, 46 | device: torch.device 47 | ): 48 | assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) 49 | 50 | # def logprint(*args): 51 | # if verbose: 52 | # print(*args) 53 | 54 | G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore 55 | 56 | # Compute w stats. 57 | # logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') 58 | z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) 59 | w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] 60 | w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] 61 | w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] 62 | w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 63 | 64 | # Setup noise inputs. 65 | noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } 66 | 67 | # Load VGG16 feature detector. 68 | vgg_file = 'models/vgg/vgg16.pt' 69 | if os.path.isfile(vgg_file) and os.stat(vgg_file).st_size == 553469545: 70 | with dnnlib.util.open_url(vgg_file) as file: 71 | # network = pickle.load(file, encoding='latin1') 72 | vgg16 = torch.jit.load(file).eval().to(device) 73 | else: 74 | with dnnlib.util.open_url('https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt') as file: 75 | vgg16 = torch.jit.load(file).eval().to(device) 76 | 77 | # Features for target image. 78 | target_images = target.unsqueeze(0).to(device).to(torch.float32) 79 | if target_images.shape[2] > 256: 80 | target_images = F.interpolate(target_images, size=(256, 256), mode='area') 81 | target_features = vgg16(target_images, resize_images=False, return_lpips=True) 82 | 83 | w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable 84 | w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device) 85 | optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) 86 | 87 | # Init noise. 88 | for buf in noise_bufs.values(): 89 | buf[:] = torch.randn_like(buf) 90 | buf.requires_grad = True 91 | 92 | pbar = ProgressBar(num_steps) 93 | for step in range(num_steps): 94 | # Learning rate schedule. 95 | t = step / num_steps 96 | w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 97 | lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) 98 | lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) 99 | lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) 100 | lr = initial_learning_rate * lr_ramp 101 | for param_group in optimizer.param_groups: 102 | param_group['lr'] = lr 103 | 104 | # Synth images from opt_w. 105 | w_noise = torch.randn_like(w_opt) * w_noise_scale 106 | ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1]) 107 | synth_images = G.synthesis(ws, noise_mode='const') 108 | 109 | # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. 110 | synth_images = (synth_images + 1) * (255/2) 111 | if synth_images.shape[2] > 256: 112 | synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') 113 | 114 | # Features for synth images. 115 | synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) 116 | dist = (target_features - synth_features).square().sum() 117 | 118 | # Noise regularization. 119 | reg_loss = 0.0 120 | for v in noise_bufs.values(): 121 | noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d() 122 | while True: 123 | reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2 124 | reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2 125 | if noise.shape[2] <= 8: 126 | break 127 | noise = F.avg_pool2d(noise, kernel_size=2) 128 | loss = dist + reg_loss * regularize_noise_weight 129 | 130 | # Step 131 | optimizer.zero_grad(set_to_none=True) 132 | loss.backward() 133 | optimizer.step() 134 | # logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') 135 | 136 | # Save projected W for each optimization step. 137 | w_out[step] = w_opt.detach()[0] 138 | 139 | # Normalize noise. 140 | with torch.no_grad(): 141 | for buf in noise_bufs.values(): 142 | buf -= buf.mean() 143 | buf *= buf.square().mean().rsqrt() 144 | pbar.upd() 145 | 146 | return w_out.repeat([1, G.mapping.num_ws, 1]) 147 | 148 | #---------------------------------------------------------------------------- 149 | 150 | def run_projection( 151 | network_pkl: str, 152 | in_dir: str, 153 | out_dir: str, 154 | save_video: bool, 155 | seed: int, 156 | steps: int 157 | ): 158 | """Project given image to the latent space of pretrained network pickle. 159 | Examples: 160 | python projector.py --outdir=out --target=~/mytargetimg.png \\ 161 | --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl 162 | """ 163 | np.random.seed(seed) 164 | torch.manual_seed(seed) 165 | 166 | # Load networks. 167 | print('Loading networks from "%s"...' % network_pkl) 168 | device = torch.device('cuda') 169 | with dnnlib.util.open_url(network_pkl) as fp: 170 | G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore 171 | 172 | img_files = img_list(in_dir) 173 | num_images = len(img_files) 174 | 175 | for image_idx in range(num_images): 176 | fname = basename(img_files[image_idx]) 177 | print('Projecting image %d/%d .. %s' % (image_idx+1, num_images, basename(img_files[image_idx]))) 178 | work_dir = os.path.join(out_dir, fname) 179 | os.makedirs(work_dir, exist_ok=True) 180 | 181 | # Load target image. 182 | target_pil = PIL.Image.open(img_files[image_idx]).convert('RGB') 183 | w, h = target_pil.size 184 | s = min(w, h) 185 | target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) 186 | target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) 187 | target_uint8 = np.array(target_pil, dtype=np.uint8) 188 | 189 | # Optimize projection. 190 | # start_time = perf_counter() 191 | projected_w_steps = project( 192 | G, 193 | target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable 194 | num_steps=steps, 195 | device=device, 196 | verbose=True 197 | ) 198 | # print (f'Elapsed: {(perf_counter()-start_time):.1f} s') 199 | 200 | # Render debug output: optional video and projected image and W vector. 201 | os.makedirs(out_dir, exist_ok=True) 202 | if save_video: 203 | vfile = '%s/proj.mp4' % work_dir 204 | video = imageio.get_writer(vfile, mode='I', fps=25, codec='libx264', bitrate='16M') 205 | print ('Saving optimization progress video %s' % vfile) 206 | for projected_w in projected_w_steps: 207 | synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') 208 | synth_image = (synth_image + 1) * (255/2) 209 | synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() 210 | video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) 211 | video.close() 212 | 213 | # Save final projected frame and W vector. 214 | target_pil.save('%s/target.jpg' % work_dir) 215 | projected_w = projected_w_steps[-1] 216 | synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') 217 | synth_image = (synth_image + 1) * (255/2) 218 | synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() 219 | 220 | PIL.Image.fromarray(synth_image, 'RGB').save('%s/%s.jpg' % (work_dir, fname)) 221 | np.savez('%s/%s.npz' % (work_dir, fname), w=projected_w.unsqueeze(0).cpu().numpy()) 222 | 223 | #---------------------------------------------------------------------------- 224 | 225 | def main(): 226 | parser = argparse.ArgumentParser( 227 | description='Project given image to the latent space of pretrained network pickle.', 228 | formatter_class=argparse.RawDescriptionHelpFormatter 229 | ) 230 | parser.add_argument('--model', help='Network pickle filename', dest='network_pkl', required=True) 231 | parser.add_argument('--in_dir', default='_in', help='Where to load the input images from', metavar='DIR') 232 | parser.add_argument('--out_dir', default='_out', help='Where to save the output images', metavar='DIR') 233 | parser.add_argument('--steps', default=1000, type=int, help='Number of iterations (default: %(default)s)') # 1000 234 | parser.add_argument('--save_video', action='store_true', help='Save an mp4 video of optimization progress') 235 | parser.add_argument('--seed', help='Random seed', type=int, default=696) 236 | 237 | run_projection(**vars(parser.parse_args())) 238 | 239 | if __name__ == "__main__": 240 | main() 241 | 242 | -------------------------------------------------------------------------------- /src/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /src/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import glob 11 | import torch 12 | import torch.utils.cpp_extension 13 | import importlib 14 | import hashlib 15 | import shutil 16 | from pathlib import Path 17 | 18 | from torch.utils.file_baton import FileBaton 19 | 20 | #---------------------------------------------------------------------------- 21 | # Global options. 22 | 23 | verbosity = 'none' # Verbosity level: 'none', 'brief', 'full' 24 | 25 | #---------------------------------------------------------------------------- 26 | # Internal helper funcs. 27 | 28 | def _find_compiler_bindir(): 29 | patterns = [ 30 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 34 | ] 35 | for pattern in patterns: 36 | matches = sorted(glob.glob(pattern)) 37 | if len(matches): 38 | return matches[-1] 39 | return None 40 | 41 | #---------------------------------------------------------------------------- 42 | # Main entry point for compiling and loading C++/CUDA plugins. 43 | 44 | _cached_plugins = dict() 45 | 46 | def get_plugin(module_name, sources, **build_kwargs): 47 | assert verbosity in ['none', 'brief', 'full'] 48 | 49 | # Already cached? 50 | if module_name in _cached_plugins: 51 | return _cached_plugins[module_name] 52 | 53 | # Print status. 54 | if verbosity == 'full': 55 | print(f'Setting up PyTorch plugin "{module_name}"...') 56 | elif verbosity == 'brief': 57 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 58 | 59 | try: # pylint: disable=too-many-nested-blocks 60 | # Make sure we can find the necessary compiler binaries. 61 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 62 | compiler_bindir = _find_compiler_bindir() 63 | if compiler_bindir is None: 64 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 65 | os.environ['PATH'] += ';' + compiler_bindir 66 | 67 | # Compile and load. 68 | verbose_build = (verbosity == 'full') 69 | 70 | # Incremental build md5sum trickery. Copies all the input source files 71 | # into a cached build directory under a combined md5 digest of the input 72 | # source files. Copying is done only if the combined digest has changed. 73 | # This keeps input file timestamps and filenames the same as in previous 74 | # extension builds, allowing for fast incremental rebuilds. 75 | # 76 | # This optimization is done only in case all the source files reside in 77 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 78 | # environment variable is set (we take this as a signal that the user 79 | # actually cares about this.) 80 | source_dirs_set = set(os.path.dirname(source) for source in sources) 81 | if len(source_dirs_set) == 1 and ('TORCH_EXTENSIONS_DIR' in os.environ): 82 | all_source_files = sorted(list(x for x in Path(list(source_dirs_set)[0]).iterdir() if x.is_file())) 83 | 84 | # Compute a combined hash digest for all source files in the same 85 | # custom op directory (usually .cu, .cpp, .py and .h files). 86 | hash_md5 = hashlib.md5() 87 | for src in all_source_files: 88 | with open(src, 'rb') as f: 89 | hash_md5.update(f.read()) 90 | build_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 91 | digest_build_dir = os.path.join(build_dir, hash_md5.hexdigest()) 92 | 93 | if not os.path.isdir(digest_build_dir): 94 | os.makedirs(digest_build_dir, exist_ok=True) 95 | baton = FileBaton(os.path.join(digest_build_dir, 'lock')) 96 | if baton.try_acquire(): 97 | try: 98 | for src in all_source_files: 99 | shutil.copyfile(src, os.path.join(digest_build_dir, os.path.basename(src))) 100 | finally: 101 | baton.release() 102 | else: 103 | # Someone else is copying source files under the digest dir, 104 | # wait until done and continue. 105 | baton.wait() 106 | digest_sources = [os.path.join(digest_build_dir, os.path.basename(x)) for x in sources] 107 | torch.utils.cpp_extension.load(name=module_name, build_directory=build_dir, 108 | verbose=verbose_build, sources=digest_sources, **build_kwargs) 109 | else: 110 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 111 | module = importlib.import_module(module_name) 112 | 113 | except: 114 | if verbosity == 'brief': 115 | print('Failed!') 116 | raise 117 | 118 | # Print status and add to cache. 119 | if verbosity == 'full': 120 | print(f'Done setting up PyTorch plugin "{module_name}".') 121 | elif verbosity == 'brief': 122 | print('Done.') 123 | _cached_plugins[module_name] = module 124 | return module 125 | 126 | #---------------------------------------------------------------------------- 127 | -------------------------------------------------------------------------------- /src/torch_utils/misc.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import re 10 | import contextlib 11 | import numpy as np 12 | import torch 13 | import warnings 14 | import dnnlib 15 | 16 | #---------------------------------------------------------------------------- 17 | # Cached construction of constant tensors. Avoids CPU=>GPU copy when the 18 | # same constant is used multiple times. 19 | 20 | _constant_cache = dict() 21 | 22 | def constant(value, shape=None, dtype=None, device=None, memory_format=None): 23 | value = np.asarray(value) 24 | if shape is not None: 25 | shape = tuple(shape) 26 | if dtype is None: 27 | dtype = torch.get_default_dtype() 28 | if device is None: 29 | device = torch.device('cpu') 30 | if memory_format is None: 31 | memory_format = torch.contiguous_format 32 | 33 | key = (value.shape, value.dtype, value.tobytes(), shape, dtype, device, memory_format) 34 | tensor = _constant_cache.get(key, None) 35 | if tensor is None: 36 | tensor = torch.as_tensor(value.copy(), dtype=dtype, device=device) 37 | if shape is not None: 38 | tensor, _ = torch.broadcast_tensors(tensor, torch.empty(shape)) 39 | tensor = tensor.contiguous(memory_format=memory_format) 40 | _constant_cache[key] = tensor 41 | return tensor 42 | 43 | #---------------------------------------------------------------------------- 44 | # Replace NaN/Inf with specified numerical values. 45 | 46 | try: 47 | nan_to_num = torch.nan_to_num # 1.8.0a0 48 | except AttributeError: 49 | def nan_to_num(input, nan=0.0, posinf=None, neginf=None, *, out=None): # pylint: disable=redefined-builtin 50 | assert isinstance(input, torch.Tensor) 51 | if posinf is None: 52 | posinf = torch.finfo(input.dtype).max 53 | if neginf is None: 54 | neginf = torch.finfo(input.dtype).min 55 | assert nan == 0 56 | return torch.clamp(input.unsqueeze(0).nansum(0), min=neginf, max=posinf, out=out) 57 | 58 | #---------------------------------------------------------------------------- 59 | # Symbolic assert. 60 | 61 | try: 62 | symbolic_assert = torch._assert # 1.8.0a0 # pylint: disable=protected-access 63 | except AttributeError: 64 | symbolic_assert = torch.Assert # 1.7.0 65 | 66 | #---------------------------------------------------------------------------- 67 | # Context manager to suppress known warnings in torch.jit.trace(). 68 | 69 | class suppress_tracer_warnings(warnings.catch_warnings): 70 | def __enter__(self): 71 | super().__enter__() 72 | warnings.simplefilter('ignore', category=torch.jit.TracerWarning) 73 | return self 74 | 75 | #---------------------------------------------------------------------------- 76 | # Assert that the shape of a tensor matches the given list of integers. 77 | # None indicates that the size of a dimension is allowed to vary. 78 | # Performs symbolic assertion when used in torch.jit.trace(). 79 | 80 | def assert_shape(tensor, ref_shape): 81 | if tensor.ndim != len(ref_shape): 82 | raise AssertionError(f'Wrong number of dimensions: got {tensor.ndim}, expected {len(ref_shape)}') 83 | for idx, (size, ref_size) in enumerate(zip(tensor.shape, ref_shape)): 84 | if ref_size is None: 85 | pass 86 | elif isinstance(ref_size, torch.Tensor): 87 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 88 | symbolic_assert(torch.equal(torch.as_tensor(size), ref_size), f'Wrong size for dimension {idx}') 89 | elif isinstance(size, torch.Tensor): 90 | with suppress_tracer_warnings(): # as_tensor results are registered as constants 91 | symbolic_assert(torch.equal(size, torch.as_tensor(ref_size)), f'Wrong size for dimension {idx}: expected {ref_size}') 92 | elif size != ref_size: 93 | raise AssertionError(f'Wrong size for dimension {idx}: got {size}, expected {ref_size}') 94 | 95 | #---------------------------------------------------------------------------- 96 | # Function decorator that calls torch.autograd.profiler.record_function(). 97 | 98 | def profiled_function(fn): 99 | def decorator(*args, **kwargs): 100 | with torch.autograd.profiler.record_function(fn.__name__): 101 | return fn(*args, **kwargs) 102 | decorator.__name__ = fn.__name__ 103 | return decorator 104 | 105 | #---------------------------------------------------------------------------- 106 | # Sampler for torch.utils.data.DataLoader that loops over the dataset 107 | # indefinitely, shuffling items as it goes. 108 | 109 | class InfiniteSampler(torch.utils.data.Sampler): 110 | def __init__(self, dataset, rank=0, num_replicas=1, shuffle=True, seed=0, window_size=0.5): 111 | assert len(dataset) > 0 112 | assert num_replicas > 0 113 | assert 0 <= rank < num_replicas 114 | assert 0 <= window_size <= 1 115 | super().__init__(dataset) 116 | self.dataset = dataset 117 | self.rank = rank 118 | self.num_replicas = num_replicas 119 | self.shuffle = shuffle 120 | self.seed = seed 121 | self.window_size = window_size 122 | 123 | def __iter__(self): 124 | order = np.arange(len(self.dataset)) 125 | rnd = None 126 | window = 0 127 | if self.shuffle: 128 | rnd = np.random.RandomState(self.seed) 129 | rnd.shuffle(order) 130 | window = int(np.rint(order.size * self.window_size)) 131 | 132 | idx = 0 133 | while True: 134 | i = idx % order.size 135 | if idx % self.num_replicas == self.rank: 136 | yield order[i] 137 | if window >= 2: 138 | j = (i - rnd.randint(window)) % order.size 139 | order[i], order[j] = order[j], order[i] 140 | idx += 1 141 | 142 | #---------------------------------------------------------------------------- 143 | # Utilities for operating with torch.nn.Module parameters and buffers. 144 | 145 | def params_and_buffers(module): 146 | assert isinstance(module, torch.nn.Module) 147 | return list(module.parameters()) + list(module.buffers()) 148 | 149 | def named_params_and_buffers(module): 150 | assert isinstance(module, torch.nn.Module) 151 | return list(module.named_parameters()) + list(module.named_buffers()) 152 | 153 | def copy_params_and_buffers(src_module, dst_module, require_all=False): 154 | assert isinstance(src_module, torch.nn.Module) 155 | assert isinstance(dst_module, torch.nn.Module) 156 | src_tensors = {name: tensor for name, tensor in named_params_and_buffers(src_module)} 157 | for name, tensor in named_params_and_buffers(dst_module): 158 | assert (name in src_tensors) or (not require_all) 159 | if name in src_tensors: 160 | tensor.copy_(src_tensors[name].detach()).requires_grad_(tensor.requires_grad) 161 | 162 | #---------------------------------------------------------------------------- 163 | # Context manager for easily enabling/disabling DistributedDataParallel 164 | # synchronization. 165 | 166 | @contextlib.contextmanager 167 | def ddp_sync(module, sync): 168 | assert isinstance(module, torch.nn.Module) 169 | if sync or not isinstance(module, torch.nn.parallel.DistributedDataParallel): 170 | yield 171 | else: 172 | with module.no_sync(): 173 | yield 174 | 175 | #---------------------------------------------------------------------------- 176 | # Check DistributedDataParallel consistency across processes. 177 | 178 | def check_ddp_consistency(module, ignore_regex=None): 179 | assert isinstance(module, torch.nn.Module) 180 | for name, tensor in named_params_and_buffers(module): 181 | fullname = type(module).__name__ + '.' + name 182 | if ignore_regex is not None and re.fullmatch(ignore_regex, fullname): 183 | continue 184 | tensor = tensor.detach() 185 | other = tensor.clone() 186 | torch.distributed.broadcast(tensor=other, src=0) 187 | assert (nan_to_num(tensor) == nan_to_num(other)).all(), fullname 188 | 189 | #---------------------------------------------------------------------------- 190 | # Print summary table of module hierarchy. 191 | 192 | def print_module_summary(module, inputs, max_nesting=3, skip_redundant=True): 193 | assert isinstance(module, torch.nn.Module) 194 | assert not isinstance(module, torch.jit.ScriptModule) 195 | assert isinstance(inputs, (tuple, list)) 196 | 197 | # Register hooks. 198 | entries = [] 199 | nesting = [0] 200 | def pre_hook(_mod, _inputs): 201 | nesting[0] += 1 202 | def post_hook(mod, _inputs, outputs): 203 | nesting[0] -= 1 204 | if nesting[0] <= max_nesting: 205 | outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] 206 | outputs = [t for t in outputs if isinstance(t, torch.Tensor)] 207 | entries.append(dnnlib.EasyDict(mod=mod, outputs=outputs)) 208 | hooks = [mod.register_forward_pre_hook(pre_hook) for mod in module.modules()] 209 | hooks += [mod.register_forward_hook(post_hook) for mod in module.modules()] 210 | 211 | # Run module. 212 | outputs = module(*inputs) 213 | for hook in hooks: 214 | hook.remove() 215 | 216 | # Identify unique outputs, parameters, and buffers. 217 | tensors_seen = set() 218 | for e in entries: 219 | e.unique_params = [t for t in e.mod.parameters() if id(t) not in tensors_seen] 220 | e.unique_buffers = [t for t in e.mod.buffers() if id(t) not in tensors_seen] 221 | e.unique_outputs = [t for t in e.outputs if id(t) not in tensors_seen] 222 | tensors_seen |= {id(t) for t in e.unique_params + e.unique_buffers + e.unique_outputs} 223 | 224 | # Filter out redundant entries. 225 | if skip_redundant: 226 | entries = [e for e in entries if len(e.unique_params) or len(e.unique_buffers) or len(e.unique_outputs)] 227 | 228 | # Construct table. 229 | rows = [[type(module).__name__, 'Parameters', 'Buffers', 'Output shape', 'Datatype']] 230 | rows += [['---'] * len(rows[0])] 231 | param_total = 0 232 | buffer_total = 0 233 | submodule_names = {mod: name for name, mod in module.named_modules()} 234 | for e in entries: 235 | name = '' if e.mod is module else submodule_names[e.mod] 236 | param_size = sum(t.numel() for t in e.unique_params) 237 | buffer_size = sum(t.numel() for t in e.unique_buffers) 238 | output_shapes = [str(list(e.outputs[0].shape)) for t in e.outputs] 239 | output_dtypes = [str(t.dtype).split('.')[-1] for t in e.outputs] 240 | rows += [[ 241 | name + (':0' if len(e.outputs) >= 2 else ''), 242 | str(param_size) if param_size else '-', 243 | str(buffer_size) if buffer_size else '-', 244 | (output_shapes + ['-'])[0], 245 | (output_dtypes + ['-'])[0], 246 | ]] 247 | for idx in range(1, len(e.outputs)): 248 | rows += [[name + f':{idx}', '-', '-', output_shapes[idx], output_dtypes[idx]]] 249 | param_total += param_size 250 | buffer_total += buffer_size 251 | rows += [['---'] * len(rows[0])] 252 | rows += [['Total', str(param_total), str(buffer_total), '-', '-']] 253 | 254 | # Print table. 255 | widths = [max(len(cell) for cell in column) for column in zip(*rows)] 256 | print() 257 | for row in rows: 258 | print(' '.join(cell + ' ' * (width - len(cell)) for cell, width in zip(row, widths))) 259 | print() 260 | return outputs 261 | 262 | #---------------------------------------------------------------------------- 263 | -------------------------------------------------------------------------------- /src/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /src/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /src/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /src/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /src/torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import sys 13 | import warnings 14 | import numpy as np 15 | import torch 16 | import dnnlib 17 | 18 | from .. import custom_ops 19 | from .. import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | activation_funcs = { 24 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 25 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 26 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 27 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 28 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 29 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 30 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 31 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 32 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 33 | } 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | _inited = False 38 | _plugin = None 39 | _null_tensor = torch.empty([0]) 40 | 41 | def _init(): 42 | global _inited, _plugin 43 | if not _inited: 44 | _inited = True 45 | sources = ['bias_act.cpp', 'bias_act.cu'] 46 | sources = [os.path.join(os.path.dirname(__file__), s) for s in sources] 47 | try: 48 | _plugin = custom_ops.get_plugin('bias_act_plugin', sources=sources, extra_cuda_cflags=['--use_fast_math']) 49 | except: 50 | warnings.warn('Failed to build CUDA kernels for bias_act. Falling back to slow reference implementation. Details:\n\n' + str(sys.exc_info()[1])) 51 | return _plugin is not None 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 56 | r"""Fused bias and activation function. 57 | 58 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 59 | and scales the result by `gain`. Each of the steps is optional. In most cases, 60 | the fused op is considerably more efficient than performing the same calculation 61 | using standard PyTorch ops. It supports first and second order gradients, 62 | but not third order gradients. 63 | 64 | Args: 65 | x: Input activation tensor. Can be of any shape. 66 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 67 | as `x`. The shape must be known, and it must match the dimension of `x` 68 | corresponding to `dim`. 69 | dim: The dimension in `x` corresponding to the elements of `b`. 70 | The value of `dim` is ignored if `b` is not specified. 71 | act: Name of the activation function to evaluate, or `"linear"` to disable. 72 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 73 | See `activation_funcs` for a full list. `None` is not allowed. 74 | alpha: Shape parameter for the activation function, or `None` to use the default. 75 | gain: Scaling factor for the output tensor, or `None` to use default. 76 | See `activation_funcs` for the default scaling of each activation function. 77 | If unsure, consider specifying 1. 78 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 79 | the clamping (default). 80 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 81 | 82 | Returns: 83 | Tensor of the same shape and datatype as `x`. 84 | """ 85 | assert isinstance(x, torch.Tensor) 86 | assert impl in ['ref', 'cuda'] 87 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 88 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 89 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 90 | 91 | #---------------------------------------------------------------------------- 92 | 93 | @misc.profiled_function 94 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 95 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 96 | """ 97 | assert isinstance(x, torch.Tensor) 98 | assert clamp is None or clamp >= 0 99 | spec = activation_funcs[act] 100 | alpha = float(alpha if alpha is not None else spec.def_alpha) 101 | gain = float(gain if gain is not None else spec.def_gain) 102 | clamp = float(clamp if clamp is not None else -1) 103 | 104 | # Add bias. 105 | if b is not None: 106 | assert isinstance(b, torch.Tensor) and b.ndim == 1 107 | assert 0 <= dim < x.ndim 108 | assert b.shape[0] == x.shape[dim] 109 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 110 | 111 | # Evaluate activation function. 112 | alpha = float(alpha) 113 | x = spec.func(x, alpha=alpha) 114 | 115 | # Scale by gain. 116 | gain = float(gain) 117 | if gain != 1: 118 | x = x * gain 119 | 120 | # Clamp. 121 | if clamp >= 0: 122 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 123 | return x 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | _bias_act_cuda_cache = dict() 128 | 129 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 130 | """Fast CUDA implementation of `bias_act()` using custom ops. 131 | """ 132 | # Parse arguments. 133 | assert clamp is None or clamp >= 0 134 | spec = activation_funcs[act] 135 | alpha = float(alpha if alpha is not None else spec.def_alpha) 136 | gain = float(gain if gain is not None else spec.def_gain) 137 | clamp = float(clamp if clamp is not None else -1) 138 | 139 | # Lookup from cache. 140 | key = (dim, act, alpha, gain, clamp) 141 | if key in _bias_act_cuda_cache: 142 | return _bias_act_cuda_cache[key] 143 | 144 | # Forward op. 145 | class BiasActCuda(torch.autograd.Function): 146 | @staticmethod 147 | def forward(ctx, x, b): # pylint: disable=arguments-differ 148 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride()[1] == 1 else torch.contiguous_format 149 | x = x.contiguous(memory_format=ctx.memory_format) 150 | b = b.contiguous() if b is not None else _null_tensor 151 | y = x 152 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 153 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 154 | ctx.save_for_backward( 155 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 156 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 157 | y if 'y' in spec.ref else _null_tensor) 158 | return y 159 | 160 | @staticmethod 161 | def backward(ctx, dy): # pylint: disable=arguments-differ 162 | dy = dy.contiguous(memory_format=ctx.memory_format) 163 | x, b, y = ctx.saved_tensors 164 | dx = None 165 | db = None 166 | 167 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 168 | dx = dy 169 | if act != 'linear' or gain != 1 or clamp >= 0: 170 | dx = BiasActCudaGrad.apply(dy, x, b, y) 171 | 172 | if ctx.needs_input_grad[1]: 173 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 174 | 175 | return dx, db 176 | 177 | # Backward op. 178 | class BiasActCudaGrad(torch.autograd.Function): 179 | @staticmethod 180 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 181 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride()[1] == 1 else torch.contiguous_format 182 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 183 | ctx.save_for_backward( 184 | dy if spec.has_2nd_grad else _null_tensor, 185 | x, b, y) 186 | return dx 187 | 188 | @staticmethod 189 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 190 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 191 | dy, x, b, y = ctx.saved_tensors 192 | d_dy = None 193 | d_x = None 194 | d_b = None 195 | d_y = None 196 | 197 | if ctx.needs_input_grad[0]: 198 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 199 | 200 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 201 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 202 | 203 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 204 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 205 | 206 | return d_dy, d_x, d_b, d_y 207 | 208 | # Add to cache. 209 | _bias_act_cuda_cache[key] = BiasActCuda 210 | return BiasActCuda 211 | 212 | #---------------------------------------------------------------------------- 213 | -------------------------------------------------------------------------------- /src/torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import torch 14 | from pkg_resources import parse_version 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | @contextlib.contextmanager 27 | def no_weight_gradients(disable=True): 28 | global weight_gradients_disabled 29 | old = weight_gradients_disabled 30 | if disable: 31 | weight_gradients_disabled = True 32 | yield 33 | weight_gradients_disabled = old 34 | 35 | #---------------------------------------------------------------------------- 36 | 37 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 38 | if _should_use_custom_op(input): 39 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 40 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 41 | 42 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 43 | if _should_use_custom_op(input): 44 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 45 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _should_use_custom_op(input): 50 | assert isinstance(input, torch.Tensor) 51 | if (not enabled) or (not torch.backends.cudnn.enabled): 52 | return False 53 | if _use_pytorch_1_11_api: 54 | # The work-around code doesn't work on PyTorch 1.11.0 onwards 55 | return False 56 | if input.device.type != 'cuda': 57 | return False 58 | return True 59 | 60 | def _tuple_of_ints(xs, ndim): 61 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 62 | assert len(xs) == ndim 63 | assert all(isinstance(x, int) for x in xs) 64 | return xs 65 | 66 | #---------------------------------------------------------------------------- 67 | 68 | _conv2d_gradfix_cache = dict() 69 | _null_tensor = torch.empty([0]) 70 | 71 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 72 | # Parse arguments. 73 | ndim = 2 74 | weight_shape = tuple(weight_shape) 75 | stride = _tuple_of_ints(stride, ndim) 76 | padding = _tuple_of_ints(padding, ndim) 77 | output_padding = _tuple_of_ints(output_padding, ndim) 78 | dilation = _tuple_of_ints(dilation, ndim) 79 | 80 | # Lookup from cache. 81 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 82 | if key in _conv2d_gradfix_cache: 83 | return _conv2d_gradfix_cache[key] 84 | 85 | # Validate arguments. 86 | assert groups >= 1 87 | assert len(weight_shape) == ndim + 2 88 | assert all(stride[i] >= 1 for i in range(ndim)) 89 | assert all(padding[i] >= 0 for i in range(ndim)) 90 | assert all(dilation[i] >= 0 for i in range(ndim)) 91 | if not transpose: 92 | assert all(output_padding[i] == 0 for i in range(ndim)) 93 | else: # transpose 94 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 95 | 96 | # Helpers. 97 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 98 | def calc_output_padding(input_shape, output_shape): 99 | if transpose: 100 | return [0, 0] 101 | return [ 102 | input_shape[i + 2] 103 | - (output_shape[i + 2] - 1) * stride[i] 104 | - (1 - 2 * padding[i]) 105 | - dilation[i] * (weight_shape[i + 2] - 1) 106 | for i in range(ndim) 107 | ] 108 | 109 | # Forward & backward. 110 | class Conv2d(torch.autograd.Function): 111 | @staticmethod 112 | def forward(ctx, input, weight, bias): 113 | assert weight.shape == weight_shape 114 | ctx.save_for_backward( 115 | input if weight.requires_grad else _null_tensor, 116 | weight if input.requires_grad else _null_tensor, 117 | ) 118 | ctx.input_shape = input.shape 119 | 120 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 121 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 122 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 123 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 124 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 125 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 126 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 127 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 128 | 129 | # General case => cuDNN. 130 | if transpose: 131 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 132 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 133 | 134 | @staticmethod 135 | def backward(ctx, grad_output): 136 | input, weight = ctx.saved_tensors 137 | input_shape = ctx.input_shape 138 | grad_input = None 139 | grad_weight = None 140 | grad_bias = None 141 | 142 | if ctx.needs_input_grad[0]: 143 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 144 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 145 | grad_input = op.apply(grad_output, weight, None) 146 | assert grad_input.shape == input_shape 147 | 148 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 149 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 150 | assert grad_weight.shape == weight_shape 151 | 152 | if ctx.needs_input_grad[2]: 153 | grad_bias = grad_output.sum([0, 2, 3]) 154 | 155 | return grad_input, grad_weight, grad_bias 156 | 157 | # Gradient with respect to the weights. 158 | class Conv2dGradWeight(torch.autograd.Function): 159 | @staticmethod 160 | def forward(ctx, grad_output, input): 161 | ctx.save_for_backward( 162 | grad_output if input.requires_grad else _null_tensor, 163 | input if grad_output.requires_grad else _null_tensor, 164 | ) 165 | ctx.grad_output_shape = grad_output.shape 166 | ctx.input_shape = input.shape 167 | 168 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 169 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 170 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 171 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 172 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 173 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 174 | 175 | # General case => cuDNN. 176 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' 177 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 178 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 179 | 180 | @staticmethod 181 | def backward(ctx, grad2_grad_weight): 182 | grad_output, input = ctx.saved_tensors 183 | grad_output_shape = ctx.grad_output_shape 184 | input_shape = ctx.input_shape 185 | grad2_grad_output = None 186 | grad2_input = None 187 | 188 | if ctx.needs_input_grad[0]: 189 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 190 | assert grad2_grad_output.shape == grad_output_shape 191 | 192 | if ctx.needs_input_grad[1]: 193 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 194 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 195 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 196 | assert grad2_input.shape == input_shape 197 | 198 | return grad2_grad_output, grad2_input 199 | 200 | _conv2d_gradfix_cache[key] = Conv2d 201 | return Conv2d 202 | 203 | #---------------------------------------------------------------------------- 204 | -------------------------------------------------------------------------------- /src/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | if not flip_weight: # conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | w = w.flip([2, 3]) 37 | 38 | # Workaround performance pitfall in cuDNN 8.0.5, triggered when using 39 | # 1x1 kernel + memory_format=channels_last + less than 64 channels. 40 | if kw == 1 and kh == 1 and stride == 1 and padding in [0, [0, 0], (0, 0)] and not transpose: 41 | if x.stride()[1] == 1 and min(out_channels, in_channels_per_group) < 64: 42 | if out_channels <= 4 and groups == 1: 43 | in_shape = x.shape 44 | x = w.squeeze(3).squeeze(2) @ x.reshape([in_shape[0], in_channels_per_group, -1]) 45 | x = x.reshape([in_shape[0], out_channels, in_shape[2], in_shape[3]]) 46 | else: 47 | x = x.to(memory_format=torch.contiguous_format) 48 | w = w.to(memory_format=torch.contiguous_format) 49 | x = conv2d_gradfix.conv2d(x, w, groups=groups) 50 | return x.to(memory_format=torch.channels_last) 51 | 52 | # Otherwise => execute using conv2d_gradfix. 53 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 54 | return op(x, w, stride=stride, padding=padding, groups=groups) 55 | 56 | #---------------------------------------------------------------------------- 57 | 58 | @misc.profiled_function 59 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 60 | r"""2D convolution with optional up/downsampling. 61 | 62 | Padding is performed only once at the beginning, not between the operations. 63 | 64 | Args: 65 | x: Input tensor of shape 66 | `[batch_size, in_channels, in_height, in_width]`. 67 | w: Weight tensor of shape 68 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 69 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 70 | calling upfirdn2d.setup_filter(). None = identity (default). 71 | up: Integer upsampling factor (default: 1). 72 | down: Integer downsampling factor (default: 1). 73 | padding: Padding with respect to the upsampled image. Can be a single number 74 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 75 | (default: 0). 76 | groups: Split input channels into N groups (default: 1). 77 | flip_weight: False = convolution, True = correlation (default: True). 78 | flip_filter: False = convolution, True = correlation (default: False). 79 | 80 | Returns: 81 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 82 | """ 83 | # Validate arguments. 84 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 85 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 86 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 87 | assert isinstance(up, int) and (up >= 1) 88 | assert isinstance(down, int) and (down >= 1) 89 | assert isinstance(groups, int) and (groups >= 1) 90 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 91 | fw, fh = _get_filter_size(f) 92 | px0, px1, py0, py1 = _parse_padding(padding) 93 | 94 | # Adjust padding to account for up/downsampling. 95 | if up > 1: 96 | px0 += (fw + up - 1) // 2 97 | px1 += (fw - up) // 2 98 | py0 += (fh + up - 1) // 2 99 | py1 += (fh - up) // 2 100 | if down > 1: 101 | px0 += (fw - down + 1) // 2 102 | px1 += (fw - down) // 2 103 | py0 += (fh - down + 1) // 2 104 | py1 += (fh - down) // 2 105 | 106 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 107 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 108 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 109 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 110 | return x 111 | 112 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 113 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 114 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 115 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 116 | return x 117 | 118 | # Fast path: downsampling only => use strided convolution. 119 | if down > 1 and up == 1: 120 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 121 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 122 | return x 123 | 124 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 125 | if up > 1: 126 | if groups == 1: 127 | w = w.transpose(0, 1) 128 | else: 129 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 130 | w = w.transpose(1, 2) 131 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 132 | px0 -= kw - 1 133 | px1 -= kw - up 134 | py0 -= kh - 1 135 | py1 -= kh - up 136 | pxt = max(min(-px0, -px1), 0) 137 | pyt = max(min(-py0, -py1), 0) 138 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 139 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 140 | if down > 1: 141 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 142 | return x 143 | 144 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 145 | if up == 1 and down == 1: 146 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 147 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 148 | 149 | # Fallback: Generic reference implementation. 150 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 151 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 152 | if down > 1: 153 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 154 | return x 155 | 156 | #---------------------------------------------------------------------------- 157 | -------------------------------------------------------------------------------- /src/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /src/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | # !!! fix for pytorch > 1.10 https://github.com/NVlabs/stylegan3/issues/188 62 | if isinstance(op, tuple) or len(op)==2: op = op[0] 63 | if _use_pytorch_1_11_api: 64 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 66 | else: 67 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 68 | ctx.save_for_backward(grid) 69 | return grad_input, grad_grid 70 | 71 | @staticmethod 72 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 73 | _ = grad2_grad_grid # unused 74 | grid, = ctx.saved_tensors 75 | grad2_grad_output = None 76 | grad2_input = None 77 | grad2_grid = None 78 | 79 | if ctx.needs_input_grad[0]: 80 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 81 | 82 | assert not ctx.needs_input_grad[2] 83 | return grad2_grad_output, grad2_input, grad2_grid 84 | 85 | #---------------------------------------------------------------------------- 86 | -------------------------------------------------------------------------------- /src/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 25 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 26 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 27 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 28 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 29 | 30 | // Create output tensor. 31 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 32 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 33 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 34 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 35 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 36 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 37 | 38 | // Initialize CUDA kernel parameters. 39 | upfirdn2d_kernel_params p; 40 | p.x = x.data_ptr(); 41 | p.f = f.data_ptr(); 42 | p.y = y.data_ptr(); 43 | p.up = make_int2(upx, upy); 44 | p.down = make_int2(downx, downy); 45 | p.pad0 = make_int2(padx0, pady0); 46 | p.flip = (flip) ? 1 : 0; 47 | p.gain = gain; 48 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 49 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 50 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 51 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 52 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 53 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 54 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 55 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 56 | 57 | // Choose CUDA kernel. 58 | upfirdn2d_kernel_spec spec; 59 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 60 | { 61 | spec = choose_upfirdn2d_kernel(p); 62 | }); 63 | 64 | // Set looping options. 65 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 66 | p.loopMinor = spec.loopMinor; 67 | p.loopX = spec.loopX; 68 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 69 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 70 | 71 | // Compute grid size. 72 | dim3 blockSize, gridSize; 73 | if (spec.tileOutW < 0) // large 74 | { 75 | blockSize = dim3(4, 32, 1); 76 | gridSize = dim3( 77 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 78 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 79 | p.launchMajor); 80 | } 81 | else // small 82 | { 83 | blockSize = dim3(256, 1, 1); 84 | gridSize = dim3( 85 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 86 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 87 | p.launchMajor); 88 | } 89 | 90 | // Launch CUDA kernel. 91 | void* args[] = {&p}; 92 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 93 | return y; 94 | } 95 | 96 | //------------------------------------------------------------------------ 97 | 98 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 99 | { 100 | m.def("upfirdn2d", &upfirdn2d); 101 | } 102 | 103 | //------------------------------------------------------------------------ 104 | -------------------------------------------------------------------------------- /src/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /src/torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /src/torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /src/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /src/training/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import numpy as np 11 | import zipfile 12 | import PIL.Image 13 | import json 14 | import torch 15 | import dnnlib 16 | 17 | try: 18 | import pyspng 19 | except ImportError: 20 | pyspng = None 21 | 22 | from util.utilgan import calc_res 23 | 24 | class Dataset(torch.utils.data.Dataset): 25 | def __init__(self, 26 | name, # Name of the dataset. 27 | raw_shape, # Shape of the raw image data (NCHW). 28 | max_size = None, # Artificially limit the size of the dataset. None = no limit. Applied before xflip. 29 | use_labels = False, # Enable conditioning labels? False = label dimension is zero. 30 | xflip = False, # Artificially double the size of the dataset via x-flips. Applied after max_size. 31 | # yflip = False, # Artificially double the size of the dataset via y-flips. Applied after max_size. 32 | random_seed = 0, # Random seed to use when applying max_size. 33 | ): 34 | self._name = name 35 | self._raw_shape = list(raw_shape) 36 | self._use_labels = use_labels 37 | self._raw_labels = None 38 | self._label_shape = None 39 | 40 | # Apply max_size. 41 | self._raw_idx = np.arange(self._raw_shape[0], dtype=np.int64) 42 | if (max_size is not None) and (self._raw_idx.size > max_size): 43 | np.random.RandomState(random_seed).shuffle(self._raw_idx) 44 | self._raw_idx = np.sort(self._raw_idx[:max_size]) 45 | 46 | # Apply xflip. 47 | self._xflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 48 | if xflip: 49 | self._raw_idx = np.tile(self._raw_idx, 2) 50 | self._xflip = np.concatenate([self._xflip, np.ones_like(self._xflip)]) 51 | # Apply yflip. 52 | # self._yflip = np.zeros(self._raw_idx.size, dtype=np.uint8) 53 | # if yflip: 54 | # self._raw_idx = np.tile(self._raw_idx, 2) 55 | # self._yflip = np.concatenate([self._yflip, np.ones_like(self._yflip)]) 56 | 57 | def _get_raw_labels(self): 58 | if self._raw_labels is None: 59 | # !!! cond dir labels 60 | self._raw_labels = self._load_dir_labels() if self._use_labels else None 61 | # self._raw_labels = self._load_raw_labels() if self._use_labels else None 62 | if self._raw_labels is None: 63 | self._raw_labels = np.zeros([self._raw_shape[0], 0], dtype=np.float32) 64 | assert isinstance(self._raw_labels, np.ndarray) 65 | assert self._raw_labels.shape[0] == self._raw_shape[0] 66 | assert self._raw_labels.dtype in [np.float32, np.int64] 67 | if self._raw_labels.dtype == np.int64: 68 | assert self._raw_labels.ndim == 1 69 | assert np.all(self._raw_labels >= 0) 70 | return self._raw_labels 71 | 72 | def close(self): # to be overridden by subclass 73 | pass 74 | 75 | def _load_raw_image(self, raw_idx): # to be overridden by subclass 76 | raise NotImplementedError 77 | 78 | def _load_raw_labels(self): # to be overridden by subclass 79 | raise NotImplementedError 80 | 81 | def _load_dir_labels(self): # to be overridden by subclass 82 | raise NotImplementedError 83 | 84 | def __getstate__(self): 85 | return dict(self.__dict__, _raw_labels=None) 86 | 87 | def __del__(self): 88 | try: 89 | self.close() 90 | except: 91 | pass 92 | 93 | def __len__(self): 94 | return self._raw_idx.size 95 | 96 | def __getitem__(self, idx): 97 | image = self._load_raw_image(self._raw_idx[idx]) 98 | assert isinstance(image, np.ndarray) 99 | assert list(image.shape) == self.image_shape 100 | assert image.dtype == np.uint8 101 | if self._xflip[idx]: 102 | assert image.ndim == 3 # CHW 103 | image = image[:, :, ::-1] 104 | return image.copy(), self.get_label(idx) 105 | 106 | def get_label(self, idx): 107 | label = self._get_raw_labels()[self._raw_idx[idx]] 108 | if label.dtype == np.int64: 109 | onehot = np.zeros(self.label_shape, dtype=np.float32) 110 | onehot[label] = 1 111 | label = onehot 112 | return label.copy() 113 | 114 | def get_details(self, idx): 115 | d = dnnlib.EasyDict() 116 | d.raw_idx = int(self._raw_idx[idx]) 117 | d.xflip = (int(self._xflip[idx]) != 0) 118 | d.raw_label = self._get_raw_labels()[d.raw_idx].copy() 119 | return d 120 | 121 | @property 122 | def name(self): 123 | return self._name 124 | 125 | @property 126 | def image_shape(self): 127 | return list(self._raw_shape[1:]) 128 | 129 | @property 130 | def num_channels(self): 131 | assert len(self.image_shape) == 3 # CHW 132 | return self.image_shape[0] 133 | 134 | @property 135 | def resolution(self): 136 | assert len(self.image_shape) == 3 # CHW 137 | # !!! custom init res 138 | max_res = calc_res(self.image_shape[1:]) 139 | return max_res 140 | # assert self.image_shape[1] == self.image_shape[2] 141 | # return self.image_shape[1] 142 | 143 | # !!! custom init res 144 | @property 145 | def res_log2(self): 146 | return int(np.ceil(np.log2(self.resolution))) 147 | 148 | # !!! custom init res 149 | @property 150 | def init_res(self): 151 | return [int(s * 2**(2-self.res_log2)) for s in self.image_shape[1:]] 152 | 153 | @property 154 | def label_shape(self): 155 | if self._label_shape is None: 156 | raw_labels = self._get_raw_labels() 157 | if raw_labels.dtype == np.int64: 158 | self._label_shape = [int(np.max(raw_labels)) + 1] 159 | else: 160 | self._label_shape = raw_labels.shape[1:] 161 | return list(self._label_shape) 162 | 163 | @property 164 | def label_dim(self): 165 | assert len(self.label_shape) == 1 166 | return self.label_shape[0] 167 | 168 | @property 169 | def has_labels(self): 170 | return any(x != 0 for x in self.label_shape) 171 | 172 | @property 173 | def has_onehot_labels(self): 174 | return self._get_raw_labels().dtype == np.int64 175 | 176 | #---------------------------------------------------------------------------- 177 | 178 | class ImageFolderDataset(Dataset): 179 | def __init__(self, 180 | path, # Path to directory or zip. 181 | resolution = None, # Ensure specific resolution, None = highest available. 182 | **super_kwargs, # Additional arguments for the Dataset base class. 183 | ): 184 | self._path = path 185 | self._zipfile = None 186 | 187 | if os.path.isdir(self._path): 188 | self._type = 'dir' 189 | self._all_fnames = {os.path.relpath(os.path.join(root, fname), start=self._path) for root, _dirs, files in os.walk(self._path) for fname in files} 190 | elif self._file_ext(self._path) == '.zip': 191 | self._type = 'zip' 192 | self._all_fnames = set(self._get_zipfile().namelist()) 193 | else: 194 | raise IOError('Path must point to a directory or zip') 195 | 196 | PIL.Image.init() 197 | self._image_fnames = sorted(fname for fname in self._all_fnames if self._file_ext(fname) in PIL.Image.EXTENSION) 198 | if len(self._image_fnames) == 0: 199 | raise IOError('No image files found in the specified path') 200 | 201 | name = os.path.splitext(os.path.basename(self._path))[0] 202 | raw_shape = [len(self._image_fnames)] + list(self._load_raw_image(0).shape) 203 | # if resolution is not None and (raw_shape[2] != resolution or raw_shape[3] != resolution): 204 | # raise IOError('Image files do not match the specified resolution') 205 | super().__init__(name=name, raw_shape=raw_shape, **super_kwargs) 206 | 207 | @staticmethod 208 | def _file_ext(fname): 209 | return os.path.splitext(fname)[1].lower() 210 | 211 | def _get_zipfile(self): 212 | assert self._type == 'zip' 213 | if self._zipfile is None: 214 | self._zipfile = zipfile.ZipFile(self._path) 215 | return self._zipfile 216 | 217 | def _open_file(self, fname): 218 | if self._type == 'dir': 219 | return open(os.path.join(self._path, fname), 'rb') 220 | if self._type == 'zip': 221 | return self._get_zipfile().open(fname, 'r') 222 | return None 223 | 224 | def close(self): 225 | try: 226 | if self._zipfile is not None: 227 | self._zipfile.close() 228 | finally: 229 | self._zipfile = None 230 | 231 | def __getstate__(self): 232 | return dict(super().__getstate__(), _zipfile=None) 233 | 234 | def _load_raw_image(self, raw_idx): 235 | fname = self._image_fnames[raw_idx] 236 | with self._open_file(fname) as f: 237 | if pyspng is not None and self._file_ext(fname) == '.png': 238 | image = pyspng.load(f.read()) 239 | else: 240 | image = np.array(PIL.Image.open(f)) 241 | if image.ndim == 2: 242 | image = image[:, :, np.newaxis] # HW => HWC 243 | image = image.transpose(2, 0, 1) # HWC => CHW 244 | return image 245 | 246 | # !!! cond dir labels 247 | def _load_dir_labels(self): 248 | dir_levels = {len(fname.replace('\\', '/').split('/')) for fname in self._image_fnames} # dict = unique only 249 | if dir_levels == {2}: 250 | print(' Dataset subdirs are set for labels') 251 | dir_names = {fname.replace('\\', '/').split('/')[0] for fname in self._image_fnames} # dict = unique only 252 | dir_labels = {} 253 | for i, dir in enumerate(sorted(dir_names)): 254 | dir_labels[dir] = i 255 | all_dirs = [fname.replace('\\', '/').split('/')[0] for fname in self._image_fnames] # list = for all files 256 | labels = [dir_labels[d] for d in all_dirs] 257 | # labels = [dir_labels[fname.replace('\\', '/').split('/')[0]] for fname in self._image_fnames] # oneliner 258 | labels = np.array(labels) 259 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 260 | return labels 261 | else: 262 | return None 263 | 264 | def _load_raw_labels(self): 265 | fname = 'dataset.json' 266 | if fname not in self._all_fnames: 267 | return None 268 | with self._open_file(fname) as f: 269 | labels = json.load(f)['labels'] 270 | if labels is None: 271 | return None 272 | labels = dict(labels) 273 | labels = [labels[fname.replace('\\', '/')] for fname in self._image_fnames] 274 | labels = np.array(labels) 275 | labels = labels.astype({1: np.int64, 2: np.float32}[labels.ndim]) 276 | return labels 277 | 278 | 279 | #---------------------------------------------------------------------------- 280 | -------------------------------------------------------------------------------- /src/training/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import torch 11 | from torch_utils import training_stats 12 | from torch_utils import misc 13 | from torch_utils.ops import conv2d_gradfix 14 | 15 | #---------------------------------------------------------------------------- 16 | 17 | class Loss: 18 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): # to be overridden by subclass 19 | raise NotImplementedError() 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | class StyleGAN2Loss(Loss): 24 | # !!! apa 25 | def __init__(self, device, G_mapping, G_synthesis, D, apa=False, augment_pipe=None, style_mixing_prob=0.9, r1_gamma=10, pl_batch_shrink=2, pl_decay=0.01, pl_weight=2): 26 | super().__init__() 27 | self.device = device 28 | self.G_mapping = G_mapping 29 | self.G_synthesis = G_synthesis 30 | self.D = D 31 | self.augment_pipe = augment_pipe 32 | self.style_mixing_prob = style_mixing_prob 33 | self.r1_gamma = r1_gamma 34 | self.pl_batch_shrink = pl_batch_shrink 35 | self.pl_decay = pl_decay 36 | self.pl_weight = pl_weight 37 | self.pl_mean = torch.zeros([], device=device) 38 | # !!! apa 39 | self.apa = apa 40 | self.pseudo_data = None 41 | 42 | def run_G(self, z, c, sync): 43 | with misc.ddp_sync(self.G_mapping, sync): 44 | ws = self.G_mapping(z, c) 45 | if self.style_mixing_prob > 0: 46 | with torch.autograd.profiler.record_function('style_mixing'): 47 | cutoff = torch.empty([], dtype=torch.int64, device=ws.device).random_(1, ws.shape[1]) 48 | cutoff = torch.where(torch.rand([], device=ws.device) < self.style_mixing_prob, cutoff, torch.full_like(cutoff, ws.shape[1])) 49 | ws[:, cutoff:] = self.G_mapping(torch.randn_like(z), c, skip_w_avg_update=True)[:, cutoff:] 50 | with misc.ddp_sync(self.G_synthesis, sync): 51 | img = self.G_synthesis(ws) 52 | return img, ws 53 | 54 | def run_D(self, img, c, sync): 55 | if self.augment_pipe is not None: 56 | img = self.augment_pipe(img) 57 | with misc.ddp_sync(self.D, sync): 58 | logits = self.D(img, c) 59 | return logits 60 | 61 | # !!! apa 62 | def adaptive_pseudo_augmentation(self, real_img): 63 | # Apply Adaptive Pseudo Augmentation (APA) 64 | batch_size = real_img.shape[0] 65 | pseudo_flag = torch.ones([batch_size, 1, 1, 1], device=self.device) 66 | pseudo_flag = torch.where(torch.rand([batch_size, 1, 1, 1], device=self.device) < self.augment_pipe.p, pseudo_flag, torch.zeros_like(pseudo_flag)) 67 | if torch.allclose(pseudo_flag, torch.zeros_like(pseudo_flag)): 68 | return real_img 69 | else: 70 | assert self.pseudo_data is not None 71 | return self.pseudo_data * pseudo_flag + real_img * (1 - pseudo_flag) 72 | 73 | def accumulate_gradients(self, phase, real_img, real_c, gen_z, gen_c, sync, gain): 74 | assert phase in ['Gmain', 'Greg', 'Gboth', 'Dmain', 'Dreg', 'Dboth'] 75 | do_Gmain = (phase in ['Gmain', 'Gboth']) 76 | do_Dmain = (phase in ['Dmain', 'Dboth']) 77 | do_Gpl = (phase in ['Greg', 'Gboth']) and (self.pl_weight != 0) 78 | do_Dr1 = (phase in ['Dreg', 'Dboth']) and (self.r1_gamma != 0) 79 | 80 | # Gmain: Maximize logits for generated images. 81 | if do_Gmain: 82 | with torch.autograd.profiler.record_function('Gmain_forward'): 83 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=(sync and not do_Gpl)) # May get synced by Gpl. 84 | # !!! apa 85 | # Update pseudo data 86 | if self.apa is True: 87 | self.pseudo_data = gen_img.detach() 88 | gen_logits = self.run_D(gen_img, gen_c, sync=False) 89 | training_stats.report('Loss/scores/fake', gen_logits) 90 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 91 | loss_Gmain = torch.nn.functional.softplus(-gen_logits) # -log(sigmoid(gen_logits)) 92 | training_stats.report('Loss/G/loss', loss_Gmain) 93 | with torch.autograd.profiler.record_function('Gmain_backward'): 94 | loss_Gmain.mean().mul(gain).backward() 95 | 96 | # Gpl: Apply path length regularization. 97 | if do_Gpl: 98 | with torch.autograd.profiler.record_function('Gpl_forward'): 99 | batch_size = gen_z.shape[0] // self.pl_batch_shrink 100 | gen_img, gen_ws = self.run_G(gen_z[:batch_size], gen_c[:batch_size], sync=sync) 101 | pl_noise = torch.randn_like(gen_img) / np.sqrt(gen_img.shape[2] * gen_img.shape[3]) 102 | with torch.autograd.profiler.record_function('pl_grads'), conv2d_gradfix.no_weight_gradients(): 103 | pl_grads = torch.autograd.grad(outputs=[(gen_img * pl_noise).sum()], inputs=[gen_ws], create_graph=True, only_inputs=True)[0] 104 | pl_lengths = pl_grads.square().sum(2).mean(1).sqrt() 105 | pl_mean = self.pl_mean.lerp(pl_lengths.mean(), self.pl_decay) 106 | self.pl_mean.copy_(pl_mean.detach()) 107 | pl_penalty = (pl_lengths - pl_mean).square() 108 | training_stats.report('Loss/pl_penalty', pl_penalty) 109 | loss_Gpl = pl_penalty * self.pl_weight 110 | training_stats.report('Loss/G/reg', loss_Gpl) 111 | with torch.autograd.profiler.record_function('Gpl_backward'): 112 | (gen_img[:, 0, 0, 0] * 0 + loss_Gpl).mean().mul(gain).backward() 113 | 114 | # Dmain: Minimize logits for generated images. 115 | loss_Dgen = 0 116 | if do_Dmain: 117 | with torch.autograd.profiler.record_function('Dgen_forward'): 118 | gen_img, _gen_ws = self.run_G(gen_z, gen_c, sync=False) 119 | gen_logits = self.run_D(gen_img, gen_c, sync=False) # Gets synced by loss_Dreal. 120 | training_stats.report('Loss/scores/fake', gen_logits) 121 | training_stats.report('Loss/signs/fake', gen_logits.sign()) 122 | loss_Dgen = torch.nn.functional.softplus(gen_logits) # -log(1 - sigmoid(gen_logits)) 123 | with torch.autograd.profiler.record_function('Dgen_backward'): 124 | loss_Dgen.mean().mul(gain).backward() 125 | 126 | # Dmain: Maximize logits for real images. 127 | # Dr1: Apply R1 regularization. 128 | if do_Dmain or do_Dr1: 129 | name = 'Dreal_Dr1' if do_Dmain and do_Dr1 else 'Dreal' if do_Dmain else 'Dr1' 130 | with torch.autograd.profiler.record_function(name + '_forward'): 131 | # !!! apa 132 | # Apply Adaptive Pseudo Augmentation (APA) when --aug!='noaug' 133 | if self.apa is True and self.augment_pipe is not None: 134 | real_img_tmp = self.adaptive_pseudo_augmentation(real_img) 135 | else: 136 | real_img_tmp = real_img 137 | real_img_tmp = real_img_tmp.detach().requires_grad_(do_Dr1) 138 | # real_img_tmp = real_img.detach().requires_grad_(do_Dr1) 139 | real_logits = self.run_D(real_img_tmp, real_c, sync=sync) 140 | training_stats.report('Loss/scores/real', real_logits) 141 | training_stats.report('Loss/signs/real', real_logits.sign()) 142 | 143 | loss_Dreal = 0 144 | if do_Dmain: 145 | loss_Dreal = torch.nn.functional.softplus(-real_logits) # -log(sigmoid(real_logits)) 146 | training_stats.report('Loss/D/loss', loss_Dgen + loss_Dreal) 147 | 148 | loss_Dr1 = 0 149 | if do_Dr1: 150 | with torch.autograd.profiler.record_function('r1_grads'), conv2d_gradfix.no_weight_gradients(): 151 | r1_grads = torch.autograd.grad(outputs=[real_logits.sum()], inputs=[real_img_tmp], create_graph=True, only_inputs=True)[0] 152 | r1_penalty = r1_grads.square().sum([1,2,3]) 153 | loss_Dr1 = r1_penalty * (self.r1_gamma / 2) 154 | training_stats.report('Loss/r1_penalty', r1_penalty) 155 | training_stats.report('Loss/D/reg', loss_Dr1) 156 | 157 | with torch.autograd.profiler.record_function(name + '_backward'): 158 | (real_logits * 0 + loss_Dreal + loss_Dr1).mean().mul(gain).backward() 159 | 160 | #---------------------------------------------------------------------------- 161 | -------------------------------------------------------------------------------- /src/util/multicrop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | from multiprocessing import Pool 4 | from shutil import get_terminal_size 5 | import time 6 | import argparse 7 | 8 | import numpy as np 9 | import cv2 10 | 11 | from utilgan import img_list, basename 12 | try: # progress bar for notebooks 13 | get_ipython().__class__.__name__ 14 | from progress_bar import ProgressIPy as ProgressBar 15 | except: # normal console 16 | from progress_bar import ProgressBar 17 | 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('-i', '--in_dir', help='Input directory') 20 | parser.add_argument('-o', '--out_dir', help='Output directory') 21 | parser.add_argument('-s', '--size', type=int, default=512, help='Output directory') 22 | parser.add_argument('--step', type=int, default=None, help='Step') 23 | parser.add_argument('--workers', type=int, default=8, help='number of workers (8, as of cpu#)') 24 | parser.add_argument('--png_compression', type=int, default=1, help='png compression (0 to 9; 0 = uncompressed, fast)') 25 | parser.add_argument('--jpg_quality', type=int, default=95, help='jpeg quality (0 to 100; 95 = max reasonable)') 26 | a = parser.parse_args() 27 | 28 | # https://pillow.readthedocs.io/en/3.0.x/handbook/image-file-formats.html#jpeg 29 | # image quality = from 1 (worst) to 95 (best); default 75. Values above 95 should be avoided; 30 | # 100 disables portions of the JPEG compression algorithm => results in large files with hardly any gain in image quality. 31 | 32 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer 33 | # compression time. If read raw images during training, use 0 for faster IO speed. 34 | 35 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 36 | 37 | def worker(path, save_folder, crop_size, step, min_step): 38 | img_name = basename(path) 39 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 40 | 41 | # convert monochrome to RGB if needed 42 | if len(img.shape) == 2: 43 | img = img[:,:,np.newaxis] 44 | if img.shape[2] == 1: 45 | img = img[:, :, (0,0,0)] 46 | h, w, c = img.shape 47 | 48 | ext = 'png' if img.shape[2]==4 else 'jpg' 49 | 50 | min_size = min(h,w) 51 | if min_size < crop_size: 52 | h = int(h * crop_size/min_size) 53 | w = int(w * crop_size/min_size) 54 | img = cv2.resize(img, (w,h), interpolation = cv2.INTER_AREA) 55 | 56 | h_space = np.arange(0, h - crop_size + 1, step) 57 | if h - (h_space[-1] + crop_size) > min_step: 58 | h_space = np.append(h_space, h - crop_size) 59 | w_space = np.arange(0, w - crop_size + 1, step) 60 | if w - (w_space[-1] + crop_size) > min_step: 61 | w_space = np.append(w_space, w - crop_size) 62 | 63 | index = 0 64 | for x in h_space: 65 | for y in w_space: 66 | index += 1 67 | crop_img = img[x:x + crop_size, y:y + crop_size, :] 68 | crop_img = np.ascontiguousarray(crop_img) 69 | if ext=='png': 70 | cv2.imwrite(os.path.join(save_folder, '%s-s%03d.%s' % (img_name, index, ext)), crop_img, [cv2.IMWRITE_PNG_COMPRESSION, a.png_compression]) 71 | else: 72 | cv2.imwrite(os.path.join(save_folder, '%s-s%03d.%s' % (img_name, index, ext)), crop_img, [cv2.IMWRITE_JPEG_QUALITY, a.jpg_quality]) 73 | return 'Processing {:s} ...'.format(img_name) 74 | 75 | def main(): 76 | """A multi-thread tool to crop sub images.""" 77 | input_folder = a.in_dir 78 | save_folder = a.out_dir 79 | n_thread = a.workers 80 | crop_size = a.size 81 | step = a.size // 2 if a.step is None else a.step 82 | min_step = a.size // 8 83 | 84 | os.makedirs(save_folder, exist_ok=True) 85 | 86 | images = img_list(input_folder, subdir=True) 87 | 88 | def update(arg): 89 | pbar.upd(arg) 90 | 91 | pbar = ProgressBar(len(images)) 92 | 93 | pool = Pool(n_thread) 94 | for path in images: 95 | pool.apply_async(worker, 96 | args=(path, save_folder, crop_size, step, min_step), 97 | callback=update) 98 | pool.close() 99 | pool.join() 100 | print('All subprocesses done.') 101 | 102 | 103 | if __name__ == '__main__': 104 | # workaround for multithreading in jupyter console 105 | __spec__ = "ModuleSpec(name='builtins', loader=)" 106 | main() 107 | 108 | -------------------------------------------------------------------------------- /src/util/progress_bar.py: -------------------------------------------------------------------------------- 1 | """ 2 | from progress_bar import ProgressBar 3 | 4 | pbar = ProgressBar(steps) 5 | pbar.upd() 6 | """ 7 | 8 | import os 9 | import sys 10 | import math 11 | os.system('') #enable VT100 Escape Sequence for WINDOWS 10 Ver. 1607 12 | 13 | from shutil import get_terminal_size 14 | import time 15 | 16 | import ipywidgets as ipy 17 | import IPython 18 | class ProgressIPy(object): 19 | def __init__(self, task_num=10): 20 | self.pbar = ipy.IntProgress(min=0, max=task_num, bar_style='') # (value=0, min=0, max=max, step=1, description=description, bar_style='') 21 | self.labl = ipy.Label() 22 | IPython.display.display(ipy.HBox([self.pbar, self.labl])) 23 | self.task_num = task_num 24 | self.completed = 0 25 | self.start() 26 | 27 | def start(self, task_num=None): 28 | if task_num is not None: 29 | self.task_num = task_num 30 | if self.task_num > 0: 31 | self.labl.value = '0/{}'.format(self.task_num) 32 | else: 33 | self.labl.value = 'completed: 0, elapsed: 0s' 34 | self.start_time = time.time() 35 | 36 | def upd(self, *p, **kw): 37 | self.completed += 1 38 | elapsed = time.time() - self.start_time + 0.0000000000001 39 | fps = self.completed / elapsed if elapsed>0 else 0 40 | if self.task_num > 0: 41 | finaltime = time.asctime(time.localtime(self.start_time + self.task_num * elapsed / float(self.completed))) 42 | fin = ' end %s' % finaltime[11:16] 43 | percentage = self.completed / float(self.task_num) 44 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 45 | self.labl.value = '{}/{}, rate {:.3g}s, time {}s, left {}s, {}'.format(self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), fin) 46 | else: 47 | self.labl.value = 'completed {}, time {}s, {:.1f} steps/s'.format(self.completed, int(elapsed + 0.5), fps) 48 | self.pbar.value += 1 49 | if self.completed == self.task_num: self.pbar.bar_style = 'success' 50 | return self.completed 51 | 52 | 53 | class ProgressBar(object): 54 | '''A progress bar which can print the progress 55 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 56 | ''' 57 | def __init__(self, task_num=0, bar_width=50, start=True): 58 | self.task_num = task_num 59 | max_bar_width = self._get_max_bar_width() 60 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 61 | self.completed = 0 62 | if start: 63 | self.start() 64 | 65 | def _get_max_bar_width(self): 66 | terminal_width, _ = get_terminal_size() 67 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 68 | if max_bar_width < 10: 69 | print('terminal is small ({}), make it bigger for proper visualization'.format(terminal_width)) 70 | max_bar_width = 10 71 | return max_bar_width 72 | 73 | def start(self, task_num=None): 74 | if task_num is not None: 75 | self.task_num = task_num 76 | if self.task_num > 0: 77 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(' ' * self.bar_width, self.task_num, 'Start...')) 78 | else: 79 | sys.stdout.write('completed: 0, elapsed: 0s') 80 | sys.stdout.flush() 81 | self.start_time = time.time() 82 | 83 | def upd(self, msg=None): 84 | self.completed += 1 85 | elapsed = time.time() - self.start_time + 0.0000000000001 86 | fps = self.completed / elapsed if elapsed>0 else 0 87 | if self.task_num > 0: 88 | percentage = self.completed / float(self.task_num) 89 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 90 | finaltime = time.asctime(time.localtime(self.start_time + self.task_num * elapsed / float(self.completed))) 91 | fin_msg = ' %ss left, end %s' % (shortime(eta), finaltime[11:16]) 92 | if msg is not None: fin_msg += ' ' + str(msg) 93 | mark_width = int(self.bar_width * percentage) 94 | bar_chars = 'X' * mark_width + '-' * (self.bar_width - mark_width) # ▒ ▓ █ 95 | sys.stdout.write('\033[2A') # cursor up 2 lines 96 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 97 | try: 98 | sys.stdout.write('[{}] {}/{}, rate {:.3g}s, time {}s, left {}s \n{}\n'.format( 99 | bar_chars, self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), fin_msg)) 100 | except: 101 | sys.stdout.write('[{}] {}/{}, rate {:.3g}s, time {}s, left {}s \n{}\n'.format( 102 | bar_chars, self.completed, self.task_num, 1./fps, shortime(elapsed), shortime(eta), '<< unprintable >>')) 103 | else: 104 | sys.stdout.write('completed {}, time {}s, {:.1f} steps/s'.format(self.completed, int(elapsed + 0.5), fps)) 105 | sys.stdout.flush() 106 | 107 | def reset(self, count=None, newline=False): 108 | self.start_time = time.time() 109 | if count is not None: 110 | self.task_num = count 111 | if newline is True: 112 | sys.stdout.write('\n\n') 113 | 114 | def time_days(sec): 115 | return '%dd %d:%02d:%02d' % (sec/86400, (sec/3600)%24, (sec/60)%60, sec%60) 116 | def time_hrs(sec): 117 | return '%d:%02d:%02d' % (sec/3600, (sec/60)%60, sec%60) 118 | def shortime(sec): 119 | if sec < 60: 120 | time_short = '%d' % (sec) 121 | elif sec < 3600: 122 | time_short = '%d:%02d' % ((sec/60)%60, sec%60) 123 | elif sec < 86400: 124 | time_short = time_hrs(sec) 125 | else: 126 | time_short = time_days(sec) 127 | return time_short 128 | 129 | -------------------------------------------------------------------------------- /src/util/utilgan.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | import numpy as np 6 | from scipy.ndimage import gaussian_filter 7 | from scipy.interpolate import CubicSpline as CubSpline 8 | from scipy.special import comb 9 | import scipy 10 | from imageio import imread 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | # from perlin import PerlinNoiseFactory as Perlin 16 | # noise = Perlin(1) 17 | 18 | # def latent_noise(t, dim, noise_step=78564.543): 19 | # latent = np.zeros((1, dim)) 20 | # for i in range(dim): 21 | # latent[0][i] = noise(t + i * noise_step) 22 | # return latent 23 | 24 | def load_latents(npy_file): 25 | key_latents = np.load(npy_file) 26 | try: 27 | key_latents = key_latents[key_latents.files[0]] 28 | except: 29 | pass 30 | idx_file = os.path.splitext(npy_file)[0] + '.txt' 31 | if os.path.exists(idx_file): 32 | with open(idx_file) as f: 33 | lat_idx = f.readline() 34 | lat_idx = [int(l.strip()) for l in lat_idx.split(',') if '\n' not in l and len(l.strip())>0] 35 | key_latents = [key_latents[i] for i in lat_idx] 36 | return np.asarray(key_latents) 37 | 38 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = 39 | 40 | def get_z(shape, rnd, uniform=False): 41 | if uniform: 42 | return rnd.uniform(0., 1., shape) 43 | else: 44 | return rnd.randn(*shape) # *x unpacks tuple/list to sequence 45 | 46 | def smoothstep(x, NN=1., xmin=0., xmax=1.): 47 | N = math.ceil(NN) 48 | x = np.clip((x - xmin) / (xmax - xmin), 0, 1) 49 | result = 0 50 | for n in range(0, N+1): 51 | result += scipy.special.comb(N+n, n) * scipy.special.comb(2*N+1, N-n) * (-x)**n 52 | result *= x**(N+1) 53 | if NN != N: result = (x + result) / 2 54 | return result 55 | 56 | def lerp(z1, z2, num_steps, smooth=0.): 57 | vectors = [] 58 | xs = [step / (num_steps - 1) for step in range(num_steps)] 59 | if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] 60 | for x in xs: 61 | interpol = z1 + (z2 - z1) * x 62 | vectors.append(interpol) 63 | return np.array(vectors) 64 | 65 | # interpolate on hypersphere 66 | def slerp(z1, z2, num_steps, smooth=0.): 67 | z1_norm = np.linalg.norm(z1) 68 | z2_norm = np.linalg.norm(z2) 69 | z2_normal = z2 * (z1_norm / z2_norm) 70 | vectors = [] 71 | xs = [step / (num_steps - 1) for step in range(num_steps)] 72 | if smooth > 0: xs = [smoothstep(x, smooth) for x in xs] 73 | for x in xs: 74 | interplain = z1 + (z2 - z1) * x 75 | interp = z1 + (z2_normal - z1) * x 76 | interp_norm = np.linalg.norm(interp) 77 | interpol_normal = interplain * (z1_norm / interp_norm) 78 | # interpol_normal = interp * (z1_norm / interp_norm) 79 | vectors.append(interpol_normal) 80 | return np.array(vectors) 81 | 82 | def cublerp(points, steps, fstep): 83 | keys = np.array([i*fstep for i in range(steps)] + [steps*fstep]) 84 | points = np.concatenate((points, np.expand_dims(points[0], 0))) 85 | cspline = CubSpline(keys, points) 86 | return cspline(range(steps*fstep+1)) 87 | 88 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = 89 | 90 | def latent_anima(shape, frames, transit, key_latents=None, smooth=0.5, cubic=False, gauss=False, seed=None, verbose=True): 91 | if key_latents is None: 92 | transit = int(max(1, min(frames, transit))) 93 | steps = max(1, int(frames // transit)) 94 | log = ' timeline: %d steps by %d' % (steps, transit) 95 | 96 | if seed is None: 97 | seed = np.random.seed(int((time.time()%1) * 9999)) 98 | rnd = np.random.RandomState(seed) 99 | 100 | # make key points 101 | if key_latents is None: 102 | key_latents = np.array([get_z(shape, rnd) for i in range(steps)]) 103 | 104 | latents = np.expand_dims(key_latents[0], 0) 105 | 106 | # populate lerp between key points 107 | if transit == 1: 108 | latents = key_latents 109 | else: 110 | if cubic: 111 | latents = cublerp(key_latents, steps, transit) 112 | log += ', cubic' 113 | else: 114 | for i in range(steps): 115 | zA = key_latents[i] 116 | zB = key_latents[(i+1) % steps] 117 | interps_z = slerp(zA, zB, transit, smooth=smooth) 118 | latents = np.concatenate((latents, interps_z)) 119 | latents = np.array(latents) 120 | 121 | if gauss: 122 | lats_post = gaussian_filter(latents, [transit, 0, 0], mode="wrap") 123 | lats_post = (lats_post / np.linalg.norm(lats_post, axis=-1, keepdims=True)) * math.sqrt(np.prod(shape)) 124 | log += ', gauss' 125 | latents = lats_post 126 | 127 | if verbose: print(log) 128 | if latents.shape[0] > frames: # extra frame 129 | latents = latents[1:] 130 | return latents 131 | 132 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = 133 | 134 | def multimask(x, size, latmask=None, countHW=[1,1], delta=0.): 135 | Hx, Wx = countHW 136 | bcount = x.shape[0] 137 | 138 | if max(countHW) > 1: 139 | W = x.shape[3] # width 140 | H = x.shape[2] # height 141 | if Wx > 1: 142 | stripe_mask = [] 143 | for i in range(Wx): 144 | ch_mask = peak_roll(W, Wx, i, delta).unsqueeze(0).unsqueeze(0) # [1,1,w] th 145 | ch_mask = ch_mask.repeat(1,H,1) # [1,h,w] 146 | stripe_mask.append(ch_mask) 147 | maskW = torch.cat(stripe_mask, 0).unsqueeze(1) # [x,1,h,w] 148 | else: maskW = [1] 149 | if Hx > 1: 150 | stripe_mask = [] 151 | for i in range(Hx): 152 | ch_mask = peak_roll(H, Hx, i, delta).unsqueeze(1).unsqueeze(0) # [1,h,1] th 153 | ch_mask = ch_mask.repeat(1,1,W) # [1,h,w] 154 | stripe_mask.append(ch_mask) 155 | maskH = torch.cat(stripe_mask, 0).unsqueeze(1) # [y,1,h,w] 156 | else: maskH = [1] 157 | 158 | mask = [] 159 | for i in range(Wx): 160 | for j in range(Hx): 161 | mask.append(maskW[i] * maskH[j]) 162 | mask = torch.cat(mask, 0).unsqueeze(1) # [xy,1,h,w] 163 | mask = mask.to(x.device) 164 | x = torch.sum(x[:Hx*Wx] * mask, 0, keepdim=True) 165 | 166 | elif latmask is not None: 167 | if len(latmask.shape) < 4: 168 | latmask = latmask.unsqueeze(1) # [b,1,h,w] 169 | lms = latmask.shape 170 | if list(lms[2:]) != list(size) and np.prod(lms) > 1: 171 | latmask = F.interpolate(latmask, size) # , mode='nearest' 172 | latmask = latmask.type(x.dtype) 173 | x = torch.sum(x[:lms[0]] * latmask, 0, keepdim=True) 174 | else: 175 | return x 176 | 177 | x = x.repeat(bcount,1,1,1) 178 | return x # [b,f,h,w] 179 | 180 | def peak_roll(width, count, num, delta): 181 | step = width // count 182 | if width > step*2: 183 | fill_range = torch.zeros([width-step*2]) 184 | full_ax = torch.cat((peak(step, delta), fill_range), 0) 185 | else: 186 | full_ax = peak(step, delta)[:width] 187 | if num == 0: 188 | shift = max(width - (step//2), 0.) # must be positive! 189 | else: 190 | shift = step*num - (step//2) 191 | full_ax = torch.roll(full_ax, shift, 0) 192 | return full_ax # [width,] 193 | 194 | def peak(steps, delta): 195 | x = torch.linspace(0.-delta, 1.+ delta, steps) 196 | x_rev = torch.flip(x,[0]) 197 | x = torch.cat((x, x_rev), 0) 198 | x = torch.clip(x, 0., 1.) 199 | return x # [steps*2,] 200 | 201 | # = = = = = = = = = = = = = = = = = = = = = = = = = = = 202 | 203 | def ups2d(x, factor=2): 204 | assert isinstance(factor, int) and factor >= 1 205 | if factor == 1: return x 206 | s = x.shape 207 | x = x.reshape(-1, s[1], s[2], 1, s[3], 1) 208 | x = x.repeat(1, 1, 1, factor, 1, factor) 209 | x = x.reshape(-1, s[1], s[2] * factor, s[3] * factor) 210 | return x 211 | 212 | # Tiles an array around two points, allowing for pad lengths greater than the input length 213 | # NB: if symm=True, every second tile is mirrored = messed up in GAN 214 | # adapted from https://discuss.pytorch.org/t/symmetric-padding/19866/3 215 | def tile_pad(xt, padding, symm=True): 216 | h, w = xt.shape[-2:] 217 | left, right, top, bottom = padding 218 | 219 | def tile(x, minx, maxx, symm=True): 220 | rng = maxx - minx 221 | if symm is True: # triangular reflection 222 | double_rng = 2*rng 223 | mod = np.fmod(x - minx, double_rng) 224 | normed_mod = np.where(mod < 0, mod+double_rng, mod) 225 | out = np.where(normed_mod >= rng, double_rng - normed_mod, normed_mod) + minx 226 | else: # repeating tiles 227 | mod = np.remainder(x - minx, rng) 228 | out = mod + minx 229 | return np.array(out, dtype=x.dtype) 230 | 231 | x_idx = np.arange(-left, w+right) 232 | y_idx = np.arange(-top, h+bottom) 233 | x_pad = tile(x_idx, -0.5, w-0.5, symm) 234 | y_pad = tile(y_idx, -0.5, h-0.5, symm) 235 | xx, yy = np.meshgrid(x_pad, y_pad) 236 | return xt[..., yy, xx] 237 | 238 | def pad_up_to(x, size, type='centr'): 239 | sh = x.shape[2:][::-1] 240 | if list(x.shape[2:]) == list(size): return x 241 | padding = [] 242 | for i, s in enumerate(size[::-1]): 243 | if 'side' in type.lower(): 244 | padding = padding + [0, s-sh[i]] 245 | else: # centr 246 | p0 = (s-sh[i]) // 2 247 | p1 = s-sh[i] - p0 248 | padding = padding + [p0,p1] 249 | y = tile_pad(x, padding, symm = 'symm' in type.lower()) 250 | # if 'symm' in type.lower(): 251 | # y = tile_pad(x, padding, symm=True) 252 | # else: 253 | # y = F.pad(x, padding, 'circular') 254 | return y 255 | 256 | # scale_type may include pad, side, symm 257 | def fix_size(x, size, scale_type='centr'): 258 | if not len(x.shape) == 4: 259 | raise Exception(" Wrong data rank, shape:", x.shape) 260 | if x.shape[2:] == size: 261 | return x 262 | if (x.shape[2]*2, x.shape[3]*2) == size: 263 | return ups2d(x) 264 | 265 | if scale_type.lower() == 'fit': 266 | return F.interpolate(x, size, mode='nearest') # , align_corners=True 267 | elif 'pad' in scale_type.lower(): 268 | pass 269 | else: # proportional scale to smaller side, then pad to bigger side 270 | sh0 = x.shape[2:] 271 | upsc = np.min(size) / np.min(sh0) 272 | new_size = [int(sh0[i]*upsc) for i in [0,1]] 273 | x = F.interpolate(x, new_size, mode='nearest') # , align_corners=True 274 | 275 | x = pad_up_to(x, size, scale_type) 276 | return x 277 | 278 | # Make list of odd sizes for upsampling to arbitrary resolution 279 | def hw_scales(size, base, n, keep_first_layers=None, verbose=False): 280 | if isinstance(base, int): base = (base, base) 281 | start_res = [int(b * 2 ** (-n)) for b in base] 282 | 283 | start_res[0] = int(start_res[0] * size[0] // base[0]) 284 | start_res[1] = int(start_res[1] * size[1] // base[1]) 285 | 286 | hw_list = [] 287 | 288 | if base[0] != base[1] and verbose is True: 289 | print(' size', size, 'base', base, 'start_res', start_res, 'n', n) 290 | if keep_first_layers is not None and keep_first_layers > 0: 291 | for i in range(keep_first_layers): 292 | hw_list.append(start_res) 293 | start_res = [x*2 for x in start_res] 294 | n -= 1 295 | 296 | ch = (size[0] / start_res[0]) ** (1/n) 297 | cw = (size[1] / start_res[1]) ** (1/n) 298 | for i in range(n): 299 | h = math.floor(start_res[0] * ch**i) 300 | w = math.floor(start_res[1] * cw**i) 301 | hw_list.append((h,w)) 302 | 303 | hw_list.append(size) 304 | return hw_list 305 | 306 | def calc_res(shape): 307 | base0 = 2**int(np.log2(shape[0])) 308 | base1 = 2**int(np.log2(shape[1])) 309 | base = min(base0, base1) 310 | min_res = min(shape[0], shape[1]) 311 | 312 | def int_log2(xs, base): 313 | return [x * 2**(2-int(np.log2(base))) % 1 == 0 for x in xs] 314 | if min_res != base or max(*shape) / min(*shape) >= 2: 315 | if np.log2(base) < 10 and all(int_log2(shape, base*2)): 316 | base = base * 2 317 | 318 | return base # , [shape[0]/base, shape[1]/base] 319 | 320 | def calc_init_res(shape, resolution=None): 321 | if len(shape) == 1: 322 | shape = [shape[0], shape[0], 1] 323 | elif len(shape) == 2: 324 | shape = [*shape, 1] 325 | size = shape[:2] if shape[2] < min(*shape[:2]) else shape[1:] # fewer colors than pixels 326 | if resolution is None: 327 | resolution = calc_res(size) 328 | res_log2 = int(np.log2(resolution)) 329 | init_res = [int(s * 2**(2-res_log2)) for s in size] 330 | return init_res, resolution, res_log2 331 | 332 | def basename(file): 333 | return os.path.splitext(os.path.basename(file))[0] 334 | 335 | def file_list(path, ext=None, subdir=None): 336 | if subdir is True: 337 | files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] 338 | else: 339 | files = [os.path.join(path, f) for f in os.listdir(path)] 340 | if ext is not None: 341 | if isinstance(ext, list): 342 | files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ext] 343 | elif isinstance(ext, str): 344 | files = [f for f in files if f.endswith(ext)] 345 | else: 346 | print(' Unknown extension/type for file list!') 347 | return sorted([f for f in files if os.path.isfile(f)]) 348 | 349 | def dir_list(in_dir): 350 | dirs = [os.path.join(in_dir, x) for x in os.listdir(in_dir)] 351 | return sorted([f for f in dirs if os.path.isdir(f)]) 352 | 353 | def img_list(path, subdir=None): 354 | if subdir is True: 355 | files = [os.path.join(dp, f) for dp, dn, fn in os.walk(path) for f in fn] 356 | else: 357 | files = [os.path.join(path, f) for f in os.listdir(path)] 358 | files = [f for f in files if os.path.splitext(f.lower())[1][1:] in ['jpg', 'jpeg', 'png', 'ppm', 'tif']] 359 | files = [f for f in files if not '/__MACOSX/' in f.replace('\\', '/')] # workaround fix for macos phantom files 360 | return sorted([f for f in files if os.path.isfile(f)]) 361 | 362 | def img_read(path): 363 | img = imread(path) 364 | # 8bit to 256bit 365 | if (img.ndim == 2) or (img.shape[2] == 1): 366 | img = np.dstack((img,img,img)) 367 | # rgba to rgb 368 | if img.shape[2] == 4: 369 | img = img[:,:,:3] 370 | return img 371 | 372 | -------------------------------------------------------------------------------- /train.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 > NUL 3 | 4 | python src/train.py --data data/%1 ^ 5 | %2 %3 %4 %5 %6 %7 %8 %9 6 | 7 | -------------------------------------------------------------------------------- /train/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/eps696/stylegan2ada/cfbe2e0a64c3e126526027fce9f0649ba5352a0c/train/.gitkeep -------------------------------------------------------------------------------- /train_resume.bat: -------------------------------------------------------------------------------- 1 | @echo off 2 | chcp 65001 > NUL 3 | 4 | python src/train.py --data data/%1 --resume train/%2 ^ 5 | %3 %4 %5 %6 %7 %8 %9 6 | 7 | --------------------------------------------------------------------------------