├── imlib
├── __init__.py
├── basic.py
├── transform.py
└── dtype.py
├── requirements.txt
├── datasets
├── 102flowers_dataset_readme.txt
├── download_pix2pix_datasets.py
└── download_cyclegan_datasets.sh
├── pylib
├── __init__.py
├── serialization.py
├── processing.py
├── argument.py
├── path.py
└── timer.py
├── loss.py
├── config.py
├── README.md
├── histogram_layers.py
├── prepare_dataset.py
├── models.py
└── train.py
/imlib/__init__.py:
--------------------------------------------------------------------------------
1 | from imlib.basic import *
2 | from imlib.dtype import *
3 | from imlib.transform import *
4 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | oyaml==1.0
2 | numpy==1.18.1
3 | tensorflow_gpu==2.1.0
4 | scikit_image==0.17.2
5 | skimage==0.0
6 | tensorflow==2.3.1
7 | tqdm==4.53.0
8 |
--------------------------------------------------------------------------------
/datasets/102flowers_dataset_readme.txt:
--------------------------------------------------------------------------------
1 | 102 Flower Category Database
2 | ----------------------------------------------
3 | https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
4 |
5 | -Downloads
6 | --Dataset images
--------------------------------------------------------------------------------
/pylib/__init__.py:
--------------------------------------------------------------------------------
1 | from pylib.argument import *
2 | from pylib.processing import *
3 | from pylib.path import *
4 | from pylib.serialization import *
5 | from pylib.timer import *
6 |
7 | import pprint
8 |
9 | pp = pprint.pprint
10 |
--------------------------------------------------------------------------------
/datasets/download_pix2pix_datasets.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 |
4 | dataset_name = 'edges2shoes' # cityscapes, facades, edges2shoes edges2handbags
5 | _URL = 'https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/' + dataset_name + '.tar.gz'
6 |
7 | path_to_zip = tf.keras.utils.get_file(dataset_name + '.tar.gz',
8 | origin=_URL,
9 | extract=True)
10 |
11 | PATH = os.path.join(os.path.dirname(path_to_zip), dataset_name + '/')
12 |
--------------------------------------------------------------------------------
/datasets/download_cyclegan_datasets.sh:
--------------------------------------------------------------------------------
1 | mkdir datasets
2 | FILE=$1
3 |
4 | if [[ $FILE != "ae_photos" && $FILE != "apple2orange" && $FILE != "summer2winter_yosemite" && $FILE != "horse2zebra" && $FILE != "monet2photo" && $FILE != "cezanne2photo" && $FILE != "ukiyoe2photo" && $FILE != "vangogh2photo" && $FILE != "maps" && $FILE != "cityscapes" && $FILE != "facades" && $FILE != "iphone2dslr_flower" && $FILE != "ae_photos" ]]; then
5 | echo "Available datasets are: apple2orange, summer2winter_yosemite, horse2zebra, monet2photo, cezanne2photo, ukiyoe2photo, vangogh2photo, maps, cityscapes, facades, iphone2dslr_flower, ae_photos"
6 | exit 1
7 | fi
8 |
9 | URL=https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets/$FILE.zip
10 | ZIP_FILE=./datasets/$FILE.zip
11 | TARGET_DIR=./datasets/$FILE/
12 | wget -N $URL -O $ZIP_FILE
13 | mkdir $TARGET_DIR
14 | unzip $ZIP_FILE -d ./datasets/
15 | rm $ZIP_FILE
16 |
--------------------------------------------------------------------------------
/imlib/basic.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import skimage.io as iio
3 |
4 | from imlib import dtype
5 |
6 |
7 | def imread(path, as_gray=False, **kwargs):
8 | """Return a float64 image in [-1.0, 1.0]."""
9 | image = iio.imread(path, as_gray, **kwargs)
10 | if image.dtype == np.uint8:
11 | image = image / 127.5 - 1
12 | elif image.dtype == np.uint16:
13 | image = image / 32767.5 - 1
14 | elif image.dtype in [np.float32, np.float64]:
15 | image = image * 2 - 1.0
16 | else:
17 | raise Exception("Inavailable image dtype: %s!" % image.dtype)
18 | return image
19 |
20 |
21 | def imwrite(image, path, quality=95, **plugin_args):
22 | """Save a [-1.0, 1.0] image."""
23 | iio.imsave(path, dtype.im2uint(image), quality=quality, **plugin_args)
24 |
25 |
26 | def imshow(image):
27 | """Show a [-1.0, 1.0] image."""
28 | iio.imshow(dtype.im2uint(image))
29 |
30 |
31 | show = iio.show
32 |
--------------------------------------------------------------------------------
/imlib/transform.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import skimage.color as color
3 | import skimage.transform as transform
4 |
5 |
6 | rgb2gray = color.rgb2gray
7 | gray2rgb = color.gray2rgb
8 |
9 | imresize = transform.resize
10 | imrescale = transform.rescale
11 |
12 |
13 | def immerge(images, n_rows=None, n_cols=None, padding=0, pad_value=0):
14 | """Merge images to an image with (n_rows * h) * (n_cols * w).
15 |
16 | Parameters
17 | ----------
18 | images : numpy.array or object which can be converted to numpy.array
19 | Images in shape of N * H * W(* C=1 or 3).
20 |
21 | """
22 | images = np.array(images)
23 | n = images.shape[0]
24 | if n_rows:
25 | n_rows = max(min(n_rows, n), 1)
26 | n_cols = int(n - 0.5) // n_rows + 1
27 | elif n_cols:
28 | n_cols = max(min(n_cols, n), 1)
29 | n_rows = int(n - 0.5) // n_cols + 1
30 | else:
31 | n_rows = int(n ** 0.5)
32 | n_cols = int(n - 0.5) // n_rows + 1
33 |
34 | h, w = images.shape[1], images.shape[2]
35 | shape = (h * n_rows + padding * (n_rows - 1),
36 | w * n_cols + padding * (n_cols - 1))
37 | if images.ndim == 4:
38 | shape += (images.shape[3],)
39 | img = np.full(shape, pad_value, dtype=images.dtype)
40 |
41 | for idx, image in enumerate(images):
42 | i = idx % n_cols
43 | j = idx // n_cols
44 | img[j * (h + padding):j * (h + padding) + h,
45 | i * (w + padding):i * (w + padding) + w, ...] = image
46 |
47 | return img
48 |
--------------------------------------------------------------------------------
/pylib/serialization.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import pickle
4 |
5 |
6 | def _check_ext(path, default_ext):
7 | name, ext = os.path.splitext(path)
8 | if ext == '':
9 | if default_ext[0] == '.':
10 | default_ext = default_ext[1:]
11 | path = name + '.' + default_ext
12 | return path
13 |
14 |
15 | def save_json(path, obj, **kwargs):
16 | # default
17 | if 'indent' not in kwargs:
18 | kwargs['indent'] = 4
19 | if 'separators' not in kwargs:
20 | kwargs['separators'] = (',', ': ')
21 |
22 | path = _check_ext(path, 'json')
23 |
24 | # wrap json.dump
25 | with open(path, 'w') as f:
26 | json.dump(obj, f, **kwargs)
27 |
28 |
29 | def load_json(path, **kwargs):
30 | # wrap json.load
31 | with open(path) as f:
32 | return json.load(f, **kwargs)
33 |
34 |
35 | def save_yaml(path, data, **kwargs):
36 | import oyaml as yaml
37 |
38 | path = _check_ext(path, 'yml')
39 |
40 | with open(path, 'w') as f:
41 | yaml.dump(data, f, **kwargs)
42 |
43 |
44 | def load_yaml(path, **kwargs):
45 | import oyaml as yaml
46 | with open(path) as f:
47 | return yaml.load(f, **kwargs)
48 |
49 |
50 | def save_pickle(path, obj, **kwargs):
51 |
52 | path = _check_ext(path, 'pkl')
53 |
54 | # wrap pickle.dump
55 | with open(path, 'wb') as f:
56 | pickle.dump(obj, f, **kwargs)
57 |
58 |
59 | def load_pickle(path, **kwargs):
60 | # wrap pickle.load
61 | with open(path, 'rb') as f:
62 | return pickle.load(f, **kwargs)
63 |
--------------------------------------------------------------------------------
/pylib/processing.py:
--------------------------------------------------------------------------------
1 | import concurrent.futures
2 | import functools
3 | import multiprocessing
4 |
5 |
6 | def run_parallels(work_fn, iterable, max_workers=None, chunksize=1, processing_bar=True, backend_executor=multiprocessing.Pool, debug=False):
7 | if not debug:
8 | with backend_executor(max_workers) as executor:
9 | try:
10 | works = executor.imap(work_fn, iterable, chunksize=chunksize) # for multiprocessing.Pool
11 | except:
12 | works = executor.map(work_fn, iterable, chunksize=chunksize)
13 |
14 | if processing_bar:
15 | try:
16 | import tqdm
17 | try:
18 | total = len(iterable)
19 | except:
20 | total = None
21 | works = tqdm.tqdm(works, total=total)
22 | except ImportError:
23 | print('`import tqdm` fails! Run without processing bar!')
24 |
25 | results = list(works)
26 | else:
27 | results = [work_fn(i) for i in iterable]
28 | return results
29 |
30 | run_parallels_mp = run_parallels
31 | run_parallels_cfprocess = functools.partial(run_parallels, backend_executor=concurrent.futures.ProcessPoolExecutor)
32 | run_parallels_cfthread = functools.partial(run_parallels, backend_executor=concurrent.futures.ThreadPoolExecutor)
33 |
34 |
35 | if __name__ == '__main__':
36 | import time
37 |
38 | def work(i):
39 | time.sleep(0.0001)
40 | i**i
41 | return i
42 |
43 | t = time.time()
44 | results = run_parallels_mp(work, range(10000), max_workers=2, chunksize=1, processing_bar=True, debug=False)
45 | for i in results:
46 | print(i)
47 | print(time.time() - t)
48 |
--------------------------------------------------------------------------------
/pylib/argument.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import functools
3 | import json
4 |
5 | from pylib import serialization
6 |
7 |
8 | GLOBAL_COMMAND_PARSER = argparse.ArgumentParser()
9 |
10 |
11 | def _serialization_wrapper(func):
12 | @functools.wraps(func)
13 | def _wrapper(*args, **kwargs):
14 | to_json = kwargs.pop("to_json", None)
15 | to_yaml = kwargs.pop("to_yaml", None)
16 | namespace = func(*args, **kwargs)
17 | if to_json:
18 | args_to_json(to_json, namespace)
19 | if to_yaml:
20 | args_to_yaml(to_yaml, namespace)
21 | return namespace
22 | return _wrapper
23 |
24 |
25 | def str2bool(v):
26 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
27 | return True
28 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
29 | return False
30 | else:
31 | raise argparse.ArgumentTypeError('Boolean value expected!')
32 |
33 |
34 | def argument(*args, **kwargs):
35 | """Wrap argparse.add_argument."""
36 | if 'type'in kwargs:
37 | if issubclass(kwargs['type'], bool):
38 | kwargs['type'] = str2bool
39 | elif issubclass(kwargs['type'], dict):
40 | kwargs['type'] = json.loads
41 | return GLOBAL_COMMAND_PARSER.add_argument(*args, **kwargs)
42 |
43 |
44 | arg = argument
45 |
46 |
47 | @_serialization_wrapper
48 | def args(args=None, namespace=None):
49 | """Parse args using the global parser."""
50 | namespace = GLOBAL_COMMAND_PARSER.parse_args(args=args, namespace=namespace)
51 | return namespace
52 |
53 |
54 | @_serialization_wrapper
55 | def args_from_xxx(obj, parser, check=True):
56 | """Load args from xxx ignoring type and choices with default still valid.
57 |
58 | Parameters
59 | ----------
60 | parser: function
61 | Should return a dict.
62 |
63 | """
64 | dict_ = parser(obj)
65 | namespace = argparse.ArgumentParser().parse_args(args='') # '' for not to accept command line args
66 | for k, v in dict_.items():
67 | namespace.__setattr__(k, v)
68 | return namespace
69 |
70 |
71 | args_from_dict = functools.partial(args_from_xxx, parser=lambda x: x)
72 | args_from_json = functools.partial(args_from_xxx, parser=serialization.load_json)
73 | args_from_yaml = functools.partial(args_from_xxx, parser=serialization.load_yaml)
74 |
75 |
76 | def args_to_json(path, namespace, **kwagrs):
77 | serialization.save_json(path, vars(namespace), **kwagrs)
78 |
79 |
80 | def args_to_yaml(path, namespace, **kwagrs):
81 | serialization.save_yaml(path, vars(namespace), **kwagrs)
82 |
--------------------------------------------------------------------------------
/imlib/dtype.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def _check(images, dtypes, min_value=-np.inf, max_value=np.inf):
5 | # check type
6 | assert isinstance(images, np.ndarray), '`images` should be np.ndarray!'
7 |
8 | # check dtype
9 | dtypes = dtypes if isinstance(dtypes, (list, tuple)) else [dtypes]
10 | assert images.dtype in dtypes, 'dtype of `images` shoud be one of %s!' % dtypes
11 |
12 | # check nan and inf
13 | assert np.all(np.isfinite(images)), '`images` contains NaN or Inf!'
14 |
15 | # check value
16 | if min_value not in [None, -np.inf]:
17 | l = '[' + str(min_value)
18 | else:
19 | l = '(-inf'
20 | min_value = -np.inf
21 | if max_value not in [None, np.inf]:
22 | r = str(max_value) + ']'
23 | else:
24 | r = 'inf)'
25 | max_value = np.inf
26 | assert np.min(images) >= min_value and np.max(images) <= max_value, \
27 | '`images` should be in the range of %s!' % (l + ',' + r)
28 |
29 |
30 | def to_range(images, min_value=0.0, max_value=1.0, dtype=None):
31 | """Transform images from [-1.0, 1.0] to [min_value, max_value] of dtype."""
32 | _check(images, [np.float32, np.float64], -1.0, 1.0)
33 | dtype = dtype if dtype else images.dtype
34 | return ((images + 1.) / 2. * (max_value - min_value) + min_value).astype(dtype)
35 |
36 |
37 | def float2im(images):
38 | """Transform images from [0, 1.0] to [-1.0, 1.0]."""
39 | _check(images, [np.float32, np.float64], 0.0, 1.0)
40 | return images * 2 - 1.0
41 |
42 |
43 | def float2uint(images):
44 | """Transform images from [0, 1.0] to uint8."""
45 | _check(images, [np.float32, np.float64], -0.0, 1.0)
46 | return (images * 255).astype(np.uint8)
47 |
48 |
49 | def im2uint(images):
50 | """Transform images from [-1.0, 1.0] to uint8."""
51 | return to_range(images, 0, 255, np.uint8)
52 |
53 |
54 | def im2float(images):
55 | """Transform images from [-1.0, 1.0] to [0.0, 1.0]."""
56 | return to_range(images, 0.0, 1.0)
57 |
58 |
59 | def uint2im(images):
60 | """Transform images from uint8 to [-1.0, 1.0] of float64."""
61 | _check(images, np.uint8)
62 | return images / 127.5 - 1.0
63 |
64 |
65 | def uint2float(images):
66 | """Transform images from uint8 to [0.0, 1.0] of float64."""
67 | _check(images, np.uint8)
68 | return images / 255.0
69 |
70 |
71 | def cv2im(images):
72 | """Transform opencv images to [-1.0, 1.0]."""
73 | images = uint2im(images)
74 | return images[..., ::-1]
75 |
76 |
77 | def im2cv(images):
78 | """Transform images from [-1.0, 1.0] to opencv images."""
79 | images = im2uint(images)
80 | return images[..., ::-1]
81 |
--------------------------------------------------------------------------------
/pylib/path.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import fnmatch
3 | import os
4 | import glob as _glob
5 | import sys
6 |
7 |
8 | def add_path(paths):
9 | if not isinstance(paths, (list, tuple)):
10 | paths = [paths]
11 | for path in paths:
12 | if path not in sys.path:
13 | sys.path.insert(0, path)
14 |
15 |
16 | def mkdir(paths):
17 | if not isinstance(paths, (list, tuple)):
18 | paths = [paths]
19 | for path in paths:
20 | if not os.path.exists(path):
21 | os.makedirs(path)
22 |
23 |
24 | def split(path):
25 | """Return dir, name, ext."""
26 | dir, name_ext = os.path.split(path)
27 | name, ext = os.path.splitext(name_ext)
28 | return dir, name, ext
29 |
30 |
31 | def directory(path):
32 | return split(path)[0]
33 |
34 |
35 | def name(path):
36 | return split(path)[1]
37 |
38 |
39 | def ext(path):
40 | return split(path)[2]
41 |
42 |
43 | def name_ext(path):
44 | return ''.join(split(path)[1:])
45 |
46 |
47 | def change_ext(path, ext):
48 | if ext[0] == '.':
49 | ext = ext[1:]
50 | return os.path.splitext(path)[0] + '.' + ext
51 |
52 |
53 | asbpath = os.path.abspath
54 |
55 |
56 | join = os.path.join
57 |
58 |
59 | def prefix(path, prefixes, sep='-'):
60 | prefixes = prefixes if isinstance(prefixes, (list, tuple)) else [prefixes]
61 | dir, name, ext = split(path)
62 | return join(dir, sep.join(prefixes) + sep + name + ext)
63 |
64 |
65 | def suffix(path, suffixes, sep='-'):
66 | suffixes = suffixes if isinstance(suffixes, (list, tuple)) else [suffixes]
67 | dir, name, ext = split(path)
68 | return join(dir, name + sep + sep.join(suffixes) + ext)
69 |
70 |
71 | def prefix_now(path, fmt="%Y-%m-%d-%H:%M:%S", sep='-'):
72 | return prefix(path, prefixes=datetime.datetime.now().strftime(fmt), sep=sep)
73 |
74 |
75 | def suffix_now(path, fmt="%Y-%m-%d-%H:%M:%S", sep='-'):
76 | return suffix(path, suffixes=datetime.datetime.now().strftime(fmt), sep=sep)
77 |
78 |
79 | def glob(dir, pats, recursive=False): # faster than match, python3 only
80 | pats = pats if isinstance(pats, (list, tuple)) else [pats]
81 | matches = []
82 | for pat in pats:
83 | matches += _glob.glob(os.path.join(dir, pat), recursive=recursive)
84 | return matches
85 |
86 |
87 | def match(dir, pats, recursive=False): # slow
88 | pats = pats if isinstance(pats, (list, tuple)) else [pats]
89 |
90 | iterator = list(os.walk(dir))
91 | if not recursive:
92 | iterator = iterator[0:1]
93 |
94 | matches = []
95 | for pat in pats:
96 | for root, _, file_names in iterator:
97 | for file_name in fnmatch.filter(file_names, pat):
98 | matches.append(os.path.join(root, file_name))
99 |
100 | return matches
101 |
--------------------------------------------------------------------------------
/loss.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from histogram_layers import HistogramLayers, HistogramLayersCT
3 |
4 | bc_loss_object = tf.keras.losses.BinaryCrossentropy(from_logits=True)
5 |
6 |
7 | # generator loss mi
8 | def generator_loss_color_transfer(disc_generated_output, gen_output, inp, target, args):
9 | gan_loss = bc_loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
10 |
11 | # histograms instances
12 | hist_1 = HistogramLayersCT(out=gen_output[..., 0], tar=target[..., 0], src=inp[..., 0], args=args)
13 | hist_2 = HistogramLayersCT(out=gen_output[..., 1], tar=target[..., 1], src=inp[..., 1], args=args)
14 | hist_3 = HistogramLayersCT(out=gen_output[..., 2], tar=target[..., 2], src=inp[..., 2], args=args)
15 |
16 | # mi loss
17 | mi_loss_1 = hist_1.calc_cond_entropy_loss_src_out()
18 | mi_loss_2 = hist_2.calc_cond_entropy_loss_src_out()
19 | mi_loss_3 = hist_3.calc_cond_entropy_loss_src_out()
20 |
21 | mi_loss = (mi_loss_1 + mi_loss_2 + mi_loss_3) / 3
22 |
23 | # hist loss
24 | hist_loss_1 = hist_1.calc_hist_loss_tar_out()
25 | hist_loss_2 = hist_2.calc_hist_loss_tar_out()
26 | hist_loss_3 = hist_3.calc_hist_loss_tar_out()
27 |
28 | hist_loss = (hist_loss_1 + hist_loss_2 + hist_loss_3) / 3
29 |
30 | total_gen_loss = (args.gan_loss_weight * gan_loss) + (args.mi_loss_weight * mi_loss) + \
31 | (args.hist_loss_weight * hist_loss)
32 |
33 | return total_gen_loss, gan_loss, mi_loss, hist_loss
34 |
35 |
36 | def generator_loss(disc_generated_output, gen_output, target, args):
37 | gan_loss = bc_loss_object(tf.ones_like(disc_generated_output), disc_generated_output)
38 |
39 | # histograms instances
40 | hist_1 = HistogramLayers(out=gen_output[..., 0], tar=target[..., 0], args=args)
41 | hist_2 = HistogramLayers(out=gen_output[..., 1], tar=target[..., 1], args=args)
42 | hist_3 = HistogramLayers(out=gen_output[..., 2], tar=target[..., 2], args=args)
43 |
44 | # mi loss
45 | mi_loss_1 = hist_1.calc_cond_entropy_loss_tar_out()
46 | mi_loss_2 = hist_2.calc_cond_entropy_loss_tar_out()
47 | mi_loss_3 = hist_3.calc_cond_entropy_loss_tar_out()
48 |
49 | mi_loss = (mi_loss_1 + mi_loss_2 + mi_loss_3) / 3
50 |
51 | # hist loss
52 | hist_loss_1 = hist_1.calc_hist_loss_tar_out()
53 | hist_loss_2 = hist_2.calc_hist_loss_tar_out()
54 | hist_loss_3 = hist_3.calc_hist_loss_tar_out()
55 |
56 | hist_loss = (hist_loss_1 + hist_loss_2 + hist_loss_3) / 3
57 |
58 | total_gen_loss = (args.gan_loss_weight * gan_loss) + (args.mi_loss_weight * mi_loss) + (
59 | args.hist_loss_weight * hist_loss)
60 |
61 | return total_gen_loss, gan_loss, mi_loss, hist_loss
62 |
63 |
64 | def discriminator_loss(disc_real_output, disc_generated_output):
65 | real_loss = bc_loss_object(tf.ones_like(disc_real_output), disc_real_output)
66 |
67 | generated_loss = bc_loss_object(tf.zeros_like(disc_generated_output), disc_generated_output)
68 |
69 | total_disc_loss = real_loss + generated_loss
70 |
71 | return total_disc_loss
72 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 |
4 | def options():
5 | parser = argparse.ArgumentParser()
6 |
7 | # basic parameters
8 | parser.add_argument('--dataroot', required=True, help='path to data directory (should have subfolder with the dataset name)')
9 | parser.add_argument('--task', default='edges2photos', choices=['edges2photos', 'color_transfer', 'colorization'])
10 | parser.add_argument('--dataset', choices=['edges2shoes', 'edges2handbags', '102flowers', 'summer2winter_yosemite'])
11 | parser.add_argument('--output_dir', default='output', help='models, samples and logs are saved here')
12 | parser.add_argument('--output_pre_dir', default='', help='if specified load checkpoint from this directory')
13 |
14 | # model parameters
15 | parser.add_argument('--img_h', type=int, default=256, help='crop image to this image height')
16 | parser.add_argument('--img_w', type=int, default=256, help='crop image to this image width')
17 | parser.add_argument('--img_c', type=int, default=3, help='# of input image channels')
18 | parser.add_argument('--img_out_c', type=int, default=3, help='# of output image channels')
19 |
20 | # histogram layers parameters
21 | parser.add_argument('--bin_num', type=int, default=256, help='histogram layers - number of bins')
22 | parser.add_argument('--kernel_width_ratio', type=float, default=2.5, help='histogram layers - scale kernel width')
23 |
24 | # optimizer parameters
25 | parser.add_argument('--gen_lr', type=float, default=2e-4)
26 | parser.add_argument('--gen_beta_1', type=float, default=0.5)
27 | parser.add_argument('--dis_lr', type=float, default=2e-4)
28 | parser.add_argument('--dis_beta_1', type=float, default=0.5)
29 |
30 | # data prepare parameters
31 | parser.add_argument('--min_val', type=float, default=-1.0, help="normalize image values to this min")
32 | parser.add_argument('--max_val', type=float, default=1.0, help="normalize image values to this max")
33 | parser.add_argument('--yuv', type=bool, default=True, help="convert images to YUV colorspace")
34 |
35 | # training params
36 | parser.add_argument('--batch_size', type=int, help='input batch size')
37 | parser.add_argument('--epochs', type=int, help='number of training epochs')
38 | parser.add_argument('--cp_freq', type=int, help='checkpoint frequency - save every # epochs')
39 | parser.add_argument('--buffer_size', type=int, help='size of shuffle buffer')
40 |
41 | # data prepare params
42 | parser.add_argument('--a2b', type=bool, help='image translation direction AtoB or BtoA')
43 | parser.add_argument('--jitter', type=bool, help='random jitter for data augmentation')
44 |
45 | # loss params
46 | parser.add_argument('--gan_loss_weight', type=float)
47 | parser.add_argument('--mi_loss_weight', type=float)
48 | parser.add_argument('--hist_loss_weight', type=float)
49 |
50 | return parser
51 |
52 |
53 | def define_task_default_params(parser):
54 | args = parser.parse_args()
55 |
56 | if args.task == 'colorization':
57 | parser.set_defaults(dataset='summer2winter_yosemite', batch_size=4, epochs=200, cp_freq=20, buffer_size=400,
58 | a2b=True, jitter=True, gan_loss_weight=1, mi_loss_weight=1, hist_loss_weight=20)
59 |
60 | elif args.task == 'color_transfer':
61 | parser.set_defaults(dataset='102flowers', batch_size=4, epochs=100, cp_freq=10, buffer_size=400, a2b=False,
62 | jitter=True, gan_loss_weight=1, mi_loss_weight=1, hist_loss_weight=100)
63 |
64 | elif args.task == 'edges2photos':
65 | parser.set_defaults(dataset='edges2shoes', batch_size=4, epochs=15, cp_freq=5, buffer_size=400, a2b=False,
66 | jitter=False, gan_loss_weight=1, mi_loss_weight=1, hist_loss_weight=100)
67 |
--------------------------------------------------------------------------------
/pylib/timer.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import timeit
3 |
4 |
5 | class Timer: # deprecated, use tqdm instead
6 | """A timer as a context manager.
7 |
8 | Wraps around a timer. A custom timer can be passed
9 | to the constructor. The default timer is timeit.default_timer.
10 |
11 | Note that the latter measures wall clock time, not CPU time!
12 | On Unix systems, it corresponds to time.time.
13 | On Windows systems, it corresponds to time.clock.
14 |
15 | Parameters
16 | ----------
17 | print_at_exit : boolean
18 | If True, print when exiting context.
19 | format : str
20 | `ms`, `s` or `datetime`.
21 |
22 | References
23 | ----------
24 | - https://github.com/brouberol/contexttimer/blob/master/contexttimer/__init__.py.
25 |
26 |
27 | """
28 |
29 | def __init__(self, fmt='s', print_at_exit=True, timer=timeit.default_timer):
30 | assert fmt in ['ms', 's', 'datetime'], "`fmt` should be 'ms', 's' or 'datetime'!"
31 | self._fmt = fmt
32 | self._print_at_exit = print_at_exit
33 | self._timer = timer
34 | self.start()
35 |
36 | def __enter__(self):
37 | """Start the timer in the context manager scope."""
38 | self.restart()
39 | return self
40 |
41 | def __exit__(self, exc_type, exc_value, exc_traceback):
42 | """Print the end time."""
43 | if self._print_at_exit:
44 | print(str(self))
45 |
46 | def __str__(self):
47 | return self.fmt(self.elapsed)[1]
48 |
49 | def start(self):
50 | self.start_time = self._timer()
51 |
52 | restart = start
53 |
54 | @property
55 | def elapsed(self):
56 | """Return the current elapsed time since last (re)start."""
57 | return self._timer() - self.start_time
58 |
59 | def fmt(self, second):
60 | if self._fmt == 'ms':
61 | time_fmt = second * 1000
62 | time_str = '%s %s' % (time_fmt, self._fmt)
63 | elif self._fmt == 's':
64 | time_fmt = second
65 | time_str = '%s %s' % (time_fmt, self._fmt)
66 | elif self._fmt == 'datetime':
67 | time_fmt = datetime.timedelta(seconds=second)
68 | time_str = str(time_fmt)
69 | return time_fmt, time_str
70 |
71 |
72 | def timeit(run_times=1, **timer_kwargs):
73 | """Function decorator displaying the function execution time.
74 |
75 | All kwargs are the arguments taken by the Timer class constructor.
76 |
77 | """
78 | # store Timer kwargs in local variable so the namespace isn't polluted
79 | # by different level args and kwargs
80 |
81 | def decorator(f):
82 | def wrapper(*args, **kwargs):
83 | timer_kwargs.update(print_at_exit=False)
84 | with Timer(**timer_kwargs) as t:
85 | for _ in range(run_times):
86 | out = f(*args, **kwargs)
87 | fmt = '[*] Execution time of function "%(function_name)s" for %(run_times)d runs is %(execution_time)s = %(execution_time_each)s * %(run_times)d [*]'
88 | context = {'function_name': f.__name__, 'run_times': run_times, 'execution_time': t, 'execution_time_each': t.fmt(t.elapsed / run_times)[1]}
89 | print(fmt % context)
90 | return out
91 | return wrapper
92 |
93 | return decorator
94 |
95 |
96 | if __name__ == "__main__":
97 | import time
98 |
99 | # 1
100 | print(1)
101 | with Timer() as t:
102 | time.sleep(1)
103 | print(t)
104 | time.sleep(1)
105 |
106 | with Timer(fmt='datetime') as t:
107 | time.sleep(1)
108 |
109 | # 2
110 | print(2)
111 | t = Timer(fmt='ms')
112 | time.sleep(2)
113 | print(t)
114 |
115 | t = Timer(fmt='datetime')
116 | time.sleep(1)
117 | print(t)
118 |
119 | # 3
120 | print(3)
121 |
122 | @timeit(run_times=5, fmt='s')
123 | def blah():
124 | time.sleep(2)
125 |
126 | blah()
127 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 | # Differentiable Histogram Loss Functions for Intensity-based Image-to-Image Translation ([TPAMI 2023](https://ieeexplore.ieee.org/document/10133915))
Official TensorFlow Implementation
4 |
5 |
6 | ## Abstract
7 | We introduce the HueNet - a novel deep learning framework for a differentiable construction of intensity (1D) and joint (2D) histograms and present its applicability to paired and unpaired image-to-image translation problems. The key idea is an innovative technique for augmenting a generative neural network by histogram layers appended to the image generator. These histogram layers allow us to define two new histogram-based loss functions for constraining the structural appearance of the synthesized output image and its color distribution. Specifically, the color similarity loss is defined by the Earth Mover's Distance between the intensity histograms of the network output and a color reference image. The structural similarity loss is determined by the mutual information between the output and a content reference image based on their joint histogram. Although the HueNet can be applied to a variety of image-to-image translation problems, we chose to demonstrate its strength on the tasks of color transfer, exemplar-based image colorization, and edges → photo, where the colors of the output image are predefined.
8 |
9 | Prerequisites
10 | -------------------------------------------------------------
11 | The code runs on linux machines with NVIDIA GPUs.
12 |
13 |
14 | Installation
15 | -------------------------------------------------------------
16 | - Tensorflow 2.0 `pip install tensorflow-gpu`
17 | - Tensorflow Addons `pip install tensorflow-addons`
18 | - (if you meet "tf.summary.histogram fails with TypeError" `pip install --upgrade tb-nightly`)
19 | - scikit-image, oyaml, tqdm
20 | - Python 3.6
21 |
22 | - For pip users, please type the command:
23 | pip install -r requirements.txt
24 |
25 |
26 | DeepHist edges2photos
27 | -------------------------------------------------------------
28 | Download a edges2photos dataset (e.g edges2shoes)
29 |
30 | `python ./datasets/download_pix2pix_datasets.py`
31 |
32 | * edit `dataset_name` in `./datasets/download_pix2pix_datasets.py` for other dataset
33 |
34 | Train a model:
35 | `python train.py --dataroot /home//.keras/datasets --task edges2photos`
36 |
37 |
38 | DeepHist colorization
39 | -------------------------------------------------------------
40 | Download a CycleGAN dataset (e.g. summer2winter_yosemite):
41 |
42 | `bash ./datasets/download_cyclegan_datasets.sh summer2winter_yosemite`
43 |
44 | Train a model:
45 | `python train.py --dataroot ./datasets --task colorization`
46 |
47 |
48 | DeepHist color transfer
49 | -------------------------------------------------------------
50 | Download and unzip 102 Flower Category Database from:
51 | https://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
52 |
53 | Train a model:
54 | `python train.py --dataroot ./datasets --task color_transfer`
55 |
56 |
57 | References
58 | -------------------------------------------------------------
59 | ```sh
60 | @inproceedings{CycleGAN2017,
61 | title={Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networkss},
62 | author={Zhu, Jun-Yan and Park, Taesung and Isola, Phillip and Efros, Alexei A},
63 | booktitle={Computer Vision (ICCV), 2017 IEEE International Conference on},
64 | year={2017}
65 | }
66 | @inproceedings{isola2017image,
67 | title={Image-to-Image Translation with Conditional Adversarial Networks},
68 | author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A},
69 | booktitle={Computer Vision and Pattern Recognition (CVPR), 2017 IEEE Conference on},
70 | year={2017}
71 | }
72 | ```
73 |
74 | Citation
75 | -------------------------------------------------------------
76 | If you find either the code or the paper useful for your research, cite our paper:
77 | ```sh
78 | @ARTICLE{10133915,
79 | author={Avi-Aharon, Mor and Arbelle, Assaf and Raviv, Tammy Riklin},
80 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence},
81 | title={Differentiable Histogram Loss Functions for Intensity-based Image-to-Image Translation},
82 | year={2023},
83 | volume={},
84 | number={},
85 | pages={1-12},
86 | doi={10.1109/TPAMI.2023.3278287}}
87 | ```
88 |
--------------------------------------------------------------------------------
/histogram_layers.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 |
5 | class HistogramLayers(object):
6 | """Network augmentation for 1D and 2D (Joint) histograms construction,
7 | Calculate Earth Mover's Distance, Mutual Information loss
8 | between output and target
9 | """
10 |
11 | def __init__(self, out, tar, args):
12 | self.bin_num = args.bin_num
13 | self.min_val = args.min_val
14 | self.max_val = args.max_val
15 | self.interval_length = (self.max_val - self.min_val) / self.bin_num
16 | self.kernel_width = self.interval_length / args.kernel_width_ratio
17 | self.maps_out = self.calc_activation_maps(out)
18 | self.maps_tar = self.calc_activation_maps(tar)
19 | self.n_pixels = self.maps_out.get_shape().as_list()[1] # number of pixels in image (H*W)
20 | self.bs = self.maps_out.get_shape().as_list()[0] # batch size
21 |
22 | def calc_activation_maps(self, img):
23 | # apply approximated shifted rect (bin_num) functions on img
24 | bins_min_max = np.linspace(self.min_val, self.max_val, self.bin_num + 1)
25 | bins_av = (bins_min_max[0:-1] + bins_min_max[1:]) / 2
26 | bins_av = tf.constant(bins_av, dtype=tf.float32) # shape = (,bin_num)
27 | bins_av = tf.expand_dims(bins_av, axis=0) # shape = (1,bin_num)
28 | bins_av = tf.expand_dims(bins_av, axis=0) # shape = (1,1,bin_num)
29 | img_flat = tf.expand_dims(tf.keras.layers.Flatten()(img), axis=-1)
30 | maps = self.activation_func(img_flat, bins_av) # shape = (batch_size,H*W,bin_num)
31 | return maps
32 |
33 | def activation_func(self, img_flat, bins_av):
34 | img_minus_bins_av = tf.subtract(img_flat, bins_av) # shape= (batch_size,H*W,bin_num)
35 | img_plus_bins_av = tf.add(img_flat, bins_av) # shape = (batch_size,H*W,bin_num)
36 | maps = tf.math.sigmoid((img_minus_bins_av + self.interval_length / 2) / self.kernel_width) \
37 | - tf.math.sigmoid((img_minus_bins_av - self.interval_length / 2) / self.kernel_width) \
38 | + tf.math.sigmoid((img_plus_bins_av - 2 * self.min_val + self.interval_length / 2) / self.kernel_width) \
39 | - tf.math.sigmoid((img_plus_bins_av - 2 * self.min_val - self.interval_length / 2) / self.kernel_width) \
40 | + tf.math.sigmoid((img_plus_bins_av - 2 * self.max_val + self.interval_length / 2) / self.kernel_width) \
41 | - tf.math.sigmoid((img_plus_bins_av - 2 * self.max_val - self.interval_length / 2) / self.kernel_width)
42 | return maps
43 |
44 | def calc_cond_entropy_loss(self, maps_x, maps_y):
45 | pxy = tf.matmul(maps_x, maps_y, transpose_a=True) / self.n_pixels
46 | py = tf.reduce_sum(pxy, 1)
47 | # calc conditional entropy: H(X|Y)=-sum_(x,y) p(x,y)log(p(x,y)/p(y))
48 | hy = tf.reduce_sum(tf.math.xlogy(py, py), 1)
49 | hxy = tf.reduce_sum(tf.math.xlogy(pxy, pxy), [1, 2])
50 | cond_entropy = hy - hxy
51 | mean_cond_entropy = tf.reduce_mean(cond_entropy)
52 | return mean_cond_entropy
53 |
54 | def ecdf(self, maps):
55 | # calculate the CDF of p
56 | p = tf.reduce_sum(maps, 1) / self.n_pixels # shape=(batch_size,bin_bum)
57 | return tf.cumsum(p, 1)
58 |
59 | def emd_loss(self, maps, maps_hat):
60 | ecdf_p = self.ecdf(maps) # shape=(batch_size, bin_bum)
61 | ecdf_p_hat = self.ecdf(maps_hat) # shape=(batch_size, bin_bum)
62 | emd = tf.reduce_mean(tf.pow(tf.abs(ecdf_p - ecdf_p_hat), 2), axis=-1) # shape=(batch_size,1)
63 | emd = tf.pow(emd, 1 / 2)
64 | return tf.reduce_mean(emd) # shape=0
65 |
66 | def calc_hist_loss_tar_out(self):
67 | return self.emd_loss(self.maps_tar, self.maps_out)
68 |
69 | def calc_cond_entropy_loss_tar_out(self):
70 | return self.calc_cond_entropy_loss(self.maps_tar, self.maps_out)
71 |
72 | def calc_relative_mi_tar_out(self):
73 | return self.calc_relative_mi(self.maps_tar, self.maps_out)
74 |
75 |
76 | class HistogramLayersCT(HistogramLayers):
77 | """ Used for Color Transfer
78 | EMD(TAR, OUT), MI(SRC, OUT)
79 | """
80 |
81 | def __init__(self, out, tar, src, args):
82 | super().__init__(out, tar, args)
83 | self.maps_src = self.calc_activation_maps(src)
84 |
85 | @tf.function
86 | def calc_cond_entropy_loss_src_out(self):
87 | return self.calc_cond_entropy_loss(self.maps_src, self.maps_out)
--------------------------------------------------------------------------------
/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import imlib as im
3 | import numpy as np
4 |
5 |
6 | def load_hist(real_image, args):
7 | hist1 = tf.histogram_fixed_width(real_image[..., 0], [args.min_val, args.max_val], nbins=args.bin_num)
8 | hist2 = tf.histogram_fixed_width(real_image[..., 1], [args.min_val, args.max_val], nbins=args.bin_num)
9 | hist3 = tf.histogram_fixed_width(real_image[..., 2], [args.min_val, args.max_val], nbins=args.bin_num)
10 |
11 | return hist1, hist2, hist3
12 |
13 |
14 | def load(image_file):
15 | image = tf.io.read_file(image_file)
16 | image = tf.image.decode_jpeg(image)
17 |
18 | w = tf.shape(image)[1]
19 |
20 | w = w // 2
21 | real_image = image[:, :w, :]
22 | input_image = image[:, w:, :]
23 |
24 | input_image = tf.cast(input_image, tf.float32)
25 | real_image = tf.cast(real_image, tf.float32)
26 |
27 | return input_image, real_image
28 |
29 |
30 | def load_colorization(image_file):
31 | image = tf.io.read_file(image_file)
32 | image = tf.image.decode_jpeg(image)
33 |
34 | image = tf.cast(image, tf.float32)
35 | gray_image = tf.image.rgb_to_grayscale(image)
36 |
37 | gray_image_3c = tf.keras.layers.concatenate([gray_image, gray_image, gray_image])
38 |
39 | return gray_image_3c, image
40 |
41 |
42 | def resize(input_image, real_image, height, width):
43 | input_image = tf.image.resize(input_image, [height, width],
44 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
45 | real_image = tf.image.resize(real_image, [height, width],
46 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
47 |
48 | return input_image, real_image
49 |
50 |
51 | # from rgb {0,..,255} to yuv [-1,1]
52 | def rgb2yuv(img):
53 | img = tf.image.rgb_to_yuv(img / 255)
54 | img = tf.stack([img[..., 0] * 2 - 1, img[..., 1] * 2, img[..., 2] * 2], -1)
55 | return img
56 |
57 |
58 | # normalizing the images to [-1, 1] (and convert to yuv if needed)
59 | def normalize(input_image, real_image, yuv):
60 | if yuv:
61 | input_image = rgb2yuv(input_image)
62 | real_image = rgb2yuv(real_image)
63 | else:
64 | input_image = (input_image / 127.5) - 1
65 | real_image = (real_image / 127.5) - 1
66 |
67 | return input_image, real_image
68 |
69 |
70 | def load_image(image_file, train, args):
71 | def random_jitter(input_image, real_image):
72 | def random_crop(input_image, real_image):
73 | stacked_image = tf.stack([input_image, real_image], axis=0)
74 | cropped_image = tf.image.random_crop(
75 | stacked_image, size=[2, args.img_h, args.img_w, args.img_c])
76 |
77 | return cropped_image[0], cropped_image[1]
78 |
79 | # resizing to 286 x 286 x 3
80 | input_image, real_image = resize(input_image, real_image, 286, 286)
81 |
82 | # randomly cropping to 256 x 256 x 3
83 | input_image, real_image = random_crop(input_image, real_image)
84 |
85 | if tf.random.uniform(()) > 0.5:
86 | # random mirroring
87 | input_image = tf.image.flip_left_right(input_image)
88 | real_image = tf.image.flip_left_right(real_image)
89 |
90 | return input_image, real_image
91 |
92 | if args.task == 'colorization':
93 | input_image, real_image = load_colorization(image_file)
94 | else:
95 | input_image, real_image = load(image_file)
96 | if train and args.jitter:
97 | input_image, real_image = random_jitter(input_image, real_image)
98 | else:
99 | input_image, real_image = resize(input_image, real_image,
100 | args.img_h, args.img_w)
101 | input_image, real_image = normalize(input_image, real_image, args.yuv)
102 |
103 | if args.a2b:
104 | return input_image, (real_image, load_hist(real_image, args))
105 | else:
106 | return real_image, (input_image, load_hist(input_image, args))
107 |
108 |
109 | def load_single_image(image_file, hist, train, args):
110 | def random_jitter_single(image):
111 | # resizing to 286 x 286 x 3
112 | image = tf.image.resize(image, [286, 286],
113 | method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
114 | # randomly cropping to 256 x 256 x 3
115 | image = tf.image.random_crop(image, size=[args.img_h, args.img_w, args.img_c])
116 | # random mirroring
117 | image = tf.image.random_flip_left_right(image)
118 | return image
119 |
120 | # load image
121 | image = tf.io.read_file(image_file)
122 | image = tf.image.decode_jpeg(image)
123 | image = tf.cast(image, tf.float32)
124 | # resize or jitter
125 | if train and args.jitter:
126 | image = random_jitter_single(image)
127 | else:
128 | image = tf.image.resize(image, [args.img_h, args.img_w], method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
129 | # normalize
130 | image = rgb2yuv(image)
131 | if hist:
132 | return image, load_hist(image, args)
133 | else:
134 | return image
135 |
136 |
137 | # from yuv [-1,1] to rgb [-1,1]
138 | def yuv2rgb(img):
139 | img = tf.stack([(img[..., 0] + 1) / 2, img[..., 1] / 2, img[..., 2] / 2], -1)
140 | img = tf.image.yuv_to_rgb(img)
141 | img = img * 2 - 1
142 | return img
143 |
144 |
145 | def save_images(model, test_input, tar, tar_hist, filename, args):
146 | prediction = model([test_input, tar_hist[0], tar_hist[1], tar_hist[2]], training=True)
147 | if args.yuv:
148 | test_input = yuv2rgb(test_input)
149 | tar = yuv2rgb(tar)
150 | prediction = yuv2rgb(prediction)
151 | img = np.concatenate([test_input[0], tar[0], prediction[0]], axis=1)
152 | img = tf.clip_by_value(img, -1.0, 1.0).numpy()
153 | im.imwrite(img, filename)
154 |
155 |
156 | def shuffle_zip(inp_ds, tar_ds, args, seed_inp=None, seed_tar=None):
157 | inp_ds = inp_ds.shuffle(buffer_size=args.buffer_size, seed=seed_inp).batch(batch_size=args.batch_size)
158 | tar_ds = tar_ds.shuffle(buffer_size=args.buffer_size, seed=seed_tar).batch(batch_size=args.batch_size)
159 | return tf.data.Dataset.zip((inp_ds, tar_ds))
160 |
--------------------------------------------------------------------------------
/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 |
4 | def downsample(filters, size, apply_batchnorm=True):
5 | initializer = tf.random_normal_initializer(0., 0.02)
6 |
7 | result = tf.keras.Sequential()
8 | result.add(
9 | tf.keras.layers.Conv2D(filters, size, strides=2, padding='same',
10 | kernel_initializer=initializer, use_bias=False))
11 |
12 | if apply_batchnorm:
13 | result.add(tf.keras.layers.BatchNormalization())
14 |
15 | result.add(tf.keras.layers.LeakyReLU())
16 |
17 | return result
18 |
19 |
20 | class ReflectionPadding2D(tf.keras.layers.Layer):
21 | '''
22 | 2D Reflection Padding
23 | Attributes:
24 | - padding: (padding_width, padding_height) tuple
25 | '''
26 |
27 | def __init__(self, padding=(1, 1), **kwargs):
28 | self.padding = tuple(padding)
29 | super(ReflectionPadding2D, self).__init__(**kwargs)
30 |
31 | def compute_output_shape(self, input_shape):
32 | return (
33 | input_shape[0], input_shape[1] + 2 * self.padding[0], input_shape[2] + 2 * self.padding[1], input_shape[3])
34 |
35 | def call(self, input_tensor, mask=None):
36 | padding_width, padding_height = self.padding
37 | return tf.pad(input_tensor, [[0, 0], [padding_height, padding_height], [padding_width, padding_width], [0, 0]],
38 | 'REFLECT')
39 |
40 |
41 | # replace Conv2DTranspose with upsampling and conv2d, add bias
42 | def upsample(filters, size, apply_dropout=False):
43 | initializer = tf.random_normal_initializer(0., 0.02)
44 |
45 | result = tf.keras.Sequential()
46 |
47 | result.add(tf.keras.layers.UpSampling2D(size=(2, 2), interpolation='bilinear'))
48 | result.add(ReflectionPadding2D(padding=(1, 1)))
49 | result.add(tf.keras.layers.Conv2D(filters, 3, strides=1, padding='valid',
50 | use_bias=True, kernel_initializer=initializer))
51 |
52 | result.add(tf.keras.layers.BatchNormalization())
53 |
54 | if apply_dropout:
55 | result.add(tf.keras.layers.Dropout(0.5))
56 |
57 | result.add(tf.keras.layers.ReLU())
58 |
59 | return result
60 |
61 |
62 | def Generator(args):
63 | inputs = tf.keras.layers.Input(shape=[args.img_h, args.img_w, args.img_c])
64 |
65 | input_hist_1 = tf.keras.layers.Input(shape=[args.bin_num])
66 | input_hist_2 = tf.keras.layers.Input(shape=[args.bin_num])
67 | input_hist_3 = tf.keras.layers.Input(shape=[args.bin_num])
68 |
69 | hist_1_em = tf.keras.layers.Embedding(args.img_h * args.img_w + 1, 1, input_length=args.bin_num)(input_hist_1)
70 | hist_2_em = tf.keras.layers.Embedding(args.img_h * args.img_w + 1, 1, input_length=args.bin_num)(input_hist_2)
71 | hist_3_em = tf.keras.layers.Embedding(args.img_h * args.img_w + 1, 1, input_length=args.bin_num)(input_hist_3)
72 |
73 | hist_em = tf.keras.layers.Concatenate()([hist_1_em, hist_2_em, hist_3_em])
74 | hist_em_flat = tf.keras.layers.Flatten()(hist_em)
75 | hist = tf.keras.layers.Dense(args.bin_num * 2, input_shape=(args.bin_num * 3,), activation='relu')(
76 | hist_em_flat)
77 |
78 | down_stack = [
79 | downsample(64, 4, apply_batchnorm=False), # (bs, 128, 128, 64)
80 | downsample(128, 4), # (bs, 64, 64, 128)
81 | downsample(256, 4), # (bs, 32, 32, 256)
82 | downsample(512, 4), # (bs, 16, 16, 512)
83 | downsample(512, 4), # (bs, 8, 8, 512)
84 | downsample(512, 4), # (bs, 4, 4, 512)
85 | downsample(512, 4), # (bs, 2, 2, 512)
86 | downsample(512, 4), # (bs, 1, 1, 512)
87 | ]
88 |
89 | up_stack = [
90 | upsample(512, 4, apply_dropout=True), # (bs, 2, 2, 1024)
91 | upsample(512, 4, apply_dropout=True), # (bs, 4, 4, 1024)
92 | upsample(512, 4, apply_dropout=True), # (bs, 8, 8, 1024)
93 | upsample(512, 4), # (bs, 16, 16, 1024)
94 | upsample(256, 4), # (bs, 32, 32, 512)
95 | upsample(128, 4), # (bs, 64, 64, 256)
96 | upsample(64, 4), # (bs, 128, 128, 128)
97 | ]
98 |
99 | initializer = tf.random_normal_initializer(0., 0.02)
100 | last = tf.keras.layers.Conv2DTranspose(args.img_out_c, 4,
101 | strides=2,
102 | padding='same',
103 | kernel_initializer=initializer,
104 | activation='tanh') # (bs, 256, 256, 3)
105 |
106 | x = inputs
107 |
108 | # Downsampling through the model
109 | skips = []
110 | for down in down_stack:
111 | x = down(x)
112 | skips.append(x)
113 |
114 | skips = reversed(skips[:-1])
115 |
116 | # concat input histogram
117 | hist = tf.expand_dims(hist, 1) # shape=(None,1, 512)
118 | hist = tf.expand_dims(hist, 1) # shape=(None,1,1 512)
119 | x = tf.keras.layers.Concatenate()([x, hist]) # shape=(None, 1, 1, 1024)
120 |
121 | # Upsampling and establishing the skip connections
122 | for up, skip in zip(up_stack, skips):
123 | x = up(x)
124 | x = tf.keras.layers.Concatenate()([x, skip])
125 |
126 | x = last(x)
127 |
128 | return tf.keras.Model(inputs=[inputs, input_hist_1, input_hist_2, input_hist_3], outputs=x)
129 |
130 |
131 | def Discriminator(args, conditional):
132 | initializer = tf.random_normal_initializer(0., 0.02)
133 |
134 | inp = tf.keras.layers.Input(shape=[args.img_h, args.img_w, args.img_c], name='input_image')
135 | if not conditional:
136 | x = inp
137 | inputs = inp
138 | else:
139 | tar = tf.keras.layers.Input(shape=[args.img_h, args.img_w, args.img_c], name='target_image')
140 | inputs = [inp, tar]
141 | x = tf.keras.layers.concatenate([inp, tar]) # (bs, 256, 256, channels*2)
142 |
143 | down1 = downsample(64, 4, False)(x) # (bs, 128, 128, 64)
144 | down2 = downsample(128, 4)(down1) # (bs, 64, 64, 128)
145 | down3 = downsample(256, 4)(down2) # (bs, 32, 32, 256)
146 |
147 | zero_pad1 = tf.keras.layers.ZeroPadding2D()(down3) # (bs, 34, 34, 256)
148 | conv = tf.keras.layers.Conv2D(512, 4, strides=1,
149 | kernel_initializer=initializer,
150 | use_bias=False)(zero_pad1) # (bs, 31, 31, 512)
151 |
152 | batchnorm1 = tf.keras.layers.BatchNormalization()(conv)
153 |
154 | leaky_relu = tf.keras.layers.LeakyReLU()(batchnorm1)
155 |
156 | zero_pad2 = tf.keras.layers.ZeroPadding2D()(leaky_relu) # (bs, 33, 33, 512)
157 |
158 | last = tf.keras.layers.Conv2D(1, 4, strides=1,
159 | kernel_initializer=initializer)(zero_pad2) # (bs, 30, 30, 1)
160 |
161 | return tf.keras.Model(inputs=inputs, outputs=last)
162 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import os
3 | import datetime
4 | import time
5 |
6 | import pathlib
7 | import random
8 |
9 | import pylib as py
10 | from prepare_dataset import load_image, load_single_image, save_images, shuffle_zip
11 | from models import Generator, Discriminator
12 | from loss import generator_loss, discriminator_loss, generator_loss_color_transfer
13 | import atexit
14 | import config
15 |
16 | # parameters
17 | parser = config.options()
18 | config.define_task_default_params(parser)
19 | args = parser.parse_args()
20 |
21 | # output_dir
22 | output_dir = py.join(args.output_dir, args.dataset)
23 | py.mkdir(output_dir)
24 |
25 | # save settings
26 | py.args_to_yaml(py.join(output_dir, 'settings.yml'), args)
27 |
28 | # Input Pipeline
29 | PATH = args.dataroot + '/' + args.dataset + '/'
30 | paired = False if args.task == 'color_transfer' else True
31 |
32 | if paired:
33 | # create paired dataset
34 | if args.dataset == 'summer2winter_yosemite':
35 | train_ds = tf.data.Dataset.list_files(PATH + 'train*/*.jpg')
36 | test_ds = tf.data.Dataset.list_files(PATH + 'test*/*.jpg')
37 | else:
38 | train_ds = tf.data.Dataset.list_files(PATH + 'train/*.jpg')
39 | test_ds = tf.data.Dataset.list_files(PATH + 'val/*.jpg')
40 |
41 | train_ds = train_ds.map(lambda x: load_image(image_file=x, train=True, args=args))
42 | train_ds = train_ds.shuffle(args.buffer_size)
43 | train_ds = train_ds.batch(args.batch_size)
44 | test_ds = test_ds.map(lambda x: load_image(image_file=x, train=False, args=args))
45 | test_ds = test_ds.batch(args.batch_size)
46 |
47 | else:
48 | # create unpaired dataset
49 | data_root = pathlib.Path(PATH)
50 | all_image_paths = list(data_root.glob('*'))
51 | all_image_paths = [str(path) for path in all_image_paths]
52 |
53 | # shuffle paths
54 | random.seed(1)
55 | random.shuffle(all_image_paths)
56 |
57 | # split to train and validation
58 | TRAIN_PER = 0.9
59 | train_size = round(len(all_image_paths) * TRAIN_PER)
60 | train_paths = all_image_paths[0:train_size]
61 | val_paths = all_image_paths[train_size:]
62 |
63 | # create dataset
64 | train_A = tf.data.Dataset.from_tensor_slices(train_paths).map(
65 | lambda x: load_single_image(image_file=x, hist=False, train=True, args=args))
66 | train_B = tf.data.Dataset.from_tensor_slices(train_paths).map(
67 | lambda x: load_single_image(image_file=x, hist=True, train=True, args=args))
68 | test_A = tf.data.Dataset.from_tensor_slices(val_paths).map(
69 | lambda x: load_single_image(image_file=x, hist=False, train=False, args=args))
70 | test_B = tf.data.Dataset.from_tensor_slices(val_paths).map(
71 | lambda x: load_single_image(image_file=x, hist=True, train=False, args=args))
72 |
73 | # Define the models
74 | generator = Generator(args=args)
75 | discriminator = Discriminator(args=args, conditional=paired)
76 |
77 | # Define the Optimizers and Checkpoint-saver
78 | generator_optimizer = tf.keras.optimizers.Adam(args.gen_lr, beta_1=args.gen_beta_1)
79 | discriminator_optimizer = tf.keras.optimizers.Adam(args.dis_lr, beta_1=args.dis_beta_1)
80 |
81 | checkpoint_dir = py.join(output_dir, 'checkpoints')
82 | checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
83 | checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
84 | discriminator_optimizer=discriminator_optimizer,
85 | generator=generator,
86 | discriminator=discriminator)
87 | # restore checkpoints
88 | if args.output_pre_dir:
89 | # restoring the latest checkpoint in checkpoint_dir
90 | output_pre_dir = py.join(args.output_pre_dir, args.dataset)
91 | checkpoint_pre_dir = py.join(output_pre_dir, 'checkpoints')
92 | checkpoint.restore(tf.train.latest_checkpoint(checkpoint_pre_dir))
93 | print('restore checkpoint: ' + tf.train.latest_checkpoint(checkpoint_pre_dir))
94 | else:
95 | print('No checkpoint to restore ...')
96 |
97 |
98 | def exit_handler():
99 | print('Save checkpoint before exit...')
100 | checkpoint.save(file_prefix=checkpoint_prefix)
101 |
102 |
103 | atexit.register(exit_handler)
104 |
105 | # sample
106 | sample_dir = py.join(output_dir, 'samples_training')
107 | py.mkdir(sample_dir)
108 |
109 | # logs
110 | log_dir = py.join(output_dir, 'logs/')
111 | summary_writer = tf.summary.create_file_writer(
112 | log_dir + "fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
113 |
114 |
115 | # Train
116 | @tf.function
117 | def train_step(input_image, target, target_hists):
118 | with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
119 | gen_output = generator([input_image, target_hists[0], target_hists[1], target_hists[2]], training=True)
120 |
121 | if paired:
122 | disc_real_output = discriminator([input_image, target], training=True)
123 | disc_generated_output = discriminator([input_image, gen_output], training=True)
124 | gen_total_loss, gen_gan_loss, gen_mi_loss, gen_hist_loss = generator_loss_color_transfer(
125 | disc_generated_output, gen_output, input_image, target, args)
126 | else:
127 | disc_real_output = discriminator(input_image, training=True)
128 | disc_generated_output = discriminator(gen_output, training=True)
129 | gen_total_loss, gen_gan_loss, gen_mi_loss, gen_hist_loss = generator_loss(disc_generated_output, gen_output,
130 | target, args)
131 |
132 | disc_loss = discriminator_loss(disc_real_output, disc_generated_output)
133 |
134 | generator_gradients = gen_tape.gradient(gen_total_loss,
135 | generator.trainable_variables)
136 | discriminator_gradients = disc_tape.gradient(disc_loss,
137 | discriminator.trainable_variables)
138 |
139 | generator_optimizer.apply_gradients(zip(generator_gradients,
140 | generator.trainable_variables))
141 | discriminator_optimizer.apply_gradients(zip(discriminator_gradients,
142 | discriminator.trainable_variables))
143 |
144 | return {'gen_total_loss': gen_total_loss,
145 | 'gen_gan_loss': gen_gan_loss,
146 | 'gen_mi_loss': gen_mi_loss,
147 | 'gen_hist_loss': gen_hist_loss,
148 | 'disc_loss': disc_loss}
149 |
150 |
151 | def fit():
152 | total_loss_dict = {'gen_total_loss': 0, 'gen_gan_loss': 0, 'gen_mi_loss': 0, 'gen_hist_loss': 0, 'disc_loss': 0}
153 |
154 | for epoch in range(args.epochs):
155 | start = time.time()
156 | global train_ds, test_ds
157 |
158 | if not paired:
159 | train_ds = shuffle_zip(train_A, train_B, args)
160 | test_ds = shuffle_zip(test_A, test_B, args)
161 |
162 | for example_input, (example_target, example_target_hists) in test_ds.take(1):
163 | save_images(generator, example_input, example_target, example_target_hists,
164 | py.join(sample_dir, 'epoch-%09d.jpg' % epoch), args)
165 |
166 | print("Epoch: ", epoch)
167 |
168 | # Train
169 | total_loss_dict = dict.fromkeys(total_loss_dict, 0) # initialize total loss with zero
170 |
171 | for n, (input_image, (target, target_hists)) in train_ds.enumerate():
172 |
173 | loss_dict = train_step(input_image, target, target_hists)
174 |
175 | for loss_name in total_loss_dict:
176 | total_loss_dict[loss_name] += loss_dict[loss_name]
177 |
178 | if (n + 1) % 100 == 0 or n == 0:
179 | print(n.numpy() + 1, end=' ')
180 | for loss_name in loss_dict:
181 | print(loss_name, ':', loss_dict[loss_name].numpy(), end=' ')
182 | print()
183 |
184 | # write logs
185 | with summary_writer.as_default():
186 | for loss_name in total_loss_dict:
187 | tf.summary.scalar(loss_name, total_loss_dict[loss_name] / float(n + 1), step=epoch)
188 | summary_writer.flush() # Mor: try to fix tensor-board graphs
189 |
190 | # saving (checkpoint) the model every args.cp_freq epochs
191 | if (epoch + 1) % args.cp_freq == 0:
192 | checkpoint.save(file_prefix=checkpoint_prefix)
193 | print("save checkpoint ...")
194 |
195 | print('Time taken for epoch {} is {} sec\n'.format(epoch + 1,
196 | time.time() - start))
197 | checkpoint.save(file_prefix=checkpoint_prefix)
198 | print("save checkpoint ...")
199 |
200 |
201 | fit()
202 |
--------------------------------------------------------------------------------