├── README.md ├── config.py ├── data.py ├── download.sh ├── model.py ├── roou.py ├── test.py ├── train.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # GalaxyGAN_python 2 | This project is the implementation of the Paper "Generative Adversarial Networks recover features in astrophysical images of galaxies beyond the deconvolution limit" on python. It is the python version of https://github.com/SpaceML/GalaxyGAN. This python version doesn't include deconvolution part of the paper. 3 | 4 | ## Amazon EC2 Setup 5 | 6 | ### EC2 Public AMI 7 | We provide an EC2 AMI with the following pre-installed packages: 8 | 9 | * CUDA 10 | * cuDNN 11 | * Tensorflow r0.12 12 | * python 13 | 14 | as well as the FITS file we used in the paper(saved in ~/fits_train and ~/fits_test) 15 | 16 | AMI Id: ami-96a97f80 17 | . (Can be launched using p2.xlarge instance in GPU compute catagory) 18 | 19 | [Launch](https://console.aws.amazon.com/ec2/v2/home?region=us-east-1#Images:sort=visibility) an instance. 20 | ### Connect to Amazon EC2 Machine 21 | 22 | Please follow the instruction of Amazon EC2. 23 | 24 | ## Prerequisites 25 | 26 | Linux or OSX 27 | 28 | NVIDIA GPU + CUDA CuDNN (CPU mode and CUDA without CuDNN may work with minimal modification, but untested) 29 | 30 | ## Dependencies 31 | 32 | We need the following python packages: 33 | `tensorflow`, `cv2`, `numpy`, `scipy`, `matplotlb`, `pyfits`, and `ipython` 34 | 35 | ## Get Our Code 36 | Clone this repo: 37 | 38 | ```bash 39 | git clone https://github.com/SpaceML/GalaxyGAN_python.git 40 | cd GalaxyGAN_python/ 41 | ``` 42 | 43 | ## Get Our FITS Files 44 | The data to download is about 5GB, after unzipping it will become about 16GB. Download this file from Google Drive: https://drive.google.com/open?id=1GCs02NBnr7X3skA04hyuXh6cUMZQLzVe 45 | 46 | 47 | ## Run Our Code 48 | 49 | 50 | ### Preprocess the .FITs 51 | If the mode equals zero, this is the training data. If the mode equals one, the data is used for testing. 52 | 53 | ```bash 54 | python roou.py --input fitsdata/fits_train --fwhm 1.4 --sig 1.2 --mode 0 55 | python roou.py --input fitsdata/fits_test --fwhm 1.4 --sig 1.2 --mode 1 56 | ``` 57 | XXX is your local address. On our AMI, you can skip this step due to all these have default values. 58 | 59 | 60 | ### Train the model 61 | 62 | If you need, you can modify the constants in the Config.py. 63 | 64 | ```bash 65 | python train.py gpu=1 66 | ``` 67 | You can appoint which gpu to run the code by changing "gpu=1". 68 | 69 | This will start the training process. If you want to load the model which already exists, you can modify the model_path in the config.py. 70 | 71 | ### Test 72 | 73 | Before you try to test your model, you should modify the model path in the config.py. 74 | 75 | ```bash 76 | python test.py gpu=1 77 | ``` 78 | The results can be seen in the folder "test". 79 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | class Config: 2 | data_path = "figures" 3 | model_path_train = "" 4 | model_path_test = "figures/checkpoint/model_20.ckpt" 5 | output_path = "results" 6 | 7 | img_size = 424 8 | adjust_size = 424 9 | train_size = 424 10 | img_channel = 3 11 | conv_channel_base = 64 12 | 13 | learning_rate = 0.0002 14 | beta1 = 0.5 15 | max_epoch = 20 16 | L1_lambda = 100 17 | save_per_epoch=1 18 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | from config import Config as conf 2 | from utils import imread 3 | import os 4 | 5 | def load(path): 6 | for i in os.listdir(path): 7 | all = imread(path + "/" + i) 8 | img, cond = all[:,:conf.img_size], all[:,conf.img_size:] 9 | yield (img, cond, i) 10 | 11 | def load_data(): 12 | data = dict() 13 | data["train"] = lambda: load(conf.data_path + "/train") 14 | data["test"] = lambda: load(conf.data_path + "/test") 15 | return data 16 | -------------------------------------------------------------------------------- /download.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget https://share.phys.ethz.ch/~blackhole/spaceml/GalaxyGAN/fitsdata.tar.gz 3 | tar -xvzf fitsdata.tar.gz -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | from config import Config as conf 2 | from utils import conv2d, deconv2d, linear, batch_norm, lrelu 3 | import tensorflow as tf 4 | 5 | class CGAN(object): 6 | 7 | def __init__(self): 8 | self.image = tf.placeholder(tf.float32, shape=(1,conf.img_size, conf.img_size, conf.img_channel)) 9 | self.cond = tf.placeholder(tf.float32, shape=(1,conf.img_size, conf.img_size, conf.img_channel)) 10 | 11 | self.gen_img = self.generator(self.cond) 12 | 13 | pos = self.discriminator(self.image, self.cond, False) 14 | neg = self.discriminator(self.gen_img, self.cond, True) 15 | pos_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=pos, labels=tf.ones_like(pos))) 16 | neg_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.zeros_like(neg))) 17 | 18 | self.d_loss = pos_loss + neg_loss 19 | self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=neg, labels=tf.ones_like(neg))) + \ 20 | conf.L1_lambda * tf.reduce_mean(tf.abs(self.image - self.gen_img)) 21 | 22 | t_vars = tf.trainable_variables() 23 | self.d_vars = [var for var in t_vars if 'disc' in var.name] 24 | self.g_vars = [var for var in t_vars if 'gen' in var.name] 25 | 26 | def discriminator(self, img, cond, reuse): 27 | dim = len(img.get_shape()) 28 | with tf.variable_scope("disc", reuse=reuse): 29 | image = tf.concat([img, cond], dim - 1) 30 | feature = conf.conv_channel_base 31 | h0 = lrelu(conv2d(image, feature, name="h0")) 32 | h1 = lrelu(batch_norm(conv2d(h0, feature*2, name="h1"), "h1")) 33 | h2 = lrelu(batch_norm(conv2d(h1, feature*4, name="h2"), "h2")) 34 | h3 = lrelu(batch_norm(conv2d(h2, feature*8, name="h3"), "h3")) 35 | h4 = linear(tf.reshape(h3, [1,-1]), 1, "linear") 36 | return h4 37 | 38 | def generator(self, cond): 39 | with tf.variable_scope("gen"): 40 | feature = conf.conv_channel_base 41 | e1 = conv2d(cond, feature, name="e1") 42 | e2 = batch_norm(conv2d(lrelu(e1), feature*2, name="e2"), "e2") 43 | e3 = batch_norm(conv2d(lrelu(e2), feature*4, name="e3"), "e3") 44 | e4 = batch_norm(conv2d(lrelu(e3), feature*8, name="e4"), "e4") 45 | e5 = batch_norm(conv2d(lrelu(e4), feature*8, name="e5"), "e5") 46 | e6 = batch_norm(conv2d(lrelu(e5), feature*8, name="e6"), "e6") 47 | e7 = batch_norm(conv2d(lrelu(e6), feature*8, name="e7"), "e7") 48 | e8 = batch_norm(conv2d(lrelu(e7), feature*8, name="e8"), "e8") 49 | 50 | size = conf.img_size 51 | num = [0] * 9 52 | for i in range(1,9): 53 | num[9-i]=size 54 | size =(size+1)/2 55 | 56 | d1 = deconv2d(tf.nn.relu(e8), [1,num[1],num[1],feature*8], name="d1") 57 | d1 = tf.concat([tf.nn.dropout(batch_norm(d1, "d1"), 0.5), e7], 3) 58 | d2 = deconv2d(tf.nn.relu(d1), [1,num[2],num[2],feature*8], name="d2") 59 | d2 = tf.concat([tf.nn.dropout(batch_norm(d2, "d2"), 0.5), e6], 3) 60 | d3 = deconv2d(tf.nn.relu(d2), [1,num[3],num[3],feature*8], name="d3") 61 | d3 = tf.concat([tf.nn.dropout(batch_norm(d3, "d3"), 0.5), e5], 3) 62 | d4 = deconv2d(tf.nn.relu(d3), [1,num[4],num[4],feature*8], name="d4") 63 | d4 = tf.concat([batch_norm(d4, "d4"), e4], 3) 64 | d5 = deconv2d(tf.nn.relu(d4), [1,num[5],num[5],feature*4], name="d5") 65 | d5 = tf.concat([batch_norm(d5, "d5"), e3], 3) 66 | d6 = deconv2d(tf.nn.relu(d5), [1,num[6],num[6],feature*2], name="d6") 67 | d6 = tf.concat([batch_norm(d6, "d6"), e2], 3) 68 | d7 = deconv2d(tf.nn.relu(d6), [1,num[7],num[7],feature], name="d7") 69 | d7 = tf.concat([batch_norm(d7, "d7"), e1], 3) 70 | d8 = deconv2d(tf.nn.relu(d7), [1,num[8],num[8],conf.img_channel], name="d8") 71 | 72 | return tf.nn.tanh(d8) 73 | -------------------------------------------------------------------------------- /roou.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: UTF-8 -*- 3 | import argparse 4 | import numpy as np 5 | import cv2 6 | import math 7 | import random 8 | from scipy.stats import norm 9 | import matplotlib.pyplot as plt 10 | import os 11 | import pyfits 12 | import glob 13 | from IPython import embed 14 | # mode : 0 training : 1 testing 15 | 16 | parser = argparse.ArgumentParser() 17 | 18 | def fspecial_gauss(size, sigma): 19 | x, y = np.mgrid[-size//2 + 1:size//2 + 1, -size//2 + 1:size//2 + 1] 20 | g = np.exp(-((x**2 + y**2)/(2.0*sigma**2))) 21 | return g/g.sum() 22 | 23 | def adjust(origin): 24 | img = origin.copy() 25 | img[img>4] = 4 26 | img[img < -0.1] = -0.1 27 | MIN = np.min(img) 28 | MAX = np.max(img) 29 | img = np.arcsinh(10*(img - MIN)/(MAX-MIN))/3 30 | return img 31 | 32 | def roou(): 33 | is_demo = 0 34 | 35 | parser.add_argument("--fwhm", default="1.4") 36 | parser.add_argument("--sig", default="1.2") 37 | parser.add_argument("--input", default="/home/ubuntu/GalaxyGAN_python/fits_train") #./fits_0.01_0.02 38 | parser.add_argument("--figure", default="figures") #./figures/test 39 | parser.add_argument("--gpu", default = "1") 40 | parser.add_argument("--model", default = "models") 41 | parser.add_argument("--mode", default="0") 42 | args = parser.parse_args() 43 | 44 | fwhm = float(args.fwhm) 45 | sig = float(args.sig) 46 | input = args.input 47 | figure = args.figure 48 | mode = int(args.mode) 49 | 50 | train_folder = '%s/train'%(args.figure) 51 | test_folder = '%s/test'%(args.figure) 52 | deconv_folder = '%s/deconv'%(args.figure) 53 | 54 | if not os.path.exists('./' + args.figure): 55 | os.makedirs("./" + args.figure) 56 | if not os.path.exists("./" + train_folder): 57 | os.makedirs("./" + train_folder) 58 | if not os.path.exists("./" + test_folder): 59 | os.makedirs("./" + test_folder) 60 | #if not os.path.exists("./" + deconv_folder): 61 | #os.makedirs("./"+ deconv_folder) 62 | 63 | fits = '%s/*/*-g.fits'%(input) 64 | files = glob.iglob(fits) 65 | 66 | for i in files: 67 | #print i 68 | file_name = os.path.basename(i) 69 | 70 | #readfiles 71 | if False: 72 | file_name = '587725489986928743' 73 | fwhm = 1.4 74 | sig = 1.2 75 | mode = 1 76 | input_folder='fits_0.01_0.02' 77 | figure_folder='figures' 78 | 79 | filename = file_name.replace("-g.fits", '') 80 | filename_g = '%s/%s/%s-g.fits'%(input,filename,filename) 81 | filename_r = '%s/%s/%s-r.fits'%(input,filename,filename) 82 | filename_i = '%s/%s/%s-i.fits'%(input,filename,filename) 83 | 84 | gfits = pyfits.open(filename_g) 85 | rfits = pyfits.open(filename_r) 86 | ifits = pyfits.open(filename_i) 87 | data_g = gfits[0].data 88 | data_r = rfits[0].data 89 | data_i = ifits[0].data 90 | 91 | figure_original = np.ones((data_g.shape[0],data_g.shape[1],3)) 92 | figure_original[:,:,0] = data_g 93 | figure_original[:,:,1] = data_r 94 | figure_original[:,:,2] = data_i 95 | 96 | #print figure_original 97 | 98 | if is_demo: 99 | cv2.imshow("img", adjust(figure_original)) 100 | cv2.waitKey(0) 101 | 102 | # gaussian filter 103 | fwhm_use = fwhm/0.396 104 | gaussian_sigma = fwhm_use / 2.355 105 | 106 | # with problem 107 | figure_blurred = cv2.GaussianBlur(figure_original, (5,5), gaussian_sigma) 108 | 109 | #print "IIIIII" 110 | 111 | if is_demo: 112 | cv2.imshow("i", figure_blurred) 113 | cv2.waitKey(0) 114 | 115 | # add white noise 116 | figure_original_nz = figure_original[figure_original<0.1] 117 | figure_original_nearzero = figure_original_nz[figure_original_nz>-0.1] 118 | figure_blurred_nz = figure_blurred[figure_blurred<0.1] 119 | figure_blurred_nearzero = figure_blurred_nz[figure_blurred_nz>-0.1] 120 | [m,s] = norm.fit(figure_original_nearzero) 121 | [m2,s2] = norm.fit(figure_blurred_nearzero) 122 | 123 | whitenoise_var = (sig*s)**2-s2**2 124 | 125 | if whitenoise_var < 0: 126 | whitenoise_var = 0.00000001 127 | 128 | whitenoise = np.random.normal(0, np.sqrt(whitenoise_var) , (data_g.shape[0], data_g.shape[1])) 129 | 130 | figure_blurred[:,:,0] = figure_blurred[:,:,0] + whitenoise 131 | figure_blurred[:,:,1] = figure_blurred[:,:,1] + whitenoise 132 | figure_blurred[:,:,2] = figure_blurred[:,:,2] + whitenoise 133 | 134 | if is_demo: 135 | cv2.imshow('k',figure_blurred) 136 | cv2.waitKey(0) 137 | 138 | # deconvolution 139 | if mode>2: 140 | hsize = 2*np.ceil(2*gaussian_sigma)+1 141 | PSF = fspecial_gauss(hsize, gaussian_sigma) 142 | #figure_deconv = deconvblind(figure_blurred, PSF) 143 | if is_demo: 144 | cv2.imshow(figure_deconv) 145 | cv2.waitKey(0) 146 | 147 | # thresold 148 | MAX = 4 149 | MIN = -0.1 150 | 151 | figure_original[figure_originalMAX]=MAX 153 | 154 | figure_blurred[figure_blurredMAX]=MAX 156 | if mode>2: 157 | figure_deconv[figure_deconvMAX]=MAX 159 | 160 | # normalize figures 161 | figure_original = (figure_original-MIN)/(MAX-MIN) 162 | figure_blurred = (figure_blurred-MIN)/(MAX-MIN) 163 | 164 | '''if mode: 165 | figure_deconv = (figure_deconv-MIN)/(MAX-MIN) 166 | if is_demo: 167 | plt.subplot(1,3,1), plt.subimage(figure_original), plt.subplot(1,3,2), plt.subimage(figure_blurred),plt.subplot(1,3,3), plt.subimage(figure_deconv) 168 | elif is_demo: 169 | plt.subplot(1,2,1), plt.subimage(figure_original), plt.subplot(1,2,2), plt.subimage(figure_blurred) 170 | ''' 171 | 172 | # asinh scaling 173 | figure_original = np.arcsinh(10*figure_original)/3 174 | figure_blurred = np.arcsinh(10*figure_blurred)/3 175 | 176 | if mode>2: 177 | figure_deconv = np.arcsinh(10*figure_deconv)/3 178 | 179 | # output result to pix2pix format 180 | figure_combined = np.zeros((data_g.shape[0], data_g.shape[1]*2,3)) 181 | figure_combined[:,: data_g.shape[1],:] = figure_original[:,:,:] 182 | figure_combined[:, data_g.shape[1]:2*data_g.shape[1],:] = figure_blurred[:,:,:] 183 | 184 | if is_demo: 185 | cv2.imshow(figure_combined) 186 | cv2.waitKey(0) 187 | 188 | if mode: 189 | jpg_path = '%s/test/%s.jpg'% (figure,filename) 190 | else: 191 | jpg_path = '%s/train/%s.jpg'% (figure,filename) 192 | 193 | if mode == 2: 194 | figure_combined_no_ori = np.zeros(data_g.shape[0], data_g.shape[1]*2,3) 195 | figure_combined_no_ori[:, data_g.shape[1]:2*data_g.shape[1],:] = figure_blurred[:,:,:] 196 | cv2.imwrite(figure_combined_no_ori,jpg_path) 197 | else: 198 | image = (figure_combined*256).astype(np.int) 199 | cv2.imwrite(jpg_path, image) 200 | 201 | if mode>2: 202 | deconv_path = '%s/deconv/%s.jpg'% (figure_folder,filename) 203 | cv2.imwrite(figure_deconv,deconv_path) 204 | 205 | roou() 206 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | from config import Config as conf 2 | from data import * 3 | import scipy.misc 4 | from model import CGAN 5 | from utils import imsave 6 | import tensorflow as tf 7 | import numpy as np 8 | import time 9 | import sys 10 | from IPython import embed 11 | 12 | def prepocess_test(img, cond): 13 | 14 | img = scipy.misc.imresize(img, [conf.train_size, conf.train_size]) 15 | cond = scipy.misc.imresize(cond, [conf.train_size, conf.train_size]) 16 | img = img.reshape(1, conf.img_size, conf.img_size, conf.img_channel) 17 | cond = cond.reshape(1, conf.img_size, conf.img_size, conf.img_channel) 18 | img = img/127.5 - 1. 19 | cond = cond/127.5 - 1. 20 | return img,cond 21 | 22 | def test(): 23 | 24 | if not os.path.exists("test"): 25 | os.makedirs("test") 26 | data = load_data() 27 | model = CGAN() 28 | 29 | d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.d_loss, var_list=model.d_vars) 30 | g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.g_loss, var_list=model.g_vars) 31 | 32 | saver = tf.train.Saver() 33 | 34 | counter = 0 35 | start_time = time.time() 36 | 37 | with tf.Session() as sess: 38 | saver.restore(sess, conf.model_path_test) 39 | test_data = data["test"]() 40 | for img, cond, name in test_data: 41 | pimg, pcond = prepocess_test(img, cond) 42 | gen_img = sess.run(model.gen_img, feed_dict={model.image:pimg, model.cond:pcond}) 43 | gen_img = gen_img.reshape(gen_img.shape[1:]) 44 | gen_img = (gen_img + 1.) * 127.5 45 | image = np.concatenate((gen_img, cond), axis=1).astype(np.int) 46 | imsave(image, "./test" + "/%s" % name) 47 | 48 | if __name__ == "__main__": 49 | if len(sys.argv) > 1 and sys.argv[1] == 'gpu=': 50 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 51 | os.environ["CUDA_VISIBLE_DEVICES"]=str(sys.argv[1][4:]) 52 | else: 53 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 54 | os.environ["CUDA_VISIBLE_DEVICES"]=str(0) 55 | test() 56 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from config import Config as conf 2 | from data import * 3 | import scipy.misc 4 | from model import CGAN 5 | from utils import imsave 6 | import tensorflow as tf 7 | import numpy as np 8 | import time 9 | import sys 10 | 11 | def prepocess_train(img, cond,): 12 | img = scipy.misc.imresize(img, [conf.adjust_size, conf.adjust_size]) 13 | cond = scipy.misc.imresize(cond, [conf.adjust_size, conf.adjust_size]) 14 | h1 = int(np.ceil(np.random.uniform(1e-2, conf.adjust_size - conf.train_size))) 15 | w1 = int(np.ceil(np.random.uniform(1e-2, conf.adjust_size - conf.train_size))) 16 | img = img[h1:h1 + conf.train_size, w1:w1 + conf.train_size] 17 | cond = cond[h1:h1 + conf.train_size, w1:w1 + conf.train_size] 18 | if np.random.random() > 0.5: 19 | img = np.fliplr(img) 20 | cond = np.fliplr(cond) 21 | img = img/127.5 - 1. 22 | cond = cond/127.5 - 1. 23 | img = img.reshape(1, conf.img_size, conf.img_size, conf.img_channel) 24 | cond = cond.reshape(1, conf.img_size, conf.img_size, conf.img_channel) 25 | return img,cond 26 | 27 | def prepocess_test(img, cond): 28 | 29 | img = scipy.misc.imresize(img, [conf.train_size, conf.train_size]) 30 | cond = scipy.misc.imresize(cond, [conf.train_size, conf.train_size]) 31 | img = img.reshape(1, conf.img_size, conf.img_size, conf.img_channel) 32 | cond = cond.reshape(1, conf.img_size, conf.img_size, conf.img_channel) 33 | img = img/127.5 - 1. 34 | cond = cond/127.5 - 1. 35 | return img,cond 36 | 37 | def train(): 38 | data = load_data() 39 | model = CGAN() 40 | 41 | d_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.d_loss, var_list=model.d_vars) 42 | g_opt = tf.train.AdamOptimizer(learning_rate=conf.learning_rate).minimize(model.g_loss, var_list=model.g_vars) 43 | 44 | saver = tf.train.Saver() 45 | 46 | start_time = time.time() 47 | if not os.path.exists(conf.data_path + "/checkpoint"): 48 | os.makedirs(conf.data_path + "/checkpoint") 49 | if not os.path.exists(conf.output_path): 50 | os.makedirs(conf.output_path) 51 | 52 | config = tf.ConfigProto() 53 | config.gpu_options.allow_growth = True 54 | with tf.Session(config=config) as sess: 55 | if conf.model_path_train == "": 56 | sess.run(tf.global_variables_initializer()) 57 | else: 58 | saver.restore(sess, conf.model_path_train) 59 | for epoch in xrange(conf.max_epoch): 60 | counter = 0 61 | train_data = data["train"]() 62 | for img, cond, name in train_data: 63 | img, cond = prepocess_train(img, cond) 64 | _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image:img, model.cond:cond}) 65 | _, m = sess.run([d_opt, model.d_loss], feed_dict={model.image:img, model.cond:cond}) 66 | _, M = sess.run([g_opt, model.g_loss], feed_dict={model.image:img, model.cond:cond}) 67 | counter += 1 68 | if counter % 50 ==0: 69 | print "Epoch [%d], Iteration [%d]: time: %4.4f, d_loss: %.8f, g_loss: %.8f" \ 70 | % (epoch, counter, time.time() - start_time, m, M) 71 | if (epoch + 1) % conf.save_per_epoch == 0: 72 | save_path = saver.save(sess, conf.data_path + "/checkpoint/" + "model_%d.ckpt" % (epoch+1)) 73 | print "Model saved in file: %s" % save_path 74 | test_data = data["test"]() 75 | for img, cond, name in test_data: 76 | pimg, pcond = prepocess_test(img, cond) 77 | gen_img = sess.run(model.gen_img, feed_dict={model.image:pimg, model.cond:pcond}) 78 | gen_img = gen_img.reshape(gen_img.shape[1:]) 79 | gen_img = (gen_img + 1.) * 127.5 80 | image = np.concatenate((gen_img, cond), axis=1).astype(np.int) 81 | imsave(image, conf.output_path + "/%s" % name) 82 | 83 | if __name__ == "__main__": 84 | if len(sys.argv) > 1 and sys.argv[1] == 'gpu=': 85 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 86 | os.environ["CUDA_VISIBLE_DEVICES"]=str(sys.argv[1][4:]) 87 | else: 88 | os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # see issue #152 89 | os.environ["CUDA_VISIBLE_DEVICES"]=str(0) 90 | train() 91 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import tensorflow as tf 3 | import scipy.misc 4 | 5 | def batch_norm(x, scope): 6 | return tf.contrib.layers.batch_norm(x, decay=0.9, updates_collections=None, epsilon=1e-5, scale=True, scope=scope) 7 | 8 | def conv2d(input, output_dim, k_h=4, k_w=4, d_h=2, d_w=2, stddev=0.02, name="conv2d"): 9 | with tf.variable_scope(name): 10 | weight = tf.get_variable('weight', [k_h, k_w, input.get_shape()[-1], output_dim], 11 | initializer=tf.truncated_normal_initializer(stddev=stddev)) 12 | bias = tf.get_variable('bias', [output_dim], initializer=tf.constant_initializer(0.0)) 13 | conv = tf.nn.bias_add(tf.nn.conv2d(input, weight, strides=[1, d_h, d_w, 1], padding='SAME'), bias) 14 | return conv 15 | 16 | def deconv2d(input, output_shape, k_h=4, k_w=4, d_h=2, d_w=2, stddev=0.02, name="deconv2d"): 17 | with tf.variable_scope(name): 18 | weight = tf.get_variable('weight', [k_h, k_w, output_shape[-1], input.get_shape()[-1]], 19 | initializer=tf.random_normal_initializer(stddev=stddev)) 20 | bias = tf.get_variable('bias', [output_shape[-1]], initializer=tf.constant_initializer(0.0)) 21 | deconv = tf.nn.bias_add(tf.nn.conv2d_transpose(input, weight, output_shape=output_shape, strides=[1, d_h, d_w, 1]), bias) 22 | return deconv 23 | 24 | def lrelu(x, leak=0.2): 25 | return tf.maximum(x, leak * x) 26 | 27 | def linear(input, output_size, scope=None, stddev=0.02, bias_start=0.0): 28 | shape = input.get_shape().as_list() 29 | with tf.variable_scope(scope or "Linear"): 30 | weight = tf.get_variable("weight", [shape[1], output_size], tf.float32, 31 | tf.random_normal_initializer(stddev=stddev)) 32 | bias = tf.get_variable("bias", [output_size], 33 | initializer=tf.constant_initializer(bias_start)) 34 | return tf.matmul(input, weight) + bias 35 | 36 | def imread(path): 37 | return scipy.misc.imread(path) 38 | 39 | def imsave(image, path): 40 | return scipy.misc.imsave(path, image) 41 | --------------------------------------------------------------------------------