├── .gitignore
├── F-LSeSim
├── LICENSE
├── README.md
├── combine.py
├── data
│ ├── __init__.py
│ ├── aligned_dataset.py
│ ├── base_dataset.py
│ ├── colorization_dataset.py
│ ├── dataset.py
│ ├── image_folder.py
│ ├── single_dataset.py
│ ├── singleimage_dataset.py
│ ├── template_dataset.py
│ └── unaligned_dataset.py
├── datasets
│ ├── bibtex
│ │ ├── cityscapes.tex
│ │ ├── facades.tex
│ │ ├── handbags.tex
│ │ ├── shoes.tex
│ │ └── transattr.tex
│ ├── combine_A_and_B.py
│ ├── download_cyclegan_dataset.sh
│ ├── download_pix2pix_dataset.sh
│ ├── make_dataset_aligned.py
│ └── prepare_cityscapes_dataset.py
├── evaluations
│ ├── DC.py
│ ├── __init__.py
│ ├── fid_score.py
│ └── inception.py
├── inference.py
├── models
│ ├── __init__.py
│ ├── base_model.py
│ ├── colorization_model.py
│ ├── cycle_gan_model.py
│ ├── cyclegan_networks.py
│ ├── discriminator.py
│ ├── downsample.py
│ ├── generator.py
│ ├── kin.py
│ ├── losses.py
│ ├── networks.py
│ ├── normalization.py
│ ├── pix2pix_model.py
│ ├── sc_model.py
│ ├── sinsc_model.py
│ ├── stylegan_networks.py
│ ├── template_model.py
│ ├── test_model.py
│ ├── tin.py
│ ├── upsample.py
│ └── util.py
├── options
│ ├── __init__.py
│ ├── base_options.py
│ ├── test_options.py
│ └── train_options.py
├── scripts
│ ├── conda_deps.sh
│ ├── download_cyclegan_model.sh
│ ├── download_pix2pix_model.sh
│ ├── edges
│ │ ├── PostprocessHED.m
│ │ └── batch_hed.py
│ ├── eval_cityscapes
│ │ ├── caffemodel
│ │ │ └── deploy.prototxt
│ │ ├── cityscapes.py
│ │ ├── download_fcn8s.sh
│ │ ├── evaluate.py
│ │ └── util.py
│ ├── install_deps.sh
│ ├── test_before_push.py
│ ├── test_colorization.sh
│ ├── test_cyclegan.sh
│ ├── test_fid.sh
│ ├── test_pix2pix.sh
│ ├── test_sc.sh
│ ├── test_single.sh
│ ├── test_sinsc.sh
│ ├── train_colorization.sh
│ ├── train_cyclegan.sh
│ ├── train_pix2pix.sh
│ ├── train_sc.sh
│ ├── train_sinsc.sh
│ └── transfer_sc.sh
├── test.py
├── test_fid.py
├── train.py
└── util
│ ├── __init__.py
│ ├── get_data.py
│ ├── html.py
│ ├── image_pool.py
│ ├── util.py
│ └── visualizer.py
├── LICENSE
├── README.md
├── appendix
├── gpu_memory.py
├── inference_usage.png
├── proof_of_concept.py
└── training_usage.png
├── combine.py
├── crop.py
├── crop_pipeline.py
├── data
└── example
│ ├── .gitkeep
│ └── config.yaml
├── evaluation.py
├── experiments
└── .gitkeep
├── imgs
├── Figure_KIN.jpg
├── Figure_patch_with_patch_lineplot_block1_mean.jpg
├── Figure_patch_with_patch_lineplot_blocks_mean.jpg
└── URUST_anime.gif
├── inference.py
├── metric_images_with_ref.py
├── metric_whole_image_no_ref.py
├── metric_whole_image_with_ref.py
├── metrics
├── __init__.py
├── calculate_fid.py
├── histogram.py
├── inception.py
├── niqe.py
├── niqe_image_params.mat
├── piqe.py
└── sobel.py
├── models
├── __init__.py
├── base.py
├── cut.py
├── cyclegan.py
├── discriminator.py
├── downsample.py
├── generator.py
├── kin.py
├── lsesim.py
├── lsesim_loss.py
├── model.py
├── normalization.py
├── projector.py
├── tests
│ ├── __init__.py
│ ├── test_cut.py
│ ├── test_cyclegan.py
│ ├── test_data
│ │ └── configs
│ │ │ ├── config_lung_lesion_for_test_cut.yaml
│ │ │ └── config_lung_lesion_for_test_cyclegan.yaml
│ └── test_kin.py
├── tin.py
└── upsample.py
├── requirements.txt
├── train.py
├── transfer.py
└── utils
├── __init__.py
├── dataset.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | __pycache__/
2 | .ipynb_checkpoints/
3 | *.out
4 | experiments/
5 | checkpoints/
6 | .pytest_cache
7 | test_dir_x/
8 | test_dir_y/
--------------------------------------------------------------------------------
/F-LSeSim/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2021 Chuanxia Zheng
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/F-LSeSim/README.md:
--------------------------------------------------------------------------------
1 |
2 | # Spatially-Correlative Loss
3 |
4 | [arXiv](https://arxiv.org/abs/2104.00854) | [website](http://www.chuanxiaz.com/publication/flsesim/)
5 |
6 |
7 |
8 |
9 |
10 | We provide the Pytorch implementation of "The Spatially-Correlative Loss for Various Image Translation Tasks". Based on the inherent self-similarity of object, we propose a new structure-preserving loss for one-sided unsupervised I2I network. The new loss will *deal only with spatial relationship of repeated signal, regardless of their original absolute value*.
11 |
12 | [The Spatially-Correlative Loss for Various Image Translation Tasks](https://arxiv.org/abs/2104.00854)
13 | [Chuanxia Zheng](http://www.chuanxiaz.com), [Tat-Jen Cham](http://www.ntu.edu.sg/home/astjcham/), [Jianfei Cai](https://research.monash.edu/en/persons/jianfei-cai)
14 | NTU and Monash University
15 | In CVPR2021
16 |
17 | ## ToDo
18 | - a simple example to use the proposed loss
19 |
20 | ## Example Results
21 |
22 | ### Unpaired Image-to-Image Translation
23 |
24 |
25 |
26 | ### Single Image Translation
27 |
28 |
29 |
30 | ### [More results on project page](http://www.chuanxiaz.com/publication/flsesim/)
31 |
32 | ## Getting Started
33 |
34 | ### Installation
35 | This code was tested with Pytorch 1.7.0, CUDA 10.2, and Python 3.7
36 |
37 | - Install Pytoch 1.7.0, torchvision, and other dependencies from [http://pytorch.org](http://pytorch.org)
38 | - Install python libraries [visdom](https://github.com/facebookresearch/visdom) and [dominate](https://github.com/Knio/dominate) for visualization
39 |
40 | ```
41 | pip install visdom dominate
42 | ```
43 | - Clone this repo:
44 |
45 | ```
46 | git clone https://github.com/lyndonzheng/F-LSeSim
47 | cd F-LSeSim
48 | ```
49 |
50 | ### [Datasets](https://github.com/taesungp/contrastive-unpaired-translation/blob/master/docs/datasets.md)
51 | Please refer to the original [CUT](https://github.com/taesungp/contrastive-unpaired-translation) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix) to download datasets and learn how to create your own datasets.
52 |
53 | ### Training
54 |
55 | - Train the *single-modal* I2I translation model:
56 |
57 | ```
58 | sh ./scripts/train_sc.sh
59 | ```
60 |
61 | - Set ```--use_norm``` for cosine similarity map, the default similarity is dot-based attention score. ```--learned_attn, --augment``` for the learned self-similarity.
62 | - To view training results and loss plots, run ```python -m visdom.server``` and copy the URL [http://localhost:port](http://localhost:port).
63 | - Training models will be saved under the **checkpoints** folder.
64 | - The more training options can be found in the **options** folder.
65 |
66 |
67 |
68 | - Train the *single-image* translation model:
69 |
70 | ```
71 | sh ./scripts/train_sinsc.sh
72 | ```
73 |
74 | As the *multi-modal* I2I translation model was trained on [MUNIT](https://github.com/NVlabs/MUNIT), we would not plan to merge the code to this repository. If you wish to obtain multi-modal results, please contact us at chuanxia001@e.ntu.edu.sg.
75 |
76 | ### Testing
77 |
78 | - Test the *single-modal* I2I translation model:
79 |
80 | ```
81 | sh ./scripts/test_sc.sh
82 | ```
83 |
84 | - Test the *single-image* translation model:
85 |
86 | ```
87 | sh ./scripts/test_sinsc.sh
88 | ```
89 |
90 | - Test the FID score for all training epochs:
91 |
92 | ```
93 | sh ./scripts/test_fid.sh
94 | ```
95 |
96 | ### Pretrained Models
97 |
98 | Download the pre-trained models (will be released soon) using the following links and put them under```checkpoints/``` directory.
99 |
100 | - ```Single-modal translation model```: [horse2zebra](https://drive.google.com/drive/folders/1k8Y5R6CnaDwfkha_lD5_yQTvoajoU6GR?usp=sharing), [semantic2image](https://drive.google.com/drive/folders/1xnF6wLTPhD35-2It8IIomJRhFZdr2qXp?usp=sharing), [apple2orange](https://drive.google.com/drive/folders/1Z9PwxkWlakDdv12Jha6WJRgO6cSfEZGs?usp=sharing)
101 | - ```Single-image translation model```: [image2monet](https://drive.google.com/drive/folders/1QcGY9H0USWHJtcifRMWh_KHOJszME6-U?usp=sharing)
102 |
103 | ## Citation
104 | ```
105 | @inproceedings{zheng2021spatiallycorrelative,
106 | title={The Spatially-Correlative Loss for Various Image Translation Tasks},
107 | author={Zheng, Chuanxia and Cham, Tat-Jen and Cai, Jianfei},
108 | booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
109 | year={2021}
110 | }
111 | ```
112 |
113 | ## Acknowledge
114 | Our code is developed based on [CUT](https://github.com/taesungp/contrastive-unpaired-translation) and [CycleGAN](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). We also thank [pytorch-fid](https://github.com/mseitzer/pytorch-fid) for FID computation, [LPIPS](https://github.com/richzhang/PerceptualSimilarity) for diversity score, and [D&C](https://github.com/clovaai/generative-evaluation-prdc) for density and coverage evaluation.
115 |
116 |
117 |
118 |
119 |
--------------------------------------------------------------------------------
/F-LSeSim/combine.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | import yaml
7 | from PIL import Image
8 | from yaml.loader import SafeLoader
9 |
10 |
11 | def read_yaml_config(config_path):
12 | with open(config_path) as f:
13 | config = yaml.load(f, Loader=SafeLoader)
14 | return config
15 |
16 |
17 | def main():
18 | parser = argparse.ArgumentParser("Combined transferred images")
19 | parser.add_argument(
20 | "-c",
21 | "--config",
22 | type=str,
23 | default="./config.yaml",
24 | help="Path to the config file.",
25 | )
26 | parser.add_argument("--patch_size", type=int, help="Patch size", default=512)
27 | parser.add_argument("--resize_h", type=int, help="Resize H", default=-1)
28 | parser.add_argument("--resize_w", type=int, help="Resize W", default=-1)
29 | parser.add_argument("--read_original", action="store_true")
30 | args = parser.parse_args()
31 |
32 | config = read_yaml_config(args.config)
33 |
34 | basename = os.path.basename(config["INFERENCE_SETTING"]["TEST_X"])
35 | filename = os.path.splitext(basename)[0]
36 | path_root = os.path.join(
37 | config["EXPERIMENT_ROOT_PATH"], config["EXPERIMENT_NAME"], "test", filename
38 | )
39 | if (
40 | "OVERWRITE_OUTPUT_PATH" in config["INFERENCE_SETTING"]
41 | and config["INFERENCE_SETTING"]["OVERWRITE_OUTPUT_PATH"] != ""
42 | ):
43 | path_root = config["INFERENCE_SETTING"]["OVERWRITE_OUTPUT_PATH"]
44 |
45 | path_base = os.path.join(
46 | path_root,
47 | config["INFERENCE_SETTING"]["NORMALIZATION"]["TYPE"],
48 | config["INFERENCE_SETTING"]["MODEL_VERSION"],
49 | )
50 |
51 | combined_image_name = f"combined_{config['INFERENCE_SETTING']['NORMALIZATION']['TYPE']}_{config['INFERENCE_SETTING']['MODEL_VERSION']}.png"
52 |
53 | if config["INFERENCE_SETTING"]["NORMALIZATION"]["TYPE"] == "kin":
54 | path_base = os.path.join(
55 | path_base,
56 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}",
57 | )
58 | combined_image_name = f"combined_{config['INFERENCE_SETTING']['NORMALIZATION']['TYPE']}_{config['INFERENCE_SETTING']['MODEL_VERSION']}_{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}.png"
59 |
60 | filenames = os.listdir(path_base)
61 | try:
62 | filenames.remove("thumbnail_Y_fake.png")
63 | except:
64 | pass
65 |
66 | y_anchor_max = 0
67 | x_anchor_max = 0
68 | for filename in filenames:
69 | _, _, y_anchor, x_anchor, _ = filename.split("_", 4)
70 | y_anchor_max = max(y_anchor_max, int(y_anchor))
71 | x_anchor_max = max(x_anchor_max, int(x_anchor))
72 |
73 | matrix = np.zeros(
74 | (y_anchor_max + args.patch_size, x_anchor_max + args.patch_size, 3),
75 | dtype=np.uint8,
76 | )
77 |
78 | for filename in sorted(filenames):
79 | print(f"Combine {filename} ", end="\r")
80 | _, _, y_anchor, x_anchor, _ = filename.split("_", 4)
81 | y_anchor = int(y_anchor)
82 | x_anchor = int(x_anchor)
83 | image = cv2.imread(os.path.join(path_base, filename))
84 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
85 | matrix[y_anchor : y_anchor + 512, x_anchor : x_anchor + 512, :] = image
86 |
87 | if (args.resize_h != -1) and (args.resize_w != -1):
88 | matrix = cv2.resize(matrix, (args.resize_w, args.resize_h), cv2.INTER_CUBIC)
89 |
90 | if args.read_original:
91 | H, W, _ = cv2.imread(config["INFERENCE_SETTING"]["TEST_X"]).shape
92 | matrix = cv2.resize(matrix, (W, H), cv2.INTER_CUBIC)
93 |
94 | matrix_image = Image.fromarray(matrix)
95 | matrix_image.save(os.path.join(path_root, combined_image_name))
96 |
97 |
98 | if __name__ == "__main__":
99 | main()
100 |
--------------------------------------------------------------------------------
/F-LSeSim/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 |
15 | import torch.utils.data
16 |
17 | from data.base_dataset import BaseDataset
18 |
19 |
20 | def find_dataset_using_name(dataset_name):
21 | """Import the module "data/[dataset_name]_dataset.py".
22 |
23 | In the file, the class called DatasetNameDataset() will
24 | be instantiated. It has to be a subclass of BaseDataset,
25 | and it is case-insensitive.
26 | """
27 | dataset_filename = "data." + dataset_name + "_dataset"
28 | datasetlib = importlib.import_module(dataset_filename)
29 |
30 | dataset = None
31 | target_dataset_name = dataset_name.replace("_", "") + "dataset"
32 | for name, cls in datasetlib.__dict__.items():
33 | if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset):
34 | dataset = cls
35 |
36 | if dataset is None:
37 | raise NotImplementedError(
38 | "In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase."
39 | % (dataset_filename, target_dataset_name)
40 | )
41 |
42 | return dataset
43 |
44 |
45 | def get_option_setter(dataset_name):
46 | """Return the static method of the dataset class."""
47 | dataset_class = find_dataset_using_name(dataset_name)
48 | return dataset_class.modify_commandline_options
49 |
50 |
51 | def create_dataset(opt):
52 | """Create a dataset given the option.
53 |
54 | This function wraps the class CustomDatasetDataLoader.
55 | This is the main interface between this package and 'train.py'/'test.py'
56 |
57 | Example:
58 | >>> from data import create_dataset
59 | >>> dataset = create_dataset(opt)
60 | """
61 | data_loader = CustomDatasetDataLoader(opt)
62 | dataset = data_loader.load_data()
63 | return dataset
64 |
65 |
66 | class CustomDatasetDataLoader:
67 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
68 |
69 | def __init__(self, opt):
70 | """Initialize this class
71 |
72 | Step 1: create a dataset instance given the name [dataset_mode]
73 | Step 2: create a multi-threaded data loader.
74 | """
75 | self.opt = opt
76 | dataset_class = find_dataset_using_name(opt.dataset_mode)
77 | self.dataset = dataset_class(opt)
78 | print("dataset [%s] was created" % type(self.dataset).__name__)
79 | self.dataloader = torch.utils.data.DataLoader(
80 | self.dataset,
81 | batch_size=opt.batch_size,
82 | shuffle=not opt.serial_batches,
83 | num_workers=int(opt.num_threads),
84 | )
85 |
86 | def load_data(self):
87 | return self
88 |
89 | def __len__(self):
90 | """Return the number of data in the dataset"""
91 | return min(len(self.dataset), self.opt.max_dataset_size)
92 |
93 | def __iter__(self):
94 | """Return a batch of data"""
95 | for i, data in enumerate(self.dataloader):
96 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
97 | break
98 | yield data
99 |
--------------------------------------------------------------------------------
/F-LSeSim/data/aligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from PIL import Image
4 |
5 | from data.base_dataset import BaseDataset, get_params, get_transform
6 | from data.image_folder import make_dataset
7 |
8 |
9 | class AlignedDataset(BaseDataset):
10 | """A dataset class for paired image dataset.
11 |
12 | It assumes that the directory '/path/to/data/train' contains image pairs in the form of {A,B}.
13 | During test time, you need to prepare a directory '/path/to/data/test'.
14 | """
15 |
16 | def __init__(self, opt):
17 | """Initialize this dataset class.
18 |
19 | Parameters:
20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
21 | """
22 | BaseDataset.__init__(self, opt)
23 | self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
24 | self.AB_paths = sorted(
25 | make_dataset(self.dir_AB, opt.max_dataset_size)
26 | ) # get image paths
27 | assert (
28 | self.opt.load_size >= self.opt.crop_size
29 | ) # crop_size should be smaller than the size of loaded image
30 | self.input_nc = (
31 | self.opt.output_nc if self.opt.direction == "BtoA" else self.opt.input_nc
32 | )
33 | self.output_nc = (
34 | self.opt.input_nc if self.opt.direction == "BtoA" else self.opt.output_nc
35 | )
36 |
37 | def __getitem__(self, index):
38 | """Return a data point and its metadata information.
39 |
40 | Parameters:
41 | index - - a random integer for data indexing
42 |
43 | Returns a dictionary that contains A, B, A_paths and B_paths
44 | A (tensor) - - an image in the input domain
45 | B (tensor) - - its corresponding image in the target domain
46 | A_paths (str) - - image paths
47 | B_paths (str) - - image paths (same as A_paths)
48 | """
49 | # read a image given a random integer index
50 | AB_path = self.AB_paths[index]
51 | AB = Image.open(AB_path).convert("RGB")
52 | # split AB image into A and B
53 | w, h = AB.size
54 | w2 = int(w / 2)
55 | A = AB.crop((0, 0, w2, h))
56 | B = AB.crop((w2, 0, w, h))
57 |
58 | # apply the same transform to both A and B
59 | transform_params = get_params(self.opt, A.size)
60 | A_transform = get_transform(
61 | self.opt, transform_params, grayscale=(self.input_nc == 1)
62 | )
63 | B_transform = get_transform(
64 | self.opt, transform_params, grayscale=(self.output_nc == 1)
65 | )
66 |
67 | A = A_transform(A)
68 | B = B_transform(B)
69 |
70 | return {"A": A, "B": B, "A_paths": AB_path, "B_paths": AB_path}
71 |
72 | def __len__(self):
73 | """Return the total number of images in the dataset."""
74 | return len(self.AB_paths)
75 |
--------------------------------------------------------------------------------
/F-LSeSim/data/colorization_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import torchvision.transforms as transforms
5 | from PIL import Image
6 | from skimage import color # require skimage
7 |
8 | from data.base_dataset import BaseDataset, get_transform
9 | from data.image_folder import make_dataset
10 |
11 |
12 | class ColorizationDataset(BaseDataset):
13 | """This dataset class can load a set of natural images in RGB, and convert RGB format into (L, ab) pairs in Lab color space.
14 |
15 | This dataset is required by pix2pix-based colorization model ('--model colorization')
16 | """
17 |
18 | @staticmethod
19 | def modify_commandline_options(parser, is_train):
20 | """Add new dataset-specific options, and rewrite default values for existing options.
21 |
22 | Parameters:
23 | parser -- original option parser
24 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
25 |
26 | Returns:
27 | the modified parser.
28 |
29 | By default, the number of channels for input image is 1 (L) and
30 | the number of channels for output image is 2 (ab). The direction is from A to B
31 | """
32 | parser.set_defaults(input_nc=1, output_nc=2, direction="AtoB")
33 | return parser
34 |
35 | def __init__(self, opt):
36 | """Initialize this dataset class.
37 |
38 | Parameters:
39 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
40 | """
41 | BaseDataset.__init__(self, opt)
42 | self.dir = os.path.join(opt.dataroot, opt.phase)
43 | self.AB_paths = sorted(make_dataset(self.dir, opt.max_dataset_size))
44 | assert opt.input_nc == 1 and opt.output_nc == 2 and opt.direction == "AtoB"
45 | self.transform = get_transform(self.opt, convert=False)
46 |
47 | def __getitem__(self, index):
48 | """Return a data point and its metadata information.
49 |
50 | Parameters:
51 | index - - a random integer for data indexing
52 |
53 | Returns a dictionary that contains A, B, A_paths and B_paths
54 | A (tensor) - - the L channel of an image
55 | B (tensor) - - the ab channels of the same image
56 | A_paths (str) - - image paths
57 | B_paths (str) - - image paths (same as A_paths)
58 | """
59 | path = self.AB_paths[index]
60 | im = Image.open(path).convert("RGB")
61 | im = self.transform(im)
62 | im = np.array(im)
63 | lab = color.rgb2lab(im).astype(np.float32)
64 | lab_t = transforms.ToTensor()(lab)
65 | A = lab_t[[0], ...] / 50.0 - 1.0
66 | B = lab_t[[1, 2], ...] / 110.0
67 | return {"A": A, "B": B, "A_paths": path, "B_paths": path}
68 |
69 | def __len__(self):
70 | """Return the total number of images in the dataset."""
71 | return len(self.AB_paths)
72 |
--------------------------------------------------------------------------------
/F-LSeSim/data/dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from pathlib import Path
4 |
5 | import albumentations as A
6 | import numpy as np
7 | from albumentations.pytorch import ToTensorV2
8 | from PIL import Image
9 | from torch.utils.data import Dataset
10 |
11 | test_transforms = A.Compose(
12 | [
13 | A.Resize(width=512, height=512),
14 | A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255),
15 | ToTensorV2(),
16 | ],
17 | additional_targets={"image0": "image"},
18 | )
19 |
20 |
21 | def remove_file(files, file_name):
22 | try:
23 | files.remove(file_name)
24 | except Exception:
25 | pass
26 |
27 |
28 | class XInferenceDataset(Dataset):
29 | def __init__(self, root_X, transform=None, return_anchor=False, thumbnail=None):
30 | self.root_X = root_X
31 | self.transform = transform
32 | self.return_anchor = return_anchor
33 | self.thumbnail = thumbnail
34 |
35 | self.X_images = os.listdir(root_X)
36 |
37 | remove_file(self.X_images, "thumbnail.png")
38 | remove_file(self.X_images, "blank_patches_list.csv")
39 |
40 | if self.return_anchor:
41 | self.__get_boundary()
42 |
43 | self.length_dataset = len(self.X_images)
44 |
45 | def __get_boundary(self):
46 | self.y_anchor_num = 0
47 | self.x_anchor_num = 0
48 | for X_image in self.X_images:
49 | y_idx, x_idx, _, _ = Path(X_image).stem.split("_")[:4]
50 | y_idx = int(y_idx)
51 | x_idx = int(x_idx)
52 | self.y_anchor_num = max(self.y_anchor_num, y_idx)
53 | self.x_anchor_num = max(self.x_anchor_num, x_idx)
54 |
55 | def get_boundary(self):
56 | assert self.return_anchor == True
57 | return (self.y_anchor_num, self.x_anchor_num)
58 |
59 | def __len__(self):
60 | return self.length_dataset
61 |
62 | def __getitem__(self, index):
63 | X_img_name = self.X_images[index]
64 |
65 | X_path = os.path.join(self.root_X, X_img_name)
66 |
67 | X_img = np.array(Image.open(X_path).convert("RGB"))
68 |
69 | if self.transform:
70 | augmentations = self.transform(image=X_img)
71 | X_img = augmentations["image"]
72 |
73 | if self.return_anchor:
74 | y_idx, x_idx, y_anchor, x_anchor = Path(X_img_name).stem.split("_")[:4]
75 | y_idx = int(y_idx)
76 | x_idx = int(x_idx)
77 | return {
78 | "X_img": X_img,
79 | "X_path": X_path,
80 | "y_idx": y_idx,
81 | "x_idx": x_idx,
82 | "y_anchor": y_anchor,
83 | "x_anchor": x_anchor,
84 | }
85 |
86 | else:
87 | return {"X_img": X_img, "X_path": X_path}
88 |
89 | def get_thumbnail(self):
90 | thumbnail_img = np.array(Image.open(self.thumbnail).convert("RGB"))
91 | if self.transform:
92 | augmentations = self.transform(image=thumbnail_img)
93 | thumbnail_img = augmentations["image"]
94 | return thumbnail_img.unsqueeze(0)
95 |
--------------------------------------------------------------------------------
/F-LSeSim/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import os
8 |
9 | import torch.utils.data as data
10 | from PIL import Image
11 |
12 | IMG_EXTENSIONS = [
13 | ".jpg",
14 | ".JPG",
15 | ".jpeg",
16 | ".JPEG",
17 | ".png",
18 | ".PNG",
19 | ".ppm",
20 | ".PPM",
21 | ".bmp",
22 | ".BMP",
23 | ".tif",
24 | ".TIF",
25 | ".tiff",
26 | ".TIFF",
27 | ]
28 |
29 |
30 | def is_image_file(filename):
31 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
32 |
33 |
34 | def make_dataset(dir, max_dataset_size=float("inf")):
35 | images = []
36 | assert os.path.isdir(dir), "%s is not a valid directory" % dir
37 |
38 | for root, _, fnames in sorted(os.walk(dir)):
39 | for fname in fnames:
40 | if is_image_file(fname):
41 | path = os.path.join(root, fname)
42 | images.append(path)
43 | return images[: min(max_dataset_size, len(images))]
44 |
45 |
46 | def default_loader(path):
47 | return Image.open(path).convert("RGB")
48 |
49 |
50 | class ImageFolder(data.Dataset):
51 | def __init__(self, root, transform=None, return_paths=False, loader=default_loader):
52 | imgs = make_dataset(root)
53 | if len(imgs) == 0:
54 | raise (
55 | RuntimeError(
56 | "Found 0 images in: " + root + "\n"
57 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)
58 | )
59 | )
60 |
61 | self.root = root
62 | self.imgs = imgs
63 | self.transform = transform
64 | self.return_paths = return_paths
65 | self.loader = loader
66 |
67 | def __getitem__(self, index):
68 | path = self.imgs[index]
69 | img = self.loader(path)
70 | if self.transform is not None:
71 | img = self.transform(img)
72 | if self.return_paths:
73 | return img, path
74 | else:
75 | return img
76 |
77 | def __len__(self):
78 | return len(self.imgs)
79 |
--------------------------------------------------------------------------------
/F-LSeSim/data/single_dataset.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | from data.base_dataset import BaseDataset, get_transform
4 | from data.image_folder import make_dataset
5 |
6 |
7 | class SingleDataset(BaseDataset):
8 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data.
9 |
10 | It can be used for generating CycleGAN results only for one side with the model option '-model test'.
11 | """
12 |
13 | def __init__(self, opt):
14 | """Initialize this dataset class.
15 |
16 | Parameters:
17 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
18 | """
19 | BaseDataset.__init__(self, opt)
20 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size))
21 | input_nc = (
22 | self.opt.output_nc if self.opt.direction == "BtoA" else self.opt.input_nc
23 | )
24 | self.transform = get_transform(opt, grayscale=(input_nc == 1))
25 |
26 | def __getitem__(self, index):
27 | """Return a data point and its metadata information.
28 |
29 | Parameters:
30 | index - - a random integer for data indexing
31 |
32 | Returns a dictionary that contains A and A_paths
33 | A(tensor) - - an image in one domain
34 | A_paths(str) - - the path of the image
35 | """
36 | A_path = self.A_paths[index]
37 | A_img = Image.open(A_path).convert("RGB")
38 | A = self.transform(A_img)
39 | return {"A": A, "A_paths": A_path}
40 |
41 | def __len__(self):
42 | """Return the total number of images in the dataset."""
43 | return len(self.A_paths)
44 |
--------------------------------------------------------------------------------
/F-LSeSim/data/singleimage_dataset.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import random
3 |
4 | import numpy as np
5 | from PIL import Image
6 |
7 | from data.base_dataset import BaseDataset, get_transform
8 | from data.image_folder import make_dataset
9 |
10 |
11 | class SingleImageDataset(BaseDataset):
12 | """
13 | This dataset class can load unaligned/unpaired datasets.
14 |
15 | It requires two directories to host training images from domain A '/path/to/data/trainA'
16 | and from domain B '/path/to/data/trainB' respectively.
17 | You can train the model with the dataset flag '--dataroot /path/to/data'.
18 | Similarly, you need to prepare two directories:
19 | '/path/to/data/testA' and '/path/to/data/testB' during test time.
20 | """
21 |
22 | def __init__(self, opt):
23 | """Initialize this dataset class.
24 |
25 | Parameters:
26 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
27 | """
28 | BaseDataset.__init__(self, opt)
29 |
30 | self.dir_A = os.path.join(
31 | opt.dataroot, "trainA"
32 | ) # create a path '/path/to/data/trainA'
33 | self.dir_B = os.path.join(
34 | opt.dataroot, "trainB"
35 | ) # create a path '/path/to/data/trainB'
36 |
37 | if os.path.exists(self.dir_A) and os.path.exists(self.dir_B):
38 | self.A_paths = sorted(
39 | make_dataset(self.dir_A, opt.max_dataset_size)
40 | ) # load images from '/path/to/data/trainA'
41 | self.B_paths = sorted(
42 | make_dataset(self.dir_B, opt.max_dataset_size)
43 | ) # load images from '/path/to/data/trainB'
44 | self.A_size = len(self.A_paths) # get the size of dataset A
45 | self.B_size = len(self.B_paths) # get the size of dataset B
46 |
47 | assert (
48 | len(self.A_paths) == 1 and len(self.B_paths) == 1
49 | ), "SingleImageDataset class should be used with one image in each domain"
50 | A_img = Image.open(self.A_paths[0]).convert("RGB")
51 | B_img = Image.open(self.B_paths[0]).convert("RGB")
52 | print("Image sizes %s and %s" % (str(A_img.size), str(B_img.size)))
53 |
54 | self.A_img = A_img
55 | self.B_img = B_img
56 |
57 | # In single-image translation, we augment the data loader by applying
58 | # random scaling. Still, we design the data loader such that the
59 | # amount of scaling is the same within a minibatch. To do this,
60 | # we precompute the random scaling values, and repeat them by |batch_size|.
61 | A_zoom = 1 / self.opt.random_scale_max
62 | zoom_levels_A = np.random.uniform(
63 | A_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2)
64 | )
65 | self.zoom_levels_A = np.reshape(
66 | np.tile(zoom_levels_A, (1, opt.batch_size, 1)), [-1, 2]
67 | )
68 |
69 | B_zoom = 1 / self.opt.random_scale_max
70 | zoom_levels_B = np.random.uniform(
71 | B_zoom, 1.0, size=(len(self) // opt.batch_size + 1, 1, 2)
72 | )
73 | self.zoom_levels_B = np.reshape(
74 | np.tile(zoom_levels_B, (1, opt.batch_size, 1)), [-1, 2]
75 | )
76 |
77 | # While the crop locations are randomized, the negative samples should
78 | # not come from the same location. To do this, we precompute the
79 | # crop locations with no repetition.
80 | self.patch_indices_A = list(range(len(self)))
81 | random.shuffle(self.patch_indices_A)
82 | self.patch_indices_B = list(range(len(self)))
83 | random.shuffle(self.patch_indices_B)
84 |
85 | def __getitem__(self, index):
86 | """Return a data point and its metadata information.
87 |
88 | Parameters:
89 | index (int) -- a random integer for data indexing
90 |
91 | Returns a dictionary that contains A, B, A_paths and B_paths
92 | A (tensor) -- an image in the input domain
93 | B (tensor) -- its corresponding image in the target domain
94 | A_paths (str) -- image paths
95 | B_paths (str) -- image paths
96 | """
97 | A_path = self.A_paths[0]
98 | B_path = self.B_paths[0]
99 | A_img = self.A_img
100 | B_img = self.B_img
101 |
102 | # apply image transformation
103 | if self.opt.phase == "train":
104 | param = {
105 | "scale_factor": self.zoom_levels_A[index],
106 | "patch_index": self.patch_indices_A[index],
107 | "flip": random.random() > 0.5,
108 | }
109 |
110 | transform_A = get_transform(self.opt, params=param, method=Image.BILINEAR)
111 | A = transform_A(A_img)
112 |
113 | param = {
114 | "scale_factor": self.zoom_levels_B[index],
115 | "patch_index": self.patch_indices_B[index],
116 | "flip": random.random() > 0.5,
117 | }
118 | transform_B = get_transform(self.opt, params=param, method=Image.BILINEAR)
119 | B = transform_B(B_img)
120 | else:
121 | transform = get_transform(self.opt, method=Image.BILINEAR)
122 | A = transform(A_img)
123 | B = transform(B_img)
124 |
125 | return {"A": A, "B": B, "A_paths": A_path, "B_paths": B_path}
126 |
127 | def __len__(self):
128 | """Let's pretend the single image contains 100,000 crops for convenience."""
129 | return 100000
130 |
--------------------------------------------------------------------------------
/F-LSeSim/data/template_dataset.py:
--------------------------------------------------------------------------------
1 | """Dataset class template
2 |
3 | This module provides a template for users to implement custom datasets.
4 | You can specify '--dataset_mode template' to use this dataset.
5 | The class name should be consistent with both the filename and its dataset_mode option.
6 | The filename should be _dataset.py
7 | The class name should be Dataset.py
8 | You need to implement the following functions:
9 | -- : Add dataset-specific options and rewrite default values for existing options.
10 | -- <__init__>: Initialize this dataset class.
11 | -- <__getitem__>: Return a data point and its metadata information.
12 | -- <__len__>: Return the number of images.
13 | """
14 | from data.base_dataset import BaseDataset, get_transform
15 |
16 | # from data.image_folder import make_dataset
17 | # from PIL import Image
18 |
19 |
20 | class TemplateDataset(BaseDataset):
21 | """A template dataset class for you to implement custom datasets."""
22 |
23 | @staticmethod
24 | def modify_commandline_options(parser, is_train):
25 | """Add new dataset-specific options, and rewrite default values for existing options.
26 |
27 | Parameters:
28 | parser -- original option parser
29 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
30 |
31 | Returns:
32 | the modified parser.
33 | """
34 | parser.add_argument(
35 | "--new_dataset_option", type=float, default=1.0, help="new dataset option"
36 | )
37 | parser.set_defaults(
38 | max_dataset_size=10, new_dataset_option=2.0
39 | ) # specify dataset-specific default values
40 | return parser
41 |
42 | def __init__(self, opt):
43 | """Initialize this dataset class.
44 |
45 | Parameters:
46 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
47 |
48 | A few things can be done here.
49 | - save the options (have been done in BaseDataset)
50 | - get image paths and meta information of the dataset.
51 | - define the image transformation.
52 | """
53 | # save the option and dataset root
54 | BaseDataset.__init__(self, opt)
55 | # get the image paths of your dataset;
56 | self.image_paths = (
57 | []
58 | ) # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
59 | # define the default transform function. You can use ; You can also define your custom transform function
60 | self.transform = get_transform(opt)
61 |
62 | def __getitem__(self, index):
63 | """Return a data point and its metadata information.
64 |
65 | Parameters:
66 | index -- a random integer for data indexing
67 |
68 | Returns:
69 | a dictionary of data with their names. It usually contains the data itself and its metadata information.
70 |
71 | Step 1: get a random image path: e.g., path = self.image_paths[index]
72 | Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
73 | Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
74 | Step 4: return a data point as a dictionary.
75 | """
76 | path = "temp" # needs to be a string
77 | data_A = None # needs to be a tensor
78 | data_B = None # needs to be a tensor
79 | return {"data_A": data_A, "data_B": data_B, "path": path}
80 |
81 | def __len__(self):
82 | """Return the total number of images."""
83 | return len(self.image_paths)
84 |
--------------------------------------------------------------------------------
/F-LSeSim/data/unaligned_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 |
4 | import torchvision.transforms as transforms
5 | from PIL import Image
6 |
7 | import util.util as util
8 | from data.base_dataset import BaseDataset, get_transform
9 | from data.image_folder import make_dataset
10 |
11 |
12 | def remove_file(files, file_name):
13 | try:
14 | files.remove(file_name)
15 | except Exception:
16 | pass
17 |
18 |
19 | class UnalignedDataset(BaseDataset):
20 | """
21 | This dataset class can load unaligned/unpaired datasets.
22 |
23 | It requires two directories to host training images from domain A '/path/to/data/trainA'
24 | and from domain B '/path/to/data/trainB' respectively.
25 | You can train the model with the dataset flag '--dataroot /path/to/data'.
26 | Similarly, you need to prepare two directories:
27 | '/path/to/data/testA' and '/path/to/data/testB' during test time.
28 | """
29 |
30 | def __init__(self, opt):
31 | """Initialize this dataset class.
32 |
33 | Parameters:
34 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
35 | """
36 | BaseDataset.__init__(self, opt)
37 | self.dir_A = os.path.join(
38 | opt.dataroot, opt.phase + "X"
39 | ) # create a path '/path/to/data/trainX'
40 | self.dir_B = os.path.join(
41 | opt.dataroot, opt.phase + "Y"
42 | ) # create a path '/path/to/data/trainY'
43 |
44 | self.A_paths = sorted(
45 | make_dataset(self.dir_A, opt.max_dataset_size)
46 | ) # load images from '/path/to/data/trainA'
47 | self.B_paths = sorted(
48 | make_dataset(self.dir_B, opt.max_dataset_size)
49 | ) # load images from '/path/to/data/trainB'
50 | self.A_size = len(self.A_paths) # get the size of dataset A
51 | self.B_size = len(self.B_paths) # get the size of dataset B
52 | # apply image transformation
53 | # For FastCUT mode, if in finetuning phase (learning rate is decaying)
54 | # do not perform resize-crop data augmentation of CycleGAN.
55 | # print('current_epoch', self.current_epoch)
56 | self.transform = get_transform(opt, convert=False)
57 | if self.opt.isTrain and opt.augment:
58 | self.transform_aug = transforms.Compose(
59 | [
60 | transforms.ColorJitter(
61 | brightness=0.5, contrast=0.5, saturation=0.5, hue=0.3
62 | ),
63 | transforms.ToTensor(),
64 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
65 | ]
66 | )
67 | else:
68 | self.transform_aug = None
69 | self.transform_tensor = transforms.Compose(
70 | [
71 | transforms.ToTensor(),
72 | transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
73 | ]
74 | )
75 |
76 | def __getitem__(self, index):
77 | """Return a data point and its metadata information.
78 |
79 | Parameters:
80 | index (int) -- a random integer for data indexing
81 |
82 | Returns a dictionary that contains A, B, A_paths and B_paths
83 | A (tensor) -- an image in the input domain
84 | B (tensor) -- its corresponding image in the target domain
85 | A_paths (str) -- image paths
86 | B_paths (str) -- image paths
87 | """
88 | A_path = self.A_paths[
89 | index % self.A_size
90 | ] # make sure index is within then range
91 | if self.opt.serial_batches: # make sure index is within then range
92 | index_B = index % self.B_size
93 | else: # randomize the index for domain B to avoid fixed pairs.
94 | index_B = random.randint(0, self.B_size - 1)
95 | B_path = self.B_paths[index_B]
96 | A_img = Image.open(A_path).convert("RGB")
97 | B_img = Image.open(B_path).convert("RGB")
98 | A_pil = self.transform(A_img)
99 | B_pil = self.transform(B_img)
100 | A = self.transform_tensor(A_pil)
101 | B = self.transform_tensor(B_pil)
102 | if self.opt.isTrain and self.transform_aug is not None:
103 | A_aug = self.transform_aug(A_pil)
104 | B_aug = self.transform_aug(B_pil)
105 | return {
106 | "A": A,
107 | "B": B,
108 | "A_paths": A_path,
109 | "B_paths": B_path,
110 | "A_aug": A_aug,
111 | "B_aug": B_aug,
112 | }
113 | else:
114 | return {"A": A, "B": B, "A_paths": A_path, "B_paths": B_path}
115 |
116 | def __len__(self):
117 | """Return the total number of images in the dataset.
118 |
119 | As we have two datasets with potentially different number of images,
120 | we take a maximum of
121 | """
122 | return max(self.A_size, self.B_size)
123 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/bibtex/cityscapes.tex:
--------------------------------------------------------------------------------
1 | @inproceedings{Cordts2016Cityscapes,
2 | title={The Cityscapes Dataset for Semantic Urban Scene Understanding},
3 | author={Cordts, Marius and Omran, Mohamed and Ramos, Sebastian and Rehfeld, Timo and Enzweiler, Markus and Benenson, Rodrigo and Franke, Uwe and Roth, Stefan and Schiele, Bernt},
4 | booktitle={Proc. of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
5 | year={2016}
6 | }
7 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/bibtex/facades.tex:
--------------------------------------------------------------------------------
1 | @INPROCEEDINGS{Tylecek13,
2 | author = {Radim Tyle{\v c}ek, Radim {\v S}{\' a}ra},
3 | title = {Spatial Pattern Templates for Recognition of Objects with Regular Structure},
4 | booktitle = {Proc. GCPR},
5 | year = {2013},
6 | address = {Saarbrucken, Germany},
7 | }
8 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/bibtex/handbags.tex:
--------------------------------------------------------------------------------
1 | @inproceedings{zhu2016generative,
2 | title={Generative Visual Manipulation on the Natural Image Manifold},
3 | author={Zhu, Jun-Yan and Kr{\"a}henb{\"u}hl, Philipp and Shechtman, Eli and Efros, Alexei A.},
4 | booktitle={Proceedings of European Conference on Computer Vision (ECCV)},
5 | year={2016}
6 | }
7 |
8 | @InProceedings{xie15hed,
9 | author = {"Xie, Saining and Tu, Zhuowen"},
10 | Title = {Holistically-Nested Edge Detection},
11 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
12 | Year = {2015},
13 | }
14 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/bibtex/shoes.tex:
--------------------------------------------------------------------------------
1 | @InProceedings{fine-grained,
2 | author = {A. Yu and K. Grauman},
3 | title = {{F}ine-{G}rained {V}isual {C}omparisons with {L}ocal {L}earning},
4 | booktitle = {Computer Vision and Pattern Recognition (CVPR)},
5 | month = {June},
6 | year = {2014}
7 | }
8 |
9 | @InProceedings{xie15hed,
10 | author = {"Xie, Saining and Tu, Zhuowen"},
11 | Title = {Holistically-Nested Edge Detection},
12 | Booktitle = "Proceedings of IEEE International Conference on Computer Vision",
13 | Year = {2015},
14 | }
15 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/bibtex/transattr.tex:
--------------------------------------------------------------------------------
1 | @article {Laffont14,
2 | title = {Transient Attributes for High-Level Understanding and Editing of Outdoor Scenes},
3 | author = {Pierre-Yves Laffont and Zhile Ren and Xiaofeng Tao and Chao Qian and James Hays},
4 | journal = {ACM Transactions on Graphics (proceedings of SIGGRAPH)},
5 | volume = {33},
6 | number = {4},
7 | year = {2014}
8 | }
9 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/combine_A_and_B.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from multiprocessing import Pool
4 |
5 | import cv2
6 | import numpy as np
7 |
8 |
9 | def image_write(path_A, path_B, path_AB):
10 | im_A = cv2.imread(
11 | path_A, 1
12 | ) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
13 | im_B = cv2.imread(
14 | path_B, 1
15 | ) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
16 | im_AB = np.concatenate([im_A, im_B], 1)
17 | cv2.imwrite(path_AB, im_AB)
18 |
19 |
20 | parser = argparse.ArgumentParser("create image pairs")
21 | parser.add_argument(
22 | "--fold_A",
23 | dest="fold_A",
24 | help="input directory for image A",
25 | type=str,
26 | default="../dataset/50kshoes_edges",
27 | )
28 | parser.add_argument(
29 | "--fold_B",
30 | dest="fold_B",
31 | help="input directory for image B",
32 | type=str,
33 | default="../dataset/50kshoes_jpg",
34 | )
35 | parser.add_argument(
36 | "--fold_AB",
37 | dest="fold_AB",
38 | help="output directory",
39 | type=str,
40 | default="../dataset/test_AB",
41 | )
42 | parser.add_argument(
43 | "--num_imgs", dest="num_imgs", help="number of images", type=int, default=1000000
44 | )
45 | parser.add_argument(
46 | "--use_AB",
47 | dest="use_AB",
48 | help="if true: (0001_A, 0001_B) to (0001_AB)",
49 | action="store_true",
50 | )
51 | parser.add_argument(
52 | "--no_multiprocessing",
53 | dest="no_multiprocessing",
54 | help="If used, chooses single CPU execution instead of parallel execution",
55 | action="store_true",
56 | default=False,
57 | )
58 | args = parser.parse_args()
59 |
60 | for arg in vars(args):
61 | print("[%s] = " % arg, getattr(args, arg))
62 |
63 | splits = os.listdir(args.fold_A)
64 |
65 | if not args.no_multiprocessing:
66 | pool = Pool()
67 |
68 | for sp in splits:
69 | img_fold_A = os.path.join(args.fold_A, sp)
70 | img_fold_B = os.path.join(args.fold_B, sp)
71 | img_list = os.listdir(img_fold_A)
72 | if args.use_AB:
73 | img_list = [img_path for img_path in img_list if "_A." in img_path]
74 |
75 | num_imgs = min(args.num_imgs, len(img_list))
76 | print("split = %s, use %d/%d images" % (sp, num_imgs, len(img_list)))
77 | img_fold_AB = os.path.join(args.fold_AB, sp)
78 | if not os.path.isdir(img_fold_AB):
79 | os.makedirs(img_fold_AB)
80 | print("split = %s, number of images = %d" % (sp, num_imgs))
81 | for n in range(num_imgs):
82 | name_A = img_list[n]
83 | path_A = os.path.join(img_fold_A, name_A)
84 | if args.use_AB:
85 | name_B = name_A.replace("_A.", "_B.")
86 | else:
87 | name_B = name_A
88 | path_B = os.path.join(img_fold_B, name_B)
89 | if os.path.isfile(path_A) and os.path.isfile(path_B):
90 | name_AB = name_A
91 | if args.use_AB:
92 | name_AB = name_AB.replace("_A.", ".") # remove _A
93 | path_AB = os.path.join(img_fold_AB, name_AB)
94 | if not args.no_multiprocessing:
95 | pool.apply_async(image_write, args=(path_A, path_B, path_AB))
96 | else:
97 | im_A = cv2.imread(
98 | path_A, 1
99 | ) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
100 | im_B = cv2.imread(
101 | path_B, 1
102 | ) # python2: cv2.CV_LOAD_IMAGE_COLOR; python3: cv2.IMREAD_COLOR
103 | im_AB = np.concatenate([im_A, im_B], 1)
104 | cv2.imwrite(path_AB, im_AB)
105 | if not args.no_multiprocessing:
106 | pool.close()
107 | pool.join()
108 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/download_cyclegan_dataset.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "mini" && $FILE != "mini_pix2pix" && $FILE != "mini_colorization" ]]; then
4 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
5 | exit 1
6 | fi
7 |
8 | if [[ $FILE == "cityscapes" ]]; then
9 | echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
10 | echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
11 | exit 1
12 | fi
13 |
14 | echo "Specified [$FILE]"
15 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
16 | ZIP_FILE=./datasets/$FILE.zip
17 | TARGET_DIR=./datasets/$FILE/
18 | wget -N $URL -O $ZIP_FILE
19 | mkdir $TARGET_DIR
20 | unzip $ZIP_FILE -d ./datasets/
21 | rm $ZIP_FILE
22 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/download_pix2pix_dataset.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | if [[ $FILE != "cityscapes" && $FILE != "night2day" && $FILE != "edges2handbags" && $FILE != "edges2shoes" && $FILE != "facades" && $FILE != "maps" ]]; then
4 | echo "Available datasets are cityscapes, night2day, edges2handbags, edges2shoes, facades, maps"
5 | exit 1
6 | fi
7 |
8 | if [[ $FILE == "cityscapes" ]]; then
9 | echo "Due to license issue, we cannot provide the Cityscapes dataset from our repository. Please download the Cityscapes dataset from https://cityscapes-dataset.com, and use the script ./datasets/prepare_cityscapes_dataset.py."
10 | echo "You need to download gtFine_trainvaltest.zip and leftImg8bit_trainvaltest.zip. For further instruction, please read ./datasets/prepare_cityscapes_dataset.py"
11 | exit 1
12 | fi
13 |
14 | echo "Specified [$FILE]"
15 |
16 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/$FILE.tar.gz
17 | TAR_FILE=./datasets/$FILE.tar.gz
18 | TARGET_DIR=./datasets/$FILE/
19 | wget -N $URL -O $TAR_FILE
20 | mkdir -p $TARGET_DIR
21 | tar -zxvf $TAR_FILE -C ./datasets/
22 | rm $TAR_FILE
23 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/make_dataset_aligned.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from PIL import Image
4 |
5 |
6 | def get_file_paths(folder):
7 | image_file_paths = []
8 | for root, dirs, filenames in os.walk(folder):
9 | filenames = sorted(filenames)
10 | for filename in filenames:
11 | input_path = os.path.abspath(root)
12 | file_path = os.path.join(input_path, filename)
13 | if filename.endswith(".png") or filename.endswith(".jpg"):
14 | image_file_paths.append(file_path)
15 |
16 | break # prevent descending into subfolders
17 | return image_file_paths
18 |
19 |
20 | def align_images(a_file_paths, b_file_paths, target_path):
21 | if not os.path.exists(target_path):
22 | os.makedirs(target_path)
23 |
24 | for i in range(len(a_file_paths)):
25 | img_a = Image.open(a_file_paths[i])
26 | img_b = Image.open(b_file_paths[i])
27 | assert img_a.size == img_b.size
28 |
29 | aligned_image = Image.new("RGB", (img_a.size[0] * 2, img_a.size[1]))
30 | aligned_image.paste(img_a, (0, 0))
31 | aligned_image.paste(img_b, (img_a.size[0], 0))
32 | aligned_image.save(os.path.join(target_path, "{:04d}.jpg".format(i)))
33 |
34 |
35 | if __name__ == "__main__":
36 | import argparse
37 |
38 | parser = argparse.ArgumentParser()
39 | parser.add_argument(
40 | "--dataset-path",
41 | dest="dataset_path",
42 | help="Which folder to process (it should have subfolders testA, testB, trainA and trainB",
43 | )
44 | args = parser.parse_args()
45 |
46 | dataset_folder = args.dataset_path
47 | print(dataset_folder)
48 |
49 | test_a_path = os.path.join(dataset_folder, "testA")
50 | test_b_path = os.path.join(dataset_folder, "testB")
51 | test_a_file_paths = get_file_paths(test_a_path)
52 | test_b_file_paths = get_file_paths(test_b_path)
53 | assert len(test_a_file_paths) == len(test_b_file_paths)
54 | test_path = os.path.join(dataset_folder, "test")
55 |
56 | train_a_path = os.path.join(dataset_folder, "trainA")
57 | train_b_path = os.path.join(dataset_folder, "trainB")
58 | train_a_file_paths = get_file_paths(train_a_path)
59 | train_b_file_paths = get_file_paths(train_b_path)
60 | assert len(train_a_file_paths) == len(train_b_file_paths)
61 | train_path = os.path.join(dataset_folder, "train")
62 |
63 | align_images(test_a_file_paths, test_b_file_paths, test_path)
64 | align_images(train_a_file_paths, train_b_file_paths, train_path)
65 |
--------------------------------------------------------------------------------
/F-LSeSim/datasets/prepare_cityscapes_dataset.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 |
4 | from PIL import Image
5 |
6 | help_msg = """
7 | The dataset can be downloaded from https://cityscapes-dataset.com.
8 | Please download the datasets [gtFine_trainvaltest.zip] and [leftImg8bit_trainvaltest.zip] and unzip them.
9 | gtFine contains the semantics segmentations. Use --gtFine_dir to specify the path to the unzipped gtFine_trainvaltest directory.
10 | leftImg8bit contains the dashcam photographs. Use --leftImg8bit_dir to specify the path to the unzipped leftImg8bit_trainvaltest directory.
11 | The processed images will be placed at --output_dir.
12 |
13 | Example usage:
14 |
15 | python prepare_cityscapes_dataset.py --gitFine_dir ./gtFine/ --leftImg8bit_dir ./leftImg8bit --output_dir ./datasets/cityscapes/
16 | """
17 |
18 |
19 | def load_resized_img(path):
20 | return Image.open(path).convert("RGB").resize((256, 256))
21 |
22 |
23 | def check_matching_pair(segmap_path, photo_path):
24 | segmap_identifier = os.path.basename(segmap_path).replace("_gtFine_color", "")
25 | photo_identifier = os.path.basename(photo_path).replace("_leftImg8bit", "")
26 |
27 | assert (
28 | segmap_identifier == photo_identifier
29 | ), "[%s] and [%s] don't seem to be matching. Aborting." % (segmap_path, photo_path)
30 |
31 |
32 | def process_cityscapes(gtFine_dir, leftImg8bit_dir, output_dir, phase):
33 | save_phase = "test" if phase == "val" else "train"
34 | savedir = os.path.join(output_dir, save_phase)
35 | os.makedirs(savedir, exist_ok=True)
36 | os.makedirs(savedir + "A", exist_ok=True)
37 | os.makedirs(savedir + "B", exist_ok=True)
38 | print("Directory structure prepared at %s" % output_dir)
39 |
40 | segmap_expr = os.path.join(gtFine_dir, phase) + "/*/*_color.png"
41 | segmap_paths = glob.glob(segmap_expr)
42 | segmap_paths = sorted(segmap_paths)
43 |
44 | photo_expr = os.path.join(leftImg8bit_dir, phase) + "/*/*_leftImg8bit.png"
45 | photo_paths = glob.glob(photo_expr)
46 | photo_paths = sorted(photo_paths)
47 |
48 | assert len(segmap_paths) == len(
49 | photo_paths
50 | ), "%d images that match [%s], and %d images that match [%s]. Aborting." % (
51 | len(segmap_paths),
52 | segmap_expr,
53 | len(photo_paths),
54 | photo_expr,
55 | )
56 |
57 | for i, (segmap_path, photo_path) in enumerate(zip(segmap_paths, photo_paths)):
58 | check_matching_pair(segmap_path, photo_path)
59 | segmap = load_resized_img(segmap_path)
60 | photo = load_resized_img(photo_path)
61 |
62 | # data for pix2pix where the two images are placed side-by-side
63 | sidebyside = Image.new("RGB", (512, 256))
64 | sidebyside.paste(segmap, (256, 0))
65 | sidebyside.paste(photo, (0, 0))
66 | savepath = os.path.join(savedir, "%d.jpg" % i)
67 | sidebyside.save(savepath, format="JPEG", subsampling=0, quality=100)
68 |
69 | # data for cyclegan where the two images are stored at two distinct directories
70 | savepath = os.path.join(savedir + "A", "%d_A.jpg" % i)
71 | photo.save(savepath, format="JPEG", subsampling=0, quality=100)
72 | savepath = os.path.join(savedir + "B", "%d_B.jpg" % i)
73 | segmap.save(savepath, format="JPEG", subsampling=0, quality=100)
74 |
75 | if i % (len(segmap_paths) // 10) == 0:
76 | print(
77 | "%d / %d: last image saved at %s, " % (i, len(segmap_paths), savepath)
78 | )
79 |
80 |
81 | if __name__ == "__main__":
82 | import argparse
83 |
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument(
86 | "--gtFine_dir",
87 | type=str,
88 | required=True,
89 | help="Path to the Cityscapes gtFine directory.",
90 | )
91 | parser.add_argument(
92 | "--leftImg8bit_dir",
93 | type=str,
94 | required=True,
95 | help="Path to the Cityscapes leftImg8bit_trainvaltest directory.",
96 | )
97 | parser.add_argument(
98 | "--output_dir",
99 | type=str,
100 | required=True,
101 | default="./datasets/cityscapes",
102 | help="Directory the output images will be written to.",
103 | )
104 | opt = parser.parse_args()
105 |
106 | print(help_msg)
107 |
108 | print("Preparing Cityscapes Dataset for val phase")
109 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "val")
110 | print("Preparing Cityscapes Dataset for train phase")
111 | process_cityscapes(opt.gtFine_dir, opt.leftImg8bit_dir, opt.output_dir, "train")
112 |
113 | print("Done")
114 |
--------------------------------------------------------------------------------
/F-LSeSim/evaluations/DC.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pathlib
3 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
4 |
5 | import numpy as np
6 | import torch
7 | from PIL import Image
8 | from scipy import linalg
9 | from torch.nn.functional import adaptive_avg_pool2d
10 |
11 | try:
12 | from tqdm import tqdm
13 | except ImportError:
14 | # If not tqdm is not available, provide a mock version of it
15 | def tqdm(x):
16 | return x
17 |
18 |
19 | from prdc import compute_prdc
20 |
21 | from evaluations.inception import InceptionV3
22 |
23 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
24 | parser.add_argument("--batch-size", type=int, default=100, help="Batch size to use")
25 | parser.add_argument(
26 | "--dims",
27 | type=int,
28 | default=2048,
29 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
30 | help=(
31 | "Dimensionality of Inception features to use. "
32 | "By default, uses pool3 features"
33 | ),
34 | )
35 | parser.add_argument(
36 | "-c", "--gpu", default="", type=str, help="GPU to use (leave blank for CPU only)"
37 | )
38 | parser.add_argument(
39 | "path",
40 | type=str,
41 | nargs=2,
42 | help=("Paths to the generated images or " "to .npz statistic files"),
43 | )
44 |
45 |
46 | def imread(filename):
47 | """
48 | Loads an image file into a (height, width, 3) uint8 ndarray. .resize((229, 229), Image.BILINEAR)
49 | """
50 | return np.asarray(Image.open(filename), dtype=np.uint8)[..., :3]
51 |
52 |
53 | def get_activations(files, model, batch_size=50, dims=2048, cuda=False):
54 | """Calculates the activations of the pool_3 layer for all images.
55 | Params:
56 | -- files : List of image files paths
57 | -- model : Instance of inception model
58 | -- batch_size : Batch size of images for the model to process at once.
59 | Make sure that the number of samples is a multiple of
60 | the batch size, otherwise some samples are ignored. This
61 | behavior is retained to match the original FID score
62 | implementation.
63 | -- dims : Dimensionality of features returned by Inception
64 | -- cuda : If set to True, use GPU
65 | Returns:
66 | -- A numpy array of dimension (num images, dims) that contains the
67 | activations of the given tensor when feeding inception with the
68 | query tensor.
69 | """
70 | model.eval()
71 |
72 | if batch_size > len(files):
73 | print(
74 | (
75 | "Warning: batch size is bigger than the data size. "
76 | "Setting batch size to data size"
77 | )
78 | )
79 | batch_size = len(files)
80 |
81 | pred_arr = np.empty((len(files), dims))
82 |
83 | for i in tqdm(range(0, len(files), batch_size)):
84 | start = i
85 | end = i + batch_size
86 |
87 | images = np.array([imread(str(f)).astype(np.float32) for f in files[start:end]])
88 |
89 | # Reshape to (n_images, 3, height, width)
90 | images = images.transpose((0, 3, 1, 2))
91 | images /= 255
92 |
93 | batch = torch.from_numpy(images).type(torch.FloatTensor)
94 | if cuda:
95 | batch = batch.cuda()
96 |
97 | pred = model(batch)[0]
98 |
99 | # If model output is not scalar, apply global spatial average pooling.
100 | # This happens if you choose a dimensionality not equal 2048.
101 | if pred.size(2) != 1 or pred.size(3) != 1:
102 | pred = adaptive_avg_pool2d(pred, output_size=(1, 1))
103 |
104 | pred_arr[start:end] = pred.cpu().data.numpy().reshape(pred.size(0), -1)
105 |
106 | return pred_arr
107 |
108 |
109 | def _compute_statistics_of_path(path, model, batch_size, dims, cuda):
110 | if path.endswith(".npz"):
111 | f = np.load(path)
112 | m, s = f["mu"][:], f["sigma"][:]
113 | f.close()
114 | else:
115 | path = pathlib.Path(path)
116 | files = list(path.glob("*.jpg")) + list(path.glob("*.png"))
117 | f = get_activations(files, model, batch_size, dims, cuda)
118 |
119 | return f
120 |
121 |
122 | def calculate_DC_given_paths(paths, batch_size, cuda, dims):
123 |
124 | for p in paths:
125 | if not os.path.exists(p):
126 | raise RuntimeError("Invalid path: %s" % p)
127 |
128 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
129 |
130 | model = InceptionV3([block_idx])
131 | if cuda:
132 | model.cuda()
133 |
134 | f0 = _compute_statistics_of_path(paths[0], model, batch_size, dims, cuda)
135 |
136 | f1 = _compute_statistics_of_path(paths[1], model, batch_size, dims, cuda)
137 |
138 | dc_value = compute_prdc(real_features=f0, fake_features=f1, nearest_k=95)
139 |
140 | return dc_value
141 |
142 |
143 | def main():
144 | args = parser.parse_args()
145 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
146 |
147 | dc_value = calculate_DC_given_paths(
148 | args.path, args.batch_size, args.gpu != "", args.dims
149 | )
150 | print(dc_value)
151 |
152 |
153 | if __name__ == "__main__":
154 | main()
155 |
--------------------------------------------------------------------------------
/F-LSeSim/evaluations/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/F-LSeSim/evaluations/__init__.py
--------------------------------------------------------------------------------
/F-LSeSim/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 |
23 | from models.base_model import BaseModel
24 |
25 |
26 | def find_model_using_name(model_name):
27 | """Import the module "models/[model_name]_model.py".
28 |
29 | In the file, the class called DatasetNameModel() will
30 | be instantiated. It has to be a subclass of BaseModel,
31 | and it is case-insensitive.
32 | """
33 | model_filename = "models." + model_name + "_model"
34 | modellib = importlib.import_module(model_filename)
35 | model = None
36 | target_model_name = model_name.replace("_", "") + "model"
37 | for name, cls in modellib.__dict__.items():
38 | if name.lower() == target_model_name.lower() and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print(
43 | "In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase."
44 | % (model_filename, target_model_name)
45 | )
46 | exit(0)
47 |
48 | return model
49 |
50 |
51 | def get_option_setter(model_name):
52 | """Return the static method of the model class."""
53 | model_class = find_model_using_name(model_name)
54 | return model_class.modify_commandline_options
55 |
56 |
57 | def create_model(opt, norm_cfg):
58 | """Create a model given the option.
59 |
60 | This function warps the class CustomDatasetDataLoader.
61 | This is the main interface between this package and 'train.py'/'test.py'
62 |
63 | Example:
64 | >>> from models import create_model
65 | >>> model = create_model(opt)
66 | """
67 | model = find_model_using_name(opt.model)
68 | instance = model(opt, norm_cfg=norm_cfg)
69 | print("model [%s] was created" % type(instance).__name__)
70 | return instance
71 |
--------------------------------------------------------------------------------
/F-LSeSim/models/colorization_model.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from skimage import color # used for lab2rgb
4 |
5 | from .pix2pix_model import Pix2PixModel
6 |
7 |
8 | class ColorizationModel(Pix2PixModel):
9 | """This is a subclass of Pix2PixModel for image colorization (black & white image -> colorful images).
10 |
11 | The model training requires '-dataset_model colorization' dataset.
12 | It trains a pix2pix model, mapping from L channel to ab channels in Lab color space.
13 | By default, the colorization dataset will automatically set '--input_nc 1' and '--output_nc 2'.
14 | """
15 |
16 | @staticmethod
17 | def modify_commandline_options(parser, is_train=True):
18 | """Add new dataset-specific options, and rewrite default values for existing options.
19 |
20 | Parameters:
21 | parser -- original option parser
22 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
23 |
24 | Returns:
25 | the modified parser.
26 |
27 | By default, we use 'colorization' dataset for this model.
28 | See the original pix2pix paper (https://arxiv.org/pdf/1611.07004.pdf) and colorization results (Figure 9 in the paper)
29 | """
30 | Pix2PixModel.modify_commandline_options(parser, is_train)
31 | parser.set_defaults(dataset_mode="colorization")
32 | return parser
33 |
34 | def __init__(self, opt):
35 | """Initialize the class.
36 |
37 | Parameters:
38 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
39 |
40 | For visualization, we set 'visual_names' as 'real_A' (input real image),
41 | 'real_B_rgb' (ground truth RGB image), and 'fake_B_rgb' (predicted RGB image)
42 | We convert the Lab image 'real_B' (inherited from Pix2pixModel) to a RGB image 'real_B_rgb'.
43 | we convert the Lab image 'fake_B' (inherited from Pix2pixModel) to a RGB image 'fake_B_rgb'.
44 | """
45 | # reuse the pix2pix model
46 | Pix2PixModel.__init__(self, opt)
47 | # specify the images to be visualized.
48 | self.visual_names = ["real_A", "real_B_rgb", "fake_B_rgb"]
49 |
50 | def lab2rgb(self, L, AB):
51 | """Convert an Lab tensor image to a RGB numpy output
52 | Parameters:
53 | L (1-channel tensor array): L channel images (range: [-1, 1], torch tensor array)
54 | AB (2-channel tensor array): ab channel images (range: [-1, 1], torch tensor array)
55 |
56 | Returns:
57 | rgb (RGB numpy image): rgb output images (range: [0, 255], numpy array)
58 | """
59 | AB2 = AB * 110.0
60 | L2 = (L + 1.0) * 50.0
61 | Lab = torch.cat([L2, AB2], dim=1)
62 | Lab = Lab[0].data.cpu().float().numpy()
63 | Lab = np.transpose(Lab.astype(np.float64), (1, 2, 0))
64 | rgb = color.lab2rgb(Lab) * 255
65 | return rgb
66 |
67 | def compute_visuals(self):
68 | """Calculate additional output images for visdom and HTML visualization"""
69 | self.real_B_rgb = self.lab2rgb(self.real_A, self.real_B)
70 | self.fake_B_rgb = self.lab2rgb(self.real_A, self.fake_B)
71 |
--------------------------------------------------------------------------------
/F-LSeSim/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.downsample import Downsample
6 | from models.normalization import make_norm_layer
7 |
8 |
9 | class DiscriminatorBasicBlock(nn.Module):
10 | def __init__(
11 | self,
12 | in_features,
13 | out_features,
14 | do_downsample=True,
15 | do_instancenorm=True,
16 | norm_cfg=None,
17 | ):
18 | super().__init__()
19 |
20 | self.do_downsample = do_downsample
21 | self.do_instancenorm = do_instancenorm
22 | self.norm_cfg = norm_cfg or {'type': 'in'}
23 | self.norm_cfg = {k.lower(): v for k, v in self.norm_cfg.items()}
24 |
25 | self.conv = nn.Conv2d(
26 | in_features, out_features, kernel_size=4, stride=1, padding=1
27 | )
28 | self.leakyrelu = nn.LeakyReLU(0.2, True)
29 |
30 | if do_instancenorm:
31 | self.instancenorm = make_norm_layer(self.norm_cfg, num_features=out_features)
32 |
33 | if do_downsample:
34 | self.downsample = Downsample(out_features)
35 |
36 | def forward(self, x):
37 | x = self.conv(x)
38 | if self.do_instancenorm:
39 | x = self.instancenorm(x)
40 | x = self.leakyrelu(x)
41 | if self.do_downsample:
42 | x = self.downsample(x)
43 | return x
44 |
45 |
46 | class Discriminator(nn.Module):
47 | def __init__(self, in_channels=3, features=64, avg_pooling=False):
48 | super().__init__()
49 | self.block1 = DiscriminatorBasicBlock(
50 | in_channels, features, do_downsample=True, do_instancenorm=False
51 | )
52 | self.block2 = DiscriminatorBasicBlock(
53 | features, features * 2, do_downsample=True, do_instancenorm=True
54 | )
55 | self.block3 = DiscriminatorBasicBlock(
56 | features * 2, features * 4, do_downsample=True, do_instancenorm=True
57 | )
58 | self.block4 = DiscriminatorBasicBlock(
59 | features * 4, features * 8, do_downsample=False, do_instancenorm=True
60 | )
61 | self.conv = nn.Conv2d(features * 8, 1, kernel_size=4, stride=1, padding=1)
62 | self.avg_pooling = avg_pooling
63 |
64 | def forward(self, x):
65 | x = self.block1(x)
66 | x = self.block2(x)
67 | x = self.block3(x)
68 | x = self.block4(x)
69 | x = self.conv(x)
70 | if self.avg_pooling:
71 | x = F.avg_pool2d(x, x.size()[2:])
72 | x = torch.flatten(x, 1)
73 | return x
74 |
75 | def set_requires_grad(self, requires_grad=False):
76 | for param in self.parameters():
77 | param.requires_grad = requires_grad
78 |
79 |
80 | if __name__ == "__main__":
81 | x = torch.randn((5, 3, 256, 256))
82 | print(x.shape)
83 | model = Discriminator(in_channels=3, avg_pooling=True)
84 | preds = model(x)
85 | print(preds.shape)
86 |
--------------------------------------------------------------------------------
/F-LSeSim/models/downsample.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Downsample(nn.Module):
5 | def __init__(self, features):
6 | super().__init__()
7 | self.reflectionpad = nn.ReflectionPad2d(1)
8 | self.conv = nn.Conv2d(features, features, kernel_size=3, stride=2)
9 |
10 | def forward(self, x):
11 | x = self.reflectionpad(x)
12 | x = self.conv(x)
13 | return x
14 |
--------------------------------------------------------------------------------
/F-LSeSim/models/networks.py:
--------------------------------------------------------------------------------
1 | from torch.optim import lr_scheduler
2 |
3 | from models.generator import Generator
4 |
5 | from . import cyclegan_networks, stylegan_networks
6 |
7 |
8 | ##################################################################################
9 | # Networks
10 | ##################################################################################
11 | def define_G(
12 | input_nc,
13 | output_nc,
14 | ngf,
15 | netG,
16 | norm="batch",
17 | use_dropout=False,
18 | init_type="normal",
19 | init_gain=0.02,
20 | no_antialias=False,
21 | no_antialias_up=False,
22 | gpu_ids=[],
23 | opt=None,
24 | norm_cfg=None,
25 | ):
26 | """
27 | Create a generator
28 | :param input_nc: the number of channels in input images
29 | :param output_nc: the number of channels in output images
30 | :param ngf: the number of filters in the first conv layer
31 | :param netG: the architecture's name: resnet_9blocks | munit | stylegan2
32 | :param norm: the name of normalization layers used in the network: batch | instance | none
33 | :param use_dropout: if use dropout layers.
34 | :param init_type: the name of our initialization method.
35 | :param init_gain: scaling factor for normal, xavier and orthogonal.
36 | :param no_antialias: use learned down sampling layer or not
37 | :param no_antialias_up: use learned up sampling layer or not
38 | :param gpu_ids: which GPUs the network runs on: e.g., 0,1,2
39 | :param opt: options
40 | :return:
41 | """
42 |
43 | if netG == "resnet_9blocks":
44 | net = Generator(norm_cfg=norm_cfg or {'type': 'in'})
45 | elif netG == "stylegan2":
46 | net = stylegan_networks.StyleGAN2Generator(input_nc, output_nc, ngf, opt=opt)
47 | else:
48 | raise NotImplementedError("Generator model name [%s] is not recognized" % netG)
49 | return cyclegan_networks.init_net(
50 | net, init_type, init_gain, gpu_ids, initialize_weights=("stylegan2" not in netG)
51 | )
52 |
53 |
54 | def define_D(
55 | input_nc,
56 | ndf,
57 | netD,
58 | n_layers_D=3,
59 | norm="batch",
60 | init_type="normal",
61 | init_gain=0.02,
62 | no_antialias=False,
63 | gpu_ids=[],
64 | opt=None,
65 | ):
66 | """
67 | Create a discriminator
68 | :param input_nc: the number of channels in input images
69 | :param ndf: the number of filters in the first conv layer
70 | :param netD: the architecture's name
71 | :param n_layers_D: the number of conv layers in the discriminator; effective when netD=='n_layers'
72 | :param norm: the type of normalization layers used in the network
73 | :param init_type: the name of the initialization method
74 | :param init_gain: scaling factor for normal, xavier and orthogonal
75 | :param no_antialias: use learned down sampling layer or not
76 | :param gpu_ids: which GPUs the network runs on: e.g., 0,1,2
77 | :param opt: options
78 | :return:
79 | """
80 | norm_value = cyclegan_networks.get_norm_layer(norm)
81 | if netD == "basic":
82 | net = cyclegan_networks.NLayerDiscriminator(
83 | input_nc, ndf, n_layers_D, norm_value, no_antialias
84 | )
85 | elif netD == "bimulti":
86 | net = cyclegan_networks.D_NLayersMulti(
87 | input_nc, ndf, n_layers=n_layers_D, norm_layer=norm_value, num_D=2
88 | )
89 | elif "stylegan2" in netD:
90 | net = stylegan_networks.StyleGAN2Discriminator(input_nc, ndf, opt=opt)
91 | else:
92 | raise NotImplementedError(
93 | "Discriminator model name [%s] is not recognized" % netD
94 | )
95 | return cyclegan_networks.init_net(
96 | net, init_type, init_gain, gpu_ids, initialize_weights=("stylegan2" not in netD)
97 | )
98 |
99 |
100 | ###############################################################################
101 | # Helper Functions
102 | ###############################################################################
103 | def get_scheduler(optimizer, opt):
104 | """Return a learning rate scheduler
105 |
106 | Parameters:
107 | optimizer -- the optimizer of the network
108 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
109 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
110 |
111 | For 'linear', we keep the same learning rate for the first epochs
112 | and linearly decay the rate to zero over the next epochs.
113 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
114 | See https://pytorch.org/docs/stable/optim.html for more details.
115 | """
116 | if opt.lr_policy == "linear":
117 |
118 | def lambda_rule(epoch):
119 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(
120 | opt.n_epochs_decay + 1
121 | )
122 | return lr_l
123 |
124 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
125 | elif opt.lr_policy == "step":
126 | scheduler = lr_scheduler.StepLR(
127 | optimizer, step_size=opt.lr_decay_iters, gamma=0.1
128 | )
129 | elif opt.lr_policy == "plateau":
130 | scheduler = lr_scheduler.ReduceLROnPlateau(
131 | optimizer, mode="min", factor=0.2, threshold=0.01, patience=5
132 | )
133 | elif opt.lr_policy == "cosine":
134 | scheduler = lr_scheduler.CosineAnnealingLR(
135 | optimizer, T_max=opt.n_epochs, eta_min=0
136 | )
137 | else:
138 | return NotImplementedError(
139 | "learning rate policy [%s] is not implemented", opt.lr_policy
140 | )
141 | return scheduler
142 |
--------------------------------------------------------------------------------
/F-LSeSim/models/normalization.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Any, Dict
3 |
4 | import torch.nn as nn
5 |
6 | from models.kin import KernelizedInstanceNorm
7 | from models.tin import ThumbInstanceNorm
8 |
9 |
10 | # TODO: To be deprecated
11 | def get_normalization_layer(num_features, normalization="kin"):
12 | if normalization == "kin":
13 | return KernelizedInstanceNorm(num_features=num_features)
14 | elif normalization == "tin":
15 | return ThumbInstanceNorm(num_features=num_features)
16 | elif normalization == "in":
17 | return nn.InstanceNorm2d(num_features)
18 | else:
19 | raise NotImplementedError
20 |
21 |
22 | def make_norm_layer(norm_cfg: Dict[str, Any], **kwargs: Any):
23 | """
24 | Create normalization layer based on given config and arguments.
25 |
26 | Args:
27 | norm_cfg (Dict[str, Any]): A dict of keyword arguments of normalization layer.
28 | It must have a key 'type' to specify which normalization layers will be used.
29 | It accepts upper case argument.
30 | **kwargs (Any): The keyword arguments are used to overwrite `norm_cfg`.
31 |
32 | Returns:
33 | nn.Module: A layer object.
34 | """
35 | norm_cfg = deepcopy(norm_cfg)
36 | norm_cfg = {k.lower(): v for k, v in norm_cfg.items()}
37 |
38 | norm_cfg.update(kwargs)
39 |
40 | if 'type' not in norm_cfg:
41 | raise ValueError('"type" wasn\'t specified.')
42 |
43 | norm_type = norm_cfg['type']
44 | del norm_cfg['type']
45 |
46 | if norm_type == 'in':
47 | return nn.InstanceNorm2d(**norm_cfg)
48 | elif norm_type == 'tin':
49 | return ThumbInstanceNorm(**norm_cfg)
50 | elif norm_type == 'kin':
51 | return KernelizedInstanceNorm(**norm_cfg)
52 | else:
53 | raise ValueError(f'Unknown norm type: {norm_type}.')
54 |
--------------------------------------------------------------------------------
/F-LSeSim/models/sinsc_model.py:
--------------------------------------------------------------------------------
1 | from .sc_model import SCModel
2 |
3 |
4 | class SinSCModel(SCModel):
5 | """
6 | This class implements the single image translation
7 | """
8 |
9 | @staticmethod
10 | def modify_commandline_options(parser, is_train=True):
11 | """
12 | :param parser: original options parser
13 | :param is_train: whether training phase or test phase. You can use this flag to add training-specific or test-specific options
14 | :return: the modified parser
15 | """
16 | parser = SCModel.modify_commandline_options(parser, is_train)
17 |
18 | parser.set_defaults(
19 | dataset_mode="singleimage",
20 | netG="stylegan2",
21 | stylegan2_G_num_downsampling=2,
22 | netD="stylegan2",
23 | gan_mode="nonsaturating",
24 | num_patches=1,
25 | attn_layers="4,7,9",
26 | lambda_spatial=10.0,
27 | lambda_identity=0.0,
28 | lambda_gradient=1.0,
29 | lambda_spatial_idt=0.0,
30 | ngf=8,
31 | ndf=8,
32 | lr=0.001,
33 | beta1=0.0,
34 | beta2=0.99,
35 | load_size=1024,
36 | crop_size=128,
37 | preprocess="zoom_and_patch",
38 | D_patch_size=None,
39 | )
40 |
41 | if is_train:
42 | parser.set_defaults(
43 | preprocess="zoom_and_patch",
44 | batch_size=16,
45 | save_epoch_freq=1,
46 | save_latest_freq=20000,
47 | n_epochs=4,
48 | n_epochs_decay=4,
49 | )
50 | else:
51 | parser.set_defaults(
52 | preprocess="none", # load the whole image as it is
53 | batch_size=1,
54 | num_test=1,
55 | )
56 |
57 | return parser
58 |
59 | def __init__(self, opt):
60 | super().__init__(opt)
61 |
--------------------------------------------------------------------------------
/F-LSeSim/models/test_model.py:
--------------------------------------------------------------------------------
1 | from . import networks
2 | from .base_model import BaseModel
3 |
4 |
5 | class TestModel(BaseModel):
6 | """This TesteModel can be used to generate CycleGAN results for only one direction.
7 | This model will automatically set '--dataset_mode single', which only loads the images from one collection.
8 |
9 | See the test instruction for more details.
10 | """
11 |
12 | @staticmethod
13 | def modify_commandline_options(parser, is_train=True):
14 | """Add new dataset-specific options, and rewrite default values for existing options.
15 |
16 | Parameters:
17 | parser -- original option parser
18 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
19 |
20 | Returns:
21 | the modified parser.
22 |
23 | The model can only be used during test time. It requires '--dataset_mode single'.
24 | You need to specify the network using the option '--model_suffix'.
25 | """
26 | assert not is_train, "TestModel cannot be used during training time"
27 | parser.set_defaults(dataset_mode="single")
28 | parser.add_argument(
29 | "--model_suffix",
30 | type=str,
31 | default="",
32 | help="In checkpoints_dir, [epoch]_net_G[model_suffix].pth will be loaded as the generator.",
33 | )
34 |
35 | return parser
36 |
37 | def __init__(self, opt):
38 | """Initialize the pix2pix class.
39 |
40 | Parameters:
41 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
42 | """
43 | assert not opt.isTrain
44 | BaseModel.__init__(self, opt)
45 | # specify the training losses you want to print out. The training/test scripts will call
46 | self.loss_names = []
47 | # specify the images you want to save/display. The training/test scripts will call
48 | self.visual_names = ["real", "fake"]
49 | # specify the models you want to save to the disk. The training/test scripts will call and
50 | self.model_names = ["G" + opt.model_suffix] # only generator is needed.
51 | self.netG = networks.define_G(
52 | opt.input_nc,
53 | opt.output_nc,
54 | opt.ngf,
55 | opt.netG,
56 | opt.norm,
57 | not opt.no_dropout,
58 | opt.init_type,
59 | opt.init_gain,
60 | self.gpu_ids,
61 | )
62 |
63 | # assigns the model to self.netG_[suffix] so that it can be loaded
64 | # please see
65 | setattr(self, "netG" + opt.model_suffix, self.netG) # store netG in self.
66 |
67 | def set_input(self, input):
68 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
69 |
70 | Parameters:
71 | input: a dictionary that contains the data itself and its metadata information.
72 |
73 | We need to use 'single_dataset' dataset mode. It only load images from one domain.
74 | """
75 | self.real = input["A"].to(self.device)
76 | self.image_paths = input["A_paths"]
77 |
78 | def forward(self):
79 | """Run forward pass."""
80 | self.fake = self.netG(self.real) # G(real)
81 |
82 | def optimize_parameters(self):
83 | """No optimization for test model."""
84 | pass
85 |
--------------------------------------------------------------------------------
/F-LSeSim/models/tin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ThumbInstanceNorm(nn.Module):
6 | def __init__(self, num_features, affine=True):
7 | super(ThumbInstanceNorm, self).__init__()
8 | self.thumb_mean = None
9 | self.thumb_std = None
10 | self.normal_instance_normalization = False
11 | self.collection_mode = False
12 | if affine == True:
13 | self.weight = nn.Parameter(
14 | torch.ones(size=(1, num_features, 1, 1), requires_grad=True)
15 | )
16 | self.bias = nn.Parameter(
17 | torch.zeros(size=(1, num_features, 1, 1), requires_grad=True)
18 | )
19 |
20 | def calc_mean_std(self, feat, eps=1e-5):
21 | size = feat.size()
22 | assert len(size) == 4
23 | N, C = size[:2]
24 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
25 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
26 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
27 | return feat_mean, feat_std
28 |
29 | def forward(self, x):
30 | if self.training or self.normal_instance_normalization:
31 | x_mean, x_std = self.calc_mean_std(x)
32 | x = (x - x_mean) / x_std * self.weight + self.bias
33 | return x
34 | else:
35 | if self.collection_mode:
36 | assert x.shape[0] == 1
37 | x_mean, x_std = self.calc_mean_std(x)
38 | self.thumb_mean = x_mean
39 | self.thumb_std = x_std
40 |
41 | x = (x - self.thumb_mean) / self.thumb_std * self.weight + self.bias
42 | return x
43 |
44 |
45 | def not_use_thumbnail_instance_norm(model):
46 | for _, layer in model.named_modules():
47 | if isinstance(layer, ThumbInstanceNorm):
48 | layer.collection_mode = False
49 | layer.normal_instance_normalization = True
50 |
51 |
52 | def init_thumbnail_instance_norm(model):
53 | for _, layer in model.named_modules():
54 | if isinstance(layer, ThumbInstanceNorm):
55 | layer.collection_mode = True
56 |
57 |
58 | def use_thumbnail_instance_norm(model):
59 | for _, layer in model.named_modules():
60 | if isinstance(layer, ThumbInstanceNorm):
61 | layer.collection_mode = False
62 | layer.normal_instance_normalization = False
63 |
--------------------------------------------------------------------------------
/F-LSeSim/models/upsample.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Upsample(nn.Module):
5 | def __init__(self, features):
6 | super().__init__()
7 | layers = [
8 | nn.ReplicationPad2d(1),
9 | nn.ConvTranspose2d(features, features, kernel_size=4, stride=2, padding=3),
10 | ]
11 | self.model = nn.Sequential(*layers)
12 |
13 | def forward(self, input):
14 | return self.model(input)
15 |
--------------------------------------------------------------------------------
/F-LSeSim/models/util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from scipy import signal
4 |
5 |
6 | def gkern(kernlen=1, std=3):
7 | """Returns a 2D Gaussian kernel array."""
8 | gkern1d = signal.gaussian(kernlen, std=std).reshape(kernlen, 1)
9 | gkern2d = np.outer(gkern1d, gkern1d)
10 | return gkern2d
11 |
12 |
13 | def get_kernel(padding=1, gaussian_std=3, mode="constant"):
14 | kernel_size = padding * 2 + 1
15 | if mode == "constant":
16 | kernel = torch.ones(kernel_size, kernel_size)
17 | kernel = kernel / (kernel_size * kernel_size)
18 |
19 | elif mode == "gaussian":
20 | kernel = gkern(kernel_size, std=gaussian_std)
21 | kernel = kernel / kernel.sum()
22 | kernel = torch.from_numpy(kernel.astype(np.float32))
23 |
24 | else:
25 | raise NotImplementedError
26 |
27 | return kernel
--------------------------------------------------------------------------------
/F-LSeSim/options/__init__.py:
--------------------------------------------------------------------------------
1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
2 |
--------------------------------------------------------------------------------
/F-LSeSim/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument(
13 | "--results_dir", type=str, default="./results/", help="saves results here."
14 | )
15 | parser.add_argument(
16 | "--aspect_ratio",
17 | type=float,
18 | default=1.0,
19 | help="aspect ratio of result images",
20 | )
21 | parser.add_argument(
22 | "--phase", type=str, default="test", help="train, val, test, etc"
23 | )
24 | # Dropout and Batchnorm has different behavioir during training and test.
25 | parser.add_argument(
26 | "--eval", action="store_true", help="use eval mode during test time."
27 | )
28 | parser.add_argument(
29 | "--num_test", type=int, default=50, help="how many test images to run"
30 | )
31 | # rewrite devalue values
32 | parser.set_defaults(model="test")
33 | # To avoid cropping, the load_size should be the same as crop_size
34 | parser.set_defaults(load_size=parser.get_default("crop_size"))
35 | self.isTrain = False
36 | return parser
37 |
--------------------------------------------------------------------------------
/F-LSeSim/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | """This class includes training options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser)
12 | # visdom and HTML visualization parameters
13 | parser.add_argument(
14 | "--display_freq",
15 | type=int,
16 | default=250,
17 | help="frequency of showing training results on screen",
18 | )
19 | parser.add_argument(
20 | "--display_ncols",
21 | type=int,
22 | default=4,
23 | help="if positive, display all images in a single visdom web panel with certain number of images per row.",
24 | )
25 | parser.add_argument(
26 | "--display_id", type=int, default=None, help="window id of the web display"
27 | )
28 | parser.add_argument(
29 | "--display_server",
30 | type=str,
31 | default="http://localhost",
32 | help="visdom server of the web display",
33 | )
34 | parser.add_argument(
35 | "--display_env",
36 | type=str,
37 | default="main",
38 | help='visdom display environment name (default is "main")',
39 | )
40 | parser.add_argument(
41 | "--display_port",
42 | type=int,
43 | default=8097,
44 | help="visdom port of the web display",
45 | )
46 | parser.add_argument(
47 | "--update_html_freq",
48 | type=int,
49 | default=1000,
50 | help="frequency of saving training results to html",
51 | )
52 | parser.add_argument(
53 | "--print_freq",
54 | type=int,
55 | default=100,
56 | help="frequency of showing training results on console",
57 | )
58 | parser.add_argument(
59 | "--no_html",
60 | action="store_true",
61 | help="do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/",
62 | )
63 | # network saving and loading parameters
64 | parser.add_argument(
65 | "--save_latest_freq",
66 | type=int,
67 | default=5000,
68 | help="frequency of saving the latest results",
69 | )
70 | parser.add_argument(
71 | "--save_epoch_freq",
72 | type=int,
73 | default=10,
74 | help="frequency of saving checkpoints at the end of epochs",
75 | )
76 | parser.add_argument(
77 | "--save_by_iter",
78 | action="store_true",
79 | help="whether saves model by iteration",
80 | )
81 | parser.add_argument(
82 | "--continue_train",
83 | action="store_true",
84 | help="continue training: load the latest model",
85 | )
86 | parser.add_argument(
87 | "--epoch_count",
88 | type=int,
89 | default=1,
90 | help="the starting epoch count, we save the model by , +, ...",
91 | )
92 | parser.add_argument(
93 | "--phase", type=str, default="train", help="train, val, test, etc"
94 | )
95 | # training parameters
96 | parser.add_argument(
97 | "--n_epochs",
98 | type=int,
99 | default=100,
100 | help="number of epochs with the initial learning rate",
101 | )
102 | parser.add_argument(
103 | "--n_epochs_decay",
104 | type=int,
105 | default=100,
106 | help="number of epochs to linearly decay learning rate to zero",
107 | )
108 | parser.add_argument(
109 | "--beta1", type=float, default=0.5, help="momentum term of adam"
110 | )
111 | parser.add_argument(
112 | "--beta2", type=float, default=0.999, help="momentum term of adam"
113 | )
114 | parser.add_argument(
115 | "--lr", type=float, default=0.0001, help="initial learning rate for adam"
116 | )
117 | parser.add_argument(
118 | "--gan_mode",
119 | type=str,
120 | default="lsgan",
121 | help="the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.",
122 | )
123 | parser.add_argument(
124 | "--pool_size",
125 | type=int,
126 | default=50,
127 | help="the size of image buffer that stores previously generated images",
128 | )
129 | parser.add_argument(
130 | "--lr_policy",
131 | type=str,
132 | default="linear",
133 | help="learning rate policy. [linear | step | plateau | cosine]",
134 | )
135 | parser.add_argument(
136 | "--lr_decay_iters",
137 | type=int,
138 | default=50,
139 | help="multiply by a gamma every lr_decay_iters iterations",
140 | )
141 | parser.add_argument(
142 | "-c", "--config", type=str, help="extra config for compatibility"
143 | )
144 | self.isTrain = True
145 | return parser
146 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/conda_deps.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | conda install numpy pyyaml mkl mkl-include setuptools cmake cffi typing
3 | conda install pytorch torchvision -c pytorch # add cuda90 if CUDA 9
4 | conda install visdom dominate -c conda-forge # install visdom and dominate
5 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/download_cyclegan_model.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | echo "Note: available models are apple2orange, orange2apple, summer2winter_yosemite, winter2summer_yosemite, horse2zebra, zebra2horse, monet2photo, style_monet, style_cezanne, style_ukiyoe, style_vangogh, sat2map, map2sat, cityscapes_photo2label, cityscapes_label2photo, facades_photo2label, facades_label2photo, iphone2dslr_flower"
4 |
5 | echo "Specified [$FILE]"
6 |
7 | mkdir -p ./checkpoints/${FILE}_pretrained
8 | MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth
9 | URL=http://efrosgans.eecs.berkeley.edu/cyclegan/pretrained_models/$FILE.pth
10 |
11 | wget -N $URL -O $MODEL_FILE
12 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/download_pix2pix_model.sh:
--------------------------------------------------------------------------------
1 | FILE=$1
2 |
3 | echo "Note: available models are edges2shoes, sat2map, map2sat, facades_label2photo, and day2night"
4 | echo "Specified [$FILE]"
5 |
6 | mkdir -p ./checkpoints/${FILE}_pretrained
7 | MODEL_FILE=./checkpoints/${FILE}_pretrained/latest_net_G.pth
8 | URL=http://efrosgans.eecs.berkeley.edu/pix2pix/models-pytorch/$FILE.pth
9 |
10 | wget -N $URL -O $MODEL_FILE
11 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/edges/PostprocessHED.m:
--------------------------------------------------------------------------------
1 | %%% Prerequisites
2 | % You need to get the cpp file edgesNmsMex.cpp from https://raw.githubusercontent.com/pdollar/edges/master/private/edgesNmsMex.cpp
3 | % and compile it in Matlab: mex edgesNmsMex.cpp
4 | % You also need to download and install Piotr's Computer Vision Matlab Toolbox: https://pdollar.github.io/toolbox/
5 |
6 | %%% parameters
7 | % hed_mat_dir: the hed mat file directory (the output of 'batch_hed.py')
8 | % edge_dir: the output HED edges directory
9 | % image_width: resize the edge map to [image_width, image_width]
10 | % threshold: threshold for image binarization (default 25.0/255.0)
11 | % small_edge: remove small edges (default 5)
12 |
13 | function [] = PostprocessHED(hed_mat_dir, edge_dir, image_width, threshold, small_edge)
14 |
15 | if ~exist(edge_dir, 'dir')
16 | mkdir(edge_dir);
17 | end
18 | fileList = dir(fullfile(hed_mat_dir, '*.mat'));
19 | nFiles = numel(fileList);
20 | fprintf('find %d mat files\n', nFiles);
21 |
22 | for n = 1 : nFiles
23 | if mod(n, 1000) == 0
24 | fprintf('process %d/%d images\n', n, nFiles);
25 | end
26 | fileName = fileList(n).name;
27 | filePath = fullfile(hed_mat_dir, fileName);
28 | jpgName = strrep(fileName, '.mat', '.jpg');
29 | edge_path = fullfile(edge_dir, jpgName);
30 |
31 | if ~exist(edge_path, 'file')
32 | E = GetEdge(filePath);
33 | E = imresize(E,[image_width,image_width]);
34 | E_simple = SimpleEdge(E, threshold, small_edge);
35 | E_simple = uint8(E_simple*255);
36 | imwrite(E_simple, edge_path, 'Quality',100);
37 | end
38 | end
39 | end
40 |
41 |
42 |
43 |
44 | function [E] = GetEdge(filePath)
45 | load(filePath);
46 | E = 1-edge_predict;
47 | end
48 |
49 | function [E4] = SimpleEdge(E, threshold, small_edge)
50 | if nargin <= 1
51 | threshold = 25.0/255.0;
52 | end
53 |
54 | if nargin <= 2
55 | small_edge = 5;
56 | end
57 |
58 | if ndims(E) == 3
59 | E = E(:,:,1);
60 | end
61 |
62 | E1 = 1 - E;
63 | E2 = EdgeNMS(E1);
64 | E3 = double(E2>=max(eps,threshold));
65 | E3 = bwmorph(E3,'thin',inf);
66 | E4 = bwareaopen(E3, small_edge);
67 | E4=1-E4;
68 | end
69 |
70 | function [E_nms] = EdgeNMS( E )
71 | E=single(E);
72 | [Ox,Oy] = gradient2(convTri(E,4));
73 | [Oxx,~] = gradient2(Ox);
74 | [Oxy,Oyy] = gradient2(Oy);
75 | O = mod(atan(Oyy.*sign(-Oxy)./(Oxx+1e-5)),pi);
76 | E_nms = edgesNmsMex(E,O,1,5,1.01,1);
77 | end
78 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/edges/batch_hed.py:
--------------------------------------------------------------------------------
1 | # HED batch processing script; modified from https://github.com/s9xie/hed/blob/master/examples/hed/HED-tutorial.ipynb
2 | # Step 1: download the hed repo: https://github.com/s9xie/hed
3 | # Step 2: download the models and protoxt, and put them under {caffe_root}/examples/hed/
4 | # Step 3: put this script under {caffe_root}/examples/hed/
5 | # Step 4: run the following script:
6 | # python batch_hed.py --images_dir=/data/to/path/photos/ --hed_mat_dir=/data/to/path/hed_mat_files/
7 | # The code sometimes crashes after computation is done. Error looks like "Check failed: ... driver shutting down". You can just kill the job.
8 | # For large images, it will produce gpu memory issue. Therefore, you better resize the images before running this script.
9 | # Step 5: run the MATLAB post-processing script "PostprocessHED.m"
10 |
11 |
12 | import argparse
13 | import os
14 | import sys
15 |
16 | import caffe
17 | import numpy as np
18 | import scipy.io as sio
19 | from PIL import Image
20 |
21 |
22 | def parse_args():
23 | parser = argparse.ArgumentParser(description="batch proccesing: photos->edges")
24 | parser.add_argument(
25 | "--caffe_root", dest="caffe_root", help="caffe root", default="../../", type=str
26 | )
27 | parser.add_argument(
28 | "--caffemodel",
29 | dest="caffemodel",
30 | help="caffemodel",
31 | default="./hed_pretrained_bsds.caffemodel",
32 | type=str,
33 | )
34 | parser.add_argument(
35 | "--prototxt",
36 | dest="prototxt",
37 | help="caffe prototxt file",
38 | default="./deploy.prototxt",
39 | type=str,
40 | )
41 | parser.add_argument(
42 | "--images_dir",
43 | dest="images_dir",
44 | help="directory to store input photos",
45 | type=str,
46 | )
47 | parser.add_argument(
48 | "--hed_mat_dir",
49 | dest="hed_mat_dir",
50 | help="directory to store output hed edges in mat file",
51 | type=str,
52 | )
53 | parser.add_argument(
54 | "--border", dest="border", help="padding border", type=int, default=128
55 | )
56 | parser.add_argument("--gpu_id", dest="gpu_id", help="gpu id", type=int, default=1)
57 | args = parser.parse_args()
58 | return args
59 |
60 |
61 | args = parse_args()
62 | for arg in vars(args):
63 | print("[%s] =" % arg, getattr(args, arg))
64 | # Make sure that caffe is on the python path:
65 | caffe_root = (
66 | args.caffe_root
67 | ) # this file is expected to be in {caffe_root}/examples/hed/
68 | sys.path.insert(0, caffe_root + "python")
69 |
70 |
71 | if not os.path.exists(args.hed_mat_dir):
72 | print("create output directory %s" % args.hed_mat_dir)
73 | os.makedirs(args.hed_mat_dir)
74 |
75 | imgList = os.listdir(args.images_dir)
76 | nImgs = len(imgList)
77 | print("#images = %d" % nImgs)
78 |
79 | caffe.set_mode_gpu()
80 | caffe.set_device(args.gpu_id)
81 | # load net
82 | net = caffe.Net(args.prototxt, args.caffemodel, caffe.TEST)
83 | # pad border
84 | border = args.border
85 |
86 | for i in range(nImgs):
87 | if i % 500 == 0:
88 | print("processing image %d/%d" % (i, nImgs))
89 | im = Image.open(os.path.join(args.images_dir, imgList[i]))
90 |
91 | in_ = np.array(im, dtype=np.float32)
92 | in_ = np.pad(in_, ((border, border), (border, border), (0, 0)), "reflect")
93 |
94 | in_ = in_[:, :, 0:3]
95 | in_ = in_[:, :, ::-1]
96 | in_ -= np.array((104.00698793, 116.66876762, 122.67891434))
97 | in_ = in_.transpose((2, 0, 1))
98 | # remove the following two lines if testing with cpu
99 |
100 | # shape for input (data blob is N x C x H x W), set data
101 | net.blobs["data"].reshape(1, *in_.shape)
102 | net.blobs["data"].data[...] = in_
103 | # run net and take argmax for prediction
104 | net.forward()
105 | fuse = net.blobs["sigmoid-fuse"].data[0][0, :, :]
106 | # get rid of the border
107 | fuse = fuse[(border + 35) : (-border + 35), (border + 35) : (-border + 35)]
108 | # save hed file to the disk
109 | name, ext = os.path.splitext(imgList[i])
110 | sio.savemat(os.path.join(args.hed_mat_dir, name + ".mat"), {"edge_predict": fuse})
111 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/eval_cityscapes/download_fcn8s.sh:
--------------------------------------------------------------------------------
1 | URL=http://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/fcn-8s-cityscapes/fcn-8s-cityscapes.caffemodel
2 | OUTPUT_FILE=./scripts/eval_cityscapes/caffemodel/fcn-8s-cityscapes.caffemodel
3 | wget -N $URL -O $OUTPUT_FILE
4 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/eval_cityscapes/evaluate.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import caffe
5 | import numpy as np
6 | import scipy.misc
7 | from cityscapes import cityscapes
8 | from PIL import Image
9 |
10 | from util import fast_hist, get_scores, segrun
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument(
14 | "--cityscapes_dir",
15 | type=str,
16 | required=True,
17 | help="Path to the original cityscapes dataset",
18 | )
19 | parser.add_argument(
20 | "--result_dir",
21 | type=str,
22 | required=True,
23 | help="Path to the generated images to be evaluated",
24 | )
25 | parser.add_argument(
26 | "--output_dir", type=str, required=True, help="Where to save the evaluation results"
27 | )
28 | parser.add_argument(
29 | "--caffemodel_dir",
30 | type=str,
31 | default="./scripts/eval_cityscapes/caffemodel/",
32 | help="Where the FCN-8s caffemodel stored",
33 | )
34 | parser.add_argument("--gpu_id", type=int, default=0, help="Which gpu id to use")
35 | parser.add_argument(
36 | "--split", type=str, default="val", help="Data split to be evaluated"
37 | )
38 | parser.add_argument(
39 | "--save_output_images",
40 | type=int,
41 | default=0,
42 | help="Whether to save the FCN output images",
43 | )
44 | args = parser.parse_args()
45 |
46 |
47 | def main():
48 | if not os.path.isdir(args.output_dir):
49 | os.makedirs(args.output_dir)
50 | if args.save_output_images > 0:
51 | output_image_dir = args.output_dir + "image_outputs/"
52 | if not os.path.isdir(output_image_dir):
53 | os.makedirs(output_image_dir)
54 | CS = cityscapes(args.cityscapes_dir)
55 | n_cl = len(CS.classes)
56 | label_frames = CS.list_label_frames(args.split)
57 | caffe.set_device(args.gpu_id)
58 | caffe.set_mode_gpu()
59 | net = caffe.Net(
60 | args.caffemodel_dir + "/deploy.prototxt",
61 | args.caffemodel_dir + "fcn-8s-cityscapes.caffemodel",
62 | caffe.TEST,
63 | )
64 |
65 | hist_perframe = np.zeros((n_cl, n_cl))
66 | for i, idx in enumerate(label_frames):
67 | if i % 10 == 0:
68 | print("Evaluating: %d/%d" % (i, len(label_frames)))
69 | city = idx.split("_")[0]
70 | # idx is city_shot_frame
71 | label = CS.load_label(args.split, city, idx)
72 | im_file = args.result_dir + "/" + idx + "_leftImg8bit.png"
73 | im = np.array(Image.open(im_file))
74 | im = scipy.misc.imresize(im, (label.shape[1], label.shape[2]))
75 | out = segrun(net, CS.preprocess(im))
76 | hist_perframe += fast_hist(label.flatten(), out.flatten(), n_cl)
77 | if args.save_output_images > 0:
78 | label_im = CS.palette(label)
79 | pred_im = CS.palette(out)
80 | scipy.misc.imsave(output_image_dir + "/" + str(i) + "_pred.jpg", pred_im)
81 | scipy.misc.imsave(output_image_dir + "/" + str(i) + "_gt.jpg", label_im)
82 | scipy.misc.imsave(output_image_dir + "/" + str(i) + "_input.jpg", im)
83 |
84 | (
85 | mean_pixel_acc,
86 | mean_class_acc,
87 | mean_class_iou,
88 | per_class_acc,
89 | per_class_iou,
90 | ) = get_scores(hist_perframe)
91 | with open(args.output_dir + "/evaluation_results.txt", "w") as f:
92 | f.write("Mean pixel accuracy: %f\n" % mean_pixel_acc)
93 | f.write("Mean class accuracy: %f\n" % mean_class_acc)
94 | f.write("Mean class IoU: %f\n" % mean_class_iou)
95 | f.write("************ Per class numbers below ************\n")
96 | for i, cl in enumerate(CS.classes):
97 | while len(cl) < 15:
98 | cl = cl + " "
99 | f.write(
100 | "%s: acc = %f, iou = %f\n" % (cl, per_class_acc[i], per_class_iou[i])
101 | )
102 |
103 |
104 | main()
105 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/eval_cityscapes/util.py:
--------------------------------------------------------------------------------
1 | # The following code is modified from https://github.com/shelhamer/clockwork-fcn
2 | import numpy as np
3 |
4 |
5 | def get_out_scoremap(net):
6 | return net.blobs["score"].data[0].argmax(axis=0).astype(np.uint8)
7 |
8 |
9 | def feed_net(net, in_):
10 | """
11 | Load prepared input into net.
12 | """
13 | net.blobs["data"].reshape(1, *in_.shape)
14 | net.blobs["data"].data[...] = in_
15 |
16 |
17 | def segrun(net, in_):
18 | feed_net(net, in_)
19 | net.forward()
20 | return get_out_scoremap(net)
21 |
22 |
23 | def fast_hist(a, b, n):
24 | k = np.where((a >= 0) & (a < n))[0]
25 | bc = np.bincount(n * a[k].astype(int) + b[k], minlength=n**2)
26 | if len(bc) != n**2:
27 | # ignore this example if dimension mismatch
28 | return 0
29 | return bc.reshape(n, n)
30 |
31 |
32 | def get_scores(hist):
33 | # Mean pixel accuracy
34 | acc = np.diag(hist).sum() / (hist.sum() + 1e-12)
35 |
36 | # Per class accuracy
37 | cl_acc = np.diag(hist) / (hist.sum(1) + 1e-12)
38 |
39 | # Per class IoU
40 | iu = np.diag(hist) / (hist.sum(1) + hist.sum(0) - np.diag(hist) + 1e-12)
41 |
42 | return acc, np.nanmean(cl_acc), np.nanmean(iu), cl_acc, iu
43 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/install_deps.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | pip install visdom
3 | pip install dominate
4 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/test_before_push.py:
--------------------------------------------------------------------------------
1 | # Simple script to make sure basic usage
2 | # such as training, testing, saving and loading
3 | # runs without errors.
4 | import os
5 |
6 |
7 | def run(command):
8 | print(command)
9 | exit_status = os.system(command)
10 | if exit_status > 0:
11 | exit(1)
12 |
13 |
14 | if __name__ == "__main__":
15 | # download mini datasets
16 | if not os.path.exists("./datasets/mini"):
17 | run("bash ./datasets/download_cyclegan_dataset.sh mini")
18 |
19 | if not os.path.exists("./datasets/mini_pix2pix"):
20 | run("bash ./datasets/download_cyclegan_dataset.sh mini_pix2pix")
21 |
22 | # pretrained cyclegan model
23 | if not os.path.exists("./checkpoints/horse2zebra_pretrained/latest_net_G.pth"):
24 | run("bash ./scripts/download_cyclegan_model.sh horse2zebra")
25 | run(
26 | "python test.py --model test --dataroot ./datasets/mini --name horse2zebra_pretrained --no_dropout --num_test 1 --no_dropout"
27 | )
28 |
29 | # pretrained pix2pix model
30 | if not os.path.exists(
31 | "./checkpoints/facades_label2photo_pretrained/latest_net_G.pth"
32 | ):
33 | run("bash ./scripts/download_pix2pix_model.sh facades_label2photo")
34 | if not os.path.exists("./datasets/facades"):
35 | run("bash ./datasets/download_pix2pix_dataset.sh facades")
36 | run(
37 | "python test.py --dataroot ./datasets/facades/ --direction BtoA --model pix2pix --name facades_label2photo_pretrained --num_test 1"
38 | )
39 |
40 | # cyclegan train/test
41 | run(
42 | "python train.py --model cycle_gan --name temp_cyclegan --dataroot ./datasets/mini --n_epochs 1 --n_epochs_decay 0 --save_latest_freq 10 --print_freq 1 --display_id -1"
43 | )
44 | run(
45 | 'python test.py --model test --name temp_cyclegan --dataroot ./datasets/mini --num_test 1 --model_suffix "_A" --no_dropout'
46 | )
47 |
48 | # pix2pix train/test
49 | run(
50 | "python train.py --model pix2pix --name temp_pix2pix --dataroot ./datasets/mini_pix2pix --n_epochs 1 --n_epochs_decay 5 --save_latest_freq 10 --display_id -1"
51 | )
52 | run(
53 | "python test.py --model pix2pix --name temp_pix2pix --dataroot ./datasets/mini_pix2pix --num_test 1"
54 | )
55 |
56 | # template train/test
57 | run(
58 | "python train.py --model template --name temp2 --dataroot ./datasets/mini_pix2pix --n_epochs 1 --n_epochs_decay 0 --save_latest_freq 10 --display_id -1"
59 | )
60 | run(
61 | "python test.py --model template --name temp2 --dataroot ./datasets/mini_pix2pix --num_test 1"
62 | )
63 |
64 | # colorization train/test (optional)
65 | if not os.path.exists("./datasets/mini_colorization"):
66 | run("bash ./datasets/download_cyclegan_dataset.sh mini_colorization")
67 |
68 | run(
69 | "python train.py --model colorization --name temp_color --dataroot ./datasets/mini_colorization --n_epochs 1 --n_epochs_decay 0 --save_latest_freq 5 --display_id -1"
70 | )
71 | run(
72 | "python test.py --model colorization --name temp_color --dataroot ./datasets/mini_colorization --num_test 1"
73 | )
74 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/test_colorization.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python test.py --dataroot ./datasets/colorization --name color_pix2pix --model colorization
3 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/test_cyclegan.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --phase test --no_dropout
3 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/test_fid.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python test_fid.py \
3 | --dataroot /media/lyndon/c6f4bbbd-8d47-4dcb-b0db-d788fe2b25571/dataset/image_translation/horse2zebra \
4 | --checkpoints_dir ./checkpoints \
5 | --name horse2zebra \
6 | --gpu_ids 1 \
7 | --model sc \
8 | --num_test 0
--------------------------------------------------------------------------------
/F-LSeSim/scripts/test_pix2pix.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --netG unet_256 --direction BtoA --dataset_mode aligned --norm batch
3 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/test_sc.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python test.py \
3 | --dataroot /media/lyndon/c6f4bbbd-8d47-4dcb-b0db-d788fe2b25571/dataset/image_translation/horse2zebra \
4 | --checkpoints_dir ./checkpoints \
5 | --name horse2zebra \
6 | --model sc \
7 | --num_test 0
--------------------------------------------------------------------------------
/F-LSeSim/scripts/test_single.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python test.py --dataroot ./datasets/facades/testB/ --name facades_pix2pix --model test --netG unet_256 --direction BtoA --dataset_mode single --norm batch
3 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/test_sinsc.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python test.py \
3 | --dataroot /media/lyndon/c6f4bbbd-8d47-4dcb-b0db-d788fe2b25571/dataset/image_translation/single_image_monet_etretat \
4 | --name image2monet \
5 | --model sinsc
--------------------------------------------------------------------------------
/F-LSeSim/scripts/train_colorization.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python train.py --dataroot ./datasets/colorization --name color_pix2pix --model colorization
3 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/train_cyclegan.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan --pool_size 50 --no_dropout
3 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/train_pix2pix.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --netG unet_256 --direction BtoA --lambda_L1 100 --dataset_mode aligned --norm batch --pool_size 0
3 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/train_sc.sh:
--------------------------------------------------------------------------------
1 | echo "Read config file: $1";
2 | mode=$(grep -A0 'MODEL_NAME:' $1 | cut -c 13- | cut -d# -f1 | tr -d '"' | sed "s/ //g"); echo "$mode"
3 | if [ "$mode" = "LSeSim" ]; then
4 | exp=$(grep -A0 'EXPERIMENT_NAME:' $1 | cut -c 18- | tr -d '"')
5 | data=$(grep -A0 'TRAIN_ROOT:' $1 | cut -c 15- | tr -d '"')
6 | augment=$(grep -A0 'Augment:' $1 | cut -c 12- | cut -d# -f1 | sed "s/ //g")
7 | if [ "$augment" = "True" ]; then
8 | set -ex
9 | python train.py \
10 | --dataroot $data \
11 | --name $exp \
12 | --model sc \
13 | --gpu_ids 1 \
14 | --lambda_spatial 10 \
15 | --lambda_gradient 0 \
16 | --attn_layers 4,7,9 \
17 | --loss_mode cos \
18 | --gan_mode lsgan \
19 | --display_port 8093 \
20 | --direction AtoB \
21 | --patch_size 64 \
22 | --learned_attn \
23 | --augment
24 | else
25 | set -ex
26 | python train.py \
27 | --dataroot $data \
28 | --name $exp \
29 | --model sc \
30 | --gpu_ids 1 \
31 | --lambda_spatial 10 \
32 | --lambda_gradient 0 \
33 | --attn_layers 4,7,9 \
34 | --loss_mode cos \
35 | --gan_mode lsgan \
36 | --display_port 8093 \
37 | --direction AtoB \
38 | --patch_size 64
39 | fi
40 | else
41 | echo "Not LSeSim model"
42 | fi
43 |
--------------------------------------------------------------------------------
/F-LSeSim/scripts/train_sinsc.sh:
--------------------------------------------------------------------------------
1 | set -ex
2 | python train.py \
3 | --dataroot /media/lyndon/c6f4bbbd-8d47-4dcb-b0db-d788fe2b25571/dataset/image_translation/single_image_monet_etretat \
4 | --name image2monet \
5 | --model sinsc \
6 | --gpu_ids 1 \
7 | --display_port 8093 \
8 | --pool_size 0
--------------------------------------------------------------------------------
/F-LSeSim/scripts/transfer_sc.sh:
--------------------------------------------------------------------------------
1 | echo "Read config file: $1";
2 | mode=$(grep -A0 'MODEL_NAME:' $1 | cut -c 13- | cut -d# -f1 | tr -d '"' | sed "s/ //g"); echo "$mode"
3 | if [ "$mode" = "LSeSim" ]; then
4 | exp=$(grep -A0 'EXPERIMENT_NAME:' $1 | cut -c 18- | tr -d '"')
5 | python3 inference.py \
6 | --name $exp \
7 | --model sc \
8 | -c $1
9 |
10 | python3 combine.py \
11 | --read_original \
12 | -c $1
13 | else
14 | echo "Not LSeSim model"
15 | fi
16 |
--------------------------------------------------------------------------------
/F-LSeSim/test.py:
--------------------------------------------------------------------------------
1 | """General-purpose test script for image-to-image translation.
2 |
3 | Once you have trained your model with train.py, you can use this script to test the model.
4 | It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'.
5 |
6 | It first creates model and dataset given the option. It will hard-code some parameters.
7 | It then runs inference for '--num_test' images and save results to an HTML file.
8 |
9 | Example (You need to train models first or download pre-trained models from our website):
10 | Test a CycleGAN model (both sides):
11 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
12 |
13 | Test a CycleGAN model (one side only):
14 | python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
15 |
16 | The option '--model test' is used for generating CycleGAN results only for one side.
17 | This option will automatically set '--dataset_mode single', which only loads the images from one set.
18 | On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
19 | which is sometimes unnecessary. The results will be saved at ./results/.
20 | Use '--results_dir ' to specify the results directory.
21 |
22 | Test a pix2pix model:
23 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
24 |
25 | See options/base_options.py and options/test_options.py for more test options.
26 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
27 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
28 | """
29 | import os
30 |
31 | from data import create_dataset
32 | from models import create_model
33 | from options.test_options import TestOptions
34 | from util import html
35 | from util.visualizer import save_images
36 |
37 | if __name__ == "__main__":
38 | opt = TestOptions().parse() # get test options
39 | # hard-code some parameters for test
40 | opt.num_threads = 0 # test code only supports num_threads = 1
41 | opt.batch_size = 1 # test code only supports batch_size = 1
42 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
43 | opt.no_flip = (
44 | True # no flip; comment this line if results on flipped images are needed.
45 | )
46 | opt.display_id = (
47 | -1
48 | ) # no visdom display; the test code saves the results to a HTML file.
49 | dataset = create_dataset(
50 | opt
51 | ) # create a dataset given opt.dataset_mode and other options
52 | model = create_model(opt) # create a model given opt.model and other options
53 | # create a website
54 | web_dir = os.path.join(
55 | opt.results_dir, opt.name, "{}_{}".format(opt.phase, opt.epoch)
56 | ) # define the website directory
57 | print("creating web directory", web_dir)
58 | webpage = html.HTML(
59 | web_dir,
60 | "Experiment = %s, Phase = %s, Epoch = %s" % (opt.name, opt.phase, opt.epoch),
61 | )
62 | # test with eval mode. This only affects layers like batchnorm and dropout.
63 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
64 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
65 | opt.num_test = opt.num_test if opt.num_test > 0 else float("inf")
66 | for i, data in enumerate(dataset):
67 | if i == 0:
68 | model.data_dependent_initialize(data)
69 | model.setup(opt)
70 | model.parallelize()
71 | if opt.eval:
72 | model.eval()
73 | if i >= opt.num_test: # only apply our model to opt.num_test images.
74 | break
75 | model.set_input(data) # unpack data from data loader
76 | model.test() # run inference
77 | visuals = model.get_current_visuals() # get image results
78 | img_path = model.get_image_paths() # get image paths
79 | if i % 5 == 0: # save images to an HTML file
80 | print("processing (%04d)-th image... %s" % (i, img_path))
81 | save_images(
82 | webpage,
83 | visuals,
84 | img_path,
85 | aspect_ratio=opt.aspect_ratio,
86 | width=opt.display_winsize,
87 | )
88 | webpage.save() # save the HTML
89 |
--------------------------------------------------------------------------------
/F-LSeSim/test_fid.py:
--------------------------------------------------------------------------------
1 | """General-purpose test script for image-to-image translation.
2 |
3 | Once you have trained your model with train.py, you can use this script to test the model.
4 | It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'.
5 |
6 | It first creates model and dataset given the option. It will hard-code some parameters.
7 | It then runs inference for '--num_test' images and save results to an HTML file.
8 |
9 | Example (You need to train models first or download pre-trained models from our website):
10 | Test a CycleGAN model (both sides):
11 | python test.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
12 |
13 | Test a CycleGAN model (one side only):
14 | python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
15 |
16 | The option '--model test' is used for generating CycleGAN results only for one side.
17 | This option will automatically set '--dataset_mode single', which only loads the images from one set.
18 | On the contrary, using '--model cycle_gan' requires loading and generating results in both directions,
19 | which is sometimes unnecessary. The results will be saved at ./results/.
20 | Use '--results_dir ' to specify the results directory.
21 |
22 | Test a pix2pix model:
23 | python test.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
24 |
25 | See options/base_options.py and options/test_options.py for more test options.
26 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
27 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
28 | """
29 | import os
30 |
31 | import matplotlib.pyplot as plt
32 |
33 | import util.util as util
34 | from data import create_dataset
35 | from evaluations.fid_score import calculate_fid_given_paths
36 | from models import create_model
37 | from options.test_options import TestOptions
38 | from util import html
39 | from util.visualizer import save_images
40 |
41 | if __name__ == "__main__":
42 | opt = TestOptions().parse() # get test options
43 | # hard-code some parameters for test
44 | opt.num_threads = 0 # test code only supports num_threads = 1
45 | opt.batch_size = 1 # test code only supports batch_size = 1
46 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
47 | opt.no_flip = (
48 | True # no flip; comment this line if results on flipped images are needed.
49 | )
50 | opt.display_id = (
51 | -1
52 | ) # no visdom display; the test code saves the results to a HTML file.
53 | opt.num_test = opt.num_test if opt.num_test > 0 else float("inf")
54 | dataset = create_dataset(
55 | opt
56 | ) # create a dataset given opt.dataset_mode and other options
57 | train_dataset = create_dataset(util.copyconf(opt, phase="train"))
58 | # traverse all epoch for the evaluation
59 | files_list = os.listdir(opt.checkpoints_dir + "/" + opt.name)
60 | epoches = []
61 | fid_values = {}
62 | for file in files_list:
63 | if "net_G" in file and "latest" not in file:
64 | name = file.split("_")
65 | epoches.append(name[0])
66 | for epoch in epoches:
67 | opt.epoch = epoch
68 | model = create_model(opt) # create a model given opt.model and other options
69 | # create a website
70 | web_dir = os.path.join(
71 | opt.results_dir, opt.name, "{}_{}".format(opt.phase, opt.epoch)
72 | ) # define the website directory
73 | print("creating web directory", web_dir)
74 | webpage = html.HTML(
75 | web_dir,
76 | "Experiment = %s, Phase = %s, Epoch = %s"
77 | % (opt.name, opt.phase, opt.epoch),
78 | )
79 | # test with eval mode. This only affects layers like batchnorm and dropout.
80 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode.
81 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout.
82 | for i, data in enumerate(dataset):
83 | if i == 0:
84 | model.data_dependent_initialize(data)
85 | model.setup(
86 | opt
87 | ) # regular setup: load and print networks; create schedulers
88 | model.parallelize()
89 | if opt.eval:
90 | model.eval()
91 | if i >= opt.num_test: # only apply our model to opt.num_test images.
92 | break
93 | model.set_input(data) # unpack data from data loader
94 | model.test() # run inference
95 | visuals = model.get_current_visuals() # get image results
96 | img_path = model.get_image_paths() # get image paths
97 | if i % 5 == 0: # save images to an HTML file
98 | print("processing (%04d)-th image... %s" % (i, img_path))
99 | save_images(
100 | webpage,
101 | visuals,
102 | img_path,
103 | aspect_ratio=opt.aspect_ratio,
104 | width=opt.display_winsize,
105 | )
106 | paths = [
107 | os.path.join(web_dir, "images", "fake_B"),
108 | os.path.join(web_dir, "images", "real_B"),
109 | ]
110 | fid_value = calculate_fid_given_paths(paths, 50, True, 2048)
111 | fid_values[int(epoch)] = fid_value
112 | webpage.save() # save the HTML
113 | print(fid_values)
114 | x = []
115 | y = []
116 | for key in sorted(fid_values.keys()):
117 | x.append(key)
118 | y.append(fid_values[key])
119 | plt.figure()
120 | plt.plot(x, y)
121 | for a, b in zip(x, y):
122 | plt.text(a, b, str(round(b, 2)))
123 | plt.xlabel("Epoch")
124 | plt.ylabel("FID on test set")
125 | plt.title(opt.name)
126 | plt.savefig(os.path.join(opt.results_dir, opt.name, "fid.jpg"))
127 |
--------------------------------------------------------------------------------
/F-LSeSim/train.py:
--------------------------------------------------------------------------------
1 | """General-purpose training script for image-to-image translation.
2 |
3 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and
4 | different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization).
5 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model').
6 |
7 | It first creates model, dataset, and visualizer given the option.
8 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models.
9 | The script supports continue/resume training. Use '--continue_train' to resume your previous training.
10 |
11 | Example:
12 | Train a CycleGAN model:
13 | python train.py --dataroot ./datasets/maps --name maps_cyclegan --model cycle_gan
14 | Train a pix2pix model:
15 | python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
16 |
17 | See options/base_options.py and options/train_options.py for more training options.
18 | See training and test tips at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/tips.md
19 | See frequently asked questions at: https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/docs/qa.md
20 | """
21 | import os
22 | import time
23 | from collections import defaultdict
24 |
25 | from torchvision.utils import save_image
26 |
27 | from data import create_dataset
28 | from models import create_model
29 | from options.train_options import TrainOptions
30 |
31 |
32 | # from util.visualizer import Visualizer
33 | def reverse_image_normalize(img, mean=0.5, std=0.5):
34 | return img * std + mean
35 |
36 |
37 | if __name__ == "__main__":
38 | opt = TrainOptions().parse() # get training options
39 | dataset = create_dataset(
40 | opt
41 | ) # create a dataset given opt.dataset_mode and other options
42 | dataset_size = len(dataset) # get the number of images in the dataset.
43 | print("The number of training images = %d" % dataset_size)
44 |
45 | model = create_model(
46 | opt, norm_cfg={'type': 'in'}
47 | ) # create a model given opt.model and other options
48 |
49 | total_iters = 0 # the total number of training iterations
50 |
51 | os.makedirs(os.path.join("./checkpoints/", opt.name, "test"), exist_ok=True)
52 | for epoch in range(
53 | opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1
54 | ): # outer loop for different epochs; we save the model by , +
55 | epoch_start_time = time.time() # timer for entire epoch
56 | iter_data_time = time.time() # timer for data loading per iteration
57 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
58 | out = defaultdict(int)
59 | for i, data in enumerate(dataset): # inner loop within one epoch
60 | iter_start_time = time.time() # timer for computation per iteration
61 | print(f"[Epoch {epoch}][Iter {i}] Processing ...", end="\r")
62 | if total_iters % opt.print_freq == 0:
63 | t_data = iter_start_time - iter_data_time
64 |
65 | total_iters += opt.batch_size
66 | epoch_iter += opt.batch_size
67 | if epoch == opt.epoch_count and i == 0:
68 | model.data_dependent_initialize(data)
69 | model.setup(opt)
70 | model.parallelize()
71 | model.print_networks(True)
72 | model.set_input(data) # unpack data from dataset and apply preprocessing
73 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights
74 |
75 | if (
76 | total_iters % opt.display_freq == 0
77 | ): # display images on visdom and save images to a HTML file
78 | save_result = total_iters % opt.update_html_freq == 0
79 | model.compute_visuals()
80 | results = model.get_current_visuals()
81 |
82 | for img_name, img in results.items():
83 | save_image(
84 | reverse_image_normalize(img),
85 | os.path.join(
86 | "./checkpoints/",
87 | opt.name,
88 | "test",
89 | f"{epoch}_{img_name}_{i}.png",
90 | ),
91 | )
92 |
93 | for k, v in out.items():
94 | out[k] /= opt.display_freq
95 |
96 | print(f"[Epoch {epoch}][Iter {i}] {out}", flush=True)
97 | for k, v in out.items():
98 | out[k] = 0
99 |
100 | losses = model.get_current_losses()
101 | for k, v in losses.items():
102 | out[k] += v
103 |
104 | if (
105 | total_iters % opt.save_latest_freq == 0
106 | ): # cache our latest model every iterations
107 | print(
108 | "saving the latest model (epoch %d, total_iters %d)"
109 | % (epoch, total_iters)
110 | )
111 | save_suffix = "iter_%d" % total_iters if opt.save_by_iter else "latest"
112 | model.save_networks(save_suffix)
113 |
114 | iter_data_time = time.time()
115 | if (
116 | epoch % opt.save_epoch_freq == 0
117 | ): # cache our model every epochs
118 | print(
119 | "saving the model at the end of epoch %d, iters %d"
120 | % (epoch, total_iters)
121 | )
122 | model.save_networks("latest")
123 | model.save_networks(epoch)
124 |
125 | print(
126 | "End of epoch %d / %d \t Time Taken: %d sec"
127 | % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)
128 | )
129 | model.update_learning_rate() # update learning rates in the beginning of every epoch.
130 |
--------------------------------------------------------------------------------
/F-LSeSim/util/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/F-LSeSim/util/get_data.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import os
4 | import tarfile
5 | from os.path import abspath, basename, isdir, join
6 | from warnings import warn
7 | from zipfile import ZipFile
8 |
9 | import requests
10 | from bs4 import BeautifulSoup
11 |
12 |
13 | class GetData(object):
14 | """A Python script for downloading CycleGAN or pix2pix datasets.
15 |
16 | Parameters:
17 | technique (str) -- One of: 'cyclegan' or 'pix2pix'.
18 | verbose (bool) -- If True, print additional information.
19 |
20 | Examples:
21 | >>> from util.get_data import GetData
22 | >>> gd = GetData(technique='cyclegan')
23 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed.
24 |
25 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh'
26 | and 'scripts/download_cyclegan_model.sh'.
27 | """
28 |
29 | def __init__(self, technique="cyclegan", verbose=True):
30 | url_dict = {
31 | "pix2pix": "http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/",
32 | "cyclegan": "https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets",
33 | }
34 | self.url = url_dict.get(technique.lower())
35 | self._verbose = verbose
36 |
37 | def _print(self, text):
38 | if self._verbose:
39 | print(text)
40 |
41 | @staticmethod
42 | def _get_options(r):
43 | soup = BeautifulSoup(r.text, "lxml")
44 | options = [
45 | h.text
46 | for h in soup.find_all("a", href=True)
47 | if h.text.endswith((".zip", "tar.gz"))
48 | ]
49 | return options
50 |
51 | def _present_options(self):
52 | r = requests.get(self.url)
53 | options = self._get_options(r)
54 | print("Options:\n")
55 | for i, o in enumerate(options):
56 | print("{0}: {1}".format(i, o))
57 | choice = input(
58 | "\nPlease enter the number of the " "dataset above you wish to download:"
59 | )
60 | return options[int(choice)]
61 |
62 | def _download_data(self, dataset_url, save_path):
63 | if not isdir(save_path):
64 | os.makedirs(save_path)
65 |
66 | base = basename(dataset_url)
67 | temp_save_path = join(save_path, base)
68 |
69 | with open(temp_save_path, "wb") as f:
70 | r = requests.get(dataset_url)
71 | f.write(r.content)
72 |
73 | if base.endswith(".tar.gz"):
74 | obj = tarfile.open(temp_save_path)
75 | elif base.endswith(".zip"):
76 | obj = ZipFile(temp_save_path, "r")
77 | else:
78 | raise ValueError("Unknown File Type: {0}.".format(base))
79 |
80 | self._print("Unpacking Data...")
81 | obj.extractall(save_path)
82 | obj.close()
83 | os.remove(temp_save_path)
84 |
85 | def get(self, save_path, dataset=None):
86 | """
87 |
88 | Download a dataset.
89 |
90 | Parameters:
91 | save_path (str) -- A directory to save the data to.
92 | dataset (str) -- (optional). A specific dataset to download.
93 | Note: this must include the file extension.
94 | If None, options will be presented for you
95 | to choose from.
96 |
97 | Returns:
98 | save_path_full (str) -- the absolute path to the downloaded data.
99 |
100 | """
101 | if dataset is None:
102 | selected_dataset = self._present_options()
103 | else:
104 | selected_dataset = dataset
105 |
106 | save_path_full = join(save_path, selected_dataset.split(".")[0])
107 |
108 | if isdir(save_path_full):
109 | warn("\n'{0}' already exists. Voiding Download.".format(save_path_full))
110 | else:
111 | self._print("Downloading Data...")
112 | url = "{0}/{1}".format(self.url, selected_dataset)
113 | self._download_data(url, save_path=save_path)
114 |
115 | return abspath(save_path_full)
116 |
--------------------------------------------------------------------------------
/F-LSeSim/util/html.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import dominate
4 | from dominate.tags import a, br, h3, img, meta, p, table, td, tr
5 |
6 |
7 | class HTML:
8 | """This HTML class allows us to save images and write texts into a single HTML file.
9 |
10 | It consists of functions such as (add a text header to the HTML file),
11 | (add a row of images to the HTML file), and (save the HTML to the disk).
12 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
13 | """
14 |
15 | def __init__(self, web_dir, title, refresh=0):
16 | """Initialize the HTML classes
17 |
18 | Parameters:
19 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
33 | with self.doc.head:
34 | meta(http_equiv="refresh", content=str(refresh))
35 |
36 | def get_image_dir(self):
37 | """Return the directory that stores images"""
38 | return self.img_dir
39 |
40 | def add_header(self, text):
41 | """Insert a header to the HTML file
42 |
43 | Parameters:
44 | text (str) -- the header text
45 | """
46 | with self.doc:
47 | h3(text)
48 |
49 | def add_images(self, ims, txts, links, width=400):
50 | """add images to the HTML file
51 |
52 | Parameters:
53 | ims (str list) -- a list of image paths
54 | txts (str list) -- a list of image names shown on the website
55 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
56 | """
57 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
58 | self.doc.add(self.t)
59 | with self.t:
60 | with tr():
61 | for im, txt, link in zip(ims, txts, links):
62 | with td(
63 | style="word-wrap: break-word;", halign="center", valign="top"
64 | ):
65 | with p():
66 | with a(href=os.path.join("images", link)):
67 | img(
68 | style="width:%dpx" % width,
69 | src=os.path.join("images", im),
70 | )
71 | br()
72 | p(txt)
73 |
74 | def save(self):
75 | """save the current content to the HMTL file"""
76 | html_file = "%s/index.html" % self.web_dir
77 | f = open(html_file, "wt")
78 | f.write(self.doc.render())
79 | f.close()
80 |
81 |
82 | if __name__ == "__main__": # we show an example usage here.
83 | html = HTML("web/", "test_html")
84 | html.add_header("hello world")
85 |
86 | ims, txts, links = [], [], []
87 | for n in range(4):
88 | ims.append("image_%d.png" % n)
89 | txts.append("text_%d" % n)
90 | links.append("image_%d.png" % n)
91 | html.add_images(ims, txts, links)
92 | html.save()
93 |
--------------------------------------------------------------------------------
/F-LSeSim/util/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import torch
4 |
5 |
6 | class ImagePool:
7 | """This class implements an image buffer that stores previously generated images.
8 |
9 | This buffer enables us to update discriminators using a history of generated images
10 | rather than the ones produced by the latest generators.
11 | """
12 |
13 | def __init__(self, pool_size):
14 | """Initialize the ImagePool class
15 |
16 | Parameters:
17 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created
18 | """
19 | self.pool_size = pool_size
20 | if self.pool_size > 0: # create an empty pool
21 | self.num_imgs = 0
22 | self.images = []
23 |
24 | def query(self, images):
25 | """Return an image from the pool.
26 |
27 | Parameters:
28 | images: the latest generated images from the generator
29 |
30 | Returns images from the buffer.
31 |
32 | By 50/100, the buffer will return input images.
33 | By 50/100, the buffer will return images previously stored in the buffer,
34 | and insert the current images to the buffer.
35 | """
36 | if self.pool_size == 0: # if the buffer size is 0, do nothing
37 | return images
38 | return_images = []
39 | for image in images:
40 | image = torch.unsqueeze(image.data, 0)
41 | if (
42 | self.num_imgs < self.pool_size
43 | ): # if the buffer is not full; keep inserting current images to the buffer
44 | self.num_imgs = self.num_imgs + 1
45 | self.images.append(image)
46 | return_images.append(image)
47 | else:
48 | p = random.uniform(0, 1)
49 | if (
50 | p > 0.5
51 | ): # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer
52 | random_id = random.randint(
53 | 0, self.pool_size - 1
54 | ) # randint is inclusive
55 | tmp = self.images[random_id].clone()
56 | self.images[random_id] = image
57 | return_images.append(tmp)
58 | else: # by another 50% chance, the buffer will return the current image
59 | return_images.append(image)
60 | return_images = torch.cat(return_images, 0) # collect all the images and return
61 | return return_images
62 |
--------------------------------------------------------------------------------
/F-LSeSim/util/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 | from __future__ import print_function
3 |
4 | import argparse
5 | import importlib
6 | import os
7 | from argparse import Namespace
8 |
9 | import cv2
10 | import numpy as np
11 | import torch
12 | import torch.nn.functional as F
13 | import torchvision
14 | from PIL import Image
15 |
16 |
17 | def str2bool(v):
18 | if isinstance(v, bool):
19 | return v
20 | if v.lower() in ("yes", "true", "t", "y", "1"):
21 | return True
22 | elif v.lower() in ("no", "false", "f", "n", "0"):
23 | return False
24 | else:
25 | raise argparse.ArgumentTypeError("Boolean value expected.")
26 |
27 |
28 | def copyconf(default_opt, **kwargs):
29 | conf = Namespace(**vars(default_opt))
30 | for key in kwargs:
31 | setattr(conf, key, kwargs[key])
32 | return conf
33 |
34 |
35 | def find_class_in_module(target_cls_name, module):
36 | target_cls_name = target_cls_name.replace("_", "").lower()
37 | clslib = importlib.import_module(module)
38 | cls = None
39 | for name, clsobj in clslib.__dict__.items():
40 | if name.lower() == target_cls_name:
41 | cls = clsobj
42 |
43 | assert cls is not None, (
44 | "In %s, there should be a class whose name matches %s in lowercase without underscore(_)"
45 | % (module, target_cls_name)
46 | )
47 |
48 | return cls
49 |
50 |
51 | def tensor2im(input_image, imtype=np.uint8):
52 | """ "Converts a Tensor array into a numpy image array.
53 |
54 | Parameters:
55 | input_image (tensor) -- the input image tensor array
56 | imtype (type) -- the desired type of the converted numpy array
57 | """
58 | if not isinstance(input_image, np.ndarray):
59 | if isinstance(input_image, torch.Tensor): # get the data from a variable
60 | image_tensor = input_image.data
61 | else:
62 | return input_image
63 | image_numpy = (
64 | image_tensor[0].cpu().float().numpy()
65 | ) # convert it into a numpy array
66 | if image_numpy.shape[0] == 1: # grayscale to RGB
67 | image_numpy = np.tile(image_numpy, (3, 1, 1))
68 | image_numpy = (
69 | (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
70 | ) # post-processing: tranpose and scaling
71 | else: # if it is a numpy array, do nothing
72 | image_numpy = input_image
73 | return image_numpy.astype(imtype)
74 |
75 |
76 | def diagnose_network(net, name="network"):
77 | """Calculate and print the mean of average absolute(gradients)
78 |
79 | Parameters:
80 | net (torch network) -- Torch network
81 | name (str) -- the name of the network
82 | """
83 | mean = 0.0
84 | count = 0
85 | for param in net.parameters():
86 | if param.grad is not None:
87 | mean += torch.mean(torch.abs(param.grad.data))
88 | count += 1
89 | if count > 0:
90 | mean = mean / count
91 | print(name)
92 | print(mean)
93 |
94 |
95 | def save_image(image_numpy, image_path, aspect_ratio=1.0):
96 | """Save a numpy image to the disk
97 |
98 | Parameters:
99 | image_numpy (numpy array) -- input numpy array
100 | image_path (str) -- the path of the image
101 | """
102 |
103 | image_pil = Image.fromarray(image_numpy)
104 | h, w, _ = image_numpy.shape
105 |
106 | if aspect_ratio > 1.0:
107 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
108 | if aspect_ratio < 1.0:
109 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
110 | image_pil.save(image_path)
111 |
112 |
113 | def print_numpy(x, val=True, shp=False):
114 | """Print the mean, min, max, median, std, and size of a numpy array
115 |
116 | Parameters:
117 | val (bool) -- if print the values of the numpy array
118 | shp (bool) -- if print the shape of the numpy array
119 | """
120 | x = x.astype(np.float64)
121 | if shp:
122 | print("shape,", x.shape)
123 | if val:
124 | x = x.flatten()
125 | print(
126 | "mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f"
127 | % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))
128 | )
129 |
130 |
131 | def mkdirs(paths):
132 | """create empty directories if they don't exist
133 |
134 | Parameters:
135 | paths (str list) -- a list of directory paths
136 | """
137 | if isinstance(paths, list) and not isinstance(paths, str):
138 | for path in paths:
139 | mkdir(path)
140 | else:
141 | mkdir(paths)
142 |
143 |
144 | def mkdir(path):
145 | """create a single empty directory if it didn't exist
146 |
147 | Parameters:
148 | path (str) -- a single directory path
149 | """
150 | if not os.path.exists(path):
151 | os.makedirs(path)
152 |
153 |
154 | def correct_resize_label(t, size):
155 | device = t.device
156 | t = t.detach().cpu()
157 | resized = []
158 | for i in range(t.size(0)):
159 | one_t = t[i, :1]
160 | one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
161 | one_np = one_np[:, :, 0]
162 | one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
163 | resized_t = torch.from_numpy(np.array(one_image)).long()
164 | resized.append(resized_t)
165 | return torch.stack(resized, dim=0).to(device)
166 |
167 |
168 | def correct_resize(t, size, mode=Image.BICUBIC):
169 | device = t.device
170 | t = t.detach().cpu()
171 | resized = []
172 | for i in range(t.size(0)):
173 | one_t = t[i : i + 1]
174 | one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
175 | resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
176 | resized.append(resized_t)
177 | return torch.stack(resized, dim=0).to(device)
178 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Ming-Yang Ho
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/appendix/gpu_memory.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | import matplotlib
4 | import matplotlib.pyplot as plt
5 | import numpy as np
6 | import seaborn as sns
7 |
8 | random.seed(20222022)
9 |
10 | if __name__ == "__main__":
11 | memory_dict = {
12 | "training": {
13 | "CycleGAN": {"x": [128, 256, 512], "y": [2939, 5517, 16357]},
14 | "CUT": {"x": [128, 256, 512], "y": [2365, 4133, 11123]},
15 | "F-LSeSim": {"x": [256, 512, 1024], "y": [3585, 8813, 23023]},
16 | "L-LSeSim": {"x": [256, 512, 1024], "y": [3605, 9331, 29289]},
17 | },
18 | "inference": {
19 | "CycleGAN": {
20 | "x": [128, 256, 512, 1024, 2048],
21 | "y": [1697, 1795, 2211, 3749, 12073],
22 | },
23 | "CUT": {
24 | "x": [128, 256, 512, 1024, 2048],
25 | "y": [1602, 1695, 2111, 3654, 11987],
26 | },
27 | "F-LSeSim": {"x": [512, 1024], "y": [8571, 19269]},
28 | "L-LSeSim": {"x": [512, 1024], "y": [8571, 19269]},
29 | "CycleGAN+KIN (ours)": {
30 | "x": [128, 256, 512, 1024, 2048, 4096, 9192],
31 | "y": [2307, 2307, 2307, 2307, 2307, 2307, 2307],
32 | },
33 | "CUT+KIN (ours)": {
34 | "x": [128, 256, 512, 1024, 2048, 4096, 9192],
35 | "y": [2307, 2307, 2307, 2307, 2307, 2307, 2307],
36 | },
37 | "F/L-LSeSim+KIN (ours)": {
38 | "x": [128, 256, 512, 1024, 2048, 4096, 9192],
39 | "y": [8581, 8581, 8581, 8581, 8581, 8581, 8581],
40 | },
41 | },
42 | }
43 |
44 | pal = sns.color_palette("husl", 8)
45 | hex_colors = list(map(matplotlib.colors.rgb2hex, pal))
46 | pal
47 |
48 | marker_list = ["o", "X", "v", "s", "^", "p", "D"]
49 | plt.figure(figsize=(9, 4))
50 | for idx, model_name in enumerate(memory_dict["training"]):
51 | x = memory_dict["training"][model_name]["x"]
52 | y = memory_dict["training"][model_name]["y"]
53 | p2 = np.poly1d(np.polyfit(x, y, 2))
54 |
55 | xp = np.linspace(128, x[-1], 100)
56 | xp_extra = np.linspace(x[-1], 2000, 100)
57 | marker_on = []
58 | for x_ in x:
59 | marker_on.append(np.searchsorted(xp, x_, side="left"))
60 | # add rnd to avoid overlapping
61 | rnd = random.randint(0, 200)
62 |
63 | plt.plot(
64 | xp,
65 | p2(xp) + rnd,
66 | color=hex_colors[idx],
67 | linewidth=3.0,
68 | linestyle="-",
69 | markersize=8,
70 | marker=marker_list[idx],
71 | markevery=marker_on,
72 | label=model_name,
73 | )
74 | plt.plot(
75 | xp_extra,
76 | p2(xp_extra) + rnd,
77 | color=hex_colors[idx],
78 | linewidth=3.0,
79 | linestyle="--",
80 | )
81 |
82 | plt.xticks(
83 | [128, 256, 512, 1024, 2048, 4096],
84 | [128, 256, 512, 1024, 2048, 4096],
85 | rotation=45,
86 | )
87 | plt.yticks(
88 | [5000, 10000, 15000, 20000, 25000, 30000],
89 | [5, 10, 15, 20, 25, 30],
90 | )
91 | plt.ylim(0, 32000)
92 | plt.title("Training")
93 | plt.ylabel("GPU Memory (GB)")
94 | plt.xlabel("Resolution (√x)")
95 | plt.legend()
96 | plt.grid()
97 | # plt.show()
98 | plt.savefig("./training_usage.png", bbox_inches="tight")
99 |
100 | plt.figure(figsize=(9, 4))
101 | for idx, model_name in enumerate(memory_dict["inference"]):
102 | x = memory_dict["inference"][model_name]["x"]
103 | y = memory_dict["inference"][model_name]["y"]
104 | if "KIN" not in model_name:
105 | p2 = np.poly1d(np.polyfit(x, y, 2))
106 | else:
107 | p2 = np.poly1d(np.polyfit(x, y, 0))
108 |
109 | xp = np.linspace(128, x[-1], 100)
110 | xp_extra = np.linspace(x[-1], 10000, 100)
111 | marker_on = []
112 |
113 | # add rnd to avoid overlapping
114 | rnd = random.randint(0, 2000)
115 |
116 | for x_ in x:
117 | marker_on.append(np.searchsorted(xp, x_, side="left"))
118 | plt.plot(
119 | xp,
120 | p2(xp) + rnd,
121 | color=hex_colors[idx],
122 | linewidth=3.0,
123 | linestyle="-",
124 | markersize=8,
125 | marker=marker_list[idx],
126 | markevery=marker_on,
127 | label=model_name,
128 | )
129 | if "KIN" in model_name:
130 | plt.plot(
131 | xp_extra,
132 | p2(xp_extra) + rnd,
133 | color=hex_colors[idx],
134 | linewidth=3.0,
135 | linestyle="-",
136 | )
137 | else:
138 | plt.plot(
139 | xp_extra,
140 | p2(xp_extra) + rnd,
141 | color=hex_colors[idx],
142 | linewidth=3.0,
143 | linestyle="--",
144 | )
145 |
146 | plt.xticks(
147 | [256, 512, 1024, 2048, 4096, 9192],
148 | [256, 512, 1024, 2048, 4096, 9192],
149 | rotation=45,
150 | )
151 | plt.yticks(
152 | [5000, 10000, 15000, 20000, 25000, 30000],
153 | [5, 10, 15, 20, 25, 30],
154 | )
155 | plt.ylim(0, 32000)
156 | plt.ylabel("GPU Memory (GB)")
157 | plt.xlabel("Resolution (√x)")
158 | plt.title("Inference")
159 | plt.legend()
160 | plt.grid()
161 | # plt.show()
162 | plt.savefig("./inference_usage.jpg", bbox_inches="tight")
163 |
--------------------------------------------------------------------------------
/appendix/inference_usage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/appendix/inference_usage.png
--------------------------------------------------------------------------------
/appendix/training_usage.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/appendix/training_usage.png
--------------------------------------------------------------------------------
/combine.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 | import numpy as np
6 | from PIL import Image
7 |
8 | from utils.util import read_yaml_config
9 |
10 |
11 | def main():
12 | parser = argparse.ArgumentParser("Combined transferred images")
13 | parser.add_argument(
14 | "-c",
15 | "--config",
16 | type=str,
17 | default="./config.yaml",
18 | help="Path to the config file.",
19 | )
20 | parser.add_argument(
21 | "--patch_size", type=int, help="Patch size", default=512
22 | )
23 | parser.add_argument(
24 | "--resize_h", type=int, help="Resize H", default=-1
25 | )
26 | parser.add_argument(
27 | "--resize_w", type=int, help="Resize W", default=-1
28 | )
29 | args = parser.parse_args()
30 |
31 | config = read_yaml_config(args.config)
32 |
33 | basename = os.path.basename(config["INFERENCE_SETTING"]["TEST_X"])
34 | filename = os.path.splitext(basename)[0]
35 | path_root = os.path.join(
36 | config["EXPERIMENT_ROOT_PATH"],
37 | config["EXPERIMENT_NAME"],
38 | "test",
39 | filename,
40 | )
41 |
42 | if (
43 | "OVERWRITE_OUTPUT_PATH" in config["INFERENCE_SETTING"]
44 | and config["INFERENCE_SETTING"]["OVERWRITE_OUTPUT_PATH"] != ""
45 | ):
46 | path_root = config["INFERENCE_SETTING"]["OVERWRITE_OUTPUT_PATH"]
47 |
48 | path_base = os.path.join(
49 | path_root,
50 | config["INFERENCE_SETTING"]["NORMALIZATION"]["TYPE"],
51 | config["INFERENCE_SETTING"]["MODEL_VERSION"],
52 | )
53 |
54 | combined_image_name = f"combined_" \
55 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['TYPE']}_" \
56 | f"{config['INFERENCE_SETTING']['MODEL_VERSION']}.png"
57 |
58 | if config["INFERENCE_SETTING"]["NORMALIZATION"]["TYPE"] == "kin":
59 | path_base = os.path.join(
60 | path_base,
61 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}"
62 | f"_{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}",
63 | )
64 | combined_image_name = f"combined_" \
65 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['TYPE']}" \
66 | f"_{config['INFERENCE_SETTING']['MODEL_VERSION']}_" \
67 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_" \
68 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}.png"
69 |
70 | filenames = os.listdir(path_base)
71 | try:
72 | filenames.remove("thumbnail_Y_fake.png")
73 | except Exception:
74 | pass
75 |
76 | y_anchor_max = 0
77 | x_anchor_max = 0
78 | for filename in filenames:
79 | try:
80 | _, _, y_anchor, x_anchor, _ = filename.split("_", 4)
81 | except Exception as e:
82 | raise ValueError(f"{filename} is not valid") from e
83 | y_anchor_max = max(y_anchor_max, int(y_anchor))
84 | x_anchor_max = max(x_anchor_max, int(x_anchor))
85 |
86 | matrix = np.zeros(
87 | (y_anchor_max + args.patch_size, x_anchor_max + args.patch_size, 3),
88 | dtype=np.uint8,
89 | )
90 |
91 | for filename in sorted(filenames):
92 | print(f"Combine {filename} ", end="\r")
93 | _, _, y_anchor, x_anchor, _ = filename.split("_", 4)
94 | y_anchor = int(y_anchor)
95 | x_anchor = int(x_anchor)
96 | image = cv2.imread(os.path.join(path_base, filename))
97 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
98 | matrix[y_anchor:y_anchor + 512, x_anchor:x_anchor + 512, :] = image
99 |
100 | if (args.resize_h != -1) and (args.resize_w != -1):
101 | matrix = cv2.resize(
102 | matrix, (args.resize_w, args.resize_h), cv2.INTER_CUBIC
103 | )
104 |
105 | matrix_image = Image.fromarray(matrix)
106 | matrix_image.save(os.path.join(path_root, combined_image_name))
107 |
108 |
109 | if __name__ == "__main__":
110 | main()
111 |
--------------------------------------------------------------------------------
/crop.py:
--------------------------------------------------------------------------------
1 | # Author: Kaminyou (https://github.com/Kaminyou)
2 | import argparse
3 | import csv
4 | import os
5 | from pathlib import Path
6 |
7 | import cv2
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from utils.util import extend_size, is_blank_patch, reduce_size
12 |
13 | if __name__ == "__main__":
14 | """
15 | USAGE
16 | 1. prepare data belongs to domain X
17 | python3 crop.py -i ./data/example/HE_cropped.jpg \
18 | -o ./data/example/trainX/ --thumbnail \
19 | --thumbnail_output ./data/example/trainX/
20 |
21 | 2. prepare data belongs to domain Y
22 | python3 crop.py -i ./data/example/ER_cropped.jpg \
23 | -o ./data/example/trainY/ --thumbnail \
24 | --thumbnail_output ./data/example/trainY/
25 |
26 | 3. prepare data belongs to domain X required to be transferred to domain Y
27 | python3 crop.py -i ./data/example/HE_cropped.jpg \
28 | -o ./data/example/testX/ \
29 | --stride 512 --thumbnail \
30 | --thumbnail_output ./data/example/testX/
31 | """
32 | parser = argparse.ArgumentParser(
33 | description="Crop a large image into patches."
34 | )
35 | parser.add_argument(
36 | "-i",
37 | "--input",
38 | help="Input image path",
39 | required=True,
40 | )
41 | parser.add_argument(
42 | "-o",
43 | "--output",
44 | help="Output image path",
45 | default="data/initial/trainX/",
46 | )
47 | parser.add_argument(
48 | "--thumbnail",
49 | help="If crop a thumbnail or not",
50 | action="store_true",
51 | )
52 | parser.add_argument(
53 | "--thumbnail_output",
54 | help="Output image path",
55 | default="data/initial/",
56 | )
57 | parser.add_argument(
58 | "--patch_size",
59 | type=int,
60 | help="Patch size",
61 | default=512,
62 | )
63 | parser.add_argument(
64 | "--stride",
65 | type=int,
66 | help="Stride to crop patch",
67 | default=256,
68 | )
69 | parser.add_argument(
70 | "--mode",
71 | type=str,
72 | help="reduce or extend",
73 | default="reduce",
74 | )
75 | args = parser.parse_args()
76 |
77 | os.makedirs(args.output, exist_ok=True)
78 |
79 | image = cv2.imread(args.input)
80 | image_name = Path(args.input).stem
81 | try:
82 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
83 | except Exception:
84 | raise ValueError
85 |
86 | if args.thumbnail:
87 | thumbnail = cv2.resize(
88 | image, (args.patch_size, args.patch_size), cv2.INTER_AREA
89 | )
90 | thumbnail_instance = Image.fromarray(thumbnail)
91 | thumbnail_instance.save(
92 | os.path.join(args.thumbnail_output, "thumbnail.png")
93 | )
94 |
95 | h, w, c = image.shape
96 |
97 | if args.mode == "reduce":
98 | resize_fn = reduce_size
99 | resize_code = cv2.INTER_AREA
100 | elif args.mode == "extend":
101 | resize_fn = extend_size
102 | resize_code = cv2.INTER_CUBIC
103 | else:
104 | raise NotImplementedError
105 |
106 | h_resize = resize_fn(h, args.patch_size)
107 | w_resize = resize_fn(w, args.patch_size)
108 | print(f"Original size: h={h} w={w}")
109 | print(f"Resize to: h={h_resize} w={w_resize}")
110 |
111 | image = cv2.resize(image, (w_resize, h_resize), resize_code)
112 |
113 | h_anchors = np.arange(0, h_resize, args.stride)
114 | w_anchors = np.arange(0, w_resize, args.stride)
115 | output_num = len(h_anchors) * len(w_anchors)
116 | max_idx_digits = max(len(str(len(h_anchors))), len(str(len(w_anchors))))
117 | max_anchor_digits = max(len(str(h_anchors[-1])), len(str(w_anchors[-1])))
118 |
119 | curr_idx = 1
120 | blank_patches_list = []
121 | for y_idx, h_anchor in enumerate(h_anchors):
122 | for x_idx, w_anchor in enumerate(w_anchors):
123 | print(f"[{curr_idx} / {output_num}] Processing ...", end="\r")
124 | image_crop = image[
125 | h_anchor:h_anchor + args.patch_size,
126 | w_anchor:w_anchor + args.patch_size,
127 | :,
128 | ]
129 |
130 | # if stride < patch_size, some images will be cropped at the margin
131 | # e.g., stride = 256, patch_size = 512, image_size = 600
132 | # => [0, 512], [256, 600]
133 | # thus the output size should be double checked
134 | if image_crop.shape[0] != args.patch_size:
135 | continue
136 | if image_crop.shape[1] != args.patch_size:
137 | continue
138 |
139 | image_crop_instance = Image.fromarray(image_crop)
140 |
141 | # filename: {y-idx}_{x-idx}_{h-anchor}_{w-anchor}.png
142 | filename = f"{str(y_idx).zfill(max_idx_digits)}_" \
143 | f"{str(x_idx).zfill(max_idx_digits)}_" \
144 | f"{str(h_anchor).zfill(max_anchor_digits)}_" \
145 | f"{str(w_anchor).zfill(max_anchor_digits)}_" \
146 | f"{image_name}.png"
147 |
148 | image_crop_instance.save(os.path.join(args.output, filename))
149 | blank_patches_list.append((filename, is_blank_patch(image_crop)))
150 | curr_idx += 1
151 |
152 | with open(
153 | os.path.join(args.output, "blank_patches_list.csv"),
154 | "w",
155 | encoding="UTF8",
156 | newline="",
157 | ) as f:
158 | writer = csv.writer(f)
159 | writer.writerows(blank_patches_list)
160 |
--------------------------------------------------------------------------------
/crop_pipeline.py:
--------------------------------------------------------------------------------
1 | # Author: Kaminyou (https://github.com/Kaminyou)
2 | import argparse
3 | import os
4 |
5 | from utils.util import read_yaml_config
6 |
7 | if __name__ == "__main__":
8 | parser = argparse.ArgumentParser("Model inference")
9 | parser.add_argument(
10 | "-c",
11 | "--config",
12 | type=str,
13 | default="./data/example/config.yaml",
14 | help="Path to the config file.",
15 | )
16 | args = parser.parse_args()
17 |
18 | config = read_yaml_config(args.config)["CROPPING_SETTING"]
19 | # Prepare the training set
20 | for train_x, train_y in zip(config["TRAIN_X"], config["TRAIN_Y"]):
21 | os.system(
22 | f"python3 crop.py -i {train_x} -o {config['TRAIN_DIR_X']} "
23 | f"--patch_size {config['PATCH_SIZE']} --stride {config['STRIDE']}"
24 | )
25 | os.system(
26 | f"python3 crop.py -i {train_y} -o {config['TRAIN_DIR_Y']} "
27 | f"--patch_size {config['PATCH_SIZE']} --stride {config['STRIDE']}"
28 | )
29 |
30 | # Prepare the testing set
31 | for test_x in config["TEST_X"]:
32 | basename = os.path.basename(test_x)
33 | filename = os.path.splitext(basename)[0]
34 | output_path = os.path.join(config['TEST_DIR_X'], filename)
35 | os.system(
36 | f"python3 crop.py -i {test_x} -o {output_path} "
37 | f"--thumbnail --thumbnail_output {output_path} "
38 | f"--patch_size {config['PATCH_SIZE']} "
39 | f"--stride {config['PATCH_SIZE']}"
40 | )
41 |
--------------------------------------------------------------------------------
/data/example/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/data/example/.gitkeep
--------------------------------------------------------------------------------
/data/example/config.yaml:
--------------------------------------------------------------------------------
1 | EXPERIMENT_ROOT_PATH: "./experiments/"
2 | EXPERIMENT_NAME: "example_CUT"
3 | MODEL_NAME: "CUT" # currently only CUT, cycleGAN, and LSeSim are support
4 | DEVICE: "cuda"
5 |
6 | CROPPING_SETTING:
7 | TRAIN_X:
8 | - "./data/example/HE_cropped.jpg"
9 | TRAIN_DIR_X: "./data/example/trainX/"
10 | TRAIN_Y:
11 | - "./data/example/ER_cropped.jpg"
12 | TRAIN_DIR_Y: "./data/example/trainY/"
13 | TEST_X:
14 | - "./data/example/HE_cropped.jpg"
15 | TEST_DIR_X: "./data/example/testX/"
16 | PATCH_SIZE: 512
17 | STRIDE: 256
18 |
19 | TRAINING_SETTING:
20 | PAIRED_TRAINING: False
21 | RANDOM_CROP_AUG: False
22 | TRAIN_ROOT: "./data/example/"
23 | TRAIN_DIR_X: "./data/example/trainX/"
24 | TRAIN_DIR_Y: "./data/example/trainY/"
25 | NUM_EPOCHS: 100
26 | LAMBDA_Y: 1
27 | LEARNING_RATE: 0.0002
28 | BATCH_SIZE: 1
29 | NUM_WORKERS: 8
30 | SAVE_MODEL: true
31 | SAVE_MODEL_EPOCH_STEP: 10
32 | VISUALIZATION_STEP: 250
33 | LOAD_MODEL: false
34 | LOAD_EPOCH: 0
35 | Augment: True #LSeSim
36 |
37 | INFERENCE_SETTING:
38 | TEST_X: "./data/example/HE_cropped.jpg"
39 | TEST_DIR_X: "./data/example/testX/HE_cropped/"
40 | MODEL_VERSION: "10"
41 | SAVE_ORIGINAL_IMAGE: False
42 | NORMALIZATION:
43 | TYPE: 'kin'
44 | PADDING: 1
45 | KERNEL_TYPE: 'constant'
46 | KERNEL_SIZE: 3
47 | THUMBNAIL: "./data/example/testX/HE_cropped/thumbnail.png" #"None" # set to "None" if it's not required
48 |
--------------------------------------------------------------------------------
/evaluation.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | exps = ["lung_lesion"]
4 | models = ["cycleGAN", "CUT", "LSeSim"]
5 | model_epochs = [60, "latest", 100]
6 | norm_types = ["in", "tin", "kin"]
7 | kernel_types = [
8 | None,
9 | None,
10 | [
11 | "constant_1",
12 | "gaussian_1",
13 | "constant_3",
14 | "gaussian_3",
15 | "constant_5",
16 | "gaussian_5",
17 | ],
18 | ]
19 |
20 | for exp in exps:
21 | for i, model in enumerate(models):
22 | model_epoch = model_epochs[i]
23 | for j, norm_type in enumerate(norm_types):
24 | kernel_type = kernel_types[j]
25 | if kernel_type is None:
26 | command = f"python3 metric.py --exp_name {model}_{model_epoch}_{norm_type} --path-A ./experiments/{exp}/{exp}_{model}/test/{norm_type}/{model_epoch}/ --path-B ./data/lung_lesion/testX --blank_patches_list ./data/lung_lesion/testX/blank_patches_list.csv >> FID.out"
27 | os.system(command)
28 | else:
29 | for kernel in kernel_type:
30 | command = f"python3 metric.py --exp_name {model}_{model_epoch}_{norm_type}_{kernel} --path-A ./experiments/{exp}/{exp}_{model}/test/{norm_type}/{model_epoch}/{kernel}/ --path-B ./data/lung_lesion/testX --blank_patches_list ./data/lung_lesion/testX/blank_patches_list.csv >> FID.out"
31 | os.system(command)
32 |
--------------------------------------------------------------------------------
/experiments/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/experiments/.gitkeep
--------------------------------------------------------------------------------
/imgs/Figure_KIN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/imgs/Figure_KIN.jpg
--------------------------------------------------------------------------------
/imgs/Figure_patch_with_patch_lineplot_block1_mean.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/imgs/Figure_patch_with_patch_lineplot_block1_mean.jpg
--------------------------------------------------------------------------------
/imgs/Figure_patch_with_patch_lineplot_blocks_mean.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/imgs/Figure_patch_with_patch_lineplot_blocks_mean.jpg
--------------------------------------------------------------------------------
/imgs/URUST_anime.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/imgs/URUST_anime.gif
--------------------------------------------------------------------------------
/metric_images_with_ref.py:
--------------------------------------------------------------------------------
1 | import os
2 | from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
3 |
4 | import torch
5 |
6 | from metrics.calculate_fid import calculate_fid_given_two_paths
7 | from metrics.inception import InceptionV3
8 |
9 | parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
10 | parser.add_argument(
11 | "--exp_name",
12 | type=str,
13 | help="Experiment name"
14 | )
15 | parser.add_argument(
16 | "--batch-size",
17 | type=int,
18 | default=50,
19 | help="Batch size to use"
20 | )
21 | parser.add_argument(
22 | "--num-workers",
23 | type=int,
24 | help=(
25 | "Number of processes to use for data loading. "
26 | "Defaults to `min(8, num_cpus)`"
27 | ),
28 | )
29 | parser.add_argument(
30 | "--device",
31 | type=str,
32 | default=None,
33 | help="Device to use. Like cuda, cuda:0 or cpu",
34 | )
35 | parser.add_argument(
36 | "--dims",
37 | type=int,
38 | default=2048,
39 | choices=list(InceptionV3.BLOCK_INDEX_BY_DIM),
40 | help=(
41 | "Dimensionality of Inception features to use. "
42 | "By default, uses pool3 features"
43 | ),
44 | )
45 | parser.add_argument(
46 | "--path-A",
47 | type=str,
48 | help=(
49 | "Paths to the original images or "
50 | "to .npz statistic files. "
51 | "Support multiple paths by using:"
52 | 'path_a1,path_a2,path_a3 ... seperated by ",". '
53 | ),
54 | )
55 | parser.add_argument(
56 | "--path-B",
57 | type=str,
58 | help=(
59 | "Paths to the generated images or "
60 | "to .npz statistic files. "
61 | "Support multiple paths by using:"
62 | 'path_a1,path_a2,path_a3 ... seperated by ",". '
63 | ),
64 | )
65 | parser.add_argument(
66 | "--blank_patches_list_A",
67 | type=str,
68 | default=None,
69 | required=False,
70 | help="Paths to the lsit of blank patches",
71 | )
72 | parser.add_argument(
73 | "--blank_patches_list_B",
74 | type=str,
75 | default=None,
76 | required=False,
77 | help="Paths to the lsit of blank patches",
78 | )
79 |
80 |
81 | def main():
82 | args = parser.parse_args()
83 |
84 | if args.device is None:
85 | device = torch.device("cuda" if (torch.cuda.is_available()) else "cpu")
86 | else:
87 | device = torch.device(args.device)
88 |
89 | if args.num_workers is None:
90 | num_avail_cpus = len(os.sched_getaffinity(0))
91 | num_workers = min(num_avail_cpus, 8)
92 | else:
93 | num_workers = args.num_workers
94 |
95 | path_As = args.path_A.split(",")
96 | path_Bs = args.path_B.split(",")
97 |
98 | fid_value = calculate_fid_given_two_paths(
99 | path_As,
100 | path_Bs,
101 | args.batch_size,
102 | device,
103 | args.dims,
104 | num_workers,
105 | args.blank_patches_list_A,
106 | args.blank_patches_list_B,
107 | )
108 | print(f"Exp::{args.exp_name} || FID: {fid_value:.4f}")
109 |
110 |
111 | if __name__ == "__main__":
112 | main()
113 |
--------------------------------------------------------------------------------
/metric_whole_image_no_ref.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from pathlib import Path
4 |
5 | import cv2
6 | from PIL import Image
7 |
8 | from metrics.niqe import niqe
9 | from metrics.piqe import piqe
10 | from metrics.sobel import calculate_sobel_gradient_pipeline
11 |
12 | parser = argparse.ArgumentParser()
13 | parser.add_argument(
14 | "--exp_name",
15 | type=str,
16 | default="",
17 | help="Experiment name",
18 | )
19 | parser.add_argument(
20 | "--path",
21 | type=str,
22 | required=True,
23 | help="Path to the image",
24 | )
25 | parser.add_argument(
26 | "--save_grad",
27 | action="store_true",
28 | help="Whether to save the gardient image",
29 | )
30 |
31 | if __name__ == "__main__":
32 | args = parser.parse_args()
33 | sobel_gradient, sobel_gradient_avg = calculate_sobel_gradient_pipeline(
34 | args.path
35 | )
36 |
37 | im_bgr = cv2.imread(args.path)
38 | piqe_score, _, _, _ = piqe(im_bgr)
39 | niqe_score = niqe(im_bgr)
40 |
41 | img_path = Path(args.path)
42 |
43 | if args.save_grad:
44 | parent_path = img_path.parents[0]
45 | save_name = img_path.stem + "_grad.png"
46 | im = Image.fromarray(sobel_gradient)
47 | im.save(os.path.join(parent_path, save_name))
48 |
49 | print(
50 | f"Exp::{args.exp_name}::{img_path.stem} || "
51 | f"Grad = {sobel_gradient_avg:.4f} "
52 | f"PIQE = {piqe_score:.4f} NIQE = {niqe_score:.4f}"
53 | )
54 |
--------------------------------------------------------------------------------
/metric_whole_image_with_ref.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | from metrics.histogram import compare_images_histogram_pipeline
5 |
6 | parser = argparse.ArgumentParser()
7 | parser.add_argument(
8 | "--exp_name",
9 | type=str,
10 | default="",
11 | help="Experiment name",
12 | )
13 | parser.add_argument(
14 | "--image_A_path",
15 | type=str,
16 | required=True,
17 | help="Path to the reference image",
18 | )
19 | parser.add_argument(
20 | "--image_B_path",
21 | type=str,
22 | required=True,
23 | help="Path to the compared image",
24 | )
25 |
26 | if __name__ == "__main__":
27 | args = parser.parse_args()
28 |
29 | similarity = compare_images_histogram_pipeline(
30 | args.image_A_path,
31 | args.image_B_path,
32 | )
33 | image_A_name = Path(args.image_A_path).stem
34 | image_B_name = Path(args.image_B_path).stem
35 | print(
36 | f"Exp::{args.exp_name}::{image_A_name} {image_B_name} || "
37 | f"Histogram corr = {similarity:.4f}"
38 | )
39 |
--------------------------------------------------------------------------------
/metrics/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/metrics/__init__.py
--------------------------------------------------------------------------------
/metrics/histogram.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 |
4 | def read_img_toRGB(img_path):
5 | img = cv2.imread(img_path)
6 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
7 | return img
8 |
9 |
10 | def calculate_histogram(
11 | image, channels=[0], hist_size=[10], hist_range=[0, 256]
12 | ):
13 | # convert to different color space if needed
14 | image = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
15 |
16 | image_hist = cv2.calcHist([image], channels, None, hist_size, hist_range)
17 | image_hist = cv2.normalize(image_hist, image_hist).flatten()
18 | return image_hist
19 |
20 |
21 | def compare_images_histogram(img_base, img_compare, method="correlation"):
22 | hist_1 = calculate_histogram(img_base)
23 | hist_2 = calculate_histogram(img_compare)
24 |
25 | if method == "intersection":
26 | comparison = cv2.compareHist(hist_1, hist_2, cv2.HISTCMP_INTERSECT)
27 | else:
28 | comparison = cv2.compareHist(hist_1, hist_2, cv2.HISTCMP_CORREL)
29 | return comparison
30 |
31 |
32 | def compare_images_histogram_pipeline(img_base_path, img_compare_path):
33 | img_base = read_img_toRGB(img_base_path)
34 | img_compare = read_img_toRGB(img_compare_path)
35 | similarity = compare_images_histogram(img_base, img_compare)
36 | return similarity
37 |
--------------------------------------------------------------------------------
/metrics/niqe_image_params.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/metrics/niqe_image_params.mat
--------------------------------------------------------------------------------
/metrics/sobel.py:
--------------------------------------------------------------------------------
1 | import cv2
2 |
3 |
4 | def read_img_toYCRCB(img_path):
5 | img = cv2.imread(img_path)
6 | img = cv2.cvtColor(img, cv2.COLOR_BGR2YCR_CB)
7 | return img
8 |
9 |
10 | def calculate_YCRCB_gradient(YCRCB_img):
11 | YCRCB_img_Y_channel = YCRCB_img[..., 0]
12 | sobelx = cv2.Sobel(YCRCB_img_Y_channel, -1, 1, 0, ksize=3)
13 | sobely = cv2.Sobel(YCRCB_img_Y_channel, -1, 0, 1, ksize=3)
14 | grad = cv2.addWeighted(sobelx, 0.5, sobely, 0.5, 0)
15 | return grad
16 |
17 |
18 | def calculate_grad_avg(grad):
19 | h, w = grad.shape
20 | return grad.sum() / h / w
21 |
22 |
23 | def calculate_sobel_gradient_pipeline(img_path):
24 | img_ycrcb = read_img_toYCRCB(img_path)
25 | grad = calculate_YCRCB_gradient(img_ycrcb)
26 | grad_avg = calculate_grad_avg(grad)
27 | return grad, grad_avg
28 |
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/models/__init__.py
--------------------------------------------------------------------------------
/models/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from models.downsample import Downsample
6 | from models.normalization import make_norm_layer
7 |
8 |
9 | class DiscriminatorBasicBlock(nn.Module):
10 | def __init__(
11 | self,
12 | in_features,
13 | out_features,
14 | do_downsample=True,
15 | do_instancenorm=True,
16 | norm_cfg=None,
17 | ):
18 | super().__init__()
19 |
20 | self.norm_cfg = norm_cfg or {'type': 'in'}
21 | self.norm_cfg = {k.lower(): v for k, v in self.norm_cfg.items()}
22 | self.do_downsample = do_downsample
23 | self.do_instancenorm = do_instancenorm
24 |
25 | self.conv = nn.Conv2d(
26 | in_features, out_features, kernel_size=4, stride=1, padding=1
27 | )
28 | self.leakyrelu = nn.LeakyReLU(0.2, True)
29 |
30 | if do_instancenorm:
31 | self.instancenorm = make_norm_layer(self.norm_cfg, num_features=out_features)
32 |
33 | if do_downsample:
34 | self.downsample = Downsample(out_features)
35 |
36 | def forward(self, x):
37 | x = self.conv(x)
38 | if self.do_instancenorm:
39 | x = self.instancenorm(x)
40 | x = self.leakyrelu(x)
41 | if self.do_downsample:
42 | x = self.downsample(x)
43 | return x
44 |
45 |
46 | class Discriminator(nn.Module):
47 | def __init__(self, in_channels=3, features=64, avg_pooling=False):
48 | super().__init__()
49 | self.block1 = DiscriminatorBasicBlock(
50 | in_channels,
51 | features,
52 | do_downsample=True,
53 | do_instancenorm=False,
54 | )
55 | self.block2 = DiscriminatorBasicBlock(
56 | features,
57 | features * 2,
58 | do_downsample=True,
59 | do_instancenorm=True,
60 | )
61 | self.block3 = DiscriminatorBasicBlock(
62 | features * 2,
63 | features * 4,
64 | do_downsample=True,
65 | do_instancenorm=True,
66 | )
67 | self.block4 = DiscriminatorBasicBlock(
68 | features * 4,
69 | features * 8,
70 | do_downsample=False,
71 | do_instancenorm=True,
72 | )
73 | self.conv = nn.Conv2d(
74 | features * 8,
75 | 1,
76 | kernel_size=4,
77 | stride=1,
78 | padding=1,
79 | )
80 | self.avg_pooling = avg_pooling
81 |
82 | def forward(self, x):
83 | x = self.block1(x)
84 | x = self.block2(x)
85 | x = self.block3(x)
86 | x = self.block4(x)
87 | x = self.conv(x)
88 | if self.avg_pooling:
89 | x = F.avg_pool2d(x, x.size()[2:])
90 | x = torch.flatten(x, 1)
91 | return x
92 |
93 | def set_requires_grad(self, requires_grad=False):
94 | for param in self.parameters():
95 | param.requires_grad = requires_grad
96 |
97 |
98 | if __name__ == "__main__":
99 | x = torch.randn((5, 3, 256, 256))
100 | print(x.shape)
101 | model = Discriminator(in_channels=3, avg_pooling=True)
102 | preds = model(x)
103 | print(preds.shape)
104 |
--------------------------------------------------------------------------------
/models/downsample.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Downsample(nn.Module):
5 | def __init__(self, features):
6 | super().__init__()
7 | self.reflectionpad = nn.ReflectionPad2d(1)
8 | self.conv = nn.Conv2d(features, features, kernel_size=3, stride=2)
9 |
10 | def forward(self, x):
11 | x = self.reflectionpad(x)
12 | x = self.conv(x)
13 | return x
14 |
--------------------------------------------------------------------------------
/models/model.py:
--------------------------------------------------------------------------------
1 | from models.cut import ContrastiveModel
2 | from models.cyclegan import CycleGanModel
3 |
4 |
5 | # XXX: Check if isTrain is used
6 | def get_model(config, model_name="CUT", norm_cfg=None, isTrain=True):
7 | if model_name == "CUT":
8 | model = ContrastiveModel(config, norm_cfg=norm_cfg)
9 | elif model_name == "cycleGAN":
10 | model = CycleGanModel(config, norm_cfg=norm_cfg)
11 | elif model_name == "LSeSim":
12 | print("Please use the scripts prepared in the F-LSeSim folder")
13 | raise NotImplementedError
14 | else:
15 | raise NotImplementedError
16 | return model
17 |
--------------------------------------------------------------------------------
/models/normalization.py:
--------------------------------------------------------------------------------
1 | from copy import deepcopy
2 | from typing import Any, Dict
3 |
4 | import torch.nn as nn
5 |
6 | from models.kin import KernelizedInstanceNorm
7 | from models.tin import ThumbInstanceNorm
8 |
9 |
10 | # TODO: To be deprecated
11 | def get_normalization_layer(num_features, normalization="kin"):
12 | if normalization == "kin":
13 | return KernelizedInstanceNorm(num_features=num_features)
14 | elif normalization == "tin":
15 | return ThumbInstanceNorm(num_features=num_features)
16 | elif normalization == "in":
17 | return nn.InstanceNorm2d(num_features)
18 | else:
19 | raise NotImplementedError
20 |
21 |
22 | def make_norm_layer(norm_cfg: Dict[str, Any], **kwargs: Any):
23 | """
24 | Create normalization layer based on given config and arguments.
25 |
26 | Args:
27 | norm_cfg (Dict[str, Any]): A dict of keyword arguments of normalization layer.
28 | It must have a key 'type' to specify which normalization layers will be used.
29 | It accepts upper case argument.
30 | **kwargs (Any): The keyword arguments are used to overwrite `norm_cfg`.
31 |
32 | Returns:
33 | nn.Module: A layer object.
34 | """
35 | norm_cfg = deepcopy(norm_cfg)
36 | norm_cfg = {k.lower(): v for k, v in norm_cfg.items()}
37 |
38 | norm_cfg.update(kwargs)
39 |
40 | if 'type' not in norm_cfg:
41 | raise ValueError('"type" wasn\'t specified.')
42 |
43 | norm_type = norm_cfg['type']
44 | del norm_cfg['type']
45 |
46 | if norm_type == 'in':
47 | return nn.InstanceNorm2d(**norm_cfg)
48 | elif norm_type == 'tin':
49 | return ThumbInstanceNorm(**norm_cfg)
50 | elif norm_type == 'kin':
51 | return KernelizedInstanceNorm(**norm_cfg)
52 | else:
53 | raise ValueError(f'Unknown norm type: {norm_type}.')
54 |
--------------------------------------------------------------------------------
/models/projector.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from models.generator import Generator
5 |
6 |
7 | class MLP(nn.Module):
8 | def __init__(self, input_nc, output_nc):
9 | super().__init__()
10 | self.mlp = nn.Sequential(
11 | *[
12 | nn.Linear(input_nc, output_nc),
13 | nn.ReLU(),
14 | nn.Linear(output_nc, output_nc),
15 | ]
16 | )
17 |
18 | def forward(self, x):
19 | return self.mlp(x)
20 |
21 |
22 | class Head(nn.Module):
23 | def __init__(self, in_channels=3, features=64, residuals=9):
24 | super().__init__()
25 | self.mlp_0 = MLP(3, 256)
26 | self.mlp_1 = MLP(128, 256)
27 | self.mlp_2 = MLP(256, 256)
28 | self.mlp_3 = MLP(256, 256)
29 | self.mlp_4 = MLP(256, 256)
30 |
31 | def forward(self, features):
32 | return_features = []
33 | for feature_id, feature in enumerate(features):
34 | mlp = getattr(self, f"mlp_{feature_id}")
35 | feature = mlp(feature)
36 | norm = feature.pow(2).sum(1, keepdim=True).pow(1.0 / 2)
37 | feature = feature.div(norm + 1e-7)
38 | return_features.append(feature)
39 | return return_features
40 |
41 |
42 | if __name__ == "__main__":
43 | x = torch.randn((5, 3, 256, 256))
44 | print(x.shape)
45 | G = Generator()
46 | H = Head()
47 | feat_k_pool, sample_ids = G(x, encode_only=True, patch_ids=None)
48 | feat_q_pool, _ = G(x, encode_only=True, patch_ids=sample_ids)
49 | print(len(feat_k_pool))
50 | return_features = H(feat_q_pool)
51 | print(len(return_features))
52 |
--------------------------------------------------------------------------------
/models/tests/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/models/tests/__init__.py
--------------------------------------------------------------------------------
/models/tests/test_cut.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pytest
5 | import torch
6 | from models.model import get_model
7 | from models.kin import KernelizedInstanceNorm
8 | from PIL import Image
9 | from torch.utils.data import DataLoader
10 | from torchvision.utils import make_grid
11 | from utils.dataset import XInferenceDataset
12 | from utils.util import (read_yaml_config, reverse_image_normalize,
13 | test_transforms)
14 |
15 |
16 | @pytest.fixture()
17 | def config():
18 | config_path = os.path.join(
19 | "./test_data",
20 | "configs",
21 | "config_lung_lesion_for_test_cut.yaml"
22 | )
23 | config = read_yaml_config(config_path)
24 |
25 | return config
26 |
27 |
28 | @pytest.fixture()
29 | def in_model(config):
30 | model = get_model(
31 | config=config,
32 | model_name=config["MODEL_NAME"],
33 | norm_cfg={'type': 'in'},
34 | isTrain=False,
35 | )
36 | model.load_networks(config["INFERENCE_SETTING"]["MODEL_VERSION"])
37 |
38 | return model
39 |
40 |
41 | @pytest.fixture()
42 | def kin_model(config):
43 | model = get_model(
44 | config=config,
45 | model_name=config["MODEL_NAME"],
46 | norm_cfg=config["INFERENCE_SETTING"]["NORMALIZATION"],
47 | isTrain=False,
48 | )
49 | model.load_networks(config["INFERENCE_SETTING"]["MODEL_VERSION"])
50 | model.eval()
51 |
52 | return model
53 |
54 |
55 | @pytest.fixture()
56 | def dataset(config):
57 | dataset = XInferenceDataset(
58 | root_X=config["INFERENCE_SETTING"]["TEST_DIR_X"],
59 | transform=test_transforms,
60 | return_anchor=True,
61 | )
62 | return dataset
63 |
64 |
65 | @pytest.fixture()
66 | def dataloader(dataset):
67 | loader = DataLoader(
68 | dataset, batch_size=1, shuffle=False, pin_memory=True
69 | )
70 | return loader
71 |
72 |
73 | @pytest.fixture()
74 | def in_expected_outputs(config):
75 | expected_output_dir = os.path.join(
76 | config["INFERENCE_SETTING"]["TEST_DIR_Y"],
77 | config["EXPERIMENT_NAME"],
78 | "in",
79 | f"{config['INFERENCE_SETTING']['MODEL_VERSION']}",
80 |
81 | )
82 | expected_output_files = sorted(
83 | os.listdir(expected_output_dir)
84 | )
85 |
86 | expected_outputs = []
87 | for expected_output_file in expected_output_files:
88 | expected_outputs.append(np.array(
89 | Image.open(os.path.join(
90 | expected_output_dir,
91 | expected_output_file
92 | )).convert("RGB")
93 | ))
94 |
95 | return expected_outputs
96 |
97 |
98 | @pytest.fixture()
99 | def kin_expected_outputs(config):
100 | expected_output_dir = os.path.join(
101 | config["INFERENCE_SETTING"]["TEST_DIR_Y"],
102 | config["EXPERIMENT_NAME"],
103 | "kin",
104 | config["INFERENCE_SETTING"]["MODEL_VERSION"],
105 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_"
106 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}"
107 | )
108 | expected_output_files = sorted(
109 | os.listdir(expected_output_dir)
110 | )
111 |
112 | expected_outputs = []
113 | for expected_output_file in expected_output_files:
114 | expected_outputs.append(np.array(
115 | Image.open(os.path.join(
116 | expected_output_dir,
117 | expected_output_file
118 | )).convert("RGB")
119 | ))
120 |
121 | return expected_outputs
122 |
123 |
124 | def test_inference(in_model, dataloader, in_expected_outputs):
125 | """
126 | Integration testing for IN inferece
127 | """
128 |
129 | for idx, data in enumerate(dataloader):
130 | X, _ = data["X_img"], data["X_path"]
131 | Y_fake = in_model.inference(X)
132 | test_output_tensor = reverse_image_normalize(Y_fake)
133 |
134 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
135 | test_output = make_grid(test_output_tensor).mul(255).add_(0.5).clamp_(
136 | 0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
137 | expected_output = in_expected_outputs[idx]
138 |
139 | assert test_output == pytest.approx(expected_output)
140 |
141 |
142 | def test_inference_with_anchor(
143 | config,
144 | kin_model,
145 | dataset,
146 | dataloader,
147 | kin_expected_outputs
148 | ):
149 | """
150 | Integration testing for KIN inferece
151 | """
152 |
153 | y_anchor_num, x_anchor_num = dataset.get_boundary()
154 |
155 | kin_model.init_kernelized_instance_norm_for_whole_model(
156 | y_anchor_num=y_anchor_num + 1,
157 | x_anchor_num=x_anchor_num + 1,
158 | )
159 |
160 | for idx, data in enumerate(dataloader):
161 | X, _, y_anchor, x_anchor = (
162 | data["X_img"],
163 | data["X_path"],
164 | data["y_idx"],
165 | data["x_idx"],
166 | )
167 | _ = kin_model.inference_with_anchor(
168 | X,
169 | y_anchor=y_anchor,
170 | x_anchor=x_anchor,
171 | mode=KernelizedInstanceNorm.Mode.PHASE_CACHING,
172 | )
173 |
174 | for idx, data in enumerate(dataloader):
175 | X, _, y_anchor, x_anchor = (
176 | data["X_img"],
177 | data["X_path"],
178 | data["y_idx"],
179 | data["x_idx"],
180 | )
181 | Y_fake = kin_model.inference_with_anchor(
182 | X,
183 | y_anchor=y_anchor,
184 | x_anchor=x_anchor,
185 | mode=KernelizedInstanceNorm.Mode.PHASE_INFERENCE,
186 | )
187 | test_output_tensor = reverse_image_normalize(Y_fake)
188 |
189 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
190 | test_output = make_grid(test_output_tensor).mul(255).add_(0.5).clamp_(
191 | 0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
192 | expected_output = kin_expected_outputs[idx]
193 |
194 | assert test_output == pytest.approx(expected_output)
195 |
--------------------------------------------------------------------------------
/models/tests/test_cyclegan.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import numpy as np
4 | import pytest
5 | import torch
6 | from models.model import get_model
7 | from models.kin import KernelizedInstanceNorm
8 | from PIL import Image
9 | from torch.utils.data import DataLoader
10 | from torchvision.utils import make_grid
11 | from utils.dataset import XInferenceDataset
12 | from utils.util import (read_yaml_config, reverse_image_normalize,
13 | test_transforms)
14 |
15 |
16 | @pytest.fixture()
17 | def config():
18 | config_path = os.path.join(
19 | "./test_data",
20 | "configs",
21 | "config_lung_lesion_for_test_cyclegan.yaml"
22 | )
23 | config = read_yaml_config(config_path)
24 |
25 | return config
26 |
27 |
28 | @pytest.fixture()
29 | def in_model(config):
30 | model = get_model(
31 | config=config,
32 | model_name=config["MODEL_NAME"],
33 | norm_cfg={'type': 'in'},
34 | isTrain=False,
35 | )
36 | model.load_networks(config["INFERENCE_SETTING"]["MODEL_VERSION"])
37 |
38 | return model
39 |
40 |
41 | @pytest.fixture()
42 | def kin_model(config):
43 | model = get_model(
44 | config=config,
45 | model_name=config["MODEL_NAME"],
46 | norm_cfg=config['INFERENCE_SETTING']['NORMALIZATION'],
47 | isTrain=False,
48 | )
49 | model.load_networks(config["INFERENCE_SETTING"]["MODEL_VERSION"])
50 | model.eval()
51 |
52 | return model
53 |
54 |
55 | @pytest.fixture()
56 | def dataset(config):
57 | dataset = XInferenceDataset(
58 | root_X=config["INFERENCE_SETTING"]["TEST_DIR_X"],
59 | transform=test_transforms,
60 | return_anchor=True,
61 | )
62 | return dataset
63 |
64 |
65 | @pytest.fixture()
66 | def dataloader(dataset):
67 | loader = DataLoader(
68 | dataset, batch_size=1, shuffle=False, pin_memory=True
69 | )
70 | return loader
71 |
72 |
73 | @pytest.fixture()
74 | def in_expected_outputs(config):
75 | expected_output_dir = os.path.join(
76 | config["INFERENCE_SETTING"]["TEST_DIR_Y"],
77 | config["EXPERIMENT_NAME"],
78 | "in",
79 | f"{config['INFERENCE_SETTING']['MODEL_VERSION']}",
80 |
81 | )
82 | expected_output_files = sorted(
83 | os.listdir(expected_output_dir)
84 | )
85 |
86 | expected_outputs = []
87 | for expected_output_file in expected_output_files:
88 | expected_outputs.append(np.array(
89 | Image.open(os.path.join(
90 | expected_output_dir,
91 | expected_output_file
92 | )).convert("RGB")
93 | ))
94 |
95 | return expected_outputs
96 |
97 |
98 | @pytest.fixture()
99 | def kin_expected_outputs(config):
100 | expected_output_dir = os.path.join(
101 | config["INFERENCE_SETTING"]["TEST_DIR_Y"],
102 | config["EXPERIMENT_NAME"],
103 | "kin",
104 | config["INFERENCE_SETTING"]["MODEL_VERSION"],
105 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['KERNEL_TYPE']}_"
106 | f"{config['INFERENCE_SETTING']['NORMALIZATION']['PADDING']}",
107 |
108 | )
109 | expected_output_files = sorted(
110 | os.listdir(expected_output_dir)
111 | )
112 |
113 | expected_outputs = []
114 | for expected_output_file in expected_output_files:
115 | expected_outputs.append(np.array(
116 | Image.open(os.path.join(
117 | expected_output_dir,
118 | expected_output_file
119 | )).convert("RGB")
120 | ))
121 |
122 | return expected_outputs
123 |
124 |
125 | def test_inference(in_model, dataloader, in_expected_outputs):
126 | """
127 | Integration testing for IN inferece
128 | """
129 |
130 | for idx, data in enumerate(dataloader):
131 | X, _ = data["X_img"], data["X_path"]
132 | Y_fake = in_model.inference(X)
133 | test_output_tensor = reverse_image_normalize(Y_fake)
134 |
135 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
136 | test_output = make_grid(test_output_tensor).mul(255).add_(0.5).clamp_(
137 | 0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
138 | expected_output = in_expected_outputs[idx]
139 |
140 | assert test_output == pytest.approx(expected_output)
141 |
142 |
143 | def test_inference_with_anchor(
144 | config,
145 | kin_model,
146 | dataset,
147 | dataloader,
148 | kin_expected_outputs
149 | ):
150 | """
151 | Integration testing for KIN inferece
152 | """
153 |
154 | y_anchor_num, x_anchor_num = dataset.get_boundary()
155 |
156 | kin_model.init_kernelized_instance_norm_for_whole_model(
157 | y_anchor_num=y_anchor_num + 1,
158 | x_anchor_num=x_anchor_num + 1,
159 | )
160 |
161 | for idx, data in enumerate(dataloader):
162 | X, _, y_anchor, x_anchor = (
163 | data["X_img"],
164 | data["X_path"],
165 | data["y_idx"],
166 | data["x_idx"],
167 | )
168 | _ = kin_model.inference_with_anchor(
169 | X,
170 | y_anchor=y_anchor,
171 | x_anchor=x_anchor,
172 | mode=KernelizedInstanceNorm.Mode.PHASE_CACHING,
173 | )
174 |
175 | for idx, data in enumerate(dataloader):
176 | X, _, y_anchor, x_anchor = (
177 | data["X_img"],
178 | data["X_path"],
179 | data["y_idx"],
180 | data["x_idx"],
181 | )
182 | Y_fake = kin_model.inference_with_anchor(
183 | X,
184 | y_anchor=y_anchor,
185 | x_anchor=x_anchor,
186 | mode=KernelizedInstanceNorm.Mode.PHASE_INFERENCE,
187 | )
188 | test_output_tensor = reverse_image_normalize(Y_fake)
189 |
190 | # Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
191 | test_output = make_grid(test_output_tensor).mul(255).add_(0.5).clamp_(
192 | 0, 255).permute(1, 2, 0).to("cpu", torch.uint8).numpy()
193 | expected_output = kin_expected_outputs[idx]
194 |
195 | assert test_output == pytest.approx(expected_output)
196 |
--------------------------------------------------------------------------------
/models/tests/test_data/configs/config_lung_lesion_for_test_cut.yaml:
--------------------------------------------------------------------------------
1 | EXPERIMENT_ROOT_PATH: "./test_data/checkpoints/lung_lesion"
2 | EXPERIMENT_NAME: "lung_lesion_CUT"
3 | MODEL_NAME: "CUT" # currently only CUT and cycleGAN are support
4 | DEVICE: "cuda"
5 |
6 | TRAINING_SETTING:
7 | PAIRED_TRAINING: False
8 | NUM_EPOCHS: 100
9 | LAMBDA_Y: 1
10 | LEARNING_RATE: 0.0002
11 | BATCH_SIZE: 1
12 | NUM_WORKERS: 4
13 | SAVE_MODEL: true
14 | SAVE_MODEL_EPOCH_STEP: 10
15 | VISUALIZATION_STEP: 250
16 | LOAD_MODEL: false
17 | LOAD_EPOCH: 0
18 |
19 | INFERENCE_SETTING:
20 | TEST_DIR_X: "./test_data/test_dir_x/lung_lesion/"
21 | TEST_DIR_Y: "./test_data/test_dir_y/lung_lesion"
22 | MODEL_VERSION: "latest"
23 | SAVE_ORIGINAL_IMAGE: False
24 | NORMALIZATION:
25 | TYPE: "kin"
26 | PADDING: 1
27 | KERNEL_TYPE: 'gaussian'
28 | KERNEL_SIZE: 3
29 | THUMBNAIL: "None" # set to "None" if it's not required
30 |
--------------------------------------------------------------------------------
/models/tests/test_data/configs/config_lung_lesion_for_test_cyclegan.yaml:
--------------------------------------------------------------------------------
1 | EXPERIMENT_ROOT_PATH: "./test_data/checkpoints/lung_lesion"
2 | EXPERIMENT_NAME: "lung_lesion_cycleGAN"
3 | MODEL_NAME: "cycleGAN" # currently only CUT and cycleGAN are support
4 | DEVICE: "cuda"
5 |
6 | TRAINING_SETTING:
7 | PAIRED_TRAINING: False
8 | NUM_EPOCHS: 100
9 | LAMBDA_Y: 1
10 | LEARNING_RATE: 0.0002
11 | BATCH_SIZE: 1
12 | NUM_WORKERS: 4
13 | SAVE_MODEL: true
14 | SAVE_MODEL_EPOCH_STEP: 10
15 | VISUALIZATION_STEP: 250
16 | LOAD_MODEL: false
17 | LOAD_EPOCH: 0
18 |
19 | INFERENCE_SETTING:
20 | TEST_DIR_X: "./test_data/test_dir_x/lung_lesion/"
21 | TEST_DIR_Y: "./test_data/test_dir_y/lung_lesion"
22 | MODEL_VERSION: "60"
23 | SAVE_ORIGINAL_IMAGE: False
24 | NORMALIZATION:
25 | TYPE: "kin"
26 | PADDING: 1
27 | KERNEL_SIZE: 3
28 | KERNEL_TYPE: 'gaussian'
29 | THUMBNAIL: "None" # set to "None" if it's not required
30 |
--------------------------------------------------------------------------------
/models/tests/test_kin.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pytest
3 | import torch
4 |
5 | from ..kin import KernelizedInstanceNorm
6 |
7 |
8 | def normalize(x):
9 | std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True)
10 | return (x - mean) / std
11 |
12 |
13 | def test_forward_normal():
14 | layer = KernelizedInstanceNorm(num_features=3, device='cpu')
15 | x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32)
16 | x = torch.FloatTensor(x)
17 |
18 | expected = normalize(x)
19 |
20 | check = layer.forward_normal(torch.FloatTensor(x))
21 |
22 | assert check.numpy() == pytest.approx(expected, abs=1e-6)
23 |
24 |
25 | def test_init_collection():
26 | layer = KernelizedInstanceNorm(num_features=3, device='cpu')
27 | layer.init_collection(y_anchor_num=10, x_anchor_num=9)
28 |
29 | expected_mean_table = np.zeros(shape=(10, 9, 3))
30 | expected_std_table = np.zeros(shape=(10, 9, 3))
31 |
32 | np.testing.assert_array_equal(layer.mean_table.numpy(), expected_mean_table)
33 | np.testing.assert_array_equal(layer.std_table.numpy(), expected_std_table)
34 |
35 |
36 | def test_forward_without_anchors():
37 | layer = KernelizedInstanceNorm(num_features=3, device='cpu')
38 | x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32)
39 | x = torch.FloatTensor(x)
40 | expected = normalize(x)
41 |
42 | check = layer.forward(torch.FloatTensor(x))
43 |
44 | assert check.numpy() == pytest.approx(expected, abs=1e-6)
45 |
46 |
47 | def test_forward_with_mode_1():
48 | layer = KernelizedInstanceNorm(num_features=3, kernel_type='constant', device='cpu').eval()
49 | layer.init_collection(y_anchor_num=3, x_anchor_num=3)
50 |
51 | x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32)
52 | x = torch.FloatTensor(x)
53 |
54 | std, mean = torch.std_mean(x, dim=(2, 3))
55 |
56 | expected_mean_table = np.zeros(shape=(3, 3, 3), dtype=np.float32)
57 | expected_std_table = np.zeros(shape=(3, 3, 3), dtype=np.float32)
58 |
59 | expected_mean_table[0, 0] = mean
60 | expected_std_table[0, 0] = std
61 |
62 | check = layer.forward(x, x_anchor=0, y_anchor=0, mode=KernelizedInstanceNorm.Mode.PHASE_CACHING)
63 |
64 | assert check.detach().numpy() == pytest.approx(normalize(x).numpy(), abs=1e-6)
65 | assert layer.mean_table.numpy() == pytest.approx(expected_mean_table)
66 | assert layer.std_table.numpy() == pytest.approx(expected_std_table)
67 |
68 |
69 | def test_forward_with_mode_2():
70 | layer = KernelizedInstanceNorm(num_features=3, kernel_type='constant', device='cpu').eval()
71 | layer.init_collection(y_anchor_num=3, x_anchor_num=3)
72 |
73 | x = np.random.normal(size=(1, 3, 32, 32)).astype(np.float32)
74 | x = torch.FloatTensor(x)
75 |
76 | layer.forward(x, x_anchor=1, y_anchor=1, mode=KernelizedInstanceNorm.Mode.PHASE_CACHING)
77 | check = layer.forward(x, x_anchor=1, y_anchor=1, mode=KernelizedInstanceNorm.Mode.PHASE_INFERENCE)
78 | std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True)
79 |
80 | mean /= 9
81 | std /= 9
82 |
83 | expected = (x - mean) / std
84 |
85 | assert check.detach().numpy() == pytest.approx(expected)
86 |
--------------------------------------------------------------------------------
/models/tin.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class ThumbInstanceNorm(nn.Module):
6 | def __init__(self, num_features, affine=True):
7 | super(ThumbInstanceNorm, self).__init__()
8 | self.thumb_mean = None
9 | self.thumb_std = None
10 | self.normal_instance_normalization = False
11 | self.collection_mode = False
12 | if affine:
13 | self.weight = nn.Parameter(
14 | torch.ones(size=(1, num_features, 1, 1), requires_grad=True)
15 | )
16 | self.bias = nn.Parameter(
17 | torch.zeros(size=(1, num_features, 1, 1), requires_grad=True)
18 | )
19 |
20 | def calc_mean_std(self, feat, eps=1e-5):
21 | size = feat.size()
22 | assert len(size) == 4
23 | N, C = size[:2]
24 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
25 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
26 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
27 | return feat_mean, feat_std
28 |
29 | def forward(self, x):
30 | if self.training or self.normal_instance_normalization:
31 | x_mean, x_std = self.calc_mean_std(x)
32 | x = (x - x_mean) / x_std * self.weight + self.bias
33 | return x
34 | else:
35 | if self.collection_mode:
36 | assert x.shape[0] == 1
37 | x_mean, x_std = self.calc_mean_std(x)
38 | self.thumb_mean = x_mean
39 | self.thumb_std = x_std
40 |
41 | shift = x - self.thumb_mean
42 | x = shift / self.thumb_std * self.weight + self.bias
43 | return x
44 |
45 |
46 | def not_use_thumbnail_instance_norm(model):
47 | for _, layer in model.named_modules():
48 | if isinstance(layer, ThumbInstanceNorm):
49 | layer.collection_mode = False
50 | layer.normal_instance_normalization = True
51 |
52 |
53 | def init_thumbnail_instance_norm(model):
54 | for _, layer in model.named_modules():
55 | if isinstance(layer, ThumbInstanceNorm):
56 | layer.collection_mode = True
57 |
58 |
59 | def use_thumbnail_instance_norm(model):
60 | for _, layer in model.named_modules():
61 | if isinstance(layer, ThumbInstanceNorm):
62 | layer.collection_mode = False
63 | layer.normal_instance_normalization = False
64 |
--------------------------------------------------------------------------------
/models/upsample.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Upsample(nn.Module):
5 | def __init__(self, features):
6 | super().__init__()
7 | layers = [
8 | nn.ReplicationPad2d(1),
9 | nn.ConvTranspose2d(
10 | features,
11 | features,
12 | kernel_size=4,
13 | stride=2,
14 | padding=3,
15 | ),
16 | ]
17 | self.model = nn.Sequential(*layers)
18 |
19 | def forward(self, input):
20 | return self.model(input)
21 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | albumentations==0.5.2
2 | numpy==1.22.4
3 | Pillow>=9.2.0
4 | PyYAML==5.4.1
5 | torch==1.7.0
6 | torchvision==0.8.0
7 | pytest
8 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from collections import defaultdict
4 |
5 | from torch.utils.data import DataLoader
6 | from torchvision.utils import save_image
7 |
8 | from models.model import get_model
9 | from utils.dataset import get_dataset
10 | from utils.util import read_yaml_config, reverse_image_normalize
11 |
12 |
13 | def main():
14 | parser = argparse.ArgumentParser("Model training")
15 | parser.add_argument(
16 | "-c",
17 | "--config",
18 | type=str,
19 | default="./config.yaml",
20 | help="Path to the config file.",
21 | )
22 | args = parser.parse_args()
23 |
24 | config = read_yaml_config(args.config)
25 |
26 | model = get_model(config=config, model_name=config["MODEL_NAME"], isTrain=True)
27 |
28 | dataset = get_dataset(config)
29 |
30 | dataloader = DataLoader(
31 | dataset,
32 | batch_size=config["TRAINING_SETTING"]["BATCH_SIZE"],
33 | shuffle=True,
34 | num_workers=config["TRAINING_SETTING"]["NUM_WORKERS"],
35 | )
36 |
37 | for epoch in range(config["TRAINING_SETTING"]["NUM_EPOCHS"]):
38 | out = defaultdict(int)
39 |
40 | for idx, data in enumerate(dataloader):
41 | print(f"[Epoch {epoch}][Iter {idx}] Processing ...", end="\r")
42 | if epoch == 0 and idx == 0:
43 | model.data_dependent_initialize(data)
44 | model.setup()
45 |
46 | model.set_input(data)
47 | model.optimize_parameters()
48 |
49 | if idx % config["TRAINING_SETTING"]["VISUALIZATION_STEP"] == 0 and idx > 0:
50 | results = model.get_current_visuals()
51 |
52 | for img_name, img in results.items():
53 | save_image(
54 | reverse_image_normalize(img),
55 | os.path.join(
56 | config["EXPERIMENT_ROOT_PATH"],
57 | config["EXPERIMENT_NAME"],
58 | "train",
59 | f"{epoch}_{img_name}_{idx}.png",
60 | ),
61 | )
62 |
63 | for k, v in out.items():
64 | out[k] /= config["TRAINING_SETTING"]["VISUALIZATION_STEP"]
65 |
66 | print(f"[Epoch {epoch}][Iter {idx}] {out}", flush=True)
67 | for k, v in out.items():
68 | out[k] = 0
69 |
70 | losses = model.get_current_losses()
71 | for k, v in losses.items():
72 | out[k] += v
73 |
74 | model.scheduler_step()
75 | if (
76 | epoch % config["TRAINING_SETTING"]["SAVE_MODEL_EPOCH_STEP"] == 0
77 | and config["TRAINING_SETTING"]["SAVE_MODEL"]
78 | ):
79 | model.save_networks(epoch)
80 |
81 | model.save_networks("latest")
82 |
83 |
84 | if __name__ == "__main__":
85 | main()
86 |
--------------------------------------------------------------------------------
/transfer.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import cv2
5 |
6 | from utils.util import read_yaml_config
7 |
8 |
9 | def main():
10 | """
11 | USAGE
12 | python3 transfer.py -c config_example.yaml
13 | or
14 | python3 transfer.py -c config_example.yaml --skip_cropping
15 | """
16 | parser = argparse.ArgumentParser("Model inference")
17 | parser.add_argument(
18 | "-c",
19 | "--config",
20 | type=str,
21 | default="./data/example/config.yaml",
22 | help="Path to the config file.",
23 | )
24 | parser.add_argument("--skip_cropping", action="store_true")
25 | args = parser.parse_args()
26 |
27 | config = read_yaml_config(args.config)
28 | H, W, _ = cv2.imread(config["INFERENCE_SETTING"]["TEST_X"]).shape
29 | if not args.skip_cropping:
30 | os.system(
31 | f"python3 crop.py -i {config['INFERENCE_SETTING']['TEST_X']} "
32 | f"-o {config['INFERENCE_SETTING']['TEST_DIR_X']} "
33 | f"--patch_size {config['CROPPING_SETTING']['PATCH_SIZE']} "
34 | f"--stride {config['CROPPING_SETTING']['PATCH_SIZE']} "
35 | f"--thumbnail "
36 | f"--thumbnail_output {config['INFERENCE_SETTING']['TEST_DIR_X']}"
37 | )
38 | print("Finish cropping and start inference")
39 | os.system(f"python3 inference.py --config {args.config}")
40 | print("Finish inference and start combining images")
41 | os.system(
42 | f"python3 combine.py --config {args.config} "
43 | f"--resize_h {H} --resize_w {W}"
44 | )
45 |
46 |
47 | if __name__ == "__main__":
48 | main()
49 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kaminyou/Kernelized-Instance-Normalization/c464433105d0b712dc605d2fd52487ac7251d1d2/utils/__init__.py
--------------------------------------------------------------------------------