├── .gitignore ├── F-LSeSim ├── LICENSE ├── README.md ├── combine.py ├── data │ ├── __init__.py │ ├── aligned_dataset.py │ ├── base_dataset.py │ ├── colorization_dataset.py │ ├── dataset.py │ ├── image_folder.py │ ├── single_dataset.py │ ├── singleimage_dataset.py │ ├── template_dataset.py │ └── unaligned_dataset.py ├── datasets │ ├── bibtex │ │ ├── cityscapes.tex │ │ ├── facades.tex │ │ ├── handbags.tex │ │ ├── shoes.tex │ │ └── transattr.tex │ ├── combine_A_and_B.py │ ├── download_cyclegan_dataset.sh │ ├── download_pix2pix_dataset.sh │ ├── make_dataset_aligned.py │ └── prepare_cityscapes_dataset.py ├── evaluations │ ├── DC.py │ ├── __init__.py │ ├── fid_score.py │ └── inception.py ├── inference.py ├── models │ ├── __init__.py │ ├── base_model.py │ ├── colorization_model.py │ ├── cycle_gan_model.py │ ├── cyclegan_networks.py │ ├── discriminator.py │ ├── downsample.py │ ├── generator.py │ ├── kin.py │ ├── losses.py │ ├── networks.py │ ├── normalization.py │ ├── pix2pix_model.py │ ├── sc_model.py │ ├── sinsc_model.py │ ├── stylegan_networks.py │ ├── template_model.py │ ├── test_model.py │ ├── tin.py │ ├── upsample.py │ └── util.py ├── options │ ├── __init__.py │ ├── base_options.py │ ├── test_options.py │ └── train_options.py ├── scripts │ ├── conda_deps.sh │ ├── download_cyclegan_model.sh │ ├── download_pix2pix_model.sh │ ├── edges │ │ ├── PostprocessHED.m │ │ └── batch_hed.py │ ├── eval_cityscapes │ │ ├── caffemodel │ │ │ └── deploy.prototxt │ │ ├── cityscapes.py │ │ ├── download_fcn8s.sh │ │ ├── evaluate.py │ │ └── util.py │ ├── install_deps.sh │ ├── test_before_push.py │ ├── test_colorization.sh │ ├── test_cyclegan.sh │ ├── test_fid.sh │ ├── test_pix2pix.sh │ ├── test_sc.sh │ ├── test_single.sh │ ├── test_sinsc.sh │ ├── train_colorization.sh │ ├── train_cyclegan.sh │ ├── train_pix2pix.sh │ ├── train_sc.sh │ ├── train_sinsc.sh │ └── transfer_sc.sh ├── test.py ├── test_fid.py ├── train.py └── util │ ├── __init__.py │ ├── get_data.py │ ├── html.py │ ├── image_pool.py │ ├── util.py │ └── visualizer.py ├── LICENSE ├── README.md ├── appendix ├── gpu_memory.py ├── inference_usage.png ├── proof_of_concept.py └── training_usage.png ├── combine.py ├── crop.py ├── crop_pipeline.py ├── data └── example │ ├── .gitkeep │ └── config.yaml ├── evaluation.py ├── experiments └── .gitkeep ├── imgs ├── Figure_KIN.jpg ├── Figure_patch_with_patch_lineplot_block1_mean.jpg ├── Figure_patch_with_patch_lineplot_blocks_mean.jpg └── URUST_anime.gif ├── inference.py ├── metric_images_with_ref.py ├── metric_whole_image_no_ref.py ├── metric_whole_image_with_ref.py ├── metrics ├── __init__.py ├── calculate_fid.py ├── histogram.py ├── inception.py ├── niqe.py ├── niqe_image_params.mat ├── piqe.py └── sobel.py ├── models ├── __init__.py ├── base.py ├── cut.py ├── cyclegan.py ├── discriminator.py ├── downsample.py ├── generator.py ├── kin.py ├── lsesim.py ├── lsesim_loss.py ├── model.py ├── normalization.py ├── projector.py ├── tests │ ├── __init__.py │ ├── test_cut.py │ ├── test_cyclegan.py │ ├── test_data │ │ └── configs │ │ │ ├── config_lung_lesion_for_test_cut.yaml │ │ │ └── config_lung_lesion_for_test_cyclegan.yaml │ └── test_kin.py ├── tin.py └── upsample.py ├── requirements.txt ├── train.py ├── transfer.py └── utils ├── __init__.py ├── dataset.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .ipynb_checkpoints/ 3 | *.out 4 | experiments/ 5 | checkpoints/ 6 | .pytest_cache 7 | test_dir_x/ 8 | test_dir_y/ -------------------------------------------------------------------------------- /F-LSeSim/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Chuanxia Zheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /F-LSeSim/README.md: -------------------------------------------------------------------------------- 1 | 2 | # Spatially-Correlative Loss 3 | 4 | [arXiv](https://arxiv.org/abs/2104.00854) | [website](http://www.chuanxiaz.com/publication/flsesim/) 5 |
6 | 7 | 8 |
9 | 10 | We provide the Pytorch implementation of "The Spatially-Correlative Loss for Various Image Translation Tasks". Based on the inherent self-similarity of object, we propose a new structure-preserving loss for one-sided unsupervised I2I network. The new loss will *deal only with spatial relationship of repeated signal, regardless of their original absolute value*. 11 | 12 | [The Spatially-Correlative Loss for Various Image Translation Tasks](https://arxiv.org/abs/2104.00854)
13 | [Chuanxia Zheng](http://www.chuanxiaz.com), [Tat-Jen Cham](http://www.ntu.edu.sg/home/astjcham/), [Jianfei Cai](https://research.monash.edu/en/persons/jianfei-cai)
14 | NTU and Monash University
15 | In CVPR2021
16 | 17 | ## ToDo 18 | - a simple example to use the proposed loss 19 | 20 | ## Example Results 21 | 22 | ### Unpaired Image-to-Image Translation 23 | 24 | 25 | 26 | ### Single Image Translation 27 | 28 | 29 | 30 | ### [More results on project page](http://www.chuanxiaz.com/publication/flsesim/) 31 | 32 | ## Getting Started 33 | 34 | ### Installation 35 | This code was tested with Pytorch 1.7.0, CUDA 10.2, and Python 3.7 36 | 37 | - Install Pytoch 1.7.0, torchvision, and other dependencies from [http://pytorch.org](http://pytorch.org) 38 | - Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate) for visualization 39 | 40 | ``` 41 | pip install visdom dominate 42 | ``` 43 | - Clone this repo: 44 | 45 | ``` 46 | git clone https://github.com/lyndonzheng/F-LSeSim 47 | cd F-LSeSim 48 | ``` 49 | 50 | ### [Datasets](https://github.com/taesungp/contrastive-unpaired-translation/blob/master/docs/datasets.md) 51 | Please refer to the original [CUT](https://github.com/taesungp/contrastive-unpaired-translation) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) to download datasets and learn how to create your own datasets. 52 | 53 | ### Training 54 | 55 | - Train the *single-modal* I2I translation model: 56 | 57 | ``` 58 | sh ./scripts/train_sc.sh 59 | ``` 60 | 61 | - Set ```--use_norm``` for cosine similarity map, the default similarity is dot-based attention score. ```--learned_attn, --augment``` for the learned self-similarity. 62 | - To view training results and loss plots, run ```python -m visdom.server``` and copy the URL [http://localhost:port](http://localhost:port). 63 | - Training models will be saved under the **checkpoints** folder. 64 | - The more training options can be found in the **options** folder. 65 |

66 | 67 | 68 | - Train the *single-image* translation model: 69 | 70 | ``` 71 | sh ./scripts/train_sinsc.sh 72 | ``` 73 | 74 | As the *multi-modal* I2I translation model was trained on [MUNIT](https://github.com/NVlabs/MUNIT), we would not plan to merge the code to this repository. If you wish to obtain multi-modal results, please contact us at chuanxia001@e.ntu.edu.sg. 75 | 76 | ### Testing 77 | 78 | - Test the *single-modal* I2I translation model: 79 | 80 | ``` 81 | sh ./scripts/test_sc.sh 82 | ``` 83 | 84 | - Test the *single-image* translation model: 85 | 86 | ``` 87 | sh ./scripts/test_sinsc.sh 88 | ``` 89 | 90 | - Test the FID score for all training epochs: 91 | 92 | ``` 93 | sh ./scripts/test_fid.sh 94 | ``` 95 | 96 | ### Pretrained Models 97 | 98 | Download the pre-trained models (will be released soon) using the following links and put them under```checkpoints/``` directory. 99 | 100 | - ```Single-modal translation model```: [horse2zebra](https://drive.google.com/drive/folders/1k8Y5R6CnaDwfkha_lD5_yQTvoajoU6GR?usp=sharing), [semantic2image](https://drive.google.com/drive/folders/1xnF6wLTPhD35-2It8IIomJRhFZdr2qXp?usp=sharing), [apple2orange](https://drive.google.com/drive/folders/1Z9PwxkWlakDdv12Jha6WJRgO6cSfEZGs?usp=sharing) 101 | - ```Single-image translation model```: [image2monet](https://drive.google.com/drive/folders/1QcGY9H0USWHJtcifRMWh_KHOJszME6-U?usp=sharing) 102 | 103 | ## Citation 104 | ``` 105 | @inproceedings{zheng2021spatiallycorrelative, 106 | title={The Spatially-Correlative Loss for Various Image Translation Tasks}, 107 | author={Zheng, Chuanxia and Cham, Tat-Jen and Cai, Jianfei}, 108 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition}, 109 | year={2021} 110 | } 111 | ``` 112 | 113 | ## Acknowledge 114 | Our code is developed based on [CUT](https://github.com/taesungp/contrastive-unpaired-translation) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). We also thank [pytorch-fid](https://github.com/mseitzer/pytorch-fid) for FID computation, [LPIPS](https://github.com/richzhang/PerceptualSimilarity) for diversity score, and [D&C](https://github.com/clovaai/generative-evaluation-prdc) for density and coverage evaluation. 115 | 116 | 117 | 118 | 119 | -------------------------------------------------------------------------------- /F-LSeSim/combine.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | import yaml 7 | from PIL import Image 8 | from yaml.loader import SafeLoader 9 | 10 | 11 | def read_yaml_config(config_path): 12 | with open(config_path) as f: 13 | config = yaml.load(f, Loader=SafeLoader) 14 | return config 15 | 16 | 17 | def main(): 18 | parser = argparse.ArgumentParser("Combined transferred images") 19 | parser.add_argument( 20 | "-c", 21 | "--config", 22 | type=str, 23 | default="./config.yaml", 24 | help="Path to the config file.", 25 | ) 26 | parser.add_argument("--patch_size", type=int, help="Patch size", default=512) 27 | parser.add_argument("--resize_h", type=int, help="Resize H", default=-1) 28 | parser.add_argument("--resize_w", type=int, help="Resize W", default=-1) 29 | parser.add_argument("--read_original", action="store_true") 30 | args = parser.parse_args() 31 | 32 | config = read_yaml_config(args.config) 33 | 34 | basename = os.path.basename(config["INFERENCE_SETTING"]["TEST_X"]) 35 | filename = os.path.splitext(basename)[0] 36 | path_root = os.path.join( 37 | config["EXPERIMENT_ROOT_PATH"], config["EXPERIMENT_NAME"], "test", filename 38 | ) 39 | if ( 40 | "OVERWRITE_OUTPUT_PATH" in config["INFERENCE_SETTING"] 41 | and config["INFERENCE_SETTING"]["OVERWRITE_OUTPUT_PATH"] != "" 42 | ): 43 | path_root = config["INFERENCE_SETTING"]["OVERWRITE_OUTPUT_PATH"] 44 | 45 | path_base = os.path.join( 46 | path_root, 47 | config["INFERENCE_SETTING"]["NORMALIZATION"]["TYPE"], 48 | config["INFERENCE_SETTING"]["MODEL_VERSION"], 49 | ) 50 | 51 | combined_image_name = f"combined_{config['INFERENCE_SETTING']['NORMALIZATION']['TYPE']}_{config['INFERENCE_SETTING']['MODEL_VERSION']}.png" 52 | 53 | if config["INFERENCE_SETTING"]["NORMALIZATION"]["TYPE"] == "kin": 54 | path_base = os.path.join( 55 | path_base, 56 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}", 57 | ) 58 | combined_image_name = f"combined_{config['INFERENCE_SETTING']['NORMALIZATION']['TYPE']}_{config['INFERENCE_SETTING']['MODEL_VERSION']}_{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}.png" 59 | 60 | filenames = os.listdir(path_base) 61 | try: 62 | filenames.remove("thumbnail_Y_fake.png") 63 | except: 64 | pass 65 | 66 | y_anchor_max = 0 67 | x_anchor_max = 0 68 | for filename in filenames: 69 | _, _, y_anchor, x_anchor, _ = filename.split("_", 4) 70 | y_anchor_max = max(y_anchor_max, int(y_anchor)) 71 | x_anchor_max = max(x_anchor_max, int(x_anchor)) 72 | 73 | matrix = np.zeros( 74 | (y_anchor_max + args.patch_size, x_anchor_max + args.patch_size, 3), 75 | dtype=np.uint8, 76 | ) 77 | 78 | for filename in sorted(filenames): 79 | print(f"Combine {filename} ", end="\r") 80 | _, _, y_anchor, x_anchor, _ = filename.split("_", 4) 81 | y_anchor = int(y_anchor) 82 | x_anchor = int(x_anchor) 83 | image = cv2.imread(os.path.join(path_base, filename)) 84 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 85 | matrix[y_anchor : y_anchor + 512, x_anchor : x_anchor + 512, :] = image 86 | 87 | if (args.resize_h != -1) and (args.resize_w != -1): 88 | matrix = cv2.resize(matrix, (args.resize_w, args.resize_h), cv2.INTER_CUBIC) 89 | 90 | if args.read_original: 91 | H, W, _ = cv2.imread(config["INFERENCE_SETTING"]["TEST_X"]).shape 92 | matrix = cv2.resize(matrix, (W, H), cv2.INTER_CUBIC) 93 | 94 | matrix_image = Image.fromarray(matrix) 95 | matrix_image.save(os.path.join(path_root, combined_image_name)) 96 | 97 | 98 | if __name__ == "__main__": 99 | main() 100 | -------------------------------------------------------------------------------- /F-LSeSim/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | 15 | import torch.utils.data 16 | 17 | from data.base_dataset import BaseDataset 18 | 19 | 20 | def find_dataset_using_name(dataset_name): 21 | """Import the module "data/[dataset_name]_dataset.py". 22 | 23 | In the file, the class called DatasetNameDataset() will 24 | be instantiated. It has to be a subclass of BaseDataset, 25 | and it is case-insensitive. 26 | """ 27 | dataset_filename = "data." + dataset_name + "_dataset" 28 | datasetlib = importlib.import_module(dataset_filename) 29 | 30 | dataset = None 31 | target_dataset_name = dataset_name.replace("_", "") + "dataset" 32 | for name, cls in datasetlib.__dict__.items(): 33 | if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset): 34 | dataset = cls 35 | 36 | if dataset is None: 37 | raise NotImplementedError( 38 | "In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." 39 | % (dataset_filename, target_dataset_name) 40 | ) 41 | 42 | return dataset 43 | 44 | 45 | def get_option_setter(dataset_name): 46 | """Return the static method of the dataset class.""" 47 | dataset_class = find_dataset_using_name(dataset_name) 48 | return dataset_class.modify_commandline_options 49 | 50 | 51 | def create_dataset(opt): 52 | """Create a dataset given the option. 53 | 54 | This function wraps the class CustomDatasetDataLoader. 55 | This is the main interface between this package and 'train.py'/'test.py' 56 | 57 | Example: 58 | >>> from data import create_dataset 59 | >>> dataset = create_dataset(opt) 60 | """ 61 | data_loader = CustomDatasetDataLoader(opt) 62 | dataset = data_loader.load_data() 63 | return dataset 64 | 65 | 66 | class CustomDatasetDataLoader: 67 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 68 | 69 | def __init__(self, opt): 70 | """Initialize this class 71 | 72 | Step 1: create a dataset instance given the name [dataset_mode] 73 | Step 2: create a multi-threaded data loader. 74 | """ 75 | self.opt = opt 76 | dataset_class = find_dataset_using_name(opt.dataset_mode) 77 | self.dataset = dataset_class(opt) 78 | print("dataset [%s] was created" % type(self.dataset).__name__) 79 | self.dataloader = torch.utils.data.DataLoader( 80 | self.dataset, 81 | batch_size=opt.batch_size, 82 | shuffle=not opt.serial_batches, 83 | num_workers=int(opt.num_threads), 84 | ) 85 | 86 | def load_data(self): 87 | return self 88 | 89 | def __len__(self): 90 | """Return the number of data in the dataset""" 91 | return min(len(self.dataset), self.opt.max_dataset_size) 92 | 93 | def __iter__(self): 94 | """Return a batch of data""" 95 | for i, data in enumerate(self.dataloader): 96 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 97 | break 98 | yield data 99 | -------------------------------------------------------------------------------- /F-LSeSim/data/aligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | 5 | from data.base_dataset import BaseDataset, get_params, get_transform 6 | from data.image_folder import make_dataset 7 | 8 | 9 | class AlignedDataset(BaseDataset): 10 | """A dataset class for paired image dataset. 11 | 12 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}. 13 | During test time, you need to prepare a directory '/path/to/data/test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory 24 | self.AB_paths = sorted( 25 | make_dataset(self.dir_AB, opt.max_dataset_size) 26 | ) # get image paths 27 | assert ( 28 | self.opt.load_size >= self.opt.crop_size 29 | ) # crop_size should be smaller than the size of loaded image 30 | self.input_nc = ( 31 | self.opt.output_nc if self.opt.direction == "BtoA" else self.opt.input_nc 32 | ) 33 | self.output_nc = ( 34 | self.opt.input_nc if self.opt.direction == "BtoA" else self.opt.output_nc 35 | ) 36 | 37 | def __getitem__(self, index): 38 | """Return a data point and its metadata information. 39 | 40 | Parameters: 41 | index - - a random integer for data indexing 42 | 43 | Returns a dictionary that contains A, B, A_paths and B_paths 44 | A (tensor) - - an image in the input domain 45 | B (tensor) - - its corresponding image in the target domain 46 | A_paths (str) - - image paths 47 | B_paths (str) - - image paths (same as A_paths) 48 | """ 49 | # read a image given a random integer index 50 | AB_path = self.AB_paths[index] 51 | AB = Image.open(AB_path).convert("RGB") 52 | # split AB image into A and B 53 | w, h = AB.size 54 | w2 = int(w / 2) 55 | A = AB.crop((0, 0, w2, h)) 56 | B = AB.crop((w2, 0, w, h)) 57 | 58 | # apply the same transform to both A and B 59 | transform_params = get_params(self.opt, A.size) 60 | A_transform = get_transform( 61 | self.opt, transform_params, grayscale=(self.input_nc == 1) 62 | ) 63 | B_transform = get_transform( 64 | self.opt, transform_params, grayscale=(self.output_nc == 1) 65 | ) 66 | 67 | A = A_transform(A) 68 | B = B_transform(B) 69 | 70 | return {"A": A, "B": B, "A_paths": AB_path, "B_paths": AB_path} 71 | 72 | def __len__(self): 73 | """Return the total number of images in the dataset.""" 74 | return len(self.AB_paths) 75 | -------------------------------------------------------------------------------- /F-LSeSim/data/colorization_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | from skimage import color # require skimage 7 | 8 | from data.base_dataset import BaseDataset, get_transform 9 | from data.image_folder import make_dataset 10 | 11 | 12 | class ColorizationDataset(BaseDataset): 13 | """This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space. 14 | 15 | This dataset is required by pix2pix-based colorization model ('--model colorization') 16 | """ 17 | 18 | @staticmethod 19 | def modify_commandline_options(parser, is_train): 20 | """Add new dataset-specific options, and rewrite default values for existing options. 21 | 22 | Parameters: 23 | parser -- original option parser 24 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 25 | 26 | Returns: 27 | the modified parser. 28 | 29 | By default, the number of channels for input image is 1 (L) and 30 | the number of channels for output image is 2 (ab). The direction is from A to B 31 | """ 32 | parser.set_defaults(input_nc=1, output_nc=2, direction="AtoB") 33 | return parser 34 | 35 | def __init__(self, opt): 36 | """Initialize this dataset class. 37 | 38 | Parameters: 39 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 40 | """ 41 | BaseDataset.__init__(self, opt) 42 | self.dir = os.path.join(opt.dataroot, opt.phase) 43 | self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size)) 44 | assert opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == "AtoB" 45 | self.transform = get_transform(self.opt, convert=False) 46 | 47 | def __getitem__(self, index): 48 | """Return a data point and its metadata information. 49 | 50 | Parameters: 51 | index - - a random integer for data indexing 52 | 53 | Returns a dictionary that contains A, B, A_paths and B_paths 54 | A (tensor) - - the L channel of an image 55 | B (tensor) - - the ab channels of the same image 56 | A_paths (str) - - image paths 57 | B_paths (str) - - image paths (same as A_paths) 58 | """ 59 | path = self.AB_paths[index] 60 | im = Image.open(path).convert("RGB") 61 | im = self.transform(im) 62 | im = np.array(im) 63 | lab = color.rgb2lab(im).astype(np.float32) 64 | lab_t = transforms.ToTensor()(lab) 65 | A = lab_t[[0], ...] / 50.0 - 1.0 66 | B = lab_t[[1, 2], ...] / 110.0 67 | return {"A": A, "B": B, "A_paths": path, "B_paths": path} 68 | 69 | def __len__(self): 70 | """Return the total number of images in the dataset.""" 71 | return len(self.AB_paths) 72 | -------------------------------------------------------------------------------- /F-LSeSim/data/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from pathlib import Path 4 | 5 | import albumentations as A 6 | import numpy as np 7 | from albumentations.pytorch import ToTensorV2 8 | from PIL import Image 9 | from torch.utils.data import Dataset 10 | 11 | test_transforms = A.Compose( 12 | [ 13 | A.Resize(width=512, height=512), 14 | A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255), 15 | ToTensorV2(), 16 | ], 17 | additional_targets={"image0": "image"}, 18 | ) 19 | 20 | 21 | def remove_file(files, file_name): 22 | try: 23 | files.remove(file_name) 24 | except Exception: 25 | pass 26 | 27 | 28 | class XInferenceDataset(Dataset): 29 | def __init__(self, root_X, transform=None, return_anchor=False, thumbnail=None): 30 | self.root_X = root_X 31 | self.transform = transform 32 | self.return_anchor = return_anchor 33 | self.thumbnail = thumbnail 34 | 35 | self.X_images = os.listdir(root_X) 36 | 37 | remove_file(self.X_images, "thumbnail.png") 38 | remove_file(self.X_images, "blank_patches_list.csv") 39 | 40 | if self.return_anchor: 41 | self.__get_boundary() 42 | 43 | self.length_dataset = len(self.X_images) 44 | 45 | def __get_boundary(self): 46 | self.y_anchor_num = 0 47 | self.x_anchor_num = 0 48 | for X_image in self.X_images: 49 | y_idx, x_idx, _, _ = Path(X_image).stem.split("_")[:4] 50 | y_idx = int(y_idx) 51 | x_idx = int(x_idx) 52 | self.y_anchor_num = max(self.y_anchor_num, y_idx) 53 | self.x_anchor_num = max(self.x_anchor_num, x_idx) 54 | 55 | def get_boundary(self): 56 | assert self.return_anchor == True 57 | return (self.y_anchor_num, self.x_anchor_num) 58 | 59 | def __len__(self): 60 | return self.length_dataset 61 | 62 | def __getitem__(self, index): 63 | X_img_name = self.X_images[index] 64 | 65 | X_path = os.path.join(self.root_X, X_img_name) 66 | 67 | X_img = np.array(Image.open(X_path).convert("RGB")) 68 | 69 | if self.transform: 70 | augmentations = self.transform(image=X_img) 71 | X_img = augmentations["image"] 72 | 73 | if self.return_anchor: 74 | y_idx, x_idx, y_anchor, x_anchor = Path(X_img_name).stem.split("_")[:4] 75 | y_idx = int(y_idx) 76 | x_idx = int(x_idx) 77 | return { 78 | "X_img": X_img, 79 | "X_path": X_path, 80 | "y_idx": y_idx, 81 | "x_idx": x_idx, 82 | "y_anchor": y_anchor, 83 | "x_anchor": x_anchor, 84 | } 85 | 86 | else: 87 | return {"X_img": X_img, "X_path": X_path} 88 | 89 | def get_thumbnail(self): 90 | thumbnail_img = np.array(Image.open(self.thumbnail).convert("RGB")) 91 | if self.transform: 92 | augmentations = self.transform(image=thumbnail_img) 93 | thumbnail_img = augmentations["image"] 94 | return thumbnail_img.unsqueeze(0) 95 | -------------------------------------------------------------------------------- /F-LSeSim/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import os 8 | 9 | import torch.utils.data as data 10 | from PIL import Image 11 | 12 | IMG_EXTENSIONS = [ 13 | ".jpg", 14 | ".JPG", 15 | ".jpeg", 16 | ".JPEG", 17 | ".png", 18 | ".PNG", 19 | ".ppm", 20 | ".PPM", 21 | ".bmp", 22 | ".BMP", 23 | ".tif", 24 | ".TIF", 25 | ".tiff", 26 | ".TIFF", 27 | ] 28 | 29 | 30 | def is_image_file(filename): 31 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 32 | 33 | 34 | def make_dataset(dir, max_dataset_size=float("inf")): 35 | images = [] 36 | assert os.path.isdir(dir), "%s is not a valid directory" % dir 37 | 38 | for root, _, fnames in sorted(os.walk(dir)): 39 | for fname in fnames: 40 | if is_image_file(fname): 41 | path = os.path.join(root, fname) 42 | images.append(path) 43 | return images[: min(max_dataset_size, len(images))] 44 | 45 | 46 | def default_loader(path): 47 | return Image.open(path).convert("RGB") 48 | 49 | 50 | class ImageFolder(data.Dataset): 51 | def __init__(self, root, transform=None, return_paths=False, loader=default_loader): 52 | imgs = make_dataset(root) 53 | if len(imgs) == 0: 54 | raise ( 55 | RuntimeError( 56 | "Found 0 images in: " + root + "\n" 57 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS) 58 | ) 59 | ) 60 | 61 | self.root = root 62 | self.imgs = imgs 63 | self.transform = transform 64 | self.return_paths = return_paths 65 | self.loader = loader 66 | 67 | def __getitem__(self, index): 68 | path = self.imgs[index] 69 | img = self.loader(path) 70 | if self.transform is not None: 71 | img = self.transform(img) 72 | if self.return_paths: 73 | return img, path 74 | else: 75 | return img 76 | 77 | def __len__(self): 78 | return len(self.imgs) 79 | -------------------------------------------------------------------------------- /F-LSeSim/data/single_dataset.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from data.base_dataset import BaseDataset, get_transform 4 | from data.image_folder import make_dataset 5 | 6 | 7 | class SingleDataset(BaseDataset): 8 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 9 | 10 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 11 | """ 12 | 13 | def __init__(self, opt): 14 | """Initialize this dataset class. 15 | 16 | Parameters: 17 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 18 | """ 19 | BaseDataset.__init__(self, opt) 20 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) 21 | input_nc = ( 22 | self.opt.output_nc if self.opt.direction == "BtoA" else self.opt.input_nc 23 | ) 24 | self.transform = get_transform(opt, grayscale=(input_nc == 1)) 25 | 26 | def __getitem__(self, index): 27 | """Return a data point and its metadata information. 28 | 29 | Parameters: 30 | index - - a random integer for data indexing 31 | 32 | Returns a dictionary that contains A and A_paths 33 | A(tensor) - - an image in one domain 34 | A_paths(str) - - the path of the image 35 | """ 36 | A_path = self.A_paths[index] 37 | A_img = Image.open(A_path).convert("RGB") 38 | A = self.transform(A_img) 39 | return {"A": A, "A_paths": A_path} 40 | 41 | def __len__(self): 42 | """Return the total number of images in the dataset.""" 43 | return len(self.A_paths) 44 | -------------------------------------------------------------------------------- /F-LSeSim/data/singleimage_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import random 3 | 4 | import numpy as np 5 | from PIL import Image 6 | 7 | from data.base_dataset import BaseDataset, get_transform 8 | from data.image_folder import make_dataset 9 | 10 | 11 | class SingleImageDataset(BaseDataset): 12 | """ 13 | This dataset class can load unaligned/unpaired datasets. 14 | 15 | It requires two directories to host training images from domain A '/path/to/data/trainA' 16 | and from domain B '/path/to/data/trainB' respectively. 17 | You can train the model with the dataset flag '--dataroot /path/to/data'. 18 | Similarly, you need to prepare two directories: 19 | '/path/to/data/testA' and '/path/to/data/testB' during test time. 20 | """ 21 | 22 | def __init__(self, opt): 23 | """Initialize this dataset class. 24 | 25 | Parameters: 26 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 27 | """ 28 | BaseDataset.__init__(self, opt) 29 | 30 | self.dir_A = os.path.join( 31 | opt.dataroot, "trainA" 32 | ) # create a path '/path/to/data/trainA' 33 | self.dir_B = os.path.join( 34 | opt.dataroot, "trainB" 35 | ) # create a path '/path/to/data/trainB' 36 | 37 | if os.path.exists(self.dir_A) and os.path.exists(self.dir_B): 38 | self.A_paths = sorted( 39 | make_dataset(self.dir_A, opt.max_dataset_size) 40 | ) # load images from '/path/to/data/trainA' 41 | self.B_paths = sorted( 42 | make_dataset(self.dir_B, opt.max_dataset_size) 43 | ) # load images from '/path/to/data/trainB' 44 | self.A_size = len(self.A_paths) # get the size of dataset A 45 | self.B_size = len(self.B_paths) # get the size of dataset B 46 | 47 | assert ( 48 | len(self.A_paths) == 1 and len(self.B_paths) == 1 49 | ), "SingleImageDataset class should be used with one image in each domain" 50 | A_img = Image.open(self.A_paths[0]).convert("RGB") 51 | B_img = Image.open(self.B_paths[0]).convert("RGB") 52 | print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size))) 53 | 54 | self.A_img = A_img 55 | self.B_img = B_img 56 | 57 | # In single-image translation, we augment the data loader by applying 58 | # random scaling. Still, we design the data loader such that the 59 | # amount of scaling is the same within a minibatch. To do this, 60 | # we precompute the random scaling values, and repeat them by |batch_size|. 61 | A_zoom = 1 / self.opt.random_scale_max 62 | zoom_levels_A = np.random.uniform( 63 | A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2) 64 | ) 65 | self.zoom_levels_A = np.reshape( 66 | np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2] 67 | ) 68 | 69 | B_zoom = 1 / self.opt.random_scale_max 70 | zoom_levels_B = np.random.uniform( 71 | B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2) 72 | ) 73 | self.zoom_levels_B = np.reshape( 74 | np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2] 75 | ) 76 | 77 | # While the crop locations are randomized, the negative samples should 78 | # not come from the same location. To do this, we precompute the 79 | # crop locations with no repetition. 80 | self.patch_indices_A = list(range(len(self))) 81 | random.shuffle(self.patch_indices_A) 82 | self.patch_indices_B = list(range(len(self))) 83 | random.shuffle(self.patch_indices_B) 84 | 85 | def __getitem__(self, index): 86 | """Return a data point and its metadata information. 87 | 88 | Parameters: 89 | index (int) -- a random integer for data indexing 90 | 91 | Returns a dictionary that contains A, B, A_paths and B_paths 92 | A (tensor) -- an image in the input domain 93 | B (tensor) -- its corresponding image in the target domain 94 | A_paths (str) -- image paths 95 | B_paths (str) -- image paths 96 | """ 97 | A_path = self.A_paths[0] 98 | B_path = self.B_paths[0] 99 | A_img = self.A_img 100 | B_img = self.B_img 101 | 102 | # apply image transformation 103 | if self.opt.phase == "train": 104 | param = { 105 | "scale_factor": self.zoom_levels_A[index], 106 | "patch_index": self.patch_indices_A[index], 107 | "flip": random.random() > 0.5, 108 | } 109 | 110 | transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR) 111 | A = transform_A(A_img) 112 | 113 | param = { 114 | "scale_factor": self.zoom_levels_B[index], 115 | "patch_index": self.patch_indices_B[index], 116 | "flip": random.random() > 0.5, 117 | } 118 | transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR) 119 | B = transform_B(B_img) 120 | else: 121 | transform = get_transform(self.opt, method=Image.BILINEAR) 122 | A = transform(A_img) 123 | B = transform(B_img) 124 | 125 | return {"A": A, "B": B, "A_paths": A_path, "B_paths": B_path} 126 | 127 | def __len__(self): 128 | """Let's pretend the single image contains 100,000 crops for convenience.""" 129 | return 100000 130 | -------------------------------------------------------------------------------- /F-LSeSim/data/template_dataset.py: -------------------------------------------------------------------------------- 1 | """Dataset class template 2 | 3 | This module provides a template for users to implement custom datasets. 4 | You can specify '--dataset_mode template' to use this dataset. 5 | The class name should be consistent with both the filename and its dataset_mode option. 6 | The filename should be _dataset.py 7 | The class name should be Dataset.py 8 | You need to implement the following functions: 9 | -- : Add dataset-specific options and rewrite default values for existing options. 10 | -- <__init__>: Initialize this dataset class. 11 | -- <__getitem__>: Return a data point and its metadata information. 12 | -- <__len__>: Return the number of images. 13 | """ 14 | from data.base_dataset import BaseDataset, get_transform 15 | 16 | # from data.image_folder import make_dataset 17 | # from PIL import Image 18 | 19 | 20 | class TemplateDataset(BaseDataset): 21 | """A template dataset class for you to implement custom datasets.""" 22 | 23 | @staticmethod 24 | def modify_commandline_options(parser, is_train): 25 | """Add new dataset-specific options, and rewrite default values for existing options. 26 | 27 | Parameters: 28 | parser -- original option parser 29 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 30 | 31 | Returns: 32 | the modified parser. 33 | """ 34 | parser.add_argument( 35 | "--new_dataset_option", type=float, default=1.0, help="new dataset option" 36 | ) 37 | parser.set_defaults( 38 | max_dataset_size=10, new_dataset_option=2.0 39 | ) # specify dataset-specific default values 40 | return parser 41 | 42 | def __init__(self, opt): 43 | """Initialize this dataset class. 44 | 45 | Parameters: 46 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 47 | 48 | A few things can be done here. 49 | - save the options (have been done in BaseDataset) 50 | - get image paths and meta information of the dataset. 51 | - define the image transformation. 52 | """ 53 | # save the option and dataset root 54 | BaseDataset.__init__(self, opt) 55 | # get the image paths of your dataset; 56 | self.image_paths = ( 57 | [] 58 | ) # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root 59 | # define the default transform function. You can use ; You can also define your custom transform function 60 | self.transform = get_transform(opt) 61 | 62 | def __getitem__(self, index): 63 | """Return a data point and its metadata information. 64 | 65 | Parameters: 66 | index -- a random integer for data indexing 67 | 68 | Returns: 69 | a dictionary of data with their names. It usually contains the data itself and its metadata information. 70 | 71 | Step 1: get a random image path: e.g., path = self.image_paths[index] 72 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB'). 73 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image) 74 | Step 4: return a data point as a dictionary. 75 | """ 76 | path = "temp" # needs to be a string 77 | data_A = None # needs to be a tensor 78 | data_B = None # needs to be a tensor 79 | return {"data_A": data_A, "data_B": data_B, "path": path} 80 | 81 | def __len__(self): 82 | """Return the total number of images.""" 83 | return len(self.image_paths) 84 | -------------------------------------------------------------------------------- /F-LSeSim/data/unaligned_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import torchvision.transforms as transforms 5 | from PIL import Image 6 | 7 | import util.util as util 8 | from data.base_dataset import BaseDataset, get_transform 9 | from data.image_folder import make_dataset 10 | 11 | 12 | def remove_file(files, file_name): 13 | try: 14 | files.remove(file_name) 15 | except Exception: 16 | pass 17 | 18 | 19 | class UnalignedDataset(BaseDataset): 20 | """ 21 | This dataset class can load unaligned/unpaired datasets. 22 | 23 | It requires two directories to host training images from domain A '/path/to/data/trainA' 24 | and from domain B '/path/to/data/trainB' respectively. 25 | You can train the model with the dataset flag '--dataroot /path/to/data'. 26 | Similarly, you need to prepare two directories: 27 | '/path/to/data/testA' and '/path/to/data/testB' during test time. 28 | """ 29 | 30 | def __init__(self, opt): 31 | """Initialize this dataset class. 32 | 33 | Parameters: 34 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 35 | """ 36 | BaseDataset.__init__(self, opt) 37 | self.dir_A = os.path.join( 38 | opt.dataroot, opt.phase + "X" 39 | ) # create a path '/path/to/data/trainX' 40 | self.dir_B = os.path.join( 41 | opt.dataroot, opt.phase + "Y" 42 | ) # create a path '/path/to/data/trainY' 43 | 44 | self.A_paths = sorted( 45 | make_dataset(self.dir_A, opt.max_dataset_size) 46 | ) # load images from '/path/to/data/trainA' 47 | self.B_paths = sorted( 48 | make_dataset(self.dir_B, opt.max_dataset_size) 49 | ) # load images from '/path/to/data/trainB' 50 | self.A_size = len(self.A_paths) # get the size of dataset A 51 | self.B_size = len(self.B_paths) # get the size of dataset B 52 | # apply image transformation 53 | # For FastCUT mode, if in finetuning phase (learning rate is decaying) 54 | # do not perform resize-crop data augmentation of CycleGAN. 55 | # print('current_epoch', self.current_epoch) 56 | self.transform = get_transform(opt, convert=False) 57 | if self.opt.isTrain and opt.augment: 58 | self.transform_aug = transforms.Compose( 59 | [ 60 | transforms.ColorJitter( 61 | brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3 62 | ), 63 | transforms.ToTensor(), 64 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 65 | ] 66 | ) 67 | else: 68 | self.transform_aug = None 69 | self.transform_tensor = transforms.Compose( 70 | [ 71 | transforms.ToTensor(), 72 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), 73 | ] 74 | ) 75 | 76 | def __getitem__(self, index): 77 | """Return a data point and its metadata information. 78 | 79 | Parameters: 80 | index (int) -- a random integer for data indexing 81 | 82 | Returns a dictionary that contains A, B, A_paths and B_paths 83 | A (tensor) -- an image in the input domain 84 | B (tensor) -- its corresponding image in the target domain 85 | A_paths (str) -- image paths 86 | B_paths (str) -- image paths 87 | """ 88 | A_path = self.A_paths[ 89 | index % self.A_size 90 | ] # make sure index is within then range 91 | if self.opt.serial_batches: # make sure index is within then range 92 | index_B = index % self.B_size 93 | else: # randomize the index for domain B to avoid fixed pairs. 94 | index_B = random.randint(0, self.B_size - 1) 95 | B_path = self.B_paths[index_B] 96 | A_img = Image.open(A_path).convert("RGB") 97 | B_img = Image.open(B_path).convert("RGB") 98 | A_pil = self.transform(A_img) 99 | B_pil = self.transform(B_img) 100 | A = self.transform_tensor(A_pil) 101 | B = self.transform_tensor(B_pil) 102 | if self.opt.isTrain and self.transform_aug is not None: 103 | A_aug = self.transform_aug(A_pil) 104 | B_aug = self.transform_aug(B_pil) 105 | return { 106 | "A": A, 107 | "B": B, 108 | "A_paths": A_path, 109 | "B_paths": B_path, 110 | "A_aug": A_aug, 111 | "B_aug": B_aug, 112 | } 113 | else: 114 | return {"A": A, "B": B, "A_paths": A_path, "B_paths": B_path} 115 | 116 | def __len__(self): 117 | """Return the total number of images in the dataset. 118 | 119 | As we have two datasets with potentially different number of images, 120 | we take a maximum of 121 | """ 122 | return max(self.A_size, self.B_size) 123 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/bibtex/cityscapes.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{Cordts2016Cityscapes, 2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding}, 3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt}, 4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 5 | year={2016} 6 | } 7 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/bibtex/facades.tex: -------------------------------------------------------------------------------- 1 | @INPROCEEDINGS{Tylecek13, 2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra}, 3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure}, 4 | booktitle = {Proc. GCPR}, 5 | year = {2013}, 6 | address = {Saarbrucken, Germany}, 7 | } 8 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/bibtex/handbags.tex: -------------------------------------------------------------------------------- 1 | @inproceedings{zhu2016generative, 2 | title={Generative Visual Manipulation on the Natural Image Manifold}, 3 | author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.}, 4 | booktitle={Proceedings of European Conference on Computer Vision (ECCV)}, 5 | year={2016} 6 | } 7 | 8 | @InProceedings{xie15hed, 9 | author = {"Xie, Saining and Tu, Zhuowen"}, 10 | Title = {Holistically-Nested Edge Detection}, 11 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision", 12 | Year = {2015}, 13 | } 14 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/bibtex/shoes.tex: -------------------------------------------------------------------------------- 1 | @InProceedings{fine-grained, 2 | author = {A. Yu and K. Grauman}, 3 | title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning}, 4 | booktitle = {Computer Vision and Pattern Recognition (CVPR)}, 5 | month = {June}, 6 | year = {2014} 7 | } 8 | 9 | @InProceedings{xie15hed, 10 | author = {"Xie, Saining and Tu, Zhuowen"}, 11 | Title = {Holistically-Nested Edge Detection}, 12 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision", 13 | Year = {2015}, 14 | } 15 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/bibtex/transattr.tex: -------------------------------------------------------------------------------- 1 | @article {Laffont14, 2 | title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes}, 3 | author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays}, 4 | journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)}, 5 | volume = {33}, 6 | number = {4}, 7 | year = {2014} 8 | } 9 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/combine_A_and_B.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from multiprocessing import Pool 4 | 5 | import cv2 6 | import numpy as np 7 | 8 | 9 | def image_write(path_A, path_B, path_AB): 10 | im_A = cv2.imread( 11 | path_A, 1 12 | ) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR 13 | im_B = cv2.imread( 14 | path_B, 1 15 | ) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR 16 | im_AB = np.concatenate([im_A, im_B], 1) 17 | cv2.imwrite(path_AB, im_AB) 18 | 19 | 20 | parser = argparse.ArgumentParser("create image pairs") 21 | parser.add_argument( 22 | "--fold_A", 23 | dest="fold_A", 24 | help="input directory for image A", 25 | type=str, 26 | default="../dataset/50kshoes_edges", 27 | ) 28 | parser.add_argument( 29 | "--fold_B", 30 | dest="fold_B", 31 | help="input directory for image B", 32 | type=str, 33 | default="../dataset/50kshoes_jpg", 34 | ) 35 | parser.add_argument( 36 | "--fold_AB", 37 | dest="fold_AB", 38 | help="output directory", 39 | type=str, 40 | default="../dataset/test_AB", 41 | ) 42 | parser.add_argument( 43 | "--num_imgs", dest="num_imgs", help="number of images", type=int, default=1000000 44 | ) 45 | parser.add_argument( 46 | "--use_AB", 47 | dest="use_AB", 48 | help="if true: (0001_A, 0001_B) to (0001_AB)", 49 | action="store_true", 50 | ) 51 | parser.add_argument( 52 | "--no_multiprocessing", 53 | dest="no_multiprocessing", 54 | help="If used, chooses single CPU execution instead of parallel execution", 55 | action="store_true", 56 | default=False, 57 | ) 58 | args = parser.parse_args() 59 | 60 | for arg in vars(args): 61 | print("[%s] = " % arg, getattr(args, arg)) 62 | 63 | splits = os.listdir(args.fold_A) 64 | 65 | if not args.no_multiprocessing: 66 | pool = Pool() 67 | 68 | for sp in splits: 69 | img_fold_A = os.path.join(args.fold_A, sp) 70 | img_fold_B = os.path.join(args.fold_B, sp) 71 | img_list = os.listdir(img_fold_A) 72 | if args.use_AB: 73 | img_list = [img_path for img_path in img_list if "_A." in img_path] 74 | 75 | num_imgs = min(args.num_imgs, len(img_list)) 76 | print("split = %s, use %d/%d images" % (sp, num_imgs, len(img_list))) 77 | img_fold_AB = os.path.join(args.fold_AB, sp) 78 | if not os.path.isdir(img_fold_AB): 79 | os.makedirs(img_fold_AB) 80 | print("split = %s, number of images = %d" % (sp, num_imgs)) 81 | for n in range(num_imgs): 82 | name_A = img_list[n] 83 | path_A = os.path.join(img_fold_A, name_A) 84 | if args.use_AB: 85 | name_B = name_A.replace("_A.", "_B.") 86 | else: 87 | name_B = name_A 88 | path_B = os.path.join(img_fold_B, name_B) 89 | if os.path.isfile(path_A) and os.path.isfile(path_B): 90 | name_AB = name_A 91 | if args.use_AB: 92 | name_AB = name_AB.replace("_A.", ".") # remove _A 93 | path_AB = os.path.join(img_fold_AB, name_AB) 94 | if not args.no_multiprocessing: 95 | pool.apply_async(image_write, args=(path_A, path_B, path_AB)) 96 | else: 97 | im_A = cv2.imread( 98 | path_A, 1 99 | ) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR 100 | im_B = cv2.imread( 101 | path_B, 1 102 | ) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR 103 | im_AB = np.concatenate([im_A, im_B], 1) 104 | cv2.imwrite(path_AB, im_AB) 105 | if not args.no_multiprocessing: 106 | pool.close() 107 | pool.join() 108 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/download_cyclegan_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "mini" && $FILE != "mini_pix2pix" && $FILE != "mini_colorization" ]]; then 4 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 5 | exit 1 6 | fi 7 | 8 | if [[ $FILE == "cityscapes" ]]; then 9 | echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py." 10 | echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py" 11 | exit 1 12 | fi 13 | 14 | echo "Specified [$FILE]" 15 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 16 | ZIP_FILE=./datasets/$FILE.zip 17 | TARGET_DIR=./datasets/$FILE/ 18 | wget -N $URL -O $ZIP_FILE 19 | mkdir $TARGET_DIR 20 | unzip $ZIP_FILE -d ./datasets/ 21 | rm $ZIP_FILE 22 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/download_pix2pix_dataset.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then 4 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps" 5 | exit 1 6 | fi 7 | 8 | if [[ $FILE == "cityscapes" ]]; then 9 | echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py." 10 | echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py" 11 | exit 1 12 | fi 13 | 14 | echo "Specified [$FILE]" 15 | 16 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz 17 | TAR_FILE=./datasets/$FILE.tar.gz 18 | TARGET_DIR=./datasets/$FILE/ 19 | wget -N $URL -O $TAR_FILE 20 | mkdir -p $TARGET_DIR 21 | tar -zxvf $TAR_FILE -C ./datasets/ 22 | rm $TAR_FILE 23 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/make_dataset_aligned.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from PIL import Image 4 | 5 | 6 | def get_file_paths(folder): 7 | image_file_paths = [] 8 | for root, dirs, filenames in os.walk(folder): 9 | filenames = sorted(filenames) 10 | for filename in filenames: 11 | input_path = os.path.abspath(root) 12 | file_path = os.path.join(input_path, filename) 13 | if filename.endswith(".png") or filename.endswith(".jpg"): 14 | image_file_paths.append(file_path) 15 | 16 | break # prevent descending into subfolders 17 | return image_file_paths 18 | 19 | 20 | def align_images(a_file_paths, b_file_paths, target_path): 21 | if not os.path.exists(target_path): 22 | os.makedirs(target_path) 23 | 24 | for i in range(len(a_file_paths)): 25 | img_a = Image.open(a_file_paths[i]) 26 | img_b = Image.open(b_file_paths[i]) 27 | assert img_a.size == img_b.size 28 | 29 | aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1])) 30 | aligned_image.paste(img_a, (0, 0)) 31 | aligned_image.paste(img_b, (img_a.size[0], 0)) 32 | aligned_image.save(os.path.join(target_path, "{:04d}.jpg".format(i))) 33 | 34 | 35 | if __name__ == "__main__": 36 | import argparse 37 | 38 | parser = argparse.ArgumentParser() 39 | parser.add_argument( 40 | "--dataset-path", 41 | dest="dataset_path", 42 | help="Which folder to process (it should have subfolders testA, testB, trainA and trainB", 43 | ) 44 | args = parser.parse_args() 45 | 46 | dataset_folder = args.dataset_path 47 | print(dataset_folder) 48 | 49 | test_a_path = os.path.join(dataset_folder, "testA") 50 | test_b_path = os.path.join(dataset_folder, "testB") 51 | test_a_file_paths = get_file_paths(test_a_path) 52 | test_b_file_paths = get_file_paths(test_b_path) 53 | assert len(test_a_file_paths) == len(test_b_file_paths) 54 | test_path = os.path.join(dataset_folder, "test") 55 | 56 | train_a_path = os.path.join(dataset_folder, "trainA") 57 | train_b_path = os.path.join(dataset_folder, "trainB") 58 | train_a_file_paths = get_file_paths(train_a_path) 59 | train_b_file_paths = get_file_paths(train_b_path) 60 | assert len(train_a_file_paths) == len(train_b_file_paths) 61 | train_path = os.path.join(dataset_folder, "train") 62 | 63 | align_images(test_a_file_paths, test_b_file_paths, test_path) 64 | align_images(train_a_file_paths, train_b_file_paths, train_path) 65 | -------------------------------------------------------------------------------- /F-LSeSim/datasets/prepare_cityscapes_dataset.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | 4 | from PIL import Image 5 | 6 | help_msg = """ 7 | The dataset can be downloaded from https://cityscapes-dataset.com. 8 | Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them. 9 | gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory. 10 | leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory. 11 | The processed images will be placed at --output_dir. 12 | 13 | Example usage: 14 | 15 | python prepare_cityscapes_dataset.py --gitFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./datasets/cityscapes/ 16 | """ 17 | 18 | 19 | def load_resized_img(path): 20 | return Image.open(path).convert("RGB").resize((256, 256)) 21 | 22 | 23 | def check_matching_pair(segmap_path, photo_path): 24 | segmap_identifier = os.path.basename(segmap_path).replace("_gtFine_color", "") 25 | photo_identifier = os.path.basename(photo_path).replace("_leftImg8bit", "") 26 | 27 | assert ( 28 | segmap_identifier == photo_identifier 29 | ), "[%s] and [%s] don't seem to be matching. Aborting." % (segmap_path, photo_path) 30 | 31 | 32 | def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase): 33 | save_phase = "test" if phase == "val" else "train" 34 | savedir = os.path.join(output_dir, save_phase) 35 | os.makedirs(savedir, exist_ok=True) 36 | os.makedirs(savedir + "A", exist_ok=True) 37 | os.makedirs(savedir + "B", exist_ok=True) 38 | print("Directory structure prepared at %s" % output_dir) 39 | 40 | segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png" 41 | segmap_paths = glob.glob(segmap_expr) 42 | segmap_paths = sorted(segmap_paths) 43 | 44 | photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png" 45 | photo_paths = glob.glob(photo_expr) 46 | photo_paths = sorted(photo_paths) 47 | 48 | assert len(segmap_paths) == len( 49 | photo_paths 50 | ), "%d images that match [%s], and %d images that match [%s]. Aborting." % ( 51 | len(segmap_paths), 52 | segmap_expr, 53 | len(photo_paths), 54 | photo_expr, 55 | ) 56 | 57 | for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)): 58 | check_matching_pair(segmap_path, photo_path) 59 | segmap = load_resized_img(segmap_path) 60 | photo = load_resized_img(photo_path) 61 | 62 | # data for pix2pix where the two images are placed side-by-side 63 | sidebyside = Image.new("RGB", (512, 256)) 64 | sidebyside.paste(segmap, (256, 0)) 65 | sidebyside.paste(photo, (0, 0)) 66 | savepath = os.path.join(savedir, "%d.jpg" % i) 67 | sidebyside.save(savepath, format="JPEG", subsampling=0, quality=100) 68 | 69 | # data for cyclegan where the two images are stored at two distinct directories 70 | savepath = os.path.join(savedir + "A", "%d_A.jpg" % i) 71 | photo.save(savepath, format="JPEG", subsampling=0, quality=100) 72 | savepath = os.path.join(savedir + "B", "%d_B.jpg" % i) 73 | segmap.save(savepath, format="JPEG", subsampling=0, quality=100) 74 | 75 | if i % (len(segmap_paths) // 10) == 0: 76 | print( 77 | "%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath) 78 | ) 79 | 80 | 81 | if __name__ == "__main__": 82 | import argparse 83 | 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument( 86 | "--gtFine_dir", 87 | type=str, 88 | required=True, 89 | help="Path to the Cityscapes gtFine directory.", 90 | ) 91 | parser.add_argument( 92 | "--leftImg8bit_dir", 93 | type=str, 94 | required=True, 95 | help="Path to the Cityscapes leftImg8bit_trainvaltest directory.", 96 | ) 97 | parser.add_argument( 98 | "--output_dir", 99 | type=str, 100 | required=True, 101 | default="./datasets/cityscapes", 102 | help="Directory the output images will be written to.", 103 | ) 104 | opt = parser.parse_args() 105 | 106 | print(help_msg) 107 | 108 | print("Preparing Cityscapes Dataset for val phase") 109 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val") 110 | print("Preparing Cityscapes Dataset for train phase") 111 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train") 112 | 113 | print("Done") 114 | -------------------------------------------------------------------------------- /F-LSeSim/evaluations/DC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pathlib 3 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from scipy import linalg 9 | from torch.nn.functional import adaptive_avg_pool2d 10 | 11 | try: 12 | from tqdm import tqdm 13 | except ImportError: 14 | # If not tqdm is not available, provide a mock version of it 15 | def tqdm(x): 16 | return x 17 | 18 | 19 | from prdc import compute_prdc 20 | 21 | from evaluations.inception import InceptionV3 22 | 23 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 24 | parser.add_argument("--batch-size", type=int, default=100, help="Batch size to use") 25 | parser.add_argument( 26 | "--dims", 27 | type=int, 28 | default=2048, 29 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 30 | help=( 31 | "Dimensionality of Inception features to use. " 32 | "By default, uses pool3 features" 33 | ), 34 | ) 35 | parser.add_argument( 36 | "-c", "--gpu", default="", type=str, help="GPU to use (leave blank for CPU only)" 37 | ) 38 | parser.add_argument( 39 | "path", 40 | type=str, 41 | nargs=2, 42 | help=("Paths to the generated images or " "to .npz statistic files"), 43 | ) 44 | 45 | 46 | def imread(filename): 47 | """ 48 | Loads an image file into a (height, width, 3) uint8 ndarray. .resize((229, 229), Image.BILINEAR) 49 | """ 50 | return np.asarray(Image.open(filename), dtype=np.uint8)[..., :3] 51 | 52 | 53 | def get_activations(files, model, batch_size=50, dims=2048, cuda=False): 54 | """Calculates the activations of the pool_3 layer for all images. 55 | Params: 56 | -- files : List of image files paths 57 | -- model : Instance of inception model 58 | -- batch_size : Batch size of images for the model to process at once. 59 | Make sure that the number of samples is a multiple of 60 | the batch size, otherwise some samples are ignored. This 61 | behavior is retained to match the original FID score 62 | implementation. 63 | -- dims : Dimensionality of features returned by Inception 64 | -- cuda : If set to True, use GPU 65 | Returns: 66 | -- A numpy array of dimension (num images, dims) that contains the 67 | activations of the given tensor when feeding inception with the 68 | query tensor. 69 | """ 70 | model.eval() 71 | 72 | if batch_size > len(files): 73 | print( 74 | ( 75 | "Warning: batch size is bigger than the data size. " 76 | "Setting batch size to data size" 77 | ) 78 | ) 79 | batch_size = len(files) 80 | 81 | pred_arr = np.empty((len(files), dims)) 82 | 83 | for i in tqdm(range(0, len(files), batch_size)): 84 | start = i 85 | end = i + batch_size 86 | 87 | images = np.array([imread(str(f)).astype(np.float32) for f in files[start:end]]) 88 | 89 | # Reshape to (n_images, 3, height, width) 90 | images = images.transpose((0, 3, 1, 2)) 91 | images /= 255 92 | 93 | batch = torch.from_numpy(images).type(torch.FloatTensor) 94 | if cuda: 95 | batch = batch.cuda() 96 | 97 | pred = model(batch)[0] 98 | 99 | # If model output is not scalar, apply global spatial average pooling. 100 | # This happens if you choose a dimensionality not equal 2048. 101 | if pred.size(2) != 1 or pred.size(3) != 1: 102 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) 103 | 104 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(pred.size(0), -1) 105 | 106 | return pred_arr 107 | 108 | 109 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda): 110 | if path.endswith(".npz"): 111 | f = np.load(path) 112 | m, s = f["mu"][:], f["sigma"][:] 113 | f.close() 114 | else: 115 | path = pathlib.Path(path) 116 | files = list(path.glob("*.jpg")) + list(path.glob("*.png")) 117 | f = get_activations(files, model, batch_size, dims, cuda) 118 | 119 | return f 120 | 121 | 122 | def calculate_DC_given_paths(paths, batch_size, cuda, dims): 123 | 124 | for p in paths: 125 | if not os.path.exists(p): 126 | raise RuntimeError("Invalid path: %s" % p) 127 | 128 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] 129 | 130 | model = InceptionV3([block_idx]) 131 | if cuda: 132 | model.cuda() 133 | 134 | f0 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda) 135 | 136 | f1 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda) 137 | 138 | dc_value = compute_prdc(real_features=f0, fake_features=f1, nearest_k=95) 139 | 140 | return dc_value 141 | 142 | 143 | def main(): 144 | args = parser.parse_args() 145 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 146 | 147 | dc_value = calculate_DC_given_paths( 148 | args.path, args.batch_size, args.gpu != "", args.dims 149 | ) 150 | print(dc_value) 151 | 152 | 153 | if __name__ == "__main__": 154 | main() 155 | -------------------------------------------------------------------------------- /F-LSeSim/evaluations/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/F-LSeSim/evaluations/__init__.py -------------------------------------------------------------------------------- /F-LSeSim/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | 23 | from models.base_model import BaseModel 24 | 25 | 26 | def find_model_using_name(model_name): 27 | """Import the module "models/[model_name]_model.py". 28 | 29 | In the file, the class called DatasetNameModel() will 30 | be instantiated. It has to be a subclass of BaseModel, 31 | and it is case-insensitive. 32 | """ 33 | model_filename = "models." + model_name + "_model" 34 | modellib = importlib.import_module(model_filename) 35 | model = None 36 | target_model_name = model_name.replace("_", "") + "model" 37 | for name, cls in modellib.__dict__.items(): 38 | if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print( 43 | "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." 44 | % (model_filename, target_model_name) 45 | ) 46 | exit(0) 47 | 48 | return model 49 | 50 | 51 | def get_option_setter(model_name): 52 | """Return the static method of the model class.""" 53 | model_class = find_model_using_name(model_name) 54 | return model_class.modify_commandline_options 55 | 56 | 57 | def create_model(opt, norm_cfg): 58 | """Create a model given the option. 59 | 60 | This function warps the class CustomDatasetDataLoader. 61 | This is the main interface between this package and 'train.py'/'test.py' 62 | 63 | Example: 64 | >>> from models import create_model 65 | >>> model = create_model(opt) 66 | """ 67 | model = find_model_using_name(opt.model) 68 | instance = model(opt, norm_cfg=norm_cfg) 69 | print("model [%s] was created" % type(instance).__name__) 70 | return instance 71 | -------------------------------------------------------------------------------- /F-LSeSim/models/colorization_model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from skimage import color # used for lab2rgb 4 | 5 | from .pix2pix_model import Pix2PixModel 6 | 7 | 8 | class ColorizationModel(Pix2PixModel): 9 | """This is a subclass of Pix2PixModel for image colorization (black & white image -> colorful images). 10 | 11 | The model training requires '-dataset_model colorization' dataset. 12 | It trains a pix2pix model, mapping from L channel to ab channels in Lab color space. 13 | By default, the colorization dataset will automatically set '--input_nc 1' and '--output_nc 2'. 14 | """ 15 | 16 | @staticmethod 17 | def modify_commandline_options(parser, is_train=True): 18 | """Add new dataset-specific options, and rewrite default values for existing options. 19 | 20 | Parameters: 21 | parser -- original option parser 22 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 23 | 24 | Returns: 25 | the modified parser. 26 | 27 | By default, we use 'colorization' dataset for this model. 28 | See the original pix2pix paper (https://arxiv.org/pdf/1611.07004.pdf) and colorization results (Figure 9 in the paper) 29 | """ 30 | Pix2PixModel.modify_commandline_options(parser, is_train) 31 | parser.set_defaults(dataset_mode="colorization") 32 | return parser 33 | 34 | def __init__(self, opt): 35 | """Initialize the class. 36 | 37 | Parameters: 38 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 39 | 40 | For visualization, we set 'visual_names' as 'real_A' (input real image), 41 | 'real_B_rgb' (ground truth RGB image), and 'fake_B_rgb' (predicted RGB image) 42 | We convert the Lab image 'real_B' (inherited from Pix2pixModel) to a RGB image 'real_B_rgb'. 43 | we convert the Lab image 'fake_B' (inherited from Pix2pixModel) to a RGB image 'fake_B_rgb'. 44 | """ 45 | # reuse the pix2pix model 46 | Pix2PixModel.__init__(self, opt) 47 | # specify the images to be visualized. 48 | self.visual_names = ["real_A", "real_B_rgb", "fake_B_rgb"] 49 | 50 | def lab2rgb(self, L, AB): 51 | """Convert an Lab tensor image to a RGB numpy output 52 | Parameters: 53 | L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array) 54 | AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array) 55 | 56 | Returns: 57 | rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array) 58 | """ 59 | AB2 = AB * 110.0 60 | L2 = (L + 1.0) * 50.0 61 | Lab = torch.cat([L2, AB2], dim=1) 62 | Lab = Lab[0].data.cpu().float().numpy() 63 | Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0)) 64 | rgb = color.lab2rgb(Lab) * 255 65 | return rgb 66 | 67 | def compute_visuals(self): 68 | """Calculate additional output images for visdom and HTML visualization""" 69 | self.real_B_rgb = self.lab2rgb(self.real_A, self.real_B) 70 | self.fake_B_rgb = self.lab2rgb(self.real_A, self.fake_B) 71 | -------------------------------------------------------------------------------- /F-LSeSim/models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.downsample import Downsample 6 | from models.normalization import make_norm_layer 7 | 8 | 9 | class DiscriminatorBasicBlock(nn.Module): 10 | def __init__( 11 | self, 12 | in_features, 13 | out_features, 14 | do_downsample=True, 15 | do_instancenorm=True, 16 | norm_cfg=None, 17 | ): 18 | super().__init__() 19 | 20 | self.do_downsample = do_downsample 21 | self.do_instancenorm = do_instancenorm 22 | self.norm_cfg = norm_cfg or {'type': 'in'} 23 | self.norm_cfg = {k.lower(): v for k, v in self.norm_cfg.items()} 24 | 25 | self.conv = nn.Conv2d( 26 | in_features, out_features, kernel_size=4, stride=1, padding=1 27 | ) 28 | self.leakyrelu = nn.LeakyReLU(0.2, True) 29 | 30 | if do_instancenorm: 31 | self.instancenorm = make_norm_layer(self.norm_cfg, num_features=out_features) 32 | 33 | if do_downsample: 34 | self.downsample = Downsample(out_features) 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | if self.do_instancenorm: 39 | x = self.instancenorm(x) 40 | x = self.leakyrelu(x) 41 | if self.do_downsample: 42 | x = self.downsample(x) 43 | return x 44 | 45 | 46 | class Discriminator(nn.Module): 47 | def __init__(self, in_channels=3, features=64, avg_pooling=False): 48 | super().__init__() 49 | self.block1 = DiscriminatorBasicBlock( 50 | in_channels, features, do_downsample=True, do_instancenorm=False 51 | ) 52 | self.block2 = DiscriminatorBasicBlock( 53 | features, features * 2, do_downsample=True, do_instancenorm=True 54 | ) 55 | self.block3 = DiscriminatorBasicBlock( 56 | features * 2, features * 4, do_downsample=True, do_instancenorm=True 57 | ) 58 | self.block4 = DiscriminatorBasicBlock( 59 | features * 4, features * 8, do_downsample=False, do_instancenorm=True 60 | ) 61 | self.conv = nn.Conv2d(features * 8, 1, kernel_size=4, stride=1, padding=1) 62 | self.avg_pooling = avg_pooling 63 | 64 | def forward(self, x): 65 | x = self.block1(x) 66 | x = self.block2(x) 67 | x = self.block3(x) 68 | x = self.block4(x) 69 | x = self.conv(x) 70 | if self.avg_pooling: 71 | x = F.avg_pool2d(x, x.size()[2:]) 72 | x = torch.flatten(x, 1) 73 | return x 74 | 75 | def set_requires_grad(self, requires_grad=False): 76 | for param in self.parameters(): 77 | param.requires_grad = requires_grad 78 | 79 | 80 | if __name__ == "__main__": 81 | x = torch.randn((5, 3, 256, 256)) 82 | print(x.shape) 83 | model = Discriminator(in_channels=3, avg_pooling=True) 84 | preds = model(x) 85 | print(preds.shape) 86 | -------------------------------------------------------------------------------- /F-LSeSim/models/downsample.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Downsample(nn.Module): 5 | def __init__(self, features): 6 | super().__init__() 7 | self.reflectionpad = nn.ReflectionPad2d(1) 8 | self.conv = nn.Conv2d(features, features, kernel_size=3, stride=2) 9 | 10 | def forward(self, x): 11 | x = self.reflectionpad(x) 12 | x = self.conv(x) 13 | return x 14 | -------------------------------------------------------------------------------- /F-LSeSim/models/networks.py: -------------------------------------------------------------------------------- 1 | from torch.optim import lr_scheduler 2 | 3 | from models.generator import Generator 4 | 5 | from . import cyclegan_networks, stylegan_networks 6 | 7 | 8 | ################################################################################## 9 | # Networks 10 | ################################################################################## 11 | def define_G( 12 | input_nc, 13 | output_nc, 14 | ngf, 15 | netG, 16 | norm="batch", 17 | use_dropout=False, 18 | init_type="normal", 19 | init_gain=0.02, 20 | no_antialias=False, 21 | no_antialias_up=False, 22 | gpu_ids=[], 23 | opt=None, 24 | norm_cfg=None, 25 | ): 26 | """ 27 | Create a generator 28 | :param input_nc: the number of channels in input images 29 | :param output_nc: the number of channels in output images 30 | :param ngf: the number of filters in the first conv layer 31 | :param netG: the architecture's name: resnet_9blocks | munit | stylegan2 32 | :param norm: the name of normalization layers used in the network: batch | instance | none 33 | :param use_dropout: if use dropout layers. 34 | :param init_type: the name of our initialization method. 35 | :param init_gain: scaling factor for normal, xavier and orthogonal. 36 | :param no_antialias: use learned down sampling layer or not 37 | :param no_antialias_up: use learned up sampling layer or not 38 | :param gpu_ids: which GPUs the network runs on: e.g., 0,1,2 39 | :param opt: options 40 | :return: 41 | """ 42 | 43 | if netG == "resnet_9blocks": 44 | net = Generator(norm_cfg=norm_cfg or {'type': 'in'}) 45 | elif netG == "stylegan2": 46 | net = stylegan_networks.StyleGAN2Generator(input_nc, output_nc, ngf, opt=opt) 47 | else: 48 | raise NotImplementedError("Generator model name [%s] is not recognized" % netG) 49 | return cyclegan_networks.init_net( 50 | net, init_type, init_gain, gpu_ids, initialize_weights=("stylegan2" not in netG) 51 | ) 52 | 53 | 54 | def define_D( 55 | input_nc, 56 | ndf, 57 | netD, 58 | n_layers_D=3, 59 | norm="batch", 60 | init_type="normal", 61 | init_gain=0.02, 62 | no_antialias=False, 63 | gpu_ids=[], 64 | opt=None, 65 | ): 66 | """ 67 | Create a discriminator 68 | :param input_nc: the number of channels in input images 69 | :param ndf: the number of filters in the first conv layer 70 | :param netD: the architecture's name 71 | :param n_layers_D: the number of conv layers in the discriminator; effective when netD=='n_layers' 72 | :param norm: the type of normalization layers used in the network 73 | :param init_type: the name of the initialization method 74 | :param init_gain: scaling factor for normal, xavier and orthogonal 75 | :param no_antialias: use learned down sampling layer or not 76 | :param gpu_ids: which GPUs the network runs on: e.g., 0,1,2 77 | :param opt: options 78 | :return: 79 | """ 80 | norm_value = cyclegan_networks.get_norm_layer(norm) 81 | if netD == "basic": 82 | net = cyclegan_networks.NLayerDiscriminator( 83 | input_nc, ndf, n_layers_D, norm_value, no_antialias 84 | ) 85 | elif netD == "bimulti": 86 | net = cyclegan_networks.D_NLayersMulti( 87 | input_nc, ndf, n_layers=n_layers_D, norm_layer=norm_value, num_D=2 88 | ) 89 | elif "stylegan2" in netD: 90 | net = stylegan_networks.StyleGAN2Discriminator(input_nc, ndf, opt=opt) 91 | else: 92 | raise NotImplementedError( 93 | "Discriminator model name [%s] is not recognized" % netD 94 | ) 95 | return cyclegan_networks.init_net( 96 | net, init_type, init_gain, gpu_ids, initialize_weights=("stylegan2" not in netD) 97 | ) 98 | 99 | 100 | ############################################################################### 101 | # Helper Functions 102 | ############################################################################### 103 | def get_scheduler(optimizer, opt): 104 | """Return a learning rate scheduler 105 | 106 | Parameters: 107 | optimizer -- the optimizer of the network 108 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 109 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 110 | 111 | For 'linear', we keep the same learning rate for the first epochs 112 | and linearly decay the rate to zero over the next epochs. 113 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 114 | See https://pytorch.org/docs/stable/optim.html for more details. 115 | """ 116 | if opt.lr_policy == "linear": 117 | 118 | def lambda_rule(epoch): 119 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float( 120 | opt.n_epochs_decay + 1 121 | ) 122 | return lr_l 123 | 124 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 125 | elif opt.lr_policy == "step": 126 | scheduler = lr_scheduler.StepLR( 127 | optimizer, step_size=opt.lr_decay_iters, gamma=0.1 128 | ) 129 | elif opt.lr_policy == "plateau": 130 | scheduler = lr_scheduler.ReduceLROnPlateau( 131 | optimizer, mode="min", factor=0.2, threshold=0.01, patience=5 132 | ) 133 | elif opt.lr_policy == "cosine": 134 | scheduler = lr_scheduler.CosineAnnealingLR( 135 | optimizer, T_max=opt.n_epochs, eta_min=0 136 | ) 137 | else: 138 | return NotImplementedError( 139 | "learning rate policy [%s] is not implemented", opt.lr_policy 140 | ) 141 | return scheduler 142 | -------------------------------------------------------------------------------- /F-LSeSim/models/normalization.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict 3 | 4 | import torch.nn as nn 5 | 6 | from models.kin import KernelizedInstanceNorm 7 | from models.tin import ThumbInstanceNorm 8 | 9 | 10 | # TODO: To be deprecated 11 | def get_normalization_layer(num_features, normalization="kin"): 12 | if normalization == "kin": 13 | return KernelizedInstanceNorm(num_features=num_features) 14 | elif normalization == "tin": 15 | return ThumbInstanceNorm(num_features=num_features) 16 | elif normalization == "in": 17 | return nn.InstanceNorm2d(num_features) 18 | else: 19 | raise NotImplementedError 20 | 21 | 22 | def make_norm_layer(norm_cfg: Dict[str, Any], **kwargs: Any): 23 | """ 24 | Create normalization layer based on given config and arguments. 25 | 26 | Args: 27 | norm_cfg (Dict[str, Any]): A dict of keyword arguments of normalization layer. 28 | It must have a key 'type' to specify which normalization layers will be used. 29 | It accepts upper case argument. 30 | **kwargs (Any): The keyword arguments are used to overwrite `norm_cfg`. 31 | 32 | Returns: 33 | nn.Module: A layer object. 34 | """ 35 | norm_cfg = deepcopy(norm_cfg) 36 | norm_cfg = {k.lower(): v for k, v in norm_cfg.items()} 37 | 38 | norm_cfg.update(kwargs) 39 | 40 | if 'type' not in norm_cfg: 41 | raise ValueError('"type" wasn\'t specified.') 42 | 43 | norm_type = norm_cfg['type'] 44 | del norm_cfg['type'] 45 | 46 | if norm_type == 'in': 47 | return nn.InstanceNorm2d(**norm_cfg) 48 | elif norm_type == 'tin': 49 | return ThumbInstanceNorm(**norm_cfg) 50 | elif norm_type == 'kin': 51 | return KernelizedInstanceNorm(**norm_cfg) 52 | else: 53 | raise ValueError(f'Unknown norm type: {norm_type}.') 54 | -------------------------------------------------------------------------------- /F-LSeSim/models/sinsc_model.py: -------------------------------------------------------------------------------- 1 | from .sc_model import SCModel 2 | 3 | 4 | class SinSCModel(SCModel): 5 | """ 6 | This class implements the single image translation 7 | """ 8 | 9 | @staticmethod 10 | def modify_commandline_options(parser, is_train=True): 11 | """ 12 | :param parser: original options parser 13 | :param is_train: whether training phase or test phase. You can use this flag to add training-specific or test-specific options 14 | :return: the modified parser 15 | """ 16 | parser = SCModel.modify_commandline_options(parser, is_train) 17 | 18 | parser.set_defaults( 19 | dataset_mode="singleimage", 20 | netG="stylegan2", 21 | stylegan2_G_num_downsampling=2, 22 | netD="stylegan2", 23 | gan_mode="nonsaturating", 24 | num_patches=1, 25 | attn_layers="4,7,9", 26 | lambda_spatial=10.0, 27 | lambda_identity=0.0, 28 | lambda_gradient=1.0, 29 | lambda_spatial_idt=0.0, 30 | ngf=8, 31 | ndf=8, 32 | lr=0.001, 33 | beta1=0.0, 34 | beta2=0.99, 35 | load_size=1024, 36 | crop_size=128, 37 | preprocess="zoom_and_patch", 38 | D_patch_size=None, 39 | ) 40 | 41 | if is_train: 42 | parser.set_defaults( 43 | preprocess="zoom_and_patch", 44 | batch_size=16, 45 | save_epoch_freq=1, 46 | save_latest_freq=20000, 47 | n_epochs=4, 48 | n_epochs_decay=4, 49 | ) 50 | else: 51 | parser.set_defaults( 52 | preprocess="none", # load the whole image as it is 53 | batch_size=1, 54 | num_test=1, 55 | ) 56 | 57 | return parser 58 | 59 | def __init__(self, opt): 60 | super().__init__(opt) 61 | -------------------------------------------------------------------------------- /F-LSeSim/models/test_model.py: -------------------------------------------------------------------------------- 1 | from . import networks 2 | from .base_model import BaseModel 3 | 4 | 5 | class TestModel(BaseModel): 6 | """This TesteModel can be used to generate CycleGAN results for only one direction. 7 | This model will automatically set '--dataset_mode single', which only loads the images from one collection. 8 | 9 | See the test instruction for more details. 10 | """ 11 | 12 | @staticmethod 13 | def modify_commandline_options(parser, is_train=True): 14 | """Add new dataset-specific options, and rewrite default values for existing options. 15 | 16 | Parameters: 17 | parser -- original option parser 18 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 19 | 20 | Returns: 21 | the modified parser. 22 | 23 | The model can only be used during test time. It requires '--dataset_mode single'. 24 | You need to specify the network using the option '--model_suffix'. 25 | """ 26 | assert not is_train, "TestModel cannot be used during training time" 27 | parser.set_defaults(dataset_mode="single") 28 | parser.add_argument( 29 | "--model_suffix", 30 | type=str, 31 | default="", 32 | help="In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.", 33 | ) 34 | 35 | return parser 36 | 37 | def __init__(self, opt): 38 | """Initialize the pix2pix class. 39 | 40 | Parameters: 41 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 42 | """ 43 | assert not opt.isTrain 44 | BaseModel.__init__(self, opt) 45 | # specify the training losses you want to print out. The training/test scripts will call 46 | self.loss_names = [] 47 | # specify the images you want to save/display. The training/test scripts will call 48 | self.visual_names = ["real", "fake"] 49 | # specify the models you want to save to the disk. The training/test scripts will call and 50 | self.model_names = ["G" + opt.model_suffix] # only generator is needed. 51 | self.netG = networks.define_G( 52 | opt.input_nc, 53 | opt.output_nc, 54 | opt.ngf, 55 | opt.netG, 56 | opt.norm, 57 | not opt.no_dropout, 58 | opt.init_type, 59 | opt.init_gain, 60 | self.gpu_ids, 61 | ) 62 | 63 | # assigns the model to self.netG_[suffix] so that it can be loaded 64 | # please see 65 | setattr(self, "netG" + opt.model_suffix, self.netG) # store netG in self. 66 | 67 | def set_input(self, input): 68 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 69 | 70 | Parameters: 71 | input: a dictionary that contains the data itself and its metadata information. 72 | 73 | We need to use 'single_dataset' dataset mode. It only load images from one domain. 74 | """ 75 | self.real = input["A"].to(self.device) 76 | self.image_paths = input["A_paths"] 77 | 78 | def forward(self): 79 | """Run forward pass.""" 80 | self.fake = self.netG(self.real) # G(real) 81 | 82 | def optimize_parameters(self): 83 | """No optimization for test model.""" 84 | pass 85 | -------------------------------------------------------------------------------- /F-LSeSim/models/tin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ThumbInstanceNorm(nn.Module): 6 | def __init__(self, num_features, affine=True): 7 | super(ThumbInstanceNorm, self).__init__() 8 | self.thumb_mean = None 9 | self.thumb_std = None 10 | self.normal_instance_normalization = False 11 | self.collection_mode = False 12 | if affine == True: 13 | self.weight = nn.Parameter( 14 | torch.ones(size=(1, num_features, 1, 1), requires_grad=True) 15 | ) 16 | self.bias = nn.Parameter( 17 | torch.zeros(size=(1, num_features, 1, 1), requires_grad=True) 18 | ) 19 | 20 | def calc_mean_std(self, feat, eps=1e-5): 21 | size = feat.size() 22 | assert len(size) == 4 23 | N, C = size[:2] 24 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 25 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 26 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 27 | return feat_mean, feat_std 28 | 29 | def forward(self, x): 30 | if self.training or self.normal_instance_normalization: 31 | x_mean, x_std = self.calc_mean_std(x) 32 | x = (x - x_mean) / x_std * self.weight + self.bias 33 | return x 34 | else: 35 | if self.collection_mode: 36 | assert x.shape[0] == 1 37 | x_mean, x_std = self.calc_mean_std(x) 38 | self.thumb_mean = x_mean 39 | self.thumb_std = x_std 40 | 41 | x = (x - self.thumb_mean) / self.thumb_std * self.weight + self.bias 42 | return x 43 | 44 | 45 | def not_use_thumbnail_instance_norm(model): 46 | for _, layer in model.named_modules(): 47 | if isinstance(layer, ThumbInstanceNorm): 48 | layer.collection_mode = False 49 | layer.normal_instance_normalization = True 50 | 51 | 52 | def init_thumbnail_instance_norm(model): 53 | for _, layer in model.named_modules(): 54 | if isinstance(layer, ThumbInstanceNorm): 55 | layer.collection_mode = True 56 | 57 | 58 | def use_thumbnail_instance_norm(model): 59 | for _, layer in model.named_modules(): 60 | if isinstance(layer, ThumbInstanceNorm): 61 | layer.collection_mode = False 62 | layer.normal_instance_normalization = False 63 | -------------------------------------------------------------------------------- /F-LSeSim/models/upsample.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Upsample(nn.Module): 5 | def __init__(self, features): 6 | super().__init__() 7 | layers = [ 8 | nn.ReplicationPad2d(1), 9 | nn.ConvTranspose2d(features, features, kernel_size=4, stride=2, padding=3), 10 | ] 11 | self.model = nn.Sequential(*layers) 12 | 13 | def forward(self, input): 14 | return self.model(input) 15 | -------------------------------------------------------------------------------- /F-LSeSim/models/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from scipy import signal 4 | 5 | 6 | def gkern(kernlen=1, std=3): 7 | """Returns a 2D Gaussian kernel array.""" 8 | gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1) 9 | gkern2d = np.outer(gkern1d, gkern1d) 10 | return gkern2d 11 | 12 | 13 | def get_kernel(padding=1, gaussian_std=3, mode="constant"): 14 | kernel_size = padding * 2 + 1 15 | if mode == "constant": 16 | kernel = torch.ones(kernel_size, kernel_size) 17 | kernel = kernel / (kernel_size * kernel_size) 18 | 19 | elif mode == "gaussian": 20 | kernel = gkern(kernel_size, std=gaussian_std) 21 | kernel = kernel / kernel.sum() 22 | kernel = torch.from_numpy(kernel.astype(np.float32)) 23 | 24 | else: 25 | raise NotImplementedError 26 | 27 | return kernel -------------------------------------------------------------------------------- /F-LSeSim/options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /F-LSeSim/options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument( 13 | "--results_dir", type=str, default="./results/", help="saves results here." 14 | ) 15 | parser.add_argument( 16 | "--aspect_ratio", 17 | type=float, 18 | default=1.0, 19 | help="aspect ratio of result images", 20 | ) 21 | parser.add_argument( 22 | "--phase", type=str, default="test", help="train, val, test, etc" 23 | ) 24 | # Dropout and Batchnorm has different behavioir during training and test. 25 | parser.add_argument( 26 | "--eval", action="store_true", help="use eval mode during test time." 27 | ) 28 | parser.add_argument( 29 | "--num_test", type=int, default=50, help="how many test images to run" 30 | ) 31 | # rewrite devalue values 32 | parser.set_defaults(model="test") 33 | # To avoid cropping, the load_size should be the same as crop_size 34 | parser.set_defaults(load_size=parser.get_default("crop_size")) 35 | self.isTrain = False 36 | return parser 37 | -------------------------------------------------------------------------------- /F-LSeSim/options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument( 14 | "--display_freq", 15 | type=int, 16 | default=250, 17 | help="frequency of showing training results on screen", 18 | ) 19 | parser.add_argument( 20 | "--display_ncols", 21 | type=int, 22 | default=4, 23 | help="if positive, display all images in a single visdom web panel with certain number of images per row.", 24 | ) 25 | parser.add_argument( 26 | "--display_id", type=int, default=None, help="window id of the web display" 27 | ) 28 | parser.add_argument( 29 | "--display_server", 30 | type=str, 31 | default="http://localhost", 32 | help="visdom server of the web display", 33 | ) 34 | parser.add_argument( 35 | "--display_env", 36 | type=str, 37 | default="main", 38 | help='visdom display environment name (default is "main")', 39 | ) 40 | parser.add_argument( 41 | "--display_port", 42 | type=int, 43 | default=8097, 44 | help="visdom port of the web display", 45 | ) 46 | parser.add_argument( 47 | "--update_html_freq", 48 | type=int, 49 | default=1000, 50 | help="frequency of saving training results to html", 51 | ) 52 | parser.add_argument( 53 | "--print_freq", 54 | type=int, 55 | default=100, 56 | help="frequency of showing training results on console", 57 | ) 58 | parser.add_argument( 59 | "--no_html", 60 | action="store_true", 61 | help="do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/", 62 | ) 63 | # network saving and loading parameters 64 | parser.add_argument( 65 | "--save_latest_freq", 66 | type=int, 67 | default=5000, 68 | help="frequency of saving the latest results", 69 | ) 70 | parser.add_argument( 71 | "--save_epoch_freq", 72 | type=int, 73 | default=10, 74 | help="frequency of saving checkpoints at the end of epochs", 75 | ) 76 | parser.add_argument( 77 | "--save_by_iter", 78 | action="store_true", 79 | help="whether saves model by iteration", 80 | ) 81 | parser.add_argument( 82 | "--continue_train", 83 | action="store_true", 84 | help="continue training: load the latest model", 85 | ) 86 | parser.add_argument( 87 | "--epoch_count", 88 | type=int, 89 | default=1, 90 | help="the starting epoch count, we save the model by , +, ...", 91 | ) 92 | parser.add_argument( 93 | "--phase", type=str, default="train", help="train, val, test, etc" 94 | ) 95 | # training parameters 96 | parser.add_argument( 97 | "--n_epochs", 98 | type=int, 99 | default=100, 100 | help="number of epochs with the initial learning rate", 101 | ) 102 | parser.add_argument( 103 | "--n_epochs_decay", 104 | type=int, 105 | default=100, 106 | help="number of epochs to linearly decay learning rate to zero", 107 | ) 108 | parser.add_argument( 109 | "--beta1", type=float, default=0.5, help="momentum term of adam" 110 | ) 111 | parser.add_argument( 112 | "--beta2", type=float, default=0.999, help="momentum term of adam" 113 | ) 114 | parser.add_argument( 115 | "--lr", type=float, default=0.0001, help="initial learning rate for adam" 116 | ) 117 | parser.add_argument( 118 | "--gan_mode", 119 | type=str, 120 | default="lsgan", 121 | help="the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.", 122 | ) 123 | parser.add_argument( 124 | "--pool_size", 125 | type=int, 126 | default=50, 127 | help="the size of image buffer that stores previously generated images", 128 | ) 129 | parser.add_argument( 130 | "--lr_policy", 131 | type=str, 132 | default="linear", 133 | help="learning rate policy. [linear | step | plateau | cosine]", 134 | ) 135 | parser.add_argument( 136 | "--lr_decay_iters", 137 | type=int, 138 | default=50, 139 | help="multiply by a gamma every lr_decay_iters iterations", 140 | ) 141 | parser.add_argument( 142 | "-c", "--config", type=str, help="extra config for compatibility" 143 | ) 144 | self.isTrain = True 145 | return parser 146 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/conda_deps.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing 3 | conda install pytorch torchvision -c pytorch # add cuda90 if CUDA 9 4 | conda install visdom dominate -c conda-forge # install visdom and dominate 5 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/download_cyclegan_model.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available models are apple2orange, orange2apple, summer2winter_yosemite, winter2summer_yosemite, horse2zebra, zebra2horse, monet2photo, style_monet, style_cezanne, style_ukiyoe, style_vangogh, sat2map, map2sat, cityscapes_photo2label, cityscapes_label2photo, facades_photo2label, facades_label2photo, iphone2dslr_flower" 4 | 5 | echo "Specified [$FILE]" 6 | 7 | mkdir -p ./checkpoints/${FILE}_pretrained 8 | MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth 9 | URL=http://efrosgans.eecs.berkeley.edu/cyclegan/pretrained_models/$FILE.pth 10 | 11 | wget -N $URL -O $MODEL_FILE 12 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/download_pix2pix_model.sh: -------------------------------------------------------------------------------- 1 | FILE=$1 2 | 3 | echo "Note: available models are edges2shoes, sat2map, map2sat, facades_label2photo, and day2night" 4 | echo "Specified [$FILE]" 5 | 6 | mkdir -p ./checkpoints/${FILE}_pretrained 7 | MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth 8 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/models-pytorch/$FILE.pth 9 | 10 | wget -N $URL -O $MODEL_FILE 11 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/edges/PostprocessHED.m: -------------------------------------------------------------------------------- 1 | %%% Prerequisites 2 | % You need to get the cpp file edgesNmsMex.cpp from https://raw.githubusercontent.com/pdollar/edges/master/private/edgesNmsMex.cpp 3 | % and compile it in Matlab: mex edgesNmsMex.cpp 4 | % You also need to download and install Piotr's Computer Vision Matlab Toolbox: https://pdollar.github.io/toolbox/ 5 | 6 | %%% parameters 7 | % hed_mat_dir: the hed mat file directory (the output of 'batch_hed.py') 8 | % edge_dir: the output HED edges directory 9 | % image_width: resize the edge map to [image_width, image_width] 10 | % threshold: threshold for image binarization (default 25.0/255.0) 11 | % small_edge: remove small edges (default 5) 12 | 13 | function [] = PostprocessHED(hed_mat_dir, edge_dir, image_width, threshold, small_edge) 14 | 15 | if ~exist(edge_dir, 'dir') 16 | mkdir(edge_dir); 17 | end 18 | fileList = dir(fullfile(hed_mat_dir, '*.mat')); 19 | nFiles = numel(fileList); 20 | fprintf('find %d mat files\n', nFiles); 21 | 22 | for n = 1 : nFiles 23 | if mod(n, 1000) == 0 24 | fprintf('process %d/%d images\n', n, nFiles); 25 | end 26 | fileName = fileList(n).name; 27 | filePath = fullfile(hed_mat_dir, fileName); 28 | jpgName = strrep(fileName, '.mat', '.jpg'); 29 | edge_path = fullfile(edge_dir, jpgName); 30 | 31 | if ~exist(edge_path, 'file') 32 | E = GetEdge(filePath); 33 | E = imresize(E,[image_width,image_width]); 34 | E_simple = SimpleEdge(E, threshold, small_edge); 35 | E_simple = uint8(E_simple*255); 36 | imwrite(E_simple, edge_path, 'Quality',100); 37 | end 38 | end 39 | end 40 | 41 | 42 | 43 | 44 | function [E] = GetEdge(filePath) 45 | load(filePath); 46 | E = 1-edge_predict; 47 | end 48 | 49 | function [E4] = SimpleEdge(E, threshold, small_edge) 50 | if nargin <= 1 51 | threshold = 25.0/255.0; 52 | end 53 | 54 | if nargin <= 2 55 | small_edge = 5; 56 | end 57 | 58 | if ndims(E) == 3 59 | E = E(:,:,1); 60 | end 61 | 62 | E1 = 1 - E; 63 | E2 = EdgeNMS(E1); 64 | E3 = double(E2>=max(eps,threshold)); 65 | E3 = bwmorph(E3,'thin',inf); 66 | E4 = bwareaopen(E3, small_edge); 67 | E4=1-E4; 68 | end 69 | 70 | function [E_nms] = EdgeNMS( E ) 71 | E=single(E); 72 | [Ox,Oy] = gradient2(convTri(E,4)); 73 | [Oxx,~] = gradient2(Ox); 74 | [Oxy,Oyy] = gradient2(Oy); 75 | O = mod(atan(Oyy.*sign(-Oxy)./(Oxx+1e-5)),pi); 76 | E_nms = edgesNmsMex(E,O,1,5,1.01,1); 77 | end 78 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/edges/batch_hed.py: -------------------------------------------------------------------------------- 1 | # HED batch processing script; modified from https://github.com/s9xie/hed/blob/master/examples/hed/HED-tutorial.ipynb 2 | # Step 1: download the hed repo: https://github.com/s9xie/hed 3 | # Step 2: download the models and protoxt, and put them under {caffe_root}/examples/hed/ 4 | # Step 3: put this script under {caffe_root}/examples/hed/ 5 | # Step 4: run the following script: 6 | # python batch_hed.py --images_dir=/data/to/path/photos/ --hed_mat_dir=/data/to/path/hed_mat_files/ 7 | # The code sometimes crashes after computation is done. Error looks like "Check failed: ... driver shutting down". You can just kill the job. 8 | # For large images, it will produce gpu memory issue. Therefore, you better resize the images before running this script. 9 | # Step 5: run the MATLAB post-processing script "PostprocessHED.m" 10 | 11 | 12 | import argparse 13 | import os 14 | import sys 15 | 16 | import caffe 17 | import numpy as np 18 | import scipy.io as sio 19 | from PIL import Image 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="batch proccesing: photos->edges") 24 | parser.add_argument( 25 | "--caffe_root", dest="caffe_root", help="caffe root", default="../../", type=str 26 | ) 27 | parser.add_argument( 28 | "--caffemodel", 29 | dest="caffemodel", 30 | help="caffemodel", 31 | default="./hed_pretrained_bsds.caffemodel", 32 | type=str, 33 | ) 34 | parser.add_argument( 35 | "--prototxt", 36 | dest="prototxt", 37 | help="caffe prototxt file", 38 | default="./deploy.prototxt", 39 | type=str, 40 | ) 41 | parser.add_argument( 42 | "--images_dir", 43 | dest="images_dir", 44 | help="directory to store input photos", 45 | type=str, 46 | ) 47 | parser.add_argument( 48 | "--hed_mat_dir", 49 | dest="hed_mat_dir", 50 | help="directory to store output hed edges in mat file", 51 | type=str, 52 | ) 53 | parser.add_argument( 54 | "--border", dest="border", help="padding border", type=int, default=128 55 | ) 56 | parser.add_argument("--gpu_id", dest="gpu_id", help="gpu id", type=int, default=1) 57 | args = parser.parse_args() 58 | return args 59 | 60 | 61 | args = parse_args() 62 | for arg in vars(args): 63 | print("[%s] =" % arg, getattr(args, arg)) 64 | # Make sure that caffe is on the python path: 65 | caffe_root = ( 66 | args.caffe_root 67 | ) # this file is expected to be in {caffe_root}/examples/hed/ 68 | sys.path.insert(0, caffe_root + "python") 69 | 70 | 71 | if not os.path.exists(args.hed_mat_dir): 72 | print("create output directory %s" % args.hed_mat_dir) 73 | os.makedirs(args.hed_mat_dir) 74 | 75 | imgList = os.listdir(args.images_dir) 76 | nImgs = len(imgList) 77 | print("#images = %d" % nImgs) 78 | 79 | caffe.set_mode_gpu() 80 | caffe.set_device(args.gpu_id) 81 | # load net 82 | net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST) 83 | # pad border 84 | border = args.border 85 | 86 | for i in range(nImgs): 87 | if i % 500 == 0: 88 | print("processing image %d/%d" % (i, nImgs)) 89 | im = Image.open(os.path.join(args.images_dir, imgList[i])) 90 | 91 | in_ = np.array(im, dtype=np.float32) 92 | in_ = np.pad(in_, ((border, border), (border, border), (0, 0)), "reflect") 93 | 94 | in_ = in_[:, :, 0:3] 95 | in_ = in_[:, :, ::-1] 96 | in_ -= np.array((104.00698793, 116.66876762, 122.67891434)) 97 | in_ = in_.transpose((2, 0, 1)) 98 | # remove the following two lines if testing with cpu 99 | 100 | # shape for input (data blob is N x C x H x W), set data 101 | net.blobs["data"].reshape(1, *in_.shape) 102 | net.blobs["data"].data[...] = in_ 103 | # run net and take argmax for prediction 104 | net.forward() 105 | fuse = net.blobs["sigmoid-fuse"].data[0][0, :, :] 106 | # get rid of the border 107 | fuse = fuse[(border + 35) : (-border + 35), (border + 35) : (-border + 35)] 108 | # save hed file to the disk 109 | name, ext = os.path.splitext(imgList[i]) 110 | sio.savemat(os.path.join(args.hed_mat_dir, name + ".mat"), {"edge_predict": fuse}) 111 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/eval_cityscapes/download_fcn8s.sh: -------------------------------------------------------------------------------- 1 | URL=http://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/fcn-8s-cityscapes/fcn-8s-cityscapes.caffemodel 2 | OUTPUT_FILE=./scripts/eval_cityscapes/caffemodel/fcn-8s-cityscapes.caffemodel 3 | wget -N $URL -O $OUTPUT_FILE 4 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/eval_cityscapes/evaluate.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import caffe 5 | import numpy as np 6 | import scipy.misc 7 | from cityscapes import cityscapes 8 | from PIL import Image 9 | 10 | from util import fast_hist, get_scores, segrun 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--cityscapes_dir", 15 | type=str, 16 | required=True, 17 | help="Path to the original cityscapes dataset", 18 | ) 19 | parser.add_argument( 20 | "--result_dir", 21 | type=str, 22 | required=True, 23 | help="Path to the generated images to be evaluated", 24 | ) 25 | parser.add_argument( 26 | "--output_dir", type=str, required=True, help="Where to save the evaluation results" 27 | ) 28 | parser.add_argument( 29 | "--caffemodel_dir", 30 | type=str, 31 | default="./scripts/eval_cityscapes/caffemodel/", 32 | help="Where the FCN-8s caffemodel stored", 33 | ) 34 | parser.add_argument("--gpu_id", type=int, default=0, help="Which gpu id to use") 35 | parser.add_argument( 36 | "--split", type=str, default="val", help="Data split to be evaluated" 37 | ) 38 | parser.add_argument( 39 | "--save_output_images", 40 | type=int, 41 | default=0, 42 | help="Whether to save the FCN output images", 43 | ) 44 | args = parser.parse_args() 45 | 46 | 47 | def main(): 48 | if not os.path.isdir(args.output_dir): 49 | os.makedirs(args.output_dir) 50 | if args.save_output_images > 0: 51 | output_image_dir = args.output_dir + "image_outputs/" 52 | if not os.path.isdir(output_image_dir): 53 | os.makedirs(output_image_dir) 54 | CS = cityscapes(args.cityscapes_dir) 55 | n_cl = len(CS.classes) 56 | label_frames = CS.list_label_frames(args.split) 57 | caffe.set_device(args.gpu_id) 58 | caffe.set_mode_gpu() 59 | net = caffe.Net( 60 | args.caffemodel_dir + "/deploy.prototxt", 61 | args.caffemodel_dir + "fcn-8s-cityscapes.caffemodel", 62 | caffe.TEST, 63 | ) 64 | 65 | hist_perframe = np.zeros((n_cl, n_cl)) 66 | for i, idx in enumerate(label_frames): 67 | if i % 10 == 0: 68 | print("Evaluating: %d/%d" % (i, len(label_frames))) 69 | city = idx.split("_")[0] 70 | # idx is city_shot_frame 71 | label = CS.load_label(args.split, city, idx) 72 | im_file = args.result_dir + "/" + idx + "_leftImg8bit.png" 73 | im = np.array(Image.open(im_file)) 74 | im = scipy.misc.imresize(im, (label.shape[1], label.shape[2])) 75 | out = segrun(net, CS.preprocess(im)) 76 | hist_perframe += fast_hist(label.flatten(), out.flatten(), n_cl) 77 | if args.save_output_images > 0: 78 | label_im = CS.palette(label) 79 | pred_im = CS.palette(out) 80 | scipy.misc.imsave(output_image_dir + "/" + str(i) + "_pred.jpg", pred_im) 81 | scipy.misc.imsave(output_image_dir + "/" + str(i) + "_gt.jpg", label_im) 82 | scipy.misc.imsave(output_image_dir + "/" + str(i) + "_input.jpg", im) 83 | 84 | ( 85 | mean_pixel_acc, 86 | mean_class_acc, 87 | mean_class_iou, 88 | per_class_acc, 89 | per_class_iou, 90 | ) = get_scores(hist_perframe) 91 | with open(args.output_dir + "/evaluation_results.txt", "w") as f: 92 | f.write("Mean pixel accuracy: %f\n" % mean_pixel_acc) 93 | f.write("Mean class accuracy: %f\n" % mean_class_acc) 94 | f.write("Mean class IoU: %f\n" % mean_class_iou) 95 | f.write("************ Per class numbers below ************\n") 96 | for i, cl in enumerate(CS.classes): 97 | while len(cl) < 15: 98 | cl = cl + " " 99 | f.write( 100 | "%s: acc = %f, iou = %f\n" % (cl, per_class_acc[i], per_class_iou[i]) 101 | ) 102 | 103 | 104 | main() 105 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/eval_cityscapes/util.py: -------------------------------------------------------------------------------- 1 | # The following code is modified from https://github.com/shelhamer/clockwork-fcn 2 | import numpy as np 3 | 4 | 5 | def get_out_scoremap(net): 6 | return net.blobs["score"].data[0].argmax(axis=0).astype(np.uint8) 7 | 8 | 9 | def feed_net(net, in_): 10 | """ 11 | Load prepared input into net. 12 | """ 13 | net.blobs["data"].reshape(1, *in_.shape) 14 | net.blobs["data"].data[...] = in_ 15 | 16 | 17 | def segrun(net, in_): 18 | feed_net(net, in_) 19 | net.forward() 20 | return get_out_scoremap(net) 21 | 22 | 23 | def fast_hist(a, b, n): 24 | k = np.where((a >= 0) & (a < n))[0] 25 | bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2) 26 | if len(bc) != n**2: 27 | # ignore this example if dimension mismatch 28 | return 0 29 | return bc.reshape(n, n) 30 | 31 | 32 | def get_scores(hist): 33 | # Mean pixel accuracy 34 | acc = np.diag(hist).sum() / (hist.sum() + 1e-12) 35 | 36 | # Per class accuracy 37 | cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12) 38 | 39 | # Per class IoU 40 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12) 41 | 42 | return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu 43 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/install_deps.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | pip install visdom 3 | pip install dominate 4 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/test_before_push.py: -------------------------------------------------------------------------------- 1 | # Simple script to make sure basic usage 2 | # such as training, testing, saving and loading 3 | # runs without errors. 4 | import os 5 | 6 | 7 | def run(command): 8 | print(command) 9 | exit_status = os.system(command) 10 | if exit_status > 0: 11 | exit(1) 12 | 13 | 14 | if __name__ == "__main__": 15 | # download mini datasets 16 | if not os.path.exists("./datasets/mini"): 17 | run("bash ./datasets/download_cyclegan_dataset.sh mini") 18 | 19 | if not os.path.exists("./datasets/mini_pix2pix"): 20 | run("bash ./datasets/download_cyclegan_dataset.sh mini_pix2pix") 21 | 22 | # pretrained cyclegan model 23 | if not os.path.exists("./checkpoints/horse2zebra_pretrained/latest_net_G.pth"): 24 | run("bash ./scripts/download_cyclegan_model.sh horse2zebra") 25 | run( 26 | "python test.py --model test --dataroot ./datasets/mini --name horse2zebra_pretrained --no_dropout --num_test 1 --no_dropout" 27 | ) 28 | 29 | # pretrained pix2pix model 30 | if not os.path.exists( 31 | "./checkpoints/facades_label2photo_pretrained/latest_net_G.pth" 32 | ): 33 | run("bash ./scripts/download_pix2pix_model.sh facades_label2photo") 34 | if not os.path.exists("./datasets/facades"): 35 | run("bash ./datasets/download_pix2pix_dataset.sh facades") 36 | run( 37 | "python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_label2photo_pretrained --num_test 1" 38 | ) 39 | 40 | # cyclegan train/test 41 | run( 42 | "python train.py --model cycle_gan --name temp_cyclegan --dataroot ./datasets/mini --n_epochs 1 --n_epochs_decay 0 --save_latest_freq 10 --print_freq 1 --display_id -1" 43 | ) 44 | run( 45 | 'python test.py --model test --name temp_cyclegan --dataroot ./datasets/mini --num_test 1 --model_suffix "_A" --no_dropout' 46 | ) 47 | 48 | # pix2pix train/test 49 | run( 50 | "python train.py --model pix2pix --name temp_pix2pix --dataroot ./datasets/mini_pix2pix --n_epochs 1 --n_epochs_decay 5 --save_latest_freq 10 --display_id -1" 51 | ) 52 | run( 53 | "python test.py --model pix2pix --name temp_pix2pix --dataroot ./datasets/mini_pix2pix --num_test 1" 54 | ) 55 | 56 | # template train/test 57 | run( 58 | "python train.py --model template --name temp2 --dataroot ./datasets/mini_pix2pix --n_epochs 1 --n_epochs_decay 0 --save_latest_freq 10 --display_id -1" 59 | ) 60 | run( 61 | "python test.py --model template --name temp2 --dataroot ./datasets/mini_pix2pix --num_test 1" 62 | ) 63 | 64 | # colorization train/test (optional) 65 | if not os.path.exists("./datasets/mini_colorization"): 66 | run("bash ./datasets/download_cyclegan_dataset.sh mini_colorization") 67 | 68 | run( 69 | "python train.py --model colorization --name temp_color --dataroot ./datasets/mini_colorization --n_epochs 1 --n_epochs_decay 0 --save_latest_freq 5 --display_id -1" 70 | ) 71 | run( 72 | "python test.py --model colorization --name temp_color --dataroot ./datasets/mini_colorization --num_test 1" 73 | ) 74 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/test_colorization.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot ./datasets/colorization --name color_pix2pix --model colorization 3 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/test_cyclegan.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test --no_dropout 3 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/test_fid.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test_fid.py \ 3 | --dataroot /media/lyndon/c6f4bbbd-8d47-4dcb-b0db-d788fe2b25571/dataset/image_translation/horse2zebra \ 4 | --checkpoints_dir ./checkpoints \ 5 | --name horse2zebra \ 6 | --gpu_ids 1 \ 7 | --model sc \ 8 | --num_test 0 -------------------------------------------------------------------------------- /F-LSeSim/scripts/test_pix2pix.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --netG unet_256 --direction BtoA --dataset_mode aligned --norm batch 3 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/test_sc.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py \ 3 | --dataroot /media/lyndon/c6f4bbbd-8d47-4dcb-b0db-d788fe2b25571/dataset/image_translation/horse2zebra \ 4 | --checkpoints_dir ./checkpoints \ 5 | --name horse2zebra \ 6 | --model sc \ 7 | --num_test 0 -------------------------------------------------------------------------------- /F-LSeSim/scripts/test_single.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py --dataroot ./datasets/facades/testB/ --name facades_pix2pix --model test --netG unet_256 --direction BtoA --dataset_mode single --norm batch 3 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/test_sinsc.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python test.py \ 3 | --dataroot /media/lyndon/c6f4bbbd-8d47-4dcb-b0db-d788fe2b25571/dataset/image_translation/single_image_monet_etretat \ 4 | --name image2monet \ 5 | --model sinsc -------------------------------------------------------------------------------- /F-LSeSim/scripts/train_colorization.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python train.py --dataroot ./datasets/colorization --name color_pix2pix --model colorization 3 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/train_cyclegan.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --pool_size 50 --no_dropout 3 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/train_pix2pix.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --netG unet_256 --direction BtoA --lambda_L1 100 --dataset_mode aligned --norm batch --pool_size 0 3 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/train_sc.sh: -------------------------------------------------------------------------------- 1 | echo "Read config file: $1"; 2 | mode=$(grep -A0 'MODEL_NAME:' $1 | cut -c 13- | cut -d# -f1 | tr -d '"' | sed "s/ //g"); echo "$mode" 3 | if [ "$mode" = "LSeSim" ]; then 4 | exp=$(grep -A0 'EXPERIMENT_NAME:' $1 | cut -c 18- | tr -d '"') 5 | data=$(grep -A0 'TRAIN_ROOT:' $1 | cut -c 15- | tr -d '"') 6 | augment=$(grep -A0 'Augment:' $1 | cut -c 12- | cut -d# -f1 | sed "s/ //g") 7 | if [ "$augment" = "True" ]; then 8 | set -ex 9 | python train.py \ 10 | --dataroot $data \ 11 | --name $exp \ 12 | --model sc \ 13 | --gpu_ids 1 \ 14 | --lambda_spatial 10 \ 15 | --lambda_gradient 0 \ 16 | --attn_layers 4,7,9 \ 17 | --loss_mode cos \ 18 | --gan_mode lsgan \ 19 | --display_port 8093 \ 20 | --direction AtoB \ 21 | --patch_size 64 \ 22 | --learned_attn \ 23 | --augment 24 | else 25 | set -ex 26 | python train.py \ 27 | --dataroot $data \ 28 | --name $exp \ 29 | --model sc \ 30 | --gpu_ids 1 \ 31 | --lambda_spatial 10 \ 32 | --lambda_gradient 0 \ 33 | --attn_layers 4,7,9 \ 34 | --loss_mode cos \ 35 | --gan_mode lsgan \ 36 | --display_port 8093 \ 37 | --direction AtoB \ 38 | --patch_size 64 39 | fi 40 | else 41 | echo "Not LSeSim model" 42 | fi 43 | -------------------------------------------------------------------------------- /F-LSeSim/scripts/train_sinsc.sh: -------------------------------------------------------------------------------- 1 | set -ex 2 | python train.py \ 3 | --dataroot /media/lyndon/c6f4bbbd-8d47-4dcb-b0db-d788fe2b25571/dataset/image_translation/single_image_monet_etretat \ 4 | --name image2monet \ 5 | --model sinsc \ 6 | --gpu_ids 1 \ 7 | --display_port 8093 \ 8 | --pool_size 0 -------------------------------------------------------------------------------- /F-LSeSim/scripts/transfer_sc.sh: -------------------------------------------------------------------------------- 1 | echo "Read config file: $1"; 2 | mode=$(grep -A0 'MODEL_NAME:' $1 | cut -c 13- | cut -d# -f1 | tr -d '"' | sed "s/ //g"); echo "$mode" 3 | if [ "$mode" = "LSeSim" ]; then 4 | exp=$(grep -A0 'EXPERIMENT_NAME:' $1 | cut -c 18- | tr -d '"') 5 | python3 inference.py \ 6 | --name $exp \ 7 | --model sc \ 8 | -c $1 9 | 10 | python3 combine.py \ 11 | --read_original \ 12 | -c $1 13 | else 14 | echo "Not LSeSim model" 15 | fi 16 | -------------------------------------------------------------------------------- /F-LSeSim/test.py: -------------------------------------------------------------------------------- 1 | """General-purpose test script for image-to-image translation. 2 | 3 | Once you have trained your model with train.py, you can use this script to test the model. 4 | It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'. 5 | 6 | It first creates model and dataset given the option. It will hard-code some parameters. 7 | It then runs inference for '--num_test' images and save results to an HTML file. 8 | 9 | Example (You need to train models first or download pre-trained models from our website): 10 | Test a CycleGAN model (both sides): 11 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan 12 | 13 | Test a CycleGAN model (one side only): 14 | python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout 15 | 16 | The option '--model test' is used for generating CycleGAN results only for one side. 17 | This option will automatically set '--dataset_mode single', which only loads the images from one set. 18 | On the contrary, using '--model cycle_gan' requires loading and generating results in both directions, 19 | which is sometimes unnecessary. The results will be saved at ./results/. 20 | Use '--results_dir ' to specify the results directory. 21 | 22 | Test a pix2pix model: 23 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 24 | 25 | See options/base_options.py and options/test_options.py for more test options. 26 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md 27 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md 28 | """ 29 | import os 30 | 31 | from data import create_dataset 32 | from models import create_model 33 | from options.test_options import TestOptions 34 | from util import html 35 | from util.visualizer import save_images 36 | 37 | if __name__ == "__main__": 38 | opt = TestOptions().parse() # get test options 39 | # hard-code some parameters for test 40 | opt.num_threads = 0 # test code only supports num_threads = 1 41 | opt.batch_size = 1 # test code only supports batch_size = 1 42 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 43 | opt.no_flip = ( 44 | True # no flip; comment this line if results on flipped images are needed. 45 | ) 46 | opt.display_id = ( 47 | -1 48 | ) # no visdom display; the test code saves the results to a HTML file. 49 | dataset = create_dataset( 50 | opt 51 | ) # create a dataset given opt.dataset_mode and other options 52 | model = create_model(opt) # create a model given opt.model and other options 53 | # create a website 54 | web_dir = os.path.join( 55 | opt.results_dir, opt.name, "{}_{}".format(opt.phase, opt.epoch) 56 | ) # define the website directory 57 | print("creating web directory", web_dir) 58 | webpage = html.HTML( 59 | web_dir, 60 | "Experiment = %s, Phase = %s, Epoch = %s" % (opt.name, opt.phase, opt.epoch), 61 | ) 62 | # test with eval mode. This only affects layers like batchnorm and dropout. 63 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 64 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 65 | opt.num_test = opt.num_test if opt.num_test > 0 else float("inf") 66 | for i, data in enumerate(dataset): 67 | if i == 0: 68 | model.data_dependent_initialize(data) 69 | model.setup(opt) 70 | model.parallelize() 71 | if opt.eval: 72 | model.eval() 73 | if i >= opt.num_test: # only apply our model to opt.num_test images. 74 | break 75 | model.set_input(data) # unpack data from data loader 76 | model.test() # run inference 77 | visuals = model.get_current_visuals() # get image results 78 | img_path = model.get_image_paths() # get image paths 79 | if i % 5 == 0: # save images to an HTML file 80 | print("processing (%04d)-th image... %s" % (i, img_path)) 81 | save_images( 82 | webpage, 83 | visuals, 84 | img_path, 85 | aspect_ratio=opt.aspect_ratio, 86 | width=opt.display_winsize, 87 | ) 88 | webpage.save() # save the HTML 89 | -------------------------------------------------------------------------------- /F-LSeSim/test_fid.py: -------------------------------------------------------------------------------- 1 | """General-purpose test script for image-to-image translation. 2 | 3 | Once you have trained your model with train.py, you can use this script to test the model. 4 | It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'. 5 | 6 | It first creates model and dataset given the option. It will hard-code some parameters. 7 | It then runs inference for '--num_test' images and save results to an HTML file. 8 | 9 | Example (You need to train models first or download pre-trained models from our website): 10 | Test a CycleGAN model (both sides): 11 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan 12 | 13 | Test a CycleGAN model (one side only): 14 | python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout 15 | 16 | The option '--model test' is used for generating CycleGAN results only for one side. 17 | This option will automatically set '--dataset_mode single', which only loads the images from one set. 18 | On the contrary, using '--model cycle_gan' requires loading and generating results in both directions, 19 | which is sometimes unnecessary. The results will be saved at ./results/. 20 | Use '--results_dir ' to specify the results directory. 21 | 22 | Test a pix2pix model: 23 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 24 | 25 | See options/base_options.py and options/test_options.py for more test options. 26 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md 27 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md 28 | """ 29 | import os 30 | 31 | import matplotlib.pyplot as plt 32 | 33 | import util.util as util 34 | from data import create_dataset 35 | from evaluations.fid_score import calculate_fid_given_paths 36 | from models import create_model 37 | from options.test_options import TestOptions 38 | from util import html 39 | from util.visualizer import save_images 40 | 41 | if __name__ == "__main__": 42 | opt = TestOptions().parse() # get test options 43 | # hard-code some parameters for test 44 | opt.num_threads = 0 # test code only supports num_threads = 1 45 | opt.batch_size = 1 # test code only supports batch_size = 1 46 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 47 | opt.no_flip = ( 48 | True # no flip; comment this line if results on flipped images are needed. 49 | ) 50 | opt.display_id = ( 51 | -1 52 | ) # no visdom display; the test code saves the results to a HTML file. 53 | opt.num_test = opt.num_test if opt.num_test > 0 else float("inf") 54 | dataset = create_dataset( 55 | opt 56 | ) # create a dataset given opt.dataset_mode and other options 57 | train_dataset = create_dataset(util.copyconf(opt, phase="train")) 58 | # traverse all epoch for the evaluation 59 | files_list = os.listdir(opt.checkpoints_dir + "/" + opt.name) 60 | epoches = [] 61 | fid_values = {} 62 | for file in files_list: 63 | if "net_G" in file and "latest" not in file: 64 | name = file.split("_") 65 | epoches.append(name[0]) 66 | for epoch in epoches: 67 | opt.epoch = epoch 68 | model = create_model(opt) # create a model given opt.model and other options 69 | # create a website 70 | web_dir = os.path.join( 71 | opt.results_dir, opt.name, "{}_{}".format(opt.phase, opt.epoch) 72 | ) # define the website directory 73 | print("creating web directory", web_dir) 74 | webpage = html.HTML( 75 | web_dir, 76 | "Experiment = %s, Phase = %s, Epoch = %s" 77 | % (opt.name, opt.phase, opt.epoch), 78 | ) 79 | # test with eval mode. This only affects layers like batchnorm and dropout. 80 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 81 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 82 | for i, data in enumerate(dataset): 83 | if i == 0: 84 | model.data_dependent_initialize(data) 85 | model.setup( 86 | opt 87 | ) # regular setup: load and print networks; create schedulers 88 | model.parallelize() 89 | if opt.eval: 90 | model.eval() 91 | if i >= opt.num_test: # only apply our model to opt.num_test images. 92 | break 93 | model.set_input(data) # unpack data from data loader 94 | model.test() # run inference 95 | visuals = model.get_current_visuals() # get image results 96 | img_path = model.get_image_paths() # get image paths 97 | if i % 5 == 0: # save images to an HTML file 98 | print("processing (%04d)-th image... %s" % (i, img_path)) 99 | save_images( 100 | webpage, 101 | visuals, 102 | img_path, 103 | aspect_ratio=opt.aspect_ratio, 104 | width=opt.display_winsize, 105 | ) 106 | paths = [ 107 | os.path.join(web_dir, "images", "fake_B"), 108 | os.path.join(web_dir, "images", "real_B"), 109 | ] 110 | fid_value = calculate_fid_given_paths(paths, 50, True, 2048) 111 | fid_values[int(epoch)] = fid_value 112 | webpage.save() # save the HTML 113 | print(fid_values) 114 | x = [] 115 | y = [] 116 | for key in sorted(fid_values.keys()): 117 | x.append(key) 118 | y.append(fid_values[key]) 119 | plt.figure() 120 | plt.plot(x, y) 121 | for a, b in zip(x, y): 122 | plt.text(a, b, str(round(b, 2))) 123 | plt.xlabel("Epoch") 124 | plt.ylabel("FID on test set") 125 | plt.title(opt.name) 126 | plt.savefig(os.path.join(opt.results_dir, opt.name, "fid.jpg")) 127 | -------------------------------------------------------------------------------- /F-LSeSim/train.py: -------------------------------------------------------------------------------- 1 | """General-purpose training script for image-to-image translation. 2 | 3 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and 4 | different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). 5 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). 6 | 7 | It first creates model, dataset, and visualizer given the option. 8 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. 9 | The script supports continue/resume training. Use '--continue_train' to resume your previous training. 10 | 11 | Example: 12 | Train a CycleGAN model: 13 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan 14 | Train a pix2pix model: 15 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA 16 | 17 | See options/base_options.py and options/train_options.py for more training options. 18 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md 19 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md 20 | """ 21 | import os 22 | import time 23 | from collections import defaultdict 24 | 25 | from torchvision.utils import save_image 26 | 27 | from data import create_dataset 28 | from models import create_model 29 | from options.train_options import TrainOptions 30 | 31 | 32 | # from util.visualizer import Visualizer 33 | def reverse_image_normalize(img, mean=0.5, std=0.5): 34 | return img * std + mean 35 | 36 | 37 | if __name__ == "__main__": 38 | opt = TrainOptions().parse() # get training options 39 | dataset = create_dataset( 40 | opt 41 | ) # create a dataset given opt.dataset_mode and other options 42 | dataset_size = len(dataset) # get the number of images in the dataset. 43 | print("The number of training images = %d" % dataset_size) 44 | 45 | model = create_model( 46 | opt, norm_cfg={'type': 'in'} 47 | ) # create a model given opt.model and other options 48 | 49 | total_iters = 0 # the total number of training iterations 50 | 51 | os.makedirs(os.path.join("./checkpoints/", opt.name, "test"), exist_ok=True) 52 | for epoch in range( 53 | opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1 54 | ): # outer loop for different epochs; we save the model by , + 55 | epoch_start_time = time.time() # timer for entire epoch 56 | iter_data_time = time.time() # timer for data loading per iteration 57 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 58 | out = defaultdict(int) 59 | for i, data in enumerate(dataset): # inner loop within one epoch 60 | iter_start_time = time.time() # timer for computation per iteration 61 | print(f"[Epoch {epoch}][Iter {i}] Processing ...", end="\r") 62 | if total_iters % opt.print_freq == 0: 63 | t_data = iter_start_time - iter_data_time 64 | 65 | total_iters += opt.batch_size 66 | epoch_iter += opt.batch_size 67 | if epoch == opt.epoch_count and i == 0: 68 | model.data_dependent_initialize(data) 69 | model.setup(opt) 70 | model.parallelize() 71 | model.print_networks(True) 72 | model.set_input(data) # unpack data from dataset and apply preprocessing 73 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 74 | 75 | if ( 76 | total_iters % opt.display_freq == 0 77 | ): # display images on visdom and save images to a HTML file 78 | save_result = total_iters % opt.update_html_freq == 0 79 | model.compute_visuals() 80 | results = model.get_current_visuals() 81 | 82 | for img_name, img in results.items(): 83 | save_image( 84 | reverse_image_normalize(img), 85 | os.path.join( 86 | "./checkpoints/", 87 | opt.name, 88 | "test", 89 | f"{epoch}_{img_name}_{i}.png", 90 | ), 91 | ) 92 | 93 | for k, v in out.items(): 94 | out[k] /= opt.display_freq 95 | 96 | print(f"[Epoch {epoch}][Iter {i}] {out}", flush=True) 97 | for k, v in out.items(): 98 | out[k] = 0 99 | 100 | losses = model.get_current_losses() 101 | for k, v in losses.items(): 102 | out[k] += v 103 | 104 | if ( 105 | total_iters % opt.save_latest_freq == 0 106 | ): # cache our latest model every iterations 107 | print( 108 | "saving the latest model (epoch %d, total_iters %d)" 109 | % (epoch, total_iters) 110 | ) 111 | save_suffix = "iter_%d" % total_iters if opt.save_by_iter else "latest" 112 | model.save_networks(save_suffix) 113 | 114 | iter_data_time = time.time() 115 | if ( 116 | epoch % opt.save_epoch_freq == 0 117 | ): # cache our model every epochs 118 | print( 119 | "saving the model at the end of epoch %d, iters %d" 120 | % (epoch, total_iters) 121 | ) 122 | model.save_networks("latest") 123 | model.save_networks(epoch) 124 | 125 | print( 126 | "End of epoch %d / %d \t Time Taken: %d sec" 127 | % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time) 128 | ) 129 | model.update_learning_rate() # update learning rates in the beginning of every epoch. 130 | -------------------------------------------------------------------------------- /F-LSeSim/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /F-LSeSim/util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import os 4 | import tarfile 5 | from os.path import abspath, basename, isdir, join 6 | from warnings import warn 7 | from zipfile import ZipFile 8 | 9 | import requests 10 | from bs4 import BeautifulSoup 11 | 12 | 13 | class GetData(object): 14 | """A Python script for downloading CycleGAN or pix2pix datasets. 15 | 16 | Parameters: 17 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 18 | verbose (bool) -- If True, print additional information. 19 | 20 | Examples: 21 | >>> from util.get_data import GetData 22 | >>> gd = GetData(technique='cyclegan') 23 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 24 | 25 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 26 | and 'scripts/download_cyclegan_model.sh'. 27 | """ 28 | 29 | def __init__(self, technique="cyclegan", verbose=True): 30 | url_dict = { 31 | "pix2pix": "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/", 32 | "cyclegan": "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets", 33 | } 34 | self.url = url_dict.get(technique.lower()) 35 | self._verbose = verbose 36 | 37 | def _print(self, text): 38 | if self._verbose: 39 | print(text) 40 | 41 | @staticmethod 42 | def _get_options(r): 43 | soup = BeautifulSoup(r.text, "lxml") 44 | options = [ 45 | h.text 46 | for h in soup.find_all("a", href=True) 47 | if h.text.endswith((".zip", "tar.gz")) 48 | ] 49 | return options 50 | 51 | def _present_options(self): 52 | r = requests.get(self.url) 53 | options = self._get_options(r) 54 | print("Options:\n") 55 | for i, o in enumerate(options): 56 | print("{0}: {1}".format(i, o)) 57 | choice = input( 58 | "\nPlease enter the number of the " "dataset above you wish to download:" 59 | ) 60 | return options[int(choice)] 61 | 62 | def _download_data(self, dataset_url, save_path): 63 | if not isdir(save_path): 64 | os.makedirs(save_path) 65 | 66 | base = basename(dataset_url) 67 | temp_save_path = join(save_path, base) 68 | 69 | with open(temp_save_path, "wb") as f: 70 | r = requests.get(dataset_url) 71 | f.write(r.content) 72 | 73 | if base.endswith(".tar.gz"): 74 | obj = tarfile.open(temp_save_path) 75 | elif base.endswith(".zip"): 76 | obj = ZipFile(temp_save_path, "r") 77 | else: 78 | raise ValueError("Unknown File Type: {0}.".format(base)) 79 | 80 | self._print("Unpacking Data...") 81 | obj.extractall(save_path) 82 | obj.close() 83 | os.remove(temp_save_path) 84 | 85 | def get(self, save_path, dataset=None): 86 | """ 87 | 88 | Download a dataset. 89 | 90 | Parameters: 91 | save_path (str) -- A directory to save the data to. 92 | dataset (str) -- (optional). A specific dataset to download. 93 | Note: this must include the file extension. 94 | If None, options will be presented for you 95 | to choose from. 96 | 97 | Returns: 98 | save_path_full (str) -- the absolute path to the downloaded data. 99 | 100 | """ 101 | if dataset is None: 102 | selected_dataset = self._present_options() 103 | else: 104 | selected_dataset = dataset 105 | 106 | save_path_full = join(save_path, selected_dataset.split(".")[0]) 107 | 108 | if isdir(save_path_full): 109 | warn("\n'{0}' already exists. Voiding Download.".format(save_path_full)) 110 | else: 111 | self._print("Downloading Data...") 112 | url = "{0}/{1}".format(self.url, selected_dataset) 113 | self._download_data(url, save_path=save_path) 114 | 115 | return abspath(save_path_full) 116 | -------------------------------------------------------------------------------- /F-LSeSim/util/html.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import dominate 4 | from dominate.tags import a, br, h3, img, meta, p, table, td, tr 5 | 6 | 7 | class HTML: 8 | """This HTML class allows us to save images and write texts into a single HTML file. 9 | 10 | It consists of functions such as (add a text header to the HTML file), 11 | (add a row of images to the HTML file), and (save the HTML to the disk). 12 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 13 | """ 14 | 15 | def __init__(self, web_dir, title, refresh=0): 16 | """Initialize the HTML classes 17 | 18 | Parameters: 19 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 33 | with self.doc.head: 34 | meta(http_equiv="refresh", content=str(refresh)) 35 | 36 | def get_image_dir(self): 37 | """Return the directory that stores images""" 38 | return self.img_dir 39 | 40 | def add_header(self, text): 41 | """Insert a header to the HTML file 42 | 43 | Parameters: 44 | text (str) -- the header text 45 | """ 46 | with self.doc: 47 | h3(text) 48 | 49 | def add_images(self, ims, txts, links, width=400): 50 | """add images to the HTML file 51 | 52 | Parameters: 53 | ims (str list) -- a list of image paths 54 | txts (str list) -- a list of image names shown on the website 55 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 56 | """ 57 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 58 | self.doc.add(self.t) 59 | with self.t: 60 | with tr(): 61 | for im, txt, link in zip(ims, txts, links): 62 | with td( 63 | style="word-wrap: break-word;", halign="center", valign="top" 64 | ): 65 | with p(): 66 | with a(href=os.path.join("images", link)): 67 | img( 68 | style="width:%dpx" % width, 69 | src=os.path.join("images", im), 70 | ) 71 | br() 72 | p(txt) 73 | 74 | def save(self): 75 | """save the current content to the HMTL file""" 76 | html_file = "%s/index.html" % self.web_dir 77 | f = open(html_file, "wt") 78 | f.write(self.doc.render()) 79 | f.close() 80 | 81 | 82 | if __name__ == "__main__": # we show an example usage here. 83 | html = HTML("web/", "test_html") 84 | html.add_header("hello world") 85 | 86 | ims, txts, links = [], [], [] 87 | for n in range(4): 88 | ims.append("image_%d.png" % n) 89 | txts.append("text_%d" % n) 90 | links.append("image_%d.png" % n) 91 | html.add_images(ims, txts, links) 92 | html.save() 93 | -------------------------------------------------------------------------------- /F-LSeSim/util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import torch 4 | 5 | 6 | class ImagePool: 7 | """This class implements an image buffer that stores previously generated images. 8 | 9 | This buffer enables us to update discriminators using a history of generated images 10 | rather than the ones produced by the latest generators. 11 | """ 12 | 13 | def __init__(self, pool_size): 14 | """Initialize the ImagePool class 15 | 16 | Parameters: 17 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 18 | """ 19 | self.pool_size = pool_size 20 | if self.pool_size > 0: # create an empty pool 21 | self.num_imgs = 0 22 | self.images = [] 23 | 24 | def query(self, images): 25 | """Return an image from the pool. 26 | 27 | Parameters: 28 | images: the latest generated images from the generator 29 | 30 | Returns images from the buffer. 31 | 32 | By 50/100, the buffer will return input images. 33 | By 50/100, the buffer will return images previously stored in the buffer, 34 | and insert the current images to the buffer. 35 | """ 36 | if self.pool_size == 0: # if the buffer size is 0, do nothing 37 | return images 38 | return_images = [] 39 | for image in images: 40 | image = torch.unsqueeze(image.data, 0) 41 | if ( 42 | self.num_imgs < self.pool_size 43 | ): # if the buffer is not full; keep inserting current images to the buffer 44 | self.num_imgs = self.num_imgs + 1 45 | self.images.append(image) 46 | return_images.append(image) 47 | else: 48 | p = random.uniform(0, 1) 49 | if ( 50 | p > 0.5 51 | ): # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 52 | random_id = random.randint( 53 | 0, self.pool_size - 1 54 | ) # randint is inclusive 55 | tmp = self.images[random_id].clone() 56 | self.images[random_id] = image 57 | return_images.append(tmp) 58 | else: # by another 50% chance, the buffer will return the current image 59 | return_images.append(image) 60 | return_images = torch.cat(return_images, 0) # collect all the images and return 61 | return return_images 62 | -------------------------------------------------------------------------------- /F-LSeSim/util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | 4 | import argparse 5 | import importlib 6 | import os 7 | from argparse import Namespace 8 | 9 | import cv2 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | import torchvision 14 | from PIL import Image 15 | 16 | 17 | def str2bool(v): 18 | if isinstance(v, bool): 19 | return v 20 | if v.lower() in ("yes", "true", "t", "y", "1"): 21 | return True 22 | elif v.lower() in ("no", "false", "f", "n", "0"): 23 | return False 24 | else: 25 | raise argparse.ArgumentTypeError("Boolean value expected.") 26 | 27 | 28 | def copyconf(default_opt, **kwargs): 29 | conf = Namespace(**vars(default_opt)) 30 | for key in kwargs: 31 | setattr(conf, key, kwargs[key]) 32 | return conf 33 | 34 | 35 | def find_class_in_module(target_cls_name, module): 36 | target_cls_name = target_cls_name.replace("_", "").lower() 37 | clslib = importlib.import_module(module) 38 | cls = None 39 | for name, clsobj in clslib.__dict__.items(): 40 | if name.lower() == target_cls_name: 41 | cls = clsobj 42 | 43 | assert cls is not None, ( 44 | "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" 45 | % (module, target_cls_name) 46 | ) 47 | 48 | return cls 49 | 50 | 51 | def tensor2im(input_image, imtype=np.uint8): 52 | """ "Converts a Tensor array into a numpy image array. 53 | 54 | Parameters: 55 | input_image (tensor) -- the input image tensor array 56 | imtype (type) -- the desired type of the converted numpy array 57 | """ 58 | if not isinstance(input_image, np.ndarray): 59 | if isinstance(input_image, torch.Tensor): # get the data from a variable 60 | image_tensor = input_image.data 61 | else: 62 | return input_image 63 | image_numpy = ( 64 | image_tensor[0].cpu().float().numpy() 65 | ) # convert it into a numpy array 66 | if image_numpy.shape[0] == 1: # grayscale to RGB 67 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 68 | image_numpy = ( 69 | (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 70 | ) # post-processing: tranpose and scaling 71 | else: # if it is a numpy array, do nothing 72 | image_numpy = input_image 73 | return image_numpy.astype(imtype) 74 | 75 | 76 | def diagnose_network(net, name="network"): 77 | """Calculate and print the mean of average absolute(gradients) 78 | 79 | Parameters: 80 | net (torch network) -- Torch network 81 | name (str) -- the name of the network 82 | """ 83 | mean = 0.0 84 | count = 0 85 | for param in net.parameters(): 86 | if param.grad is not None: 87 | mean += torch.mean(torch.abs(param.grad.data)) 88 | count += 1 89 | if count > 0: 90 | mean = mean / count 91 | print(name) 92 | print(mean) 93 | 94 | 95 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 96 | """Save a numpy image to the disk 97 | 98 | Parameters: 99 | image_numpy (numpy array) -- input numpy array 100 | image_path (str) -- the path of the image 101 | """ 102 | 103 | image_pil = Image.fromarray(image_numpy) 104 | h, w, _ = image_numpy.shape 105 | 106 | if aspect_ratio > 1.0: 107 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 108 | if aspect_ratio < 1.0: 109 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 110 | image_pil.save(image_path) 111 | 112 | 113 | def print_numpy(x, val=True, shp=False): 114 | """Print the mean, min, max, median, std, and size of a numpy array 115 | 116 | Parameters: 117 | val (bool) -- if print the values of the numpy array 118 | shp (bool) -- if print the shape of the numpy array 119 | """ 120 | x = x.astype(np.float64) 121 | if shp: 122 | print("shape,", x.shape) 123 | if val: 124 | x = x.flatten() 125 | print( 126 | "mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f" 127 | % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)) 128 | ) 129 | 130 | 131 | def mkdirs(paths): 132 | """create empty directories if they don't exist 133 | 134 | Parameters: 135 | paths (str list) -- a list of directory paths 136 | """ 137 | if isinstance(paths, list) and not isinstance(paths, str): 138 | for path in paths: 139 | mkdir(path) 140 | else: 141 | mkdir(paths) 142 | 143 | 144 | def mkdir(path): 145 | """create a single empty directory if it didn't exist 146 | 147 | Parameters: 148 | path (str) -- a single directory path 149 | """ 150 | if not os.path.exists(path): 151 | os.makedirs(path) 152 | 153 | 154 | def correct_resize_label(t, size): 155 | device = t.device 156 | t = t.detach().cpu() 157 | resized = [] 158 | for i in range(t.size(0)): 159 | one_t = t[i, :1] 160 | one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0)) 161 | one_np = one_np[:, :, 0] 162 | one_image = Image.fromarray(one_np).resize(size, Image.NEAREST) 163 | resized_t = torch.from_numpy(np.array(one_image)).long() 164 | resized.append(resized_t) 165 | return torch.stack(resized, dim=0).to(device) 166 | 167 | 168 | def correct_resize(t, size, mode=Image.BICUBIC): 169 | device = t.device 170 | t = t.detach().cpu() 171 | resized = [] 172 | for i in range(t.size(0)): 173 | one_t = t[i : i + 1] 174 | one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC) 175 | resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0 176 | resized.append(resized_t) 177 | return torch.stack(resized, dim=0).to(device) 178 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ming-Yang Ho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /appendix/gpu_memory.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import matplotlib 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import seaborn as sns 7 | 8 | random.seed(20222022) 9 | 10 | if __name__ == "__main__": 11 | memory_dict = { 12 | "training": { 13 | "CycleGAN": {"x": [128, 256, 512], "y": [2939, 5517, 16357]}, 14 | "CUT": {"x": [128, 256, 512], "y": [2365, 4133, 11123]}, 15 | "F-LSeSim": {"x": [256, 512, 1024], "y": [3585, 8813, 23023]}, 16 | "L-LSeSim": {"x": [256, 512, 1024], "y": [3605, 9331, 29289]}, 17 | }, 18 | "inference": { 19 | "CycleGAN": { 20 | "x": [128, 256, 512, 1024, 2048], 21 | "y": [1697, 1795, 2211, 3749, 12073], 22 | }, 23 | "CUT": { 24 | "x": [128, 256, 512, 1024, 2048], 25 | "y": [1602, 1695, 2111, 3654, 11987], 26 | }, 27 | "F-LSeSim": {"x": [512, 1024], "y": [8571, 19269]}, 28 | "L-LSeSim": {"x": [512, 1024], "y": [8571, 19269]}, 29 | "CycleGAN+KIN (ours)": { 30 | "x": [128, 256, 512, 1024, 2048, 4096, 9192], 31 | "y": [2307, 2307, 2307, 2307, 2307, 2307, 2307], 32 | }, 33 | "CUT+KIN (ours)": { 34 | "x": [128, 256, 512, 1024, 2048, 4096, 9192], 35 | "y": [2307, 2307, 2307, 2307, 2307, 2307, 2307], 36 | }, 37 | "F/L-LSeSim+KIN (ours)": { 38 | "x": [128, 256, 512, 1024, 2048, 4096, 9192], 39 | "y": [8581, 8581, 8581, 8581, 8581, 8581, 8581], 40 | }, 41 | }, 42 | } 43 | 44 | pal = sns.color_palette("husl", 8) 45 | hex_colors = list(map(matplotlib.colors.rgb2hex, pal)) 46 | pal 47 | 48 | marker_list = ["o", "X", "v", "s", "^", "p", "D"] 49 | plt.figure(figsize=(9, 4)) 50 | for idx, model_name in enumerate(memory_dict["training"]): 51 | x = memory_dict["training"][model_name]["x"] 52 | y = memory_dict["training"][model_name]["y"] 53 | p2 = np.poly1d(np.polyfit(x, y, 2)) 54 | 55 | xp = np.linspace(128, x[-1], 100) 56 | xp_extra = np.linspace(x[-1], 2000, 100) 57 | marker_on = [] 58 | for x_ in x: 59 | marker_on.append(np.searchsorted(xp, x_, side="left")) 60 | # add rnd to avoid overlapping 61 | rnd = random.randint(0, 200) 62 | 63 | plt.plot( 64 | xp, 65 | p2(xp) + rnd, 66 | color=hex_colors[idx], 67 | linewidth=3.0, 68 | linestyle="-", 69 | markersize=8, 70 | marker=marker_list[idx], 71 | markevery=marker_on, 72 | label=model_name, 73 | ) 74 | plt.plot( 75 | xp_extra, 76 | p2(xp_extra) + rnd, 77 | color=hex_colors[idx], 78 | linewidth=3.0, 79 | linestyle="--", 80 | ) 81 | 82 | plt.xticks( 83 | [128, 256, 512, 1024, 2048, 4096], 84 | [128, 256, 512, 1024, 2048, 4096], 85 | rotation=45, 86 | ) 87 | plt.yticks( 88 | [5000, 10000, 15000, 20000, 25000, 30000], 89 | [5, 10, 15, 20, 25, 30], 90 | ) 91 | plt.ylim(0, 32000) 92 | plt.title("Training") 93 | plt.ylabel("GPU Memory (GB)") 94 | plt.xlabel("Resolution (√x)") 95 | plt.legend() 96 | plt.grid() 97 | # plt.show() 98 | plt.savefig("./training_usage.png", bbox_inches="tight") 99 | 100 | plt.figure(figsize=(9, 4)) 101 | for idx, model_name in enumerate(memory_dict["inference"]): 102 | x = memory_dict["inference"][model_name]["x"] 103 | y = memory_dict["inference"][model_name]["y"] 104 | if "KIN" not in model_name: 105 | p2 = np.poly1d(np.polyfit(x, y, 2)) 106 | else: 107 | p2 = np.poly1d(np.polyfit(x, y, 0)) 108 | 109 | xp = np.linspace(128, x[-1], 100) 110 | xp_extra = np.linspace(x[-1], 10000, 100) 111 | marker_on = [] 112 | 113 | # add rnd to avoid overlapping 114 | rnd = random.randint(0, 2000) 115 | 116 | for x_ in x: 117 | marker_on.append(np.searchsorted(xp, x_, side="left")) 118 | plt.plot( 119 | xp, 120 | p2(xp) + rnd, 121 | color=hex_colors[idx], 122 | linewidth=3.0, 123 | linestyle="-", 124 | markersize=8, 125 | marker=marker_list[idx], 126 | markevery=marker_on, 127 | label=model_name, 128 | ) 129 | if "KIN" in model_name: 130 | plt.plot( 131 | xp_extra, 132 | p2(xp_extra) + rnd, 133 | color=hex_colors[idx], 134 | linewidth=3.0, 135 | linestyle="-", 136 | ) 137 | else: 138 | plt.plot( 139 | xp_extra, 140 | p2(xp_extra) + rnd, 141 | color=hex_colors[idx], 142 | linewidth=3.0, 143 | linestyle="--", 144 | ) 145 | 146 | plt.xticks( 147 | [256, 512, 1024, 2048, 4096, 9192], 148 | [256, 512, 1024, 2048, 4096, 9192], 149 | rotation=45, 150 | ) 151 | plt.yticks( 152 | [5000, 10000, 15000, 20000, 25000, 30000], 153 | [5, 10, 15, 20, 25, 30], 154 | ) 155 | plt.ylim(0, 32000) 156 | plt.ylabel("GPU Memory (GB)") 157 | plt.xlabel("Resolution (√x)") 158 | plt.title("Inference") 159 | plt.legend() 160 | plt.grid() 161 | # plt.show() 162 | plt.savefig("./inference_usage.jpg", bbox_inches="tight") 163 | -------------------------------------------------------------------------------- /appendix/inference_usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/appendix/inference_usage.png -------------------------------------------------------------------------------- /appendix/training_usage.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/appendix/training_usage.png -------------------------------------------------------------------------------- /combine.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | import numpy as np 6 | from PIL import Image 7 | 8 | from utils.util import read_yaml_config 9 | 10 | 11 | def main(): 12 | parser = argparse.ArgumentParser("Combined transferred images") 13 | parser.add_argument( 14 | "-c", 15 | "--config", 16 | type=str, 17 | default="./config.yaml", 18 | help="Path to the config file.", 19 | ) 20 | parser.add_argument( 21 | "--patch_size", type=int, help="Patch size", default=512 22 | ) 23 | parser.add_argument( 24 | "--resize_h", type=int, help="Resize H", default=-1 25 | ) 26 | parser.add_argument( 27 | "--resize_w", type=int, help="Resize W", default=-1 28 | ) 29 | args = parser.parse_args() 30 | 31 | config = read_yaml_config(args.config) 32 | 33 | basename = os.path.basename(config["INFERENCE_SETTING"]["TEST_X"]) 34 | filename = os.path.splitext(basename)[0] 35 | path_root = os.path.join( 36 | config["EXPERIMENT_ROOT_PATH"], 37 | config["EXPERIMENT_NAME"], 38 | "test", 39 | filename, 40 | ) 41 | 42 | if ( 43 | "OVERWRITE_OUTPUT_PATH" in config["INFERENCE_SETTING"] 44 | and config["INFERENCE_SETTING"]["OVERWRITE_OUTPUT_PATH"] != "" 45 | ): 46 | path_root = config["INFERENCE_SETTING"]["OVERWRITE_OUTPUT_PATH"] 47 | 48 | path_base = os.path.join( 49 | path_root, 50 | config["INFERENCE_SETTING"]["NORMALIZATION"]["TYPE"], 51 | config["INFERENCE_SETTING"]["MODEL_VERSION"], 52 | ) 53 | 54 | combined_image_name = f"combined_" \ 55 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['TYPE']}_" \ 56 | f"{config['INFERENCE_SETTING']['MODEL_VERSION']}.png" 57 | 58 | if config["INFERENCE_SETTING"]["NORMALIZATION"]["TYPE"] == "kin": 59 | path_base = os.path.join( 60 | path_base, 61 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}" 62 | f"_{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}", 63 | ) 64 | combined_image_name = f"combined_" \ 65 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['TYPE']}" \ 66 | f"_{config['INFERENCE_SETTING']['MODEL_VERSION']}_" \ 67 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_" \ 68 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}.png" 69 | 70 | filenames = os.listdir(path_base) 71 | try: 72 | filenames.remove("thumbnail_Y_fake.png") 73 | except Exception: 74 | pass 75 | 76 | y_anchor_max = 0 77 | x_anchor_max = 0 78 | for filename in filenames: 79 | try: 80 | _, _, y_anchor, x_anchor, _ = filename.split("_", 4) 81 | except Exception as e: 82 | raise ValueError(f"{filename} is not valid") from e 83 | y_anchor_max = max(y_anchor_max, int(y_anchor)) 84 | x_anchor_max = max(x_anchor_max, int(x_anchor)) 85 | 86 | matrix = np.zeros( 87 | (y_anchor_max + args.patch_size, x_anchor_max + args.patch_size, 3), 88 | dtype=np.uint8, 89 | ) 90 | 91 | for filename in sorted(filenames): 92 | print(f"Combine {filename} ", end="\r") 93 | _, _, y_anchor, x_anchor, _ = filename.split("_", 4) 94 | y_anchor = int(y_anchor) 95 | x_anchor = int(x_anchor) 96 | image = cv2.imread(os.path.join(path_base, filename)) 97 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 98 | matrix[y_anchor:y_anchor + 512, x_anchor:x_anchor + 512, :] = image 99 | 100 | if (args.resize_h != -1) and (args.resize_w != -1): 101 | matrix = cv2.resize( 102 | matrix, (args.resize_w, args.resize_h), cv2.INTER_CUBIC 103 | ) 104 | 105 | matrix_image = Image.fromarray(matrix) 106 | matrix_image.save(os.path.join(path_root, combined_image_name)) 107 | 108 | 109 | if __name__ == "__main__": 110 | main() 111 | -------------------------------------------------------------------------------- /crop.py: -------------------------------------------------------------------------------- 1 | # Author: Kaminyou (https://github.com/Kaminyou) 2 | import argparse 3 | import csv 4 | import os 5 | from pathlib import Path 6 | 7 | import cv2 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from utils.util import extend_size, is_blank_patch, reduce_size 12 | 13 | if __name__ == "__main__": 14 | """ 15 | USAGE 16 | 1. prepare data belongs to domain X 17 | python3 crop.py -i ./data/example/HE_cropped.jpg \ 18 | -o ./data/example/trainX/ --thumbnail \ 19 | --thumbnail_output ./data/example/trainX/ 20 | 21 | 2. prepare data belongs to domain Y 22 | python3 crop.py -i ./data/example/ER_cropped.jpg \ 23 | -o ./data/example/trainY/ --thumbnail \ 24 | --thumbnail_output ./data/example/trainY/ 25 | 26 | 3. prepare data belongs to domain X required to be transferred to domain Y 27 | python3 crop.py -i ./data/example/HE_cropped.jpg \ 28 | -o ./data/example/testX/ \ 29 | --stride 512 --thumbnail \ 30 | --thumbnail_output ./data/example/testX/ 31 | """ 32 | parser = argparse.ArgumentParser( 33 | description="Crop a large image into patches." 34 | ) 35 | parser.add_argument( 36 | "-i", 37 | "--input", 38 | help="Input image path", 39 | required=True, 40 | ) 41 | parser.add_argument( 42 | "-o", 43 | "--output", 44 | help="Output image path", 45 | default="data/initial/trainX/", 46 | ) 47 | parser.add_argument( 48 | "--thumbnail", 49 | help="If crop a thumbnail or not", 50 | action="store_true", 51 | ) 52 | parser.add_argument( 53 | "--thumbnail_output", 54 | help="Output image path", 55 | default="data/initial/", 56 | ) 57 | parser.add_argument( 58 | "--patch_size", 59 | type=int, 60 | help="Patch size", 61 | default=512, 62 | ) 63 | parser.add_argument( 64 | "--stride", 65 | type=int, 66 | help="Stride to crop patch", 67 | default=256, 68 | ) 69 | parser.add_argument( 70 | "--mode", 71 | type=str, 72 | help="reduce or extend", 73 | default="reduce", 74 | ) 75 | args = parser.parse_args() 76 | 77 | os.makedirs(args.output, exist_ok=True) 78 | 79 | image = cv2.imread(args.input) 80 | image_name = Path(args.input).stem 81 | try: 82 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 83 | except Exception: 84 | raise ValueError 85 | 86 | if args.thumbnail: 87 | thumbnail = cv2.resize( 88 | image, (args.patch_size, args.patch_size), cv2.INTER_AREA 89 | ) 90 | thumbnail_instance = Image.fromarray(thumbnail) 91 | thumbnail_instance.save( 92 | os.path.join(args.thumbnail_output, "thumbnail.png") 93 | ) 94 | 95 | h, w, c = image.shape 96 | 97 | if args.mode == "reduce": 98 | resize_fn = reduce_size 99 | resize_code = cv2.INTER_AREA 100 | elif args.mode == "extend": 101 | resize_fn = extend_size 102 | resize_code = cv2.INTER_CUBIC 103 | else: 104 | raise NotImplementedError 105 | 106 | h_resize = resize_fn(h, args.patch_size) 107 | w_resize = resize_fn(w, args.patch_size) 108 | print(f"Original size: h={h} w={w}") 109 | print(f"Resize to: h={h_resize} w={w_resize}") 110 | 111 | image = cv2.resize(image, (w_resize, h_resize), resize_code) 112 | 113 | h_anchors = np.arange(0, h_resize, args.stride) 114 | w_anchors = np.arange(0, w_resize, args.stride) 115 | output_num = len(h_anchors) * len(w_anchors) 116 | max_idx_digits = max(len(str(len(h_anchors))), len(str(len(w_anchors)))) 117 | max_anchor_digits = max(len(str(h_anchors[-1])), len(str(w_anchors[-1]))) 118 | 119 | curr_idx = 1 120 | blank_patches_list = [] 121 | for y_idx, h_anchor in enumerate(h_anchors): 122 | for x_idx, w_anchor in enumerate(w_anchors): 123 | print(f"[{curr_idx} / {output_num}] Processing ...", end="\r") 124 | image_crop = image[ 125 | h_anchor:h_anchor + args.patch_size, 126 | w_anchor:w_anchor + args.patch_size, 127 | :, 128 | ] 129 | 130 | # if stride < patch_size, some images will be cropped at the margin 131 | # e.g., stride = 256, patch_size = 512, image_size = 600 132 | # => [0, 512], [256, 600] 133 | # thus the output size should be double checked 134 | if image_crop.shape[0] != args.patch_size: 135 | continue 136 | if image_crop.shape[1] != args.patch_size: 137 | continue 138 | 139 | image_crop_instance = Image.fromarray(image_crop) 140 | 141 | # filename: {y-idx}_{x-idx}_{h-anchor}_{w-anchor}.png 142 | filename = f"{str(y_idx).zfill(max_idx_digits)}_" \ 143 | f"{str(x_idx).zfill(max_idx_digits)}_" \ 144 | f"{str(h_anchor).zfill(max_anchor_digits)}_" \ 145 | f"{str(w_anchor).zfill(max_anchor_digits)}_" \ 146 | f"{image_name}.png" 147 | 148 | image_crop_instance.save(os.path.join(args.output, filename)) 149 | blank_patches_list.append((filename, is_blank_patch(image_crop))) 150 | curr_idx += 1 151 | 152 | with open( 153 | os.path.join(args.output, "blank_patches_list.csv"), 154 | "w", 155 | encoding="UTF8", 156 | newline="", 157 | ) as f: 158 | writer = csv.writer(f) 159 | writer.writerows(blank_patches_list) 160 | -------------------------------------------------------------------------------- /crop_pipeline.py: -------------------------------------------------------------------------------- 1 | # Author: Kaminyou (https://github.com/Kaminyou) 2 | import argparse 3 | import os 4 | 5 | from utils.util import read_yaml_config 6 | 7 | if __name__ == "__main__": 8 | parser = argparse.ArgumentParser("Model inference") 9 | parser.add_argument( 10 | "-c", 11 | "--config", 12 | type=str, 13 | default="./data/example/config.yaml", 14 | help="Path to the config file.", 15 | ) 16 | args = parser.parse_args() 17 | 18 | config = read_yaml_config(args.config)["CROPPING_SETTING"] 19 | # Prepare the training set 20 | for train_x, train_y in zip(config["TRAIN_X"], config["TRAIN_Y"]): 21 | os.system( 22 | f"python3 crop.py -i {train_x} -o {config['TRAIN_DIR_X']} " 23 | f"--patch_size {config['PATCH_SIZE']} --stride {config['STRIDE']}" 24 | ) 25 | os.system( 26 | f"python3 crop.py -i {train_y} -o {config['TRAIN_DIR_Y']} " 27 | f"--patch_size {config['PATCH_SIZE']} --stride {config['STRIDE']}" 28 | ) 29 | 30 | # Prepare the testing set 31 | for test_x in config["TEST_X"]: 32 | basename = os.path.basename(test_x) 33 | filename = os.path.splitext(basename)[0] 34 | output_path = os.path.join(config['TEST_DIR_X'], filename) 35 | os.system( 36 | f"python3 crop.py -i {test_x} -o {output_path} " 37 | f"--thumbnail --thumbnail_output {output_path} " 38 | f"--patch_size {config['PATCH_SIZE']} " 39 | f"--stride {config['PATCH_SIZE']}" 40 | ) 41 | -------------------------------------------------------------------------------- /data/example/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/data/example/.gitkeep -------------------------------------------------------------------------------- /data/example/config.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT_ROOT_PATH: "./experiments/" 2 | EXPERIMENT_NAME: "example_CUT" 3 | MODEL_NAME: "CUT" # currently only CUT, cycleGAN, and LSeSim are support 4 | DEVICE: "cuda" 5 | 6 | CROPPING_SETTING: 7 | TRAIN_X: 8 | - "./data/example/HE_cropped.jpg" 9 | TRAIN_DIR_X: "./data/example/trainX/" 10 | TRAIN_Y: 11 | - "./data/example/ER_cropped.jpg" 12 | TRAIN_DIR_Y: "./data/example/trainY/" 13 | TEST_X: 14 | - "./data/example/HE_cropped.jpg" 15 | TEST_DIR_X: "./data/example/testX/" 16 | PATCH_SIZE: 512 17 | STRIDE: 256 18 | 19 | TRAINING_SETTING: 20 | PAIRED_TRAINING: False 21 | RANDOM_CROP_AUG: False 22 | TRAIN_ROOT: "./data/example/" 23 | TRAIN_DIR_X: "./data/example/trainX/" 24 | TRAIN_DIR_Y: "./data/example/trainY/" 25 | NUM_EPOCHS: 100 26 | LAMBDA_Y: 1 27 | LEARNING_RATE: 0.0002 28 | BATCH_SIZE: 1 29 | NUM_WORKERS: 8 30 | SAVE_MODEL: true 31 | SAVE_MODEL_EPOCH_STEP: 10 32 | VISUALIZATION_STEP: 250 33 | LOAD_MODEL: false 34 | LOAD_EPOCH: 0 35 | Augment: True #LSeSim 36 | 37 | INFERENCE_SETTING: 38 | TEST_X: "./data/example/HE_cropped.jpg" 39 | TEST_DIR_X: "./data/example/testX/HE_cropped/" 40 | MODEL_VERSION: "10" 41 | SAVE_ORIGINAL_IMAGE: False 42 | NORMALIZATION: 43 | TYPE: 'kin' 44 | PADDING: 1 45 | KERNEL_TYPE: 'constant' 46 | KERNEL_SIZE: 3 47 | THUMBNAIL: "./data/example/testX/HE_cropped/thumbnail.png" #"None" # set to "None" if it's not required 48 | -------------------------------------------------------------------------------- /evaluation.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | exps = ["lung_lesion"] 4 | models = ["cycleGAN", "CUT", "LSeSim"] 5 | model_epochs = [60, "latest", 100] 6 | norm_types = ["in", "tin", "kin"] 7 | kernel_types = [ 8 | None, 9 | None, 10 | [ 11 | "constant_1", 12 | "gaussian_1", 13 | "constant_3", 14 | "gaussian_3", 15 | "constant_5", 16 | "gaussian_5", 17 | ], 18 | ] 19 | 20 | for exp in exps: 21 | for i, model in enumerate(models): 22 | model_epoch = model_epochs[i] 23 | for j, norm_type in enumerate(norm_types): 24 | kernel_type = kernel_types[j] 25 | if kernel_type is None: 26 | command = f"python3 metric.py --exp_name {model}_{model_epoch}_{norm_type} --path-A ./experiments/{exp}/{exp}_{model}/test/{norm_type}/{model_epoch}/ --path-B ./data/lung_lesion/testX --blank_patches_list ./data/lung_lesion/testX/blank_patches_list.csv >> FID.out" 27 | os.system(command) 28 | else: 29 | for kernel in kernel_type: 30 | command = f"python3 metric.py --exp_name {model}_{model_epoch}_{norm_type}_{kernel} --path-A ./experiments/{exp}/{exp}_{model}/test/{norm_type}/{model_epoch}/{kernel}/ --path-B ./data/lung_lesion/testX --blank_patches_list ./data/lung_lesion/testX/blank_patches_list.csv >> FID.out" 31 | os.system(command) 32 | -------------------------------------------------------------------------------- /experiments/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/experiments/.gitkeep -------------------------------------------------------------------------------- /imgs/Figure_KIN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/imgs/Figure_KIN.jpg -------------------------------------------------------------------------------- /imgs/Figure_patch_with_patch_lineplot_block1_mean.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/imgs/Figure_patch_with_patch_lineplot_block1_mean.jpg -------------------------------------------------------------------------------- /imgs/Figure_patch_with_patch_lineplot_blocks_mean.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/imgs/Figure_patch_with_patch_lineplot_blocks_mean.jpg -------------------------------------------------------------------------------- /imgs/URUST_anime.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/imgs/URUST_anime.gif -------------------------------------------------------------------------------- /metric_images_with_ref.py: -------------------------------------------------------------------------------- 1 | import os 2 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser 3 | 4 | import torch 5 | 6 | from metrics.calculate_fid import calculate_fid_given_two_paths 7 | from metrics.inception import InceptionV3 8 | 9 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) 10 | parser.add_argument( 11 | "--exp_name", 12 | type=str, 13 | help="Experiment name" 14 | ) 15 | parser.add_argument( 16 | "--batch-size", 17 | type=int, 18 | default=50, 19 | help="Batch size to use" 20 | ) 21 | parser.add_argument( 22 | "--num-workers", 23 | type=int, 24 | help=( 25 | "Number of processes to use for data loading. " 26 | "Defaults to `min(8, num_cpus)`" 27 | ), 28 | ) 29 | parser.add_argument( 30 | "--device", 31 | type=str, 32 | default=None, 33 | help="Device to use. Like cuda, cuda:0 or cpu", 34 | ) 35 | parser.add_argument( 36 | "--dims", 37 | type=int, 38 | default=2048, 39 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM), 40 | help=( 41 | "Dimensionality of Inception features to use. " 42 | "By default, uses pool3 features" 43 | ), 44 | ) 45 | parser.add_argument( 46 | "--path-A", 47 | type=str, 48 | help=( 49 | "Paths to the original images or " 50 | "to .npz statistic files. " 51 | "Support multiple paths by using:" 52 | 'path_a1,path_a2,path_a3 ... seperated by ",". ' 53 | ), 54 | ) 55 | parser.add_argument( 56 | "--path-B", 57 | type=str, 58 | help=( 59 | "Paths to the generated images or " 60 | "to .npz statistic files. " 61 | "Support multiple paths by using:" 62 | 'path_a1,path_a2,path_a3 ... seperated by ",". ' 63 | ), 64 | ) 65 | parser.add_argument( 66 | "--blank_patches_list_A", 67 | type=str, 68 | default=None, 69 | required=False, 70 | help="Paths to the lsit of blank patches", 71 | ) 72 | parser.add_argument( 73 | "--blank_patches_list_B", 74 | type=str, 75 | default=None, 76 | required=False, 77 | help="Paths to the lsit of blank patches", 78 | ) 79 | 80 | 81 | def main(): 82 | args = parser.parse_args() 83 | 84 | if args.device is None: 85 | device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu") 86 | else: 87 | device = torch.device(args.device) 88 | 89 | if args.num_workers is None: 90 | num_avail_cpus = len(os.sched_getaffinity(0)) 91 | num_workers = min(num_avail_cpus, 8) 92 | else: 93 | num_workers = args.num_workers 94 | 95 | path_As = args.path_A.split(",") 96 | path_Bs = args.path_B.split(",") 97 | 98 | fid_value = calculate_fid_given_two_paths( 99 | path_As, 100 | path_Bs, 101 | args.batch_size, 102 | device, 103 | args.dims, 104 | num_workers, 105 | args.blank_patches_list_A, 106 | args.blank_patches_list_B, 107 | ) 108 | print(f"Exp::{args.exp_name} || FID: {fid_value:.4f}") 109 | 110 | 111 | if __name__ == "__main__": 112 | main() 113 | -------------------------------------------------------------------------------- /metric_whole_image_no_ref.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from pathlib import Path 4 | 5 | import cv2 6 | from PIL import Image 7 | 8 | from metrics.niqe import niqe 9 | from metrics.piqe import piqe 10 | from metrics.sobel import calculate_sobel_gradient_pipeline 11 | 12 | parser = argparse.ArgumentParser() 13 | parser.add_argument( 14 | "--exp_name", 15 | type=str, 16 | default="", 17 | help="Experiment name", 18 | ) 19 | parser.add_argument( 20 | "--path", 21 | type=str, 22 | required=True, 23 | help="Path to the image", 24 | ) 25 | parser.add_argument( 26 | "--save_grad", 27 | action="store_true", 28 | help="Whether to save the gardient image", 29 | ) 30 | 31 | if __name__ == "__main__": 32 | args = parser.parse_args() 33 | sobel_gradient, sobel_gradient_avg = calculate_sobel_gradient_pipeline( 34 | args.path 35 | ) 36 | 37 | im_bgr = cv2.imread(args.path) 38 | piqe_score, _, _, _ = piqe(im_bgr) 39 | niqe_score = niqe(im_bgr) 40 | 41 | img_path = Path(args.path) 42 | 43 | if args.save_grad: 44 | parent_path = img_path.parents[0] 45 | save_name = img_path.stem + "_grad.png" 46 | im = Image.fromarray(sobel_gradient) 47 | im.save(os.path.join(parent_path, save_name)) 48 | 49 | print( 50 | f"Exp::{args.exp_name}::{img_path.stem} || " 51 | f"Grad = {sobel_gradient_avg:.4f} " 52 | f"PIQE = {piqe_score:.4f} NIQE = {niqe_score:.4f}" 53 | ) 54 | -------------------------------------------------------------------------------- /metric_whole_image_with_ref.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | from metrics.histogram import compare_images_histogram_pipeline 5 | 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument( 8 | "--exp_name", 9 | type=str, 10 | default="", 11 | help="Experiment name", 12 | ) 13 | parser.add_argument( 14 | "--image_A_path", 15 | type=str, 16 | required=True, 17 | help="Path to the reference image", 18 | ) 19 | parser.add_argument( 20 | "--image_B_path", 21 | type=str, 22 | required=True, 23 | help="Path to the compared image", 24 | ) 25 | 26 | if __name__ == "__main__": 27 | args = parser.parse_args() 28 | 29 | similarity = compare_images_histogram_pipeline( 30 | args.image_A_path, 31 | args.image_B_path, 32 | ) 33 | image_A_name = Path(args.image_A_path).stem 34 | image_B_name = Path(args.image_B_path).stem 35 | print( 36 | f"Exp::{args.exp_name}::{image_A_name} {image_B_name} || " 37 | f"Histogram corr = {similarity:.4f}" 38 | ) 39 | -------------------------------------------------------------------------------- /metrics/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/metrics/__init__.py -------------------------------------------------------------------------------- /metrics/histogram.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def read_img_toRGB(img_path): 5 | img = cv2.imread(img_path) 6 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 7 | return img 8 | 9 | 10 | def calculate_histogram( 11 | image, channels=[0], hist_size=[10], hist_range=[0, 256] 12 | ): 13 | # convert to different color space if needed 14 | image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) 15 | 16 | image_hist = cv2.calcHist([image], channels, None, hist_size, hist_range) 17 | image_hist = cv2.normalize(image_hist, image_hist).flatten() 18 | return image_hist 19 | 20 | 21 | def compare_images_histogram(img_base, img_compare, method="correlation"): 22 | hist_1 = calculate_histogram(img_base) 23 | hist_2 = calculate_histogram(img_compare) 24 | 25 | if method == "intersection": 26 | comparison = cv2.compareHist(hist_1, hist_2, cv2.HISTCMP_INTERSECT) 27 | else: 28 | comparison = cv2.compareHist(hist_1, hist_2, cv2.HISTCMP_CORREL) 29 | return comparison 30 | 31 | 32 | def compare_images_histogram_pipeline(img_base_path, img_compare_path): 33 | img_base = read_img_toRGB(img_base_path) 34 | img_compare = read_img_toRGB(img_compare_path) 35 | similarity = compare_images_histogram(img_base, img_compare) 36 | return similarity 37 | -------------------------------------------------------------------------------- /metrics/niqe_image_params.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/metrics/niqe_image_params.mat -------------------------------------------------------------------------------- /metrics/sobel.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | def read_img_toYCRCB(img_path): 5 | img = cv2.imread(img_path) 6 | img = cv2.cvtColor(img, cv2.COLOR_BGR2YCR_CB) 7 | return img 8 | 9 | 10 | def calculate_YCRCB_gradient(YCRCB_img): 11 | YCRCB_img_Y_channel = YCRCB_img[..., 0] 12 | sobelx = cv2.Sobel(YCRCB_img_Y_channel, -1, 1, 0, ksize=3) 13 | sobely = cv2.Sobel(YCRCB_img_Y_channel, -1, 0, 1, ksize=3) 14 | grad = cv2.addWeighted(sobelx, 0.5, sobely, 0.5, 0) 15 | return grad 16 | 17 | 18 | def calculate_grad_avg(grad): 19 | h, w = grad.shape 20 | return grad.sum() / h / w 21 | 22 | 23 | def calculate_sobel_gradient_pipeline(img_path): 24 | img_ycrcb = read_img_toYCRCB(img_path) 25 | grad = calculate_YCRCB_gradient(img_ycrcb) 26 | grad_avg = calculate_grad_avg(grad) 27 | return grad, grad_avg 28 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/models/__init__.py -------------------------------------------------------------------------------- /models/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from models.downsample import Downsample 6 | from models.normalization import make_norm_layer 7 | 8 | 9 | class DiscriminatorBasicBlock(nn.Module): 10 | def __init__( 11 | self, 12 | in_features, 13 | out_features, 14 | do_downsample=True, 15 | do_instancenorm=True, 16 | norm_cfg=None, 17 | ): 18 | super().__init__() 19 | 20 | self.norm_cfg = norm_cfg or {'type': 'in'} 21 | self.norm_cfg = {k.lower(): v for k, v in self.norm_cfg.items()} 22 | self.do_downsample = do_downsample 23 | self.do_instancenorm = do_instancenorm 24 | 25 | self.conv = nn.Conv2d( 26 | in_features, out_features, kernel_size=4, stride=1, padding=1 27 | ) 28 | self.leakyrelu = nn.LeakyReLU(0.2, True) 29 | 30 | if do_instancenorm: 31 | self.instancenorm = make_norm_layer(self.norm_cfg, num_features=out_features) 32 | 33 | if do_downsample: 34 | self.downsample = Downsample(out_features) 35 | 36 | def forward(self, x): 37 | x = self.conv(x) 38 | if self.do_instancenorm: 39 | x = self.instancenorm(x) 40 | x = self.leakyrelu(x) 41 | if self.do_downsample: 42 | x = self.downsample(x) 43 | return x 44 | 45 | 46 | class Discriminator(nn.Module): 47 | def __init__(self, in_channels=3, features=64, avg_pooling=False): 48 | super().__init__() 49 | self.block1 = DiscriminatorBasicBlock( 50 | in_channels, 51 | features, 52 | do_downsample=True, 53 | do_instancenorm=False, 54 | ) 55 | self.block2 = DiscriminatorBasicBlock( 56 | features, 57 | features * 2, 58 | do_downsample=True, 59 | do_instancenorm=True, 60 | ) 61 | self.block3 = DiscriminatorBasicBlock( 62 | features * 2, 63 | features * 4, 64 | do_downsample=True, 65 | do_instancenorm=True, 66 | ) 67 | self.block4 = DiscriminatorBasicBlock( 68 | features * 4, 69 | features * 8, 70 | do_downsample=False, 71 | do_instancenorm=True, 72 | ) 73 | self.conv = nn.Conv2d( 74 | features * 8, 75 | 1, 76 | kernel_size=4, 77 | stride=1, 78 | padding=1, 79 | ) 80 | self.avg_pooling = avg_pooling 81 | 82 | def forward(self, x): 83 | x = self.block1(x) 84 | x = self.block2(x) 85 | x = self.block3(x) 86 | x = self.block4(x) 87 | x = self.conv(x) 88 | if self.avg_pooling: 89 | x = F.avg_pool2d(x, x.size()[2:]) 90 | x = torch.flatten(x, 1) 91 | return x 92 | 93 | def set_requires_grad(self, requires_grad=False): 94 | for param in self.parameters(): 95 | param.requires_grad = requires_grad 96 | 97 | 98 | if __name__ == "__main__": 99 | x = torch.randn((5, 3, 256, 256)) 100 | print(x.shape) 101 | model = Discriminator(in_channels=3, avg_pooling=True) 102 | preds = model(x) 103 | print(preds.shape) 104 | -------------------------------------------------------------------------------- /models/downsample.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Downsample(nn.Module): 5 | def __init__(self, features): 6 | super().__init__() 7 | self.reflectionpad = nn.ReflectionPad2d(1) 8 | self.conv = nn.Conv2d(features, features, kernel_size=3, stride=2) 9 | 10 | def forward(self, x): 11 | x = self.reflectionpad(x) 12 | x = self.conv(x) 13 | return x 14 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | from models.cut import ContrastiveModel 2 | from models.cyclegan import CycleGanModel 3 | 4 | 5 | # XXX: Check if isTrain is used 6 | def get_model(config, model_name="CUT", norm_cfg=None, isTrain=True): 7 | if model_name == "CUT": 8 | model = ContrastiveModel(config, norm_cfg=norm_cfg) 9 | elif model_name == "cycleGAN": 10 | model = CycleGanModel(config, norm_cfg=norm_cfg) 11 | elif model_name == "LSeSim": 12 | print("Please use the scripts prepared in the F-LSeSim folder") 13 | raise NotImplementedError 14 | else: 15 | raise NotImplementedError 16 | return model 17 | -------------------------------------------------------------------------------- /models/normalization.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | from typing import Any, Dict 3 | 4 | import torch.nn as nn 5 | 6 | from models.kin import KernelizedInstanceNorm 7 | from models.tin import ThumbInstanceNorm 8 | 9 | 10 | # TODO: To be deprecated 11 | def get_normalization_layer(num_features, normalization="kin"): 12 | if normalization == "kin": 13 | return KernelizedInstanceNorm(num_features=num_features) 14 | elif normalization == "tin": 15 | return ThumbInstanceNorm(num_features=num_features) 16 | elif normalization == "in": 17 | return nn.InstanceNorm2d(num_features) 18 | else: 19 | raise NotImplementedError 20 | 21 | 22 | def make_norm_layer(norm_cfg: Dict[str, Any], **kwargs: Any): 23 | """ 24 | Create normalization layer based on given config and arguments. 25 | 26 | Args: 27 | norm_cfg (Dict[str, Any]): A dict of keyword arguments of normalization layer. 28 | It must have a key 'type' to specify which normalization layers will be used. 29 | It accepts upper case argument. 30 | **kwargs (Any): The keyword arguments are used to overwrite `norm_cfg`. 31 | 32 | Returns: 33 | nn.Module: A layer object. 34 | """ 35 | norm_cfg = deepcopy(norm_cfg) 36 | norm_cfg = {k.lower(): v for k, v in norm_cfg.items()} 37 | 38 | norm_cfg.update(kwargs) 39 | 40 | if 'type' not in norm_cfg: 41 | raise ValueError('"type" wasn\'t specified.') 42 | 43 | norm_type = norm_cfg['type'] 44 | del norm_cfg['type'] 45 | 46 | if norm_type == 'in': 47 | return nn.InstanceNorm2d(**norm_cfg) 48 | elif norm_type == 'tin': 49 | return ThumbInstanceNorm(**norm_cfg) 50 | elif norm_type == 'kin': 51 | return KernelizedInstanceNorm(**norm_cfg) 52 | else: 53 | raise ValueError(f'Unknown norm type: {norm_type}.') 54 | -------------------------------------------------------------------------------- /models/projector.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from models.generator import Generator 5 | 6 | 7 | class MLP(nn.Module): 8 | def __init__(self, input_nc, output_nc): 9 | super().__init__() 10 | self.mlp = nn.Sequential( 11 | *[ 12 | nn.Linear(input_nc, output_nc), 13 | nn.ReLU(), 14 | nn.Linear(output_nc, output_nc), 15 | ] 16 | ) 17 | 18 | def forward(self, x): 19 | return self.mlp(x) 20 | 21 | 22 | class Head(nn.Module): 23 | def __init__(self, in_channels=3, features=64, residuals=9): 24 | super().__init__() 25 | self.mlp_0 = MLP(3, 256) 26 | self.mlp_1 = MLP(128, 256) 27 | self.mlp_2 = MLP(256, 256) 28 | self.mlp_3 = MLP(256, 256) 29 | self.mlp_4 = MLP(256, 256) 30 | 31 | def forward(self, features): 32 | return_features = [] 33 | for feature_id, feature in enumerate(features): 34 | mlp = getattr(self, f"mlp_{feature_id}") 35 | feature = mlp(feature) 36 | norm = feature.pow(2).sum(1, keepdim=True).pow(1.0 / 2) 37 | feature = feature.div(norm + 1e-7) 38 | return_features.append(feature) 39 | return return_features 40 | 41 | 42 | if __name__ == "__main__": 43 | x = torch.randn((5, 3, 256, 256)) 44 | print(x.shape) 45 | G = Generator() 46 | H = Head() 47 | feat_k_pool, sample_ids = G(x, encode_only=True, patch_ids=None) 48 | feat_q_pool, _ = G(x, encode_only=True, patch_ids=sample_ids) 49 | print(len(feat_k_pool)) 50 | return_features = H(feat_q_pool) 51 | print(len(return_features)) 52 | -------------------------------------------------------------------------------- /models/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/models/tests/__init__.py -------------------------------------------------------------------------------- /models/tests/test_cut.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from models.model import get_model 7 | from models.kin import KernelizedInstanceNorm 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from torchvision.utils import make_grid 11 | from utils.dataset import XInferenceDataset 12 | from utils.util import (read_yaml_config, reverse_image_normalize, 13 | test_transforms) 14 | 15 | 16 | @pytest.fixture() 17 | def config(): 18 | config_path = os.path.join( 19 | "./test_data", 20 | "configs", 21 | "config_lung_lesion_for_test_cut.yaml" 22 | ) 23 | config = read_yaml_config(config_path) 24 | 25 | return config 26 | 27 | 28 | @pytest.fixture() 29 | def in_model(config): 30 | model = get_model( 31 | config=config, 32 | model_name=config["MODEL_NAME"], 33 | norm_cfg={'type': 'in'}, 34 | isTrain=False, 35 | ) 36 | model.load_networks(config["INFERENCE_SETTING"]["MODEL_VERSION"]) 37 | 38 | return model 39 | 40 | 41 | @pytest.fixture() 42 | def kin_model(config): 43 | model = get_model( 44 | config=config, 45 | model_name=config["MODEL_NAME"], 46 | norm_cfg=config["INFERENCE_SETTING"]["NORMALIZATION"], 47 | isTrain=False, 48 | ) 49 | model.load_networks(config["INFERENCE_SETTING"]["MODEL_VERSION"]) 50 | model.eval() 51 | 52 | return model 53 | 54 | 55 | @pytest.fixture() 56 | def dataset(config): 57 | dataset = XInferenceDataset( 58 | root_X=config["INFERENCE_SETTING"]["TEST_DIR_X"], 59 | transform=test_transforms, 60 | return_anchor=True, 61 | ) 62 | return dataset 63 | 64 | 65 | @pytest.fixture() 66 | def dataloader(dataset): 67 | loader = DataLoader( 68 | dataset, batch_size=1, shuffle=False, pin_memory=True 69 | ) 70 | return loader 71 | 72 | 73 | @pytest.fixture() 74 | def in_expected_outputs(config): 75 | expected_output_dir = os.path.join( 76 | config["INFERENCE_SETTING"]["TEST_DIR_Y"], 77 | config["EXPERIMENT_NAME"], 78 | "in", 79 | f"{config['INFERENCE_SETTING']['MODEL_VERSION']}", 80 | 81 | ) 82 | expected_output_files = sorted( 83 | os.listdir(expected_output_dir) 84 | ) 85 | 86 | expected_outputs = [] 87 | for expected_output_file in expected_output_files: 88 | expected_outputs.append(np.array( 89 | Image.open(os.path.join( 90 | expected_output_dir, 91 | expected_output_file 92 | )).convert("RGB") 93 | )) 94 | 95 | return expected_outputs 96 | 97 | 98 | @pytest.fixture() 99 | def kin_expected_outputs(config): 100 | expected_output_dir = os.path.join( 101 | config["INFERENCE_SETTING"]["TEST_DIR_Y"], 102 | config["EXPERIMENT_NAME"], 103 | "kin", 104 | config["INFERENCE_SETTING"]["MODEL_VERSION"], 105 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_" 106 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}" 107 | ) 108 | expected_output_files = sorted( 109 | os.listdir(expected_output_dir) 110 | ) 111 | 112 | expected_outputs = [] 113 | for expected_output_file in expected_output_files: 114 | expected_outputs.append(np.array( 115 | Image.open(os.path.join( 116 | expected_output_dir, 117 | expected_output_file 118 | )).convert("RGB") 119 | )) 120 | 121 | return expected_outputs 122 | 123 | 124 | def test_inference(in_model, dataloader, in_expected_outputs): 125 | """ 126 | Integration testing for IN inferece 127 | """ 128 | 129 | for idx, data in enumerate(dataloader): 130 | X, _ = data["X_img"], data["X_path"] 131 | Y_fake = in_model.inference(X) 132 | test_output_tensor = reverse_image_normalize(Y_fake) 133 | 134 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 135 | test_output = make_grid(test_output_tensor).mul(255).add_(0.5).clamp_( 136 | 0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() 137 | expected_output = in_expected_outputs[idx] 138 | 139 | assert test_output == pytest.approx(expected_output) 140 | 141 | 142 | def test_inference_with_anchor( 143 | config, 144 | kin_model, 145 | dataset, 146 | dataloader, 147 | kin_expected_outputs 148 | ): 149 | """ 150 | Integration testing for KIN inferece 151 | """ 152 | 153 | y_anchor_num, x_anchor_num = dataset.get_boundary() 154 | 155 | kin_model.init_kernelized_instance_norm_for_whole_model( 156 | y_anchor_num=y_anchor_num + 1, 157 | x_anchor_num=x_anchor_num + 1, 158 | ) 159 | 160 | for idx, data in enumerate(dataloader): 161 | X, _, y_anchor, x_anchor = ( 162 | data["X_img"], 163 | data["X_path"], 164 | data["y_idx"], 165 | data["x_idx"], 166 | ) 167 | _ = kin_model.inference_with_anchor( 168 | X, 169 | y_anchor=y_anchor, 170 | x_anchor=x_anchor, 171 | mode=KernelizedInstanceNorm.Mode.PHASE_CACHING, 172 | ) 173 | 174 | for idx, data in enumerate(dataloader): 175 | X, _, y_anchor, x_anchor = ( 176 | data["X_img"], 177 | data["X_path"], 178 | data["y_idx"], 179 | data["x_idx"], 180 | ) 181 | Y_fake = kin_model.inference_with_anchor( 182 | X, 183 | y_anchor=y_anchor, 184 | x_anchor=x_anchor, 185 | mode=KernelizedInstanceNorm.Mode.PHASE_INFERENCE, 186 | ) 187 | test_output_tensor = reverse_image_normalize(Y_fake) 188 | 189 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 190 | test_output = make_grid(test_output_tensor).mul(255).add_(0.5).clamp_( 191 | 0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() 192 | expected_output = kin_expected_outputs[idx] 193 | 194 | assert test_output == pytest.approx(expected_output) 195 | -------------------------------------------------------------------------------- /models/tests/test_cyclegan.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pytest 5 | import torch 6 | from models.model import get_model 7 | from models.kin import KernelizedInstanceNorm 8 | from PIL import Image 9 | from torch.utils.data import DataLoader 10 | from torchvision.utils import make_grid 11 | from utils.dataset import XInferenceDataset 12 | from utils.util import (read_yaml_config, reverse_image_normalize, 13 | test_transforms) 14 | 15 | 16 | @pytest.fixture() 17 | def config(): 18 | config_path = os.path.join( 19 | "./test_data", 20 | "configs", 21 | "config_lung_lesion_for_test_cyclegan.yaml" 22 | ) 23 | config = read_yaml_config(config_path) 24 | 25 | return config 26 | 27 | 28 | @pytest.fixture() 29 | def in_model(config): 30 | model = get_model( 31 | config=config, 32 | model_name=config["MODEL_NAME"], 33 | norm_cfg={'type': 'in'}, 34 | isTrain=False, 35 | ) 36 | model.load_networks(config["INFERENCE_SETTING"]["MODEL_VERSION"]) 37 | 38 | return model 39 | 40 | 41 | @pytest.fixture() 42 | def kin_model(config): 43 | model = get_model( 44 | config=config, 45 | model_name=config["MODEL_NAME"], 46 | norm_cfg=config['INFERENCE_SETTING']['NORMALIZATION'], 47 | isTrain=False, 48 | ) 49 | model.load_networks(config["INFERENCE_SETTING"]["MODEL_VERSION"]) 50 | model.eval() 51 | 52 | return model 53 | 54 | 55 | @pytest.fixture() 56 | def dataset(config): 57 | dataset = XInferenceDataset( 58 | root_X=config["INFERENCE_SETTING"]["TEST_DIR_X"], 59 | transform=test_transforms, 60 | return_anchor=True, 61 | ) 62 | return dataset 63 | 64 | 65 | @pytest.fixture() 66 | def dataloader(dataset): 67 | loader = DataLoader( 68 | dataset, batch_size=1, shuffle=False, pin_memory=True 69 | ) 70 | return loader 71 | 72 | 73 | @pytest.fixture() 74 | def in_expected_outputs(config): 75 | expected_output_dir = os.path.join( 76 | config["INFERENCE_SETTING"]["TEST_DIR_Y"], 77 | config["EXPERIMENT_NAME"], 78 | "in", 79 | f"{config['INFERENCE_SETTING']['MODEL_VERSION']}", 80 | 81 | ) 82 | expected_output_files = sorted( 83 | os.listdir(expected_output_dir) 84 | ) 85 | 86 | expected_outputs = [] 87 | for expected_output_file in expected_output_files: 88 | expected_outputs.append(np.array( 89 | Image.open(os.path.join( 90 | expected_output_dir, 91 | expected_output_file 92 | )).convert("RGB") 93 | )) 94 | 95 | return expected_outputs 96 | 97 | 98 | @pytest.fixture() 99 | def kin_expected_outputs(config): 100 | expected_output_dir = os.path.join( 101 | config["INFERENCE_SETTING"]["TEST_DIR_Y"], 102 | config["EXPERIMENT_NAME"], 103 | "kin", 104 | config["INFERENCE_SETTING"]["MODEL_VERSION"], 105 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_" 106 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}", 107 | 108 | ) 109 | expected_output_files = sorted( 110 | os.listdir(expected_output_dir) 111 | ) 112 | 113 | expected_outputs = [] 114 | for expected_output_file in expected_output_files: 115 | expected_outputs.append(np.array( 116 | Image.open(os.path.join( 117 | expected_output_dir, 118 | expected_output_file 119 | )).convert("RGB") 120 | )) 121 | 122 | return expected_outputs 123 | 124 | 125 | def test_inference(in_model, dataloader, in_expected_outputs): 126 | """ 127 | Integration testing for IN inferece 128 | """ 129 | 130 | for idx, data in enumerate(dataloader): 131 | X, _ = data["X_img"], data["X_path"] 132 | Y_fake = in_model.inference(X) 133 | test_output_tensor = reverse_image_normalize(Y_fake) 134 | 135 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 136 | test_output = make_grid(test_output_tensor).mul(255).add_(0.5).clamp_( 137 | 0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() 138 | expected_output = in_expected_outputs[idx] 139 | 140 | assert test_output == pytest.approx(expected_output) 141 | 142 | 143 | def test_inference_with_anchor( 144 | config, 145 | kin_model, 146 | dataset, 147 | dataloader, 148 | kin_expected_outputs 149 | ): 150 | """ 151 | Integration testing for KIN inferece 152 | """ 153 | 154 | y_anchor_num, x_anchor_num = dataset.get_boundary() 155 | 156 | kin_model.init_kernelized_instance_norm_for_whole_model( 157 | y_anchor_num=y_anchor_num + 1, 158 | x_anchor_num=x_anchor_num + 1, 159 | ) 160 | 161 | for idx, data in enumerate(dataloader): 162 | X, _, y_anchor, x_anchor = ( 163 | data["X_img"], 164 | data["X_path"], 165 | data["y_idx"], 166 | data["x_idx"], 167 | ) 168 | _ = kin_model.inference_with_anchor( 169 | X, 170 | y_anchor=y_anchor, 171 | x_anchor=x_anchor, 172 | mode=KernelizedInstanceNorm.Mode.PHASE_CACHING, 173 | ) 174 | 175 | for idx, data in enumerate(dataloader): 176 | X, _, y_anchor, x_anchor = ( 177 | data["X_img"], 178 | data["X_path"], 179 | data["y_idx"], 180 | data["x_idx"], 181 | ) 182 | Y_fake = kin_model.inference_with_anchor( 183 | X, 184 | y_anchor=y_anchor, 185 | x_anchor=x_anchor, 186 | mode=KernelizedInstanceNorm.Mode.PHASE_INFERENCE, 187 | ) 188 | test_output_tensor = reverse_image_normalize(Y_fake) 189 | 190 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer 191 | test_output = make_grid(test_output_tensor).mul(255).add_(0.5).clamp_( 192 | 0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy() 193 | expected_output = kin_expected_outputs[idx] 194 | 195 | assert test_output == pytest.approx(expected_output) 196 | -------------------------------------------------------------------------------- /models/tests/test_data/configs/config_lung_lesion_for_test_cut.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT_ROOT_PATH: "./test_data/checkpoints/lung_lesion" 2 | EXPERIMENT_NAME: "lung_lesion_CUT" 3 | MODEL_NAME: "CUT" # currently only CUT and cycleGAN are support 4 | DEVICE: "cuda" 5 | 6 | TRAINING_SETTING: 7 | PAIRED_TRAINING: False 8 | NUM_EPOCHS: 100 9 | LAMBDA_Y: 1 10 | LEARNING_RATE: 0.0002 11 | BATCH_SIZE: 1 12 | NUM_WORKERS: 4 13 | SAVE_MODEL: true 14 | SAVE_MODEL_EPOCH_STEP: 10 15 | VISUALIZATION_STEP: 250 16 | LOAD_MODEL: false 17 | LOAD_EPOCH: 0 18 | 19 | INFERENCE_SETTING: 20 | TEST_DIR_X: "./test_data/test_dir_x/lung_lesion/" 21 | TEST_DIR_Y: "./test_data/test_dir_y/lung_lesion" 22 | MODEL_VERSION: "latest" 23 | SAVE_ORIGINAL_IMAGE: False 24 | NORMALIZATION: 25 | TYPE: "kin" 26 | PADDING: 1 27 | KERNEL_TYPE: 'gaussian' 28 | KERNEL_SIZE: 3 29 | THUMBNAIL: "None" # set to "None" if it's not required 30 | -------------------------------------------------------------------------------- /models/tests/test_data/configs/config_lung_lesion_for_test_cyclegan.yaml: -------------------------------------------------------------------------------- 1 | EXPERIMENT_ROOT_PATH: "./test_data/checkpoints/lung_lesion" 2 | EXPERIMENT_NAME: "lung_lesion_cycleGAN" 3 | MODEL_NAME: "cycleGAN" # currently only CUT and cycleGAN are support 4 | DEVICE: "cuda" 5 | 6 | TRAINING_SETTING: 7 | PAIRED_TRAINING: False 8 | NUM_EPOCHS: 100 9 | LAMBDA_Y: 1 10 | LEARNING_RATE: 0.0002 11 | BATCH_SIZE: 1 12 | NUM_WORKERS: 4 13 | SAVE_MODEL: true 14 | SAVE_MODEL_EPOCH_STEP: 10 15 | VISUALIZATION_STEP: 250 16 | LOAD_MODEL: false 17 | LOAD_EPOCH: 0 18 | 19 | INFERENCE_SETTING: 20 | TEST_DIR_X: "./test_data/test_dir_x/lung_lesion/" 21 | TEST_DIR_Y: "./test_data/test_dir_y/lung_lesion" 22 | MODEL_VERSION: "60" 23 | SAVE_ORIGINAL_IMAGE: False 24 | NORMALIZATION: 25 | TYPE: "kin" 26 | PADDING: 1 27 | KERNEL_SIZE: 3 28 | KERNEL_TYPE: 'gaussian' 29 | THUMBNAIL: "None" # set to "None" if it's not required 30 | -------------------------------------------------------------------------------- /models/tests/test_kin.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pytest 3 | import torch 4 | 5 | from ..kin import KernelizedInstanceNorm 6 | 7 | 8 | def normalize(x): 9 | std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True) 10 | return (x - mean) / std 11 | 12 | 13 | def test_forward_normal(): 14 | layer = KernelizedInstanceNorm(num_features=3, device='cpu') 15 | x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32) 16 | x = torch.FloatTensor(x) 17 | 18 | expected = normalize(x) 19 | 20 | check = layer.forward_normal(torch.FloatTensor(x)) 21 | 22 | assert check.numpy() == pytest.approx(expected, abs=1e-6) 23 | 24 | 25 | def test_init_collection(): 26 | layer = KernelizedInstanceNorm(num_features=3, device='cpu') 27 | layer.init_collection(y_anchor_num=10, x_anchor_num=9) 28 | 29 | expected_mean_table = np.zeros(shape=(10, 9, 3)) 30 | expected_std_table = np.zeros(shape=(10, 9, 3)) 31 | 32 | np.testing.assert_array_equal(layer.mean_table.numpy(), expected_mean_table) 33 | np.testing.assert_array_equal(layer.std_table.numpy(), expected_std_table) 34 | 35 | 36 | def test_forward_without_anchors(): 37 | layer = KernelizedInstanceNorm(num_features=3, device='cpu') 38 | x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32) 39 | x = torch.FloatTensor(x) 40 | expected = normalize(x) 41 | 42 | check = layer.forward(torch.FloatTensor(x)) 43 | 44 | assert check.numpy() == pytest.approx(expected, abs=1e-6) 45 | 46 | 47 | def test_forward_with_mode_1(): 48 | layer = KernelizedInstanceNorm(num_features=3, kernel_type='constant', device='cpu').eval() 49 | layer.init_collection(y_anchor_num=3, x_anchor_num=3) 50 | 51 | x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32) 52 | x = torch.FloatTensor(x) 53 | 54 | std, mean = torch.std_mean(x, dim=(2, 3)) 55 | 56 | expected_mean_table = np.zeros(shape=(3, 3, 3), dtype=np.float32) 57 | expected_std_table = np.zeros(shape=(3, 3, 3), dtype=np.float32) 58 | 59 | expected_mean_table[0, 0] = mean 60 | expected_std_table[0, 0] = std 61 | 62 | check = layer.forward(x, x_anchor=0, y_anchor=0, mode=KernelizedInstanceNorm.Mode.PHASE_CACHING) 63 | 64 | assert check.detach().numpy() == pytest.approx(normalize(x).numpy(), abs=1e-6) 65 | assert layer.mean_table.numpy() == pytest.approx(expected_mean_table) 66 | assert layer.std_table.numpy() == pytest.approx(expected_std_table) 67 | 68 | 69 | def test_forward_with_mode_2(): 70 | layer = KernelizedInstanceNorm(num_features=3, kernel_type='constant', device='cpu').eval() 71 | layer.init_collection(y_anchor_num=3, x_anchor_num=3) 72 | 73 | x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32) 74 | x = torch.FloatTensor(x) 75 | 76 | layer.forward(x, x_anchor=1, y_anchor=1, mode=KernelizedInstanceNorm.Mode.PHASE_CACHING) 77 | check = layer.forward(x, x_anchor=1, y_anchor=1, mode=KernelizedInstanceNorm.Mode.PHASE_INFERENCE) 78 | std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True) 79 | 80 | mean /= 9 81 | std /= 9 82 | 83 | expected = (x - mean) / std 84 | 85 | assert check.detach().numpy() == pytest.approx(expected) 86 | -------------------------------------------------------------------------------- /models/tin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ThumbInstanceNorm(nn.Module): 6 | def __init__(self, num_features, affine=True): 7 | super(ThumbInstanceNorm, self).__init__() 8 | self.thumb_mean = None 9 | self.thumb_std = None 10 | self.normal_instance_normalization = False 11 | self.collection_mode = False 12 | if affine: 13 | self.weight = nn.Parameter( 14 | torch.ones(size=(1, num_features, 1, 1), requires_grad=True) 15 | ) 16 | self.bias = nn.Parameter( 17 | torch.zeros(size=(1, num_features, 1, 1), requires_grad=True) 18 | ) 19 | 20 | def calc_mean_std(self, feat, eps=1e-5): 21 | size = feat.size() 22 | assert len(size) == 4 23 | N, C = size[:2] 24 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 25 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 26 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 27 | return feat_mean, feat_std 28 | 29 | def forward(self, x): 30 | if self.training or self.normal_instance_normalization: 31 | x_mean, x_std = self.calc_mean_std(x) 32 | x = (x - x_mean) / x_std * self.weight + self.bias 33 | return x 34 | else: 35 | if self.collection_mode: 36 | assert x.shape[0] == 1 37 | x_mean, x_std = self.calc_mean_std(x) 38 | self.thumb_mean = x_mean 39 | self.thumb_std = x_std 40 | 41 | shift = x - self.thumb_mean 42 | x = shift / self.thumb_std * self.weight + self.bias 43 | return x 44 | 45 | 46 | def not_use_thumbnail_instance_norm(model): 47 | for _, layer in model.named_modules(): 48 | if isinstance(layer, ThumbInstanceNorm): 49 | layer.collection_mode = False 50 | layer.normal_instance_normalization = True 51 | 52 | 53 | def init_thumbnail_instance_norm(model): 54 | for _, layer in model.named_modules(): 55 | if isinstance(layer, ThumbInstanceNorm): 56 | layer.collection_mode = True 57 | 58 | 59 | def use_thumbnail_instance_norm(model): 60 | for _, layer in model.named_modules(): 61 | if isinstance(layer, ThumbInstanceNorm): 62 | layer.collection_mode = False 63 | layer.normal_instance_normalization = False 64 | -------------------------------------------------------------------------------- /models/upsample.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Upsample(nn.Module): 5 | def __init__(self, features): 6 | super().__init__() 7 | layers = [ 8 | nn.ReplicationPad2d(1), 9 | nn.ConvTranspose2d( 10 | features, 11 | features, 12 | kernel_size=4, 13 | stride=2, 14 | padding=3, 15 | ), 16 | ] 17 | self.model = nn.Sequential(*layers) 18 | 19 | def forward(self, input): 20 | return self.model(input) 21 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==0.5.2 2 | numpy==1.22.4 3 | Pillow>=9.2.0 4 | PyYAML==5.4.1 5 | torch==1.7.0 6 | torchvision==0.8.0 7 | pytest 8 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from collections import defaultdict 4 | 5 | from torch.utils.data import DataLoader 6 | from torchvision.utils import save_image 7 | 8 | from models.model import get_model 9 | from utils.dataset import get_dataset 10 | from utils.util import read_yaml_config, reverse_image_normalize 11 | 12 | 13 | def main(): 14 | parser = argparse.ArgumentParser("Model training") 15 | parser.add_argument( 16 | "-c", 17 | "--config", 18 | type=str, 19 | default="./config.yaml", 20 | help="Path to the config file.", 21 | ) 22 | args = parser.parse_args() 23 | 24 | config = read_yaml_config(args.config) 25 | 26 | model = get_model(config=config, model_name=config["MODEL_NAME"], isTrain=True) 27 | 28 | dataset = get_dataset(config) 29 | 30 | dataloader = DataLoader( 31 | dataset, 32 | batch_size=config["TRAINING_SETTING"]["BATCH_SIZE"], 33 | shuffle=True, 34 | num_workers=config["TRAINING_SETTING"]["NUM_WORKERS"], 35 | ) 36 | 37 | for epoch in range(config["TRAINING_SETTING"]["NUM_EPOCHS"]): 38 | out = defaultdict(int) 39 | 40 | for idx, data in enumerate(dataloader): 41 | print(f"[Epoch {epoch}][Iter {idx}] Processing ...", end="\r") 42 | if epoch == 0 and idx == 0: 43 | model.data_dependent_initialize(data) 44 | model.setup() 45 | 46 | model.set_input(data) 47 | model.optimize_parameters() 48 | 49 | if idx % config["TRAINING_SETTING"]["VISUALIZATION_STEP"] == 0 and idx > 0: 50 | results = model.get_current_visuals() 51 | 52 | for img_name, img in results.items(): 53 | save_image( 54 | reverse_image_normalize(img), 55 | os.path.join( 56 | config["EXPERIMENT_ROOT_PATH"], 57 | config["EXPERIMENT_NAME"], 58 | "train", 59 | f"{epoch}_{img_name}_{idx}.png", 60 | ), 61 | ) 62 | 63 | for k, v in out.items(): 64 | out[k] /= config["TRAINING_SETTING"]["VISUALIZATION_STEP"] 65 | 66 | print(f"[Epoch {epoch}][Iter {idx}] {out}", flush=True) 67 | for k, v in out.items(): 68 | out[k] = 0 69 | 70 | losses = model.get_current_losses() 71 | for k, v in losses.items(): 72 | out[k] += v 73 | 74 | model.scheduler_step() 75 | if ( 76 | epoch % config["TRAINING_SETTING"]["SAVE_MODEL_EPOCH_STEP"] == 0 77 | and config["TRAINING_SETTING"]["SAVE_MODEL"] 78 | ): 79 | model.save_networks(epoch) 80 | 81 | model.save_networks("latest") 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /transfer.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | 4 | import cv2 5 | 6 | from utils.util import read_yaml_config 7 | 8 | 9 | def main(): 10 | """ 11 | USAGE 12 | python3 transfer.py -c config_example.yaml 13 | or 14 | python3 transfer.py -c config_example.yaml --skip_cropping 15 | """ 16 | parser = argparse.ArgumentParser("Model inference") 17 | parser.add_argument( 18 | "-c", 19 | "--config", 20 | type=str, 21 | default="./data/example/config.yaml", 22 | help="Path to the config file.", 23 | ) 24 | parser.add_argument("--skip_cropping", action="store_true") 25 | args = parser.parse_args() 26 | 27 | config = read_yaml_config(args.config) 28 | H, W, _ = cv2.imread(config["INFERENCE_SETTING"]["TEST_X"]).shape 29 | if not args.skip_cropping: 30 | os.system( 31 | f"python3 crop.py -i {config['INFERENCE_SETTING']['TEST_X']} " 32 | f"-o {config['INFERENCE_SETTING']['TEST_DIR_X']} " 33 | f"--patch_size {config['CROPPING_SETTING']['PATCH_SIZE']} " 34 | f"--stride {config['CROPPING_SETTING']['PATCH_SIZE']} " 35 | f"--thumbnail " 36 | f"--thumbnail_output {config['INFERENCE_SETTING']['TEST_DIR_X']}" 37 | ) 38 | print("Finish cropping and start inference") 39 | os.system(f"python3 inference.py --config {args.config}") 40 | print("Finish inference and start combining images") 41 | os.system( 42 | f"python3 combine.py --config {args.config} " 43 | f"--resize_h {H} --resize_w {W}" 44 | ) 45 | 46 | 47 | if __name__ == "__main__": 48 | main() 49 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/utils/__init__.py --------------------------------------------------------------------------------