├── LICENSE ├── README.md ├── assets ├── Fig5-1.png ├── cglo-1.png └── overview.png ├── config.py ├── data_loader.py ├── download.py ├── main.py ├── models.py ├── requirements.txt ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Image & Vision Computing Lab, Institute of Information Science, Academia Sinica 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Mixed-Bg-Fg-Synthesis via Conditional GLO 2 | Official Implementation of [Stingray Detection of Aerial Images Using Augmented Training Images Generated by A Conditional Generative Model](https://arxiv.org/abs/1805.04262) 3 | 4 | Created by [Yi-Min Chou](https://github.com/yyyjoe), Chien-Hung Chen, Keng-Hao Liu, Chu-Song Chen 5 | 6 | ## The overview of data augmentation 7 | 8 | ![alt tag](./assets/overview.png) 9 | 10 | 1. Crop the object patches (rotate and flip to augment data) and randomly crop the background patches to establish the dataset for training conditional GLO. 11 | 2. Training conditional GLO and use the well-trained model to generate the fake stingray image. 12 | 3. Paste the generated stingray patches to original positions. 13 | 4. Use the augmented data generated from C-GLO to train detection models. 14 | 15 | ## Conditional-Generative-Latent-Optimization 16 | 17 | ![alt tag](./assets/cglo-1.png) 18 | 19 | ## Prerequisition 20 | - Python 2.7 21 | - [TensorFlow 1.4.0 or higher](https://github.com/tensorflow/tensorflow) 22 | 23 | ## How to Run 24 | 25 | - Clone the Mixed-Bg-Fg-Synthesis repository: 26 | ```bash 27 | $ git clone --recursive https://github.com/ivclab/ConditionalGLO.git 28 | ``` 29 | 30 | - Install required packages: 31 | ```bash 32 | $ pip install -r requirements.txt 33 | ``` 34 | 35 | - Download Stingray Data: 36 | ```bash 37 | $ python download.py 38 | ``` 39 | 40 | - Run the training code: 41 | ```bash 42 | # The training result will be saved in `./logs/FOLDER_NAME/` 43 | $ python main.py --is_train=True 44 | ``` 45 | 46 | - Run the testing code: 47 | ```bash 48 | # The testing result will be saved in `./logs/FOLDER_NAME_test/` 49 | $ python main.py --is_train=False --load_path=FOLDER_NAME 50 | ``` 51 | 52 | ## Experimental Results 53 | Original background image(top), mixed background and foreground synthesis generated by C-GLO (bottom) 54 | ![alt tag](./assets/Fig5-1.png) 55 | 56 | 57 | ## Citation 58 | Please cite following paper if these codes help your research: 59 | 60 | @inproceedings{chou2018stingray, 61 | title={Stingray Detection of Aerial Images Using Augmented Training Images Generated by a Conditional Generative Model}, 62 | author={Chou, Yi-Min and Chen, Chien-Hung and Liu, Keng-Hao and Chen, Chu-Song}, 63 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops}, 64 | pages={1403--1409}, 65 | year={2018} 66 | } 67 | 68 | @inproceedings{ 69 | title = {Changing Background to Foreground: An Augmentation Method Based on Conditional Generative Network for Stingray Detection}, 70 | Author = {Chou, Yi-Min and Chen, Chien-Hung and Liu, Keng-Hao and Chen, Chu-Song}, 71 | booktitle = {IEEE International Conference on Image Processing, ICIP}, 72 | year = {2018} 73 | } 74 | 75 | ## Contact 76 | Please feel free to leave suggestions or comments to [Yi-Min Chou](https://github.com/yyyjoe), Chien-Hung Chen(redsword26@iis.sinica.edu.tw), Keng-Hao Liu(keng3@mail.nsysu.edu.tw), Chu-Song Chen(song@iis.sinica.edu.tw) 77 | -------------------------------------------------------------------------------- /assets/Fig5-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivclab/ConditionalGLO/0c9f6c95e479884fab4e04505615e24f7eec6185/assets/Fig5-1.png -------------------------------------------------------------------------------- /assets/cglo-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivclab/ConditionalGLO/0c9f6c95e479884fab4e04505615e24f7eec6185/assets/cglo-1.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ivclab/ConditionalGLO/0c9f6c95e479884fab4e04505615e24f7eec6185/assets/overview.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The BEGAN-tensorflow Authors(Taehoon Kim). All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # MIT License 16 | # 17 | # Modifications copyright (c) 2018 Image & Vision Computing Lab, Institute of Information Science, Academia Sinica 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | # ============================================================================== 37 | import argparse 38 | 39 | def str2bool(v): 40 | return v.lower() in ('true', '1') 41 | 42 | arg_lists = [] 43 | parser = argparse.ArgumentParser() 44 | 45 | def add_argument_group(name): 46 | arg = parser.add_argument_group(name) 47 | arg_lists.append(arg) 48 | return arg 49 | 50 | # Network 51 | net_arg = add_argument_group('Network') 52 | net_arg.add_argument('--input_scale_size', type=int, default=64, 53 | help='input image will be resized with the given value as width and height') 54 | net_arg.add_argument('--z_dim', type=int, default=128) 55 | 56 | # Data 57 | data_arg = add_argument_group('Data') 58 | data_arg.add_argument('--dataset', type=str, default='marine') 59 | data_arg.add_argument('--split', type=str, default='train') 60 | data_arg.add_argument('--batch_size', type=int, default=16) 61 | 62 | 63 | # Training / test parameters 64 | train_arg = add_argument_group('Training') 65 | train_arg.add_argument('--is_train', type=str2bool, default=True) 66 | train_arg.add_argument('--max_step', type=int, default=500000) 67 | train_arg.add_argument('--lr_update_step', type=int, default=200000) 68 | train_arg.add_argument('--lr_lower_boundary', type=float, default=0.000001) 69 | train_arg.add_argument('--z_lr', type=float, default=0.0008) 70 | train_arg.add_argument('--g_lr', type=float, default=0.00008) 71 | train_arg.add_argument('--beta1', type=float, default=0.5) 72 | train_arg.add_argument('--beta2', type=float, default=0.999) 73 | 74 | 75 | # Misc 76 | misc_arg = add_argument_group('Misc') 77 | misc_arg.add_argument('--load_path', type=str, default='') 78 | misc_arg.add_argument('--log_step', type=int, default=100) 79 | misc_arg.add_argument('--log_dir', type=str, default='logs') 80 | misc_arg.add_argument('--data_dir', type=str, default='data') 81 | misc_arg.add_argument('--random_seed', type=int, default=123) 82 | misc_arg.add_argument('--p_data_dir', type=str, default='./data/string_data/training_p') 83 | misc_arg.add_argument('--n_data_dir', type=str, default='./data/string_data/n_patches') 84 | 85 | def get_config(): 86 | config, unparsed = parser.parse_known_args() 87 | data_format = 'NCHW' 88 | setattr(config, 'data_format', data_format) 89 | return config, unparsed 90 | -------------------------------------------------------------------------------- /data_loader.py: -------------------------------------------------------------------------------- 1 | # MIT Licens 2 | # 3 | # Copyright (c) 2018 Image & Vision Computing Lab, Institute of Information Science, Academia Sinica 4 | # 5 | # Permission is hereby granted, free of charge, to any person obtaining a copy 6 | # of this software and associated documentation files (the "Software"), to deal 7 | # in the Software without restriction, including without limitation the rights 8 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | # copies of the Software, and to permit persons to whom the Software is 10 | # furnished to do so, subject to the following conditions: 11 | # 12 | # The above copyright notice and this permission notice shall be included in all 13 | # copies or substantial portions of the Software. 14 | # 15 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | # SOFTWARE. 22 | # ============================================================================== 23 | import numpy as np 24 | import tensorflow as tf 25 | from scipy.misc import imresize 26 | import os 27 | import skimage.io as io 28 | import warnings 29 | from tqdm import trange 30 | warnings.simplefilter(action='ignore', category=FutureWarning) 31 | 32 | def data_loader(config): 33 | dir_path_p = config.p_data_dir 34 | dir_p = os.listdir(dir_path_p) 35 | dir_p.sort() 36 | 37 | dir_path_n = config.n_data_dir 38 | dir_n = os.listdir(dir_path_n) 39 | dir_n.sort() 40 | 41 | p_num = len(dir_p) #30496 42 | n_num = len(dir_n) #7664 43 | size = config.input_scale_size 44 | img = [] 45 | label = [] 46 | print('Loading data...') 47 | print('Loading positive data') 48 | for i in trange(p_num): 49 | data = io.imread(dir_path_p + os.sep + dir_p[i]) 50 | data = imresize(data,[size,size,data.shape[2]]) 51 | img.append(data) 52 | label.append(1) 53 | 54 | print('Loading negative data') 55 | for i in trange(n_num): 56 | data = io.imread(dir_path_n + os.sep + dir_n[i]) 57 | data = imresize(data,[size,size,data.shape[2]]) 58 | img.append(data) 59 | label.append(0) 60 | 61 | label = np.array(label,dtype = np.float32).reshape([-1,1]) 62 | img = nhwc_to_nchw(np.array(img,dtype = np.float32)) 63 | 64 | return img,label,p_num,n_num 65 | 66 | 67 | 68 | def nhwc_to_nchw(x): 69 | return np.transpose(x, [0, 3, 1, 2]) 70 | -------------------------------------------------------------------------------- /download.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The BEGAN-tensorflow Authors(Taehoon Kim). All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # MIT License 16 | # 17 | # Modifications copyright (c) 2018 Image & Vision Computing Lab, Institute of Information Science, Academia Sinica 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | # ============================================================================== 37 | import requests 38 | import zipfile 39 | import os 40 | def download_file_from_google_drive(id, destination): 41 | URL = "https://docs.google.com/uc?export=download" 42 | 43 | session = requests.Session() 44 | 45 | response = session.get(URL, params = { 'id' : id }, stream = True) 46 | token = get_confirm_token(response) 47 | 48 | if token: 49 | params = { 'id' : id, 'confirm' : token } 50 | response = session.get(URL, params = params, stream = True) 51 | 52 | save_response_content(response, destination) 53 | 54 | def get_confirm_token(response): 55 | for key, value in response.cookies.items(): 56 | if key.startswith('download_warning'): 57 | return value 58 | 59 | return None 60 | 61 | def save_response_content(response, destination): 62 | CHUNK_SIZE = 32768 63 | 64 | with open(destination, "wb") as f: 65 | for chunk in response.iter_content(CHUNK_SIZE): 66 | if chunk: # filter out keep-alive new chunks 67 | f.write(chunk) 68 | 69 | if __name__ == "__main__": 70 | path = './data' 71 | if not os.path.exists(path): 72 | os.makedirs(path) 73 | 74 | path = './logs' 75 | if not os.path.exists(path): 76 | os.makedirs(path) 77 | 78 | file_id = '1sCkWbk3RypXNIZqGaSRwb8BfYDYG6_Do' 79 | destination = './data/data.zip' 80 | download_file_from_google_drive(file_id, destination) 81 | with zipfile.ZipFile(destination) as zf: 82 | zip_dir = zf.namelist()[0] 83 | zf.extractall('./data') 84 | os.remove(destination) 85 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The BEGAN-tensorflow Authors(Taehoon Kim). All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # MIT License 16 | # 17 | # Modifications copyright (c) 2018 Image & Vision Computing Lab, Institute of Information Science, Academia Sinica 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | # ============================================================================== 37 | import numpy as np 38 | import tensorflow as tf 39 | from trainer import Trainer 40 | from config import get_config 41 | from utils import prepare_dirs_and_logger, save_config 42 | 43 | def main(config): 44 | prepare_dirs_and_logger(config) 45 | 46 | rng = np.random.RandomState(config.random_seed) 47 | tf.set_random_seed(config.random_seed) 48 | 49 | if config.is_train: 50 | data_path = config.data_path 51 | batch_size = config.batch_size 52 | 53 | else: 54 | setattr(config, 'batch_size', 1) 55 | 56 | 57 | trainer = Trainer(config) 58 | 59 | if config.is_train: 60 | save_config(config) 61 | trainer.train() 62 | else: 63 | if not config.load_path: 64 | raise Exception("[!] You should specify `load_path` to load a pretrained model") 65 | trainer.test() 66 | 67 | if __name__ == "__main__": 68 | config, unparsed = get_config() 69 | main(config) 70 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The BEGAN-tensorflow Authors(Taehoon Kim). All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # MIT License 16 | # 17 | # Modifications copyright (c) 2018 Image & Vision Computing Lab, Institute of Information Science, Academia Sinica 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | # ============================================================================== 37 | import numpy as np 38 | import tensorflow as tf 39 | slim = tf.contrib.slim 40 | 41 | def Glo_Generator(x,condition,batch_size,z_dim,reuse): 42 | 43 | x_dim = z_dim 44 | condition_dim = 1 45 | 46 | 47 | with tf.variable_scope('G1',reuse=reuse) as g1: 48 | weights_g1 = { 49 | 50 | "w1_g1" : tf.get_variable("w1_g1",[x_dim + condition_dim , 4*4*1024]), #4*4*1024 51 | "w2_g1" : tf.get_variable("w2_g1",[5, 5, 512, 1024]), #8*8*512 52 | "w3_g1" : tf.get_variable("w3_g1",[5, 5, 256, 512]), #16*16*256 53 | "w4_g1" : tf.get_variable("w4_g1",[5, 5, 128, 256]), #32*32*128 54 | "w5_g1" : tf.get_variable("w5_g1",[5, 5, 3, 128]), #64*64*3 55 | } 56 | 57 | biases_g1 = { 58 | 59 | "b1_g1" : tf.get_variable("b1_g1", [4*4*1024]), 60 | "b2_g1" : tf.get_variable("b2_g1", [512]), 61 | "b3_g1" : tf.get_variable("b3_g1", [256]), 62 | "b4_g1" : tf.get_variable("b4_g1", [128]), 63 | "b5_g1" : tf.get_variable("b5_g1", [3]), 64 | } 65 | 66 | 67 | g1_out1 = tf.add(tf.matmul(tf.concat([x,condition],1), weights_g1["w1_g1"]), biases_g1["b1_g1"]) #(16,8*8*512) 68 | g1_out1 = tf.reshape(g1_out1,[batch_size,4,4,1024]) 69 | 70 | output_shape_g2 = tf.stack([batch_size, 8, 8, 512]) 71 | g1_out2 = tf.nn.relu(slim.batch_norm(tf.add(deconv2d(g1_out1, weights_g1["w2_g1"], output_shape_g2), biases_g1["b2_g1"]))) 72 | 73 | output_shape_g3 = tf.stack([batch_size, 16, 16, 256]) 74 | g1_out3 = tf.nn.relu(slim.batch_norm(tf.add(deconv2d(g1_out2, weights_g1["w3_g1"], output_shape_g3), biases_g1["b3_g1"]))) 75 | 76 | output_shape_g4 = tf.stack([batch_size, 32, 32, 128]) 77 | g1_out4 = tf.nn.relu(slim.batch_norm(tf.add(deconv2d(g1_out3, weights_g1["w4_g1"], output_shape_g4), biases_g1["b4_g1"]))) 78 | 79 | output_shape_g5 = tf.stack([batch_size, 64, 64, 3]) 80 | g1_out5 = tf.nn.tanh(tf.add(deconv2d(g1_out4, weights_g1["w5_g1"], output_shape_g5), biases_g1["b5_g1"])) 81 | g1_out5 = tf.transpose(g1_out5,[0,3,1,2]) 82 | 83 | variables_g1 = tf.contrib.framework.get_trainable_variables(g1) 84 | return g1_out5,variables_g1 85 | 86 | def deconv2d(x, W, output_shape): 87 | return tf.nn.conv2d_transpose(x, W, output_shape, strides = [1, 2, 2, 1], padding = 'SAME') 88 | 89 | 90 | def nchw_to_nhwc(x): 91 | return tf.transpose(x, [0, 2, 3, 1]) 92 | 93 | def nhwc_to_nchw(x): 94 | return tf.transpose(x, [0, 3, 1, 2]) 95 | 96 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow==3.1.2 2 | scikit-image==0.10.1 3 | tqdm==4.11.2 4 | requests==2.18.1 5 | 6 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The BEGAN-tensorflow Authors(Taehoon Kim). All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | # MIT License 16 | # 17 | # Modifications copyright (c) 2018 Image & Vision Computing Lab, Institute of Information Science, Academia Sinica 18 | # 19 | # Permission is hereby granted, free of charge, to any person obtaining a copy 20 | # of this software and associated documentation files (the "Software"), to deal 21 | # in the Software without restriction, including without limitation the rights 22 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 23 | # copies of the Software, and to permit persons to whom the Software is 24 | # furnished to do so, subject to the following conditions: 25 | # 26 | # The above copyright notice and this permission notice shall be included in all 27 | # copies or substantial portions of the Software. 28 | # 29 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 30 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 31 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 32 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 33 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 34 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 35 | # SOFTWARE. 36 | # ============================================================================== 37 | from __future__ import print_function 38 | 39 | import os 40 | import numpy as np 41 | from models import * 42 | from utils import save_image 43 | from tqdm import trange 44 | from data_loader import data_loader 45 | 46 | def to_nhwc(image, data_format): 47 | if data_format == 'NCHW': 48 | new_image = nchw_to_nhwc(image) 49 | else: 50 | new_image = image 51 | return new_image 52 | 53 | def norm_img(image, data_format=None): 54 | image = image/127.5 - 1. 55 | if data_format: 56 | image = to_nhwc(image, data_format) 57 | return image 58 | 59 | def denorm_img(norm, data_format): 60 | return tf.clip_by_value(to_nhwc((norm + 1)*127.5, data_format), 0, 255) 61 | 62 | def switch_condition(condition): 63 | out = (-1*((2*condition)-1)+2)/2 64 | return out 65 | 66 | class Trainer(object): 67 | def __init__(self, config): 68 | self.config = config 69 | self.beta1 = config.beta1 70 | self.beta2 = config.beta2 71 | self.batch_size = config.batch_size 72 | 73 | self.step = tf.Variable(0, name='step', trainable=False) 74 | self.g_lr = tf.Variable(config.g_lr, name='g_lr') 75 | self.z_lr = tf.Variable(config.z_lr, name='d_lr') 76 | self.g_lr_update = tf.assign(self.g_lr, tf.maximum(self.g_lr * 0.5, config.lr_lower_boundary), name='g_lr_update') 77 | self.z_lr_update = tf.assign(self.z_lr, tf.maximum(self.z_lr * 0.5, config.lr_lower_boundary), name='z_lr_update') 78 | self.img,self.condition,self.p_num,self.n_num = data_loader(config) 79 | self.data_num = self.p_num + self.n_num 80 | self.z_dim = config.z_dim 81 | self.model_dir = config.model_dir 82 | self.load_path = config.load_path 83 | self.data_format = config.data_format 84 | self.start_step = 0 85 | self.log_step = config.log_step 86 | self.max_step = config.max_step 87 | self.lr_update_step = config.lr_update_step 88 | self.build_model() 89 | self.saver = tf.train.Saver() 90 | self.summary_writer = tf.summary.FileWriter(self.model_dir) 91 | sv = tf.train.Supervisor(logdir=self.model_dir, 92 | is_chief=True, 93 | saver=self.saver, 94 | summary_op=None, 95 | summary_writer=self.summary_writer, 96 | save_model_secs=300, 97 | global_step=self.step, 98 | ready_for_local_init_op=None) 99 | 100 | gpu_options = tf.GPUOptions(allow_growth=True) 101 | sess_config = tf.ConfigProto(allow_soft_placement=True, 102 | gpu_options=gpu_options) 103 | 104 | self.sess = sv.prepare_or_wait_for_session(config=sess_config) 105 | 106 | def train(self): 107 | 108 | for step in trange(self.start_step, self.max_step): 109 | if step%(self.data_num/self.batch_size)==0: 110 | self.sess.run(self.update_latent) 111 | 112 | batch = np.random.randint(0,self.data_num,size=[self.batch_size]) 113 | self.sess.run(self.g_optim,feed_dict={self.INDEX:batch,self.X_REAL:self.img[batch],self.CONDITION:self.condition[batch]}) 114 | self.sess.run(self.z_optim,feed_dict={self.INDEX:batch,self.X_REAL:self.img[batch],self.CONDITION:self.condition[batch]}) 115 | 116 | if step % self.log_step == 0 : 117 | loss,result = self.sess.run([self.loss,self.summary_op],feed_dict={self.INDEX:batch,self.X_REAL:self.img[batch],self.CONDITION:self.condition[batch]}) 118 | print("[{:6d}/{}] Loss: {:.6f} ". \ 119 | format(step, self.max_step,loss)) 120 | self.summary_writer.add_summary(result, step) 121 | self.summary_writer.flush() 122 | 123 | 124 | if step % (self.log_step * 100 ) == 0: 125 | x = self.generate(batch,self.model_dir, idx=step) 126 | x_real = self.generate_real(batch,self.model_dir, idx=step) 127 | 128 | if step % self.lr_update_step == self.lr_update_step - 1: 129 | self.sess.run([self.g_lr_update, self.z_lr_update]) 130 | 131 | 132 | def build_model(self): 133 | self.INDEX = tf.placeholder(tf.int32,[self.batch_size]) 134 | self.X_REAL = tf.placeholder(tf.float32,shape=(self.batch_size,3,self.config.input_scale_size,self.config.input_scale_size)) 135 | x_real = norm_img(self.X_REAL) 136 | 137 | scale = tf.sqrt(float(self.z_dim)) 138 | epsilon = 1e-07 139 | self.latent_z = tf.Variable(tf.random_normal([self.data_num,self.z_dim])/scale, dtype=tf.float32) 140 | mean,variance = tf.nn.moments(self.latent_z,[0,1]) 141 | self.update_latent = tf.assign(self.latent_z,tf.nn.batch_normalization(self.latent_z, mean, variance, 0, 1/scale, epsilon)) 142 | 143 | self.CONDITION = tf.placeholder(tf.float32,[None,1]) 144 | look_up = tf.gather(self.latent_z,self.INDEX,axis=0) 145 | G,self.G_var = Glo_Generator(look_up,self.CONDITION,self.batch_size,self.z_dim,reuse=False) 146 | 147 | optimizer = tf.train.AdamOptimizer 148 | g_optimizer, z_optimizer = optimizer(self.g_lr,beta1=self.beta1,beta2=self.beta2), optimizer(self.z_lr,beta1=self.beta1,beta2=self.beta2) 149 | 150 | self.loss = tf.reduce_mean(tf.abs(x_real - G)) 151 | 152 | self.g_optim = g_optimizer.minimize(self.loss, var_list=self.G_var) 153 | self.z_optim = z_optimizer.minimize(self.loss, var_list=self.latent_z) 154 | 155 | self.G = denorm_img(G, self.data_format) 156 | self.y = denorm_img(x_real, self.data_format) 157 | 158 | self.summary_op = tf.summary.merge([ 159 | tf.summary.scalar('loss', self.loss) 160 | ]) 161 | 162 | def test(self): 163 | dir_path_n = self.config.n_data_dir 164 | dir_n = os.listdir(dir_path_n) 165 | dir_n.sort() 166 | for i in range(len(dir_n)): 167 | dir_n[i]=dir_n[i].split('.jpg') 168 | path = self.model_dir+'_test/' 169 | if not os.path.exists(path): 170 | os.makedirs(path) 171 | 172 | for step in range(self.p_num/self.batch_size, self.data_num/self.batch_size): 173 | batch = np.arange(self.batch_size) + (self.batch_size*step) 174 | x_condi = self.generate_condition(batch,path, idx=step,is_train=False,name=dir_n[step-self.p_num/self.batch_size][0]) 175 | x_real = self.generate_real(batch,path, idx=step,is_train=False,name=dir_n[step-self.p_num/self.batch_size][0]) 176 | 177 | def generate(self,input,root_path=None, path=None, idx=None, is_train=True,name=None): 178 | 179 | x = self.sess.run(self.G,feed_dict={self.CONDITION:self.condition[input],self.INDEX:input}) 180 | if path is None and is_train: 181 | path = os.path.join(root_path, '{}_G.png'.format(idx)) 182 | save_image(x, path) 183 | print("[*] Samples saved: {}".format(path)) 184 | return x 185 | 186 | def generate_condition(self,input,root_path=None, path=None, idx=None, is_train=True,name=None): 187 | # change conditional label 188 | x = self.sess.run(self.G,feed_dict={self.CONDITION:switch_condition(self.condition[input]),self.INDEX:input}) 189 | if path is None and is_train: 190 | path = os.path.join(root_path, '{}_G_condition.png'.format(idx)) 191 | save_image(x, path) 192 | print("[*] Samples saved: {}".format(path)) 193 | else: 194 | path = os.path.join(root_path, '{}.png'.format(name)) 195 | save_image(x, path, nrow=1, padding=0,is_train=is_train) 196 | print("[*] Samples saved: {}".format(path)) 197 | return x 198 | 199 | 200 | def generate_real(self,input, root_path=None, path=None, idx=None, is_train=True,name=None): 201 | x = self.sess.run(self.y,feed_dict={self.X_REAL:self.img[input]}) 202 | if path is None and is_train: 203 | path = os.path.join(root_path, '{}_G_real.png'.format(idx)) 204 | save_image(x, path) 205 | print("[*] Samples saved: {}".format(path)) 206 | else: 207 | path = os.path.join(root_path, '{}_real.png'.format(name)) 208 | save_image(x, path, nrow=1, padding=0,is_train=is_train) 209 | print("[*] Samples saved: {}".format(path)) 210 | return x 211 | 212 | 213 | 214 | 215 | 216 | 217 | 218 | 219 | 220 | 221 | 222 | 223 | 224 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The BEGAN-tensorflow Authors(Taehoon Kim). All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | from __future__ import print_function 16 | 17 | import os 18 | import math 19 | import json 20 | import logging 21 | import numpy as np 22 | from PIL import Image 23 | from datetime import datetime 24 | 25 | def prepare_dirs_and_logger(config): 26 | formatter = logging.Formatter("%(asctime)s:%(levelname)s::%(message)s") 27 | logger = logging.getLogger() 28 | 29 | for hdlr in logger.handlers: 30 | logger.removeHandler(hdlr) 31 | 32 | handler = logging.StreamHandler() 33 | handler.setFormatter(formatter) 34 | 35 | logger.addHandler(handler) 36 | 37 | if config.load_path: 38 | if config.load_path.startswith(config.log_dir): 39 | config.model_dir = config.load_path 40 | else: 41 | if config.load_path.startswith(config.dataset): 42 | config.model_name = config.load_path 43 | else: 44 | config.model_name = "{}_{}".format(config.dataset, config.load_path) 45 | else: 46 | config.model_name = "{}_{}".format(config.dataset, get_time()) 47 | 48 | if not hasattr(config, 'model_dir'): 49 | config.model_dir = os.path.join(config.log_dir, config.model_name) 50 | config.data_path = os.path.join(config.data_dir, config.dataset) 51 | 52 | for path in [config.log_dir, config.data_dir, config.model_dir]: 53 | if not os.path.exists(path): 54 | os.makedirs(path) 55 | 56 | def get_time(): 57 | return datetime.now().strftime("%m%d_%H%M%S") 58 | 59 | def save_config(config): 60 | param_path = os.path.join(config.model_dir, "params.json") 61 | 62 | print("[*] MODEL dir: %s" % config.model_dir) 63 | print("[*] PARAM path: %s" % param_path) 64 | 65 | with open(param_path, 'w') as fp: 66 | json.dump(config.__dict__, fp, indent=4, sort_keys=True) 67 | 68 | def rank(array): 69 | return len(array.shape) 70 | 71 | def make_grid(tensor, nrow=8, padding=2, 72 | normalize=False, scale_each=False,is_train=True): 73 | """Code based on https://github.com/pytorch/vision/blob/master/torchvision/utils.py""" 74 | if is_train: 75 | scale = 1 76 | else: 77 | scale = 0 78 | nmaps = tensor.shape[0] 79 | xmaps = min(nrow, nmaps) 80 | ymaps = int(math.ceil(float(nmaps) / xmaps)) 81 | height, width = int(tensor.shape[1] + padding), int(tensor.shape[2] + padding) 82 | 83 | grid = np.zeros([height * ymaps + scale + padding // 2, width * xmaps + scale + padding // 2, 3], dtype=np.uint8) 84 | k = 0 85 | for y in range(ymaps): 86 | for x in range(xmaps): 87 | if k >= nmaps: 88 | break 89 | h, h_width = y * height + scale + padding // 2, height - padding 90 | w, w_width = x * width + scale + padding // 2, width - padding 91 | 92 | grid[h:h+h_width, w:w+w_width] = tensor[k] 93 | k = k + 1 94 | return grid 95 | 96 | def save_image(tensor, filename, nrow=8, padding=2, 97 | normalize=False, scale_each=False,is_train=True): 98 | ndarr = make_grid(tensor, nrow=nrow, padding=padding, 99 | normalize=normalize, scale_each=scale_each,is_train=is_train) 100 | im = Image.fromarray(ndarr) 101 | im.save(filename) 102 | 103 | --------------------------------------------------------------------------------