├── .dockerignore ├── .gitignore ├── .gitmodules ├── Dockerfile ├── LICENSE ├── README.md ├── analysis ├── examine_noise_color_coding.py ├── render_noise_maps.py └── shapiro_wilk.py ├── analyze_latent_code.py ├── configs ├── autoencoder.yaml └── fine_tune_for_noise.yaml ├── data ├── __init__.py ├── autoencoder_dataset.py ├── demo_dataset.py └── denoising_eval_dataset.py ├── embeddings ├── __init__.py └── utils.py ├── evaluate_all_denoising_checkpoints.py ├── evaluate_checkpoints.py ├── evaluate_denoising.py ├── evaluation ├── __init__.py ├── autoencoder_evaluation.py ├── calculate_fid_for_dataset.py ├── create_denoise_eval_set.py ├── datasets.json ├── fid.py ├── find_all_saved_checkpoints.py └── psnr_ssim.py ├── extensions ├── __init__.py └── fid_score.py ├── file_based_simple_style_transfer.py ├── interpolate_between_embeddings.py ├── latent_projecting ├── __init__.py ├── losses.py ├── projector.py └── style_transfer.py ├── losses ├── __init__.py ├── lpips │ ├── __init__.py │ ├── base_model.py │ ├── dist_model.py │ ├── networks_basic.py │ ├── pretrained_networks.py │ └── weights │ │ ├── v0.0 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth │ │ └── v0.1 │ │ ├── alex.pth │ │ ├── squeeze.pth │ │ └── vgg.pth ├── perceptual_loss.py ├── perceptual_style_loss.py ├── psnr.py └── style_loss.py ├── networks ├── __init__.py ├── encoder │ ├── __init__.py │ ├── autoencoder.py │ ├── resnet_based_encoder.py │ └── u_net_like_encoder.py ├── stylegan1 │ ├── __init__.py │ └── model.py └── stylegan2 │ ├── __init__.py │ ├── model.py │ └── op │ ├── __init__.py │ ├── fused_act.py │ ├── fused_bias_act.cpp │ ├── fused_bias_act_kernel.cu │ ├── upfirdn2d.cpp │ ├── upfirdn2d.py │ └── upfirdn2d_kernel.cu ├── project.py ├── pytest.ini ├── reconstruct_image.py ├── requirements.txt ├── tests ├── __init__.py ├── test_argument_parsing.py ├── test_data_structures.py ├── test_image_utils.py ├── test_latent_projecting.py ├── test_losses.py ├── test_projecting_functions.py ├── test_projector.py ├── test_style_transfer.py └── testdata │ ├── config_stylegan_1.json │ └── config_stylegan_2.json ├── train_code_finder.py ├── updater ├── __init__.py ├── autoencoder_discriminator_updater.py └── autoencoder_updater.py └── utils ├── StyleImagePlotter.py ├── __init__.py ├── clean_runs.sh ├── command_line_args.py ├── config.py ├── convert_autoencoder_checkpoint.py ├── create_denoising_eval_set.py ├── data_loading.py ├── folder_to_json.py ├── image_utils.py └── strip_images_from_big_embedding_file.py /.dockerignore: -------------------------------------------------------------------------------- 1 | stylegan_code_finder/logs/**/* 2 | stylegan_code_finder/logs 3 | stylegan_code_finder/wandb/**/* 4 | stylegan_code_finder/wandb 5 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | .pyc 3 | __pycache__ 4 | build 5 | dist 6 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "training_tools"] 2 | path = training_tools 3 | url = https://github.com/Bartzi/pytorch-training.git 4 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.2-devel-ubuntu18.04 2 | 3 | ARG DEBIAN_FRONTEND=noninteractive 4 | 5 | RUN apt-get update && apt-get install -y \ 6 | build-essential \ 7 | git \ 8 | ninja-build \ 9 | software-properties-common \ 10 | pkg-config \ 11 | unzip \ 12 | wget \ 13 | libgl1-mesa-glx \ 14 | libgl1 \ 15 | zsh 16 | 17 | ARG PYTHON=python3.8 18 | ENV LANG C.UTF-8 19 | 20 | RUN add-apt-repository ppa:deadsnakes/ppa && apt-get update && apt-get install -y \ 21 | ${PYTHON} \ 22 | python3-distutils \ 23 | python3-apt \ 24 | python3-dev \ 25 | ${PYTHON}-dev 26 | 27 | RUN wget https://bootstrap.pypa.io/get-pip.py 28 | RUN ${PYTHON} get-pip.py 29 | RUN ln -sf /usr/bin/${PYTHON} /usr/local/bin/python3 30 | RUN ln -sf /usr/local/bin/pip /usr/local/bin/pip3 31 | 32 | RUN pip3 --no-cache-dir install --upgrade \ 33 | pip \ 34 | setuptools 35 | 36 | RUN ln -s $(which ${PYTHON}) /usr/local/bin/python 37 | 38 | ARG BASE=/app 39 | RUN mkdir ${BASE} 40 | RUN mkdir /data 41 | 42 | COPY requirements.txt ${BASE}/requirements.txt 43 | COPY training_tools ${BASE}/training_tools 44 | RUN cd ${BASE}/training_tools && pip3 install . 45 | 46 | WORKDIR ${BASE} 47 | RUN pip3 install -r requirements.txt 48 | 49 | # you can change these to the values of your user to avoid permission problems 50 | ARG UNAME=one_model 51 | ARG UID=1000 52 | ARG GID=100 53 | 54 | RUN groupadd -g $GID -o $UNAME 55 | RUN useradd -m -u $UID -g $GID -o -s /bin/zsh $UNAME 56 | 57 | USER $UNAME 58 | RUN git clone https://github.com/robbyrussell/oh-my-zsh.git ~/.oh-my-zsh 59 | RUN cp ~/.oh-my-zsh/templates/zshrc.zsh-template ~/.zshrc 60 | 61 | CMD ["/bin/zsh"] 62 | -------------------------------------------------------------------------------- /analysis/examine_noise_color_coding.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import sys 3 | from pathlib import Path 4 | 5 | import numpy 6 | import torch 7 | from PIL import Image 8 | from tqdm import tqdm 9 | from typing import List 10 | 11 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 12 | 13 | from latent_projecting import Latents 14 | from networks import get_stylegan_1_based_autoencoder, get_stylegan_2_based_autoencoder, load_weights, \ 15 | StyleganAutoencoder 16 | from pytorch_training.images import make_image 17 | from utils.config import load_config 18 | from utils.data_loading import build_data_loader, build_latent_and_noise_generator 19 | 20 | 21 | def render_color_grid(autoencoder: StyleganAutoencoder, latents: Latents, indices: List[int], grid_size: int, bounds: List[int]) -> List[List[torch.Tensor]]: 22 | 23 | def generate(latents: Latents) -> torch.Tensor: 24 | with torch.no_grad(): 25 | generated, _ = autoencoder.decoder([latents.latent], input_is_latent=autoencoder.is_wplus(latents), noise=latents.noise) 26 | return generated 27 | 28 | assert len(indices) == 2, "Render Color grid only supports the rendering of two indices at once!" 29 | assert len(bounds) == 2, "Render Color grid only supports the rendering with min and max bound" 30 | 31 | shift_factor = numpy.linspace(bounds[0], bounds[1], num=grid_size) 32 | x_shifts, y_shifts = map(numpy.squeeze, numpy.meshgrid(shift_factor, shift_factor, sparse=True)) 33 | 34 | x_noise_map = latents.noise[indices[0]].clone() 35 | y_noise_map = latents.noise[indices[1]].clone() 36 | 37 | grid = [] 38 | for y_shift in tqdm(y_shifts, leave=False): 39 | latents.noise[indices[1]] = y_noise_map.clone() * y_shift 40 | x_images = [] 41 | for x_shift in tqdm(x_shifts, leave=False): 42 | latents.noise[indices[0]] = x_noise_map.clone() * x_shift 43 | generated_image = generate(latents) 44 | generated_image = Image.fromarray(make_image(generated_image[0])) 45 | x_images.append(generated_image) 46 | grid.append(x_images) 47 | 48 | return grid 49 | 50 | 51 | def main(args): 52 | checkpoint_path = Path(args.model_checkpoint) 53 | 54 | config = load_config(checkpoint_path, None) 55 | if config['stylegan_variant'] == 1: 56 | autoencoder_func = get_stylegan_1_based_autoencoder(argparse.Namespace(**config)) 57 | else: 58 | autoencoder_func = get_stylegan_2_based_autoencoder(argparse.Namespace(**config)) 59 | 60 | autoencoder = autoencoder_func( 61 | config['image_size'], 62 | config['latent_size'], 63 | config['input_dim'], 64 | ).to(args.device) 65 | 66 | load_weights(autoencoder, checkpoint_path, key='autoencoder', strict=True) 67 | 68 | config['batch_size'] = 1 69 | if args.generate: 70 | data_loader = build_latent_and_noise_generator(autoencoder, config) 71 | else: 72 | data_loader = build_data_loader(args.images, config, args.absolute, shuffle_off=True) 73 | 74 | noise_dest_dir = checkpoint_path.parent.parent / "color_model_analysis" 75 | noise_dest_dir.mkdir(parents=True, exist_ok=True) 76 | 77 | num_images = 0 78 | for idx, batch in enumerate(tqdm(data_loader, total=args.num_images)): 79 | batch = batch.to(args.device) 80 | if args.generate: 81 | latents = batch 82 | image_name = Path(f"generate_{idx}.png") 83 | else: 84 | with torch.no_grad(): 85 | latents: Latents = autoencoder.encode(batch) 86 | 87 | image_name = Path(data_loader.dataset.image_data[idx]) 88 | 89 | color_grid = render_color_grid(autoencoder, latents, args.indices, args.grid_size, args.bounds) 90 | 91 | full_image = Image.new( 92 | 'RGB', 93 | (args.grid_size * config['image_size'], args.grid_size * config['image_size']) 94 | ) 95 | 96 | for y, x_images in enumerate(color_grid): 97 | for x, image in enumerate(x_images): 98 | full_image.paste(image, (x * config['image_size'], y * config['image_size'])) 99 | 100 | full_image.save(noise_dest_dir / f"{image_name.stem}_color_grid.png") 101 | 102 | num_images += 1 103 | if num_images >= args.num_images: 104 | break 105 | 106 | 107 | if __name__ == "__main__": 108 | parser = argparse.ArgumentParser(description="Render noise maps of given image", formatter_class=argparse.ArgumentDefaultsHelpFormatter) 109 | parser.add_argument("model_checkpoint", help="path to model checkpoint that is to be used to generate noise map") 110 | parser.add_argument("images", help="path to json with all images to be analyzed") 111 | parser.add_argument("-n", "--num-images", type=int, default=1, help="number of images where you want to have a look at the noise maps") 112 | parser.add_argument("--absolute", action='store_true', default=False, help="use this if the json contains absolute paths") 113 | parser.add_argument("--device", default='cuda', help="which device to use") 114 | parser.add_argument("-i", "--indices", type=int, nargs=2, default=[4, 5], help="indices to use for color space analysis") 115 | parser.add_argument("-g", "--grid-size", type=int, default=10, help="Size of rendered image grid (squared, so only one dim necessary)") 116 | parser.add_argument("-b", "--bounds", type=int, nargs=2, default=[-2, 2], help="interpolation bounds") 117 | parser.add_argument("--generate", action='store_true', default=False, help="Do not use images, but use unconditional generation instead") 118 | 119 | main(parser.parse_args()) 120 | -------------------------------------------------------------------------------- /analysis/render_noise_maps.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import sys 4 | from pathlib import Path 5 | from typing import List 6 | 7 | import numpy 8 | import torch 9 | from PIL import Image 10 | from tqdm import tqdm, trange 11 | 12 | sys.path.append(str(Path(__file__).resolve().parent.parent)) 13 | 14 | 15 | from latent_projecting import Latents 16 | from networks import load_weights, StyleganAutoencoder, get_autoencoder 17 | from pytorch_training.images import make_image 18 | from utils.config import load_config 19 | from utils.data_loading import build_data_loader, build_latent_and_noise_generator 20 | 21 | 22 | def noise_normalize(tensor: torch.Tensor) -> torch.Tensor: 23 | normalized_tensors = [] 24 | for sub_tensor in tensor: 25 | min_value = sub_tensor.min() 26 | sub_tensor = sub_tensor.sub(min_value) 27 | max_value = sub_tensor.max() 28 | normalized_tensors.append(sub_tensor.div(max_value + 1e-8)) 29 | 30 | return torch.stack(normalized_tensors) 31 | 32 | 33 | def render_with_shifted_noise(autoencoder: StyleganAutoencoder, latents: Latents, shifting_rounds: int) -> List[List[Image.Image]]: 34 | if shifting_rounds == 1: 35 | shift_factor = torch.tensor([random.random() * 4 - 2]) 36 | else: 37 | shift_factor = torch.tensor(numpy.linspace(-2, 2, num=shifting_rounds)) 38 | 39 | def generate(latents: Latents) -> torch.Tensor: 40 | with torch.no_grad(): 41 | generated, _ = autoencoder.decoder([latents.latent], input_is_latent=autoencoder.is_wplus(latents), noise=latents.noise) 42 | return generated 43 | 44 | shifted_images = [[Image.fromarray(make_image(generate(latents)[0]))] for _ in range(shifting_rounds)] 45 | 46 | for the_round in trange(shifting_rounds, leave=False): 47 | for i in range(len(latents.noise)): 48 | noise_copy = latents.noise[i].clone() 49 | latents.noise[i] = latents.noise[i] * shift_factor[the_round] 50 | generated_image = generate(latents) 51 | generated_image = Image.fromarray(make_image(generated_image[0])) 52 | shifted_images[the_round].append(generated_image) 53 | latents.noise[i] = noise_copy 54 | 55 | return shifted_images 56 | 57 | 58 | def main(args): 59 | checkpoint_path = Path(args.model_checkpoint) 60 | 61 | config = load_config(checkpoint_path, None) 62 | 63 | autoencoder = get_autoencoder(config).to(args.device) 64 | load_weights(autoencoder, checkpoint_path, key='autoencoder', strict=True) 65 | 66 | config['batch_size'] = 1 67 | if args.generate: 68 | data_loader = build_latent_and_noise_generator(autoencoder, config) 69 | else: 70 | data_loader = build_data_loader(args.images, config, args.absolute, shuffle_off=True) 71 | 72 | noise_dest_dir = checkpoint_path.parent.parent / "noise_maps" 73 | noise_dest_dir.mkdir(parents=True, exist_ok=True) 74 | 75 | num_images = 0 76 | for idx, batch in enumerate(tqdm(data_loader, total=args.num_images)): 77 | batch = batch.to(args.device) 78 | 79 | if args.generate: 80 | latents = batch 81 | image_names = [Path(f"generate_{idx}.png")] 82 | else: 83 | with torch.no_grad(): 84 | latents: Latents = autoencoder.encode(batch) 85 | 86 | image_names = [Path(data_loader.dataset.image_data[idx * config['batch_size'] + batch_idx]) for batch_idx in range(len(batch))] 87 | 88 | if args.shift_noise: 89 | noise_shifted_tensors = render_with_shifted_noise(autoencoder, latents, args.rounds) 90 | 91 | images = [] 92 | for noise_tensors in latents.noise: 93 | noise_images = make_image(noise_tensors, normalize_func=noise_normalize) 94 | images.append([Image.fromarray(im).resize((config['image_size'], config['image_size']), Image.NEAREST) for im in noise_images]) 95 | 96 | for batch_idx, (image, orig_file_name) in enumerate(zip(batch, image_names)): 97 | full_image = Image.new( 98 | 'RGB', 99 | ( 100 | (len(images) + 1) * config['image_size'], 101 | config['image_size'] if not args.shift_noise else config['image_size'] * (args.rounds + 1) 102 | ) 103 | ) 104 | if not args.generate: 105 | full_image.paste(Image.fromarray(make_image(image)), (0, 0)) 106 | for i, noise_images in enumerate(images): 107 | full_image.paste(noise_images[batch_idx], ((i + 1) * config['image_size'], 0)) 108 | 109 | if args.shift_noise: 110 | for i, shifted_images in enumerate(noise_shifted_tensors): 111 | for j, shifted_image in enumerate(shifted_images): 112 | full_image.paste(shifted_image, (j * config['image_size'], (i + 1) * config['image_size'])) 113 | 114 | full_image.save(noise_dest_dir / f"{orig_file_name.stem}_noise.png") 115 | 116 | num_images += len(image_names) 117 | if num_images >= args.num_images: 118 | break 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser(description="Render noise maps of given image") 123 | parser.add_argument("model_checkpoint", help="path to model checkpoint that is to be used to generate noise map") 124 | parser.add_argument("images", help="path to json with all images to be analyzed") 125 | parser.add_argument("-n", "--num-images", type=int, default=1, help="number of images where you want to have a look at the noise maps") 126 | parser.add_argument("--absolute", action='store_true', default=False, help="use this if the json contains absolute paths") 127 | parser.add_argument("--device", default='cuda', help="which device to use") 128 | parser.add_argument("-s", "--shift-noise", action='store_true', default=False, help="do not just render the noise maps, but also render some versions of the image with shifted noise vectors") 129 | parser.add_argument("-r", "--rounds", type=int, default=1, help="Number of shifting rounds to perform") 130 | parser.add_argument("--generate", action='store_true', default=False, help="Do not use images, but use unconditional generation instead") 131 | 132 | main(parser.parse_args()) 133 | -------------------------------------------------------------------------------- /analysis/shapiro_wilk.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | from scipy import stats 5 | from tqdm import tqdm 6 | 7 | 8 | def main(args): 9 | with np.load(args.latent_codes, mmap_mode='r') as data: 10 | # available arrays: 11 | # latent_codes, image_names, 12 | # noise_4_4, noise_8_8, noise_16_16, noise_32_32, noise_64_64, noise_128_128, noise_256_256 13 | latent_codes = data["latent_codes"] 14 | num_samples, slices, code_length = latent_codes.shape 15 | print("shape:", num_samples, slices, code_length) 16 | print("normal distribution can be assumed if second value is larger than 0.05") 17 | for i in range(slices): 18 | shapiro_test = stats.shapiro(latent_codes[:,i,:]) 19 | print("samples :", "slice:", i, "latent_space :", "result:", shapiro_test) 20 | for i in range(slices): 21 | shapiro_test = stats.shapiro(latent_codes[:, i, 130]) 22 | print("samples :", "slice:", i, "latent_space 0", "result:", shapiro_test) 23 | for i in range(slices): 24 | shapiro_test = stats.shapiro(latent_codes[0, i, :]) 25 | print("samples 0", "slice:", i, "latent_space :", "result:", shapiro_test) 26 | 27 | 28 | if __name__ == "__main__": 29 | parser = argparse.ArgumentParser( 30 | description="Test whether the latent codes are a normal distribution with shapiro wilk test", 31 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 32 | ) 33 | parser.add_argument("latent_codes", help='Path to file which contains latent codes') 34 | parser.add_argument("-d", "--device", default='cuda', help='Use CPU or GPU for embedding') 35 | 36 | main(parser.parse_args()) 37 | -------------------------------------------------------------------------------- /configs/autoencoder.yaml: -------------------------------------------------------------------------------- 1 | 2 | # logger options 3 | image_save_iter: 500 # How often do you want to save output images during training 4 | image_display_iter: 500 # How often do you want to display output images during training 5 | display_size: 16 # How many images do you want to display each time 6 | snapshot_save_iter: 10000 # How often do you want to save trained models 7 | log_iter: 10 # How often do you want to log the training stats 8 | #validation_iter: 2 # --> if you want to do evaluation in a fixed interval and not every epoch 9 | 10 | # optimization options 11 | max_iter: 100000 # maximum number of training iterations 12 | batch_size: 4 # batch size 13 | weight_decay: 0.0001 # weight decay 14 | beta1: 0.5 # Adam parameter 15 | beta2: 0.999 # Adam parameter 16 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 17 | lr: 1e-04 18 | lr_to_noise: 1e-04 # different learning rate for parameters leading to noise 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | 23 | loss_weights: 24 | reconstruction: 0.1 25 | discriminator: 1 26 | 27 | regularization: 28 | d_interval: 16 29 | r1_weight: 10 30 | 31 | # model options 32 | latent_size: 512 33 | 34 | 35 | # data options 36 | input_dim: 3 # number of image channels [1/3] 37 | num_workers: 20 # number of data loading threads 38 | image_size: 256 # first resize the shortest image side to this size 39 | downsample_size: 256 40 | extend_noise_with_random: True 41 | -------------------------------------------------------------------------------- /configs/fine_tune_for_noise.yaml: -------------------------------------------------------------------------------- 1 | 2 | # logger options 3 | image_save_iter: 500 # How often do you want to save output images during training 4 | image_display_iter: 500 # How often do you want to display output images during training 5 | display_size: 16 # How many images do you want to display each time 6 | snapshot_save_iter: 10000 # How often do you want to save trained models 7 | log_iter: 10 # How often do you want to log the training stats 8 | #validation_iter: 2 # --> if you want to do evaluation in a fixed interval and not every epoch 9 | 10 | # optimization options 11 | max_iter: 100000 # maximum number of training iterations 12 | batch_size: 4 # batch size 13 | weight_decay: 0.0001 # weight decay 14 | beta1: 0.5 # Adam parameter 15 | beta2: 0.999 # Adam parameter 16 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 17 | lr: 1e-07 18 | lr_to_noise: 1e-04 # different learning rate for parameters leading to noise 19 | lr_policy: step # learning rate scheduler 20 | step_size: 100000 # how often to decay learning rate 21 | gamma: 0.5 # how much to decay learning rate 22 | 23 | loss_weights: 24 | reconstruction: 0.1 25 | discriminator: 1 26 | 27 | regularization: 28 | d_interval: 16 29 | r1_weight: 10 30 | 31 | # model options 32 | latent_size: 512 33 | 34 | 35 | # data options 36 | input_dim: 3 # number of image channels [1/3] 37 | num_workers: 20 # number of data loading threads 38 | image_size: 256 # first resize the shortest image side to this size 39 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from typing import Type 3 | 4 | from data.autoencoder_dataset import AutoencoderDataset, DenoisingAutoencoderDataset, \ 5 | BlackAndWhiteDenoisingAutoencoderDataset 6 | 7 | 8 | def get_dataset_class(args: argparse.Namespace) -> Type[AutoencoderDataset]: 9 | if getattr(args, 'denoising', False): 10 | dataset_class = DenoisingAutoencoderDataset 11 | elif getattr(args, 'black_and_white_denoising', False): 12 | dataset_class = BlackAndWhiteDenoisingAutoencoderDataset 13 | else: 14 | dataset_class = AutoencoderDataset 15 | return dataset_class 16 | -------------------------------------------------------------------------------- /data/autoencoder_dataset.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import os 3 | from typing import Dict 4 | 5 | import imgaug 6 | import imgaug.augmenters as iaa 7 | import numpy 8 | from PIL import Image 9 | 10 | from pytorch_training.data.json_dataset import JSONDataset 11 | 12 | DENOISING_VARIANCES = [5, 10, 15, 25, 35, 50] 13 | imgaug.seed(666) 14 | 15 | 16 | class AutoencoderDataset(JSONDataset): 17 | 18 | def augment_image(self, image: Image) -> Image: 19 | return image 20 | 21 | def __getitem__(self, index: int) -> Dict[str, numpy.ndarray]: 22 | path = self.image_data[index] 23 | if self.root is not None: 24 | path = os.path.join(self.root, path) 25 | 26 | image = self.loader(path) 27 | augmented_image = self.augment_image(image) 28 | 29 | if self.transforms is not None: 30 | image = self.transforms(image) 31 | augmented_image = self.transforms(augmented_image) 32 | 33 | return { 34 | 'input_image': augmented_image, 35 | 'output_image': image 36 | } 37 | 38 | 39 | class DenoisingAutoencoderDataset(AutoencoderDataset): 40 | 41 | def __init__(self, *args, **kwargs): 42 | super().__init__(*args, **kwargs) 43 | 44 | gaussian_noise = functools.partial(iaa.AdditiveGaussianNoise, scale=DENOISING_VARIANCES) 45 | self.noise_augmenter = iaa.OneOf([ 46 | gaussian_noise(), 47 | gaussian_noise(per_channel=True) 48 | ]) 49 | 50 | def augment_image(self, image: Image) -> Image: 51 | image = numpy.array(image).copy() 52 | image = self.noise_augmenter(image=image) 53 | image = Image.fromarray(image) 54 | return image 55 | 56 | 57 | class BlackAndWhiteDenoisingAutoencoderDataset(DenoisingAutoencoderDataset): 58 | 59 | def __init__(self, *args, **kwargs): 60 | loader_func = kwargs['loader'] 61 | kwargs['loader'] = lambda path: loader_func(path).convert('L').convert('RGB') 62 | super().__init__(*args, **kwargs) 63 | 64 | def augment_image(self, image: Image) -> Image: 65 | image = super().augment_image(image) 66 | image = image.convert('L').convert("RGB") 67 | return image 68 | -------------------------------------------------------------------------------- /data/demo_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | 3 | from pytorch_training.data.utils import default_loader 4 | from torch.utils import data 5 | from typing import Callable, Dict 6 | 7 | 8 | class DemoDataset(data.Dataset): 9 | 10 | def __init__(self, image_file: str, root: str = None, transforms: Callable = None, loader: Callable = default_loader): 11 | self.image_file = image_file 12 | self.transforms = transforms 13 | self.loader = loader 14 | 15 | def __len__(self): 16 | return 1 17 | 18 | def __getitem__(self, index: int) -> Dict[str, numpy.ndarray]: 19 | image = self.loader(self.image_file) 20 | 21 | if self.transforms is not None: 22 | image = self.transforms(image) 23 | 24 | return { 25 | 'input_image': image, 26 | 'output_image': image 27 | } 28 | 29 | -------------------------------------------------------------------------------- /data/denoising_eval_dataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | from typing import Callable, Dict 3 | 4 | import numpy 5 | import os 6 | from torch.utils.data import Dataset 7 | 8 | from pytorch_training.data.utils import default_loader 9 | 10 | 11 | class DenoisingEvaluationDataset(Dataset): 12 | 13 | def __init__(self, json_file: str, root: str = None, transforms: Callable = None, loader: Callable = default_loader): 14 | with open(json_file) as f: 15 | self.image_data = json.load(f) 16 | 17 | self.root = root 18 | self.transforms = transforms 19 | self.loader = loader 20 | 21 | def __len__(self): 22 | return len(self.image_data) 23 | 24 | def __getitem__(self, index: int) -> Dict[str, numpy.ndarray]: 25 | paths = self.image_data[index] 26 | 27 | loaded_images = {} 28 | for image_type, path in paths.items(): 29 | if self.root is not None: 30 | path = os.path.join(self.root, path) 31 | 32 | image = self.loader(path) 33 | if self.transforms is not None: 34 | image = self.transforms(image) 35 | 36 | loaded_images[image_type] = image 37 | 38 | return loaded_images 39 | 40 | -------------------------------------------------------------------------------- /embeddings/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/embeddings/__init__.py -------------------------------------------------------------------------------- /embeddings/utils.py: -------------------------------------------------------------------------------- 1 | from typing import Dict, List 2 | 3 | import numpy 4 | import torch 5 | 6 | 7 | def noises_from_embedding(embedding_data: Dict[str, numpy.ndarray], index: int) -> List[torch.Tensor]: 8 | noise_keys = [key for key in embedding_data.keys() if 'noise' in key] 9 | noise_keys = sorted(noise_keys, key=lambda x: (int((s := x.split('_'))[-2]), int(s[-1]))) 10 | 11 | noises = [embedding_data[noise_key][index].astype(numpy.float32) for noise_key in noise_keys] 12 | noises = [torch.tensor(noise) for noise in noises] 13 | spreaded_noises = [] 14 | for noise in noises: 15 | if noise.shape[0] > 1: 16 | # we are dealing with multiple noise maps per resolution, as it happens in Stylegan2 17 | spreaded_noises.extend(noise.chunk(2, dim=0)) 18 | else: 19 | spreaded_noises.append(noise) 20 | return spreaded_noises 21 | 22 | 23 | def latent_from_embedding(embedding_data: Dict[str, numpy.ndarray], index:int) -> torch.Tensor: 24 | return torch.tensor(embedding_data['latent_codes'][index].astype(numpy.float32)) 25 | -------------------------------------------------------------------------------- /evaluate_all_denoising_checkpoints.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from tqdm import tqdm 5 | 6 | from evaluate_denoising import evaluate_denoising 7 | from utils.config import load_config 8 | 9 | 10 | if __name__ == "__main__": 11 | parser = argparse.ArgumentParser(description="Evaluate all denoising checkpoints in given project dir") 12 | parser.add_argument("project_dir", help="path to project dir") 13 | parser.add_argument("test_dataset_dir", help="path to test dataset dir with all test splits to run") 14 | parser.add_argument("dataset_name", help="name of evaluation dataset (e.g. BSD68 or Set12)") 15 | 16 | args = parser.parse_args() 17 | 18 | project_dir = Path(args.project_dir) 19 | all_checkpoints = list(project_dir.glob("**/*/checkpoints/100000.pt")) 20 | 21 | test_dataset_dir = Path(args.test_dataset_dir) 22 | test_datasets = list(test_dataset_dir.glob("*.json")) 23 | 24 | evaluate_args = argparse.Namespace() 25 | evaluate_args.device = 'cuda' 26 | evaluate_args.save = True 27 | evaluate_args.dataset_name = args.dataset_name 28 | 29 | for dataset in tqdm(test_datasets): 30 | evaluate_args.test_dataset = dataset 31 | 32 | for checkpoint in tqdm(all_checkpoints, leave=False): 33 | config = load_config(checkpoint, None) 34 | 35 | if not (config.get('denoising', False) or config.get('black_and_white_denoising', False)): 36 | # no denoising checkpoint 37 | continue 38 | 39 | evaluate_args.model_checkpoint = checkpoint 40 | evaluate_denoising(evaluate_args) 41 | -------------------------------------------------------------------------------- /evaluate_checkpoints.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import itertools 4 | import json 5 | import statistics 6 | from collections import defaultdict 7 | from pathlib import Path 8 | from typing import Dict 9 | 10 | import torch 11 | from tqdm import tqdm 12 | from tqdm.contrib import tenumerate 13 | 14 | from data import get_dataset_class 15 | from evaluation.fid import FID 16 | from evaluation.psnr_ssim import PSNRSSIMEvaluator 17 | from networks import get_autoencoder, load_weights, StyleganAutoencoder 18 | from utils.config import load_config 19 | from utils.data_loading import build_data_loader 20 | 21 | 22 | def save_eval_result(eval_result: dict, eval_type: str, dest_dir: Path, dataset_name: str, checkpoint_name: str): 23 | dest_file = dest_dir / f"{eval_type}.json" 24 | if dest_file.exists(): 25 | with dest_file.open("r") as f: 26 | json_data = json.load(f) 27 | else: 28 | json_data = {} 29 | 30 | checkpoint_results = json_data.get(checkpoint_name, {}) 31 | checkpoint_results[dataset_name] = eval_result 32 | json_data[checkpoint_name] = checkpoint_results 33 | 34 | with dest_file.open("w") as f: 35 | json.dump(json_data, f, indent='\t') 36 | 37 | 38 | def evaluate_reconstruction(autoencoder: StyleganAutoencoder, data_loaders: dict) -> dict: 39 | metrics = defaultdict(list) 40 | psnr_ssim_evaluator = PSNRSSIMEvaluator() 41 | 42 | for i, batch in tenumerate(data_loaders['test'], desc="psnr_ssim", leave=False): 43 | batch = {k: v.to('cuda') for k, v in batch.items()} 44 | with torch.no_grad(): 45 | denoised = autoencoder(batch['input_image']) 46 | 47 | psnr, ssim = psnr_ssim_evaluator.psnr_and_ssim(denoised, batch['output_image']) 48 | 49 | metrics['psnr'].append(float(psnr.cpu().numpy())) 50 | metrics['ssim'].append(float(ssim.cpu().numpy())) 51 | metrics = {k: statistics.mean(v) for k, v in metrics.items()} 52 | return metrics 53 | 54 | 55 | def evaluate_fid(autoencoder: StyleganAutoencoder, data_loaders: dict, dataset: dict) -> dict: 56 | fid_evaluator = FID(num_samples=50_000) 57 | 58 | test_dataset_len = len(data_loaders['test']) 59 | if test_dataset_len > 50_000: 60 | data_key = 'test' 61 | else: 62 | if len(data_loaders['train']) < 50_000: 63 | print("warning test and train dataset are smaller than 50.000 samples!") 64 | data_key = 'train' 65 | 66 | fid_scores = fid_evaluator(autoencoder, data_loaders[data_key], dataset[data_key]) 67 | return {"fid": fid_scores} 68 | 69 | 70 | def has_not_been_evaluated(checkpoint_name: str, dataset_name: str, evaluation_root: Path) -> Dict[str, bool]: 71 | already_done_map = {} 72 | for eval_type in ["fid", "reconstruction"]: 73 | dest_file = evaluation_root / f"{eval_type}.json" 74 | if not dest_file.exists(): 75 | already_done_map[eval_type] = True 76 | continue 77 | 78 | with dest_file.open() as f: 79 | evaluation_data = json.load(f) 80 | 81 | if checkpoint_name not in evaluation_data: 82 | already_done_map[eval_type] = True 83 | continue 84 | 85 | evaluation_data = evaluation_data[checkpoint_name] 86 | already_done_map[eval_type] = dataset_name not in evaluation_data 87 | 88 | return already_done_map 89 | 90 | 91 | def evaluate_checkpoint(checkpoint: str, dataset: dict, args: argparse.Namespace): 92 | checkpoint = Path(checkpoint) 93 | train_run_root_dir = checkpoint.parent.parent 94 | evaluation_root = train_run_root_dir / 'evaluation' 95 | evaluation_root.mkdir(exist_ok=True) 96 | 97 | dataset_name = dataset.pop('name') 98 | to_evaluate = has_not_been_evaluated(checkpoint.name, dataset_name, evaluation_root) 99 | if not args.fid: 100 | to_evaluate['fid'] = False 101 | if not args.reconstruction: 102 | to_evaluate['reconstruction'] = False 103 | 104 | if not any(to_evaluate.values()): 105 | # there is nothing to evaluate 106 | return 107 | 108 | config = load_config(checkpoint, None) 109 | 110 | dataset = {k: Path(v) for k, v in dataset.items()} 111 | 112 | autoencoder = get_autoencoder(config).to('cuda') 113 | autoencoder = load_weights(autoencoder, checkpoint, key='autoencoder', strict=True) 114 | 115 | config['batch_size'] = 1 116 | 117 | dataset_class = get_dataset_class(argparse.Namespace(**config)) 118 | data_loaders = { 119 | key: build_data_loader(value, config, config['absolute'], shuffle_off=True, dataset_class=dataset_class) 120 | for key, value in dataset.items() 121 | } 122 | 123 | if to_evaluate['fid']: 124 | fid_result = evaluate_fid(autoencoder, data_loaders, dataset) 125 | save_eval_result(fid_result, "fid", evaluation_root, dataset_name, checkpoint.name) 126 | 127 | if to_evaluate['reconstruction']: 128 | reconstruction_result = evaluate_reconstruction(autoencoder, data_loaders) 129 | save_eval_result(reconstruction_result, "reconstruction", evaluation_root, dataset_name, checkpoint.name) 130 | 131 | del autoencoder 132 | torch.cuda.empty_cache() 133 | 134 | 135 | def main(args): 136 | checkpoint_file = Path(args.checkpoint_list) 137 | with checkpoint_file.open() as f: 138 | checkpoints = [line.rstrip() for line in f] 139 | 140 | dataset_file = Path(args.dataset_file) 141 | with dataset_file.open() as f: 142 | datasets = json.load(f) 143 | 144 | failed_combinations = [] 145 | try: 146 | for checkpoint, dataset in tqdm(itertools.product(checkpoints, datasets), total=len(checkpoints) * len(datasets)): 147 | try: 148 | evaluate_checkpoint(checkpoint, copy.deepcopy(dataset), args) 149 | except Exception as e: 150 | failed_combinations.append({"combination": (checkpoint, dataset), "reason": str(e)}) 151 | finally: 152 | for combination in failed_combinations: 153 | print(f"The following eval combination failed: {combination['combination']}, with reason: {combination['reason']}") 154 | 155 | 156 | if __name__ == "__main__": 157 | parser = argparse.ArgumentParser(description="Tool that takes a list of checkpoints and datasets and runs evaluation") 158 | parser.add_argument("checkpoint_list", help="path to file that contains the path to a trained checkpoint in each line") 159 | parser.add_argument("dataset_file", help="path to json file that contains the paths to datasets each model is to be evaluated on") 160 | parser.add_argument("--skip-fid", dest="fid", action='store_false', default=True, help="skip fid during evaluation") 161 | parser.add_argument("--skip-reconstruction", dest="reconstruction", action='store_false', default=True, help="skip calculation of reconstruction metrics such as PSNR and SSID") 162 | 163 | main(parser.parse_args()) 164 | -------------------------------------------------------------------------------- /evaluate_denoising.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import statistics 4 | from collections import defaultdict 5 | from pathlib import Path 6 | from typing import List 7 | 8 | import torch 9 | from PIL import Image 10 | from tqdm.contrib import tenumerate 11 | 12 | from data.denoising_eval_dataset import DenoisingEvaluationDataset 13 | from evaluation.psnr_ssim import PSNRSSIMEvaluator 14 | from networks import get_autoencoder, load_weights 15 | from pytorch_training.images.utils import clamp_and_unnormalize, make_image 16 | from utils.config import load_config 17 | from utils.data_loading import build_data_loader 18 | 19 | 20 | def save_images(images: List[torch.Tensor], save_dir: Path, index: int): 21 | dest_file_name = save_dir / f"{index}.png" 22 | 23 | images = [Image.fromarray(make_image(image, normalize_func=lambda x: x)) for image in images] 24 | 25 | dest_image = Image.new((im := images[0]).mode, (im.width * len(images), im.height)) 26 | for i, image in enumerate(images): 27 | dest_image.paste(image, (image.width * i, 0)) 28 | 29 | dest_image.save(str(dest_file_name)) 30 | 31 | 32 | def evaluate_denoising(args): 33 | config = load_config(args.model_checkpoint, None) 34 | args.test_dataset = Path(args.test_dataset) 35 | 36 | assert config['denoising'] is True or config['black_and_white_denoising'] is True, "you are supplying a train run that has not been trained for denoising! Aborting" 37 | 38 | autoencoder = get_autoencoder(config).to(args.device) 39 | autoencoder = load_weights(autoencoder, args.model_checkpoint, key='autoencoder', strict=True) 40 | 41 | config['batch_size'] = 1 42 | data_loader = build_data_loader(args.test_dataset, config, config['absolute'], shuffle_off=True, dataset_class=DenoisingEvaluationDataset) 43 | 44 | metrics = defaultdict(list) 45 | psnr_ssim_evaluator = PSNRSSIMEvaluator() 46 | 47 | train_run_root_dir = Path(args.model_checkpoint).parent.parent 48 | evaluation_root = train_run_root_dir / 'evaluation' / f"denoise_{args.dataset_name}" 49 | evaluation_root.mkdir(parents=True, exist_ok=True) 50 | 51 | for i, batch in tenumerate(data_loader, leave=False): 52 | batch = {k: v.to(args.device) for k, v in batch.items()} 53 | with torch.no_grad(): 54 | denoised = autoencoder(batch['noisy']) 55 | 56 | noisy = clamp_and_unnormalize(batch['noisy']) 57 | original = clamp_and_unnormalize(batch['original']) 58 | denoised = clamp_and_unnormalize(denoised) 59 | 60 | if args.save: 61 | save_dir = evaluation_root / "qualitative" / args.test_dataset.stem 62 | save_dir.mkdir(exist_ok=True, parents=True) 63 | save_images([original[0], noisy[0], denoised[0]], save_dir, i) 64 | 65 | psnr, ssim = psnr_ssim_evaluator.psnr_and_ssim(denoised, original) 66 | 67 | metrics['psnr'].append(float(psnr.cpu().numpy())) 68 | metrics['ssim'].append(float(ssim.cpu().numpy())) 69 | 70 | metrics = {k: statistics.mean(v) for k, v in metrics.items()} 71 | 72 | evaluation_file = evaluation_root / f'denoising_{args.test_dataset.stem}.json' 73 | with evaluation_file.open('w') as f: 74 | json.dump(metrics, f, indent='\t') 75 | 76 | 77 | if __name__ == "__main__": 78 | parser = argparse.ArgumentParser(description="Tool that takes a trained denoising model and an evaluation dataset and produces denoising eval results") 79 | parser.add_argument("model_checkpoint", help="Path to trained model that is to be evaluated") 80 | parser.add_argument("test_dataset", help="path to json holding pairs of noisy and clean image paths") 81 | parser.add_argument("dataset_name", help="name of evaluation dataset (e.g. BSD68 or Set12)") 82 | parser.add_argument("--device", default='cuda', help="device to use") 83 | parser.add_argument("--save", action='store_true', default=False, help="save reconstructed images together with real images for visual inspection") 84 | 85 | evaluate_denoising(parser.parse_args()) 86 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/evaluation/__init__.py -------------------------------------------------------------------------------- /evaluation/autoencoder_evaluation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from kornia import ssim as ssim_loss, psnr_loss 4 | 5 | from losses.lpips import PerceptualLoss 6 | from networks import StyleganAutoencoder 7 | from pytorch_training.images.utils import clamp_and_unnormalize 8 | from pytorch_training.reporter import get_current_reporter 9 | 10 | 11 | class AutoEncoderEvalFunc: 12 | 13 | def __init__(self, autoencoder: StyleganAutoencoder, device: int, use_perceptual_loss: bool = True): 14 | self.autoencoder = autoencoder 15 | self.perceptual_loss = PerceptualLoss(model='net-lin', net='vgg', use_gpu=True, gpu_ids=[device]) 16 | self.use_perceptual_loss = use_perceptual_loss 17 | 18 | def __call__(self, batch): 19 | reporter = get_current_reporter() 20 | 21 | with torch.no_grad(): 22 | reconstructed_images = self.autoencoder(batch['input_image']) 23 | original_image = batch['output_image'] 24 | 25 | mse_loss = F.mse_loss(original_image, reconstructed_images, reduction='none') 26 | loss = mse_loss.mean(dim=(1, 2, 3)).sum() 27 | reporter.add_observation({"reconstruction_loss": loss}, prefix='evaluation') 28 | if self.use_perceptual_loss: 29 | perceptual_loss = self.perceptual_loss(reconstructed_images, original_image).sum() 30 | reporter.add_observation({"perceptual_loss": perceptual_loss}, prefix='evaluation') 31 | loss += perceptual_loss 32 | 33 | original_image = clamp_and_unnormalize(original_image) 34 | reconstructed_images = clamp_and_unnormalize(reconstructed_images) 35 | psnr = psnr_loss(reconstructed_images, original_image, max_val=1) 36 | 37 | ssim = ssim_loss(original_image, reconstructed_images, 5, reduction='mean') 38 | # since we get a loss, we need to calculate/reconstruct the original ssim value 39 | ssim = 1 - 2 * ssim 40 | 41 | reporter.add_observation({"psnr": psnr, "ssim": ssim}, prefix='evaluation') 42 | 43 | reporter.add_observation({"autoencoder_loss": loss}, prefix='evaluation') 44 | -------------------------------------------------------------------------------- /evaluation/calculate_fid_for_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | from evaluation.fid import FID 6 | from networks import get_autoencoder, load_weights 7 | from utils.config import load_config 8 | from utils.data_loading import build_data_loader 9 | 10 | 11 | def save_fid_score(fid_score: float, dest_dir: Path, dataset_name: str): 12 | dest_file = dest_dir / "fid.json" 13 | if dest_file.exists(): 14 | with dest_file.open("r") as f: 15 | json_data = json.load(f) 16 | else: 17 | json_data = {} 18 | 19 | if dataset_name in json_data: 20 | print("WARNING: Already found an FID result for this dataset") 21 | while True: 22 | answer = input("Overwrite [y|N]? ") 23 | if len(answer) == 0 or answer.lower() == 'n': 24 | return 25 | elif answer.lower() == 'y': 26 | break 27 | print(f"Did not understand: {answer}") 28 | json_data[dataset_name] = fid_score 29 | 30 | with dest_file.open("w") as f: 31 | json.dump(json_data, f, indent='\t') 32 | 33 | 34 | def main(args: argparse.Namespace): 35 | dest_dir = Path(args.model_checkpoint).parent.parent / 'evaluation' 36 | dest_dir.mkdir(parents=True, exist_ok=True) 37 | 38 | config = load_config(args.model_checkpoint, None) 39 | dataset = Path(args.dataset) 40 | 41 | config['batch_size'] = args.batch_size 42 | data_loader = build_data_loader(dataset, config, config['absolute'], shuffle_off=True) 43 | fid_calculator = FID(args.num_samples, device=args.device) 44 | 45 | autoencoder = get_autoencoder(config).to(args.device) 46 | autoencoder = load_weights(autoencoder, args.model_checkpoint, key='autoencoder') 47 | 48 | fid_score = fid_calculator(autoencoder, data_loader, args.dataset) 49 | 50 | save_fid_score(fid_score, dest_dir, args.dataset_name) 51 | 52 | print(f"FID Score for {args.dataset_name} is {fid_score}.") 53 | 54 | 55 | if __name__ == "__main__": 56 | parser = argparse.ArgumentParser(description="Tool that calculates FID metric for a given model and dataset") 57 | parser.add_argument("model_checkpoint", help="path to the model that is to be analyzed") 58 | parser.add_argument("dataset", help="path to json holding dataset information") 59 | parser.add_argument("dataset_name", help="human readable name of dataset you are evaluating (used for saving the results)") 60 | parser.add_argument("-b", "--batch_size", type=int, default=1, help="Batch Size for forwarding through model") 61 | parser.add_argument("-n", "--num-samples", type=int, default=1000, help="number of samples to use for FID calculation") 62 | parser.add_argument("-d", "--device", default='cuda', help="device to use") 63 | 64 | main(parser.parse_args()) 65 | -------------------------------------------------------------------------------- /evaluation/create_denoise_eval_set.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import imgaug.augmenters as iaa 3 | from pathlib import Path 4 | 5 | import numpy 6 | from PIL import Image 7 | from tqdm import tqdm 8 | 9 | 10 | NOISE_SCALES = [5, 10, 15, 25, 35, 50] 11 | 12 | 13 | if __name__ == "__main__": 14 | parser = argparse.ArgumentParser(description="based on a given image dir, create noisy versions of images and save them in extra dirs + gt") 15 | parser.add_argument("image_dir") 16 | 17 | args = parser.parse_args() 18 | 19 | image_dir = Path(args.image_dir) 20 | 21 | for scale in tqdm(NOISE_SCALES): 22 | image_files = list(image_dir.glob('*.png')) 23 | with Image.open(image_files[0]) as test_image: 24 | per_channel = test_image.mode != 'L' 25 | 26 | augmenter = iaa.AdditiveGaussianNoise(scale=scale, per_channel=per_channel) 27 | dest_dir = image_dir.parent / f"noisy_{scale}" 28 | dest_dir.mkdir(exist_ok=True) 29 | 30 | for image_file in tqdm(image_files, leave=False): 31 | with Image.open(image_file) as the_image: 32 | image_array = numpy.array(the_image) 33 | noisy_array = augmenter(image=image_array) 34 | noisy_image = Image.fromarray(noisy_array) 35 | 36 | dest_name = dest_dir / image_file.name 37 | noisy_image.save(str(dest_name)) 38 | -------------------------------------------------------------------------------- /evaluation/datasets.json: -------------------------------------------------------------------------------- 1 | [ 2 | { 3 | "test": "/mnt/ssd1/christian/one_model_to_generate_them_all/lsun_data/church/val_images/images.json", 4 | "train": "/mnt/ssd1/christian/one_model_to_generate_them_all/lsun_data/church/train_images/images.json", 5 | "name": "lsun_church" 6 | }, 7 | { 8 | "test": "/mnt/ssd1/christian/one_model_to_generate_them_all/lsun_data/bedroom/val_images/images.json", 9 | "train": "/mnt/ssd1/christian/one_model_to_generate_them_all/lsun_data/bedroom/train_images/images.json", 10 | "name": "lsun_bedroom" 11 | }, 12 | { 13 | "test": "/mnt/ssd2/christian/one_model_to_generate_them_all/lsun_data/cat/cat_images/val.json", 14 | "train": "/mnt/ssd2/christian/one_model_to_generate_them_all/lsun_data/cat/cat_images/train.json", 15 | "name": "lsun_cat" 16 | }, 17 | { 18 | "test": "/mnt/ssd1/christian/flickr-faces/images1024x1024/val.json", 19 | "train": "/mnt/ssd1/christian/flickr-faces/images1024x1024/train.json", 20 | "name": "ffhq" 21 | } 22 | ] 23 | -------------------------------------------------------------------------------- /evaluation/fid.py: -------------------------------------------------------------------------------- 1 | import contextlib 2 | import hashlib 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Union, Tuple 6 | 7 | import numpy 8 | import pytorch_fid.fid_score 9 | import pytorch_fid.inception 10 | import torch 11 | import torch.distributed 12 | from torch.nn.functional import adaptive_avg_pool2d 13 | from torch.utils.data import DataLoader 14 | from tqdm.contrib import tenumerate 15 | 16 | from networks import StyleganAutoencoder 17 | from pytorch_training.data.utils import cache_or_load_file 18 | from pytorch_training.distributed import get_world_size, get_rank 19 | 20 | 21 | @dataclass 22 | class FIDStatistics: 23 | mu: numpy.ndarray 24 | sigma: numpy.ndarray 25 | 26 | 27 | class FID: 28 | 29 | def __init__(self, num_samples: int = 1000, dim: int = 2048, device: str = 'cuda'): 30 | self.num_samples = num_samples 31 | self.inception_dim = dim 32 | self.block_index = pytorch_fid.inception.InceptionV3.BLOCK_INDEX_BY_DIM[dim] 33 | self.device = device 34 | 35 | self.model = pytorch_fid.inception.InceptionV3([self.block_index]) 36 | self.model.eval() 37 | 38 | @contextlib.contextmanager 39 | def init_inception_model(self): 40 | torch.cuda.empty_cache() 41 | self.model = self.model.to(self.device) 42 | yield 43 | self.model = self.model.to('cpu') 44 | torch.cuda.empty_cache() 45 | 46 | def __call__(self, model: StyleganAutoencoder, data_loader: DataLoader, dataset_path: Union[str, Path] = None) -> float: 47 | with self.init_inception_model(): 48 | real_statistics, fake_statistics = self.calculate_statistics(model, data_loader, dataset_path) 49 | 50 | fid_score = pytorch_fid.fid_score.calculate_frechet_distance( 51 | real_statistics.mu, 52 | real_statistics.sigma, 53 | fake_statistics.mu, 54 | fake_statistics.sigma 55 | ) 56 | 57 | return fid_score 58 | 59 | @staticmethod 60 | def load_precalculated_mu_and_sigma(path: Path) -> FIDStatistics: 61 | data = numpy.load(str(path)) 62 | return FIDStatistics(data['mu'][:], data['sigma'][:]) 63 | 64 | @staticmethod 65 | def get_statistics(activations: numpy.ndarray) -> FIDStatistics: 66 | mu = numpy.mean(activations, axis=0) 67 | sigma = numpy.cov(activations, rowvar=False) 68 | return FIDStatistics(mu, sigma) 69 | 70 | def multiprocess_synchronize(self, activations: torch.Tensor) -> numpy.ndarray: 71 | if get_world_size() > 1: 72 | # we are running in distributed setting, so we will need to gather all predictions for each worker 73 | gathered_activations = [torch.empty(activations.shape, device=self.device) for _ in range(get_world_size())] 74 | torch.distributed.all_gather(gathered_activations, activations) 75 | activations = torch.cat(gathered_activations, dim=0) 76 | return activations.cpu().numpy() 77 | 78 | @contextlib.contextmanager 79 | def get_progress_bar(self, data_loader, total, description): 80 | if get_rank() == 0: 81 | pbar = tenumerate(data_loader, total=total // data_loader.batch_size + 1, desc=description, leave=False) 82 | else: 83 | pbar = enumerate(data_loader) 84 | 85 | yield pbar 86 | 87 | if hasattr(pbar, 'close'): 88 | pbar.close() 89 | 90 | def calculate_real_statistics(self, path: Union[Path, None], data_loader: DataLoader) -> FIDStatistics: 91 | total = min(self.num_samples, len(data_loader) * data_loader.batch_size) 92 | activations = torch.empty((total, self.inception_dim), device=self.device) 93 | 94 | with self.get_progress_bar(data_loader, total, 'fid real') as progress_bar: 95 | for i, batch in progress_bar: 96 | batch = batch['output_image'] 97 | batch = batch.to(self.device) 98 | inception_predictions = self.get_inception_predictions(batch) 99 | 100 | start = i * data_loader.batch_size 101 | end = min((i + 1) * data_loader.batch_size, total) 102 | activations[start:end] = inception_predictions[:end-start] 103 | 104 | if end >= total: 105 | break 106 | 107 | activations = self.multiprocess_synchronize(activations) 108 | statistics = self.get_statistics(activations) 109 | if path is not None: 110 | numpy.savez(str(path), mu=statistics.mu, sigma=statistics.sigma) 111 | 112 | return statistics 113 | 114 | def get_inception_predictions(self, batch: torch.Tensor) -> torch.Tensor: 115 | with torch.no_grad(): 116 | inception_predictions = self.model(batch)[0] 117 | 118 | # If model output is not scalar, apply global spatial average pooling. 119 | # This happens if you choose a dimensionality not equal 2048. 120 | if inception_predictions.size(2) != 1 or inception_predictions.size(3) != 1: 121 | inception_predictions = adaptive_avg_pool2d(inception_predictions, output_size=(1, 1)) 122 | inception_predictions = inception_predictions 123 | inception_predictions = inception_predictions.reshape(len(inception_predictions), -1) 124 | 125 | return inception_predictions 126 | 127 | def calculate_fake_statistics(self, autoencoder: StyleganAutoencoder, data_loader: DataLoader) -> FIDStatistics: 128 | total = min(self.num_samples, len(data_loader) * data_loader.batch_size) 129 | activations = torch.empty((total, self.inception_dim), device=self.device) 130 | 131 | with self.get_progress_bar(data_loader, total, 'fid fake') as progress_bar: 132 | for i, batch in progress_bar: 133 | batch = batch['input_image'] 134 | batch = batch.to(self.device) 135 | with torch.no_grad(): 136 | reconstructed_images = autoencoder(batch) 137 | 138 | inception_predictions = self.get_inception_predictions(reconstructed_images) 139 | 140 | start = i * data_loader.batch_size 141 | end = min((i + 1) * data_loader.batch_size, total) 142 | activations[start:end] = inception_predictions[:end-start] 143 | 144 | if end >= total: 145 | break 146 | 147 | activations = self.multiprocess_synchronize(activations) 148 | return self.get_statistics(activations) 149 | 150 | def calculate_statistics(self, autoencoder: StyleganAutoencoder, data_loader: DataLoader, dataset_path: Union[str, Path]) -> Tuple[FIDStatistics, FIDStatistics]: 151 | if dataset_path is not None: 152 | dataset_path = Path(dataset_path) 153 | 154 | hasher = hashlib.sha512(str(dataset_path).encode('utf-8')) 155 | hasher.update(str(self.num_samples).encode('utf-8')) 156 | fid_file_name = f"{hasher.hexdigest()}_fid.npz" 157 | real_statistics = cache_or_load_file( 158 | dataset_path.parent / fid_file_name, 159 | lambda x: self.calculate_real_statistics(x, data_loader), 160 | self.load_precalculated_mu_and_sigma 161 | ) 162 | else: 163 | real_statistics = self.calculate_real_statistics(None, data_loader) 164 | 165 | torch.cuda.empty_cache() 166 | 167 | fake_statistics = self.calculate_fake_statistics(autoencoder, data_loader) 168 | 169 | torch.cuda.empty_cache() 170 | 171 | return real_statistics, fake_statistics 172 | -------------------------------------------------------------------------------- /evaluation/find_all_saved_checkpoints.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | 5 | if __name__ == "__main__": 6 | parser = argparse.ArgumentParser(description="Finds all trained checkpoints in given project log dir") 7 | parser.add_argument("project_dir", help="path to project dir that is to be analyzed") 8 | 9 | args = parser.parse_args() 10 | project_dir = Path(args.project_dir) 11 | 12 | runs = [d for d in project_dir.iterdir() if d.is_dir()] 13 | all_checkpoints = [] 14 | for run in runs: 15 | checkpoints = list(run.glob('**/checkpoints/*.pt')) 16 | all_checkpoints.extend(checkpoints) 17 | 18 | with (project_dir / "trained_checkpoints.txt").open('w') as f: 19 | for checkpoint in all_checkpoints: 20 | print(checkpoint.resolve(), file=f) 21 | -------------------------------------------------------------------------------- /evaluation/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia import psnr_loss, ssim as ssim_loss 3 | from typing import Tuple 4 | 5 | from pytorch_training.images.utils import clamp_and_unnormalize 6 | 7 | 8 | class PSNRSSIMEvaluator: 9 | 10 | def __init__(self, max_value: int = 1, ssim_kernel_size: int = 5): 11 | self.max_value = max_value 12 | self.ssim_kernel_size = ssim_kernel_size 13 | 14 | def unnormalize(self, image: torch.Tensor) -> torch.Tensor: 15 | if image.min() < 0: 16 | image = clamp_and_unnormalize(image) 17 | return image 18 | 19 | def psnr(self, image: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 20 | image = self.unnormalize(image) 21 | target = self.unnormalize(target) 22 | 23 | assert len(image) == 1, "Batch size of images must be one in order to get a meaningful psnr result" 24 | psnr = psnr_loss(image, target, self.max_value) 25 | return psnr 26 | 27 | def ssim(self, image: torch.Tensor, target: torch.Tensor) -> torch.Tensor: 28 | image = self.unnormalize(image) 29 | target = self.unnormalize(target) 30 | 31 | assert len(image) == 1, "Batch size of images must be one in order to get a meaningful ssim result" 32 | ssim = ssim_loss(image, target, self.ssim_kernel_size, reduction='none') 33 | ssim = (1 - 2 * ssim).mean() 34 | return ssim 35 | 36 | def psnr_and_ssim(self, image: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 37 | psnr = self.psnr(image, target) 38 | ssim = self.ssim(image, target) 39 | return psnr, ssim 40 | -------------------------------------------------------------------------------- /extensions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/extensions/__init__.py -------------------------------------------------------------------------------- /extensions/fid_score.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | from pytorch_training.distributed import get_rank, synchronize 4 | from pytorch_training.reporter import get_current_reporter 5 | from torch.utils.data import DataLoader 6 | from typing import Union 7 | 8 | from evaluation.fid import FID 9 | from networks import StyleganAutoencoder 10 | 11 | from pytorch_training import Extension, Trainer 12 | 13 | 14 | class FIDScore(Extension): 15 | 16 | def __init__(self, autoencoder: StyleganAutoencoder, data_loader: DataLoader, *args, dataset_path: Union[str, Path] = None, device: str = 'cuda', **kwargs): 17 | super().__init__(*args, **kwargs) 18 | self.fid_calculator = FID(device=device) 19 | self.autoencoder = autoencoder 20 | self.data_loader = data_loader 21 | self.dataset_path = dataset_path 22 | 23 | def initialize(self, trainer: 'Trainer'): 24 | self.run(trainer) 25 | 26 | def finalize(self, trainer: 'Trainer'): 27 | self.run(trainer) 28 | 29 | def run(self, trainer: Trainer): 30 | fid = self.fid_calculator(self.autoencoder, self.data_loader, self.dataset_path) 31 | synchronize() 32 | if get_rank() == 0: 33 | with get_current_reporter() as reporter: 34 | reporter.add_observation({"fid_score": fid}, "evaluation") 35 | -------------------------------------------------------------------------------- /file_based_simple_style_transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from argparse import Namespace 3 | from pathlib import Path 4 | 5 | from PIL import Image 6 | 7 | from latent_projecting.style_transfer import StyleTransferer 8 | from utils.command_line_args import add_default_args_for_projecting 9 | from pytorch_training.images.utils import make_image 10 | 11 | 12 | def main(args: Namespace): 13 | transferer = StyleTransferer(args) 14 | content_latents, style_latents = transferer.get_latents(args.content_path, args.style_path) 15 | 16 | if args.mixing_index < 0: 17 | stylized_images = { 18 | i: transferer.do_style_transfer(content_latents, style_latents, i) 19 | for i in range(content_latents.latent.shape[1]) 20 | } 21 | else: 22 | stylized_images = { 23 | args.mixing_index: transferer.do_style_transfer(content_latents, style_latents, args.mixing_index) 24 | } 25 | 26 | destination_dir = Path(args.content_path).parent / "simple_style_transfer" / args.destination_dir 27 | destination_dir.mkdir(parents=True, exist_ok=True) 28 | 29 | for index, (image_array, optimization_path) in stylized_images.items(): 30 | content_base_name = args.content_path 31 | style_base_name = args.style_path 32 | 33 | content_name = Path(content_base_name).stem 34 | style_name = Path(style_base_name).stem 35 | 36 | image_name = f"{content_name}_{style_name}_{index}" 37 | destination_name = destination_dir / f"{image_name}.png" 38 | Image.fromarray(make_image(image_array)[0]).save(destination_name) 39 | 40 | if optimization_path is not None and args.gif: 41 | transferer.projector.create_gif(optimization_path.latent, optimization_path.noise, image_name, destination_dir) 42 | 43 | 44 | if __name__ == "__main__": 45 | parser = argparse.ArgumentParser(description="Tool that does style transfer described in Image2Stylegan Paper") 46 | parser.add_argument("--content", dest="content_path", required=True) 47 | parser.add_argument("--style", dest="style_path", required=True) 48 | parser.add_argument("--destination", dest="destination_dir", required=True) 49 | parser.add_argument("--mixing-index", type=int, default=-1) 50 | parser.add_argument("--post-optimize", action='store_true', default=False) 51 | parser.add_argument("--gif", action='store_true', default=False) 52 | parser = add_default_args_for_projecting(parser) 53 | 54 | main(parser.parse_args()) 55 | -------------------------------------------------------------------------------- /interpolate_between_embeddings.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | from pathlib import Path 4 | from typing import Union, Tuple, Dict, List 5 | 6 | import numpy 7 | import torch 8 | from PIL import Image 9 | from torch import nn 10 | from tqdm import trange 11 | 12 | from embeddings.utils import latent_from_embedding, noises_from_embedding 13 | from networks import get_autoencoder, load_weights 14 | from pytorch_training.images import make_image 15 | from utils.config import load_config 16 | 17 | 18 | def interpolate(start_array: Union[numpy.ndarray, torch.Tensor], end_array: Union[numpy.ndarray, torch.Tensor], fraction: float) -> Union[numpy.ndarray, torch.Tensor]: 19 | return (1 - fraction) * start_array + fraction * end_array 20 | 21 | 22 | def load_embeddings(embeddings: Dict[str, numpy.ndarray], index: int) -> Tuple[torch.Tensor, List[torch.Tensor]]: 23 | latent = latent_from_embedding(embeddings, index).unsqueeze(0) 24 | noises = [noise.unsqueeze(0) for noise in noises_from_embedding(embeddings, index)] 25 | return latent, noises 26 | 27 | 28 | def make_interpolation_image(steps: int, device: torch.device, autoencoder: nn.Module, is_w_plus: bool, 29 | start_latent: torch.Tensor, end_latent: torch.Tensor, 30 | start_noises: List[torch.Tensor], end_noises: List[torch.Tensor]): 31 | all_interpolation_images = [] 32 | for interpolation_strategy in ['all', 'latent', 'noise']: 33 | interpolation_images = [] 34 | 35 | start_image, _ = autoencoder.decoder([start_latent.to(device)], input_is_latent=is_w_plus, noise=[n.to(device) for n in start_noises]) 36 | interpolation_images.append(make_image(start_image.squeeze(0))) 37 | 38 | for i in trange(steps + 1): 39 | step_fraction = i / steps 40 | if interpolation_strategy in ['latent', 'all']: 41 | latent = interpolate(start_latent, end_latent, step_fraction) 42 | else: 43 | latent = start_latent 44 | latent = latent.to(device) 45 | 46 | if interpolation_strategy in ['noise', 'all']: 47 | noises = [interpolate(start_noise, end_noise, step_fraction) for start_noise, end_noise in zip(start_noises, end_noises)] 48 | else: 49 | noises = autoencoder.decoder.make_noise() 50 | noises = [noise.to(device) for noise in noises] 51 | 52 | image, _ = autoencoder.decoder([latent], input_is_latent=is_w_plus, noise=noises) 53 | image = make_image(image.squeeze(0)) 54 | interpolation_images.append(image) 55 | 56 | end_image, _ = autoencoder.decoder([end_latent.to(device)], input_is_latent=is_w_plus, noise=[n.to(device) for n in end_noises]) 57 | interpolation_images.append(make_image(end_image.squeeze(0))) 58 | 59 | all_images = numpy.concatenate(interpolation_images, axis=1) 60 | image = Image.fromarray(all_images) 61 | all_interpolation_images.append(image) 62 | 63 | dest_image = Image.new("RGB", (all_interpolation_images[0].width, all_interpolation_images[0].height * 3)) 64 | for i, image in enumerate(all_interpolation_images): 65 | dest_image.paste(image, (0, i * image.height)) 66 | 67 | return dest_image 68 | 69 | 70 | def main(args): 71 | embedding_dir = Path(args.embedding_file).parent 72 | embedded_data = numpy.load(args.embedding_file, mmap_mode='r') 73 | 74 | checkpoint_for_embedding = embedding_dir.parent / 'checkpoints' / f"{Path(args.embedding_file).stem.split('_')[-3]}.pt" 75 | 76 | config = load_config(checkpoint_for_embedding, None) 77 | autoencoder = get_autoencoder(config).to(args.device) 78 | autoencoder = load_weights(autoencoder, checkpoint_for_embedding, key='autoencoder', strict=True) 79 | 80 | num_images = len(embedded_data['image_names']) 81 | 82 | interpolation_dir = embedding_dir / 'interpolations' 83 | interpolation_dir.mkdir(parents=True, exist_ok=True) 84 | 85 | is_w_plus = not config['w_only'] 86 | 87 | for _ in range(args.num_images): 88 | start_image_idx, end_image_idx = random.sample(list(range(num_images)), k=2) 89 | 90 | start_latent, start_noises = load_embeddings(embedded_data, start_image_idx) 91 | end_latent, end_noises = load_embeddings(embedded_data, end_image_idx) 92 | 93 | for steps in args.steps: 94 | result = make_interpolation_image(steps, args.device, autoencoder, is_w_plus, 95 | start_latent, end_latent, start_noises, end_noises) 96 | result.save(str(interpolation_dir / f"{start_image_idx}_{end_image_idx}_all_{steps}_steps.png")) 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser( 101 | description="extract two embedding codes and interpolate between them based on a number of steps", 102 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 103 | ) 104 | parser.add_argument("embedding_file", help='Path to npz with embedding of latent codes + noise') 105 | parser.add_argument("--device", default='cuda', help="which device to use (cuda, or cpu)") 106 | parser.add_argument("-s", "--steps", type=int, default=[5, 20], nargs="+", 107 | help="number of interpolation steps to perform (multiple values will create multiple scales)") 108 | parser.add_argument("-n", "--num-images", type=int, default=1, 109 | help="perform interpolation or multiple images") 110 | 111 | main(parser.parse_args()) 112 | -------------------------------------------------------------------------------- /latent_projecting/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from argparse import Namespace 4 | from collections import defaultdict 5 | from dataclasses import dataclass 6 | from typing import Tuple, List, Callable 7 | 8 | import torch 9 | from torch import optim 10 | 11 | from latent_projecting.losses import w_plus_style_loss, noise_loss, w_plus_loss, naive_noise_loss 12 | from pytorch_training.optimizer.lr_scheduling import LambdaLRWithRamp 13 | 14 | 15 | @dataclass 16 | class Latents: 17 | latent: torch.Tensor 18 | noise: List[torch.Tensor] 19 | 20 | def to(self, device) -> Latents: 21 | self.latent = self.latent.to(device) 22 | self.noise = [noise.to(device) for noise in self.noise] 23 | return self 24 | 25 | def __getitem__(self, key: int) -> Latents: 26 | latent = self.latent[key].unsqueeze(0) 27 | noise = [noise[key].unsqueeze(0) for noise in self.noise] 28 | return Latents(latent, noise) 29 | 30 | def detach(self): 31 | self.latent = self.latent.detach() 32 | self.noise = [noise.detach() for noise in self.noise] 33 | 34 | def numpy(self) -> Latents: 35 | latent = self.latent.cpu().numpy() 36 | noises = [noise.cpu().numpy() for noise in self.noise] 37 | return Latents(latent, noises) 38 | 39 | 40 | @dataclass 41 | class CodeLatents(Latents): 42 | code: torch.Tensor 43 | 44 | def to(self, device) -> CodeLatents: 45 | super().to(device) 46 | self.code = self.code.to(device) 47 | return self 48 | 49 | def __getitem__(self, key: int) -> CodeLatents: 50 | latent = self.latent[key].unsqueeze(0) 51 | noise = [noise[key].unsqueeze(0) for noise in self.noise] 52 | code = self.code[key].unsqueeze(0) 53 | return CodeLatents(latent, noise, code) 54 | 55 | def detach(self): 56 | super().detach() 57 | self.code = self.code.detach() 58 | 59 | 60 | @dataclass 61 | class LatentPaths: 62 | latent: List[torch.Tensor] 63 | noise: List[List[torch.Tensor]] 64 | 65 | def to(self, device) -> LatentPaths: 66 | for i in range(len(self)): 67 | self.latent[i] = self.latent[i].to(device) 68 | self.noise[i] = [noise.to(device) for noise in self.noise[i]] 69 | return self 70 | 71 | def __len__(self): 72 | assert len(self.latent) == len(self.noise) 73 | return len(self.latent) 74 | 75 | def __iter__(self) -> Latents: 76 | for latent, noise in zip(self.latent, self.noise): 77 | yield Latents(latent, noise) 78 | 79 | def __add__(self, other: LatentPaths) -> LatentPaths: 80 | self.latent += other.latent 81 | self.noise += other.noise 82 | return self 83 | 84 | def split(self) -> List[LatentPaths]: 85 | latent_paths = torch.stack(self.latent, dim=0) 86 | latent_paths = torch.transpose(latent_paths, 0, 1) 87 | latent_paths = torch.split(latent_paths, 1) 88 | latent_paths = [torch.split(tensor[0], 1) for tensor in latent_paths] 89 | 90 | noise_paths = defaultdict(list) 91 | for path_element in self.noise: 92 | splitted_noises = defaultdict(list) 93 | for noise_element in path_element: 94 | splitted_batch = noise_element.split(1) 95 | for i in range(len(splitted_batch)): 96 | splitted_noises[i].append(splitted_batch[i]) 97 | 98 | for batch_index, noises in splitted_noises.items(): 99 | noise_paths[batch_index].append(noises) 100 | 101 | return [LatentPaths(list(latent), list(noises)) for latent, noises in zip(latent_paths, noise_paths.values())] 102 | 103 | 104 | def optimize_noise(args: Namespace, projector: "Projector", latents: Latents, images: torch.Tensor, loss_func: Callable) -> Tuple[LatentPaths, Latents]: 105 | latents.to(projector.device) 106 | projector.set_requires_grad(latents, False) 107 | optimizer = optim.Adam(latents.noise, lr=args.noise_lr) 108 | scheduling_function = LambdaLRWithRamp.get_lr_with_ramp(args.noise_step, args.noise_lr_rampdown, args.noise_lr_rampup) 109 | lr_scheduler = LambdaLRWithRamp(optimizer, scheduling_function) 110 | 111 | paths, best_latent = projector.project( 112 | latents, 113 | images, 114 | optimizer, 115 | args.noise_step, 116 | loss_func, 117 | lr_scheduler=lr_scheduler, 118 | ) 119 | 120 | return paths, best_latent 121 | 122 | 123 | def run_image_reconstruction(args: Namespace, projector: "Projector", latents: Latents, images: torch.Tensor, do_optimize_noise: bool = True, latent_abort_condition: Callable = None, noise_abort_condition: Callable = None) -> Tuple[LatentPaths, Latents]: 124 | latents.to(projector.device) 125 | projector.abort_condition = latent_abort_condition 126 | projector.set_requires_grad(latents, True) 127 | optimizer = optim.Adam([latents.latent], lr=args.lr) 128 | scheduling_function = LambdaLRWithRamp.get_lr_with_ramp(args.latent_step, args.lr_rampdown, args.lr_rampup) 129 | lr_scheduler = LambdaLRWithRamp(optimizer, scheduling_function) 130 | 131 | paths, best_latent = projector.project( 132 | latents, 133 | images, 134 | optimizer, 135 | args.latent_step, 136 | w_plus_loss({"l_percept": 1, "l_mse": args.mse}, args.device), 137 | lr_scheduler=lr_scheduler, 138 | ) 139 | 140 | if do_optimize_noise: 141 | projector.abort_condition = noise_abort_condition 142 | more_paths, best_latent = optimize_noise( 143 | args, 144 | projector, 145 | best_latent, 146 | images, 147 | naive_noise_loss({"l_mse": 1}) 148 | ) 149 | paths = LatentPaths(paths.latent + more_paths.latent, paths.noise + more_paths.noise) 150 | 151 | return paths, best_latent 152 | 153 | 154 | def run_local_style_transfer(args: Namespace, projector: "Projector", latents: Latents, content_image: torch.Tensor, style_image: torch.Tensor, mask_image: torch.Tensor) -> Tuple[LatentPaths, Latents]: 155 | latents.to(projector.device) 156 | projector.set_requires_grad(latents, True) 157 | optimizer = optim.Adam([latents.latent], lr=args.lr) 158 | scheduling_function = LambdaLRWithRamp.get_lr_with_ramp(args.style_latent_step, args.style_lr_rampdown, args.style_lr_rampup) 159 | lr_scheduler = LambdaLRWithRamp(optimizer, scheduling_function) 160 | 161 | latent_path, best_latent = projector.project( 162 | latents, 163 | content_image, 164 | optimizer, 165 | args.style_latent_step, 166 | w_plus_style_loss({"l_percept": 1, "l_mse": 1, "l_style": 1}, content_image, style_image, mask_image, args.device), 167 | lr_scheduler=lr_scheduler, 168 | ) 169 | 170 | # optimize noise 171 | projector.set_requires_grad(best_latent, False) 172 | optimizer = optim.Adam(best_latent.noise, lr=args.noise_lr) 173 | scheduling_function = LambdaLRWithRamp.get_lr_with_ramp(args.style_noise_step, args.noise_style_lr_rampdown, args.noise_style_lr_rampup) 174 | lr_scheduler = LambdaLRWithRamp(optimizer, scheduling_function) 175 | 176 | more_latent_path, more_noise_path = projector.project( 177 | best_latent.to(args.device), 178 | content_image, 179 | optimizer, 180 | args.style_noise_step, 181 | noise_loss( 182 | {"l_mse_1": 1, "l_mse_2": 1}, 183 | content_image, 184 | projector.generate(best_latent.to(args.device))[0], 185 | # style_image, 186 | mask_image 187 | ), 188 | lr_scheduler=lr_scheduler, 189 | ) 190 | 191 | latent_path = latent_path + more_latent_path 192 | 193 | return latent_path, best_latent 194 | -------------------------------------------------------------------------------- /latent_projecting/losses.py: -------------------------------------------------------------------------------- 1 | import typing as tp 2 | 3 | import torch 4 | from torch import Tensor 5 | from torch.nn import functional as F 6 | 7 | from losses.lpips import PerceptualLoss 8 | from losses.perceptual_style_loss import FixedPerceptualAndStyleLoss 9 | 10 | 11 | def w_plus_loss(lambdas: tp.Dict[str, float], device: str) -> tp.Callable: 12 | perceptual_loss_function = PerceptualLoss( 13 | model='net-lin', net='vgg', use_gpu=device.startswith('cuda') 14 | ) 15 | 16 | def loss_impl(generated_image: Tensor, original_image: Tensor) -> tp.Tuple[torch.Tensor, dict]: 17 | perceptual_loss = lambdas['l_percept'] * perceptual_loss_function(generated_image, original_image).sum() 18 | mse_loss = lambdas['l_mse'] * F.mse_loss(generated_image, original_image, reduction='none') 19 | mse_loss = mse_loss.mean(dim=(1, 2, 3)).sum() 20 | loss = perceptual_loss + mse_loss 21 | loss_dict = { 22 | 'perceptual_loss': perceptual_loss.item(), 23 | 'mse_loss': mse_loss.item(), 24 | } 25 | return loss, loss_dict 26 | 27 | return loss_impl 28 | 29 | 30 | def naive_noise_loss(lambdas: tp.Dict[str, float]) -> tp.Callable: 31 | def loss_impl(generated_image: Tensor, original_image: Tensor) -> tp.Tuple[torch.Tensor, dict]: 32 | mse_loss = lambdas['l_mse'] * F.mse_loss(generated_image, original_image, reduction='none') 33 | mse_loss = mse_loss.mean(dim=(1, 2, 3)).sum() 34 | loss_dict = {'mse_loss': mse_loss.item()} 35 | return mse_loss, loss_dict 36 | 37 | return loss_impl 38 | 39 | 40 | def w_plus_style_loss(lambdas: tp.Dict[str, float], content_image: Tensor, style_image: Tensor, mask_image: Tensor, device: str) -> tp.Callable: 41 | perceptual_and_style_loss = FixedPerceptualAndStyleLoss(content_image, style_image, mask_image.detach(), (1 - mask_image).detach()) 42 | perceptual_and_style_loss.to(device) 43 | 44 | def loss_impl(generated_image: Tensor, original_image: Tensor) -> tp.Tuple[Tensor, dict]: 45 | style_loss, perceptual_loss = perceptual_and_style_loss(generated_image) 46 | style_loss = lambdas['l_style'] * style_loss 47 | perceptual_loss = lambdas['l_percept'] * perceptual_loss 48 | 49 | mse_loss = torch.square(mask_image * (generated_image - content_image)).mean() 50 | mse_loss = lambdas['l_mse'] * mse_loss 51 | loss_dict = { 52 | 'mse_loss': mse_loss.item(), 53 | 'style_loss': style_loss.item(), 54 | 'perceptual_loss': perceptual_loss.item(), 55 | } 56 | loss = mse_loss + style_loss + perceptual_loss 57 | 58 | return loss, loss_dict 59 | 60 | return loss_impl 61 | 62 | 63 | def noise_loss(lambdas: tp.Dict[str, float], content_image: Tensor, style_image: Tensor, mask_image: Tensor) -> tp.Callable: 64 | 65 | def loss_impl(generated_image: Tensor, original_image: Tensor) -> tp.Tuple[Tensor, dict]: 66 | mse_loss_1 = lambdas['l_mse_1'] * torch.square(mask_image * (generated_image - content_image.detach())).mean() 67 | mse_loss_2 = lambdas['l_mse_2'] * torch.square((1 - mask_image) * (generated_image - style_image.detach())).mean() 68 | 69 | loss_dict = { 70 | 'mse_1': mse_loss_1.item(), 71 | 'mse_2': mse_loss_2.item(), 72 | } 73 | loss = mse_loss_1 + mse_loss_2 74 | 75 | return loss, loss_dict 76 | 77 | return loss_impl 78 | -------------------------------------------------------------------------------- /latent_projecting/projector.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import subprocess 4 | import tempfile 5 | from collections import defaultdict 6 | from pathlib import Path 7 | from typing import Callable, Tuple, Union 8 | 9 | import torch 10 | from PIL import Image, ImageFilter 11 | from matplotlib import pyplot as plt 12 | from torch.optim import Optimizer 13 | from torch.optim.lr_scheduler import _LRScheduler 14 | from torchvision import transforms 15 | from tqdm import tqdm 16 | 17 | from latent_projecting import Latents, LatentPaths 18 | from utils.config import load_config 19 | from losses.psnr import PSNR 20 | from networks import get_stylegan1_generator, StyledGenerator, get_stylegan2_generator 21 | from networks.stylegan2.model import Generator 22 | from pytorch_training.data import Compose 23 | from pytorch_training.images.utils import make_image 24 | from utils.image_utils import render_text_on_image 25 | 26 | 27 | class Projector: 28 | 29 | def __init__(self, args, abort_condition=None): 30 | self.args = args 31 | self.abort_condition = abort_condition 32 | 33 | self.device = args.device 34 | self.config = load_config(args.ckpt, args.config) 35 | self.generator = self.load_generator() 36 | self.generator.eval() 37 | self.debug_step = args.debug_step 38 | 39 | self.psnr = PSNR() 40 | self.log = [] 41 | 42 | def reset(self): 43 | self.log.clear() 44 | 45 | def load_generator(self) -> Union[Generator, StyledGenerator]: 46 | if self.config['stylegan_variant'] == 2: 47 | generator = get_stylegan2_generator(self.config['image_size'], self.config['latent_size'], 48 | init_ckpt=self.config['stylegan_checkpoint']) 49 | else: 50 | generator = get_stylegan1_generator(self.config['image_size'], self.config['latent_size'], 51 | init_ckpt=self.config['stylegan_checkpoint']) 52 | 53 | generator.eval() 54 | generator = generator.to(self.device) 55 | return generator 56 | 57 | def get_transforms(self) -> Compose: 58 | return Compose( 59 | [ 60 | transforms.Resize((self.config['image_size'], self.config['image_size'])), 61 | transforms.ToTensor(), 62 | transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), 63 | ] 64 | ) 65 | 66 | def get_blur_transform(self, from_tensor: bool = True) -> list: 67 | blur_transform = [transforms.Lambda(lambda image: image.filter(ImageFilter.GaussianBlur(radius=3)))] 68 | if from_tensor: 69 | blur_transform.insert(0, transforms.ToPILImage()) 70 | blur_transform.append(transforms.ToTensor()) 71 | return blur_transform 72 | 73 | def get_mask_transform(self, invert_mask: bool = False, mask_multiplier: float = 1) -> Compose: 74 | transformations = [ 75 | transforms.Resize((self.config['image_size'], self.config['image_size'])), 76 | transforms.Grayscale(num_output_channels=1), 77 | transforms.ToTensor(), 78 | ] 79 | if invert_mask: 80 | transformations.append(transforms.Lambda(lambda image: 1 - image)) 81 | 82 | if mask_multiplier < 1: 83 | multiplier_transform = transforms.Lambda(lambda image: image * mask_multiplier) 84 | transformations.append(multiplier_transform) 85 | 86 | transformations.extend(self.get_blur_transform()) 87 | 88 | return Compose(transformations) 89 | 90 | @staticmethod 91 | def sample_mean_latent(sample_size: int, latent_size: int, device: str, generator: Union[StyledGenerator, Generator]) -> Tuple[torch.Tensor, torch.Tensor]: 92 | with torch.no_grad(): 93 | noise_sample = torch.randn(sample_size, latent_size, device=device) 94 | latent_out = generator.style(noise_sample) 95 | 96 | latent_mean = latent_out.mean(0) 97 | latent_std = ((latent_out - latent_mean).pow(2).sum() / sample_size) ** 0.5 98 | return latent_mean, latent_std 99 | 100 | def get_mean_latent(self, sample_size: int) -> Tuple[torch.Tensor, torch.Tensor]: 101 | return self.sample_mean_latent(sample_size, self.config['latent_size'], self.device, self.generator) 102 | 103 | def set_requires_grad(self, latents: Latents, flag: bool): 104 | latents.latent.requires_grad = flag 105 | 106 | for noise in latents.noise: 107 | noise.requires_grad = not flag 108 | 109 | def create_initial_latent_and_noise(self) -> Latents: 110 | n_mean_latent = 10000 111 | latent_mean, latent_std = self.get_mean_latent(n_mean_latent) 112 | 113 | base_noises = self.generator.make_noise() 114 | noises = [noise.detach().clone() for noise in base_noises] 115 | 116 | if self.args.no_mean_latent: 117 | latent_in = torch.normal(0, latent_std.item(), size=(1, self.config['latent_size']), 118 | device=self.device) 119 | else: 120 | latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(1, 1) 121 | 122 | if self.args.w_plus: 123 | latent_in = latent_in.unsqueeze(1).repeat(1, self.generator.n_latent, 1) 124 | 125 | return Latents(latent_in, noises) 126 | 127 | def generate(self, latents: Latents) -> torch.Tensor: 128 | return self.generator([latents.latent], input_is_latent=True, noise=latents.noise) 129 | 130 | def project(self, latents: Latents, images: torch.Tensor, optimizer: Optimizer, num_steps: int, loss_function: Callable, lr_scheduler: _LRScheduler = None) -> Tuple[LatentPaths, Latents]: 131 | pbar = tqdm(range(num_steps), leave=False) 132 | latent_path = [] 133 | noise_path = [] 134 | 135 | best_latent = best_noise = best_psnr = None 136 | 137 | for i in pbar: 138 | img_gen, _ = self.generate(latents) 139 | 140 | batch, channel, height, width = img_gen.shape 141 | 142 | if height > 256: 143 | factor = height // 256 144 | 145 | img_gen = img_gen.reshape( 146 | batch, channel, height // factor, factor, width // factor, factor 147 | ) 148 | img_gen = img_gen.mean([3, 5]) 149 | 150 | # # n_loss = noise_regularize(noises) 151 | loss, loss_dict = loss_function(img_gen, images) 152 | 153 | optimizer.zero_grad() 154 | loss.backward() 155 | optimizer.step() 156 | 157 | loss_dict['psnr'] = self.psnr(img_gen, images).item() 158 | loss_dict['lr'] = optimizer.param_groups[0]["lr"] 159 | 160 | if lr_scheduler is not None: 161 | lr_scheduler.step() 162 | 163 | self.log.append(loss_dict) 164 | 165 | if best_psnr is None or best_psnr < loss_dict['psnr']: 166 | best_psnr = loss_dict['psnr'] 167 | best_latent = latents.latent.detach().clone().cpu() 168 | best_noise = [noise.detach().clone().cpu() for noise in latents.noise] 169 | 170 | if i % self.debug_step == 0: 171 | latent_path.append(latents.latent.detach().clone().cpu()) 172 | noise_path.append([noise.detach().clone().cpu() for noise in latents.noise]) 173 | 174 | loss_description = "; ".join(f"{key}: {value:.6f}" for key, value in loss_dict.items()) 175 | pbar.set_description(loss_description) 176 | 177 | loss_dict['iteration'] = i 178 | if self.abort_condition is not None and self.abort_condition(loss_dict): 179 | break 180 | 181 | latent_path.append(latents.latent.detach().clone().cpu()) 182 | noise_path.append([noise.detach().clone().cpu() for noise in latents.noise]) 183 | 184 | return LatentPaths(latent_path, noise_path), Latents(best_latent, best_noise) 185 | 186 | def create_gif(self, latent_paths: LatentPaths, image_name: str, destination_dir: Path) -> None: 187 | destination_dir = destination_dir / 'gifs' 188 | destination_dir.mkdir(parents=True, exist_ok=True) 189 | 190 | with tempfile.TemporaryDirectory() as temp_dir, tempfile.NamedTemporaryFile(mode='w') as temp_file: 191 | latent_paths = latent_paths.to(self.device) 192 | temp_dir = Path(temp_dir) 193 | 194 | progress_bar = tqdm(latent_paths) 195 | progress_bar.set_description("creating gif") 196 | for i, latent in enumerate(progress_bar): 197 | img_gen, _ = self.generator([latent.latent], input_is_latent=True, noise=latent.noise) 198 | 199 | image_index = i * self.debug_step 200 | temp_dest_name = temp_dir / f"{image_index}.png" 201 | img_ar = make_image(img_gen)[0] 202 | image = Image.fromarray(img_ar) 203 | image = render_text_on_image(f"{image_index:06}", image) 204 | image.save(temp_dest_name) 205 | 206 | print(temp_dest_name, file=temp_file) 207 | 208 | process_args = [ 209 | 'convert', 210 | '-delay 10', 211 | '-loop 0', 212 | f'@{temp_file.name}', 213 | str(destination_dir / f"{image_name}.gif") 214 | ] 215 | temp_file.flush() 216 | subprocess.run(' '.join(process_args), shell=True, check=True) 217 | 218 | def render_log(self, destination_dir: Union[str, Path], image_base_name: str) -> None: 219 | destination_dir = Path(destination_dir) 220 | destination_dir = destination_dir / 'log' 221 | os.makedirs(destination_dir, exist_ok=True) 222 | 223 | plot_data = defaultdict(list) 224 | for logged_data in self.log: 225 | for key, value in logged_data.items(): 226 | plot_data[key].append(value) 227 | 228 | for key in plot_data.keys(): 229 | plt.clf() 230 | fig, axis = plt.subplots() 231 | axis.plot(plot_data[key]) 232 | axis.set(ylabel=key) 233 | axis.grid() 234 | 235 | fig.savefig(destination_dir / f"{image_base_name}_{key}.png") 236 | plt.close(fig) 237 | 238 | with open(destination_dir / f"{image_base_name}_log.json", 'w') as f: 239 | json.dump(self.log, f, indent='\t') 240 | -------------------------------------------------------------------------------- /latent_projecting/style_transfer.py: -------------------------------------------------------------------------------- 1 | from argparse import Namespace 2 | from pathlib import Path 3 | from typing import Union, Tuple, Optional, Dict 4 | 5 | import torch 6 | from PIL import Image 7 | 8 | from latent_projecting import Latents, run_image_reconstruction, LatentPaths, noise_loss, optimize_noise 9 | from networks import UNetLikeEncoder 10 | from networks.stylegan1.model import Generator as StyleGan1Generator 11 | from networks.stylegan2.model import Generator as StyleGan2Generator 12 | from pytorch_training.images.utils import is_image, load_and_prepare_image, clamp_and_unnormalize, make_image 13 | from latent_projecting.projector import Projector 14 | 15 | 16 | class StyleTransferer: 17 | 18 | def __init__(self, opts: Namespace) -> None: 19 | self.args = opts 20 | self.projector = Projector(opts) 21 | 22 | def embed_image(self, image_path: Union[str, Path], is_content_image: bool = False) -> Latents: 23 | image = load_and_prepare_image(image_path, self.projector.get_transforms()).to(self.projector.device) 24 | latents_in = self.projector.create_initial_latent_and_noise() 25 | 26 | _, best_latent = run_image_reconstruction(self.projector.args, self.projector, latents_in, image) 27 | 28 | return best_latent 29 | 30 | def get_latents(self, content_path: Union[str, Path], style_path: Union[str, Path]) -> Tuple[Latents, Latents]: 31 | if is_image(content_path): 32 | content_latents = self.embed_image(content_path, True) 33 | else: 34 | embedded_data = torch.load(content_path) 35 | content_latents = Latents(embedded_data['latent'], embedded_data['noise']) 36 | 37 | if is_image(style_path): 38 | style_latents = self.embed_image(style_path, False) 39 | else: 40 | embedded_data = torch.load(style_path) 41 | style_latents = Latents(embedded_data['latent'], embedded_data['noise']) 42 | 43 | for latents in [content_latents, style_latents]: 44 | if len(latents.latent.shape) < 3: 45 | latents.latent = latents.latent.unsqueeze(0) 46 | 47 | return content_latents.to(self.projector.device), style_latents.to(self.projector.device) 48 | 49 | def post_noise_optimize(self, content_latent: Latents, transfer_latent: Latents) -> Tuple[LatentPaths, Latents]: 50 | content_latent = content_latent.to(self.projector.device) 51 | transfer_latent = transfer_latent.to(self.projector.device) 52 | 53 | content_image = self.projector.generate(content_latent)[0].detach() 54 | style_image = self.projector.generate(transfer_latent)[0].detach() 55 | content_mask = clamp_and_unnormalize(content_image.clone().detach()) 56 | loss_func = noise_loss( 57 | {"l_mse_1": 1, "l_mse_2": 1}, 58 | content_image, 59 | style_image, 60 | (1 - content_mask).detach() 61 | ) 62 | 63 | path, latent_and_noise = optimize_noise(self.args, self.projector, transfer_latent, content_image, loss_func) 64 | 65 | return path, latent_and_noise 66 | 67 | def do_style_transfer(self, content_latent: Latents, style_latent: Latents, layer_id: int) -> Tuple[torch.Tensor, Optional[LatentPaths]]: 68 | latent = torch.cat([content_latent.latent[:, :layer_id, :], style_latent.latent[:, layer_id:, :]], dim=1).detach().clone() 69 | latent = latent.to(self.projector.device) 70 | # noise = content_latent.noise[:layer_id] + style_latent.noise[layer_id:] 71 | noise = [n.detach().clone().to(self.projector.device) for n in content_latent.noise] 72 | latent_and_noise = Latents(latent, noise) 73 | 74 | path = None 75 | if self.args.post_optimize: 76 | path, latent_and_noise = self.post_noise_optimize(content_latent, latent_and_noise) 77 | 78 | latent_and_noise = latent_and_noise.to(self.projector.device) 79 | return self.projector.generate(latent_and_noise)[0], path 80 | 81 | def save_stylized_images(self, stylized_images: Dict[int, Tuple[torch.Tensor, Optional[LatentPaths]]], content_path: Path, style_path: Path, destination_dir: Path, create_gif: bool = False): 82 | for index, (image_array, optimization_path) in stylized_images.items(): 83 | content_name = content_path.stem 84 | style_name = style_path.stem 85 | 86 | image_name = f"{content_name}_{style_name}_{index}" 87 | destination_name = destination_dir / f"{image_name}.png" 88 | Image.fromarray(make_image(image_array)[0]).save(destination_name) 89 | 90 | if optimization_path is not None and create_gif: 91 | self.projector.create_gif(optimization_path, image_name, destination_dir) 92 | 93 | 94 | class EncoderBasedStyleTransferer(StyleTransferer): 95 | 96 | def __init__(self, *args, **kwargs): 97 | super().__init__(*args, **kwargs) 98 | self.content_encoder = self.build_encoder(self.args.content_checkpoint) 99 | self.style_encoder = self.build_encoder(self.args.style_checkpoint) 100 | 101 | def build_encoder(self, checkpoint) -> UNetLikeEncoder: 102 | if self.projector.config['stylegan_variant'] == 1: 103 | channel_map = StyleGan1Generator.get_channels() 104 | else: 105 | channel_map = StyleGan2Generator.get_channels() 106 | 107 | encoder = UNetLikeEncoder( 108 | self.projector.config['image_size'], 109 | self.projector.config['latent_size'], 110 | self.projector.config['input_dim'], 111 | channel_map 112 | ) 113 | encoder.eval() 114 | 115 | checkpoint = torch.load(checkpoint) 116 | 117 | if 'autoencoder' in checkpoint: 118 | # we need to adapt the tensors we actually want to load 119 | stripped_checkpoint = {key: value for key, value in checkpoint['autoencoder'].items() if 'encoder' in key} 120 | checkpoint = {'.'.join(key.split('.')[2:]): value for key, value in stripped_checkpoint.items()} 121 | 122 | encoder.load_state_dict(checkpoint) 123 | 124 | return encoder.to(self.projector.device) 125 | 126 | def embed_image(self, image_path: Union[str, Path], is_content_image: bool = True) -> Latents: 127 | image = load_and_prepare_image(image_path, self.projector.get_transforms()).to(self.projector.device) 128 | if is_content_image: 129 | encoder = self.content_encoder 130 | else: 131 | encoder = self.style_encoder 132 | 133 | with torch.no_grad(): 134 | return encoder(image) 135 | -------------------------------------------------------------------------------- /losses/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import Tensor 3 | 4 | 5 | def gram_matrix(features: Tensor, mask: Tensor) -> Tensor: 6 | batch_size, num_channels, height, width = features.shape 7 | 8 | if mask is not None: 9 | normalize_denominator = mask.square().sum(dim=(2, 3)).sqrt() 10 | normalize_denominator = normalize_denominator.expand(1, 1, -1, -1) 11 | normalize_denominator = normalize_denominator.permute((2, 3, 0, 1)) 12 | normalize_denominator = normalize_denominator.repeat((1,) + mask.shape[1:]) 13 | normalized_mask = mask / normalize_denominator 14 | features = normalized_mask * features 15 | 16 | features = features.view(batch_size * num_channels, height * width) 17 | features = features.permute((1, 0)) 18 | return torch.mm(features.T, features) 19 | 20 | 21 | def euclidean_distance(tensor_1: Tensor, tensor_2: Tensor, mask: Tensor = None) -> Tensor: 22 | difference = tensor_1 - tensor_2 23 | if mask is not None: 24 | difference = mask * difference 25 | 26 | distance = difference.square().sum().sqrt() / tensor_1.shape.numel() 27 | return distance 28 | -------------------------------------------------------------------------------- /losses/lpips/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import numpy as np 7 | from skimage.measure import compare_ssim 8 | import torch 9 | from torch.autograd import Variable 10 | 11 | from . import dist_model 12 | 13 | 14 | class PerceptualLoss(torch.nn.Module): 15 | def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) 16 | # def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss 17 | super(PerceptualLoss, self).__init__() 18 | print('Setting up Perceptual loss...') 19 | self.use_gpu = use_gpu 20 | self.spatial = spatial 21 | self.gpu_ids = gpu_ids 22 | self.model = dist_model.DistModel() 23 | self.model.initialize(model=model, net=net, use_gpu=use_gpu, colorspace=colorspace, spatial=self.spatial, gpu_ids=gpu_ids) 24 | print('...[%s] initialized'%self.model.name()) 25 | print('...Done') 26 | 27 | def forward(self, pred, target, normalize=False, ret_per_layer=False): 28 | """ 29 | Pred and target are Variables. 30 | If normalize is True, assumes the images are between [0,1] and then scales them between [-1,+1] 31 | If normalize is False, assumes the images are already between [-1,+1] 32 | 33 | Inputs pred and target are Nx3xHxW 34 | Output pytorch Variable N long 35 | """ 36 | 37 | if normalize: 38 | target = 2 * target - 1 39 | pred = 2 * pred - 1 40 | 41 | return self.model.forward(target, pred, retPerLayer=ret_per_layer) 42 | 43 | def normalize_tensor(in_feat,eps=1e-10): 44 | norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True)) 45 | return in_feat/(norm_factor+eps) 46 | 47 | def l2(p0, p1, range=255.): 48 | return .5*np.mean((p0 / range - p1 / range)**2) 49 | 50 | def psnr(p0, p1, peak=255.): 51 | return 10*np.log10(peak**2/np.mean((1.*p0-1.*p1)**2)) 52 | 53 | def dssim(p0, p1, range=255.): 54 | return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2. 55 | 56 | def rgb2lab(in_img,mean_cent=False): 57 | from skimage import color 58 | img_lab = color.rgb2lab(in_img) 59 | if(mean_cent): 60 | img_lab[:,:,0] = img_lab[:,:,0]-50 61 | return img_lab 62 | 63 | def tensor2np(tensor_obj): 64 | # change dimension of a tensor object into a numpy array 65 | return tensor_obj[0].cpu().float().numpy().transpose((1,2,0)) 66 | 67 | def np2tensor(np_obj): 68 | # change dimenion of np array into tensor array 69 | return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 70 | 71 | def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False): 72 | # image tensor to lab tensor 73 | from skimage import color 74 | 75 | img = tensor2im(image_tensor) 76 | img_lab = color.rgb2lab(img) 77 | if(mc_only): 78 | img_lab[:,:,0] = img_lab[:,:,0]-50 79 | if(to_norm and not mc_only): 80 | img_lab[:,:,0] = img_lab[:,:,0]-50 81 | img_lab = img_lab/100. 82 | 83 | return np2tensor(img_lab) 84 | 85 | def tensorlab2tensor(lab_tensor,return_inbnd=False): 86 | from skimage import color 87 | import warnings 88 | warnings.filterwarnings("ignore") 89 | 90 | lab = tensor2np(lab_tensor)*100. 91 | lab[:,:,0] = lab[:,:,0]+50 92 | 93 | rgb_back = 255.*np.clip(color.lab2rgb(lab.astype('float')),0,1) 94 | if(return_inbnd): 95 | # convert back to lab, see if we match 96 | lab_back = color.rgb2lab(rgb_back.astype('uint8')) 97 | mask = 1.*np.isclose(lab_back,lab,atol=2.) 98 | mask = np2tensor(np.prod(mask,axis=2)[:,:,np.newaxis]) 99 | return (im2tensor(rgb_back),mask) 100 | else: 101 | return im2tensor(rgb_back) 102 | 103 | def rgb2lab(input): 104 | from skimage import color 105 | return color.rgb2lab(input / 255.) 106 | 107 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 108 | image_numpy = image_tensor[0].cpu().float().numpy() 109 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 110 | return image_numpy.astype(imtype) 111 | 112 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 113 | return torch.Tensor((image / factor - cent) 114 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 115 | 116 | def tensor2vec(vector_tensor): 117 | return vector_tensor.data.cpu().numpy()[:, :, 0, 0] 118 | 119 | def voc_ap(rec, prec, use_07_metric=False): 120 | """ ap = voc_ap(rec, prec, [use_07_metric]) 121 | Compute VOC AP given precision and recall. 122 | If use_07_metric is true, uses the 123 | VOC 07 11 point method (default:False). 124 | """ 125 | if use_07_metric: 126 | # 11 point metric 127 | ap = 0. 128 | for t in np.arange(0., 1.1, 0.1): 129 | if np.sum(rec >= t) == 0: 130 | p = 0 131 | else: 132 | p = np.max(prec[rec >= t]) 133 | ap = ap + p / 11. 134 | else: 135 | # correct AP calculation 136 | # first append sentinel values at the end 137 | mrec = np.concatenate(([0.], rec, [1.])) 138 | mpre = np.concatenate(([0.], prec, [0.])) 139 | 140 | # compute the precision envelope 141 | for i in range(mpre.size - 1, 0, -1): 142 | mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i]) 143 | 144 | # to calculate area under PR curve, look for points 145 | # where X axis (recall) changes value 146 | i = np.where(mrec[1:] != mrec[:-1])[0] 147 | 148 | # and sum (\Delta recall) * prec 149 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1]) 150 | return ap 151 | 152 | def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.): 153 | # def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=1.): 154 | image_numpy = image_tensor[0].cpu().float().numpy() 155 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor 156 | return image_numpy.astype(imtype) 157 | 158 | def im2tensor(image, imtype=np.uint8, cent=1., factor=255./2.): 159 | # def im2tensor(image, imtype=np.uint8, cent=1., factor=1.): 160 | return torch.Tensor((image / factor - cent) 161 | [:, :, :, np.newaxis].transpose((3, 2, 0, 1))) 162 | -------------------------------------------------------------------------------- /losses/lpips/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch.autograd import Variable 4 | from pdb import set_trace as st 5 | from IPython import embed 6 | 7 | class BaseModel(): 8 | def __init__(self): 9 | pass; 10 | 11 | def name(self): 12 | return 'BaseModel' 13 | 14 | def initialize(self, use_gpu=True, gpu_ids=[0]): 15 | self.use_gpu = use_gpu 16 | self.gpu_ids = gpu_ids 17 | 18 | def forward(self): 19 | pass 20 | 21 | def get_image_paths(self): 22 | pass 23 | 24 | def optimize_parameters(self): 25 | pass 26 | 27 | def get_current_visuals(self): 28 | return self.input 29 | 30 | def get_current_errors(self): 31 | return {} 32 | 33 | def save(self, label): 34 | pass 35 | 36 | # helper saving function that can be used by subclasses 37 | def save_network(self, network, path, network_label, epoch_label): 38 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 39 | save_path = os.path.join(path, save_filename) 40 | torch.save(network.state_dict(), save_path) 41 | 42 | # helper loading function that can be used by subclasses 43 | def load_network(self, network, network_label, epoch_label): 44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label) 45 | save_path = os.path.join(self.save_dir, save_filename) 46 | print('Loading network from %s'%save_path) 47 | network.load_state_dict(torch.load(save_path)) 48 | 49 | def update_learning_rate(): 50 | pass 51 | 52 | def get_image_paths(self): 53 | return self.image_paths 54 | 55 | def save_done(self, flag=False): 56 | np.save(os.path.join(self.save_dir, 'done_flag'),flag) 57 | np.savetxt(os.path.join(self.save_dir, 'done_flag'),[flag,],fmt='%i') 58 | 59 | -------------------------------------------------------------------------------- /losses/lpips/networks_basic.py: -------------------------------------------------------------------------------- 1 | 2 | from __future__ import absolute_import 3 | 4 | import sys 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.init as init 8 | from torch.autograd import Variable 9 | import numpy as np 10 | from pdb import set_trace as st 11 | from skimage import color 12 | from IPython import embed 13 | from . import pretrained_networks as pn 14 | 15 | import losses.lpips as util 16 | 17 | 18 | def spatial_average(in_tens, keepdim=True): 19 | return in_tens.mean([2,3],keepdim=keepdim) 20 | 21 | 22 | def upsample(in_tens, out_H=64): # assumes scale factor is same for H and W 23 | in_H = in_tens.shape[2] 24 | scale_factor = 1.*out_H/in_H 25 | 26 | return nn.Upsample(scale_factor=scale_factor, mode='bilinear', align_corners=False)(in_tens) 27 | 28 | # Learned perceptual metric 29 | class PNetLin(nn.Module): 30 | def __init__(self, pnet_type='vgg', pnet_rand=False, pnet_tune=False, use_dropout=True, spatial=False, version='0.1', lpips=True): 31 | super(PNetLin, self).__init__() 32 | 33 | self.pnet_type = pnet_type 34 | self.pnet_tune = pnet_tune 35 | self.pnet_rand = pnet_rand 36 | self.spatial = spatial 37 | self.lpips = lpips 38 | self.version = version 39 | self.scaling_layer = ScalingLayer() 40 | 41 | if(self.pnet_type in ['vgg','vgg16']): 42 | net_type = pn.vgg16 43 | self.chns = [64,128,256,512,512] 44 | elif(self.pnet_type=='alex'): 45 | net_type = pn.alexnet 46 | self.chns = [64,192,384,256,256] 47 | elif(self.pnet_type=='squeeze'): 48 | net_type = pn.squeezenet 49 | self.chns = [64,128,256,384,384,512,512] 50 | self.L = len(self.chns) 51 | 52 | self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune) 53 | 54 | if(lpips): 55 | self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) 56 | self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) 57 | self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) 58 | self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) 59 | self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) 60 | self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4] 61 | if(self.pnet_type=='squeeze'): # 7 layers for squeezenet 62 | self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout) 63 | self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout) 64 | self.lins+=[self.lin5,self.lin6] 65 | 66 | def forward(self, in0, in1, retPerLayer=False): 67 | # v0.0 - original release had a bug, where input was not scaled 68 | in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1) 69 | outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input) 70 | feats0, feats1, diffs = {}, {}, {} 71 | 72 | for kk in range(self.L): 73 | feats0[kk], feats1[kk] = util.normalize_tensor(outs0[kk]), util.normalize_tensor(outs1[kk]) 74 | diffs[kk] = (feats0[kk]-feats1[kk])**2 75 | 76 | if(self.lpips): 77 | if(self.spatial): 78 | res = [upsample(self.lins[kk].model(diffs[kk]), out_H=in0.shape[2]) for kk in range(self.L)] 79 | else: 80 | res = [spatial_average(self.lins[kk].model(diffs[kk]), keepdim=True) for kk in range(self.L)] 81 | else: 82 | if(self.spatial): 83 | res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_H=in0.shape[2]) for kk in range(self.L)] 84 | else: 85 | res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)] 86 | 87 | val = res[0] 88 | for l in range(1,self.L): 89 | val += res[l] 90 | 91 | if(retPerLayer): 92 | return (val, res) 93 | else: 94 | return val 95 | 96 | class ScalingLayer(nn.Module): 97 | def __init__(self): 98 | super(ScalingLayer, self).__init__() 99 | self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None]) 100 | self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None]) 101 | 102 | def forward(self, inp): 103 | return (inp - self.shift) / self.scale 104 | 105 | 106 | class NetLinLayer(nn.Module): 107 | ''' A single linear layer which does a 1x1 conv ''' 108 | def __init__(self, chn_in, chn_out=1, use_dropout=False): 109 | super(NetLinLayer, self).__init__() 110 | 111 | layers = [nn.Dropout(),] if(use_dropout) else [] 112 | layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),] 113 | self.model = nn.Sequential(*layers) 114 | 115 | 116 | class Dist2LogitLayer(nn.Module): 117 | ''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) ''' 118 | def __init__(self, chn_mid=32, use_sigmoid=True): 119 | super(Dist2LogitLayer, self).__init__() 120 | 121 | layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),] 122 | layers += [nn.LeakyReLU(0.2,True),] 123 | layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),] 124 | layers += [nn.LeakyReLU(0.2,True),] 125 | layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),] 126 | if(use_sigmoid): 127 | layers += [nn.Sigmoid(),] 128 | self.model = nn.Sequential(*layers) 129 | 130 | def forward(self,d0,d1,eps=0.1): 131 | return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1)) 132 | 133 | class BCERankingLoss(nn.Module): 134 | def __init__(self, chn_mid=32): 135 | super(BCERankingLoss, self).__init__() 136 | self.net = Dist2LogitLayer(chn_mid=chn_mid) 137 | # self.parameters = list(self.net.parameters()) 138 | self.loss = torch.nn.BCELoss() 139 | 140 | def forward(self, d0, d1, judge): 141 | per = (judge+1.)/2. 142 | self.logit = self.net.forward(d0,d1) 143 | return self.loss(self.logit, per) 144 | 145 | # L2, DSSIM metrics 146 | class FakeNet(nn.Module): 147 | def __init__(self, use_gpu=True, colorspace='Lab'): 148 | super(FakeNet, self).__init__() 149 | self.use_gpu = use_gpu 150 | self.colorspace=colorspace 151 | 152 | class L2(FakeNet): 153 | 154 | def forward(self, in0, in1, retPerLayer=None): 155 | assert(in0.size()[0]==1) # currently only supports batchSize 1 156 | 157 | if(self.colorspace=='RGB'): 158 | (N,C,X,Y) = in0.size() 159 | value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N) 160 | return value 161 | elif(self.colorspace=='Lab'): 162 | value = util.l2(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 163 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 164 | ret_var = Variable( torch.Tensor((value,) ) ) 165 | if(self.use_gpu): 166 | ret_var = ret_var.cuda() 167 | return ret_var 168 | 169 | class DSSIM(FakeNet): 170 | 171 | def forward(self, in0, in1, retPerLayer=None): 172 | assert(in0.size()[0]==1) # currently only supports batchSize 1 173 | 174 | if(self.colorspace=='RGB'): 175 | value = util.dssim(1.*util.tensor2im(in0.data), 1.*util.tensor2im(in1.data), range=255.).astype('float') 176 | elif(self.colorspace=='Lab'): 177 | value = util.dssim(util.tensor2np(util.tensor2tensorlab(in0.data,to_norm=False)), 178 | util.tensor2np(util.tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float') 179 | ret_var = Variable( torch.Tensor((value,) ) ) 180 | if(self.use_gpu): 181 | ret_var = ret_var.cuda() 182 | return ret_var 183 | 184 | def print_network(net): 185 | num_params = 0 186 | for param in net.parameters(): 187 | num_params += param.numel() 188 | print('Network',net) 189 | print('Total number of parameters: %d' % num_params) 190 | -------------------------------------------------------------------------------- /losses/lpips/pretrained_networks.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import torch 3 | from torchvision import models as tv 4 | from IPython import embed 5 | 6 | class squeezenet(torch.nn.Module): 7 | def __init__(self, requires_grad=False, pretrained=True): 8 | super(squeezenet, self).__init__() 9 | pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features 10 | self.slice1 = torch.nn.Sequential() 11 | self.slice2 = torch.nn.Sequential() 12 | self.slice3 = torch.nn.Sequential() 13 | self.slice4 = torch.nn.Sequential() 14 | self.slice5 = torch.nn.Sequential() 15 | self.slice6 = torch.nn.Sequential() 16 | self.slice7 = torch.nn.Sequential() 17 | self.N_slices = 7 18 | for x in range(2): 19 | self.slice1.add_module(str(x), pretrained_features[x]) 20 | for x in range(2,5): 21 | self.slice2.add_module(str(x), pretrained_features[x]) 22 | for x in range(5, 8): 23 | self.slice3.add_module(str(x), pretrained_features[x]) 24 | for x in range(8, 10): 25 | self.slice4.add_module(str(x), pretrained_features[x]) 26 | for x in range(10, 11): 27 | self.slice5.add_module(str(x), pretrained_features[x]) 28 | for x in range(11, 12): 29 | self.slice6.add_module(str(x), pretrained_features[x]) 30 | for x in range(12, 13): 31 | self.slice7.add_module(str(x), pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h = self.slice1(X) 38 | h_relu1 = h 39 | h = self.slice2(h) 40 | h_relu2 = h 41 | h = self.slice3(h) 42 | h_relu3 = h 43 | h = self.slice4(h) 44 | h_relu4 = h 45 | h = self.slice5(h) 46 | h_relu5 = h 47 | h = self.slice6(h) 48 | h_relu6 = h 49 | h = self.slice7(h) 50 | h_relu7 = h 51 | vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7']) 52 | out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7) 53 | 54 | return out 55 | 56 | 57 | class alexnet(torch.nn.Module): 58 | def __init__(self, requires_grad=False, pretrained=True): 59 | super(alexnet, self).__init__() 60 | alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features 61 | self.slice1 = torch.nn.Sequential() 62 | self.slice2 = torch.nn.Sequential() 63 | self.slice3 = torch.nn.Sequential() 64 | self.slice4 = torch.nn.Sequential() 65 | self.slice5 = torch.nn.Sequential() 66 | self.N_slices = 5 67 | for x in range(2): 68 | self.slice1.add_module(str(x), alexnet_pretrained_features[x]) 69 | for x in range(2, 5): 70 | self.slice2.add_module(str(x), alexnet_pretrained_features[x]) 71 | for x in range(5, 8): 72 | self.slice3.add_module(str(x), alexnet_pretrained_features[x]) 73 | for x in range(8, 10): 74 | self.slice4.add_module(str(x), alexnet_pretrained_features[x]) 75 | for x in range(10, 12): 76 | self.slice5.add_module(str(x), alexnet_pretrained_features[x]) 77 | if not requires_grad: 78 | for param in self.parameters(): 79 | param.requires_grad = False 80 | 81 | def forward(self, X): 82 | h = self.slice1(X) 83 | h_relu1 = h 84 | h = self.slice2(h) 85 | h_relu2 = h 86 | h = self.slice3(h) 87 | h_relu3 = h 88 | h = self.slice4(h) 89 | h_relu4 = h 90 | h = self.slice5(h) 91 | h_relu5 = h 92 | alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5']) 93 | out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5) 94 | 95 | return out 96 | 97 | class vgg16(torch.nn.Module): 98 | def __init__(self, requires_grad=False, pretrained=True): 99 | super(vgg16, self).__init__() 100 | vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features 101 | self.slice1 = torch.nn.Sequential() 102 | self.slice2 = torch.nn.Sequential() 103 | self.slice3 = torch.nn.Sequential() 104 | self.slice4 = torch.nn.Sequential() 105 | self.slice5 = torch.nn.Sequential() 106 | self.N_slices = 5 107 | for x in range(4): 108 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 109 | for x in range(4, 9): 110 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 111 | for x in range(9, 16): 112 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 113 | for x in range(16, 23): 114 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 115 | for x in range(23, 30): 116 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 117 | if not requires_grad: 118 | for param in self.parameters(): 119 | param.requires_grad = False 120 | 121 | def forward(self, X): 122 | h = self.slice1(X) 123 | h_relu1_2 = h 124 | h = self.slice2(h) 125 | h_relu2_2 = h 126 | h = self.slice3(h) 127 | h_relu3_3 = h 128 | h = self.slice4(h) 129 | h_relu4_3 = h 130 | h = self.slice5(h) 131 | h_relu5_3 = h 132 | vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) 133 | out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3) 134 | 135 | return out 136 | 137 | 138 | 139 | class resnet(torch.nn.Module): 140 | def __init__(self, requires_grad=False, pretrained=True, num=18): 141 | super(resnet, self).__init__() 142 | if(num==18): 143 | self.net = tv.resnet18(pretrained=pretrained) 144 | elif(num==34): 145 | self.net = tv.resnet34(pretrained=pretrained) 146 | elif(num==50): 147 | self.net = tv.resnet50(pretrained=pretrained) 148 | elif(num==101): 149 | self.net = tv.resnet101(pretrained=pretrained) 150 | elif(num==152): 151 | self.net = tv.resnet152(pretrained=pretrained) 152 | self.N_slices = 5 153 | 154 | self.conv1 = self.net.conv1 155 | self.bn1 = self.net.bn1 156 | self.relu = self.net.relu 157 | self.maxpool = self.net.maxpool 158 | self.layer1 = self.net.layer1 159 | self.layer2 = self.net.layer2 160 | self.layer3 = self.net.layer3 161 | self.layer4 = self.net.layer4 162 | 163 | def forward(self, X): 164 | h = self.conv1(X) 165 | h = self.bn1(h) 166 | h = self.relu(h) 167 | h_relu1 = h 168 | h = self.maxpool(h) 169 | h = self.layer1(h) 170 | h_conv2 = h 171 | h = self.layer2(h) 172 | h_conv3 = h 173 | h = self.layer3(h) 174 | h_conv4 = h 175 | h = self.layer4(h) 176 | h_conv5 = h 177 | 178 | outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5']) 179 | out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5) 180 | 181 | return out 182 | -------------------------------------------------------------------------------- /losses/lpips/weights/v0.0/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/losses/lpips/weights/v0.0/alex.pth -------------------------------------------------------------------------------- /losses/lpips/weights/v0.0/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/losses/lpips/weights/v0.0/squeeze.pth -------------------------------------------------------------------------------- /losses/lpips/weights/v0.0/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/losses/lpips/weights/v0.0/vgg.pth -------------------------------------------------------------------------------- /losses/lpips/weights/v0.1/alex.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/losses/lpips/weights/v0.1/alex.pth -------------------------------------------------------------------------------- /losses/lpips/weights/v0.1/squeeze.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/losses/lpips/weights/v0.1/squeeze.pth -------------------------------------------------------------------------------- /losses/lpips/weights/v0.1/vgg.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/losses/lpips/weights/v0.1/vgg.pth -------------------------------------------------------------------------------- /losses/perceptual_loss.py: -------------------------------------------------------------------------------- 1 | from torch import nn, Tensor 2 | from torch.nn import functional as F 3 | 4 | 5 | class PerceptualLoss(nn.Module): 6 | 7 | def __init__(self, target: Tensor, mask: Tensor = None): 8 | super().__init__() 9 | if mask is not None: 10 | target = mask * target 11 | self.target = target.detach() 12 | self.mask = mask 13 | 14 | def forward(self, x: Tensor) -> Tensor: 15 | if self.mask is not None: 16 | x = x * self.mask.detach() 17 | 18 | loss = F.mse_loss(x, self.target) 19 | return loss 20 | -------------------------------------------------------------------------------- /losses/perceptual_style_loss.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple, Union 2 | 3 | import torch 4 | import torchvision 5 | from torch import nn 6 | from torch.nn import functional as F 7 | 8 | from losses.perceptual_loss import PerceptualLoss 9 | from losses.style_loss import StyleLoss 10 | 11 | 12 | class AbstractPerceptualAndStyleLoss(nn.Module): 13 | 14 | def adapt_vgg_model(self, vgg: nn.Module) -> nn.Module: 15 | layers = [] 16 | for layer in vgg.children(): 17 | if isinstance(layer, nn.ReLU): 18 | layer = nn.ReLU(inplace=False) 19 | layers.append(layer) 20 | return nn.Sequential(*layers) 21 | 22 | def disable_update(self): 23 | for parameter in self.parameters(): 24 | parameter.requires_grad = False 25 | 26 | 27 | class FixedPerceptualAndStyleLoss(AbstractPerceptualAndStyleLoss): 28 | 29 | def __init__(self, perceptual_target, style_target, perceptual_mask=None, style_mask=None): 30 | super().__init__() 31 | self.perceptual_target = perceptual_target 32 | self.style_target = style_target 33 | self.perceptual_mask = perceptual_mask 34 | self.style_mask = style_mask 35 | 36 | vgg = torchvision.models.vgg16(pretrained=True) 37 | vgg_features = self.adapt_vgg_model(vgg.features) 38 | 39 | blocks = { 40 | 'conv1_1': vgg_features[:2], 41 | 'conv1_2': vgg_features[2:4], 42 | 'conv2_2': vgg_features[4:9], 43 | 'conv3_3': vgg_features[9:16], 44 | } 45 | self.vgg_blocks = nn.ModuleDict(blocks) 46 | self.vgg_blocks.eval() 47 | 48 | self.style_losses = None 49 | self.perceptual_losses = None 50 | 51 | self.disable_update() 52 | 53 | def create_losses(self, target, mask, steps, loss_class) -> dict: 54 | features = target.clone() 55 | if mask is not None: 56 | mask = mask.clone() 57 | losses = {} 58 | 59 | for name, block in self.vgg_blocks.items(): 60 | features = block(features) 61 | if mask is not None: 62 | mask = F.interpolate(mask, size=features.shape[-2:], mode='nearest') 63 | if name in steps: 64 | loss = loss_class(features, mask=mask) 65 | losses[name] = loss 66 | 67 | return losses 68 | 69 | def build_style_and_perceptual_losses(self): 70 | self.style_losses = nn.ModuleDict( 71 | self.create_losses(self.style_target, self.style_mask, ['conv3_3'], StyleLoss) 72 | ) 73 | self.perceptual_losses = nn.ModuleDict( 74 | self.create_losses(self.perceptual_target, self.perceptual_mask, list(self.vgg_blocks.keys()), PerceptualLoss) 75 | ) 76 | self.disable_update() 77 | 78 | def forward(self, generated_image: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 79 | if self.style_losses is None: 80 | self.build_style_and_perceptual_losses() 81 | 82 | features = generated_image 83 | style_losses = [] 84 | perceptual_losses = [] 85 | for name, block in self.vgg_blocks.items(): 86 | features = block(features) 87 | if name in self.perceptual_losses: 88 | perceptual_losses.append(self.perceptual_losses[name](features)) 89 | if name in self.style_losses: 90 | style_losses.append(self.style_losses[name](features)) 91 | 92 | return sum(style_losses), sum(perceptual_losses) 93 | 94 | 95 | class PerceptualAndStyleLoss(AbstractPerceptualAndStyleLoss): 96 | 97 | def __init__(self, use_perceptual_loss=True, use_style_loss=True): 98 | super().__init__() 99 | vgg = torchvision.models.vgg16(pretrained=True) 100 | vgg_features = self.adapt_vgg_model(vgg.features) 101 | 102 | blocks = { 103 | 'conv1_1': vgg_features[:2], 104 | 'conv1_2': vgg_features[2:4], 105 | 'conv2_2': vgg_features[4:9], 106 | 'conv3_3': vgg_features[9:16], 107 | } 108 | self.vgg_blocks = nn.ModuleDict(blocks) 109 | self.vgg_blocks.eval() 110 | 111 | if use_perceptual_loss: 112 | self.perceptual_blocks = list(self.vgg_blocks.keys()) 113 | else: 114 | self.perceptual_blocks = [] 115 | 116 | if use_style_loss: 117 | self.style_blocks = ['conv3_3'] 118 | else: 119 | self.style_blocks = [] 120 | 121 | def forward(self, image: torch.Tensor, target: torch.Tensor, mask: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: 122 | image_features = image 123 | target_features = target 124 | style_losses = [] 125 | perceptual_losses = [] 126 | for name, block in self.vgg_blocks.items(): 127 | image_features = block(image_features) 128 | target_features = block(target_features) 129 | 130 | if mask is not None: 131 | mask = F.interpolate(mask, target_features.shape[-2:], mode='bilinear') 132 | 133 | if name in self.perceptual_blocks: 134 | perceptual_losses.append(self.run_loss(image_features, target_features, mask, PerceptualLoss)) 135 | 136 | if name in self.style_blocks: 137 | style_losses.append(self.run_loss(image_features, target_features, mask, StyleLoss)) 138 | 139 | return sum(style_losses), sum(perceptual_losses) 140 | 141 | @staticmethod 142 | def run_loss(image_features: torch.Tensor, target_features: torch.Tensor, mask: torch.Tensor, loss_class: Union[PerceptualLoss, StyleLoss]) -> torch.Tensor: 143 | loss_func = loss_class(target_features, mask=mask) 144 | return loss_func(image_features) 145 | 146 | 147 | class StyleLossNetwork(AbstractPerceptualAndStyleLoss): 148 | 149 | def __init__(self): 150 | super().__init__() 151 | vgg = torchvision.models.vgg16(pretrained=True) 152 | vgg_features = self.adapt_vgg_model(vgg.features) 153 | 154 | blocks = { 155 | 'conv3_3': vgg_features[:16], 156 | } 157 | self.vgg_blocks = nn.ModuleDict(blocks) 158 | self.vgg_blocks.eval() 159 | 160 | self.disable_update() 161 | 162 | def forward(self, generated_image: torch.Tensor, style_image: torch.Tensor) -> torch.Tensor: 163 | generated_features = generated_image 164 | style_features = style_image 165 | style_losses = [] 166 | 167 | for name, block in self.vgg_blocks.items(): 168 | generated_features = block(generated_features) 169 | style_features = block(style_features) 170 | 171 | loss_func = StyleLoss(style_features) 172 | style_loss = loss_func(generated_features) 173 | style_losses.append(style_loss) 174 | 175 | return sum(style_losses) 176 | -------------------------------------------------------------------------------- /losses/psnr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | class PSNR: 6 | 7 | def __init__(self, max_value: int = 1): 8 | self.name = "PSNR" 9 | self.max_value = max_value 10 | 11 | def __call__(self, image_1: torch.Tensor, image_2: torch.Tensor) -> torch.Tensor: 12 | assert image_1.shape == image_2.shape, "For a meaningful PSNR calculation, the shape of image_1 and image_2 should be the same" 13 | 14 | if len(image_1.shape) == 4: 15 | # we are dealing with a batch of images 16 | reduction = 'none' 17 | mean_dims = (1, 2, 3) 18 | else: 19 | reduction = 'mean' 20 | mean_dims = None 21 | 22 | mse = F.mse_loss(image_1, image_2, reduction=reduction) 23 | if mean_dims is not None: 24 | mse = mse.mean(dim=mean_dims) 25 | 26 | psnr = 20 * torch.log10(self.max_value ** 2 / torch.sqrt(mse)) 27 | return psnr.mean() 28 | 29 | 30 | -------------------------------------------------------------------------------- /losses/style_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from losses import gram_matrix 5 | 6 | 7 | class StyleLoss(nn.Module): 8 | 9 | def __init__(self, target_feature: torch.Tensor, mask=None): 10 | super().__init__() 11 | self.target_gram_matrix = gram_matrix(target_feature, mask).detach() 12 | self.mask = mask 13 | 14 | def forward(self, x: torch.Tensor) -> torch.Tensor: 15 | batch_size, num_channels = x.shape[:2] 16 | 17 | G = gram_matrix(x, self.mask) 18 | loss = (G - self.target_gram_matrix).square().sum() / (4 * (batch_size * num_channels)**2) 19 | 20 | return loss 21 | -------------------------------------------------------------------------------- /networks/encoder/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/networks/encoder/__init__.py -------------------------------------------------------------------------------- /networks/encoder/autoencoder.py: -------------------------------------------------------------------------------- 1 | from itertools import chain 2 | from typing import Iterator, Sequence, List, Dict 3 | 4 | import random 5 | import torch 6 | import torch.nn.functional as F 7 | from torch import nn 8 | from torch.nn import Parameter 9 | 10 | from latent_projecting import Latents, CodeLatents 11 | 12 | 13 | class StyleganAutoencoder(nn.Module): 14 | 15 | def __init__(self, encoder, decoder): 16 | super().__init__() 17 | self.encoder = encoder 18 | self.decoder = decoder 19 | self.use_generated_noise = True 20 | 21 | def forward(self, x: torch.Tensor) -> torch.Tensor: 22 | latent_codes = self.encode(x) 23 | 24 | if not self.use_generated_noise: 25 | latent_codes.noise = self.decoder.make_noise() 26 | 27 | reconstructed_x, _ = self.decoder([latent_codes.latent], input_is_latent=self.is_wplus(latent_codes), noise=latent_codes.noise) 28 | return reconstructed_x 29 | 30 | def is_wplus(self, latents: Latents): 31 | return len(latents.latent.shape) == 3 32 | 33 | def trainable_parameters(self, recurse: bool = ..., as_groups: Sequence[Sequence[str]] = None) -> [Iterator[Parameter], List[Dict[str, list]]]: 34 | if as_groups is None: 35 | return self.encoder.parameters(recurse=recurse) 36 | 37 | main_params = [] 38 | filtered_params = [[] for _ in as_groups] 39 | for name, param in self.encoder.named_parameters(recurse=recurse): 40 | in_group = False 41 | for i, key_list in enumerate(as_groups): 42 | if any(key in name for key in key_list): 43 | filtered_params[i].append(param) 44 | in_group = True 45 | break 46 | if not in_group: 47 | main_params.append(param) 48 | 49 | return [{'params': params} for params in [main_params] + filtered_params] 50 | 51 | def encode(self, x: torch.Tensor) -> Latents: 52 | return self.encoder(x) 53 | 54 | 55 | class DropoutStyleganAutoencoder(StyleganAutoencoder): 56 | 57 | def __init__(self, *args, dropout_ratio=0.5, **kwargs): 58 | super().__init__(*args, **kwargs) 59 | self.dropout_ratio = dropout_ratio 60 | 61 | def forward(self, x: torch.Tensor) -> torch.Tensor: 62 | latent_codes = self.encode(x) 63 | 64 | random_noises = self.decoder.make_noise() 65 | mixed_noise = [predicted_noise if random.random() > self.dropout_ratio else generated_noise for predicted_noise, generated_noise in zip(latent_codes.noise, random_noises)] 66 | 67 | reconstructed_x, _ = self.decoder([latent_codes.latent], input_is_latent=self.is_wplus(latent_codes), noise=mixed_noise) 68 | return reconstructed_x 69 | 70 | 71 | class CodeStyleganAutoencoder(StyleganAutoencoder): 72 | 73 | def forward(self, x: torch.Tensor) -> torch.Tensor: 74 | latent_info_codes = self.encode(x) 75 | 76 | latent = torch.cat([latent_info_codes.latent, latent_info_codes.code], dim=1) 77 | reconstructed_x, _ = self.decoder([latent], input_is_latent=False, noise=latent_info_codes.noise) 78 | 79 | return reconstructed_x 80 | 81 | def encode(self, x: torch.Tensor) -> CodeLatents: 82 | return self.encoder(x) 83 | 84 | 85 | class ContentAndStyleStyleganAutoencoder(StyleganAutoencoder): 86 | 87 | def forward(self, content_images: torch.Tensor, style_images: torch.Tensor) -> torch.Tensor: 88 | encoder_input_image = torch.cat([content_images, style_images], dim=1) 89 | latents = self.encode(encoder_input_image) 90 | 91 | reconstructed_x, _ = self.decoder([latents.latent], input_is_latent=self.is_wplus(latents), noise=latents.noise) 92 | return reconstructed_x 93 | 94 | 95 | class SuperResolutionStyleganAutoencoder(StyleganAutoencoder): 96 | 97 | def __init__(self, *args, extend_noise_with_random: bool = True, **kwargs): 98 | super().__init__(*args, **kwargs) 99 | self.extend_noise_with_random = extend_noise_with_random 100 | 101 | def forward(self, x: torch.Tensor) -> torch.Tensor: 102 | x = F.interpolate(x, (self.encoder.image_size, self.encoder.image_size), mode='area').detach() 103 | latents = self.encode(x) 104 | 105 | if self.decoder.size > self.encoder.image_size: 106 | # we have to add some noise to perform super resolution 107 | noises = latents.noise 108 | num_predicted_noise_maps = len(latents.noise) 109 | 110 | random_noises = self.decoder.make_noise() 111 | if self.extend_noise_with_random: 112 | noises.extend(random_noises[num_predicted_noise_maps:]) 113 | else: 114 | noise_maps_to_add = len(random_noises) - num_predicted_noise_maps 115 | current_noise_map = noises[-1] 116 | for i in range(noise_maps_to_add): 117 | current_noise_map = F.interpolate( 118 | current_noise_map.clone().detach(), 119 | random_noises[num_predicted_noise_maps + i].shape[-2:], 120 | mode='bilinear' 121 | ) 122 | noises.append(current_noise_map) 123 | 124 | latents.noise = noises 125 | 126 | # we also have to add some latent code parts if we have a w_plus latent 127 | if self.is_wplus(latents): 128 | target_num_latents = self.decoder.n_latent 129 | last_latent = latents.latent[:, -1, ...].unsqueeze(1).detach() 130 | padded_latent = last_latent.repeat((1, target_num_latents - latents.latent.shape[1], 1)) 131 | latents.latent = torch.cat([latents.latent, padded_latent], dim=1) 132 | 133 | reconstructed_x, _ = self.decoder([latents.latent], input_is_latent=self.is_wplus(latents), noise=latents.noise) 134 | return reconstructed_x 135 | 136 | 137 | class TwoStemStyleganAutoencoder(nn.Module): 138 | 139 | def __init__(self, latent_encoder, noise_encoder, decoder, update_latent=True, update_noise=True): 140 | super().__init__() 141 | self.latent_encoder = latent_encoder 142 | self.noise_encoder = noise_encoder 143 | self.decoder = decoder 144 | self.update_latent = update_latent 145 | self.update_noise = update_noise 146 | 147 | assert update_latent or update_noise, "'update_latent' or 'update_noise' must be true for Two Stem Autoencoder" 148 | 149 | @property 150 | def encoder(self): 151 | return self.latent_encoder 152 | 153 | def is_wplus(self, latents: Latents): 154 | return len(latents.latent.shape) == 3 155 | 156 | def forward(self, x: torch.Tensor) -> torch.Tensor: 157 | encoded = self.encode(x) 158 | 159 | reconstructed_x, _ = self.decoder([encoded.latent], input_is_latent=self.is_wplus(encoded), noise=encoded.noise) 160 | return reconstructed_x 161 | 162 | def trainable_parameters(self, recurse: bool = ..., as_groups: Sequence[Sequence[str]] = None) -> [Iterator[Parameter], List[Dict[str, list]]]: 163 | networks = [] 164 | if self.update_latent: 165 | networks.append(self.latent_encoder) 166 | if self.update_noise: 167 | networks.append(self.noise_encoder) 168 | 169 | if as_groups is None: 170 | return chain.from_iterable([network.parameters(recurse=recurse) for network in networks]) 171 | 172 | main_params = [] 173 | filtered_params = [[] for _ in as_groups] 174 | for network in networks: 175 | for name, param in network.named_parameters(recurse=recurse): 176 | in_group = False 177 | for i, key_list in enumerate(as_groups): 178 | if any(key in name for key in key_list): 179 | filtered_params[i].append(param) 180 | in_group = True 181 | break 182 | if not in_group: 183 | main_params.append(param) 184 | 185 | return [{'params': params} for params in [main_params] + filtered_params] 186 | 187 | def encode(self, x: torch.Tensor) -> Latents: 188 | with torch.set_grad_enabled(self.update_latent): 189 | latent_codes = self.latent_encoder(x).latent 190 | 191 | if self.update_noise: 192 | noise_codes = self.noise_encoder(x).noise 193 | else: 194 | noise_codes = self.decoder.make_noise() 195 | 196 | return Latents(latent=latent_codes, noise=noise_codes) 197 | -------------------------------------------------------------------------------- /networks/encoder/resnet_based_encoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import torch 4 | from torch import nn 5 | from torchvision.models.resnet import BasicBlock 6 | 7 | 8 | class Encoder(nn.Module): 9 | 10 | def __init__(self, image_size: int, latent_size: int, num_input_channels: int, size_channel_map: dict, target_size: int = 4): 11 | super().__init__() 12 | 13 | self.image_size = image_size 14 | self.latent_size = latent_size 15 | log_input_size = int(math.log(image_size, 2)) 16 | log_target_size = int(math.log(target_size, 2)) 17 | assert image_size > target_size, "Input size must be larger than target size" 18 | assert 2 ** log_input_size == image_size, "Input size must be a power of 2" 19 | assert 2 ** log_target_size == target_size, "Target size must be a power of 2" 20 | 21 | self.start_block = BasicBlock( 22 | num_input_channels, 23 | size_channel_map[image_size], 24 | downsample=nn.Sequential( 25 | nn.Conv2d(num_input_channels, size_channel_map[image_size], kernel_size=1, stride=1), 26 | nn.BatchNorm2d(size_channel_map[image_size]) 27 | ) 28 | ) 29 | 30 | self.resnet_blocks = [ 31 | BasicBlock( 32 | in_planes := size_channel_map[2 ** current_size], 33 | out_planes := size_channel_map[2 ** (current_size - 1)], 34 | stride=2, 35 | downsample=nn.Sequential( 36 | nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=2), 37 | nn.BatchNorm2d(out_planes) 38 | ) 39 | ) 40 | for current_size in range(log_input_size, log_target_size, -1) 41 | ] 42 | self.resnet_blocks = nn.ModuleList([self.start_block] + self.resnet_blocks) 43 | 44 | num_latents = (log_input_size - log_target_size) * 2 + 2 45 | self.to_latent = [ 46 | nn.Conv2d(size_channel_map[target_size], self.latent_size, kernel_size=(target_size, target_size), stride=1) 47 | for _ in range(num_latents) 48 | ] 49 | self.to_latent = nn.ModuleList(self.to_latent) 50 | 51 | def forward(self, x): 52 | h = x 53 | for resnet_block in self.resnet_blocks: 54 | h = resnet_block(h) 55 | 56 | latent_codes = [to_latent(h) for to_latent in self.to_latent] 57 | latent_codes = torch.stack(latent_codes, dim=1) 58 | latent_codes = latent_codes.squeeze(3).squeeze(3) 59 | 60 | return latent_codes 61 | -------------------------------------------------------------------------------- /networks/stylegan1/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/networks/stylegan1/__init__.py -------------------------------------------------------------------------------- /networks/stylegan2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/networks/stylegan2/__init__.py -------------------------------------------------------------------------------- /networks/stylegan2/op/__init__.py: -------------------------------------------------------------------------------- 1 | from .fused_act import FusedLeakyReLU, fused_leaky_relu 2 | from .upfirdn2d import upfirdn2d 3 | -------------------------------------------------------------------------------- /networks/stylegan2/op/fused_act.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Function 6 | from torch.utils.cpp_extension import load 7 | 8 | 9 | module_path = os.path.dirname(__file__) 10 | fused = load( 11 | 'fused', 12 | sources=[ 13 | os.path.join(module_path, 'fused_bias_act.cpp'), 14 | os.path.join(module_path, 'fused_bias_act_kernel.cu'), 15 | ], 16 | ) 17 | 18 | 19 | class FusedLeakyReLUFunctionBackward(Function): 20 | @staticmethod 21 | def forward(ctx, grad_output, out, negative_slope, scale): 22 | ctx.save_for_backward(out) 23 | ctx.negative_slope = negative_slope 24 | ctx.scale = scale 25 | 26 | empty = grad_output.new_empty(0) 27 | 28 | grad_input = fused.fused_bias_act( 29 | grad_output, empty, out, 3, 1, negative_slope, scale 30 | ) 31 | 32 | dim = [0] 33 | 34 | if grad_input.ndim > 2: 35 | dim += list(range(2, grad_input.ndim)) 36 | 37 | grad_bias = grad_input.sum(dim).detach() 38 | 39 | return grad_input, grad_bias 40 | 41 | @staticmethod 42 | def backward(ctx, gradgrad_input, gradgrad_bias): 43 | out, = ctx.saved_tensors 44 | gradgrad_out = fused.fused_bias_act( 45 | gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale 46 | ) 47 | 48 | return gradgrad_out, None, None, None 49 | 50 | 51 | class FusedLeakyReLUFunction(Function): 52 | @staticmethod 53 | def forward(ctx, input, bias, negative_slope, scale): 54 | empty = input.new_empty(0) 55 | out = fused.fused_bias_act(input, bias, empty, 3, 0, negative_slope, scale) 56 | ctx.save_for_backward(out) 57 | ctx.negative_slope = negative_slope 58 | ctx.scale = scale 59 | 60 | return out 61 | 62 | @staticmethod 63 | def backward(ctx, grad_output): 64 | out, = ctx.saved_tensors 65 | 66 | grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply( 67 | grad_output, out, ctx.negative_slope, ctx.scale 68 | ) 69 | 70 | return grad_input, grad_bias, None, None 71 | 72 | 73 | class FusedLeakyReLU(nn.Module): 74 | def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5): 75 | super().__init__() 76 | 77 | self.bias = nn.Parameter(torch.zeros(channel)) 78 | self.negative_slope = negative_slope 79 | self.scale = scale 80 | 81 | def forward(self, input): 82 | return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale) 83 | 84 | 85 | def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): 86 | return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) 87 | -------------------------------------------------------------------------------- /networks/stylegan2/op/fused_bias_act.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 5 | int act, int grad, float alpha, float scale); 6 | 7 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 8 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 9 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 10 | 11 | torch::Tensor fused_bias_act(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 12 | int act, int grad, float alpha, float scale) { 13 | CHECK_CUDA(input); 14 | CHECK_CUDA(bias); 15 | 16 | return fused_bias_act_op(input, bias, refer, act, grad, alpha, scale); 17 | } 18 | 19 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 20 | m.def("fused_bias_act", &fused_bias_act, "fused bias act (CUDA)"); 21 | } -------------------------------------------------------------------------------- /networks/stylegan2/op/fused_bias_act_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | template 19 | static __global__ void fused_bias_act_kernel(scalar_t* out, const scalar_t* p_x, const scalar_t* p_b, const scalar_t* p_ref, 20 | int act, int grad, scalar_t alpha, scalar_t scale, int loop_x, int size_x, int step_b, int size_b, int use_bias, int use_ref) { 21 | int xi = blockIdx.x * loop_x * blockDim.x + threadIdx.x; 22 | 23 | scalar_t zero = 0.0; 24 | 25 | for (int loop_idx = 0; loop_idx < loop_x && xi < size_x; loop_idx++, xi += blockDim.x) { 26 | scalar_t x = p_x[xi]; 27 | 28 | if (use_bias) { 29 | x += p_b[(xi / step_b) % size_b]; 30 | } 31 | 32 | scalar_t ref = use_ref ? p_ref[xi] : zero; 33 | 34 | scalar_t y; 35 | 36 | switch (act * 10 + grad) { 37 | default: 38 | case 10: y = x; break; 39 | case 11: y = x; break; 40 | case 12: y = 0.0; break; 41 | 42 | case 30: y = (x > 0.0) ? x : x * alpha; break; 43 | case 31: y = (ref > 0.0) ? x : x * alpha; break; 44 | case 32: y = 0.0; break; 45 | } 46 | 47 | out[xi] = y * scale; 48 | } 49 | } 50 | 51 | 52 | torch::Tensor fused_bias_act_op(const torch::Tensor& input, const torch::Tensor& bias, const torch::Tensor& refer, 53 | int act, int grad, float alpha, float scale) { 54 | int curDevice = -1; 55 | cudaGetDevice(&curDevice); 56 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 57 | 58 | auto x = input.contiguous(); 59 | auto b = bias.contiguous(); 60 | auto ref = refer.contiguous(); 61 | 62 | int use_bias = b.numel() ? 1 : 0; 63 | int use_ref = ref.numel() ? 1 : 0; 64 | 65 | int size_x = x.numel(); 66 | int size_b = b.numel(); 67 | int step_b = 1; 68 | 69 | for (int i = 1 + 1; i < x.dim(); i++) { 70 | step_b *= x.size(i); 71 | } 72 | 73 | int loop_x = 4; 74 | int block_size = 4 * 32; 75 | int grid_size = (size_x - 1) / (loop_x * block_size) + 1; 76 | 77 | auto y = torch::empty_like(x); 78 | 79 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "fused_bias_act_kernel", [&] { 80 | fused_bias_act_kernel<<>>( 81 | y.data_ptr(), 82 | x.data_ptr(), 83 | b.data_ptr(), 84 | ref.data_ptr(), 85 | act, 86 | grad, 87 | alpha, 88 | scale, 89 | loop_x, 90 | size_x, 91 | step_b, 92 | size_b, 93 | use_bias, 94 | use_ref 95 | ); 96 | }); 97 | 98 | return y; 99 | } -------------------------------------------------------------------------------- /networks/stylegan2/op/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | 4 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 5 | int up_x, int up_y, int down_x, int down_y, 6 | int pad_x0, int pad_x1, int pad_y0, int pad_y1); 7 | 8 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 9 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 10 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 11 | 12 | torch::Tensor upfirdn2d(const torch::Tensor& input, const torch::Tensor& kernel, 13 | int up_x, int up_y, int down_x, int down_y, 14 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 15 | CHECK_CUDA(input); 16 | CHECK_CUDA(kernel); 17 | 18 | return upfirdn2d_op(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1); 19 | } 20 | 21 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 22 | m.def("upfirdn2d", &upfirdn2d, "upfirdn2d (CUDA)"); 23 | } -------------------------------------------------------------------------------- /networks/stylegan2/op/upfirdn2d.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import torch 4 | from torch.autograd import Function 5 | from torch.utils.cpp_extension import load 6 | 7 | 8 | module_path = os.path.dirname(__file__) 9 | upfirdn2d_op = load( 10 | 'upfirdn2d', 11 | sources=[ 12 | os.path.join(module_path, 'upfirdn2d.cpp'), 13 | os.path.join(module_path, 'upfirdn2d_kernel.cu'), 14 | ], 15 | ) 16 | 17 | 18 | class UpFirDn2dBackward(Function): 19 | @staticmethod 20 | def forward( 21 | ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size 22 | ): 23 | 24 | up_x, up_y = up 25 | down_x, down_y = down 26 | g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad 27 | 28 | grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1) 29 | 30 | grad_input = upfirdn2d_op.upfirdn2d( 31 | grad_output, 32 | grad_kernel, 33 | down_x, 34 | down_y, 35 | up_x, 36 | up_y, 37 | g_pad_x0, 38 | g_pad_x1, 39 | g_pad_y0, 40 | g_pad_y1, 41 | ) 42 | grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3]) 43 | 44 | ctx.save_for_backward(kernel) 45 | 46 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 47 | 48 | ctx.up_x = up_x 49 | ctx.up_y = up_y 50 | ctx.down_x = down_x 51 | ctx.down_y = down_y 52 | ctx.pad_x0 = pad_x0 53 | ctx.pad_x1 = pad_x1 54 | ctx.pad_y0 = pad_y0 55 | ctx.pad_y1 = pad_y1 56 | ctx.in_size = in_size 57 | ctx.out_size = out_size 58 | 59 | return grad_input 60 | 61 | @staticmethod 62 | def backward(ctx, gradgrad_input): 63 | kernel, = ctx.saved_tensors 64 | 65 | gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1) 66 | 67 | gradgrad_out = upfirdn2d_op.upfirdn2d( 68 | gradgrad_input, 69 | kernel, 70 | ctx.up_x, 71 | ctx.up_y, 72 | ctx.down_x, 73 | ctx.down_y, 74 | ctx.pad_x0, 75 | ctx.pad_x1, 76 | ctx.pad_y0, 77 | ctx.pad_y1, 78 | ) 79 | # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0], ctx.out_size[1], ctx.in_size[3]) 80 | gradgrad_out = gradgrad_out.view( 81 | ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1] 82 | ) 83 | 84 | return gradgrad_out, None, None, None, None, None, None, None, None 85 | 86 | 87 | class UpFirDn2d(Function): 88 | @staticmethod 89 | def forward(ctx, input, kernel, up, down, pad): 90 | up_x, up_y = up 91 | down_x, down_y = down 92 | pad_x0, pad_x1, pad_y0, pad_y1 = pad 93 | 94 | kernel_h, kernel_w = kernel.shape 95 | batch, channel, in_h, in_w = input.shape 96 | ctx.in_size = input.shape 97 | 98 | input = input.reshape(-1, in_h, in_w, 1) 99 | 100 | ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) 101 | 102 | out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 103 | out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 104 | ctx.out_size = (out_h, out_w) 105 | 106 | ctx.up = (up_x, up_y) 107 | ctx.down = (down_x, down_y) 108 | ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1) 109 | 110 | g_pad_x0 = kernel_w - pad_x0 - 1 111 | g_pad_y0 = kernel_h - pad_y0 - 1 112 | g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1 113 | g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1 114 | 115 | ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1) 116 | 117 | out = upfirdn2d_op.upfirdn2d( 118 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 119 | ) 120 | # out = out.view(major, out_h, out_w, minor) 121 | out = out.view(-1, channel, out_h, out_w) 122 | 123 | return out 124 | 125 | @staticmethod 126 | def backward(ctx, grad_output): 127 | kernel, grad_kernel = ctx.saved_tensors 128 | 129 | grad_input = UpFirDn2dBackward.apply( 130 | grad_output, 131 | kernel, 132 | grad_kernel, 133 | ctx.up, 134 | ctx.down, 135 | ctx.pad, 136 | ctx.g_pad, 137 | ctx.in_size, 138 | ctx.out_size, 139 | ) 140 | 141 | return grad_input, None, None, None, None 142 | 143 | 144 | def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): 145 | out = UpFirDn2d.apply( 146 | input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) 147 | ) 148 | 149 | return out 150 | 151 | 152 | def upfirdn2d_native( 153 | input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 154 | ): 155 | _, in_h, in_w, minor = input.shape 156 | kernel_h, kernel_w = kernel.shape 157 | 158 | out = input.view(-1, in_h, 1, in_w, 1, minor) 159 | out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) 160 | out = out.view(-1, in_h * up_y, in_w * up_x, minor) 161 | 162 | out = F.pad( 163 | out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)] 164 | ) 165 | out = out[ 166 | :, 167 | max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0), 168 | max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0), 169 | :, 170 | ] 171 | 172 | out = out.permute(0, 3, 1, 2) 173 | out = out.reshape( 174 | [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1] 175 | ) 176 | w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w) 177 | out = F.conv2d(out, w) 178 | out = out.reshape( 179 | -1, 180 | minor, 181 | in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1, 182 | in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, 183 | ) 184 | out = out.permute(0, 2, 3, 1) 185 | 186 | return out[:, ::down_y, ::down_x, :] 187 | 188 | -------------------------------------------------------------------------------- /networks/stylegan2/op/upfirdn2d_kernel.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | // 3 | // This work is made available under the Nvidia Source Code License-NC. 4 | // To view a copy of this license, visit 5 | // https://nvlabs.github.io/stylegan2/license.html 6 | 7 | #include 8 | 9 | #include 10 | #include 11 | #include 12 | #include 13 | 14 | #include 15 | #include 16 | 17 | 18 | static __host__ __device__ __forceinline__ int floor_div(int a, int b) { 19 | int c = a / b; 20 | 21 | if (c * b > a) { 22 | c--; 23 | } 24 | 25 | return c; 26 | } 27 | 28 | 29 | struct UpFirDn2DKernelParams { 30 | int up_x; 31 | int up_y; 32 | int down_x; 33 | int down_y; 34 | int pad_x0; 35 | int pad_x1; 36 | int pad_y0; 37 | int pad_y1; 38 | 39 | int major_dim; 40 | int in_h; 41 | int in_w; 42 | int minor_dim; 43 | int kernel_h; 44 | int kernel_w; 45 | int out_h; 46 | int out_w; 47 | int loop_major; 48 | int loop_x; 49 | }; 50 | 51 | 52 | template 53 | __global__ void upfirdn2d_kernel(scalar_t* out, const scalar_t* input, const scalar_t* kernel, const UpFirDn2DKernelParams p) { 54 | const int tile_in_h = ((tile_out_h - 1) * down_y + kernel_h - 1) / up_y + 1; 55 | const int tile_in_w = ((tile_out_w - 1) * down_x + kernel_w - 1) / up_x + 1; 56 | 57 | __shared__ volatile float sk[kernel_h][kernel_w]; 58 | __shared__ volatile float sx[tile_in_h][tile_in_w]; 59 | 60 | int minor_idx = blockIdx.x; 61 | int tile_out_y = minor_idx / p.minor_dim; 62 | minor_idx -= tile_out_y * p.minor_dim; 63 | tile_out_y *= tile_out_h; 64 | int tile_out_x_base = blockIdx.y * p.loop_x * tile_out_w; 65 | int major_idx_base = blockIdx.z * p.loop_major; 66 | 67 | if (tile_out_x_base >= p.out_w | tile_out_y >= p.out_h | major_idx_base >= p.major_dim) { 68 | return; 69 | } 70 | 71 | for (int tap_idx = threadIdx.x; tap_idx < kernel_h * kernel_w; tap_idx += blockDim.x) { 72 | int ky = tap_idx / kernel_w; 73 | int kx = tap_idx - ky * kernel_w; 74 | scalar_t v = 0.0; 75 | 76 | if (kx < p.kernel_w & ky < p.kernel_h) { 77 | v = kernel[(p.kernel_h - 1 - ky) * p.kernel_w + (p.kernel_w - 1 - kx)]; 78 | } 79 | 80 | sk[ky][kx] = v; 81 | } 82 | 83 | for (int loop_major = 0, major_idx = major_idx_base; loop_major < p.loop_major & major_idx < p.major_dim; loop_major++, major_idx++) { 84 | for (int loop_x = 0, tile_out_x = tile_out_x_base; loop_x < p.loop_x & tile_out_x < p.out_w; loop_x++, tile_out_x += tile_out_w) { 85 | int tile_mid_x = tile_out_x * down_x + up_x - 1 - p.pad_x0; 86 | int tile_mid_y = tile_out_y * down_y + up_y - 1 - p.pad_y0; 87 | int tile_in_x = floor_div(tile_mid_x, up_x); 88 | int tile_in_y = floor_div(tile_mid_y, up_y); 89 | 90 | __syncthreads(); 91 | 92 | for (int in_idx = threadIdx.x; in_idx < tile_in_h * tile_in_w; in_idx += blockDim.x) { 93 | int rel_in_y = in_idx / tile_in_w; 94 | int rel_in_x = in_idx - rel_in_y * tile_in_w; 95 | int in_x = rel_in_x + tile_in_x; 96 | int in_y = rel_in_y + tile_in_y; 97 | 98 | scalar_t v = 0.0; 99 | 100 | if (in_x >= 0 & in_y >= 0 & in_x < p.in_w & in_y < p.in_h) { 101 | v = input[((major_idx * p.in_h + in_y) * p.in_w + in_x) * p.minor_dim + minor_idx]; 102 | } 103 | 104 | sx[rel_in_y][rel_in_x] = v; 105 | } 106 | 107 | __syncthreads(); 108 | for (int out_idx = threadIdx.x; out_idx < tile_out_h * tile_out_w; out_idx += blockDim.x) { 109 | int rel_out_y = out_idx / tile_out_w; 110 | int rel_out_x = out_idx - rel_out_y * tile_out_w; 111 | int out_x = rel_out_x + tile_out_x; 112 | int out_y = rel_out_y + tile_out_y; 113 | 114 | int mid_x = tile_mid_x + rel_out_x * down_x; 115 | int mid_y = tile_mid_y + rel_out_y * down_y; 116 | int in_x = floor_div(mid_x, up_x); 117 | int in_y = floor_div(mid_y, up_y); 118 | int rel_in_x = in_x - tile_in_x; 119 | int rel_in_y = in_y - tile_in_y; 120 | int kernel_x = (in_x + 1) * up_x - mid_x - 1; 121 | int kernel_y = (in_y + 1) * up_y - mid_y - 1; 122 | 123 | scalar_t v = 0.0; 124 | 125 | #pragma unroll 126 | for (int y = 0; y < kernel_h / up_y; y++) 127 | #pragma unroll 128 | for (int x = 0; x < kernel_w / up_x; x++) 129 | v += sx[rel_in_y + y][rel_in_x + x] * sk[kernel_y + y * up_y][kernel_x + x * up_x]; 130 | 131 | if (out_x < p.out_w & out_y < p.out_h) { 132 | out[((major_idx * p.out_h + out_y) * p.out_w + out_x) * p.minor_dim + minor_idx] = v; 133 | } 134 | } 135 | } 136 | } 137 | } 138 | 139 | 140 | torch::Tensor upfirdn2d_op(const torch::Tensor& input, const torch::Tensor& kernel, 141 | int up_x, int up_y, int down_x, int down_y, 142 | int pad_x0, int pad_x1, int pad_y0, int pad_y1) { 143 | int curDevice = -1; 144 | cudaGetDevice(&curDevice); 145 | cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); 146 | 147 | UpFirDn2DKernelParams p; 148 | 149 | auto x = input.contiguous(); 150 | auto k = kernel.contiguous(); 151 | 152 | p.major_dim = x.size(0); 153 | p.in_h = x.size(1); 154 | p.in_w = x.size(2); 155 | p.minor_dim = x.size(3); 156 | p.kernel_h = k.size(0); 157 | p.kernel_w = k.size(1); 158 | p.up_x = up_x; 159 | p.up_y = up_y; 160 | p.down_x = down_x; 161 | p.down_y = down_y; 162 | p.pad_x0 = pad_x0; 163 | p.pad_x1 = pad_x1; 164 | p.pad_y0 = pad_y0; 165 | p.pad_y1 = pad_y1; 166 | 167 | p.out_h = (p.in_h * p.up_y + p.pad_y0 + p.pad_y1 - p.kernel_h + p.down_y) / p.down_y; 168 | p.out_w = (p.in_w * p.up_x + p.pad_x0 + p.pad_x1 - p.kernel_w + p.down_x) / p.down_x; 169 | 170 | auto out = at::empty({p.major_dim, p.out_h, p.out_w, p.minor_dim}, x.options()); 171 | 172 | int mode = -1; 173 | 174 | int tile_out_h; 175 | int tile_out_w; 176 | 177 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 178 | mode = 1; 179 | tile_out_h = 16; 180 | tile_out_w = 64; 181 | } 182 | 183 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 3 && p.kernel_w <= 3) { 184 | mode = 2; 185 | tile_out_h = 16; 186 | tile_out_w = 64; 187 | } 188 | 189 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 4 && p.kernel_w <= 4) { 190 | mode = 3; 191 | tile_out_h = 16; 192 | tile_out_w = 64; 193 | } 194 | 195 | if (p.up_x == 2 && p.up_y == 2 && p.down_x == 1 && p.down_y == 1 && p.kernel_h <= 2 && p.kernel_w <= 2) { 196 | mode = 4; 197 | tile_out_h = 16; 198 | tile_out_w = 64; 199 | } 200 | 201 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 4 && p.kernel_w <= 4) { 202 | mode = 5; 203 | tile_out_h = 8; 204 | tile_out_w = 32; 205 | } 206 | 207 | if (p.up_x == 1 && p.up_y == 1 && p.down_x == 2 && p.down_y == 2 && p.kernel_h <= 2 && p.kernel_w <= 2) { 208 | mode = 6; 209 | tile_out_h = 8; 210 | tile_out_w = 32; 211 | } 212 | 213 | dim3 block_size; 214 | dim3 grid_size; 215 | 216 | if (tile_out_h > 0 && tile_out_w) { 217 | p.loop_major = (p.major_dim - 1) / 16384 + 1; 218 | p.loop_x = 1; 219 | block_size = dim3(32 * 8, 1, 1); 220 | grid_size = dim3(((p.out_h - 1) / tile_out_h + 1) * p.minor_dim, 221 | (p.out_w - 1) / (p.loop_x * tile_out_w) + 1, 222 | (p.major_dim - 1) / p.loop_major + 1); 223 | } 224 | 225 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] { 226 | switch (mode) { 227 | case 1: 228 | upfirdn2d_kernel<<>>( 229 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 230 | ); 231 | 232 | break; 233 | 234 | case 2: 235 | upfirdn2d_kernel<<>>( 236 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 237 | ); 238 | 239 | break; 240 | 241 | case 3: 242 | upfirdn2d_kernel<<>>( 243 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 244 | ); 245 | 246 | break; 247 | 248 | case 4: 249 | upfirdn2d_kernel<<>>( 250 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 251 | ); 252 | 253 | break; 254 | 255 | case 5: 256 | upfirdn2d_kernel<<>>( 257 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 258 | ); 259 | 260 | break; 261 | 262 | case 6: 263 | upfirdn2d_kernel<<>>( 264 | out.data_ptr(), x.data_ptr(), k.data_ptr(), p 265 | ); 266 | 267 | break; 268 | } 269 | }); 270 | 271 | return out; 272 | } -------------------------------------------------------------------------------- /project.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import matplotlib 6 | import torch 7 | from PIL import Image 8 | from tqdm import trange 9 | 10 | from latent_projecting import run_image_reconstruction, Latents 11 | from latent_projecting.projector import Projector 12 | from utils.command_line_args import add_default_args_for_projecting 13 | from pytorch_training.images.utils import make_image 14 | 15 | matplotlib.use('AGG') 16 | 17 | Image.init() 18 | 19 | 20 | def latent_noise(latent, strength): 21 | noise = torch.randn_like(latent) * strength 22 | 23 | return latent + noise 24 | 25 | 26 | # def abort_condition(loss_dict): 27 | # if loss_dict['psnr'] > 10: 28 | # return True 29 | # return False 30 | 31 | 32 | def main(args): 33 | projector = Projector(args) 34 | 35 | transform = projector.get_transforms() 36 | 37 | imgs = [] 38 | image_names = [] 39 | 40 | for file_name in os.listdir(args.files): 41 | if os.path.splitext(file_name)[-1] not in Image.EXTENSION.keys(): 42 | continue 43 | 44 | image_name = os.path.join(args.files, file_name) 45 | img = transform(Image.open(image_name).convert('RGB')) 46 | image_names.append(image_name) 47 | imgs.append(img) 48 | 49 | imgs = torch.stack(imgs, 0).to(args.device) 50 | 51 | n_mean_latent = 10000 52 | latent_mean, latent_std = projector.get_mean_latent(n_mean_latent) 53 | 54 | for idx in trange(0, len(imgs), args.batch_size): 55 | images = imgs[idx:idx + args.batch_size] 56 | 57 | base_noises = projector.generator.make_noise() 58 | base_noises = [noise.repeat(len(images), 1, 1, 1) for noise in base_noises] 59 | 60 | noises = [noise.detach().clone() for noise in base_noises] 61 | 62 | if args.no_mean_latent: 63 | latent_in = torch.normal(0, latent_std.item(), size=(len(images), projector.config['latent_size']), device=args.device) 64 | else: 65 | latent_in = latent_mean.detach().clone().unsqueeze(0).repeat(len(images), 1) 66 | 67 | if args.w_plus: 68 | latent_in = latent_in.unsqueeze(1).repeat(1, projector.generator.n_latent, 1) 69 | 70 | # optimize latent vector 71 | paths, best_latent = run_image_reconstruction(args, projector, Latents(latent_in, noises), images, do_optimize_noise=args.optimize_noise) 72 | 73 | # result_file = {'noises': noises} 74 | 75 | img_gen, _ = projector.generator([best_latent.latent.cuda()], input_is_latent=True, noise=[noise.cuda() for noise in best_latent.noise]) 76 | 77 | img_ar = make_image(img_gen) 78 | 79 | destination_dir = Path(args.files) / 'projected' / args.destination 80 | destination_dir.mkdir(parents=True, exist_ok=True) 81 | 82 | path_per_image = paths.split() 83 | for i in range(len(images)): 84 | image_name = image_names[idx + i] 85 | image_latent = best_latent[i] 86 | result_file = { 87 | 'noise': image_latent.noise, 88 | 'latent': image_latent.latent, 89 | } 90 | image_base_name = os.path.splitext(os.path.basename(image_name))[0] 91 | img_name = image_base_name + '-project.png' 92 | pil_img = Image.fromarray(img_ar[i]) 93 | pil_img.save(destination_dir / img_name) 94 | torch.save(result_file, destination_dir / f'results_{image_base_name}.pth') 95 | if args.create_gif: 96 | projector.create_gif( 97 | path_per_image[i].to(args.device), 98 | image_base_name, 99 | destination_dir 100 | ) 101 | projector.render_log(destination_dir, image_base_name) 102 | 103 | # cleanup 104 | del paths 105 | del best_latent 106 | torch.cuda.empty_cache() 107 | projector.reset() 108 | 109 | 110 | if __name__ == "__main__": 111 | parser = argparse.ArgumentParser() 112 | parser.add_argument('files', metavar='FILES', help="Path to dir holding all images to embed") 113 | parser.add_argument('destination', help="name of the destination subdir where results will be saved") 114 | parser.add_argument('--noise', type=float, default=0.05) 115 | parser.add_argument('--create-gif', help='create gif showing the optimization process', action='store_true', default=False) 116 | parser.add_argument('--no-noise-optimize', action='store_false', default=True, dest='optimize_noise', help="do not perform noise optimization") 117 | 118 | parser = add_default_args_for_projecting(parser) 119 | 120 | args = parser.parse_args() 121 | assert not Path(args.destination).is_absolute(), "The destination path is supposed to be a relative path!" 122 | main(args) 123 | -------------------------------------------------------------------------------- /pytest.ini: -------------------------------------------------------------------------------- 1 | [pytest] 2 | filterwarnings = 3 | ignore::UserWarning 4 | ignore::DeprecationWarning 5 | norecursedirs = logs .git 6 | -------------------------------------------------------------------------------- /reconstruct_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from PIL import Image 5 | from pytorch_training.images import make_image 6 | 7 | from data.demo_dataset import DemoDataset 8 | from networks import get_autoencoder, load_weights 9 | from utils.config import load_config 10 | from utils.data_loading import build_data_loader 11 | 12 | 13 | def main(args): 14 | root_dir = Path(args.autoencoder_checkpoint).parent.parent 15 | output_dir = root_dir / args.output_dir 16 | output_dir.mkdir(exist_ok=True, parents=True) 17 | 18 | config = load_config(args.autoencoder_checkpoint, None) 19 | config['batch_size'] = 1 20 | autoencoder = get_autoencoder(config).to(args.device) 21 | autoencoder = load_weights(autoencoder, args.autoencoder_checkpoint, key='autoencoder') 22 | 23 | input_image = Path(args.image) 24 | data_loader = build_data_loader(input_image, config, config['absolute'], shuffle_off=True, dataset_class=DemoDataset) 25 | 26 | image = next(iter(data_loader)) 27 | image = {k: v.to(args.device) for k,v in image.items()} 28 | 29 | reconstructed = Image.fromarray(make_image(autoencoder(image['input_image'])[0].squeeze(0))) 30 | 31 | output_name = Path(args.output_dir) / f"reconstructed_{input_image.stem}_stylegan_{config['stylegan_variant']}_{'w_only' if config['w_only'] else 'w_plus'}.png" 32 | reconstructed.save(output_name) 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser( 37 | description="reconstruct a given image", 38 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 39 | ) 40 | parser.add_argument("autoencoder_checkpoint", help='Path to autoencoder checkpoint which shall be used for embedding') 41 | parser.add_argument("image", help="image to reconstruct") 42 | parser.add_argument("--device", default='cuda', help="which device to use (cuda, or cpu)") 43 | parser.add_argument("--output-dir", default='.') 44 | 45 | main(parser.parse_args()) 46 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | kornia==0.3.1 2 | numpy==1.18.2 3 | opencv-python-headless==4.2.0.34 4 | Pillow==8.0.0 5 | PyYAML==5.3.1 6 | requests==2.23.0 7 | scikit-image==0.16.2 8 | scipy==1.4.1 9 | tensorboard-pytorch==0.7.1 10 | torch==1.5.0 11 | torchvision==0.6.0 12 | tqdm==4.44.1 13 | wandb==0.9.5 14 | ipython==7.16.1 15 | tensorboard==2.2.2 16 | statsmodels==0.11.1 17 | pytorch-fid==0.1.1 18 | imgaug==0.4.0 19 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/tests/__init__.py -------------------------------------------------------------------------------- /tests/test_argument_parsing.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from utils.command_line_args import add_default_args_for_projecting 4 | 5 | 6 | def test_default_args(): 7 | parser = argparse.ArgumentParser() 8 | parser = add_default_args_for_projecting(parser) 9 | 10 | assert isinstance(parser, argparse.ArgumentParser) 11 | -------------------------------------------------------------------------------- /tests/test_data_structures.py: -------------------------------------------------------------------------------- 1 | import copy 2 | 3 | import pytest 4 | import torch 5 | 6 | from latent_projecting import LatentPaths, Latents 7 | from tests.test_projector import possible_devices 8 | 9 | 10 | class TestLatentPaths: 11 | 12 | @pytest.fixture 13 | def path_length(self): 14 | return 20 15 | 16 | @pytest.fixture 17 | def path(self, path_length): 18 | latent = [torch.ones((1, 14, 512)) for _ in range(path_length)] 19 | noises = [[torch.ones((1, 1, 4, 4)) for _ in range(7)] for __ in range(path_length)] 20 | 21 | path = LatentPaths(latent, noises) 22 | return path 23 | 24 | def test_len(self, path, path_length): 25 | assert len(path) == path_length 26 | 27 | path = LatentPaths(path.latent, path.noise[:-2]) 28 | 29 | with pytest.raises(AssertionError): 30 | assert len(path) == path_length 31 | 32 | @pytest.mark.parametrize('device', possible_devices) 33 | def test_to(self, device, path): 34 | path = path.to(device) 35 | 36 | for element in path: 37 | assert device in str(element.latent.device) 38 | for noise in element.noise: 39 | assert device in str(noise.device) 40 | 41 | def test_add(self, path, path_length): 42 | second_path = copy.deepcopy(path) 43 | path = path + second_path 44 | 45 | assert len(path.latent) == 2 * path_length 46 | assert len(path.noise) == 2 * path_length 47 | 48 | @pytest.mark.parametrize('batch_size', [1, 2, 5]) 49 | def test_split(self, path_length, batch_size): 50 | latent = [torch.ones((batch_size, 14, 512)) for _ in range(path_length)] 51 | noises = [[torch.ones((batch_size, 1, 4, 4)) for _ in range(7)] for __ in range(path_length)] 52 | path = LatentPaths(latent, noises) 53 | 54 | splitted_path = path.split() 55 | 56 | assert len(splitted_path) == batch_size 57 | for split in splitted_path: 58 | assert len(split) == path_length 59 | for element in split: 60 | assert element.latent.shape == (1, 14, 512) 61 | for noise in element.noise: 62 | assert noise.shape == (1, 1, 4, 4) 63 | 64 | 65 | class TestLatents: 66 | 67 | @pytest.mark.parametrize('device', possible_devices) 68 | def test_to(self, device): 69 | def check_device(latents, dev): 70 | assert dev in str(latents.latent.device) 71 | for noise in latents.noise: 72 | assert dev in str(noise.device) 73 | 74 | latent = torch.ones((1, 14, 512)) 75 | noises = [torch.ones((1, 1, 4, 4)) for _ in range(7)] 76 | 77 | latents = Latents(latent, noises) 78 | check_device(latents, 'cpu') 79 | 80 | latents = latents.to(device) 81 | check_device(latents, device) 82 | 83 | @pytest.mark.parametrize('batch_size', [1, 2, 5]) 84 | def test_getitem(self, batch_size): 85 | latent = torch.ones((batch_size, 14, 512)) 86 | noises = [torch.ones((batch_size, 1, 4, 4)) for _ in range(7)] 87 | 88 | latents = Latents(latent, noises) 89 | 90 | for i in range(batch_size): 91 | sub_latent = latents[i] 92 | assert sub_latent.latent.shape == (1, 14, 512) 93 | for noise in sub_latent.noise: 94 | assert noise.shape == (1, 1, 4, 4) 95 | -------------------------------------------------------------------------------- /tests/test_image_utils.py: -------------------------------------------------------------------------------- 1 | import numpy 2 | import pytest 3 | import torch 4 | from PIL import Image 5 | 6 | from pytorch_training.images.utils import make_image 7 | from utils.image_utils import render_text_on_image 8 | 9 | 10 | class TestImageUtils: 11 | 12 | @pytest.fixture(params=[(1, 3, 256, 256), (3, 256, 256)]) 13 | def tensor(self, request): 14 | tensor = torch.rand(request.param) 15 | tensor[0, 0] = -1 16 | tensor[-1, -1] = 1 17 | 18 | return tensor 19 | 20 | def test_render_text(self, tensor): 21 | if len(tensor.shape) == 3: 22 | image = make_image(tensor) 23 | else: 24 | image = make_image(tensor)[0] 25 | image_with_text = render_text_on_image("test", Image.fromarray(image)) 26 | 27 | text_array = numpy.array(image_with_text) 28 | assert not numpy.allclose(image, text_array) 29 | assert numpy.allclose(image[:128, :128, :], text_array[:128, :128, :]) 30 | -------------------------------------------------------------------------------- /tests/test_latent_projecting.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | 4 | from latent_projecting import noise_loss, w_plus_style_loss, naive_noise_loss, w_plus_loss 5 | 6 | possible_devices = ['cpu'] 7 | if torch.cuda.is_available(): 8 | possible_devices += ['cuda'] 9 | 10 | 11 | class TestLosses: 12 | 13 | def run(self, loss_func, device='cpu'): 14 | losses = loss_func( 15 | torch.randn(self.shape).to(device), 16 | torch.randn(self.shape).to(device) 17 | ) 18 | 19 | assert isinstance(losses, tuple) 20 | assert len(losses) == 2 21 | assert isinstance(losses[0], torch.Tensor) 22 | assert isinstance(losses[1], dict) 23 | 24 | @pytest.fixture(autouse=True) 25 | def shape(self): 26 | self.shape = 1, 3, 256, 256 27 | self.mask_shape = 1, 1, 256, 256 28 | 29 | def test_noise_loss(self): 30 | lambdas = {"l_mse_1": 1, "l_mse_2": 1} 31 | loss_func = noise_loss(lambdas, torch.randn(self.shape), torch.randn(self.shape), torch.randn(self.mask_shape)) 32 | self.run(loss_func) 33 | 34 | @pytest.mark.parametrize("device", possible_devices) 35 | def test_w_pluss_style_loss(self, device): 36 | lambdas = {"l_style": 1, "l_percept": 1, "l_mse": 1} 37 | loss_func = w_plus_style_loss( 38 | lambdas, 39 | torch.randn(self.shape).to(device), 40 | torch.randn(self.shape).to(device), 41 | torch.randn(self.mask_shape).to(device), 42 | device 43 | ) 44 | self.run(loss_func, device=device) 45 | 46 | def test_naive_noise_loss(self): 47 | lambdas = {"l_mse": 1} 48 | loss_func = naive_noise_loss(lambdas) 49 | self.run(loss_func) 50 | 51 | @pytest.mark.parametrize("device", possible_devices) 52 | def test_w_plus_loss(self, device): 53 | lambdas = {"l_percept": 1, "l_mse": 1} 54 | loss_func = w_plus_loss(lambdas, device) 55 | self.run(loss_func, device=device) 56 | -------------------------------------------------------------------------------- /tests/test_losses.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import pytest 4 | import torch 5 | 6 | from losses.perceptual_loss import PerceptualLoss 7 | from losses.perceptual_style_loss import FixedPerceptualAndStyleLoss 8 | from losses.psnr import PSNR 9 | from losses.style_loss import StyleLoss 10 | 11 | 12 | class TestPSNR: 13 | 14 | @pytest.fixture(autouse=True) 15 | def psnr(self): 16 | self.psnr_func = PSNR() 17 | 18 | def test_psnr_same_input(self): 19 | image = torch.ones((5, 5)) 20 | 21 | psnr = self.psnr_func(image, image) 22 | assert float(psnr) == float('inf') 23 | 24 | def test_psnr_random_input(self): 25 | image_1 = np.random.random((5, 5)) 26 | image_2 = np.random.random((5, 5)) 27 | 28 | cv2_psnr = cv2.PSNR(image_1, image_2, R=1) 29 | psnr = self.psnr_func(torch.Tensor(image_1), torch.Tensor(image_2)) 30 | 31 | assert float(cv2_psnr) == pytest.approx(float(psnr)) 32 | 33 | 34 | class TestStyleLoss: 35 | 36 | @pytest.fixture(autouse=True) 37 | def shape(self): 38 | self.shape = 1, 3, 256, 256 39 | 40 | def test_return_type_style(self): 41 | loss_func = StyleLoss(torch.ones(self.shape)) 42 | loss = loss_func(torch.zeros(self.shape)) 43 | assert isinstance(loss, torch.Tensor) 44 | 45 | def test_return_type_perceptual(self): 46 | loss_func = PerceptualLoss(torch.ones(self.shape)) 47 | loss = loss_func(torch.zeros(self.shape)) 48 | assert isinstance(loss, torch.Tensor) 49 | 50 | def test_return_type_combined(self): 51 | loss_func = FixedPerceptualAndStyleLoss(torch.ones(self.shape), torch.ones(self.shape)) 52 | losses = loss_func(torch.zeros(self.shape)) 53 | assert isinstance(losses, tuple) 54 | assert len(losses) == 2 55 | for loss in losses: 56 | assert isinstance(loss, torch.Tensor) 57 | -------------------------------------------------------------------------------- /tests/test_projecting_functions.py: -------------------------------------------------------------------------------- 1 | import pytest 2 | import torch 3 | from torch import optim 4 | 5 | from latent_projecting import naive_noise_loss, optimize_noise, LatentPaths, Latents, run_image_reconstruction, \ 6 | run_local_style_transfer 7 | from tests.test_projector import ProjectorTests, possible_devices 8 | 9 | 10 | class TestProjectFunctions(ProjectorTests): 11 | 12 | def get_input_data(self, projector, device): 13 | latents = projector.create_initial_latent_and_noise().to(device) 14 | images = torch.ones(self.shape).to(device) 15 | lambdas = {"l_mse": 1} 16 | loss_func = naive_noise_loss(lambdas) 17 | return latents, images, loss_func 18 | 19 | def test_optimize_noise(self, projector, device): 20 | input_data = self.get_input_data(projector, device) 21 | result = optimize_noise(self.args, projector, *input_data) 22 | 23 | assert isinstance(result, tuple) 24 | assert len(result) == 2 25 | assert isinstance(result[0], LatentPaths) 26 | assert isinstance(result[1], Latents) 27 | 28 | @pytest.mark.parametrize("latent_abort_condition", [None, lambda loss_dict: loss_dict['psnr'] < 100]) 29 | @pytest.mark.parametrize("noise_abort_condition", [None, lambda loss_dict: loss_dict['psnr'] < 100]) 30 | @pytest.mark.parametrize("do_optimize_noise", [True, False]) 31 | def test_run_image_reconstruction(self, projector, device, do_optimize_noise, latent_abort_condition, noise_abort_condition): 32 | latents, images, _ = self.get_input_data(projector, device) 33 | result = run_image_reconstruction( 34 | self.args, 35 | projector, 36 | latents, 37 | images, 38 | do_optimize_noise=do_optimize_noise, 39 | latent_abort_condition=latent_abort_condition, 40 | noise_abort_condition=noise_abort_condition 41 | ) 42 | 43 | assert isinstance(result, tuple) 44 | assert len(result) == 2 45 | assert isinstance(result[0], LatentPaths) 46 | assert isinstance(result[1], Latents) 47 | 48 | def test_run_local_style_transfer(self, projector, device): 49 | latents = projector.create_initial_latent_and_noise().to(device) 50 | mask_shape = (1, 1,) + self.shape[-2:] 51 | 52 | self.args.style_latent_step = 5 53 | self.args.style_lr_rampdown = 0 54 | self.args.style_lr_rampup = 0 55 | 56 | self.args.style_noise_step = 5 57 | self.args.noise_style_lr_rampdown = 0 58 | self.args.noise_style_lr_rampup = 0 59 | 60 | result = run_local_style_transfer( 61 | self.args, 62 | projector, 63 | latents, 64 | torch.randn(self.shape).to(device), 65 | torch.randn(self.shape).to(device), 66 | torch.randn(mask_shape).to(device), 67 | ) 68 | 69 | assert isinstance(result, tuple) 70 | assert len(result) == 2 71 | assert isinstance(result[0], LatentPaths) 72 | assert isinstance(result[1], Latents) 73 | -------------------------------------------------------------------------------- /tests/test_projector.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import tempfile 3 | from pathlib import Path 4 | from shutil import which 5 | 6 | import numpy as np 7 | import pytest 8 | import torch 9 | from PIL import Image 10 | from torch import optim 11 | 12 | from latent_projecting import Latents, naive_noise_loss, LatentPaths 13 | from latent_projecting.projector import Projector 14 | from pytorch_training.data import Compose 15 | from utils.command_line_args import add_default_args_for_projecting 16 | 17 | 18 | possible_devices = ['cpu', 'cuda'] 19 | 20 | 21 | @pytest.mark.parametrize("stylegan_variant", ["1", "2"]) 22 | class ProjectorTests: 23 | 24 | @pytest.fixture(autouse=True) 25 | def set_up(self, stylegan_variant): 26 | parser = argparse.ArgumentParser() 27 | parser = add_default_args_for_projecting(parser) 28 | 29 | self.args = parser.parse_args(args=[]) 30 | self.args.latent_step = 5 31 | self.args.noise_step = 5 32 | self.shape = 1, 3, 256, 256 33 | self.args.config = Path(__file__).parent / "testdata" / f"config_stylegan_{stylegan_variant}.json" 34 | 35 | @pytest.fixture(params=possible_devices) 36 | def device(self, request): 37 | if request.param == 'cuda' and not torch.cuda.is_available(): 38 | pytest.skip("can not run cuda tests, since no GPU is available") 39 | return request.param 40 | 41 | @pytest.fixture(autouse=True) 42 | def maybe_skip(self, stylegan_variant, device): 43 | if stylegan_variant == '2' and device == 'cpu': 44 | pytest.skip("Stylegan2 does not work on CPU!") 45 | 46 | @pytest.fixture 47 | def projector(self, stylegan_variant, device): 48 | self.args.device = device 49 | return Projector(self.args) 50 | 51 | 52 | class TestProjector(ProjectorTests): 53 | 54 | def test_config(self, stylegan_variant, projector): 55 | config = projector.config 56 | assert config['stylegan_variant'] == int(stylegan_variant) 57 | for key in ['image_size', 'latent_size', 'stylegan_checkpoint']: 58 | assert config.get(key, None) is not None 59 | 60 | def test_config_without_input(self, stylegan_variant): 61 | self.args.config = None 62 | with pytest.raises(RuntimeError): 63 | Projector(self.args) 64 | 65 | def test_get_blur_transform(self, stylegan_variant, projector): 66 | transform = projector.get_blur_transform() 67 | assert len(transform) == 3 68 | 69 | transform = projector.get_blur_transform(from_tensor=False) 70 | assert len(transform) == 1 71 | 72 | def test_get_transforms(self, projector): 73 | transform = projector.get_transforms() 74 | assert len(transform.transforms) == 3 75 | assert isinstance(transform, Compose) 76 | 77 | image = Image.new('RGB', (512, 512), 'black') 78 | transformed = transform(image) 79 | 80 | assert transformed.min() == -1 81 | assert transformed.shape == (3, projector.config['image_size'], projector.config['image_size']) 82 | 83 | image = Image.new('RGB', (128, 128), 'white') 84 | transformed = transform(image) 85 | 86 | assert transformed.max() == 1 87 | assert transformed.shape == (3, projector.config['image_size'], projector.config['image_size']) 88 | 89 | def test_get_mask_transform(self, projector): 90 | def transform_image(max_val, transformation): 91 | image = np.random.random((224, 224, 3)) 92 | image[0:20, 0:20, :] = max_val 93 | image[20:40, 20:40, :] = 0 94 | image = (image * 255).astype('uint8') 95 | image = Image.fromarray(image, 'RGB') 96 | return transformation(image) 97 | 98 | transform = projector.get_mask_transform() 99 | assert len(transform.transforms) == 6 100 | transformed = transform_image(1, transform) 101 | 102 | assert transformed.min() == 0 103 | assert transformed.max() == 1 104 | assert transformed.shape == (1, projector.config['image_size'], projector.config['image_size']) 105 | 106 | transform = projector.get_mask_transform(invert_mask=True) 107 | assert len(transform.transforms) == 7 108 | transformed = transform_image(1, transform) 109 | 110 | assert transformed[0, 0, 0] == 0 111 | assert transformed[0, 30, 30] == 1 112 | assert transformed.shape == (1, projector.config['image_size'], projector.config['image_size']) 113 | 114 | transform = projector.get_mask_transform(mask_multiplier=0.9) 115 | assert len(transform.transforms) == 7 116 | transformed = transform_image(0.9, transform) 117 | 118 | assert transformed[0, 0, 0] < 1 119 | assert transformed[0, 30, 30] == 0 120 | assert transformed.min() == 0 121 | assert transformed.shape == (1, projector.config['image_size'], projector.config['image_size']) 122 | 123 | def test_get_mean_latent(self, projector): 124 | mean_latent = projector.get_mean_latent(1000) 125 | 126 | assert isinstance(mean_latent, tuple) 127 | assert len(mean_latent) == 2 128 | 129 | assert mean_latent[0].numel() == projector.config['latent_size'] 130 | assert mean_latent[1].numel() == 1 131 | 132 | def test_requires_grad(self, projector): 133 | latents = projector.create_initial_latent_and_noise() 134 | assert latents.latent.requires_grad is False 135 | for noise in latents.noise: 136 | assert noise.requires_grad is False 137 | 138 | projector.set_requires_grad(latents, True) 139 | assert latents.latent.requires_grad is True 140 | for noise in latents.noise: 141 | assert noise.requires_grad is False 142 | 143 | projector.set_requires_grad(latents, False) 144 | assert latents.latent.requires_grad is False 145 | for noise in latents.noise: 146 | assert noise.requires_grad is True 147 | 148 | def test_generate(self, projector, device): 149 | latents = projector.create_initial_latent_and_noise() 150 | generated = projector.generate(latents)[0] 151 | 152 | assert isinstance(generated, torch.Tensor) 153 | assert generated.shape[-2:] == (projector.config['image_size'], projector.config['image_size']) 154 | 155 | def run_project(self, projector, device): 156 | latents = projector.create_initial_latent_and_noise().to(device) 157 | images = torch.ones(self.shape).to(device) 158 | optimizer = optim.Adam([latents.latent], lr=projector.args.lr) 159 | lambdas = {"l_mse": 1} 160 | loss_func = naive_noise_loss(lambdas) 161 | 162 | return projector.project(latents, images, optimizer, 5, loss_func) 163 | 164 | def test_project(self, projector, device): 165 | result = self.run_project(projector, device) 166 | assert isinstance(result, tuple) 167 | assert len(result) == 2 168 | 169 | assert isinstance(result[0], LatentPaths) 170 | assert isinstance(result[1], Latents) 171 | 172 | assert result[1].latent.shape[2] == projector.config['latent_size'] 173 | 174 | @pytest.mark.skipif(which('convert') is None, reason="Convert not installed on system and necessary for gif creation") 175 | def test_create_gif(self, projector, device): 176 | latent_paths, _ = self.run_project(projector, device) 177 | with tempfile.TemporaryDirectory() as temp_dir: 178 | temp_dir = Path(temp_dir) 179 | file_name = 'test' 180 | projector.create_gif(latent_paths, file_name, temp_dir) 181 | assert (temp_dir / 'gifs' / f"{file_name}.gif").exists() 182 | 183 | def test_render_log(self, projector, device): 184 | latent_paths, _ = self.run_project(projector, device) 185 | with tempfile.TemporaryDirectory() as temp_dir: 186 | temp_dir = Path(temp_dir) 187 | base_name = 'test' 188 | projector.render_log(temp_dir, base_name) 189 | 190 | assert (temp_dir / 'log' / f"{base_name}_log.json").exists() 191 | 192 | possible_keys = projector.log[0].keys() 193 | for key in possible_keys: 194 | assert (temp_dir / 'log' / f"{base_name}_{key}.png").exists() 195 | -------------------------------------------------------------------------------- /tests/test_style_transfer.py: -------------------------------------------------------------------------------- 1 | import tempfile 2 | from pathlib import Path 3 | from unittest.mock import patch 4 | 5 | import pytest 6 | import torch 7 | from PIL import Image 8 | 9 | from latent_projecting import Latents, LatentPaths 10 | from latent_projecting.style_transfer import StyleTransferer 11 | from tests.test_projector import ProjectorTests 12 | 13 | original_load = torch.load 14 | 15 | 16 | def load_patch_stylegan_1(file_name, *args, **kwargs): 17 | file_name = str(file_name) 18 | if 'content' in file_name or 'style' in file_name: 19 | return {'latent': torch.randn(1, 14, 512), 'noise': [torch.randn((1, 1, 2**(i + 2))) for i in range(7)]} 20 | else: 21 | return original_load(file_name) 22 | 23 | 24 | def load_patch_stylegan_2(file_name, *args, **kwargs): 25 | file_name = str(file_name) 26 | if 'content' in file_name or 'style' in file_name: 27 | shape_filler = lambda i: 2**(max(2, i - (i % 2) + 1)) 28 | return {'latent': torch.randn(1, 14, 512), 'noise': [torch.randn((1, 1, shape_filler(i), shape_filler(i))) for i in range(1, 14)]} 29 | else: 30 | return original_load(file_name) 31 | 32 | 33 | class TestStyleTransfer(ProjectorTests): 34 | 35 | @pytest.fixture() 36 | def transferer(self, device): 37 | self.args.device = device 38 | return StyleTransferer(self.args) 39 | 40 | def test_embed_image(self, transferer, device, stylegan_variant): 41 | image = Image.new("RGB", (256, 256), "black") 42 | 43 | with tempfile.TemporaryDirectory() as temp_dir: 44 | temp_dir = Path(temp_dir) 45 | image_name = temp_dir / "image.png" 46 | image.save(image_name) 47 | 48 | latents = transferer.embed_image(image_name) 49 | 50 | assert isinstance(latents, Latents) 51 | 52 | assert latents.latent.shape == (1, 14, 512) 53 | if stylegan_variant == "1": 54 | assert len(latents.noise) == 7 55 | else: 56 | assert len(latents.noise) == 13 57 | 58 | @pytest.mark.parametrize('style_image', [True, False]) 59 | @pytest.mark.parametrize('content_image', [True, False]) 60 | @patch('latent_projecting.style_transfer.torch.load') 61 | def test_get_latents(self, load, transferer, device, content_image, style_image, stylegan_variant): 62 | if stylegan_variant == '1': 63 | load.side_effect = load_patch_stylegan_1 64 | else: 65 | load.side_effect = load_patch_stylegan_2 66 | 67 | blank_image = Image.new("RGB", (256, 256), 'black') 68 | 69 | with tempfile.TemporaryDirectory() as temp_dir: 70 | temp_dir = Path(temp_dir) 71 | 72 | if content_image: 73 | content_name = temp_dir / 'content.png' 74 | blank_image.save(content_name) 75 | else: 76 | content_name = temp_dir / 'content.pth' 77 | 78 | if style_image: 79 | style_name = temp_dir / 'style.png' 80 | blank_image.save(style_name) 81 | else: 82 | style_name = temp_dir / 'style.pth' 83 | 84 | content_latent, style_latent = transferer.get_latents(content_name, style_name) 85 | 86 | assert isinstance(content_latent, Latents) 87 | assert isinstance(style_latent, Latents) 88 | 89 | for latent in [content_latent, style_latent]: 90 | assert latent.latent.shape == (1, 14, 512) 91 | if stylegan_variant == '1': 92 | assert len(latent.noise) == 7 93 | else: 94 | assert len(latent.noise) == 13 95 | 96 | assert device in str(latent.latent.device) 97 | for noise in latent.noise: 98 | assert device in str(noise.device) 99 | 100 | def test_post_noise_optimize(self, transferer): 101 | content_latent = transferer.projector.create_initial_latent_and_noise() 102 | transfer_latent = transferer.projector.create_initial_latent_and_noise() 103 | 104 | result = transferer.post_noise_optimize(content_latent, transfer_latent) 105 | 106 | assert isinstance(result, tuple) 107 | assert len(result) == 2 108 | assert isinstance(result[0], LatentPaths) 109 | assert isinstance(result[1], Latents) 110 | 111 | @pytest.mark.parametrize('post_optimize', [True, False]) 112 | def test_do_style_transfer(self, post_optimize, device): 113 | self.args.device = device 114 | self.args.post_optimize = post_optimize 115 | transferer = StyleTransferer(self.args) 116 | 117 | content_latent = transferer.projector.create_initial_latent_and_noise() 118 | transfer_latent = transferer.projector.create_initial_latent_and_noise() 119 | 120 | results = [transferer.do_style_transfer(content_latent, transfer_latent, i) for i in range(14)] 121 | 122 | for result in results: 123 | assert isinstance(result, tuple) 124 | assert len(result) == 2 125 | assert isinstance(result[0], torch.Tensor) 126 | if post_optimize: 127 | assert isinstance(result[1], LatentPaths) 128 | else: 129 | assert result[1] is None 130 | -------------------------------------------------------------------------------- /tests/testdata/config_stylegan_1.json: -------------------------------------------------------------------------------- 1 | { 2 | "image_save_iter": 500, 3 | "image_display_iter": 500, 4 | "display_size": 16, 5 | "snapshot_save_iter": 10000, 6 | "log_iter": 10, 7 | "max_iter": 1000000, 8 | "batch_size": 8, 9 | "weight_decay": 0.0001, 10 | "beta1": 0.5, 11 | "beta2": 0.999, 12 | "init": "kaiming", 13 | "lr": 0.0001, 14 | "lr_policy": "step", 15 | "step_size": 100000, 16 | "gamma": 0.5, 17 | "loss_weights": { 18 | "reconstruction": 0.1, 19 | "discriminator": 1 20 | }, 21 | "regularization": { 22 | "d_interval": 16, 23 | "r1_weight": 10 24 | }, 25 | "latent_size": 512, 26 | "input_dim": 3, 27 | "num_workers": 0, 28 | "image_size": 256, 29 | "config": "configs/autoencoder.yaml", 30 | "stylegan_checkpoint": "/home/christian/workspace/pycharm-upload/wpi/semantic_segmentation/style-based-gan-pytorch/checkpoint_256/820000.model", 31 | "images": "/mnt/ssd1/christian/wpi/segmentation/generation_with_gan/without_vollard_no_border/images.json", 32 | "stylegan_variant": 1, 33 | "absolute": false, 34 | "device": "cuda", 35 | "log_dir": "logs/wpi_only_autoencoder/unet_encoder_lr_shift/2020-05-20T11:19:17.498041", 36 | "log_name": "unet_encoder_lr_shift" 37 | } 38 | -------------------------------------------------------------------------------- /tests/testdata/config_stylegan_2.json: -------------------------------------------------------------------------------- 1 | { 2 | "image_save_iter": 500, 3 | "image_display_iter": 500, 4 | "display_size": 16, 5 | "snapshot_save_iter": 10000, 6 | "log_iter": 10, 7 | "max_iter": 1000000, 8 | "batch_size": 8, 9 | "weight_decay": 0.0001, 10 | "beta1": 0.5, 11 | "beta2": 0.999, 12 | "init": "kaiming", 13 | "lr": 0.0001, 14 | "lr_policy": "step", 15 | "step_size": 100000, 16 | "gamma": 0.5, 17 | "loss_weights": { 18 | "reconstruction": 0.1, 19 | "discriminator": 1 20 | }, 21 | "regularization": { 22 | "d_interval": 16, 23 | "r1_weight": 10 24 | }, 25 | "latent_size": 512, 26 | "input_dim": 3, 27 | "num_workers": 0, 28 | "image_size": 256, 29 | "config": "configs/autoencoder.yaml", 30 | "stylegan_checkpoint": "/home/christian/workspace/pycharm-upload/wpi/stylegan-pytorch/logs/2020-05-08T13:07:44.825280_latent_512_no_mirror/checkpoint/250000.pt", 31 | "images": "/mnt/ssd1/christian/wpi/segmentation/generation_with_gan/without_vollard_no_border/images.json", 32 | "stylegan_variant": 2, 33 | "absolute": false, 34 | "device": "cuda", 35 | "log_dir": "logs/wpi_only_autoencoder/unet_encoder_lr_shift/2020-05-20T11:19:17.498041", 36 | "log_name": "unet_encoder_lr_shift" 37 | } 38 | -------------------------------------------------------------------------------- /updater/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/updater/__init__.py -------------------------------------------------------------------------------- /updater/autoencoder_discriminator_updater.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | from torch.nn.parallel import DistributedDataParallel 5 | 6 | from pytorch_training.reporter import get_current_reporter 7 | from pytorch_training.updater import UpdateDisabler, GradientApplier 8 | 9 | from updater.autoencoder_updater import AutoencoderUpdater 10 | 11 | 12 | class AutoencoderDiscriminatorUpdater(AutoencoderUpdater): 13 | 14 | def __init__(self, *args, **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.regularization_settings = { 17 | "d_interval": 16, 18 | "r1_weight": 10, 19 | } 20 | 21 | def get_discriminator(self) -> nn.Module: 22 | discriminator = self.networks['discriminator'] 23 | if isinstance(discriminator, DistributedDataParallel): 24 | discriminator_module = discriminator.module 25 | else: 26 | discriminator_module = discriminator 27 | return discriminator_module 28 | 29 | def update_core(self): 30 | reporter = get_current_reporter() 31 | image_batch = next(self.iterators['images']) 32 | image_batch = {k: v.to(self.device) for k, v in image_batch.items()} 33 | 34 | discriminator_observations = self.update_discriminator( 35 | image_batch['input_image'].clone().detach(), 36 | image_batch['output_image'].clone().detach(), 37 | ) 38 | reporter.add_observation(discriminator_observations, 'discriminator') 39 | 40 | generator_observations = self.update_generator( 41 | image_batch['input_image'].clone().detach(), 42 | image_batch['output_image'].clone().detach(), 43 | ) 44 | reporter.add_observation(generator_observations, 'generator') 45 | 46 | def update_discriminator(self, input_images: torch.Tensor, output_images: torch.Tensor) -> dict: 47 | autoencoder = self.get_autoencoder() 48 | discriminator = self.get_discriminator() 49 | discriminator_optimizer = self.optimizers['discriminator'] 50 | 51 | with UpdateDisabler(autoencoder), GradientApplier([discriminator], [discriminator_optimizer]): 52 | reconstructed_image = autoencoder(input_images) 53 | fake_prediction = discriminator(reconstructed_image) 54 | fake_loss = F.softplus(fake_prediction).mean() 55 | fake_loss.backward() 56 | 57 | real_prediction = discriminator(output_images.detach()) 58 | real_loss = F.softplus(-real_prediction).mean() 59 | real_loss.backward() 60 | 61 | discriminator_loss = real_loss.detach() + fake_loss.detach() 62 | 63 | loss_data = { 64 | 'loss': discriminator_loss, 65 | 'real_score': real_prediction.mean(), 66 | 'fake_score': fake_prediction.mean() 67 | } 68 | 69 | if self.iteration % self.regularization_settings['d_interval'] == 0: 70 | image.requires_grad = True 71 | real_prediction = discriminator(image) 72 | grad_of_reference_image, = torch.autograd.grad(outputs=real_prediction.sum(), inputs=image, create_graph=True) 73 | gradient_penalty = grad_of_reference_image.pow(2).view(grad_of_reference_image.shape[0], -1).sum(1).mean() 74 | 75 | discriminator.zero_grad() 76 | (self.regularization_settings['r1_weight'] / 2 * gradient_penalty * self.regularization_settings['d_interval'] + 0 * real_prediction[0]).backward() 77 | discriminator_optimizer.step() 78 | 79 | loss_data['gradient_penalty'] = self.regularization_settings['r1_weight'] / 2 * gradient_penalty.detach().cpu() * self.regularization_settings['d_interval'] 80 | 81 | torch.cuda.empty_cache() 82 | 83 | return loss_data 84 | 85 | def update_generator(self, input_images: torch.Tensor, output_images: torch.Tensor) -> dict: 86 | autoencoder = self.get_autoencoder() 87 | discriminator = self.get_discriminator() 88 | 89 | reporter = get_current_reporter() 90 | 91 | autoencoder_optimizer = self.optimizers['main'] 92 | log_data = {} 93 | 94 | with UpdateDisabler(autoencoder.decoder), GradientApplier([autoencoder], [autoencoder_optimizer]): 95 | reconstructed_images = autoencoder(input_images) 96 | 97 | mse_loss = F.mse_loss(output_images, reconstructed_images, reduction='none') 98 | loss = mse_loss.mean(dim=(1, 2, 3)).sum() 99 | reporter.add_observation({"reconstruction_loss": loss}, prefix='loss') 100 | if self.use_perceptual_loss: 101 | perceptual_loss = self.perceptual_loss(reconstructed_images, output_images).sum() 102 | loss += perceptual_loss 103 | reporter.add_observation( 104 | {"autoencoder_loss": loss, "perceptual_loss": perceptual_loss}, 105 | prefix='loss' 106 | ) 107 | 108 | discriminator_prediction = discriminator(reconstructed_images) 109 | discriminator_loss = F.softplus(-discriminator_prediction).mean() 110 | 111 | loss += discriminator_loss 112 | loss.backward() 113 | 114 | log_data.update({ 115 | "loss": loss, 116 | "discriminator_loss": discriminator_loss, 117 | }) 118 | torch.cuda.empty_cache() 119 | 120 | return log_data 121 | 122 | -------------------------------------------------------------------------------- /updater/autoencoder_updater.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.nn import functional as F 4 | from torch.nn.parallel import DistributedDataParallel 5 | 6 | from losses.lpips import PerceptualLoss 7 | from pytorch_training import Updater 8 | from pytorch_training.reporter import get_current_reporter 9 | from pytorch_training.updater import UpdateDisabler, GradientApplier 10 | 11 | 12 | class AutoencoderUpdater(Updater): 13 | 14 | def __init__(self, *args, use_perceptual_loss: bool = True, disable_update_for: str = 'none', **kwargs): 15 | super().__init__(*args, **kwargs) 16 | self.perceptual_loss = PerceptualLoss(model='net-lin', net='vgg', use_gpu=True, gpu_ids=[self.device]) 17 | self.use_perceptual_loss = use_perceptual_loss 18 | self.disable_update(disable_update_for) 19 | 20 | def get_autoencoder(self) -> nn.Module: 21 | autoencoder = self.networks['autoencoder'] 22 | if isinstance(autoencoder, DistributedDataParallel): 23 | autoencoder_module = autoencoder.module 24 | else: 25 | autoencoder_module = autoencoder 26 | return autoencoder_module 27 | 28 | def disable_update(self, disable_update_for: str): 29 | if disable_update_for == 'none': 30 | return 31 | 32 | autoencoder = self.get_autoencoder() 33 | 34 | disable_noise = disable_update_for == 'noise' 35 | for name, parameter in autoencoder.encoder.named_parameters(): 36 | if 'noise' in name: 37 | parameter.requires_grad = not disable_noise 38 | else: 39 | parameter.requires_grad = disable_noise 40 | 41 | if disable_noise: 42 | autoencoder.use_generated_noise = False 43 | 44 | def calculate_loss(self, input_images: torch.Tensor, reconstructed_images: torch.Tensor): 45 | reporter = get_current_reporter() 46 | 47 | mse_loss = F.mse_loss(input_images, reconstructed_images, reduction='none') 48 | loss = mse_loss.mean(dim=(1, 2, 3)).sum() 49 | reporter.add_observation({"reconstruction_loss": loss}, prefix='loss') 50 | if self.use_perceptual_loss: 51 | perceptual_loss = self.perceptual_loss(reconstructed_images, input_images).sum() 52 | reporter.add_observation({"perceptual_loss": perceptual_loss}, prefix='loss') 53 | loss += perceptual_loss 54 | 55 | loss.backward() 56 | reporter.add_observation({"autoencoder_loss": loss}, prefix='loss') 57 | 58 | def update_core(self): 59 | autoencoder = self.get_autoencoder() 60 | 61 | with UpdateDisabler(autoencoder.decoder), GradientApplier([autoencoder], [self.optimizers['main']]): 62 | image_batch = next(self.iterators['images']) 63 | image_batch = {k: v.to(self.device) for k, v in image_batch.items()} 64 | 65 | reconstructed_images = autoencoder(image_batch['input_image']) 66 | 67 | self.calculate_loss(image_batch['output_image'], reconstructed_images) 68 | 69 | -------------------------------------------------------------------------------- /utils/StyleImagePlotter.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | 5 | from pytorch_training.extensions import ImagePlotter 6 | 7 | 8 | class StyleImagePlotter(ImagePlotter): 9 | 10 | def __init__(self, *args, style_images: list = None, **kwargs): 11 | super().__init__(*args, **kwargs) 12 | assert style_images is not None, "You have to supply style images in order to use StyleImagePlotter" 13 | 14 | self.style_images = torch.stack(style_images).cuda() 15 | 16 | def get_predictions(self) -> List[torch.Tensor]: 17 | assert len(self.networks) == 2, f"StyleImagePlotter assumes that there are two networks for plotting, but there is/are {len(self.networks)}" 18 | 19 | predictions = [self.input_images, self.style_images] 20 | generated_images = self.networks[0](self.input_images, self.style_images) 21 | predictions.append(generated_images) 22 | 23 | reconstructed_images = self.networks[1](generated_images) 24 | predictions.append(reconstructed_images) 25 | return predictions 26 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Bartzi/one-model-to-reconstruct-them-all/7c099a310e1e206be65d283fa2012bc8afa1a388/utils/__init__.py -------------------------------------------------------------------------------- /utils/clean_runs.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | function usage_and_exit() { 4 | echo "usage: clean_runs.sh [-y] experiment_folder [experiment_folder...]" 5 | echo "" 6 | echo "Options:" 7 | echo " -y : confirm yes to deleting old models (you still need to confirm deletion of empty runs)" 8 | exit 1 9 | } 10 | 11 | function check_yes() { 12 | read -p "$1 [y/N]? " REPLY 13 | echo "" 14 | if [[ ! "$REPLY" =~ ^[Yy]$ ]]; then 15 | return 1 16 | fi 17 | return 0 18 | } 19 | 20 | function clean_checkpoints() 21 | { 22 | if ! [ -d "code/wandb" ]; then 23 | echo "Removing code/wandb..." 24 | rm -rf "code/wandb" 25 | fi 26 | 27 | if ! [ -d "checkpoints" ] || [ "$(ls "checkpoints" | wc -l)" = "0" ]; then 28 | echo "The run does not seem to have any checkpoints. Scanning its size..." 29 | echo "The folder contains: $(du -sh .)" 30 | if check_yes "> Delete it COMPLETELY?"; then 31 | local current_folder="$(basename "$(pwd)")" 32 | cd .. 33 | rm -rf "${current_folder}" 34 | fi 35 | return 36 | fi 37 | 38 | # categorize all models 39 | local all_models=$( cd "checkpoints" && ls ) 40 | local last_snapshot=$( tail -n -1 <<< "${all_models}" | tr '\n' ' ' ) 41 | 42 | if [ "${last_snapshot}" = "100000.pt" ] || [ "$((10#${last_snapshot/.pt/}+0))" -le 100000 ]; then 43 | local old_models=$( head -n -1 <<< "${all_models}" | tr '\n' ' ' ) 44 | local keep_models="${last_snapshot}" 45 | else 46 | local keep_models="100000.pt ${last_snapshot}" 47 | local old_models=$( grep -P -v 100000.pt <<< "${all_models}" | grep -P -v ${last_snapshot} | tr '\n' ' ' ) 48 | fi 49 | 50 | old_models=$( echo "${old_models}" | xargs ) 51 | if [ ! -z "${old_models}" ]; then 52 | echo "> Deleting the following obsolete model files:" 53 | echo " ${old_models}" 54 | echo "> The following model files are kept:" 55 | echo " ${keep_models}" 56 | if [ "${confirm}" = false ] || check_yes "> Delete the files listed under deletion?"; then 57 | ( cd "checkpoints" && rm -f -- ${old_models} ) 58 | fi 59 | fi 60 | } 61 | 62 | confirm=true 63 | 64 | for arg in "$@"; do 65 | if [ "${arg}" = "-y" ]; then 66 | confirm=false 67 | continue 68 | fi 69 | done 70 | 71 | for argument in "$@"; do 72 | if [ "${argument}" = "-y" ]; then 73 | continue 74 | fi 75 | ( 76 | echo "Processing ${argument}..." 77 | cd "${argument}" 78 | clean_checkpoints 79 | ) 80 | done 81 | -------------------------------------------------------------------------------- /utils/command_line_args.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | 4 | def add_default_args_for_projecting(parser: ArgumentParser) -> ArgumentParser: 5 | parser.add_argument('--ckpt', type=str) 6 | parser.add_argument('--size', type=int, default=256) 7 | parser.add_argument('--lr-rampup', type=float, default=0.0) 8 | parser.add_argument('--lr-rampdown', type=float, default=0.0) 9 | parser.add_argument('--lr', type=float, default=0.01) 10 | parser.add_argument('--noise-lr', type=float, default=5) 11 | parser.add_argument('--noise-lr-rampup', type=float, default=0.0) 12 | parser.add_argument('--noise-lr-rampdown', type=float, default=0.0) 13 | parser.add_argument('--latent-step', type=int, default=5000) 14 | parser.add_argument('--noise-step', type=int, default=3000) 15 | parser.add_argument('--mse', type=float, default=1) 16 | parser.add_argument('--no-w-plus', dest='w_plus', action='store_false', default=True) 17 | parser.add_argument('-b', '--batch-size', type=int, default=16, help='batch size for projecting') 18 | parser.add_argument('--no-mean-latent', action='store_true', 19 | help="use pure random latent for start") 20 | parser.add_argument('--device', default='cuda', help="which device to use") 21 | parser.add_argument('--config', 22 | help='path to a config file if the ckpt is not saved in a location that also contains config info for the train run') 23 | parser.add_argument('--debug-step', type=int, default=50, 24 | help='number of iterations after which to save a debug image') 25 | 26 | return parser 27 | -------------------------------------------------------------------------------- /utils/config.py: -------------------------------------------------------------------------------- 1 | import json 2 | from pathlib import Path 3 | 4 | import yaml 5 | 6 | 7 | def load_config(checkpoint_path: str = None, config: str = None) -> dict: 8 | if checkpoint_path is None and config is None: 9 | raise RuntimeError("You have to supply either checkpoint path or path to a config file!") 10 | 11 | if config is not None: 12 | with open(config) as f: 13 | config = json.load(f) 14 | if checkpoint_path is not None: 15 | config['stylegan_checkpoint'] = checkpoint_path 16 | assert config.get('stylegan_checkpoint', None) is not None 17 | 18 | return config 19 | 20 | config_dir = Path(checkpoint_path).parent.parent / 'config' 21 | original_config = config_dir / 'config.json' 22 | with open(original_config) as f: 23 | original_config = json.load(f) 24 | 25 | original_args = config_dir / 'args.json' 26 | with open(original_args) as f: 27 | original_args = json.load(f) 28 | 29 | original_config.update(original_args) 30 | 31 | return original_config 32 | 33 | 34 | def load_yaml_config(config_path): 35 | with open(config_path) as f: 36 | return yaml.safe_load(f) 37 | -------------------------------------------------------------------------------- /utils/convert_autoencoder_checkpoint.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import torch 5 | 6 | 7 | def convert_autoencoder_checkpoint(checkpoint): 8 | encoder_weights = {} 9 | decoder_weights = {} 10 | autoencoder_weights = {} 11 | 12 | if all(key in checkpoint for key in ['encoder', 'decoder']): 13 | # already converted, no need for further conversion 14 | return checkpoint 15 | 16 | for name, weight in checkpoint['autoencoder'].items(): 17 | name = name.split('.') 18 | 19 | if name[0] == 'module': 20 | name = name[1:] 21 | 22 | for name_part in ['encoder', 'decoder']: 23 | name_part_in_name = [n == name_part for n in name] 24 | if any(name_part_in_name): 25 | new_name = '.'.join(name[name_part_in_name.index(True) + 1:]) 26 | eval(f"{name_part}_weights")[new_name] = weight 27 | break 28 | 29 | autoencoder_weights['.'.join(name)] =weight 30 | 31 | checkpoint['autoencoder'] = autoencoder_weights 32 | checkpoint['encoder'] = encoder_weights 33 | checkpoint['decoder'] = decoder_weights 34 | 35 | return checkpoint 36 | 37 | 38 | if __name__ == "__main__": 39 | parser = argparse.ArgumentParser(description="convert autoencoder checkpoint without decoder and encoder keys") 40 | parser.add_argument('checkpoint', help='path to checkpoint to convert') 41 | parser.add_argument('--destination', help='if you want to save it in another file (relative to original checkpoint)') 42 | 43 | args = parser.parse_args() 44 | 45 | checkpoint = torch.load(args.checkpoint) 46 | checkpoint = convert_autoencoder_checkpoint(checkpoint) 47 | 48 | dest_file_name = Path(args.checkpoint) 49 | if args.destination is not None: 50 | dest_file_name = dest_file_name.parent / args.destination 51 | 52 | torch.save( 53 | checkpoint, 54 | str(dest_file_name) 55 | ) 56 | -------------------------------------------------------------------------------- /utils/create_denoising_eval_set.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | from pathlib import Path 4 | 5 | 6 | if __name__ == "__main__": 7 | parser = argparse.ArgumentParser(description="Take CBSD68 dir and create multiple evaluation dataset") 8 | parser.add_argument("cbs_path", help="path to root dir of cbs68 dataset") 9 | 10 | args = parser.parse_args() 11 | 12 | cbs_root = Path(args.cbs_path) 13 | original_image_dir = cbs_root / 'original_png' 14 | noisy_image_dirs = cbs_root.glob('noisy*') 15 | 16 | original_pngs = list(sorted(original_image_dir.glob("*.png"), key=lambda x: int(x.stem))) 17 | for noisy_image_dir in noisy_image_dirs: 18 | if not noisy_image_dir.is_dir(): 19 | continue 20 | 21 | noisy_pngs = list(sorted(noisy_image_dir.glob("*.png"), key=lambda x: int(x.stem))) 22 | 23 | assert len(original_pngs) == len(noisy_pngs), f"number of original and noisy images is not equal!!, {len(original_pngs)} vs {len(noisy_pngs)}" 24 | 25 | gt = [ 26 | {'original': str(original_png), 'noisy': str(noisy_png)} 27 | for original_png, noisy_png in zip(original_pngs, noisy_pngs) 28 | ] 29 | 30 | gt_file_name = cbs_root / f"{noisy_image_dir.parts[-1]}.json" 31 | 32 | with gt_file_name.open('w') as f: 33 | json.dump(gt, f, indent='\t') 34 | -------------------------------------------------------------------------------- /utils/data_loading.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import Union, Dict, Iterable, Type 4 | 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import DataLoader, DistributedSampler 8 | from torchvision import transforms 9 | 10 | from data.autoencoder_dataset import AutoencoderDataset 11 | from latent_projecting import Latents 12 | from networks import StyleganAutoencoder 13 | from pytorch_training.data.utils import default_loader 14 | from pytorch_training.distributed import get_world_size, get_rank 15 | 16 | 17 | def resilient_loader(path): 18 | try: 19 | return default_loader(path) 20 | except Exception: 21 | return Image.new('RGB', (256, 256)) 22 | 23 | 24 | def build_data_loader(image_path: Union[str, Path], config: dict, uses_absolute_paths: bool, shuffle_off: bool = False, dataset_class: Type[AutoencoderDataset] = AutoencoderDataset) -> DataLoader: 25 | transform_list = [ 26 | transforms.Resize((config['image_size'], config['image_size'])), 27 | transforms.ToTensor(), 28 | transforms.Normalize((0.5,) * config['input_dim'], (0.5,) * config['input_dim']) 29 | ] 30 | transform_list = transforms.Compose(transform_list) 31 | 32 | dataset = dataset_class( 33 | image_path, 34 | root=os.path.dirname(image_path) if not uses_absolute_paths else None, 35 | transforms=transform_list, 36 | loader=resilient_loader, 37 | ) 38 | 39 | sampler = None 40 | if get_world_size() > 1: 41 | sampler = DistributedSampler(dataset, shuffle=not shuffle_off) 42 | sampler.set_epoch(get_rank()) 43 | 44 | if shuffle_off: 45 | shuffle = False 46 | else: 47 | shuffle = sampler is None 48 | 49 | loader = DataLoader( 50 | dataset, 51 | config['batch_size'], 52 | shuffle=shuffle, 53 | drop_last=True, 54 | sampler=sampler, 55 | ) 56 | return loader 57 | 58 | 59 | def build_latent_and_noise_generator(autoencoder: StyleganAutoencoder, config: Dict) -> Iterable: 60 | torch.random.manual_seed(1) 61 | while True: 62 | latent_code = torch.randn(config['batch_size'], config['latent_size']) 63 | noise = autoencoder.decoder.make_noise() 64 | yield Latents(latent_code, noise) 65 | -------------------------------------------------------------------------------- /utils/folder_to_json.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import os 4 | from pathlib import Path 5 | 6 | from tqdm import tqdm 7 | 8 | from pytorch_training.images import is_image 9 | 10 | 11 | def get_file_name(root: str) -> str: 12 | for dir, _, files in os.walk(root): 13 | for file_name in files: 14 | if is_image(file_name) and not "embed-test" in dir: 15 | yield os.path.relpath(os.path.join(dir, file_name), start=root) 16 | 17 | 18 | if __name__ == "__main__": 19 | parser = argparse.ArgumentParser(description="Convert files in dir to json for training") 20 | parser.add_argument("dir") 21 | parser.add_argument("--split", action='store_true', default=False, help='create train and val split') 22 | 23 | args = parser.parse_args() 24 | 25 | files = [name for name in tqdm(get_file_name(args.dir))] 26 | 27 | dest_dir = Path(args.dir) 28 | if args.split: 29 | split_index = int(len(files) * 0.9) 30 | train_data = files[:split_index] 31 | val_data = files[split_index:] 32 | 33 | with (dest_dir / 'train.json').open('w') as f: 34 | json.dump(train_data, f, indent='\t') 35 | 36 | with (dest_dir / 'val.json').open('w') as f: 37 | json.dump(val_data, f, indent='\t') 38 | else: 39 | with open(os.path.join(args.dir, "images.json"), 'w') as f: 40 | json.dump(files, f, indent='\t') 41 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageDraw 2 | 3 | 4 | def render_text_on_image(text: str, image: Image) -> Image: 5 | draw = ImageDraw.Draw(image) 6 | 7 | font = draw.getfont() 8 | text_size = draw.textsize(text, font=font) 9 | text_location = (image.width - text_size[0], image.height - text_size[1], image.width, image.height) 10 | draw.rectangle(text_location, fill=(255, 255, 255, 128)) 11 | draw.text(text_location[:2], text, font=font, fill=(0, 255, 0)) 12 | 13 | return image 14 | -------------------------------------------------------------------------------- /utils/strip_images_from_big_embedding_file.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | 8 | if __name__ == "__main__": 9 | parser = argparse.ArgumentParser(description="Take a file with many image embeddings and extract a given number of image embeddings from it") 10 | parser.add_argument("embedding_file", help="path to embedding file") 11 | parser.add_argument("-n", "--num-samples", type=int, default=100, help="number of embeddings to extract") 12 | 13 | args = parser.parse_args() 14 | 15 | embedded_data = np.load(args.embedding_file, mmap_mode='r') 16 | 17 | image_data = {key: embedded_data[key][:args.num_samples] for key in tqdm(list(embedded_data.keys()))} 18 | 19 | embedding_path = Path(args.embedding_file) 20 | embedding_name_parts = embedding_path.stem.split('_') 21 | embedding_name_parts[0] = 'small_embedding' 22 | new_embedding_name = '_'.join(embedding_name_parts) 23 | with (embedding_path.parent / f"{new_embedding_name}.npz").open('wb') as f: 24 | np.savez(f, **image_data) 25 | --------------------------------------------------------------------------------