├── checkpoints └── MOVE_CHECKPOINTS_HERE.txt ├── .gitignore ├── networks ├── backbones │ ├── __init__.py │ ├── functions.py │ └── fsmunit.py ├── __init__.py ├── base_model.py └── fsmunit_model.py ├── teaser.png ├── options ├── log_options.py ├── train_options.py └── __init__.py ├── util ├── __init__.py └── callbacks.py ├── requirements.yml ├── dataset ├── list_dz_twilight.txt └── create_dataset.py ├── data ├── image_folder.py ├── __init__.py ├── anchor_dataset.py └── base_dataset.py ├── inference_general.py ├── inference_exemplar.py ├── train.py ├── README.md └── LICENSE /checkpoints/MOVE_CHECKPOINTS_HERE.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints/**/*.pth 2 | __pycache__ 3 | -------------------------------------------------------------------------------- /networks/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | from .functions import * 2 | -------------------------------------------------------------------------------- /teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/astra-vision/ManiFest/HEAD/teaser.png -------------------------------------------------------------------------------- /options/log_options.py: -------------------------------------------------------------------------------- 1 | import munch 2 | 3 | def LogOptions(): 4 | lo = munch.Munch() 5 | # Save images each x iters 6 | lo.display_freq = 1000 7 | 8 | # Print info each x iters 9 | lo.print_freq = 10 10 | return lo 11 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | from torch.nn import DataParallel 3 | 4 | import sys 5 | 6 | class DataParallelPassthrough(DataParallel): 7 | def __getattr__(self, name): 8 | try: 9 | return super().__getattr__(name) 10 | except AttributeError: 11 | return getattr(self.module, name) 12 | -------------------------------------------------------------------------------- /requirements.yml: -------------------------------------------------------------------------------- 1 | name: fewshot 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - pytorch==1.8.0 7 | - torchvision==0.9.0 8 | - cudatoolkit=10.2 9 | - pip: 10 | - human-id==0.1.0.post3 11 | - kornia==0.5.11 12 | - munch==2.5.0 13 | - numpy==1.21.2 14 | - pillow==8.3.2 15 | - pytorch-lightning==1.4.9 16 | - wandb==0.12.4 17 | - torchmetrics==0.6.0 18 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | import munch 2 | 3 | 4 | def TrainOptions(): 5 | to = munch.Munch() 6 | # Iterations 7 | to.total_iterations = 150001 8 | # Save checkpoint every x iters 9 | to.save_latest_freq = 5000 10 | 11 | # Save checkpoint every x epochs 12 | to.save_epoch_freq = 5 13 | 14 | # Adam settings 15 | to.beta1 = 0.5 16 | 17 | # gan type 18 | to.gan_mode = 'lsgan' 19 | 20 | return to 21 | -------------------------------------------------------------------------------- /dataset/list_dz_twilight.txt: -------------------------------------------------------------------------------- 1 | GOPR0348_frame_000001_rgb_anon.png 2 | GOPR0348_frame_000004_rgb_anon.png 3 | GOPR0348_frame_000032_rgb_anon.png 4 | GOPR0348_frame_000038_rgb_anon.png 5 | GOPR0348_frame_000081_rgb_anon.png 6 | GOPR0348_frame_000092_rgb_anon.png 7 | GOPR0348_frame_000130_rgb_anon.png 8 | GOPR0348_frame_000159_rgb_anon.png 9 | GOPR0348_frame_000186_rgb_anon.png 10 | GOPR0348_frame_000258_rgb_anon.png 11 | GOPR0348_frame_000309_rgb_anon.png 12 | GOPR0348_frame_000372_rgb_anon.png 13 | GOPR0348_frame_000379_rgb_anon.png 14 | GOPR0348_frame_000388_rgb_anon.png 15 | GOPR0348_frame_000428_rgb_anon.png 16 | GOPR0348_frame_000576_rgb_anon.png 17 | GOPR0348_frame_000588_rgb_anon.png 18 | GOPR0348_frame_000610_rgb_anon.png 19 | GOPR0348_frame_000686_rgb_anon.png 20 | GOPR0348_frame_000699_rgb_anon.png 21 | GOPR0348_frame_000717_rgb_anon.png 22 | GOPR0348_frame_000727_rgb_anon.png 23 | GOPR0348_frame_000747_rgb_anon.png 24 | GOPR0348_frame_000769_rgb_anon.png 25 | GOPR0348_frame_000808_rgb_anon.png 26 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | 3 | from argparse import ArgumentParser as AP 4 | from .train_options import TrainOptions 5 | from .log_options import LogOptions 6 | from networks import get_model_options 7 | from data import get_dataset_options 8 | import munch 9 | 10 | 11 | def get_options(cmdline_opt): 12 | 13 | bo = munch.Munch() 14 | # Set the number of channels of input image 15 | # Set the number of channels of output image 16 | bo.input_nc = 3 17 | bo.output_nc = 3 18 | bo.gpu_ids = cmdline_opt.gpus 19 | # Dataset options 20 | bo.dataset_mode = cmdline_opt.dataset 21 | bo.model = cmdline_opt.model 22 | # Scheduling policies 23 | bo.lr = cmdline_opt.learning_rate 24 | bo.lr_policy = cmdline_opt.scheduler_policy 25 | bo.decay_iters_step = cmdline_opt.decay_iters_step 26 | bo.decay_step_gamma = cmdline_opt.decay_step_gamma 27 | 28 | opts = [] 29 | opts.append(get_model_options(bo.model)()) 30 | opts.append(get_dataset_options(bo.dataset_mode)()) 31 | opts.append(LogOptions()) 32 | opts.append(TrainOptions()) 33 | 34 | # Checks for Nones 35 | opts = [x for x in opts if x] 36 | for x in opts: 37 | bo.update(x) 38 | return bo 39 | -------------------------------------------------------------------------------- /util/callbacks.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from hashlib import md5 3 | import os 4 | from pytorch_lightning.core.saving import save_hparams_to_yaml 5 | 6 | class LogAndCheckpointEveryNSteps(pl.Callback): 7 | """ 8 | Save a checkpoint/logs every N steps 9 | """ 10 | 11 | def __init__( 12 | self, 13 | save_step_frequency=50, 14 | viz_frequency=5, 15 | log_frequency=5 16 | ): 17 | self.save_step_frequency = save_step_frequency 18 | self.viz_frequency = viz_frequency 19 | self.log_frequency = log_frequency 20 | 21 | def on_batch_end(self, trainer: pl.Trainer, _): 22 | global_step = trainer.global_step 23 | 24 | # Saving checkpoint 25 | if global_step % self.save_step_frequency == 0 and global_step != 0: 26 | filename = "iter_{}.pth".format(global_step) 27 | ckpt_path = os.path.join(trainer.checkpoint_callback.dirpath, filename) 28 | trainer.save_checkpoint(ckpt_path) 29 | 30 | # Logging losses 31 | if global_step % self.log_frequency == 0 and global_step != 0: 32 | trainer.model.log_current_losses() 33 | 34 | # Image visualization 35 | if global_step % self.viz_frequency == 0 and global_step != 0: 36 | trainer.model.log_current_visuals() 37 | 38 | class Hash(pl.Callback): 39 | 40 | def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): 41 | if batch_idx == 99: 42 | print("Hash " + md5(pl_module.state_dict()["netG_B.dec.model.4.conv.weight"].cpu().detach().numpy()).hexdigest()) 43 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | This enables dynamic loading of models, similarly to what happens with the dataset. 3 | """ 4 | 5 | import importlib 6 | from networks.base_model import BaseModel 7 | 8 | 9 | def find_model_using_name(model_name): 10 | """Import the module "networks/[model_name]_model.py". 11 | 12 | In the file, the class called DatasetNameModel() will 13 | be instantiated. It has to be a subclass of BaseModel, 14 | and it is case-insensitive. 15 | """ 16 | model_filename = "networks." + model_name + "_model" 17 | modellib = importlib.import_module(model_filename) 18 | model = None 19 | target_model_name = model_name.replace('_', '') + 'model' 20 | for name, cls in modellib.__dict__.items(): 21 | if name.lower() == target_model_name.lower() \ 22 | and issubclass(cls, BaseModel): 23 | model = cls 24 | 25 | if model is None: 26 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 27 | exit(0) 28 | 29 | return model 30 | 31 | 32 | def get_model_options(model_name): 33 | model_filename = "networks." + model_name + "_model" 34 | modellib = importlib.import_module(model_filename) 35 | for name, cls in modellib.__dict__.items(): 36 | if name.lower() == 'modeloptions': 37 | return cls 38 | return None 39 | 40 | def create_model(opt): 41 | """Create a model given the option. 42 | 43 | This function warps the class CustomDatasetDataLoader. 44 | This is the main interface between this package and 'train.py'/'test.py' 45 | 46 | Example: 47 | >>> from networks import create_model 48 | >>> model = create_model(opt) 49 | """ 50 | model = find_model_using_name(opt.model) 51 | instance = model(opt) 52 | return instance 53 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | import os.path 12 | 13 | IMG_EXTENSIONS = [ 14 | '.jpg', '.JPG', '.jpeg', '.JPEG', 15 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 16 | '.tif', '.TIF', '.tiff', '.TIFF', 17 | ] 18 | 19 | 20 | def is_image_file(filename): 21 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 22 | 23 | 24 | def make_dataset(dir, max_dataset_size=float("inf")): 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | for fname in sorted(os.listdir(dir)): 29 | if is_image_file(fname): 30 | path = os.path.join(dir, fname) 31 | images.append(path) 32 | return images[:min(max_dataset_size, len(images))] 33 | 34 | 35 | def default_loader(path): 36 | return Image.open(path).convert('RGB') 37 | 38 | 39 | class ImageFolder(data.Dataset): 40 | 41 | def __init__(self, root, transform=None, return_paths=False, 42 | loader=default_loader): 43 | imgs = make_dataset(root) 44 | if len(imgs) == 0: 45 | raise(RuntimeError("Found 0 images in: " + root + "\n" 46 | "Supported image extensions are: " + 47 | ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /inference_general.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import os 4 | from PIL import Image 5 | from munch import Munch 6 | from torchvision.transforms import ToPILImage, ToTensor 7 | from networks import create_model 8 | 9 | from argparse import ArgumentParser as AP 10 | 11 | def main(ap): 12 | CHECKPOINT = ap.checkpoint 13 | OUTPUT_DIR = ap.output_dir 14 | INPUT_DIR = ap.input_dir 15 | # Load parameters 16 | #with open(os.path.join(root_dir, 'hparams.yaml')) as cfg_file: 17 | ckpt_path = torch.load(CHECKPOINT, map_location='cpu') 18 | hparams = ckpt_path['hyper_parameters'] 19 | opt = Munch(hparams).opt 20 | print(opt.model) 21 | print(opt.seed) 22 | opt.phase = 'val' 23 | opt.no_flip = True 24 | # Load parameters to the model, load the checkpoint 25 | model = create_model(opt) 26 | model = model.load_from_checkpoint(CHECKPOINT) 27 | # Transfer the model to the GPU 28 | model.to('cuda') 29 | val_ds = INPUT_DIR 30 | image_list = os.listdir(val_ds) 31 | os.makedirs('{}/general'.format(OUTPUT_DIR), exist_ok=True) 32 | for index, im_path in enumerate(image_list): 33 | print('{}/{}:{}'.format(index + 1, len(image_list), im_path)) 34 | original_image = Image.open(os.path.join(val_ds, im_path)) 35 | original_size = original_image.size 36 | im = original_image.resize((480, 256), Image.BILINEAR) 37 | style_array = torch.randn(1, 8, 1, 1).cuda() 38 | im = ToTensor()(im) * 2 - 1 39 | im = im.cuda().unsqueeze(0) 40 | result = model.forward(im, style_array, type='global', ref_image=None) 41 | result = torch.clamp(result, -1, 1) 42 | img_global = ToPILImage()((result[0] + 1) / 2).resize(original_size, Image.BILINEAR) 43 | img_global.save('{}/general/{}'.format(OUTPUT_DIR, im_path)) 44 | 45 | 46 | if __name__ == '__main__': 47 | ap = AP() 48 | ap.add_argument('--checkpoint', required=True, type=str, help='checkpoint to load') 49 | ap.add_argument('--output_dir', required=True, type=str, help='where to save images') 50 | ap.add_argument('--input_dir', default='datasets/acdc_day2night/valRC', type=str, help='directory with images to translate') 51 | ap = ap.parse_args() 52 | main(ap) 53 | 54 | 55 | 56 | 57 | -------------------------------------------------------------------------------- /inference_exemplar.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | import torch 3 | import os 4 | from math import pi 5 | from PIL import Image 6 | from munch import Munch 7 | from torchvision.transforms import ToPILImage, ToTensor 8 | from networks import find_model_using_name, create_model 9 | 10 | from argparse import ArgumentParser as AP 11 | 12 | def main(ap): 13 | CHECKPOINT = ap.checkpoint 14 | OUTPUT_DIR = ap.output_dir 15 | INPUT_DIR = ap.input_dir 16 | EXEMPLAR_IMAGE = ap.exemplar_image 17 | # Load parameters 18 | #with open(os.path.join(root_dir, 'hparams.yaml')) as cfg_file: 19 | ckpt_path = torch.load(CHECKPOINT, map_location='cpu') 20 | hparams = ckpt_path['hyper_parameters'] 21 | opt = Munch(hparams).opt 22 | opt.phase = 'val' 23 | opt.no_flip = True 24 | # Load parameters to the model, load the checkpoint 25 | model = create_model(opt) 26 | model = model.load_from_checkpoint(CHECKPOINT) 27 | # Transfer the model to the GPU 28 | model.to('cuda') 29 | val_ds = INPUT_DIR 30 | 31 | im_ref = Image.open(EXEMPLAR_IMAGE).resize((480, 256), Image.BILINEAR) 32 | im_ref = ToTensor()(im_ref) * 2 - 1 33 | im_ref = im_ref.cuda().unsqueeze(0) 34 | 35 | os.makedirs('{}/exemplar'.format(OUTPUT_DIR), exist_ok=True) 36 | for index, im_path in enumerate(os.listdir(val_ds)): 37 | print(index) 38 | im = Image.open(os.path.join(val_ds, im_path)).resize((480, 256), Image.BILINEAR) 39 | im = ToTensor()(im) * 2 - 1 40 | im = im.cuda().unsqueeze(0) 41 | style_array = torch.randn(1, 8, 1, 1).cuda() 42 | result = model.forward(im, style_array, type='exemplar', ref_image=im_ref) 43 | result = torch.clamp(result, -1, 1) 44 | img_global = ToPILImage()((result[0] + 1) / 2) 45 | img_global.save('{}/exemplar/{}'.format(OUTPUT_DIR, im_path)) 46 | 47 | if __name__ == '__main__': 48 | ap = AP() 49 | ap.add_argument('--checkpoint', required=True, type=str, help='checkpoint to load') 50 | ap.add_argument('--output_dir', required=True, type=str, help='where to save images') 51 | ap.add_argument('--input_dir', default='datasets/acdc_day2night/valRC', type=str, help='directory with images to translate') 52 | ap.add_argument('--exemplar_image', required=True, type=str, help='exemplar_image') 53 | ap = ap.parse_args() 54 | main(ap) 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os 3 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # Disables tensorflow loggings 4 | 5 | from options import get_options 6 | from data import create_dataset 7 | from networks import create_model, get_model_options 8 | from argparse import ArgumentParser as AP 9 | 10 | import pytorch_lightning as pl 11 | from pytorch_lightning.loggers import WandbLogger 12 | 13 | from util.callbacks import LogAndCheckpointEveryNSteps 14 | from human_id import generate_id 15 | 16 | def start(cmdline): 17 | 18 | pl.trainer.seed_everything(cmdline.seed) 19 | opt = get_options(cmdline) 20 | opt.phase = 'train' 21 | opt.seed = cmdline.seed 22 | 23 | callbacks = [] 24 | 25 | logger = None 26 | if not cmdline.debug: 27 | logger = WandbLogger(name=cmdline.comment, save_dir="./experiments", project='fewshot-aw') 28 | logger.log_hyperparams(opt) 29 | callbacks.append(LogAndCheckpointEveryNSteps(save_step_frequency=opt.save_latest_freq, 30 | viz_frequency=opt.display_freq, 31 | log_frequency=opt.print_freq)) 32 | root_dir = './experiments' 33 | else: 34 | root_dir = os.path.join('/tmp', generate_id()) 35 | 36 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 37 | model = create_model(opt) # create a model given opt.model and other options 38 | 39 | precision = 16 if cmdline.mixed_precision else 32 40 | 41 | trainer = pl.Trainer(default_root_dir=os.path.join(root_dir, 'checkpoints'), callbacks=callbacks, 42 | gpus=cmdline.gpus, logger=logger, precision=precision, amp_level='01') 43 | 44 | trainer.fit(model, dataset) 45 | 46 | 47 | if __name__ == '__main__': 48 | ap = AP() 49 | ap.add_argument('--id', default=None, type=str, help='Set an existing uuid to resume a training') 50 | ap.add_argument('--debug', default=False, action='store_true', help='Disables experiment saving') 51 | ap.add_argument('--comment', required=True, help='run identifier') 52 | ap.add_argument('--gpus', default=[0], type=int, nargs='+', help='gpus to train on') 53 | ap.add_argument('--model', default='comomunit', type=str, help='Choose model for training') 54 | ap.add_argument('--dataset', default='anchor', type=str, help='Module name of the dataset importer') 55 | ap.add_argument('--learning_rate', default=0.0001, type=float, help='Learning rate') 56 | ap.add_argument('--scheduler_policy', default='step', type=str, help='Scheduler policy') 57 | ap.add_argument('--decay_iters_step', default=200000, type=int, help='Decay iterations step') 58 | ap.add_argument('--decay_step_gamma', default=0.5, type=float, help='Decay step gamma') 59 | ap.add_argument('--seed', default=2, type=int, help='Random seed') 60 | ap.add_argument('--mixed_precision', default=False, action='store_true', help='Use mixed precision to reduce memory usage') 61 | start(ap.parse_args()) 62 | 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # ManiFest: Manifold Deformation for Few-shot Image Translation 2 | 3 | [ManiFest: Manifold Deformation for Few-shot Image Translation](https://arxiv.org/abs/2111.13681) 4 | Fabio Pizzati, Jean-François Lalonde, Raoul de Charette 5 | 6 | ECCV 2022 7 | 8 | ## Preview 9 | 10 | ![teaser](teaser.png) 11 | 12 | ## Citation 13 | To cite our paper, please use 14 | ``` 15 | @inproceedings{pizzati2021manifest, 16 | title={{ManiFest: Manifold Deformation for Few-shot Image Translation}}, 17 | author={Pizzati, Fabio and Lalonde, Jean-François and de Charette, Raoul}, 18 | booktitle={ECCV}, 19 | year={2022} 20 | } 21 | ``` 22 | 23 | ## Prerequisites 24 | Please create an environment using the `requirements.yml` file provided. 25 | 26 | ```conda env create -f requirements.yml``` 27 | 28 | Download the pretrained models and the pretrained VGG used for the style alignment loss by following the link: 29 | 30 | ``` 31 | https://www.rocq.inria.fr/rits_files/computer-vision/manifest/manifest_checkpoints.tar.gz 32 | ``` 33 | 34 | Move the VGG network weights in the `res` folder and the checkpoints in the `checkpoints` one. 35 | 36 | ## Inference 37 | We provide pretrained models for the day2night, day2twilight and clear2fog tasks as described in the paper. 38 | 39 | To perform `general` inference using the pretrained model, please run the following command: 40 | 41 | ``` 42 | python inference_general.py --input_dir --output_dir --checkpoint 43 | ``` 44 | 45 | To perform `exemplar` inference, please use 46 | 47 | ``` 48 | python inference_exemplar.py --input_dir --output_dir --checkpoint --exemplar_image 49 | ``` 50 | 51 | ## Training 52 | 53 | We provide training code for all three tasks. 54 | 55 | Download the [ACDC](https://acdc.vision.ee.ethz.ch/), 56 | [VIPER](https://playing-for-benchmarks.org/) and 57 | [Dark Zurich](https://www.trace.ethz.ch/publications/2019/GCMA_UIoU/) datasets. 58 | 59 | Then, run the scripts provided in the `datasets' directory to create symbolic links. 60 | 61 | ``` 62 | python create_dataset.py --root_acdc --root_viper --root_dz 63 | ``` 64 | 65 | To start training, modify the `data/anchor_dataset.py` file and choose among `day2night`, `day2twilight` 66 | or `clear2fog` in the `root` option. Finally, start the training with 67 | 68 | ``` 69 | python train.py --comment "review training" --model fsmunit --dataset anchor 70 | ``` 71 | 72 | If you don't have a WANDB api key, please run 73 | 74 | ``` 75 | WANDB_MODE=offline python train.py --comment "review training" --model fsmunit --dataset anchor 76 | ``` 77 | 78 | ## Code structure 79 | When extending the code, please consider the following structure. The `train.py` file intializes logging utilities and set up callbacks for model saving and debug. The main training logic 80 | is in `networks/fsmunit_model.py`. In `networks/backbones/fsmunit.py` it's possible to find the architectural components. 81 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | __init__.py 3 | Enables dynamic loading of datasets, depending on an argument. 4 | """ 5 | import importlib 6 | import torch.utils.data 7 | from data.base_dataset import BaseDataset 8 | 9 | 10 | def find_dataset_using_name(dataset_name): 11 | """Import the module "data/[dataset_name]_dataset.py". 12 | 13 | In the file, the class called DatasetNameDataset() will 14 | be instantiated. It has to be a subclass of BaseDataset, 15 | and it is case-insensitive. 16 | """ 17 | dataset_filename = "data." + dataset_name + "_dataset" 18 | datasetlib = importlib.import_module(dataset_filename) 19 | 20 | dataset = None 21 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 22 | for name, cls in datasetlib.__dict__.items(): 23 | if name.lower() == target_dataset_name.lower() \ 24 | and issubclass(cls, BaseDataset): 25 | dataset = cls 26 | 27 | if dataset is None: 28 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 29 | 30 | return dataset 31 | 32 | 33 | def create_dataset(opt): 34 | """Create a dataset given the option. 35 | 36 | This function wraps the class CustomDatasetDataLoader. 37 | This is the main interface between this package and 'train.py'/'test.py' 38 | 39 | Example: 40 | >>> from data import create_dataset 41 | >>> dataset = create_dataset(opt) 42 | """ 43 | data_loader = CustomDatasetDataLoader(opt) 44 | dataset = data_loader.load_data() 45 | return dataset 46 | 47 | def get_dataset_options(dataset_name): 48 | dataset_filename = "data." + dataset_name + "_dataset" 49 | datalib = importlib.import_module(dataset_filename) 50 | for name, cls in datalib.__dict__.items(): 51 | if name.lower() == 'datasetoptions': 52 | return cls 53 | return None 54 | 55 | class CustomDatasetDataLoader(): 56 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 57 | 58 | def __init__(self, opt): 59 | """Initialize this class 60 | 61 | Step 1: create a dataset instance given the name [dataset_mode] 62 | Step 2: create a multi-threaded data loader. 63 | """ 64 | self.opt = opt 65 | dataset_class = find_dataset_using_name(opt.dataset_mode) 66 | self.dataset = dataset_class(opt) 67 | self.dataloader = torch.utils.data.DataLoader( 68 | self.dataset, 69 | batch_size=opt.batch_size, 70 | shuffle=not opt.serial_batches, 71 | num_workers=int(opt.num_threads)) 72 | 73 | def load_data(self): 74 | return self 75 | 76 | def __len__(self): 77 | """Return the number of data in the dataset""" 78 | return min(len(self.dataset), self.opt.max_dataset_size) 79 | 80 | def __iter__(self): 81 | """Return a batch of data""" 82 | for i, data in enumerate(self.dataloader): 83 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 84 | break 85 | yield data 86 | -------------------------------------------------------------------------------- /dataset/create_dataset.py: -------------------------------------------------------------------------------- 1 | # This script creates datasets. We'll link every image in the target domain for ACDC experiemnts, and select only 25 2 | # by using the dataloader. This was necessary to run ablation studies on seeds. 3 | 4 | import os 5 | from argparse import ArgumentParser as AP 6 | 7 | # Viper sequence codes 8 | day_codes = ['001', '002', '003', '004', '005', 9 | '006', '044', '045', '046', '047', 10 | '048', '049', '050', '051', '065', 11 | '066', '067', '068', '069', ] 12 | sunset_codes = ['007', '014', '015', '016', '017', '018', 13 | '019', '020', '021', '022', '023', '024', 14 | '025', '026', '027', '028', '029'] 15 | night_codes = ['008', '009', '010', '011', '012', '013', 16 | '052', '053', '054', '055', '056', '057', 17 | '058', '070', '071', '072', '073', '074', 18 | '075', '076', '077', ] 19 | rain_codes = ['030', '031', '059', '060', '061', '062', '063', '064'] 20 | snow_codes = ['032', '033', '034', '035', '036', '037', 21 | '038', '039', '040', '041', '042', '043'] 22 | 23 | def get_base_acdc(root, condition, subset): 24 | base = os.path.join(root, 'rgb_anon', condition) 25 | base_source = os.path.join(base, '{}_ref'.format(subset)) 26 | base_target = os.path.join(base, '{}'.format(subset)) 27 | return base, base_source, base_target 28 | 29 | def create_ds(root, viper_root, name, subset, condition, anchor, type='acdc'): 30 | os.makedirs(name, exist_ok=True) 31 | os.makedirs('{}/{}S'.format(name, subset), exist_ok=True) 32 | os.makedirs('{}/{}T'.format(name, subset), exist_ok=True) 33 | os.makedirs('{}/{}A'.format(name, subset), exist_ok=True) 34 | if type == 'acdc': 35 | base, base_source, base_target = get_base_acdc(root, condition, subset) 36 | link(base_source, '{}/{}S'.format(name, subset)) 37 | link(base_target, '{}/{}T'.format(name, subset)) 38 | elif type == 'dz': 39 | with open('list_dz_twilight.txt') as file: 40 | dz_gt_set = file.read().splitlines() 41 | base_source = os.path.join(root, 'rgb_anon', 'train', 'day') 42 | link(base_source, '{}/{}S'.format(name, subset)) 43 | base_target = os.path.join(root, 'rgb_anon', 'train', 'twilight', 'GOPR0348') 44 | files_target = [os.path.join(base_target, x) for x in dz_gt_set] 45 | for f in files_target: 46 | os.symlink(f, '{}/{}T/{}'.format(name, subset, os.path.basename(f))) 47 | 48 | viper_root = os.path.join(viper_root, '{}/img'.format(subset)) 49 | print(anchor) 50 | if anchor == 'day': 51 | seqs = day_codes 52 | elif anchor == 'night': 53 | seqs = night_codes 54 | elif anchor == 'rain': 55 | seqs = rain_codes 56 | elif anchor == 'sunset': 57 | seqs = sunset_codes 58 | elif anchor == 'snow': 59 | seqs = snow_codes 60 | else: 61 | raise('Wrong anchor!') 62 | 63 | # create synthetic anchor 64 | for x in os.listdir(viper_root): 65 | dir_path = os.path.join(viper_root, x) 66 | if x not in seqs: 67 | print("{} skipped".format(x)) 68 | continue 69 | for y in os.listdir(dir_path): 70 | os.symlink(os.path.join(dir_path, y), os.path.join(name, '{}A'.format(subset), y)) 71 | 72 | 73 | def link(directory, output_dir): 74 | for x in os.listdir(directory): 75 | dir_path = os.path.join(directory, x) 76 | if os.path.isdir(dir_path): 77 | for y in os.listdir(dir_path): 78 | os.symlink(os.path.join(dir_path, y), os.path.join(output_dir, y)) 79 | 80 | if __name__ == '__main__': 81 | ap = AP() 82 | ap.add_argument('--root_acdc', help='ACDC root', required=True, type=str) 83 | ap.add_argument('--root_dz', help='DZ root', required=True, type=str) 84 | ap.add_argument('--root_viper', help='Viper root', required=True, type=str) 85 | opt = ap.parse_args() 86 | root_acdc = opt.root_acdc 87 | root_viper = opt.root_viper 88 | 89 | create_ds(opt.root_acdc, opt.root_viper, 'day2night', 'train', 'night', 'night') 90 | create_ds(opt.root_acdc, opt.root_viper, 'clear2fog', 'train', 'fog', 'day') 91 | create_ds(opt.root_dz, opt.root_viper, 'day2twilight', 'train', None, 'night', type='dz') 92 | 93 | -------------------------------------------------------------------------------- /data/anchor_dataset.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | from data.base_dataset import BaseDataset, get_transform 3 | from data.image_folder import make_dataset 4 | from PIL import Image 5 | import random 6 | import pytorch_lightning as pl 7 | import munch 8 | 9 | def DatasetOptions(): 10 | do = munch.Munch() 11 | do.serial_batches = False 12 | do.num_threads = 8 13 | do.batch_size = 1 14 | do.load_size = 480 15 | do.dataroot = 'dataset/day2twilight' 16 | do.crop_size = 0 17 | do.max_dataset_size = float('inf') 18 | do.preprocess = 'none' 19 | do.no_flip = False 20 | do.num_images = 25 21 | do.num_classes_A = 1 22 | do.num_classes_B = 2 23 | do.size_resize = (480, 256) 24 | 25 | return do 26 | 27 | class AnchorDataset(BaseDataset): 28 | 29 | 30 | def __init__(self, opt): 31 | BaseDataset.__init__(self, opt) 32 | self.dir_A_id = os.path.join(opt.dataroot, opt.phase + 'S') # create a path '/path/to/data/trainA' 33 | self.dir_A_m = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainB' 34 | self.dir_T = os.path.join(opt.dataroot, opt.phase + 'T') # create a path '/path/to/data/trainB' 35 | 36 | self.S_paths = sorted(make_dataset(self.dir_A_id, opt.max_dataset_size)) # load images from '/path/to/data/trainA' 37 | self.A_m_paths = sorted(make_dataset(self.dir_A_m, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 38 | self.T_paths = sorted(make_dataset(self.dir_T, opt.max_dataset_size)) # load images from '/path/to/data/trainB' 39 | 40 | self.S_size = len(self.S_paths) # get the size of dataset A 41 | self.A_m_size = len(self.A_m_paths) # get the size of dataset B 42 | self.GT_size = len(self.T_paths) # get the size of dataset B 43 | 44 | random.shuffle(self.S_paths) 45 | random.shuffle(self.A_m_paths) 46 | random.shuffle(self.T_paths) 47 | 48 | input_nc = self.opt.input_nc # get the number of channels of input image 49 | output_nc = self.opt.output_nc # get the number of channels of output image 50 | self.transform_A_id = get_transform(self.opt, grayscale=(input_nc == 1)) 51 | self.transform_A_m = get_transform(self.opt, grayscale=(output_nc == 1)) 52 | self.transform_T = get_transform(self.opt, grayscale=(output_nc == 1)) 53 | 54 | def __getitem__(self, index): 55 | """Return a data point and its metadata information. 56 | 57 | Parameters: 58 | index (int) -- a random integer for data indexing 59 | 60 | Returns a dictionary that contains A, B, S_paths and A_m_paths 61 | A (tensor) -- an image in the input domain 62 | B (tensor) -- its corresponding image in the target domain 63 | S_paths (str) -- image paths 64 | A_m_paths (str) -- image paths 65 | """ 66 | S_path = self.S_paths[index % self.S_size] # make sure index is within then range 67 | index_T = random.randint(0, self.opt.num_images - 1) 68 | class_anchor = random.randint(0, self.opt.num_classes_B - 1) # Multi-class support 69 | if class_anchor == 0: 70 | index_A = random.randint(0, len(self.A_m_paths) - 1) 71 | A_path = self.A_m_paths[index_A] 72 | elif class_anchor == 1: 73 | index_A = random.randint(0, len(self.S_paths) - 1) 74 | A_path = self.S_paths[index_A] 75 | 76 | T_path = self.T_paths[index_T] 77 | 78 | S_img = Image.open(S_path).convert('RGB').resize(self.opt.size_resize, Image.BILINEAR) 79 | A_img = Image.open(A_path).convert('RGB').resize(self.opt.size_resize, Image.BILINEAR) 80 | T_img = Image.open(T_path).convert('RGB').resize(self.opt.size_resize, Image.BILINEAR) 81 | 82 | S = self.transform_A_id(S_img) 83 | T = self.transform_T(T_img) 84 | 85 | A = self.transform_A_m(A_img) 86 | 87 | return {'S': S, 'A': A, 'class_anchor': class_anchor, 'T': T, 88 | 'S_paths': S_path, 'A_m_paths': A_path, 'T_paths': T_path, 'S_class': 0} 89 | 90 | def __len__(self): 91 | """Return the total number of images in the dataset. 92 | 93 | As we have two datasets with potentially different number of images, 94 | we take a maximum of 95 | """ 96 | return max(self.S_size, self.A_m_size) 97 | -------------------------------------------------------------------------------- /networks/backbones/functions.py: -------------------------------------------------------------------------------- 1 | """ 2 | functions.py 3 | Here we get helper functions to 1) get schedulers given an option 2) initialize the network weights. 4 | """ 5 | 6 | import torch 7 | from torch.nn import init 8 | from torch.optim import lr_scheduler 9 | 10 | ############################################################################### 11 | # Helper Functions 12 | ############################################################################### 13 | 14 | def get_scheduler(optimizer, opt): 15 | """Return a learning rate scheduler 16 | 17 | Parameters: 18 | optimizer -- the optimizer of the network 19 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  20 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 21 | 22 | For 'linear', we keep the same learning rate for the first epochs 23 | and linearly decay the rate to zero over the next epochs. 24 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 25 | See https://pytorch.org/docs/stable/optim.html for more details. 26 | """ 27 | if opt.lr_policy == 'linear': 28 | def lambda_rule(iteration): 29 | lr_l = 1.0 - max(0, logger.get_global_step() - opt.static_iters) / float(opt.decay_iters + 1) 30 | return lr_l 31 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 32 | elif opt.lr_policy == 'step': 33 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.decay_iters_step, gamma=0.1) 34 | elif opt.lr_policy == 'plateau': 35 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 36 | elif opt.lr_policy == 'cosine': 37 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 38 | else: 39 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 40 | return scheduler 41 | 42 | 43 | def init_weights(net, init_type='normal', init_gain=0.02): 44 | """Initialize network weights. 45 | 46 | Parameters: 47 | net (network) -- network to be initialized 48 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 49 | init_gain (float) -- scaling factor for normal, xavier and orthogonal. 50 | 51 | We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might 52 | work better for some applications. Feel free to try yourself. 53 | """ 54 | def init_func(m): # define the initialization function 55 | classname = m.__class__.__name__ 56 | if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 57 | if init_type == 'normal': 58 | init.normal_(m.weight.data, 0.0, init_gain) 59 | elif init_type == 'xavier': 60 | init.xavier_normal_(m.weight.data, gain=init_gain) 61 | elif init_type == 'kaiming': 62 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 63 | elif init_type == 'orthogonal': 64 | init.orthogonal_(m.weight.data, gain=init_gain) 65 | else: 66 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 67 | if hasattr(m, 'bias') and m.bias is not None: 68 | init.constant_(m.bias.data, 0.0) 69 | elif classname.find('BatchNorm2d') != -1: # BatchNorm Layer's weight is not a matrix; only normal distribution applies. 70 | init.normal_(m.weight.data, 1.0, init_gain) 71 | init.constant_(m.bias.data, 0.0) 72 | 73 | net.apply(init_func) # apply the initialization function 74 | 75 | 76 | def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]): 77 | """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights 78 | Parameters: 79 | net (network) -- the network to be initialized 80 | init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal 81 | gain (float) -- scaling factor for normal, xavier and orthogonal. 82 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 83 | 84 | Return an initialized network. 85 | """ 86 | init_weights(net, init_type, init_gain=init_gain) 87 | return net 88 | -------------------------------------------------------------------------------- /networks/base_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_model.py 3 | Abstract definition of a model, where helper functions as image extraction and gradient propagation are defined. 4 | """ 5 | 6 | from collections import OrderedDict 7 | from abc import abstractmethod 8 | 9 | import pytorch_lightning as pl 10 | from torch.optim import lr_scheduler 11 | 12 | from torchvision.transforms import ToPILImage 13 | import wandb 14 | class BaseModel(pl.LightningModule): 15 | 16 | def __init__(self, opt): 17 | super().__init__() 18 | #self.hparams = opt 19 | self.opt = opt 20 | self.gpu_ids = opt.gpu_ids 21 | self.loss_names = [] 22 | self.model_names = [] 23 | self.visual_names = [] 24 | self.image_paths = [] 25 | self.save_hyperparameters() 26 | self.schedulers = [] 27 | self.metric = 0 # used for learning rate policy 'plateau' 28 | 29 | @abstractmethod 30 | def set_input(self, input): 31 | pass 32 | 33 | def eval(self): 34 | for name in self.model_names: 35 | if isinstance(name, str): 36 | net = getattr(self, 'net' + name) 37 | net.eval() 38 | 39 | def compute_visuals(self): 40 | pass 41 | 42 | def get_image_paths(self): 43 | return self.image_paths 44 | 45 | def update_learning_rate(self): 46 | for scheduler in self.schedulers: 47 | if self.opt.lr_policy == 'plateau': 48 | scheduler.step(self.metric) 49 | else: 50 | scheduler.step() 51 | 52 | lr = self.optimizers[0].param_groups[0]['lr'] 53 | return lr 54 | 55 | def get_current_visuals(self): 56 | visual_ret = OrderedDict() 57 | for name in self.visual_names: 58 | if isinstance(name, str): 59 | visual_ret[name] = (getattr(self, name).detach() + 1) / 2 60 | return visual_ret 61 | def log_current_losses(self): 62 | losses = '\n' 63 | for name in self.loss_names: 64 | if isinstance(name, str): 65 | loss_value = float(getattr(self, 'loss_' + name)) 66 | self.logger.experiment.log({'loss_{}'.format(name): loss_value}, self.trainer.global_step) 67 | losses += 'loss_{}={:.4f}\t'.format(name, loss_value) 68 | print(losses) 69 | 70 | def log_current_visuals(self): 71 | visuals = self.get_current_visuals() 72 | for key, viz in visuals.items(): 73 | img = ToPILImage()(viz[0].cpu()) 74 | self.logger.experiment.log({'img_{}'.format(key): wandb.Image(img)}, self.trainer.global_step) 75 | 76 | def get_scheduler(self, opt, optimizer): 77 | if opt.lr_policy == 'linear': 78 | def lambda_rule(iter): 79 | lr_l = 1.0 - max(0, self.trainer.global_step - opt.static_iters) / float(opt.decay_iters + 1) 80 | return lr_l 81 | 82 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 83 | elif opt.lr_policy == 'step': 84 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.decay_iters_step, gamma=0.5) 85 | elif opt.lr_policy == 'plateau': 86 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 87 | elif opt.lr_policy == 'cosine': 88 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 89 | else: 90 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 91 | return scheduler 92 | 93 | def print_networks(self): 94 | for name in self.model_names: 95 | if isinstance(name, str): 96 | net = getattr(self, 'net' + name) 97 | num_params = 0 98 | for param in net.parameters(): 99 | num_params += param.numel() 100 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 101 | 102 | def get_optimizer_dict(self): 103 | return_dict = {} 104 | for index, opt in enumerate(self.optimizers): 105 | return_dict['Optimizer_{}'.format(index)] = opt 106 | return return_dict 107 | 108 | def set_requires_grad(self, nets, requires_grad=False): 109 | if not isinstance(nets, list): 110 | nets = [nets] 111 | for net in nets: 112 | if net is not None: 113 | for param in net.parameters(): 114 | param.requires_grad = requires_grad 115 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | base_dataset.py: 3 | All datasets are a subclass of BaseDataset and implement abstract methods. 4 | Includes augmentation strategies which can be used at sampling time. 5 | """ 6 | import random 7 | import numpy as np 8 | import torch.utils.data as data 9 | from PIL import Image 10 | import torchvision.transforms as transforms 11 | from abc import ABC, abstractmethod 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | """ 21 | 22 | def __init__(self, opt): 23 | """Initialize the class; save the options in the class 24 | 25 | Parameters: 26 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 27 | """ 28 | self.opt = opt 29 | self.root = opt.dataroot 30 | 31 | @abstractmethod 32 | def __len__(self): 33 | """Return the total number of images in the dataset.""" 34 | return 0 35 | 36 | @abstractmethod 37 | def __getitem__(self, index): 38 | """Return a data point and its metadata information. 39 | 40 | Parameters: 41 | index - - a random integer for data indexing 42 | 43 | Returns: 44 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 45 | """ 46 | pass 47 | 48 | 49 | def get_params(opt, size): 50 | w, h = size 51 | new_h = h 52 | new_w = w 53 | if opt.preprocess == 'resize_and_crop': 54 | new_h = new_w = opt.load_size 55 | elif opt.preprocess == 'scale_width_and_crop': 56 | new_w = opt.load_size 57 | new_h = opt.load_size * h // w 58 | 59 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 60 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 61 | flip = random.random() > 0.5 62 | 63 | return {'crop_pos': (x, y), 'flip': flip} 64 | 65 | 66 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 67 | transform_list = [] 68 | if grayscale: 69 | transform_list.append(transforms.Grayscale(1)) 70 | if 'resize' in opt.preprocess: 71 | osize = [opt.load_size, opt.load_size] 72 | transform_list.append(transforms.Resize(osize, method)) 73 | elif 'scale_width' in opt.preprocess: 74 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 75 | 76 | if 'crop' in opt.preprocess: 77 | if params is None: 78 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 79 | else: 80 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 81 | 82 | if opt.preprocess == 'none': 83 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 84 | 85 | if not opt.no_flip: 86 | if params is None: 87 | transform_list.append(transforms.RandomHorizontalFlip()) 88 | elif params['flip']: 89 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 90 | 91 | if convert: 92 | transform_list += [transforms.ToTensor()] 93 | if grayscale: 94 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 95 | else: 96 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 97 | return transforms.Compose(transform_list) 98 | 99 | def get_masked_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 100 | transform_list = [] 101 | if grayscale: 102 | transform_list.append(transforms.Grayscale(1)) 103 | if 'resize' in opt.preprocess: 104 | osize = [opt.load_size, opt.load_size] 105 | transform_list.append(transforms.Resize(osize, method)) 106 | elif 'scale_width' in opt.preprocess: 107 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 108 | 109 | if 'crop' in opt.preprocess: 110 | if params is None: 111 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 112 | else: 113 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 114 | 115 | if opt.preprocess == 'none': 116 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 117 | 118 | if not opt.no_flip: 119 | if params is None: 120 | transform_list.append(transforms.RandomHorizontalFlip()) 121 | elif params['flip']: 122 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 123 | 124 | if convert: 125 | transform_list += [transforms.ToTensor()] 126 | if grayscale: 127 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 128 | else: 129 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 130 | return transforms.Compose(transform_list) 131 | 132 | 133 | def __make_power_2(img, base, method=Image.BICUBIC): 134 | ow, oh = img.size 135 | h = int(round(oh / base) * base) 136 | w = int(round(ow / base) * base) 137 | if h == oh and w == ow: 138 | return img 139 | 140 | __print_size_warning(ow, oh, w, h) 141 | return img.resize((w, h), method) 142 | 143 | 144 | def __scale_width(img, target_size, crop_size, method=Image.BICUBIC): 145 | ow, oh = img.size 146 | if ow == target_size and oh >= crop_size: 147 | return img 148 | w = target_size 149 | h = int(max(target_size * oh / ow, crop_size)) 150 | return img.resize((w, h), method) 151 | 152 | 153 | def __crop(img, pos, size): 154 | ow, oh = img.size 155 | x1, y1 = pos 156 | tw = th = size 157 | if (ow > tw or oh > th): 158 | return img.crop((x1, y1, x1 + tw, y1 + th)) 159 | return img 160 | 161 | 162 | def __flip(img, flip): 163 | if flip: 164 | return img.transpose(Image.FLIP_LEFT_RIGHT) 165 | return img 166 | 167 | 168 | def __print_size_warning(ow, oh, w, h): 169 | """Print warning information about image size(only print once)""" 170 | if not hasattr(__print_size_warning, 'has_printed'): 171 | logger.warning("The image size needs to be a multiple of 4. " 172 | "The loaded image size was (%d, %d), so it was adjusted to " 173 | "(%d, %d). This adjustment will be done to all images " 174 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 175 | __print_size_warning.has_printed = True 176 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2022 Vislab/Ambarella 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /networks/fsmunit_model.py: -------------------------------------------------------------------------------- 1 | ''' 2 | In the provided file, 3 | Source images are identified by "A" 4 | Anchor images are identified by "B" 5 | Few-shot target by "GT" 6 | 7 | We exploit the additional loss described in the supp, called loss_vgg_gan 8 | L_style is called loss_vgg_style 9 | ''' 10 | 11 | import torch 12 | import itertools 13 | from .base_model import BaseModel 14 | from .backbones import fsmunit as networks 15 | import munch 16 | import random 17 | import kornia.augmentation as K 18 | 19 | def ModelOptions(): 20 | mo = munch.Munch() 21 | # Generator 22 | mo.gen_dim = 64 23 | mo.style_dim = 8 24 | mo.gen_activ = 'relu' 25 | mo.n_downsample = 2 26 | mo.n_res = 4 27 | mo.gen_pad_type = 'reflect' 28 | mo.mlp_dim = 256 29 | 30 | # Discriminiator 31 | mo.disc_dim = 64 32 | mo.disc_norm = 'none' 33 | mo.disc_activ = 'lrelu' 34 | mo.disc_n_layer = 4 35 | mo.num_scales = 3 36 | mo.disc_pad_type = 'reflect' 37 | 38 | # Initialization 39 | mo.init_type_gen = 'kaiming' 40 | mo.init_type_disc = 'normal' 41 | mo.init_gain = 0.02 42 | 43 | # Weights 44 | mo.lambda_gan = 1 45 | mo.lambda_gan_patches = 1 46 | mo.lambda_rec_image = 10 47 | mo.lambda_rec_style = 1 48 | mo.lambda_rec_content = 1 49 | mo.lambda_rec_cycle = 10 50 | mo.lambda_vgg = 0.5 51 | 52 | mo.lambda_vgg_style = 2e-4 53 | mo.lambda_vgg_fs_res = 0.5 54 | mo.lambda_vgg_fs = 0.5 55 | return mo 56 | 57 | class FSMunitModel(BaseModel): 58 | 59 | def __init__(self, opt): 60 | BaseModel.__init__(self, opt) 61 | # specify the training losses you want to print out. The training/test scripts will call 62 | self.loss_names = ['D_A', 'G_A', 'cycle_A', 'rec_A', 'rec_style_A', 'rec_content_A', 'vgg_A','G_patches', 'D_patches', 63 | 'D_B', 'G_B', 'cycle_B', 'rec_B', 'rec_style_B', 'rec_content_B', 'vgg_B', 'vgg_style'] 64 | # specify the images you want to save/display. The training/test scripts will call 65 | visual_names_A = ['real_A', 'fake_B', 'fake_B_weights', 'fake_B_residual', 'fake_B_style', 'rec_A_img', 'rec_A_cycle'] 66 | visual_names_B = ['real_B', 'fake_A', 'rec_B_img', 'rec_B_cycle', 'real_GT'] 67 | 68 | self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B 69 | # specify the models you want to save to the disk. The training/test scripts will call and . 70 | self.model_names = ['G_A', 'G_B', 'D_A', 'D_B'] 71 | 72 | 73 | self.netG_A = networks.define_G_munit(opt.input_nc, opt.output_nc, opt.gen_dim, opt.style_dim, opt.n_downsample, 74 | opt.n_res, opt.gen_pad_type, opt.mlp_dim, opt.gen_activ, opt.init_type_gen, 75 | opt.init_gain, self.gpu_ids, num_classes=self.opt.num_classes_A) 76 | self.netG_B = networks.define_G_munit(opt.input_nc, opt.output_nc, opt.gen_dim, opt.style_dim, opt.n_downsample, 77 | opt.n_res, opt.gen_pad_type, opt.mlp_dim, opt.gen_activ, opt.init_type_gen, 78 | opt.init_gain, self.gpu_ids, num_classes=self.opt.num_classes_B) 79 | 80 | self.netD_A = networks.define_D_munit(opt.output_nc, opt.disc_dim, opt.disc_norm, opt.disc_activ, opt.disc_n_layer, 81 | opt.gan_mode, opt.num_scales, opt.disc_pad_type, opt.init_type_disc, 82 | opt.init_gain, self.gpu_ids, num_classes=self.opt.num_classes_B) 83 | self.netD_B = networks.define_D_munit(opt.output_nc, opt.disc_dim, opt.disc_norm, opt.disc_activ, opt.disc_n_layer, 84 | opt.gan_mode, opt.num_scales, opt.disc_pad_type, opt.init_type_disc, 85 | opt.init_gain, self.gpu_ids, num_classes=self.opt.num_classes_A) 86 | 87 | # Patch-based discriminator 88 | self.netD_patches = networks.define_D_munit(opt.output_nc, opt.disc_dim, opt.disc_norm, opt.disc_activ, opt.disc_n_layer, 89 | opt.gan_mode, 2, opt.disc_pad_type, opt.init_type_disc, 90 | opt.init_gain, self.gpu_ids, num_classes=1) 91 | 92 | self.num_classes_B = self.opt.num_classes_B 93 | self.num_classes_A = self.opt.num_classes_A 94 | 95 | if opt.lambda_vgg > 0: 96 | self.instance_norm = torch.nn.InstanceNorm2d(512) 97 | self.vgg = networks.Vgg16() 98 | # TODO: pass pretrained weights path as argument 99 | self.vgg.load_state_dict(torch.load('checkpoints/ex-i2iwand/res/vgg_imagenet.pth')) 100 | self.vgg.to(self.device) 101 | self.vgg.eval() 102 | for param in self.vgg.parameters(): 103 | param.requires_grad = False 104 | self.weights = torch.nn.Parameter(torch.zeros(self.num_classes_B).cuda()) 105 | 106 | self.augmentations_patches = torch.nn.Sequential( 107 | K.RandomHorizontalFlip(), 108 | K.RandomAffine(360, p=1), 109 | K.RandomCrop(size=(64, 64)) 110 | ) 111 | 112 | self.alternate = False 113 | 114 | def configure_optimizers(self): 115 | opt_G = torch.optim.Adam([{'params': self.netG_A.parameters()}, {'params': self.netG_B.parameters()}, {'params': [self.weights], 'lr': 0.01}], 116 | weight_decay=0.0001, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 117 | opt_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters(), self.netD_patches.parameters()), 118 | weight_decay=0.0001, lr=self.opt.lr, betas=(self.opt.beta1, 0.999)) 119 | scheduler_G = self.get_scheduler(self.opt, opt_G) 120 | scheduler_D = self.get_scheduler(self.opt, opt_D) 121 | return [opt_D, opt_G], [scheduler_D, scheduler_G] 122 | 123 | def reconCriterion(self, input, target): 124 | return torch.mean(torch.abs(input - target)) 125 | 126 | def set_input(self, input): 127 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 128 | 129 | Parameters: 130 | input (dict): include the data itself and its metadata information. 131 | 132 | The option 'direction' can be used to swap domain A and domain B. 133 | """ 134 | self.real_A = input['S'] 135 | self.real_B = input['A'] 136 | self.class_B = input['class_anchor'] 137 | self.class_A = input['S_class'] 138 | self.real_GT = input['T'] 139 | self.image_paths = input['S_paths'] 140 | self.random_class_B = random.randint(0, self.num_classes_B - 1) 141 | self.random_class_A = random.randint(0, self.num_classes_A - 1) 142 | 143 | def __vgg_preprocess(self, batch): 144 | tensortype = type(batch) 145 | (r, g, b) = torch.chunk(batch, 3, dim=1) 146 | batch = torch.cat((b, g, r), dim=1) # convert RGB to BGR 147 | batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] 148 | mean = tensortype(batch.data.size()).to(self.device) 149 | mean[:, 0, :, :] = 103.939 150 | mean[:, 1, :, :] = 116.779 151 | mean[:, 2, :, :] = 123.680 152 | batch = batch.sub(mean) # subtract mean 153 | return batch 154 | 155 | def __calc_mean_std(self, feat, eps=1e-5): 156 | # eps is a small value added to the variance to avoid divide-by-zero. 157 | size = feat.size() 158 | assert (len(size) == 4) 159 | N, C = size[:2] 160 | feat_var = feat.view(N, C, -1).var(dim=2) + eps 161 | feat_std = feat_var.sqrt().view(N, C, 1, 1) 162 | feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1) 163 | return feat_mean, feat_std 164 | 165 | def __compute_vgg_loss_onlyperceptual(self, img, target): 166 | img_vgg = self.__vgg_preprocess(img) 167 | target_vgg = self.__vgg_preprocess(target) 168 | img_fea = self.vgg(img_vgg) 169 | target_fea = self.vgg(target_vgg) 170 | return torch.mean((self.instance_norm(img_fea) - self.instance_norm(target_fea)) ** 2) 171 | 172 | def __get_vgg_style_loss(self, f1, f2): 173 | style_loss = 0 174 | for x, y in zip(f1, f2): 175 | x_mean, x_std = self.__calc_mean_std(x) 176 | y_mean, y_std = self.__calc_mean_std(y) 177 | style_loss += torch.mean((x_mean - y_mean) ** 2) + torch.mean((x_std - y_std) ** 2) 178 | return style_loss 179 | 180 | def get_fewshot_style_code(self, im): 181 | _, f = self.vgg(im, with_style=True) 182 | means, stds = [],[] 183 | for x in f: 184 | x_mean, x_std = self.__calc_mean_std(x) 185 | means.append(x_mean.squeeze()) 186 | stds.append(x_std.squeeze()) 187 | style_feat = torch.cat(means + stds, 0) 188 | return style_feat.unsqueeze(0).unsqueeze(-1).unsqueeze(-1) 189 | 190 | def __compute_vgg_loss(self, img, target, style): 191 | img_vgg = self.__vgg_preprocess(img) 192 | target_vgg = self.__vgg_preprocess(target) 193 | style_vgg = self.__vgg_preprocess(style) 194 | img_fea, style_feat_img = self.vgg(img_vgg, with_style=True) 195 | target_fea = self.vgg(target_vgg) 196 | _, style_feat_style = self.vgg(style_vgg, with_style=True) 197 | content_loss = torch.mean((self.instance_norm(img_fea) - self.instance_norm(target_fea)) ** 2) 198 | style_loss = self.__get_vgg_style_loss(style_feat_img, style_feat_style) 199 | return content_loss + style_loss * self.opt.lambda_vgg_style 200 | 201 | def forward(self, im, style_B_fake = None, type='global', ref_image=None, weights = None): 202 | """Run forward pass; called by both functions and .""" 203 | # Random style sampling 204 | if style_B_fake is None: 205 | style_B_fake = torch.randn(im.size(0), self.opt.style_dim, 1, 1).to(self.device) 206 | 207 | if weights is None: 208 | weights = torch.softmax(self.weights, 0) 209 | 210 | # Encoding 211 | self.content_A, self.style_A_real = self.netG_A.encode(im) 212 | if type == 'global': 213 | self.fake_B_weights, self.fake_B_residual = self.netG_B.decode_weighted_global(self.content_A, 214 | style_B_fake, 215 | style_B_fake, 216 | weights) 217 | elif type == 'exemplar': 218 | 219 | self.fake_B_weights, self.fake_B_residual = self.netG_B.decode_weighted_exemplar(self.content_A, 220 | style_B_fake, 221 | self.get_fewshot_style_code( 222 | ref_image), 223 | weights) 224 | else: 225 | raise NotImplementedError 226 | 227 | self.fake_B = self.fake_B_weights + self.fake_B_residual 228 | return self.fake_B 229 | 230 | def forward_train(self): 231 | """Run forward pass; called by both functions and .""" 232 | # Random style sampling 233 | self.style_A_fake = torch.randn(self.real_A.size(0), self.opt.style_dim, 1, 1).to(self.device) 234 | self.style_B_fake = torch.randn(self.real_B.size(0), self.opt.style_dim, 1, 1).to(self.device) 235 | 236 | # Encoding 237 | self.content_A, self.style_A_real = self.netG_A.encode(self.real_A) 238 | self.content_B, self.style_B_real = self.netG_B.encode(self.real_B) 239 | 240 | # Reconstruction 241 | self.rec_A_img = self.netG_A.decode(self.content_A, self.style_A_real, self.class_A) 242 | self.rec_B_img = self.netG_B.decode(self.content_B, self.style_B_real, self.class_B) 243 | 244 | # Cross domain 245 | self.fake_B = self.netG_B.decode(self.content_A, self.style_B_fake, self.class_B) 246 | self.fake_A = self.netG_A.decode(self.content_B, self.style_A_fake, self.class_A) 247 | 248 | # Re-encoding everyting 249 | self.rec_content_B, self.rec_style_A = self.netG_A.encode(self.fake_A) 250 | self.rec_content_A, self.rec_style_B = self.netG_B.encode(self.fake_B) 251 | 252 | if self.opt.lambda_rec_cycle > 0: 253 | self.rec_A_cycle = self.netG_A.decode(self.rec_content_A, self.style_A_real, self.class_A) 254 | self.rec_B_cycle = self.netG_B.decode(self.rec_content_B, self.style_B_real, self.class_B) 255 | if not self.alternate: 256 | self.fake_B_weights, self.fake_B_residual = self.netG_B.decode_weighted_exemplar(self.content_A, 257 | self.style_B_fake, 258 | self.get_fewshot_style_code(self.real_GT), 259 | torch.softmax(self.weights, 0)) 260 | self.alternate = True 261 | else: 262 | self.fake_B_weights, self.fake_B_residual = self.netG_B.decode_weighted_global(self.content_A, 263 | self.style_B_fake, 264 | self.style_B_fake, 265 | torch.softmax(self.weights, 0)) 266 | self.alternate = False 267 | self.fake_B_style = self.fake_B_weights + self.fake_B_residual 268 | self.fake_B_style = torch.clamp(self.fake_B_style, -1, 1) 269 | 270 | def training_step_D(self): 271 | with torch.no_grad(): 272 | # Random style sampling 273 | self.style_A_fake = torch.randn(self.real_A.size(0), self.opt.style_dim, 1, 1).to(self.device) 274 | self.style_B_fake = torch.randn(self.real_B.size(0), self.opt.style_dim, 1, 1).to(self.device) 275 | 276 | # Encoding 277 | self.content_A, self.style_A_real = self.netG_A.encode(self.real_A) 278 | self.content_B, self.style_B_real = self.netG_B.encode(self.real_B) 279 | 280 | self.fake_B = self.netG_B.decode(self.content_A, self.style_B_fake, self.class_B) 281 | self.fake_A = self.netG_A.decode(self.content_B, self.style_A_fake, self.class_A) 282 | if not self.alternate: 283 | self.fake_B_weights, self.fake_B_residual = self.netG_B.decode_weighted_exemplar(self.content_A,self.style_B_fake, self.get_fewshot_style_code(self.real_GT), torch.softmax(self.weights, 0)) 284 | else: 285 | self.fake_B_weights, self.fake_B_residual = self.netG_B.decode_weighted_global(self.content_A,self.style_B_fake, self.style_B_fake, torch.softmax(self.weights, 0)) 286 | self.fake_B_style = self.fake_B_weights + self.fake_B_residual 287 | 288 | self.loss_D_A = self.netD_A.calc_dis_loss(self.fake_B, self.real_B, self.class_B, self.device) * self.opt.lambda_gan 289 | self.loss_D_B = self.netD_B.calc_dis_loss(self.fake_A, self.real_A, self.class_A, self.device) * self.opt.lambda_gan 290 | self.loss_D_patches = self.netD_patches.calc_dis_loss(self.augmentations_patches(self.fake_B_style), 291 | self.augmentations_patches(self.real_GT), 292 | 0, self.device)* self.opt.lambda_gan_patches 293 | loss_D = self.loss_D_A + self.loss_D_B + self.loss_D_patches 294 | return loss_D 295 | 296 | 297 | def training_step_G(self): 298 | self.forward_train() 299 | self.loss_rec_A = self.reconCriterion(self.rec_A_img, self.real_A) * self.opt.lambda_rec_image 300 | self.loss_rec_B = self.reconCriterion(self.rec_B_img, self.real_B) * self.opt.lambda_rec_image 301 | 302 | self.loss_rec_style_A = self.reconCriterion(self.rec_style_A, self.style_A_fake) * self.opt.lambda_rec_style 303 | self.loss_rec_style_B = self.reconCriterion(self.rec_style_B, self.style_B_fake) * self.opt.lambda_rec_style 304 | 305 | self.loss_rec_content_A = self.reconCriterion(self.rec_content_A, self.content_A) * self.opt.lambda_rec_content 306 | self.loss_rec_content_B = self.reconCriterion(self.rec_content_B, self.content_B) * self.opt.lambda_rec_content 307 | 308 | if self.opt.lambda_rec_cycle > 0: 309 | self.loss_cycle_A = self.reconCriterion(self.rec_A_cycle, self.real_A) * self.opt.lambda_rec_cycle 310 | self.loss_cycle_B = self.reconCriterion(self.rec_B_cycle, self.real_B) * self.opt.lambda_rec_cycle 311 | else: 312 | self.loss_cycle_A = 0 313 | self.loss_cycle_B = 0 314 | 315 | self.loss_G_A = self.netD_A.calc_gen_loss(self.fake_B, self.class_B, self.device) * self.opt.lambda_gan 316 | self.loss_G_B = self.netD_B.calc_gen_loss(self.fake_A, self.class_A, self.device) * self.opt.lambda_gan 317 | self.loss_G_patches = self.netD_patches.calc_gen_loss( 318 | self.augmentations_patches(self.fake_B_style), 0, self.device) * self.opt.lambda_gan_patches 319 | 320 | if self.opt.lambda_vgg > 0: 321 | self.loss_vgg_A = self.__compute_vgg_loss_onlyperceptual(self.fake_A, self.real_B) * self.opt.lambda_vgg 322 | self.loss_vgg_B = self.__compute_vgg_loss_onlyperceptual(self.fake_B, self.real_A) * self.opt.lambda_vgg 323 | else: 324 | self.loss_vgg_A = 0 325 | self.loss_vgg_B = 0 326 | 327 | self.loss_vgg_style = self.__compute_vgg_loss(self.fake_B_style, self.real_A, self.real_GT) * self.opt.lambda_vgg_fs_res 328 | self.loss_vgg_gan = self.__compute_vgg_loss(self.fake_B_weights, self.real_A, self.real_GT) * self.opt.lambda_vgg_fs # Using the manifold to get nearer 329 | 330 | self.loss_G = self.loss_rec_A + self.loss_rec_B + self.loss_rec_style_A + self.loss_rec_style_B + \ 331 | self.loss_rec_content_A + self.loss_rec_content_B + self.loss_cycle_A + self.loss_cycle_B + \ 332 | self.loss_G_A + self.loss_G_B + self.loss_vgg_A + self.loss_vgg_B + self.loss_vgg_style + \ 333 | self.loss_G_patches + self.loss_vgg_gan 334 | 335 | return self.loss_G 336 | 337 | def training_step(self, batch, batch_idx, optimizer_idx): 338 | 339 | self.set_input(batch) 340 | if optimizer_idx == 0: 341 | self.set_requires_grad([self.netD_A, self.netD_B, self.netD_patches], True) 342 | self.set_requires_grad([self.netG_A, self.netG_B], False) 343 | 344 | return self.training_step_D() 345 | elif optimizer_idx == 1: 346 | self.set_requires_grad([self.netD_A, self.netD_B, self.netD_patches], False) # Ds require no gradients when optimizing Gs 347 | self.set_requires_grad([self.netG_A, self.netG_B], True) 348 | 349 | return self.training_step_G() 350 | -------------------------------------------------------------------------------- /networks/backbones/fsmunit.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | import functools 5 | from torch.optim import lr_scheduler 6 | import torch.nn.functional as F 7 | from .functions import init_net, init_weights, get_scheduler 8 | 9 | 10 | ######################################################################################################################## 11 | # MUNIT architecture 12 | ######################################################################################################################## 13 | 14 | def define_G_munit(input_nc, output_nc, gen_dim, style_dim, n_downsample, n_res, 15 | pad_type, mlp_dim, activ='relu', init_type = 'kaiming', init_gain=0.02, gpu_ids=[], num_classes = 1): 16 | gen = AdaINGen(input_nc, output_nc, gen_dim, style_dim, n_downsample, n_res, activ, pad_type, mlp_dim, num_classes) 17 | return init_net(gen, init_type=init_type, init_gain = init_gain, gpu_ids = gpu_ids) 18 | 19 | def define_D_munit(input_nc, disc_dim, norm, activ, n_layer, gan_type, num_scales, pad_type, 20 | init_type = 'kaiming', init_gain = 0.02, gpu_ids = [], num_classes = 1): 21 | disc = MsImageDis(input_nc, n_layer, gan_type, disc_dim, norm, activ, num_scales, pad_type, num_classes) 22 | return init_net(disc, init_type=init_type, init_gain = init_gain, gpu_ids = gpu_ids) 23 | 24 | 25 | class MsImageDis(nn.Module): 26 | # Multi-scale discriminator architecture 27 | def __init__(self, input_dim, n_layer, gan_type, dim, norm, activ, num_scales, pad_type, num_classes): 28 | super(MsImageDis, self).__init__() 29 | self.n_layer = n_layer 30 | self.gan_type = gan_type 31 | self.dim = dim 32 | self.norm = norm 33 | self.activ = activ 34 | self.num_scales = num_scales 35 | self.pad_type = pad_type 36 | self.input_dim = input_dim 37 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 38 | self.cnns = nn.ModuleList() 39 | self.heads = nn.ModuleList() 40 | self.num_classes = num_classes 41 | for _ in range(self.num_scales): 42 | cnn, heads_single = self._make_net(num_classes) 43 | self.cnns.append(cnn) 44 | self.heads.append(heads_single) 45 | 46 | def _make_net(self, num_classes): 47 | dim = self.dim 48 | cnn_x = [] 49 | cnn_x += [Conv2dBlock(self.input_dim, dim, 4, 2, 1, norm='none', activation=self.activ, pad_type=self.pad_type)] 50 | for i in range(self.n_layer - 1): 51 | cnn_x += [Conv2dBlock(dim, dim * 2, 4, 2, 1, norm=self.norm, activation=self.activ, pad_type=self.pad_type)] 52 | dim *= 2 53 | heads = nn.ModuleList() 54 | for _ in range(0, num_classes): 55 | heads += [nn.Conv2d(dim, 1, 1, 1, 0)] 56 | cnn_x = nn.Sequential(*cnn_x) 57 | return cnn_x, heads 58 | 59 | def forward(self, x, selected_class): 60 | outputs = [] 61 | for model, head in zip(self.cnns, self.heads): 62 | out = model(x) 63 | out = head[selected_class](out) 64 | outputs.append(out) 65 | x = self.downsample(x) 66 | return outputs 67 | 68 | def calc_dis_loss(self, input_fake, input_real, selected_class, device): 69 | # calculate the loss to train D 70 | outs0 = self.forward(input_fake, selected_class) 71 | outs1 = self.forward(input_real, selected_class) 72 | loss = 0 73 | 74 | for it, (out0, out1) in enumerate(zip(outs0, outs1)): 75 | if self.gan_type == 'lsgan': 76 | loss += torch.mean((out0 - 0)**2) + torch.mean((out1 - 1)**2) 77 | elif self.gan_type == 'nsgan': 78 | all0 = torch.zeros_like(out0).to(device) 79 | all1 = torch.ones_like(out1).to(device) 80 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all0) + 81 | F.binary_cross_entropy(F.sigmoid(out1), all1)) 82 | else: 83 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 84 | return loss 85 | 86 | def calc_gen_loss(self, input_fake, selected_class, device): 87 | # calculate the loss to train G 88 | outs0 = self.forward(input_fake, selected_class) 89 | loss = 0 90 | for it, (out0) in enumerate(outs0): 91 | if self.gan_type == 'lsgan': 92 | loss += torch.mean((out0 - 1)**2) # LSGAN 93 | elif self.gan_type == 'nsgan': 94 | all1 = torch.ones_like(out0.data).to(device) 95 | loss += torch.mean(F.binary_cross_entropy(F.sigmoid(out0), all1)) 96 | else: 97 | assert 0, "Unsupported GAN type: {}".format(self.gan_type) 98 | return loss 99 | 100 | class AdaINGen(nn.Module): 101 | # AdaIN auto-encoder architecture 102 | def __init__(self, input_dim, output_dim, dim, style_dim, n_downsample, n_res, activ, pad_type, mlp_dim, num_class): 103 | super(AdaINGen, self).__init__() 104 | 105 | # style encoder 106 | self.enc_style = StyleEncoder(4, input_dim, dim, style_dim, norm='none', activ=activ, pad_type=pad_type) 107 | 108 | # content encoder 109 | self.enc_content = ContentEncoder(n_downsample, n_res, input_dim, dim, 'instance', activ, pad_type=pad_type) 110 | self.dec = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_dim, res_norm='adain', activ=activ, pad_type=pad_type) 111 | self.res_exemplar = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_dim, res_norm='adain', activ=activ, pad_type=pad_type) 112 | self.res_global = Decoder(n_downsample, n_res, self.enc_content.output_dim, output_dim, res_norm='adain', activ=activ, pad_type=pad_type) 113 | # MLP to generate AdaIN parameters 114 | self.mlp = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ, num_class=num_class) 115 | self.mlp_exemplar = MLP(1920, self.get_num_adain_params(self.dec), mlp_dim, 1, norm='none', activ=activ, num_class=1) 116 | self.mlp_global = MLP(style_dim, self.get_num_adain_params(self.dec), mlp_dim, 3, norm='none', activ=activ, num_class=1) 117 | self.num_class = num_class 118 | 119 | def forward(self, images): 120 | # reconstruct an image 121 | content, style_fake = self.encode(images) 122 | images_recon = self.decode(content, style_fake) 123 | return images_recon 124 | 125 | def encode(self, images): 126 | # encode an image to its content and style codes 127 | style_fake = self.enc_style(images) 128 | content = self.enc_content(images) 129 | return content, style_fake 130 | 131 | def decode(self, content, style, selected_class): 132 | # decode content and style codes to an image 133 | adain_params = self.mlp(style, selected_class) 134 | self.assign_adain_params(adain_params, self.dec) 135 | images = self.dec(content) 136 | return images 137 | 138 | def decode_weighted_exemplar(self, content, style, style_residual, weights): 139 | # decode content and style codes to an image 140 | adain_params = None 141 | for selected_class in range(0, self.num_class): 142 | if adain_params is None: 143 | adain_params = self.mlp(style, selected_class) * weights[selected_class] 144 | else: 145 | adain_params += self.mlp(style, selected_class) * weights[selected_class] 146 | adain_res_params = self.mlp_exemplar(style_residual, 0) 147 | self.assign_adain_params(adain_params, self.dec) 148 | self.assign_adain_params(adain_res_params, self.res_exemplar) 149 | images = self.dec(content) 150 | residual = self.res_exemplar(content) 151 | return images, residual 152 | 153 | def decode_weighted_global(self, content, style, style_residual, weights): 154 | # decode content and style codes to an image 155 | adain_params = None 156 | for selected_class in range(0, self.num_class): 157 | if adain_params is None: 158 | adain_params = self.mlp(style, selected_class) * weights[selected_class] 159 | else: 160 | adain_params += self.mlp(style, selected_class) * weights[selected_class] 161 | adain_res_params = self.mlp_global(style_residual, 0) 162 | self.assign_adain_params(adain_params, self.dec) 163 | self.assign_adain_params(adain_res_params, self.res_global) 164 | images = self.dec(content) 165 | residual = self.res_global(content) 166 | return images, residual 167 | 168 | def assign_adain_params(self, adain_params, model): 169 | # assign the adain_params to the AdaIN layers in model 170 | for m in model.modules(): 171 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 172 | mean = adain_params[:, :m.num_features] 173 | std = adain_params[:, m.num_features:2*m.num_features] 174 | m.bias = mean.contiguous().view(-1) 175 | m.weight = std.contiguous().view(-1) 176 | if adain_params.size(1) > 2*m.num_features: 177 | adain_params = adain_params[:, 2*m.num_features:] 178 | 179 | def get_num_adain_params(self, model): 180 | # return the number of AdaIN parameters needed by the model 181 | num_adain_params = 0 182 | for m in model.modules(): 183 | if m.__class__.__name__ == "AdaptiveInstanceNorm2d": 184 | num_adain_params += 2*m.num_features 185 | return num_adain_params 186 | 187 | 188 | class StyleEncoder(nn.Module): 189 | def __init__(self, n_downsample, input_dim, dim, style_dim, norm, activ, pad_type): 190 | super(StyleEncoder, self).__init__() 191 | self.model = [] 192 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 193 | for i in range(2): 194 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 195 | dim *= 2 196 | for i in range(n_downsample - 2): 197 | self.model += [Conv2dBlock(dim, dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 198 | self.model += [nn.AdaptiveAvgPool2d(1)] # global average pooling 199 | self.model += [nn.Conv2d(dim, style_dim, 1, 1, 0)] 200 | self.model = nn.Sequential(*self.model) 201 | self.output_dim = dim 202 | 203 | def forward(self, x): 204 | return self.model(x) 205 | 206 | class ContentEncoder(nn.Module): 207 | def __init__(self, n_downsample, n_res, input_dim, dim, norm, activ, pad_type): 208 | super(ContentEncoder, self).__init__() 209 | self.model = [] 210 | self.model += [Conv2dBlock(input_dim, dim, 7, 1, 3, norm=norm, activation=activ, pad_type=pad_type)] 211 | # downsampling blocks 212 | for i in range(n_downsample): 213 | self.model += [Conv2dBlock(dim, 2 * dim, 4, 2, 1, norm=norm, activation=activ, pad_type=pad_type)] 214 | dim *= 2 215 | # residual blocks 216 | self.model += [ResBlocks(n_res, dim, norm=norm, activation=activ, pad_type=pad_type)] 217 | self.model = nn.Sequential(*self.model) 218 | self.output_dim = dim 219 | 220 | def forward(self, x): 221 | return self.model(x) 222 | 223 | class Decoder(nn.Module): 224 | def __init__(self, n_upsample, n_res, dim, output_dim, res_norm='adain', activ='relu', pad_type='zero'): 225 | super(Decoder, self).__init__() 226 | 227 | self.model = [] 228 | # AdaIN residual blocks 229 | self.model += [ResBlocks(n_res, dim, res_norm, activ, pad_type=pad_type)] 230 | # upsampling blocks 231 | for i in range(n_upsample): 232 | self.model += [nn.Upsample(scale_factor=2), 233 | Conv2dBlock(dim, dim // 2, 5, 1, 2, norm='layer', activation=activ, pad_type=pad_type)] 234 | dim //= 2 235 | # use reflection padding in the last conv layer 236 | self.model += [Conv2dBlock(dim, output_dim, 7, 1, 3, norm='none', activation='tanh', pad_type=pad_type)] 237 | self.model = nn.Sequential(*self.model) 238 | 239 | def forward(self, x): 240 | return self.model(x) 241 | 242 | 243 | class ResBlocks(nn.Module): 244 | def __init__(self, num_blocks, dim, norm='instance', activation='relu', pad_type='zero'): 245 | super(ResBlocks, self).__init__() 246 | self.model = [] 247 | for i in range(num_blocks): 248 | self.model += [ResBlock(dim, norm=norm, activation=activation, pad_type=pad_type)] 249 | self.model = nn.Sequential(*self.model) 250 | 251 | def forward(self, x): 252 | return self.model(x) 253 | 254 | class MLP(nn.Module): 255 | def __init__(self, input_dim, output_dim, dim, n_blk, norm='none', activ='relu', num_class=1): 256 | 257 | super(MLP, self).__init__() 258 | self.model = [] 259 | self.model += [LinearBlock(input_dim, dim, norm=norm, activation=activ)] 260 | for i in range(n_blk - 2): 261 | self.model += [LinearBlock(dim, dim, norm=norm, activation=activ)] 262 | #self.model += [LinearBlock(dim, output_dim, norm='none', activation='none')] # no output activations 263 | self.heads = nn.ModuleList() 264 | for _ in range(0, num_class): 265 | self.heads += [LinearBlock(dim, output_dim, norm='none', activation='none')] 266 | self.model = nn.Sequential(*self.model) 267 | 268 | def forward(self, x, selected_class): 269 | out = self.model(x.view(x.size(0), -1)) 270 | out = self.heads[selected_class](out) 271 | return out 272 | 273 | ################################################################################## 274 | # Basic Blocks 275 | ################################################################################## 276 | class ResBlock(nn.Module): 277 | def __init__(self, dim, norm='instance', activation='relu', pad_type='zero'): 278 | super(ResBlock, self).__init__() 279 | 280 | model = [] 281 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation=activation, pad_type=pad_type)] 282 | model += [Conv2dBlock(dim ,dim, 3, 1, 1, norm=norm, activation='none', pad_type=pad_type)] 283 | self.model = nn.Sequential(*model) 284 | 285 | def forward(self, x): 286 | residual = x 287 | out = self.model(x) 288 | out += residual 289 | return out 290 | 291 | class Conv2dBlock(nn.Module): 292 | def __init__(self, input_dim ,output_dim, kernel_size, stride, 293 | padding=0, norm='none', activation='relu', pad_type='zero'): 294 | super(Conv2dBlock, self).__init__() 295 | self.use_bias = True 296 | # initialize padding 297 | if pad_type == 'reflect': 298 | self.pad = nn.ReflectionPad2d(padding) 299 | elif pad_type == 'replicate': 300 | self.pad = nn.ReplicationPad2d(padding) 301 | elif pad_type == 'zero': 302 | self.pad = nn.ZeroPad2d(padding) 303 | else: 304 | assert 0, "Unsupported padding type: {}".format(pad_type) 305 | 306 | # initialize normalization 307 | norm_dim = output_dim 308 | if norm == 'batch': 309 | self.norm = nn.BatchNorm2d(norm_dim) 310 | elif norm == 'instance': 311 | #self.norm = nn.InstanceNorm2d(norm_dim, track_running_stats=True) 312 | self.norm = nn.InstanceNorm2d(norm_dim) 313 | elif norm == 'layer': 314 | self.norm = LayerNorm(norm_dim) 315 | elif norm == 'adain': 316 | self.norm = AdaptiveInstanceNorm2d(norm_dim) 317 | elif norm == 'none' or norm == 'spectral': 318 | self.norm = None 319 | else: 320 | assert 0, "Unsupported normalization: {}".format(norm) 321 | 322 | # initialize activation 323 | if activation == 'relu': 324 | self.activation = nn.ReLU(inplace=True) 325 | elif activation == 'lrelu': 326 | self.activation = nn.LeakyReLU(0.2, inplace=True) 327 | elif activation == 'prelu': 328 | self.activation = nn.PReLU() 329 | elif activation == 'selu': 330 | self.activation = nn.SELU(inplace=True) 331 | elif activation == 'tanh': 332 | self.activation = nn.Tanh() 333 | elif activation == 'none': 334 | self.activation = None 335 | else: 336 | assert 0, "Unsupported activation: {}".format(activation) 337 | 338 | # initialize convolution 339 | if norm == 'spectral': 340 | self.conv = SpectralNorm(nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias)) 341 | else: 342 | self.conv = nn.Conv2d(input_dim, output_dim, kernel_size, stride, bias=self.use_bias) 343 | 344 | def forward(self, x): 345 | x = self.conv(self.pad(x)) 346 | if self.norm: 347 | x = self.norm(x) 348 | if self.activation: 349 | x = self.activation(x) 350 | return x 351 | 352 | class LinearBlock(nn.Module): 353 | def __init__(self, input_dim, output_dim, norm='none', activation='relu'): 354 | super(LinearBlock, self).__init__() 355 | use_bias = True 356 | # initialize fully connected layer 357 | if norm == 'spectral': 358 | self.fc = SpectralNorm(nn.Linear(input_dim, output_dim, bias=use_bias)) 359 | else: 360 | self.fc = nn.Linear(input_dim, output_dim, bias=use_bias) 361 | 362 | # initialize normalization 363 | norm_dim = output_dim 364 | if norm == 'batch': 365 | self.norm = nn.BatchNorm1d(norm_dim) 366 | elif norm == 'instance': 367 | self.norm = nn.InstanceNorm1d(norm_dim) 368 | elif norm == 'layer': 369 | self.norm = LayerNorm(norm_dim) 370 | elif norm == 'none' or norm == 'spectral': 371 | self.norm = None 372 | else: 373 | assert 0, "Unsupported normalization: {}".format(norm) 374 | 375 | # initialize activation 376 | if activation == 'relu': 377 | self.activation = nn.ReLU(inplace=True) 378 | elif activation == 'lrelu': 379 | self.activation = nn.LeakyReLU(0.2, inplace=True) 380 | elif activation == 'prelu': 381 | self.activation = nn.PReLU() 382 | elif activation == 'selu': 383 | self.activation = nn.SELU(inplace=True) 384 | elif activation == 'tanh': 385 | self.activation = nn.Tanh() 386 | elif activation == 'none': 387 | self.activation = None 388 | else: 389 | assert 0, "Unsupported activation: {}".format(activation) 390 | 391 | def forward(self, x): 392 | out = self.fc(x) 393 | if self.norm: 394 | out = self.norm(out) 395 | if self.activation: 396 | out = self.activation(out) 397 | return out 398 | 399 | 400 | class Vgg16(nn.Module): 401 | def __init__(self): 402 | super(Vgg16, self).__init__() 403 | self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) 404 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) 405 | 406 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1) 407 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1) 408 | 409 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1) 410 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 411 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) 412 | 413 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1) 414 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 415 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 416 | 417 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 418 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 419 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1) 420 | 421 | def forward(self, X, with_style=False): 422 | h = F.relu(self.conv1_1(X), inplace=True) 423 | relu1_1 = h 424 | h = F.relu(self.conv1_2(h), inplace=True) 425 | h = F.max_pool2d(h, kernel_size=2, stride=2) 426 | 427 | h = F.relu(self.conv2_1(h), inplace=True) 428 | relu2_1 = h 429 | h = F.relu(self.conv2_2(h), inplace=True) 430 | h = F.max_pool2d(h, kernel_size=2, stride=2) 431 | 432 | h = F.relu(self.conv3_1(h), inplace=True) 433 | relu3_1 = h 434 | h = F.relu(self.conv3_2(h), inplace=True) 435 | h = F.relu(self.conv3_3(h), inplace=True) 436 | h = F.max_pool2d(h, kernel_size=2, stride=2) 437 | 438 | h = F.relu(self.conv4_1(h), inplace=True) 439 | relu4_1 = h 440 | 441 | h = F.relu(self.conv4_2(h), inplace=True) 442 | h = F.relu(self.conv4_3(h), inplace=True) 443 | 444 | h = F.relu(self.conv5_1(h), inplace=True) 445 | h = F.relu(self.conv5_2(h), inplace=True) 446 | h = F.relu(self.conv5_3(h), inplace=True) 447 | relu5_3 = h 448 | if not with_style: 449 | return relu5_3 450 | return relu5_3, [relu1_1, relu2_1, relu3_1, relu4_1] 451 | 452 | 453 | class AdaptiveInstanceNorm2d(nn.Module): 454 | def __init__(self, num_features, eps=1e-5, momentum=0.1): 455 | super(AdaptiveInstanceNorm2d, self).__init__() 456 | self.num_features = num_features 457 | self.eps = eps 458 | self.momentum = momentum 459 | # weight and bias are dynamically assigned 460 | self.weight = None 461 | self.bias = None 462 | # just dummy buffers, not used 463 | self.register_buffer('running_mean', torch.zeros(num_features)) 464 | self.register_buffer('running_var', torch.ones(num_features)) 465 | 466 | 467 | 468 | def forward(self, x): 469 | assert self.weight is not None and self.bias is not None, "Please assign weight and bias before calling AdaIN!" 470 | b, c = x.size(0), x.size(1) 471 | if self.weight.type() == 'torch.cuda.HalfTensor': 472 | running_mean = self.running_mean.repeat(b).to(torch.float16) 473 | running_var = self.running_var.repeat(b).to(torch.float16) 474 | else: 475 | running_mean = self.running_mean.repeat(b) 476 | running_var = self.running_var.repeat(b) 477 | # Apply instance norm 478 | x_reshaped = x.contiguous().view(1, b * c, *x.size()[2:]) 479 | 480 | out = F.batch_norm( 481 | x_reshaped, running_mean, running_var, self.weight, self.bias, 482 | True, self.momentum, self.eps) 483 | 484 | return out.view(b, c, *x.size()[2:]) 485 | 486 | def __repr__(self): 487 | return self.__class__.__name__ + '(' + str(self.num_features) + ')' 488 | 489 | 490 | class LayerNorm(nn.Module): 491 | def __init__(self, num_features, eps=1e-5, affine=True): 492 | super(LayerNorm, self).__init__() 493 | self.num_features = num_features 494 | self.affine = affine 495 | self.eps = eps 496 | 497 | if self.affine: 498 | self.gamma = nn.Parameter(torch.Tensor(num_features).uniform_()) 499 | self.beta = nn.Parameter(torch.zeros(num_features)) 500 | 501 | def forward(self, x): 502 | shape = [-1] + [1] * (x.dim() - 1) 503 | # print(x.size()) 504 | if x.size(0) == 1: 505 | # These two lines run much faster in pytorch 0.4 than the two lines listed below. 506 | mean = x.view(-1).mean().view(*shape) 507 | std = x.view(-1).std().view(*shape) 508 | else: 509 | mean = x.view(x.size(0), -1).mean(1).view(*shape) 510 | std = x.view(x.size(0), -1).std(1).view(*shape) 511 | 512 | x = (x - mean) / (std + self.eps) 513 | 514 | if self.affine: 515 | shape = [1, -1] + [1] * (x.dim() - 2) 516 | x = x * self.gamma.view(*shape) + self.beta.view(*shape) 517 | return x 518 | 519 | def l2normalize(v, eps=1e-12): 520 | return v / (v.norm() + eps) 521 | 522 | 523 | class SpectralNorm(nn.Module): 524 | """ 525 | Based on the paper "Spectral Normalization for Generative Adversarial Networks" by Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida 526 | and the Pytorch implementation https://github.com/christiancosgrove/pytorch-spectral-normalization-gan 527 | """ 528 | def __init__(self, module, name='weight', power_iterations=1): 529 | super(SpectralNorm, self).__init__() 530 | self.module = module 531 | self.name = name 532 | self.power_iterations = power_iterations 533 | if not self._made_params(): 534 | self._make_params() 535 | 536 | def _update_u_v(self): 537 | u = getattr(self.module, self.name + "_u") 538 | v = getattr(self.module, self.name + "_v") 539 | w = getattr(self.module, self.name + "_bar") 540 | 541 | height = w.data.shape[0] 542 | for _ in range(self.power_iterations): 543 | v.data = l2normalize(torch.mv(torch.t(w.view(height,-1).data), u.data)) 544 | u.data = l2normalize(torch.mv(w.view(height,-1).data, v.data)) 545 | 546 | # sigma = torch.dot(u.data, torch.mv(w.view(height,-1).data, v.data)) 547 | sigma = u.dot(w.view(height, -1).mv(v)) 548 | setattr(self.module, self.name, w / sigma.expand_as(w)) 549 | 550 | def _made_params(self): 551 | try: 552 | u = getattr(self.module, self.name + "_u") 553 | v = getattr(self.module, self.name + "_v") 554 | w = getattr(self.module, self.name + "_bar") 555 | return True 556 | except AttributeError: 557 | return False 558 | 559 | 560 | def _make_params(self): 561 | w = getattr(self.module, self.name) 562 | 563 | height = w.data.shape[0] 564 | width = w.view(height, -1).data.shape[1] 565 | 566 | u = nn.Parameter(w.data.new(height).normal_(0, 1), requires_grad=False) 567 | v = nn.Parameter(w.data.new(width).normal_(0, 1), requires_grad=False) 568 | u.data = l2normalize(u.data) 569 | v.data = l2normalize(v.data) 570 | w_bar = nn.Parameter(w.data) 571 | 572 | del self.module._parameters[self.name] 573 | 574 | self.module.register_parameter(self.name + "_u", u) 575 | self.module.register_parameter(self.name + "_v", v) 576 | self.module.register_parameter(self.name + "_bar", w_bar) 577 | 578 | 579 | def forward(self, *args): 580 | self._update_u_v() 581 | return self.module.forward(*args) --------------------------------------------------------------------------------