├── LICENSE ├── README.md ├── taki0112_reshape_progressive.py ├── nvlabs_to_taki0112.py ├── Convert_NVlabs_StyleGAN_pkl_to_taki0112_checkpoint.ipynb └── taki0112_to_nvlabs.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 aydao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # stylegan-convert-architecture 2 | 3 | This repository contains a script that can convert any @NVlabs StyleGAN pkl into the analogous architecture that uses vanilla Tensorflow checkpoint, courtesy of @taki0112 ([in this excellent repository](https://github.com/taki0112/StyleGAN-Tensorflow)). Note: this script assumes that the input pkl is a model that has finished progressive growing, and that it will not grow any more, which is the typical case for transfer learning on a StyleGAN. 4 | 5 | **Why would anyone be interested in converting this?** For one, the NVlabs code requires using their dataset_tool script which must expand the source dataset by at least 10x. The taki0112 implemenation, however, works directly on raw images. Secondly, the NVlabs code requires their "dnnlib" codebase, while taki0112 simply uses a single clean and readable Tensorflow file. Finally, NVlabs released their StyleGAN code and trained models under an [Attribution-NonCommercial 4.0 International](https://github.com/NVlabs/stylegan/blob/master/LICENSE.txt) license, while taki0112 uses the more permissive [MIT](https://github.com/taki0112/StyleGAN-Tensorflow/blob/master/LICENSE) license. 6 | 7 | **Recommended usage:** let's say you've trained a NVlabs StyleGAN and want to transfer/retrain on a considerably larger dataset. Using this script, you can copy over your learned weights and begin training using [this code](https://github.com/taki0112/StyleGAN-Tensorflow) on a dataset up to ~10x larger. 8 | -------------------------------------------------------------------------------- /taki0112_reshape_progressive.py: -------------------------------------------------------------------------------- 1 | # 2 | # aydao ~ 2019 3 | # 4 | # Use this if you started training a taki0112 model from scratch 5 | # First, run this script first (example usage below) 6 | # This will copy over the layers from your original model into another suited for the next script 7 | # The script will drop a taki0112 checkpoint with its counter set to "2" (chosen arbitrarily by me) 8 | # At that point, you can run other taki0112_to_nvlabs script to convert it to the nvlabs.pkl 9 | # 10 | # 11 | from StyleGAN import StyleGAN 12 | import argparse 13 | from utils import * 14 | import os 15 | import pickle 16 | import numpy as np 17 | import PIL.Image 18 | import copy 19 | import dnnlib 20 | import dnnlib.tflib as tflib 21 | from dnnlib import EasyDict 22 | import config 23 | from metrics import metric_base 24 | from training.training_loop import process_reals 25 | from training import misc 26 | import tensorflow as tf 27 | import sys, time 28 | 29 | """parsing and configuration""" 30 | def parse_args(): 31 | desc = "Convert taki0112 StyleGAN checkpoint to NVlabs StyleGAN pkl (copies over the model weights)" 32 | parser = argparse.ArgumentParser(description=desc) 33 | 34 | parser.add_argument("--cache_dir", type=str, default="./", 35 | help="The cache directory to save the NVlabs pkl in") 36 | 37 | parser.add_argument("--dataset", type=str, default="FFHQ", 38 | help="The dataset name what you want to generate") 39 | 40 | parser.add_argument("--gpu_num", type=int, default=1, help="The number of gpu") 41 | 42 | parser.add_argument("--start_res", type=int, default=8, help="The number of starting resolution") 43 | 44 | parser.add_argument("--img_size", type=int, default=1024, help="The target size of image") 45 | parser.add_argument("--progressive", type=str2bool, default=True, help="use progressive training") 46 | 47 | parser.add_argument("--checkpoint_dir", type=str, default="./checkpoint/", 48 | help="Directory name to save the checkpoints") 49 | 50 | parser.add_argument("--result_dir", type=str, default="./results/", 51 | help="Directory name to save the generated images") 52 | 53 | # Extra args for taki0112 code 54 | # Do not change these args, the code will automatically set them for you 55 | # parser.add_argument("--progressive", help=argparse.SUPPRESS) 56 | parser.add_argument("--phase", help=argparse.SUPPRESS) 57 | parser.add_argument("--sn", help=argparse.SUPPRESS) 58 | 59 | args = parser.parse_args() 60 | args = check_args(args) 61 | args = handle_extra_args(args) 62 | return args 63 | 64 | """checking arguments""" 65 | def check_args(args): 66 | 67 | # --checkpoint_dir 68 | check_folder(args.checkpoint_dir) 69 | 70 | # --result_dir 71 | check_folder(args.result_dir) 72 | 73 | return args 74 | 75 | """handle the extra args based on required args""" 76 | def handle_extra_args(args): 77 | args.phase = "train" 78 | # assuming you are not using spectral normalization since NVlabs does not use it 79 | # you probably *could* figure out how to convert with sn, but it'd take some more tinkering 80 | args.sn = False 81 | # preempt the taki0112 code from making needless directories... 82 | args.sample_dir = args.result_dir 83 | args.log_dir = args.result_dir 84 | # magic values not needed for this script 85 | args.iteration = 0 86 | args.max_iteration = 2500 87 | args.batch_size = 1 88 | args.test_num = 1 89 | args.seed = True 90 | return args 91 | 92 | def make_temp_dataset_file(dataset_dir): 93 | filename = dataset_dir + "/nvlabs_to_taki0112_tempfile.png" 94 | os.makedirs(os.path.dirname(filename), exist_ok=True) 95 | with open(filename, "w") as f: 96 | f.write("this file is safe to delete") 97 | # give it a second 98 | time.sleep(1) 99 | return filename 100 | 101 | def delete_temp_dataset_file(args, dataset_dir, filename): 102 | try: 103 | os.remove(filename) 104 | if len(os.listdir(dataset_dir)) == 0: 105 | os.rmdir(dataset_dir) 106 | except OSError as e: 107 | print ("Error: %s - %s." % (e.filename, e.strerror)) 108 | 109 | """main""" 110 | def main(): 111 | # 112 | # 113 | # Example usage: python taki0112_reshape_progresive.py --dataset FFHQ --start_res 8 --img_size 512 114 | # 115 | # 116 | # parse arguments 117 | args = parse_args() 118 | if args is None: 119 | exit() 120 | 121 | checkpoint_dir = args.checkpoint_dir 122 | 123 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 124 | 125 | # this is a hack since the taki0112 expects a dataset folder which may not exist 126 | dataset = args.dataset 127 | dataset_dir = "./dataset/" + dataset 128 | temp_dataset_file = make_temp_dataset_file(dataset_dir) 129 | 130 | # build the taki0112 StyleGAN architecture (vanilla Tensorflow) 131 | gan = StyleGAN(sess, args) 132 | 133 | # you have to go through this process to initialize everything needed to load the checkpoint... 134 | gan.build_model() 135 | 136 | # remove the temp file and the directory if it is empty 137 | delete_temp_dataset_file(args, dataset_dir, temp_dataset_file) 138 | 139 | # Initialize TensorFlow. 140 | tflib.init_tf() 141 | 142 | tf.global_variables_initializer().run() 143 | gan.saver = tf.train.Saver(max_to_keep=10) 144 | gan.load(checkpoint_dir) 145 | 146 | copy_layers = [] 147 | vars = tf.trainable_variables("discriminator") 148 | vars_vals = sess.run(vars) 149 | for var, val in zip(vars, vars_vals): 150 | copy_layers.append((var.name,val)) 151 | vars = tf.trainable_variables("generator") 152 | vars_vals = sess.run(vars) 153 | for var, val in zip(vars, vars_vals): 154 | copy_layers.append((var.name,val)) 155 | 156 | return args, copy_layers 157 | 158 | def copy_over(args, copy_layers): 159 | tf.reset_default_graph() 160 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 161 | checkpoint_dir = args.checkpoint_dir 162 | args.progressive = True 163 | dataset = args.dataset 164 | dataset_dir = "./dataset/" + dataset 165 | temp_dataset_file = make_temp_dataset_file(dataset_dir) 166 | gan2 = StyleGAN(sess, args) 167 | gan2.build_model() 168 | delete_temp_dataset_file(args, dataset_dir, temp_dataset_file) 169 | tflib.init_tf() 170 | tf.global_variables_initializer().run() 171 | gan2.saver = tf.train.Saver(max_to_keep=10) 172 | gan2.load(checkpoint_dir) 173 | 174 | update_layers = [] 175 | variables = [] 176 | variables += tf.trainable_variables("discriminator") 177 | variables += tf.trainable_variables("generator") 178 | 179 | variable_dict = {} 180 | for variable in variables: 181 | variable_name = variable.name 182 | variable_dict[variable_name] = variable 183 | 184 | for copy_layer in copy_layers: 185 | copy_name, copy_value = copy_layer 186 | variable = variable_dict[copy_name] 187 | update_layer = tf.assign(variable, copy_value) 188 | update_layers.append(update_layer) 189 | 190 | sess.run(update_layers) 191 | 192 | # just picking 2 as the counter for the taki model 193 | counter = 2 194 | gan2.save(checkpoint_dir, counter) 195 | 196 | if __name__ == "__main__": 197 | args, copy_layers = main() 198 | copy_over(args, copy_layers) 199 | -------------------------------------------------------------------------------- /nvlabs_to_taki0112.py: -------------------------------------------------------------------------------- 1 | # 2 | # aydao ~ 2019 3 | # 4 | # Convert network-snapshot-######.pkl StyleGAN models to more general tensorflow checkpoints 5 | # 6 | # This script relies on both the NVlabs StyleGAN GitHub repository and the taki0112 StyleGAN repo 7 | # It assumes both are in the same directory as this script 8 | # 9 | from StyleGAN import StyleGAN 10 | import argparse 11 | from utils import * 12 | import os 13 | import pickle 14 | import numpy as np 15 | import PIL.Image 16 | import dnnlib 17 | import dnnlib.tflib as tflib 18 | import config 19 | import tensorflow as tf 20 | import sys, time 21 | 22 | """parsing and configuration""" 23 | def parse_args(): 24 | desc = "Convert NVlabs StyleGAN pkl to taki0112 StyleGAN checkpoint (copies over the model weights)" 25 | parser = argparse.ArgumentParser(description=desc) 26 | 27 | parser.add_argument("--nvlabs", type=str, help="The source NVlabs StyleGAN, a network .pkl file") 28 | 29 | parser.add_argument("--dataset", type=str, default="FFHQ", 30 | help="The dataset name what you want to generate") 31 | 32 | parser.add_argument("--gpu_num", type=int, default=1, help="The number of gpu") 33 | 34 | # parser.add_argument("--sn", type=str2bool, default=False, help="use spectral normalization") 35 | 36 | parser.add_argument("--img_size", type=int, default=1024, help="The target size of image") 37 | 38 | parser.add_argument("--checkpoint_dir", type=str, default="./checkpoint/", 39 | help="Directory name to save the checkpoints") 40 | 41 | parser.add_argument("--result_dir", type=str, default="./results/", 42 | help="Directory name to save the generated images") 43 | 44 | # Extra args for taki0112 code 45 | # Do not change these args, the code will automatically set them for you 46 | parser.add_argument("--start_res", help=argparse.SUPPRESS) 47 | parser.add_argument("--sn", help=argparse.SUPPRESS) 48 | parser.add_argument("--progressive", help=argparse.SUPPRESS) 49 | parser.add_argument("--phase", help=argparse.SUPPRESS) 50 | 51 | args = parser.parse_args() 52 | args = check_args(args) 53 | args = handle_extra_args(args) 54 | return args 55 | 56 | """checking arguments""" 57 | def check_args(args): 58 | 59 | # --checkpoint_dir 60 | check_folder(args.checkpoint_dir) 61 | 62 | # --result_dir 63 | check_folder(args.result_dir) 64 | 65 | return args 66 | 67 | """handle the extra args based on required args""" 68 | def handle_extra_args(args): 69 | # these are things you either do not need for this script, or can be set automatically 70 | args.start_res = args.img_size 71 | # assuming you are transferring from a progressive NVlabs StyleGAN model 72 | # though when you continue training you likely do not want it progressive any more 73 | args.progressive = True 74 | # NVlabs didn't use spectral normalization. 75 | # Arguably you could turn it on here, but it might require much more training... 76 | args.sn = False 77 | args.phase = "train" 78 | # preempt the taki0112 code from making needless directories... 79 | args.sample_dir = args.result_dir 80 | args.log_dir = args.result_dir 81 | # magic values not needed for this script 82 | args.iteration = 0 83 | args.max_iteration = 2500 84 | args.batch_size = 1 85 | args.test_num = 1 86 | args.seed = True 87 | return args 88 | 89 | def make_temp_dataset_file(dataset_dir): 90 | filename = dataset_dir + "/nvlabs_to_taki0112_tempfile.png" 91 | os.makedirs(os.path.dirname(filename), exist_ok=True) 92 | with open(filename, "w") as f: 93 | f.write("this file is safe to delete") 94 | # give it a second 95 | time.sleep(1) 96 | return filename 97 | 98 | def delete_temp_dataset_file(args, dataset_dir, filename): 99 | try: 100 | os.remove(filename) 101 | if len(os.listdir(dataset_dir)) == 0: 102 | os.rmdir(dataset_dir) 103 | except OSError as e: 104 | print ("Error: %s - %s." % (e.filename, e.strerror)) 105 | 106 | """main""" 107 | def main(): 108 | # 109 | # Usage example: 110 | # python [this_file].py --nvlabs ./cache/karras2019stylegan-ffhq-1024x1024.pkl 111 | # --dataset FFHQ --img_size 1024 --gpu_num 2 112 | # 113 | # parse arguments 114 | args = parse_args() 115 | if args is None: 116 | exit() 117 | 118 | checkpoint_dir = args.checkpoint_dir 119 | nvlabs_stylegan_pkl_name = args.nvlabs 120 | # the taki0112 StyleGAN models expect the following naming prefix 121 | model_name = "StyleGAN.model" 122 | 123 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 124 | 125 | 126 | # this is a hack since the taki0112 expects a dataset folder which may not exist 127 | dataset = args.dataset 128 | dataset_dir = "./dataset/" + dataset 129 | temp_dataset_file = make_temp_dataset_file(dataset_dir) 130 | 131 | # build the taki0112 StyleGAN architecture (vanilla Tensorflow) 132 | gan = StyleGAN(sess, args) 133 | gan.build_model() 134 | 135 | # remove the temp file and the directory if it is empty 136 | delete_temp_dataset_file(args, dataset_dir, temp_dataset_file) 137 | 138 | # now that a progressive model is built, turn off progressive 139 | # the progressive structure is needed to copy over from NVlabs, 140 | # but you won't need this in training later 141 | # basically, this just saves this out in a folder without 142 | # the _progressive tag in the directory name 143 | # this shouldn't cause any weird side effects. Probably. 144 | args.progressive = False 145 | gan.progressive = False 146 | 147 | tflib.init_tf() 148 | 149 | gan.sess = sess 150 | tf.global_variables_initializer().run() 151 | gan.saver = tf.train.Saver(max_to_keep=10) 152 | gan.writer = tf.summary.FileWriter(gan.log_dir + "/" + gan.model_dir, gan.sess.graph) 153 | 154 | counter = 0 155 | gan.save(checkpoint_dir, counter) 156 | # Or, you can use this instead if you want to base the taki0112 network on an existing one 157 | # gan.load(checkpoint_dir) 158 | 159 | # Now moving to NVlabs code 160 | src_d = "D" 161 | dst_d = "discriminator" 162 | src_gs = "G_synthesis_1" # "G_synthesis" 163 | dst_gs = "generator/g_synthesis" 164 | src_gm = "G_mapping_1" # "G_mapping" 165 | dst_gm = "generator/g_mapping" 166 | 167 | # Load the existing NVlabs StyleGAN network 168 | G, D, Gs = pickle.load(open(nvlabs_stylegan_pkl_name, "rb")) 169 | 170 | vars = tf.trainable_variables(src_gm) 171 | vars_vals = sess.run(vars) 172 | 173 | # Copy over the discriminator weights 174 | for (new, old) in zip(tf.trainable_variables(dst_d), tf.trainable_variables(src_d)): 175 | update_weight = [tf.assign(new, old)] 176 | sess.run(update_weight) 177 | temp_vals = sess.run([new, old]) 178 | 179 | # Copy over the Generator's mapping network weights 180 | for (new, old) in zip(tf.trainable_variables(dst_gm), tf.trainable_variables(src_gm)): 181 | update_weight = [tf.assign(new, old)] 182 | sess.run(update_weight) 183 | temp_vals = sess.run([new, old]) 184 | 185 | # Because the two network architectures use slightly different columns on one variable, 186 | # you must set up code to handle the edge case transpose of the first case 187 | first = True 188 | for (new, old) in zip(tf.trainable_variables(dst_gs), tf.trainable_variables(src_gs)): 189 | temp_vals = sess.run([new, old]) 190 | if new.shape != old.shape: 191 | # you need a transpose with perm # old = tf.reshape(old, tf.shape(new)) 192 | # DO NOT USE RESHAPE 193 | # (made this mistake here and the results work but are quite terrifying) 194 | if (first): 195 | first = False 196 | old = tf.transpose(old, perm=[0, 2, 3, 1]) 197 | else: 198 | old = tf.transpose(old, perm=[0, 1, 3, 2]) 199 | update_weight = [tf.assign(new, old)] 200 | sess.run(update_weight) 201 | 202 | # Also, assign the NVlabs Gs dlatent_avg to the w_avg in the taki0112 network 203 | new = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="generator") 204 | if "avg" in str(x)][0] 205 | old = G.get_var("dlatent_avg") 206 | update_weight = [tf.assign(new, old)] 207 | sess.run(update_weight) 208 | vars = [new] 209 | vars_vals = gan.sess.run(vars) 210 | vars_vals = sess.run(vars) 211 | 212 | # Save the new taki0112 StyleGAN checkpoint 213 | # I elect to set the counter to 1 here to differentiate from the souce checkpoint (at 0) 214 | gan.saver = tf.train.Saver(max_to_keep=10) 215 | counter = 1 216 | gan.save(checkpoint_dir, counter) 217 | 218 | 219 | if __name__ == "__main__": 220 | main() 221 | -------------------------------------------------------------------------------- /Convert_NVlabs_StyleGAN_pkl_to_taki0112_checkpoint.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "nbformat": 4, 3 | "nbformat_minor": 0, 4 | "metadata": { 5 | "colab": { 6 | "name": "Convert NVlabs StyleGAN pkl to taki0112 checkpoint.ipynb", 7 | "version": "0.3.2", 8 | "provenance": [] 9 | }, 10 | "kernelspec": { 11 | "name": "python3", 12 | "display_name": "Python 3" 13 | }, 14 | "accelerator": "GPU" 15 | }, 16 | "cells": [ 17 | { 18 | "cell_type": "markdown", 19 | "metadata": { 20 | "id": "view-in-github", 21 | "colab_type": "text" 22 | }, 23 | "source": [ 24 | "\"Open" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": { 30 | "id": "BJJVP1hkenNX", 31 | "colab_type": "text" 32 | }, 33 | "source": [ 34 | "# 0. Convert a NVlabs .pkl to a taki0112 checkpoint\n", 35 | "\n", 36 | "This notebook will let you use a simple script to copy over weights from a StyleGAN network in the idiosyncratic dnnlib architecture from Nvidia to a more general Tensorflow one courtesy of taki0112 on GitHub" 37 | ] 38 | }, 39 | { 40 | "cell_type": "markdown", 41 | "metadata": { 42 | "id": "LfGxMBX_NdRL", 43 | "colab_type": "text" 44 | }, 45 | "source": [ 46 | "# 1. Use a GPU\n", 47 | "\n", 48 | "**Make sure that the notebook is running on a GPU**\n", 49 | "\n", 50 | "Edit -> Notebook Settings -> Hardware Accelerator -> GPU" 51 | ] 52 | }, 53 | { 54 | "cell_type": "markdown", 55 | "metadata": { 56 | "id": "F_qy4XWNOQSe", 57 | "colab_type": "text" 58 | }, 59 | "source": [ 60 | "# 2. Get both StyleGAN repositories" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "metadata": { 66 | "id": "nPsxy0m1OcDl", 67 | "colab_type": "code", 68 | "colab": {} 69 | }, 70 | "source": [ 71 | "!git clone https://github.com/NVlabs/stylegan.git\n", 72 | "!git clone https://github.com/taki0112/stylegan-tensorflow.git\n", 73 | "!git clone https://github.com/aydao/stylegan-convert-architecture.git" 74 | ], 75 | "execution_count": 0, 76 | "outputs": [] 77 | }, 78 | { 79 | "cell_type": "markdown", 80 | "metadata": { 81 | "id": "ZMbDca0bO4hS", 82 | "colab_type": "text" 83 | }, 84 | "source": [ 85 | "Move the taki0112 and aydao code into the same directory as the NVlabs code." 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "metadata": { 91 | "id": "Bmh7_paHO9Tv", 92 | "colab_type": "code", 93 | "colab": {} 94 | }, 95 | "source": [ 96 | "!mv ./stylegan-tensorflow/* ./stylegan/\n", 97 | "!rm -rf ./stylegan-tensorflow/\n", 98 | "!mv ./stylegan-convert-architecture/* ./stylegan/\n", 99 | "!rm -rf ./stylegan-convert-architecture/" 100 | ], 101 | "execution_count": 0, 102 | "outputs": [] 103 | }, 104 | { 105 | "cell_type": "markdown", 106 | "metadata": { 107 | "id": "zSlTXr1ad371", 108 | "colab_type": "text" 109 | }, 110 | "source": [ 111 | "# 3. Example conversion using the NVlabs FFHQ model\n" 112 | ] 113 | }, 114 | { 115 | "cell_type": "code", 116 | "metadata": { 117 | "id": "sU0vHwZnPF09", 118 | "colab_type": "code", 119 | "colab": {} 120 | }, 121 | "source": [ 122 | "cd stylegan" 123 | ], 124 | "execution_count": 0, 125 | "outputs": [] 126 | }, 127 | { 128 | "cell_type": "code", 129 | "metadata": { 130 | "id": "9jzaXYzyeSI9", 131 | "colab_type": "code", 132 | "colab": {} 133 | }, 134 | "source": [ 135 | "!mkdir ./cache/" 136 | ], 137 | "execution_count": 0, 138 | "outputs": [] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "metadata": { 143 | "id": "rGkWqOkTd1Y9", 144 | "colab_type": "code", 145 | "colab": {} 146 | }, 147 | "source": [ 148 | "!gdown https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ" 149 | ], 150 | "execution_count": 0, 151 | "outputs": [] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "metadata": { 156 | "id": "7_2PmAB4d1vp", 157 | "colab_type": "code", 158 | "colab": {} 159 | }, 160 | "source": [ 161 | "mv ./karras2019stylegan-ffhq-1024x1024.pkl ./cache" 162 | ], 163 | "execution_count": 0, 164 | "outputs": [] 165 | }, 166 | { 167 | "cell_type": "code", 168 | "metadata": { 169 | "id": "a6NfGTc5d1xy", 170 | "colab_type": "code", 171 | "colab": {} 172 | }, 173 | "source": [ 174 | "!python nvlabs_to_taki0112.py --nvlabs ./cache/karras2019stylegan-ffhq-1024x1024.pkl --dataset FFHQ --img_size 1024 --gpu_num 1" 175 | ], 176 | "execution_count": 0, 177 | "outputs": [] 178 | }, 179 | { 180 | "cell_type": "code", 181 | "metadata": { 182 | "id": "7zV73mr2uwzk", 183 | "colab_type": "code", 184 | "colab": {} 185 | }, 186 | "source": [ 187 | "# create a temp dataset directory/file since the taki0112 expects it \n", 188 | "!mkdir ./dataset/FFHQ/\n", 189 | "!touch ./dataset/FFHQ/temp.png" 190 | ], 191 | "execution_count": 0, 192 | "outputs": [] 193 | }, 194 | { 195 | "cell_type": "code", 196 | "metadata": { 197 | "id": "4rcocpDUvKDQ", 198 | "colab_type": "code", 199 | "colab": {} 200 | }, 201 | "source": [ 202 | "!python main.py --dataset FFHQ --img_size 1024 --start_res 1024 --progressive False --phase draw --draw style_mix\n", 203 | "!python main.py --dataset FFHQ --img_size 1024 --start_res 1024 --progressive False --phase draw --draw truncation_trick" 204 | ], 205 | "execution_count": 0, 206 | "outputs": [] 207 | }, 208 | { 209 | "cell_type": "markdown", 210 | "metadata": { 211 | "id": "IqnWJgAMvlXg", 212 | "colab_type": "text" 213 | }, 214 | "source": [ 215 | "The images are in ./results/StyleGAN_FFHQ_1024to1024/paper_figure/" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "metadata": { 221 | "id": "p7_m5yMTxuFs", 222 | "colab_type": "code", 223 | "colab": {} 224 | }, 225 | "source": [ 226 | "from matplotlib.pyplot import imshow\n", 227 | "import numpy as np\n", 228 | "from PIL import Image\n", 229 | "%matplotlib inline" 230 | ], 231 | "execution_count": 0, 232 | "outputs": [] 233 | }, 234 | { 235 | "cell_type": "code", 236 | "metadata": { 237 | "id": "yh5fnmdEzrXF", 238 | "colab_type": "code", 239 | "colab": {} 240 | }, 241 | "source": [ 242 | "style_mix = Image.open(\"./results/StyleGAN_FFHQ_1024to1024/paper_figure/figure03-style-mixing.jpg\", \"r\")\n", 243 | "imshow(np.asarray(style_mix))" 244 | ], 245 | "execution_count": 0, 246 | "outputs": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "metadata": { 251 | "id": "t3NW8eopztjf", 252 | "colab_type": "code", 253 | "colab": {} 254 | }, 255 | "source": [ 256 | "truncation_trick = Image.open(\"./results/StyleGAN_FFHQ_1024to1024/paper_figure/figure08-truncation-trick.jpg\", \"r\")\n", 257 | "imshow(np.asarray(truncation_trick))" 258 | ], 259 | "execution_count": 0, 260 | "outputs": [] 261 | }, 262 | { 263 | "cell_type": "markdown", 264 | "metadata": { 265 | "id": "gwLo4zV4NyGU", 266 | "colab_type": "text" 267 | }, 268 | "source": [ 269 | "# N. Upload your network-snapshot-######.pkl" 270 | ] 271 | }, 272 | { 273 | "cell_type": "markdown", 274 | "metadata": { 275 | "id": "UqbnQh-A4ymj", 276 | "colab_type": "text" 277 | }, 278 | "source": [ 279 | "Either use the integrated browser (Files -> UPLOAD) to get your pkl uploaded or put it in Google drive and mount it to this instance." 280 | ] 281 | }, 282 | { 283 | "cell_type": "code", 284 | "metadata": { 285 | "id": "eHHrxJKPNPmS", 286 | "colab_type": "code", 287 | "colab": {} 288 | }, 289 | "source": [ 290 | "# for me, mine is ./cache/network-snapshot-011185.pkl with a max resolution (lod0) of 512 pixels, and I'll just call the dataset \"mine\"" 291 | ], 292 | "execution_count": 0, 293 | "outputs": [] 294 | }, 295 | { 296 | "cell_type": "code", 297 | "metadata": { 298 | "id": "txwcTOLW1V6m", 299 | "colab_type": "code", 300 | "colab": {} 301 | }, 302 | "source": [ 303 | "!python nvlabs_to_taki0112.py --nvlabs ./cache/network-snapshot-011185.pkl --dataset mine --img_size 512 --gpu_num 1" 304 | ], 305 | "execution_count": 0, 306 | "outputs": [] 307 | }, 308 | { 309 | "cell_type": "code", 310 | "metadata": { 311 | "id": "Ky7GLG9h8XhT", 312 | "colab_type": "code", 313 | "colab": {} 314 | }, 315 | "source": [ 316 | "# create a temp dataset directory/file since the taki0112 expects it \n", 317 | "!mkdir ./dataset/mine/\n", 318 | "!touch ./dataset/mine/temp.png\n", 319 | "!python main.py --dataset mine --img_size 512 --start_res 512 --progressive False --phase draw --draw style_mix\n", 320 | "!python main.py --dataset mine --img_size 512 --start_res 512 --progressive False --phase draw --draw truncation_trick" 321 | ], 322 | "execution_count": 0, 323 | "outputs": [] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "metadata": { 328 | "id": "qzhThkT5-wKE", 329 | "colab_type": "code", 330 | "colab": {} 331 | }, 332 | "source": [ 333 | "from matplotlib.pyplot import imshow\n", 334 | "import numpy as np\n", 335 | "from PIL import Image\n", 336 | "%matplotlib inline" 337 | ], 338 | "execution_count": 0, 339 | "outputs": [] 340 | }, 341 | { 342 | "cell_type": "code", 343 | "metadata": { 344 | "id": "k7EvlQIp93gA", 345 | "colab_type": "code", 346 | "colab": {} 347 | }, 348 | "source": [ 349 | "style_mix = Image.open(\"./results/StyleGAN_mine_512to512/paper_figure/figure03-style-mixing.jpg\", \"r\")\n", 350 | "imshow(np.asarray(style_mix))" 351 | ], 352 | "execution_count": 0, 353 | "outputs": [] 354 | }, 355 | { 356 | "cell_type": "code", 357 | "metadata": { 358 | "id": "Zs4WGLCQ-zlc", 359 | "colab_type": "code", 360 | "colab": {} 361 | }, 362 | "source": [ 363 | "truncation_trick = Image.open(\"./results//StyleGAN_mine_512to512/paper_figure/figure08-truncation-trick.jpg\", \"r\")\n", 364 | "imshow(np.asarray(truncation_trick))" 365 | ], 366 | "execution_count": 0, 367 | "outputs": [] 368 | }, 369 | { 370 | "cell_type": "code", 371 | "metadata": { 372 | "id": "bXwtNXLZ-0PF", 373 | "colab_type": "code", 374 | "colab": {} 375 | }, 376 | "source": [ 377 | "" 378 | ], 379 | "execution_count": 0, 380 | "outputs": [] 381 | } 382 | ] 383 | } 384 | -------------------------------------------------------------------------------- /taki0112_to_nvlabs.py: -------------------------------------------------------------------------------- 1 | # 2 | # aydao ~ 2019 3 | # 4 | # Convert taki0112 StyleGAN checkpoints to network-snapshot-######.pkl StyleGAN models 5 | # 6 | # This script relies on both the NVlabs StyleGAN GitHub repository and the taki0112 StyleGAN repo 7 | # It assumes both are in the same directory as this script 8 | # 9 | from StyleGAN import StyleGAN 10 | import argparse 11 | from utils import * 12 | import os 13 | import pickle 14 | import numpy as np 15 | import PIL.Image 16 | import copy 17 | import dnnlib 18 | import dnnlib.tflib as tflib 19 | from dnnlib import EasyDict 20 | import config 21 | from metrics import metric_base 22 | from training.training_loop import process_reals 23 | from training import misc 24 | import tensorflow as tf 25 | import sys, time 26 | 27 | """parsing and configuration""" 28 | def parse_args(): 29 | desc = "Convert taki0112 StyleGAN checkpoint to NVlabs StyleGAN pkl (copies over the model weights)" 30 | parser = argparse.ArgumentParser(description=desc) 31 | 32 | parser.add_argument("--kimg", type=str, help="kimg/iteration of the NVlabs pkl, use format ######") 33 | parser.add_argument("--cache_dir", type=str, default="./", 34 | help="The cache directory to save the NVlabs pkl in") 35 | 36 | parser.add_argument("--dataset", type=str, default="FFHQ", 37 | help="The dataset name what you want to generate") 38 | 39 | parser.add_argument("--gpu_num", type=int, default=1, help="The number of gpu") 40 | 41 | parser.add_argument("--start_res", type=int, default=8, help="The number of starting resolution") 42 | parser.add_argument("--img_size", type=int, default=1024, help="The target size of image") 43 | parser.add_argument("--progressive", type=str2bool, default=True, help="use progressive training") 44 | 45 | parser.add_argument("--checkpoint_dir", type=str, default="./checkpoint/", 46 | help="Directory name to save the checkpoints") 47 | 48 | parser.add_argument("--result_dir", type=str, default="./results/", 49 | help="Directory name to save the generated images") 50 | 51 | # Extra args for taki0112 code 52 | # Do not change these args, the code will automatically set them for you 53 | # parser.add_argument("--progressive", help=argparse.SUPPRESS) 54 | parser.add_argument("--phase", help=argparse.SUPPRESS) 55 | parser.add_argument("--sn", help=argparse.SUPPRESS) 56 | 57 | args = parser.parse_args() 58 | args = check_args(args) 59 | args = handle_extra_args(args) 60 | return args 61 | 62 | """checking arguments""" 63 | def check_args(args): 64 | 65 | # --checkpoint_dir 66 | check_folder(args.checkpoint_dir) 67 | 68 | # --result_dir 69 | check_folder(args.result_dir) 70 | 71 | return args 72 | 73 | """handle the extra args based on required args""" 74 | def handle_extra_args(args): 75 | # these are things you either do not need for this script, or can be set automatically 76 | args.phase = "train" 77 | # assuming you are not using spectral normalization since NVlabs does not use it 78 | # you probably *could* figure out how to convert with sn, but it'd take some more tinkering 79 | args.sn = False 80 | # preempt the taki0112 code from making needless directories... 81 | args.sample_dir = args.result_dir 82 | args.log_dir = args.result_dir 83 | # magic values not needed for this script 84 | args.iteration = 0 85 | args.max_iteration = 2500 86 | args.batch_size = 1 87 | args.test_num = 1 88 | args.seed = True 89 | return args 90 | 91 | def make_temp_dataset_file(dataset_dir): 92 | filename = dataset_dir + "/nvlabs_to_taki0112_tempfile.png" 93 | os.makedirs(os.path.dirname(filename), exist_ok=True) 94 | with open(filename, "w") as f: 95 | f.write("this file is safe to delete") 96 | # give it a second 97 | time.sleep(1) 98 | return filename 99 | 100 | def delete_temp_dataset_file(args, dataset_dir, filename): 101 | try: 102 | os.remove(filename) 103 | if len(os.listdir(dataset_dir)) == 0: 104 | os.rmdir(dataset_dir) 105 | except OSError as e: 106 | print ("Error: %s - %s." % (e.filename, e.strerror)) 107 | 108 | """main""" 109 | def main(): 110 | # 111 | # Usage example: 112 | # python [this_file].py --kimg ###### --dataset [your data] --gpu_num 1 113 | # --start_res 8 --img_size 512 --progressive True 114 | # 115 | # 116 | # parse arguments 117 | args = parse_args() 118 | if args is None: 119 | exit() 120 | 121 | checkpoint_dir = args.checkpoint_dir 122 | nvlabs_stylegan_pkl_kimg = args.kimg 123 | nvlabs_stylegan_pkl_name = "network-snapshot-"+nvlabs_stylegan_pkl_kimg+".pkl" 124 | 125 | 126 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess: 127 | 128 | # this is a hack since the taki0112 expects a dataset folder which may not exist 129 | dataset = args.dataset 130 | dataset_dir = "./dataset/" + dataset 131 | temp_dataset_file = make_temp_dataset_file(dataset_dir) 132 | 133 | 134 | # build the taki0112 StyleGAN architecture (vanilla Tensorflow) 135 | gan = StyleGAN(sess, args) 136 | 137 | 138 | # you have to go through this process to initialize everything needed to load the checkpoint... 139 | original_start_res = args.start_res 140 | args.start_res = args.img_size 141 | gan.start_res = args.img_size 142 | gan.build_model() 143 | args.start_res = original_start_res 144 | gan.start_res = original_start_res 145 | 146 | # remove the temp file and the directory if it is empty 147 | delete_temp_dataset_file(args, dataset_dir, temp_dataset_file) 148 | 149 | # Initialize TensorFlow. 150 | tflib.init_tf() 151 | 152 | tf.global_variables_initializer().run() 153 | 154 | 155 | vars = tf.trainable_variables("discriminator") 156 | vars_vals = sess.run(vars) 157 | for var, val in zip(vars, vars_vals): 158 | print(var.name) 159 | 160 | gan.saver = tf.train.Saver(max_to_keep=10) 161 | gan.load(checkpoint_dir) 162 | 163 | # 164 | # 165 | # Make an NVlabs StyleGAN network (default initialization) 166 | # 167 | # 168 | 169 | # StyleGAN initialization parameters and options, if you care to change them, do so here 170 | desc = "sgan" 171 | train = EasyDict(run_func_name="training.training_loop.training_loop") 172 | G = EasyDict(func_name="training.networks_stylegan.G_style") 173 | D = EasyDict(func_name="training.networks_stylegan.D_basic") 174 | G_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) 175 | D_opt = EasyDict(beta1=0.0, beta2=0.99, epsilon=1e-8) 176 | G_loss = EasyDict(func_name="training.loss.G_logistic_nonsaturating") 177 | D_loss = EasyDict(func_name="training.loss.D_logistic_simplegp", r1_gamma=10.0) 178 | dataset = EasyDict() 179 | sched = EasyDict() 180 | grid = EasyDict(size="4k", layout="random") 181 | metrics = [metric_base.fid50k] 182 | submit_config = dnnlib.SubmitConfig() 183 | tf_config = {"rnd.np_random_seed": 1000} 184 | drange_net = [-1,1] 185 | G_smoothing_kimg = 10.0 186 | 187 | # Dataset. 188 | desc += "-"+args.dataset 189 | dataset = EasyDict(tfrecord_dir=args.dataset) 190 | train.mirror_augment = True 191 | 192 | # Number of GPUs. 193 | gpu_num = args.gpu_num 194 | if gpu_num == 1: 195 | desc += "-1gpu"; submit_config.num_gpus = 1 196 | sched.minibatch_base = 4 197 | sched.minibatch_dict = {4: 128, 8: 128, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8, 512: 4} 198 | elif gpu_num == 2: 199 | desc += "-2gpu"; submit_config.num_gpus = 2 200 | sched.minibatch_base = 8 201 | sched.minibatch_dict = {4: 256, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16, 256: 8} 202 | elif gpu_num == 4: 203 | desc += "-4gpu"; submit_config.num_gpus = 4 204 | sched.minibatch_base = 16 205 | sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32, 128: 16} 206 | elif gpu_num == 8: 207 | desc += "-8gpu"; submit_config.num_gpus = 8 208 | sched.minibatch_base = 32 209 | sched.minibatch_dict = {4: 512, 8: 256, 16: 128, 32: 64, 64: 32} 210 | else: 211 | print("ERROR: invalid number of gpus:",gpu_num) 212 | sys.exit(-1) 213 | 214 | # Default options. 215 | train.total_kimg = 0 216 | sched.lod_initial_resolution = 8 217 | sched.G_lrate_dict = {128: 0.0015, 256: 0.002, 512: 0.003, 1024: 0.003} 218 | sched.D_lrate_dict = EasyDict(sched.G_lrate_dict) 219 | 220 | # Initialize dnnlib and TensorFlow. 221 | # ctx = dnnlib.RunContext(submit_config, train) 222 | tflib.init_tf(tf_config) 223 | 224 | # Construct networks. 225 | with tf.device('/gpu:0'): 226 | print('Constructing networks...') 227 | dataset_resolution = args.img_size 228 | dataset_channels = 3 # fairly sure everyone is using 3 channels ... # training_set.shape[0], 229 | dataset_label_size = 0 # training_set.label_size, 230 | G = tflib.Network('G', 231 | num_channels=dataset_channels, 232 | resolution=dataset_resolution, 233 | label_size=dataset_label_size, 234 | **G) 235 | D = tflib.Network('D', 236 | num_channels=dataset_channels, 237 | resolution=dataset_resolution, 238 | label_size=dataset_label_size, 239 | **D) 240 | Gs = G.clone('Gs') 241 | G.print_layers(); D.print_layers() 242 | 243 | print('Building TensorFlow graph...') 244 | with tf.name_scope('Inputs'), tf.device('/cpu:0'): 245 | lod_in = tf.placeholder(tf.float32, name='lod_in', shape=[]) 246 | lrate_in = tf.placeholder(tf.float32, name='lrate_in', shape=[]) 247 | minibatch_in = tf.placeholder(tf.int32, name='minibatch_in', shape=[]) 248 | minibatch_split = minibatch_in // submit_config.num_gpus 249 | Gs_beta = 0.5 ** tf.div(tf.cast(minibatch_in, tf.float32), 250 | G_smoothing_kimg * 1000.0) if G_smoothing_kimg > 0.0 else 0.0 251 | 252 | 253 | src_d = "discriminator" 254 | dst_d = "D" 255 | src_gs = "generator/g_synthesis" 256 | dst_gs = "G_synthesis" # "G_synthesis_1" <<<< this is handled later 257 | src_gm = "generator/g_mapping" 258 | dst_gm = "G_mapping" # "G_mapping_1" <<<< this is handled later 259 | 260 | 261 | vars = tf.trainable_variables(src_gm) 262 | vars_vals = sess.run(vars) 263 | 264 | 265 | # Copy over the discriminator weights 266 | for (new, old) in zip(tf.trainable_variables(dst_d), tf.trainable_variables(src_d)): 267 | update_weight = [tf.assign(new, old)] 268 | sess.run(update_weight) 269 | temp_vals = sess.run([new, old]) 270 | 271 | # Copy over the Generator's mapping network weights 272 | for (new, old) in zip(tf.trainable_variables(dst_gm), tf.trainable_variables(src_gm)): 273 | update_weight = [tf.assign(new, old)] 274 | sess.run(update_weight) 275 | temp_vals = sess.run([new, old]) 276 | 277 | # Because the two network architectures use slightly different columns on one variable, 278 | # you must set up code to handle the edge case transpose of the first case 279 | first = True 280 | for (new, old) in zip(tf.trainable_variables(dst_gs), tf.trainable_variables(src_gs)): 281 | temp_vals = sess.run([new, old]) 282 | if new.shape != old.shape: 283 | # you need a transpose with perm # old = tf.reshape(old, tf.shape(new)) 284 | # DO NOT USE RESHAPE 285 | # (made this mistake here and the results work but are quite terrifying) 286 | if (first): 287 | first = False 288 | old = tf.transpose(old, perm=[0, 3, 1, 2]) 289 | else: 290 | old = tf.transpose(old, perm=[0, 1, 3, 2]) 291 | update_weight = [tf.assign(new, old)] 292 | sess.run(update_weight) 293 | 294 | # also update the running average network (not 100% sure this is necessary) 295 | dst_gs = "G_synthesis_1" 296 | dst_gm = "G_mapping_1" 297 | for (new, old) in zip(tf.trainable_variables(dst_gm), tf.trainable_variables(src_gm)): 298 | update_weight = [tf.assign(new, old)] 299 | sess.run(update_weight) 300 | temp_vals = sess.run([new, old]) 301 | first = True 302 | for (new, old) in zip(tf.trainable_variables(dst_gs), tf.trainable_variables(src_gs)): 303 | temp_vals = sess.run([new, old]) 304 | if new.shape != old.shape: 305 | # you need a transpose with perm # old = tf.reshape(old, tf.shape(new)) 306 | # DO NOT USE RESHAPE 307 | # (made this mistake here and the results work but are quite terrifying) 308 | if (first): 309 | first = False 310 | old = tf.transpose(old, perm=[0, 3, 1, 2]) 311 | else: 312 | old = tf.transpose(old, perm=[0, 1, 3, 2]) 313 | update_weight = [tf.assign(new, old)] 314 | sess.run(update_weight) 315 | 316 | # Also, assign the w_avg in the taki0112 network to the NVlabs Gs dlatent_avg 317 | new = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="G") 318 | if "dlatent_avg" in str(x)][0] # G.get_var("dlatent_avg") 319 | old = [x for x in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope="generator") 320 | if "avg" in str(x)][0] 321 | update_weight = [tf.assign(new, old)] 322 | sess.run(update_weight) 323 | vars = [new] 324 | vars_vals = gan.sess.run(vars) 325 | vars_vals = sess.run(vars) 326 | 327 | misc.save_pkl((G, D, Gs), "./"+nvlabs_stylegan_pkl_name) 328 | 329 | if __name__ == "__main__": 330 | main() 331 | --------------------------------------------------------------------------------