├── .gitignore ├── .gitmodules ├── README.md ├── assets └── images │ ├── commands.txt │ ├── feature_optimization.png │ ├── image_to_latent_predictions.png │ ├── initial_method.png │ ├── latent_benefits.png │ ├── latent_difference.png │ ├── male_female_actual.png │ ├── male_female_gaussian.png │ ├── male_female_mapping.png │ ├── point_female.png │ ├── point_male.png │ ├── test_01 │ ├── gender │ │ ├── 000_003.jpg │ │ ├── 000_004.jpg │ │ ├── 000_005.jpg │ │ ├── 000_006.jpg │ │ ├── 000_007.jpg │ │ └── test_01_w_to_m.gif │ ├── pose │ │ ├── 000_002.jpg │ │ ├── 000_003.jpg │ │ ├── 000_004.jpg │ │ └── test_01_pose.gif │ ├── test_01.png │ ├── test_01_optimization.gif │ └── test_01_optimized.png │ └── test_02 │ ├── age │ ├── old.jpg │ ├── test_02_age.gif │ └── young.jpg │ ├── gender │ ├── female.jpg │ ├── male.jpg │ └── test_02_gender.gif │ ├── glasses │ ├── 000_010.jpg │ ├── 000_018.jpg │ └── test_02_glasses.gif │ ├── test_02.jpg │ └── test_02.npy ├── encode_image.py ├── models ├── image_to_latent.py ├── latent_optimizer.py └── losses.py ├── train_image_to_latent_model.ipynb └── utilities ├── files.py ├── hooks.py └── images.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | .pt 3 | .npy 4 | .ipynb_checkpoints 5 | image_to_latent.pt 6 | dlatents.npy 7 | notes.txt 8 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "InterFaceGAN"] 2 | path = InterFaceGAN 3 | url = git://github.com/ShenYujun/InterFaceGAN.git 4 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleGAN Encoder - Pytorch Implementation 2 | | Reference Image | Latent Optimization | Gender Transformation | Pose Transformation | 3 | |---|---|---|---| 4 | | | | | | 5 | |Reference Image | Age Transformation | Gender Transformation | Glasses Transformation | 6 | | | | | | 7 | 8 | ## Contents 9 | - [Setup](#setup) 10 | - [Usage](#usage) 11 | - [The Image To Latent Model](#the-image-to-latent-model) 12 | - [Explanations](#explanations) 13 | 14 | ## Setup 15 | 1. Submodule Update 16 | ```bash 17 | git submodule update --init --recursive 18 | ``` 19 | 20 | Download the Image To Latent and StyleGAN models from the release on this repo. Place the Image To Latent model at the root of the directory and place the StyleGAN model at ./InterFaceGAN/models/pretrain/. 21 | 22 | ## Usage 23 | Take an image of a face you'd like to modify and align the face by using an align face script. I'd recommend the one [here](https://github.com/Puzer/stylegan-encoder/blob/master/align_images.py). 24 | 25 | Then find the latents for the aligned face by using the encode_image.py script. 26 | ```bash 27 | python encode_image.py 28 | aligned_image.jpg 29 | dlatents.npy # The filepath to save the latents at. 30 | --save_optimized_image true 31 | ``` 32 | 33 | The script will generate a numpy array file with the latents that can then be passed to the edit.py script located in the InterFaceGAN repo. Edit an image by running the edit.py script. 34 | ```bash 35 | python InterFaceGAN/edit.py 36 | -m stylegan_ffhq 37 | -o results 38 | -b InterFaceGAN/boundaries/stylegan_ffhq_pose_boundary.npy # Use any of the boundaries found in the InterFaceGAN repo. 39 | -i dlatents.npy 40 | -s WP 41 | --steps 20 42 | ``` 43 | The resulting script will modify the latents and correspondingly the aligned face with the boundary that you select (pose in the above example). It will save all of the transformed images in the -o directory (./results in the above example). 44 | 45 | If the resulting image is not to your liking, play around with the vgg16 layer that is used and also modify/add loss functions. 46 | 47 | ### Encode Image 48 | ```bash 49 | python encode_image.py -h 50 | 51 | usage: encode_image.py [-h] [--save_optimized_image SAVE_OPTIMIZED_IMAGE] 52 | [--optimized_image_path OPTIMIZED_IMAGE_PATH] 53 | [--video VIDEO] [--video_path VIDEO_PATH] 54 | [--save_frequency SAVE_FREQUENCY] 55 | [--iterations ITERATIONS] [--model_type MODEL_TYPE] 56 | [--learning_rate LEARNING_RATE] [--vgg_layer VGG_LAYER] 57 | [--use_latent_finder USE_LATENT_FINDER] 58 | [--image_to_latent_path IMAGE_TO_LATENT_PATH] 59 | image_path dlatent_path 60 | 61 | Find the latent space representation of an input image. 62 | 63 | positional arguments: 64 | image_path Filepath of the image to be encoded. 65 | dlatent_path Filepath to save the dlatent (WP) at. 66 | 67 | optional arguments: 68 | -h, --help show this help message and exit 69 | --save_optimized_image SAVE_OPTIMIZED_IMAGE 70 | Whether or not to save the image created with the 71 | optimized latents. 72 | --optimized_image_path OPTIMIZED_IMAGE_PATH 73 | The path to save the image created with the optimized 74 | latents. 75 | --video VIDEO Whether or not to save a video of the encoding 76 | process. 77 | --video_path VIDEO_PATH 78 | Where to save the video at. 79 | --save_frequency SAVE_FREQUENCY 80 | How often to save the images to video. Smaller = 81 | Faster. 82 | --iterations ITERATIONS 83 | Number of optimizations steps. 84 | --model_type MODEL_TYPE 85 | The model to use from InterFaceGAN repo. 86 | --learning_rate LEARNING_RATE 87 | Learning rate for SGD. 88 | --vgg_layer VGG_LAYER 89 | The VGG network layer number to extract features from. 90 | --use_latent_finder USE_LATENT_FINDER 91 | Whether or not to use a latent finder to find the 92 | starting latents to optimize from. 93 | --image_to_latent_path IMAGE_TO_LATENT_PATH 94 | The path to the .pt (Pytorch) latent finder model. 95 | 96 | 97 | ``` 98 | 99 | ## The Image To Latent Model 100 | The process of optimizing the latents with strictly just the features extracted by the VGG16 model can be timely and possibly prone to local minima. To combat this problem, we can use another model thats sole goal is to predict the latents of an image. This gives the latent optimizer model a better initilization point to optimize from and helps reduce the amount of time needed for optimization and the likelyhood of getting stuck in a far away minima. 101 | 102 | Here you can see the the images generated with the predicted latents from the Image To Latent Model. 103 | 104 | 105 | ### Usage 106 | The encode_image.py script by default does not use the Image To Latent model, but you can activate it by specifiying the following args when running encode_image.py. Without using an Image To Latent model the encode_image.py script defaults to optimize latents initialized with all zeros. 107 | ```bash 108 | python encode_image.py 109 | aligned_image.jpg 110 | dlatents.npy 111 | --use_latent_finder true # Activates model. 112 | --image_to_latent_path ./image_to_latent.pt # Specifies path to model. 113 | ``` 114 | 115 | ### Training 116 | All of the training is located in the [train_image_to_latent_model.ipynb notebook](https://github.com/jacobhallberg/pytorch_stylegan_encoder/blob/master/train_image_to_latent_model.ipynb). To generate a dataset use the following command. 117 | ```bash 118 | python InterFaceGAN/generate_data.py 119 | -m stylegan_ffhq 120 | -o dataset_directory 121 | -n 50000 122 | -s WP 123 | ``` 124 | This will populate a directory at ./dataset_directory with 50,000 generated faces and a numpy file called wp.npy. You can then load these into the notebook to train a new model. Using more than 50,000 will train a better latent predictor. 125 | 126 | ## Explanations 127 | 128 | ### What is a StyleGAN? 129 | [StyleGAN](https://github.com/NVlabs/stylegan) is a NVIDIA based work that enables the generation of high-quality images representing the image dataset that it was trained on with the ability to control aspects of the image synthesis. 130 | 131 | ### What are latent codes (latents)? 132 | Typically with generative models the latent code acts as input into the generative model and modifying the latent code modifies the output image. StyleGAN uses latent codes, but applies a non-linear transformation to the input latent codes z, creating a learned latent space W which governs the features of the generated output images. 133 | 134 | If you can control the latent space you can control the features of the generated output image. This is the underlying principal that is used to transform faces in this repo. 135 | 136 | 137 | 138 | ###### Image from [StyleGan Paper](https://arxiv.org/pdf/1812.04948.pdf) figure 1. 139 | 140 | ### What are the benifits of using a mapping network to create a latent space? 141 | In traditional GAN architecture, the input vector z is sampled from a uniform gaussian distribution. The issue with sampling from a uniform guassian distribution and then generating images from the sampled vectors z, is that if the features of your data distrubtion do not follow a guassian distribution the sampled vectors z contain features that never existed in your data distrubtion. This causes the generator to generate images with features never seen in your data distribution. 142 | 143 | | Actual Feature Distribution | Gaussian Feature Distribution | 144 | |---|---| 145 | | | | 146 | 147 | ###### Latent colored blocks from [StyleGan Paper](https://arxiv.org/pdf/1812.04948.pdf) figure 6. 148 | 149 | 150 | For example, above shows the actual feature distribution of some data and the feature distribtuion of data sampled from a uniform gaussian distribution. In the above case, the actual distribution of data does not contain males with long hair, but the sampled vector z from a gaussian distribution will generate images of males with long hair. 151 | 152 | This is where StyleGAN shines. The mapping network doesn't have to map the vectors z into a gaussian distribution because the mapping network is learned through the data itself. Meaning the mapping network is able to produce a latent space W that can better represent the features seen in the data by taking in a uniform distribution of data z and mapping it to a distribution that contains gaps w. 153 | 154 | | Actual Feature Distribution | Mapping Network Feature Distribution | 155 | |---|---| 156 | | | | 157 | 158 | Additionally, with StyleGAN the image creation starts from a constant vector that is optimized during the training process. This constant vector acts as a seed for the GAN and the mapped vectors w are passed into the convolutional layers within the GAN through adaptive instance normalization (AdaIN). This takes away the responsiblity of the GAN having to learn how to warp a uniform distribution into one that represents the data and allows it to simply focus on generating images. All of these aspects together allow for very high quality image generation. 159 | 160 | 161 | ### How do latents (latent space) make it easier to modify an image? 162 | A vector within the latent space W from the mapping network represents a fixed image with fixed features. If you take that vector and shift it across an axis, you modify the features of the image. If modified in soley the direction of a specific feature within the latent space W, everything about the image stays the same besides the feature that the vector (latent) is being shifted towards. 163 | 164 | To make this more clear, imagine a vector that represents a male with short hair within the latent space W. If you'd like to keep the short hair, but generate a female version of a male, all you need to do is shift the vector in the direction of female without changing the direction of the length of hair. 165 | 166 | | Male with Short Hair | Male Transformed To Female with Short Hair | 167 | |---|---| 168 | | | | 169 | 170 | This can be done with any discoverable feature within the latent space. For example, age and gender. 171 | 172 | | Reference Image | Younger | Older | Transformation | 173 | |---|---|---|---| 174 | | | | | | 175 | 176 | | Reference Image | More Feminine | More Masculine | Transformation | 177 | |---|---|---|---| 178 | | | | | | 179 | 180 | What you may notice from these transformations is that features are not completely independent. When changing one feature you often change many other dependent features. 181 | 182 | ### Okay, we have a query image we want to modify. How do we get the latent representation of that query image so that we can modify it? 183 | The first step that you may think of is to just compare a random generated image from the GAN with your query image with a loss function like mean squared error (MSE). Afterwards, use gradient decent to optimize the latent values of the random image until the generated image matches your query image. 184 | 185 | 186 | 187 | The issue with this is that it turns out to be really difficult to optimize from pixel differences between images without a specialised loss function. 188 | 189 | To get around this issue, instead of comparing pixel-wise you can compare feature-wise by extracting the features of both images through a pretrained feature extractor like VGG16 and forgoing the use of the final fully-connected classification layers. Featurewise optimization works much better in practice with simple loss functions like MSE. 190 | 191 | 192 | 193 | What if instead of a random latent vector as a starting point, we could speed up the optimization process by making a really good guess as to what the query image's latent vector is? This is where a machine learning model called the Image To Latent model comes in. I've talked about it briefly [here](#the-image-to-latent-model) 194 | 195 | ### How do we discover features within a latent space to modify latent representations? 196 | Simply put, one can use support vector machines or other classifiers to discover a seperating hyperplane that within the latent space seperates the features of interest from other features. To edit the face you then can take the normal of the hyperplane and travel in the direction of the normal. This in returns modifies the latent code of the query image and in return modifies the generated image. 197 | 198 | This is all talked about [here](https://github.com/ShenYujun/InterFaceGAN). 199 | 200 | The purpose of this reposity is to create the latent space representation of a query image and not the editing as the editing is a relatively simple process. 201 | -------------------------------------------------------------------------------- /assets/images/commands.txt: -------------------------------------------------------------------------------- 1 | convert -delay 50 -loop 0 *.jpg -morph 5 test_01_w_to_m.gif 2 | -------------------------------------------------------------------------------- /assets/images/feature_optimization.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/feature_optimization.png -------------------------------------------------------------------------------- /assets/images/image_to_latent_predictions.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/image_to_latent_predictions.png -------------------------------------------------------------------------------- /assets/images/initial_method.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/initial_method.png -------------------------------------------------------------------------------- /assets/images/latent_benefits.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/latent_benefits.png -------------------------------------------------------------------------------- /assets/images/latent_difference.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/latent_difference.png -------------------------------------------------------------------------------- /assets/images/male_female_actual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/male_female_actual.png -------------------------------------------------------------------------------- /assets/images/male_female_gaussian.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/male_female_gaussian.png -------------------------------------------------------------------------------- /assets/images/male_female_mapping.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/male_female_mapping.png -------------------------------------------------------------------------------- /assets/images/point_female.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/point_female.png -------------------------------------------------------------------------------- /assets/images/point_male.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/point_male.png -------------------------------------------------------------------------------- /assets/images/test_01/gender/000_003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/gender/000_003.jpg -------------------------------------------------------------------------------- /assets/images/test_01/gender/000_004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/gender/000_004.jpg -------------------------------------------------------------------------------- /assets/images/test_01/gender/000_005.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/gender/000_005.jpg -------------------------------------------------------------------------------- /assets/images/test_01/gender/000_006.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/gender/000_006.jpg -------------------------------------------------------------------------------- /assets/images/test_01/gender/000_007.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/gender/000_007.jpg -------------------------------------------------------------------------------- /assets/images/test_01/gender/test_01_w_to_m.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/gender/test_01_w_to_m.gif -------------------------------------------------------------------------------- /assets/images/test_01/pose/000_002.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/pose/000_002.jpg -------------------------------------------------------------------------------- /assets/images/test_01/pose/000_003.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/pose/000_003.jpg -------------------------------------------------------------------------------- /assets/images/test_01/pose/000_004.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/pose/000_004.jpg -------------------------------------------------------------------------------- /assets/images/test_01/pose/test_01_pose.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/pose/test_01_pose.gif -------------------------------------------------------------------------------- /assets/images/test_01/test_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/test_01.png -------------------------------------------------------------------------------- /assets/images/test_01/test_01_optimization.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/test_01_optimization.gif -------------------------------------------------------------------------------- /assets/images/test_01/test_01_optimized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_01/test_01_optimized.png -------------------------------------------------------------------------------- /assets/images/test_02/age/old.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/age/old.jpg -------------------------------------------------------------------------------- /assets/images/test_02/age/test_02_age.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/age/test_02_age.gif -------------------------------------------------------------------------------- /assets/images/test_02/age/young.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/age/young.jpg -------------------------------------------------------------------------------- /assets/images/test_02/gender/female.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/gender/female.jpg -------------------------------------------------------------------------------- /assets/images/test_02/gender/male.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/gender/male.jpg -------------------------------------------------------------------------------- /assets/images/test_02/gender/test_02_gender.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/gender/test_02_gender.gif -------------------------------------------------------------------------------- /assets/images/test_02/glasses/000_010.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/glasses/000_010.jpg -------------------------------------------------------------------------------- /assets/images/test_02/glasses/000_018.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/glasses/000_018.jpg -------------------------------------------------------------------------------- /assets/images/test_02/glasses/test_02_glasses.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/glasses/test_02_glasses.gif -------------------------------------------------------------------------------- /assets/images/test_02/test_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/test_02.jpg -------------------------------------------------------------------------------- /assets/images/test_02/test_02.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jacobhallberg/pytorch_stylegan_encoder/0a300de6ce8daa8279d47cc283ba490faacdf960/assets/images/test_02/test_02.npy -------------------------------------------------------------------------------- /encode_image.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from tqdm import tqdm 3 | import numpy as np 4 | import torch 5 | from InterFaceGAN.models.stylegan_generator import StyleGANGenerator 6 | from models.latent_optimizer import LatentOptimizer 7 | from models.image_to_latent import ImageToLatent 8 | from models.losses import LatentLoss 9 | from utilities.hooks import GeneratedImageHook 10 | from utilities.images import load_images, images_to_video, save_image 11 | from utilities.files import validate_path 12 | 13 | parser = argparse.ArgumentParser(description="Find the latent space representation of an input image.") 14 | parser.add_argument("image_path", help="Filepath of the image to be encoded.") 15 | parser.add_argument("dlatent_path", help="Filepath to save the dlatent (WP) at.") 16 | 17 | parser.add_argument("--save_optimized_image", default=False, help="Whether or not to save the image created with the optimized latents.", type=bool) 18 | parser.add_argument("--optimized_image_path", default="optimized.png", help="The path to save the image created with the optimized latents.", type=str) 19 | parser.add_argument("--video", default=False, help="Whether or not to save a video of the encoding process.", type=bool) 20 | parser.add_argument("--video_path", default="video.avi", help="Where to save the video at.", type=str) 21 | parser.add_argument("--save_frequency", default=10, help="How often to save the images to video. Smaller = Faster.", type=int) 22 | parser.add_argument("--iterations", default=1000, help="Number of optimizations steps.", type=int) 23 | parser.add_argument("--model_type", default="stylegan_ffhq", help="The model to use from InterFaceGAN repo.", type=str) 24 | parser.add_argument("--learning_rate", default=1, help="Learning rate for SGD.", type=int) 25 | parser.add_argument("--vgg_layer", default=12, help="The VGG network layer number to extract features from.", type=int) 26 | parser.add_argument("--use_latent_finder", default=False, help="Whether or not to use a latent finder to find the starting latents to optimize from.", type=bool) 27 | parser.add_argument("--image_to_latent_path", default="image_to_latent.pt", help="The path to the .pt (Pytorch) latent finder model.", type=str) 28 | 29 | args, other = parser.parse_known_args() 30 | 31 | def optimize_latents(): 32 | print("Optimizing Latents.") 33 | synthesizer = StyleGANGenerator(args.model_type).model.synthesis 34 | latent_optimizer = LatentOptimizer(synthesizer, args.vgg_layer) 35 | 36 | # Optimize only the dlatents. 37 | for param in latent_optimizer.parameters(): 38 | param.requires_grad_(False) 39 | 40 | if args.video or args.save_optimized_image: 41 | # Hook, saves an image during optimization to be used to create video. 42 | generated_image_hook = GeneratedImageHook(latent_optimizer.post_synthesis_processing, args.save_frequency) 43 | 44 | reference_image = load_images([args.image_path]) 45 | reference_image = torch.from_numpy(reference_image).cuda() 46 | reference_image = latent_optimizer.vgg_processing(reference_image) 47 | reference_features = latent_optimizer.vgg16(reference_image).detach() 48 | reference_image = reference_image.detach() 49 | 50 | if args.use_latent_finder: 51 | image_to_latent = ImageToLatent().cuda() 52 | image_to_latent.load_state_dict(torch.load(args.image_to_latent_path)) 53 | image_to_latent.eval() 54 | 55 | latents_to_be_optimized = image_to_latent(reference_image) 56 | latents_to_be_optimized = latents_to_be_optimized.detach().cuda().requires_grad_(True) 57 | else: 58 | latents_to_be_optimized = torch.zeros((1,18,512)).cuda().requires_grad_(True) 59 | 60 | criterion = LatentLoss() 61 | optimizer = torch.optim.SGD([latents_to_be_optimized], lr=args.learning_rate) 62 | 63 | progress_bar = tqdm(range(args.iterations)) 64 | for step in progress_bar: 65 | optimizer.zero_grad() 66 | 67 | generated_image_features = latent_optimizer(latents_to_be_optimized) 68 | 69 | loss = criterion(generated_image_features, reference_features) 70 | loss.backward() 71 | loss = loss.item() 72 | 73 | optimizer.step() 74 | progress_bar.set_description("Step: {}, Loss: {}".format(step, loss)) 75 | 76 | optimized_dlatents = latents_to_be_optimized.detach().cpu().numpy() 77 | np.save(args.dlatent_path, optimized_dlatents) 78 | 79 | if args.video: 80 | images_to_video(generated_image_hook.get_images(), args.video_path) 81 | if args.save_optimized_image: 82 | save_image(generated_image_hook.last_image, args.optimized_image_path) 83 | 84 | def main(): 85 | assert(validate_path(args.image_path, "r")) 86 | assert(validate_path(args.dlatent_path, "w")) 87 | assert(1 <= args.vgg_layer <= 16) 88 | if args.video: assert(validate_path(args.video_path, "w")) 89 | if args.save_optimized_image: assert(validate_path(args.optimized_image_path, "w")) 90 | if args.use_latent_finder: assert(validate_path(args.image_to_latent_path, "r")) 91 | 92 | optimize_latents() 93 | 94 | if __name__ == "__main__": 95 | main() 96 | 97 | 98 | 99 | 100 | 101 | 102 | -------------------------------------------------------------------------------- /models/image_to_latent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models import resnet50 3 | from PIL import Image 4 | import numpy as np 5 | 6 | class ImageToLatent(torch.nn.Module): 7 | def __init__(self, image_size=256): 8 | super().__init__() 9 | 10 | self.image_size = image_size 11 | self.activation = torch.nn.ELU() 12 | 13 | self.resnet = list(resnet50(pretrained=True).children())[:-2] 14 | self.resnet = torch.nn.Sequential(*self.resnet) 15 | self.conv2d = torch.nn.Conv2d(2048, 256, kernel_size=1) 16 | self.flatten = torch.nn.Flatten() 17 | self.dense1 = torch.nn.Linear(16384, 256) 18 | self.dense2 = torch.nn.Linear(256, (18 * 512)) 19 | 20 | def forward(self, image): 21 | x = self.resnet(image) 22 | x = self.conv2d(x) 23 | x = self.flatten(x) 24 | x = self.dense1(x) 25 | x = self.dense2(x) 26 | x = x.view((-1, 18, 512)) 27 | 28 | return x 29 | 30 | class ImageLatentDataset(torch.utils.data.Dataset): 31 | def __init__(self, filenames, dlatents, image_size=256, transforms = None): 32 | self.filenames = filenames 33 | self.dlatents = dlatents 34 | self.image_size = image_size 35 | self.transforms = transforms 36 | 37 | def __len__(self): 38 | return len(self.filenames) 39 | 40 | def __getitem__(self, index): 41 | filename = self.filenames[index] 42 | dlatent = self.dlatents[index] 43 | 44 | image = self.load_image(filename) 45 | image = Image.fromarray(np.uint8(image)) 46 | 47 | if self.transforms: 48 | image = self.transforms(image) 49 | 50 | return image, dlatent 51 | 52 | def load_image(self, filename): 53 | image = np.asarray(Image.open(filename)) 54 | 55 | return image 56 | -------------------------------------------------------------------------------- /models/latent_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from torchvision.models import vgg16 3 | import torch 4 | 5 | class PostSynthesisProcessing(torch.nn.Module): 6 | def __init__(self): 7 | super().__init__() 8 | 9 | self.min_value = -1 10 | self.max_value = 1 11 | 12 | def forward(self, synthesized_image): 13 | synthesized_image = (synthesized_image - self.min_value) * torch.tensor(255).float() / (self.max_value - self.min_value) 14 | synthesized_image = torch.clamp(synthesized_image + 0.5, min=0, max=255) 15 | 16 | return synthesized_image 17 | 18 | class VGGProcessing(torch.nn.Module): 19 | def __init__(self): 20 | super().__init__() 21 | 22 | self.image_size = 256 23 | self.mean = torch.tensor([0.485, 0.456, 0.406], device="cuda").view(-1, 1, 1) 24 | self.std = torch.tensor([0.229, 0.224, 0.225], device="cuda").view(-1, 1, 1) 25 | 26 | def forward(self, image): 27 | image = image / torch.tensor(255).float() 28 | image = F.adaptive_avg_pool2d(image, self.image_size) 29 | 30 | image = (image - self.mean) / self.std 31 | 32 | return image 33 | 34 | 35 | class LatentOptimizer(torch.nn.Module): 36 | def __init__(self, synthesizer, layer=12): 37 | super().__init__() 38 | 39 | self.synthesizer = synthesizer.cuda().eval() 40 | self.post_synthesis_processing = PostSynthesisProcessing() 41 | self.vgg_processing = VGGProcessing() 42 | self.vgg16 = vgg16(pretrained=True).features[:layer].cuda().eval() 43 | 44 | 45 | def forward(self, dlatents): 46 | generated_image = self.synthesizer(dlatents) 47 | generated_image = self.post_synthesis_processing(generated_image) 48 | generated_image = self.vgg_processing(generated_image) 49 | features = self.vgg16(generated_image) 50 | 51 | return features 52 | -------------------------------------------------------------------------------- /models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class LatentLoss(torch.nn.Module): 4 | def __init__(self): 5 | super().__init__() 6 | self.l1_loss = L1Loss() 7 | self.log_cosh_loss = LogCoshLoss() 8 | self.l2_loss = torch.nn.MSELoss() 9 | 10 | def forward(self, real_features, generated_features, average_dlatents = None, dlatents = None): 11 | # Take a look at: 12 | # https://github.com/pbaylies/stylegan-encoder/blob/master/encoder/perceptual_model.py 13 | # For additional losses and practical scaling factors. 14 | 15 | loss = 0 16 | # Possible TODO: Add more feature based loss functions to create better optimized latents. 17 | 18 | # Modify scaling factors or disable losses to get best result (Image dependent). 19 | 20 | # VGG16 Feature Loss 21 | # Absolute vs MSE Loss 22 | # loss += 1 * self.l1_loss(real_features, generated_features) 23 | loss += 1 * self.l2_loss(real_features, generated_features) 24 | 25 | # Pixel Loss 26 | # loss += 1.5 * self.log_cosh_loss(real_image, generated_image) 27 | 28 | # Dlatent Loss - Forces latents to stay near the space the model uses for faces. 29 | if average_dlatents is not None and dlatents is not None: 30 | loss += 1 * 512 * self.l1_loss(average_dlatents, dlatents) 31 | 32 | return loss 33 | 34 | class LogCoshLoss(torch.nn.Module): 35 | def __init__(self): 36 | super().__init__() 37 | 38 | def forward(self, true, pred): 39 | loss = true - pred 40 | return torch.mean(torch.log(torch.cosh(loss + 1e-12))) 41 | 42 | class L1Loss(torch.nn.Module): 43 | def __init__(self): 44 | super().__init__() 45 | 46 | def forward(self, true, pred): 47 | return torch.mean(torch.abs(true - pred)) 48 | -------------------------------------------------------------------------------- /utilities/files.py: -------------------------------------------------------------------------------- 1 | def validate_path(path, op): 2 | try: 3 | open(path, op) 4 | return True 5 | except: 6 | return False 7 | -------------------------------------------------------------------------------- /utilities/hooks.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | class GeneratedImageHook: 4 | # Pytorch forward pass module hook. 5 | 6 | def __init__(self, module, every_n=10): 7 | self.generated_images = [] 8 | self.count = 1 9 | self.every_n = every_n 10 | self.last_image = None 11 | 12 | self.hook = module.register_forward_hook(self.save_generated_image) 13 | 14 | def save_generated_image(self, module, input, output): 15 | image = output.detach().cpu().numpy()[0] 16 | if self.count % self.every_n == 0: 17 | self.generated_images.append(image) 18 | self.count = 0 19 | 20 | self.last_image = image 21 | self.count += 1 22 | 23 | def close(self): 24 | self.hook.remove() 25 | 26 | def get_images(self): 27 | return self.generated_images 28 | -------------------------------------------------------------------------------- /utilities/images.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | from PIL import Image 4 | 5 | def load_images(filenames): 6 | # Images must all be of same shape. 7 | images = [] 8 | for filename in filenames: 9 | temp_image = np.asarray(Image.open(filename)) 10 | 11 | # Adjust channel dimension to work with torch. 12 | temp_image = np.transpose(temp_image, (2,0,1)) 13 | images.append(temp_image) 14 | 15 | return np.array(images) 16 | 17 | def images_to_video(images, save_path, image_size = 1024): 18 | size = (image_size, image_size) 19 | fps = 10 20 | video = cv2.VideoWriter(save_path, cv2.VideoWriter_fourcc(*'DIVX'), fps, size) 21 | 22 | for i, image in enumerate(images): 23 | # Channel, width, height -> width, height, channel, then RGB to BGR 24 | image = np.transpose(image, (1,2,0)) 25 | image = image[:,:,::-1] 26 | video.write(image.astype(np.uint8)) 27 | 28 | video.release() 29 | 30 | def save_image(image, save_path): 31 | image = np.transpose(image, (1,2,0)).astype(np.uint8) 32 | image = Image.fromarray(image) 33 | image.save(save_path) 34 | --------------------------------------------------------------------------------