├── .gitignore ├── Dockerfile ├── README.md ├── datasets ├── __init__.py ├── base_dataset.py ├── data_utils.py ├── texture_dataset.py └── warp_dataset.py ├── environment.yml ├── inference.py ├── media ├── diagram.png ├── example.png ├── texture_custom_data_example.png ├── texture_train_example.png └── warp_train_example.png ├── models ├── README.md ├── __init__.py ├── base_gan.py ├── base_model.py ├── pix2pix_model.py ├── texture_model.py └── warp_model.py ├── modules ├── __init__.py ├── discriminators.py ├── layers.py ├── loss.py ├── losses │ ├── __init__.py │ ├── adversarial.py │ └── perceptual.py ├── pix2pix_modules.py └── swapnet_modules.py ├── optimizers └── __init__.py ├── options ├── base_options.py ├── test_options.py └── train_options.py ├── test ├── Test TextureDataset Draw ROIs.ipynb └── Test TextureDataset.ipynb ├── train.py └── util ├── __init__.py ├── calculate_imagedir_stats.py ├── decode_labels.py ├── draw_rois.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Project 2 | # machine learning models 3 | checkpoints 4 | # actual data for datasets 5 | data 6 | results 7 | 8 | 9 | .idea 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .coverage 51 | .coverage.* 52 | .cache 53 | nosetests.xml 54 | coverage.xml 55 | *.cover 56 | .hypothesis/ 57 | .pytest_cache/ 58 | 59 | # Translations 60 | *.mo 61 | *.pot 62 | 63 | # Django stuff: 64 | *.log 65 | local_settings.py 66 | db.sqlite3 67 | 68 | # Flask stuff: 69 | instance/ 70 | .webassets-cache 71 | 72 | # Scrapy stuff: 73 | .scrapy 74 | 75 | # Sphinx documentation 76 | docs/_build/ 77 | 78 | # PyBuilder 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # celery beat schedule file 88 | celerybeat-schedule 89 | 90 | # SageMath parsed files 91 | *.sage.py 92 | 93 | # Environments 94 | .env 95 | .venv 96 | env/ 97 | venv/ 98 | ENV/ 99 | env.bak/ 100 | venv.bak/ 101 | 102 | # Spyder project settings 103 | .spyderproject 104 | .spyproject 105 | 106 | # Rope project settings 107 | .ropeproject 108 | 109 | # mkdocs documentation 110 | /site 111 | 112 | # mypy 113 | .mypy_cache/ 114 | -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 2 | 3 | RUN apt-get update 4 | 5 | RUN apt-get install -y \ 6 | build-essential wget 7 | 8 | RUN apt-get install -y git 9 | RUN apt-get install -y curl 10 | 11 | WORKDIR /app/ 12 | 13 | RUN echo "Installing and creating Miniconda environment..." 14 | # Install Miniconda 15 | RUN curl -so /miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh \ 16 | && chmod +x /miniconda.sh \ 17 | && /miniconda.sh -b -p /miniconda \ 18 | && rm /miniconda.sh \ 19 | && echo ". /miniconda/etc/profile.d/conda.sh" >> ~/.bashrc 20 | 21 | ENV PATH=/miniconda/bin:$PATH 22 | 23 | RUN mkdir SwapNet 24 | 25 | # Create the environment, set to activate automatically 26 | RUN cd SwapNet && \ 27 | wget https://raw.githubusercontent.com/andrewjong/SwapNet/master/environment.yml \ 28 | && conda env create \ 29 | && echo "source activate swapnet" >> ~/.bashrc 30 | 31 | # Checking environment, required for ROI to build properly 32 | # this command should display gpu properties 33 | RUN /bin/bash -c "nvidia-smi || echo 'nvidia-smi failed. A GPU is necessary to properly compile the ROI dependency. Make sure you have the NVIDIA Container Toolkit installed and enabled-by-default by editing /etc/docker/daemon.json'" 34 | 35 | # This should print true 36 | RUN /bin/bash -c "source activate swapnet && python -c 'import torch; print(torch.cuda.is_available())'" 37 | 38 | # CUDA Home should not be none 39 | RUN /bin/bash -c "source activate swapnet && python -c 'import torch;from torch.utils.cpp_extension import CUDA_HOME; print(CUDA_HOME)'" 40 | 41 | # ROI Dependency 42 | RUN echo "Compiling ROI dependency..." 43 | 44 | RUN git clone https://github.com/jwyang/faster-rcnn.pytorch.git # clone to a SEPARATE project directory 45 | 46 | RUN /bin/bash -c "source activate swapnet && cd faster-rcnn.pytorch && git checkout pytorch-1.0 && pip install -r requirements.txt" 47 | 48 | RUN /bin/bash -c "source activate swapnet && cd faster-rcnn.pytorch/lib/pycocotools && wget https://raw.githubusercontent.com/muaz-urwa/temp_files/master/setup.py && python setup.py build_ext --inplace" 49 | 50 | RUN /bin/bash -c "source activate swapnet && cd faster-rcnn.pytorch/lib && python setup.py build develop" 51 | 52 | RUN /bin/bash -c "source activate swapnet && ln -s /app/faster-rcnn.pytorch/lib /app/SwapNet/lib" 53 | 54 | RUN /bin/bash -c "source activate swapnet && conda install seaborn" 55 | 56 | RUN echo "Done!" 57 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | from torchvision.transforms import transforms 5 | 6 | from datasets.base_dataset import BaseDataset 7 | 8 | 9 | def find_dataset_using_name(dataset_name): 10 | """Import the module "data/[dataset_name]_dataset.py". 11 | 12 | In the file, the class called DatasetNameDataset() will 13 | be instantiated. It has to be a subclass of BaseDataset, 14 | and it is case-insensitive. 15 | """ 16 | dataset_filename = "datasets." + dataset_name + "_dataset" 17 | datasetlib = importlib.import_module(dataset_filename) 18 | 19 | dataset = None 20 | target_dataset_name = dataset_name.replace("_", "") + "dataset" 21 | for name, cls in datasetlib.__dict__.items(): 22 | if name.lower() == target_dataset_name.lower() and issubclass(cls, BaseDataset): 23 | dataset = cls 24 | 25 | if dataset is None: 26 | raise NotImplementedError( 27 | f"In {dataset_filename}.py, there should be a subclass of BaseDataset " 28 | f"with class name that matches {target_dataset_name} in lowercase." 29 | ) 30 | 31 | return dataset 32 | 33 | def get_options_modifier(dataset_name): 34 | """Return the static method of the dataset class.""" 35 | dataset_class = find_dataset_using_name(dataset_name) 36 | return dataset_class.modify_commandline_options 37 | 38 | 39 | def create_dataset(opt, **ds_kwargs): 40 | """Create a dataset given the option. 41 | 42 | This function wraps the class CappedDataLoader. 43 | This is the main interface between this package and 'train.py'/'test.py' 44 | 45 | Example: 46 | >>> from datasets import create_dataset 47 | >>> dataset = create_dataset(opt) 48 | """ 49 | data_loader = CappedDataLoader(opt, **ds_kwargs) 50 | return data_loader 51 | 52 | 53 | class CappedDataLoader: 54 | """Wrapper class of Dataset class that caps the data limit at the specified 55 | max_dataset_size """ 56 | 57 | def __init__(self, opt, **ds_kwargs): 58 | """Initialize this class 59 | 60 | Step 1: create a dataset instance given the name [dataset_mode] 61 | Step 2: create a multi-threaded data loader. 62 | """ 63 | self.opt = opt 64 | dname = opt.dataset if opt.dataset else opt.model 65 | print(f"Creating dataset {dname}...", end=" ") 66 | dataset_class = find_dataset_using_name(dname) 67 | self.dataset = dataset_class(opt, **ds_kwargs) 68 | print(f"dataset [{type(self.dataset).__name__}] was created") 69 | self.dataloader = torch.utils.data.DataLoader( 70 | self.dataset, 71 | batch_size=opt.batch_size, 72 | shuffle=opt.shuffle_data, 73 | num_workers=opt.num_workers, 74 | ) 75 | 76 | def __len__(self): 77 | """Return the number of data in the dataset""" 78 | return min(len(self.dataset), self.opt.max_dataset_size) 79 | 80 | def __iter__(self): 81 | """Return a batch of data""" 82 | for i, data in enumerate(self.dataloader): 83 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 84 | break 85 | yield data 86 | 87 | 88 | def get_transforms(opt): 89 | """ 90 | Return Composed torchvision transforms based on specified arguments. 91 | """ 92 | transforms_list = [] 93 | if "none" in opt.input_transforms: 94 | return 95 | every = "all" in opt.input_transforms 96 | 97 | if every or "vflip" in opt.input_transforms: 98 | transforms_list.append(transforms.RandomVerticalFlip()) 99 | if every or "hflip" in opt.input_transforms: 100 | transforms_list.append(transforms.RandomHorizontalFlip()) 101 | if every or "affine" in opt.input_transforms: 102 | transforms_list.append( 103 | transforms.RandomAffine( 104 | degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=20 105 | ) 106 | ) 107 | if every or "perspective" in opt.input_transforms: 108 | transforms_list.append(transforms.RandomPerspective()) 109 | 110 | return transforms.RandomOrder(transforms_list) 111 | -------------------------------------------------------------------------------- /datasets/base_dataset.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch.utils.data as data 4 | 5 | 6 | class BaseDataset(data.Dataset, ABC): 7 | """This class is an abstract base class (ABC) for datasets. 8 | To create a subclass, you need to implement the following four functions: 9 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 10 | -- <__len__>: return the size of dataset. 11 | -- <__getitem__>: get a data point. 12 | -- : (optionally) add dataset-specific options and set default options. 13 | """ 14 | 15 | def __init__(self, opt): 16 | """Initialize the class; save the options in the class 17 | Parameters: 18 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 19 | """ 20 | self.opt = opt 21 | self.root = opt.dataroot 22 | self.crop_bounds = self.parse_crop_bounds() 23 | self.is_train = opt.is_train 24 | 25 | @staticmethod 26 | def modify_commandline_options(parser, is_train): 27 | """Add new dataset-specific options, and rewrite default values for existing options. 28 | Parameters: 29 | parser -- original option parser 30 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 31 | Returns: 32 | the modified parser. 33 | """ 34 | return parser 35 | 36 | @abstractmethod 37 | def __len__(self): 38 | """Return the total number of images in the dataset.""" 39 | return 0 40 | 41 | @abstractmethod 42 | def __getitem__(self, index): 43 | """Return a data point and its metadata information. 44 | Parameters: 45 | index - - a random integer for data indexing 46 | Returns: 47 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 48 | """ 49 | pass 50 | 51 | def parse_crop_bounds(self): 52 | if isinstance(self.opt.crop_size, int) and self.opt.crop_size < self.opt.load_size: 53 | minimum = int((self.opt.load_size - self.opt.crop_size) / 2) 54 | maximum = self.opt.load_size - minimum 55 | crop_bounds = (minimum, minimum), (maximum, maximum) # assuming square 56 | else: 57 | crop_bounds = eval(self.opt.crop_bounds) if self.opt.crop_bounds else None 58 | return crop_bounds 59 | -------------------------------------------------------------------------------- /datasets/data_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from scipy import sparse 4 | import pandas as pd 5 | import random 6 | from collections import Counter 7 | 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | from scipy.sparse import load_npz 12 | from torch import Tensor 13 | from torchvision.transforms import functional as TF 14 | 15 | IMG_EXTENSIONS = [ 16 | ".jpg", 17 | ".JPG", 18 | ".jpeg", 19 | ".JPEG", 20 | ".png", 21 | ".PNG", 22 | ".ppm", 23 | ".PPM", 24 | ".bmp", 25 | ".BMP", 26 | ] 27 | NP_EXTENSIONS = [".npz"] # numpy compressed 28 | 29 | 30 | def get_norm_stats(dataroot, key): 31 | try: 32 | df = pd.read_json( 33 | os.path.join(dataroot, "normalization_stats.json"), lines=True 34 | ).set_index("path") 35 | except ValueError: 36 | raise ValueError(f"Could not find 'normalization_stats.json' for {dataroot}") 37 | series = df.loc[key] 38 | return series["means"], series["stds"] 39 | 40 | 41 | def unnormalize(tensor, mean, std, clamp=True, inplace=False): 42 | if not inplace: 43 | tensor = tensor.clone() 44 | 45 | def unnormalize_1(ten, men, st): 46 | for t, m, s in zip(ten, men, st): 47 | t.mul_(s).add_(m) 48 | if clamp: 49 | t.clamp_(0, 1) 50 | 51 | if tensor.shape == 4: 52 | # then we have batch size in front or something 53 | for t in tensor: 54 | unnormalize_1(t, mean, std) 55 | else: 56 | unnormalize_1(tensor, mean, std) 57 | 58 | return tensor 59 | 60 | 61 | def scale_tensor(tensor, scale_each=False, range=None): 62 | """ 63 | From torchvision's make_grid 64 | :return: 65 | """ 66 | tensor = tensor.clone() # avoid modifying tensor in-place 67 | if range is not None: 68 | assert isinstance( 69 | range, tuple 70 | ), "range has to be a tuple (min, max) if specified. min and max are numbers" 71 | 72 | def norm_ip(img, min, max): 73 | img.clamp_(min=min, max=max) 74 | img.add_(-min).div_(max - min + 1e-5) 75 | 76 | def norm_range(t, range): 77 | if range is not None: 78 | norm_ip(t, range[0], range[1]) 79 | else: 80 | norm_ip(t, float(t.min()), float(t.max())) 81 | 82 | if scale_each is True: 83 | for t in tensor: # loop over mini-batch dimension 84 | norm_range(t, range) 85 | else: 86 | norm_range(tensor, range) 87 | 88 | return tensor 89 | 90 | 91 | def is_image_file(filename): 92 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 93 | 94 | 95 | def in_extensions(filename, extensions): 96 | return any(filename.endswith(extension) for extension in extensions) 97 | 98 | 99 | def find_valid_files(dir, extensions=None, max_dataset_size=float("inf")): 100 | """ 101 | Get all the images recursively under a dir. 102 | Args: 103 | dir: 104 | extensions: specific extensions to look for. else will use IMG_EXTENSIONS 105 | max_dataset_size: 106 | 107 | Returns: found files, where each item is a tuple (id, ext) 108 | 109 | """ 110 | if isinstance(extensions, str): 111 | extensions = [extensions] 112 | images = [] 113 | assert os.path.isdir(dir), "%s is not a valid directory" % dir 114 | 115 | for root, _, fnames in sorted(os.walk(dir, followlinks=True)): 116 | for fname in fnames: 117 | if in_extensions(fname, extensions if extensions else IMG_EXTENSIONS): 118 | path = os.path.join(root, fname) 119 | images.append(path) 120 | return images[: min(max_dataset_size, len(images))] 121 | 122 | 123 | def get_dir_file_extension(dir, check=5): 124 | """ 125 | Guess what extensions are for all files in a dir. 126 | Args: 127 | dir: 128 | check: 129 | 130 | Returns: 131 | 132 | """ 133 | exts = [] 134 | for root, _, fnames in os.walk(dir, followlinks=True): 135 | for fname in fnames[:check]: 136 | ext = os.path.splitext(fname)[1] 137 | exts.append(ext) 138 | if len(exts) == 0: 139 | raise ValueError(f"did not find any files under dir: {dir}") 140 | return Counter(exts).most_common(1)[0][0] 141 | 142 | 143 | def remove_top_dir(dir, n=1): 144 | """ 145 | Removes the top level dirs from a path 146 | Args: 147 | dir: 148 | n: 149 | 150 | Returns: 151 | 152 | """ 153 | parts = dir.split(os.path.sep) 154 | top_removed = os.path.sep.join(parts[n:]) 155 | return top_removed 156 | 157 | 158 | def remove_extension(fname): 159 | return os.path.splitext(fname)[0] 160 | 161 | 162 | def change_extension(fname, ext1, ext2): 163 | """ 164 | :return: file name with new extension 165 | """ 166 | return fname[: -len(ext1)] + ext2 167 | 168 | 169 | def crop_tensors(*tensors, crop_bounds=((0, 0), (-1, -1))): 170 | """ 171 | Crop multiple tensors 172 | Args: 173 | *tensors: 174 | crop_bounds: 175 | 176 | Returns: 177 | 178 | """ 179 | ret = [] 180 | for t in tensors: 181 | ret.append(crop_tensor(t, crop_bounds)) 182 | 183 | return ret[0] if len(ret) == 1 else ret 184 | 185 | 186 | def crop_tensor(tensor: Tensor, crop_bounds): 187 | """ 188 | Crops a tensor at the given crop bounds. 189 | :param tensor: 190 | :param crop_bounds: (x_min, y_min), (x_max, y_max) 191 | :return: 192 | """ 193 | (x_min, y_min), (x_max, y_max) = crop_bounds 194 | return tensor[:, y_min:y_max, x_min:x_max] 195 | 196 | 197 | def crop_rois(rois, crop_bounds): 198 | """ 199 | Crop roi coordinates 200 | 201 | roi coordinates should be 202 | xmin, ymin, xmax, ymax 203 | ..., ..., ..., ... 204 | :param rois: 205 | :param crop_bounds: 206 | :return: 207 | """ 208 | # TODO: might have to worry about nan values? 209 | if isinstance(rois, np.ndarray): 210 | clip, stack, copy = (np.clip, np.stack, lambda x: x.copy()) 211 | min = lambda inp, *args: inp.min(*args) 212 | elif isinstance(rois, torch.Tensor): 213 | clip, stack, copy = (torch.clamp, torch.stack, lambda x: x.clone()) 214 | # must do [0] because torch.min() returns two values 215 | min = lambda inp, *args: inp.min(*args)[0] 216 | else: 217 | raise ValueError( 218 | f"input must be numpy ndarray or torch Tensor, received {type(rois)}" 219 | ) 220 | 221 | if crop_bounds is not None: 222 | rois = copy(rois) 223 | (x_min, y_min), (x_max, y_max) = crop_bounds 224 | # clip the x-axis to be within bounds. xmin and xmax index 225 | xs = rois[:, [0, 2]] 226 | xs = clip(xs, x_min, x_max - 1) 227 | xs -= x_min 228 | # clip the y-axis to be within bounds. ymin and ymax index 229 | ys = rois[:, (1, 3)] 230 | ys = clip(ys, y_min, y_max - 1) 231 | ys -= y_min 232 | # put it back together again 233 | rois = stack((xs[:, 0], ys[:, 0], xs[:, 1], ys[:, 1]), 1) 234 | return rois 235 | 236 | 237 | def random_image_roi_flip(img, rois, vp=0.5, hp=0.5): 238 | """ 239 | Randomly flips an image and associated ROI tensor together. 240 | I.e. if the image flips, the ROI will flip to match. 241 | Args: 242 | img: a PIL image 243 | rois: 244 | vp: 245 | hp: 246 | Returns: flipped PIL image, flipped rois 247 | """ 248 | W, H = img.size 249 | 250 | if random.random() < vp: 251 | img = TF.vflip(img) 252 | center = int(H / 2) 253 | flip_rois_(rois, 0, center) 254 | 255 | if random.random() < hp: 256 | img = TF.hflip(img) 257 | center = int(W / 2) 258 | flip_rois_(rois, 1, center) 259 | 260 | return img, rois 261 | 262 | 263 | def flip_rois_(rois, axis, center): 264 | """ 265 | Flips rois in place, along the given axis, at the given center value 266 | Args: 267 | rois: roi tensor 268 | axis: 0 for a vertical flip, 1 for a horizontal flip 269 | center: a positive integer, where to flip 270 | E.g. if axis=1 271 | ------------ ------------ 272 | | | | | | | 273 | | + | | => | | + | 274 | | | | | | | 275 | | | | | | | 276 | ------------ ------------ 277 | Returns: 278 | """ 279 | if axis == 0: # vertical flip, flip y values 280 | min_idx, max_idx = -3, -1 # use negative indexing in case of batch in 1st dim 281 | elif axis == 1: # horizontal flip, flip x values 282 | min_idx, max_idx = -4, -2 283 | else: 284 | raise ValueError(f"dim argument must be 0 or 1, received {axis}") 285 | 286 | # put side by side 287 | values = torch.stack((rois[:, min_idx], rois[:, max_idx])) 288 | # compute the flip 289 | values -= center 290 | values *= -1 291 | values += center 292 | # max and min are now swapped because we flipped 293 | max, min = torch.chunk(values, 2) 294 | # put them back in 295 | rois[:, min_idx], rois[:, max_idx] = min, max 296 | 297 | 298 | def decompress_cloth_segment(fname, n_labels) -> Tensor: 299 | """ 300 | Load cloth segmentation sparse matrix npz file and output a one hot tensor. 301 | :return: tensor of size(H,W,n_labels) 302 | """ 303 | try: 304 | data_sparse = load_npz(fname) 305 | except Exception as e: 306 | print("Could not decompress cloth segment:", fname) 307 | raise e 308 | return to_onehot_tensor(data_sparse, n_labels) 309 | 310 | 311 | def compress_and_save_cloth(cloth_tensor, fname): 312 | """ 313 | Assumes the tensor is a 1 hot encoded tensor. 314 | Compresses a tensor to a scipy sparse matrix, saves to the given file. 315 | Args: 316 | cloth_tensor: 317 | fname: 318 | 319 | Returns: 320 | """ 321 | assert len(cloth_tensor.shape) == 3, "can only compress 1 tensor at a time. remove the preceeding batch size" 322 | max_only = cloth_tensor.argmax(dim=0) 323 | as_numpy = max_only.cpu().numpy() 324 | # use column sparse matrix, because saves a bit more space for a person standing. 325 | # there's more empty space to the sides of the person 326 | as_sparse = sparse.csc_matrix(as_numpy) 327 | sparse.save_npz(fname, as_sparse) 328 | 329 | 330 | def to_onehot_tensor(sp_matrix, n_labels): 331 | """ 332 | convert sparse scipy labels matrix to onehot pt tensor of size (n_labels,H,W) 333 | Note: sparse tensors aren't supported in multiprocessing https://github.com/pytorch/pytorch/issues/20248 334 | 335 | :param sp_matrix: sparse 2d scipy matrix, with entries in range(n_labels) 336 | :return: pt tensor of size(n_labels,H,W) 337 | """ 338 | sp_matrix = sp_matrix.tocoo() 339 | indices = np.vstack((sp_matrix.data, sp_matrix.row, sp_matrix.col)) 340 | indices = torch.LongTensor(indices) 341 | values = torch.Tensor([1.0] * sp_matrix.nnz) 342 | shape = (n_labels,) + sp_matrix.shape 343 | return torch.sparse.FloatTensor(indices, values, torch.Size(shape)).to_dense() 344 | 345 | 346 | def per_channel_transform(input_tensor, transform_function) -> Tensor: 347 | """ 348 | Randomly transform each of n_channels of input data. 349 | Out of place operation 350 | :param input_tensor: must be a numpy array of size (n_channels, w, h) 351 | :param transform_function: any torchvision transforms class 352 | :return: transformed pt tensor 353 | """ 354 | input_tensor = input_tensor.numpy() 355 | tform_input_np = np.zeros(shape=input_tensor.shape, dtype=input_tensor.dtype) 356 | n_channels = input_tensor.shape[0] 357 | for i in range(n_channels): 358 | tform_input_np[i] = np.array( 359 | transform_function(Image.fromarray(input_tensor[i])) 360 | ) 361 | return torch.from_numpy(tform_input_np) 362 | -------------------------------------------------------------------------------- /datasets/texture_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import numpy as np 4 | import pandas as pd 5 | import torch 6 | from PIL import Image 7 | from torch import nn 8 | from torchvision import transforms as transforms 9 | from torchvision.transforms import functional as tf 10 | 11 | from datasets import BaseDataset, get_transforms 12 | from datasets.data_utils import ( 13 | IMG_EXTENSIONS, 14 | find_valid_files, 15 | get_dir_file_extension, 16 | remove_extension, 17 | decompress_cloth_segment, 18 | random_image_roi_flip, 19 | crop_tensors, 20 | crop_rois, 21 | get_norm_stats, 22 | ) 23 | from util.util import remove_prefix 24 | 25 | 26 | class TextureDataset(BaseDataset): 27 | """ Texture dataset for the texture module of SwapNet """ 28 | 29 | @staticmethod 30 | def modify_commandline_options(parser, is_train): 31 | # transforms 32 | parser.add_argument( 33 | "--input_transforms", 34 | nargs="+", 35 | default="none", 36 | choices=("none", "hflip", "vflip", "all"), 37 | help="what random transforms to perform on the input ('all' for all transforms)", 38 | ) 39 | if is_train: 40 | parser.set_defaults(input_transforms=("hflip", "vflip")) 41 | return parser 42 | 43 | def __init__(self, opt, texture_dir=None, cloth_dir=None): 44 | """ 45 | 46 | Args: 47 | opt: Namespace object 48 | texture_dir (str): optional override path to texture dir 49 | cloth_dir (str): optional override path to cloth dir 50 | """ 51 | super().__init__(opt) 52 | # get all texture files 53 | self.texture_dir = ( 54 | texture_dir if texture_dir else os.path.join(opt.dataroot, "texture") 55 | ) 56 | self.texture_files = find_valid_files(self.texture_dir, IMG_EXTENSIONS) 57 | 58 | self.texture_norm_stats = get_norm_stats( 59 | os.path.dirname(self.texture_dir), "texture" 60 | ) 61 | opt.texture_norm_stats = self.texture_norm_stats 62 | self._normalize_texture = transforms.Normalize(*self.texture_norm_stats) 63 | 64 | # cloth files 65 | self.cloth_dir = cloth_dir if cloth_dir else os.path.join(opt.dataroot, "cloth") 66 | self.cloth_ext = get_dir_file_extension(self.cloth_dir) 67 | if not self.is_train: 68 | self.cloth_files = find_valid_files(self.cloth_dir, extensions=".npz") 69 | if not opt.shuffle_data: 70 | self.cloth_files.sort() 71 | 72 | # load rois 73 | self.rois_db = os.path.join(opt.dataroot, "rois.csv") 74 | self.rois_df = pd.read_csv(self.rois_db, index_col=0) 75 | # todo: remove None values preemptively, else we have to fill in with 0 76 | self.rois_df = self.rois_df.replace("None", 0).astype(np.float32) 77 | 78 | # # per-channel transforms on the input 79 | # self.input_transform = get_transforms(opt) 80 | 81 | def __len__(self): 82 | if self.is_train: 83 | return len(self.texture_files) 84 | else: 85 | return min(len(self.texture_files), len(self.cloth_files)) 86 | 87 | def __getitem__(self, index: int): 88 | """ """ 89 | # (1) Get target texture. 90 | target_texture_file = self.texture_files[index] 91 | target_texture_img = Image.open(target_texture_file).convert("RGB") 92 | 93 | target_texture_tensor = self._normalize_texture( 94 | tf.to_tensor(tf.resize(target_texture_img, self.opt.load_size)) 95 | ) 96 | 97 | # file id for matching cloth and matching ROI 98 | file_id = remove_prefix( 99 | remove_extension(target_texture_file), self.texture_dir + "/" 100 | ) 101 | 102 | # (2) Get corresponding cloth if train, else cloth at index if inference. 103 | cloth_file = ( 104 | os.path.join(self.cloth_dir, file_id + self.cloth_ext) 105 | if self.is_train 106 | else self.cloth_files[index] 107 | ) 108 | cloth_tensor = decompress_cloth_segment(cloth_file, n_labels=19) 109 | # resize cloth tensor 110 | # We have to unsqueeze because interpolate expects batch in dim1 111 | cloth_tensor = nn.functional.interpolate( 112 | cloth_tensor.unsqueeze(0), size=self.opt.load_size 113 | ).squeeze() 114 | 115 | # (3) Get and scale corresponding roi. 116 | original_size = target_texture_img.size[0] # PIL width 117 | scale = float(self.opt.load_size) / original_size 118 | rois = np.rint(self.rois_df.loc[file_id].values * scale) 119 | rois_tensor = torch.from_numpy(rois) 120 | 121 | # (4) Get randomly flipped input. 122 | # input will be randomly flipped of target; if we flip input, we must flip rois 123 | hflip = ( 124 | 0.5 if any(t in self.opt.input_transforms for t in ("hflip", "all")) else 0 125 | ) 126 | vflip = ( 127 | 0.5 if any(t in self.opt.input_transforms for t in ("vflip", "all")) else 0 128 | ) 129 | input_texture_image, rois_tensor = random_image_roi_flip( 130 | target_texture_img, rois_tensor, vp=vflip, hp=hflip 131 | ) 132 | input_texture_tensor = self._normalize_texture( 133 | tf.to_tensor(tf.resize(input_texture_image, self.opt.load_size)) 134 | ) 135 | 136 | # do cropping if needed 137 | if self.crop_bounds: 138 | input_texture_tensor, cloth_tensor, target_texture_tensor = crop_tensors( 139 | input_texture_tensor, 140 | cloth_tensor, 141 | target_texture_tensor, 142 | crop_bounds=self.crop_bounds, 143 | ) 144 | rois_tensor = crop_rois(rois_tensor, self.crop_bounds) 145 | 146 | # assert shapes 147 | assert ( 148 | input_texture_tensor.shape[-2:] 149 | == target_texture_tensor.shape[-2:] 150 | == cloth_tensor.shape[-2:] 151 | ), f"input {input_texture_tensor.shape}; target {target_texture_tensor.shape}; cloth {cloth_tensor.shape}" 152 | 153 | return { 154 | "texture_paths": target_texture_file, 155 | "input_textures": input_texture_tensor, 156 | "rois": rois_tensor, 157 | "cloth_paths": cloth_file, 158 | "cloths": cloth_tensor, 159 | "target_textures": target_texture_tensor, 160 | } 161 | -------------------------------------------------------------------------------- /datasets/warp_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from typing import Tuple 4 | 5 | from torch import Tensor 6 | from PIL import Image 7 | from torch import nn 8 | from torchvision import transforms as transforms 9 | 10 | from datasets import BaseDataset, get_transforms 11 | from datasets.data_utils import ( 12 | get_dir_file_extension, 13 | remove_top_dir, 14 | remove_extension, 15 | find_valid_files, 16 | decompress_cloth_segment, 17 | per_channel_transform, 18 | crop_tensors, 19 | get_norm_stats, 20 | ) 21 | 22 | 23 | class WarpDataset(BaseDataset): 24 | """ Warp dataset for the warp module of SwapNet """ 25 | 26 | @staticmethod 27 | def modify_commandline_options(parser, is_train): 28 | parser.add_argument( 29 | "--input_transforms", 30 | nargs="+", 31 | default="none", 32 | choices=("none", "hflip", "vflip", "affine", "perspective", "all"), 33 | help="what random transforms to perform on the input ('all' for all transforms)", 34 | ) 35 | if is_train: 36 | parser.set_defaults( 37 | input_transforms=("hflip", "vflip", "affine", "perspective") 38 | ) 39 | parser.add_argument( 40 | "--per_channel_transform", 41 | action="store_true", 42 | default=True, # TODO: make this a toggle based on if data is RGB or labels 43 | help="Perform the transform for each label instead of on the image as a " 44 | "whole. --cloth_representation must be 'labels'.", 45 | ) 46 | return parser 47 | 48 | def __init__(self, opt, cloth_dir=None, body_dir=None): 49 | """ 50 | 51 | Args: 52 | opt: 53 | cloth_dir: (optional) path to cloth dir, if provided 54 | body_dir: (optional) path to body dir, if provided 55 | """ 56 | super().__init__(opt) 57 | 58 | self.cloth_dir = cloth_dir if cloth_dir else os.path.join(opt.dataroot, "cloth") 59 | print("cloth dir", self.cloth_dir) 60 | extensions = [".npz"] if self.opt.cloth_representation == "labels" else None 61 | print("Extensions:", extensions) 62 | self.cloth_files = find_valid_files(self.cloth_dir, extensions) 63 | if not opt.shuffle_data: 64 | self.cloth_files.sort() 65 | 66 | self.body_dir = body_dir if body_dir else os.path.join(opt.dataroot, "body") 67 | if not self.is_train: # only load these during inference 68 | self.body_files = find_valid_files(self.body_dir) 69 | if not opt.shuffle_data: 70 | self.body_files.sort() 71 | print("body dir", self.body_dir) 72 | self.body_norm_stats = get_norm_stats(os.path.dirname(self.body_dir), "body") 73 | opt.body_norm_stats = self.body_norm_stats 74 | self._normalize_body = transforms.Normalize(*self.body_norm_stats) 75 | 76 | self.cloth_transform = get_transforms(opt) 77 | 78 | def __len__(self): 79 | """ 80 | Get the length of usable images. Note the length of cloth and body segmentations should be same 81 | """ 82 | if not self.is_train: 83 | return min(len(self.cloth_files), len(self.body_files)) 84 | else: 85 | return len(self.cloth_files) 86 | 87 | def _load_cloth(self, index) -> Tuple[str, Tensor, Tensor]: 88 | """ 89 | Loads the cloth file as a tensor 90 | """ 91 | cloth_file = self.cloth_files[index] 92 | target_cloth_tensor = decompress_cloth_segment( 93 | cloth_file, self.opt.cloth_channels 94 | ) 95 | if self.is_train: 96 | # during train, we want to do some fancy transforms 97 | if self.opt.dataset_mode == "image": 98 | # in image mode, the input cloth is the same as the target cloth 99 | input_cloth_tensor = target_cloth_tensor.clone() 100 | elif self.opt.dataset_mode == "video": 101 | # video mode, can choose a random image 102 | cloth_file = self.cloth_files[random.randint(0, len(self)) - 1] 103 | input_cloth_tensor = decompress_cloth_segment( 104 | cloth_file, self.opt.cloth_channels 105 | ) 106 | else: 107 | raise ValueError(self.opt.dataset_mode) 108 | 109 | # apply the transformation for input cloth segmentation 110 | if self.cloth_transform: 111 | input_cloth_tensor = self._perform_cloth_transform(input_cloth_tensor) 112 | 113 | return cloth_file, input_cloth_tensor, target_cloth_tensor 114 | else: 115 | # during inference, we just want to load the current cloth 116 | return cloth_file, target_cloth_tensor, target_cloth_tensor 117 | 118 | def _load_body(self, index): 119 | """ Loads the body file as a tensor """ 120 | if self.is_train: 121 | # use corresponding strategy during train 122 | cloth_file = self.cloth_files[index] 123 | body_file = get_corresponding_file(cloth_file, self.body_dir) 124 | else: 125 | # else we have to load by index 126 | body_file = self.body_files[index] 127 | as_pil_image = Image.open(body_file).convert("RGB") 128 | body_tensor = self._normalize_body(transforms.ToTensor()(as_pil_image)) 129 | return body_file, body_tensor 130 | 131 | def _perform_cloth_transform(self, cloth_tensor): 132 | """ Either does per-channel transform or whole-image transform """ 133 | if self.opt.per_channel_transform: 134 | return per_channel_transform(cloth_tensor, self.cloth_transform) 135 | else: 136 | raise NotImplementedError("Sorry, per_channel_transform must be true") 137 | # return self.input_transform(cloth_tensor) 138 | 139 | def __getitem__(self, index): 140 | """ 141 | :returns: 142 | For training, return (input) AUGMENTED cloth seg, (input) body seg and (target) cloth seg 143 | of the SAME image 144 | For inference (e.g validation), return (input) cloth seg and (input) body seg 145 | of 2 different images 146 | """ 147 | 148 | # the input cloth segmentation 149 | cloth_file, input_cloth_tensor, target_cloth_tensor = self._load_cloth(index) 150 | body_file, body_tensor = self._load_body(index) 151 | 152 | # RESIZE TENSORS 153 | # We have to unsqueeze because interpolate expects batch in dim1 154 | input_cloth_tensor = nn.functional.interpolate( 155 | input_cloth_tensor.unsqueeze(0), size=self.opt.load_size 156 | ).squeeze() 157 | if self.is_train: 158 | target_cloth_tensor = nn.functional.interpolate( 159 | target_cloth_tensor.unsqueeze(0), size=self.opt.load_size 160 | ).squeeze() 161 | body_tensor = nn.functional.interpolate( 162 | body_tensor.unsqueeze(0), 163 | size=self.opt.load_size, 164 | mode="bilinear", # same as default for torchvision.resize 165 | ).squeeze() 166 | 167 | # crop to the proper image size 168 | if self.crop_bounds: 169 | input_cloth_tensor, body_tensor = crop_tensors( 170 | input_cloth_tensor, body_tensor, crop_bounds=self.crop_bounds 171 | ) 172 | if self.is_train: # avoid extra work if we don't need targets for inference 173 | target_cloth_tensor = crop_tensors( 174 | target_cloth_tensor, crop_bounds=self.crop_bounds 175 | ) 176 | 177 | return { 178 | "body_paths": body_file, 179 | "bodys": body_tensor, 180 | "cloth_paths": cloth_file, 181 | "input_cloths": input_cloth_tensor, 182 | "target_cloths": target_cloth_tensor, 183 | } 184 | 185 | 186 | def get_corresponding_file(original, target_dir, target_ext=None): 187 | """ 188 | Say an original file is 189 | dataroot/subject/body/SAMPLE_ID.jpg 190 | 191 | And we want the corresponding file 192 | dataroot/subject/cloth/SAMPLE_ID.npz 193 | 194 | The corresponding file is in target_dir dataroot/subject/cloth, so we replace the 195 | top level directories with the target dir 196 | 197 | Args: 198 | original: 199 | target_dir: 200 | target_ext: 201 | 202 | Returns: 203 | 204 | """ 205 | # number of top dir to replace 206 | num_top_parts = len(target_dir.split(os.path.sep)) 207 | # replace the top dirs 208 | top_removed = remove_top_dir(original, num_top_parts) 209 | target_file = os.path.join(target_dir, top_removed) 210 | # extension of files in the target dir 211 | if not target_ext: 212 | target_ext = get_dir_file_extension(target_dir) 213 | # change the extension 214 | target_file = remove_extension(target_file) + target_ext 215 | return target_file 216 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: swapnet 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - attrs=19.1.0=py37_1 7 | - backcall=0.1.0=py37_0 8 | - blas=1.0=mkl 9 | - bleach=3.1.0=py37_0 10 | - ca-certificates=2019.5.15=1 11 | - certifi=2019.6.16=py37_1 12 | - cffi=1.12.2=py37h2e261b9_1 13 | - cudatoolkit=10.0.130=0 14 | - dbus=1.13.6=h746ee38_0 15 | - decorator=4.4.0=py37_1 16 | - defusedxml=0.5.0=py37_1 17 | - entrypoints=0.3=py37_0 18 | - expat=2.2.6=he6710b0_0 19 | - fontconfig=2.13.0=h9420a91_0 20 | - freetype=2.9.1=h8a8886c_1 21 | - glib=2.56.2=hd408876_0 22 | - gmp=6.1.2=h6c8ec71_1 23 | - gst-plugins-base=1.14.0=hbbd80ab_1 24 | - gstreamer=1.14.0=hb453b48_1 25 | - icu=58.2=h9c2bf20_1 26 | - intel-openmp=2019.3=199 27 | - ipykernel=5.1.0=py37h39e3cac_0 28 | - ipython=7.4.0=py37h39e3cac_0 29 | - ipython_genutils=0.2.0=py37_0 30 | - ipywidgets=7.4.2=py37_0 31 | - jedi=0.13.3=py37_0 32 | - jinja2=2.10=py37_0 33 | - jpeg=9b=h024ee3a_2 34 | - jsonschema=3.0.1=py37_0 35 | - jupyter=1.0.0=py37_7 36 | - jupyter_client=5.2.4=py37_0 37 | - jupyter_console=6.0.0=py37_0 38 | - jupyter_core=4.4.0=py37_0 39 | - libedit=3.1.20181209=hc058e9b_0 40 | - libffi=3.2.1=hd88cf55_4 41 | - libgcc-ng=8.2.0=hdf63c60_1 42 | - libgfortran-ng=7.3.0=hdf63c60_0 43 | - libpng=1.6.36=hbc83047_0 44 | - libsodium=1.0.16=h1bed415_0 45 | - libstdcxx-ng=8.2.0=hdf63c60_1 46 | - libtiff=4.0.10=h2733197_2 47 | - libuuid=1.0.3=h1bed415_2 48 | - libxcb=1.13=h1bed415_1 49 | - libxml2=2.9.9=he19cac6_0 50 | - markupsafe=1.1.1=py37h7b6447c_0 51 | - mistune=0.8.4=py37h7b6447c_0 52 | - mkl=2019.3=199 53 | - mkl_fft=1.0.10=py37ha843d7b_0 54 | - mkl_random=1.0.2=py37hd81dba3_0 55 | - nbconvert=5.4.1=py37_3 56 | - nbformat=4.4.0=py37_0 57 | - ncurses=6.1=he6710b0_1 58 | - ninja=1.9.0=py37hfd86e86_0 59 | - notebook=5.7.8=py37_0 60 | - numpy=1.16.2=py37h7e9f1db_0 61 | - numpy-base=1.16.2=py37hde5b4d6_0 62 | - olefile=0.46=py37_0 63 | - openssl=1.1.1c=h7b6447c_1 64 | - pandoc=2.2.3.2=0 65 | - pandocfilters=1.4.2=py37_1 66 | - parso=0.3.4=py37_0 67 | - pcre=8.43=he6710b0_0 68 | - pexpect=4.6.0=py37_0 69 | - pickleshare=0.7.5=py37_0 70 | - pillow=5.4.1=py37h34e0f95_0 71 | - pip=19.0.3=py37_0 72 | - prometheus_client=0.6.0=py37_0 73 | - prompt_toolkit=2.0.9=py37_0 74 | - ptyprocess=0.6.0=py37_0 75 | - pycparser=2.19=py37_0 76 | - pygments=2.3.1=py37_0 77 | - pyqt=5.9.2=py37h05f1152_2 78 | - pyrsistent=0.14.11=py37h7b6447c_0 79 | - python=3.7.3=h0371630_0 80 | - python-dateutil=2.8.0=py37_0 81 | - pytorch=1.2.0=py3.7_cuda10.0.130_cudnn7.6.2_0 82 | - pyzmq=18.0.0=py37he6710b0_0 83 | - qt=5.9.7=h5867ecd_1 84 | - qtconsole=4.4.3=py37_0 85 | - readline=7.0=h7b6447c_5 86 | - send2trash=1.5.0=py37_0 87 | - setuptools=40.8.0=py37_0 88 | - sip=4.19.8=py37hf484d3e_0 89 | - six=1.12.0=py37_0 90 | - sqlite=3.27.2=h7b6447c_0 91 | - terminado=0.8.1=py37_1 92 | - testpath=0.4.2=py37_0 93 | - tk=8.6.8=hbc83047_0 94 | - torchvision=0.4.0=py37_cu100 95 | - tornado=6.0.2=py37h7b6447c_0 96 | - traitlets=4.3.2=py37_0 97 | - wcwidth=0.1.7=py37_0 98 | - webencodings=0.5.1=py37_1 99 | - wheel=0.33.1=py37_0 100 | - widgetsnbextension=3.4.2=py37_0 101 | - xz=5.2.4=h14c3975_4 102 | - zeromq=4.3.1=he6710b0_3 103 | - zlib=1.2.11=h7b6447c_3 104 | - zstd=1.3.7=h0b5b093_0 105 | - pip: 106 | - adabound==0.0.5 107 | - chardet==3.0.4 108 | - cycler==0.10.0 109 | - cython==0.29.6 110 | - dominate==2.4.0 111 | - easydict==1.9 112 | - idna==2.8 113 | - kiwisolver==1.0.1 114 | - matplotlib==3.0.3 115 | - msgpack==0.6.1 116 | - opencv-python==4.0.0.21 117 | - pandas==0.24.2 118 | - protobuf==3.7.1 119 | - pyparsing==2.3.1 120 | - pytz==2019.1 121 | - pyyaml==5.1 122 | - requests==2.22.0 123 | - scipy==1.2.1 124 | - tensorboardx==1.6 125 | - torchfile==0.1.0 126 | - tqdm==4.31.1 127 | - urllib3==1.25.3 128 | - visdom==0.1.8.8 129 | - websocket-client==0.56.0 130 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | import os 4 | from typing import Callable 5 | 6 | from tqdm import tqdm 7 | 8 | from datasets import create_dataset 9 | from datasets.data_utils import compress_and_save_cloth, remove_extension 10 | from models import create_model 11 | from options.base_options import load 12 | from options.test_options import TestOptions 13 | from util import html 14 | from util.util import PromptOnce 15 | from util.visualizer import save_images 16 | 17 | WARP_SUBDIR = "warp" 18 | TEXTURE_SUBDIR = "texture" 19 | 20 | 21 | # FUNCTIONS SHOULD NOT BE IMPORTED BY OTHER MODULES. THEY ARE ONLY HELPER METHODS, 22 | # AND DEPEND ON GLOBAL VARIABLES UNDER MAIN 23 | 24 | 25 | def _setup(subfolder_name, create_webpage=True): 26 | """ 27 | Setup outdir, create a webpage 28 | Args: 29 | subfolder_name: name of the outdir and where the webpage files should go 30 | 31 | Returns: 32 | 33 | """ 34 | out_dir = get_out_dir(subfolder_name) 35 | PromptOnce.makedirs(out_dir, not opt.no_confirm) 36 | webpage = None 37 | if create_webpage: 38 | webpage = html.HTML( 39 | out_dir, 40 | f"Experiment = {opt.name}, Phase = {subfolder_name} inference, " 41 | f"Loaded Epoch = {opt.load_epoch}", 42 | ) 43 | return out_dir, webpage 44 | 45 | 46 | def get_out_dir(subfolder_name): 47 | return os.path.join(opt.results_dir, subfolder_name) 48 | 49 | 50 | def _rebuild_from_checkpoint(checkpoint_file, same_crop_load_size=False, **ds_kwargs): 51 | """ 52 | Loads a model and dataset based on the config in a particular dir. 53 | Args: 54 | checkpoint_file: dir containing args.json and model checkpoints 55 | **ds_kwargs: override kwargs for dataset 56 | 57 | Returns: loaded model, initialized dataset 58 | 59 | """ 60 | checkpoint_dir = os.path.dirname(checkpoint_file) 61 | # read the config file so we can load in the model 62 | loaded_opt = load(copy.deepcopy(opt), os.path.join(checkpoint_dir, "args.json")) 63 | # force certain attributes in the loaded cfg 64 | override_namespace( 65 | loaded_opt, 66 | is_train=False, 67 | batch_size=1, 68 | shuffle_data=opt.shuffle_data, # let inference opt take precedence 69 | ) 70 | if same_crop_load_size: # need to override this if we're using intermediates 71 | loaded_opt.load_size = loaded_opt.crop_size 72 | model = create_model(loaded_opt) 73 | # loads the checkpoint 74 | model.load_model_weights("generator", checkpoint_file).eval() 75 | model.print_networks(opt.verbose) 76 | 77 | dataset = create_dataset(loaded_opt, **ds_kwargs) 78 | 79 | return model, dataset 80 | 81 | 82 | def override_namespace(namespace, **kwargs): 83 | """ 84 | Simply overrides the attributes in the object with the specified keyword arguments 85 | Args: 86 | namespace: argparse.Namespace object 87 | **kwargs: keyword/value pairs to use as override 88 | """ 89 | assert isinstance(namespace, argparse.Namespace) 90 | for k, v in kwargs.items(): 91 | setattr(namespace, k, v) 92 | 93 | 94 | def _run_test_loop(model, dataset, webpage=None, iteration_post_hook: Callable = None): 95 | """ 96 | 97 | Args: 98 | model: object that extends BaseModel 99 | dataset: object that extends BaseDataset 100 | webpage: webpage object for saving 101 | iteration_post_hook: a function to call at the end of every iteration 102 | 103 | Returns: 104 | 105 | """ 106 | 107 | total = min(len(dataset), opt.max_dataset_size) 108 | with tqdm(total=total, unit="img") as pbar: 109 | for i, data in enumerate(dataset): 110 | if i >= total: 111 | break 112 | model.set_input(data) # set input 113 | model.test() # forward pass 114 | image_paths = model.get_image_paths() # ids of the loaded images 115 | 116 | if webpage: 117 | visuals = model.get_current_visuals() 118 | save_images(webpage, visuals, image_paths, width=opt.display_winsize) 119 | 120 | if iteration_post_hook: 121 | iteration_post_hook(local=locals()) 122 | 123 | pbar.update() 124 | 125 | if webpage: 126 | webpage.save() 127 | 128 | 129 | def _run_warp(): 130 | """ 131 | Runs the warp stage 132 | """ 133 | warp_out, webpage = _setup(WARP_SUBDIR, create_webpage=not opt.skip_intermediates) 134 | 135 | print(f"Rebuilding warp from {opt.warp_checkpoint}") 136 | warp_model, warp_dataset = _rebuild_from_checkpoint( 137 | opt.warp_checkpoint, cloth_dir=opt.cloth_dir, body_dir=opt.body_dir 138 | ) 139 | 140 | def save_cloths_npz(local): 141 | """ 142 | We must store the intermediate cloths as .npz files 143 | """ 144 | name = "_to_".join( 145 | [remove_extension(os.path.basename(p)) for p in local["image_paths"][0]] 146 | ) 147 | out_name = os.path.join(warp_out, name) 148 | # save the warped cloths 149 | compress_and_save_cloth(local["model"].fakes[0], out_name) 150 | 151 | print(f"Warping cloth to match body segmentations in {opt.body_dir}...") 152 | 153 | try: 154 | _run_test_loop( 155 | warp_model, warp_dataset, webpage, iteration_post_hook=save_cloths_npz 156 | ) 157 | except KeyboardInterrupt: 158 | print("Ending warp early.") 159 | print(f"Warp results stored in {warp_out}") 160 | 161 | 162 | def _run_texture(): 163 | """ 164 | Runs the texture stage. If opt.warp_checkpoint is also True, then it will use those 165 | intermediate cloth outputs as the texture stage's input. 166 | """ 167 | texture_out, webpage = _setup(TEXTURE_SUBDIR, create_webpage=True) 168 | 169 | if opt.warp_checkpoint: # if intermediate, cloth dir is the warped cloths 170 | cloth_dir = get_out_dir(WARP_SUBDIR) 171 | else: # otherwise if texture checkpoint alone, use what the user specified 172 | cloth_dir = opt.cloth_dir 173 | 174 | print(f"Rebuilding texture from {opt.texture_checkpoint}") 175 | texture_model, texture_dataset = _rebuild_from_checkpoint( 176 | opt.texture_checkpoint, 177 | same_crop_load_size=True if opt.warp_checkpoint else False, 178 | texture_dir=opt.texture_dir, 179 | cloth_dir=cloth_dir, 180 | ) 181 | 182 | print(f"Texturing cloth segmentations in {cloth_dir}...") 183 | try: 184 | _run_test_loop(texture_model, texture_dataset, webpage) 185 | except KeyboardInterrupt: 186 | print("Ending texture early.") 187 | print(f"Textured results stored in {texture_out}") 188 | 189 | 190 | if __name__ == "__main__": 191 | config = TestOptions() 192 | config.parse() 193 | opt = config.opt 194 | 195 | # override checkpoint options 196 | if opt.checkpoint: 197 | if not opt.warp_checkpoint: 198 | opt.warp_checkpoint = os.path.join( 199 | opt.checkpoint, "warp", f"{opt.load_epoch}_net_generator.pth" 200 | ) 201 | print("Set warp_checkpoint to", opt.warp_checkpoint) 202 | if not opt.texture_checkpoint: 203 | opt.texture_checkpoint = os.path.join( 204 | opt.checkpoint, "texture", f"{opt.load_epoch}_net_generator.pth" 205 | ) 206 | print("Set texture_checkpoint to", opt.texture_checkpoint) 207 | 208 | # use dataroot if not individually provided 209 | for subdir in ("body", "cloth", "texture"): 210 | attribute = f"{subdir}_dir" 211 | if not getattr(opt, attribute): 212 | setattr(opt, attribute, os.path.join(opt.dataroot, subdir)) 213 | 214 | # Run warp stage 215 | if opt.warp_checkpoint: 216 | print("Running warp inference...") 217 | _run_warp() 218 | 219 | # Run texture stage 220 | if opt.texture_checkpoint: 221 | print("Running texture inference...") 222 | _run_texture() 223 | 224 | print("\nDone!") 225 | -------------------------------------------------------------------------------- /media/diagram.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewjong/SwapNet/2b5fa1fcefb25a9ac5dfc4964b2e2cb2c82d110e/media/diagram.png -------------------------------------------------------------------------------- /media/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewjong/SwapNet/2b5fa1fcefb25a9ac5dfc4964b2e2cb2c82d110e/media/example.png -------------------------------------------------------------------------------- /media/texture_custom_data_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewjong/SwapNet/2b5fa1fcefb25a9ac5dfc4964b2e2cb2c82d110e/media/texture_custom_data_example.png -------------------------------------------------------------------------------- /media/texture_train_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewjong/SwapNet/2b5fa1fcefb25a9ac5dfc4964b2e2cb2c82d110e/media/texture_train_example.png -------------------------------------------------------------------------------- /media/warp_train_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewjong/SwapNet/2b5fa1fcefb25a9ac5dfc4964b2e2cb2c82d110e/media/warp_train_example.png -------------------------------------------------------------------------------- /models/README.md: -------------------------------------------------------------------------------- 1 | "Models" are a bigger concept. 2 | A model class knows how to calculate loss and optimize its parameters. 3 | 4 | Unlike a "module" which simply extend PyTorch modules. -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from models.base_model import BaseModel 3 | 4 | 5 | def find_model_using_name(model_name): 6 | """Import the module "models/[model_name]_model.py". 7 | In the file, the class called DatasetNameModel() will 8 | be instantiated. It has to be a subclass of BaseModel, 9 | and it is case-insensitive. 10 | """ 11 | model_filename = "models." + model_name + "_model" 12 | modellib = importlib.import_module(model_filename) 13 | model = None 14 | target_model_name = model_name.replace('_', '') + 'model' 15 | for name, cls in modellib.__dict__.items(): 16 | if name.lower() == target_model_name.lower() \ 17 | and issubclass(cls, BaseModel): 18 | model = cls 19 | 20 | if model is None: 21 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 22 | exit(0) 23 | 24 | return model 25 | 26 | 27 | def get_options_modifier(model_name): 28 | """Return the static method of the model class.""" 29 | model_class = find_model_using_name(model_name) 30 | return model_class.modify_commandline_options 31 | 32 | 33 | def create_model(opt): 34 | """Create a model given the option. 35 | This function warps the class CappedDataLoader. 36 | This is the main interface between this package and 'train.py'/'test.py' 37 | Example: 38 | >>> from models import create_model 39 | >>> model = create_model(opt) 40 | """ 41 | model = find_model_using_name(opt.model) 42 | instance = model(opt) 43 | print("model [%s] was created" % type(instance).__name__) 44 | return instance 45 | -------------------------------------------------------------------------------- /models/base_gan.py: -------------------------------------------------------------------------------- 1 | """ 2 | A general framework for GAN training. 3 | """ 4 | from argparse import ArgumentParser 5 | from abc import ABC, abstractmethod 6 | 7 | import optimizers 8 | from models import BaseModel 9 | import modules.loss 10 | from modules import discriminators 11 | from modules.discriminators import Discriminator 12 | 13 | 14 | class BaseGAN(BaseModel, ABC): 15 | @staticmethod 16 | def modify_commandline_options(parser: ArgumentParser, is_train): 17 | """ 18 | Adds several GAN-related training arguments. 19 | Child classes should call to extend this static method. 20 | >>> parser = super(ChildClass, ChildClass).modify_commandline_options( 21 | >>> parser, is_train 22 | >>> ) 23 | """ 24 | if is_train: 25 | # gan mode choice 26 | parser.add_argument( 27 | "--gan_mode", 28 | help="gan regularization to use", 29 | default="vanilla", 30 | choices=( 31 | "vanilla", 32 | "wgan", 33 | "wgan-gp", 34 | "lsgan", 35 | "dragan-gp", 36 | "dragan-lp", 37 | "mescheder-r1-gp", 38 | "mescheder-r2-gp", 39 | ), 40 | ) 41 | parser.add_argument( 42 | "--lambda_gan", 43 | type=float, 44 | default=1.0, 45 | help="weight for adversarial loss", 46 | ) 47 | parser.add_argument( 48 | "--lambda_discriminator", 49 | type=float, 50 | default=1.0, 51 | help="weight for discriminator loss", 52 | ) 53 | parser.add_argument( 54 | "--lambda_gp", 55 | help="weight parameter for gradient penalty", 56 | type=float, 57 | default=10, 58 | ) 59 | # discriminator choice 60 | parser.add_argument( 61 | "--discriminator", 62 | default="basic", 63 | choices=("basic", "pixel", "n_layers"), 64 | help="what discriminator type to use", 65 | ) 66 | parser.add_argument( 67 | "--n_layers_D", 68 | type=int, 69 | default=3, 70 | help="only used if discriminator==n_layers", 71 | ) 72 | parser.add_argument( 73 | "--norm", 74 | type=str, 75 | default="instance", 76 | help="instance normalization or batch normalization [instance | batch | none]", 77 | ) 78 | # optimizer choice 79 | parser.add_argument( 80 | "--optimizer_G", 81 | "--opt_G", 82 | "--optim_G", 83 | help="optimizer for generator", 84 | default="AdamW", 85 | choices=("AdamW", "AdaBound"), 86 | ) 87 | parser.add_argument( 88 | "--lr", 89 | "--g_lr", 90 | "--learning_rate", 91 | type=float, 92 | # as recommended by "10 Lessons I Learned Training GANs For a Year" 93 | default=0.0001, 94 | help="initial learning rate for generator", 95 | ) 96 | parser.add_argument('--beta1', type=float, default=0.5, 97 | help='momentum term of adam') 98 | parser.add_argument( 99 | "--optimizer_D", 100 | "--opt_D", 101 | "--optim_D", 102 | help="optimizer for discriminator", 103 | default="AdamW", 104 | choices=("AdamW", "AdaBound"), 105 | ) 106 | parser.add_argument( 107 | "--d_lr", 108 | type=float, 109 | # as recommended by "10 Lessons I Learned Training GANs For a Year" 110 | default=0.0004, 111 | help="initial learning rate for Discriminator", 112 | ) 113 | parser.add_argument( 114 | "--d_wt_decay", 115 | "--d_weight_decay", 116 | dest="d_weight_decay", 117 | default=0.01, 118 | type=float, 119 | help="optimizer L2 weight decay", 120 | ) 121 | parser.add_argument( 122 | "--gan_label_mode", 123 | default="smooth", 124 | choices=("hard", "smooth"), 125 | help="whether to use hard (real 1.0 and fake 0.0) or smooth " 126 | "(real [0.7, 1.1] and fake [0., 0.3]) values for labels", 127 | ) 128 | return parser 129 | 130 | def __init__(self, opt): 131 | """ 132 | Sets the generator, discriminator, and optimizers. 133 | 134 | Sets self.net_generator to the return value of self.define_G() 135 | 136 | Args: 137 | opt: 138 | """ 139 | super().__init__(opt) 140 | self.net_generator = self.define_G().to(self.device) 141 | modules.init_weights(self.net_generator, opt.init_type, opt.init_gain) 142 | 143 | self.model_names = ["generator"] 144 | 145 | if self.is_train: 146 | # setup discriminator 147 | self.net_discriminator = discriminators.define_D( 148 | self.get_D_inchannels(), 64, opt.discriminator, opt.n_layers_D, opt.norm 149 | ).to(self.device) 150 | modules.init_weights(self.net_discriminator, opt.init_type, opt.init_gain) 151 | 152 | # load discriminator only at train time 153 | self.model_names.append("discriminator") 154 | 155 | # setup GAN loss 156 | use_smooth = True if opt.gan_label_mode == "smooth" else False 157 | self.criterion_GAN = modules.loss.GANLoss( 158 | opt.gan_mode, smooth_labels=use_smooth 159 | ).to(self.device) 160 | 161 | if opt.lambda_discriminator: 162 | self.loss_names = ["D", "D_real", "D_fake"] 163 | if any(gp_mode in opt.gan_mode for gp_mode in ["gp", "lp"]): 164 | self.loss_names += ["D_gp"] 165 | self.loss_names += ["G"] 166 | if opt.lambda_gan: 167 | self.loss_names += ["G_gan"] 168 | 169 | # Define optimizers 170 | self.optimizer_G = optimizers.define_optimizer( 171 | self.net_generator.parameters(), opt, "G" 172 | ) 173 | self.optimizer_D = optimizers.define_optimizer( 174 | self.net_discriminator.parameters(), opt, "D" 175 | ) 176 | self.optimizer_names = ("G", "D") 177 | 178 | @abstractmethod 179 | def get_D_inchannels(self): 180 | """ 181 | Return number of channels for discriminator input. 182 | Called when constructing the Discriminator network. 183 | """ 184 | pass 185 | 186 | @abstractmethod 187 | def define_G(self): 188 | """ 189 | Return the generator module. Called in init() 190 | The returned value is set to self.net_generator(). 191 | """ 192 | pass 193 | 194 | def optimize_parameters(self): 195 | self.forward() 196 | # update D 197 | self.optimizer_D.zero_grad() 198 | self.backward_D() 199 | self.optimizer_D.step() 200 | # update G 201 | self.optimizer_G.zero_grad() 202 | self.backward_G() 203 | self.optimizer_G.step() 204 | 205 | def backward_D(self): 206 | """ 207 | Calculates loss and backpropagates for the discriminator 208 | """ 209 | # https://github.com/martinarjovsky/WassersteinGAN/blob/f7a01e82007ea408647c451b9e1c8f1932a3db67/main.py#L185 210 | if self.opt.gan_mode == "wgan": 211 | # clamp parameters to a cube 212 | for p in self.net_discriminator.parameters(): 213 | p.data.clamp(-0.01, 0.01) 214 | 215 | # calculate fake 216 | pred_fake = self.net_discriminator(self.fakes.detach()) 217 | self.loss_D_fake = self.criterion_GAN(pred_fake, False) 218 | # calculate real 219 | pred_real = self.net_discriminator(self.targets) 220 | self.loss_D_real = self.criterion_GAN(pred_real, True) 221 | 222 | self.loss_D = 0.5 * (self.loss_D_fake + self.loss_D_real) * self.opt.lambda_discriminator 223 | 224 | if any(gp_mode in self.opt.gan_mode for gp_mode in ["gp", "lp"]): 225 | # calculate gradient penalty 226 | self.loss_D_gp = modules.loss.gradient_penalty( 227 | self.net_discriminator, self.targets, self.fakes, self.opt.gan_mode 228 | ) 229 | self.loss_D += self.opt.lambda_gp * self.loss_D_gp 230 | 231 | self.loss_D.backward() 232 | 233 | @abstractmethod 234 | def backward_G(self): 235 | """ 236 | Calculate loss and backpropagates for the generator 237 | """ 238 | pass 239 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | 6 | # from . import networks 7 | from util.util import PromptOnce 8 | 9 | 10 | class BaseModel(ABC): 11 | """This class is an abstract base class (ABC) for models. 12 | To create a subclass, you need to implement the following five functions: 13 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 14 | -- : unpack data from dataset and apply preprocessing. 15 | -- : produce intermediate results. 16 | -- : calculate losses, gradients, and update network weights. 17 | -- : (optionally) add model-specific options and set default options. 18 | """ 19 | 20 | def __init__(self, opt): 21 | """Initialize the BaseModel class. 22 | Parameters: 23 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 24 | When creating your custom class, you need to implement your own initialization. 25 | In this fucntion, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): specify the images that you want to display and save. 29 | -- self.visual_names (str list): define networks used in our training. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_id = opt.gpu_id 34 | self.is_train = opt.is_train 35 | # get device name: CPU or GPU 36 | self.device = ( 37 | torch.device(f"cuda:{self.gpu_id}") 38 | if self.gpu_id is not None 39 | else torch.device("cpu") 40 | ) 41 | # save all the checkpoints to save_dir 42 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) 43 | if self.is_train: 44 | PromptOnce.makedirs(self.save_dir, not opt.no_confirm) 45 | 46 | self.loss_names = [] 47 | self.model_names = [] 48 | self.visual_names = [] 49 | self.optimizer_names = [] 50 | # self.optimizers = [] 51 | self.image_paths = [] 52 | self.metric = 0 # used for learning rate policy 'plateau' 53 | 54 | @staticmethod 55 | def modify_commandline_options(parser, is_train): 56 | """Add new model-specific options, and rewrite default values for existing options. 57 | Parameters: 58 | parser: -- original option parser 59 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 60 | Returns: 61 | the modified parser. 62 | """ 63 | return parser 64 | 65 | @abstractmethod 66 | def set_input(self, input): 67 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 68 | Parameters: 69 | input (dict): includes the data itself and its metadata information. 70 | """ 71 | pass 72 | 73 | @abstractmethod 74 | def forward(self): 75 | """Run forward pass; called by both functions and .""" 76 | pass 77 | 78 | @abstractmethod 79 | def optimize_parameters(self): 80 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 81 | pass 82 | 83 | def setup(self, opt): 84 | """Load and print networks; create schedulers 85 | Parameters: 86 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 87 | """ 88 | # if self.is_train: 89 | # self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 90 | if not self.is_train or opt.continue_train: 91 | self.load_checkpoint_dir(opt.load_epoch) 92 | self.print_networks(opt.verbose) 93 | return self 94 | 95 | def eval(self): 96 | """Make models eval mode during test time""" 97 | for name in self.model_names: 98 | if isinstance(name, str): 99 | net = getattr(self, "net_" + name) 100 | net.eval() 101 | return self 102 | 103 | def test(self): 104 | """Forward function used in test time. 105 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 106 | It also calls to produce additional visualization results 107 | """ 108 | with torch.no_grad(): 109 | self.forward() 110 | self.compute_visuals() 111 | 112 | def compute_visuals(self): 113 | """Calculate additional output images for visdom and HTML visualization""" 114 | pass 115 | 116 | def get_image_paths(self): 117 | """ Return image paths that are used to load current data""" 118 | return self.image_paths 119 | 120 | def update_learning_rate(self): 121 | """Update learning rates for all the networks; called at the end of every epoch""" 122 | # for scheduler in self.schedulers: 123 | # if self.opt.lr_policy == 'plateau': 124 | # scheduler.step(self.metric) 125 | # else: 126 | # scheduler.step() 127 | 128 | lr = self.optimizers[0].param_groups[0]["lr"] 129 | print("learning rate = %.7f" % lr) 130 | 131 | def get_current_visuals(self): 132 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 133 | visual_ret = OrderedDict() 134 | for name in self.visual_names: 135 | if isinstance(name, str): 136 | visual_ret[name] = getattr(self, name) 137 | return visual_ret 138 | 139 | def get_current_losses(self): 140 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 141 | errors_ret = OrderedDict() 142 | for name in self.loss_names: 143 | if isinstance(name, str): 144 | errors_ret[name] = float( 145 | getattr(self, "loss_" + name) 146 | ) # float(...) works for both scalar tensor and float number 147 | return errors_ret 148 | 149 | def save_checkpoint(self, epoch): 150 | """Save all the networks to the disk. 151 | 152 | Or save latest. 153 | Parameters: 154 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 155 | """ 156 | for name in self.model_names: 157 | if isinstance(name, str): 158 | save_filename = f"{epoch}_net_{name}.pth" 159 | save_path = os.path.join(self.save_dir, save_filename) 160 | net = getattr(self, f"net_{name}") 161 | 162 | if self.gpu_id is not None and torch.cuda.is_available(): 163 | torch.save(net.cpu().state_dict(), save_path) 164 | net.cuda(self.gpu_id) 165 | else: 166 | torch.save(net.cpu().state_dict(), save_path) 167 | # todo: save optimizers too! 168 | for name in self.optimizer_names: 169 | if isinstance(name, str): 170 | save_filename = f"{epoch}_optim_{name}.pth" 171 | save_path = os.path.join(self.save_dir, save_filename) 172 | optim = getattr(self, f"optimizer_{name}") 173 | torch.save(optim.state_dict(), save_path) 174 | 175 | def load_model_weights(self, model_name, weights_file): 176 | """ Loads the weights for a single model 177 | 178 | Args: 179 | model_name: name of the model to load parameters into 180 | weights_file: path to weights file 181 | """ 182 | net = getattr(self, f"net_{model_name}") 183 | print(f"loading the model {model_name} from {weights_file}") 184 | state_dict = torch.load(weights_file, map_location=self.device) 185 | if hasattr(state_dict, "_metadata"): 186 | del state_dict._metadata 187 | 188 | net.load_state_dict(state_dict) 189 | return self 190 | 191 | def load_checkpoint_dir(self, epoch): 192 | """Load all the networks from the disk. 193 | Parameters: 194 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 195 | """ 196 | for name in self.model_names: 197 | if isinstance(name, str): 198 | load_filename = f"{epoch}_net_{name}.pth" 199 | load_path = os.path.join(self.save_dir, load_filename) 200 | self.load_model_weights(name, load_path) 201 | 202 | if self.is_train: 203 | for name in self.optimizer_names: 204 | if isinstance(name, str): 205 | load_filename = f"{epoch}_optim_{name}.pth" 206 | load_path = os.path.join(self.save_dir, load_filename) 207 | optim = getattr(self, f"optimizer_{name}") 208 | print(f"loading the optimizer {name} from {load_path}") 209 | state_dict = torch.load(load_path) 210 | if hasattr(state_dict, "_metadata"): 211 | del state_dict._metadata 212 | optim.load_state_dict(state_dict) 213 | return self 214 | 215 | def print_networks(self, verbose): 216 | """Print the total number of parameters in the network and (if verbose) network architecture 217 | Parameters: 218 | verbose (bool) -- if verbose: print the network architecture 219 | """ 220 | print("---------- Networks initialized -------------") 221 | for name in self.model_names: 222 | if isinstance(name, str): 223 | net = getattr(self, "net_" + name) 224 | num_params = 0 225 | for param in net.parameters(): 226 | num_params += param.numel() 227 | if verbose: 228 | print(net) 229 | print( 230 | "[Network %s] Total number of parameters : %.3f M" 231 | % (name, num_params / 1e6) 232 | ) 233 | print("-----------------------------------------------") 234 | 235 | def set_requires_grad(self, nets, requires_grad=False): 236 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 237 | Parameters: 238 | nets (network list) -- a list of networks 239 | requires_grad (bool) -- whether the networks require gradients or not 240 | """ 241 | if not isinstance(nets, list): 242 | nets = [nets] 243 | for net in nets: 244 | if net is not None: 245 | for param in net.parameters(): 246 | param.requires_grad = requires_grad 247 | -------------------------------------------------------------------------------- /models/pix2pix_model.py: -------------------------------------------------------------------------------- 1 | from datasets.data_utils import unnormalize, scale_tensor 2 | from models import BaseModel 3 | import torch 4 | 5 | from modules.discriminators import define_D 6 | from modules.loss import GANLoss 7 | from modules.pix2pix_modules import define_G 8 | from util.decode_labels import decode_cloth_labels 9 | 10 | 11 | class Pix2PixModel(BaseModel): 12 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 13 | 14 | The model training requires '--dataset_mode aligned' dataset. 15 | By default, it uses a '--netG unet256' U-Net generator, 16 | a '--netD basic' discriminator (PatchGAN), 17 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 18 | 19 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 20 | """ 21 | @staticmethod 22 | def modify_commandline_options(parser, is_train=True): 23 | """Add new dataset-specific options, and rewrite default values for existing options. 24 | 25 | Parameters: 26 | parser -- original option parser 27 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 28 | 29 | Returns: 30 | the modified parser. 31 | 32 | For pix2pix, we do not use image buffer 33 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 34 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 35 | """ 36 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 37 | parser.set_defaults(norm='batch', netG='unet_128') 38 | if is_train: 39 | parser.add_argument('--lambda_l1', type=float, default=10, help='weight for L1 loss') 40 | # gan mode choice 41 | parser.add_argument( 42 | "--gan_mode", 43 | help="gan regularization to use", 44 | default="vanilla", 45 | choices=( 46 | "vanilla", 47 | "wgan", 48 | "wgan-gp", 49 | "lsgan", 50 | "dragan-gp", 51 | "dragan-lp", 52 | "mescheder-r1-gp", 53 | "mescheder-r2-gp", 54 | ), 55 | ) 56 | parser.add_argument( 57 | "--lambda_gan", 58 | type=float, 59 | default=1.0, 60 | help="weight for adversarial loss", 61 | ) 62 | parser.add_argument( 63 | "--lambda_gp", 64 | help="weight parameter for gradient penalty", 65 | type=float, 66 | default=10, 67 | ) 68 | # discriminator choice 69 | parser.add_argument( 70 | "--discriminator", 71 | default="basic", 72 | choices=("basic", "pixel", "n_layers"), 73 | help="what discriminator type to use", 74 | ) 75 | parser.add_argument( 76 | "--n_layers_D", 77 | type=int, 78 | default=3, 79 | help="only used if discriminator==n_layers", 80 | ) 81 | parser.add_argument( 82 | "--norm", 83 | type=str, 84 | default="instance", 85 | help="instance normalization or batch normalization [instance | batch | none]", 86 | ) 87 | # optimizer choice 88 | parser.add_argument( 89 | "--optimizer_G", 90 | help="optimizer for generator", 91 | default="AdamW", 92 | choices=("AdamW", "AdaBound"), 93 | ) 94 | parser.add_argument( 95 | "--optimizer_D", 96 | help="optimizer for discriminator", 97 | default="AdamW", 98 | choices=("AdamW", "AdaBound"), 99 | ) 100 | parser.add_argument('--beta1', type=float, default=0.5, 101 | help='momentum term of adam') 102 | parser.add_argument( 103 | "--gan_label_mode", 104 | default="smooth", 105 | choices=("hard", "smooth"), 106 | help="whether to use hard (real 1.0 and fake 0.0) or smooth " 107 | "(real [0.7, 1.1] and fake [0., 0.3]) values for labels", 108 | ) 109 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 110 | return parser 111 | 112 | 113 | def __init__(self, opt): 114 | """Initialize the pix2pix class. 115 | 116 | Parameters: 117 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 118 | """ 119 | BaseModel.__init__(self, opt) 120 | # specify the training losses you want to print out. The training/test scripts will call 121 | self.loss_names = ['G', 'G_GAN', 'G_L1', 'D_real', 'D_fake'] 122 | # specify the images you want to save/display. The training/test scripts will call 123 | # self.visual_names = ['real_A', 'fake_B', 'real_B'] 124 | self.visual_names = ['cloth_decoded', 'fakes_scaled', 'textures_unnormalized'] 125 | # specify the models you want to save to the disk. The training/test scripts will call and 126 | if self.is_train: 127 | self.model_names = ['G', 'D'] 128 | else: # during test time, only load G 129 | self.model_names = ['G'] 130 | # define networks (both generator and discriminator) 131 | self.net_G = define_G(opt.cloth_channels + 36, opt.texture_channels, 64, "unet_128", opt.norm, True, opt.init_type, opt.init_gain).to(self.device) 132 | 133 | if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc 134 | self.net_D = define_D(opt.cloth_channels + 36 + opt.texture_channels, 64, opt.discriminator, opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain).to(self.device) 135 | 136 | if self.is_train: 137 | # define loss functions 138 | use_smooth = True if opt.gan_label_mode == "smooth" else False 139 | self.criterionGAN = GANLoss(opt.gan_mode, smooth_labels=use_smooth).to(self.device) 140 | self.criterionL1 = torch.nn.L1Loss() 141 | # initialize optimizers; schedulers will be automatically created by function . 142 | self.optimizer_G = torch.optim.Adam(self.net_G.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 143 | self.optimizer_D = torch.optim.Adam(self.net_D.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 144 | # self.optimizers.append(self.optimizer_G) 145 | # self.optimizers.append(self.optimizer_D) 146 | 147 | def set_input(self, input): 148 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 149 | 150 | Parameters: 151 | input (dict): include the data itself and its metadata information. 152 | 153 | The option 'direction' can be used to swap images in domain A and domain B. 154 | """ 155 | # AtoB = self.opt.direction == 'AtoB' 156 | # self.real_A = input['A' if AtoB else 'B'].to(self.device) 157 | # self.real_B = input['B' if AtoB else 'A'].to(self.device) 158 | to_concat = torch.zeros((self.opt.batch_size, 36, self.opt.crop_size, self.opt.crop_size), device=self.device) 159 | self.real_A = torch.cat((to_concat, input["cloths"].to(self.device)), 1) 160 | 161 | # self.real_A = torch.randn_like(cloth_tensor).to(self.device) 162 | self.real_B = input["target_textures"].to(self.device) 163 | 164 | def compute_visuals(self): 165 | self.cloth_decoded = decode_cloth_labels(self.real_A) 166 | self.fakes_scaled = scale_tensor(self.fake_B) 167 | self.textures_unnormalized = unnormalize( 168 | self.real_B, *self.opt.texture_norm_stats 169 | ) 170 | 171 | def forward(self): 172 | """Run forward pass; called by both functions and .""" 173 | self.fake_B = self.net_G(self.real_A) # G(A) 174 | 175 | def backward_D(self): 176 | """Calculate GAN loss for the discriminator""" 177 | # Fake; stop backprop to the generator by detaching fake_B 178 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) # we use conditional GANs; we need to feed both input and output to the discriminator 179 | pred_fake = self.net_D(fake_AB.detach()) 180 | self.loss_D_fake = self.criterionGAN(pred_fake, False) 181 | # Real 182 | real_AB = torch.cat((self.real_A, self.real_B), 1) 183 | pred_real = self.net_D(real_AB) 184 | self.loss_D_real = self.criterionGAN(pred_real, True) 185 | # combine loss and calculate gradients 186 | self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 187 | self.loss_D.backward() 188 | 189 | def backward_G(self): 190 | """Calculate GAN and L1 loss for the generator""" 191 | # First, G(A) should fake the discriminator 192 | fake_AB = torch.cat((self.real_A, self.fake_B), 1) 193 | pred_fake = self.net_D(fake_AB) 194 | self.loss_G_GAN = self.criterionGAN(pred_fake, True) 195 | # Second, G(A) = B 196 | self.loss_G_L1 = self.criterionL1(self.fake_B, self.real_B) * self.opt.lambda_l1 197 | # combine loss and calculate gradients 198 | self.loss_G = self.loss_G_GAN + self.loss_G_L1 199 | self.loss_G.backward() 200 | 201 | def optimize_parameters(self): 202 | self.forward() # compute fake images: G(A) 203 | # update D 204 | self.set_requires_grad(self.net_D, True) # enable backprop for D 205 | self.optimizer_D.zero_grad() # set D's gradients to zero 206 | self.backward_D() # calculate gradients for D 207 | self.optimizer_D.step() # update D's weights 208 | # update G 209 | self.set_requires_grad(self.net_D, False) # D requires no gradients when optimizing G 210 | self.optimizer_G.zero_grad() # set G's gradients to zero 211 | self.backward_G() # calculate graidents for G 212 | self.optimizer_G.step() # udpate G's weights 213 | -------------------------------------------------------------------------------- /models/texture_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import modules.losses 7 | from datasets.data_utils import unnormalize, scale_tensor 8 | from models.base_gan import BaseGAN 9 | from modules import get_norm_layer 10 | from modules.pix2pix_modules import UnetGenerator 11 | from modules.swapnet_modules import TextureModule 12 | from util.decode_labels import decode_cloth_labels 13 | from util.draw_rois import draw_rois_on_texture 14 | 15 | 16 | class TextureModel(BaseGAN): 17 | """ 18 | Implements training steps of the SwapNet Texture Module. 19 | """ 20 | 21 | @staticmethod 22 | def modify_commandline_options(parser: ArgumentParser, is_train): 23 | parser = super(TextureModel, TextureModel).modify_commandline_options( 24 | parser, is_train 25 | ) 26 | if is_train: 27 | parser.add_argument( 28 | "--netG", 29 | default="swapnet", 30 | choices=["swapnet", "unet_128"] 31 | ) 32 | parser.add_argument( 33 | "--lambda_l1", 34 | type=float, 35 | default=10, 36 | help="weight for L1 loss in final term", 37 | ) 38 | parser.add_argument( 39 | "--lambda_content", 40 | type=float, 41 | default=20, 42 | help="weight for content loss in final term", 43 | ) 44 | parser.add_argument( 45 | "--lambda_style", 46 | type=float, 47 | default=1e-8, # experimentally set to be within the same magnitude as l1 and content 48 | help="weight for content loss in final term", 49 | ) 50 | # based on the num entries in self.visual_names during training 51 | parser.set_defaults(display_ncols=5) 52 | return parser 53 | 54 | def __init__(self, opt): 55 | super().__init__(opt) 56 | 57 | # TODO: decode cloth visual 58 | self.visual_names = [ 59 | "textures_unnormalized", 60 | "cloths_decoded", 61 | "fakes", 62 | "fakes_scaled", 63 | ] 64 | if self.is_train: 65 | self.visual_names.append("targets_unnormalized") 66 | # Define additional loss for generator 67 | self.criterion_L1 = nn.L1Loss().to(self.device) 68 | self.criterion_perceptual = modules.losses.PerceptualLoss( 69 | use_style=opt.lambda_style != 0).to(self.device) 70 | 71 | for loss in ["l1", "content", "style"]: 72 | if getattr(opt, "lambda_" + loss) != 0: 73 | self.loss_names.append(f"G_{loss}") 74 | 75 | def compute_visuals(self): 76 | self.textures_unnormalized = unnormalize( 77 | self.textures, *self.opt.texture_norm_stats 78 | ) 79 | self.textures_unnormalized = draw_rois_on_texture( 80 | self.rois, self.textures_unnormalized 81 | ) 82 | self.cloths_decoded = decode_cloth_labels(self.cloths) 83 | 84 | self.fakes_scaled = scale_tensor(self.fakes, scale_each=True) 85 | 86 | if self.is_train: 87 | self.targets_unnormalized = unnormalize( 88 | self.targets, *self.opt.texture_norm_stats 89 | ) 90 | # all batch, only first 3 channels 91 | # self.DEBUG_random_input = self.net_generator.DEBUG_random_input[:, :3] # take the top 3 layers, to 'sample' the RGB image 92 | 93 | def get_D_inchannels(self): 94 | return self.opt.texture_channels + self.opt.cloth_channels 95 | 96 | def define_G(self): 97 | if self.opt.netG == "unet_128": 98 | norm_layer = get_norm_layer("batch") 99 | return UnetGenerator( 100 | self.opt.texture_channels, self.opt.texture_channels, 7, 64, norm_layer=norm_layer, use_dropout=True 101 | ) 102 | elif self.opt.netG == "swapnet": 103 | return TextureModule( 104 | texture_channels=self.opt.texture_channels, 105 | cloth_channels=self.opt.cloth_channels, 106 | num_roi=self.opt.body_channels, 107 | img_size=self.opt.crop_size, 108 | norm_type=self.opt.norm, 109 | ) 110 | else: 111 | raise ValueError("Cannot find implementation for " + self.opt.netG) 112 | 113 | def set_input(self, input): 114 | self.textures = input["input_textures"].to(self.device) 115 | self.rois = input["rois"].to(self.device) 116 | self.cloths = input["cloths"].to(self.device) 117 | self.targets = input["target_textures"].to(self.device) 118 | 119 | self.image_paths = tuple(zip(input["cloth_paths"], input["texture_paths"])) 120 | 121 | def forward(self): 122 | if self.opt.netG == "swapnet": 123 | self.fakes = self.net_generator(self.textures, self.rois, self.cloths) 124 | elif self.opt.netG.startswith("unet_"): 125 | self.fakes = self.net_generator(self.textures) 126 | 127 | def backward_D(self): 128 | """ 129 | Calculates loss and backpropagates for the discriminator 130 | """ 131 | # https://github.com/martinarjovsky/WassersteinGAN/blob/f7a01e82007ea408647c451b9e1c8f1932a3db67/main.py#L185 132 | if self.opt.gan_mode == "wgan": 133 | # clamp parameters to a cube 134 | for p in self.net_discriminator.parameters(): 135 | p.data.clamp(-0.01, 0.01) 136 | 137 | # calculate fake 138 | fake_AB = torch.cat((self.cloths, self.fakes), 1) 139 | pred_fake = self.net_discriminator(fake_AB.detach()) 140 | self.loss_D_fake = self.criterion_GAN(pred_fake, False) 141 | # calculate real 142 | real_AB = torch.cat((self.cloths, self.targets), 1) 143 | pred_real = self.net_discriminator(real_AB) 144 | self.loss_D_real = self.criterion_GAN(pred_real, True) 145 | 146 | self.loss_D = 0.5 * (self.loss_D_fake + self.loss_D_real) 147 | 148 | if any(gp_mode in self.opt.gan_mode for gp_mode in ["gp", "lp"]): 149 | # calculate gradient penalty 150 | self.loss_D_gp = modules.loss.gradient_penalty( 151 | self.net_discriminator, self.targets, self.fakes, self.opt.gan_mode 152 | ) 153 | self.loss_D += self.opt.lambda_gp * self.loss_D_gp 154 | 155 | self.loss_D.backward() 156 | 157 | def backward_G(self): 158 | """ 159 | Backward G for Texture stage. 160 | Loss composed of GAN loss, L1 loss, and feature loss. 161 | Returns: 162 | 163 | """ 164 | fake_AB = torch.cat((self.cloths, self.fakes), 1) 165 | pred_fake = self.net_discriminator(fake_AB) 166 | self.loss_G_gan = self.criterion_GAN(pred_fake, True) * self.opt.lambda_gan 167 | 168 | self.loss_G_l1 = ( 169 | self.criterion_L1(self.fakes, self.targets) * self.opt.lambda_l1 170 | ) 171 | self.loss_G_content = self.loss_G_style = 0 172 | if self.opt.lambda_content != 0 or self.opt.lambda_style != 0: 173 | self.loss_G_content, self.loss_G_style = self.criterion_perceptual( 174 | self.fakes, self.targets) 175 | self.loss_G_content *= self.opt.lambda_content 176 | self.loss_G_style *= self.opt.lambda_style 177 | 178 | # weighted sum 179 | self.loss_G = self.loss_G_gan + self.loss_G_l1 + self.loss_G_content + self.loss_G_style 180 | self.loss_G.backward() 181 | -------------------------------------------------------------------------------- /models/warp_model.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from torch import nn 5 | 6 | import modules.loss 7 | from datasets.data_utils import unnormalize, remove_top_dir 8 | from models import BaseModel 9 | from models.base_gan import BaseGAN 10 | from modules.swapnet_modules import WarpModule 11 | from util.decode_labels import decode_cloth_labels 12 | 13 | 14 | class WarpModel(BaseGAN): 15 | """ 16 | Implements training steps of the SwapNet Texture Module. 17 | """ 18 | 19 | @staticmethod 20 | def modify_commandline_options(parser: ArgumentParser, is_train): 21 | """ 22 | Adds warp_mode option for generator loss. This is because Khiem found out using 23 | plain Cross Entropy works just fine. CE mode saves time and space by not having 24 | to train an additional discriminator network. 25 | """ 26 | if is_train: 27 | parser.add_argument("--warp_mode", default="gan", choices=("gan", "ce")) 28 | parser.add_argument( 29 | "--lambda_ce", 30 | type=float, 31 | default=100, 32 | help="weight for cross entropy loss in final term", 33 | ) 34 | # based on the num entries in self.visual_names during training 35 | parser.set_defaults(display_ncols=4) 36 | # https://stackoverflow.com/questions/26788214/super-and-staticmethod-interaction 37 | parser = super(WarpModel, WarpModel).modify_commandline_options( 38 | parser, is_train 39 | ) 40 | return parser 41 | 42 | def __init__(self, opt): 43 | """ 44 | Initialize the WarpModel. Either in GAN mode or plain Cross Entropy mode. 45 | Args: 46 | opt: 47 | """ 48 | # 3 for RGB 49 | self.body_channels = ( 50 | opt.body_channels if opt.body_representation == "labels" else 3 51 | ) 52 | # 3 for RGB 53 | self.cloth_channels = ( 54 | opt.cloth_channels if opt.cloth_representation == "labels" else 3 55 | ) 56 | 57 | BaseGAN.__init__(self, opt) 58 | 59 | # TODO: decode visuals for cloth 60 | self.visual_names = ["inputs_decoded", "bodys_unnormalized", "fakes_decoded"] 61 | 62 | if self.is_train: 63 | self.visual_names.append( 64 | "targets_decoded" 65 | ) # only show targets during training 66 | # we use cross entropy loss in both 67 | self.criterion_CE = nn.CrossEntropyLoss() 68 | if opt.warp_mode != "gan": 69 | # remove discriminator related things if no GAN 70 | self.model_names = ["generator"] 71 | self.loss_names = "G" 72 | del self.net_discriminator 73 | del self.optimizer_D 74 | self.optimizer_names = ["G"] 75 | else: 76 | self.loss_names += ["G_ce"] 77 | 78 | def compute_visuals(self): 79 | self.inputs_decoded = decode_cloth_labels(self.inputs) 80 | self.bodys_unnormalized = unnormalize(self.bodys, *self.opt.body_norm_stats) 81 | self.targets_decoded = decode_cloth_labels(self.targets) 82 | self.fakes_decoded = decode_cloth_labels(self.fakes) 83 | 84 | def define_G(self): 85 | """ 86 | The generator is the Warp Module. 87 | """ 88 | return WarpModule( 89 | body_channels=self.body_channels, cloth_channels=self.cloth_channels 90 | ) 91 | 92 | def get_D_inchannels(self): 93 | """ 94 | The Warp stage discriminator is a conditional discriminator. 95 | This means we concatenate the generated warped cloth with the body segmentation. 96 | """ 97 | return self.cloth_channels + self.body_channels 98 | 99 | def set_input(self, input): 100 | self.bodys = input["bodys"].to(self.device) 101 | self.inputs = input["input_cloths"].to(self.device) 102 | self.targets = input["target_cloths"].to(self.device) 103 | 104 | self.image_paths = tuple(zip(input["cloth_paths"], input["body_paths"])) 105 | 106 | def forward(self): 107 | self.fakes = self.net_generator(self.bodys, self.inputs) 108 | 109 | def backward_D(self): 110 | """ 111 | Warp stage's custom backward_D implementation passes CONDITIONED input to 112 | the discriminator. Concats the bodys with the cloth 113 | """ 114 | # calculate fake 115 | conditioned_fake = torch.cat((self.bodys, self.fakes), 1) 116 | pred_fake = self.net_discriminator(conditioned_fake.detach()) 117 | self.loss_D_fake = self.criterion_GAN(pred_fake, False) 118 | # calculate real 119 | conditioned_real = torch.cat((self.bodys, self.targets), 1) 120 | pred_real = self.net_discriminator(conditioned_real) 121 | self.loss_D_real = self.criterion_GAN(pred_real, True) 122 | 123 | self.loss_D = 0.5 * (self.loss_D_fake + self.loss_D_real) 124 | 125 | # calculate gradient penalty 126 | if any(gp_mode in self.opt.gan_mode for gp_mode in ["gp", "lp"]): 127 | self.loss_D_gp = ( 128 | modules.loss.gradient_penalty( 129 | self.net_discriminator, 130 | conditioned_real, 131 | conditioned_fake, 132 | self.opt.gan_mode, 133 | ) 134 | * self.opt.lambda_gp 135 | ) 136 | self.loss_D += self.loss_D_gp 137 | 138 | # final loss 139 | self.loss_D.backward() 140 | 141 | def backward_G(self): 142 | """ 143 | If GAN mode, loss is weighted sum of cross entropy loss and adversarial GAN 144 | loss. Else, loss is just cross entropy loss. 145 | """ 146 | # cross entropy loss needed for both gan mode and ce mode 147 | loss_ce = ( 148 | self.criterion_CE(self.fakes, torch.argmax(self.targets, dim=1)) 149 | * self.opt.lambda_ce 150 | ) 151 | 152 | # if we're in GAN mode, calculate adversarial loss too 153 | if self.opt.warp_mode == "gan": 154 | self.loss_G_ce = loss_ce # store loss_ce 155 | 156 | # calculate adversarial loss 157 | conditioned_fake = torch.cat((self.bodys, self.fakes), 1) 158 | pred_fake = self.net_discriminator(conditioned_fake) 159 | self.loss_G_gan = self.criterion_GAN(pred_fake, True) * self.opt.lambda_gan 160 | 161 | # total loss is weighted sum 162 | self.loss_G = self.loss_G_gan + self.loss_G_ce 163 | else: 164 | # otherwise our only loss is cross entropy 165 | self.loss_G = loss_ce 166 | 167 | self.loss_G.backward() 168 | 169 | def optimize_parameters(self): 170 | """ 171 | Optimize both G and D if in GAN mode, else just G. 172 | Returns: 173 | 174 | """ 175 | if self.opt.warp_mode == "gan": 176 | # will optimize both D and G 177 | super().optimize_parameters() 178 | else: 179 | self.forward() 180 | # optimize G only 181 | self.optimizer_G.zero_grad() 182 | self.backward_G() 183 | self.optimizer_G.step() 184 | -------------------------------------------------------------------------------- /modules/__init__.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | from torch import nn 4 | from torch.nn import init 5 | 6 | 7 | def init_weights(net, init_type="normal", init_gain=0.02): 8 | """Initialize network weights. 9 | 10 | Parameters: 11 | net (network) -- network to be initialized 12 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 13 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 14 | 15 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 16 | work better for some applications. Feel free to try yourself. 17 | """ 18 | 19 | def init_func(m): # define the initialization function 20 | classname = m.__class__.__name__ 21 | if hasattr(m, "weight") and ( 22 | classname.find("Conv") != -1 or classname.find("Linear") != -1 23 | ): 24 | if init_type == "normal": 25 | init.normal_(m.weight.data, 0.0, init_gain) 26 | elif init_type == "xavier": 27 | init.xavier_normal_(m.weight.data, gain=init_gain) 28 | elif init_type == "kaiming": 29 | init.kaiming_normal_(m.weight.data, a=0, mode="fan_in") 30 | elif init_type == "orthogonal": 31 | init.orthogonal_(m.weight.data, gain=init_gain) 32 | else: 33 | raise NotImplementedError( 34 | "initialization method [%s] is not implemented" % init_type 35 | ) 36 | if hasattr(m, "bias") and m.bias is not None: 37 | init.constant_(m.bias.data, 0.0) 38 | elif ( 39 | classname.find("BatchNorm2d") != -1 40 | ): # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 41 | init.normal_(m.weight.data, 1.0, init_gain) 42 | init.constant_(m.bias.data, 0.0) 43 | 44 | print("initialize network with %s" % init_type) 45 | net.apply(init_func) # apply the initialization function 46 | 47 | 48 | class Identity(nn.Module): 49 | def forward(self, x): 50 | return x 51 | 52 | 53 | def get_norm_layer(norm_type="instance"): 54 | """Return a normalization layer 55 | 56 | Parameters: 57 | norm_type (str) -- the name of the normalization layer: batch | instance | none 58 | 59 | For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). 60 | For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. 61 | """ 62 | if norm_type == "batch": 63 | norm_layer = functools.partial( 64 | nn.BatchNorm2d, affine=True, track_running_stats=True 65 | ) 66 | elif norm_type == "instance": 67 | norm_layer = functools.partial( 68 | nn.InstanceNorm2d, affine=False, track_running_stats=False 69 | ) 70 | elif norm_type == "none": 71 | norm_layer = lambda x: Identity() 72 | else: 73 | raise NotImplementedError("normalization layer [%s] is not found" % norm_type) 74 | return norm_layer 75 | -------------------------------------------------------------------------------- /modules/discriminators.py: -------------------------------------------------------------------------------- 1 | """ 2 | Discriminators to be used in GAN systems. 3 | """ 4 | import torch 5 | from torch import nn 6 | import functools 7 | 8 | from modules import get_norm_layer 9 | 10 | 11 | class Discriminator(nn.Module): 12 | def __init__(self, in_channels=3, img_size=512): 13 | super(Discriminator, self).__init__() 14 | 15 | def discriminator_block(in_feat, out_feat, bn=True): 16 | block = [ 17 | nn.Conv2d(in_feat, out_feat, 3, 2, 1), 18 | nn.LeakyReLU(0.2, inplace=True), 19 | nn.Dropout2d(0.25), 20 | ] 21 | if bn: 22 | block.append(nn.BatchNorm2d(out_feat, 0.8)) 23 | return block 24 | 25 | self.model = nn.Sequential( 26 | *discriminator_block(in_channels, 16, bn=False), 27 | *discriminator_block(16, 32), 28 | *discriminator_block(32, 64), 29 | *discriminator_block(64, 128), 30 | ) 31 | 32 | # The height and width of downsampled image 33 | ds_size = img_size // 2 ** 4 34 | self.adv_layer = nn.Sequential( 35 | nn.Linear(128 * ds_size ** 2, 1), # a linear layer 36 | ) 37 | 38 | def forward(self, input): 39 | out = self.model(input) 40 | out = out.view(out.shape[0], -1) 41 | validity = self.adv_layer(out) 42 | 43 | return validity 44 | 45 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', 46 | init_gain=0.02, gpu_ids=[]): 47 | """Create a discriminator 48 | 49 | Parameters: 50 | input_nc (int) -- the number of channels in input images 51 | ndf (int) -- the number of filters in the first conv layer 52 | netD (str) -- the architecture's name: basic | n_layers | pixel 53 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 54 | norm (str) -- the type of normalization layers used in the network. 55 | init_type (str) -- the name of the initialization method. 56 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 57 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 58 | 59 | Returns a discriminator 60 | 61 | Our current implementation provides three types of discriminators: 62 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 63 | It can classify whether 70×70 overlapping patches are real or fake. 64 | Such a patch-level discriminator architecture has fewer parameters 65 | than a full-image discriminator and can work on arbitrarily-sized images 66 | in a fully convolutional fashion. 67 | 68 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator 69 | with the parameter (default=3 as used in [basic] (PatchGAN).) 70 | 71 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 72 | It encourages greater color diversity but has no effect on spatial statistics. 73 | 74 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 75 | """ 76 | net = None 77 | norm_layer = get_norm_layer(norm_type=norm) 78 | 79 | if netD == 'basic': # default PatchGAN classifier 80 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 81 | elif netD == 'n_layers': # more options 82 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 83 | elif netD == 'pixel': # classify if each pixel is real or fake 84 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 85 | else: 86 | raise NotImplementedError( 87 | 'Discriminator model name [%s] is not recognized' % netD) 88 | return net 89 | 90 | 91 | class NLayerDiscriminator(nn.Module): 92 | """Defines a PatchGAN discriminator""" 93 | 94 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 95 | """Construct a PatchGAN discriminator 96 | 97 | Parameters: 98 | input_nc (int) -- the number of channels in input images 99 | ndf (int) -- the number of filters in the last conv layer 100 | n_layers (int) -- the number of conv layers in the discriminator 101 | norm_layer -- normalization layer 102 | """ 103 | super(NLayerDiscriminator, self).__init__() 104 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 105 | use_bias = norm_layer.func == nn.InstanceNorm2d 106 | else: 107 | use_bias = norm_layer == nn.InstanceNorm2d 108 | 109 | kw = 4 110 | padw = 1 111 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] 112 | nf_mult = 1 113 | nf_mult_prev = 1 114 | for n in range(1, n_layers): # gradually increase the number of filters 115 | nf_mult_prev = nf_mult 116 | nf_mult = min(2 ** n, 8) 117 | sequence += [ 118 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 119 | norm_layer(ndf * nf_mult), 120 | nn.LeakyReLU(0.2, True) 121 | ] 122 | 123 | nf_mult_prev = nf_mult 124 | nf_mult = min(2 ** n_layers, 8) 125 | sequence += [ 126 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 127 | norm_layer(ndf * nf_mult), 128 | nn.LeakyReLU(0.2, True) 129 | ] 130 | 131 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 132 | self.model = nn.Sequential(*sequence) 133 | 134 | def forward(self, input): 135 | """Standard forward.""" 136 | return self.model(input) 137 | 138 | 139 | class PixelDiscriminator(nn.Module): 140 | """Defines a 1x1 PatchGAN discriminator (pixelGAN)""" 141 | 142 | def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d): 143 | """Construct a 1x1 PatchGAN discriminator 144 | 145 | Parameters: 146 | input_nc (int) -- the number of channels in input images 147 | ndf (int) -- the number of filters in the last conv layer 148 | norm_layer -- normalization layer 149 | """ 150 | super(PixelDiscriminator, self).__init__() 151 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 152 | use_bias = norm_layer.func == nn.InstanceNorm2d 153 | else: 154 | use_bias = norm_layer == nn.InstanceNorm2d 155 | 156 | self.net = [ 157 | nn.Conv2d(input_nc, ndf, kernel_size=1, stride=1, padding=0), 158 | nn.LeakyReLU(0.2, True), 159 | nn.Conv2d(ndf, ndf * 2, kernel_size=1, stride=1, padding=0, bias=use_bias), 160 | norm_layer(ndf * 2), 161 | nn.LeakyReLU(0.2, True), 162 | nn.Conv2d(ndf * 2, 1, kernel_size=1, stride=1, padding=0, bias=use_bias)] 163 | 164 | self.net = nn.Sequential(*self.net) 165 | 166 | def forward(self, input): 167 | """Standard forward.""" 168 | return self.net(input) 169 | -------------------------------------------------------------------------------- /modules/layers.py: -------------------------------------------------------------------------------- 1 | """ 2 | Layers / building blocks meant to be used by larger modules. 3 | """ 4 | import torch 5 | from torch import nn 6 | 7 | ############################## 8 | # U-NET 9 | ############################## 10 | 11 | 12 | class UNetDown(nn.Module): 13 | def __init__(self, in_size, out_size, normalize=True, dropout=0.0): 14 | super(UNetDown, self).__init__() 15 | layers = [nn.Conv2d(in_size, out_size, 4, 2, 1, bias=False)] 16 | if normalize: 17 | layers.append(nn.InstanceNorm2d(out_size)) 18 | layers.append(nn.LeakyReLU(0.2)) 19 | if dropout: 20 | layers.append(nn.Dropout(dropout)) 21 | self.model = nn.Sequential(*layers) 22 | 23 | def forward(self, x): 24 | return self.model(x) 25 | 26 | 27 | class UNetUp(nn.Module): 28 | def __init__(self, in_size, out_size, dropout=0.0): 29 | super(UNetUp, self).__init__() 30 | layers = [ 31 | nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), 32 | nn.InstanceNorm2d(out_size), 33 | nn.ReLU(inplace=True), 34 | ] 35 | if dropout: 36 | layers.append(nn.Dropout(dropout)) 37 | 38 | self.model = nn.Sequential(*layers) 39 | 40 | def forward(self, x, skip_input): 41 | x = self.model(x) 42 | if skip_input is not None: 43 | x = torch.cat((x, skip_input), 1) 44 | return x 45 | 46 | 47 | class DualUNetUp(UNetUp): 48 | """ 49 | My guess of how dual u-net works, according to the paper 50 | "Multi-View Image Generation from a Single-View" 51 | 52 | @author Andrew 53 | """ 54 | 55 | def __init__(self, in_size, out_size, dropout=0.0): 56 | super(DualUNetUp, self).__init__(in_size, out_size, dropout) 57 | 58 | def forward(self, x, skip_input_1, skip_input_2): 59 | x = self.model(x) 60 | # print("DualUNetUp before cat:", x.shape) 61 | x = torch.cat((x, skip_input_1, skip_input_2), 1) 62 | 63 | return x 64 | 65 | 66 | class GeneratorUNet(nn.Module): 67 | """ 68 | Just for reference, how to create a U-Net 69 | """ 70 | def __init__(self, in_channels=3, out_channels=3): 71 | super(GeneratorUNet, self).__init__() 72 | 73 | self.down1 = UNetDown(in_channels, 64, normalize=False) 74 | self.down2 = UNetDown(64, 128) 75 | self.down3 = UNetDown(128, 256) 76 | self.down4 = UNetDown(256, 512, dropout=0.5) 77 | 78 | self.down5 = UNetDown(512, 512, dropout=0.5) 79 | self.down6 = UNetDown(512, 512, dropout=0.5) 80 | self.down7 = UNetDown(512, 512, dropout=0.5) 81 | self.down8 = UNetDown(512, 512, normalize=False, dropout=0.5) 82 | self.up1 = UNetUp(512, 512, dropout=0.5) 83 | self.up2 = UNetUp(1024, 512, dropout=0.5) 84 | self.up3 = UNetUp(1024, 512, dropout=0.5) 85 | self.up4 = UNetUp(1024, 512, dropout=0.5) 86 | 87 | self.up5 = UNetUp(1024, 256) 88 | self.up6 = UNetUp(512, 128) 89 | self.up7 = UNetUp(256, 64) 90 | self.final = nn.Sequential( 91 | nn.Upsample(scale_factor=2), 92 | nn.ZeroPad2d((1, 0, 1, 0)), 93 | nn.Conv2d(128, out_channels, 4, padding=1), 94 | nn.Tanh(), 95 | ) 96 | 97 | def forward(self, x): 98 | # U-Net generator with skip connections from encoder to decoder 99 | print("x shape:", x.shape) 100 | d1 = self.down1(x) 101 | d2 = self.down2(d1) 102 | d3 = self.down3(d2) 103 | d4 = self.down4(d3) 104 | print("d4 shape:", d4.shape) 105 | d5 = self.down5(d4) 106 | d6 = self.down6(d5) 107 | d7 = self.down7(d6) 108 | d8 = self.down8(d7) 109 | u1 = self.up1(d8, d7) 110 | u2 = self.up2(u1, d6) 111 | u3 = self.up3(u2, d5) 112 | u4 = self.up4(u3, d4) 113 | print("u4 shape:", u4.shape) 114 | u5 = self.up5(u4, d3) 115 | print("u5 shape:", u5.shape) 116 | u6 = self.up6(u5, d2) 117 | u7 = self.up7(u6, d1) 118 | 119 | return self.final(u7) 120 | 121 | ############################## 122 | # ResBlock 123 | ############################## 124 | 125 | 126 | class ResidualBlock(nn.Module): 127 | def __init__(self, in_features, dropout=0.0): 128 | super(ResidualBlock, self).__init__() 129 | 130 | conv_block = [ 131 | nn.ReflectionPad2d(1), 132 | nn.Conv2d(in_features, in_features, 3), 133 | nn.InstanceNorm2d(in_features), 134 | nn.ReLU(inplace=True), 135 | nn.Dropout(dropout), # added by AJ 136 | nn.ReflectionPad2d(1), 137 | nn.Conv2d(in_features, in_features, 3), 138 | nn.InstanceNorm2d(in_features), 139 | ] 140 | 141 | self.conv_block = nn.Sequential(*conv_block) 142 | 143 | def forward(self, x): 144 | return x + self.conv_block(x) 145 | 146 | -------------------------------------------------------------------------------- /modules/loss.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom loss modules 3 | """ 4 | 5 | from abc import ABC 6 | 7 | import torch 8 | from torch import nn 9 | from torchvision import models 10 | 11 | 12 | class GANLoss(nn.Module): 13 | """Define different GAN objectives. 14 | 15 | The GANLoss class abstracts away the need to create the target label tensor 16 | that has the same size as the input. 17 | """ 18 | 19 | default_real = 1.0 20 | default_fake = 0 21 | default_smooth_real = (0.7, 1.1) 22 | default_smooth_fake = (0.0, 0.3) 23 | 24 | def __init__( 25 | self, 26 | gan_mode, 27 | smooth_labels=True, 28 | target_real_label=None, 29 | target_fake_label=None, 30 | ): 31 | """ Initialize the GANLoss class. 32 | 33 | Parameters: 34 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 35 | target_real_label (bool) - - label for a real image 36 | target_fake_label (bool) - - label of a fake image 37 | 38 | Note: Do not use sigmoid as the last layer of Discriminator. 39 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 40 | """ 41 | super().__init__() 42 | if target_real_label is None: 43 | target_real_label = ( 44 | self.default_smooth_real if smooth_labels else self.default_real 45 | ) 46 | if target_fake_label is None: 47 | target_fake_label = ( 48 | self.default_smooth_fake if smooth_labels else self.default_fake 49 | ) 50 | 51 | self.register_buffer("real_label", torch.tensor(target_real_label)) 52 | self.register_buffer("fake_label", torch.tensor(target_fake_label)) 53 | self.gan_mode = gan_mode 54 | if gan_mode == "lsgan": 55 | self.loss = nn.MSELoss() 56 | # according to DRAGAN GitHub, dragan also uses BCE loss: https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb 57 | elif gan_mode in ["vanilla", "dragan", "dragan-gp", "dragan-lp"]: 58 | self.loss = nn.BCEWithLogitsLoss() 59 | elif "wgan" in gan_mode: 60 | self.loss = None 61 | else: 62 | raise NotImplementedError("gan mode %s not implemented" % gan_mode) 63 | 64 | @staticmethod 65 | def rand_between(low, high, normal=False): 66 | """ 67 | Args: 68 | low: a torch.Tensor 69 | high: a torch.Tensor 70 | normal: whether to use normal distribution. if not, will use uniform 71 | 72 | Returns: random tensor between low and high 73 | """ 74 | if normal: 75 | return torch.randn(1) * (high - low) + low 76 | else: 77 | return torch.rand(1) * (high - low) + low 78 | 79 | def get_target_tensor(self, prediction, target_is_real): 80 | """Create label tensors with the same size as the input. 81 | 82 | Parameters: 83 | prediction (tensor) - - tpyically the prediction from a discriminator 84 | target_is_real (bool) - - if the ground truth label is for real images or fake images 85 | 86 | Returns: 87 | A label tensor filled with ground truth label, and with the size of the input 88 | """ 89 | 90 | if target_is_real: 91 | # smooth labels 92 | if len(self.real_label) == 2: 93 | low, high = self.real_label 94 | target_tensor = GANLoss.rand_between(low, high).to( 95 | self.real_label.device 96 | ) 97 | else: 98 | target_tensor = self.real_label 99 | else: 100 | # smooth labels 101 | if len(self.fake_label) == 2: 102 | low, high = self.real_label 103 | target_tensor = GANLoss.rand_between(low, high).to( 104 | self.fake_label.device 105 | ) 106 | else: 107 | target_tensor = self.fake_label 108 | return target_tensor.expand_as(prediction) 109 | 110 | def __call__(self, prediction, target_is_real): 111 | """Calculate loss given Discriminator's output and ground truth labels. 112 | 113 | Parameters: 114 | prediction (tensor) - - typically the prediction output from a discriminator 115 | target_is_real (bool) - - if the ground truth label is for real images or fake images 116 | 117 | Returns: 118 | the calculated loss. 119 | """ 120 | if self.gan_mode in ["lsgan", "vanilla", "dragan-gp", "dragan-lp"]: 121 | target_tensor = self.get_target_tensor(prediction, target_is_real) 122 | loss = self.loss(prediction, target_tensor) 123 | elif "wgan" in self.gan_mode: 124 | if target_is_real: 125 | loss = -prediction.mean() 126 | else: 127 | loss = prediction.mean() 128 | else: 129 | raise ValueError(f"{self.gan_mode} not recognized") 130 | return loss 131 | 132 | 133 | def gradient_penalty(f, real, fake, mode, p_norm=2): 134 | """ 135 | From https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch/blob/master/torchprob/gan/loss.py 136 | Args: 137 | f: a discriminator 138 | real: target 139 | fake: generated 140 | mode: 141 | p_norm: 142 | 143 | Returns: 144 | 145 | """ 146 | 147 | def _gradient_penalty(f, real, fake=None, penalty_type="gp", p_norm=2): 148 | def _interpolate(a, b=None): 149 | if b is None: # interpolation in DRAGAN 150 | beta = torch.rand_like(a) 151 | b = a + 0.5 * a.std() * beta 152 | shape = [a.size(0)] + [1] * (a.dim() - 1) 153 | alpha = torch.rand(shape, device=a.device) 154 | inter = a + alpha * (b - a) 155 | return inter 156 | 157 | x = _interpolate(real, fake).detach() 158 | x.requires_grad = True 159 | pred = f(x) 160 | grad = torch.autograd.grad( 161 | pred, x, grad_outputs=torch.ones_like(pred), create_graph=True 162 | )[0] 163 | norm = grad.view(grad.size(0), -1).norm(p=p_norm, dim=1) 164 | 165 | if penalty_type == "gp": 166 | gp = ((norm - 1) ** 2).mean() 167 | elif penalty_type == "lp": 168 | gp = (torch.max(torch.zeros_like(norm), norm - 1) ** 2).mean() 169 | 170 | return gp 171 | 172 | if not mode or mode == "vanilla": 173 | gp = torch.tensor(0, dtype=real.dtype, device=real.device) 174 | elif mode in ["dragan", "dragan-gp", "dragan-lp"]: 175 | penalty_type = "gp" if mode == "dragan" else mode[-2:] 176 | gp = _gradient_penalty(f, real, penalty_type=penalty_type, p_norm=p_norm) 177 | elif mode in ["wgan-gp", "wgan-lp"]: 178 | gp = _gradient_penalty(f, real, fake, penalty_type=mode[-2:], p_norm=p_norm) 179 | else: 180 | raise ValueError("Don't know how to handle gan mode", mode) 181 | 182 | # TODO: implement mescheder's simplified gradient penalties 183 | 184 | return gp 185 | 186 | 187 | def get_vgg_feature_loss(opt, nlayers): 188 | """ 189 | Initialize a MultieLayerFeatureLoss module with VGG19 feature extractor 190 | :param opt: command line arguments 191 | :param nlayers: 192 | :return: 193 | """ 194 | feature_extractor = models.vgg19(pretrained=True) 195 | vgg_inp_size = 224 196 | scale = vgg_inp_size / opt.crop_size 197 | return MultiLayerFeatureLoss(feature_extractor, scale, num_layers=nlayers) 198 | 199 | 200 | class FeatureLoss(ABC, nn.Module): 201 | def __init__(self, feature_extractor, scale=224 / 512): 202 | super().__init__() 203 | # set to eval mode to disable dropout and such 204 | self.feature_extractor = feature_extractor.eval() 205 | self.scale = scale 206 | 207 | def downsize(self, *inputs): 208 | """ 209 | Downsize the inputs so they match the size required for the pretrained model 210 | :param inputs: 211 | :return: 212 | """ 213 | outs = [] 214 | for a in inputs: 215 | outs.append(nn.functional.interpolate(a, scale_factor=self.scale)) 216 | 217 | return tuple(outs) 218 | 219 | 220 | class L1FeatureLoss(FeatureLoss): 221 | def __init__(self, feature_extractor, scale): 222 | super().__init__(feature_extractor, scale) 223 | self.loss_fn = nn.L1Loss() 224 | 225 | def forward(self, generated, actual): 226 | generated, actual = self.downsize(generated, actual) 227 | generated_feat = self.feature_extractor(generated.detach()) 228 | actual_feat = self.feature_extractor(actual.detach()) 229 | 230 | loss = self.loss_fn(generated_feat, actual_feat) 231 | return loss 232 | 233 | 234 | class MultiLayerFeatureLoss(FeatureLoss): 235 | """ 236 | Computes the feature loss with the last n layers of a deep feature extractor. 237 | """ 238 | 239 | def __init__(self, feature_extractor, scale, loss_fn=nn.L1Loss(), num_layers=3): 240 | """ 241 | :param feature_extractor: an pretrained model, i.e. resnet18(), vgg19() 242 | :param loss_fn: an initialized loss function 243 | :param num_layers: number of layers from the end to keep. e.g. 3 will compute 244 | the loss using the last 3 layers of the feature extractor network 245 | """ 246 | # e.g. VGG 247 | super().__init__(feature_extractor, scale) 248 | 249 | features = list(feature_extractor.features) 250 | self.num_layers = num_layers 251 | self.loss_fn = loss_fn 252 | 253 | self.layer_weights = [i + 1 / num_layers for i in range(num_layers)] 254 | 255 | self.features = nn.ModuleList(features).eval() 256 | 257 | start = len(self.features) - num_layers 258 | end = len(self.features) 259 | self.layers_to_keep = {i for i in range(start, end)} 260 | 261 | def extract_intermediate_layers(self, x): 262 | """ 263 | Extracts features of intermediate layers using the feature extractor 264 | :param x: the input 265 | :return: 266 | """ 267 | results = [] 268 | for ii, model in enumerate(self.features): 269 | x = model(x) 270 | if ii in self.layers_to_keep: 271 | results.append(x) 272 | 273 | return results 274 | 275 | def forward(self, generated, actual): 276 | generated, actual = self.downsize(generated, actual) 277 | generated_feat_list = self.extract_intermediate_layers(generated) 278 | actual_feat_list = self.extract_intermediate_layers(actual) 279 | total_loss = 0 280 | 281 | for i, w in enumerate(self.layer_weights): 282 | total_loss += w * self.loss_fn(generated_feat_list[i], actual_feat_list[i]) 283 | 284 | return total_loss 285 | -------------------------------------------------------------------------------- /modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from typing import Tuple 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .adversarial import GANLoss 8 | from .perceptual import PerceptualLoss 9 | 10 | ############### 11 | # CHARBONNIER # 12 | ############### 13 | 14 | class L1_Charbonnier_loss(nn.Module): 15 | """L1 Charbonnierloss. 16 | Credit: https://github.com/twtygqyy/pytorch-SRDenseNet/blob/a3185aa9838d1746a6c133caa7b57aaad1e40fd0/srdensenet.py#L134 17 | """ 18 | 19 | def __init__(self): 20 | super(L1_Charbonnier_loss, self).__init__() 21 | self.eps = 1e-6 22 | 23 | def forward(self, X, Y): 24 | diff = torch.add(X, -Y) 25 | error = torch.sqrt(diff * diff + self.eps) 26 | loss = torch.sum(error) 27 | return loss 28 | 29 | 30 | def gaussian(window_size, sigma): 31 | x = torch.arange(window_size).float() - window_size // 2 32 | if window_size % 2 == 0: 33 | x = x + 0.5 34 | gauss = torch.exp((-x.pow(2.0) / float(2 * sigma ** 2))) 35 | return gauss / gauss.sum() 36 | 37 | 38 | def get_gaussian_kernel1d(kernel_size: int, 39 | sigma: float, 40 | force_even: bool = False) -> torch.Tensor: 41 | r"""Function that returns Gaussian filter coefficients. 42 | 43 | Args: 44 | kernel_size (int): filter size. It should be odd and positive. 45 | sigma (float): gaussian standard deviation. 46 | force_even (bool): overrides requirement for odd kernel size. 47 | 48 | Returns: 49 | Tensor: 1D tensor with gaussian filter coefficients. 50 | 51 | Shape: 52 | - Output: :math:`(\text{kernel_size})` 53 | 54 | Examples:: 55 | 56 | >>> kornia.image.get_gaussian_kernel(3, 2.5) 57 | tensor([0.3243, 0.3513, 0.3243]) 58 | 59 | >>> kornia.image.get_gaussian_kernel(5, 1.5) 60 | tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201]) 61 | """ 62 | if (not isinstance(kernel_size, int) or ( 63 | (kernel_size % 2 == 0) and not force_even) or ( 64 | kernel_size <= 0)): 65 | raise TypeError( 66 | "kernel_size must be an odd positive integer. " 67 | "Got {}".format(kernel_size) 68 | ) 69 | window_1d: torch.Tensor = gaussian(kernel_size, sigma) 70 | return window_1d 71 | 72 | 73 | def get_gaussian_kernel2d( 74 | kernel_size: Tuple[int, int], 75 | sigma: Tuple[float, float], 76 | force_even: bool = False) -> torch.Tensor: 77 | r"""Function that returns Gaussian filter matrix coefficients. 78 | 79 | Args: 80 | kernel_size (Tuple[int, int]): filter sizes in the x and y direction. 81 | Sizes should be odd and positive. 82 | sigma (Tuple[int, int]): gaussian standard deviation in the x and y 83 | direction. 84 | force_even (bool): overrides requirement for odd kernel size. 85 | 86 | Returns: 87 | Tensor: 2D tensor with gaussian filter matrix coefficients. 88 | 89 | Shape: 90 | - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` 91 | 92 | Examples:: 93 | 94 | >>> kornia.image.get_gaussian_kernel2d((3, 3), (1.5, 1.5)) 95 | tensor([[0.0947, 0.1183, 0.0947], 96 | [0.1183, 0.1478, 0.1183], 97 | [0.0947, 0.1183, 0.0947]]) 98 | 99 | >>> kornia.image.get_gaussian_kernel2d((3, 5), (1.5, 1.5)) 100 | tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370], 101 | [0.0462, 0.0899, 0.1123, 0.0899, 0.0462], 102 | [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]]) 103 | """ 104 | if not isinstance(kernel_size, tuple) or len(kernel_size) != 2: 105 | raise TypeError( 106 | "kernel_size must be a tuple of length two. Got {}".format( 107 | kernel_size 108 | ) 109 | ) 110 | if not isinstance(sigma, tuple) or len(sigma) != 2: 111 | raise TypeError( 112 | "sigma must be a tuple of length two. Got {}".format(sigma) 113 | ) 114 | ksize_x, ksize_y = kernel_size 115 | sigma_x, sigma_y = sigma 116 | kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even) 117 | kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even) 118 | kernel_2d: torch.Tensor = torch.matmul( 119 | kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t() 120 | ) 121 | return kernel_2d 122 | 123 | 124 | ################################## 125 | # SSIM | CREDIT: TorchGeometry library 126 | ################################## 127 | 128 | class SSIM(nn.Module): 129 | r"""Creates a criterion that measures the Structural Similarity (SSIM) 130 | index between each element in the input `x` and target `y`. 131 | 132 | The index can be described as: 133 | 134 | .. math:: 135 | 136 | \text{SSIM}(x, y) = \frac{(2\mu_x\mu_y+c_1)(2\sigma_{xy}+c_2)} 137 | {(\mu_x^2+\mu_y^2+c_1)(\sigma_x^2+\sigma_y^2+c_2)} 138 | 139 | where: 140 | - :math:`c_1=(k_1 L)^2` and :math:`c_2=(k_2 L)^2` are two variables to 141 | stabilize the division with weak denominator. 142 | - :math:`L` is the dynamic range of the pixel-values (typically this is 143 | :math:`2^{\#\text{bits per pixel}}-1`). 144 | 145 | the loss, or the Structural dissimilarity (DSSIM) can be finally described 146 | as: 147 | 148 | .. math:: 149 | 150 | \text{loss}(x, y) = \frac{1 - \text{SSIM}(x, y)}{2} 151 | 152 | Arguments: 153 | window_size (int): the size of the kernel. 154 | max_val (float): the dynamic range of the images. Default: 1. 155 | reduction (str, optional): Specifies the reduction to apply to the 156 | output: 'none' | 'mean' | 'sum'. 'none': no reduction will be applied, 157 | 'mean': the sum of the output will be divided by the number of elements 158 | in the output, 'sum': the output will be summed. Default: 'none'. 159 | 160 | Returns: 161 | Tensor: the ssim index. 162 | 163 | Shape: 164 | - Input: :math:`(B, C, H, W)` 165 | - Target :math:`(B, C, H, W)` 166 | - Output: scale, if reduction is 'none', then :math:`(B, C, H, W)` 167 | 168 | Examples:: 169 | 170 | >>> input1 = torch.rand(1, 4, 5, 5) 171 | >>> input2 = torch.rand(1, 4, 5, 5) 172 | >>> ssim = kornia.losses.SSIM(5, reduction='none') 173 | >>> loss = ssim(input1, input2) # 1x4x5x5 174 | """ 175 | 176 | def __init__( 177 | self, 178 | window_size: int, 179 | reduction: str = 'none', 180 | max_val: float = 1.0) -> None: 181 | super(SSIM, self).__init__() 182 | self.window_size: int = window_size 183 | self.max_val: float = max_val 184 | self.reduction: str = reduction 185 | 186 | self.window: torch.Tensor = get_gaussian_kernel2d( 187 | (window_size, window_size), (1.5, 1.5)) 188 | self.padding: int = self.compute_zero_padding(window_size) 189 | 190 | self.C1: float = (0.01 * self.max_val) ** 2 191 | self.C2: float = (0.03 * self.max_val) ** 2 192 | 193 | @staticmethod 194 | def compute_zero_padding(kernel_size: int) -> int: 195 | """Computes zero padding.""" 196 | return (kernel_size - 1) // 2 197 | 198 | def filter2D( 199 | self, 200 | input: torch.Tensor, 201 | kernel: torch.Tensor, 202 | channel: int) -> torch.Tensor: 203 | return F.conv2d(input, kernel, padding=self.padding, groups=channel) 204 | 205 | def forward( # type: ignore 206 | self, 207 | img1: torch.Tensor, 208 | img2: torch.Tensor) -> torch.Tensor: 209 | if not torch.is_tensor(img1): 210 | raise TypeError("Input img1 type is not a torch.Tensor. Got {}" 211 | .format(type(img1))) 212 | if not torch.is_tensor(img2): 213 | raise TypeError("Input img2 type is not a torch.Tensor. Got {}" 214 | .format(type(img2))) 215 | if not len(img1.shape) == 4: 216 | raise ValueError("Invalid img1 shape, we expect BxCxHxW. Got: {}" 217 | .format(img1.shape)) 218 | if not len(img2.shape) == 4: 219 | raise ValueError("Invalid img2 shape, we expect BxCxHxW. Got: {}" 220 | .format(img2.shape)) 221 | if not img1.shape == img2.shape: 222 | raise ValueError("img1 and img2 shapes must be the same. Got: {}" 223 | .format(img1.shape, img2.shape)) 224 | if not img1.device == img2.device: 225 | raise ValueError("img1 and img2 must be in the same device. Got: {}" 226 | .format(img1.device, img2.device)) 227 | if not img1.dtype == img2.dtype: 228 | raise ValueError("img1 and img2 must be in the same dtype. Got: {}" 229 | .format(img1.dtype, img2.dtype)) 230 | # prepare kernel 231 | b, c, h, w = img1.shape 232 | tmp_kernel: torch.Tensor = self.window.to(img1.device).to(img1.dtype) 233 | kernel: torch.Tensor = tmp_kernel.repeat(c, 1, 1, 1) 234 | 235 | # compute local mean per channel 236 | mu1: torch.Tensor = self.filter2D(img1, kernel, c) 237 | mu2: torch.Tensor = self.filter2D(img2, kernel, c) 238 | 239 | mu1_sq = mu1.pow(2) 240 | mu2_sq = mu2.pow(2) 241 | mu1_mu2 = mu1 * mu2 242 | 243 | # compute local sigma per channel 244 | sigma1_sq = self.filter2D(img1 * img1, kernel, c) - mu1_sq 245 | sigma2_sq = self.filter2D(img2 * img2, kernel, c) - mu2_sq 246 | sigma12 = self.filter2D(img1 * img2, kernel, c) - mu1_mu2 247 | 248 | ssim_map = ((2 * mu1_mu2 + self.C1) * (2 * sigma12 + self.C2)) / \ 249 | ((mu1_sq + mu2_sq + self.C1) * (sigma1_sq + sigma2_sq + self.C2)) 250 | 251 | loss = torch.clamp(torch.tensor(1.) - ssim_map, min=0, max=1) / 2. 252 | 253 | if self.reduction == 'mean': 254 | loss = torch.mean(loss) 255 | elif self.reduction == 'sum': 256 | loss = torch.sum(loss) 257 | elif self.reduction == 'none': 258 | pass 259 | return loss 260 | 261 | 262 | # functional interface 263 | def ssim( 264 | img1: torch.Tensor, 265 | img2: torch.Tensor, 266 | window_size: int, 267 | reduction: str = 'none', 268 | max_val: float = 1.0) -> torch.Tensor: 269 | r"""Function that measures the Structural Similarity (SSIM) index between 270 | each element in the input `x` and target `y`. 271 | 272 | See :class:`~kornia.losses.SSIM` for details. 273 | """ 274 | return SSIM(window_size, reduction, max_val)(img1, img2) 275 | -------------------------------------------------------------------------------- /modules/losses/adversarial.py: -------------------------------------------------------------------------------- 1 | """ 2 | Custom loss modules 3 | """ 4 | 5 | import torch 6 | from torch import nn 7 | 8 | 9 | class GANLoss(nn.Module): 10 | """Define different GAN objectives. 11 | 12 | The GANLoss class abstracts away the need to create the target label tensor 13 | that has the same size as the input. 14 | """ 15 | 16 | default_real = 1.0 17 | default_fake = 0 18 | default_smooth_real = (0.7, 1.1) 19 | default_smooth_fake = (0.0, 0.3) 20 | 21 | def __init__( 22 | self, 23 | gan_mode, 24 | smooth_labels=True, 25 | target_real_label=None, 26 | target_fake_label=None, 27 | ): 28 | """ Initialize the GANLoss class. 29 | 30 | Parameters: 31 | gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. 32 | target_real_label (bool) - - label for a real image 33 | target_fake_label (bool) - - label of a fake image 34 | 35 | Note: Do not use sigmoid as the last layer of Discriminator. 36 | LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. 37 | """ 38 | super().__init__() 39 | if target_real_label is None: 40 | target_real_label = ( 41 | self.default_smooth_real if smooth_labels else self.default_real 42 | ) 43 | if target_fake_label is None: 44 | target_fake_label = ( 45 | self.default_smooth_fake if smooth_labels else self.default_fake 46 | ) 47 | 48 | self.register_buffer("real_label", torch.tensor(target_real_label)) 49 | self.register_buffer("fake_label", torch.tensor(target_fake_label)) 50 | self.gan_mode = gan_mode 51 | if gan_mode == "lsgan": 52 | self.loss = nn.MSELoss() 53 | # according to DRAGAN GitHub, dragan also uses BCE loss: https://github.com/kodalinaveen3/DRAGAN/blob/master/DRAGAN.ipynb 54 | elif gan_mode in ["vanilla", "dragan", "dragan-gp", "dragan-lp"]: 55 | self.loss = nn.BCEWithLogitsLoss() 56 | elif "wgan" in gan_mode: 57 | self.loss = None 58 | else: 59 | raise NotImplementedError("gan mode %s not implemented" % gan_mode) 60 | 61 | @staticmethod 62 | def rand_between(low, high, normal=False): 63 | """ 64 | Args: 65 | low: a torch.Tensor 66 | high: a torch.Tensor 67 | normal: whether to use normal distribution. if not, will use uniform 68 | 69 | Returns: random tensor between low and high 70 | """ 71 | rand_func = torch.randn if normal else torch.rand 72 | return rand_func(1) * (high - low) + low 73 | 74 | def get_target_tensor(self, prediction, target_is_real): 75 | """Create label tensors with the same size as the input. 76 | 77 | Parameters: 78 | prediction (tensor) - - tpyically the prediction from a discriminator 79 | target_is_real (bool) - - if the ground truth label is for real images or fake images 80 | 81 | Returns: 82 | A label tensor filled with ground truth label, and with the size of the input 83 | """ 84 | 85 | if target_is_real: 86 | # smooth labels 87 | if len(self.real_label) == 2: 88 | low, high = self.real_label 89 | target_tensor = GANLoss.rand_between(low, high).to( 90 | self.real_label.device 91 | ) 92 | else: 93 | target_tensor = self.real_label 94 | else: 95 | # smooth labels 96 | if len(self.fake_label) == 2: 97 | low, high = self.real_label 98 | target_tensor = GANLoss.rand_between(low, high).to( 99 | self.fake_label.device 100 | ) 101 | else: 102 | target_tensor = self.fake_label 103 | return target_tensor.expand_as(prediction) 104 | 105 | def __call__(self, prediction, target_is_real): 106 | """Calculate loss given Discriminator's output and ground truth labels. 107 | 108 | Parameters: 109 | prediction (tensor) - - typically the prediction output from a discriminator 110 | target_is_real (bool) - - if the ground truth label is for real images or fake images 111 | 112 | Returns: 113 | the calculated loss. 114 | """ 115 | if self.gan_mode in ["lsgan", "vanilla", "dragan-gp", "dragan-lp"]: 116 | target_tensor = self.get_target_tensor(prediction, target_is_real) 117 | loss = self.loss(prediction, target_tensor) 118 | elif "wgan" in self.gan_mode: 119 | if target_is_real: 120 | loss = -prediction.mean() 121 | else: 122 | loss = prediction.mean() 123 | else: 124 | raise ValueError(f"{self.gan_mode} not recognized") 125 | return loss 126 | 127 | 128 | def gradient_penalty(f, real, fake, mode, p_norm=2): 129 | """ 130 | From https://github.com/LynnHo/DCGAN-LSGAN-WGAN-GP-DRAGAN-Pytorch/blob/master/torchprob/gan/loss.py 131 | Args: 132 | f: a discriminator 133 | real: target 134 | fake: generated 135 | mode: 136 | p_norm: 137 | 138 | Returns: 139 | 140 | """ 141 | 142 | def _gradient_penalty(f, real, fake=None, penalty_type="gp", p_norm=2): 143 | def _interpolate(a, b=None): 144 | if b is None: # interpolation in DRAGAN 145 | beta = torch.rand_like(a) 146 | b = a + 0.5 * a.std() * beta 147 | shape = [a.size(0)] + [1] * (a.dim() - 1) 148 | alpha = torch.rand(shape, device=a.device) 149 | inter = a + alpha * (b - a) 150 | return inter 151 | 152 | x = _interpolate(real, fake).detach() 153 | x.requires_grad = True 154 | pred = f(x) 155 | grad = torch.autograd.grad( 156 | pred, x, grad_outputs=torch.ones_like(pred), create_graph=True 157 | )[0] 158 | norm = grad.view(grad.size(0), -1).norm(p=p_norm, dim=1) 159 | 160 | if penalty_type == "gp": 161 | gp = ((norm - 1) ** 2).mean() 162 | elif penalty_type == "lp": 163 | gp = (torch.max(torch.zeros_like(norm), norm - 1) ** 2).mean() 164 | 165 | return gp 166 | 167 | if not mode or mode == "vanilla": 168 | gp = torch.tensor(0, dtype=real.dtype, device=real.device) 169 | elif mode in ["dragan", "dragan-gp", "dragan-lp"]: 170 | penalty_type = "gp" if mode == "dragan" else mode[-2:] 171 | gp = _gradient_penalty(f, real, penalty_type=penalty_type, p_norm=p_norm) 172 | elif mode in ["wgan-gp", "wgan-lp"]: 173 | gp = _gradient_penalty(f, real, fake, penalty_type=mode[-2:], p_norm=p_norm) 174 | else: 175 | raise ValueError("Don't know how to handle gan mode", mode) 176 | 177 | # TODO: implement mescheder's simplified gradient penalties 178 | 179 | return gp 180 | 181 | 182 | -------------------------------------------------------------------------------- /modules/losses/perceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import vgg16 4 | 5 | 6 | def gram_matrix(tensor): 7 | b, c, h, w = tensor.size() 8 | tensor = tensor.view(b * c, h * w) 9 | gram = torch.mm(tensor, tensor.t()) 10 | return gram 11 | 12 | 13 | class PerceptualLoss(nn.Module): 14 | def __init__(self, normalize=True, use_style=False): 15 | """ 16 | 17 | Args: 18 | normalize: 19 | use_style: whether to calculate style loss using gram matrix 20 | """ 21 | super(PerceptualLoss, self).__init__() 22 | 23 | self.normalize = normalize 24 | self.use_style = use_style 25 | 26 | vgg = vgg16(pretrained=True).features 27 | 28 | slices_idx = [ 29 | [0, 4], # until 5th layer 30 | [4, 9], # until 10th layer 31 | [9, 16], # until 17th layer 32 | [16, 23], # until 23rd layer 33 | [23, 30], # until 31st layer 34 | ] 35 | 36 | self.net = torch.nn.Sequential() 37 | 38 | for i, idx in enumerate(slices_idx): 39 | seq = torch.nn.Sequential() 40 | for j in range(idx[0], idx[1]): 41 | seq.add_module(str(j), vgg[j]) 42 | self.net.add_module(str(i), seq) 43 | 44 | for p in self.parameters(): 45 | p.requires_grad = False 46 | 47 | self.mse = nn.MSELoss() 48 | 49 | def forward(self, output, target): 50 | 51 | output_f = self.get_features(output) 52 | with torch.no_grad(): 53 | target_f = self.get_features(target) 54 | 55 | content_losses = [] 56 | style_losses = [] 57 | 58 | for o, t in zip(output_f, target_f): 59 | content_losses.append(self.mse(o, t)) 60 | if self.use_style: 61 | gram_output = gram_matrix(output) 62 | gram_target = gram_matrix(target) 63 | style_losses.append(self.mse(gram_output, gram_target)) 64 | content_loss = sum(content_losses) 65 | style_loss = sum(style_losses) 66 | return content_loss, style_loss 67 | 68 | def get_features(self, x): 69 | """Assumes x in [0, 1]: transform to [-1, 1].""" 70 | x = 2.0 * x - 1.0 71 | feats = [] 72 | for i, s in enumerate(self.net): 73 | x = s(x) 74 | if self.normalize: # unit L2 norm over features, this implies the loss is a cosine loss in feature space 75 | f = x / (torch.sqrt(torch.pow(x, 2).sum(1, keepdim=True)) + 1e-8) 76 | else: 77 | f = x 78 | feats.append(f) 79 | return feats 80 | -------------------------------------------------------------------------------- /modules/pix2pix_modules.py: -------------------------------------------------------------------------------- 1 | import functools 2 | 3 | import torch 4 | from torch import nn 5 | 6 | from modules import get_norm_layer, init_weights 7 | from modules.discriminators import NLayerDiscriminator, PixelDiscriminator 8 | 9 | 10 | def define_G( 11 | input_nc, 12 | output_nc, 13 | ngf, 14 | netG, 15 | norm="batch", 16 | use_dropout=False, 17 | init_type="normal", 18 | init_gain=0.02, 19 | ): 20 | """Create a generator 21 | 22 | Parameters: 23 | input_nc (int) -- the number of channels in input images 24 | output_nc (int) -- the number of channels in output images 25 | ngf (int) -- the number of filters in the last conv layer 26 | netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 27 | norm (str) -- the name of normalization layers used in the network: batch | instance | none 28 | use_dropout (bool) -- if use dropout layers. 29 | init_type (str) -- the name of our initialization method. 30 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 31 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 32 | 33 | Returns a generator 34 | 35 | Our current implementation provides two types of generators: 36 | U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) 37 | The original U-Net paper: https://arxiv.org/abs/1505.04597 38 | 39 | Resnet-based generator: [resnet_6blocks] (with 6 Resnet blocks) and [resnet_9blocks] (with 9 Resnet blocks) 40 | Resnet-based generator consists of several Resnet blocks between a few downsampling/upsampling operations. 41 | We adapt Torch code from Justin Johnson's neural style transfer project (https://github.com/jcjohnson/fast-neural-style). 42 | 43 | 44 | The generator has been initialized by . It uses RELU for non-linearity. 45 | """ 46 | net = None 47 | norm_layer = get_norm_layer(norm_type=norm) 48 | 49 | if netG == "resnet_9blocks": 50 | raise NotImplementedError 51 | elif netG == "resnet_6blocks": 52 | raise NotImplementedError 53 | elif netG == "unet_128": 54 | net = UnetGenerator( 55 | input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout 56 | ) 57 | elif netG == "unet_256": 58 | net = UnetGenerator( 59 | input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout 60 | ) 61 | else: 62 | raise NotImplementedError("Generator model name [%s] is not recognized" % netG) 63 | init_weights(net, init_type, init_gain) 64 | return net 65 | 66 | def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal', init_gain=0.02): 67 | """Create a discriminator 68 | 69 | Parameters: 70 | input_nc (int) -- the number of channels in input images 71 | ndf (int) -- the number of filters in the first conv layer 72 | netD (str) -- the architecture's name: basic | n_layers | pixel 73 | n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' 74 | norm (str) -- the type of normalization layers used in the network. 75 | init_type (str) -- the name of the initialization method. 76 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 77 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 78 | 79 | Returns a discriminator 80 | 81 | Our current implementation provides three types of discriminators: 82 | [basic]: 'PatchGAN' classifier described in the original pix2pix paper. 83 | It can classify whether 70×70 overlapping patches are real or fake. 84 | Such a patch-level discriminator architecture has fewer parameters 85 | than a full-image discriminator and can work on arbitrarily-sized images 86 | in a fully convolutional fashion. 87 | 88 | [n_layers]: With this mode, you cna specify the number of conv layers in the discriminator 89 | with the parameter (default=3 as used in [basic] (PatchGAN).) 90 | 91 | [pixel]: 1x1 PixelGAN discriminator can classify whether a pixel is real or not. 92 | It encourages greater color diversity but has no effect on spatial statistics. 93 | 94 | The discriminator has been initialized by . It uses Leakly RELU for non-linearity. 95 | """ 96 | net = None 97 | norm_layer = get_norm_layer(norm_type=norm) 98 | 99 | if netD == 'basic': # default PatchGAN classifier 100 | net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer) 101 | elif netD == 'n_layers': # more options 102 | net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer) 103 | elif netD == 'pixel': # classify if each pixel is real or fake 104 | net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer) 105 | else: 106 | raise NotImplementedError('Discriminator model name [%s] is not recognized' % netD) 107 | 108 | init_weights(net, init_type, init_gain) 109 | return net 110 | 111 | 112 | 113 | class UnetGenerator(nn.Module): 114 | """Create a Unet-based generator""" 115 | 116 | def __init__( 117 | self, 118 | input_nc, 119 | output_nc, 120 | num_downs, 121 | ngf=64, 122 | norm_layer=nn.BatchNorm2d, 123 | use_dropout=False, 124 | ): 125 | """Construct a Unet generator 126 | Parameters: 127 | input_nc (int) -- the number of channels in input images 128 | output_nc (int) -- the number of channels in output images 129 | num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, 130 | image of size 128x128 will become of size 1x1 # at the bottleneck 131 | ngf (int) -- the number of filters in the last conv layer 132 | norm_layer -- normalization layer 133 | 134 | We construct the U-Net from the innermost layer to the outermost layer. 135 | It is a recursive process. 136 | """ 137 | super(UnetGenerator, self).__init__() 138 | # construct unet structure 139 | unet_block = UnetSkipConnectionBlock( 140 | ngf * 8, 141 | ngf * 8, 142 | input_nc=None, 143 | submodule=None, 144 | norm_layer=norm_layer, 145 | innermost=True, 146 | ) # add the innermost layer 147 | for i in range(num_downs - 5): # add intermediate layers with ngf * 8 filters 148 | unet_block = UnetSkipConnectionBlock( 149 | ngf * 8, 150 | ngf * 8, 151 | input_nc=None, 152 | submodule=unet_block, 153 | norm_layer=norm_layer, 154 | use_dropout=use_dropout, 155 | ) 156 | # gradually reduce the number of filters from ngf * 8 to ngf 157 | unet_block = UnetSkipConnectionBlock( 158 | ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer 159 | ) 160 | unet_block = UnetSkipConnectionBlock( 161 | ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer 162 | ) 163 | unet_block = UnetSkipConnectionBlock( 164 | ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer 165 | ) 166 | self.model = UnetSkipConnectionBlock( 167 | output_nc, 168 | ngf, 169 | input_nc=input_nc, 170 | submodule=unet_block, 171 | outermost=True, 172 | norm_layer=norm_layer, 173 | ) # add the outermost layer 174 | 175 | def forward(self, input): 176 | """Standard forward""" 177 | return self.model(input) 178 | 179 | 180 | class UnetSkipConnectionBlock(nn.Module): 181 | """Defines the Unet submodule with skip connection. 182 | X -------------------identity---------------------- 183 | |-- downsampling -- |submodule| -- upsampling --| 184 | """ 185 | 186 | def __init__( 187 | self, 188 | outer_nc, 189 | inner_nc, 190 | input_nc=None, 191 | submodule=None, 192 | outermost=False, 193 | innermost=False, 194 | norm_layer=nn.BatchNorm2d, 195 | use_dropout=False, 196 | ): 197 | """Construct a Unet submodule with skip connections. 198 | 199 | Parameters: 200 | outer_nc (int) -- the number of filters in the outer conv layer 201 | inner_nc (int) -- the number of filters in the inner conv layer 202 | input_nc (int) -- the number of channels in input images/features 203 | submodule (UnetSkipConnectionBlock) -- previously defined submodules 204 | outermost (bool) -- if this module is the outermost module 205 | innermost (bool) -- if this module is the innermost module 206 | norm_layer -- normalization layer 207 | user_dropout (bool) -- if use dropout layers. 208 | """ 209 | super(UnetSkipConnectionBlock, self).__init__() 210 | self.outermost = outermost 211 | if type(norm_layer) == functools.partial: 212 | use_bias = norm_layer.func == nn.InstanceNorm2d 213 | else: 214 | use_bias = norm_layer == nn.InstanceNorm2d 215 | if input_nc is None: 216 | input_nc = outer_nc 217 | downconv = nn.Conv2d( 218 | input_nc, inner_nc, kernel_size=4, stride=2, padding=1, bias=use_bias 219 | ) 220 | downrelu = nn.LeakyReLU(0.2, True) 221 | downnorm = norm_layer(inner_nc) 222 | uprelu = nn.ReLU(True) 223 | upnorm = norm_layer(outer_nc) 224 | 225 | if outermost: 226 | upconv = nn.ConvTranspose2d( 227 | inner_nc * 2, outer_nc, kernel_size=4, stride=2, padding=1 228 | ) 229 | down = [downconv] 230 | up = [uprelu, upconv, nn.Tanh()] 231 | model = down + [submodule] + up 232 | elif innermost: 233 | upconv = nn.ConvTranspose2d( 234 | inner_nc, outer_nc, kernel_size=4, stride=2, padding=1, bias=use_bias 235 | ) 236 | down = [downrelu, downconv] 237 | up = [uprelu, upconv, upnorm] 238 | model = down + up 239 | else: 240 | upconv = nn.ConvTranspose2d( 241 | inner_nc * 2, 242 | outer_nc, 243 | kernel_size=4, 244 | stride=2, 245 | padding=1, 246 | bias=use_bias, 247 | ) 248 | down = [downrelu, downconv, downnorm] 249 | up = [uprelu, upconv, upnorm] 250 | 251 | if use_dropout: 252 | model = down + [submodule] + up + [nn.Dropout(0.5)] 253 | else: 254 | model = down + [submodule] + up 255 | 256 | self.model = nn.Sequential(*model) 257 | 258 | def forward(self, x): 259 | if self.outermost: 260 | return self.model(x) 261 | else: # add skip connections 262 | return torch.cat([x, self.model(x)], 1) 263 | -------------------------------------------------------------------------------- /modules/swapnet_modules.py: -------------------------------------------------------------------------------- 1 | """ 2 | Networks based on SwapNet (Raj et al. 2018). 3 | """ 4 | import code 5 | import logging 6 | import sys 7 | import math 8 | 9 | import torch 10 | 11 | from modules import pix2pix_modules, get_norm_layer 12 | 13 | sys.path.append("../lib") 14 | from torchvision.ops import RoIAlign as ROIAlign 15 | from torch import nn 16 | 17 | from modules.layers import UNetDown, UNetUp, DualUNetUp, ResidualBlock 18 | 19 | wm_log = logging.getLogger("warp_module_shape") 20 | 21 | 22 | class WarpModule(nn.Module): 23 | """ 24 | The warping module takes a body segmentation to represent the "pose", 25 | and an input clothing segmentation to transform to match the pose. 26 | """ 27 | 28 | def __init__(self, body_channels=3, cloth_channels=19, dropout=0.5): 29 | super(WarpModule, self).__init__() 30 | 31 | ###################### 32 | # Body pre-encoding # (top left of SwapNet diagram) 33 | ###################### 34 | self.body_down1 = UNetDown(body_channels, 64, normalize=False) 35 | self.body_down2 = UNetDown(64, 128) 36 | self.body_down3 = UNetDown(128, 256) 37 | self.body_down4 = UNetDown(256, 512, dropout=dropout) 38 | 39 | ###################### 40 | # Cloth pre-encoding # (bottom left of SwapNet diagram) 41 | ###################### 42 | self.cloth_down1 = UNetDown(cloth_channels, 64, normalize=False) 43 | self.cloth_down2 = UNetDown(64, 128) 44 | self.cloth_down3 = UNetDown(128, 256) 45 | self.cloth_down4 = UNetDown(256, 512) 46 | self.cloth_down5 = UNetDown(512, 1024, dropout=dropout) 47 | self.cloth_down6 = UNetDown(1024, 1024, normalize=False, dropout=dropout) 48 | # the two UNetUp's below will be used WITHOUT concatenation. 49 | # hence the input size will not double 50 | self.cloth_up1 = UNetUp(1024, 1024) 51 | self.cloth_up2 = UNetUp(1024, 512) 52 | 53 | ###################### 54 | # Resblocks # (middle of SwapNet diagram) 55 | ###################### 56 | self.resblocks = nn.Sequential( 57 | # I don't really know if dropout should go here. I'm just guessing 58 | ResidualBlock(1024, dropout=dropout), 59 | ResidualBlock(1024, dropout=dropout), 60 | ResidualBlock(1024, dropout=dropout), 61 | ResidualBlock(1024, dropout=dropout), 62 | ) 63 | 64 | ###################### 65 | # Dual Decoding # (right of SwapNet diagram, maybe) 66 | ###################### 67 | # The SwapNet diagram just says "cloth" decoder, so I don't know if they're 68 | # actually doing dual decoding like I've done here. 69 | # Still, I think it's cool and it makes more sense to me. 70 | # Found from "Multi-view Image Generation from a Single-View". 71 | # --------------------- 72 | # input encoded (512) & cat body_d4 (512) cloth_d4 (512) 73 | self.dual_up1 = DualUNetUp(1024, 256) 74 | # input dual_up1 (256) & cat body_d3 (256) cloth_d3 (256) 75 | self.dual_up2 = DualUNetUp(3 * 256, 128) 76 | # input dual_up2 (128) & cat body_d2 (128) cloth_d2 (128) 77 | self.dual_up3 = DualUNetUp(3 * 128, 64) 78 | 79 | # TBH I don't really know what the below code does. 80 | # like why don't we dualnetup with down1? 81 | # maybe specific to pix2pix? hm, if so maybe we should replicate. 82 | # ------ 83 | # update: OHHH I get it now. it's because U-Net only outputs half the size as 84 | # the original image, hence we need to upsample. 85 | self.upsample_and_pad = nn.Sequential( 86 | nn.Upsample(scale_factor=2), 87 | nn.ZeroPad2d((1, 0, 1, 0)), 88 | nn.Conv2d(3 * 64, cloth_channels, 4, padding=1), 89 | nn.Tanh(), 90 | ) 91 | 92 | def forward(self, body, cloth): 93 | wm_log.debug("body shape:", body.shape) 94 | wm_log.debug("cloth shape:", cloth.shape) 95 | wm_log.debug("shapes should match except in the channel dim") 96 | ###################### 97 | # Body pre-encoding # 98 | ###################### 99 | body_d1 = self.body_down1(body) 100 | body_d2 = self.body_down2(body_d1) 101 | wm_log.debug("body_d2 shape, should be 128 channel:", body_d2.shape) 102 | body_d3 = self.body_down3(body_d2) 103 | wm_log.debug("body_d3 shape, should be 256 channel:", body_d3.shape) 104 | body_d4 = self.body_down4(body_d3) 105 | wm_log.debug("body_d4 shape, should be 512 channel:", body_d4.shape) 106 | 107 | wm_log.debug("==============") 108 | ###################### 109 | # Cloth pre-encoding # 110 | ###################### 111 | cloth_d1 = self.cloth_down1(cloth) 112 | cloth_d2 = self.cloth_down2(cloth_d1) 113 | wm_log.debug("cloth_d2 shape, should be 128 channel:", cloth_d2.shape) 114 | cloth_d3 = self.cloth_down3(cloth_d2) 115 | wm_log.debug("cloth_d3 shape, should be 256 channel:", cloth_d3.shape) 116 | cloth_d4 = self.cloth_down4(cloth_d3) 117 | wm_log.debug("cloth_d4 shape, should be 512 channel:", cloth_d4.shape) 118 | cloth_d5 = self.cloth_down5(cloth_d4) 119 | wm_log.debug("cloth_d5 shape, should be 1024 channel:", cloth_d5.shape) 120 | cloth_d6 = self.cloth_down6(cloth_d5) 121 | wm_log.debug("cloth_d6 shape, should be 1024 channel:", cloth_d6.shape) 122 | cloth_u1 = self.cloth_up1(cloth_d6, None) 123 | wm_log.debug("cloth_u1 shape, should be 1024 channel:", cloth_u1.shape) 124 | cloth_u2 = self.cloth_up2(cloth_u1, None) 125 | wm_log.debug("cloth_u2 shape, should be 512 channel:", cloth_u2.shape) 126 | 127 | ####################### 128 | # Combine & Resblocks # 129 | ####################### 130 | # cat on the channel dimension? should be same HxW 131 | body_and_cloth = torch.cat((body_d4, cloth_u2), dim=1) 132 | wm_log.debug( 133 | "body_and_cloth shape, should be 1024 channel:", body_and_cloth.shape 134 | ) 135 | encoded = self.resblocks(body_and_cloth) 136 | wm_log.debug("encoded shape, should be 1024 channel:", encoded.shape) 137 | 138 | # ###################### 139 | # # Dual Decoding # 140 | # ###################### 141 | dual_u1 = self.dual_up1(encoded, body_d3, cloth_d3) 142 | wm_log.debug("dual_u1 shape, should be 3*256 channel:", dual_u1.shape) 143 | dual_u2 = self.dual_up2(dual_u1, body_d2, cloth_d2) 144 | wm_log.debug("dual_u2 shape, should be 3*128 channel:", dual_u2.shape) 145 | dual_u3 = self.dual_up3(dual_u2, body_d1, cloth_d1) 146 | wm_log.debug("dual_u3 shape, should be 3*64 channel:", dual_u3.shape) 147 | 148 | # this is from that commented out code in the __init__() 149 | upsampled = self.upsample_and_pad(dual_u3) 150 | wm_log.debug("upsampled shape, should be original channel:", upsampled.shape) 151 | return upsampled 152 | 153 | 154 | class TextureModule(nn.Module): 155 | def __init__( 156 | self, 157 | texture_channels=3, 158 | cloth_channels=19, 159 | num_roi=12, 160 | norm_type="batch", 161 | dropout=0.5, 162 | unet_type="pix2pix", 163 | img_size=128, 164 | ): 165 | super(TextureModule, self).__init__() 166 | self.roi_align = ROIAlign( 167 | output_size=(128, 128), spatial_scale=1, sampling_ratio=1 168 | ) 169 | 170 | self.num_roi = num_roi 171 | channels = texture_channels * num_roi 172 | self.encode = UNetDown(channels, channels) 173 | 174 | # UNET 175 | 176 | if unet_type == "pix2pix": 177 | # fast log2 of img_size, int only. E.g. if size=128 => num_downs=7 178 | num_downs = math.frexp(img_size)[1] - 1 179 | use_dropout = True if dropout is not None else False 180 | norm_layer = get_norm_layer(norm_type=norm_type) 181 | self.unet = pix2pix_modules.UnetGenerator( 182 | channels + cloth_channels, 183 | texture_channels, 184 | num_downs, 185 | norm_layer=norm_layer, 186 | use_dropout=use_dropout, 187 | ) 188 | else: 189 | self.unet = nn.Sequential( 190 | UNetDown(channels + cloth_channels, 64, normalize=False), 191 | UNetDown(64, 128), 192 | UNetDown(128, 256), 193 | UNetDown(256, 512, dropout=dropout), 194 | UNetDown(512, 1024, dropout=dropout), 195 | UNetDown(1024, 1024, normalize=False, dropout=dropout), 196 | UNetUp(1024, 1024, dropout=dropout), 197 | UNetUp(2 * 1024, 512, dropout=dropout), 198 | UNetUp(2 * 512, 256), 199 | UNetUp(2 * 256, 128), 200 | UNetUp(2 * 128, 64), 201 | # upsample and pad 202 | nn.Upsample(scale_factor=2), 203 | nn.ZeroPad2d((1, 0, 1, 0)), 204 | nn.Conv2d(128, texture_channels, 4, padding=1), 205 | nn.Tanh(), 206 | ) 207 | # print("u-net:", self.unet) 208 | 209 | @staticmethod 210 | def reshape_rois(rois): 211 | """ 212 | Takes a (batch x num_rois x num_coordinates) and reshapes it into a 2D tensor. 213 | The 2D tensor has the first column as the batch index and the remaining columns 214 | as the coordinates. 215 | 216 | num_coordinates should be 4. 217 | 218 | :param rois: a (batch x num_rois x num_coordinates) tensor. coordinates is 4 219 | :return: a 2D tensor formatted for roi layers 220 | """ 221 | # Append batch to rois 222 | # get the batch indices 223 | b_idx = torch.arange(rois.shape[0]).unsqueeze_(-1) 224 | # expand out and reshape to to batchx1 dimension 225 | b_idx = b_idx.expand(rois.shape[0], rois.shape[1]).reshape(-1).unsqueeze_(-1) 226 | b_idx = b_idx.to(rois.device).type(rois.dtype) 227 | reshaped = rois.view(-1, rois.shape[-1]) 228 | reshaped = torch.cat((b_idx, reshaped), dim=1) 229 | return reshaped 230 | 231 | def forward(self, input_tex, rois, cloth): 232 | rois = TextureModule.reshape_rois(rois) 233 | # do roi alignment 234 | pooled_rois = self.roi_align(input_tex, rois) 235 | # reshape the pooled rois such that pool output goes in the channels instead of 236 | # batch size 237 | batch_size = int(pooled_rois.shape[0] / self.num_roi) 238 | pooled_rois = pooled_rois.view( 239 | batch_size, -1, pooled_rois.shape[2], pooled_rois.shape[3] 240 | ) 241 | 242 | encoded_tex = self.encode(pooled_rois) 243 | 244 | scale_factor = input_tex.shape[2] / encoded_tex.shape[2] 245 | upsampled_tex = nn.functional.interpolate( 246 | encoded_tex, scale_factor=scale_factor 247 | ) 248 | 249 | # # DEBUG DEBUG: see if GAN works without textured input 250 | # s = torch.sum(upsampled_tex).item() 251 | # if s != 0: 252 | # print( 253 | # f"Warning!!! Input of 0 did not result in upsampled_tex of 0. Sum is {s}. Setting to 0s anyway" 254 | # ) 255 | # upsampled_tex = torch.zeros_like(upsampled_tex) 256 | 257 | # concat on the channel dimension 258 | tex_with_cloth = torch.cat((upsampled_tex, cloth), 1) 259 | 260 | return self.unet(tex_with_cloth) 261 | -------------------------------------------------------------------------------- /optimizers/__init__.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from adabound import AdaBound 5 | from torch.optim import AdamW 6 | 7 | AdamW, AdaBound 8 | 9 | 10 | def get_options_modifier(optimizer_name): 11 | optimizer_name = optimizer_name.lower() 12 | 13 | # Adam or AdamW 14 | if "adam" in optimizer_name.lower(): 15 | return adam_modifier 16 | # AdaBound or AdaBoundW 17 | elif "adabound" in optimizer_name.lower(): 18 | return adabound_modifier 19 | elif "sgd" in optimizer_name.lower(): 20 | raise NotImplementedError 21 | else: 22 | raise NotImplementedError 23 | 24 | 25 | def adam_modifier(parser: ArgumentParser, *_): 26 | parser.add_argument("--b1", type=float, default=0.9, help="Adam b1") 27 | parser.add_argument("--b2", type=float, default=0.999, help="Adam b2") 28 | return parser 29 | 30 | 31 | def adabound_modifier(parser: ArgumentParser, *_): 32 | parser = adam_modifier(parser) 33 | parser.add_argument("--final_lr", type=float, default=0.1, help="AdaBound final_lr") 34 | return parser 35 | 36 | 37 | def define_optimizer(parameters, opt, net: str) -> torch.optim.Optimizer: 38 | """ 39 | Return an initialized Optimizer class 40 | :param opt: 41 | :param net: 42 | :param parameters: 43 | :return: 44 | """ 45 | # check whether optimizer_G or optimizer_D 46 | if net != "D" and net != "G": 47 | raise ValueError(f"net arg must be 'D' or 'G', received {net}") 48 | arg = "optimizer_" + net 49 | choice = getattr(opt, arg) 50 | 51 | # add optimizer kwargs 52 | lr = opt.d_lr if net == "D" else opt.lr 53 | wd = opt.d_weight_decay if net == "D" else opt.weight_decay 54 | kwargs = {"lr": lr, "weight_decay": wd, "betas": (opt.b1, opt.b2)} 55 | if choice == "AdaBound": 56 | kwargs["final_lr"] = opt.final_lr 57 | 58 | optim_class = eval(choice) 59 | optimizer = optim_class(parameters, **kwargs) 60 | return optimizer 61 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | """ 2 | A list of base options common to all stages 3 | """ 4 | import copy 5 | import sys 6 | import argparse 7 | import json 8 | import os 9 | 10 | import torch 11 | 12 | import datasets 13 | import models 14 | import optimizers 15 | from util.util import PromptOnce 16 | 17 | datasets, models, optimizers # so auto import doesn't remove above 18 | 19 | 20 | class BaseOptions: 21 | def __init__(self): 22 | parser = argparse.ArgumentParser( 23 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 24 | conflict_handler="resolve", 25 | ) 26 | # == EXPERIMENT SETUP == 27 | parser.add_argument( 28 | "--config_file", 29 | help="load arguments from a json file instead of command line", 30 | ) 31 | parser.add_argument( 32 | "--name", 33 | default="my_experiment", 34 | help="name of the experiment, determines where things are saved", 35 | ) 36 | parser.add_argument( 37 | "--comments", 38 | default="", 39 | help="additional comments to add to this experiment, saved in args.json", 40 | ) 41 | parser.add_argument("--verbose", action="store_true") 42 | parser.add_argument( 43 | "--display_winsize", 44 | type=int, 45 | default=256, 46 | help="display window size for both visdom and HTML", 47 | ) 48 | # == MODEL INIT / LOADING / SAVING == 49 | parser.add_argument( 50 | "--model", help="which model to run", choices=("warp", "texture", "pix2pix") 51 | ) 52 | parser.add_argument( 53 | "--checkpoints_dir", default="./checkpoints", help="Where to save models" 54 | ) 55 | parser.add_argument( 56 | "--load_epoch", 57 | default="latest", 58 | help="epoch to load (use with --continue_train or for inference, 'latest' " 59 | "for latest ", 60 | ) 61 | # == DATA / IMAGE LOADING == 62 | parser.add_argument( 63 | "--dataroot", 64 | required=True, 65 | help="path to data, should contain 'cloth/', 'body/', 'texture/', " 66 | "'rois.csv'", 67 | ) 68 | parser.add_argument( 69 | "--dataset", help="dataset class to use, if none then will use model name" 70 | ) 71 | parser.add_argument( 72 | "--dataset_mode", 73 | default="image", 74 | choices=("image", "video"), 75 | help="how data is formatted. video mode allows additional source inputs" 76 | "from other frames of the video", 77 | ) 78 | # channels 79 | parser.add_argument( 80 | "--cloth_representation", 81 | default="labels", # default according to SwapNet 82 | choices=("rgb", "labels"), 83 | help="which representation the cloth segmentations are in. 'labels' means " 84 | "a 2D tensor where each value is the cloth label. 'rgb' ", 85 | ) 86 | parser.add_argument( 87 | "--body_representation", 88 | default="rgb", # default according to SwapNet 89 | choices=("rgb", "labels"), 90 | help="which representation the body segmentations are in", 91 | ) 92 | parser.add_argument( 93 | "--cloth_channels", 94 | default=19, 95 | type=int, 96 | help="only used if --cloth_representation == 'labels'. cloth segmentation " 97 | "number of channels", 98 | ) 99 | parser.add_argument( 100 | "--body_channels", 101 | default=12, 102 | type=int, 103 | help="only used if --body_representation == 'labels'. body segmentation " 104 | "number of channels. Use 12 for neural body fitting output", 105 | ) 106 | parser.add_argument( 107 | "--texture_channels", 108 | default=3, 109 | type=int, 110 | help="RGB textured image number of channels", 111 | ) 112 | # image dimension / editing 113 | parser.add_argument( 114 | "--pad", action="store_true", help="add a padding to make image square" 115 | ) 116 | parser.add_argument( 117 | "--load_size", 118 | default=128, 119 | type=int, 120 | help="scale images (after padding) to this size", 121 | ) 122 | parser.add_argument( 123 | "--crop_size", type=int, default=128, help="then crop to this size" 124 | ) 125 | parser.add_argument( 126 | "--crop_bounds", 127 | help="DO NOT USE WITH --crop_size. crop images to a region: ((xmin, ymin), (xmax, ymax))", 128 | ) 129 | # == ITERATION PROPERTIES == 130 | parser.add_argument( 131 | "--max_dataset_size", type=int, default=float("inf"), help="cap on data" 132 | ) 133 | parser.add_argument( 134 | "--batch_size", type=int, default=8, help="batch size to load data" 135 | ) 136 | parser.add_argument( 137 | "--shuffle_data", 138 | default=True, 139 | type=bool, 140 | help="whether to shuffle dataset (default is True)", 141 | ) 142 | parser.add_argument( 143 | "--num_workers", 144 | default=4, 145 | type=int, 146 | help="number of CPU threads for data loading", 147 | ) 148 | parser.add_argument( 149 | "--gpu_id", default=0, type=int, help="gpu id to use. -1 for cpu" 150 | ) 151 | parser.add_argument( 152 | "--no_confirm", action="store_true", help="do not prompt for confirmations" 153 | ) 154 | 155 | self._parser = parser 156 | self.is_train = None 157 | 158 | def gather_options(self): 159 | """ 160 | Gathers options from all modifieable thingies. 161 | :return: 162 | """ 163 | parser = self._parser 164 | 165 | # basic options 166 | opt, _ = parser.parse_known_args() 167 | parser.set_defaults(dataset=opt.model) 168 | opt.batch_size 169 | 170 | # modify options for each arg that can do so 171 | modifiers = ["model", "dataset"] 172 | if self.is_train: 173 | modifiers.append("optimizer_D") 174 | for arg in modifiers: 175 | # becomes model(s), dataset(s), optimizer(s) 176 | import_source = eval(arg.split("_")[0] + "s") 177 | # becomes e.g. opt.model, opt.dataset, opt.optimizer 178 | name = getattr(opt, arg) 179 | print(arg, name) 180 | if name is not None: 181 | options_modifier = import_source.get_options_modifier(name) 182 | parser = options_modifier(parser, self.is_train) 183 | opt, _ = parser.parse_known_args() 184 | # hacky, add optimizer G params if different from opt_D 185 | if arg is "optimizer_D" and opt.optimizer_D != opt.optimizer_G: 186 | modifiers.append("optimizer_G") 187 | 188 | self._parser = parser 189 | final_opt = self._parser.parse_args() 190 | return final_opt 191 | 192 | @staticmethod 193 | def _validate(opt): 194 | """ 195 | Validate that options are correct 196 | :return: 197 | """ 198 | assert ( 199 | opt.crop_size <= opt.load_size 200 | ), "Crop size must be less than or equal to load size " 201 | 202 | def parse(self, print_options=True, store_options=True, user_overrides=True): 203 | """ 204 | 205 | Args: 206 | print_options: print the options to screen when parsed 207 | store_options: save the arguments to file: "{opt.checkpoints_dir}/{opt.name}/args.json" 208 | 209 | Returns: 210 | 211 | """ 212 | opt = self.gather_options() 213 | opt.is_train = self.is_train 214 | 215 | # perform assertions on arguments 216 | BaseOptions._validate(opt) 217 | 218 | if opt.gpu_id > 0: 219 | torch.cuda.set_device(opt.gpu_id) 220 | torch.backends.cudnn.benchmark = True 221 | 222 | self.opt = opt 223 | 224 | # Load options from config file if present 225 | if opt.config_file: 226 | self.load(opt.config_file, user_overrides) 227 | 228 | if print_options: # print what we parsed 229 | self.print() 230 | 231 | root = opt.checkpoints_dir if self.is_train else opt.results_dir 232 | self.save_file = os.path.join(root, opt.name, "args.json") 233 | if store_options: # store options to file 234 | self.save() 235 | return opt 236 | 237 | def print(self): 238 | """ 239 | prints the options nicely 240 | :return: 241 | """ 242 | d = vars(self.opt) 243 | print("=====OPTIONS======") 244 | for k, v in d.items(): 245 | print(k, ":", v) 246 | print("==================") 247 | 248 | def save(self): 249 | """ 250 | Saves to a .json file 251 | :return: 252 | """ 253 | d = vars(self.opt) 254 | 255 | PromptOnce.makedirs(os.path.dirname(self.save_file), not self.opt.no_confirm) 256 | with open(self.save_file, "w") as f: 257 | f.write(json.dumps(d, indent=4)) 258 | 259 | def load(self, json_file, user_overrides): 260 | load(self.opt, json_file, user_overrides=user_overrides) 261 | 262 | 263 | def load(opt, json_file, user_overrides=True): 264 | """ 265 | 266 | Args: 267 | opt: Namespace that will get modified 268 | json_file: 269 | user_overrides: whether user command line arguments should override anything being loaded from the config file 270 | 271 | """ 272 | opt = copy.deepcopy(opt) 273 | with open(json_file, "r") as f: 274 | args = json.load(f) 275 | 276 | # if the user specifies arguments on the command line, don't override these 277 | if user_overrides: 278 | user_args = filter(lambda a: a.startswith("--"), sys.argv[1:]) 279 | user_args = set( 280 | [a.lstrip("-") for a in user_args] 281 | ) # get rid of left dashes 282 | print("Not overriding:", user_args) 283 | 284 | # override default options with values in config file 285 | for k, v in args.items(): 286 | # only override if not specified on the cmdline 287 | if not user_overrides or (user_overrides and k not in user_args): 288 | setattr(opt, k, v) 289 | # but make sure the config file matches up 290 | opt.config_file = json_file 291 | return opt 292 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from options.base_options import BaseOptions 4 | 5 | 6 | class TestOptions(BaseOptions): 7 | def __init__(self, **defaults): 8 | super().__init__() 9 | self.is_train = False 10 | parser = self._parser 11 | 12 | parser.set_defaults(max_dataset_size=50, shuffle_data=False) 13 | parser.add_argument( 14 | "--interval", 15 | metavar="N", 16 | default=1, 17 | type=int, 18 | help="only run every n images", 19 | ) 20 | parser.add_argument( 21 | "--warp_checkpoint", 22 | help="Use this to run the warp stage. Specifies the checkpoint file of " 23 | "warp stage model, containing args.json file in same dir", 24 | ) 25 | parser.add_argument( 26 | "--texture_checkpoint", 27 | help="Use this to run the texture stage. Specifies the checkpoint dir of " 28 | "texture stage containing args.json file", 29 | ) 30 | parser.add_argument( 31 | "--checkpoint", 32 | help="Shorthand for both warp and texture checkpoint to use the 'latest' " 33 | "generator file (or specify using --load_epoch). This should be the " 34 | "root dir containing warp/ and texture/ checkpoint folders.", 35 | ) 36 | parser.add_argument( 37 | "--body_dir", 38 | help="Directory to use as target bodys for where the cloth will be placed " 39 | "on. If same directory as --cloth_root, use --shuffle_data to achieve " 40 | "clothing transfer. If not provided, will uses --dataroot/body", 41 | ) 42 | parser.add_argument( 43 | "--cloth_dir", 44 | help="Directory to use for the clothing source. If same directory as " 45 | "--body_root, use --shuffle_data to achieve clothing transfer. If not " 46 | "provided, will use --dataroot/cloth", 47 | ) 48 | parser.add_argument( 49 | "--texture_dir", 50 | help="Directory to use for the clothing source. If same directory as " 51 | "--body_root, use --shuffle_data to achieve clothing transfer. If not " 52 | "provided, will use --dataroot/texture", 53 | ) 54 | parser.add_argument( 55 | "--results_dir", 56 | default="results", 57 | help="folder to output intermediate and final results", 58 | ) 59 | parser.add_argument( 60 | "--skip_intermediates", 61 | action="store_true", 62 | help="choose not to save intermediate cloth visuals as images for warp " 63 | "stage (instead, just save .npz files)", 64 | ) 65 | 66 | parser.add_argument( 67 | "--dataroot", 68 | required=False, 69 | help="path to dataroot if cloth, body, and texture not individually specified", 70 | ) 71 | # remove arguments 72 | parser.add_argument( 73 | "--model", help=argparse.SUPPRESS 74 | ) # remove model as we restore from checkpoint 75 | parser.add_argument("--name", default="", help=argparse.SUPPRESS) 76 | 77 | parser.set_defaults(**defaults) 78 | 79 | @staticmethod 80 | def _validate(opt): 81 | super(TestOptions, TestOptions)._validate(opt) 82 | 83 | if not (opt.body_dir or opt.cloth_dir or opt.texture_dir or opt.dataroot): 84 | raise ValueError( 85 | "Must either (1) specify --dataroot, or (2) --body_dir, --cloth_dir, " 86 | "and --texture_dir individually" 87 | ) 88 | 89 | if not opt.dataroot: 90 | if opt.warp_checkpoint and not opt.body_dir: 91 | raise ValueError("Warp stage must have body_dir") 92 | if opt.texture_checkpoint and not opt.texture_dir: 93 | raise ValueError("Texture stage must have texture_dir") 94 | 95 | if not opt.warp_checkpoint and not opt.texture_checkpoint: 96 | raise ValueError("Must set either warp_checkpoint or texture_checkpoint") 97 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from options.base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | def __init__(self): 6 | super().__init__() 7 | self.is_train = True 8 | parser = self._parser 9 | # override the model arg from base options, such that model is REQUIRED 10 | parser.add_argument( 11 | "--model", 12 | help="which model to run", 13 | choices=("warp", "texture", "pix2pix"), 14 | required=True 15 | ) 16 | parser.add_argument( 17 | "--continue_train", 18 | action="store_true", 19 | help="continue training from latest checkpoint", 20 | ) 21 | # visdom and HTML visualization parameters 22 | parser.add_argument( 23 | "--display_freq", 24 | type=int, 25 | default=400, 26 | help="frequency of showing training results on screen", 27 | ) 28 | parser.add_argument( 29 | "--display_ncols", 30 | type=int, 31 | default=4, 32 | help="if positive, display all images in a single visdom web panel with " 33 | "certain number of images per row.", 34 | ) 35 | parser.add_argument( 36 | "--display_id", type=int, default=1, help="window id of the web display" 37 | ) 38 | parser.add_argument( 39 | "--display_server", 40 | type=str, 41 | default="http://localhost", 42 | help="visdom server of the web display", 43 | ) 44 | parser.add_argument( 45 | "--display_env", 46 | type=str, 47 | default="main", 48 | help='visdom display environment name (default is "main")', 49 | ) 50 | parser.add_argument( 51 | "--display_port", 52 | type=int, 53 | default=8097, 54 | help="visdom port of the web display", 55 | ) 56 | parser.add_argument( 57 | "--update_html_freq", 58 | type=int, 59 | default=1000, 60 | help="frequency of saving training results to html", 61 | ) 62 | parser.add_argument( 63 | "--print_freq", 64 | type=int, 65 | default=100, 66 | help="frequency of showing training results on console", 67 | ) 68 | parser.add_argument( 69 | "--no_html", 70 | action="store_true", 71 | help="do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/", 72 | ) 73 | # Training parameters 74 | parser.add_argument( 75 | "--n_epochs", "--num_epochs", default=20, type=int, help="number of epochs to train until" 76 | ) 77 | parser.add_argument( 78 | "--start_epoch", type=int, default=0, help="epoch to start training from" 79 | ) 80 | parser.add_argument( 81 | "--sample_freq", 82 | help="how often to sample and save image results from the generator", 83 | ) 84 | parser.add_argument( 85 | "--checkpoint_freq", 86 | default=2, 87 | type=int, 88 | help="how often to save checkpoints. negative numbers for middle of epoch", 89 | ) 90 | parser.add_argument( 91 | "--latest_checkpoint_freq", 92 | default=5120, 93 | type=int, 94 | help="how often (in iterations) to save latest checkpoint", 95 | ) 96 | parser.add_argument( 97 | "--save_by_iter", 98 | action="store_true", 99 | help="whether saves model by iteration", 100 | ) 101 | parser.add_argument( 102 | "--lr", 103 | "--learning_rate", 104 | type=float, 105 | default=0.01, 106 | help="initial learning rate", 107 | ) 108 | parser.add_argument( 109 | "--wt_decay", 110 | "--weight_decay", 111 | dest="weight_decay", 112 | default=0, 113 | type=float, 114 | help="optimizer L2 weight decay", 115 | ) 116 | # weights init 117 | parser.add_argument( 118 | "--init_type", 119 | default="kaiming", 120 | choices=("normal", "xavier", "kaiming"), 121 | help="weights initialization method", 122 | ) 123 | parser.add_argument( 124 | "--init_gain", default=0.02, type=float, help="init scaling factor" 125 | ) 126 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | General-purpose training script for image-to-image translation, adapted for SwapNet. 3 | 4 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, 5 | colorization) and different datasets (with option '--dataset_mode': e.g., image, video). 6 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model 7 | ('--model'). 8 | 9 | It first creates the model, dataset, and visualizer given the options. 10 | It then does standard network training. During the training, it also visualize/save 11 | the images, print/save the loss plot, and save models. 12 | The script supports continue/resume training. Use '--continue_train' to resume your 13 | previous training. 14 | 15 | Example: 16 | Train the warp model: 17 | python train.py --name warp_stage --model warp --dataroot data/deep_fashion 18 | Train the texture model: 19 | python train.py --name texture_stage --model texture --dataroot data/deep_fashion 20 | """ 21 | from tqdm import tqdm 22 | import time 23 | from options.train_options import TrainOptions 24 | from datasets import create_dataset 25 | from models import create_model 26 | 27 | from util.visualizer import Visualizer 28 | 29 | print = tqdm.write 30 | 31 | if __name__ == "__main__": 32 | opt = TrainOptions().parse(store_options=True) # get training options 33 | # create a dataset given opt.dataset_mode and other options 34 | dataset = create_dataset(opt) 35 | dataset_size = len(dataset) # get the number of images in the dataset. 36 | print(f"The number of training images = {dataset_size:d}") 37 | 38 | model = create_model(opt) # create a model given opt.model and other options 39 | model.setup(opt) # regular setup: load and print networks; create schedulers 40 | # create a visualizer that display/save images and plots 41 | visualizer = Visualizer(opt) 42 | total_iters = 0 # the total number of training iterations 43 | 44 | # outer loop for different epochs; 45 | # we save the model by # , + 46 | for epoch in tqdm( 47 | range(opt.start_epoch + 1, opt.n_epochs + 1), desc="Completed Epochs" 48 | ): 49 | epoch_start_time = time.time() # timer for entire epoch 50 | iter_data_time = time.time() # timer for data loading per iteration 51 | # the number of training iterations in current epoch, reset to 0 every epoch 52 | epoch_iter = 0 53 | 54 | with tqdm(total=len(dataset), unit="image") as pbar: 55 | for i, data in enumerate(dataset): # inner loop within one epoch 56 | iter_start_time = time.time() # timer for computation per iteration 57 | if total_iters % opt.print_freq == 0: 58 | t_data = iter_start_time - iter_data_time 59 | visualizer.reset() 60 | total_iters += opt.batch_size 61 | epoch_iter += opt.batch_size 62 | model.set_input(data) # unpack data from dataset and preprocess 63 | # calculate loss functions, get gradients, update network weights 64 | model.optimize_parameters() 65 | 66 | if total_iters % opt.display_freq == 0: 67 | # display images on visdom and save images to a HTML file 68 | save_result = total_iters % opt.update_html_freq == 0 69 | model.compute_visuals() 70 | visualizer.display_current_results( 71 | model.get_current_visuals(), epoch, save_result 72 | ) 73 | 74 | losses = model.get_current_losses() 75 | Visualizer.just_print_losses( 76 | epoch, losses, print_func=lambda m: pbar.set_description(m) 77 | ) 78 | if total_iters % opt.print_freq == 0: 79 | # print training losses and save logging information to the disk 80 | t_comp = (time.time() - iter_start_time) / opt.batch_size 81 | visualizer.print_current_losses( 82 | epoch, 83 | epoch_iter, 84 | losses, 85 | t_comp, 86 | t_data, 87 | print_func=lambda *args: None, 88 | ) 89 | if opt.display_id > 0: 90 | visualizer.plot_current_losses( 91 | epoch - 1, float(epoch_iter) / dataset_size, losses 92 | ) 93 | if ( 94 | opt.latest_checkpoint_freq 95 | and total_iters % opt.latest_checkpoint_freq == 0 96 | ): 97 | # cache our latest model every iterations 98 | print( 99 | f"saving the latest model (epoch {epoch:d}, total_iters {total_iters:d}) " 100 | ) 101 | save_prefix = ( 102 | "iter_%d" % total_iters if opt.save_by_iter else f"latest" 103 | ) 104 | model.save_checkpoint(save_prefix) 105 | 106 | iter_data_time = time.time() 107 | # weird unpacking to get the batch_size (we can't use opt.batch_size in case total len is not a multiple of batch_size 108 | pbar.update(len(tuple(data.values())[0])) 109 | 110 | if opt.checkpoint_freq and epoch % opt.checkpoint_freq == 0: 111 | # cache our model every epochs 112 | print( 113 | f"saving the model at the end of epoch {epoch:d}, iters {total_iters:d}" 114 | ) 115 | model.save_checkpoint("latest") 116 | model.save_checkpoint(epoch) 117 | 118 | # print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs, time.time() - epoch_start_time)) 119 | # model.update_learning_rate() # update learning rates at the end of every epoch. 120 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/calculate_imagedir_stats.py: -------------------------------------------------------------------------------- 1 | """ 2 | in this script, we calculate the image per channel mean and standard 3 | deviation in the training set, do not calculate the statistics on the 4 | whole dataset, as per here http://cs231n.github.io/neural-networks-2/#datapre 5 | gist source: https://gist.github.com/jdhao/9a86d4b9e4f79c5330d54de991461fd6 6 | """ 7 | import argparse 8 | import json 9 | import os 10 | 11 | import numpy as np 12 | from os import listdir, chdir 13 | from tqdm import tqdm 14 | from os.path import join, isdir 15 | from glob import glob 16 | import cv2 17 | import timeit 18 | import sys 19 | 20 | sys.path.append('.') 21 | 22 | # number of channels of the dataset image, 3 for color jpg, 1 for grayscale img 23 | CHANNEL_NUM = 3 24 | 25 | 26 | def cal_dir_stat(root): 27 | pixel_num = 0 # store all pixel number in the dataset 28 | channel_sum = np.zeros(CHANNEL_NUM) 29 | channel_sum_squared = np.zeros(CHANNEL_NUM) 30 | 31 | im_pths = glob(join(root,"**/*"+".jpg"), recursive=True) 32 | print(len(im_pths)) 33 | 34 | for path in tqdm(im_pths): 35 | im = cv2.imread(path) # image in M*N*3 shape, channel in BGR order 36 | im = im / 255.0 37 | pixel_num += im.size / CHANNEL_NUM 38 | channel_sum += np.sum(im, axis=(0, 1)) 39 | channel_sum_squared += np.sum(np.square(im), axis=(0, 1)) 40 | 41 | bgr_mean = channel_sum / pixel_num 42 | bgr_std = np.sqrt(channel_sum_squared / pixel_num - np.square(bgr_mean)) 43 | 44 | rgb_mean = list(bgr_mean)[::-1] 45 | rgb_std = list(bgr_std)[::-1] 46 | return rgb_mean, rgb_std 47 | 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument("data_dir", nargs="+") 50 | parser.add_argument("--output_file", help="file to append output to", default="normalization_stats.json") 51 | args = parser.parse_args() 52 | 53 | # The script assumes that under train_root, there are separate directories for each class 54 | # of training images. 55 | for d in args.data_dir: 56 | print("Calculating stats for", d) 57 | train_root = d 58 | start = timeit.default_timer() 59 | mean, std = cal_dir_stat(train_root) 60 | end = timeit.default_timer() 61 | print("elapsed time: {}".format(end - start)) 62 | print("mean:{}\nstd:{}".format(mean, std)) 63 | 64 | def file_has_lines(f): 65 | for i, _ in enumerate(f): 66 | if i > 1: 67 | return True 68 | else: 69 | return False 70 | 71 | stats = { 72 | "path": d.split(os.path.sep)[-1], 73 | "means": mean, 74 | "stds": std 75 | } 76 | 77 | of = os.path.join(d, "..", args.output_file) 78 | 79 | with open(of, "a") as f: 80 | f.write(json.dumps(stats) + "\n") 81 | 82 | print("Done!") 83 | -------------------------------------------------------------------------------- /util/decode_labels.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image 4 | 5 | n_classes = 20 6 | # colour map 7 | label_colours = [(0,0,0) 8 | # 0=Background 9 | ,(128,0,0),(255,0,0),(0,85,0),(170,0,51),(255,85,0) 10 | # 1=Hat, 2=Hair, 3=Glove, 4=Sunglasses, 5=UpperClothes 11 | ,(0,0,85),(0,119,221),(85,85,0),(0,85,85),(85,51,0) 12 | # 6=Dress, 7=Coat, 8=Socks, 9=Pants, 10=Jumpsuits 13 | ,(52,86,128),(0,128,0),(0,0,255),(51,170,221),(0,255,255) 14 | # 11=Scarf, 12=Skirt, 13=Face, 14=LeftArm, 15=RightArm 15 | ,(85,255,170),(170,255,85),(255,255,0),(255,170,0)] 16 | # 16=LeftLeg, 17=RightLeg, 18=LeftShoe, 19=RightShoe 17 | 18 | 19 | # take out sunglasses 20 | label_colours = label_colours[:4] + label_colours[5:] 21 | n_classes = 19 22 | 23 | 24 | def decode_cloth_labels(pt_tensor, num_images=-1, num_classes=n_classes): 25 | """Decode batch of segmentation masks. 26 | AJ comment: Converts the tensor into a RGB image. 27 | Args: 28 | as_tf_order: result of inference after taking argmax. 29 | num_images: number of images to decode from the batch. 30 | Returns: 31 | A batch with num_images RGB images of the same size as the input. 32 | """ 33 | # change to H x W x C order 34 | tf_order = pt_tensor.permute(0, 2, 3, 1) 35 | argmax = tf_order.argmax(dim=-1, keepdim=True) 36 | mask = argmax.cpu().numpy() 37 | 38 | n, h, w, c = mask.shape 39 | if num_images < 0: 40 | num_images = n 41 | assert(n >= num_images), 'Batch size %d should be greater or equal than number of images to save %d.' % (n, num_images) 42 | outputs = np.zeros((num_images, h, w, 3), dtype=np.uint8) 43 | for i in range(num_images): 44 | img = Image.new('RGB', (len(mask[i, 0]), len(mask[i]))) 45 | pixels = img.load() 46 | # AJ: this enumerates the "rows" of the image (I think) 47 | for j_, j in enumerate(mask[i, :, :, 0]): 48 | for k_, k in enumerate(j): 49 | if k < n_classes: 50 | pixels[k_,j_] = label_colours[k] 51 | outputs[i] = np.array(img) 52 | 53 | # convert back to tensor. effectively puts back into range [0,1] 54 | back_to_pt = torch.from_numpy(outputs).permute(0, 3, 1, 2) 55 | return back_to_pt 56 | -------------------------------------------------------------------------------- /util/draw_rois.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image, ImageDraw 3 | import torch 4 | import seaborn 5 | 6 | from util.util import tensor2im 7 | 8 | NUM_BODY_LABELS = 12 9 | 10 | BODY_COLORS = (np.array(seaborn.color_palette("hls", NUM_BODY_LABELS)) * 255).astype( 11 | np.uint8 12 | ) 13 | 14 | 15 | # TODO move to util file 16 | def draw_rois_on_texture( 17 | rois: torch.Tensor, texture_tensors: torch.Tensor, width_factor=0.01 18 | ): 19 | """ 20 | 21 | Args: 22 | rois: roi in 23 | texture_tensors: 24 | width: 25 | 26 | Returns: 27 | 28 | """ 29 | 30 | samples = [] 31 | 32 | # do for all in the batch 33 | for roi_batch, t in zip(rois, texture_tensors): 34 | # unsqueeze because of batch annoyances 35 | im = Image.fromarray(tensor2im(t.unsqueeze(0))) 36 | draw = ImageDraw.Draw(im) 37 | for i, roi_row in enumerate(roi_batch.cpu()): 38 | draw.rectangle( 39 | roi_row.numpy(), 40 | outline=tuple(BODY_COLORS[i]), 41 | width=int(round(width_factor * im.size[0])), 42 | ) 43 | 44 | samples.append(np.array(im)) 45 | 46 | # return to batch size 47 | return np.stack(samples) 48 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. If there are multiple images 11 | in the batch, simply converts and returns the first 12 | 13 | Parameters: 14 | input_image (tensor) -- the input image tensor array 15 | imtype (type) -- the desired type of the converted numpy array 16 | """ 17 | if not isinstance(input_image, np.ndarray): 18 | if isinstance(input_image, torch.Tensor): # get the data from a variable 19 | image_tensor = input_image.data 20 | else: 21 | return input_image 22 | # select the first and convert it into a numpy array 23 | image_numpy: np.ndarray = image_tensor[0].cpu().numpy() 24 | if image_numpy.shape[0] == 1: # grayscale to RGB 25 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 26 | # post-processing: tranpose and scaling 27 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 28 | else: # if it is a numpy array 29 | if len(input_image.shape) == 4: # select first in batch 30 | input_image = input_image[0] 31 | image_numpy = input_image 32 | return image_numpy.astype(imtype) 33 | 34 | 35 | def diagnose_network(net, name="network"): 36 | """Calculate and print the mean of average absolute(gradients) 37 | 38 | Parameters: 39 | net (torch network) -- Torch network 40 | name (str) -- the name of the network 41 | """ 42 | mean = 0.0 43 | count = 0 44 | for param in net.parameters(): 45 | if param.grad is not None: 46 | mean += torch.mean(torch.abs(param.grad.data)) 47 | count += 1 48 | if count > 0: 49 | mean = mean / count 50 | print(name) 51 | print(mean) 52 | 53 | 54 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 55 | """Save a numpy image to the disk 56 | 57 | Parameters: 58 | image_numpy (numpy array) -- input numpy array 59 | image_path (str) -- the path of the image 60 | """ 61 | 62 | image_pil = Image.fromarray(image_numpy) 63 | h, w, _ = image_numpy.shape 64 | 65 | if aspect_ratio > 1.0: 66 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 67 | if aspect_ratio < 1.0: 68 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 69 | image_pil.save(image_path) 70 | 71 | 72 | def print_numpy(x, val=True, shp=False): 73 | """Print the mean, min, max, median, std, and size of a numpy array 74 | 75 | Parameters: 76 | val (bool) -- if print the values of the numpy array 77 | shp (bool) -- if print the shape of the numpy array 78 | """ 79 | x = x.astype(np.float64) 80 | if shp: 81 | print("shape,", x.shape) 82 | if val: 83 | x = x.flatten() 84 | print( 85 | "mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f" 86 | % (np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)) 87 | ) 88 | 89 | 90 | def remove_prefix(text, prefix): 91 | """ 92 | Remove prefix from a string 93 | :param text: 94 | :param prefix: 95 | :return: 96 | """ 97 | return text[text.startswith(prefix) and len(prefix) :] 98 | 99 | 100 | class PromptOnce: 101 | """ 102 | Prompts the user if a path already exists. However, it will only prompt once during 103 | the whole run of the program. 104 | """ 105 | 106 | already_asked = False 107 | 108 | @staticmethod 109 | def makedirs(path, prompt=True): 110 | try: 111 | os.makedirs(path) 112 | PromptOnce.already_asked = True 113 | except FileExistsError as e: 114 | if prompt and len(os.listdir(path)) != 0 and not PromptOnce.already_asked: 115 | print(f"The experiment directory '{path}' already exists.") 116 | print(" Here are its contents:") 117 | print("\t", os.listdir(path)) 118 | a = input( 119 | f"\n Existing data will be overwritten!\n" 120 | f" Are you sure you want to continue? (y/N): " 121 | ) 122 | if a.lower().strip() != "y": 123 | print(" Did not receive confirmation to overwrite. Exiting...") 124 | quit() 125 | print() 126 | PromptOnce.already_asked = True 127 | 128 | 129 | def mkdirs(paths): 130 | """create empty directories if they don't exist 131 | 132 | Parameters: 133 | paths (str list) -- a list of directory paths 134 | """ 135 | if isinstance(paths, list) and not isinstance(paths, str): 136 | for path in paths: 137 | mkdir(path) 138 | else: 139 | mkdir(paths) 140 | 141 | 142 | def mkdir(path): 143 | """create a single empty directory if it didn't exist 144 | 145 | Parameters: 146 | path (str) -- a single directory path 147 | """ 148 | if not os.path.exists(path): 149 | os.makedirs(path) 150 | -------------------------------------------------------------------------------- /util/visualizer.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import sys 4 | import ntpath 5 | import time 6 | 7 | import torch 8 | from datasets.data_utils import remove_extension 9 | from . import util, html 10 | from subprocess import Popen, PIPE 11 | 12 | from tqdm import tqdm 13 | print = tqdm.write 14 | 15 | if sys.version_info[0] == 2: 16 | VisdomExceptionBase = Exception 17 | else: 18 | VisdomExceptionBase = ConnectionError 19 | 20 | 21 | def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256): 22 | """Save images to the disk. 23 | 24 | Parameters: 25 | webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details) 26 | visuals (OrderedDict) -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs 27 | image_path (str) -- the string is used to create image paths 28 | aspect_ratio (float) -- the aspect ratio of saved images 29 | width (int) -- the images will be resized to width x width 30 | 31 | This function will save images stored in 'visuals' to the HTML file specified by 'webpage'. 32 | """ 33 | image_dir = webpage.get_image_dir() 34 | name = "_to_".join([remove_extension(ntpath.basename(p)) for p in image_path[0]]) 35 | 36 | webpage.add_header(name) 37 | ims, txts, links = [], [], [] 38 | 39 | for label, im_data in visuals.items(): 40 | im = util.tensor2im(im_data) 41 | image_name = f'{name}_{label}.png' 42 | save_path = os.path.join(image_dir, image_name) 43 | util.save_image(im, save_path, aspect_ratio=aspect_ratio) 44 | ims.append(image_name) 45 | txts.append(label) 46 | links.append(image_name) 47 | webpage.add_images(ims, txts, links, width=width) 48 | 49 | 50 | class Visualizer(): 51 | """This class includes several functions that can display/save images and print/save logging information. 52 | 53 | It uses a Python library 'visdom' for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images. 54 | """ 55 | 56 | def __init__(self, opt): 57 | """Initialize the Visualizer class 58 | 59 | Parameters: 60 | opt -- stores all the experiment flags; needs to be a subclass of BaseOptions 61 | Step 1: Cache the training/test options 62 | Step 2: connect to a visdom server 63 | Step 3: create an HTML object for saveing HTML filters 64 | Step 4: create a logging file to store training losses 65 | """ 66 | self.opt = opt # cache the option 67 | self.display_id = opt.display_id 68 | self.use_html = opt.is_train and not opt.no_html 69 | self.win_size = opt.display_winsize 70 | self.name = opt.name 71 | self.port = opt.display_port 72 | self.saved = False 73 | if self.display_id > 0: # connect to a visdom server given and 74 | import visdom 75 | self.ncols = opt.display_ncols 76 | self.vis = visdom.Visdom(server=opt.display_server, port=opt.display_port, env=opt.display_env) 77 | if not self.vis.check_connection(): 78 | self.create_visdom_connections() 79 | print("You can monitor training in visdom! Go to http://localhost:8097 in your browser.") 80 | 81 | if self.use_html: # create an HTML object at /web/; images will be saved under /web/images/ 82 | self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web') 83 | self.img_dir = os.path.join(self.web_dir, 'images') 84 | print('create web directory %s...' % self.web_dir) 85 | util.mkdirs([self.web_dir, self.img_dir]) 86 | # create a logging file to store training losses 87 | self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt') 88 | with open(self.log_name, "a") as log_file: 89 | now = time.strftime("%c") 90 | log_file.write('================ Training Loss (%s) ================\n' % now) 91 | 92 | def reset(self): 93 | """Reset the self.saved status""" 94 | self.saved = False 95 | 96 | def create_visdom_connections(self): 97 | """If the program could not connect to Visdom server, this function will start a new server at port < self.port > """ 98 | cmd = sys.executable + ' -m visdom.server -p %d &>/dev/null &' % self.port 99 | print('\n\nCould not connect to Visdom server. \n Trying to start a server....') 100 | print('Command: %s' % cmd) 101 | Popen(cmd, shell=True, stdout=PIPE, stderr=PIPE) 102 | 103 | def display_current_results(self, visuals, epoch, save_result): 104 | """Display current results on visdom; save current results to an HTML file. 105 | 106 | Parameters: 107 | visuals (OrderedDict) - - dictionary of images to display or save 108 | epoch (int) - - the current epoch 109 | save_result (bool) - - if save the current results to an HTML file 110 | """ 111 | if self.display_id > 0: # show images in the browser using visdom 112 | ncols = self.ncols 113 | if ncols > 0: # show all the images in one visdom panel 114 | ncols = min(ncols, len(visuals)) 115 | h, w = next(iter(visuals.values())).shape[:2] 116 | table_css = """""" % (w, h) # create a table css 120 | # create a table of images. 121 | title = self.name 122 | label_html = '' 123 | label_html_row = '' 124 | images = [] 125 | idx = 0 126 | for label, image in visuals.items(): 127 | image_numpy = util.tensor2im(image) 128 | label_html_row += '%s' % label 129 | images.append(image_numpy.transpose([2, 0, 1])) 130 | idx += 1 131 | if idx % ncols == 0: 132 | label_html += '%s' % label_html_row 133 | label_html_row = '' 134 | white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255 135 | while idx % ncols != 0: 136 | images.append(white_image) 137 | label_html_row += '' 138 | idx += 1 139 | if label_html_row != '': 140 | label_html += '%s' % label_html_row 141 | 142 | # ADD IMAGES TO VISDOM 143 | try: 144 | self.vis.images(images, nrow=ncols, win=self.display_id + 1, 145 | padding=2, opts=dict(title=title + ' images')) 146 | label_html = '%s
' % label_html 147 | self.vis.text(table_css + label_html, win=self.display_id + 2, 148 | opts=dict(title=title + ' labels')) 149 | except VisdomExceptionBase: 150 | self.create_visdom_connections() 151 | 152 | else: # show each image in a separate visdom panel; 153 | idx = 1 154 | try: 155 | for label, image in visuals.items(): 156 | image_numpy = util.tensor2im(image) 157 | self.vis.image(image_numpy.transpose([2, 0, 1]), opts=dict(title=label), 158 | win=self.display_id + idx) 159 | idx += 1 160 | except VisdomExceptionBase: 161 | self.create_visdom_connections() 162 | 163 | if self.use_html and (save_result or not self.saved): # save images to an HTML file if they haven't been saved. 164 | self.saved = True 165 | # save images to the disk 166 | for label, image in visuals.items(): 167 | image_numpy = util.tensor2im(image) 168 | img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label)) 169 | util.save_image(image_numpy, img_path) 170 | 171 | # update website 172 | webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=1) 173 | for n in range(epoch, 0, -1): 174 | webpage.add_header('epoch [%d]' % n) 175 | ims, txts, links = [], [], [] 176 | 177 | for label, image_numpy in visuals.items(): 178 | image_numpy = util.tensor2im(image) 179 | img_path = 'epoch%.3d_%s.png' % (n, label) 180 | ims.append(img_path) 181 | txts.append(label) 182 | links.append(img_path) 183 | webpage.add_images(ims, txts, links, width=self.win_size) 184 | webpage.save() 185 | 186 | def plot_current_losses(self, epoch, counter_ratio, losses): 187 | """display the current losses on visdom display: dictionary of error labels and values 188 | 189 | Parameters: 190 | epoch (int) -- current epoch 191 | counter_ratio (float) -- progress (percentage) in the current epoch, between 0 to 1 192 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 193 | """ 194 | if not hasattr(self, 'plot_data'): 195 | self.plot_data = {'X': [], 'Y': [], 'legend': list(losses.keys())} 196 | self.plot_data['X'].append(epoch + counter_ratio) 197 | self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']]) 198 | try: 199 | self.vis.line( 200 | X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1), 201 | Y=np.array(self.plot_data['Y']), 202 | opts={ 203 | 'title': self.name + ' loss over time', 204 | 'legend': self.plot_data['legend'], 205 | 'xlabel': 'epoch', 206 | 'ylabel': 'loss'}, 207 | win=self.display_id) 208 | except VisdomExceptionBase: 209 | self.create_visdom_connections() 210 | 211 | # losses: same format as |losses| of plot_current_losses 212 | def print_current_losses(self, epoch, iters, losses, t_comp, t_data, print_func=print): 213 | """print current losses on console; also save the losses to the disk 214 | 215 | Parameters: 216 | epoch (int) -- current epoch 217 | iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch) 218 | losses (OrderedDict) -- training losses stored in the format of (name, float) pairs 219 | t_comp (float) -- computational time per data point (normalized by batch_size) 220 | t_data (float) -- data loading time per data point (normalized by batch_size) 221 | print_func -- how to print loss information (default is the standard print function) 222 | """ 223 | message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data) 224 | for k, v in losses.items(): 225 | message += '%s: %.3f ' % (k, v) 226 | 227 | print_func(message) # print the message 228 | with open(self.log_name, "a") as log_file: 229 | log_file.write('%s\n' % message) # save the message 230 | 231 | @staticmethod 232 | def just_print_losses(epoch, losses, print_func=print): 233 | message = f'(epoch: {epoch:d}) ' 234 | for k, v in losses.items(): 235 | message += '%s: %.3f ' % (k, v) 236 | print_func(message) 237 | --------------------------------------------------------------------------------