├── imlib ├── __init__.py ├── basic.py ├── transform.py └── dtype.py ├── requirements.txt ├── datasets ├── 102flowers_dataset_readme.txt ├── download_pix2pix_datasets.py └── download_cyclegan_datasets.sh ├── pylib ├── __init__.py ├── serialization.py ├── processing.py ├── argument.py ├── path.py └── timer.py ├── loss.py ├── config.py ├── README.md ├── histogram_layers.py ├── prepare_dataset.py ├── models.py └── train.py /imlib/__init__.py: -------------------------------------------------------------------------------- 1 | from imlib.basic import * 2 | from imlib.dtype import * 3 | from imlib.transform import * 4 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | oyaml==1.0 2 | numpy==1.18.1 3 | tensorflow_gpu==2.1.0 4 | scikit_image==0.17.2 5 | skimage==0.0 6 | tensorflow==2.3.1 7 | tqdm==4.53.0 8 | -------------------------------------------------------------------------------- /datasets/102flowers_dataset_readme.txt: -------------------------------------------------------------------------------- 1 | 102 Flower Category Database 2 | ---------------------------------------------- 3 | https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html 4 | 5 | -Downloads 6 | --Dataset images -------------------------------------------------------------------------------- /pylib/__init__.py: -------------------------------------------------------------------------------- 1 | from pylib.argument import * 2 | from pylib.processing import * 3 | from pylib.path import * 4 | from pylib.serialization import * 5 | from pylib.timer import * 6 | 7 | import pprint 8 | 9 | pp = pprint.pprint 10 | -------------------------------------------------------------------------------- /datasets/download_pix2pix_datasets.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | 4 | dataset_name = 'edges2shoes' # cityscapes, facades, edges2shoes edges2handbags 5 | _URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/' + dataset_name + '.tar.gz' 6 | 7 | path_to_zip = tf.keras.utils.get_file(dataset_name + '.tar.gz', 8 | origin=_URL, 9 | extract=True) 10 | 11 | PATH = os.path.join(os.path.dirname(path_to_zip), dataset_name + '/') 12 | -------------------------------------------------------------------------------- /datasets/download_cyclegan_datasets.sh: -------------------------------------------------------------------------------- 1 | mkdir datasets 2 | FILE=$1 3 | 4 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then 5 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos" 6 | exit 1 7 | fi 8 | 9 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip 10 | ZIP_FILE=./datasets/$FILE.zip 11 | TARGET_DIR=./datasets/$FILE/ 12 | wget -N $URL -O $ZIP_FILE 13 | mkdir $TARGET_DIR 14 | unzip $ZIP_FILE -d ./datasets/ 15 | rm $ZIP_FILE 16 | -------------------------------------------------------------------------------- /imlib/basic.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.io as iio 3 | 4 | from imlib import dtype 5 | 6 | 7 | def imread(path, as_gray=False, **kwargs): 8 | """Return a float64 image in [-1.0, 1.0].""" 9 | image = iio.imread(path, as_gray, **kwargs) 10 | if image.dtype == np.uint8: 11 | image = image / 127.5 - 1 12 | elif image.dtype == np.uint16: 13 | image = image / 32767.5 - 1 14 | elif image.dtype in [np.float32, np.float64]: 15 | image = image * 2 - 1.0 16 | else: 17 | raise Exception("Inavailable image dtype: %s!" % image.dtype) 18 | return image 19 | 20 | 21 | def imwrite(image, path, quality=95, **plugin_args): 22 | """Save a [-1.0, 1.0] image.""" 23 | iio.imsave(path, dtype.im2uint(image), quality=quality, **plugin_args) 24 | 25 | 26 | def imshow(image): 27 | """Show a [-1.0, 1.0] image.""" 28 | iio.imshow(dtype.im2uint(image)) 29 | 30 | 31 | show = iio.show 32 | -------------------------------------------------------------------------------- /imlib/transform.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import skimage.color as color 3 | import skimage.transform as transform 4 | 5 | 6 | rgb2gray = color.rgb2gray 7 | gray2rgb = color.gray2rgb 8 | 9 | imresize = transform.resize 10 | imrescale = transform.rescale 11 | 12 | 13 | def immerge(images, n_rows=None, n_cols=None, padding=0, pad_value=0): 14 | """Merge images to an image with (n_rows * h) * (n_cols * w). 15 | 16 | Parameters 17 | ---------- 18 | images : numpy.array or object which can be converted to numpy.array 19 | Images in shape of N * H * W(* C=1 or 3). 20 | 21 | """ 22 | images = np.array(images) 23 | n = images.shape[0] 24 | if n_rows: 25 | n_rows = max(min(n_rows, n), 1) 26 | n_cols = int(n - 0.5) // n_rows + 1 27 | elif n_cols: 28 | n_cols = max(min(n_cols, n), 1) 29 | n_rows = int(n - 0.5) // n_cols + 1 30 | else: 31 | n_rows = int(n ** 0.5) 32 | n_cols = int(n - 0.5) // n_rows + 1 33 | 34 | h, w = images.shape[1], images.shape[2] 35 | shape = (h * n_rows + padding * (n_rows - 1), 36 | w * n_cols + padding * (n_cols - 1)) 37 | if images.ndim == 4: 38 | shape += (images.shape[3],) 39 | img = np.full(shape, pad_value, dtype=images.dtype) 40 | 41 | for idx, image in enumerate(images): 42 | i = idx % n_cols 43 | j = idx // n_cols 44 | img[j * (h + padding):j * (h + padding) + h, 45 | i * (w + padding):i * (w + padding) + w, ...] = image 46 | 47 | return img 48 | -------------------------------------------------------------------------------- /pylib/serialization.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import pickle 4 | 5 | 6 | def _check_ext(path, default_ext): 7 | name, ext = os.path.splitext(path) 8 | if ext == '': 9 | if default_ext[0] == '.': 10 | default_ext = default_ext[1:] 11 | path = name + '.' + default_ext 12 | return path 13 | 14 | 15 | def save_json(path, obj, **kwargs): 16 | # default 17 | if 'indent' not in kwargs: 18 | kwargs['indent'] = 4 19 | if 'separators' not in kwargs: 20 | kwargs['separators'] = (',', ': ') 21 | 22 | path = _check_ext(path, 'json') 23 | 24 | # wrap json.dump 25 | with open(path, 'w') as f: 26 | json.dump(obj, f, **kwargs) 27 | 28 | 29 | def load_json(path, **kwargs): 30 | # wrap json.load 31 | with open(path) as f: 32 | return json.load(f, **kwargs) 33 | 34 | 35 | def save_yaml(path, data, **kwargs): 36 | import oyaml as yaml 37 | 38 | path = _check_ext(path, 'yml') 39 | 40 | with open(path, 'w') as f: 41 | yaml.dump(data, f, **kwargs) 42 | 43 | 44 | def load_yaml(path, **kwargs): 45 | import oyaml as yaml 46 | with open(path) as f: 47 | return yaml.load(f, **kwargs) 48 | 49 | 50 | def save_pickle(path, obj, **kwargs): 51 | 52 | path = _check_ext(path, 'pkl') 53 | 54 | # wrap pickle.dump 55 | with open(path, 'wb') as f: 56 | pickle.dump(obj, f, **kwargs) 57 | 58 | 59 | def load_pickle(path, **kwargs): 60 | # wrap pickle.load 61 | with open(path, 'rb') as f: 62 | return pickle.load(f, **kwargs) 63 | -------------------------------------------------------------------------------- /pylib/processing.py: -------------------------------------------------------------------------------- 1 | import concurrent.futures 2 | import functools 3 | import multiprocessing 4 | 5 | 6 | def run_parallels(work_fn, iterable, max_workers=None, chunksize=1, processing_bar=True, backend_executor=multiprocessing.Pool, debug=False): 7 | if not debug: 8 | with backend_executor(max_workers) as executor: 9 | try: 10 | works = executor.imap(work_fn, iterable, chunksize=chunksize) # for multiprocessing.Pool 11 | except: 12 | works = executor.map(work_fn, iterable, chunksize=chunksize) 13 | 14 | if processing_bar: 15 | try: 16 | import tqdm 17 | try: 18 | total = len(iterable) 19 | except: 20 | total = None 21 | works = tqdm.tqdm(works, total=total) 22 | except ImportError: 23 | print('`import tqdm` fails! Run without processing bar!') 24 | 25 | results = list(works) 26 | else: 27 | results = [work_fn(i) for i in iterable] 28 | return results 29 | 30 | run_parallels_mp = run_parallels 31 | run_parallels_cfprocess = functools.partial(run_parallels, backend_executor=concurrent.futures.ProcessPoolExecutor) 32 | run_parallels_cfthread = functools.partial(run_parallels, backend_executor=concurrent.futures.ThreadPoolExecutor) 33 | 34 | 35 | if __name__ == '__main__': 36 | import time 37 | 38 | def work(i): 39 | time.sleep(0.0001) 40 | i**i 41 | return i 42 | 43 | t = time.time() 44 | results = run_parallels_mp(work, range(10000), max_workers=2, chunksize=1, processing_bar=True, debug=False) 45 | for i in results: 46 | print(i) 47 | print(time.time() - t) 48 | -------------------------------------------------------------------------------- /pylib/argument.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import functools 3 | import json 4 | 5 | from pylib import serialization 6 | 7 | 8 | GLOBAL_COMMAND_PARSER = argparse.ArgumentParser() 9 | 10 | 11 | def _serialization_wrapper(func): 12 | @functools.wraps(func) 13 | def _wrapper(*args, **kwargs): 14 | to_json = kwargs.pop("to_json", None) 15 | to_yaml = kwargs.pop("to_yaml", None) 16 | namespace = func(*args, **kwargs) 17 | if to_json: 18 | args_to_json(to_json, namespace) 19 | if to_yaml: 20 | args_to_yaml(to_yaml, namespace) 21 | return namespace 22 | return _wrapper 23 | 24 | 25 | def str2bool(v): 26 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 27 | return True 28 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 29 | return False 30 | else: 31 | raise argparse.ArgumentTypeError('Boolean value expected!') 32 | 33 | 34 | def argument(*args, **kwargs): 35 | """Wrap argparse.add_argument.""" 36 | if 'type'in kwargs: 37 | if issubclass(kwargs['type'], bool): 38 | kwargs['type'] = str2bool 39 | elif issubclass(kwargs['type'], dict): 40 | kwargs['type'] = json.loads 41 | return GLOBAL_COMMAND_PARSER.add_argument(*args, **kwargs) 42 | 43 | 44 | arg = argument 45 | 46 | 47 | @_serialization_wrapper 48 | def args(args=None, namespace=None): 49 | """Parse args using the global parser.""" 50 | namespace = GLOBAL_COMMAND_PARSER.parse_args(args=args, namespace=namespace) 51 | return namespace 52 | 53 | 54 | @_serialization_wrapper 55 | def args_from_xxx(obj, parser, check=True): 56 | """Load args from xxx ignoring type and choices with default still valid. 57 | 58 | Parameters 59 | ---------- 60 | parser: function 61 | Should return a dict. 62 | 63 | """ 64 | dict_ = parser(obj) 65 | namespace = argparse.ArgumentParser().parse_args(args='') # '' for not to accept command line args 66 | for k, v in dict_.items(): 67 | namespace.__setattr__(k, v) 68 | return namespace 69 | 70 | 71 | args_from_dict = functools.partial(args_from_xxx, parser=lambda x: x) 72 | args_from_json = functools.partial(args_from_xxx, parser=serialization.load_json) 73 | args_from_yaml = functools.partial(args_from_xxx, parser=serialization.load_yaml) 74 | 75 | 76 | def args_to_json(path, namespace, **kwagrs): 77 | serialization.save_json(path, vars(namespace), **kwagrs) 78 | 79 | 80 | def args_to_yaml(path, namespace, **kwagrs): 81 | serialization.save_yaml(path, vars(namespace), **kwagrs) 82 | -------------------------------------------------------------------------------- /imlib/dtype.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def _check(images, dtypes, min_value=-np.inf, max_value=np.inf): 5 | # check type 6 | assert isinstance(images, np.ndarray), '`images` should be np.ndarray!' 7 | 8 | # check dtype 9 | dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes] 10 | assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes 11 | 12 | # check nan and inf 13 | assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!' 14 | 15 | # check value 16 | if min_value not in [None, -np.inf]: 17 | l = '[' + str(min_value) 18 | else: 19 | l = '(-inf' 20 | min_value = -np.inf 21 | if max_value not in [None, np.inf]: 22 | r = str(max_value) + ']' 23 | else: 24 | r = 'inf)' 25 | max_value = np.inf 26 | assert np.min(images) >= min_value and np.max(images) <= max_value, \ 27 | '`images` should be in the range of %s!' % (l + ',' + r) 28 | 29 | 30 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None): 31 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype.""" 32 | _check(images, [np.float32, np.float64], -1.0, 1.0) 33 | dtype = dtype if dtype else images.dtype 34 | return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype) 35 | 36 | 37 | def float2im(images): 38 | """Transform images from [0, 1.0] to [-1.0, 1.0].""" 39 | _check(images, [np.float32, np.float64], 0.0, 1.0) 40 | return images * 2 - 1.0 41 | 42 | 43 | def float2uint(images): 44 | """Transform images from [0, 1.0] to uint8.""" 45 | _check(images, [np.float32, np.float64], -0.0, 1.0) 46 | return (images * 255).astype(np.uint8) 47 | 48 | 49 | def im2uint(images): 50 | """Transform images from [-1.0, 1.0] to uint8.""" 51 | return to_range(images, 0, 255, np.uint8) 52 | 53 | 54 | def im2float(images): 55 | """Transform images from [-1.0, 1.0] to [0.0, 1.0].""" 56 | return to_range(images, 0.0, 1.0) 57 | 58 | 59 | def uint2im(images): 60 | """Transform images from uint8 to [-1.0, 1.0] of float64.""" 61 | _check(images, np.uint8) 62 | return images / 127.5 - 1.0 63 | 64 | 65 | def uint2float(images): 66 | """Transform images from uint8 to [0.0, 1.0] of float64.""" 67 | _check(images, np.uint8) 68 | return images / 255.0 69 | 70 | 71 | def cv2im(images): 72 | """Transform opencv images to [-1.0, 1.0].""" 73 | images = uint2im(images) 74 | return images[..., ::-1] 75 | 76 | 77 | def im2cv(images): 78 | """Transform images from [-1.0, 1.0] to opencv images.""" 79 | images = im2uint(images) 80 | return images[..., ::-1] 81 | -------------------------------------------------------------------------------- /pylib/path.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import fnmatch 3 | import os 4 | import glob as _glob 5 | import sys 6 | 7 | 8 | def add_path(paths): 9 | if not isinstance(paths, (list, tuple)): 10 | paths = [paths] 11 | for path in paths: 12 | if path not in sys.path: 13 | sys.path.insert(0, path) 14 | 15 | 16 | def mkdir(paths): 17 | if not isinstance(paths, (list, tuple)): 18 | paths = [paths] 19 | for path in paths: 20 | if not os.path.exists(path): 21 | os.makedirs(path) 22 | 23 | 24 | def split(path): 25 | """Return dir, name, ext.""" 26 | dir, name_ext = os.path.split(path) 27 | name, ext = os.path.splitext(name_ext) 28 | return dir, name, ext 29 | 30 | 31 | def directory(path): 32 | return split(path)[0] 33 | 34 | 35 | def name(path): 36 | return split(path)[1] 37 | 38 | 39 | def ext(path): 40 | return split(path)[2] 41 | 42 | 43 | def name_ext(path): 44 | return ''.join(split(path)[1:]) 45 | 46 | 47 | def change_ext(path, ext): 48 | if ext[0] == '.': 49 | ext = ext[1:] 50 | return os.path.splitext(path)[0] + '.' + ext 51 | 52 | 53 | asbpath = os.path.abspath 54 | 55 | 56 | join = os.path.join 57 | 58 | 59 | def prefix(path, prefixes, sep='-'): 60 | prefixes = prefixes if isinstance(prefixes, (list, tuple)) else [prefixes] 61 | dir, name, ext = split(path) 62 | return join(dir, sep.join(prefixes) + sep + name + ext) 63 | 64 | 65 | def suffix(path, suffixes, sep='-'): 66 | suffixes = suffixes if isinstance(suffixes, (list, tuple)) else [suffixes] 67 | dir, name, ext = split(path) 68 | return join(dir, name + sep + sep.join(suffixes) + ext) 69 | 70 | 71 | def prefix_now(path, fmt="%Y-%m-%d-%H:%M:%S", sep='-'): 72 | return prefix(path, prefixes=datetime.datetime.now().strftime(fmt), sep=sep) 73 | 74 | 75 | def suffix_now(path, fmt="%Y-%m-%d-%H:%M:%S", sep='-'): 76 | return suffix(path, suffixes=datetime.datetime.now().strftime(fmt), sep=sep) 77 | 78 | 79 | def glob(dir, pats, recursive=False): # faster than match, python3 only 80 | pats = pats if isinstance(pats, (list, tuple)) else [pats] 81 | matches = [] 82 | for pat in pats: 83 | matches += _glob.glob(os.path.join(dir, pat), recursive=recursive) 84 | return matches 85 | 86 | 87 | def match(dir, pats, recursive=False): # slow 88 | pats = pats if isinstance(pats, (list, tuple)) else [pats] 89 | 90 | iterator = list(os.walk(dir)) 91 | if not recursive: 92 | iterator = iterator[0:1] 93 | 94 | matches = [] 95 | for pat in pats: 96 | for root, _, file_names in iterator: 97 | for file_name in fnmatch.filter(file_names, pat): 98 | matches.append(os.path.join(root, file_name)) 99 | 100 | return matches 101 | -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from histogram_layers import HistogramLayers, HistogramLayersCT 3 | 4 | bc_loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True) 5 | 6 | 7 | # generator loss mi 8 | def generator_loss_color_transfer(disc_generated_output, gen_output, inp, target, args): 9 | gan_loss = bc_loss_object(tf.ones_like(disc_generated_output), disc_generated_output) 10 | 11 | # histograms instances 12 | hist_1 = HistogramLayersCT(out=gen_output[..., 0], tar=target[..., 0], src=inp[..., 0], args=args) 13 | hist_2 = HistogramLayersCT(out=gen_output[..., 1], tar=target[..., 1], src=inp[..., 1], args=args) 14 | hist_3 = HistogramLayersCT(out=gen_output[..., 2], tar=target[..., 2], src=inp[..., 2], args=args) 15 | 16 | # mi loss 17 | mi_loss_1 = hist_1.calc_cond_entropy_loss_src_out() 18 | mi_loss_2 = hist_2.calc_cond_entropy_loss_src_out() 19 | mi_loss_3 = hist_3.calc_cond_entropy_loss_src_out() 20 | 21 | mi_loss = (mi_loss_1 + mi_loss_2 + mi_loss_3) / 3 22 | 23 | # hist loss 24 | hist_loss_1 = hist_1.calc_hist_loss_tar_out() 25 | hist_loss_2 = hist_2.calc_hist_loss_tar_out() 26 | hist_loss_3 = hist_3.calc_hist_loss_tar_out() 27 | 28 | hist_loss = (hist_loss_1 + hist_loss_2 + hist_loss_3) / 3 29 | 30 | total_gen_loss = (args.gan_loss_weight * gan_loss) + (args.mi_loss_weight * mi_loss) + \ 31 | (args.hist_loss_weight * hist_loss) 32 | 33 | return total_gen_loss, gan_loss, mi_loss, hist_loss 34 | 35 | 36 | def generator_loss(disc_generated_output, gen_output, target, args): 37 | gan_loss = bc_loss_object(tf.ones_like(disc_generated_output), disc_generated_output) 38 | 39 | # histograms instances 40 | hist_1 = HistogramLayers(out=gen_output[..., 0], tar=target[..., 0], args=args) 41 | hist_2 = HistogramLayers(out=gen_output[..., 1], tar=target[..., 1], args=args) 42 | hist_3 = HistogramLayers(out=gen_output[..., 2], tar=target[..., 2], args=args) 43 | 44 | # mi loss 45 | mi_loss_1 = hist_1.calc_cond_entropy_loss_tar_out() 46 | mi_loss_2 = hist_2.calc_cond_entropy_loss_tar_out() 47 | mi_loss_3 = hist_3.calc_cond_entropy_loss_tar_out() 48 | 49 | mi_loss = (mi_loss_1 + mi_loss_2 + mi_loss_3) / 3 50 | 51 | # hist loss 52 | hist_loss_1 = hist_1.calc_hist_loss_tar_out() 53 | hist_loss_2 = hist_2.calc_hist_loss_tar_out() 54 | hist_loss_3 = hist_3.calc_hist_loss_tar_out() 55 | 56 | hist_loss = (hist_loss_1 + hist_loss_2 + hist_loss_3) / 3 57 | 58 | total_gen_loss = (args.gan_loss_weight * gan_loss) + (args.mi_loss_weight * mi_loss) + ( 59 | args.hist_loss_weight * hist_loss) 60 | 61 | return total_gen_loss, gan_loss, mi_loss, hist_loss 62 | 63 | 64 | def discriminator_loss(disc_real_output, disc_generated_output): 65 | real_loss = bc_loss_object(tf.ones_like(disc_real_output), disc_real_output) 66 | 67 | generated_loss = bc_loss_object(tf.zeros_like(disc_generated_output), disc_generated_output) 68 | 69 | total_disc_loss = real_loss + generated_loss 70 | 71 | return total_disc_loss 72 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | 4 | def options(): 5 | parser = argparse.ArgumentParser() 6 | 7 | # basic parameters 8 | parser.add_argument('--dataroot', required=True, help='path to data directory (should have subfolder with the dataset name)') 9 | parser.add_argument('--task', default='edges2photos', choices=['edges2photos', 'color_transfer', 'colorization']) 10 | parser.add_argument('--dataset', choices=['edges2shoes', 'edges2handbags', '102flowers', 'summer2winter_yosemite']) 11 | parser.add_argument('--output_dir', default='output', help='models, samples and logs are saved here') 12 | parser.add_argument('--output_pre_dir', default='', help='if specified load checkpoint from this directory') 13 | 14 | # model parameters 15 | parser.add_argument('--img_h', type=int, default=256, help='crop image to this image height') 16 | parser.add_argument('--img_w', type=int, default=256, help='crop image to this image width') 17 | parser.add_argument('--img_c', type=int, default=3, help='# of input image channels') 18 | parser.add_argument('--img_out_c', type=int, default=3, help='# of output image channels') 19 | 20 | # histogram layers parameters 21 | parser.add_argument('--bin_num', type=int, default=256, help='histogram layers - number of bins') 22 | parser.add_argument('--kernel_width_ratio', type=float, default=2.5, help='histogram layers - scale kernel width') 23 | 24 | # optimizer parameters 25 | parser.add_argument('--gen_lr', type=float, default=2e-4) 26 | parser.add_argument('--gen_beta_1', type=float, default=0.5) 27 | parser.add_argument('--dis_lr', type=float, default=2e-4) 28 | parser.add_argument('--dis_beta_1', type=float, default=0.5) 29 | 30 | # data prepare parameters 31 | parser.add_argument('--min_val', type=float, default=-1.0, help="normalize image values to this min") 32 | parser.add_argument('--max_val', type=float, default=1.0, help="normalize image values to this max") 33 | parser.add_argument('--yuv', type=bool, default=True, help="convert images to YUV colorspace") 34 | 35 | # training params 36 | parser.add_argument('--batch_size', type=int, help='input batch size') 37 | parser.add_argument('--epochs', type=int, help='number of training epochs') 38 | parser.add_argument('--cp_freq', type=int, help='checkpoint frequency - save every # epochs') 39 | parser.add_argument('--buffer_size', type=int, help='size of shuffle buffer') 40 | 41 | # data prepare params 42 | parser.add_argument('--a2b', type=bool, help='image translation direction AtoB or BtoA') 43 | parser.add_argument('--jitter', type=bool, help='random jitter for data augmentation') 44 | 45 | # loss params 46 | parser.add_argument('--gan_loss_weight', type=float) 47 | parser.add_argument('--mi_loss_weight', type=float) 48 | parser.add_argument('--hist_loss_weight', type=float) 49 | 50 | return parser 51 | 52 | 53 | def define_task_default_params(parser): 54 | args = parser.parse_args() 55 | 56 | if args.task == 'colorization': 57 | parser.set_defaults(dataset='summer2winter_yosemite', batch_size=4, epochs=200, cp_freq=20, buffer_size=400, 58 | a2b=True, jitter=True, gan_loss_weight=1, mi_loss_weight=1, hist_loss_weight=20) 59 | 60 | elif args.task == 'color_transfer': 61 | parser.set_defaults(dataset='102flowers', batch_size=4, epochs=100, cp_freq=10, buffer_size=400, a2b=False, 62 | jitter=True, gan_loss_weight=1, mi_loss_weight=1, hist_loss_weight=100) 63 | 64 | elif args.task == 'edges2photos': 65 | parser.set_defaults(dataset='edges2shoes', batch_size=4, epochs=15, cp_freq=5, buffer_size=400, a2b=False, 66 | jitter=False, gan_loss_weight=1, mi_loss_weight=1, hist_loss_weight=100) 67 | -------------------------------------------------------------------------------- /pylib/timer.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import timeit 3 | 4 | 5 | class Timer: # deprecated, use tqdm instead 6 | """A timer as a context manager. 7 | 8 | Wraps around a timer. A custom timer can be passed 9 | to the constructor. The default timer is timeit.default_timer. 10 | 11 | Note that the latter measures wall clock time, not CPU time! 12 | On Unix systems, it corresponds to time.time. 13 | On Windows systems, it corresponds to time.clock. 14 | 15 | Parameters 16 | ---------- 17 | print_at_exit : boolean 18 | If True, print when exiting context. 19 | format : str 20 | `ms`, `s` or `datetime`. 21 | 22 | References 23 | ---------- 24 | - https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py. 25 | 26 | 27 | """ 28 | 29 | def __init__(self, fmt='s', print_at_exit=True, timer=timeit.default_timer): 30 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!" 31 | self._fmt = fmt 32 | self._print_at_exit = print_at_exit 33 | self._timer = timer 34 | self.start() 35 | 36 | def __enter__(self): 37 | """Start the timer in the context manager scope.""" 38 | self.restart() 39 | return self 40 | 41 | def __exit__(self, exc_type, exc_value, exc_traceback): 42 | """Print the end time.""" 43 | if self._print_at_exit: 44 | print(str(self)) 45 | 46 | def __str__(self): 47 | return self.fmt(self.elapsed)[1] 48 | 49 | def start(self): 50 | self.start_time = self._timer() 51 | 52 | restart = start 53 | 54 | @property 55 | def elapsed(self): 56 | """Return the current elapsed time since last (re)start.""" 57 | return self._timer() - self.start_time 58 | 59 | def fmt(self, second): 60 | if self._fmt == 'ms': 61 | time_fmt = second * 1000 62 | time_str = '%s %s' % (time_fmt, self._fmt) 63 | elif self._fmt == 's': 64 | time_fmt = second 65 | time_str = '%s %s' % (time_fmt, self._fmt) 66 | elif self._fmt == 'datetime': 67 | time_fmt = datetime.timedelta(seconds=second) 68 | time_str = str(time_fmt) 69 | return time_fmt, time_str 70 | 71 | 72 | def timeit(run_times=1, **timer_kwargs): 73 | """Function decorator displaying the function execution time. 74 | 75 | All kwargs are the arguments taken by the Timer class constructor. 76 | 77 | """ 78 | # store Timer kwargs in local variable so the namespace isn't polluted 79 | # by different level args and kwargs 80 | 81 | def decorator(f): 82 | def wrapper(*args, **kwargs): 83 | timer_kwargs.update(print_at_exit=False) 84 | with Timer(**timer_kwargs) as t: 85 | for _ in range(run_times): 86 | out = f(*args, **kwargs) 87 | fmt = '[*] Execution time of function "%(function_name)s" for %(run_times)d runs is %(execution_time)s = %(execution_time_each)s * %(run_times)d [*]' 88 | context = {'function_name': f.__name__, 'run_times': run_times, 'execution_time': t, 'execution_time_each': t.fmt(t.elapsed / run_times)[1]} 89 | print(fmt % context) 90 | return out 91 | return wrapper 92 | 93 | return decorator 94 | 95 | 96 | if __name__ == "__main__": 97 | import time 98 | 99 | # 1 100 | print(1) 101 | with Timer() as t: 102 | time.sleep(1) 103 | print(t) 104 | time.sleep(1) 105 | 106 | with Timer(fmt='datetime') as t: 107 | time.sleep(1) 108 | 109 | # 2 110 | print(2) 111 | t = Timer(fmt='ms') 112 | time.sleep(2) 113 | print(t) 114 | 115 | t = Timer(fmt='datetime') 116 | time.sleep(1) 117 | print(t) 118 | 119 | # 3 120 | print(3) 121 | 122 | @timeit(run_times=5, fmt='s') 123 | def blah(): 124 | time.sleep(2) 125 | 126 | blah() 127 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 | 3 | # Differentiable Histogram Loss Functions for Intensity-based Image-to-Image Translation ([TPAMI 2023](https://ieeexplore.ieee.org/document/10133915))
Official TensorFlow Implementation 4 |
5 | 6 | ## Abstract 7 | We introduce the HueNet - a novel deep learning framework for a differentiable construction of intensity (1D) and joint (2D) histograms and present its applicability to paired and unpaired image-to-image translation problems. The key idea is an innovative technique for augmenting a generative neural network by histogram layers appended to the image generator. These histogram layers allow us to define two new histogram-based loss functions for constraining the structural appearance of the synthesized output image and its color distribution. Specifically, the color similarity loss is defined by the Earth Mover's Distance between the intensity histograms of the network output and a color reference image. The structural similarity loss is determined by the mutual information between the output and a content reference image based on their joint histogram. Although the HueNet can be applied to a variety of image-to-image translation problems, we chose to demonstrate its strength on the tasks of color transfer, exemplar-based image colorization, and edges → photo, where the colors of the output image are predefined. 8 | 9 | Prerequisites 10 | ------------------------------------------------------------- 11 | The code runs on linux machines with NVIDIA GPUs. 12 | 13 | 14 | Installation 15 | ------------------------------------------------------------- 16 | - Tensorflow 2.0 `pip install tensorflow-gpu` 17 | - Tensorflow Addons `pip install tensorflow-addons` 18 | - (if you meet "tf.summary.histogram fails with TypeError" `pip install --upgrade tb-nightly`) 19 | - scikit-image, oyaml, tqdm 20 | - Python 3.6 21 | 22 | - For pip users, please type the command: 23 | pip install -r requirements.txt 24 | 25 | 26 | DeepHist edges2photos 27 | ------------------------------------------------------------- 28 | Download a edges2photos dataset (e.g edges2shoes) 29 | 30 | `python ./datasets/download_pix2pix_datasets.py` 31 | 32 | * edit `dataset_name` in `./datasets/download_pix2pix_datasets.py` for other dataset 33 | 34 | Train a model: 35 | `python train.py --dataroot /home//.keras/datasets --task edges2photos` 36 | 37 | 38 | DeepHist colorization 39 | ------------------------------------------------------------- 40 | Download a CycleGAN dataset (e.g. summer2winter_yosemite): 41 | 42 | `bash ./datasets/download_cyclegan_datasets.sh summer2winter_yosemite` 43 | 44 | Train a model: 45 | `python train.py --dataroot ./datasets --task colorization` 46 | 47 | 48 | DeepHist color transfer 49 | ------------------------------------------------------------- 50 | Download and unzip 102 Flower Category Database from: 51 | https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html 52 | 53 | Train a model: 54 | `python train.py --dataroot ./datasets --task color_transfer` 55 | 56 | 57 | References 58 | ------------------------------------------------------------- 59 | ```sh 60 | @inproceedings{CycleGAN2017, 61 | title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss}, 62 | author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A}, 63 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on}, 64 | year={2017} 65 | } 66 | @inproceedings{isola2017image, 67 | title={Image-to-Image Translation with Conditional Adversarial Networks}, 68 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A}, 69 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on}, 70 | year={2017} 71 | } 72 | ``` 73 | 74 | Citation 75 | ------------------------------------------------------------- 76 | If you find either the code or the paper useful for your research, cite our paper: 77 | ```sh 78 | @ARTICLE{10133915, 79 | author={Avi-Aharon, Mor and Arbelle, Assaf and Raviv, Tammy Riklin}, 80 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 81 | title={Differentiable Histogram Loss Functions for Intensity-based Image-to-Image Translation}, 82 | year={2023}, 83 | volume={}, 84 | number={}, 85 | pages={1-12}, 86 | doi={10.1109/TPAMI.2023.3278287}} 87 | ``` 88 | -------------------------------------------------------------------------------- /histogram_layers.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | 5 | class HistogramLayers(object): 6 | """Network augmentation for 1D and 2D (Joint) histograms construction, 7 | Calculate Earth Mover's Distance, Mutual Information loss 8 | between output and target 9 | """ 10 | 11 | def __init__(self, out, tar, args): 12 | self.bin_num = args.bin_num 13 | self.min_val = args.min_val 14 | self.max_val = args.max_val 15 | self.interval_length = (self.max_val - self.min_val) / self.bin_num 16 | self.kernel_width = self.interval_length / args.kernel_width_ratio 17 | self.maps_out = self.calc_activation_maps(out) 18 | self.maps_tar = self.calc_activation_maps(tar) 19 | self.n_pixels = self.maps_out.get_shape().as_list()[1] # number of pixels in image (H*W) 20 | self.bs = self.maps_out.get_shape().as_list()[0] # batch size 21 | 22 | def calc_activation_maps(self, img): 23 | # apply approximated shifted rect (bin_num) functions on img 24 | bins_min_max = np.linspace(self.min_val, self.max_val, self.bin_num + 1) 25 | bins_av = (bins_min_max[0:-1] + bins_min_max[1:]) / 2 26 | bins_av = tf.constant(bins_av, dtype=tf.float32) # shape = (,bin_num) 27 | bins_av = tf.expand_dims(bins_av, axis=0) # shape = (1,bin_num) 28 | bins_av = tf.expand_dims(bins_av, axis=0) # shape = (1,1,bin_num) 29 | img_flat = tf.expand_dims(tf.keras.layers.Flatten()(img), axis=-1) 30 | maps = self.activation_func(img_flat, bins_av) # shape = (batch_size,H*W,bin_num) 31 | return maps 32 | 33 | def activation_func(self, img_flat, bins_av): 34 | img_minus_bins_av = tf.subtract(img_flat, bins_av) # shape= (batch_size,H*W,bin_num) 35 | img_plus_bins_av = tf.add(img_flat, bins_av) # shape = (batch_size,H*W,bin_num) 36 | maps = tf.math.sigmoid((img_minus_bins_av + self.interval_length / 2) / self.kernel_width) \ 37 | - tf.math.sigmoid((img_minus_bins_av - self.interval_length / 2) / self.kernel_width) \ 38 | + tf.math.sigmoid((img_plus_bins_av - 2 * self.min_val + self.interval_length / 2) / self.kernel_width) \ 39 | - tf.math.sigmoid((img_plus_bins_av - 2 * self.min_val - self.interval_length / 2) / self.kernel_width) \ 40 | + tf.math.sigmoid((img_plus_bins_av - 2 * self.max_val + self.interval_length / 2) / self.kernel_width) \ 41 | - tf.math.sigmoid((img_plus_bins_av - 2 * self.max_val - self.interval_length / 2) / self.kernel_width) 42 | return maps 43 | 44 | def calc_cond_entropy_loss(self, maps_x, maps_y): 45 | pxy = tf.matmul(maps_x, maps_y, transpose_a=True) / self.n_pixels 46 | py = tf.reduce_sum(pxy, 1) 47 | # calc conditional entropy: H(X|Y)=-sum_(x,y) p(x,y)log(p(x,y)/p(y)) 48 | hy = tf.reduce_sum(tf.math.xlogy(py, py), 1) 49 | hxy = tf.reduce_sum(tf.math.xlogy(pxy, pxy), [1, 2]) 50 | cond_entropy = hy - hxy 51 | mean_cond_entropy = tf.reduce_mean(cond_entropy) 52 | return mean_cond_entropy 53 | 54 | def ecdf(self, maps): 55 | # calculate the CDF of p 56 | p = tf.reduce_sum(maps, 1) / self.n_pixels # shape=(batch_size,bin_bum) 57 | return tf.cumsum(p, 1) 58 | 59 | def emd_loss(self, maps, maps_hat): 60 | ecdf_p = self.ecdf(maps) # shape=(batch_size, bin_bum) 61 | ecdf_p_hat = self.ecdf(maps_hat) # shape=(batch_size, bin_bum) 62 | emd = tf.reduce_mean(tf.pow(tf.abs(ecdf_p - ecdf_p_hat), 2), axis=-1) # shape=(batch_size,1) 63 | emd = tf.pow(emd, 1 / 2) 64 | return tf.reduce_mean(emd) # shape=0 65 | 66 | def calc_hist_loss_tar_out(self): 67 | return self.emd_loss(self.maps_tar, self.maps_out) 68 | 69 | def calc_cond_entropy_loss_tar_out(self): 70 | return self.calc_cond_entropy_loss(self.maps_tar, self.maps_out) 71 | 72 | def calc_relative_mi_tar_out(self): 73 | return self.calc_relative_mi(self.maps_tar, self.maps_out) 74 | 75 | 76 | class HistogramLayersCT(HistogramLayers): 77 | """ Used for Color Transfer 78 | EMD(TAR, OUT), MI(SRC, OUT) 79 | """ 80 | 81 | def __init__(self, out, tar, src, args): 82 | super().__init__(out, tar, args) 83 | self.maps_src = self.calc_activation_maps(src) 84 | 85 | @tf.function 86 | def calc_cond_entropy_loss_src_out(self): 87 | return self.calc_cond_entropy_loss(self.maps_src, self.maps_out) -------------------------------------------------------------------------------- /prepare_dataset.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import imlib as im 3 | import numpy as np 4 | 5 | 6 | def load_hist(real_image, args): 7 | hist1 = tf.histogram_fixed_width(real_image[..., 0], [args.min_val, args.max_val], nbins=args.bin_num) 8 | hist2 = tf.histogram_fixed_width(real_image[..., 1], [args.min_val, args.max_val], nbins=args.bin_num) 9 | hist3 = tf.histogram_fixed_width(real_image[..., 2], [args.min_val, args.max_val], nbins=args.bin_num) 10 | 11 | return hist1, hist2, hist3 12 | 13 | 14 | def load(image_file): 15 | image = tf.io.read_file(image_file) 16 | image = tf.image.decode_jpeg(image) 17 | 18 | w = tf.shape(image)[1] 19 | 20 | w = w // 2 21 | real_image = image[:, :w, :] 22 | input_image = image[:, w:, :] 23 | 24 | input_image = tf.cast(input_image, tf.float32) 25 | real_image = tf.cast(real_image, tf.float32) 26 | 27 | return input_image, real_image 28 | 29 | 30 | def load_colorization(image_file): 31 | image = tf.io.read_file(image_file) 32 | image = tf.image.decode_jpeg(image) 33 | 34 | image = tf.cast(image, tf.float32) 35 | gray_image = tf.image.rgb_to_grayscale(image) 36 | 37 | gray_image_3c = tf.keras.layers.concatenate([gray_image, gray_image, gray_image]) 38 | 39 | return gray_image_3c, image 40 | 41 | 42 | def resize(input_image, real_image, height, width): 43 | input_image = tf.image.resize(input_image, [height, width], 44 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 45 | real_image = tf.image.resize(real_image, [height, width], 46 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 47 | 48 | return input_image, real_image 49 | 50 | 51 | # from rgb {0,..,255} to yuv [-1,1] 52 | def rgb2yuv(img): 53 | img = tf.image.rgb_to_yuv(img / 255) 54 | img = tf.stack([img[..., 0] * 2 - 1, img[..., 1] * 2, img[..., 2] * 2], -1) 55 | return img 56 | 57 | 58 | # normalizing the images to [-1, 1] (and convert to yuv if needed) 59 | def normalize(input_image, real_image, yuv): 60 | if yuv: 61 | input_image = rgb2yuv(input_image) 62 | real_image = rgb2yuv(real_image) 63 | else: 64 | input_image = (input_image / 127.5) - 1 65 | real_image = (real_image / 127.5) - 1 66 | 67 | return input_image, real_image 68 | 69 | 70 | def load_image(image_file, train, args): 71 | def random_jitter(input_image, real_image): 72 | def random_crop(input_image, real_image): 73 | stacked_image = tf.stack([input_image, real_image], axis=0) 74 | cropped_image = tf.image.random_crop( 75 | stacked_image, size=[2, args.img_h, args.img_w, args.img_c]) 76 | 77 | return cropped_image[0], cropped_image[1] 78 | 79 | # resizing to 286 x 286 x 3 80 | input_image, real_image = resize(input_image, real_image, 286, 286) 81 | 82 | # randomly cropping to 256 x 256 x 3 83 | input_image, real_image = random_crop(input_image, real_image) 84 | 85 | if tf.random.uniform(()) > 0.5: 86 | # random mirroring 87 | input_image = tf.image.flip_left_right(input_image) 88 | real_image = tf.image.flip_left_right(real_image) 89 | 90 | return input_image, real_image 91 | 92 | if args.task == 'colorization': 93 | input_image, real_image = load_colorization(image_file) 94 | else: 95 | input_image, real_image = load(image_file) 96 | if train and args.jitter: 97 | input_image, real_image = random_jitter(input_image, real_image) 98 | else: 99 | input_image, real_image = resize(input_image, real_image, 100 | args.img_h, args.img_w) 101 | input_image, real_image = normalize(input_image, real_image, args.yuv) 102 | 103 | if args.a2b: 104 | return input_image, (real_image, load_hist(real_image, args)) 105 | else: 106 | return real_image, (input_image, load_hist(input_image, args)) 107 | 108 | 109 | def load_single_image(image_file, hist, train, args): 110 | def random_jitter_single(image): 111 | # resizing to 286 x 286 x 3 112 | image = tf.image.resize(image, [286, 286], 113 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 114 | # randomly cropping to 256 x 256 x 3 115 | image = tf.image.random_crop(image, size=[args.img_h, args.img_w, args.img_c]) 116 | # random mirroring 117 | image = tf.image.random_flip_left_right(image) 118 | return image 119 | 120 | # load image 121 | image = tf.io.read_file(image_file) 122 | image = tf.image.decode_jpeg(image) 123 | image = tf.cast(image, tf.float32) 124 | # resize or jitter 125 | if train and args.jitter: 126 | image = random_jitter_single(image) 127 | else: 128 | image = tf.image.resize(image, [args.img_h, args.img_w], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR) 129 | # normalize 130 | image = rgb2yuv(image) 131 | if hist: 132 | return image, load_hist(image, args) 133 | else: 134 | return image 135 | 136 | 137 | # from yuv [-1,1] to rgb [-1,1] 138 | def yuv2rgb(img): 139 | img = tf.stack([(img[..., 0] + 1) / 2, img[..., 1] / 2, img[..., 2] / 2], -1) 140 | img = tf.image.yuv_to_rgb(img) 141 | img = img * 2 - 1 142 | return img 143 | 144 | 145 | def save_images(model, test_input, tar, tar_hist, filename, args): 146 | prediction = model([test_input, tar_hist[0], tar_hist[1], tar_hist[2]], training=True) 147 | if args.yuv: 148 | test_input = yuv2rgb(test_input) 149 | tar = yuv2rgb(tar) 150 | prediction = yuv2rgb(prediction) 151 | img = np.concatenate([test_input[0], tar[0], prediction[0]], axis=1) 152 | img = tf.clip_by_value(img, -1.0, 1.0).numpy() 153 | im.imwrite(img, filename) 154 | 155 | 156 | def shuffle_zip(inp_ds, tar_ds, args, seed_inp=None, seed_tar=None): 157 | inp_ds = inp_ds.shuffle(buffer_size=args.buffer_size, seed=seed_inp).batch(batch_size=args.batch_size) 158 | tar_ds = tar_ds.shuffle(buffer_size=args.buffer_size, seed=seed_tar).batch(batch_size=args.batch_size) 159 | return tf.data.Dataset.zip((inp_ds, tar_ds)) 160 | -------------------------------------------------------------------------------- /models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def downsample(filters, size, apply_batchnorm=True): 5 | initializer = tf.random_normal_initializer(0., 0.02) 6 | 7 | result = tf.keras.Sequential() 8 | result.add( 9 | tf.keras.layers.Conv2D(filters, size, strides=2, padding='same', 10 | kernel_initializer=initializer, use_bias=False)) 11 | 12 | if apply_batchnorm: 13 | result.add(tf.keras.layers.BatchNormalization()) 14 | 15 | result.add(tf.keras.layers.LeakyReLU()) 16 | 17 | return result 18 | 19 | 20 | class ReflectionPadding2D(tf.keras.layers.Layer): 21 | ''' 22 | 2D Reflection Padding 23 | Attributes: 24 | - padding: (padding_width, padding_height) tuple 25 | ''' 26 | 27 | def __init__(self, padding=(1, 1), **kwargs): 28 | self.padding = tuple(padding) 29 | super(ReflectionPadding2D, self).__init__(**kwargs) 30 | 31 | def compute_output_shape(self, input_shape): 32 | return ( 33 | input_shape[0], input_shape[1] + 2 * self.padding[0], input_shape[2] + 2 * self.padding[1], input_shape[3]) 34 | 35 | def call(self, input_tensor, mask=None): 36 | padding_width, padding_height = self.padding 37 | return tf.pad(input_tensor, [[0, 0], [padding_height, padding_height], [padding_width, padding_width], [0, 0]], 38 | 'REFLECT') 39 | 40 | 41 | # replace Conv2DTranspose with upsampling and conv2d, add bias 42 | def upsample(filters, size, apply_dropout=False): 43 | initializer = tf.random_normal_initializer(0., 0.02) 44 | 45 | result = tf.keras.Sequential() 46 | 47 | result.add(tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear')) 48 | result.add(ReflectionPadding2D(padding=(1, 1))) 49 | result.add(tf.keras.layers.Conv2D(filters, 3, strides=1, padding='valid', 50 | use_bias=True, kernel_initializer=initializer)) 51 | 52 | result.add(tf.keras.layers.BatchNormalization()) 53 | 54 | if apply_dropout: 55 | result.add(tf.keras.layers.Dropout(0.5)) 56 | 57 | result.add(tf.keras.layers.ReLU()) 58 | 59 | return result 60 | 61 | 62 | def Generator(args): 63 | inputs = tf.keras.layers.Input(shape=[args.img_h, args.img_w, args.img_c]) 64 | 65 | input_hist_1 = tf.keras.layers.Input(shape=[args.bin_num]) 66 | input_hist_2 = tf.keras.layers.Input(shape=[args.bin_num]) 67 | input_hist_3 = tf.keras.layers.Input(shape=[args.bin_num]) 68 | 69 | hist_1_em = tf.keras.layers.Embedding(args.img_h * args.img_w + 1, 1, input_length=args.bin_num)(input_hist_1) 70 | hist_2_em = tf.keras.layers.Embedding(args.img_h * args.img_w + 1, 1, input_length=args.bin_num)(input_hist_2) 71 | hist_3_em = tf.keras.layers.Embedding(args.img_h * args.img_w + 1, 1, input_length=args.bin_num)(input_hist_3) 72 | 73 | hist_em = tf.keras.layers.Concatenate()([hist_1_em, hist_2_em, hist_3_em]) 74 | hist_em_flat = tf.keras.layers.Flatten()(hist_em) 75 | hist = tf.keras.layers.Dense(args.bin_num * 2, input_shape=(args.bin_num * 3,), activation='relu')( 76 | hist_em_flat) 77 | 78 | down_stack = [ 79 | downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64) 80 | downsample(128, 4), # (bs, 64, 64, 128) 81 | downsample(256, 4), # (bs, 32, 32, 256) 82 | downsample(512, 4), # (bs, 16, 16, 512) 83 | downsample(512, 4), # (bs, 8, 8, 512) 84 | downsample(512, 4), # (bs, 4, 4, 512) 85 | downsample(512, 4), # (bs, 2, 2, 512) 86 | downsample(512, 4), # (bs, 1, 1, 512) 87 | ] 88 | 89 | up_stack = [ 90 | upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024) 91 | upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024) 92 | upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024) 93 | upsample(512, 4), # (bs, 16, 16, 1024) 94 | upsample(256, 4), # (bs, 32, 32, 512) 95 | upsample(128, 4), # (bs, 64, 64, 256) 96 | upsample(64, 4), # (bs, 128, 128, 128) 97 | ] 98 | 99 | initializer = tf.random_normal_initializer(0., 0.02) 100 | last = tf.keras.layers.Conv2DTranspose(args.img_out_c, 4, 101 | strides=2, 102 | padding='same', 103 | kernel_initializer=initializer, 104 | activation='tanh') # (bs, 256, 256, 3) 105 | 106 | x = inputs 107 | 108 | # Downsampling through the model 109 | skips = [] 110 | for down in down_stack: 111 | x = down(x) 112 | skips.append(x) 113 | 114 | skips = reversed(skips[:-1]) 115 | 116 | # concat input histogram 117 | hist = tf.expand_dims(hist, 1) # shape=(None,1, 512) 118 | hist = tf.expand_dims(hist, 1) # shape=(None,1,1 512) 119 | x = tf.keras.layers.Concatenate()([x, hist]) # shape=(None, 1, 1, 1024) 120 | 121 | # Upsampling and establishing the skip connections 122 | for up, skip in zip(up_stack, skips): 123 | x = up(x) 124 | x = tf.keras.layers.Concatenate()([x, skip]) 125 | 126 | x = last(x) 127 | 128 | return tf.keras.Model(inputs=[inputs, input_hist_1, input_hist_2, input_hist_3], outputs=x) 129 | 130 | 131 | def Discriminator(args, conditional): 132 | initializer = tf.random_normal_initializer(0., 0.02) 133 | 134 | inp = tf.keras.layers.Input(shape=[args.img_h, args.img_w, args.img_c], name='input_image') 135 | if not conditional: 136 | x = inp 137 | inputs = inp 138 | else: 139 | tar = tf.keras.layers.Input(shape=[args.img_h, args.img_w, args.img_c], name='target_image') 140 | inputs = [inp, tar] 141 | x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2) 142 | 143 | down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64) 144 | down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128) 145 | down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256) 146 | 147 | zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256) 148 | conv = tf.keras.layers.Conv2D(512, 4, strides=1, 149 | kernel_initializer=initializer, 150 | use_bias=False)(zero_pad1) # (bs, 31, 31, 512) 151 | 152 | batchnorm1 = tf.keras.layers.BatchNormalization()(conv) 153 | 154 | leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1) 155 | 156 | zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512) 157 | 158 | last = tf.keras.layers.Conv2D(1, 4, strides=1, 159 | kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1) 160 | 161 | return tf.keras.Model(inputs=inputs, outputs=last) 162 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import os 3 | import datetime 4 | import time 5 | 6 | import pathlib 7 | import random 8 | 9 | import pylib as py 10 | from prepare_dataset import load_image, load_single_image, save_images, shuffle_zip 11 | from models import Generator, Discriminator 12 | from loss import generator_loss, discriminator_loss, generator_loss_color_transfer 13 | import atexit 14 | import config 15 | 16 | # parameters 17 | parser = config.options() 18 | config.define_task_default_params(parser) 19 | args = parser.parse_args() 20 | 21 | # output_dir 22 | output_dir = py.join(args.output_dir, args.dataset) 23 | py.mkdir(output_dir) 24 | 25 | # save settings 26 | py.args_to_yaml(py.join(output_dir, 'settings.yml'), args) 27 | 28 | # Input Pipeline 29 | PATH = args.dataroot + '/' + args.dataset + '/' 30 | paired = False if args.task == 'color_transfer' else True 31 | 32 | if paired: 33 | # create paired dataset 34 | if args.dataset == 'summer2winter_yosemite': 35 | train_ds = tf.data.Dataset.list_files(PATH + 'train*/*.jpg') 36 | test_ds = tf.data.Dataset.list_files(PATH + 'test*/*.jpg') 37 | else: 38 | train_ds = tf.data.Dataset.list_files(PATH + 'train/*.jpg') 39 | test_ds = tf.data.Dataset.list_files(PATH + 'val/*.jpg') 40 | 41 | train_ds = train_ds.map(lambda x: load_image(image_file=x, train=True, args=args)) 42 | train_ds = train_ds.shuffle(args.buffer_size) 43 | train_ds = train_ds.batch(args.batch_size) 44 | test_ds = test_ds.map(lambda x: load_image(image_file=x, train=False, args=args)) 45 | test_ds = test_ds.batch(args.batch_size) 46 | 47 | else: 48 | # create unpaired dataset 49 | data_root = pathlib.Path(PATH) 50 | all_image_paths = list(data_root.glob('*')) 51 | all_image_paths = [str(path) for path in all_image_paths] 52 | 53 | # shuffle paths 54 | random.seed(1) 55 | random.shuffle(all_image_paths) 56 | 57 | # split to train and validation 58 | TRAIN_PER = 0.9 59 | train_size = round(len(all_image_paths) * TRAIN_PER) 60 | train_paths = all_image_paths[0:train_size] 61 | val_paths = all_image_paths[train_size:] 62 | 63 | # create dataset 64 | train_A = tf.data.Dataset.from_tensor_slices(train_paths).map( 65 | lambda x: load_single_image(image_file=x, hist=False, train=True, args=args)) 66 | train_B = tf.data.Dataset.from_tensor_slices(train_paths).map( 67 | lambda x: load_single_image(image_file=x, hist=True, train=True, args=args)) 68 | test_A = tf.data.Dataset.from_tensor_slices(val_paths).map( 69 | lambda x: load_single_image(image_file=x, hist=False, train=False, args=args)) 70 | test_B = tf.data.Dataset.from_tensor_slices(val_paths).map( 71 | lambda x: load_single_image(image_file=x, hist=True, train=False, args=args)) 72 | 73 | # Define the models 74 | generator = Generator(args=args) 75 | discriminator = Discriminator(args=args, conditional=paired) 76 | 77 | # Define the Optimizers and Checkpoint-saver 78 | generator_optimizer = tf.keras.optimizers.Adam(args.gen_lr, beta_1=args.gen_beta_1) 79 | discriminator_optimizer = tf.keras.optimizers.Adam(args.dis_lr, beta_1=args.dis_beta_1) 80 | 81 | checkpoint_dir = py.join(output_dir, 'checkpoints') 82 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt") 83 | checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer, 84 | discriminator_optimizer=discriminator_optimizer, 85 | generator=generator, 86 | discriminator=discriminator) 87 | # restore checkpoints 88 | if args.output_pre_dir: 89 | # restoring the latest checkpoint in checkpoint_dir 90 | output_pre_dir = py.join(args.output_pre_dir, args.dataset) 91 | checkpoint_pre_dir = py.join(output_pre_dir, 'checkpoints') 92 | checkpoint.restore(tf.train.latest_checkpoint(checkpoint_pre_dir)) 93 | print('restore checkpoint: ' + tf.train.latest_checkpoint(checkpoint_pre_dir)) 94 | else: 95 | print('No checkpoint to restore ...') 96 | 97 | 98 | def exit_handler(): 99 | print('Save checkpoint before exit...') 100 | checkpoint.save(file_prefix=checkpoint_prefix) 101 | 102 | 103 | atexit.register(exit_handler) 104 | 105 | # sample 106 | sample_dir = py.join(output_dir, 'samples_training') 107 | py.mkdir(sample_dir) 108 | 109 | # logs 110 | log_dir = py.join(output_dir, 'logs/') 111 | summary_writer = tf.summary.create_file_writer( 112 | log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")) 113 | 114 | 115 | # Train 116 | @tf.function 117 | def train_step(input_image, target, target_hists): 118 | with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: 119 | gen_output = generator([input_image, target_hists[0], target_hists[1], target_hists[2]], training=True) 120 | 121 | if paired: 122 | disc_real_output = discriminator([input_image, target], training=True) 123 | disc_generated_output = discriminator([input_image, gen_output], training=True) 124 | gen_total_loss, gen_gan_loss, gen_mi_loss, gen_hist_loss = generator_loss_color_transfer( 125 | disc_generated_output, gen_output, input_image, target, args) 126 | else: 127 | disc_real_output = discriminator(input_image, training=True) 128 | disc_generated_output = discriminator(gen_output, training=True) 129 | gen_total_loss, gen_gan_loss, gen_mi_loss, gen_hist_loss = generator_loss(disc_generated_output, gen_output, 130 | target, args) 131 | 132 | disc_loss = discriminator_loss(disc_real_output, disc_generated_output) 133 | 134 | generator_gradients = gen_tape.gradient(gen_total_loss, 135 | generator.trainable_variables) 136 | discriminator_gradients = disc_tape.gradient(disc_loss, 137 | discriminator.trainable_variables) 138 | 139 | generator_optimizer.apply_gradients(zip(generator_gradients, 140 | generator.trainable_variables)) 141 | discriminator_optimizer.apply_gradients(zip(discriminator_gradients, 142 | discriminator.trainable_variables)) 143 | 144 | return {'gen_total_loss': gen_total_loss, 145 | 'gen_gan_loss': gen_gan_loss, 146 | 'gen_mi_loss': gen_mi_loss, 147 | 'gen_hist_loss': gen_hist_loss, 148 | 'disc_loss': disc_loss} 149 | 150 | 151 | def fit(): 152 | total_loss_dict = {'gen_total_loss': 0, 'gen_gan_loss': 0, 'gen_mi_loss': 0, 'gen_hist_loss': 0, 'disc_loss': 0} 153 | 154 | for epoch in range(args.epochs): 155 | start = time.time() 156 | global train_ds, test_ds 157 | 158 | if not paired: 159 | train_ds = shuffle_zip(train_A, train_B, args) 160 | test_ds = shuffle_zip(test_A, test_B, args) 161 | 162 | for example_input, (example_target, example_target_hists) in test_ds.take(1): 163 | save_images(generator, example_input, example_target, example_target_hists, 164 | py.join(sample_dir, 'epoch-%09d.jpg' % epoch), args) 165 | 166 | print("Epoch: ", epoch) 167 | 168 | # Train 169 | total_loss_dict = dict.fromkeys(total_loss_dict, 0) # initialize total loss with zero 170 | 171 | for n, (input_image, (target, target_hists)) in train_ds.enumerate(): 172 | 173 | loss_dict = train_step(input_image, target, target_hists) 174 | 175 | for loss_name in total_loss_dict: 176 | total_loss_dict[loss_name] += loss_dict[loss_name] 177 | 178 | if (n + 1) % 100 == 0 or n == 0: 179 | print(n.numpy() + 1, end=' ') 180 | for loss_name in loss_dict: 181 | print(loss_name, ':', loss_dict[loss_name].numpy(), end=' ') 182 | print() 183 | 184 | # write logs 185 | with summary_writer.as_default(): 186 | for loss_name in total_loss_dict: 187 | tf.summary.scalar(loss_name, total_loss_dict[loss_name] / float(n + 1), step=epoch) 188 | summary_writer.flush() # Mor: try to fix tensor-board graphs 189 | 190 | # saving (checkpoint) the model every args.cp_freq epochs 191 | if (epoch + 1) % args.cp_freq == 0: 192 | checkpoint.save(file_prefix=checkpoint_prefix) 193 | print("save checkpoint ...") 194 | 195 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1, 196 | time.time() - start)) 197 | checkpoint.save(file_prefix=checkpoint_prefix) 198 | print("save checkpoint ...") 199 | 200 | 201 | fit() 202 | --------------------------------------------------------------------------------