├── directions ├── age.pt ├── pose.pt └── smile.pt ├── images └── headline-large.jpeg ├── clip2latent ├── stylegan3 │ ├── docs │ │ ├── avg_spectra_screen0.png │ │ ├── visualizer_screen0.png │ │ ├── avg_spectra_screen0_half.png │ │ ├── visualizer_screen0_half.png │ │ ├── stylegan3-teaser-1920x1006.png │ │ ├── troubleshooting.md │ │ ├── dataset-tool-help.txt │ │ └── train-help.txt │ ├── viz │ │ ├── __init__.py │ │ ├── stylemix_widget.py │ │ ├── performance_widget.py │ │ ├── capture_widget.py │ │ ├── trunc_noise_widget.py │ │ ├── latent_widget.py │ │ ├── equivariance_widget.py │ │ └── pickle_widget.py │ ├── gui_utils │ │ ├── __init__.py │ │ ├── imgui_window.py │ │ ├── text_utils.py │ │ └── imgui_utils.py │ ├── metrics │ │ ├── __init__.py │ │ ├── inception_score.py │ │ ├── frechet_inception_distance.py │ │ ├── kernel_inception_distance.py │ │ ├── precision_recall.py │ │ ├── perceptual_path_length.py │ │ └── metric_main.py │ ├── training │ │ └── __init__.py │ ├── torch_utils │ │ ├── __init__.py │ │ ├── ops │ │ │ ├── __init__.py │ │ │ ├── bias_act.h │ │ │ ├── filtered_lrelu_rd.cu │ │ │ ├── filtered_lrelu_wr.cu │ │ │ ├── filtered_lrelu_ns.cu │ │ │ ├── upfirdn2d.h │ │ │ ├── fma.py │ │ │ ├── grid_sample_gradfix.py │ │ │ ├── filtered_lrelu.h │ │ │ ├── bias_act.cpp │ │ │ ├── upfirdn2d.cpp │ │ │ ├── bias_act.cu │ │ │ └── conv2d_resample.py │ │ └── custom_ops.py │ ├── environment.yml │ ├── dnnlib │ │ └── __init__.py │ ├── Dockerfile │ ├── .github │ │ └── ISSUE_TEMPLATE │ │ │ └── bug_report.md │ ├── LICENSE.txt │ ├── gen_images.py │ └── gen_video.py ├── data.py ├── train_utils.py └── models.py ├── requirements-colab.txt ├── setup.py ├── requirements.txt ├── config ├── model │ ├── mid.yaml │ └── small.yaml ├── config.yaml └── data │ ├── sg3-lhq.yaml │ ├── sg2-ffhq.yaml │ ├── sg2-ffhq-w3.yaml │ └── sg2-encoded-ffhq.yaml ├── text ├── landscape-val.txt ├── face-val.txt └── face-test.txt ├── scripts ├── timings.py ├── check_dalle.py ├── eval_fid.py ├── eval.py ├── train.py └── generate_dataset.py ├── tests ├── train_test.py ├── data_test.py └── model_test.py ├── LICENSE.md ├── .gitignore ├── colab-gradio.ipynb └── README.md /directions/age.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/directions/age.pt -------------------------------------------------------------------------------- /directions/pose.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/directions/pose.pt -------------------------------------------------------------------------------- /directions/smile.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/directions/smile.pt -------------------------------------------------------------------------------- /images/headline-large.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/images/headline-large.jpeg -------------------------------------------------------------------------------- /clip2latent/stylegan3/docs/avg_spectra_screen0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/clip2latent/stylegan3/docs/avg_spectra_screen0.png -------------------------------------------------------------------------------- /clip2latent/stylegan3/docs/visualizer_screen0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/clip2latent/stylegan3/docs/visualizer_screen0.png -------------------------------------------------------------------------------- /clip2latent/stylegan3/docs/avg_spectra_screen0_half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/clip2latent/stylegan3/docs/avg_spectra_screen0_half.png -------------------------------------------------------------------------------- /clip2latent/stylegan3/docs/visualizer_screen0_half.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/clip2latent/stylegan3/docs/visualizer_screen0_half.png -------------------------------------------------------------------------------- /clip2latent/stylegan3/docs/stylegan3-teaser-1920x1006.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/justinpinkney/clip2latent/main/clip2latent/stylegan3/docs/stylegan3-teaser-1920x1006.png -------------------------------------------------------------------------------- /requirements-colab.txt: -------------------------------------------------------------------------------- 1 | wandb==0.12.16 2 | ninja==1.10.2.3 3 | dalle2-pytorch==0.2.38 4 | hydra-core==1.1.2 5 | typer==0.4.1 6 | joblib==1.1.0 7 | webdataset==0.2.5 8 | gradio==3.4 9 | protobuf==3.20.1 10 | -e . 11 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='clip2latent', 5 | packages=['clip2latent'], 6 | version='1.0', 7 | description='Official code for clip2latent', 8 | author='Justin Pinkney', 9 | package_data = {'clip2latent': ['stylegan3/**/*']} 10 | ) 11 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --extra-index-url https://download.pytorch.org/whl/cu113 2 | torch==1.11.0 3 | torchvision==0.12.0 4 | wandb==0.12.16 5 | ninja==1.10.2.3 6 | dalle2-pytorch==0.2.38 7 | hydra-core==1.1.2 8 | typer==0.4.1 9 | joblib==1.1.0 10 | webdataset==0.2.5 11 | gradio==3.4 12 | protobuf==3.20.1 13 | scipy==1.9.1 14 | pytest 15 | -e . -------------------------------------------------------------------------------- /config/model/mid.yaml: -------------------------------------------------------------------------------- 1 | network: 2 | dim: 512 3 | num_timesteps: 1000 4 | depth: 12 5 | dim_head: 64 6 | heads: 12 7 | diffusion: 8 | image_embed_dim: ${model.network.dim} 9 | timesteps: ${model.network.num_timesteps} 10 | cond_drop_prob: 0.2 11 | image_embed_scale: 1.0 12 | text_embed_scale: 1.0 13 | beta_schedule: "cosine" 14 | predict_x_start: True -------------------------------------------------------------------------------- /config/model/small.yaml: -------------------------------------------------------------------------------- 1 | network: 2 | dim: 512 3 | num_timesteps: 1000 4 | depth: 6 5 | dim_head: 64 6 | heads: 8 7 | diffusion: 8 | image_embed_dim: ${model.network.dim} 9 | timesteps: ${model.network.num_timesteps} 10 | cond_drop_prob: 0.2 11 | image_embed_scale: 1.0 12 | text_embed_scale: 1.0 13 | beta_schedule: "cosine" 14 | predict_x_start: True -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - model: mid 3 | - data: sg2-ffhq 4 | 5 | logging: wandb 6 | wandb_project: clip2latent 7 | wandb_entity: null 8 | name: null 9 | device: "cuda:0" 10 | resume: null 11 | train: 12 | znorm_embed: false 13 | znorm_latent: true 14 | max_it: 1_000_000 15 | val_it: 10_000 16 | lr: 1.0e-4 17 | weight_decay: 1.0e-2 18 | ema_update_every: 10 19 | ema_beta: 0.9999 20 | ema_power: 0.75 21 | -------------------------------------------------------------------------------- /config/data/sg3-lhq.yaml: -------------------------------------------------------------------------------- 1 | bs: 512 2 | format: webdataset 3 | path: data/webdataset/sg3-lhq-256-clip/{00000..99}.tar 4 | embed_noise_scale: 1.0 5 | sg_pkl: 'https://huggingface.co/justinpinkney/stylegan3-t-lhq-256/resolve/main/lhq-256-stylegan3-t-25Mimg.pkl' 6 | clip_variant: "ViT-B/32" 7 | n_latents: 1 8 | latent_dim: 512 9 | latent_repeats: [16] 10 | val_im_samples: 64 11 | val_text_samples: text/landscape-val.txt 12 | val_samples_per_text: 4 13 | -------------------------------------------------------------------------------- /config/data/sg2-ffhq.yaml: -------------------------------------------------------------------------------- 1 | bs: 512 2 | format: webdataset 3 | path: data/webdataset/sg2-ffhq-1024-clip/{00000..99}.tar 4 | embed_noise_scale: 1.0 5 | sg_pkl: 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl' 6 | clip_variant: "ViT-B/32" 7 | n_latents: 1 8 | latent_dim: 512 9 | latent_repeats: [18] 10 | val_im_samples: 64 11 | val_text_samples: text/face-val.txt 12 | val_samples_per_text: 4 13 | -------------------------------------------------------------------------------- /config/data/sg2-ffhq-w3.yaml: -------------------------------------------------------------------------------- 1 | bs: 512 2 | format: webdataset 3 | path: data/webdataset/sg2-ffhq-1024-w3/{00000..99}.tar 4 | embed_noise_scale: 0.75 5 | sg_pkl: 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl' 6 | clip_variant: "ViT-B/32" 7 | n_latents: 3 8 | latent_dim: 512 9 | latent_repeats: [4,4,10] 10 | val_im_samples: 64 11 | val_text_samples: text/face-val.txt 12 | val_samples_per_text: 4 13 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/viz/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/gui_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/training/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/environment.yml: -------------------------------------------------------------------------------- 1 | name: stylegan3 2 | channels: 3 | - pytorch 4 | - nvidia 5 | dependencies: 6 | - python >= 3.8 7 | - pip 8 | - numpy>=1.20 9 | - click>=8.0 10 | - pillow=8.3.1 11 | - scipy=1.7.1 12 | - pytorch=1.9.1 13 | - cudatoolkit=11.1 14 | - requests=2.26.0 15 | - tqdm=4.62.2 16 | - ninja=1.10.2 17 | - matplotlib=3.4.2 18 | - imageio=2.9.0 19 | - pip: 20 | - imgui==1.3.0 21 | - glfw==2.2.0 22 | - pyopengl==3.1.5 23 | - imageio-ffmpeg==0.4.3 24 | - pyspng 25 | -------------------------------------------------------------------------------- /config/data/sg2-encoded-ffhq.yaml: -------------------------------------------------------------------------------- 1 | bs: 512 2 | format: webdataset 3 | path: 4 | - data/webdataset/encoded/ffhq-encoded-clip.tar 5 | - data/webdataset/encoded/celeba-encoded-clip.tar 6 | embed_noise_scale: 0.75 7 | sg_pkl: 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl' 8 | clip_variant: "ViT-B/32" 9 | n_latents: 18 10 | latent_dim: 512 11 | latent_repeats: [1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1] 12 | val_im_samples: 64 13 | val_text_samples: text/face-val.txt 14 | val_samples_per_text: 4 15 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /text/landscape-val.txt: -------------------------------------------------------------------------------- 1 | A photograph of a beautiful beach in paradise 2 | Snowy mountains against a blue sky 3 | a landscape painting by J M W Turner 4 | A lighthouse in stormy seas 5 | A wonderful sunset 6 | A photograph of the grand canyon 7 | A moody haunted forest 8 | The land of Mordor 9 | A grassy field with flowers in the foreground and a hut in the distance 10 | Pleasant british countryside 11 | windows xp desktop background 12 | An alien landscape full of strange plants 13 | A snowscape with the northern lights above 14 | Mountains on the left and ocean on the right 15 | Cherry blossom trees in Japan with Mount Fuji in the background 16 | A boring landscape -------------------------------------------------------------------------------- /text/face-val.txt: -------------------------------------------------------------------------------- 1 | A photograph of a young man with a beard 2 | A photograph of a old woman's face with grey hair 3 | A photograph of a child at a birthday party 4 | A picture of a face outside in bright sun in front of green grass 5 | This man has bangs arched eyebrows curly hair and a small nose 6 | A photo of Barack Obama 7 | An arctic explorer 8 | A clown's face covered in make up 9 | A photo of an old person by the beach 10 | A portrait of Angela Merkel 11 | Young boy with sunglasses and an angry face 12 | Middle aged man with a moustache and a happy expression 13 | The face of a person who just finished running a marathon 14 | A vampire's face 15 | An office worker wearing a shirt looking stressed 16 | The Mona Lisa -------------------------------------------------------------------------------- /clip2latent/stylegan3/Dockerfile: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | FROM nvcr.io/nvidia/pytorch:21.08-py3 10 | 11 | ENV PYTHONDONTWRITEBYTECODE 1 12 | ENV PYTHONUNBUFFERED 1 13 | 14 | RUN pip install imageio imageio-ffmpeg==0.4.4 pyspng==0.1.0 15 | 16 | WORKDIR /workspace 17 | 18 | RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh 19 | ENTRYPOINT ["/entry.sh"] 20 | -------------------------------------------------------------------------------- /scripts/timings.py: -------------------------------------------------------------------------------- 1 | import math 2 | from clip2latent.models import Clip2StyleGAN 3 | from PIL import Image 4 | import torch 5 | from datetime import datetime 6 | import numpy as np 7 | 8 | device = "cuda:0" 9 | skips = 250 10 | model = Clip2StyleGAN("best.yaml", device=device, checkpoint="best.ckpt") 11 | inp = ["a photo",] 12 | warmup = 5 13 | measure = 10 14 | 15 | args = { 16 | "skips": skips, 17 | "n_samples_per_txt": 16, 18 | "clip_sort": True, 19 | "cond_scale": 2.0, 20 | "show_progress": False, 21 | } 22 | 23 | for i in range(warmup): 24 | out = model(inp, **args) 25 | torch.cuda.synchronize() 26 | 27 | times = [] 28 | for i in range(measure): 29 | start = datetime.now() 30 | out = model(inp, **args) 31 | torch.cuda.synchronize() 32 | taken = datetime.now() - start 33 | times.append(taken) 34 | print(taken) 35 | 36 | print("-------------") 37 | print(np.mean(times)) -------------------------------------------------------------------------------- /tests/train_test.py: -------------------------------------------------------------------------------- 1 | from hydra import initialize, compose 2 | from scripts import train, generate_dataset 3 | 4 | def test_end_to_end(tmp_path): 5 | """minimal training test""" 6 | 7 | n_samples = per_folder = 100 8 | gen_dir = tmp_path/"out" 9 | generate_dataset.main( 10 | out_dir=gen_dir, 11 | n_samples=n_samples, 12 | samples_per_folder=per_folder, 13 | ) 14 | 15 | tar_dir = tmp_path/"webdataset" 16 | generate_dataset.make_webdataset(gen_dir, tar_dir) 17 | with initialize(config_path="../config"): 18 | cfg = compose( 19 | config_name="config", 20 | overrides=[ 21 | "logging=null", 22 | f"data.path={tar_dir}/00000.tar", 23 | "model=small", 24 | "model.network.num_timesteps=10", 25 | "train.max_it=10", 26 | "train.val_it=9", 27 | ], 28 | ) 29 | train.main(cfg) 30 | 31 | -------------------------------------------------------------------------------- /scripts/check_dalle.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import torch 3 | from pathlib import Path 4 | from transformers import CLIPProcessor, CLIPModel 5 | from tqdm import tqdm 6 | 7 | with open("text/dalle.txt", 'rt') as f: 8 | data = f.readlines() 9 | data = [x.strip('\n') for x in data] 10 | 11 | dalle_data = Path("all-dalle-test") 12 | ims = list(dalle_data.glob("*.png")) 13 | 14 | model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") 15 | processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") 16 | similarities = [] 17 | 18 | with torch.no_grad(): 19 | for im in tqdm(ims): 20 | image = Image.open(im) 21 | matches = [x for x in data if x == str(im.name)[29:-4]] 22 | 23 | inputs = processor(text=matches, images=image, return_tensors="pt", padding=True) 24 | 25 | outputs = model(**inputs) 26 | sim = outputs.logits_per_image/100 27 | print(sim) 28 | similarities.append(sim) 29 | 30 | print(torch.cat(similarities).mean()) -------------------------------------------------------------------------------- /tests/data_test.py: -------------------------------------------------------------------------------- 1 | from scripts.generate_dataset import main 2 | 3 | import numpy as np 4 | import pytest 5 | 6 | @pytest.mark.parametrize("save_im", [True, False]) 7 | def test_generate(tmp_path, save_im): 8 | n_samples = 100 9 | per_folder = 50 10 | out_dir = tmp_path/"out" 11 | main(out_dir=out_dir, n_samples=n_samples, samples_per_folder=per_folder, save_im=save_im) 12 | 13 | created_folders = list(out_dir.glob("*")) 14 | assert len(created_folders) == n_samples // per_folder 15 | for d in created_folders: 16 | ims = list(d.glob("*.jpg")) 17 | npy = list(d.glob("*.npy")) 18 | 19 | if save_im: 20 | assert len(ims) == per_folder 21 | assert len(npy) == 2*per_folder # latent and embedding 22 | 23 | # Defaults for sg2 and clip vitb/32 24 | embed = np.load(list(d.glob("*.img_feat.npy"))[0]) 25 | assert embed.shape == (512,) 26 | 27 | latent = np.load(list(d.glob("*.latent.npy"))[0]) 28 | assert latent.shape == (512,) 29 | 30 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Justin Pinkney 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /clip2latent/stylegan3/.github/ISSUE_TEMPLATE/bug_report.md: -------------------------------------------------------------------------------- 1 | --- 2 | name: Bug report 3 | about: Create a report to help us improve 4 | title: '' 5 | labels: '' 6 | assignees: '' 7 | 8 | --- 9 | 10 | **Describe the bug** 11 | A clear and concise description of what the bug is. 12 | 13 | **To Reproduce** 14 | Steps to reproduce the behavior: 15 | 1. In '...' directory, run command '...' 16 | 2. See error (copy&paste full log, including exceptions and **stacktraces**). 17 | 18 | Please copy&paste text instead of screenshots for better searchability. 19 | 20 | **Expected behavior** 21 | A clear and concise description of what you expected to happen. 22 | 23 | **Screenshots** 24 | If applicable, add screenshots to help explain your problem. 25 | 26 | **Desktop (please complete the following information):** 27 | - OS: [e.g. Linux Ubuntu 20.04, Windows 10] 28 | - PyTorch version (e.g., pytorch 1.9.0) 29 | - CUDA toolkit version (e.g., CUDA 11.4) 30 | - NVIDIA driver version 31 | - GPU [e.g., Titan V, RTX 3090] 32 | - Docker: did you use Docker? If yes, specify docker image URL (e.g., nvcr.io/nvidia/pytorch:21.08-py3) 33 | 34 | **Additional context** 35 | Add any other context about the problem here. 36 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /tests/model_test.py: -------------------------------------------------------------------------------- 1 | from omegaconf import OmegaConf 2 | from clip2latent.latent_prior import LatentPrior, WPlusPriorNetwork 3 | from dalle2_pytorch import DiffusionPriorNetwork 4 | import torch 5 | import pytest 6 | from hydra import initialize, compose 7 | 8 | from clip2latent.models import Clip2StyleGAN, load_models 9 | 10 | @pytest.mark.parametrize( 11 | ["n_latents", "repeats"], 12 | [(1, (18,)), (3, (1, 4, 13)), (18, 18*(1,))], 13 | ) 14 | def test_prior_sample(n_latents, repeats): 15 | dim = 8 16 | bs = 2 17 | t = 4 18 | out_latents = 18 19 | if n_latents == 1: 20 | net = DiffusionPriorNetwork(dim=dim, num_timesteps=t, depth=2, dim_head=8, heads=2) 21 | else: 22 | net = WPlusPriorNetwork(n_latents=n_latents, dim=dim, num_timesteps=t, depth=2, dim_head=8, heads=2) 23 | prior = LatentPrior(net, image_embed_dim=dim, timesteps=t, latent_repeats=repeats, num_latents=n_latents) 24 | 25 | inp = torch.ones(bs, dim) 26 | out = prior.sample(inp) 27 | 28 | assert out.shape == (bs, out_latents, dim) 29 | 30 | # TODO clipper test 31 | 32 | def test_end_to_end(): 33 | with initialize(config_path="../config"): 34 | cfg = compose(config_name="config") 35 | 36 | device = "cuda" 37 | model = Clip2StyleGAN(cfg, device) 38 | inp = ["a text prompt", "a different prompt"] 39 | n_samples = 3 40 | 41 | out, sim = model(inp, n_samples_per_txt=n_samples, clip_sort=True) 42 | 43 | assert out.shape == (n_samples*len(inp), 3, 1024, 1024) 44 | assert sim.shape == (n_samples*len(inp),) 45 | assert sim[0] >= sim[-1] -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /text/face-test.txt: -------------------------------------------------------------------------------- 1 | a person with glasses 2 | a person with brown hair 3 | a person with curly blonde hair 4 | a person with a hat 5 | a person with bushy eyebrows and a small mouth 6 | a person smiling 7 | a person who is angry 8 | a person looking up at the sky 9 | a person with their eyes closed 10 | a person talking 11 | a man with a beard 12 | a happy man with a moustache 13 | a young man 14 | an old man 15 | a middle aged man 16 | a youthful man with a bored expression 17 | a woman with a hat 18 | a happy woman with glasses 19 | a young woman 20 | an old woman 21 | a middle aged woman 22 | a baby crying in a red bouncer 23 | a child with blue eyes and straight brown hair in the sunshine 24 | an old woman with large sunglasses and ear rings 25 | a young man with a bald head who is wearing necklace in the city at night 26 | a youthful woman with a bored expression 27 | President Xi Jinping 28 | Prime Minister Boris Johnson 29 | President Joe Biden 30 | President Barack Obama 31 | Chancellor Angela Merkel 32 | President Emmanuel Macron 33 | Prime Minister Shinzo Abe 34 | Robert De Niro 35 | Danny Devito 36 | Denzel Washington 37 | Meryl Streep 38 | Cate Blanchett 39 | Morgan Freeman 40 | Whoopi Goldberg 41 | Usain Bolt 42 | Muhammad Ali 43 | Serena Williams 44 | Roger Federer 45 | Martina Navratilova 46 | Jessica Ennis-Hill 47 | Cathy Freeman 48 | Christiano Ronaldo 49 | Elsa from Frozen 50 | Eric Cartman from South Park 51 | Chihiro from Spirited Away 52 | Bart from the Simpsons 53 | Woody from Toy Story 54 | a university graduate 55 | a firefighter 56 | a police officer 57 | a butcher 58 | a scientist 59 | a gardener 60 | a hairdresser 61 | a man visiting the beach 62 | a woman giving a TED talk 63 | a child playing with friends 64 | a person watching birds in the forest -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /clip2latent/data.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | import webdataset as wds 3 | import torch 4 | import hydra 5 | 6 | from clip2latent import train_utils 7 | 8 | def identity(x): 9 | return x 10 | 11 | def add_noise(x, scale=0.75): 12 | orig_norm = x.norm(dim=-1, keepdim=True) 13 | x = x/orig_norm 14 | noise = torch.randn_like(x) 15 | noise /= noise.norm(dim=-1, keepdim=True) 16 | x += scale*noise 17 | x /= x.norm(dim=-1, keepdim=True) 18 | x *= orig_norm 19 | return x 20 | 21 | def load_data(cfg, n_stats=10_000, shuffle=5000, n_workers=16): 22 | """Create train and validation data from a config""" 23 | 24 | if cfg.format != "webdataset": 25 | raise NotImplementedError() 26 | 27 | try: 28 | data_path = hydra.utils.to_absolute_path(cfg.path) 29 | except TypeError: 30 | # We might specify multiple paths 31 | data_path = [hydra.utils.to_absolute_path(x) for x in cfg.path] 32 | stats_ds = wds.WebDataset(data_path).decode().to_tuple('img_feat.npy', 'latent.npy').shuffle(shuffle).batched(n_stats) 33 | stats_data = next(stats_ds.__iter__()) 34 | 35 | stats = { 36 | "clip_features": train_utils.make_data_stats(torch.tensor(stats_data[0])), 37 | "w": train_utils.make_data_stats(torch.tensor(stats_data[1])), 38 | } 39 | 40 | ds = ( 41 | wds.WebDataset(data_path) 42 | .shuffle(shuffle) 43 | .decode() 44 | .to_tuple('img_feat.npy', 'latent.npy') 45 | .batched(cfg.bs) 46 | .map_tuple(torch.tensor, torch.tensor) 47 | ) 48 | if cfg.embed_noise_scale > 0: 49 | ds = ds.map_tuple(partial(add_noise, scale=cfg.embed_noise_scale), identity) 50 | ds = ds.map_tuple(identity, partial(train_utils.normalise_data, w_mean=stats["w"][0], w_std=stats["w"][1])) 51 | 52 | loader = wds.WebLoader(ds, num_workers=n_workers, batch_size=None) 53 | return stats,loader 54 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/metrics/inception_score.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Inception Score (IS) from the paper "Improved techniques for training 10 | GANs". Matches the original implementation by Salimans et al. at 11 | https://github.com/openai/improved-gan/blob/master/inception_score/model.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_is(opts, num_gen, num_splits): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(no_output_bias=True) # Match the original implementation by not applying bias in the softmax layer. 22 | 23 | gen_probs = metric_utils.compute_feature_stats_for_generator( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | capture_all=True, max_items=num_gen).get_all() 26 | 27 | if opts.rank != 0: 28 | return float('nan'), float('nan') 29 | 30 | scores = [] 31 | for i in range(num_splits): 32 | part = gen_probs[i * num_gen // num_splits : (i + 1) * num_gen // num_splits] 33 | kl = part * (np.log(part) - np.log(np.mean(part, axis=0, keepdims=True))) 34 | kl = np.mean(np.sum(kl, axis=1)) 35 | scores.append(np.exp(kl)) 36 | return float(np.mean(scores)), float(np.std(scores)) 37 | 38 | #---------------------------------------------------------------------------- 39 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/metrics/frechet_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Frechet Inception Distance (FID) from the paper 10 | "GANs trained by a two time-scale update rule converge to a local Nash 11 | equilibrium". Matches the original implementation by Heusel et al. at 12 | https://github.com/bioinf-jku/TTUR/blob/master/fid.py""" 13 | 14 | import numpy as np 15 | import scipy.linalg 16 | from . import metric_utils 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | def compute_fid(opts, max_real, num_gen): 21 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 22 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 23 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 24 | 25 | mu_real, sigma_real = metric_utils.compute_feature_stats_for_dataset( 26 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 27 | rel_lo=0, rel_hi=0, capture_mean_cov=True, max_items=max_real).get_mean_cov() 28 | 29 | mu_gen, sigma_gen = metric_utils.compute_feature_stats_for_generator( 30 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 31 | rel_lo=0, rel_hi=1, capture_mean_cov=True, max_items=num_gen).get_mean_cov() 32 | 33 | if opts.rank != 0: 34 | return float('nan') 35 | 36 | m = np.square(mu_gen - mu_real).sum() 37 | s, _ = scipy.linalg.sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member 38 | fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2)) 39 | return float(fid) 40 | 41 | #---------------------------------------------------------------------------- 42 | -------------------------------------------------------------------------------- /scripts/eval_fid.py: -------------------------------------------------------------------------------- 1 | from textwrap import wrap 2 | from typing import List 3 | from omegaconf import OmegaConf 4 | import torch 5 | from pathlib import Path 6 | import numpy as np 7 | from PIL import Image 8 | import train 9 | from joblib import Parallel, delayed 10 | from tqdm import tqdm 11 | from datetime import datetime 12 | import typer 13 | 14 | 15 | 16 | class FidWrapper(): 17 | def __init__(self, samples, model, cond_scale=1.0) -> None: 18 | self.samples = samples 19 | self.count = 0 20 | self.model = model 21 | self.cond_scale = cond_scale 22 | 23 | def forward(self, z): 24 | bs = z.shape[0] 25 | inp = self.samples[self.count:(self.count+bs)] 26 | self.count += bs 27 | images, _ = self.model(inp, cond_scale=self.cond_scale) 28 | images = (255*(images.clamp(-1,1)*0.5 + 0.5)).to(torch.uint8) 29 | 30 | # im = Image.fromarray(images[0].detach().cpu().permute(1,2,0).numpy()).save("temp.jpg") 31 | return images 32 | 33 | def __len__(self): 34 | return len(self.samples) 35 | 36 | 37 | def main( 38 | skips:List[int]=[1,100,250], 39 | cond_scales:List[float]=[1,1.05,1.1,1.2,1.3,1.5,1.75,2,2.5,3,4,5,10], 40 | write_results:bool = True, 41 | checkpoint:str = "best.ckpt", 42 | cfg_file:str= "best.yaml", 43 | device:str= "cuda:0", 44 | n_samples:int= 16, 45 | truncation:float=1.0, 46 | ): 47 | 48 | typer.echo(f"Running skips: {skips}") 49 | typer.echo(f"Running cond scales: {cond_scales}") 50 | 51 | with open("celeba-samples.txt", 'rt') as f: 52 | text_samples = f.read().splitlines() 53 | 54 | model = Clip2StyleGAN(cfg_file, device, checkpoint, skips=100) 55 | model.to(device) 56 | model.eval() 57 | 58 | 59 | wrapper = FidWrapper(text_samples, model, cond_scale=1.0) 60 | print(len(wrapper)) 61 | 62 | from cleanfid import fid 63 | # function that accepts a latent and returns an image in range[0,255] 64 | gen = lambda z: wrapper.forward(z) 65 | score = fid.compute_fid(fdir2="/mnt/data_rome/laion/tedigan-data/CelebAMask-HQ/CelebA-HQ-img", gen=gen, 66 | batch_size=40, num_gen=len(wrapper)) 67 | 68 | print(score) 69 | 70 | 71 | if __name__ == "__main__": 72 | typer.run(main) -------------------------------------------------------------------------------- /clip2latent/stylegan3/docs/troubleshooting.md: -------------------------------------------------------------------------------- 1 | # Troubleshooting 2 | 3 | Our PyTorch code uses custom [CUDA extensions](https://pytorch.org/tutorials/advanced/cpp_extension.html) to speed up some of the network layers. Getting these to run can sometimes be a hassle. 4 | 5 | This page aims to give guidance on how to diagnose and fix run-time problems related to these extensions. 6 | 7 | ## Before you start 8 | 9 | 1. Try Docker first! Ensure you can successfully run our models using the recommended Docker image. Follow the instructions in [README.md](/README.md) to get it running. 10 | 2. Can't use Docker? Read on.. 11 | 12 | ## Installing dependencies 13 | 14 | Make sure you've installed everything listed on the requirements section in the [README.md](/README.md). The key components w.r.t. custom extensions are: 15 | 16 | - **[CUDA toolkit 11.1](https://developer.nvidia.com/cuda-toolkit)** or later (this is not the same as `cudatoolkit` from Conda). 17 | - PyTorch invokes `nvcc` to compile our CUDA kernels. 18 | - **ninja** 19 | - PyTorch uses [Ninja](https://ninja-build.org/) as its build system. 20 | - **GCC** (Linux) or **Visual Studio** (Windows) 21 | - GCC 7.x or later is required. Earlier versions such as GCC 6.3 [are known not to work](https://github.com/NVlabs/stylegan3/issues/2). 22 | 23 | #### Why is CUDA toolkit installation necessary? 24 | 25 | The PyTorch package contains the required CUDA toolkit libraries needed to run PyTorch, so why is a separate CUDA toolkit installation required? Our models use custom CUDA kernels to implement operations such as efficient resampling of 2D images. PyTorch code invokes the CUDA compiler at run-time to compile these kernels on first-use. The tools and libraries required for this compilation are not bundled in PyTorch and thus a host CUDA toolkit installation is required. 26 | 27 | ## Things to try 28 | 29 | - Completely remove: `$HOME/.cache/torch_extensions` (Linux) or `C:\Users\\AppData\Local\torch_extensions\torch_extensions\Cache` (Windows) and re-run StyleGAN3 python code. 30 | - Run ninja in `$HOME/.cache/torch_extensions` to see that it builds. 31 | - Inspect the `build.ninja` in the build directories under `$HOME/.cache/torch_extensions` and check CUDA tools and versions are consistent with what you intended to use. 32 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/metrics/kernel_inception_distance.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Kernel Inception Distance (KID) from the paper "Demystifying MMD 10 | GANs". Matches the original implementation by Binkowski et al. at 11 | https://github.com/mbinkowski/MMD-GAN/blob/master/gan/compute_scores.py""" 12 | 13 | import numpy as np 14 | from . import metric_utils 15 | 16 | #---------------------------------------------------------------------------- 17 | 18 | def compute_kid(opts, max_real, num_gen, num_subsets, max_subset_size): 19 | # Direct TorchScript translation of http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz 20 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/inception-2015-12-05.pkl' 21 | detector_kwargs = dict(return_features=True) # Return raw features before the softmax layer. 22 | 23 | real_features = metric_utils.compute_feature_stats_for_dataset( 24 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 25 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all() 26 | 27 | gen_features = metric_utils.compute_feature_stats_for_generator( 28 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 29 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all() 30 | 31 | if opts.rank != 0: 32 | return float('nan') 33 | 34 | n = real_features.shape[1] 35 | m = min(min(real_features.shape[0], gen_features.shape[0]), max_subset_size) 36 | t = 0 37 | for _subset_idx in range(num_subsets): 38 | x = gen_features[np.random.choice(gen_features.shape[0], m, replace=False)] 39 | y = real_features[np.random.choice(real_features.shape[0], m, replace=False)] 40 | a = (x @ x.T / n + 1) ** 3 + (y @ y.T / n + 1) ** 3 41 | b = (x @ y.T / n + 1) ** 3 42 | t += (a.sum() - np.diag(a).sum()) / (m - 1) - b.sum() * 2 / m 43 | kid = t / num_subsets / m 44 | return float(kid) 45 | 46 | #---------------------------------------------------------------------------- 47 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | runs/ 2 | checkpoints/ 3 | data/ 4 | wandb/ 5 | outputs 6 | .vscode/ 7 | data 8 | 9 | # Byte-compiled / optimized / DLL files 10 | __pycache__/ 11 | *.py[cod] 12 | *$py.class 13 | 14 | # C extensions 15 | *.so 16 | 17 | # Distribution / packaging 18 | .Python 19 | build/ 20 | develop-eggs/ 21 | dist/ 22 | downloads/ 23 | eggs/ 24 | .eggs/ 25 | lib/ 26 | lib64/ 27 | parts/ 28 | sdist/ 29 | var/ 30 | wheels/ 31 | pip-wheel-metadata/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | 62 | # Translations 63 | *.mo 64 | *.pot 65 | 66 | # Django stuff: 67 | *.log 68 | local_settings.py 69 | db.sqlite3 70 | db.sqlite3-journal 71 | 72 | # Flask stuff: 73 | instance/ 74 | .webassets-cache 75 | 76 | # Scrapy stuff: 77 | .scrapy 78 | 79 | # Sphinx documentation 80 | docs/_build/ 81 | 82 | # PyBuilder 83 | target/ 84 | 85 | # Jupyter Notebook 86 | .ipynb_checkpoints 87 | 88 | # IPython 89 | profile_default/ 90 | ipython_config.py 91 | 92 | # pyenv 93 | .python-version 94 | 95 | # pipenv 96 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 97 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 98 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 99 | # install all needed dependencies. 100 | #Pipfile.lock 101 | 102 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 103 | __pypackages__/ 104 | 105 | # Celery stuff 106 | celerybeat-schedule 107 | celerybeat.pid 108 | 109 | # SageMath parsed files 110 | *.sage.py 111 | 112 | # Environments 113 | .env 114 | .venv 115 | env/ 116 | venv/ 117 | ENV/ 118 | env.bak/ 119 | venv.bak/ 120 | 121 | # Spyder project settings 122 | .spyderproject 123 | .spyproject 124 | 125 | # Rope project settings 126 | .ropeproject 127 | 128 | # mkdocs documentation 129 | /site 130 | 131 | # mypy 132 | .mypy_cache/ 133 | .dmypy.json 134 | dmypy.json 135 | 136 | # Pyre type checker 137 | .pyre/ 138 | -------------------------------------------------------------------------------- /clip2latent/train_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from tqdm.auto import tqdm 4 | from PIL import Image 5 | import torchvision 6 | 7 | def make_data_stats(w): 8 | w_mean = w.mean(dim=0) 9 | w_std = w.std(dim=0) 10 | return w_mean, w_std 11 | 12 | def normalise_data(w, w_mean, w_std): 13 | device = w.device 14 | w = w - w_mean.to(device) 15 | w = w / w_std.to(device) 16 | return w 17 | 18 | def denormalise_data(w, w_mean, w_std): 19 | device = w.device 20 | w = w * w_std.to(device) 21 | w = w + w_mean.to(device) 22 | return w 23 | 24 | def make_grid(ims, pil=True, resize=True): 25 | if resize: 26 | ims = F.interpolate(ims, size=(256,256)) 27 | grid = torchvision.utils.make_grid( 28 | ims.clamp(-1,1), 29 | normalize=True, 30 | value_range=(-1,1), 31 | nrow=4, 32 | ) 33 | if pil: 34 | grid = Image.fromarray((255*grid).to(torch.uint8).permute(1,2,0).detach().cpu().numpy()) 35 | return grid 36 | 37 | @torch.no_grad() 38 | def make_image_val_data(G, clip_model, n_im_val_samples, device, latent_dim=512): 39 | clip_features = [] 40 | 41 | zs = torch.randn((n_im_val_samples, latent_dim), device=device) 42 | ws = G.mapping(zs, c=None) 43 | for w in tqdm(ws): 44 | out = G.synthesis(w.unsqueeze(0)) 45 | image_features = clip_model.embed_image(out) 46 | clip_features.append(image_features) 47 | 48 | clip_features = torch.cat(clip_features, dim=0) 49 | val_data = { 50 | "clip_features": clip_features, 51 | "z": zs, 52 | "w": ws, 53 | } 54 | return val_data 55 | 56 | 57 | @torch.no_grad() 58 | def make_text_val_data(G, clip_model, text_samples_file): 59 | """Load text samples from file""" 60 | with open(text_samples_file, 'rt') as f: 61 | text_samples = f.read().splitlines() 62 | text_features = clip_model.embed_text(text_samples) 63 | val_data = {"clip_features": text_features,} 64 | return val_data, text_samples 65 | 66 | @torch.no_grad() 67 | def compute_val(diffusion, input_embed, G, clip_model, device, cond_scale=1.0, bs=8): 68 | 69 | diffusion.eval() 70 | images = [] 71 | inp = input_embed.to(device) 72 | out = diffusion.sample(inp, cond_scale=cond_scale) 73 | 74 | pred_w_clip_features = [] 75 | # batch in 1s to not worry about memory 76 | for w in out.chunk(bs): 77 | out = G.synthesis(w) 78 | images.append(out) 79 | image_features = clip_model.embed_image(out) 80 | pred_w_clip_features.append(image_features) 81 | 82 | pred_w_clip_features = torch.cat(pred_w_clip_features, dim=0) 83 | images = torch.cat(images, dim=0) 84 | 85 | y = input_embed/input_embed.norm(dim=1, keepdim=True) 86 | y_hat = pred_w_clip_features/pred_w_clip_features.norm(dim=1, keepdim=True) 87 | return torch.cosine_similarity(y, y_hat), images 88 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/docs/dataset-tool-help.txt: -------------------------------------------------------------------------------- 1 | Usage: dataset_tool.py [OPTIONS] 2 | 3 | Convert an image dataset into a dataset archive usable with StyleGAN2 ADA 4 | PyTorch. 5 | 6 | The input dataset format is guessed from the --source argument: 7 | 8 | --source *_lmdb/ Load LSUN dataset 9 | --source cifar-10-python.tar.gz Load CIFAR-10 dataset 10 | --source train-images-idx3-ubyte.gz Load MNIST dataset 11 | --source path/ Recursively load all images from path/ 12 | --source dataset.zip Recursively load all images from dataset.zip 13 | 14 | Specifying the output format and path: 15 | 16 | --dest /path/to/dir Save output files under /path/to/dir 17 | --dest /path/to/dataset.zip Save output files into /path/to/dataset.zip 18 | 19 | The output dataset format can be either an image folder or an uncompressed 20 | zip archive. Zip archives makes it easier to move datasets around file 21 | servers and clusters, and may offer better training performance on network 22 | file systems. 23 | 24 | Images within the dataset archive will be stored as uncompressed PNG. 25 | Uncompresed PNGs can be efficiently decoded in the training loop. 26 | 27 | Class labels are stored in a file called 'dataset.json' that is stored at 28 | the dataset root folder. This file has the following structure: 29 | 30 | { 31 | "labels": [ 32 | ["00000/img00000000.png",6], 33 | ["00000/img00000001.png",9], 34 | ... repeated for every image in the datase 35 | ["00049/img00049999.png",1] 36 | ] 37 | } 38 | 39 | If the 'dataset.json' file cannot be found, the dataset is interpreted as 40 | not containing class labels. 41 | 42 | Image scale/crop and resolution requirements: 43 | 44 | Output images must be square-shaped and they must all have the same power- 45 | of-two dimensions. 46 | 47 | To scale arbitrary input image size to a specific width and height, use 48 | the --resolution option. Output resolution will be either the original 49 | input resolution (if resolution was not specified) or the one specified 50 | with --resolution option. 51 | 52 | Use the --transform=center-crop or --transform=center-crop-wide options to 53 | apply a center crop transform on the input image. These options should be 54 | used with the --resolution option. For example: 55 | 56 | python dataset_tool.py --source LSUN/raw/cat_lmdb --dest /tmp/lsun_cat \ 57 | --transform=center-crop-wide --resolution=512x384 58 | 59 | Options: 60 | --source PATH Directory or archive name for input dataset 61 | [required] 62 | 63 | --dest PATH Output directory or archive name for output 64 | dataset [required] 65 | 66 | --max-images INTEGER Output only up to `max-images` images 67 | --transform [center-crop|center-crop-wide] 68 | Input crop/resize mode 69 | --resolution WxH Output resolution (e.g., '512x512') 70 | --help Show this message and exit. 71 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/viz/stylemix_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class StyleMixingWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.seed_def = 1000 18 | self.seed = self.seed_def 19 | self.animate = False 20 | self.enables = [] 21 | 22 | @imgui_utils.scoped_by_object_id 23 | def __call__(self, show=True): 24 | viz = self.viz 25 | num_ws = viz.result.get('num_ws', 0) 26 | num_enables = viz.result.get('num_ws', 18) 27 | self.enables += [False] * max(num_enables - len(self.enables), 0) 28 | 29 | if show: 30 | imgui.text('Stylemix') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8), imgui_utils.grayed_out(num_ws == 0): 33 | _changed, self.seed = imgui.input_int('##seed', self.seed) 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | with imgui_utils.grayed_out(num_ws == 0): 36 | _clicked, self.animate = imgui.checkbox('Anim', self.animate) 37 | 38 | pos2 = imgui.get_content_region_max()[0] - 1 - viz.button_w 39 | pos1 = pos2 - imgui.get_text_line_height() - viz.spacing 40 | pos0 = viz.label_w + viz.font_size * 12 41 | imgui.push_style_var(imgui.STYLE_FRAME_PADDING, [0, 0]) 42 | for idx in range(num_enables): 43 | imgui.same_line(round(pos0 + (pos1 - pos0) * (idx / (num_enables - 1)))) 44 | if idx == 0: 45 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() + 3) 46 | with imgui_utils.grayed_out(num_ws == 0): 47 | _clicked, self.enables[idx] = imgui.checkbox(f'##{idx}', self.enables[idx]) 48 | if imgui.is_item_hovered(): 49 | imgui.set_tooltip(f'{idx}') 50 | imgui.pop_style_var(1) 51 | 52 | imgui.same_line(pos2) 53 | imgui.set_cursor_pos_y(imgui.get_cursor_pos_y() - 3) 54 | with imgui_utils.grayed_out(num_ws == 0): 55 | if imgui_utils.button('Reset', width=-1, enabled=(self.seed != self.seed_def or self.animate or any(self.enables[:num_enables]))): 56 | self.seed = self.seed_def 57 | self.animate = False 58 | self.enables = [False] * num_enables 59 | 60 | if any(self.enables[:num_ws]): 61 | viz.args.stylemix_idx = [idx for idx, enable in enumerate(self.enables) if enable] 62 | viz.args.stylemix_seed = self.seed & ((1 << 32) - 1) 63 | if self.animate: 64 | self.seed += 1 65 | 66 | #---------------------------------------------------------------------------- 67 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/docs/train-help.txt: -------------------------------------------------------------------------------- 1 | Usage: train.py [OPTIONS] 2 | 3 | Train a GAN using the techniques described in the paper "Alias-Free 4 | Generative Adversarial Networks". 5 | 6 | Examples: 7 | 8 | # Train StyleGAN3-T for AFHQv2 using 8 GPUs. 9 | python train.py --outdir=~/training-runs --cfg=stylegan3-t --data=~/datasets/afhqv2-512x512.zip \ 10 | --gpus=8 --batch=32 --gamma=8.2 --mirror=1 11 | 12 | # Fine-tune StyleGAN3-R for MetFaces-U using 1 GPU, starting from the pre-trained FFHQ-U pickle. 13 | python train.py --outdir=~/training-runs --cfg=stylegan3-r --data=~/datasets/metfacesu-1024x1024.zip \ 14 | --gpus=8 --batch=32 --gamma=6.6 --mirror=1 --kimg=5000 --snap=5 \ 15 | --resume=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-ffhqu-1024x1024.pkl 16 | 17 | # Train StyleGAN2 for FFHQ at 1024x1024 resolution using 8 GPUs. 18 | python train.py --outdir=~/training-runs --cfg=stylegan2 --data=~/datasets/ffhq-1024x1024.zip \ 19 | --gpus=8 --batch=32 --gamma=10 --mirror=1 --aug=noaug 20 | 21 | Options: 22 | --outdir DIR Where to save the results [required] 23 | --cfg [stylegan3-t|stylegan3-r|stylegan2] 24 | Base configuration [required] 25 | --data [ZIP|DIR] Training data [required] 26 | --gpus INT Number of GPUs to use [required] 27 | --batch INT Total batch size [required] 28 | --gamma FLOAT R1 regularization weight [required] 29 | --cond BOOL Train conditional model [default: False] 30 | --mirror BOOL Enable dataset x-flips [default: False] 31 | --aug [noaug|ada|fixed] Augmentation mode [default: ada] 32 | --resume [PATH|URL] Resume from given network pickle 33 | --freezed INT Freeze first layers of D [default: 0] 34 | --p FLOAT Probability for --aug=fixed [default: 0.2] 35 | --target FLOAT Target value for --aug=ada [default: 0.6] 36 | --batch-gpu INT Limit batch size per GPU 37 | --cbase INT Capacity multiplier [default: 32768] 38 | --cmax INT Max. feature maps [default: 512] 39 | --glr FLOAT G learning rate [default: varies] 40 | --dlr FLOAT D learning rate [default: 0.002] 41 | --map-depth INT Mapping network depth [default: varies] 42 | --mbstd-group INT Minibatch std group size [default: 4] 43 | --desc STR String to include in result dir name 44 | --metrics [NAME|A,B,C|none] Quality metrics [default: fid50k_full] 45 | --kimg KIMG Total training duration [default: 25000] 46 | --tick KIMG How often to print progress [default: 4] 47 | --snap TICKS How often to save snapshots [default: 50] 48 | --seed INT Random seed [default: 0] 49 | --fp32 BOOL Disable mixed-precision [default: False] 50 | --nobench BOOL Disable cuDNN benchmarking [default: False] 51 | --workers INT DataLoader worker processes [default: 3] 52 | -n, --dry-run Print training options and exit 53 | --help Show this message and exit. 54 | -------------------------------------------------------------------------------- /colab-gradio.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": { 6 | "id": "view-in-github", 7 | "colab_type": "text" 8 | }, 9 | "source": [ 10 | "\"Open" 11 | ] 12 | }, 13 | { 14 | "cell_type": "code", 15 | "source": [ 16 | "!git clone https://github.com/justinpinkney/clip2latent.git\n", 17 | "%cd /content/clip2latent" 18 | ], 19 | "metadata": { 20 | "id": "HrChV-UDSihz" 21 | }, 22 | "execution_count": null, 23 | "outputs": [] 24 | }, 25 | { 26 | "cell_type": "code", 27 | "source": [ 28 | "!pip install -r requirements-colab.txt\n", 29 | "!pip install -e ." 30 | ], 31 | "metadata": { 32 | "id": "wUjSFQ9DTf_P" 33 | }, 34 | "execution_count": null, 35 | "outputs": [] 36 | }, 37 | { 38 | "cell_type": "code", 39 | "source": [ 40 | "from scripts import demo\n", 41 | "demo.main()" 42 | ], 43 | "metadata": { 44 | "colab": { 45 | "base_uri": "https://localhost:8080/", 46 | "height": 592 47 | }, 48 | "id": "7bI6NuX-T4BZ", 49 | "outputId": "2d97eea3-c6a0-49f2-9e3c-c48bdd7410df" 50 | }, 51 | "execution_count": null, 52 | "outputs": [ 53 | { 54 | "output_type": "stream", 55 | "name": "stdout", 56 | "text": [ 57 | "Colab notebook detected. To show errors in colab notebook, set `debug=True` in `launch()`\n", 58 | "Running on public URL: https://21364.gradio.app\n", 59 | "\n", 60 | "This share link expires in 72 hours. For free permanent hosting, check out Spaces: https://huggingface.co/spaces\n" 61 | ] 62 | }, 63 | { 64 | "output_type": "display_data", 65 | "data": { 66 | "text/plain": [ 67 | "" 68 | ], 69 | "text/html": [ 70 | "
" 71 | ] 72 | }, 73 | "metadata": {} 74 | } 75 | ] 76 | } 77 | ], 78 | "metadata": { 79 | "kernelspec": { 80 | "display_name": "Python 3.9.12 ('dalleprior')", 81 | "language": "python", 82 | "name": "python3" 83 | }, 84 | "language_info": { 85 | "codemirror_mode": { 86 | "name": "ipython", 87 | "version": 3 88 | }, 89 | "file_extension": ".py", 90 | "mimetype": "text/x-python", 91 | "name": "python", 92 | "nbconvert_exporter": "python", 93 | "pygments_lexer": "ipython3", 94 | "version": "3.9.12" 95 | }, 96 | "orig_nbformat": 4, 97 | "vscode": { 98 | "interpreter": { 99 | "hash": "32436db2375e982882eadc95defe3b8ca59bd5ef9d38e4f83b7c12d3cf0b3436" 100 | } 101 | }, 102 | "colab": { 103 | "provenance": [], 104 | "include_colab_link": true 105 | }, 106 | "accelerator": "GPU", 107 | "gpuClass": "standard" 108 | }, 109 | "nbformat": 4, 110 | "nbformat_minor": 0 111 | } -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | from pkg_resources import parse_version 16 | 17 | # pylint: disable=redefined-builtin 18 | # pylint: disable=arguments-differ 19 | # pylint: disable=protected-access 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | enabled = False # Enable the custom op by setting this to true. 24 | _use_pytorch_1_11_api = parse_version(torch.__version__) >= parse_version('1.11.0a') # Allow prerelease builds of 1.11 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def grid_sample(input, grid): 29 | if _should_use_custom_op(): 30 | return _GridSample2dForward.apply(input, grid) 31 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def _should_use_custom_op(): 36 | return enabled 37 | 38 | #---------------------------------------------------------------------------- 39 | 40 | class _GridSample2dForward(torch.autograd.Function): 41 | @staticmethod 42 | def forward(ctx, input, grid): 43 | assert input.ndim == 4 44 | assert grid.ndim == 4 45 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 46 | ctx.save_for_backward(input, grid) 47 | return output 48 | 49 | @staticmethod 50 | def backward(ctx, grad_output): 51 | input, grid = ctx.saved_tensors 52 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 53 | return grad_input, grad_grid 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | class _GridSample2dBackward(torch.autograd.Function): 58 | @staticmethod 59 | def forward(ctx, grad_output, input, grid): 60 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 61 | if _use_pytorch_1_11_api: 62 | output_mask = (ctx.needs_input_grad[1], ctx.needs_input_grad[2]) 63 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False, output_mask) 64 | else: 65 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 66 | ctx.save_for_backward(grid) 67 | return grad_input, grad_grid 68 | 69 | @staticmethod 70 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 71 | _ = grad2_grad_grid # unused 72 | grid, = ctx.saved_tensors 73 | grad2_grad_output = None 74 | grad2_input = None 75 | grad2_grid = None 76 | 77 | if ctx.needs_input_grad[0]: 78 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 79 | 80 | assert not ctx.needs_input_grad[2] 81 | return grad2_grad_output, grad2_input, grad2_grid 82 | 83 | #---------------------------------------------------------------------------- 84 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/viz/performance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import array 10 | import numpy as np 11 | import imgui 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class PerformanceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.gui_times = [float('nan')] * 60 20 | self.render_times = [float('nan')] * 30 21 | self.fps_limit = 60 22 | self.use_vsync = False 23 | self.is_async = False 24 | self.force_fp32 = False 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | self.gui_times = self.gui_times[1:] + [viz.frame_delta] 30 | if 'render_time' in viz.result: 31 | self.render_times = self.render_times[1:] + [viz.result.render_time] 32 | del viz.result.render_time 33 | 34 | if show: 35 | imgui.text('GUI') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 8): 38 | imgui.plot_lines('##gui_times', array.array('f', self.gui_times), scale_min=0) 39 | imgui.same_line(viz.label_w + viz.font_size * 9) 40 | t = [x for x in self.gui_times if x > 0] 41 | t = np.mean(t) if len(t) > 0 else 0 42 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 43 | imgui.same_line(viz.label_w + viz.font_size * 14) 44 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 45 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 46 | with imgui_utils.item_width(viz.font_size * 6): 47 | _changed, self.fps_limit = imgui.input_int('FPS limit', self.fps_limit, flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 48 | self.fps_limit = min(max(self.fps_limit, 5), 1000) 49 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 50 | _clicked, self.use_vsync = imgui.checkbox('Vertical sync', self.use_vsync) 51 | 52 | if show: 53 | imgui.text('Render') 54 | imgui.same_line(viz.label_w) 55 | with imgui_utils.item_width(viz.font_size * 8): 56 | imgui.plot_lines('##render_times', array.array('f', self.render_times), scale_min=0) 57 | imgui.same_line(viz.label_w + viz.font_size * 9) 58 | t = [x for x in self.render_times if x > 0] 59 | t = np.mean(t) if len(t) > 0 else 0 60 | imgui.text(f'{t*1e3:.1f} ms' if t > 0 else 'N/A') 61 | imgui.same_line(viz.label_w + viz.font_size * 14) 62 | imgui.text(f'{1/t:.1f} FPS' if t > 0 else 'N/A') 63 | imgui.same_line(viz.label_w + viz.font_size * 18 + viz.spacing * 3) 64 | _clicked, self.is_async = imgui.checkbox('Separate process', self.is_async) 65 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w * 2 - viz.spacing) 66 | _clicked, self.force_fp32 = imgui.checkbox('Force FP32', self.force_fp32) 67 | 68 | viz.set_fps_limit(self.fps_limit) 69 | viz.set_vsync(self.use_vsync) 70 | viz.set_async(self.is_async) 71 | viz.args.force_fp32 = self.force_fp32 72 | 73 | #---------------------------------------------------------------------------- 74 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/metrics/precision_recall.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Precision/Recall (PR) from the paper "Improved Precision and Recall 10 | Metric for Assessing Generative Models". Matches the original implementation 11 | by Kynkaanniemi et al. at 12 | https://github.com/kynkaat/improved-precision-and-recall-metric/blob/master/precision_recall.py""" 13 | 14 | import torch 15 | from . import metric_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | def compute_distances(row_features, col_features, num_gpus, rank, col_batch_size): 20 | assert 0 <= rank < num_gpus 21 | num_cols = col_features.shape[0] 22 | num_batches = ((num_cols - 1) // col_batch_size // num_gpus + 1) * num_gpus 23 | col_batches = torch.nn.functional.pad(col_features, [0, 0, 0, -num_cols % num_batches]).chunk(num_batches) 24 | dist_batches = [] 25 | for col_batch in col_batches[rank :: num_gpus]: 26 | dist_batch = torch.cdist(row_features.unsqueeze(0), col_batch.unsqueeze(0))[0] 27 | for src in range(num_gpus): 28 | dist_broadcast = dist_batch.clone() 29 | if num_gpus > 1: 30 | torch.distributed.broadcast(dist_broadcast, src=src) 31 | dist_batches.append(dist_broadcast.cpu() if rank == 0 else None) 32 | return torch.cat(dist_batches, dim=1)[:, :num_cols] if rank == 0 else None 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def compute_pr(opts, max_real, num_gen, nhood_size, row_batch_size, col_batch_size): 37 | detector_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 38 | detector_kwargs = dict(return_features=True) 39 | 40 | real_features = metric_utils.compute_feature_stats_for_dataset( 41 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 42 | rel_lo=0, rel_hi=0, capture_all=True, max_items=max_real).get_all_torch().to(torch.float16).to(opts.device) 43 | 44 | gen_features = metric_utils.compute_feature_stats_for_generator( 45 | opts=opts, detector_url=detector_url, detector_kwargs=detector_kwargs, 46 | rel_lo=0, rel_hi=1, capture_all=True, max_items=num_gen).get_all_torch().to(torch.float16).to(opts.device) 47 | 48 | results = dict() 49 | for name, manifold, probes in [('precision', real_features, gen_features), ('recall', gen_features, real_features)]: 50 | kth = [] 51 | for manifold_batch in manifold.split(row_batch_size): 52 | dist = compute_distances(row_features=manifold_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 53 | kth.append(dist.to(torch.float32).kthvalue(nhood_size + 1).values.to(torch.float16) if opts.rank == 0 else None) 54 | kth = torch.cat(kth) if opts.rank == 0 else None 55 | pred = [] 56 | for probes_batch in probes.split(row_batch_size): 57 | dist = compute_distances(row_features=probes_batch, col_features=manifold, num_gpus=opts.num_gpus, rank=opts.rank, col_batch_size=col_batch_size) 58 | pred.append((dist <= kth).any(dim=1) if opts.rank == 0 else None) 59 | results[name] = float(torch.cat(pred).to(torch.float32).mean() if opts.rank == 0 else 'nan') 60 | return results['precision'], results['recall'] 61 | 62 | #---------------------------------------------------------------------------- 63 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/viz/capture_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import re 11 | import numpy as np 12 | import imgui 13 | import PIL.Image 14 | from gui_utils import imgui_utils 15 | from . import renderer 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class CaptureWidget: 20 | def __init__(self, viz): 21 | self.viz = viz 22 | self.path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '_screenshots')) 23 | self.dump_image = False 24 | self.dump_gui = False 25 | self.defer_frames = 0 26 | self.disabled_time = 0 27 | 28 | def dump_png(self, image): 29 | viz = self.viz 30 | try: 31 | _height, _width, channels = image.shape 32 | assert channels in [1, 3] 33 | assert image.dtype == np.uint8 34 | os.makedirs(self.path, exist_ok=True) 35 | file_id = 0 36 | for entry in os.scandir(self.path): 37 | if entry.is_file(): 38 | match = re.fullmatch(r'(\d+).*', entry.name) 39 | if match: 40 | file_id = max(file_id, int(match.group(1)) + 1) 41 | if channels == 1: 42 | pil_image = PIL.Image.fromarray(image[:, :, 0], 'L') 43 | else: 44 | pil_image = PIL.Image.fromarray(image, 'RGB') 45 | pil_image.save(os.path.join(self.path, f'{file_id:05d}.png')) 46 | except: 47 | viz.result.error = renderer.CapturedException() 48 | 49 | @imgui_utils.scoped_by_object_id 50 | def __call__(self, show=True): 51 | viz = self.viz 52 | if show: 53 | with imgui_utils.grayed_out(self.disabled_time != 0): 54 | imgui.text('Capture') 55 | imgui.same_line(viz.label_w) 56 | _changed, self.path = imgui_utils.input_text('##path', self.path, 1024, 57 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 58 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 59 | help_text='PATH') 60 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.path != '': 61 | imgui.set_tooltip(self.path) 62 | imgui.same_line() 63 | if imgui_utils.button('Save image', width=viz.button_w, enabled=(self.disabled_time == 0 and 'image' in viz.result)): 64 | self.dump_image = True 65 | self.defer_frames = 2 66 | self.disabled_time = 0.5 67 | imgui.same_line() 68 | if imgui_utils.button('Save GUI', width=-1, enabled=(self.disabled_time == 0)): 69 | self.dump_gui = True 70 | self.defer_frames = 2 71 | self.disabled_time = 0.5 72 | 73 | self.disabled_time = max(self.disabled_time - viz.frame_delta, 0) 74 | if self.defer_frames > 0: 75 | self.defer_frames -= 1 76 | elif self.dump_image: 77 | if 'image' in viz.result: 78 | self.dump_png(viz.result.image) 79 | self.dump_image = False 80 | elif self.dump_gui: 81 | viz.capture_next_frame() 82 | self.dump_gui = False 83 | captured_frame = viz.pop_captured_frame() 84 | if captured_frame is not None: 85 | self.dump_png(captured_frame) 86 | 87 | #---------------------------------------------------------------------------- 88 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/viz/trunc_noise_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import imgui 10 | from gui_utils import imgui_utils 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | class TruncationNoiseWidget: 15 | def __init__(self, viz): 16 | self.viz = viz 17 | self.prev_num_ws = 0 18 | self.trunc_psi = 1 19 | self.trunc_cutoff = 0 20 | self.noise_enable = True 21 | self.noise_seed = 0 22 | self.noise_anim = False 23 | 24 | @imgui_utils.scoped_by_object_id 25 | def __call__(self, show=True): 26 | viz = self.viz 27 | num_ws = viz.result.get('num_ws', 0) 28 | has_noise = viz.result.get('has_noise', False) 29 | if num_ws > 0 and num_ws != self.prev_num_ws: 30 | if self.trunc_cutoff > num_ws or self.trunc_cutoff == self.prev_num_ws: 31 | self.trunc_cutoff = num_ws 32 | self.prev_num_ws = num_ws 33 | 34 | if show: 35 | imgui.text('Truncate') 36 | imgui.same_line(viz.label_w) 37 | with imgui_utils.item_width(viz.font_size * 10), imgui_utils.grayed_out(num_ws == 0): 38 | _changed, self.trunc_psi = imgui.slider_float('##psi', self.trunc_psi, -1, 2, format='Psi %.2f') 39 | imgui.same_line() 40 | if num_ws == 0: 41 | imgui_utils.button('Cutoff 0', width=(viz.font_size * 8 + viz.spacing), enabled=False) 42 | else: 43 | with imgui_utils.item_width(viz.font_size * 8 + viz.spacing): 44 | changed, new_cutoff = imgui.slider_int('##cutoff', self.trunc_cutoff, 0, num_ws, format='Cutoff %d') 45 | if changed: 46 | self.trunc_cutoff = min(max(new_cutoff, 0), num_ws) 47 | 48 | with imgui_utils.grayed_out(not has_noise): 49 | imgui.same_line() 50 | _clicked, self.noise_enable = imgui.checkbox('Noise##enable', self.noise_enable) 51 | imgui.same_line(round(viz.font_size * 27.7)) 52 | with imgui_utils.grayed_out(not self.noise_enable): 53 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing - viz.font_size * 4): 54 | _changed, self.noise_seed = imgui.input_int('##seed', self.noise_seed) 55 | imgui.same_line(spacing=0) 56 | _clicked, self.noise_anim = imgui.checkbox('Anim##noise', self.noise_anim) 57 | 58 | is_def_trunc = (self.trunc_psi == 1 and self.trunc_cutoff == num_ws) 59 | is_def_noise = (self.noise_enable and self.noise_seed == 0 and not self.noise_anim) 60 | with imgui_utils.grayed_out(is_def_trunc and not has_noise): 61 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 62 | if imgui_utils.button('Reset', width=-1, enabled=(not is_def_trunc or not is_def_noise)): 63 | self.prev_num_ws = num_ws 64 | self.trunc_psi = 1 65 | self.trunc_cutoff = num_ws 66 | self.noise_enable = True 67 | self.noise_seed = 0 68 | self.noise_anim = False 69 | 70 | if self.noise_anim: 71 | self.noise_seed += 1 72 | viz.args.update(trunc_psi=self.trunc_psi, trunc_cutoff=self.trunc_cutoff, random_seed=self.noise_seed) 73 | viz.args.noise_mode = ('none' if not self.noise_enable else 'const' if self.noise_seed == 0 else 'random') 74 | 75 | #---------------------------------------------------------------------------- 76 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/viz/latent_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class LatentWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.latent = dnnlib.EasyDict(x=0, y=0, anim=False, speed=0.25) 20 | self.latent_def = dnnlib.EasyDict(self.latent) 21 | self.step_y = 100 22 | 23 | def drag(self, dx, dy): 24 | viz = self.viz 25 | self.latent.x += dx / viz.font_size * 4e-2 26 | self.latent.y += dy / viz.font_size * 4e-2 27 | 28 | @imgui_utils.scoped_by_object_id 29 | def __call__(self, show=True): 30 | viz = self.viz 31 | if show: 32 | imgui.text('Latent') 33 | imgui.same_line(viz.label_w) 34 | seed = round(self.latent.x) + round(self.latent.y) * self.step_y 35 | with imgui_utils.item_width(viz.font_size * 8): 36 | changed, seed = imgui.input_int('##seed', seed) 37 | if changed: 38 | self.latent.x = seed 39 | self.latent.y = 0 40 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 41 | frac_x = self.latent.x - round(self.latent.x) 42 | frac_y = self.latent.y - round(self.latent.y) 43 | with imgui_utils.item_width(viz.font_size * 5): 44 | changed, (new_frac_x, new_frac_y) = imgui.input_float2('##frac', frac_x, frac_y, format='%+.2f', flags=imgui.INPUT_TEXT_ENTER_RETURNS_TRUE) 45 | if changed: 46 | self.latent.x += new_frac_x - frac_x 47 | self.latent.y += new_frac_y - frac_y 48 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.spacing * 2) 49 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag', width=viz.button_w) 50 | if dragging: 51 | self.drag(dx, dy) 52 | imgui.same_line(viz.label_w + viz.font_size * 13 + viz.button_w + viz.spacing * 3) 53 | _clicked, self.latent.anim = imgui.checkbox('Anim', self.latent.anim) 54 | imgui.same_line(round(viz.font_size * 27.7)) 55 | with imgui_utils.item_width(-1 - viz.button_w * 2 - viz.spacing * 2), imgui_utils.grayed_out(not self.latent.anim): 56 | changed, speed = imgui.slider_float('##speed', self.latent.speed, -5, 5, format='Speed %.3f', power=3) 57 | if changed: 58 | self.latent.speed = speed 59 | imgui.same_line() 60 | snapped = dnnlib.EasyDict(self.latent, x=round(self.latent.x), y=round(self.latent.y)) 61 | if imgui_utils.button('Snap', width=viz.button_w, enabled=(self.latent != snapped)): 62 | self.latent = snapped 63 | imgui.same_line() 64 | if imgui_utils.button('Reset', width=-1, enabled=(self.latent != self.latent_def)): 65 | self.latent = dnnlib.EasyDict(self.latent_def) 66 | 67 | if self.latent.anim: 68 | self.latent.x += viz.frame_delta * self.latent.speed 69 | viz.args.w0_seeds = [] # [[seed, weight], ...] 70 | for ofs_x, ofs_y in [[0, 0], [1, 0], [0, 1], [1, 1]]: 71 | seed_x = np.floor(self.latent.x) + ofs_x 72 | seed_y = np.floor(self.latent.y) + ofs_y 73 | seed = (int(seed_x) + int(seed_y) * self.step_y) & ((1 << 32) - 1) 74 | weight = (1 - abs(self.latent.x - seed_x)) * (1 - abs(self.latent.y - seed_y)) 75 | if weight > 0: 76 | viz.args.w0_seeds.append([seed, weight]) 77 | 78 | #---------------------------------------------------------------------------- 79 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/gui_utils/imgui_window.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import os 10 | import imgui 11 | import imgui.integrations.glfw 12 | 13 | from . import glfw_window 14 | from . import imgui_utils 15 | from . import text_utils 16 | 17 | #---------------------------------------------------------------------------- 18 | 19 | class ImguiWindow(glfw_window.GlfwWindow): 20 | def __init__(self, *, title='ImguiWindow', font=None, font_sizes=range(14,24), **glfw_kwargs): 21 | if font is None: 22 | font = text_utils.get_default_font() 23 | font_sizes = {int(size) for size in font_sizes} 24 | super().__init__(title=title, **glfw_kwargs) 25 | 26 | # Init fields. 27 | self._imgui_context = None 28 | self._imgui_renderer = None 29 | self._imgui_fonts = None 30 | self._cur_font_size = max(font_sizes) 31 | 32 | # Delete leftover imgui.ini to avoid unexpected behavior. 33 | if os.path.isfile('imgui.ini'): 34 | os.remove('imgui.ini') 35 | 36 | # Init ImGui. 37 | self._imgui_context = imgui.create_context() 38 | self._imgui_renderer = _GlfwRenderer(self._glfw_window) 39 | self._attach_glfw_callbacks() 40 | imgui.get_io().ini_saving_rate = 0 # Disable creating imgui.ini at runtime. 41 | imgui.get_io().mouse_drag_threshold = 0 # Improve behavior with imgui_utils.drag_custom(). 42 | self._imgui_fonts = {size: imgui.get_io().fonts.add_font_from_file_ttf(font, size) for size in font_sizes} 43 | self._imgui_renderer.refresh_font_texture() 44 | 45 | def close(self): 46 | self.make_context_current() 47 | self._imgui_fonts = None 48 | if self._imgui_renderer is not None: 49 | self._imgui_renderer.shutdown() 50 | self._imgui_renderer = None 51 | if self._imgui_context is not None: 52 | #imgui.destroy_context(self._imgui_context) # Commented out to avoid creating imgui.ini at the end. 53 | self._imgui_context = None 54 | super().close() 55 | 56 | def _glfw_key_callback(self, *args): 57 | super()._glfw_key_callback(*args) 58 | self._imgui_renderer.keyboard_callback(*args) 59 | 60 | @property 61 | def font_size(self): 62 | return self._cur_font_size 63 | 64 | @property 65 | def spacing(self): 66 | return round(self._cur_font_size * 0.4) 67 | 68 | def set_font_size(self, target): # Applied on next frame. 69 | self._cur_font_size = min((abs(key - target), key) for key in self._imgui_fonts.keys())[1] 70 | 71 | def begin_frame(self): 72 | # Begin glfw frame. 73 | super().begin_frame() 74 | 75 | # Process imgui events. 76 | self._imgui_renderer.mouse_wheel_multiplier = self._cur_font_size / 10 77 | if self.content_width > 0 and self.content_height > 0: 78 | self._imgui_renderer.process_inputs() 79 | 80 | # Begin imgui frame. 81 | imgui.new_frame() 82 | imgui.push_font(self._imgui_fonts[self._cur_font_size]) 83 | imgui_utils.set_default_style(spacing=self.spacing, indent=self.font_size, scrollbar=self.font_size+4) 84 | 85 | def end_frame(self): 86 | imgui.pop_font() 87 | imgui.render() 88 | imgui.end_frame() 89 | self._imgui_renderer.render(imgui.get_draw_data()) 90 | super().end_frame() 91 | 92 | #---------------------------------------------------------------------------- 93 | # Wrapper class for GlfwRenderer to fix a mouse wheel bug on Linux. 94 | 95 | class _GlfwRenderer(imgui.integrations.glfw.GlfwRenderer): 96 | def __init__(self, *args, **kwargs): 97 | super().__init__(*args, **kwargs) 98 | self.mouse_wheel_multiplier = 1 99 | 100 | def scroll_callback(self, window, x_offset, y_offset): 101 | self.io.mouse_wheel += y_offset * self.mouse_wheel_multiplier 102 | 103 | #---------------------------------------------------------------------------- 104 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/LICENSE.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, NVIDIA Corporation & affiliates. All rights reserved. 2 | 3 | 4 | NVIDIA Source Code License for StyleGAN3 5 | 6 | 7 | ======================================================================= 8 | 9 | 1. Definitions 10 | 11 | "Licensor" means any person or entity that distributes its Work. 12 | 13 | "Software" means the original work of authorship made available under 14 | this License. 15 | 16 | "Work" means the Software and any additions to or derivative works of 17 | the Software that are made available under this License. 18 | 19 | The terms "reproduce," "reproduction," "derivative works," and 20 | "distribution" have the meaning as provided under U.S. copyright law; 21 | provided, however, that for the purposes of this License, derivative 22 | works shall not include works that remain separable from, or merely 23 | link (or bind by name) to the interfaces of, the Work. 24 | 25 | Works, including the Software, are "made available" under this License 26 | by including in or with the Work either (a) a copyright notice 27 | referencing the applicability of this License to the Work, or (b) a 28 | copy of this License. 29 | 30 | 2. License Grants 31 | 32 | 2.1 Copyright Grant. Subject to the terms and conditions of this 33 | License, each Licensor grants to you a perpetual, worldwide, 34 | non-exclusive, royalty-free, copyright license to reproduce, 35 | prepare derivative works of, publicly display, publicly perform, 36 | sublicense and distribute its Work and any resulting derivative 37 | works in any form. 38 | 39 | 3. Limitations 40 | 41 | 3.1 Redistribution. You may reproduce or distribute the Work only 42 | if (a) you do so under this License, (b) you include a complete 43 | copy of this License with your distribution, and (c) you retain 44 | without modification any copyright, patent, trademark, or 45 | attribution notices that are present in the Work. 46 | 47 | 3.2 Derivative Works. You may specify that additional or different 48 | terms apply to the use, reproduction, and distribution of your 49 | derivative works of the Work ("Your Terms") only if (a) Your Terms 50 | provide that the use limitation in Section 3.3 applies to your 51 | derivative works, and (b) you identify the specific derivative 52 | works that are subject to Your Terms. Notwithstanding Your Terms, 53 | this License (including the redistribution requirements in Section 54 | 3.1) will continue to apply to the Work itself. 55 | 56 | 3.3 Use Limitation. The Work and any derivative works thereof only 57 | may be used or intended for use non-commercially. Notwithstanding 58 | the foregoing, NVIDIA and its affiliates may use the Work and any 59 | derivative works commercially. As used herein, "non-commercially" 60 | means for research or evaluation purposes only. 61 | 62 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim 63 | against any Licensor (including any claim, cross-claim or 64 | counterclaim in a lawsuit) to enforce any patents that you allege 65 | are infringed by any Work, then your rights under this License from 66 | such Licensor (including the grant in Section 2.1) will terminate 67 | immediately. 68 | 69 | 3.5 Trademarks. This License does not grant any rights to use any 70 | Licensor’s or its affiliates’ names, logos, or trademarks, except 71 | as necessary to reproduce the notices described in this License. 72 | 73 | 3.6 Termination. If you violate any term of this License, then your 74 | rights under this License (including the grant in Section 2.1) will 75 | terminate immediately. 76 | 77 | 4. Disclaimer of Warranty. 78 | 79 | THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY 80 | KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF 81 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR 82 | NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER 83 | THIS LICENSE. 84 | 85 | 5. Limitation of Liability. 86 | 87 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL 88 | THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE 89 | SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, 90 | INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF 91 | OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK 92 | (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, 93 | LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER 94 | COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF 95 | THE POSSIBILITY OF SUCH DAMAGES. 96 | 97 | ======================================================================= 98 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /scripts/eval.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | from omegaconf import OmegaConf 3 | import torch 4 | from pathlib import Path 5 | import numpy as np 6 | from PIL import Image 7 | import train 8 | from joblib import Parallel, delayed 9 | from tqdm import tqdm 10 | from datetime import datetime 11 | import typer 12 | 13 | @torch.no_grad() 14 | def run_diffusion(diffusion_prior, clip_model, text_samples, cond_scale, n_samples, truncation=1.0): 15 | diffusion_prior.eval() 16 | text_features = clip_model.embed_text(text_samples) 17 | text_features = text_features.repeat_interleave(n_samples, dim=0) 18 | pred_w = diffusion_prior.sample(text_features, cond_scale=cond_scale, show_progress=True, truncation=truncation) 19 | return pred_w 20 | 21 | @torch.no_grad() 22 | def run_synthesis(pred_w, sample, G, clip_model, sort_by_clip=True): 23 | 24 | text_features = clip_model.embed_text(sample) 25 | text_features = text_features.tile(pred_w.shape[0], 1) 26 | 27 | images = G.synthesis(pred_w) 28 | pred_w_clip_features = clip_model.embed_image(images) 29 | 30 | similarity = torch.cosine_similarity(pred_w_clip_features, text_features) 31 | 32 | if sort_by_clip: 33 | similarity, idxs = torch.sort(similarity, descending=True) 34 | images = images[idxs, ...] 35 | 36 | return images, similarity 37 | 38 | 39 | def run_eval(cond_scale, skips, text_samples, checkpoint, cfg_file, device, n_samples, write_results, truncation): 40 | 41 | if write_results: 42 | output_dir = Path(f"eval_results/{skips}_{cond_scale}") 43 | output_dir.mkdir(exist_ok=True, parents=True) 44 | best_dir = output_dir/"best" 45 | best_dir.mkdir(exist_ok=True) 46 | best_file = best_dir/"similarity.csv" 47 | best_file.unlink(missing_ok=True) 48 | 49 | cfg = OmegaConf.load(cfg_file) 50 | 51 | G, clip_model, trainer = train.load_models(cfg, device) 52 | trainer.load_state_dict(torch.load(checkpoint, map_location="cpu")["state_dict"], strict=False) 53 | diffusion_prior = trainer.ema_diffusion_prior.ema_model 54 | diffusion_prior.set_timestep_skip(skips) 55 | 56 | def save_im(im, f_name): 57 | Image.fromarray(im).save(f_name) 58 | 59 | best_scores = [] 60 | 61 | w_pred = run_diffusion(diffusion_prior, clip_model, text_samples, cond_scale, n_samples, truncation=truncation) 62 | w_preds = w_pred.split(n_samples) 63 | 64 | with Parallel(n_jobs=n_samples) as parallel: 65 | for sample, w_pred in tqdm(zip(text_samples, w_preds), total=len(text_samples)): 66 | if write_results: 67 | label = sample.lower().replace(' ', '-') 68 | sample_dir = output_dir/"samples"/label 69 | sample_dir.mkdir(exist_ok=True, parents=True) 70 | similarity_file = sample_dir/"similarity.csv" 71 | 72 | images, similarity = run_synthesis(w_pred, sample, G, clip_model) 73 | images = 255*(images.clamp(-1,1)*0.5 + 0.5).permute(0, 2, 3, 1).cpu() 74 | images = images.numpy().astype(np.uint8) 75 | 76 | if write_results: 77 | filenames = [sample_dir/f'{idx:03}.png' for idx in range(images.shape[0])] 78 | 79 | with open(similarity_file, 'wt') as f: 80 | for idx, (s, filename) in enumerate(zip(similarity, filenames)): 81 | f.write(f'{filename.stem}, {s}\n') 82 | 83 | parallel(delayed(save_im)(im, f_name) for im, f_name in zip(images, filenames)) 84 | 85 | best_im = images[0] 86 | best_score = similarity[0] 87 | mean_score = similarity.mean() 88 | # Is std worth recording? 89 | if write_results: 90 | best_filename = best_dir/f'{label}.png' 91 | Image.fromarray(best_im).save(best_filename) 92 | with open(best_file, 'at+') as f: 93 | f.write(f'"{sample}", {best_score}\n') 94 | best_scores.append(best_score.unsqueeze(0)) 95 | 96 | mean_score = torch.cat(best_scores, dim=0).mean().cpu() 97 | 98 | print('------------') 99 | print(f'Timestep skips: {skips}, condition_scale: {cond_scale}') 100 | print(f'Mean CLIP score: {mean_score}') 101 | print('------------') 102 | 103 | return mean_score 104 | 105 | def main( 106 | skips:List[int]=[1,100,250], 107 | cond_scales:List[float]=[1,1.05,1.1,1.2,1.3,1.5,1.75,2,2.5,3,4,5,10], 108 | write_results:bool = True, 109 | checkpoint:str = "best.ckpt", 110 | cfg_file:str= "best.yaml", 111 | device:str= "cuda:0", 112 | n_samples:int= 16, 113 | truncation:float=1.0, 114 | ): 115 | 116 | typer.echo(f"Running skips: {skips}") 117 | typer.echo(f"Running cond scales: {cond_scales}") 118 | 119 | with open("test.txt", 'rt') as f: 120 | text_samples = f.read().splitlines() 121 | text_samples = ["a photograph of " + x for x in text_samples] 122 | 123 | score_filename = f"all_scores_{datetime.now()}.csv" 124 | if write_results: 125 | with open(score_filename, 'wt') as f: 126 | f.write(f"skips, cond_scale, score\n") 127 | 128 | for s in skips: 129 | for c in cond_scales: 130 | score = run_eval(c, s, text_samples, checkpoint, cfg_file, device, n_samples, write_results, truncation=truncation) 131 | if write_results: 132 | with open(score_filename, 'at+') as f: 133 | f.write(f"{s}, {c}, {score}\n") 134 | 135 | 136 | if __name__ == "__main__": 137 | typer.run(main) -------------------------------------------------------------------------------- /clip2latent/stylegan3/metrics/perceptual_path_length.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Perceptual Path Length (PPL) from the paper "A Style-Based Generator 10 | Architecture for Generative Adversarial Networks". Matches the original 11 | implementation by Karras et al. at 12 | https://github.com/NVlabs/stylegan/blob/master/metrics/perceptual_path_length.py""" 13 | 14 | import copy 15 | import numpy as np 16 | import torch 17 | from . import metric_utils 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | # Spherical interpolation of a batch of vectors. 22 | def slerp(a, b, t): 23 | a = a / a.norm(dim=-1, keepdim=True) 24 | b = b / b.norm(dim=-1, keepdim=True) 25 | d = (a * b).sum(dim=-1, keepdim=True) 26 | p = t * torch.acos(d) 27 | c = b - d * a 28 | c = c / c.norm(dim=-1, keepdim=True) 29 | d = a * torch.cos(p) + c * torch.sin(p) 30 | d = d / d.norm(dim=-1, keepdim=True) 31 | return d 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | class PPLSampler(torch.nn.Module): 36 | def __init__(self, G, G_kwargs, epsilon, space, sampling, crop, vgg16): 37 | assert space in ['z', 'w'] 38 | assert sampling in ['full', 'end'] 39 | super().__init__() 40 | self.G = copy.deepcopy(G) 41 | self.G_kwargs = G_kwargs 42 | self.epsilon = epsilon 43 | self.space = space 44 | self.sampling = sampling 45 | self.crop = crop 46 | self.vgg16 = copy.deepcopy(vgg16) 47 | 48 | def forward(self, c): 49 | # Generate random latents and interpolation t-values. 50 | t = torch.rand([c.shape[0]], device=c.device) * (1 if self.sampling == 'full' else 0) 51 | z0, z1 = torch.randn([c.shape[0] * 2, self.G.z_dim], device=c.device).chunk(2) 52 | 53 | # Interpolate in W or Z. 54 | if self.space == 'w': 55 | w0, w1 = self.G.mapping(z=torch.cat([z0,z1]), c=torch.cat([c,c])).chunk(2) 56 | wt0 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2)) 57 | wt1 = w0.lerp(w1, t.unsqueeze(1).unsqueeze(2) + self.epsilon) 58 | else: # space == 'z' 59 | zt0 = slerp(z0, z1, t.unsqueeze(1)) 60 | zt1 = slerp(z0, z1, t.unsqueeze(1) + self.epsilon) 61 | wt0, wt1 = self.G.mapping(z=torch.cat([zt0,zt1]), c=torch.cat([c,c])).chunk(2) 62 | 63 | # Randomize noise buffers. 64 | for name, buf in self.G.named_buffers(): 65 | if name.endswith('.noise_const'): 66 | buf.copy_(torch.randn_like(buf)) 67 | 68 | # Generate images. 69 | img = self.G.synthesis(ws=torch.cat([wt0,wt1]), noise_mode='const', force_fp32=True, **self.G_kwargs) 70 | 71 | # Center crop. 72 | if self.crop: 73 | assert img.shape[2] == img.shape[3] 74 | c = img.shape[2] // 8 75 | img = img[:, :, c*3 : c*7, c*2 : c*6] 76 | 77 | # Downsample to 256x256. 78 | factor = self.G.img_resolution // 256 79 | if factor > 1: 80 | img = img.reshape([-1, img.shape[1], img.shape[2] // factor, factor, img.shape[3] // factor, factor]).mean([3, 5]) 81 | 82 | # Scale dynamic range from [-1,1] to [0,255]. 83 | img = (img + 1) * (255 / 2) 84 | if self.G.img_channels == 1: 85 | img = img.repeat([1, 3, 1, 1]) 86 | 87 | # Evaluate differential LPIPS. 88 | lpips_t0, lpips_t1 = self.vgg16(img, resize_images=False, return_lpips=True).chunk(2) 89 | dist = (lpips_t0 - lpips_t1).square().sum(1) / self.epsilon ** 2 90 | return dist 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def compute_ppl(opts, num_samples, epsilon, space, sampling, crop, batch_size): 95 | vgg16_url = 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/metrics/vgg16.pkl' 96 | vgg16 = metric_utils.get_feature_detector(vgg16_url, num_gpus=opts.num_gpus, rank=opts.rank, verbose=opts.progress.verbose) 97 | 98 | # Setup sampler and labels. 99 | sampler = PPLSampler(G=opts.G, G_kwargs=opts.G_kwargs, epsilon=epsilon, space=space, sampling=sampling, crop=crop, vgg16=vgg16) 100 | sampler.eval().requires_grad_(False).to(opts.device) 101 | c_iter = metric_utils.iterate_random_labels(opts=opts, batch_size=batch_size) 102 | 103 | # Sampling loop. 104 | dist = [] 105 | progress = opts.progress.sub(tag='ppl sampling', num_items=num_samples) 106 | for batch_start in range(0, num_samples, batch_size * opts.num_gpus): 107 | progress.update(batch_start) 108 | x = sampler(next(c_iter)) 109 | for src in range(opts.num_gpus): 110 | y = x.clone() 111 | if opts.num_gpus > 1: 112 | torch.distributed.broadcast(y, src=src) 113 | dist.append(y) 114 | progress.update(num_samples) 115 | 116 | # Compute PPL. 117 | if opts.rank != 0: 118 | return float('nan') 119 | dist = torch.cat(dist)[:num_samples].cpu().numpy() 120 | lo = np.percentile(dist, 1, interpolation='lower') 121 | hi = np.percentile(dist, 99, interpolation='higher') 122 | ppl = np.extract(np.logical_and(dist >= lo, dist <= hi), dist).mean() 123 | return float(ppl) 124 | 125 | #---------------------------------------------------------------------------- 126 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/gui_utils/text_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import functools 10 | from typing import Optional 11 | 12 | import dnnlib 13 | import numpy as np 14 | import PIL.Image 15 | import PIL.ImageFont 16 | import scipy.ndimage 17 | 18 | from . import gl_utils 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def get_default_font(): 23 | url = 'http://fonts.gstatic.com/s/opensans/v17/mem8YaGs126MiZpBA-U1UpcaXcl0Aw.ttf' # Open Sans regular 24 | return dnnlib.util.open_url(url, return_filename=True) 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | @functools.lru_cache(maxsize=None) 29 | def get_pil_font(font=None, size=32): 30 | if font is None: 31 | font = get_default_font() 32 | return PIL.ImageFont.truetype(font=font, size=size) 33 | 34 | #---------------------------------------------------------------------------- 35 | 36 | def get_array(string, *, dropshadow_radius: int=None, **kwargs): 37 | if dropshadow_radius is not None: 38 | offset_x = int(np.ceil(dropshadow_radius*2/3)) 39 | offset_y = int(np.ceil(dropshadow_radius*2/3)) 40 | return _get_array_priv(string, dropshadow_radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 41 | else: 42 | return _get_array_priv(string, **kwargs) 43 | 44 | @functools.lru_cache(maxsize=10000) 45 | def _get_array_priv( 46 | string: str, *, 47 | size: int = 32, 48 | max_width: Optional[int]=None, 49 | max_height: Optional[int]=None, 50 | min_size=10, 51 | shrink_coef=0.8, 52 | dropshadow_radius: int=None, 53 | offset_x: int=None, 54 | offset_y: int=None, 55 | **kwargs 56 | ): 57 | cur_size = size 58 | array = None 59 | while True: 60 | if dropshadow_radius is not None: 61 | # separate implementation for dropshadow text rendering 62 | array = _get_array_impl_dropshadow(string, size=cur_size, radius=dropshadow_radius, offset_x=offset_x, offset_y=offset_y, **kwargs) 63 | else: 64 | array = _get_array_impl(string, size=cur_size, **kwargs) 65 | height, width, _ = array.shape 66 | if (max_width is None or width <= max_width) and (max_height is None or height <= max_height) or (cur_size <= min_size): 67 | break 68 | cur_size = max(int(cur_size * shrink_coef), min_size) 69 | return array 70 | 71 | #---------------------------------------------------------------------------- 72 | 73 | @functools.lru_cache(maxsize=10000) 74 | def _get_array_impl(string, *, font=None, size=32, outline=0, outline_pad=3, outline_coef=3, outline_exp=2, line_pad: int=None): 75 | pil_font = get_pil_font(font=font, size=size) 76 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 77 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 78 | width = max(line.shape[1] for line in lines) 79 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 80 | line_spacing = line_pad if line_pad is not None else size // 2 81 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 82 | mask = np.concatenate(lines, axis=0) 83 | alpha = mask 84 | if outline > 0: 85 | mask = np.pad(mask, int(np.ceil(outline * outline_pad)), mode='constant', constant_values=0) 86 | alpha = mask.astype(np.float32) / 255 87 | alpha = scipy.ndimage.gaussian_filter(alpha, outline) 88 | alpha = 1 - np.maximum(1 - alpha * outline_coef, 0) ** outline_exp 89 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 90 | alpha = np.maximum(alpha, mask) 91 | return np.stack([mask, alpha], axis=-1) 92 | 93 | #---------------------------------------------------------------------------- 94 | 95 | @functools.lru_cache(maxsize=10000) 96 | def _get_array_impl_dropshadow(string, *, font=None, size=32, radius: int, offset_x: int, offset_y: int, line_pad: int=None, **kwargs): 97 | assert (offset_x > 0) and (offset_y > 0) 98 | pil_font = get_pil_font(font=font, size=size) 99 | lines = [pil_font.getmask(line, 'L') for line in string.split('\n')] 100 | lines = [np.array(line, dtype=np.uint8).reshape([line.size[1], line.size[0]]) for line in lines] 101 | width = max(line.shape[1] for line in lines) 102 | lines = [np.pad(line, ((0, 0), (0, width - line.shape[1])), mode='constant') for line in lines] 103 | line_spacing = line_pad if line_pad is not None else size // 2 104 | lines = [np.pad(line, ((0, line_spacing), (0, 0)), mode='constant') for line in lines[:-1]] + lines[-1:] 105 | mask = np.concatenate(lines, axis=0) 106 | alpha = mask 107 | 108 | mask = np.pad(mask, 2*radius + max(abs(offset_x), abs(offset_y)), mode='constant', constant_values=0) 109 | alpha = mask.astype(np.float32) / 255 110 | alpha = scipy.ndimage.gaussian_filter(alpha, radius) 111 | alpha = 1 - np.maximum(1 - alpha * 1.5, 0) ** 1.4 112 | alpha = (alpha * 255 + 0.5).clip(0, 255).astype(np.uint8) 113 | alpha = np.pad(alpha, [(offset_y, 0), (offset_x, 0)], mode='constant')[:-offset_y, :-offset_x] 114 | alpha = np.maximum(alpha, mask) 115 | return np.stack([mask, alpha], axis=-1) 116 | 117 | #---------------------------------------------------------------------------- 118 | 119 | @functools.lru_cache(maxsize=10000) 120 | def get_texture(string, bilinear=True, mipmap=True, **kwargs): 121 | return gl_utils.Texture(image=get_array(string, **kwargs), bilinear=bilinear, mipmap=mipmap) 122 | 123 | #---------------------------------------------------------------------------- 124 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/viz/equivariance_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import numpy as np 10 | import imgui 11 | import dnnlib 12 | from gui_utils import imgui_utils 13 | 14 | #---------------------------------------------------------------------------- 15 | 16 | class EquivarianceWidget: 17 | def __init__(self, viz): 18 | self.viz = viz 19 | self.xlate = dnnlib.EasyDict(x=0, y=0, anim=False, round=False, speed=1e-2) 20 | self.xlate_def = dnnlib.EasyDict(self.xlate) 21 | self.rotate = dnnlib.EasyDict(val=0, anim=False, speed=5e-3) 22 | self.rotate_def = dnnlib.EasyDict(self.rotate) 23 | self.opts = dnnlib.EasyDict(untransform=False) 24 | self.opts_def = dnnlib.EasyDict(self.opts) 25 | 26 | @imgui_utils.scoped_by_object_id 27 | def __call__(self, show=True): 28 | viz = self.viz 29 | if show: 30 | imgui.text('Translate') 31 | imgui.same_line(viz.label_w) 32 | with imgui_utils.item_width(viz.font_size * 8): 33 | _changed, (self.xlate.x, self.xlate.y) = imgui.input_float2('##xlate', self.xlate.x, self.xlate.y, format='%.4f') 34 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 35 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag fast##xlate', width=viz.button_w) 36 | if dragging: 37 | self.xlate.x += dx / viz.font_size * 2e-2 38 | self.xlate.y += dy / viz.font_size * 2e-2 39 | imgui.same_line() 40 | _clicked, dragging, dx, dy = imgui_utils.drag_button('Drag slow##xlate', width=viz.button_w) 41 | if dragging: 42 | self.xlate.x += dx / viz.font_size * 4e-4 43 | self.xlate.y += dy / viz.font_size * 4e-4 44 | imgui.same_line() 45 | _clicked, self.xlate.anim = imgui.checkbox('Anim##xlate', self.xlate.anim) 46 | imgui.same_line() 47 | _clicked, self.xlate.round = imgui.checkbox('Round##xlate', self.xlate.round) 48 | imgui.same_line() 49 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.xlate.anim): 50 | changed, speed = imgui.slider_float('##xlate_speed', self.xlate.speed, 0, 0.5, format='Speed %.5f', power=5) 51 | if changed: 52 | self.xlate.speed = speed 53 | imgui.same_line() 54 | if imgui_utils.button('Reset##xlate', width=-1, enabled=(self.xlate != self.xlate_def)): 55 | self.xlate = dnnlib.EasyDict(self.xlate_def) 56 | 57 | if show: 58 | imgui.text('Rotate') 59 | imgui.same_line(viz.label_w) 60 | with imgui_utils.item_width(viz.font_size * 8): 61 | _changed, self.rotate.val = imgui.input_float('##rotate', self.rotate.val, format='%.4f') 62 | imgui.same_line(viz.label_w + viz.font_size * 8 + viz.spacing) 63 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag fast##rotate', width=viz.button_w) 64 | if dragging: 65 | self.rotate.val += dx / viz.font_size * 2e-2 66 | imgui.same_line() 67 | _clicked, dragging, dx, _dy = imgui_utils.drag_button('Drag slow##rotate', width=viz.button_w) 68 | if dragging: 69 | self.rotate.val += dx / viz.font_size * 4e-4 70 | imgui.same_line() 71 | _clicked, self.rotate.anim = imgui.checkbox('Anim##rotate', self.rotate.anim) 72 | imgui.same_line() 73 | with imgui_utils.item_width(-1 - viz.button_w - viz.spacing), imgui_utils.grayed_out(not self.rotate.anim): 74 | changed, speed = imgui.slider_float('##rotate_speed', self.rotate.speed, -1, 1, format='Speed %.4f', power=3) 75 | if changed: 76 | self.rotate.speed = speed 77 | imgui.same_line() 78 | if imgui_utils.button('Reset##rotate', width=-1, enabled=(self.rotate != self.rotate_def)): 79 | self.rotate = dnnlib.EasyDict(self.rotate_def) 80 | 81 | if show: 82 | imgui.set_cursor_pos_x(imgui.get_content_region_max()[0] - 1 - viz.button_w*1 - viz.font_size*16) 83 | _clicked, self.opts.untransform = imgui.checkbox('Untransform', self.opts.untransform) 84 | imgui.same_line(imgui.get_content_region_max()[0] - 1 - viz.button_w) 85 | if imgui_utils.button('Reset##opts', width=-1, enabled=(self.opts != self.opts_def)): 86 | self.opts = dnnlib.EasyDict(self.opts_def) 87 | 88 | if self.xlate.anim: 89 | c = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 90 | t = c.copy() 91 | if np.max(np.abs(t)) < 1e-4: 92 | t += 1 93 | t *= 0.1 / np.hypot(*t) 94 | t += c[::-1] * [1, -1] 95 | d = t - c 96 | d *= (viz.frame_delta * self.xlate.speed) / np.hypot(*d) 97 | self.xlate.x += d[0] 98 | self.xlate.y += d[1] 99 | 100 | if self.rotate.anim: 101 | self.rotate.val += viz.frame_delta * self.rotate.speed 102 | 103 | pos = np.array([self.xlate.x, self.xlate.y], dtype=np.float64) 104 | if self.xlate.round and 'img_resolution' in viz.result: 105 | pos = np.rint(pos * viz.result.img_resolution) / viz.result.img_resolution 106 | angle = self.rotate.val * np.pi * 2 107 | 108 | viz.args.input_transform = [ 109 | [np.cos(angle), np.sin(angle), pos[0]], 110 | [-np.sin(angle), np.cos(angle), pos[1]], 111 | [0, 0, 1]] 112 | 113 | viz.args.update(untransform=self.opts.untransform) 114 | 115 | #---------------------------------------------------------------------------- 116 | -------------------------------------------------------------------------------- /clip2latent/models.py: -------------------------------------------------------------------------------- 1 | import clip 2 | from omegaconf import DictConfig, OmegaConf 3 | import torch 4 | import torch.nn.functional as F 5 | from dalle2_pytorch import DiffusionPriorNetwork 6 | from dalle2_pytorch.train import DiffusionPriorTrainer 7 | from torchvision import transforms 8 | from io import BytesIO 9 | import requests 10 | from pathlib import Path 11 | 12 | from clip2latent.latent_prior import LatentPrior, WPlusPriorNetwork 13 | 14 | 15 | def load_sg(network_pkl): 16 | import sys 17 | code_folder = Path(__file__).parent 18 | sg3_path = str(code_folder/"stylegan3") 19 | sys.path.append(sg3_path) 20 | import dnnlib 21 | import legacy 22 | 23 | with dnnlib.util.open_url(network_pkl) as f: 24 | G = legacy.load_network_pkl(f)['G_ema'] # type: ignore 25 | return G 26 | 27 | def is_url(path): 28 | if isinstance(path, str) and path.startswith("http"): 29 | return True 30 | else: 31 | return False 32 | 33 | def load_remote_cfg(cfg): 34 | r = requests.get(cfg) 35 | r.raise_for_status() 36 | f = BytesIO(r.content) 37 | return OmegaConf.load(f) 38 | 39 | class Clipper(torch.nn.Module): 40 | def __init__(self, clip_variant): 41 | super().__init__() 42 | clip_model, _ = clip.load(clip_variant, device="cpu") 43 | self.clip = clip_model 44 | self.normalize = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 45 | self.clip_size = (224,224) 46 | 47 | def embed_image(self, image): 48 | """Expects images in -1 to 1 range""" 49 | clip_in = F.interpolate(image, self.clip_size, mode="area") 50 | clip_in = self.normalize(0.5*clip_in + 0.5).clamp(0,1) 51 | return self.clip.encode_image(self.normalize(clip_in)) 52 | 53 | def embed_text(self, text_samples): 54 | text = clip.tokenize(text_samples).to(self._get_device()) 55 | return self.clip.encode_text(text) 56 | 57 | def _get_device(self): 58 | for p in self.clip.parameters(): 59 | return p.device 60 | 61 | class Clip2StyleGAN(torch.nn.Module): 62 | """A wrapper around the compontent models to create an end-to-end text2image model""" 63 | def __init__(self, cfg, device, checkpoint=None) -> None: 64 | super().__init__() 65 | 66 | if not isinstance(cfg, DictConfig): 67 | if is_url(cfg): 68 | cfg = load_remote_cfg(cfg) 69 | else: 70 | cfg = OmegaConf.load(cfg) 71 | 72 | G, clip_model, trainer = load_models(cfg, device) 73 | if checkpoint is not None: 74 | if is_url(checkpoint): 75 | state_dict = torch.hub.load_state_dict_from_url(checkpoint, map_location="cpu") 76 | else: 77 | state_dict = torch.load(checkpoint, map_location="cpu") 78 | trainer.load_state_dict(state_dict["state_dict"], strict=False) 79 | diffusion_prior = trainer.ema_diffusion_prior.ema_model 80 | self.G = G 81 | self.clip_model = clip_model 82 | self.diffusion_prior = diffusion_prior 83 | 84 | def forward(self, text_samples, n_samples_per_txt=1, cond_scale=1.0, truncation=1.0, skips=1, clip_sort=False, edit=None, show_progress=True): 85 | self.diffusion_prior.set_timestep_skip(skips) 86 | text_features = self.clip_model.embed_text(text_samples) 87 | if n_samples_per_txt > 1: 88 | text_features = text_features.repeat_interleave(n_samples_per_txt, dim=0) 89 | pred_w = self.diffusion_prior.sample(text_features, cond_scale=cond_scale, show_progress=show_progress, truncation=truncation) 90 | 91 | if edit is not None: 92 | pred_w = pred_w + edit.to(pred_w.device) 93 | images = self.G.synthesis(pred_w) 94 | 95 | pred_w_clip_features = self.clip_model.embed_image(images) 96 | similarity = torch.cosine_similarity(pred_w_clip_features, text_features) 97 | if clip_sort: 98 | similarity, idxs = torch.sort(similarity, descending=True) 99 | images = images[idxs, ...] 100 | 101 | return images, similarity 102 | 103 | def load_models(cfg, device, stats=None): 104 | """Load the diffusion trainer and eval models based on a config 105 | 106 | If the model requires statistics for embed or latent normalisation 107 | then these should be passed into this function, unless the state of 108 | the model is to be loaded from a state_dict (which will contain these) 109 | statistics, in which case the stats will be filled with dummy values. 110 | """ 111 | if cfg.data.n_latents > 1: 112 | prior_network = WPlusPriorNetwork(n_latents=cfg.data.n_latents, **cfg.model.network).to(device) 113 | else: 114 | prior_network = DiffusionPriorNetwork(**cfg.model.network).to(device) 115 | 116 | embed_stats = latent_stats = (None, None) 117 | if stats is None: 118 | # Make dummy stats assuming they will be loaded from the state dict 119 | clip_dummy_stat = torch.zeros(cfg.model.network.dim,1) 120 | w_dummy_stat = torch.zeros(cfg.model.network.dim) 121 | if cfg.data.n_latents > 1: 122 | w_dummy_stat = w_dummy_stat.unsqueeze(0).tile(1, cfg.data.n_latents) 123 | stats = {"clip_features": (clip_dummy_stat, clip_dummy_stat), "w": (w_dummy_stat, w_dummy_stat)} 124 | 125 | if cfg.train.znorm_embed: 126 | embed_stats = stats["clip_features"] 127 | if cfg.train.znorm_latent: 128 | latent_stats = stats["w"] 129 | 130 | diffusion_prior = LatentPrior( 131 | prior_network, 132 | num_latents=cfg.data.n_latents, 133 | latent_repeats=cfg.data.latent_repeats, 134 | latent_stats=latent_stats, 135 | embed_stats=embed_stats, 136 | **cfg.model.diffusion).to(device) 137 | diffusion_prior.cfg = cfg 138 | 139 | # Load eval models 140 | G = load_sg(cfg.data.sg_pkl).to(device) 141 | clip_model = Clipper(cfg.data.clip_variant).to(device) 142 | 143 | trainer = DiffusionPriorTrainer( 144 | diffusion_prior=diffusion_prior, 145 | lr=cfg.train.lr, 146 | wd=cfg.train.weight_decay, 147 | ema_beta=cfg.train.ema_beta, 148 | ema_update_every=cfg.train.ema_update_every, 149 | ).to(device) 150 | 151 | return G, clip_model, trainer 152 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/gen_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate images using pretrained network pickle.""" 10 | 11 | import os 12 | import re 13 | from typing import List, Optional, Tuple, Union 14 | 15 | import click 16 | import dnnlib 17 | import numpy as np 18 | import PIL.Image 19 | import torch 20 | 21 | import legacy 22 | 23 | #---------------------------------------------------------------------------- 24 | 25 | def parse_range(s: Union[str, List]) -> List[int]: 26 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 27 | 28 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 29 | ''' 30 | if isinstance(s, list): return s 31 | ranges = [] 32 | range_re = re.compile(r'^(\d+)-(\d+)$') 33 | for p in s.split(','): 34 | m = range_re.match(p) 35 | if m: 36 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 37 | else: 38 | ranges.append(int(p)) 39 | return ranges 40 | 41 | #---------------------------------------------------------------------------- 42 | 43 | def parse_vec2(s: Union[str, Tuple[float, float]]) -> Tuple[float, float]: 44 | '''Parse a floating point 2-vector of syntax 'a,b'. 45 | 46 | Example: 47 | '0,1' returns (0,1) 48 | ''' 49 | if isinstance(s, tuple): return s 50 | parts = s.split(',') 51 | if len(parts) == 2: 52 | return (float(parts[0]), float(parts[1])) 53 | raise ValueError(f'cannot parse 2-vector {s}') 54 | 55 | #---------------------------------------------------------------------------- 56 | 57 | def make_transform(translate: Tuple[float,float], angle: float): 58 | m = np.eye(3) 59 | s = np.sin(angle/360.0*np.pi*2) 60 | c = np.cos(angle/360.0*np.pi*2) 61 | m[0][0] = c 62 | m[0][1] = s 63 | m[0][2] = translate[0] 64 | m[1][0] = -s 65 | m[1][1] = c 66 | m[1][2] = translate[1] 67 | return m 68 | 69 | #---------------------------------------------------------------------------- 70 | 71 | @click.command() 72 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 73 | @click.option('--seeds', type=parse_range, help='List of random seeds (e.g., \'0,1,4-6\')', required=True) 74 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 75 | @click.option('--class', 'class_idx', type=int, help='Class label (unconditional if not specified)') 76 | @click.option('--noise-mode', help='Noise mode', type=click.Choice(['const', 'random', 'none']), default='const', show_default=True) 77 | @click.option('--translate', help='Translate XY-coordinate (e.g. \'0.3,1\')', type=parse_vec2, default='0,0', show_default=True, metavar='VEC2') 78 | @click.option('--rotate', help='Rotation angle in degrees', type=float, default=0, show_default=True, metavar='ANGLE') 79 | @click.option('--outdir', help='Where to save the output images', type=str, required=True, metavar='DIR') 80 | def generate_images( 81 | network_pkl: str, 82 | seeds: List[int], 83 | truncation_psi: float, 84 | noise_mode: str, 85 | outdir: str, 86 | translate: Tuple[float,float], 87 | rotate: float, 88 | class_idx: Optional[int] 89 | ): 90 | """Generate images using pretrained network pickle. 91 | 92 | Examples: 93 | 94 | \b 95 | # Generate an image using pre-trained AFHQv2 model ("Ours" in Figure 1, left). 96 | python gen_images.py --outdir=out --trunc=1 --seeds=2 \\ 97 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 98 | 99 | \b 100 | # Generate uncurated images with truncation using the MetFaces-U dataset 101 | python gen_images.py --outdir=out --trunc=0.7 --seeds=600-605 \\ 102 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-t-metfacesu-1024x1024.pkl 103 | """ 104 | 105 | print('Loading networks from "%s"...' % network_pkl) 106 | device = torch.device('cuda') 107 | with dnnlib.util.open_url(network_pkl) as f: 108 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 109 | 110 | os.makedirs(outdir, exist_ok=True) 111 | 112 | # Labels. 113 | label = torch.zeros([1, G.c_dim], device=device) 114 | if G.c_dim != 0: 115 | if class_idx is None: 116 | raise click.ClickException('Must specify class label with --class when using a conditional network') 117 | label[:, class_idx] = 1 118 | else: 119 | if class_idx is not None: 120 | print ('warn: --class=lbl ignored when running on an unconditional network') 121 | 122 | # Generate images. 123 | for seed_idx, seed in enumerate(seeds): 124 | print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) 125 | z = torch.from_numpy(np.random.RandomState(seed).randn(1, G.z_dim)).to(device) 126 | 127 | # Construct an inverse rotation/translation matrix and pass to the generator. The 128 | # generator expects this matrix as an inverse to avoid potentially failing numerical 129 | # operations in the network. 130 | if hasattr(G.synthesis, 'input'): 131 | m = make_transform(translate, rotate) 132 | m = np.linalg.inv(m) 133 | G.synthesis.input.transform.copy_(torch.from_numpy(m)) 134 | 135 | img = G(z, label, truncation_psi=truncation_psi, noise_mode=noise_mode) 136 | img = (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8) 137 | PIL.Image.fromarray(img[0].cpu().numpy(), 'RGB').save(f'{outdir}/seed{seed:04d}.png') 138 | 139 | 140 | #---------------------------------------------------------------------------- 141 | 142 | if __name__ == "__main__": 143 | generate_images() # pylint: disable=no-value-for-parameter 144 | 145 | #---------------------------------------------------------------------------- 146 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/metrics/metric_main.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Main API for computing and reporting quality metrics.""" 10 | 11 | import os 12 | import time 13 | import json 14 | import torch 15 | import dnnlib 16 | 17 | from . import metric_utils 18 | from . import frechet_inception_distance 19 | from . import kernel_inception_distance 20 | from . import precision_recall 21 | from . import perceptual_path_length 22 | from . import inception_score 23 | from . import equivariance 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _metric_dict = dict() # name => fn 28 | 29 | def register_metric(fn): 30 | assert callable(fn) 31 | _metric_dict[fn.__name__] = fn 32 | return fn 33 | 34 | def is_valid_metric(metric): 35 | return metric in _metric_dict 36 | 37 | def list_valid_metrics(): 38 | return list(_metric_dict.keys()) 39 | 40 | #---------------------------------------------------------------------------- 41 | 42 | def calc_metric(metric, **kwargs): # See metric_utils.MetricOptions for the full list of arguments. 43 | assert is_valid_metric(metric) 44 | opts = metric_utils.MetricOptions(**kwargs) 45 | 46 | # Calculate. 47 | start_time = time.time() 48 | results = _metric_dict[metric](opts) 49 | total_time = time.time() - start_time 50 | 51 | # Broadcast results. 52 | for key, value in list(results.items()): 53 | if opts.num_gpus > 1: 54 | value = torch.as_tensor(value, dtype=torch.float64, device=opts.device) 55 | torch.distributed.broadcast(tensor=value, src=0) 56 | value = float(value.cpu()) 57 | results[key] = value 58 | 59 | # Decorate with metadata. 60 | return dnnlib.EasyDict( 61 | results = dnnlib.EasyDict(results), 62 | metric = metric, 63 | total_time = total_time, 64 | total_time_str = dnnlib.util.format_time(total_time), 65 | num_gpus = opts.num_gpus, 66 | ) 67 | 68 | #---------------------------------------------------------------------------- 69 | 70 | def report_metric(result_dict, run_dir=None, snapshot_pkl=None): 71 | metric = result_dict['metric'] 72 | assert is_valid_metric(metric) 73 | if run_dir is not None and snapshot_pkl is not None: 74 | snapshot_pkl = os.path.relpath(snapshot_pkl, run_dir) 75 | 76 | jsonl_line = json.dumps(dict(result_dict, snapshot_pkl=snapshot_pkl, timestamp=time.time())) 77 | print(jsonl_line) 78 | if run_dir is not None and os.path.isdir(run_dir): 79 | with open(os.path.join(run_dir, f'metric-{metric}.jsonl'), 'at') as f: 80 | f.write(jsonl_line + '\n') 81 | 82 | #---------------------------------------------------------------------------- 83 | # Recommended metrics. 84 | 85 | @register_metric 86 | def fid50k_full(opts): 87 | opts.dataset_kwargs.update(max_size=None, xflip=False) 88 | fid = frechet_inception_distance.compute_fid(opts, max_real=None, num_gen=50000) 89 | return dict(fid50k_full=fid) 90 | 91 | @register_metric 92 | def kid50k_full(opts): 93 | opts.dataset_kwargs.update(max_size=None, xflip=False) 94 | kid = kernel_inception_distance.compute_kid(opts, max_real=1000000, num_gen=50000, num_subsets=100, max_subset_size=1000) 95 | return dict(kid50k_full=kid) 96 | 97 | @register_metric 98 | def pr50k3_full(opts): 99 | opts.dataset_kwargs.update(max_size=None, xflip=False) 100 | precision, recall = precision_recall.compute_pr(opts, max_real=200000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 101 | return dict(pr50k3_full_precision=precision, pr50k3_full_recall=recall) 102 | 103 | @register_metric 104 | def ppl2_wend(opts): 105 | ppl = perceptual_path_length.compute_ppl(opts, num_samples=50000, epsilon=1e-4, space='w', sampling='end', crop=False, batch_size=2) 106 | return dict(ppl2_wend=ppl) 107 | 108 | @register_metric 109 | def eqt50k_int(opts): 110 | opts.G_kwargs.update(force_fp32=True) 111 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_int=True) 112 | return dict(eqt50k_int=psnr) 113 | 114 | @register_metric 115 | def eqt50k_frac(opts): 116 | opts.G_kwargs.update(force_fp32=True) 117 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqt_frac=True) 118 | return dict(eqt50k_frac=psnr) 119 | 120 | @register_metric 121 | def eqr50k(opts): 122 | opts.G_kwargs.update(force_fp32=True) 123 | psnr = equivariance.compute_equivariance_metrics(opts, num_samples=50000, batch_size=4, compute_eqr=True) 124 | return dict(eqr50k=psnr) 125 | 126 | #---------------------------------------------------------------------------- 127 | # Legacy metrics. 128 | 129 | @register_metric 130 | def fid50k(opts): 131 | opts.dataset_kwargs.update(max_size=None) 132 | fid = frechet_inception_distance.compute_fid(opts, max_real=50000, num_gen=50000) 133 | return dict(fid50k=fid) 134 | 135 | @register_metric 136 | def kid50k(opts): 137 | opts.dataset_kwargs.update(max_size=None) 138 | kid = kernel_inception_distance.compute_kid(opts, max_real=50000, num_gen=50000, num_subsets=100, max_subset_size=1000) 139 | return dict(kid50k=kid) 140 | 141 | @register_metric 142 | def pr50k3(opts): 143 | opts.dataset_kwargs.update(max_size=None) 144 | precision, recall = precision_recall.compute_pr(opts, max_real=50000, num_gen=50000, nhood_size=3, row_batch_size=10000, col_batch_size=10000) 145 | return dict(pr50k3_precision=precision, pr50k3_recall=recall) 146 | 147 | @register_metric 148 | def is50k(opts): 149 | opts.dataset_kwargs.update(max_size=None, xflip=False) 150 | mean, std = inception_score.compute_is(opts, num_gen=50000, num_splits=10) 151 | return dict(is50k_mean=mean, is50k_std=std) 152 | 153 | #---------------------------------------------------------------------------- 154 | -------------------------------------------------------------------------------- /scripts/train.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from datetime import datetime 3 | from functools import partial 4 | from pathlib import Path 5 | 6 | import hydra 7 | import numpy as np 8 | import torch 9 | from omegaconf import OmegaConf 10 | from tqdm.auto import tqdm 11 | 12 | import wandb 13 | from clip2latent.data import load_data 14 | from clip2latent.models import load_models 15 | from clip2latent.train_utils import (compute_val, make_grid, 16 | make_image_val_data, make_text_val_data) 17 | 18 | logger = logging.getLogger(__name__) 19 | noop = lambda *args, **kwargs: None 20 | logfun = noop 21 | 22 | class Checkpointer(): 23 | """A small class to take care of saving checkpoints""" 24 | def __init__(self, directory, checkpoint_its): 25 | directory = Path(directory) 26 | self.directory = directory 27 | self.checkpoint_its = checkpoint_its 28 | if not directory.exists(): 29 | directory.mkdir(parents=True) 30 | 31 | def save_checkpoint(self, model, iteration): 32 | if iteration % self.checkpoint_its: 33 | return 34 | 35 | k_it = iteration // 1000 36 | filename = self.directory/f"{k_it:06}.ckpt" 37 | checkpoint = {"state_dict": model.state_dict()} 38 | if hasattr(model, "cfg"): 39 | checkpoint["cfg"] = model.cfg 40 | 41 | print(f"Saving checkpoint to {filename}") 42 | torch.save(checkpoint, filename) 43 | 44 | 45 | 46 | def validation(current_it, device, diffusion_prior, G, clip_model, val_data, samples_per_text): 47 | single_im = {"clip_features": val_data["val_im"]["clip_features"][0].unsqueeze(0)} 48 | captions = val_data["val_caption"] 49 | 50 | for input_data, key, cond_scale, repeats in zip( 51 | [val_data["val_im"], single_im, val_data["val_text"], val_data["val_text"]], 52 | ["image-similarity", "image-vars", "text2im", "text2im-super2"], 53 | [1.0, 1.0, 1.0, 2.0], 54 | [1, 8, samples_per_text, samples_per_text], 55 | ): 56 | tiled_data = input_data["clip_features"].repeat_interleave(repeats, dim=0) 57 | cos_sim, ims = compute_val(diffusion_prior, tiled_data, G, clip_model, device, cond_scale=cond_scale) 58 | logfun({f'val/{key}':cos_sim.mean()}, step=current_it) 59 | 60 | 61 | if key.startswith("text"): 62 | num_chunks = int(np.ceil(ims.shape[0]//repeats)) 63 | for idx, (sim, im_chunk) in enumerate(zip( 64 | cos_sim.chunk(num_chunks), 65 | ims.chunk(num_chunks) 66 | )): 67 | 68 | caption = captions[idx] 69 | im = wandb.Image(make_grid(im_chunk), caption=f'{sim.mean():.2f} - {caption}') 70 | logfun({f'val/image/{key}/{idx}': im}, step=current_it) 71 | else: 72 | for idx, im in enumerate(ims.chunk(int(np.ceil(ims.shape[0]/16)))): 73 | logfun({f'val/image/{key}/{idx}': wandb.Image(make_grid(im))}, step=current_it) 74 | 75 | logger.info("Validation done.") 76 | 77 | def train_step(diffusion_prior, device, batch): 78 | diffusion_prior.train() 79 | batch_z, batch_w = batch 80 | batch_z = batch_z.to(device) 81 | batch_w = batch_w.to(device) 82 | 83 | loss = diffusion_prior(batch_z, batch_w) 84 | loss.backward() 85 | return loss 86 | 87 | 88 | def train(trainer, loader, device, val_it, validate, save_checkpoint, max_it, print_it=50): 89 | 90 | current_it = 0 91 | current_epoch = 0 92 | 93 | while current_it < max_it: 94 | 95 | logfun({'epoch': current_epoch}, step=current_it) 96 | pbar = tqdm(loader) 97 | for batch in pbar: 98 | if current_it % val_it == 0: 99 | validate(current_it, device, trainer) 100 | 101 | trainer.train() 102 | batch_clip, batch_latent = batch 103 | 104 | input_args = { 105 | "image_embed": batch_latent.to(device), 106 | "text_embed": batch_clip.to(device) 107 | } 108 | loss = trainer(**input_args) 109 | 110 | if (current_it % print_it == 0): 111 | logfun({'loss': loss}, step=current_it) 112 | 113 | trainer.update() 114 | current_it += 1 115 | pbar.set_postfix({"loss": loss, "epoch": current_epoch, "it": current_it}) 116 | 117 | save_checkpoint(trainer, current_it) 118 | 119 | current_epoch += 1 120 | 121 | 122 | @hydra.main(config_path="../config", config_name="config") 123 | def main(cfg): 124 | 125 | if cfg.logging == "wandb": 126 | wandb.init( 127 | project=cfg.wandb_project, 128 | config=OmegaConf.to_container(cfg), 129 | entity=cfg.wandb_entity, 130 | name=cfg.name, 131 | ) 132 | global logfun 133 | logfun = wandb.log 134 | elif cfg.logging is None: 135 | logger.info("Not logging") 136 | else: 137 | raise NotImplementedError(f"Logging type {cfg.logging} not implemented") 138 | 139 | device = cfg.device 140 | stats, loader = load_data(cfg.data) 141 | G, clip_model, trainer = load_models(cfg, device, stats) 142 | 143 | text_embed, text_samples = make_text_val_data(G, clip_model, hydra.utils.to_absolute_path(cfg.data.val_text_samples)) 144 | val_data = { 145 | "val_im": make_image_val_data(G, clip_model, cfg.data.val_im_samples, device), 146 | "val_text": text_embed, 147 | "val_caption": text_samples, 148 | } 149 | 150 | if 'resume' in cfg and cfg.resume is not None: 151 | # Does not load previous iteration count 152 | logger.info(f"Resuming from {cfg.resume}") 153 | trainer.load_state_dict(torch.load(cfg.resume, map_location="cpu")["state_dict"]) 154 | 155 | checkpoint_dir = f"checkpoints/{datetime.now():%Y%m%d-%H%M%S}" 156 | checkpointer = Checkpointer(checkpoint_dir, cfg.train.val_it) 157 | validate = partial(validation, 158 | G=G, 159 | clip_model=clip_model, 160 | val_data=val_data, 161 | samples_per_text=cfg.data.val_samples_per_text, 162 | ) 163 | 164 | train(trainer, loader, device, 165 | val_it=cfg.train.val_it, 166 | max_it=cfg.train.max_it, 167 | validate=validate, 168 | save_checkpoint=checkpointer.save_checkpoint, 169 | ) 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /scripts/generate_dataset.py: -------------------------------------------------------------------------------- 1 | # Generate datasets 2 | from multiprocessing import Process 3 | import multiprocessing as mp 4 | import math 5 | from functools import partial 6 | from pathlib import Path 7 | 8 | import numpy as np 9 | import torch 10 | import typer 11 | from joblib import Parallel, delayed 12 | from PIL import Image 13 | from tqdm import tqdm 14 | import torch.nn.functional as F 15 | 16 | from clip2latent.models import Clipper, load_sg 17 | 18 | 19 | import multiprocessing as mp 20 | try: 21 | mp.set_start_method('spawn') 22 | except: 23 | pass 24 | 25 | generators = { 26 | "sg2-ffhq-1024": partial(load_sg, 'https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl'), 27 | "sg3-lhq-256": partial(load_sg, 'https://huggingface.co/justinpinkney/stylegan3-t-lhq-256/resolve/main/lhq-256-stylegan3-t-25Mimg.pkl'), 28 | } 29 | 30 | def mix_styles(w_batch, space): 31 | """Defines a style mixing procedure""" 32 | space_spec = { 33 | "w3": (4, 4, 10), 34 | } 35 | latent_mix = space_spec[space] 36 | 37 | bs = w_batch.shape[0] 38 | spec = torch.tensor(latent_mix).to(w_batch.device) 39 | 40 | index = torch.randint(0,bs, (len(spec),bs)).to(w_batch.device) 41 | return w_batch[index, 0, :].permute(1,0,2).repeat_interleave(spec, dim=1), spec 42 | 43 | @torch.no_grad() 44 | def run_folder_list( 45 | device_index, 46 | out_dir, 47 | generator_name, 48 | feature_extractor_name, 49 | out_image_size, 50 | batch_size, 51 | n_save_workers, 52 | samples_per_folder, 53 | folder_indexes, 54 | space="w", 55 | save_im=True, 56 | ): 57 | """Generate a directory of generated images and correspdonding embeddings and latents""" 58 | latent_dim = 512 59 | device = f"cuda:{device_index}" 60 | typer.echo(device_index) 61 | 62 | typer.echo("Loading generator") 63 | G = generators[generator_name]().to(device).eval() 64 | 65 | typer.echo("Loading feature extractor") 66 | feature_extractor = Clipper(feature_extractor_name).to(device) 67 | 68 | typer.echo("Generating samples") 69 | typer.echo(f"using space {space}") 70 | 71 | with Parallel(n_jobs=n_save_workers, prefer="threads") as parallel: 72 | for i_folder in folder_indexes: 73 | folder_name = out_dir/f"{i_folder:05d}" 74 | folder_name.mkdir(exist_ok=True) 75 | 76 | z = torch.randn(samples_per_folder, latent_dim, device=device) 77 | w = G.mapping(z, c=None) 78 | ds = torch.utils.data.TensorDataset(w) 79 | loader = torch.utils.data.DataLoader(ds, batch_size=batch_size, shuffle=False, drop_last=False) 80 | for batch_idx, batch in enumerate(tqdm(loader, position=device_index)): 81 | if space == "w": 82 | this_w = batch[0].to(device) 83 | latents = this_w[:,0,:].cpu().numpy() 84 | else: 85 | this_w, select_idxs = mix_styles(batch[0].to(device), space) 86 | latents = this_w[:,select_idxs,:].cpu().numpy() 87 | 88 | out = G.synthesis(this_w) 89 | 90 | out = F.interpolate(out, (out_image_size,out_image_size), mode="area") 91 | image_features = feature_extractor.embed_image(out) 92 | image_features = image_features.cpu().numpy() 93 | 94 | if save_im: 95 | out = out.permute(0,2,3,1).clamp(-1,1) 96 | out = (255*(out*0.5 + 0.5).cpu().numpy()).astype(np.uint8) 97 | else: 98 | out = [None]*len(latents) 99 | parallel( 100 | delayed(process_and_save)(batch_size, folder_name, batch_idx, idx, latent, im, image_feature, save_im) 101 | for idx, (latent, im, image_feature) in enumerate(zip(latents, out, image_features)) 102 | ) 103 | 104 | typer.echo("finished folder") 105 | 106 | 107 | def process_and_save(batch_size, folder_name, batch_idx, idx, latent, im, image_feature, save_im): 108 | count = batch_idx*batch_size + idx 109 | basename = folder_name/f"{folder_name.stem}{count:04}" 110 | np.save(basename.with_suffix(".latent.npy"), latent) 111 | np.save(basename.with_suffix(".img_feat.npy"), image_feature) 112 | if save_im: 113 | im = Image.fromarray(im) 114 | im.save(basename.with_suffix(".gen.jpg"), quality=95) 115 | 116 | def make_webdataset(in_dir, out_dir): 117 | import tarfile 118 | in_folders = [x for x in Path(in_dir).glob("*") if x.is_dir] 119 | out_dir = Path(out_dir) 120 | out_dir.mkdir() 121 | for folder in in_folders: 122 | filename = out_dir/f"{folder.stem}.tar" 123 | files_to_add = sorted(list(folder.rglob("*"))) 124 | 125 | with tarfile.open(filename, "w") as tar: 126 | for f in files_to_add: 127 | tar.add(f) 128 | 129 | 130 | def main( 131 | out_dir:Path, 132 | n_samples:int=1_000_000, 133 | generator_name:str="sg2-ffhq-1024", # Key into `generators` dict` 134 | feature_extractor_name:str="ViT-B/32", 135 | n_gpus:int=2, 136 | out_image_size:int=256, 137 | batch_size:int=32, 138 | n_save_workers:int=16, 139 | space:str="w", 140 | samples_per_folder:int=10_000, 141 | save_im:bool=False, # Save the generated images? 142 | ): 143 | typer.echo("starting") 144 | 145 | out_dir.mkdir(parents=True) 146 | 147 | n_folders = math.ceil(n_samples/samples_per_folder) 148 | folder_indexes = range(n_folders) 149 | 150 | sub_indexes = np.array_split(folder_indexes, n_gpus) 151 | 152 | processes = [] 153 | for dev_idx, folder_list in enumerate(sub_indexes): 154 | p = Process( 155 | target=run_folder_list, 156 | args=( 157 | dev_idx, 158 | out_dir, 159 | generator_name, 160 | feature_extractor_name, 161 | out_image_size, 162 | batch_size, 163 | n_save_workers, 164 | samples_per_folder, 165 | folder_list, 166 | space, 167 | save_im, 168 | ), 169 | ) 170 | p.start() 171 | processes.append(p) 172 | 173 | for p in processes: 174 | p.join() 175 | 176 | typer.echo("finished all") 177 | 178 | if __name__ == "__main__": 179 | # mp.set_start_method('spawn') 180 | typer.run(main) 181 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files*/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files*/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files*/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files*/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/gui_utils/imgui_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import contextlib 10 | import imgui 11 | 12 | #---------------------------------------------------------------------------- 13 | 14 | def set_default_style(color_scheme='dark', spacing=9, indent=23, scrollbar=27): 15 | s = imgui.get_style() 16 | s.window_padding = [spacing, spacing] 17 | s.item_spacing = [spacing, spacing] 18 | s.item_inner_spacing = [spacing, spacing] 19 | s.columns_min_spacing = spacing 20 | s.indent_spacing = indent 21 | s.scrollbar_size = scrollbar 22 | s.frame_padding = [4, 3] 23 | s.window_border_size = 1 24 | s.child_border_size = 1 25 | s.popup_border_size = 1 26 | s.frame_border_size = 1 27 | s.window_rounding = 0 28 | s.child_rounding = 0 29 | s.popup_rounding = 3 30 | s.frame_rounding = 3 31 | s.scrollbar_rounding = 3 32 | s.grab_rounding = 3 33 | 34 | getattr(imgui, f'style_colors_{color_scheme}')(s) 35 | c0 = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 36 | c1 = s.colors[imgui.COLOR_FRAME_BACKGROUND] 37 | s.colors[imgui.COLOR_POPUP_BACKGROUND] = [x * 0.7 + y * 0.3 for x, y in zip(c0, c1)][:3] + [1] 38 | 39 | #---------------------------------------------------------------------------- 40 | 41 | @contextlib.contextmanager 42 | def grayed_out(cond=True): 43 | if cond: 44 | s = imgui.get_style() 45 | text = s.colors[imgui.COLOR_TEXT_DISABLED] 46 | grab = s.colors[imgui.COLOR_SCROLLBAR_GRAB] 47 | back = s.colors[imgui.COLOR_MENUBAR_BACKGROUND] 48 | imgui.push_style_color(imgui.COLOR_TEXT, *text) 49 | imgui.push_style_color(imgui.COLOR_CHECK_MARK, *grab) 50 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB, *grab) 51 | imgui.push_style_color(imgui.COLOR_SLIDER_GRAB_ACTIVE, *grab) 52 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND, *back) 53 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_HOVERED, *back) 54 | imgui.push_style_color(imgui.COLOR_FRAME_BACKGROUND_ACTIVE, *back) 55 | imgui.push_style_color(imgui.COLOR_BUTTON, *back) 56 | imgui.push_style_color(imgui.COLOR_BUTTON_HOVERED, *back) 57 | imgui.push_style_color(imgui.COLOR_BUTTON_ACTIVE, *back) 58 | imgui.push_style_color(imgui.COLOR_HEADER, *back) 59 | imgui.push_style_color(imgui.COLOR_HEADER_HOVERED, *back) 60 | imgui.push_style_color(imgui.COLOR_HEADER_ACTIVE, *back) 61 | imgui.push_style_color(imgui.COLOR_POPUP_BACKGROUND, *back) 62 | yield 63 | imgui.pop_style_color(14) 64 | else: 65 | yield 66 | 67 | #---------------------------------------------------------------------------- 68 | 69 | @contextlib.contextmanager 70 | def item_width(width=None): 71 | if width is not None: 72 | imgui.push_item_width(width) 73 | yield 74 | imgui.pop_item_width() 75 | else: 76 | yield 77 | 78 | #---------------------------------------------------------------------------- 79 | 80 | def scoped_by_object_id(method): 81 | def decorator(self, *args, **kwargs): 82 | imgui.push_id(str(id(self))) 83 | res = method(self, *args, **kwargs) 84 | imgui.pop_id() 85 | return res 86 | return decorator 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | def button(label, width=0, enabled=True): 91 | with grayed_out(not enabled): 92 | clicked = imgui.button(label, width=width) 93 | clicked = clicked and enabled 94 | return clicked 95 | 96 | #---------------------------------------------------------------------------- 97 | 98 | def collapsing_header(text, visible=None, flags=0, default=False, enabled=True, show=True): 99 | expanded = False 100 | if show: 101 | if default: 102 | flags |= imgui.TREE_NODE_DEFAULT_OPEN 103 | if not enabled: 104 | flags |= imgui.TREE_NODE_LEAF 105 | with grayed_out(not enabled): 106 | expanded, visible = imgui.collapsing_header(text, visible=visible, flags=flags) 107 | expanded = expanded and enabled 108 | return expanded, visible 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | def popup_button(label, width=0, enabled=True): 113 | if button(label, width, enabled): 114 | imgui.open_popup(label) 115 | opened = imgui.begin_popup(label) 116 | return opened 117 | 118 | #---------------------------------------------------------------------------- 119 | 120 | def input_text(label, value, buffer_length, flags, width=None, help_text=''): 121 | old_value = value 122 | color = list(imgui.get_style().colors[imgui.COLOR_TEXT]) 123 | if value == '': 124 | color[-1] *= 0.5 125 | with item_width(width): 126 | imgui.push_style_color(imgui.COLOR_TEXT, *color) 127 | value = value if value != '' else help_text 128 | changed, value = imgui.input_text(label, value, buffer_length, flags) 129 | value = value if value != help_text else '' 130 | imgui.pop_style_color(1) 131 | if not flags & imgui.INPUT_TEXT_ENTER_RETURNS_TRUE: 132 | changed = (value != old_value) 133 | return changed, value 134 | 135 | #---------------------------------------------------------------------------- 136 | 137 | def drag_previous_control(enabled=True): 138 | dragging = False 139 | dx = 0 140 | dy = 0 141 | if imgui.begin_drag_drop_source(imgui.DRAG_DROP_SOURCE_NO_PREVIEW_TOOLTIP): 142 | if enabled: 143 | dragging = True 144 | dx, dy = imgui.get_mouse_drag_delta() 145 | imgui.reset_mouse_drag_delta() 146 | imgui.end_drag_drop_source() 147 | return dragging, dx, dy 148 | 149 | #---------------------------------------------------------------------------- 150 | 151 | def drag_button(label, width=0, enabled=True): 152 | clicked = button(label, width=width, enabled=enabled) 153 | dragging, dx, dy = drag_previous_control(enabled=enabled) 154 | return clicked, dragging, dx, dy 155 | 156 | #---------------------------------------------------------------------------- 157 | 158 | def drag_hidden_window(label, x, y, width, height, enabled=True): 159 | imgui.push_style_color(imgui.COLOR_WINDOW_BACKGROUND, 0, 0, 0, 0) 160 | imgui.push_style_color(imgui.COLOR_BORDER, 0, 0, 0, 0) 161 | imgui.set_next_window_position(x, y) 162 | imgui.set_next_window_size(width, height) 163 | imgui.begin(label, closable=False, flags=(imgui.WINDOW_NO_TITLE_BAR | imgui.WINDOW_NO_RESIZE | imgui.WINDOW_NO_MOVE)) 164 | dragging, dx, dy = drag_previous_control(enabled=enabled) 165 | imgui.end() 166 | imgui.pop_style_color(2) 167 | return dragging, dx, dy 168 | 169 | #---------------------------------------------------------------------------- 170 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/viz/pickle_widget.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import os 11 | import re 12 | 13 | import dnnlib 14 | import imgui 15 | import numpy as np 16 | from gui_utils import imgui_utils 17 | 18 | from . import renderer 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | def _locate_results(pattern): 23 | return pattern 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | class PickleWidget: 28 | def __init__(self, viz): 29 | self.viz = viz 30 | self.search_dirs = [] 31 | self.cur_pkl = None 32 | self.user_pkl = '' 33 | self.recent_pkls = [] 34 | self.browse_cache = dict() # {tuple(path, ...): [dnnlib.EasyDict(), ...], ...} 35 | self.browse_refocus = False 36 | self.load('', ignore_errors=True) 37 | 38 | def add_recent(self, pkl, ignore_errors=False): 39 | try: 40 | resolved = self.resolve_pkl(pkl) 41 | if resolved not in self.recent_pkls: 42 | self.recent_pkls.append(resolved) 43 | except: 44 | if not ignore_errors: 45 | raise 46 | 47 | def load(self, pkl, ignore_errors=False): 48 | viz = self.viz 49 | viz.clear_result() 50 | viz.skip_frame() # The input field will change on next frame. 51 | try: 52 | resolved = self.resolve_pkl(pkl) 53 | name = resolved.replace('\\', '/').split('/')[-1] 54 | self.cur_pkl = resolved 55 | self.user_pkl = resolved 56 | viz.result.message = f'Loading {name}...' 57 | viz.defer_rendering() 58 | if resolved in self.recent_pkls: 59 | self.recent_pkls.remove(resolved) 60 | self.recent_pkls.insert(0, resolved) 61 | except: 62 | self.cur_pkl = None 63 | self.user_pkl = pkl 64 | if pkl == '': 65 | viz.result = dnnlib.EasyDict(message='No network pickle loaded') 66 | else: 67 | viz.result = dnnlib.EasyDict(error=renderer.CapturedException()) 68 | if not ignore_errors: 69 | raise 70 | 71 | @imgui_utils.scoped_by_object_id 72 | def __call__(self, show=True): 73 | viz = self.viz 74 | recent_pkls = [pkl for pkl in self.recent_pkls if pkl != self.user_pkl] 75 | if show: 76 | imgui.text('Pickle') 77 | imgui.same_line(viz.label_w) 78 | changed, self.user_pkl = imgui_utils.input_text('##pkl', self.user_pkl, 1024, 79 | flags=(imgui.INPUT_TEXT_AUTO_SELECT_ALL | imgui.INPUT_TEXT_ENTER_RETURNS_TRUE), 80 | width=(-1 - viz.button_w * 2 - viz.spacing * 2), 81 | help_text=' | | | | /.pkl') 82 | if changed: 83 | self.load(self.user_pkl, ignore_errors=True) 84 | if imgui.is_item_hovered() and not imgui.is_item_active() and self.user_pkl != '': 85 | imgui.set_tooltip(self.user_pkl) 86 | imgui.same_line() 87 | if imgui_utils.button('Recent...', width=viz.button_w, enabled=(len(recent_pkls) != 0)): 88 | imgui.open_popup('recent_pkls_popup') 89 | imgui.same_line() 90 | if imgui_utils.button('Browse...', enabled=len(self.search_dirs) > 0, width=-1): 91 | imgui.open_popup('browse_pkls_popup') 92 | self.browse_cache.clear() 93 | self.browse_refocus = True 94 | 95 | if imgui.begin_popup('recent_pkls_popup'): 96 | for pkl in recent_pkls: 97 | clicked, _state = imgui.menu_item(pkl) 98 | if clicked: 99 | self.load(pkl, ignore_errors=True) 100 | imgui.end_popup() 101 | 102 | if imgui.begin_popup('browse_pkls_popup'): 103 | def recurse(parents): 104 | key = tuple(parents) 105 | items = self.browse_cache.get(key, None) 106 | if items is None: 107 | items = self.list_runs_and_pkls(parents) 108 | self.browse_cache[key] = items 109 | for item in items: 110 | if item.type == 'run' and imgui.begin_menu(item.name): 111 | recurse([item.path]) 112 | imgui.end_menu() 113 | if item.type == 'pkl': 114 | clicked, _state = imgui.menu_item(item.name) 115 | if clicked: 116 | self.load(item.path, ignore_errors=True) 117 | if len(items) == 0: 118 | with imgui_utils.grayed_out(): 119 | imgui.menu_item('No results found') 120 | recurse(self.search_dirs) 121 | if self.browse_refocus: 122 | imgui.set_scroll_here() 123 | viz.skip_frame() # Focus will change on next frame. 124 | self.browse_refocus = False 125 | imgui.end_popup() 126 | 127 | paths = viz.pop_drag_and_drop_paths() 128 | if paths is not None and len(paths) >= 1: 129 | self.load(paths[0], ignore_errors=True) 130 | 131 | viz.args.pkl = self.cur_pkl 132 | 133 | def list_runs_and_pkls(self, parents): 134 | items = [] 135 | run_regex = re.compile(r'\d+-.*') 136 | pkl_regex = re.compile(r'network-snapshot-\d+\.pkl') 137 | for parent in set(parents): 138 | if os.path.isdir(parent): 139 | for entry in os.scandir(parent): 140 | if entry.is_dir() and run_regex.fullmatch(entry.name): 141 | items.append(dnnlib.EasyDict(type='run', name=entry.name, path=os.path.join(parent, entry.name))) 142 | if entry.is_file() and pkl_regex.fullmatch(entry.name): 143 | items.append(dnnlib.EasyDict(type='pkl', name=entry.name, path=os.path.join(parent, entry.name))) 144 | 145 | items = sorted(items, key=lambda item: (item.name.replace('_', ' '), item.path)) 146 | return items 147 | 148 | def resolve_pkl(self, pattern): 149 | assert isinstance(pattern, str) 150 | assert pattern != '' 151 | 152 | # URL => return as is. 153 | if dnnlib.util.is_url(pattern): 154 | return pattern 155 | 156 | # Short-hand pattern => locate. 157 | path = _locate_results(pattern) 158 | 159 | # Run dir => pick the last saved snapshot. 160 | if os.path.isdir(path): 161 | pkl_files = sorted(glob.glob(os.path.join(path, 'network-snapshot-*.pkl'))) 162 | if len(pkl_files) == 0: 163 | raise IOError(f'No network pickle found in "{path}"') 164 | path = pkl_files[-1] 165 | 166 | # Normalize. 167 | path = os.path.abspath(path) 168 | return path 169 | 170 | #---------------------------------------------------------------------------- 171 | -------------------------------------------------------------------------------- /clip2latent/stylegan3/gen_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Generate lerp videos using pretrained network pickle.""" 10 | 11 | import copy 12 | import os 13 | import re 14 | from typing import List, Optional, Tuple, Union 15 | 16 | import click 17 | import dnnlib 18 | import imageio 19 | import numpy as np 20 | import scipy.interpolate 21 | import torch 22 | from tqdm import tqdm 23 | 24 | import legacy 25 | 26 | #---------------------------------------------------------------------------- 27 | 28 | def layout_grid(img, grid_w=None, grid_h=1, float_to_uint8=True, chw_to_hwc=True, to_numpy=True): 29 | batch_size, channels, img_h, img_w = img.shape 30 | if grid_w is None: 31 | grid_w = batch_size // grid_h 32 | assert batch_size == grid_w * grid_h 33 | if float_to_uint8: 34 | img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8) 35 | img = img.reshape(grid_h, grid_w, channels, img_h, img_w) 36 | img = img.permute(2, 0, 3, 1, 4) 37 | img = img.reshape(channels, grid_h * img_h, grid_w * img_w) 38 | if chw_to_hwc: 39 | img = img.permute(1, 2, 0) 40 | if to_numpy: 41 | img = img.cpu().numpy() 42 | return img 43 | 44 | #---------------------------------------------------------------------------- 45 | 46 | def gen_interp_video(G, mp4: str, seeds, shuffle_seed=None, w_frames=60*4, kind='cubic', grid_dims=(1,1), num_keyframes=None, wraps=2, psi=1, device=torch.device('cuda'), **video_kwargs): 47 | grid_w = grid_dims[0] 48 | grid_h = grid_dims[1] 49 | 50 | if num_keyframes is None: 51 | if len(seeds) % (grid_w*grid_h) != 0: 52 | raise ValueError('Number of input seeds must be divisible by grid W*H') 53 | num_keyframes = len(seeds) // (grid_w*grid_h) 54 | 55 | all_seeds = np.zeros(num_keyframes*grid_h*grid_w, dtype=np.int64) 56 | for idx in range(num_keyframes*grid_h*grid_w): 57 | all_seeds[idx] = seeds[idx % len(seeds)] 58 | 59 | if shuffle_seed is not None: 60 | rng = np.random.RandomState(seed=shuffle_seed) 61 | rng.shuffle(all_seeds) 62 | 63 | zs = torch.from_numpy(np.stack([np.random.RandomState(seed).randn(G.z_dim) for seed in all_seeds])).to(device) 64 | ws = G.mapping(z=zs, c=None, truncation_psi=psi) 65 | _ = G.synthesis(ws[:1]) # warm up 66 | ws = ws.reshape(grid_h, grid_w, num_keyframes, *ws.shape[1:]) 67 | 68 | # Interpolation. 69 | grid = [] 70 | for yi in range(grid_h): 71 | row = [] 72 | for xi in range(grid_w): 73 | x = np.arange(-num_keyframes * wraps, num_keyframes * (wraps + 1)) 74 | y = np.tile(ws[yi][xi].cpu().numpy(), [wraps * 2 + 1, 1, 1]) 75 | interp = scipy.interpolate.interp1d(x, y, kind=kind, axis=0) 76 | row.append(interp) 77 | grid.append(row) 78 | 79 | # Render video. 80 | video_out = imageio.get_writer(mp4, mode='I', fps=60, codec='libx264', **video_kwargs) 81 | for frame_idx in tqdm(range(num_keyframes * w_frames)): 82 | imgs = [] 83 | for yi in range(grid_h): 84 | for xi in range(grid_w): 85 | interp = grid[yi][xi] 86 | w = torch.from_numpy(interp(frame_idx / w_frames)).to(device) 87 | img = G.synthesis(ws=w.unsqueeze(0), noise_mode='const')[0] 88 | imgs.append(img) 89 | video_out.append_data(layout_grid(torch.stack(imgs), grid_w=grid_w, grid_h=grid_h)) 90 | video_out.close() 91 | 92 | #---------------------------------------------------------------------------- 93 | 94 | def parse_range(s: Union[str, List[int]]) -> List[int]: 95 | '''Parse a comma separated list of numbers or ranges and return a list of ints. 96 | 97 | Example: '1,2,5-10' returns [1, 2, 5, 6, 7] 98 | ''' 99 | if isinstance(s, list): return s 100 | ranges = [] 101 | range_re = re.compile(r'^(\d+)-(\d+)$') 102 | for p in s.split(','): 103 | m = range_re.match(p) 104 | if m: 105 | ranges.extend(range(int(m.group(1)), int(m.group(2))+1)) 106 | else: 107 | ranges.append(int(p)) 108 | return ranges 109 | 110 | #---------------------------------------------------------------------------- 111 | 112 | def parse_tuple(s: Union[str, Tuple[int,int]]) -> Tuple[int, int]: 113 | '''Parse a 'M,N' or 'MxN' integer tuple. 114 | 115 | Example: 116 | '4x2' returns (4,2) 117 | '0,1' returns (0,1) 118 | ''' 119 | if isinstance(s, tuple): return s 120 | m = re.match(r'^(\d+)[x,](\d+)$', s) 121 | if m: 122 | return (int(m.group(1)), int(m.group(2))) 123 | raise ValueError(f'cannot parse tuple {s}') 124 | 125 | #---------------------------------------------------------------------------- 126 | 127 | @click.command() 128 | @click.option('--network', 'network_pkl', help='Network pickle filename', required=True) 129 | @click.option('--seeds', type=parse_range, help='List of random seeds', required=True) 130 | @click.option('--shuffle-seed', type=int, help='Random seed to use for shuffling seed order', default=None) 131 | @click.option('--grid', type=parse_tuple, help='Grid width/height, e.g. \'4x3\' (default: 1x1)', default=(1,1)) 132 | @click.option('--num-keyframes', type=int, help='Number of seeds to interpolate through. If not specified, determine based on the length of the seeds array given by --seeds.', default=None) 133 | @click.option('--w-frames', type=int, help='Number of frames to interpolate between latents', default=120) 134 | @click.option('--trunc', 'truncation_psi', type=float, help='Truncation psi', default=1, show_default=True) 135 | @click.option('--output', help='Output .mp4 filename', type=str, required=True, metavar='FILE') 136 | def generate_images( 137 | network_pkl: str, 138 | seeds: List[int], 139 | shuffle_seed: Optional[int], 140 | truncation_psi: float, 141 | grid: Tuple[int,int], 142 | num_keyframes: Optional[int], 143 | w_frames: int, 144 | output: str 145 | ): 146 | """Render a latent vector interpolation video. 147 | 148 | Examples: 149 | 150 | \b 151 | # Render a 4x2 grid of interpolations for seeds 0 through 31. 152 | python gen_video.py --output=lerp.mp4 --trunc=1 --seeds=0-31 --grid=4x2 \\ 153 | --network=https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan3/versions/1/files/stylegan3-r-afhqv2-512x512.pkl 154 | 155 | Animation length and seed keyframes: 156 | 157 | The animation length is either determined based on the --seeds value or explicitly 158 | specified using the --num-keyframes option. 159 | 160 | When num keyframes is specified with --num-keyframes, the output video length 161 | will be 'num_keyframes*w_frames' frames. 162 | 163 | If --num-keyframes is not specified, the number of seeds given with 164 | --seeds must be divisible by grid size W*H (--grid). In this case the 165 | output video length will be '# seeds/(w*h)*w_frames' frames. 166 | """ 167 | 168 | print('Loading networks from "%s"...' % network_pkl) 169 | device = torch.device('cuda') 170 | with dnnlib.util.open_url(network_pkl) as f: 171 | G = legacy.load_network_pkl(f)['G_ema'].to(device) # type: ignore 172 | 173 | gen_interp_video(G=G, mp4=output, bitrate='12M', grid_dims=grid, num_keyframes=num_keyframes, w_frames=w_frames, seeds=seeds, shuffle_seed=shuffle_seed, psi=truncation_psi) 174 | 175 | #---------------------------------------------------------------------------- 176 | 177 | if __name__ == "__main__": 178 | generate_images() # pylint: disable=no-value-for-parameter 179 | 180 | #---------------------------------------------------------------------------- 181 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # clip2latent - Official PyTorch Code 2 | 3 | [![Open Arxiv](https://img.shields.io/badge/arXiv-2210.02347-b31b1b.svg)](https://arxiv.org/abs/2210.02347) 4 | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/justinpinkney/clip2latent/blob/main/demo.ipynb) 5 | [![Open in Spaces](https://img.shields.io/badge/%F0%9F%A4%97-Open%20in%20Spaces-orange)](https://huggingface.co/spaces/lambdalabs/clip2latent-demo) 6 | 7 | ![](images/headline-large.jpeg) 8 | 9 | > ## _clip2latent: Text driven sampling of a pre-trained StyleGAN using denoising diffusion and CLIP_ 10 | > 11 | > Justin N. M. Pinkney and Chuan Li 12 | > @ Lambda Inc. 13 | > 14 | > We introduce a new method to efficiently create text-to-image models from a pre-trained CLIP and StyleGAN. It enables text driven sampling with an existing generative model without any external data or fine-tuning. This is achieved by training a diffusion model conditioned on CLIP embeddings to sample latent vectors of a pre-trained StyleGAN, which we call \textit{clip2latent}. We leverage the alignment between CLIP’s image and text embeddings to avoid the need for any text labelled data for training the conditional diffusion model. We demonstrate that clip2latent allows us to generate high-resolution (1024x1024 pixels) images based on text prompts with fast sampling, high image quality, and low training compute and data requirements. We also show that the use of the well studied StyleGAN architecture, without further fine-tuning, allows us to directly apply existing methods to control and modify the generated images adding a further layer of control to our text-to-image pipeline. 15 | 16 | 17 | ## Installation 18 | 19 | ```bash 20 | git clone https://github.com/justinpinkney/clip2latent.git 21 | cd clip2latent 22 | python -m venv .venv --prompt clip2latent 23 | . .venv/bin/activate 24 | pip install --upgrade pip 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | ## Usage 29 | 30 | ### Inference 31 | 32 | To run the models for inference the simplest way is to start the gradio demo (or run it [in Colab](https://colab.research.google.com/github/justinpinkney/clip2latent/blob/demo/colab-gradio.ipynb)): 33 | 34 | ```bash 35 | python scripts/demo.py 36 | ``` 37 | 38 | This will fetch the required models from huggingface hub and start gradio demo which can be accessed via a web browser. 39 | 40 | To run a model via python: 41 | 42 | ```python 43 | from clip2latent import models 44 | 45 | prompt = "a hairy man" 46 | device = "cuda:0" 47 | cfg_file = "https://huggingface.co/lambdalabs/clip2latent/resolve/main/ffhq-sg2-510.yaml" 48 | checkpoint = "https://huggingface.co/lambdalabs/clip2latent/resolve/main/ffhq-sg2-510.ckpt" 49 | 50 | model = models.Clip2StyleGAN(cfg_file, device, checkpoint) 51 | images, clip_score = model(prompt) 52 | # images are tensors of shape: bchw, range: -1..1 53 | ``` 54 | 55 | Or take a look at the example notebook `demo.ipynb`. 56 | 57 | ### Training 58 | 59 | #### Generate data 60 | 61 | To train a model of your own first you need to generate some data. We provide a command line interface which will run a StyleGAN model and pass the generated images to CLIP. The W latent vector and the CLIP image embedding will be stored as npy files, packed into tar files ready for use as a webdataset. To generate data used to traing the ffhq model in the paper do: 62 | 63 | ```bash 64 | python scripts/generate_dataset.py 65 | ``` 66 | 67 | For more details of dataset generation options see the help for `generate_dataset.py`: 68 | 69 | ``` 70 | Usage: generate_dataset.py [OPTIONS] OUT_DIR 71 | 72 | Arguments: 73 | OUT_DIR Location to save dataset [required] 74 | 75 | Options: 76 | --n-samples INTEGER Number of samples to generate [default: 1000000] 77 | --generator-name TEXT Name of predefined generator loader [default: sg2-ffhq-1024] 78 | --feature-extractor-name TEXT CLIP model to use for image embedding [default: ViT-B/32] 79 | --n-gpus INTEGER Number of GPUs to use [default: 2] 80 | --out-image-size INTEGER If saving generated images, resize to this dimension [default: 256] 81 | --batch-size INTEGER Batch size [default: 32] 82 | --n-save-workers INTEGER Number of workers to use while saving [default: 16] 83 | --space TEXT Latent space to use [default: w] 84 | --samples-per-folder INTEGER Number of samples per tar file [default: 10000] 85 | --save-im / --no-save-im Save images? [default: no-save-im] 86 | ``` 87 | 88 | To use a different StyleGAN generator, add the required loading function to the `generators` dict in `generate_dataset.py`, then use that key as the `generator_name`. To use non-StyleGAN generators should be possible but would require additional modifications. 89 | 90 | #### Train 91 | 92 | To manage configuration for the model and training parameters we use [hydra](https://hydra.cc/), to train with default configuration simply run: 93 | 94 | ```bash 95 | python scripts/train.py 96 | ``` 97 | 98 | This will run the model with the default configuration as follows: 99 | 100 | ```yaml 101 | model: 102 | network: 103 | dim: 512 104 | num_timesteps: 1000 105 | depth: 12 106 | dim_head: 64 107 | heads: 12 108 | diffusion: 109 | image_embed_dim: 512 110 | timesteps: 1000 111 | cond_drop_prob: 0.2 112 | image_embed_scale: 1.0 113 | text_embed_scale: 1.0 114 | beta_schedule: cosine 115 | predict_x_start: true 116 | data: 117 | bs: 512 118 | format: webdataset 119 | path: data/webdataset/sg2-ffhq-1024-clip/{00000..99}.tar 120 | embed_noise_scale: 1.0 121 | sg_pkl: https://api.ngc.nvidia.com/v2/models/nvidia/research/stylegan2/versions/1/files/stylegan2-ffhq-1024x1024.pkl 122 | clip_variant: ViT-B/32 123 | n_latents: 1 124 | latent_dim: 512 125 | latent_repeats: 126 | - 18 127 | val_im_samples: 64 128 | val_text_samples: text/face-val.txt 129 | val_samples_per_text: 4 130 | logging: wandb 131 | wandb_project: clip2latent 132 | wandb_entity: null 133 | name: null 134 | device: cuda:0 135 | resume: null 136 | train: 137 | znorm_embed: false 138 | znorm_latent: true 139 | max_it: 1000000 140 | val_it: 10000 141 | lr: 0.0001 142 | weight_decay: 0.01 143 | ema_update_every: 10 144 | ema_beta: 0.9999 145 | ema_power: 0.75 146 | ``` 147 | 148 | To train with a different configuration you can either change individual parameters using the following command line override syntax: 149 | 150 | ```bash 151 | python scripts/train.py data.bs=128 152 | ``` 153 | 154 | which would set the batch size to 128. 155 | 156 | Alternatively you can create your own yaml configuration files and switch between them, e.g. we also provide an example 'small' model configuration at `config/model/small.yaml`, to train using this simply run 157 | 158 | ```bash 159 | python scripts/train.py model=small 160 | ``` 161 | 162 | For more details please refer to the [hydra documentation](https://hydra.cc/docs/intro/). 163 | 164 | Training is set up to run on a single GPU and does not currently support multigpu training. The default settings will take around 18 hours to train on a single A100-80GB, although the best checkpoint is likely to occur within 10 hours of training. 165 | 166 | ## Acknowledgements 167 | 168 | - This code uses [lucidrains](https://github.com/lucidrains)' implementation of the [dalle2 prior](https://github.com/lucidrains/DALLE2-pytorch). 169 | - Compute for training was provided by [Lambda GPU Cloud](https://lambdalabs.com/service/gpu-cloud). 170 | 171 | ## Citation 172 | 173 | ``` 174 | @misc{https://doi.org/10.48550/arxiv.2210.02347, 175 | doi = {10.48550/ARXIV.2210.02347}, 176 | url = {https://arxiv.org/abs/2210.02347}, 177 | author = {Pinkney, Justin N. M. and Li, Chuan}, 178 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, 179 | title = {clip2latent: Text driven sampling of a pre-trained StyleGAN using denoising diffusion and CLIP}, 180 | publisher = {arXiv}, 181 | year = {2022}, 182 | copyright = {Creative Commons Attribution 4.0 International} 183 | } 184 | ``` 185 | --------------------------------------------------------------------------------