├── .gitignore ├── Downloads ├── SR_BSD100.tar.xz ├── SR_BSD100.zip ├── user study.tar.xz └── user study.zip ├── Input ├── Editing │ ├── stone_edit.png │ ├── stone_edit_mask.png │ ├── tree_edit.png │ └── tree_edit_mask.png ├── Harmonization │ ├── starry_night_naive.png │ ├── starry_night_naive_mask.png │ ├── tree.jpg │ └── tree_mask.jpg ├── Images │ ├── 33039_LR.png │ ├── balloons.png │ ├── birds.png │ ├── colusseum.png │ ├── cows.png │ ├── lightning1.png │ ├── mountains.jpg │ ├── mountains3.png │ ├── seascape.png │ ├── starry_night.png │ ├── stone.png │ ├── tree.png │ ├── trees3.png │ ├── volacano.png │ ├── wild_bush.jpg │ └── zebra.png └── Paint │ ├── cows.png │ └── trees1.png ├── LICENSE.txt ├── README.md ├── SIFID.npy ├── SIFID ├── inception.py └── sifid_score.py ├── SR.py ├── SinGAN ├── __pycache__ │ ├── functions.cpython-36.pyc │ ├── imresize.cpython-36.pyc │ ├── manipulate.cpython-36.pyc │ ├── models.cpython-36.pyc │ └── training.cpython-36.pyc ├── functions.py ├── imresize.py ├── manipulate.py ├── models.py └── training.py ├── animation.py ├── config.py ├── config.pyc ├── editing.py ├── harmonization.py ├── imgs ├── manipulation.PNG └── teaser.PNG ├── main_train.py ├── paint2image.py ├── random_samples.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | # Don't track content of these folders 2 | Output/ 3 | TrainedModels/ 4 | __pycache__/ 5 | .idea 6 | -------------------------------------------------------------------------------- /Downloads/SR_BSD100.tar.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Downloads/SR_BSD100.tar.xz -------------------------------------------------------------------------------- /Downloads/SR_BSD100.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Downloads/SR_BSD100.zip -------------------------------------------------------------------------------- /Downloads/user study.tar.xz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Downloads/user study.tar.xz -------------------------------------------------------------------------------- /Downloads/user study.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Downloads/user study.zip -------------------------------------------------------------------------------- /Input/Editing/stone_edit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Editing/stone_edit.png -------------------------------------------------------------------------------- /Input/Editing/stone_edit_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Editing/stone_edit_mask.png -------------------------------------------------------------------------------- /Input/Editing/tree_edit.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Editing/tree_edit.png -------------------------------------------------------------------------------- /Input/Editing/tree_edit_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Editing/tree_edit_mask.png -------------------------------------------------------------------------------- /Input/Harmonization/starry_night_naive.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Harmonization/starry_night_naive.png -------------------------------------------------------------------------------- /Input/Harmonization/starry_night_naive_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Harmonization/starry_night_naive_mask.png -------------------------------------------------------------------------------- /Input/Harmonization/tree.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Harmonization/tree.jpg -------------------------------------------------------------------------------- /Input/Harmonization/tree_mask.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Harmonization/tree_mask.jpg -------------------------------------------------------------------------------- /Input/Images/33039_LR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/33039_LR.png -------------------------------------------------------------------------------- /Input/Images/balloons.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/balloons.png -------------------------------------------------------------------------------- /Input/Images/birds.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/birds.png -------------------------------------------------------------------------------- /Input/Images/colusseum.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/colusseum.png -------------------------------------------------------------------------------- /Input/Images/cows.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/cows.png -------------------------------------------------------------------------------- /Input/Images/lightning1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/lightning1.png -------------------------------------------------------------------------------- /Input/Images/mountains.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/mountains.jpg -------------------------------------------------------------------------------- /Input/Images/mountains3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/mountains3.png -------------------------------------------------------------------------------- /Input/Images/seascape.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/seascape.png -------------------------------------------------------------------------------- /Input/Images/starry_night.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/starry_night.png -------------------------------------------------------------------------------- /Input/Images/stone.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/stone.png -------------------------------------------------------------------------------- /Input/Images/tree.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/tree.png -------------------------------------------------------------------------------- /Input/Images/trees3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/trees3.png -------------------------------------------------------------------------------- /Input/Images/volacano.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/volacano.png -------------------------------------------------------------------------------- /Input/Images/wild_bush.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/wild_bush.jpg -------------------------------------------------------------------------------- /Input/Images/zebra.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Images/zebra.png -------------------------------------------------------------------------------- /Input/Paint/cows.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Paint/cows.png -------------------------------------------------------------------------------- /Input/Paint/trees1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/Input/Paint/trees1.png -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | SinGAN: Learning a Generative Model from a Single Natural Image 4 | ICCV 2019 5 | Tamar Rott Shaham, Tali Dekel, Tomer Michaeli 6 | 7 | Copyright (c) 2019 Tamar Rott Shaham, Tali Dekel, Tomer Michaeli 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SinGAN 2 | 3 | [Project](https://tamarott.github.io/SinGAN.htm) | [Arxiv](https://arxiv.org/pdf/1905.01164.pdf) | [CVF](http://openaccess.thecvf.com/content_ICCV_2019/papers/Shaham_SinGAN_Learning_a_Generative_Model_From_a_Single_Natural_Image_ICCV_2019_paper.pdf) | [Supplementary materials](https://openaccess.thecvf.com/content_ICCV_2019/supplemental/Shaham_SinGAN_Learning_a_ICCV_2019_supplemental.pdf) | [Talk (ICCV`19)](https://youtu.be/mdAcPe74tZI?t=3191) 4 | ### Official pytorch implementation of the paper: "SinGAN: Learning a Generative Model from a Single Natural Image" 5 | #### ICCV 2019 Best paper award (Marr prize) 6 | 7 | 8 | ## Random samples from a *single* image 9 | With SinGAN, you can train a generative model from a single natural image, and then generate random samples from the given image, for example: 10 | 11 | ![](imgs/teaser.PNG) 12 | 13 | 14 | ## SinGAN's applications 15 | SinGAN can be also used for a line of image manipulation tasks, for example: 16 | ![](imgs/manipulation.PNG) 17 | This is done by injecting an image to the already trained model. See section 4 in our [paper](https://arxiv.org/pdf/1905.01164.pdf) for more details. 18 | 19 | 20 | ### Citation 21 | If you use this code for your research, please cite our paper: 22 | 23 | ``` 24 | @inproceedings{rottshaham2019singan, 25 | title={SinGAN: Learning a Generative Model from a Single Natural Image}, 26 | author={Rott Shaham, Tamar and Dekel, Tali and Michaeli, Tomer}, 27 | booktitle={Computer Vision (ICCV), IEEE International Conference on}, 28 | year={2019} 29 | } 30 | ``` 31 | 32 | ## Code 33 | 34 | ### Install dependencies 35 | 36 | ``` 37 | python -m pip install -r requirements.txt 38 | ``` 39 | 40 | This code was tested with python 3.6, torch 1.4 41 | 42 | Please note: the code currently only supports torch 1.4 or earlier because of the optimization scheme. 43 | 44 | For later torch versions, you may try this repository: https://github.com/kligvasser/SinGAN (results won't necessarily be identical to the official implementation). 45 | 46 | 47 | ### Train 48 | To train SinGAN model on your own image, put the desired training image under Input/Images, and run 49 | 50 | ``` 51 | python main_train.py --input_name 52 | ``` 53 | 54 | This will also use the resulting trained model to generate random samples starting from the coarsest scale (n=0). 55 | 56 | To run this code on a cpu machine, specify `--not_cuda` when calling `main_train.py` 57 | 58 | ### Random samples 59 | To generate random samples from any starting generation scale, please first train SinGAN model on the desired image (as described above), then run 60 | 61 | ``` 62 | python random_samples.py --input_name --mode random_samples --gen_start_scale 63 | ``` 64 | 65 | pay attention: for using the full model, specify the generation start scale to be 0, to start the generation from the second scale, specify it to be 1, and so on. 66 | 67 | ### Random samples of arbitrary sizes 68 | To generate random samples of arbitrary sizes, please first train SinGAN model on the desired image (as described above), then run 69 | 70 | ``` 71 | python random_samples.py --input_name --mode random_samples_arbitrary_sizes --scale_h --scale_v 72 | ``` 73 | 74 | ### Animation from a single image 75 | 76 | To generate short animation from a single image, run 77 | 78 | ``` 79 | python animation.py --input_name 80 | ``` 81 | 82 | This will automatically start a new training phase with noise padding mode. 83 | 84 | ### Harmonization 85 | 86 | To harmonize a pasted object into an image (See example in Fig. 13 in [our paper](https://arxiv.org/pdf/1905.01164.pdf)), please first train SinGAN model on the desired background image (as described above), then save the naively pasted reference image and it's binary mask under "Input/Harmonization" (see saved images for an example). Run the command 87 | 88 | ``` 89 | python harmonization.py --input_name --ref_name --harmonization_start_scale 90 | 91 | ``` 92 | 93 | Please note that different injection scale will produce different harmonization effects. The coarsest injection scale equals 1. 94 | 95 | ### Editing 96 | 97 | To edit an image, (See example in Fig. 12 in [our paper](https://arxiv.org/pdf/1905.01164.pdf)), please first train SinGAN model on the desired non-edited image (as described above), then save the naive edit as a reference image under "Input/Editing" with a corresponding binary map (see saved images for an example). Run the command 98 | 99 | ``` 100 | python editing.py --input_name --ref_name --editing_start_scale 101 | 102 | ``` 103 | both the masked and unmasked output will be saved. 104 | Here as well, different injection scale will produce different editing effects. The coarsest injection scale equals 1. 105 | 106 | ### Paint to Image 107 | 108 | To transfer a paint into a realistic image (See example in Fig. 11 in [our paper](https://arxiv.org/pdf/1905.01164.pdf)), please first train SinGAN model on the desired image (as described above), then save your paint under "Input/Paint", and run the command 109 | 110 | ``` 111 | python paint2image.py --input_name --ref_name --paint_start_scale 112 | 113 | ``` 114 | Here as well, different injection scale will produce different editing effects. The coarsest injection scale equals 1. 115 | 116 | Advanced option: Specify quantization_flag to be True, to re-train *only* the injection level of the model, to get a on a color-quantized version of upsampled generated images from the previous scale. For some images, this might lead to more realistic results. 117 | 118 | ### Super Resolution 119 | To super resolve an image, please run: 120 | ``` 121 | python SR.py --input_name 122 | ``` 123 | This will automatically train a SinGAN model correspond to 4x upsampling factor (if not exist already). 124 | For different SR factors, please specify it using the parameter `--sr_factor` when calling the function. 125 | SinGAN's results on the BSD100 dataset can be download from the 'Downloads' folder. 126 | 127 | ## Additional Data and Functions 128 | 129 | ### Single Image Fréchet Inception Distance (SIFID score) 130 | To calculate the SIFID between real images and their corresponding fake samples, please run: 131 | ``` 132 | python SIFID/sifid_score.py --path2real --path2fake 133 | ``` 134 | Make sure that each of the fake images file name is identical to its corresponding real image file name. Images should be saved in `.jpg` format. 135 | 136 | ### Super Resolution Results 137 | SinGAN's SR results on the BSD100 dataset can be download from the 'Downloads' folder. 138 | 139 | ### User Study 140 | The data used for the user study can be found in the Downloads folder. 141 | 142 | real folder: 50 real images, randomly picked from the [places database](http://places.csail.mit.edu/) 143 | 144 | fake_high_variance folder: random samples starting from n=N for each of the real images 145 | 146 | fake_mid_variance folder: random samples starting from n=N-1 for each of the real images 147 | 148 | For additional details please see section 3.1 in our [paper](https://arxiv.org/pdf/1905.01164.pdf) 149 | -------------------------------------------------------------------------------- /SIFID.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/SIFID.npy -------------------------------------------------------------------------------- /SIFID/inception.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | from torchvision import models 4 | 5 | 6 | class InceptionV3(nn.Module): 7 | """Pretrained InceptionV3 network returning feature maps""" 8 | 9 | # Index of default block of inception to return, 10 | # corresponds to output of final average pooling 11 | DEFAULT_BLOCK_INDEX = 3 12 | 13 | # Maps feature dimensionality to their output blocks indices 14 | BLOCK_INDEX_BY_DIM = { 15 | 64: 0, # First max pooling features 16 | 192: 1, # Second max pooling featurs 17 | 768: 2, # Pre-aux classifier features 18 | 2048: 3 # Final average pooling features 19 | } 20 | 21 | def __init__(self, 22 | output_blocks=[DEFAULT_BLOCK_INDEX], 23 | resize_input=False, 24 | normalize_input=True, 25 | requires_grad=False): 26 | """Build pretrained InceptionV3 27 | 28 | Parameters 29 | ---------- 30 | output_blocks : list of int 31 | Indices of blocks to return features of. Possible values are: 32 | - 0: corresponds to output of first max pooling 33 | - 1: corresponds to output of second max pooling 34 | - 2: corresponds to output which is fed to aux classifier 35 | - 3: corresponds to output of final average pooling 36 | resize_input : bool 37 | If true, bilinearly resizes input to width and height 299 before 38 | feeding input to model. As the network without fully connected 39 | layers is fully convolutional, it should be able to handle inputs 40 | of arbitrary size, so resizing might not be strictly needed 41 | normalize_input : bool 42 | If true, scales the input from range (0, 1) to the range the 43 | pretrained Inception network expects, namely (-1, 1) 44 | requires_grad : bool 45 | If true, parameters of the model require gradient. Possibly useful 46 | for finetuning the network 47 | """ 48 | super(InceptionV3, self).__init__() 49 | 50 | self.resize_input = resize_input 51 | self.normalize_input = normalize_input 52 | self.output_blocks = sorted(output_blocks) 53 | self.last_needed_block = max(output_blocks) 54 | 55 | assert self.last_needed_block <= 3, \ 56 | 'Last possible output block index is 3' 57 | 58 | self.blocks = nn.ModuleList() 59 | 60 | inception = models.inception_v3(pretrained=True) 61 | 62 | # Block 0: input to maxpool1 63 | block0 = [ 64 | inception.Conv2d_1a_3x3, 65 | inception.Conv2d_2a_3x3, 66 | inception.Conv2d_2b_3x3, 67 | ] 68 | 69 | 70 | self.blocks.append(nn.Sequential(*block0)) 71 | 72 | # Block 1: maxpool1 to maxpool2 73 | if self.last_needed_block >= 1: 74 | block1 = [ 75 | nn.MaxPool2d(kernel_size=3, stride=2), 76 | inception.Conv2d_3b_1x1, 77 | inception.Conv2d_4a_3x3, 78 | ] 79 | self.blocks.append(nn.Sequential(*block1)) 80 | 81 | # Block 2: maxpool2 to aux classifier 82 | if self.last_needed_block >= 2: 83 | block2 = [ 84 | nn.MaxPool2d(kernel_size=3, stride=2), 85 | inception.Mixed_5b, 86 | inception.Mixed_5c, 87 | inception.Mixed_5d, 88 | inception.Mixed_6a, 89 | inception.Mixed_6b, 90 | inception.Mixed_6c, 91 | inception.Mixed_6d, 92 | inception.Mixed_6e, 93 | ] 94 | self.blocks.append(nn.Sequential(*block2)) 95 | 96 | # Block 3: aux classifier to final avgpool 97 | if self.last_needed_block >= 3: 98 | block3 = [ 99 | inception.Mixed_7a, 100 | inception.Mixed_7b, 101 | inception.Mixed_7c, 102 | ] 103 | self.blocks.append(nn.Sequential(*block3)) 104 | 105 | if self.last_needed_block >= 4: 106 | block4 = [ 107 | nn.AdaptiveAvgPool2d(output_size=(1, 1)) 108 | ] 109 | self.blocks.append(nn.Sequential(*block4)) 110 | 111 | for param in self.parameters(): 112 | param.requires_grad = requires_grad 113 | 114 | def forward(self, inp): 115 | """Get Inception feature maps 116 | 117 | Parameters 118 | ---------- 119 | inp : torch.autograd.Variable 120 | Input tensor of shape Bx3xHxW. Values are expected to be in 121 | range (0, 1) 122 | 123 | Returns 124 | ------- 125 | List of torch.autograd.Variable, corresponding to the selected output 126 | block, sorted ascending by index 127 | """ 128 | outp = [] 129 | x = inp 130 | 131 | if self.resize_input: 132 | x = F.upsample(x, 133 | size=(299, 299), 134 | mode='bilinear', 135 | align_corners=False) 136 | 137 | if self.normalize_input: 138 | x = 2 * x - 1 # Scale from range (0, 1) to range (-1, 1) 139 | 140 | for idx, block in enumerate(self.blocks): 141 | x = block(x) 142 | if idx in self.output_blocks: 143 | outp.append(x) 144 | 145 | if idx == self.last_needed_block: 146 | break 147 | 148 | return outp 149 | -------------------------------------------------------------------------------- /SIFID/sifid_score.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | """Calculates ***Single Image*** Frechet Inception Distance (SIFID) to evalulate Single-Image-GANs 3 | Code was adapted from: 4 | https://github.com/mseitzer/pytorch-fid.git 5 | Which was adapted from the TensorFlow implementation of: 6 | 7 | 8 | https://github.com/bioinf-jku/TTUR 9 | 10 | The FID metric calculates the distance between two distributions of images. 11 | The SIFID calculates the distance between the distribution of deep features of a single real image and a single fake image. 12 | Copyright 2018 Institute of Bioinformatics, JKU Linz 13 | Licensed under the Apache License, Version 2.0 (the "License"); 14 | you may not use this file except in compliance with the License. 15 | You may obtain a copy of the License at 16 | http://www.apache.org/licenses/LICENSE-2.0 17 | Unless required by applicable law or agreed to in writing, software 18 | distributed under the License is distributed on an "AS IS" BASIS, 19 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 20 | See the License for the specific language governing permissions and 21 | limitations under the License. 22 | """ 23 | 24 | import os 25 | import pathlib 26 | from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter 27 | 28 | import numpy as np 29 | import torch 30 | from scipy import linalg 31 | #from scipy.misc import imread 32 | from matplotlib.pyplot import imread 33 | from torch.nn.functional import adaptive_avg_pool2d 34 | 35 | try: 36 | from tqdm import tqdm 37 | except ImportError: 38 | # If not tqdm is not available, provide a mock version of it 39 | def tqdm(x): return x 40 | 41 | from inception import InceptionV3 42 | import torchvision 43 | import numpy 44 | import scipy 45 | import pickle 46 | 47 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 48 | parser.add_argument('--path2real', type=str, help=('Path to the real images')) 49 | parser.add_argument('--path2fake', type=str, help=('Path to generated images')) 50 | parser.add_argument('-c', '--gpu', default='', type=str, help='GPU to use (leave blank for CPU only)') 51 | parser.add_argument('--images_suffix', default='jpg', type=str, help='image file suffix') 52 | 53 | 54 | def get_activations(files, model, batch_size=1, dims=64, 55 | cuda=False, verbose=False): 56 | """Calculates the activations of the pool_3 layer for all images. 57 | 58 | Params: 59 | -- files : List of image files paths 60 | -- model : Instance of inception model 61 | -- batch_size : Batch size of images for the model to process at once. 62 | Make sure that the number of samples is a multiple of 63 | the batch size, otherwise some samples are ignored. This 64 | behavior is retained to match the original FID score 65 | implementation. 66 | -- dims : Dimensionality of features returned by Inception 67 | -- cuda : If set to True, use GPU 68 | -- verbose : If set to True and parameter out_step is given, the number 69 | of calculated batches is reported. 70 | Returns: 71 | -- A numpy array of dimension (num images, dims) that contains the 72 | activations of the given tensor when feeding inception with the 73 | query tensor. 74 | """ 75 | model.eval() 76 | 77 | if len(files) % batch_size != 0: 78 | print(('Warning: number of images is not a multiple of the ' 79 | 'batch size. Some samples are going to be ignored.')) 80 | if batch_size > len(files): 81 | print(('Warning: batch size is bigger than the data size. ' 82 | 'Setting batch size to data size')) 83 | batch_size = len(files) 84 | 85 | n_batches = len(files) // batch_size 86 | n_used_imgs = n_batches * batch_size 87 | 88 | pred_arr = np.empty((n_used_imgs, dims)) 89 | 90 | for i in tqdm(range(n_batches)): 91 | if verbose: 92 | print('\rPropagating batch %d/%d' % (i + 1, n_batches), 93 | end='', flush=True) 94 | start = i * batch_size 95 | end = start + batch_size 96 | 97 | images = np.array([imread(str(f)).astype(np.float32) 98 | for f in files[start:end]]) 99 | 100 | images = images[:,:,:,0:3] 101 | # Reshape to (n_images, 3, height, width) 102 | images = images.transpose((0, 3, 1, 2)) 103 | #images = images[0,:,:,:] 104 | images /= 255 105 | 106 | batch = torch.from_numpy(images).type(torch.FloatTensor) 107 | if cuda: 108 | batch = batch.cuda() 109 | 110 | pred = model(batch)[0] 111 | 112 | # If model output is not scalar, apply global spatial average pooling. 113 | # This happens if you choose a dimensionality not equal 2048. 114 | 115 | #if pred.shape[2] != 1 or pred.shape[3] != 1: 116 | # pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 117 | 118 | 119 | pred_arr = pred.cpu().data.numpy().transpose(0, 2, 3, 1).reshape(batch_size*pred.shape[2]*pred.shape[3],-1) 120 | 121 | 122 | if verbose: 123 | print(' done') 124 | 125 | return pred_arr 126 | 127 | 128 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 129 | """Numpy implementation of the Frechet Distance. 130 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 131 | and X_2 ~ N(mu_2, C_2) is 132 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 133 | 134 | Stable version by Dougal J. Sutherland. 135 | 136 | Params: 137 | -- mu1 : Numpy array containing the activations of a layer of the 138 | inception net (like returned by the function 'get_predictions') 139 | for generated samples. 140 | -- mu2 : The sample mean over activations, precalculated on an 141 | representative data set. 142 | -- sigma1: The covariance matrix over activations for generated samples. 143 | -- sigma2: The covariance matrix over activations, precalculated on an 144 | representative data set. 145 | 146 | Returns: 147 | -- : The Frechet Distance. 148 | """ 149 | 150 | mu1 = np.atleast_1d(mu1) 151 | mu2 = np.atleast_1d(mu2) 152 | 153 | sigma1 = np.atleast_2d(sigma1) 154 | sigma2 = np.atleast_2d(sigma2) 155 | 156 | assert mu1.shape == mu2.shape, \ 157 | 'Training and test mean vectors have different lengths' 158 | assert sigma1.shape == sigma2.shape, \ 159 | 'Training and test covariances have different dimensions' 160 | 161 | diff = mu1 - mu2 162 | 163 | # Product might be almost singular 164 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 165 | if not np.isfinite(covmean).all(): 166 | msg = ('fid calculation produces singular product; ' 167 | 'adding %s to diagonal of cov estimates') % eps 168 | print(msg) 169 | offset = np.eye(sigma1.shape[0]) * eps 170 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 171 | 172 | # Numerical error might give slight imaginary component 173 | if np.iscomplexobj(covmean): 174 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 175 | m = np.max(np.abs(covmean.imag)) 176 | raise ValueError('Imaginary component {}'.format(m)) 177 | covmean = covmean.real 178 | 179 | tr_covmean = np.trace(covmean) 180 | 181 | return (diff.dot(diff) + np.trace(sigma1) + 182 | np.trace(sigma2) - 2 * tr_covmean) 183 | 184 | 185 | def calculate_activation_statistics(files, model, batch_size=1, 186 | dims=64, cuda=False, verbose=False): 187 | """Calculation of the statistics used by the FID. 188 | Params: 189 | -- files : List of image files paths 190 | -- model : Instance of inception model 191 | -- batch_size : The images numpy array is split into batches with 192 | batch size batch_size. A reasonable batch size 193 | depends on the hardware. 194 | -- dims : Dimensionality of features returned by Inception 195 | -- cuda : If set to True, use GPU 196 | -- verbose : If set to True and parameter out_step is given, the 197 | number of calculated batches is reported. 198 | Returns: 199 | -- mu : The mean over samples of the activations of the inception model. 200 | -- sigma : The covariance matrix of the activations of the inception model. 201 | """ 202 | act = get_activations(files, model, batch_size, dims, cuda, verbose) 203 | mu = np.mean(act, axis=0) 204 | sigma = np.cov(act, rowvar=False) 205 | return mu, sigma 206 | 207 | 208 | def _compute_statistics_of_path(files, model, batch_size, dims, cuda): 209 | if path.endswith('.npz'): 210 | f = np.load(path) 211 | m, s = f['mu'][:], f['sigma'][:] 212 | f.close() 213 | else: 214 | path = pathlib.Path(path) 215 | files = sorted(list(path.glob('*.jpg'))+ list(path.glob('*.png'))) 216 | m, s = calculate_activation_statistics(files, model, batch_size, 217 | dims, cuda) 218 | 219 | return m, s 220 | 221 | 222 | def calculate_sifid_given_paths(path1, path2, batch_size, cuda, dims, suffix): 223 | """Calculates the SIFID of two paths""" 224 | 225 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 226 | 227 | model = InceptionV3([block_idx]) 228 | if cuda: 229 | model.cuda() 230 | 231 | path1 = pathlib.Path(path1) 232 | files1 = sorted(list(path1.glob('*.%s' %suffix))) 233 | 234 | path2 = pathlib.Path(path2) 235 | files2 = sorted(list(path2.glob('*.%s' %suffix))) 236 | 237 | fid_values = [] 238 | Im_ind = [] 239 | for i in range(len(files2)): 240 | m1, s1 = calculate_activation_statistics([files1[i]], model, batch_size, dims, cuda) 241 | m2, s2 = calculate_activation_statistics([files2[i]], model, batch_size, dims, cuda) 242 | fid_values.append(calculate_frechet_distance(m1, s1, m2, s2)) 243 | file_num1 = files1[i].name 244 | file_num2 = files2[i].name 245 | Im_ind.append(int(file_num1[:-4])) 246 | Im_ind.append(int(file_num2[:-4])) 247 | return fid_values 248 | 249 | 250 | if __name__ == '__main__': 251 | args = parser.parse_args() 252 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 253 | 254 | path1 = args.path2real 255 | path2 = args.path2fake 256 | suffix = args.images_suffix 257 | 258 | sifid_values = calculate_sifid_given_paths(path1,path2,1,args.gpu!='',64,suffix) 259 | 260 | sifid_values = np.asarray(sifid_values,dtype=np.float32) 261 | numpy.save('SIFID', sifid_values) 262 | print('SIFID: ', sifid_values.mean()) 263 | -------------------------------------------------------------------------------- /SR.py: -------------------------------------------------------------------------------- 1 | from config import get_arguments 2 | from SinGAN.manipulate import * 3 | from SinGAN.training import * 4 | from SinGAN.imresize import imresize 5 | import SinGAN.functions as functions 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = get_arguments() 10 | parser.add_argument('--input_dir', help='input image dir', default='Input/Images') 11 | parser.add_argument('--input_name', help='training image name', default="33039_LR.png")#required=True) 12 | parser.add_argument('--sr_factor', help='super resolution factor', type=float, default=4) 13 | parser.add_argument('--mode', help='task to be done', default='SR') 14 | opt = parser.parse_args() 15 | opt = functions.post_config(opt) 16 | Gs = [] 17 | Zs = [] 18 | reals = [] 19 | NoiseAmp = [] 20 | dir2save = functions.generate_dir2save(opt) 21 | if dir2save is None: 22 | print('task does not exist') 23 | #elif (os.path.exists(dir2save)): 24 | # print("output already exist") 25 | else: 26 | try: 27 | os.makedirs(dir2save) 28 | except OSError: 29 | pass 30 | 31 | mode = opt.mode 32 | in_scale, iter_num = functions.calc_init_scale(opt) 33 | opt.scale_factor = 1 / in_scale 34 | opt.scale_factor_init = 1 / in_scale 35 | opt.mode = 'train' 36 | dir2trained_model = functions.generate_dir2save(opt) 37 | if (os.path.exists(dir2trained_model)): 38 | Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) 39 | opt.mode = mode 40 | else: 41 | print('*** Train SinGAN for SR ***') 42 | real = functions.read_image(opt) 43 | opt.min_size = 18 44 | real = functions.adjust_scales2image_SR(real, opt) 45 | train(opt, Gs, Zs, reals, NoiseAmp) 46 | opt.mode = mode 47 | print('%f' % pow(in_scale, iter_num)) 48 | Zs_sr = [] 49 | reals_sr = [] 50 | NoiseAmp_sr = [] 51 | Gs_sr = [] 52 | real = reals[-1] # read_image(opt) 53 | real_ = real 54 | opt.scale_factor = 1 / in_scale 55 | opt.scale_factor_init = 1 / in_scale 56 | for j in range(1, iter_num + 1, 1): 57 | real_ = imresize(real_, pow(1 / opt.scale_factor, 1), opt) 58 | reals_sr.append(real_) 59 | Gs_sr.append(Gs[-1]) 60 | NoiseAmp_sr.append(NoiseAmp[-1]) 61 | z_opt = torch.full(real_.shape, 0, device=opt.device) 62 | m = nn.ZeroPad2d(5) 63 | z_opt = m(z_opt) 64 | Zs_sr.append(z_opt) 65 | out = SinGAN_generate(Gs_sr, Zs_sr, reals_sr, NoiseAmp_sr, opt, in_s=reals_sr[0], num_samples=1) 66 | out = out[:, :, 0:int(opt.sr_factor * reals[-1].shape[2]), 0:int(opt.sr_factor * reals[-1].shape[3])] 67 | dir2save = functions.generate_dir2save(opt) 68 | plt.imsave('%s/%s_HR.png' % (dir2save,opt.input_name[:-4]), functions.convert_image_np(out.detach()), vmin=0, vmax=1) 69 | 70 | 71 | 72 | 73 | -------------------------------------------------------------------------------- /SinGAN/__pycache__/functions.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/SinGAN/__pycache__/functions.cpython-36.pyc -------------------------------------------------------------------------------- /SinGAN/__pycache__/imresize.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/SinGAN/__pycache__/imresize.cpython-36.pyc -------------------------------------------------------------------------------- /SinGAN/__pycache__/manipulate.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/SinGAN/__pycache__/manipulate.cpython-36.pyc -------------------------------------------------------------------------------- /SinGAN/__pycache__/models.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/SinGAN/__pycache__/models.cpython-36.pyc -------------------------------------------------------------------------------- /SinGAN/__pycache__/training.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/SinGAN/__pycache__/training.cpython-36.pyc -------------------------------------------------------------------------------- /SinGAN/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import matplotlib.pyplot as plt 3 | import matplotlib.patches as patches 4 | import numpy as np 5 | import torch.nn as nn 6 | import scipy.io as sio 7 | import math 8 | from skimage import io as img 9 | from skimage import color, morphology, filters 10 | #from skimage import morphology 11 | #from skimage import filters 12 | from SinGAN.imresize import imresize 13 | import os 14 | import random 15 | from sklearn.cluster import KMeans 16 | 17 | 18 | # custom weights initialization called on netG and netD 19 | 20 | def read_image(opt): 21 | x = img.imread('%s%s' % (opt.input_img,opt.ref_image)) 22 | return np2torch(x) 23 | 24 | def denorm(x): 25 | out = (x + 1) / 2 26 | return out.clamp(0, 1) 27 | 28 | def norm(x): 29 | out = (x -0.5) *2 30 | return out.clamp(-1, 1) 31 | 32 | #def denorm2image(I1,I2): 33 | # out = (I1-I1.mean())/(I1.max()-I1.min()) 34 | # out = out*(I2.max()-I2.min())+I2.mean() 35 | # return out#.clamp(I2.min(), I2.max()) 36 | 37 | #def norm2image(I1,I2): 38 | # out = (I1-I2.mean())*2 39 | # return out#.clamp(I2.min(), I2.max()) 40 | 41 | def convert_image_np(inp): 42 | if inp.shape[1]==3: 43 | inp = denorm(inp) 44 | inp = move_to_cpu(inp[-1,:,:,:]) 45 | inp = inp.numpy().transpose((1,2,0)) 46 | else: 47 | inp = denorm(inp) 48 | inp = move_to_cpu(inp[-1,-1,:,:]) 49 | inp = inp.numpy().transpose((0,1)) 50 | # mean = np.array([x/255.0 for x in [125.3,123.0,113.9]]) 51 | # std = np.array([x/255.0 for x in [63.0,62.1,66.7]]) 52 | 53 | inp = np.clip(inp,0,1) 54 | return inp 55 | 56 | def save_image(real_cpu,receptive_feild,ncs,epoch_num,file_name): 57 | fig,ax = plt.subplots(1) 58 | if ncs==1: 59 | ax.imshow(real_cpu.view(real_cpu.size(2),real_cpu.size(3)),cmap='gray') 60 | else: 61 | #ax.imshow(convert_image_np(real_cpu[0,:,:,:].cpu())) 62 | ax.imshow(convert_image_np(real_cpu.cpu())) 63 | rect = patches.Rectangle((0,0),receptive_feild,receptive_feild,linewidth=5,edgecolor='r',facecolor='none') 64 | ax.add_patch(rect) 65 | ax.axis('off') 66 | plt.savefig(file_name) 67 | plt.close(fig) 68 | 69 | def convert_image_np_2d(inp): 70 | inp = denorm(inp) 71 | inp = inp.numpy() 72 | # mean = np.array([x/255.0 for x in [125.3,123.0,113.9]]) 73 | # std = np.array([x/255.0 for x in [63.0,62.1,66.7]]) 74 | # inp = std* 75 | return inp 76 | 77 | def generate_noise(size,num_samp=1,device='cuda',type='gaussian', scale=1): 78 | if type == 'gaussian': 79 | noise = torch.randn(num_samp, size[0], round(size[1]/scale), round(size[2]/scale), device=device) 80 | noise = upsampling(noise,size[1], size[2]) 81 | if type =='gaussian_mixture': 82 | noise1 = torch.randn(num_samp, size[0], size[1], size[2], device=device)+5 83 | noise2 = torch.randn(num_samp, size[0], size[1], size[2], device=device) 84 | noise = noise1+noise2 85 | if type == 'uniform': 86 | noise = torch.randn(num_samp, size[0], size[1], size[2], device=device) 87 | return noise 88 | 89 | def plot_learning_curves(G_loss,D_loss,epochs,label1,label2,name): 90 | fig,ax = plt.subplots(1) 91 | n = np.arange(0,epochs) 92 | plt.plot(n,G_loss,n,D_loss) 93 | #plt.title('loss') 94 | #plt.ylabel('loss') 95 | plt.xlabel('epochs') 96 | plt.legend([label1,label2],loc='upper right') 97 | plt.savefig('%s.png' % name) 98 | plt.close(fig) 99 | 100 | def plot_learning_curve(loss,epochs,name): 101 | fig,ax = plt.subplots(1) 102 | n = np.arange(0,epochs) 103 | plt.plot(n,loss) 104 | plt.ylabel('loss') 105 | plt.xlabel('epochs') 106 | plt.savefig('%s.png' % name) 107 | plt.close(fig) 108 | 109 | def upsampling(im,sx,sy): 110 | m = nn.Upsample(size=[round(sx),round(sy)],mode='bilinear',align_corners=True) 111 | return m(im) 112 | 113 | def reset_grads(model,require_grad): 114 | for p in model.parameters(): 115 | p.requires_grad_(require_grad) 116 | return model 117 | 118 | def move_to_gpu(t): 119 | if (torch.cuda.is_available()): 120 | t = t.to(torch.device('cuda')) 121 | return t 122 | 123 | def move_to_cpu(t): 124 | t = t.to(torch.device('cpu')) 125 | return t 126 | 127 | def calc_gradient_penalty(netD, real_data, fake_data, LAMBDA, device): 128 | #print real_data.size() 129 | alpha = torch.rand(1, 1) 130 | alpha = alpha.expand(real_data.size()) 131 | alpha = alpha.to(device)#cuda() #gpu) #if use_cuda else alpha 132 | 133 | interpolates = alpha * real_data + ((1 - alpha) * fake_data) 134 | 135 | 136 | interpolates = interpolates.to(device)#.cuda() 137 | interpolates = torch.autograd.Variable(interpolates, requires_grad=True) 138 | 139 | disc_interpolates = netD(interpolates) 140 | 141 | gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates, 142 | grad_outputs=torch.ones(disc_interpolates.size()).to(device),#.cuda(), #if use_cuda else torch.ones( 143 | #disc_interpolates.size()), 144 | create_graph=True, retain_graph=True, only_inputs=True)[0] 145 | #LAMBDA = 1 146 | gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA 147 | return gradient_penalty 148 | 149 | def read_image(opt): 150 | x = img.imread('%s/%s' % (opt.input_dir,opt.input_name)) 151 | x = np2torch(x,opt) 152 | x = x[:,0:3,:,:] 153 | return x 154 | 155 | def read_image_dir(dir,opt): 156 | x = img.imread('%s' % (dir)) 157 | x = np2torch(x,opt) 158 | x = x[:,0:3,:,:] 159 | return x 160 | 161 | def np2torch(x,opt): 162 | if opt.nc_im == 3: 163 | x = x[:,:,:,None] 164 | x = x.transpose((3, 2, 0, 1))/255 165 | else: 166 | x = color.rgb2gray(x) 167 | x = x[:,:,None,None] 168 | x = x.transpose(3, 2, 0, 1) 169 | x = torch.from_numpy(x) 170 | if not(opt.not_cuda): 171 | x = move_to_gpu(x) 172 | x = x.type(torch.cuda.FloatTensor) if not(opt.not_cuda) else x.type(torch.FloatTensor) 173 | #x = x.type(torch.FloatTensor) 174 | x = norm(x) 175 | return x 176 | 177 | def torch2uint8(x): 178 | x = x[0,:,:,:] 179 | x = x.permute((1,2,0)) 180 | x = 255*denorm(x) 181 | x = x.cpu().numpy() 182 | x = x.astype(np.uint8) 183 | return x 184 | 185 | def read_image2np(opt): 186 | x = img.imread('%s/%s' % (opt.input_dir,opt.input_name)) 187 | x = x[:, :, 0:3] 188 | return x 189 | 190 | def save_networks(netG,netD,z,opt): 191 | torch.save(netG.state_dict(), '%s/netG.pth' % (opt.outf)) 192 | torch.save(netD.state_dict(), '%s/netD.pth' % (opt.outf)) 193 | torch.save(z, '%s/z_opt.pth' % (opt.outf)) 194 | 195 | def adjust_scales2image(real_,opt): 196 | #opt.num_scales = int((math.log(math.pow(opt.min_size / (real_.shape[2]), 1), opt.scale_factor_init))) + 1 197 | opt.num_scales = math.ceil((math.log(math.pow(opt.min_size / (min(real_.shape[2], real_.shape[3])), 1), opt.scale_factor_init))) + 1 198 | scale2stop = math.ceil(math.log(min([opt.max_size, max([real_.shape[2], real_.shape[3]])]) / max([real_.shape[2], real_.shape[3]]),opt.scale_factor_init)) 199 | opt.stop_scale = opt.num_scales - scale2stop 200 | opt.scale1 = min(opt.max_size / max([real_.shape[2], real_.shape[3]]),1) # min(250/max([real_.shape[0],real_.shape[1]]),1) 201 | real = imresize(real_, opt.scale1, opt) 202 | #opt.scale_factor = math.pow(opt.min_size / (real.shape[2]), 1 / (opt.stop_scale)) 203 | opt.scale_factor = math.pow(opt.min_size/(min(real.shape[2],real.shape[3])),1/(opt.stop_scale)) 204 | scale2stop = math.ceil(math.log(min([opt.max_size, max([real_.shape[2], real_.shape[3]])]) / max([real_.shape[2], real_.shape[3]]),opt.scale_factor_init)) 205 | opt.stop_scale = opt.num_scales - scale2stop 206 | return real 207 | 208 | def adjust_scales2image_SR(real_,opt): 209 | opt.min_size = 18 210 | opt.num_scales = int((math.log(opt.min_size / min(real_.shape[2], real_.shape[3]), opt.scale_factor_init))) + 1 211 | scale2stop = int(math.log(min(opt.max_size , max(real_.shape[2], real_.shape[3])) / max(real_.shape[0], real_.shape[3]), opt.scale_factor_init)) 212 | opt.stop_scale = opt.num_scales - scale2stop 213 | opt.scale1 = min(opt.max_size / max([real_.shape[2], real_.shape[3]]), 1) # min(250/max([real_.shape[0],real_.shape[1]]),1) 214 | real = imresize(real_, opt.scale1, opt) 215 | #opt.scale_factor = math.pow(opt.min_size / (real.shape[2]), 1 / (opt.stop_scale)) 216 | opt.scale_factor = math.pow(opt.min_size/(min(real.shape[2],real.shape[3])),1/(opt.stop_scale)) 217 | scale2stop = int(math.log(min(opt.max_size, max(real_.shape[2], real_.shape[3])) / max(real_.shape[0], real_.shape[3]), opt.scale_factor_init)) 218 | opt.stop_scale = opt.num_scales - scale2stop 219 | return real 220 | 221 | def creat_reals_pyramid(real,reals,opt): 222 | real = real[:,0:3,:,:] 223 | for i in range(0,opt.stop_scale+1,1): 224 | scale = math.pow(opt.scale_factor,opt.stop_scale-i) 225 | curr_real = imresize(real,scale,opt) 226 | reals.append(curr_real) 227 | return reals 228 | 229 | 230 | def load_trained_pyramid(opt, mode_='train'): 231 | #dir = 'TrainedModels/%s/scale_factor=%f' % (opt.input_name[:-4], opt.scale_factor_init) 232 | mode = opt.mode 233 | opt.mode = 'train' 234 | if (mode == 'animation_train') | (mode == 'SR_train') | (mode == 'paint_train'): 235 | opt.mode = mode 236 | dir = generate_dir2save(opt) 237 | if(os.path.exists(dir)): 238 | Gs = torch.load('%s/Gs.pth' % dir) 239 | Zs = torch.load('%s/Zs.pth' % dir) 240 | reals = torch.load('%s/reals.pth' % dir) 241 | NoiseAmp = torch.load('%s/NoiseAmp.pth' % dir) 242 | else: 243 | print('no appropriate trained model is exist, please train first') 244 | opt.mode = mode 245 | return Gs,Zs,reals,NoiseAmp 246 | 247 | def generate_in2coarsest(reals,scale_v,scale_h,opt): 248 | real = reals[opt.gen_start_scale] 249 | real_down = upsampling(real, scale_v * real.shape[2], scale_h * real.shape[3]) 250 | if opt.gen_start_scale == 0: 251 | in_s = torch.full(real_down.shape, 0, device=opt.device) 252 | else: #if n!=0 253 | in_s = upsampling(real_down, real_down.shape[2], real_down.shape[3]) 254 | return in_s 255 | 256 | def generate_dir2save(opt): 257 | dir2save = None 258 | if (opt.mode == 'train') | (opt.mode == 'SR_train'): 259 | dir2save = 'TrainedModels/%s/scale_factor=%f,alpha=%d' % (opt.input_name[:-4], opt.scale_factor_init,opt.alpha) 260 | elif (opt.mode == 'animation_train') : 261 | dir2save = 'TrainedModels/%s/scale_factor=%f_noise_padding' % (opt.input_name[:-4], opt.scale_factor_init) 262 | elif (opt.mode == 'paint_train') : 263 | dir2save = 'TrainedModels/%s/scale_factor=%f_paint/start_scale=%d' % (opt.input_name[:-4], opt.scale_factor_init,opt.paint_start_scale) 264 | elif opt.mode == 'random_samples': 265 | dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out,opt.input_name[:-4], opt.gen_start_scale) 266 | elif opt.mode == 'random_samples_arbitrary_sizes': 267 | dir2save = '%s/RandomSamples_ArbitrerySizes/%s/scale_v=%f_scale_h=%f' % (opt.out,opt.input_name[:-4], opt.scale_v, opt.scale_h) 268 | elif opt.mode == 'animation': 269 | dir2save = '%s/Animation/%s' % (opt.out, opt.input_name[:-4]) 270 | elif opt.mode == 'SR': 271 | dir2save = '%s/SR/%s' % (opt.out, opt.sr_factor) 272 | elif opt.mode == 'harmonization': 273 | dir2save = '%s/Harmonization/%s/%s_out' % (opt.out, opt.input_name[:-4],opt.ref_name[:-4]) 274 | elif opt.mode == 'editing': 275 | dir2save = '%s/Editing/%s/%s_out' % (opt.out, opt.input_name[:-4],opt.ref_name[:-4]) 276 | elif opt.mode == 'paint2image': 277 | dir2save = '%s/Paint2image/%s/%s_out' % (opt.out, opt.input_name[:-4],opt.ref_name[:-4]) 278 | if opt.quantization_flag: 279 | dir2save = '%s_quantized' % dir2save 280 | return dir2save 281 | 282 | def post_config(opt): 283 | # init fixed parameters 284 | opt.device = torch.device("cpu" if opt.not_cuda else "cuda:0") 285 | opt.niter_init = opt.niter 286 | opt.noise_amp_init = opt.noise_amp 287 | opt.nfc_init = opt.nfc 288 | opt.min_nfc_init = opt.min_nfc 289 | opt.scale_factor_init = opt.scale_factor 290 | opt.out_ = 'TrainedModels/%s/scale_factor=%f/' % (opt.input_name[:-4], opt.scale_factor) 291 | if opt.mode == 'SR': 292 | opt.alpha = 100 293 | 294 | if opt.manualSeed is None: 295 | opt.manualSeed = random.randint(1, 10000) 296 | print("Random Seed: ", opt.manualSeed) 297 | random.seed(opt.manualSeed) 298 | torch.manual_seed(opt.manualSeed) 299 | if torch.cuda.is_available() and opt.not_cuda: 300 | print("WARNING: You have a CUDA device, so you should probably run with --cuda") 301 | return opt 302 | 303 | def calc_init_scale(opt): 304 | in_scale = math.pow(1/2,1/3) 305 | iter_num = round(math.log(1 / opt.sr_factor, in_scale)) 306 | in_scale = pow(opt.sr_factor, 1 / iter_num) 307 | return in_scale,iter_num 308 | 309 | def quant(prev,device): 310 | arr = prev.reshape((-1, 3)).cpu() 311 | kmeans = KMeans(n_clusters=5, random_state=0).fit(arr) 312 | labels = kmeans.labels_ 313 | centers = kmeans.cluster_centers_ 314 | x = centers[labels] 315 | x = torch.from_numpy(x) 316 | x = move_to_gpu(x) 317 | x = x.type(torch.cuda.FloatTensor) if () else x.type(torch.FloatTensor) 318 | #x = x.type(torch.FloatTensor.to(device)) 319 | x = x.view(prev.shape) 320 | return x,centers 321 | 322 | def quant2centers(paint, centers): 323 | arr = paint.reshape((-1, 3)).cpu() 324 | kmeans = KMeans(n_clusters=5, init=centers, n_init=1).fit(arr) 325 | labels = kmeans.labels_ 326 | #centers = kmeans.cluster_centers_ 327 | x = centers[labels] 328 | x = torch.from_numpy(x) 329 | x = move_to_gpu(x) 330 | x = x.type(torch.cuda.FloatTensor) if torch.cuda.is_available() else x.type(torch.FloatTensor) 331 | #x = x.type(torch.cuda.FloatTensor) 332 | x = x.view(paint.shape) 333 | return x 334 | 335 | return paint 336 | 337 | 338 | def dilate_mask(mask,opt): 339 | if opt.mode == "harmonization": 340 | element = morphology.disk(radius=7) 341 | if opt.mode == "editing": 342 | element = morphology.disk(radius=20) 343 | mask = torch2uint8(mask) 344 | mask = mask[:,:,0] 345 | mask = morphology.binary_dilation(mask,selem=element) 346 | mask = filters.gaussian(mask, sigma=5) 347 | nc_im = opt.nc_im 348 | opt.nc_im = 1 349 | mask = np2torch(mask,opt) 350 | opt.nc_im = nc_im 351 | mask = mask.expand(1, 3, mask.shape[2], mask.shape[3]) 352 | plt.imsave('%s/%s_mask_dilated.png' % (opt.ref_dir, opt.ref_name[:-4]), convert_image_np(mask), vmin=0,vmax=1) 353 | mask = (mask-mask.min())/(mask.max()-mask.min()) 354 | return mask 355 | 356 | 357 | -------------------------------------------------------------------------------- /SinGAN/imresize.py: -------------------------------------------------------------------------------- 1 | # This code was taken from: https://github.com/assafshocher/resizer by Assaf Shocher 2 | 3 | import numpy as np 4 | from scipy.ndimage import filters, measurements, interpolation 5 | from skimage import color 6 | from math import pi 7 | #from SinGAN.functions import torch2uint8, np2torch 8 | import torch 9 | 10 | 11 | def denorm(x): 12 | out = (x + 1) / 2 13 | return out.clamp(0, 1) 14 | 15 | def norm(x): 16 | out = (x - 0.5) * 2 17 | return out.clamp(-1, 1) 18 | 19 | def move_to_gpu(t): 20 | if (torch.cuda.is_available()): 21 | t = t.to(torch.device('cuda')) 22 | return t 23 | 24 | def np2torch(x,opt): 25 | if opt.nc_im == 3: 26 | x = x[:,:,:,None] 27 | x = x.transpose((3, 2, 0, 1))/255 28 | else: 29 | x = color.rgb2gray(x) 30 | x = x[:,:,None,None] 31 | x = x.transpose(3, 2, 0, 1) 32 | x = torch.from_numpy(x) 33 | if not (opt.not_cuda): 34 | x = move_to_gpu(x) 35 | x = x.type(torch.cuda.FloatTensor) if not(opt.not_cuda) else x.type(torch.FloatTensor) 36 | #x = x.type(torch.cuda.FloatTensor) 37 | x = norm(x) 38 | return x 39 | 40 | def torch2uint8(x): 41 | x = x[0,:,:,:] 42 | x = x.permute((1,2,0)) 43 | x = 255*denorm(x) 44 | x = x.cpu().numpy() 45 | x = x.astype(np.uint8) 46 | return x 47 | 48 | 49 | def imresize(im,scale,opt): 50 | #s = im.shape 51 | im = torch2uint8(im) 52 | im = imresize_in(im, scale_factor=scale) 53 | im = np2torch(im,opt) 54 | #im = im[:, :, 0:int(scale * s[2]), 0:int(scale * s[3])] 55 | return im 56 | 57 | def imresize_to_shape(im,output_shape,opt): 58 | #s = im.shape 59 | im = torch2uint8(im) 60 | im = imresize_in(im, output_shape=output_shape) 61 | im = np2torch(im,opt) 62 | #im = im[:, :, 0:int(scale * s[2]), 0:int(scale * s[3])] 63 | return im 64 | 65 | 66 | def imresize_in(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 67 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa 68 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor) 69 | 70 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only) 71 | if type(kernel) == np.ndarray and scale_factor[0] <= 1: 72 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag) 73 | 74 | # Choose interpolation method, each method has the matching kernel size 75 | method, kernel_width = { 76 | "cubic": (cubic, 4.0), 77 | "lanczos2": (lanczos2, 4.0), 78 | "lanczos3": (lanczos3, 6.0), 79 | "box": (box, 1.0), 80 | "linear": (linear, 2.0), 81 | None: (cubic, 4.0) # set default interpolation method as cubic 82 | }.get(kernel) 83 | 84 | # Antialiasing is only used when downscaling 85 | antialiasing *= (scale_factor[0] < 1) 86 | 87 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient 88 | sorted_dims = np.argsort(np.array(scale_factor)).tolist() 89 | 90 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction 91 | out_im = np.copy(im) 92 | for dim in sorted_dims: 93 | # No point doing calculations for scale-factor 1. nothing will happen anyway 94 | if scale_factor[dim] == 1.0: 95 | continue 96 | 97 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the 98 | # weights that multiply the values there to get its result. 99 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim], 100 | method, kernel_width, antialiasing) 101 | 102 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim 103 | out_im = resize_along_dim(out_im, dim, weights, field_of_view) 104 | 105 | return out_im 106 | 107 | 108 | def fix_scale_and_size(input_shape, output_shape, scale_factor): 109 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the 110 | # same size as the number of input dimensions) 111 | if scale_factor is not None: 112 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. 113 | if np.isscalar(scale_factor): 114 | scale_factor = [scale_factor, scale_factor] 115 | 116 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales 117 | scale_factor = list(scale_factor) 118 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor))) 119 | 120 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size 121 | # to all the unspecified dimensions 122 | if output_shape is not None: 123 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):]) 124 | 125 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is 126 | # sub-optimal, because there can be different scales to the same output-shape. 127 | if scale_factor is None: 128 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) 129 | 130 | # Dealing with missing output-shape. calculating according to scale-factor 131 | if output_shape is None: 132 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) 133 | 134 | return scale_factor, output_shape 135 | 136 | 137 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing): 138 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied 139 | # such that each position from the field_of_view will be multiplied with a matching filter from the 140 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers 141 | # around it. This is only done for one dimension of the image. 142 | 143 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of 144 | # 1/sf. this means filtering is more 'low-pass filter'. 145 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel 146 | kernel_width *= 1.0 / scale if antialiasing else 1.0 147 | 148 | # These are the coordinates of the output image 149 | out_coordinates = np.arange(1, out_length+1) 150 | 151 | # These are the matching positions of the output-coordinates on the input image coordinates. 152 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: 153 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. 154 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to 155 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big 156 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). 157 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is 158 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: 159 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) 160 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale) 161 | 162 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter 163 | left_boundary = np.floor(match_coordinates - kernel_width / 2) 164 | 165 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers 166 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) 167 | expanded_kernel_width = np.ceil(kernel_width) + 2 168 | 169 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image 170 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the 171 | # vertical dim is the pixels it 'sees' (kernel_size + 2) 172 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) 173 | 174 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the 175 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in 176 | # 'field_of_view') 177 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) 178 | 179 | # Normalize weights to sum up to 1. be careful from dividing by 0 180 | sum_weights = np.sum(weights, axis=1) 181 | sum_weights[sum_weights == 0] = 1.0 182 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) 183 | 184 | # We use this mirror structure as a trick for reflection padding at the boundaries 185 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) 186 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] 187 | 188 | # Get rid of weights and pixel positions that are of zero weight 189 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) 190 | weights = np.squeeze(weights[:, non_zero_out_pixels]) 191 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) 192 | 193 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size 194 | return weights, field_of_view 195 | 196 | 197 | def resize_along_dim(im, dim, weights, field_of_view): 198 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize 199 | tmp_im = np.swapaxes(im, dim, 0) 200 | 201 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for 202 | # tmp_im[field_of_view.T], (bsxfun style) 203 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1]) 204 | 205 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1. 206 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim 207 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with 208 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: 209 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the 210 | # same number 211 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0) 212 | 213 | # Finally we swap back the axes to the original order 214 | return np.swapaxes(tmp_out_im, dim, 0) 215 | 216 | 217 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag): 218 | # See kernel_shift function to understand what this is 219 | if kernel_shift_flag: 220 | kernel = kernel_shift(kernel, scale_factor) 221 | 222 | # First run a correlation (convolution with flipped kernel) 223 | out_im = np.zeros_like(im) 224 | for channel in range(np.ndim(im)): 225 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel) 226 | 227 | # Then subsample and return 228 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None], 229 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :] 230 | 231 | 232 | def kernel_shift(kernel, sf): 233 | # There are two reasons for shifting the kernel: 234 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 235 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 236 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 237 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 238 | # top left corner of the first pixel. that is why different shift size needed between od and even size. 239 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 240 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 241 | 242 | # First calculate the current center of mass for the kernel 243 | current_center_of_mass = measurements.center_of_mass(kernel) 244 | 245 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above 246 | wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (sf - (kernel.shape[0] % 2)) 247 | 248 | # Define the shift vector for the kernel shifting (x,y) 249 | shift_vec = wanted_center_of_mass - current_center_of_mass 250 | 251 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 252 | # (biggest shift among dims + 1 for safety) 253 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant') 254 | 255 | # Finally shift the kernel and return 256 | return interpolation.shift(kernel, shift_vec) 257 | 258 | 259 | # These next functions are all interpolation methods. x is the distance from the left pixel center 260 | 261 | 262 | def cubic(x): 263 | absx = np.abs(x) 264 | absx2 = absx ** 2 265 | absx3 = absx ** 3 266 | return ((1.5*absx3 - 2.5*absx2 + 1) * (absx <= 1) + 267 | (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * ((1 < absx) & (absx <= 2))) 268 | 269 | 270 | def lanczos2(x): 271 | return (((np.sin(pi*x) * np.sin(pi*x/2) + np.finfo(np.float32).eps) / 272 | ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)) 273 | * (abs(x) < 2)) 274 | 275 | 276 | def box(x): 277 | return ((-0.5 <= x) & (x < 0.5)) * 1.0 278 | 279 | 280 | def lanczos3(x): 281 | return (((np.sin(pi*x) * np.sin(pi*x/3) + np.finfo(np.float32).eps) / 282 | ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)) 283 | * (abs(x) < 3)) 284 | 285 | 286 | def linear(x): 287 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) 288 | -------------------------------------------------------------------------------- /SinGAN/manipulate.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import SinGAN.functions 3 | import SinGAN.models 4 | import argparse 5 | import os 6 | import random 7 | from SinGAN.imresize import imresize 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | import torch.utils.data 11 | import torchvision.datasets as dset 12 | import torchvision.transforms as transforms 13 | import torchvision.utils as vutils 14 | from skimage import io as img 15 | import numpy as np 16 | from skimage import color 17 | import math 18 | import imageio 19 | import matplotlib.pyplot as plt 20 | from SinGAN.training import * 21 | from config import get_arguments 22 | 23 | def generate_gif(Gs,Zs,reals,NoiseAmp,opt,alpha=0.1,beta=0.9,start_scale=2,fps=10): 24 | 25 | in_s = torch.full(Zs[0].shape, 0, device=opt.device) 26 | images_cur = [] 27 | count = 0 28 | 29 | for G,Z_opt,noise_amp,real in zip(Gs,Zs,NoiseAmp,reals): 30 | pad_image = int(((opt.ker_size - 1) * opt.num_layer) / 2) 31 | nzx = Z_opt.shape[2] 32 | nzy = Z_opt.shape[3] 33 | #pad_noise = 0 34 | #m_noise = nn.ZeroPad2d(int(pad_noise)) 35 | m_image = nn.ZeroPad2d(int(pad_image)) 36 | images_prev = images_cur 37 | images_cur = [] 38 | if count == 0: 39 | z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device) 40 | z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3]) 41 | z_prev1 = 0.95*Z_opt +0.05*z_rand 42 | z_prev2 = Z_opt 43 | else: 44 | z_prev1 = 0.95*Z_opt +0.05*functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) 45 | z_prev2 = Z_opt 46 | 47 | for i in range(0,100,1): 48 | if count == 0: 49 | z_rand = functions.generate_noise([1,nzx,nzy], device=opt.device) 50 | z_rand = z_rand.expand(1,3,Z_opt.shape[2],Z_opt.shape[3]) 51 | diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*z_rand 52 | else: 53 | diff_curr = beta*(z_prev1-z_prev2)+(1-beta)*(functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device)) 54 | 55 | z_curr = alpha*Z_opt+(1-alpha)*(z_prev1+diff_curr) 56 | z_prev2 = z_prev1 57 | z_prev1 = z_curr 58 | 59 | if images_prev == []: 60 | I_prev = in_s 61 | else: 62 | I_prev = images_prev[i] 63 | I_prev = imresize(I_prev, 1 / opt.scale_factor, opt) 64 | I_prev = I_prev[:, :, 0:real.shape[2], 0:real.shape[3]] 65 | #I_prev = functions.upsampling(I_prev,reals[count].shape[2],reals[count].shape[3]) 66 | I_prev = m_image(I_prev) 67 | if count < start_scale: 68 | z_curr = Z_opt 69 | 70 | z_in = noise_amp*z_curr+I_prev 71 | I_curr = G(z_in.detach(),I_prev) 72 | 73 | if (count == len(Gs)-1): 74 | I_curr = functions.denorm(I_curr).detach() 75 | I_curr = I_curr[0,:,:,:].cpu().numpy() 76 | I_curr = I_curr.transpose(1, 2, 0)*255 77 | I_curr = I_curr.astype(np.uint8) 78 | 79 | images_cur.append(I_curr) 80 | count += 1 81 | dir2save = functions.generate_dir2save(opt) 82 | try: 83 | os.makedirs('%s/start_scale=%d' % (dir2save,start_scale) ) 84 | except OSError: 85 | pass 86 | imageio.mimsave('%s/start_scale=%d/alpha=%f_beta=%f.gif' % (dir2save,start_scale,alpha,beta),images_cur,fps=fps) 87 | del images_cur 88 | 89 | def SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt,in_s=None,scale_v=1,scale_h=1,n=0,gen_start_scale=0,num_samples=50): 90 | #if torch.is_tensor(in_s) == False: 91 | if in_s is None: 92 | in_s = torch.full(reals[0].shape, 0, device=opt.device) 93 | images_cur = [] 94 | for G,Z_opt,noise_amp in zip(Gs,Zs,NoiseAmp): 95 | pad1 = ((opt.ker_size-1)*opt.num_layer)/2 96 | m = nn.ZeroPad2d(int(pad1)) 97 | nzx = (Z_opt.shape[2]-pad1*2)*scale_v 98 | nzy = (Z_opt.shape[3]-pad1*2)*scale_h 99 | 100 | images_prev = images_cur 101 | images_cur = [] 102 | 103 | for i in range(0,num_samples,1): 104 | if n == 0: 105 | z_curr = functions.generate_noise([1,nzx,nzy], device=opt.device) 106 | z_curr = z_curr.expand(1,3,z_curr.shape[2],z_curr.shape[3]) 107 | z_curr = m(z_curr) 108 | else: 109 | z_curr = functions.generate_noise([opt.nc_z,nzx,nzy], device=opt.device) 110 | z_curr = m(z_curr) 111 | 112 | if images_prev == []: 113 | I_prev = m(in_s) 114 | #I_prev = m(I_prev) 115 | #I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] 116 | #I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) 117 | else: 118 | I_prev = images_prev[i] 119 | I_prev = imresize(I_prev,1/opt.scale_factor, opt) 120 | if opt.mode != "SR": 121 | I_prev = I_prev[:, :, 0:round(scale_v * reals[n].shape[2]), 0:round(scale_h * reals[n].shape[3])] 122 | I_prev = m(I_prev) 123 | I_prev = I_prev[:,:,0:z_curr.shape[2],0:z_curr.shape[3]] 124 | I_prev = functions.upsampling(I_prev,z_curr.shape[2],z_curr.shape[3]) 125 | else: 126 | I_prev = m(I_prev) 127 | 128 | if n < gen_start_scale: 129 | z_curr = Z_opt 130 | 131 | z_in = noise_amp*(z_curr)+I_prev 132 | I_curr = G(z_in.detach(),I_prev) 133 | 134 | if n == len(reals)-1: 135 | if opt.mode == 'train': 136 | dir2save = '%s/RandomSamples/%s/gen_start_scale=%d' % (opt.out, opt.input_name[:-4], gen_start_scale) 137 | else: 138 | dir2save = functions.generate_dir2save(opt) 139 | try: 140 | os.makedirs(dir2save) 141 | except OSError: 142 | pass 143 | if (opt.mode != "harmonization") & (opt.mode != "editing") & (opt.mode != "SR") & (opt.mode != "paint2image"): 144 | plt.imsave('%s/%d.png' % (dir2save, i), functions.convert_image_np(I_curr.detach()), vmin=0,vmax=1) 145 | #plt.imsave('%s/%d_%d.png' % (dir2save,i,n),functions.convert_image_np(I_curr.detach()), vmin=0, vmax=1) 146 | #plt.imsave('%s/in_s.png' % (dir2save), functions.convert_image_np(in_s), vmin=0,vmax=1) 147 | images_cur.append(I_curr) 148 | n+=1 149 | return I_curr.detach() 150 | 151 | -------------------------------------------------------------------------------- /SinGAN/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | 6 | 7 | class ConvBlock(nn.Sequential): 8 | def __init__(self, in_channel, out_channel, ker_size, padd, stride): 9 | super(ConvBlock,self).__init__() 10 | self.add_module('conv',nn.Conv2d(in_channel ,out_channel,kernel_size=ker_size,stride=stride,padding=padd)), 11 | self.add_module('norm',nn.BatchNorm2d(out_channel)), 12 | self.add_module('LeakyRelu',nn.LeakyReLU(0.2, inplace=True)) 13 | 14 | def weights_init(m): 15 | classname = m.__class__.__name__ 16 | if classname.find('Conv2d') != -1: 17 | m.weight.data.normal_(0.0, 0.02) 18 | elif classname.find('Norm') != -1: 19 | m.weight.data.normal_(1.0, 0.02) 20 | m.bias.data.fill_(0) 21 | 22 | class WDiscriminator(nn.Module): 23 | def __init__(self, opt): 24 | super(WDiscriminator, self).__init__() 25 | self.is_cuda = torch.cuda.is_available() 26 | N = int(opt.nfc) 27 | self.head = ConvBlock(opt.nc_im,N,opt.ker_size,opt.padd_size,1) 28 | self.body = nn.Sequential() 29 | for i in range(opt.num_layer-2): 30 | N = int(opt.nfc/pow(2,(i+1))) 31 | block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1) 32 | self.body.add_module('block%d'%(i+1),block) 33 | self.tail = nn.Conv2d(max(N,opt.min_nfc),1,kernel_size=opt.ker_size,stride=1,padding=opt.padd_size) 34 | 35 | def forward(self,x): 36 | x = self.head(x) 37 | x = self.body(x) 38 | x = self.tail(x) 39 | return x 40 | 41 | 42 | class GeneratorConcatSkip2CleanAdd(nn.Module): 43 | def __init__(self, opt): 44 | super(GeneratorConcatSkip2CleanAdd, self).__init__() 45 | self.is_cuda = torch.cuda.is_available() 46 | N = opt.nfc 47 | self.head = ConvBlock(opt.nc_im,N,opt.ker_size,opt.padd_size,1) #GenConvTransBlock(opt.nc_z,N,opt.ker_size,opt.padd_size,opt.stride) 48 | self.body = nn.Sequential() 49 | for i in range(opt.num_layer-2): 50 | N = int(opt.nfc/pow(2,(i+1))) 51 | block = ConvBlock(max(2*N,opt.min_nfc),max(N,opt.min_nfc),opt.ker_size,opt.padd_size,1) 52 | self.body.add_module('block%d'%(i+1),block) 53 | self.tail = nn.Sequential( 54 | nn.Conv2d(max(N,opt.min_nfc),opt.nc_im,kernel_size=opt.ker_size,stride =1,padding=opt.padd_size), 55 | nn.Tanh() 56 | ) 57 | def forward(self,x,y): 58 | x = self.head(x) 59 | x = self.body(x) 60 | x = self.tail(x) 61 | ind = int((y.shape[2]-x.shape[2])/2) 62 | y = y[:,:,ind:(y.shape[2]-ind),ind:(y.shape[3]-ind)] 63 | return x+y 64 | -------------------------------------------------------------------------------- /SinGAN/training.py: -------------------------------------------------------------------------------- 1 | import SinGAN.functions as functions 2 | import SinGAN.models as models 3 | import os 4 | import torch.nn as nn 5 | import torch.optim as optim 6 | import torch.utils.data 7 | import math 8 | import matplotlib.pyplot as plt 9 | from SinGAN.imresize import imresize 10 | 11 | def train(opt,Gs,Zs,reals,NoiseAmp): 12 | real_ = functions.read_image(opt) 13 | in_s = 0 14 | scale_num = 0 15 | real = imresize(real_,opt.scale1,opt) 16 | reals = functions.creat_reals_pyramid(real,reals,opt) 17 | nfc_prev = 0 18 | 19 | while scale_num 0: 224 | if mode == 'rand': 225 | count = 0 226 | pad_noise = int(((opt.ker_size-1)*opt.num_layer)/2) 227 | if opt.mode == 'animation_train': 228 | pad_noise = 0 229 | for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp): 230 | if count == 0: 231 | z = functions.generate_noise([1, Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=opt.device) 232 | z = z.expand(1, 3, z.shape[2], z.shape[3]) 233 | else: 234 | z = functions.generate_noise([opt.nc_z,Z_opt.shape[2] - 2 * pad_noise, Z_opt.shape[3] - 2 * pad_noise], device=opt.device) 235 | z = m_noise(z) 236 | G_z = G_z[:,:,0:real_curr.shape[2],0:real_curr.shape[3]] 237 | G_z = m_image(G_z) 238 | z_in = noise_amp*z+G_z 239 | G_z = G(z_in.detach(),G_z) 240 | G_z = imresize(G_z,1/opt.scale_factor,opt) 241 | G_z = G_z[:,:,0:real_next.shape[2],0:real_next.shape[3]] 242 | count += 1 243 | if mode == 'rec': 244 | count = 0 245 | for G,Z_opt,real_curr,real_next,noise_amp in zip(Gs,Zs,reals,reals[1:],NoiseAmp): 246 | G_z = G_z[:, :, 0:real_curr.shape[2], 0:real_curr.shape[3]] 247 | G_z = m_image(G_z) 248 | z_in = noise_amp*Z_opt+G_z 249 | G_z = G(z_in.detach(),G_z) 250 | G_z = imresize(G_z,1/opt.scale_factor,opt) 251 | G_z = G_z[:,:,0:real_next.shape[2],0:real_next.shape[3]] 252 | #if count != (len(Gs)-1): 253 | # G_z = m_image(G_z) 254 | count += 1 255 | return G_z 256 | 257 | def train_paint(opt,Gs,Zs,reals,NoiseAmp,centers,paint_inject_scale): 258 | in_s = torch.full(reals[0].shape, 0, device=opt.device) 259 | scale_num = 0 260 | nfc_prev = 0 261 | 262 | while scale_num (len(Gs)-1)): 37 | print("injection scale should be between 1 and %d" % (len(Gs)-1)) 38 | else: 39 | ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt) 40 | mask = functions.read_image_dir('%s/%s_mask%s' % (opt.ref_dir,opt.ref_name[:-4],opt.ref_name[-4:]), opt) 41 | if ref.shape[3] != real.shape[3]: 42 | ''' 43 | mask = imresize(mask, real.shape[3]/ref.shape[3], opt) 44 | mask = mask[:, :, :real.shape[2], :real.shape[3]] 45 | ref = imresize(ref, real.shape[3] / ref.shape[3], opt) 46 | ref = ref[:, :, :real.shape[2], :real.shape[3]] 47 | ''' 48 | mask = imresize_to_shape(mask, [real.shape[2],real.shape[3]], opt) 49 | mask = mask[:, :, :real.shape[2], :real.shape[3]] 50 | ref = imresize_to_shape(ref, [real.shape[2],real.shape[3]], opt) 51 | ref = ref[:, :, :real.shape[2], :real.shape[3]] 52 | 53 | mask = functions.dilate_mask(mask, opt) 54 | 55 | N = len(reals) - 1 56 | n = opt.editing_start_scale 57 | in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) 58 | in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] 59 | in_s = imresize(in_s, 1 / opt.scale_factor, opt) 60 | in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] 61 | out = SinGAN_generate(Gs[n:], Zs[n:], reals, NoiseAmp[n:], opt, in_s, n=n, num_samples=1) 62 | plt.imsave('%s/start_scale=%d.png' % (dir2save, opt.editing_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1) 63 | out = (1-mask)*real+mask*out 64 | plt.imsave('%s/start_scale=%d_masked.png' % (dir2save, opt.editing_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1) 65 | 66 | 67 | 68 | 69 | -------------------------------------------------------------------------------- /harmonization.py: -------------------------------------------------------------------------------- 1 | from config import get_arguments 2 | from SinGAN.manipulate import * 3 | from SinGAN.training import * 4 | from SinGAN.imresize import imresize 5 | from SinGAN.imresize import imresize_to_shape 6 | import SinGAN.functions as functions 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = get_arguments() 11 | parser.add_argument('--input_dir', help='input image dir', default='Input/Images') 12 | parser.add_argument('--input_name', help='training image name', required=True) 13 | parser.add_argument('--ref_dir', help='input reference dir', default='Input/Harmonization') 14 | parser.add_argument('--ref_name', help='reference image name', required=True) 15 | parser.add_argument('--harmonization_start_scale', help='harmonization injection scale', type=int, required=True) 16 | parser.add_argument('--mode', help='task to be done', default='harmonization') 17 | opt = parser.parse_args() 18 | opt = functions.post_config(opt) 19 | Gs = [] 20 | Zs = [] 21 | reals = [] 22 | NoiseAmp = [] 23 | dir2save = functions.generate_dir2save(opt) 24 | if dir2save is None: 25 | print('task does not exist') 26 | #elif (os.path.exists(dir2save)): 27 | # print("output already exist") 28 | else: 29 | try: 30 | os.makedirs(dir2save) 31 | except OSError: 32 | pass 33 | real = functions.read_image(opt) 34 | real = functions.adjust_scales2image(real, opt) 35 | Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) 36 | if (opt.harmonization_start_scale < 1) | (opt.harmonization_start_scale > (len(Gs)-1)): 37 | print("injection scale should be between 1 and %d" % (len(Gs)-1)) 38 | else: 39 | ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt) 40 | mask = functions.read_image_dir('%s/%s_mask%s' % (opt.ref_dir,opt.ref_name[:-4],opt.ref_name[-4:]), opt) 41 | if ref.shape[3] != real.shape[3]: 42 | mask = imresize_to_shape(mask, [real.shape[2], real.shape[3]], opt) 43 | mask = mask[:, :, :real.shape[2], :real.shape[3]] 44 | ref = imresize_to_shape(ref, [real.shape[2], real.shape[3]], opt) 45 | ref = ref[:, :, :real.shape[2], :real.shape[3]] 46 | mask = functions.dilate_mask(mask, opt) 47 | 48 | N = len(reals) - 1 49 | n = opt.harmonization_start_scale 50 | in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) 51 | in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] 52 | in_s = imresize(in_s, 1 / opt.scale_factor, opt) 53 | in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] 54 | out = SinGAN_generate(Gs[n:], Zs[n:], reals, NoiseAmp[n:], opt, in_s, n=n, num_samples=1) 55 | out = (1-mask)*real+mask*out 56 | plt.imsave('%s/start_scale=%d.png' % (dir2save,opt.harmonization_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1) 57 | 58 | 59 | 60 | 61 | -------------------------------------------------------------------------------- /imgs/manipulation.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/imgs/manipulation.PNG -------------------------------------------------------------------------------- /imgs/teaser.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/tamarott/SinGAN/df38a4214af95462fa97a613d6ba53eb441509dd/imgs/teaser.PNG -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | from config import get_arguments 2 | from SinGAN.manipulate import * 3 | from SinGAN.training import * 4 | import SinGAN.functions as functions 5 | 6 | 7 | if __name__ == '__main__': 8 | parser = get_arguments() 9 | parser.add_argument('--input_dir', help='input image dir', default='Input/Images') 10 | parser.add_argument('--input_name', help='input image name', required=True) 11 | parser.add_argument('--mode', help='task to be done', default='train') 12 | opt = parser.parse_args() 13 | opt = functions.post_config(opt) 14 | Gs = [] 15 | Zs = [] 16 | reals = [] 17 | NoiseAmp = [] 18 | dir2save = functions.generate_dir2save(opt) 19 | 20 | if (os.path.exists(dir2save)): 21 | print('trained model already exist') 22 | else: 23 | try: 24 | os.makedirs(dir2save) 25 | except OSError: 26 | pass 27 | real = functions.read_image(opt) 28 | functions.adjust_scales2image(real, opt) 29 | train(opt, Gs, Zs, reals, NoiseAmp) 30 | SinGAN_generate(Gs,Zs,reals,NoiseAmp,opt) 31 | -------------------------------------------------------------------------------- /paint2image.py: -------------------------------------------------------------------------------- 1 | from config import get_arguments 2 | from SinGAN.manipulate import * 3 | from SinGAN.training import * 4 | from SinGAN.imresize import imresize 5 | from SinGAN.imresize import imresize_to_shape 6 | import SinGAN.functions as functions 7 | 8 | 9 | if __name__ == '__main__': 10 | parser = get_arguments() 11 | parser.add_argument('--input_dir', help='input image dir', default='Input/Images') 12 | parser.add_argument('--input_name', help='training image name', required=True) 13 | parser.add_argument('--ref_dir', help='input reference dir', default='Input/Paint') 14 | parser.add_argument('--ref_name', help='reference image name', required=True) 15 | parser.add_argument('--paint_start_scale', help='paint injection scale', type=int, required=True) 16 | parser.add_argument('--quantization_flag', help='specify if to perform color quantization training', type=bool, default=False) 17 | parser.add_argument('--mode', help='task to be done', default='paint2image') 18 | opt = parser.parse_args() 19 | opt = functions.post_config(opt) 20 | Gs = [] 21 | Zs = [] 22 | reals = [] 23 | NoiseAmp = [] 24 | dir2save = functions.generate_dir2save(opt) 25 | if dir2save is None: 26 | print('task does not exist') 27 | #elif (os.path.exists(dir2save)): 28 | # print("output already exist") 29 | else: 30 | try: 31 | os.makedirs(dir2save) 32 | except OSError: 33 | pass 34 | real = functions.read_image(opt) 35 | real = functions.adjust_scales2image(real, opt) 36 | Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) 37 | if (opt.paint_start_scale < 1) | (opt.paint_start_scale > (len(Gs)-1)): 38 | print("injection scale should be between 1 and %d" % (len(Gs)-1)) 39 | else: 40 | ref = functions.read_image_dir('%s/%s' % (opt.ref_dir, opt.ref_name), opt) 41 | if ref.shape[3] != real.shape[3]: 42 | ref = imresize_to_shape(ref, [real.shape[2], real.shape[3]], opt) 43 | ref = ref[:, :, :real.shape[2], :real.shape[3]] 44 | 45 | N = len(reals) - 1 46 | n = opt.paint_start_scale 47 | in_s = imresize(ref, pow(opt.scale_factor, (N - n + 1)), opt) 48 | in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] 49 | in_s = imresize(in_s, 1 / opt.scale_factor, opt) 50 | in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] 51 | if opt.quantization_flag: 52 | opt.mode = 'paint_train' 53 | dir2trained_model = functions.generate_dir2save(opt) 54 | # N = len(reals) - 1 55 | # n = opt.paint_start_scale 56 | real_s = imresize(real, pow(opt.scale_factor, (N - n)), opt) 57 | real_s = real_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] 58 | real_quant, centers = functions.quant(real_s, opt.device) 59 | plt.imsave('%s/real_quant.png' % dir2save, functions.convert_image_np(real_quant), vmin=0, vmax=1) 60 | plt.imsave('%s/in_paint.png' % dir2save, functions.convert_image_np(in_s), vmin=0, vmax=1) 61 | in_s = functions.quant2centers(ref, centers) 62 | in_s = imresize(in_s, pow(opt.scale_factor, (N - n)), opt) 63 | # in_s = in_s[:, :, :reals[n - 1].shape[2], :reals[n - 1].shape[3]] 64 | # in_s = imresize(in_s, 1 / opt.scale_factor, opt) 65 | in_s = in_s[:, :, :reals[n].shape[2], :reals[n].shape[3]] 66 | plt.imsave('%s/in_paint_quant.png' % dir2save, functions.convert_image_np(in_s), vmin=0, vmax=1) 67 | if (os.path.exists(dir2trained_model)): 68 | # print('Trained model does not exist, training SinGAN for SR') 69 | Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) 70 | opt.mode = 'paint2image' 71 | else: 72 | train_paint(opt, Gs, Zs, reals, NoiseAmp, centers, opt.paint_start_scale) 73 | opt.mode = 'paint2image' 74 | out = SinGAN_generate(Gs[n:], Zs[n:], reals, NoiseAmp[n:], opt, in_s, n=n, num_samples=1) 75 | plt.imsave('%s/start_scale=%d.png' % (dir2save, opt.paint_start_scale), functions.convert_image_np(out.detach()), vmin=0, vmax=1) 76 | 77 | 78 | 79 | 80 | 81 | -------------------------------------------------------------------------------- /random_samples.py: -------------------------------------------------------------------------------- 1 | from config import get_arguments 2 | from SinGAN.manipulate import * 3 | from SinGAN.training import * 4 | from SinGAN.imresize import imresize 5 | import SinGAN.functions as functions 6 | 7 | 8 | if __name__ == '__main__': 9 | parser = get_arguments() 10 | parser.add_argument('--input_dir', help='input image dir', default='Input/Images') 11 | parser.add_argument('--input_name', help='input image name', required=True) 12 | parser.add_argument('--mode', help='random_samples | random_samples_arbitrary_sizes', default='train', required=True) 13 | # for random_samples: 14 | parser.add_argument('--gen_start_scale', type=int, help='generation start scale', default=0) 15 | # for random_samples_arbitrary_sizes: 16 | parser.add_argument('--scale_h', type=float, help='horizontal resize factor for random samples', default=1.5) 17 | parser.add_argument('--scale_v', type=float, help='vertical resize factor for random samples', default=1) 18 | opt = parser.parse_args() 19 | opt = functions.post_config(opt) 20 | Gs = [] 21 | Zs = [] 22 | reals = [] 23 | NoiseAmp = [] 24 | dir2save = functions.generate_dir2save(opt) 25 | if dir2save is None: 26 | print('task does not exist') 27 | elif (os.path.exists(dir2save)): 28 | if opt.mode == 'random_samples': 29 | print('random samples for image %s, start scale=%d, already exist' % (opt.input_name, opt.gen_start_scale)) 30 | elif opt.mode == 'random_samples_arbitrary_sizes': 31 | print('random samples for image %s at size: scale_h=%f, scale_v=%f, already exist' % (opt.input_name, opt.scale_h, opt.scale_v)) 32 | else: 33 | try: 34 | os.makedirs(dir2save) 35 | except OSError: 36 | pass 37 | if opt.mode == 'random_samples': 38 | real = functions.read_image(opt) 39 | functions.adjust_scales2image(real, opt) 40 | Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) 41 | in_s = functions.generate_in2coarsest(reals,1,1,opt) 42 | SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, gen_start_scale=opt.gen_start_scale) 43 | 44 | elif opt.mode == 'random_samples_arbitrary_sizes': 45 | real = functions.read_image(opt) 46 | functions.adjust_scales2image(real, opt) 47 | Gs, Zs, reals, NoiseAmp = functions.load_trained_pyramid(opt) 48 | in_s = functions.generate_in2coarsest(reals,opt.scale_v,opt.scale_h,opt) 49 | SinGAN_generate(Gs, Zs, reals, NoiseAmp, opt, in_s, scale_v=opt.scale_v, scale_h=opt.scale_h) 50 | 51 | 52 | 53 | 54 | 55 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | scikit-image 3 | scikit-learn 4 | scipy 5 | numpy 6 | torch 7 | torchvision 8 | --------------------------------------------------------------------------------