[0-9]+(?:\.[0-9]+)*) # release segment
34 | (?P # pre-release
35 | [-_\.]?
36 | (?P(a|b|c|rc|alpha|beta|pre|preview))
37 | [-_\.]?
38 | (?P[0-9]+)?
39 | )?
40 | (?P # post release
41 | (?:-(?P[0-9]+))
42 | |
43 | (?:
44 | [-_\.]?
45 | (?Ppost|rev|r)
46 | [-_\.]?
47 | (?P[0-9]+)?
48 | )
49 | )?
50 | (?P # dev release
51 | [-_\.]?
52 | (?Pdev)
53 | [-_\.]?
54 | (?P[0-9]+)?
55 | )?
56 | )
57 | (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))? # local version
58 | """
59 |
60 |
61 | def _validate_input(
62 | tensors: List[torch.Tensor],
63 | dim_range: Tuple[int, int] = (0, -1),
64 | data_range: Tuple[float, float] = (0., -1.),
65 | # size_dim_range: Tuple[float, float] = (0., -1.),
66 | size_range: Optional[Tuple[int, int]] = None,
67 | ) -> None:
68 | r"""Check that input(-s) satisfies the requirements
69 | Args:
70 | tensors: Tensors to check
71 | dim_range: Allowed number of dimensions. (min, max)
72 | data_range: Allowed range of values in tensors. (min, max)
73 | size_range: Dimensions to include in size comparison. (start_dim, end_dim + 1)
74 | """
75 |
76 | if not __debug__:
77 | return
78 |
79 | x = tensors[0]
80 |
81 | for t in tensors:
82 | assert torch.is_tensor(t), f'Expected torch.Tensor, got {type(t)}'
83 | assert t.device == x.device, f'Expected tensors to be on {x.device}, got {t.device}'
84 |
85 | if size_range is None:
86 | assert t.size() == x.size(), f'Expected tensors with same size, got {t.size()} and {x.size()}'
87 | else:
88 | assert t.size()[size_range[0]: size_range[1]] == x.size()[size_range[0]: size_range[1]], \
89 | f'Expected tensors with same size at given dimensions, got {t.size()} and {x.size()}'
90 |
91 | if dim_range[0] == dim_range[1]:
92 | assert t.dim() == dim_range[0], f'Expected number of dimensions to be {dim_range[0]}, got {t.dim()}'
93 | elif dim_range[0] < dim_range[1]:
94 | assert dim_range[0] <= t.dim() <= dim_range[1], \
95 | f'Expected number of dimensions to be between {dim_range[0]} and {dim_range[1]}, got {t.dim()}'
96 |
97 | if data_range[0] < data_range[1]:
98 | assert data_range[0] <= t.min(), \
99 | f'Expected values to be greater or equal to {data_range[0]}, got {t.min()}'
100 | assert t.max() <= data_range[1], \
101 | f'Expected values to be lower or equal to {data_range[1]}, got {t.max()}'
102 |
103 |
104 | def _reduce(x: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
105 | r"""Reduce input in batch dimension if needed.
106 |
107 | Args:
108 | x: Tensor with shape (N, *).
109 | reduction: Specifies the reduction type:
110 | ``'none'`` | ``'mean'`` | ``'sum'``. Default: ``'mean'``
111 | """
112 | if reduction == 'none':
113 | return x
114 | elif reduction == 'mean':
115 | return x.mean(dim=0)
116 | elif reduction == 'sum':
117 | return x.sum(dim=0)
118 | else:
119 | raise ValueError("Unknown reduction. Expected one of {'none', 'mean', 'sum'}")
120 |
121 |
122 | def _parse_version(version: Union[str, bytes]) -> Tuple[int, ...]:
123 | """ Parses valid Python versions according to Semver and PEP 440 specifications.
124 | For more on Semver check: https://semver.org/
125 | For more on PEP 440 check: https://www.python.org/dev/peps/pep-0440/.
126 |
127 | Implementation is inspired by:
128 | - https://github.com/python-semver
129 | - https://github.com/pypa/packaging
130 |
131 | Args:
132 | version: unparsed information about the library of interest.
133 |
134 | Returns:
135 | parsed information about the library of interest.
136 | """
137 | if isinstance(version, bytes):
138 | version = version.decode("UTF-8")
139 | elif not isinstance(version, str) and not isinstance(version, bytes):
140 | raise TypeError(f"not expecting type {type(version)}")
141 |
142 | # Semver processing
143 | match = SEMVER_VERSION_PATTERN.match(version)
144 | if match:
145 | matched_version_parts: Dict[str, Any] = match.groupdict()
146 | release = tuple([int(matched_version_parts[k]) for k in ['major', 'minor', 'patch']])
147 | return release
148 |
149 | # PEP 440 processing
150 | regex = re.compile(r"^\s*" + PEP_440_VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)
151 | match = regex.search(version)
152 |
153 | if match is None:
154 | warnings.warn(f"{version} is not a valid SemVer or PEP 440 string")
155 | return tuple()
156 |
157 | release = tuple(int(i) for i in match.group("release").split("."))
158 | return release
159 |
--------------------------------------------------------------------------------
/libs/metric/pytorch_fid/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = '0.3.0'
2 |
3 | import torch
4 | from einops import rearrange, repeat
5 |
6 | from .inception import InceptionV3
7 | from .fid_score import calculate_frechet_distance
8 |
9 |
10 | class PytorchFIDFactory(torch.nn.Module):
11 | """
12 |
13 | Args:
14 | channels:
15 | inception_block_idx:
16 |
17 | Examples:
18 | >>> fid_factory = PytorchFIDFactory()
19 | >>> fid_score = fid_factory.score(real_samples=data, fake_samples=all_images)
20 | >>> print(fid_score)
21 | """
22 |
23 | def __init__(self, channels: int = 3, inception_block_idx: int = 2048):
24 | super().__init__()
25 | self.channels = channels
26 |
27 | # load models
28 | assert inception_block_idx in InceptionV3.BLOCK_INDEX_BY_DIM
29 | block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[inception_block_idx]
30 | self.inception_v3 = InceptionV3([block_idx])
31 |
32 | @torch.no_grad()
33 | def calculate_activation_statistics(self, samples):
34 | features = self.inception_v3(samples)[0]
35 | features = rearrange(features, '... 1 1 -> ...')
36 |
37 | mu = torch.mean(features, dim=0).cpu()
38 | sigma = torch.cov(features).cpu()
39 | return mu, sigma
40 |
41 | def score(self, real_samples, fake_samples):
42 | if self.channels == 1:
43 | real_samples, fake_samples = map(
44 | lambda t: repeat(t, 'b 1 ... -> b c ...', c=3), (real_samples, fake_samples)
45 | )
46 |
47 | min_batch = min(real_samples.shape[0], fake_samples.shape[0])
48 | real_samples, fake_samples = map(lambda t: t[:min_batch], (real_samples, fake_samples))
49 |
50 | m1, s1 = self.calculate_activation_statistics(real_samples)
51 | m2, s2 = self.calculate_activation_statistics(fake_samples)
52 |
53 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
54 | return fid_value
55 |
--------------------------------------------------------------------------------
/libs/modules/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/libs/modules/resizer/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from .resizer import resize
7 | from . import interp_methods
8 |
9 | __all__ = ['resize', 'interp_methods']
10 |
--------------------------------------------------------------------------------
/libs/modules/resizer/interp_methods.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Description:
3 |
4 | from math import pi
5 |
6 | try:
7 | import torch
8 | except ImportError:
9 | torch = None
10 |
11 | try:
12 | import numpy
13 | except ImportError:
14 | numpy = None
15 |
16 | if numpy is None and torch is None:
17 | raise ImportError("Must have either Numpy or PyTorch but both not found")
18 |
19 |
20 | def set_framework_dependencies(x):
21 | if type(x) is numpy.ndarray:
22 | to_dtype = lambda a: a
23 | fw = numpy
24 | else:
25 | to_dtype = lambda a: a.to(x.dtype)
26 | fw = torch
27 | eps = fw.finfo(fw.float32).eps
28 | return fw, to_dtype, eps
29 |
30 |
31 | def support_sz(sz):
32 | def wrapper(f):
33 | f.support_sz = sz
34 | return f
35 |
36 | return wrapper
37 |
38 |
39 | @support_sz(4)
40 | def cubic(x):
41 | fw, to_dtype, eps = set_framework_dependencies(x)
42 | absx = fw.abs(x)
43 | absx2 = absx ** 2
44 | absx3 = absx ** 3
45 | return ((1.5 * absx3 - 2.5 * absx2 + 1.) * to_dtype(absx <= 1.) +
46 | (-0.5 * absx3 + 2.5 * absx2 - 4. * absx + 2.) *
47 | to_dtype((1. < absx) & (absx <= 2.)))
48 |
49 |
50 | @support_sz(4)
51 | def lanczos2(x):
52 | fw, to_dtype, eps = set_framework_dependencies(x)
53 | return (((fw.sin(pi * x) * fw.sin(pi * x / 2) + eps) /
54 | ((pi ** 2 * x ** 2 / 2) + eps)) * to_dtype(abs(x) < 2))
55 |
56 |
57 | @support_sz(6)
58 | def lanczos3(x):
59 | fw, to_dtype, eps = set_framework_dependencies(x)
60 | return (((fw.sin(pi * x) * fw.sin(pi * x / 3) + eps) /
61 | ((pi ** 2 * x ** 2 / 3) + eps)) * to_dtype(abs(x) < 3))
62 |
63 |
64 | @support_sz(2)
65 | def linear(x):
66 | fw, to_dtype, eps = set_framework_dependencies(x)
67 | return ((x + 1) * to_dtype((-1 <= x) & (x < 0)) + (1 - x) *
68 | to_dtype((0 <= x) & (x <= 1)))
69 |
70 |
71 | @support_sz(1)
72 | def box(x):
73 | fw, to_dtype, eps = set_framework_dependencies(x)
74 | return to_dtype((-1 <= x) & (x < 0)) + to_dtype((0 <= x) & (x <= 1))
75 |
--------------------------------------------------------------------------------
/libs/modules/vision/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from .inception import inception_v3
7 | from .vgg import vgg16, vgg19
8 |
9 | __all__ = [
10 | 'inception_v3',
11 | 'vgg16',
12 | 'vgg19'
13 | ]
14 |
--------------------------------------------------------------------------------
/libs/solver/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/libs/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | from . import lazy
6 |
7 | # __getattr__, __dir__, __all__ = lazy.attach(
8 | # __name__,
9 | # submodules={},
10 | # submod_attrs={
11 | # 'misc': ['identity', 'exists', 'default', 'has_int_squareroot', 'sum_params', 'cycle', 'num_to_groups',
12 | # 'extract', 'normalize', 'unnormalize'],
13 | # 'tqdm': ['tqdm_decorator'],
14 | # 'lazy': ['load']
15 | # }
16 | # )
17 |
18 | from .misc import (
19 | identity,
20 | exists,
21 | default,
22 | has_int_squareroot,
23 | sum_params,
24 | cycle,
25 | num_to_groups,
26 | extract,
27 | normalize,
28 | unnormalize
29 | )
30 | from .tqdm import tqdm_decorator
31 |
--------------------------------------------------------------------------------
/libs/utils/imshow.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import pathlib
7 | from pathlib import Path
8 | from typing import Union, List, Text, BinaryIO
9 |
10 | import matplotlib.pyplot as plt
11 | import torch
12 | import torchvision.transforms as transforms
13 |
14 | __all__ = [
15 | 'show_tensor_image',
16 | 'show_images',
17 | 'simulate_forward_diffusion',
18 | 'save_grid_images_and_labels',
19 | 'save_grid_images_and_captions'
20 | ]
21 |
22 | reverse_transforms = transforms.Compose([
23 | # unnormalizing to [0,1]
24 | transforms.Lambda(lambda t: torch.clamp((t + 1) / 2, min=0.0, max=1.0)),
25 | # Add 0.5 after unnormalizing to [0, 255]
26 | transforms.Lambda(lambda t: torch.clamp(t * 255. + 0.5, min=0, max=255)),
27 | # CHW to HWC
28 | transforms.Lambda(lambda t: t.permute(1, 2, 0)),
29 | # to numpy ndarray, dtype int8
30 | transforms.Lambda(lambda t: t.to('cpu', torch.uint8).numpy()),
31 | # Converts a numpy ndarray of shape H x W x C to a PIL Image
32 | transforms.ToPILImage(),
33 | ])
34 |
35 |
36 | def show_tensor_image(image, title="", f_name=None):
37 | # Take first image of batch
38 | if len(image.shape) == 4:
39 | image = image[0, :, :, :]
40 | plt.imshow(reverse_transforms(image))
41 | plt.title(title)
42 |
43 | if f_name is not None and Path(f_name).is_file():
44 | plt.savefig(f_name)
45 | plt.close()
46 |
47 |
48 | def show_images(dataset, num_samples=20, cols=4):
49 | """ Plots some samples from the dataset """
50 | plt.figure(figsize=(15, 15))
51 | for i, img in enumerate(dataset):
52 | if i == num_samples:
53 | break
54 | plt.subplot(num_samples / cols + 1, cols, i + 1)
55 | plt.imshow(img[0])
56 | plt.close()
57 |
58 |
59 | def simulate_forward_diffusion(
60 | image,
61 | dataloader: torch.utils.data.DataLoader,
62 | T: int,
63 | ddpm: torch.nn.Module,
64 | num_images: int,
65 | ):
66 | """ Simulate forward diffusion
67 | Args:
68 | image: add noise to this image
69 | image = next(iter(dataloader))[0]
70 | dataloader:
71 | T:
72 | ddpm:
73 | num_images:
74 | """
75 | plt.figure(figsize=(15, 15))
76 | plt.axis('off')
77 |
78 | stepsize = int(T / num_images)
79 |
80 | for idx in range(0, T, stepsize):
81 | t = torch.Tensor([idx]).type(torch.int64)
82 | plt.subplot(1, num_images + 1, (idx / stepsize) + 1)
83 | image, noise = ddpm.q_sample(image, t)
84 | show_tensor_image(image)
85 |
86 | plt.savefig(f"forward-step-{stepsize}.png")
87 | plt.close()
88 |
89 |
90 | @torch.no_grad()
91 | def save_grid_images_and_labels(
92 | images: Union[torch.Tensor, List[torch.Tensor]],
93 | probs: Union[torch.Tensor, List[torch.Tensor]],
94 | labels: Union[torch.Tensor, List[torch.Tensor]],
95 | classes: Union[torch.Tensor, List[torch.Tensor]],
96 | fp: Union[Text, pathlib.Path, BinaryIO],
97 | nrow: int = 4,
98 | normalize: bool = True
99 | ) -> None:
100 | """Save a given Tensor into an image file.
101 | """
102 | num_images = len(images)
103 | num_rows, num_cols = get_subplot_shape(num_images, nrow)
104 |
105 | fig = plt.figure(figsize=(25, 20))
106 |
107 | for i in range(num_images):
108 | ax = fig.add_subplot(num_rows, num_cols, i + 1)
109 |
110 | image, true_label, prob = images[i], labels[i], probs[i]
111 |
112 | true_prob = prob[true_label]
113 | incorrect_prob, incorrect_label = torch.max(prob, dim=0)
114 | true_class = classes[true_label]
115 |
116 | incorrect_class = classes[incorrect_label]
117 |
118 | if normalize:
119 | image = reverse_transforms(image)
120 |
121 | ax.imshow(image)
122 | title = f'true label: {true_class} ({true_prob:.3f})\n ' \
123 | f'pred label: {incorrect_class} ({incorrect_prob:.3f})'
124 | ax.set_title(title, fontsize=20)
125 | ax.axis('off')
126 |
127 | fig.subplots_adjust(hspace=0.3)
128 |
129 | plt.savefig(fp)
130 | plt.close()
131 |
132 |
133 | @torch.no_grad()
134 | def save_grid_images_and_captions(
135 | images: Union[torch.Tensor, List[torch.Tensor]],
136 | captions: List,
137 | fp: Union[Text, pathlib.Path, BinaryIO],
138 | nrow: int = 4,
139 | normalize: bool = True
140 | ) -> None:
141 | """
142 | Save a grid of images and their captions into an image file.
143 |
144 | Args:
145 | images (Union[torch.Tensor, List[torch.Tensor]]): A list of images to display.
146 | captions (List): A list of captions for each image.
147 | fp (Union[Text, pathlib.Path, BinaryIO]): The file path to save the image to.
148 | nrow (int, optional): The number of images to display in each row. Defaults to 4.
149 | normalize (bool, optional): Whether to normalize the image or not. Defaults to False.
150 | """
151 | num_images = len(images)
152 | num_rows, num_cols = get_subplot_shape(num_images, nrow)
153 |
154 | fig = plt.figure(figsize=(25, 20))
155 |
156 | for i in range(num_images):
157 | ax = fig.add_subplot(num_rows, num_cols, i + 1)
158 | image, caption = images[i], captions[i]
159 |
160 | if normalize:
161 | image = reverse_transforms(image)
162 |
163 | ax.imshow(image)
164 | title = f'"{caption}"' if num_images > 1 else f'"{captions}"'
165 | title = insert_newline(title)
166 | ax.set_title(title, fontsize=20)
167 | ax.axis('off')
168 |
169 | fig.subplots_adjust(hspace=0.3)
170 |
171 | plt.savefig(fp)
172 | plt.close()
173 |
174 |
175 | def get_subplot_shape(num_images, nrow):
176 | """
177 | Calculate the number of rows and columns required to display images in a grid.
178 |
179 | Args:
180 | num_images (int): The total number of images to display.
181 | nrow (int): The maximum number of images to display in each row.
182 |
183 | Returns:
184 | Tuple[int, int]: The number of rows and columns required to display images in a grid.
185 | """
186 | num_cols = min(num_images, nrow)
187 | num_rows = (num_images + num_cols - 1) // num_cols
188 | return num_rows, num_cols
189 |
190 |
191 | def insert_newline(string, point=9):
192 | # split by blank
193 | words = string.split()
194 | if len(words) <= point:
195 | return string
196 |
197 | word_chunks = [words[i:i + point] for i in range(0, len(words), point)]
198 | new_string = "\n".join(" ".join(chunk) for chunk in word_chunks)
199 | return new_string
200 |
--------------------------------------------------------------------------------
/libs/utils/lazy.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import importlib
7 | import importlib.util
8 | import os
9 | import sys
10 |
11 |
12 | def attach(package_name, submodules=None, submod_attrs=None):
13 | """Attach lazily loaded submodules, functions, or other attributes.
14 |
15 | Typically, modules import submodules and attributes as follows::
16 |
17 | import mysubmodule
18 | import anothersubmodule
19 |
20 | from .foo import someattr
21 |
22 | The idea is to replace a package's `__getattr__`, `__dir__`, and
23 | `__all__`, such that all imports work exactly the way they did
24 | before, except that they are only imported when used.
25 |
26 | The typical way to call this function, replacing the above imports, is::
27 |
28 | __getattr__, __lazy_dir__, __all__ = lazy.attach(
29 | __name__,
30 | ['mysubmodule', 'anothersubmodule'],
31 | {'foo': 'someattr'}
32 | )
33 |
34 | This functionality requires Python 3.7 or higher.
35 |
36 | Parameters
37 | ----------
38 | package_name : str
39 | Typically use ``__name__``.
40 | submodules : set
41 | List of submodules to attach.
42 | submod_attrs : dict
43 | Dictionary of submodule -> list of attributes / functions.
44 | These attributes are imported as they are used.
45 |
46 | Returns
47 | -------
48 | __getattr__, __dir__, __all__
49 |
50 | """
51 | if submod_attrs is None:
52 | submod_attrs = {}
53 |
54 | if submodules is None:
55 | submodules = set()
56 | else:
57 | submodules = set(submodules)
58 |
59 | attr_to_modules = {
60 | attr: mod for mod, attrs in submod_attrs.items() for attr in attrs
61 | }
62 |
63 | __all__ = list(submodules | attr_to_modules.keys())
64 |
65 | def __getattr__(name):
66 | if name in submodules:
67 | return importlib.import_module(f'{package_name}.{name}')
68 | elif name in attr_to_modules:
69 | submod = importlib.import_module(
70 | f'{package_name}.{attr_to_modules[name]}'
71 | )
72 | return getattr(submod, name)
73 | else:
74 | raise AttributeError(f'No {package_name} attribute {name}')
75 |
76 | def __dir__():
77 | return __all__
78 |
79 | eager_import = os.environ.get('EAGER_IMPORT', '')
80 | if eager_import not in ['', '0', 'false']:
81 | for attr in set(attr_to_modules.keys()) | submodules:
82 | __getattr__(attr)
83 |
84 | return __getattr__, __dir__, list(__all__)
85 |
86 |
87 | def load(fullname):
88 | """Return a lazily imported proxy for a module.
89 |
90 | We often see the following pattern::
91 |
92 | def myfunc():
93 | import scipy as sp
94 | sp.argmin(...)
95 | ....
96 |
97 | This is to prevent a module, in this case `scipy`, from being
98 | imported at function definition time, since that can be slow.
99 |
100 | This function provides a proxy module that, upon access, imports
101 | the actual module. So the idiom equivalent to the above example is::
102 |
103 | sp = lazy.load("scipy")
104 |
105 | def myfunc():
106 | sp.argmin(...)
107 | ....
108 |
109 | The initial import time is fast because the actual import is delayed
110 | until the first attribute is requested. The overall import time may
111 | decrease as well for users that don't make use of large portions
112 | of the library.
113 |
114 | Parameters
115 | ----------
116 | fullname : str
117 | The full name of the module or submodule to import. For example::
118 |
119 | sp = lazy.load('scipy') # import scipy as sp
120 | spla = lazy.load('scipy.linalg') # import scipy.linalg as spla
121 |
122 | Returns
123 | -------
124 | pm : importlib.util._LazyModule
125 | Proxy module. Can be used like any regularly imported module.
126 | Actual loading of the module occurs upon first attribute request.
127 |
128 | """
129 | try:
130 | return sys.modules[fullname]
131 | except KeyError:
132 | pass
133 |
134 | spec = importlib.util.find_spec(fullname)
135 | if spec is None:
136 | raise ModuleNotFoundError(f"No module name '{fullname}'")
137 |
138 | module = importlib.util.module_from_spec(spec)
139 | sys.modules[fullname] = module
140 |
141 | loader = importlib.util.LazyLoader(spec.loader)
142 | loader.exec_module(module)
143 |
144 | return module
145 |
--------------------------------------------------------------------------------
/libs/utils/logging.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import os
7 | import sys
8 | import errno
9 |
10 |
11 | def get_logger(logs_dir: str, file_name: str = "log.txt"):
12 | logger = PrintLogger(os.path.join(logs_dir, file_name))
13 | sys.stdout = logger # record all python print
14 | return logger
15 |
16 |
17 | class PrintLogger(object):
18 |
19 | def __init__(self, fpath=None):
20 | """
21 | python standard input/output records
22 | """
23 | self.console = sys.stdout
24 | self.file = None
25 | if fpath is not None:
26 | mkdir_if_missing(os.path.dirname(fpath))
27 | self.file = open(fpath, 'w')
28 |
29 | def __del__(self):
30 | self.close()
31 |
32 | def __enter__(self):
33 | pass
34 |
35 | def __exit__(self, *args):
36 | self.close()
37 |
38 | def write(self, msg):
39 | self.console.write(msg)
40 | if self.file is not None:
41 | self.file.write(msg)
42 |
43 | def write_in(self, msg):
44 | """write in log only, not console"""
45 | if self.file is not None:
46 | self.file.write(msg)
47 |
48 | def flush(self):
49 | self.console.flush()
50 | if self.file is not None:
51 | self.file.flush()
52 | os.fsync(self.file.fileno())
53 |
54 | def close(self):
55 | self.console.close()
56 | if self.file is not None:
57 | self.file.close()
58 |
59 |
60 | def mkdir_if_missing(dir_path):
61 | try:
62 | os.makedirs(dir_path)
63 | except OSError as e:
64 | if e.errno != errno.EEXIST:
65 | raise
66 |
--------------------------------------------------------------------------------
/libs/utils/meter.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from enum import Enum
7 |
8 | import torch
9 | import torch.distributed as dist
10 |
11 |
12 | class Summary(Enum):
13 | NONE = 0
14 | AVERAGE = 1
15 | SUM = 2
16 | COUNT = 3
17 |
18 |
19 | class AverageMeter(object):
20 | """Computes and stores the average and current value"""
21 |
22 | def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
23 | self.name = name
24 | self.fmt = fmt
25 | self.summary_type = summary_type
26 | self.reset()
27 |
28 | def reset(self):
29 | self.val = 0
30 | self.avg = 0
31 | self.sum = 0
32 | self.count = 0
33 |
34 | def update(self, val, n=1):
35 | self.val = val
36 | self.sum += val * n
37 | self.count += n
38 | self.avg = self.sum / self.count
39 |
40 | def all_reduce(self):
41 | if torch.cuda.is_available():
42 | device = torch.device("cuda")
43 | elif torch.backends.mps.is_available():
44 | device = torch.device("mps")
45 | else:
46 | device = torch.device("cpu")
47 |
48 | total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
49 | dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
50 | self.sum, self.count = total.tolist()
51 | self.avg = self.sum / self.count
52 |
53 | def __str__(self):
54 | fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
55 | return fmtstr.format(**self.__dict__)
56 |
57 | def summary(self):
58 | fmtstr = ''
59 | if self.summary_type is Summary.NONE:
60 | fmtstr = ''
61 | elif self.summary_type is Summary.AVERAGE:
62 | fmtstr = '{name} {avg:.3f}'
63 | elif self.summary_type is Summary.SUM:
64 | fmtstr = '{name} {sum:.3f}'
65 | elif self.summary_type is Summary.COUNT:
66 | fmtstr = '{name} {count:.3f}'
67 | else:
68 | raise ValueError('invalid summary type %r' % self.summary_type)
69 |
70 | return fmtstr.format(**self.__dict__)
71 |
--------------------------------------------------------------------------------
/libs/utils/misc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import math
7 |
8 | import torch
9 |
10 |
11 | def identity(t, *args, **kwargs):
12 | """return t"""
13 | return t
14 |
15 |
16 | def exists(x):
17 | """whether x is None or not"""
18 | return x is not None
19 |
20 |
21 | def default(val, d):
22 | """ternary judgment: val != None ? val : d"""
23 | if exists(val):
24 | return val
25 | return d() if callable(d) else d
26 |
27 |
28 | def has_int_squareroot(num):
29 | return (math.sqrt(num) ** 2) == num
30 |
31 |
32 | def num_to_groups(num, divisor):
33 | groups = num // divisor
34 | remainder = num % divisor
35 | arr = [divisor] * groups
36 | if remainder > 0:
37 | arr.append(remainder)
38 | return arr
39 |
40 |
41 | #################################################################################
42 | # Model Utils #
43 | #################################################################################
44 |
45 | def sum_params(model: torch.nn.Module, eps: float = 1e6):
46 | return sum(p.numel() for p in model.parameters()) / eps
47 |
48 |
49 | #################################################################################
50 | # DataLoader Utils #
51 | #################################################################################
52 |
53 | def cycle(dl):
54 | while True:
55 | for data in dl:
56 | yield data
57 |
58 |
59 | #################################################################################
60 | # Diffusion Model Utils #
61 | #################################################################################
62 |
63 | def extract(a, t, x_shape):
64 | b, *_ = t.shape
65 | assert x_shape[0] == b
66 | out = a.gather(-1, t) # 1-D tensor, shape: (b,)
67 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) # shape: [b, 1, 1, 1]
68 |
69 |
70 | def unnormalize(x):
71 | """unnormalize_to_zero_to_one"""
72 | x = (x + 1) * 0.5 # Map the data interval to [0, 1]
73 | return torch.clamp(x, 0.0, 1.0)
74 |
75 |
76 | def normalize(x):
77 | """normalize_to_neg_one_to_one"""
78 | x = x * 2 - 1 # Map the data interval to [-1, 1]
79 | return torch.clamp(x, -1.0, 1.0)
80 |
--------------------------------------------------------------------------------
/libs/utils/model_summary.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import sys
7 | from collections import OrderedDict
8 |
9 | import numpy as np
10 | import torch
11 |
12 | layer_modules = (torch.nn.MultiheadAttention,)
13 |
14 |
15 | def summary(model, input_data=None, input_data_args=None, input_shape=None, input_dtype=torch.FloatTensor,
16 | batch_size=-1,
17 | *args, **kwargs):
18 | """
19 | give example input data as least one way like below:
20 | ① input_data ---> model.forward(input_data)
21 | ② input_data_args ---> model.forward(*input_data_args)
22 | ③ input_shape & input_dtype ---> model.forward(*[torch.rand(2, *size).type(input_dtype) for size in input_shape])
23 | """
24 |
25 | hooks = []
26 | summary = OrderedDict()
27 |
28 | def register_hook(module):
29 | def hook(module, inputs, outputs):
30 |
31 | class_name = str(module.__class__).split(".")[-1].split("'")[0]
32 | module_idx = len(summary)
33 |
34 | key = "%s-%i" % (class_name, module_idx + 1)
35 |
36 | info = OrderedDict()
37 | info["id"] = id(module)
38 | if isinstance(outputs, (list, tuple)):
39 | try:
40 | info["out"] = [batch_size] + list(outputs[0].size())[1:]
41 | except AttributeError:
42 | # pack_padded_seq and pad_packed_seq store feature into data attribute
43 | info["out"] = [batch_size] + list(outputs[0].data.size())[1:]
44 | else:
45 | info["out"] = [batch_size] + list(outputs.size())[1:]
46 |
47 | info["params_nt"], info["params"] = 0, 0
48 | for name, param in module.named_parameters():
49 | info["params"] += param.nelement() * param.requires_grad
50 | info["params_nt"] += param.nelement() * (not param.requires_grad)
51 |
52 | summary[key] = info
53 |
54 | # ignore Sequential and ModuleList and other containers
55 | if isinstance(module, layer_modules) or not module._modules:
56 | hooks.append(module.register_forward_hook(hook))
57 |
58 | model.apply(register_hook)
59 |
60 | # multiple inputs to the network
61 | if isinstance(input_shape, tuple):
62 | input_shape = [input_shape]
63 |
64 | if input_data is not None:
65 | x = [input_data]
66 | elif input_shape is not None:
67 | # batch_size of 2 for batchnorm
68 | x = [torch.rand(2, *size).type(input_dtype) for size in input_shape]
69 | elif input_data_args is not None:
70 | x = input_data_args
71 | else:
72 | x = []
73 | try:
74 | with torch.no_grad():
75 | model(*x) if not (kwargs or args) else model(*x, *args, **kwargs)
76 | except Exception:
77 | # This can be usefull for debugging
78 | print("Failed to run summary...")
79 | raise
80 | finally:
81 | for hook in hooks:
82 | hook.remove()
83 | summary_logs = []
84 | summary_logs.append("--------------------------------------------------------------------------")
85 | line_new = "{:<30} {:>20} {:>20}".format("Layer (type)", "Output Shape", "Param #")
86 | summary_logs.append(line_new)
87 | summary_logs.append("==========================================================================")
88 | total_params = 0
89 | total_output = 0
90 | trainable_params = 0
91 | for layer in summary:
92 | # layer, output_shape, params
93 | line_new = "{:<30} {:>20} {:>20}".format(
94 | layer,
95 | str(summary[layer]["out"]),
96 | "{0:,}".format(summary[layer]["params"] + summary[layer]["params_nt"])
97 | )
98 | total_params += (summary[layer]["params"] + summary[layer]["params_nt"])
99 | total_output += np.prod(summary[layer]["out"])
100 | trainable_params += summary[layer]["params"]
101 | summary_logs.append(line_new)
102 |
103 | # assume 4 bytes/number
104 | if input_data is not None:
105 | total_input_size = abs(sys.getsizeof(input_data) / (1024 ** 2.))
106 | elif input_shape is not None:
107 | total_input_size = abs(np.prod(input_shape) * batch_size * 4. / (1024 ** 2.))
108 | else:
109 | total_input_size = 0.0
110 | total_output_size = abs(2. * total_output * 4. / (1024 ** 2.)) # x2 for gradients
111 | total_params_size = abs(total_params * 4. / (1024 ** 2.))
112 | total_size = total_params_size + total_output_size + total_input_size
113 |
114 | summary_logs.append("==========================================================================")
115 | summary_logs.append("Total params: {0:,}".format(total_params))
116 | summary_logs.append("Trainable params: {0:,}".format(trainable_params))
117 | summary_logs.append("Non-trainable params: {0:,}".format(total_params - trainable_params))
118 | summary_logs.append("--------------------------------------------------------------------------")
119 | summary_logs.append("Input size (MB): %0.6f" % total_input_size)
120 | summary_logs.append("Forward/backward pass size (MB): %0.6f" % total_output_size)
121 | summary_logs.append("Params size (MB): %0.6f" % total_params_size)
122 | summary_logs.append("Estimated Total Size (MB): %0.6f" % total_size)
123 | summary_logs.append("--------------------------------------------------------------------------")
124 |
125 | summary_info = "\n".join(summary_logs)
126 |
127 | print(summary_info)
128 | return summary_info
129 |
--------------------------------------------------------------------------------
/libs/utils/tqdm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from typing import Callable
7 | from tqdm.auto import tqdm
8 |
9 |
10 | def tqdm_decorator(func: Callable):
11 | """A decorator function called tqdm_decorator that takes a function as an argument and
12 | returns a new function that wraps the input function with a tqdm progress bar.
13 |
14 | Noting: **The input function is assumed to have an object self as its first argument**, which contains a step attribute,
15 | an args attribute with a train_num_steps attribute, and an accelerator attribute with an is_main_process attribute.
16 |
17 | Args:
18 | func: tqdm_decorator
19 |
20 | Returns:
21 | a new function that wraps the input function with a tqdm progress bar.
22 | """
23 |
24 | def wrapper(*args, **kwargs):
25 | with tqdm(initial=args[0].step,
26 | total=args[0].args.train_num_steps,
27 | disable=not args[0].accelerator.is_main_process) as pbar:
28 | func(*args, **kwargs, pbar=pbar)
29 |
30 | return wrapper
31 |
--------------------------------------------------------------------------------
/pipelines/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/pipelines/inversion/ILVR.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from argparse import Namespace
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torchvision import utils as tv_util
11 | from tqdm import tqdm
12 |
13 | from libs.engine import ModelState
14 | from libs.utils import cycle
15 | from sketch_nn.augment.resizer import Resizer
16 |
17 |
18 | class ILVRPipeline(ModelState):
19 |
20 | def __init__(
21 | self,
22 | args: Namespace,
23 | eps_model: nn.Module,
24 | eps_model_path: str,
25 | diffusion: nn.Module,
26 | dataloader: torch.utils.data.DataLoader
27 | ):
28 | super().__init__(args)
29 | self.args = args
30 |
31 | # set log path
32 | self.results_path = self.results_path.joinpath(f"{args.task}-sample")
33 | self.results_path.mkdir(exist_ok=True)
34 |
35 | self.diffusion = diffusion
36 |
37 | # create eps_model
38 | self.print(f"loading SDE from `{eps_model_path}` ....")
39 | self.eps_model = self.load_ckpt_model_only(eps_model, eps_model_path)
40 | if args.model.use_fp16:
41 | self.eps_model.convert_to_fp16()
42 | self.eps_model.eval()
43 | self.print(f"-> eps_model Params: {(sum(p.numel() for p in self.eps_model.parameters()) / 1e6):.3f}M")
44 |
45 | self.eps_model, self.dataloader = self.accelerator.prepare(self.eps_model, dataloader)
46 | self.dataloader = cycle(self.dataloader)
47 |
48 | def sample(self):
49 | device = self.accelerator.device
50 | accelerator = self.accelerator
51 |
52 | sample = next(iter(self.dataloader))
53 | batch_size = sample["image"].shape[0] # get real batch_size
54 | image_size = self.args.image_size
55 |
56 | down_N = self.args.down_N
57 | shape = (batch_size, 3, image_size, image_size)
58 | shape_d = (
59 | batch_size, 3, int(image_size / down_N), int(image_size / down_N)
60 | )
61 | down = Resizer(shape, 1 / down_N).to(device)
62 | up = Resizer(shape_d, down_N).to(device)
63 | resizers = (down, up)
64 |
65 | extra_kwargs = {}
66 | model_kwargs = {}
67 |
68 | i = 0
69 | with tqdm(initial=i, total=self.args.total_samples, disable=not accelerator.is_main_process) as pbar:
70 | while self.step < self.args.total_samples:
71 | sample = next(self.dataloader)
72 | ref_img, name = sample["image"], sample["fname"]
73 | extra_kwargs["ref_img"] = ref_img
74 |
75 | sample = self.diffusion.p_sample_loop(
76 | self.eps_model,
77 | (batch_size, 3, image_size, image_size),
78 | clip_denoised=self.args.diffusion.clip_denoised,
79 | model_kwargs=model_kwargs,
80 | resizers=resizers,
81 | range_t=self.args.range_t,
82 | extra_kwargs=extra_kwargs
83 | )
84 |
85 | if self.accelerator.is_main_process:
86 | sample = self.accelerator.gather(sample)
87 | sample = (sample + 1) / 2
88 |
89 | if self.args.get_final_results:
90 | for b in range(sample.shape[0]):
91 | name_ = name[b].split(".")[0] # Remove file suffixes
92 | save_path = self.results_path / f"{int(self.step + b)}-{name_}.png"
93 | tv_util.save_image(sample[b], save_path)
94 | else:
95 | for b in range(sample.shape[0]):
96 | save_path = self.results_path.joinpath(
97 | f"i-{i}-b-{b}-t-down_N-{down_N}-rt-{self.args.range_t}.png"
98 | )
99 | # (x0, sampled)
100 | save_grids = torch.cat(
101 | [ref_img[b].unsqueeze_(0), sample[b].unsqueeze_(0)],
102 | dim=0
103 | )
104 | tv_util.save_image(save_grids.float(), save_path, nrow=sample.shape[0])
105 |
106 | i += batch_size
107 | pbar.update(1)
108 | self.close()
109 |
--------------------------------------------------------------------------------
/pipelines/inversion/ILVR_mixup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from argparse import Namespace
7 |
8 | import torch
9 | import torch.nn as nn
10 | from torchvision import utils as tv_util
11 | from tqdm import tqdm
12 |
13 | from libs.engine import ModelState
14 | from libs.utils import cycle
15 | from sketch_nn.augment.resizer import Resizer
16 |
17 | class ILVRMixupPipeline(ModelState):
18 |
19 | def __init__(
20 | self,
21 | args: Namespace,
22 | eps_model: nn.Module,
23 | eps_model_path: str,
24 | diffusion: nn.Module,
25 | src_dataloader: torch.utils.data.DataLoader,
26 | ref_dataloader: torch.utils.data.DataLoader
27 | ):
28 | super().__init__(args)
29 | self.args = args
30 |
31 | # set log path
32 | self.results_path = self.results_path.joinpath(f"{args.task}-sample")
33 | self.results_path.mkdir(exist_ok=True)
34 |
35 | self.diffusion = diffusion
36 |
37 | # create eps_model
38 | self.print(f"loading SDE from `{eps_model_path}` ....")
39 | self.eps_model = self.load_ckpt_model_only(eps_model, eps_model_path)
40 | if args.model.use_fp16:
41 | self.eps_model.convert_to_fp16()
42 | self.eps_model.eval()
43 | self.print(f"-> eps_model Params: {(sum(p.numel() for p in self.eps_model.parameters()) / 1e6):.3f}M")
44 |
45 | self.eps_model = self.accelerator.prepare(self.eps_model)
46 | self.src_dataloader, self.ref_dataloader = self.accelerator.prepare(src_dataloader, ref_dataloader)
47 | self.src_dataloader = cycle(self.src_dataloader)
48 | self.ref_dataloader = cycle(self.ref_dataloader)
49 |
50 | def sample(self):
51 | device = self.accelerator.device
52 |
53 | sample = next(iter(self.src_dataloader))
54 | batch_size = sample["image"].shape[0] # get real batch_size
55 | image_size = self.args.image_size
56 |
57 | down_N = self.args.down_N
58 | shape = (batch_size, 3, image_size, image_size)
59 | shape_d = (
60 | batch_size, 3, int(image_size / down_N), int(image_size / down_N)
61 | )
62 | down = Resizer(shape, 1 / down_N).to(device)
63 | up = Resizer(shape_d, down_N).to(device)
64 | resizers = (down, up)
65 |
66 | model_kwargs = {}
67 | i = 0
68 |
69 | with tqdm(initial=i, total=self.args.total_samples, disable=not self.accelerator.is_local_main_process) as pbar:
70 | while i < self.args.total_samples:
71 | src_sample = next(self.src_dataloader)
72 | src_input, src_name = src_sample["image"], src_sample["fname"]
73 | ref_sample = next(self.ref_dataloader)
74 | ref_input, ref_name = ref_sample["image"], ref_sample["fname"]
75 |
76 | extra_kwargs = {
77 | "src_input": src_input,
78 | "ref_input": ref_input,
79 | "fuse_scale": self.args.fuse_scale
80 | }
81 |
82 | sample = self.diffusion.p_sample_loop(
83 | self.eps_model,
84 | (batch_size, 3, image_size, image_size),
85 | clip_denoised=self.args.diffusion.clip_denoised,
86 | model_kwargs=model_kwargs,
87 | resizers=resizers,
88 | range_t=self.args.range_t,
89 | extra_kwargs=extra_kwargs
90 | )
91 |
92 | if self.accelerator.is_main_process:
93 | sample = self.accelerator.gather(sample)
94 | sample = (sample + 1) / 2
95 |
96 | if self.args.get_final_results:
97 | for b in range(sample.shape[0]):
98 | s_name_ = src_name[b].split(".")[0] # Remove file suffixes
99 | r_name_ = ref_name[b].split(".")[0] # Remove file suffixes
100 | save_path = self.results_path / f"{int(self.step + b)}-{s_name_}_to_{r_name_}.png"
101 | tv_util.save_image(sample[b], save_path)
102 | else:
103 | for b in range(sample.shape[0]):
104 | save_path = self.results_path.joinpath(
105 | f"i-{i}-b-{b}-t-down_N-{down_N}-rt-{self.args.range_t}.png"
106 | )
107 | # (x0, sampled)
108 | save_grids = torch.cat(
109 | [ref_input[b].unsqueeze_(0), sample[b].unsqueeze_(0)],
110 | dim=0
111 | )
112 | tv_util.save_image(save_grids.float(), save_path, nrow=sample.shape[0])
113 |
114 | i += batch_size
115 | pbar.update(1)
116 |
117 | self.close()
118 |
--------------------------------------------------------------------------------
/pipelines/inversion/SDEdit_iter_pipeline.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from datetime import datetime
7 |
8 | import torch
9 | from torchvision import utils as tv_util
10 | from tqdm import tqdm
11 |
12 | from libs.engine import ModelState
13 | from libs.utils import cycle
14 | from sketch_nn.augment.resizer import Resizer
15 | from sketch_nn.methods.inversion.SDEdit_iter import IterativeSDEdit
16 |
17 |
18 | class IterativeSDEditPipeline(ModelState):
19 |
20 | def __init__(self, args, sde_model, sde_path, src_dataloader, ref_dataloader):
21 | super().__init__(args)
22 | self.args = args
23 |
24 | self.print(f"loading SDE from `{sde_path}` ....")
25 | self.sde_model = self.load_ckpt_model_only(sde_model, sde_path)
26 | self.print(f"-> SDE Params: {(sum(p.numel() for p in sde_model.parameters()) / 1e6):.3f}M")
27 |
28 | self.results_path = self.results_path.joinpath(f"{args.dataset}-{args.task}-sample-seed-{args.seed}")
29 | self.results_path.mkdir(exist_ok=True)
30 |
31 | dpm_cfg = args.diffusion
32 | self.SDEdit = IterativeSDEdit(dpm_cfg.timesteps, dpm_cfg.beta_schedule, dpm_cfg.var_type)
33 |
34 | self.SDEdit, self.sde_model = self.accelerator.prepare(self.SDEdit, self.sde_model)
35 | self.src_dataloader, self.ref_dataloader = self.accelerator.prepare(src_dataloader, ref_dataloader)
36 | self.src_dataloader = cycle(self.src_dataloader)
37 | self.ref_dataloader = cycle(self.ref_dataloader)
38 |
39 | self.print()
40 |
41 | def sample(self):
42 | accelerator = self.accelerator
43 | device = self.accelerator.device
44 |
45 | sample = next(iter(self.src_dataloader))
46 | batch_size = sample["image"].shape[0] # online batch_size
47 | image_size = self.args.image_size
48 |
49 | s_down_N = self.args.src_down_N
50 | shape = (batch_size, 3, image_size, image_size)
51 | s_shape_d = (
52 | batch_size, 3, int(image_size / s_down_N), int(image_size / s_down_N)
53 | )
54 | src_down = Resizer(shape, 1 / s_down_N).to(device)
55 | src_up = Resizer(s_shape_d, s_down_N).to(device)
56 | low_passer = (src_down, src_up)
57 |
58 | model_kwargs = {}
59 | iter_kwargs = {
60 | 'low_passer': low_passer,
61 | 'fusion_scale': self.args.fusion_scale
62 | }
63 | i = 0
64 | with tqdm(initial=i, total=self.args.total_samples, disable=not accelerator.is_main_process) as pbar:
65 | while i < self.args.total_samples:
66 | src_sample = next(self.src_dataloader)
67 | src_input, name = src_sample["image"], src_sample["fname"]
68 | ref_sample = next(self.ref_dataloader)
69 | ref_input = ref_sample["image"]
70 |
71 | model_kwargs['step'] = i
72 |
73 | start_time = datetime.now()
74 | results = self.SDEdit.iterative_sampling_progressive(
75 | src_input,
76 | ref_input,
77 | self.args.iter_step,
78 | iter_kwargs,
79 | list(self.args.repeat_step),
80 | list(self.args.perturb_step),
81 | model=self.sde_model,
82 | model_kwargs=model_kwargs,
83 | device=device,
84 | recorder=pbar
85 | )
86 | pbar.set_description(f"one batch time: {datetime.now() - start_time}, "
87 | f"total_iter: {self.args.iter_step}")
88 |
89 | if accelerator.is_main_process:
90 | results = accelerator.gather(results)
91 | # gather final result
92 | for b in range(batch_size):
93 | all_iter_grids = []
94 | for ith in range(self.args.iter_step):
95 | for kth in range(self.args.repeat_step[ith]):
96 | x0, final, perturb_x0, blurred_x0 = results[f"{ith}-{kth}-th"]
97 | # (x0, perturbed_x0, kth_translated_x0)
98 | # (x0, perturbed_x0, src, blurred_x0, kth_translated_x0)
99 | save_grids = torch.cat([x0[b].unsqueeze_(0),
100 | perturb_x0[b].unsqueeze_(0),
101 | src_input[b].unsqueeze_(0)
102 | if ith != 0 else torch.zeros_like(src_input[b]).unsqueeze_(0),
103 | blurred_x0[b].unsqueeze_(0),
104 | final[b].unsqueeze_(0)], dim=0)
105 | all_iter_grids.append(save_grids)
106 | # visual
107 | img_name = name[b].split(".")[0] # Remove file suffixes
108 | save_path = self.results_path.joinpath(
109 | f"{i}-{img_name}-iter-{ith + 1}-K-{kth + 1}-t-{self.args.perturb_step}.png"
110 | )
111 | tv_util.save_image(torch.cat(all_iter_grids, dim=0), save_path, nrow=5)
112 |
113 | i += 1
114 | pbar.update(1)
115 |
116 | self.close()
117 |
--------------------------------------------------------------------------------
/pipelines/inversion/SDEdit_pipeline.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from datetime import datetime
7 |
8 | import torch
9 | from torchvision import utils as tv_util
10 | from tqdm import tqdm
11 |
12 | from libs.engine import ModelState
13 | from dpm_nn.inversion.SDEdit import SDEdit
14 |
15 |
16 | class SDEditPipeline(ModelState):
17 |
18 | def __init__(self, args, sde_model, sde_path, dataloader, use_dpm_solver: bool = False):
19 | super().__init__(args)
20 | self.args = args
21 | self.use_dpm_solver = use_dpm_solver
22 |
23 | self.print(f"loading SDE from `{sde_path}` ....")
24 | self.sde_model = self.load_ckpt_model_only(sde_model, sde_path)
25 | self.print(f"-> SDE Params: {(sum(p.numel() for p in sde_model.parameters()) / 1e6):.3f}M")
26 |
27 | dpm_cfg = args.diffusion
28 | self.SDEdit = SDEdit(dpm_cfg.timesteps, dpm_cfg.beta_schedule, dpm_cfg.var_type)
29 |
30 | self.SDEdit, self.sde_model = \
31 | self.accelerator.prepare(self.SDEdit, self.sde_model)
32 | self.dataloader = self.accelerator.prepare(dataloader)
33 |
34 | self.print()
35 |
36 | def sample(self):
37 | accelerator = self.accelerator
38 | device = self.accelerator.device
39 |
40 | sample = next(iter(self.dataloader))
41 | batch_size = sample["image"].shape[0] # online batch_size
42 | image_size = self.args.image_size
43 |
44 | model_kwargs = {}
45 | with tqdm(self.dataloader, disable=not accelerator.is_local_main_process) as pbar:
46 | for i, sample in enumerate(pbar):
47 | start_time = datetime.now()
48 |
49 | src_input, name = sample["image"], sample["fname"]
50 |
51 | model_kwargs['step'] = i
52 | results = self.SDEdit.sampling_progressive(
53 | src_input,
54 | mask=sample.get('mask', None), # editing mask
55 | repeat_step=self.args.repeat_step,
56 | perturb_step=self.args.perturb_step,
57 | model=self.sde_model,
58 | model_kwargs=model_kwargs,
59 | device=device,
60 | recorder=pbar,
61 | use_dpm_solver=self.use_dpm_solver
62 | )
63 |
64 | pbar.set_description(f"time per batch: {datetime.now() - start_time}")
65 | # pbar.write(f"Running time: {datetime.now() - start_time} | batch_size: {batch_size} \n")
66 |
67 | if accelerator.is_main_process:
68 | results = accelerator.gather(results)
69 | # gather final result
70 | for b in range(batch_size):
71 | all_iter_grids = []
72 | for kth in range(len(results)):
73 | x0, final, perturb_x0 = results[f"{kth}-th"]
74 | # (x0, perturbed_x0, kth_translated_x0)
75 | save_grids = torch.cat(
76 | [x0[b].unsqueeze_(0), perturb_x0[b].unsqueeze_(0), final[b].unsqueeze_(0)],
77 | dim=0
78 | )
79 | all_iter_grids.append(save_grids)
80 | # visual
81 | img_name = name[b].split(".")[0] # Remove file suffixes
82 | save_path = self.results_path.joinpath(
83 | f"i-{i}-{img_name}-B-{b}-K-{kth + 1}-t-{self.args.perturb_step}.png"
84 | )
85 | tv_util.save_image(torch.cat(all_iter_grids, dim=0), save_path, nrow=3)
86 |
87 | self.close()
88 |
--------------------------------------------------------------------------------
/pipelines/inversion/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py @ file:///opt/conda/conda-bld/absl-py_1639803114343/work
2 | accelerate==0.18.0
3 | aiohttp @ file:///tmp/build/80754af9/aiohttp_1646806366512/work
4 | aiosignal @ file:///tmp/build/80754af9/aiosignal_1637843061372/work
5 | antlr4-python3-runtime==4.9.3
6 | async-timeout @ file:///tmp/build/80754af9/async-timeout_1637851218186/work
7 | attrs @ file:///opt/conda/conda-bld/attrs_1642510447205/work
8 | beautifulsoup4==4.12.2
9 | blinker==1.4
10 | Bottleneck @ file:///opt/conda/conda-bld/bottleneck_1657175564434/work
11 | brotlipy==0.7.0
12 | cachetools @ file:///tmp/build/80754af9/cachetools_1619597386817/work
13 | certifi @ file:///croot/certifi_1665076670883/work/certifi
14 | cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work
15 | charset-normalizer @ file:///tmp/build/80754af9/charset-normalizer_1630003229654/work
16 | click @ file:///tmp/build/80754af9/click_1646038465422/work
17 | clip @ git+https://github.com/openai/CLIP.git@d50d76daa670286dd6cacf3bcd80b5e4823fc8e1
18 | cmake==3.26.1
19 | contourpy==1.0.6
20 | cryptography @ file:///tmp/build/80754af9/cryptography_1652083738073/work
21 | cycler @ file:///tmp/build/80754af9/cycler_1637851556182/work
22 | docker-pycreds==0.4.0
23 | einops==0.5.0
24 | filelock==3.9.0
25 | fonttools==4.25.0
26 | frozenlist @ file:///tmp/build/80754af9/frozenlist_1637767111923/work
27 | fsspec==2023.4.0
28 | ftfy==6.1.1
29 | future==0.18.2
30 | gitdb==4.0.10
31 | GitPython==3.1.30
32 | google-auth @ file:///opt/conda/conda-bld/google-auth_1646735974934/work
33 | google-auth-oauthlib @ file:///tmp/build/80754af9/google-auth-oauthlib_1617120569401/work
34 | googledrivedownloader @ file:///home/conda/feedstock_root/build_artifacts/googledrivedownloader_1619807768586/work
35 | grpcio @ file:///tmp/build/80754af9/grpcio_1637590823556/work
36 | huggingface-hub==0.12.0
37 | idna @ file:///tmp/build/80754af9/idna_1637925883363/work
38 | importlib-metadata @ file:///tmp/build/80754af9/importlib-metadata_1648562408398/work
39 | Jinja2 @ file:///opt/conda/conda-bld/jinja2_1647436528585/work
40 | joblib @ file:///tmp/build/80754af9/joblib_1635411271373/work
41 | kiwisolver @ file:///opt/conda/conda-bld/kiwisolver_1653292039266/work
42 | kornia==0.6.11
43 | lightning-utilities==0.8.0
44 | lit==16.0.0
45 | Markdown @ file:///tmp/build/80754af9/markdown_1614363528767/work
46 | MarkupSafe @ file:///opt/conda/conda-bld/markupsafe_1654597864307/work
47 | matplotlib==3.6.2
48 | mkl-fft==1.3.1
49 | mkl-random @ file:///tmp/build/80754af9/mkl_random_1626186064646/work
50 | mkl-service==2.4.0
51 | mpmath==1.3.0
52 | multidict @ file:///opt/conda/conda-bld/multidict_1662369340274/work
53 | munkres==1.1.4
54 | networkx @ file:///opt/conda/conda-bld/networkx_1657784097507/work
55 | numexpr @ file:///opt/conda/conda-bld/numexpr_1656940300424/work
56 | numpy @ file:///tmp/abs_653_j00fmm/croots/recipe/numpy_and_numpy_base_1659432701727/work
57 | nvidia-cublas-cu11==11.10.3.66
58 | nvidia-cuda-cupti-cu11==11.7.101
59 | nvidia-cuda-nvrtc-cu11==11.7.99
60 | nvidia-cuda-runtime-cu11==11.7.99
61 | nvidia-cudnn-cu11==8.5.0.96
62 | nvidia-cufft-cu11==10.9.0.58
63 | nvidia-curand-cu11==10.2.10.91
64 | nvidia-cusolver-cu11==11.4.0.1
65 | nvidia-cusparse-cu11==11.7.4.91
66 | nvidia-nccl-cu11==2.14.3
67 | nvidia-nvtx-cu11==11.7.91
68 | oauthlib @ file:///tmp/abs_08ngfezid4/croots/recipe/oauthlib_1659642459222/work
69 | omegaconf==2.3.0
70 | opencv-python==4.7.0.72
71 | packaging @ file:///tmp/build/80754af9/packaging_1637314298585/work
72 | pandas==1.4.3
73 | pathtools==0.1.2
74 | Pillow==9.2.0
75 | ply==3.11
76 | promise==2.3
77 | protobuf==3.20.1
78 | psutil==5.9.2
79 | pyasn1 @ file:///Users/ktietz/demo/mc3/conda-bld/pyasn1_1629708007385/work
80 | pyasn1-modules==0.2.8
81 | pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
82 | PyJWT @ file:///opt/conda/conda-bld/pyjwt_1657544592787/work
83 | pyOpenSSL @ file:///opt/conda/conda-bld/pyopenssl_1643788558760/work
84 | pyparsing @ file:///opt/conda/conda-bld/pyparsing_1661452539315/work
85 | PyQt5-sip==12.11.0
86 | PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
87 | python-dateutil @ file:///tmp/build/80754af9/python-dateutil_1626374649649/work
88 | python-louvain @ file:///tmp/build/80754af9/python-louvain_1612304551119/work
89 | pytz @ file:///opt/conda/conda-bld/pytz_1654762638606/work
90 | PyYAML==6.0
91 | regex==2022.10.31
92 | requests @ file:///opt/conda/conda-bld/requests_1657734628632/work
93 | requests-oauthlib==1.3.0
94 | rsa @ file:///tmp/build/80754af9/rsa_1614366226499/work
95 | scikit-learn @ file:///tmp/abs_d76175bc-917a-47d4-9994-b56265948a6328vmoe2o/croots/recipe/scikit-learn_1658419412415/work
96 | scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy_1653073867187/work
97 | sentencepiece==0.1.99
98 | sentry-sdk==1.12.1
99 | setproctitle==1.3.2
100 | shortuuid==1.0.11
101 | sip @ file:///tmp/abs_44cd77b_pu/croots/recipe/sip_1659012365470/work
102 | six @ file:///tmp/build/80754af9/six_1644875935023/work
103 | smmap==5.0.0
104 | soupsieve==2.4.1
105 | sympy==1.11.1
106 | tensorboard @ file:///home/builder/stiwari/miniconda3/envs/tf_new_env/conda-bld/tensorboard_1661447826088/work/tensorboard-2.9.0-py3-none-any.whl
107 | tensorboard-data-server @ file:///tmp/build/80754af9/tensorboard-data-server_1633035064162/work/tensorboard_data_server-0.6.0-py3-none-manylinux2010_x86_64.whl
108 | tensorboard-plugin-wit @ file:///home/builder/tkoch/workspace/tensorflow/tensorboard-plugin-wit_1658918494740/work/tensorboard_plugin_wit-1.8.1-py3-none-any.whl
109 | threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work
110 | tokenizers==0.13.2
111 | toml @ file:///tmp/build/80754af9/toml_1616166611790/work
112 | torch==1.13.1+cu116
113 | torch-cluster @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-cluster_1631029005429/work
114 | torch-geometric @ file:///usr/share/miniconda/envs/test/conda-bld/pyg_1640156451028/work
115 | torch-scatter @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-scatter_1634900577572/work
116 | torch-sparse @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-sparse_1631173533284/work
117 | torch-spline-conv @ file:///usr/share/miniconda/envs/test/conda-bld/pytorch-spline-conv_1631007898768/work
118 | torchaudio==0.13.1+cu116
119 | torchmetrics==0.11.4
120 | torchvision==0.14.1+cu116
121 | tornado @ file:///opt/conda/conda-bld/tornado_1662061693373/work
122 | tqdm @ file:///opt/conda/conda-bld/tqdm_1650891076910/work
123 | transformers==4.26.0
124 | triton==2.0.0
125 | typing_extensions @ file:///tmp/abs_ben9emwtky/croots/recipe/typing_extensions_1659638822008/work
126 | urllib3 @ file:///tmp/abs_5dhwnz6atv/croots/recipe/urllib3_1659110457909/work
127 | wandb==0.13.7
128 | wcwidth==0.2.5
129 | Werkzeug @ file:///opt/conda/conda-bld/werkzeug_1645628268370/work
130 | yacs @ file:///tmp/build/80754af9/yacs_1634047592950/work
131 | yarl @ file:///opt/conda/conda-bld/yarl_1661437085904/work
132 | zipp @ file:///opt/conda/conda-bld/zipp_1652341764480/work
133 |
--------------------------------------------------------------------------------
/run/run_SDEdit.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import os
7 | import sys
8 | import argparse
9 |
10 | from accelerate.utils import set_seed
11 |
12 | sys.path.append(os.path.split(os.path.abspath(os.path.dirname(__file__)))[0])
13 |
14 | from libs.utils.argparse import (merge_and_update_config, accelerate_parser, base_data_parser, base_sampling_parser)
15 | from dpm_nn.guided_dpm.build_ADMs import ADMs_build_util
16 | from sketch_nn.dataset.build import build_image2image_translation_dataset
17 |
18 |
19 | def main(args):
20 | assert len(args.data_folder) > 0, "Insufficient dataset entry!"
21 |
22 | args.batch_size = args.valid_batch_size
23 |
24 | if args.task == "base": # SDEdit - image to image translation
25 | from pipelines.inversion.SDEdit_pipeline import SDEditPipeline
26 |
27 | dataloader = build_image2image_translation_dataset(args.dataset, args.data_folder,
28 | split=args.split,
29 | image_size=args.image_size,
30 | batch_size=args.valid_batch_size,
31 | shuffle=args.shuffle, drop_last=True,
32 | num_workers=args.num_workers)
33 |
34 | sde_model, _ = ADMs_build_util(args.image_size, args.num_classes, args.model, args.diffusion)
35 |
36 | SDEdit = SDEditPipeline(args, sde_model, args.sdepath, dataloader, args.use_dpm_solver)
37 | SDEdit.sample()
38 |
39 | elif args.task == "mask": # TODO: SDEdit - image editing
40 | pass
41 |
42 | elif args.task == "ref":
43 | from pipelines.inversion.SDEdit_iter_pipeline import IterativeSDEditPipeline
44 |
45 | src_dataloader = build_image2image_translation_dataset(args.dataset, args.data_folder,
46 | split=args.split,
47 | image_size=args.image_size,
48 | batch_size=args.valid_batch_size,
49 | shuffle=args.shuffle, drop_last=True,
50 | num_workers=args.num_workers)
51 | ref_dataloader = build_image2image_translation_dataset(args.dataset, args.ref_data_folder,
52 | split=args.split,
53 | image_size=args.image_size,
54 | batch_size=args.valid_batch_size,
55 | shuffle=args.shuffle, drop_last=True,
56 | num_workers=args.num_workers)
57 |
58 | sde_model, _ = ADMs_build_util(args.image_size, args.num_classes, args.model, args.diffusion)
59 |
60 | SDEdit = IterativeSDEditPipeline(args, sde_model, args.sdepath, src_dataloader, ref_dataloader)
61 | SDEdit.sample()
62 |
63 |
64 | if __name__ == '__main__':
65 | """
66 | ## cat2dog, base sampling, SDEdit:
67 | CUDA_VISIBLE_DEVICES=0 python run/run_SDEdit.py -c SDEdit/cat2dog-img256.yaml -sdepath ./checkpoint/InvSDE/afhq_dog_4m.pt -dpath ./dataset/afhq/val/cat -respath ./workdir/sdedit_cat -vbz 32 -final -ts 500
68 | CUDA_VISIBLE_DEVICES=0 python run/run_SDEdit.py -c SDEdit/cat2dog-img256.yaml -sdepath ./checkpoint/InvSDE/afhq_dog_4m.pt -dpath ./dataset/afhq/val/dog -respath ./workdir/sdedit_dog -vbz 32 -final -ts 500
69 |
70 | ## SDEdit + ref:
71 | CUDA_VISIBLE_DEVICES=0 python run/run_SDEdit.py -c SDEdit/iter-cat2dog-img256-p400-k33-dN32.yaml --task ref -sdepath ./checkpoint/afhq_dog_4m.pt -dpath ./dataset/afhq/train/cat -rdpath ./dataset/afhq/train_edge_map/dog -respath /data2/xingxm/skgruns/ -vbz 8
72 | """
73 |
74 | parser = argparse.ArgumentParser(
75 | description="SDEdit",
76 | parents=[accelerate_parser(), base_data_parser(), base_sampling_parser()]
77 | )
78 |
79 | # flag
80 | parser.add_argument("-tk", "--task",
81 | default="base", type=str, choices=["base", "mask", "ref"],
82 | help="guided image synthesis and editing.")
83 | # config
84 | parser.add_argument("-c", "--config",
85 | required=True, type=str,
86 | default="SDEdit/cat2dog-img256.yaml",
87 | help="YAML/YML file for configuration.")
88 | # data path
89 | parser.add_argument("-dpath", "--data_folder",
90 | nargs="+", type=str,
91 | # default==['./dataset/afhq/val/cat'],
92 | # default=['./dataset/afhq/train/cat', './dataset/afhq/train/dog'],
93 | # default=['./dataset/afhq/train/cat', './dataset/afhq/train/wild', './dataset/afhq/train/dog'],
94 | help="single input for single-domain, multi inputs for multi-domain")
95 | parser.add_argument("-rdpath", "--ref_data_folder",
96 | nargs="+", type=str, default=None,
97 | # default==['./dataset/afhq/val/cat'],
98 | # default=['./dataset/afhq/train/cat', './dataset/afhq/train/dog'],
99 | # default=['./dataset/afhq/train/cat', './dataset/afhq/train/wild', './dataset/afhq/train/dog'],
100 | help="single input for single-domain, multi inputs for multi-domain")
101 | # model path
102 | parser.add_argument("-sdepath",
103 | default="./checkpoint/afhq_dog_4m.pt", type=str,
104 | help="place pretrained model in `./checkpoint/afhq_dog_4m.pt`, "
105 | "if None, then train from scratch")
106 | # use dpm-solver
107 | parser.add_argument("-uds", "--use_dpm_solver",
108 | action='store_true',
109 | help="use dpm_solver accelerates sampling.")
110 | # sampling mode
111 | parser.add_argument("-final", "--get_final_results",
112 | action='store_true',
113 | help="visualize intermediate results or just get final output.")
114 |
115 | args = parser.parse_args()
116 | args = merge_and_update_config(args)
117 |
118 | set_seed(args.seed)
119 | main(args)
120 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 |
5 | """
6 | Description: How to install
7 | # all:
8 | pip install omegaconf tqdm scipy opencv-python einops BeautifulSoup4 timm matplotlib torchmetrics accelerate diffusers triton transformers -i https://pypi.tuna.tsinghua.edu.cn/simple
9 |
10 | # CLIP:
11 | pip install git+https://github.com/openai/CLIP.git -i https://pypi.tuna.tsinghua.edu.cn/simple
12 |
13 | # torch 1.13.1:
14 | conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia
15 |
16 | # xformers (python=3.10):
17 | conda install xformers -c xformers
18 | xFormers - Toolbox to Accelerate Research on Transformers:
19 | https://github.com/facebookresearch/xformers
20 | """
21 |
22 | from setuptools import setup, find_packages
23 |
24 | setup(
25 | name='SketchGuidedGeneration',
26 | packages=find_packages(),
27 | version='0.0.13',
28 | license='MIT',
29 | description='Sketch Guided Content Generation',
30 | author='XiMing Xing',
31 | author_email='ximingxing@gmail.com',
32 | url='https://github.com/ximinng/SketchGeneration/',
33 | long_description_content_type='text/markdown',
34 | keywords=[
35 | 'artificial intelligence',
36 | 'generative models',
37 | 'sketch'
38 | ],
39 | install_requires=[
40 | 'omegaconf', # YAML processor
41 | 'accelerate', # Hugging Face - pytorch distributed configuration
42 | 'diffusers', # Hugging Face - diffusion models
43 | 'transformers', # Hugging Face - transformers
44 | 'einops',
45 | 'pillow',
46 | 'torch>=1.13.1',
47 | 'torchvision',
48 | 'tensorboard',
49 | 'torchmetrics',
50 | 'tqdm', # progress bar
51 | 'timm', # computer vision models
52 | "numpy", # numpy
53 | 'matplotlib',
54 | 'scikit-learn',
55 | 'omegaconf', # configs
56 | 'Pillow', # keep the PIL.Image.Resampling deprecation away,
57 | 'wandb', # weights & Biases
58 | 'opencv-python', # cv2
59 | 'BeautifulSoup4'
60 | ],
61 | classifiers=[
62 | 'Development Status :: 4 - Beta',
63 | 'Intended Audience :: Developers',
64 | 'Topic :: Scientific/Engineering :: Artificial Intelligence',
65 | 'License :: OSI Approved :: MIT License',
66 | 'Programming Language :: Python :: 3.8',
67 | ],
68 | )
69 |
--------------------------------------------------------------------------------
/sketch_nn/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from . import augment
7 | from . import dataset
8 | from . import edge_map
9 | from . import methods
10 | from . import model
11 | from . import photo2sketch
12 | from . import rasterize
13 |
14 | __version__ = '0.0.12'
15 |
--------------------------------------------------------------------------------
/sketch_nn/augment/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from .mixup import Mixup
7 |
--------------------------------------------------------------------------------
/sketch_nn/augment/mixup.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import torch
7 | import numpy as np
8 |
9 |
10 | class Mixup(object):
11 | """
12 | "Mixup: Beyond Empirical Risk Minimization (https://arxiv.org/abs/1710.09412)". In ICLR, 2018.
13 | https://github.com/facebookresearch/mixup-cifar10
14 | """
15 |
16 | def single_domain_mix(self, x, y, alpha=1.0, device='cpu'):
17 | if alpha > 0:
18 | lam = np.random.beta(alpha, alpha)
19 | else:
20 | lam = 1
21 |
22 | batch_size = x.size()[0]
23 | index = torch.randperm(batch_size).to(device)
24 |
25 | mixed_x = lam * x + (1 - lam) * x[index, :]
26 | y_a, y_b = y, y[index]
27 | return mixed_x, y_a, y_b, lam
28 |
29 | def dual_domains_mix(self, x1, x2, y1, y2, alpha=1.0, device='cpu'):
30 | if alpha > 0:
31 | lam = np.random.beta(alpha, alpha)
32 | else:
33 | lam = 1
34 |
35 | mixed_x = lam * x1 + (1 - lam) * x2
36 | return mixed_x, y1, y2, lam
37 |
38 | def criterion(self, criterion, pred, y_a, y_b, lam):
39 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
40 |
--------------------------------------------------------------------------------
/sketch_nn/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | # sketch dataset
7 | from .mnist import MNISTDataset
8 | from .sketchx_shoe_chairV2 import SketchXShoeAndChairCoordDataset, SketchXShoeAndChairPhotoDataset
9 | from .sketchy import SketchyDataset
10 | # real image dataset
11 | from .cifar10 import CIFAR10Dataset
12 | from .imagenet import ImageNetDataset
13 | # common
14 | from .base_dataset import MultiDomainDataset, SingleDomainDataset, SingleDomainWithFileNameDataset
15 |
16 | # utils
17 | from .base_dataset import is_image_file
18 |
--------------------------------------------------------------------------------
/sketch_nn/dataset/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 |
7 | ImageSuffices = [
8 | '.jpg', '.JPG', '.jpeg', '.JPEG',
9 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
10 | '.tif', '.TIF', '.tiff', '.TIFF',
11 | ]
12 |
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in ImageSuffices)
16 |
--------------------------------------------------------------------------------
/sketch_nn/edge_map/DoG/XDoG.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import numpy as np
7 | import cv2
8 | from scipy import ndimage as ndi
9 | from skimage import filters
10 |
11 |
12 | class XDoG:
13 |
14 | def __init__(self,
15 | gamma=0.98,
16 | phi=200,
17 | eps=-0.1,
18 | sigma=0.8,
19 | k=10,
20 | binarize: bool = True):
21 | """
22 | XDoG algorithm.
23 |
24 | Args:
25 | gamma: Control the size of the Gaussian filter
26 | phi: Control changes in edge strength
27 | eps: Threshold for controlling edge strength
28 | sigma: The standard deviation of the Gaussian filter controls the degree of smoothness
29 | k: Control the size ratio of Gaussian filter, (k=10 or k=1.6)
30 | binarize(bool): Whether to binarize the output
31 | """
32 |
33 | super(XDoG, self).__init__()
34 |
35 | self.gamma = gamma
36 | assert 0 <= self.gamma <= 1
37 |
38 | self.phi = phi
39 | assert 0 <= self.phi <= 1500
40 |
41 | self.eps = eps
42 | assert -1 <= self.eps <= 1
43 |
44 | self.sigma = sigma
45 | assert 0.1 <= self.sigma <= 10
46 |
47 | self.k = k
48 | assert 1 <= self.k <= 100
49 |
50 | self.binarize = binarize
51 |
52 | def __call__(self, img):
53 | # to gray if image is not already grayscale
54 | if len(img.shape) == 3 and img.shape[2] == 3:
55 | img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
56 | elif len(img.shape) == 3 and img.shape[2] == 4:
57 | img = cv2.cvtColor(img, cv2.COLOR_BGRA2GRAY)
58 |
59 | if np.isnan(img).any():
60 | img[np.isnan(img)] = np.mean(img[~np.isnan(img)])
61 |
62 | # gaussian filter
63 | imf1 = ndi.gaussian_filter(img, self.sigma)
64 | imf2 = ndi.gaussian_filter(img, self.sigma * self.k)
65 | imdiff = imf1 - self.gamma * imf2
66 |
67 | # XDoG
68 | imdiff = (imdiff < self.eps) * 1.0 + (imdiff >= self.eps) * (1.0 + np.tanh(self.phi * imdiff))
69 |
70 | # normalize
71 | imdiff -= imdiff.min()
72 | imdiff /= imdiff.max()
73 |
74 | if self.binarize:
75 | th = filters.threshold_otsu(imdiff)
76 | imdiff = (imdiff >= th).astype('float32')
77 |
78 | return imdiff
79 |
--------------------------------------------------------------------------------
/sketch_nn/edge_map/DoG/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from .XDoG import XDoG
7 |
8 | __all__ = ['XDoG']
9 |
--------------------------------------------------------------------------------
/sketch_nn/edge_map/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/sketch_nn/edge_map/canny/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import cv2
7 |
8 |
9 | class CannyDetector:
10 |
11 | def __call__(self, img, low_threshold, high_threshold, L2gradient=False):
12 | return cv2.Canny(img, low_threshold, high_threshold, L2gradient)
13 |
14 |
15 | __all__ = ['CannyDetector']
16 |
--------------------------------------------------------------------------------
/sketch_nn/edge_map/image_grads/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from .laplacian import LaplacianDetector
7 |
8 | __all__ = ['LaplacianDetector']
9 |
--------------------------------------------------------------------------------
/sketch_nn/edge_map/image_grads/laplacian.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 |
7 | import cv2
8 |
9 |
10 | class LaplacianDetector:
11 |
12 | def __call__(self, img):
13 | return cv2.Laplacian(img, cv2.CV_64F)
14 |
--------------------------------------------------------------------------------
/sketch_nn/methods/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/sketch_nn/methods/inversion/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/sketch_nn/model/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/InformativeDrawings/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/InformativeDrawings/default_config.yaml:
--------------------------------------------------------------------------------
1 | input_nc: 3
2 | output_nc: 1
3 | n_blocks: 3
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | import torch
4 |
5 |
6 | class BaseModel:
7 |
8 | def name(self):
9 | return 'BaseModel'
10 |
11 | def initialize(self, opt):
12 | self.opt = opt
13 | self.isTrain = opt.isTrain
14 | self.device = torch.device("cuda" if opt.use_cuda else "cpu")
15 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name)
16 |
17 | def set_input(self, input):
18 | self.input = input
19 |
20 | def forward(self):
21 | pass
22 |
23 | # used in test time, no backprop
24 | def test(self):
25 | pass
26 |
27 | def get_image_paths(self):
28 | pass
29 |
30 | def optimize_parameters(self):
31 | pass
32 |
33 | def get_current_visuals(self):
34 | return self.input
35 |
36 | def get_current_errors(self):
37 | return {}
38 |
39 | def save(self, label):
40 | pass
41 |
42 | # helper saving function that can be used by subclasses
43 | def save_network(self, network, network_label, epoch_label):
44 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
45 | save_path = os.path.join(self.save_dir, save_filename)
46 | torch.save(network.cpu().state_dict(), save_path)
47 | network = network.to(self.device)
48 |
49 | # helper loading function that can be used by subclasses
50 | def load_network(self, network, network_label, epoch_label):
51 | save_filename = '%s_net_%s.pth' % (epoch_label, network_label)
52 | if self.opt.pretrain_path:
53 | save_path = os.path.join(self.opt.pretrain_path, save_filename)
54 | else:
55 | save_path = os.path.join(self.save_dir, save_filename)
56 | network.load_state_dict(torch.load(save_path))
57 |
58 | # update learning rate (called once every epoch)
59 | def update_learning_rate(self):
60 | for scheduler in self.schedulers:
61 | scheduler.step()
62 | lr = self.optimizers[0].param_groups[0]['lr']
63 | print('learning rate = %.7f' % lr)
64 |
65 | def set_requires_grad(self, nets, requires_grad=False):
66 | if not isinstance(nets, list):
67 | nets = [nets]
68 | for net in nets:
69 | if net is not None:
70 | for param in net.parameters():
71 | param.requires_grad = requires_grad
72 |
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 |
6 | from .util import mkdirs
7 |
8 |
9 | class BaseOptions():
10 | def __init__(self):
11 | self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12 | self.initialized = False
13 |
14 | def initialize(self):
15 | self.parser.add_argument('--dataroot', required=True,
16 | help='path to images (should have subfolders trainA, trainB, valA, valB, etc)')
17 | self.parser.add_argument('--batchSize', type=int, default=1, help='input batch size')
18 | self.parser.add_argument('--loadSize', type=int, default=286, help='scale images to this size')
19 | self.parser.add_argument('--fineSize', type=int, default=256, help='then crop to this size')
20 | self.parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
21 | self.parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels')
22 | self.parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in first conv layer')
23 | self.parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in first conv layer')
24 | self.parser.add_argument('--which_model_netD', type=str, default='basic', help='selects model to use for netD')
25 | self.parser.add_argument('--which_model_netG', type=str, default='resnet_9blocks',
26 | help='selects model to use for netG')
27 | self.parser.add_argument('--n_layers_D', type=int, default=3, help='only used if which_model_netD==n_layers')
28 | self.parser.add_argument('--no-cuda', action='store_true', default=False,
29 | help='disable CUDA training (please use CUDA_VISIBLE_DEVICES to select GPU)')
30 | self.parser.add_argument('--name', type=str, default='experiment_name',
31 | help='name of the experiment. It decides where to store samples and models')
32 | self.parser.add_argument('--dataset_mode', type=str, default='unaligned',
33 | help='chooses how datasets are loaded. [unaligned | aligned | single]')
34 | self.parser.add_argument('--model', type=str, default='cycle_gan',
35 | help='chooses which model to use. cycle_gan, pix2pix, test')
36 | self.parser.add_argument('--which_direction', type=str, default='AtoB', help='AtoB or BtoA')
37 | self.parser.add_argument('--nThreads', default=6, type=int, help='# threads for loading data')
38 | self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
39 | self.parser.add_argument('--norm', type=str, default='instance',
40 | help='instance normalization or batch normalization')
41 | self.parser.add_argument('--serial_batches', action='store_true',
42 | help='if true, takes images in order to make batches, otherwise takes them randomly')
43 | self.parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
44 | self.parser.add_argument('--display_id', type=int, default=1, help='window id of the web display')
45 | self.parser.add_argument('--display_server', type=str, default="http://localhost",
46 | help='visdom server of the web display')
47 | self.parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display')
48 | self.parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator')
49 | self.parser.add_argument('--max_dataset_size', type=int, default=float("inf"),
50 | help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
51 | self.parser.add_argument('--resize_or_crop', type=str, default='resize_and_crop',
52 | help='scaling and cropping of images at load time [resize_and_crop|crop|scale_width|scale_width_and_crop]')
53 | self.parser.add_argument('--no_flip', action='store_true',
54 | help='if specified, do not flip the images for data augmentation')
55 | self.parser.add_argument('--init_type', type=str, default='normal',
56 | help='network initialization [normal|xavier|kaiming|orthogonal]')
57 | self.parser.add_argument('--render_dir', type=str, default='sketch-rendered')
58 | self.parser.add_argument('--aug_folder', type=str, default='width-5')
59 | self.parser.add_argument('--stroke_dir', type=str, default='')
60 | self.parser.add_argument('--crop', action='store_true')
61 | self.parser.add_argument('--rotate', action='store_true')
62 | self.parser.add_argument('--color_jitter', action='store_true')
63 | self.parser.add_argument('--stroke_no_couple', action='store_true', help='')
64 | self.parser.add_argument('--pretrain_path', type=str, default='')
65 | self.parser.add_argument('--nGT', type=int, default=5)
66 | self.parser.add_argument('--rot_int_max', type=int, default=3)
67 | self.parser.add_argument('--jitter_amount', type=float, default=0.02)
68 | self.parser.add_argument('--inverse_gamma', action='store_true')
69 | self.parser.add_argument('--img_mean', type=float, nargs='+')
70 | self.parser.add_argument('--img_std', type=float, nargs='+')
71 | self.parser.add_argument('--lst_file', type=str)
72 | self.initialized = True
73 |
74 | def parse(self):
75 | if not self.initialized:
76 | self.initialize()
77 | self.opt = self.parser.parse_args()
78 | self.opt.isTrain = self.isTrain # train or test
79 |
80 | self.opt.use_cuda = not self.opt.no_cuda and torch.cuda.is_available()
81 | args = vars(self.opt)
82 |
83 | print('------------ Options -------------')
84 | for k, v in sorted(args.items()):
85 | print('%s: %s' % (str(k), str(v)))
86 | print('-------------- End ----------------')
87 |
88 | # save to the disk
89 | expr_dir = os.path.join(self.opt.checkpoints_dir, self.opt.name)
90 | mkdirs(expr_dir)
91 | file_name = os.path.join(expr_dir, 'opt.txt')
92 | with open(file_name, 'wt') as opt_file:
93 | opt_file.write('------------ Options -------------\n')
94 | for k, v in sorted(args.items()):
95 | opt_file.write('%s: %s\n' % (str(k), str(v)))
96 | opt_file.write('-------------- End ----------------\n')
97 | return self.opt
98 |
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/default_config.yaml:
--------------------------------------------------------------------------------
1 | which_model_netG: "resnet_9blocks"
2 | input_nc: 3
3 | output_nc: 1
4 | norm: "batch"
5 | use_dropout: False
6 | n_blocks: 9
7 | ngf: 64
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/image_pool.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 | import torch
4 |
5 |
6 | class ImagePool():
7 | def __init__(self, pool_size):
8 | self.pool_size = pool_size
9 | if self.pool_size > 0:
10 | self.num_imgs = 0
11 | self.images = []
12 |
13 | def query(self, images):
14 | if self.pool_size == 0:
15 | return images
16 | return_images = []
17 | for image in images:
18 | image = torch.unsqueeze(image, 0)
19 | if self.num_imgs < self.pool_size:
20 | self.num_imgs = self.num_imgs + 1
21 | self.images.append(image)
22 | return_images.append(image)
23 | else:
24 | p = random.uniform(0, 1)
25 | if p > 0.5:
26 | random_id = random.randint(0, self.pool_size-1)
27 | tmp = self.images[random_id].clone()
28 | self.images[random_id] = image
29 | return_images.append(tmp)
30 | else:
31 | return_images.append(image)
32 | return_images = torch.cat(return_images, 0)
33 | return return_images
34 |
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | def initialize(self):
6 | BaseOptions.initialize(self)
7 | self.parser.add_argument('--ntest', type=int, default=float("inf"), help='# of test examples.')
8 | self.parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
9 | self.parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images')
10 | self.parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
11 | self.parser.add_argument('--which_epoch', type=str, default='latest',
12 | help='which epoch to load? set to latest to use latest cached model')
13 | self.parser.add_argument('--how_many', type=int, default=50, help='how many test images to run')
14 | self.parser.add_argument('--file_name', type=str, default='')
15 | self.parser.add_argument('--suffix', type=str, default='')
16 | self.isTrain = False
17 |
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/PhotoSketching/util.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import torch
3 | import numpy as np
4 | from PIL import Image
5 | import inspect
6 | import re
7 | import numpy as np
8 | import os
9 | import collections
10 |
11 | # Converts a Tensor into a Numpy array
12 | # |imtype|: the desired type of the converted numpy array
13 | def tensor2im(image_tensor, imtype=np.uint8):
14 | image_numpy = image_tensor[0].cpu().float().numpy()
15 | if image_numpy.shape[0] == 1:
16 | image_numpy = np.tile(image_numpy, (3, 1, 1))
17 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
18 | return image_numpy.astype(imtype)
19 |
20 | def tensor2im2(image_tensor, imtype=np.uint8):
21 | image_numpy = image_tensor.detach().cpu().float().numpy()
22 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0
23 | return image_numpy.astype(imtype)
24 |
25 | def tensor2im3(image_tensor, imtype=np.uint8):
26 | image_numpy = 1.0 - image_tensor.detach().cpu().float().numpy()
27 | if image_numpy.shape[0] == 1:
28 | image_numpy = np.tile(image_numpy, (3, 1, 1))
29 | image_numpy = 255.0 * np.transpose(image_numpy, (1, 2, 0))
30 | return image_numpy.astype(imtype)
31 |
32 | def tensor2im4(image_tensor, img_mean, img_std, imtype=np.uint8):
33 | image_numpy = image_tensor.detach().cpu().float().numpy()
34 | n_channel = len(img_mean)
35 | for c in range(n_channel):
36 | image_numpy[c, :, :] = image_numpy[c, :, :]*img_std[c] + img_mean[c]
37 | if image_numpy.shape[0] == 1:
38 | image_numpy = np.tile(image_numpy, (3, 1, 1))
39 | image_numpy = 255.0 * np.transpose(image_numpy, (1, 2, 0))
40 | return image_numpy.astype(imtype)
41 |
42 | def diagnose_network(net, name='network'):
43 | mean = 0.0
44 | count = 0
45 | for param in net.parameters():
46 | if param.grad is not None:
47 | mean += torch.mean(torch.abs(param.grad.detach()))
48 | count += 1
49 | if count > 0:
50 | mean = mean / count
51 | print(name)
52 | print(mean)
53 |
54 |
55 | def save_image(image_numpy, image_path):
56 | image_pil = Image.fromarray(image_numpy)
57 | image_pil.save(image_path)
58 |
59 |
60 | def print_numpy(x, val=True, shp=False):
61 | x = x.astype(np.float64)
62 | if shp:
63 | print('shape,', x.shape)
64 | if val:
65 | x = x.flatten()
66 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
67 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
68 |
69 |
70 | def mkdirs(paths):
71 | if isinstance(paths, list) and not isinstance(paths, str):
72 | for path in paths:
73 | mkdir(path)
74 | else:
75 | mkdir(paths)
76 |
77 |
78 | def mkdir(path):
79 | if not os.path.exists(path):
80 | os.makedirs(path)
81 |
--------------------------------------------------------------------------------
/sketch_nn/photo2sketch/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import os
7 | from argparse import Namespace
8 | from typing import Union, Dict
9 | from functools import lru_cache
10 | from omegaconf import OmegaConf
11 |
12 | __all__ = ["photo2sketch_model_build_util", "photo2sketch_available_models"]
13 |
14 | _METHODS = ["PhotoSketching", "InformativeDrawings"]
15 |
16 |
17 | def photo2sketch_available_models():
18 | return _METHODS
19 |
20 |
21 | @lru_cache()
22 | def default_config_path(dir_name: str) -> str:
23 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), dir_name, "default_config.yaml")
24 |
25 |
26 | def photo2sketch_model_build_util(
27 | method: str = "PhotoSketching",
28 | model_config: Union[Namespace, Dict] = None
29 | ):
30 | assert method in _METHODS, f"Model {method} not recognized."
31 |
32 | if model_config is None: # load default configuration
33 | config_path = default_config_path(method)
34 | model_config = OmegaConf.load(config_path)
35 |
36 | if method == "PhotoSketching":
37 | from .PhotoSketching.networks import ResnetGenerator, get_norm_layer
38 | norm_layer = get_norm_layer(norm_type=model_config.norm)
39 | model = ResnetGenerator(model_config.input_nc, model_config.output_nc,
40 | model_config.ngf, norm_layer, model_config.use_dropout,
41 | model_config.n_blocks)
42 | return model
43 | elif method == "InformativeDrawings":
44 | from .InformativeDrawings.model import Generator
45 | model = Generator(model_config.input_nc, model_config.output_nc, model_config.n_blocks)
46 | return model
47 | else:
48 | raise ModuleNotFoundError("Model [%s] not recognized." % method)
49 |
--------------------------------------------------------------------------------
/sketch_nn/rasterize/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | from .rasterize import sketch_vector_rasterize
--------------------------------------------------------------------------------
/sketch_nn/rasterize/bresenham.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description: Implementation of Bresenham's line drawing algorithm.
5 | # See en.wikipedia.org/wiki/Bresenham's_line_algorithm
6 |
7 |
8 | def bresenham_algo(x0, y0, x1, y1):
9 | """
10 | Yield integer coordinates on the line from (x0, y0) to (x1, y1).
11 | Input coordinates should be integers.
12 |
13 | Examples:
14 | >>> from bresenham import bresenham
15 | >>> list(bresenham(-1, -4, 3, 2))
16 | [(-1, -4), (0, -3), (0, -2), (1, -1), (2, 0), (2, 1), (3, 2)]
17 |
18 | Args:
19 | x0: integer coordinates
20 | y0: integer coordinates
21 | x1: integer coordinates
22 | y1: integer coordinates
23 |
24 | Returns:
25 | The result will contain both the start and the end point.
26 | """
27 | dx = x1 - x0
28 | dy = y1 - y0
29 |
30 | xsign = 1 if dx > 0 else -1
31 | ysign = 1 if dy > 0 else -1
32 |
33 | dx = abs(dx)
34 | dy = abs(dy)
35 |
36 | if dx > dy:
37 | xx, xy, yx, yy = xsign, 0, 0, ysign
38 | else:
39 | dx, dy = dy, dx
40 | xx, xy, yx, yy = 0, ysign, xsign, 0
41 |
42 | D = 2 * dy - dx
43 | y = 0
44 |
45 | for x in range(dx + 1):
46 | yield x0 + x * xx + y * yx, y0 + x * xy + y * yy
47 | if D >= 0:
48 | y += 1
49 | D -= 2 * dx
50 | D += 2 * dy
51 |
--------------------------------------------------------------------------------
/sketch_nn/rasterize/rasterize.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
6 | import numpy as np
7 | import scipy.ndimage
8 |
9 | from .bresenham import bresenham_algo
10 |
11 |
12 | def get_stroke_num(vector_image):
13 | return len(np.split(vector_image[:, :2], np.where(vector_image[:, 2])[0] + 1, axis=0)[:-1])
14 |
15 |
16 | def select_strokes(vector_image, strokes):
17 | """
18 | select strokes
19 | Args:
20 | vector_image: vector_image(x,y,p) coordinate array
21 | strokes: after keeping only selected strokes
22 |
23 | Returns:
24 |
25 | """
26 | c = vector_image
27 | c_split = np.split(c[:, :2], np.where(c[:, 2])[0] + 1, axis=0)[:-1]
28 |
29 | c_selected = []
30 | for i in strokes:
31 | c_selected.append(c_split[i])
32 |
33 | xyp = []
34 | for i in c_selected:
35 | p = np.zeros((len(i), 1))
36 | p[-1] = 1
37 | xyp.append(np.hstack((i, p)))
38 | xyp = np.concatenate(xyp)
39 | return xyp
40 |
41 |
42 | def batch_points2png(vector_images, Side=256):
43 | for vector_image in vector_images:
44 | pixel_length = 0
45 | # number_of_samples = random
46 | sample_freq = list(np.round(np.linspace(0, len(vector_image), 18)[1:]))
47 | Sample_len = []
48 | raster_images = []
49 | raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32)
50 | initX, initY = int(vector_image[0, 0]), int(vector_image[0, 1])
51 | for i in range(0, len(vector_image)):
52 | if i > 0:
53 | if vector_image[i - 1, 2] == 1:
54 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1])
55 |
56 | cordList = list(bresenham_algo(initX, initY, int(vector_image[i, 0]), int(vector_image[i, 1])))
57 | pixel_length += len(cordList)
58 |
59 | for cord in cordList:
60 | if (cord[0] > 0 and cord[1] > 0) and (cord[0] < Side and cord[1] < Side):
61 | raster_image[cord[1], cord[0]] = 255.0
62 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1])
63 |
64 | if i in sample_freq:
65 | raster_images.append(scipy.ndimage.binary_dilation(raster_image, iterations=2) * 255.0)
66 | Sample_len.append(pixel_length)
67 |
68 | raster_images.append(scipy.ndimage.binary_dilation(raster_image, iterations=3) * 255.0)
69 | Sample_len.append(pixel_length)
70 |
71 | return raster_images
72 |
73 |
74 | def points2png(vector_image, Side=256):
75 | raster_image = np.zeros((int(Side), int(Side)), dtype=np.float32)
76 | initX, initY = int(vector_image[0, 0]), int(vector_image[0, 1])
77 | pixel_length = 0
78 |
79 | for i in range(0, len(vector_image)):
80 | if i > 0:
81 | if vector_image[i - 1, 2] == 1:
82 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1])
83 |
84 | cordList = list(bresenham_algo(initX, initY, int(vector_image[i, 0]), int(vector_image[i, 1])))
85 | pixel_length += len(cordList)
86 |
87 | for cord in cordList:
88 | if (cord[0] > 0 and cord[1] > 0) and (cord[0] < Side and cord[1] < Side):
89 | raster_image[cord[1], cord[0]] = 255.0
90 | initX, initY = int(vector_image[i, 0]), int(vector_image[i, 1])
91 |
92 | raster_image = scipy.ndimage.binary_dilation(raster_image) * 255.0
93 | return raster_image
94 |
95 |
96 | def preprocess(sketch_points, side=256.0):
97 | sketch_points = sketch_points.astype(np.float)
98 | sketch_points[:, :2] = sketch_points[:, :2] / np.array([256, 256])
99 | sketch_points[:, :2] = sketch_points[:, :2] * side
100 | sketch_points = np.round(sketch_points)
101 | return sketch_points
102 |
103 |
104 | def sketch_vector_rasterize(sketch_points):
105 | sketch_points = preprocess(sketch_points)
106 | raster_images = points2png(sketch_points)
107 | return raster_images
108 |
109 |
110 | def convert_to_red(image):
111 | l = image.shape[1]
112 | image[1] = np.zeros((l, l))
113 | image[2] = np.zeros((l, l))
114 | return image
115 |
116 |
117 | def convert_to_green(image):
118 | l = image.shape[1]
119 | image[0] = np.zeros((l, l))
120 | image[2] = np.zeros((l, l))
121 | return image
122 |
123 |
124 | def convert_to_blue(image):
125 | l = image.shape[1]
126 | image[0] = np.zeros((l, l))
127 | image[1] = np.zeros((l, l))
128 | return image
129 |
130 |
131 | def convert_to_black(image):
132 | l = image.shape[1]
133 | image[0] = np.zeros((l, l))
134 | image[1] = np.zeros((l, l))
135 | image[2] = np.zeros((l, l))
136 | return image
137 |
--------------------------------------------------------------------------------
/style_transfer/AdaIN/README.md:
--------------------------------------------------------------------------------
1 | # pytorch-AdaIN
2 |
3 | This is an unofficial pytorch implementation of a paper, Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization [Huang+, ICCV2017].
4 | I'm really grateful to the [original implementation](https://github.com/xunhuang1995/AdaIN-style) in Torch by the authors, which is very useful.
5 |
6 | 
7 |
8 | ## Requirements
9 | Please install requirements by `pip install -r requirements.txt`
10 |
11 | - Python 3.5+
12 | - PyTorch 0.4+
13 | - TorchVision
14 | - Pillow
15 |
16 | (optional, for training)
17 | - tqdm
18 | - TensorboardX
19 |
20 | ## Usage
21 |
22 | ### Download models
23 | Download [decoder.pth](https://drive.google.com/file/d/1bMfhMMwPeXnYSQI6cDWElSZxOxc6aVyr/view?usp=sharing)/[vgg_normalized.pth](https://drive.google.com/file/d/1EpkBA2K2eYILDSyPTt0fztz59UjAIpZU/view?usp=sharing) and put them under `models/`.
24 |
25 | ### Test
26 | Use `--content` and `--style` to provide the respective path to the content and style image.
27 | ```
28 | CUDA_VISIBLE_DEVICES= python test.py --content input/content/cornell.jpg --style input/style/woman_with_hat_matisse.jpg
29 | ```
30 |
31 | You can also run the code on directories of content and style images using `--content_dir` and `--style_dir`. It will save every possible combination of content and styles to the output directory.
32 | ```
33 | CUDA_VISIBLE_DEVICES= python test.py --content_dir input/content --style_dir input/style
34 | ```
35 |
36 | This is an example of mixing four styles by specifying `--style` and `--style_interpolation_weights` option.
37 | ```
38 | CUDA_VISIBLE_DEVICES= python test.py --content input/content/avril.jpg --style input/style/picasso_self_portrait.jpg,input/style/impronte_d_artista.jpg,input/style/trial.jpg,input/style/antimonocromatismo.jpg --style_interpolation_weights 1,1,1,1 --content_size 512 --style_size 512 --crop
39 | ```
40 |
41 | Some other options:
42 | * `--content_size`: New (minimum) size for the content image. Keeping the original size if set to 0.
43 | * `--style_size`: New (minimum) size for the content image. Keeping the original size if set to 0.
44 | * `--alpha`: Adjust the degree of stylization. It should be a value between 0.0 and 1.0 (default).
45 | * `--preserve_color`: Preserve the color of the content image.
46 |
47 |
48 | ### Train
49 | Use `--content_dir` and `--style_dir` to provide the respective directory to the content and style images.
50 | ```
51 | CUDA_VISIBLE_DEVICES= python train.py --content_dir --style_dir
52 | ```
53 |
54 | For more details and parameters, please refer to --help option.
55 |
56 | I share the model trained by this code [here](https://drive.google.com/file/d/1YIBRdgGBoVllLhmz_N7PwfeP5V9Vz2Nr/view?usp=sharing)
57 |
58 | ## References
59 | - [1]: X. Huang and S. Belongie. "Arbitrary Style Transfer in Real-time with Adaptive Instance Normalization.", in ICCV, 2017.
60 | - [2]: [Original implementation in Torch](https://github.com/xunhuang1995/AdaIN-style)
61 |
--------------------------------------------------------------------------------
/style_transfer/AdaIN/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 | # URL: https://github.com/naoto0804/pytorch-AdaIN
6 |
7 | from .function import coral, adaptive_instance_normalization
8 |
9 | __all__ = ['coral', 'adaptive_instance_normalization']
10 |
--------------------------------------------------------------------------------
/style_transfer/AdaIN/function.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def calc_mean_std(feat, eps=1e-5):
5 | # eps is a small value added to the variance to avoid divide-by-zero.
6 | size = feat.size()
7 | assert (len(size) == 4)
8 | N, C = size[:2]
9 | feat_var = feat.view(N, C, -1).var(dim=2) + eps
10 | feat_std = feat_var.sqrt().view(N, C, 1, 1)
11 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
12 | return feat_mean, feat_std
13 |
14 |
15 | def adaptive_instance_normalization(content_feat, style_feat):
16 | assert (content_feat.size()[:2] == style_feat.size()[:2])
17 | size = content_feat.size()
18 | style_mean, style_std = calc_mean_std(style_feat)
19 | content_mean, content_std = calc_mean_std(content_feat)
20 |
21 | normalized_feat = (content_feat - content_mean.expand(
22 | size)) / content_std.expand(size)
23 | return normalized_feat * style_std.expand(size) + style_mean.expand(size)
24 |
25 |
26 | def _calc_feat_flatten_mean_std(feat):
27 | # takes 3D feat (C, H, W), return mean and std of array within channels
28 | assert (feat.size()[0] == 3)
29 | assert (isinstance(feat, torch.FloatTensor))
30 | feat_flatten = feat.view(3, -1)
31 | mean = feat_flatten.mean(dim=-1, keepdim=True)
32 | std = feat_flatten.std(dim=-1, keepdim=True)
33 | return feat_flatten, mean, std
34 |
35 |
36 | def _mat_sqrt(x):
37 | U, D, V = torch.svd(x)
38 | return torch.mm(torch.mm(U, D.pow(0.5).diag()), V.t())
39 |
40 |
41 | def coral(source, target):
42 | # assume both source and target are 3D array (C, H, W)
43 | # Note: flatten -> f
44 |
45 | source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
46 | source_f_norm = (source_f - source_f_mean.expand_as(
47 | source_f)) / source_f_std.expand_as(source_f)
48 | source_f_cov_eye = \
49 | torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
50 |
51 | target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
52 | target_f_norm = (target_f - target_f_mean.expand_as(
53 | target_f)) / target_f_std.expand_as(target_f)
54 | target_f_cov_eye = \
55 | torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
56 |
57 | source_f_norm_transfer = torch.mm(
58 | _mat_sqrt(target_f_cov_eye),
59 | torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
60 | source_f_norm)
61 | )
62 |
63 | source_f_transfer = source_f_norm_transfer * \
64 | target_f_std.expand_as(source_f_norm) + \
65 | target_f_mean.expand_as(source_f_norm)
66 |
67 | return source_f_transfer.view(source.size())
68 |
--------------------------------------------------------------------------------
/style_transfer/AdaIN/net.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 | from .function import calc_mean_std, adaptive_instance_normalization as adain
4 |
5 | decoder = nn.Sequential(
6 | nn.ReflectionPad2d((1, 1, 1, 1)),
7 | nn.Conv2d(512, 256, (3, 3)),
8 | nn.ReLU(),
9 | nn.Upsample(scale_factor=2, mode='nearest'),
10 | nn.ReflectionPad2d((1, 1, 1, 1)),
11 | nn.Conv2d(256, 256, (3, 3)),
12 | nn.ReLU(),
13 | nn.ReflectionPad2d((1, 1, 1, 1)),
14 | nn.Conv2d(256, 256, (3, 3)),
15 | nn.ReLU(),
16 | nn.ReflectionPad2d((1, 1, 1, 1)),
17 | nn.Conv2d(256, 256, (3, 3)),
18 | nn.ReLU(),
19 | nn.ReflectionPad2d((1, 1, 1, 1)),
20 | nn.Conv2d(256, 128, (3, 3)),
21 | nn.ReLU(),
22 | nn.Upsample(scale_factor=2, mode='nearest'),
23 | nn.ReflectionPad2d((1, 1, 1, 1)),
24 | nn.Conv2d(128, 128, (3, 3)),
25 | nn.ReLU(),
26 | nn.ReflectionPad2d((1, 1, 1, 1)),
27 | nn.Conv2d(128, 64, (3, 3)),
28 | nn.ReLU(),
29 | nn.Upsample(scale_factor=2, mode='nearest'),
30 | nn.ReflectionPad2d((1, 1, 1, 1)),
31 | nn.Conv2d(64, 64, (3, 3)),
32 | nn.ReLU(),
33 | nn.ReflectionPad2d((1, 1, 1, 1)),
34 | nn.Conv2d(64, 3, (3, 3)),
35 | )
36 |
37 | vgg = nn.Sequential(
38 | nn.Conv2d(3, 3, (1, 1)),
39 | nn.ReflectionPad2d((1, 1, 1, 1)),
40 | nn.Conv2d(3, 64, (3, 3)),
41 | nn.ReLU(), # relu1-1
42 | nn.ReflectionPad2d((1, 1, 1, 1)),
43 | nn.Conv2d(64, 64, (3, 3)),
44 | nn.ReLU(), # relu1-2
45 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
46 | nn.ReflectionPad2d((1, 1, 1, 1)),
47 | nn.Conv2d(64, 128, (3, 3)),
48 | nn.ReLU(), # relu2-1
49 | nn.ReflectionPad2d((1, 1, 1, 1)),
50 | nn.Conv2d(128, 128, (3, 3)),
51 | nn.ReLU(), # relu2-2
52 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
53 | nn.ReflectionPad2d((1, 1, 1, 1)),
54 | nn.Conv2d(128, 256, (3, 3)),
55 | nn.ReLU(), # relu3-1
56 | nn.ReflectionPad2d((1, 1, 1, 1)),
57 | nn.Conv2d(256, 256, (3, 3)),
58 | nn.ReLU(), # relu3-2
59 | nn.ReflectionPad2d((1, 1, 1, 1)),
60 | nn.Conv2d(256, 256, (3, 3)),
61 | nn.ReLU(), # relu3-3
62 | nn.ReflectionPad2d((1, 1, 1, 1)),
63 | nn.Conv2d(256, 256, (3, 3)),
64 | nn.ReLU(), # relu3-4
65 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
66 | nn.ReflectionPad2d((1, 1, 1, 1)),
67 | nn.Conv2d(256, 512, (3, 3)),
68 | nn.ReLU(), # relu4-1, this is the last layer used
69 | nn.ReflectionPad2d((1, 1, 1, 1)),
70 | nn.Conv2d(512, 512, (3, 3)),
71 | nn.ReLU(), # relu4-2
72 | nn.ReflectionPad2d((1, 1, 1, 1)),
73 | nn.Conv2d(512, 512, (3, 3)),
74 | nn.ReLU(), # relu4-3
75 | nn.ReflectionPad2d((1, 1, 1, 1)),
76 | nn.Conv2d(512, 512, (3, 3)),
77 | nn.ReLU(), # relu4-4
78 | nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
79 | nn.ReflectionPad2d((1, 1, 1, 1)),
80 | nn.Conv2d(512, 512, (3, 3)),
81 | nn.ReLU(), # relu5-1
82 | nn.ReflectionPad2d((1, 1, 1, 1)),
83 | nn.Conv2d(512, 512, (3, 3)),
84 | nn.ReLU(), # relu5-2
85 | nn.ReflectionPad2d((1, 1, 1, 1)),
86 | nn.Conv2d(512, 512, (3, 3)),
87 | nn.ReLU(), # relu5-3
88 | nn.ReflectionPad2d((1, 1, 1, 1)),
89 | nn.Conv2d(512, 512, (3, 3)),
90 | nn.ReLU() # relu5-4
91 | )
92 |
93 |
94 | class Net(nn.Module):
95 | def __init__(self, encoder, decoder):
96 | super(Net, self).__init__()
97 | enc_layers = list(encoder.children())
98 | self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
99 | self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
100 | self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
101 | self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
102 | self.decoder = decoder
103 | self.mse_loss = nn.MSELoss()
104 |
105 | # fix the encoder
106 | for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4']:
107 | for param in getattr(self, name).parameters():
108 | param.requires_grad = False
109 |
110 | # extract relu1_1, relu2_1, relu3_1, relu4_1 from input image
111 | def encode_with_intermediate(self, input):
112 | results = [input]
113 | for i in range(4):
114 | func = getattr(self, 'enc_{:d}'.format(i + 1))
115 | results.append(func(results[-1]))
116 | return results[1:]
117 |
118 | # extract relu4_1 from input image
119 | def encode(self, input):
120 | for i in range(4):
121 | input = getattr(self, 'enc_{:d}'.format(i + 1))(input)
122 | return input
123 |
124 | def calc_content_loss(self, input, target):
125 | assert (input.size() == target.size())
126 | assert (target.requires_grad is False)
127 | return self.mse_loss(input, target)
128 |
129 | def calc_style_loss(self, input, target):
130 | assert (input.size() == target.size())
131 | assert (target.requires_grad is False)
132 | input_mean, input_std = calc_mean_std(input)
133 | target_mean, target_std = calc_mean_std(target)
134 | return self.mse_loss(input_mean, target_mean) + \
135 | self.mse_loss(input_std, target_std)
136 |
137 | def forward(self, content, style, alpha=1.0):
138 | assert 0 <= alpha <= 1
139 | style_feats = self.encode_with_intermediate(style)
140 | content_feat = self.encode(content)
141 | t = adain(content_feat, style_feats[-1])
142 | t = alpha * t + (1 - alpha) * content_feat
143 |
144 | g_t = self.decoder(t)
145 | g_t_feats = self.encode_with_intermediate(g_t)
146 |
147 | loss_c = self.calc_content_loss(g_t_feats[-1], t)
148 | loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
149 | for i in range(1, 4):
150 | loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
151 | return loss_c, loss_s
152 |
--------------------------------------------------------------------------------
/style_transfer/STROTSS/README.md:
--------------------------------------------------------------------------------
1 | # PyTorch implementation of Style Transfer by Relaxed Optimal Transport and Self-Similarity (STROTSS) with improvements
2 |
3 | Implements [STROTSS](https://arxiv.org/abs/1904.12785) with sinkhorn EMD as introduced in the paper [Interactive Neural Style Transfer with artists](https://arxiv.org/pdf/2003.06659).
4 |
5 | This code is inspired by [the original implementation](https://github.com/nkolkin13/STROTSS) released by the authors of STROTSS.
6 |
7 |
8 | ## Dependencies:
9 | * python3 >= 3.6
10 | * pytorch >= 1.0
11 | * torchvision >= 0.4
12 | * imageio >= 2.2
13 | * numpy >= 1.1
14 |
15 | ## Usage:
16 |
17 | * standard
18 | ```
19 | python test.py -c images/content_im.jpg -s images/style_im.jpg
20 | ```
21 | * sinkhorn earth movers distance
22 | ```
23 | python test.py -c images/content_im.jpg -s images/style_im.jpg --use_sinkhorn
24 | ```
25 | * guidance masks
26 | ```
27 | python test.py -c images/content_im.jpg -s images/style_im.jpg --content_guidance images/content_guidance.jpg --style_guidance images/style_guidance
28 | ```
29 | General usage
30 | ```
31 | python test.py
32 | --content CONTENT
33 | --style STYLE
34 | [--output OUTPUT]
35 | [--content_weight CONTENT_WEIGHT]
36 | [--max_scale MAX_SCALE]
37 | [--seed SEED]
38 | [--content_guidance CONTENT_GUIDANCE]
39 | [--style_guidance STYLE_GUIDANCE]
40 | [--print_freq PRINT_FREQ]
41 | [--use_sinkhorn]
42 | [--sinkhorn_reg SINKHORN_REG]
43 | [--sinkhorn_maxiter SINKHORN_MAXITER]
44 | ```
45 |
46 | ## Citation
47 |
48 | If you use this code, please cite [the original STROTSS paper](https://arxiv.org/abs/1904.12785) and
49 | ```
50 | @article{kerdreux2020interactive,
51 | title={Interactive Neural Style Transfer with Artists},
52 | author={Kerdreux, Thomas and Thiry, Louis and Kerdreux, Erwan},
53 | journal={arXiv preprint arXiv:2003.06659},
54 | year={2020}
55 | }
56 | ```
57 |
--------------------------------------------------------------------------------
/style_transfer/STROTSS/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------
/style_transfer/STROTSS/style_transfer.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import time
3 |
4 | import imageio
5 | import torch
6 |
7 | from . import utils
8 | from . import vgg_pt
9 | from . import loss_utils
10 |
11 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
12 |
13 |
14 | def style_transfer(stylized_im, content_im, style_path, output_path,
15 | long_side, content_weight, content_regions, style_regions,
16 | lr, print_freq=100, max_iter=250,
17 | resample_freq=1, optimize_laplacian_pyramid=True,
18 | use_sinkhorn=False, sinkhorn_reg=0.1, sinkhorn_maxiter=30):
19 | cnn = vgg_pt.Vgg16_pt().to(device)
20 |
21 | phi = lambda x: cnn.forward(x)
22 | phi2 = lambda x, y, z: cnn.forward_cat(x, z, samps=y, forward_func=cnn.forward)
23 |
24 | if optimize_laplacian_pyramid:
25 | laplacian_pyramid = utils.create_laplacian_pyramid(stylized_im, pyramid_depth=5)
26 | parameters = [torch.nn.Parameter(li.data, requires_grad=True) for li in laplacian_pyramid]
27 | else:
28 | parameters = [torch.nn.Parameter(stylized_im.data, requires_grad=True)]
29 |
30 | optimizer = torch.optim.RMSprop(parameters, lr=lr)
31 |
32 | content_im_cnn_features = cnn(content_im)
33 |
34 | style_image_paths = glob.glob(style_path + '*')[::3]
35 |
36 | strotss_loss = loss_utils.RelaxedOptimalTransportSelfSimilarityLoss(
37 | use_sinkhorn=use_sinkhorn, sinkhorn_reg=sinkhorn_reg, sinkhorn_maxiter=sinkhorn_maxiter)
38 |
39 | style_features = []
40 | for style_region in style_regions:
41 | style_features.append(utils.load_style_features(phi2, style_image_paths, style_region,
42 | subsamps=1000, scale=long_side, inner=5))
43 |
44 | if optimize_laplacian_pyramid:
45 | stylized_im = utils.synthetize_image_from_laplacian_pyramid(parameters)
46 | else:
47 | stylized_im = parameters[0]
48 |
49 | resized_content_regions = []
50 | for content_region in content_regions:
51 | resized_content_region = utils.resize(torch.from_numpy(content_region),
52 | (stylized_im.size(3), stylized_im.size(2)), mode='nearest').numpy()
53 | resized_content_regions.append(resized_content_region.astype('bool'))
54 |
55 | for i in range(max_iter):
56 | if i == 200:
57 | optimizer = torch.optim.RMSprop(parameters, lr=0.1 * lr)
58 |
59 | optimizer.zero_grad()
60 | if optimize_laplacian_pyramid:
61 | stylized_im = utils.synthetize_image_from_laplacian_pyramid(parameters)
62 | else:
63 | stylized_im = parameters[0]
64 |
65 | if i == 0 or i % (resample_freq * 10) == 0:
66 | for i_region, resized_content_region in enumerate(resized_content_regions):
67 | strotss_loss.init_inds(content_im_cnn_features, style_features[i_region], resized_content_region,
68 | i_region)
69 |
70 | if i == 0 or i % resample_freq == 0:
71 | strotss_loss.shuffle_feature_inds()
72 |
73 | stylized_im_cnn_features = cnn(stylized_im)
74 |
75 | loss = strotss_loss.eval(stylized_im_cnn_features,
76 | content_im_cnn_features, style_features,
77 | content_weight=content_weight, moment_weight=1.0)
78 |
79 | loss.backward()
80 | optimizer.step()
81 |
82 | if i % print_freq == 0:
83 | print(f'step {i}/{max_iter}, loss {loss.item():.6f}')
84 |
85 | return stylized_im, loss
86 |
87 |
88 | def run_style_transfer(content_path, style_path, content_weight, max_scale, content_regions, style_regions,
89 | output_path='./output.png', print_freq=100, use_sinkhorn=False, sinkhorn_reg=0.1,
90 | sinkhorn_maxiter=30):
91 | smallest_size = 64
92 | start = time.time()
93 |
94 | content_image, style_image = utils.load_img(content_path), utils.load_img(style_path)
95 | _, content_H, content_W = content_image.size()
96 | _, style_H, style_W = style_image.size()
97 | print(f'content image size {content_H}x{content_W}, style image size {style_H}x{style_W}')
98 |
99 | for scale in range(1, max_scale + 1):
100 | t0 = time.time()
101 |
102 | scaled_size = smallest_size * (2 ** (scale - 1))
103 |
104 | print('Processing scale {}/{}, size {}...'.format(scale, max_scale, scaled_size))
105 |
106 | content_scaled_size = (int(content_H * scaled_size / content_W), scaled_size) if content_H < content_W else (
107 | scaled_size, int(content_W * scaled_size / content_H))
108 | content_image_scaled = utils.resize(content_image.unsqueeze(0), content_scaled_size).to(device)
109 | bottom_laplacian = content_image_scaled - utils.resize(utils.downsample(content_image_scaled),
110 | content_scaled_size)
111 |
112 | lr = 2e-3
113 | if scale == 1:
114 | style_image_mean = style_image.unsqueeze(0).mean(dim=(2, 3), keepdim=True).to(device)
115 | stylized_im = style_image_mean + bottom_laplacian
116 | elif scale > 1 and scale < max_scale:
117 | stylized_im = utils.resize(stylized_im.clone(), content_scaled_size) + bottom_laplacian
118 | elif scale == max_scale:
119 | stylized_im = utils.resize(stylized_im.clone(), content_scaled_size)
120 | lr = 1e-3
121 |
122 | stylized_im, final_loss = style_transfer(stylized_im, content_image_scaled, style_path, output_path,
123 | scaled_size, content_weight, content_regions, style_regions, lr,
124 | print_freq=print_freq, use_sinkhorn=use_sinkhorn,
125 | sinkhorn_reg=sinkhorn_reg, sinkhorn_maxiter=sinkhorn_maxiter)
126 |
127 | content_weight /= 2.0
128 | print('...done in {:.1f} sec, final loss {:.4f}'.format(time.time() - t0, final_loss.item()))
129 |
130 | print('Finished in {:.1f} secs'.format(time.time() - start))
131 |
132 | canvas = torch.clamp(stylized_im[0], -0.5, 0.5).data.cpu().numpy().transpose(1, 2, 0)
133 | print(f'Saving to output to {output_path}.')
134 | imageio.imwrite(output_path, canvas)
135 |
136 | return final_loss, stylized_im
137 |
--------------------------------------------------------------------------------
/style_transfer/STROTSS/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | import imageio
4 | import numpy as np
5 | import torch
6 |
7 | from .style_transfer import run_style_transfer
8 | from .utils import extract_regions
9 |
10 | if __name__ == '__main__':
11 | parser = argparse.ArgumentParser('Style transfer by relaxed optimal transport with sinkhorn distance')
12 | parser.add_argument('--content', '-c', help="path of content img", required=True)
13 | parser.add_argument('--style', '-s', help="path of style img", required=True)
14 | parser.add_argument('--output', '-o', help="path of output img", default='output.png')
15 | parser.add_argument('--content_weight', type=float, help='no padding used', default=0.5)
16 | parser.add_argument('--max_scale', type=int, help='max scale for the style transfer', default=4)
17 | parser.add_argument('--seed', type=int, help='random seed', default=0)
18 | parser.add_argument('--content_guidance', default='', help="path of content guidance region image")
19 | parser.add_argument('--style_guidance', default='', help="path of style guidance regions image")
20 | parser.add_argument('--print_freq', type=int, default=100, help='print frequency for the loss')
21 | parser.add_argument('--use_sinkhorn', action='store_true', help='use sinkhorn algo. for the earth mover distance')
22 | parser.add_argument('--sinkhorn_reg', type=float, help='reg param for sinkhorn', default=0.1)
23 | parser.add_argument('--sinkhorn_maxiter', type=int, default=30, help='number of interations for sinkohrn algo')
24 |
25 | args = parser.parse_args()
26 |
27 | torch.manual_seed(args.seed)
28 | np.random.seed(args.seed)
29 | content_weight = 16 * args.content_weight
30 | max_scale = args.max_scale
31 | use_guidance_region = args.content_guidance and args.style_guidance
32 |
33 | if use_guidance_region:
34 | content_regions, style_regions = extract_regions(args.content_guidance, args.style_guidance)
35 | else:
36 | content_img, style_img = imageio.imread(args.content), imageio.imread(args.style)
37 | content_regions, style_regions = [np.ones(content_img.shape[:2], dtype=np.float32)], [
38 | np.ones(style_img.shape[:2], dtype=np.float32)]
39 |
40 | loss, canvas = run_style_transfer(args.content, args.style, content_weight,
41 | max_scale, content_regions, style_regions, args.output,
42 | print_freq=args.print_freq, use_sinkhorn=args.use_sinkhorn,
43 | sinkhorn_reg=args.sinkhorn_reg,
44 | sinkhorn_maxiter=args.sinkhorn_maxiter)
45 |
--------------------------------------------------------------------------------
/style_transfer/STROTSS/utils.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 |
3 | import imageio
4 | import numpy as np
5 | import torch
6 | import torch.nn.functional as F
7 | import torchvision
8 |
9 | device = 'cuda' if torch.cuda.is_available() else 'cpu'
10 |
11 |
12 | def downsample(img, factor=2, mode='bilinear'):
13 | img_H, img_W = img.size(2), img.size(3)
14 | return F.interpolate(img, (max(img_H // factor, 1), max(img_W // factor, 1)), mode=mode)
15 |
16 |
17 | def resize(img, size, mode='bilinear'):
18 | if len(img.shape) == 2:
19 | return F.interpolate(img.unsqueeze(0).unsqueeze(0), size, mode=mode)[0, 0]
20 | elif len(img.shape) == 3:
21 | return F.interpolate(img.unsqueeze(0), size, mode=mode)[0]
22 | return F.interpolate(img, size, mode=mode)
23 |
24 |
25 | def load_img(img_path, size=None):
26 | img = torchvision.transforms.functional.to_tensor(Image.open(img_path).convert('RGB')) - 0.5
27 | if size is None:
28 | return img
29 | elif isinstance(size, (int, float)):
30 | return F.interpolate(img.unsqueeze(0), scale_factor=size / img.size(1), mode='bilinear')[0]
31 | else:
32 | return F.interpolate(img.unsqueeze(0), size, mode='bilinear')[0]
33 |
34 |
35 | def create_laplacian_pyramid(image, pyramid_depth):
36 | laplacian_pyramid = []
37 | current_image = image
38 | for i in range(pyramid_depth):
39 | laplacian_pyramid.append(current_image - resize(downsample(current_image), current_image.shape[2:4]))
40 | current_image = downsample(current_image)
41 | laplacian_pyramid.append(current_image)
42 |
43 | return laplacian_pyramid
44 |
45 |
46 | def synthetize_image_from_laplacian_pyramid(laplacian_pyramid):
47 | current_image = laplacian_pyramid[-1]
48 | for i in range(len(laplacian_pyramid) - 2, -1, -1):
49 | up_x = laplacian_pyramid[i].size(2)
50 | up_y = laplacian_pyramid[i].size(3)
51 | current_image = laplacian_pyramid[i] + resize(current_image, (up_x, up_y))
52 |
53 | return current_image
54 |
55 |
56 | YUV_transform = torch.from_numpy(np.float32([
57 | [0.577350, 0.577350, 0.577350],
58 | [-0.577350, 0.788675, -0.211325],
59 | [-0.577350, -0.211325, 0.788675]
60 | ])).to(device)
61 |
62 |
63 | def rgb_to_yuv(rgb):
64 | global YUV_transform
65 | return torch.mm(YUV_transform, rgb)
66 |
67 |
68 | def extract_regions(content_path, style_path, min_count=10000):
69 | style_guidance_img = imageio.imread(style_path).transpose(1, 0, 2)
70 | content_guidance_img = imageio.imread(content_path).transpose(1, 0, 2)
71 |
72 | color_codes, color_counts = np.unique(style_guidance_img.reshape(-1, style_guidance_img.shape[2]), axis=0,
73 | return_counts=True)
74 |
75 | color_codes = color_codes[color_counts > min_count]
76 |
77 | content_regions = []
78 | style_regions = []
79 |
80 | for color_code in color_codes:
81 | color_code = color_code[np.newaxis, np.newaxis, :]
82 |
83 | style_regions.append((np.abs(style_guidance_img - color_code).sum(axis=2) == 0).astype(np.float32))
84 | content_regions.append((np.abs(content_guidance_img - color_code).sum(axis=2) == 0).astype(np.float32))
85 |
86 | return [content_regions, style_regions]
87 |
88 |
89 | def load_style_features(features_extractor, paths, style_region, subsamps=-1, scale=-1, inner=1):
90 | features = []
91 |
92 | for p in paths:
93 | style_im = load_img(p, size=scale).unsqueeze(0).to(device)
94 |
95 | r = resize(torch.from_numpy(style_region), (style_im.size(3), style_im.size(2))).numpy()
96 |
97 | # NOTE: understand inner
98 | for j in range(inner):
99 | with torch.no_grad():
100 | features_j = features_extractor(style_im, subsamps, r)
101 |
102 | features_j = [feat_j.view(feat_j.size(0), feat_j.size(1), -1, 1) for feat_j in features_j]
103 |
104 | if len(features) == 0:
105 | features = features_j
106 | else:
107 | features = [torch.cat([features_j[i], features[i]], 2) for i in range(len(features))]
108 |
109 | return features
110 |
--------------------------------------------------------------------------------
/style_transfer/STROTSS/vgg_pt.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | import torchvision
6 |
7 |
8 | class Vgg16_pt(torch.nn.Module):
9 | def __init__(self, requires_grad=False, use_random=True):
10 | super(Vgg16_pt, self).__init__()
11 | # load pretrained model
12 | self.vgg_layers = torchvision.models.vgg16(
13 | weights=torchvision.models.VGG16_Weights.DEFAULT
14 | ).features
15 | self.use_random = use_random
16 |
17 | if not requires_grad:
18 | for param in self.parameters():
19 | param.requires_grad = False
20 |
21 | self.inds = range(11)
22 | self.layer_indices = [1, 3, 6, 8, 11, 13, 15, 22, 29]
23 |
24 | def forward_base(self, X):
25 | l2 = [X]
26 | x = X
27 | for i in range(30):
28 | x = self.vgg_layers[i].forward(x)
29 | if i in self.layer_indices:
30 | l2.append(x)
31 |
32 | return l2
33 |
34 | def forward(self, X):
35 | return self.forward_base(X)
36 |
37 | def forward_cat(self, X, r, samps=100, forward_func=None):
38 |
39 | if not forward_func:
40 | forward_func = self.forward
41 |
42 | x = X
43 | out2 = forward_func(X)
44 |
45 | try:
46 | r = r[:, :, 0]
47 | except:
48 | pass
49 |
50 | if r.max() < 0.1:
51 | region_mask = np.greater(r.flatten() + 1., 0.5)
52 | else:
53 | region_mask = np.greater(r.flatten(), 0.5)
54 |
55 | xx, xy = np.meshgrid(np.array(range(x.size(2))), np.array(range(x.size(3))))
56 | xx = np.expand_dims(xx.flatten(), 1)
57 | xy = np.expand_dims(xy.flatten(), 1)
58 | xc = np.concatenate([xx, xy], 1)
59 | xc = xc[region_mask, :]
60 |
61 | const2 = min(samps, xc.shape[0])
62 |
63 | if self.use_random:
64 | np.random.shuffle(xc)
65 | else:
66 | xc = xc[::(xc.shape[0] // const2), :]
67 |
68 | xx = xc[:const2, 0]
69 | yy = xc[:const2, 1]
70 |
71 | temp = X
72 | temp_list = [temp[:, :, xx[j], yy[j]].unsqueeze(2).unsqueeze(3) for j in range(const2)]
73 | temp = torch.cat(temp_list, 2)
74 |
75 | l2 = []
76 | for i in range(len(out2)):
77 |
78 | temp = out2[i]
79 |
80 | if i > 0 and out2[i].size(2) < out2[i - 1].size(2):
81 | xx = xx / 2.0
82 | yy = yy / 2.0
83 |
84 | xx = np.clip(xx, 0, temp.size(2) - 1).astype(np.int32)
85 | yy = np.clip(yy, 0, temp.size(3) - 1).astype(np.int32)
86 |
87 | temp_list = [temp[:, :, xx[j], yy[j]].unsqueeze(2).unsqueeze(3) for j in range(const2)]
88 | temp = torch.cat(temp_list, 2)
89 |
90 | l2.append(temp.clone().detach())
91 |
92 | out2 = [torch.cat([li.contiguous() for li in l2], 1)]
93 |
94 | return out2
95 |
96 | def forward_diff(self, X):
97 | l2 = self.forward_base(X)
98 |
99 | out2 = [l2[i].contiguous() for i in self.inds]
100 |
101 | for i in range(len(out2)):
102 | temp = out2[i]
103 | temp2 = F.pad(temp, (2, 2, 0, 0), value=1.)
104 | temp3 = F.pad(temp, (0, 0, 2, 2), value=1.)
105 | out2[i] = torch.cat(
106 | [temp, temp2[:, :, :, 4:], temp2[:, :, :, :-4], temp3[:, :, 4:, :], temp3[:, :, :-4, :]], 1)
107 |
108 | return out2
109 |
--------------------------------------------------------------------------------
/style_transfer/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # Copyright (c) XiMing Xing. All rights reserved.
3 | # Author: XiMing Xing
4 | # Description:
5 |
--------------------------------------------------------------------------------