├── .gitignore ├── LICENSE ├── README.md ├── data ├── __init__.py ├── base_dataset.py ├── image_folder.py ├── jointset_dataset.py ├── noise_dataset.py ├── noiseframe_dataset.py ├── noiseshufflevideo_dataset.py └── randomvideo_dataset.py ├── dnnlib ├── __init__.py └── util.py ├── docs ├── results1.gif ├── results2.gif └── teaser.png ├── environment └── stylefacev.yaml ├── legacy.py ├── models ├── __init__.py ├── base_model.py ├── diy_networks.py ├── lmcode_networks.py ├── networks.py ├── reenact_model.py ├── resnet.py ├── rnn_dnet.py ├── rnn_dnet3.py ├── rnn_losses.py ├── rnn_net.py ├── sample_model.py ├── stylefacevadv_model.py ├── stylepose_model.py ├── stylepre_model.py ├── stylernn_model.py └── stylernns_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── scripts └── vid2img.py ├── test.py ├── torch_utils ├── __init__.py ├── custom_ops.py ├── misc.py ├── ops │ ├── __init__.py │ ├── bias_act.cpp │ ├── bias_act.cu │ ├── bias_act.h │ ├── bias_act.py │ ├── conv2d_gradfix.py │ ├── conv2d_resample.py │ ├── filtered_lrelu.cpp │ ├── filtered_lrelu.cu │ ├── filtered_lrelu.h │ ├── filtered_lrelu.py │ ├── filtered_lrelu_ns.cu │ ├── filtered_lrelu_rd.cu │ ├── filtered_lrelu_wr.cu │ ├── fma.py │ ├── grid_sample_gradfix.py │ ├── upfirdn2d.cpp │ ├── upfirdn2d.cu │ ├── upfirdn2d.h │ └── upfirdn2d.py ├── persistence.py └── training_stats.py ├── train.py └── util ├── __init__.py ├── get_data.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | debug* 3 | datasets/ 4 | checkpoints/ 5 | results/ 6 | build/ 7 | dist/ 8 | torch.egg-info/ 9 | */**/__pycache__ 10 | torch/version.py 11 | torch/csrc/generic/TensorMethods.cpp 12 | torch/lib/*.so* 13 | torch/lib/*.dylib* 14 | torch/lib/*.h 15 | torch/lib/build 16 | torch/lib/tmp_install 17 | torch/lib/include 18 | torch/lib/torch_shm_manager 19 | torch/csrc/cudnn/cuDNN.cpp 20 | torch/csrc/nn/THNN.cwrap 21 | torch/csrc/nn/THNN.cpp 22 | torch/csrc/nn/THCUNN.cwrap 23 | torch/csrc/nn/THCUNN.cpp 24 | torch/csrc/nn/THNN_generic.cwrap 25 | torch/csrc/nn/THNN_generic.cpp 26 | torch/csrc/nn/THNN_generic.h 27 | docs/src/**/* 28 | test/data/legacy_modules.t7 29 | test/data/gpu_tensors.pt 30 | test/htmlcov 31 | test/.coverage 32 | */*.pyc 33 | */**/*.pyc 34 | */**/**/*.pyc 35 | */**/**/**/*.pyc 36 | */**/**/**/**/*.pyc 37 | */*.so* 38 | */**/*.so* 39 | */**/*.dylib* 40 | test/data/legacy_serialized.pt 41 | *~ 42 | .idea 43 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Haonan Qiu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StyleFaceV - Official PyTorch Implementation 2 | 3 | This repository provides the official PyTorch implementation for the following paper: 4 | 5 | **StyleFaceV: Face Video Generation via Decomposing and Recomposing Pretrained StyleGAN3**
6 | [Haonan Qiu](http://haonanqiu.com/), [Yuming Jiang](https://yumingj.github.io/), [Hang Zhou](https://hangz-nju-cuhk.github.io/), [Wayne Wu](https://wywu.github.io/), and [Ziwei Liu](https://liuziwei7.github.io/)
7 | Arxiv, 2022. 8 | 9 | From [MMLab@NTU](https://www.mmlab-ntu.com/index.html) affiliated with S-Lab, Nanyang Technological University and SenseTime Research. 10 | 11 | 12 | 13 | [**[Project Page]**](http://haonanqiu.com/projects/StyleFaceV.html) | [**[Paper]**](https://arxiv.org/abs/2208.07862) | [**[Demo Video]**](https://youtu.be/BZNLcD04-Fc) 14 | 15 | 16 | ### Generated Samples 17 | 18 | 19 | 20 | 21 | 22 | ## Updates 23 | 24 | - [07/2022] Paper and demo video are released. 25 | - [07/2022] Code is released. 26 | 27 | ## Installation 28 | **Clone this repo:** 29 | ```bash 30 | git clone https://github.com/arthur-qiu/StyleFaceV.git 31 | cd StyleFaceV 32 | ``` 33 | 34 | **Dependencies:** 35 | 36 | All dependencies for defining the environment are provided in `environment/stylefacev.yaml`. 37 | We recommend using [Anaconda](https://docs.anaconda.com/anaconda/install/) to manage the python environment: 38 | 39 | ```bash 40 | conda env create -f ./environment/stylefacev.yaml 41 | conda activate stylefacev 42 | ``` 43 | 44 | ## Datasets 45 | 46 | Image Data: [Unaligned FFHQ](https://github.com/NVlabs/ffhq-dataset) 47 | 48 | Video Data: [RAVDESS](https://zenodo.org/record/1188976) 49 | 50 | Download the processed video data via this [Google Drive](https://drive.google.com/file/d/17tMHrpvTm08ixAwnzTI9dN0BhCjmCwgV/view?usp=sharing) or process the data via this [repo](https://github.com/AliaksandrSiarohin/video-preprocessing) 51 | 52 | Put all the data at the path `../data`. 53 | 54 | Transform the video data into `.png` form: 55 | 56 | ```bash 57 | python scripts/vid2img.py 58 | ``` 59 | 60 | ## Sampling 61 | 62 | ### Pretrained Models 63 | 64 | Pretrained models can be downloaded from this [Google Drive](https://drive.google.com/file/d/1c_JWfDjN44XpI8OG24p3FkdEufJGsv34/view?usp=sharing). Unzip the file and put them under the dataset folder with the following structure: 65 | ``` 66 | pretrained_models 67 | ├── network-snapshot-005000.pkl # styleGAN3 checkpoint finetuned on both RAVDNESS and unaligned FFHQ. 68 | ├── wing.ckpt # Face Alignment model from https://github.com/protossw512/AdaptiveWingLoss. 69 | ├── motion_net.pth # trained motion sampler. 70 | ├── pre_net.pth 71 | └── pre_pose_net.pth 72 | checkpoints/stylefacev 73 | ├── latest_net_FE.pth # appearance extractor + recompostion 74 | ├── latest_net_FE_lm.pth # first half of pose extractor 75 | └── latest_net_FE_pose.pth # second half of pose extractor 76 | ``` 77 | 78 | ### Generating Videos 79 | 80 | ```bash 81 | python test.py --dataroot ../data/actor_align_512_png --name stylefacev \ 82 | --network_pkl=pretrained_models/network-snapshot-005000.pkl --model sample \ 83 | --model_names FE,FE_pose,FE_lm --rnn_path pretrained_models/motion_net.pth \ 84 | --n_frames_G 60 --num_test=64 --results_dir './sample_results/' 85 | ``` 86 | 87 | ## Training 88 | 89 | ### Pre Stage 90 | 91 | If you want to use new datasets, please finetune the StyleGAN3 model first. 92 | 93 | This stage is purely trained on image data and will help the convergence. 94 | 95 | ```bash 96 | python train.py --dataroot ../data/actor_align_512_png --name stylepose \ 97 | --network_pkl=pretrained_models/network-snapshot-005000.pkl \ 98 | --model stylevpose --n_epochs 5 --n_epochs_decay 5 99 | python train.py --dataroot ../data/actor_align_512_png --name stylefacev_pre \ 100 | --network_pkl=pretrained_models/network-snapshot-005000.pkl \ 101 | --model stylepre --pose_path checkpoints/stylevpose/latest_net_FE.pth 102 | ``` 103 | 104 | You can also use `pre_net.pth` and `pre_pose_net.pth` from the folder of `pretrained_models`. 105 | 106 | ```bash 107 | python train.py --dataroot ../data/actor_align_512_png --name stylefacev_pre \ 108 | --network_pkl=pretrained_models/network-snapshot-005000.pkl --model stylepre \ 109 | --pre_path pretrained_models/pre_net.pth --pose_path pretrained_models/pre_pose_net.pth 110 | ``` 111 | 112 | ### Decomposing and Recomposing Pipeline 113 | 114 | ```bash 115 | python train.py --dataroot ../data/actor_align_512_png --name stylefacev \ 116 | --network_pkl=pretrained_models/network-snapshot-005000.pkl --model stylefacevadv \ 117 | --pose_path pretrained_models/pre_pose_net.pth \ 118 | --pre_path checkpoints/stylefacev_pre/latest_net_FE.pth \ 119 | --n_epochs 50 --n_epochs_decay 50 --lr 0.0002 120 | ``` 121 | 122 | ### Motion Sampler 123 | 124 | ```bash 125 | python train.py --dataroot ../data/actor_align_512_png --name motion \ 126 | --network_pkl=pretrained_models/network-snapshot-005000.pkl --model stylernn \ 127 | --pre_path checkpoints/stylefacev/latest_net_FE.pth \ 128 | --pose_path checkpoints/stylefacev/latest_net_FE_pose.pth \ 129 | --lm_path checkpoints/stylefacev/latest_net_FE_lm.pth \ 130 | --n_frames_G 30 131 | ``` 132 | 133 | If you do not have a 32G GPU, reduce the `n_frames_G` (12 for 16G). Or only add supervision on pose representations: 134 | 135 | ```bash 136 | python train.py --dataroot ../data/actor_align_512_png --name motion \ 137 | --network_pkl=pretrained_models/network-snapshot-005000.pkl --model stylernns \ 138 | --pose_path checkpoints/stylefacev/latest_net_FE_pose.pth \ 139 | --lm_path checkpoints/stylefacev/latest_net_FE_lm.pth \ 140 | --n_frames_G 30 141 | ``` 142 | 143 | ## Citation 144 | 145 | If you find this work useful for your research, please consider citing our paper: 146 | 147 | ```bibtex 148 | @misc{https://doi.org/10.48550/arxiv.2208.07862, 149 | doi = {10.48550/ARXIV.2208.07862}, 150 | url = {https://arxiv.org/abs/2208.07862}, 151 | author = {Qiu, Haonan and Jiang, Yuming and Zhou, Hang and Wu, Wayne and Liu, Ziwei}, 152 | keywords = {Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, 153 | title = {StyleFaceV: Face Video Generation via Decomposing and Recomposing Pretrained StyleGAN3}, 154 | publisher = {arXiv}, 155 | year = {2022}, 156 | copyright = {arXiv.org perpetual, non-exclusive license} 157 | } 158 | ``` 159 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from data.base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | self.dataloader = torch.utils.data.DataLoader( 76 | self.dataset, 77 | batch_size=opt.batch_size, 78 | shuffle=not opt.serial_batches, 79 | num_workers=int(opt.num_threads)) 80 | 81 | def load_data(self): 82 | return self 83 | 84 | def __len__(self): 85 | """Return the number of data in the dataset""" 86 | return min(len(self.dataset), self.opt.max_dataset_size) 87 | 88 | def __iter__(self): 89 | """Return a batch of data""" 90 | for i, data in enumerate(self.dataloader): 91 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 92 | break 93 | yield data 94 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess: 86 | osize = [opt.load_size, opt.load_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if opt.preprocess == 'none': 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor()] 108 | if grayscale: 109 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 110 | else: 111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 112 | return transforms.Compose(transform_list) 113 | 114 | 115 | def __make_power_2(img, base, method=Image.BICUBIC): 116 | ow, oh = img.size 117 | h = int(round(oh / base) * base) 118 | w = int(round(ow / base) * base) 119 | if h == oh and w == ow: 120 | return img 121 | 122 | __print_size_warning(ow, oh, w, h) 123 | return img.resize((w, h), method) 124 | 125 | 126 | def __scale_width(img, target_size, crop_size, method=Image.BICUBIC): 127 | ow, oh = img.size 128 | if ow == target_size and oh >= crop_size: 129 | return img 130 | w = target_size 131 | h = int(max(target_size * oh / ow, crop_size)) 132 | return img.resize((w, h), method) 133 | 134 | 135 | def __crop(img, pos, size): 136 | ow, oh = img.size 137 | x1, y1 = pos 138 | tw = th = size 139 | if (ow > tw or oh > th): 140 | return img.crop((x1, y1, x1 + tw, y1 + th)) 141 | return img 142 | 143 | 144 | def __flip(img, flip): 145 | if flip: 146 | return img.transpose(Image.FLIP_LEFT_RIGHT) 147 | return img 148 | 149 | 150 | def __print_size_warning(ow, oh, w, h): 151 | """Print warning information about image size(only print once)""" 152 | if not hasattr(__print_size_warning, 'has_printed'): 153 | print("The image size needs to be a multiple of 4. " 154 | "The loaded image size was (%d, %d), so it was adjusted to " 155 | "(%d, %d). This adjustment will be done to all images " 156 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 157 | __print_size_warning.has_printed = True 158 | -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | 12 | IMG_EXTENSIONS = [ 13 | '.jpg', '.JPG', '.jpeg', '.JPEG', 14 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 15 | '.tif', '.TIF', '.tiff', '.TIFF', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_id_dataset(dir, max_dataset_size=float("inf")): 24 | ids = [] 25 | images = [] 26 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 27 | 28 | id_names = sorted(os.listdir(dir)) 29 | for id_name in id_names: 30 | id_path = os.path.join(dir, id_name) 31 | fnames = os.listdir(id_path) 32 | for fname in fnames: 33 | path = os.path.join(dir, id_name, fname) 34 | images.append(path) 35 | ids.append(id_name) 36 | return images[:min(max_dataset_size, len(images))], ids[:min(max_dataset_size, len(ids))] 37 | 38 | def make_noid_dataset(dir, max_dataset_size=float("inf")): 39 | images = [] 40 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 41 | 42 | fnames = sorted(os.listdir(dir)) 43 | for fname in fnames: 44 | path = os.path.join(dir, fname) 45 | images.append(path) 46 | return images[:min(max_dataset_size, len(images))] 47 | 48 | def make_dataset(dir, max_dataset_size=float("inf")): 49 | images = [] 50 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 51 | 52 | for root, _, fnames in sorted(os.walk(dir)): 53 | for fname in fnames: 54 | if is_image_file(fname): 55 | path = os.path.join(root, fname) 56 | images.append(path) 57 | return images[:min(max_dataset_size, len(images))] 58 | 59 | 60 | def default_loader(path): 61 | return Image.open(path).convert('RGB') 62 | 63 | 64 | class ImageFolder(data.Dataset): 65 | 66 | def __init__(self, root, transform=None, return_paths=False, 67 | loader=default_loader): 68 | imgs = make_dataset(root) 69 | if len(imgs) == 0: 70 | raise(RuntimeError("Found 0 images in: " + root + "\n" 71 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 72 | 73 | self.root = root 74 | self.imgs = imgs 75 | self.transform = transform 76 | self.return_paths = return_paths 77 | self.loader = loader 78 | 79 | def __getitem__(self, index): 80 | path = self.imgs[index] 81 | img = self.loader(path) 82 | if self.transform is not None: 83 | img = self.transform(img) 84 | if self.return_paths: 85 | return img, path 86 | else: 87 | return img 88 | 89 | def __len__(self): 90 | return len(self.imgs) 91 | -------------------------------------------------------------------------------- /data/jointset_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_params, get_transform 2 | from data.image_folder import make_id_dataset, make_dataset 3 | from PIL import Image 4 | import random 5 | import os 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class JointSetDataset(BaseDataset): 11 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 12 | 13 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | self.A_paths, self.A_ids = make_id_dataset(opt.dataroot, opt.max_dataset_size) 24 | self.B_paths = make_dataset(opt.dataroot2, opt.max_dataset_size) 25 | 26 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 27 | 28 | def __getitem__(self, index): 29 | """Return a data point and its metadata information. 30 | 31 | Parameters: 32 | index - - a random integer for data indexing 33 | 34 | Returns a dictionary that contains A and A_paths 35 | A(tensor) - - an image in one domain 36 | A_paths(str) - - the path of the image 37 | """ 38 | # A_id = self.A_ids[index] 39 | B_path = self.B_paths[index] 40 | B_img = Image.open(B_path).convert('RGB') 41 | 42 | transform_params = get_params(self.opt, B_img.size) 43 | self.transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 44 | 45 | B = self.transform(B_img) 46 | 47 | A_index = index % len(self.A_paths) 48 | A_video = self.A_paths[A_index] 49 | A_frames = os.listdir(A_video) 50 | A_frame = random.sample(A_frames, 1)[0] 51 | A_path = os.path.join(A_video, A_frame) 52 | A_img = Image.open(A_path).convert('RGB') 53 | A = self.transform(A_img) 54 | 55 | return {'A': A, 'A_paths': A_path, 'B': B, 'B_paths': B_path} 56 | 57 | def __len__(self): 58 | """Return the total number of images in the dataset.""" 59 | return len(self.B_paths) 60 | -------------------------------------------------------------------------------- /data/noise_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_transform 2 | from data.image_folder import make_dataset 3 | import numpy as np 4 | import torch 5 | 6 | 7 | class NoiseDataset(BaseDataset): 8 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 9 | 10 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 11 | """ 12 | 13 | def __init__(self, opt): 14 | """Initialize this dataset class. 15 | 16 | Parameters: 17 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 18 | """ 19 | BaseDataset.__init__(self, opt) 20 | self.A_paths = sorted(make_dataset(opt.dataroot, opt.max_dataset_size)) 21 | # input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 22 | # self.transform = get_transform(opt, grayscale=(input_nc == 1)) 23 | 24 | def __getitem__(self, index): 25 | """Return a data point and its metadata information. 26 | 27 | Parameters: 28 | index - - a random integer for data indexing 29 | 30 | Returns a dictionary that contains A and A_paths 31 | A(tensor) - - an image in one domain 32 | A_paths(str) - - the path of the image 33 | """ 34 | A_path = self.A_paths[index] 35 | A = torch.from_numpy(np.random.RandomState(index).randn(512)) 36 | return {'A': A, 'A_paths': A_path} 37 | 38 | def __len__(self): 39 | """Return the total number of images in the dataset.""" 40 | return len(self.A_paths) 41 | -------------------------------------------------------------------------------- /data/noiseframe_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_params, get_transform 2 | from data.image_folder import make_id_dataset, make_dataset, make_noid_dataset 3 | from PIL import Image 4 | import random 5 | import os 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class NoiseFrameDataset(BaseDataset): 11 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 12 | 13 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | if 'FaceForensicspp' in opt.dataroot: 24 | self.A_paths = make_noid_dataset(opt.dataroot, opt.max_dataset_size) 25 | else: 26 | self.A_paths, self.A_ids = make_id_dataset(opt.dataroot, opt.max_dataset_size) 27 | 28 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 29 | self.nun_frames = self.opt.num_frame 30 | self.seq_frames = self.opt.seq_frame 31 | self.max_gap = self.opt.max_gap 32 | self.max_dataset_size = self.opt.max_dataset_size 33 | 34 | def __getitem__(self, index): 35 | """Return a data point and its metadata information. 36 | 37 | Parameters: 38 | index - - a random integer for data indexing 39 | 40 | Returns a dictionary that contains A and A_paths 41 | A(tensor) - - an image in one domain 42 | A_paths(str) - - the path of the image 43 | """ 44 | 45 | A_list = [] 46 | A_index = index % len(self.A_paths) 47 | A_video = self.A_paths[A_index] 48 | A_frames = sorted(os.listdir(A_video)) 49 | max_frames = len(A_frames) 50 | while max_frames < 60: 51 | A_index = (A_index + 1) % len(self.A_paths) 52 | A_video = self.A_paths[A_index] 53 | A_frames = sorted(os.listdir(A_video)) 54 | max_frames = len(A_frames) 55 | 56 | first_index = random.randint(0, max_frames - 1 - (self.nun_frames -1) * self.seq_frames) 57 | last_index = first_index + (self.nun_frames -1) * self.seq_frames 58 | gap_index = random.randint(5, self.max_gap) 59 | if first_index > 14: 60 | app_index = first_index - gap_index 61 | else: 62 | app_index = last_index + gap_index 63 | 64 | for i in range(first_index, first_index + self.nun_frames * self.seq_frames, self.seq_frames): 65 | A_frame = A_frames[i] 66 | # print(A_frame) 67 | A_path = os.path.join(A_video, A_frame) 68 | A_img = Image.open(A_path).convert('RGB') 69 | 70 | if i == first_index: 71 | transform_params = get_params(self.opt, A_img.size) 72 | self.transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 73 | 74 | A = self.transform(A_img) 75 | A_list.append(A.unsqueeze(0)) 76 | 77 | As = torch.cat(A_list, 0) 78 | 79 | A_frame = A_frames[app_index] 80 | # print(A_frame) 81 | A_path = os.path.join(A_video, A_frame) 82 | A_img = Image.open(A_path).convert('RGB') 83 | 84 | A = self.transform(A_img) 85 | 86 | randindex = random.randint(0, self.max_dataset_size - 1) 87 | B = torch.from_numpy(np.random.RandomState(randindex).randn(512)) 88 | 89 | return {'A': A, 'B': B, 'As': As, 'A_paths': A_path} 90 | 91 | def __len__(self): 92 | """Return the total number of images in the dataset.""" 93 | return len(self.A_paths) 94 | -------------------------------------------------------------------------------- /data/noiseshufflevideo_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_transform, get_params 2 | from data.image_folder import make_id_dataset 3 | from PIL import Image 4 | import random 5 | import os 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class NoiseShuffleVideoDataset(BaseDataset): 11 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 12 | 13 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | self.opt = opt 24 | self.A_paths, self.A_ids = make_id_dataset(opt.dataroot, opt.max_dataset_size) 25 | 26 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 27 | 28 | def __getitem__(self, index): 29 | """Return a data point and its metadata information. 30 | 31 | Parameters: 32 | index - - a random integer for data indexing 33 | 34 | Returns a dictionary that contains A and A_paths 35 | A(tensor) - - an image in one domain 36 | A_paths(str) - - the path of the image 37 | """ 38 | # A_id = self.A_ids[index] 39 | A_list = [] 40 | random.seed(index) 41 | A_index = int(random.random() * (len(self.A_paths) - 1)) 42 | A_video = self.A_paths[A_index] 43 | A_frames = sorted(os.listdir(A_video)) 44 | max_frames = len(A_frames) 45 | while max_frames < 60: 46 | A_index = (A_index + 1) % len(self.A_paths) 47 | A_video = self.A_paths[A_index] 48 | A_frames = sorted(os.listdir(A_video)) 49 | max_frames = len(A_frames) 50 | 51 | for i in range(max_frames): 52 | A_frame = A_frames[i] 53 | # print(A_frame) 54 | A_path = os.path.join(A_video, A_frame) 55 | A_img = Image.open(A_path).convert('RGB') 56 | 57 | if i == 0: 58 | transform_params = get_params(self.opt, A_img.size) 59 | self.transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 60 | 61 | A = self.transform(A_img) 62 | A_list.append(A.unsqueeze(0)) 63 | 64 | A = torch.cat(A_list, 0) 65 | B = torch.from_numpy(np.random.RandomState(index).randn(512)) 66 | 67 | return {'A': A, 'A_paths': A_path, 'B': B} 68 | 69 | def __len__(self): 70 | """Return the total number of images in the dataset.""" 71 | return len(self.A_paths) 72 | -------------------------------------------------------------------------------- /data/randomvideo_dataset.py: -------------------------------------------------------------------------------- 1 | from data.base_dataset import BaseDataset, get_transform, get_params 2 | from data.image_folder import make_id_dataset 3 | from PIL import Image 4 | import random 5 | import os 6 | import numpy as np 7 | import torch 8 | 9 | 10 | class RandomVideoDataset(BaseDataset): 11 | """This dataset class can load a set of images specified by the path --dataroot /path/to/data. 12 | 13 | It can be used for generating CycleGAN results only for one side with the model option '-model test'. 14 | """ 15 | 16 | def __init__(self, opt): 17 | """Initialize this dataset class. 18 | 19 | Parameters: 20 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 21 | """ 22 | BaseDataset.__init__(self, opt) 23 | self.opt = opt 24 | self.A_paths, self.A_ids = make_id_dataset(opt.dataroot, opt.max_dataset_size) 25 | 26 | self.input_nc = self.opt.output_nc if self.opt.direction == 'BtoA' else self.opt.input_nc 27 | 28 | def __getitem__(self, index): 29 | """Return a data point and its metadata information. 30 | 31 | Parameters: 32 | index - - a random integer for data indexing 33 | 34 | Returns a dictionary that contains A and A_paths 35 | A(tensor) - - an image in one domain 36 | A_paths(str) - - the path of the image 37 | """ 38 | # A_id = self.A_ids[index] 39 | A_list = [] 40 | A_video = self.A_paths[index] 41 | A_frames = sorted(os.listdir(A_video)) 42 | max_frames = len(A_frames) 43 | while max_frames < 60: 44 | A_index = (index + 1) % len(self.A_paths) 45 | A_video = self.A_paths[A_index] 46 | A_frames = sorted(os.listdir(A_video)) 47 | max_frames = len(A_frames) 48 | 49 | for i in range(max_frames): 50 | A_frame = A_frames[i] 51 | # print(A_frame) 52 | A_path = os.path.join(A_video, A_frame) 53 | A_img = Image.open(A_path).convert('RGB') 54 | 55 | if i == 0: 56 | transform_params = get_params(self.opt, A_img.size) 57 | self.transform = get_transform(self.opt, transform_params, grayscale=(self.input_nc == 1)) 58 | 59 | A = self.transform(A_img) 60 | A_list.append(A.unsqueeze(0)) 61 | 62 | A = torch.cat(A_list, 0) 63 | 64 | return {'A': A, 'A_paths': A_path} 65 | 66 | def __len__(self): 67 | """Return the total number of images in the dataset.""" 68 | return len(self.A_paths) 69 | -------------------------------------------------------------------------------- /dnnlib/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | from .util import EasyDict, make_cache_dir_path 10 | -------------------------------------------------------------------------------- /docs/results1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arthur-qiu/StyleFaceV/b4b74d222cbe30b1924477a612ed3e78029b2c85/docs/results1.gif -------------------------------------------------------------------------------- /docs/results2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arthur-qiu/StyleFaceV/b4b74d222cbe30b1924477a612ed3e78029b2c85/docs/results2.gif -------------------------------------------------------------------------------- /docs/teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/arthur-qiu/StyleFaceV/b4b74d222cbe30b1924477a612ed3e78029b2c85/docs/teaser.png -------------------------------------------------------------------------------- /environment/stylefacev.yaml: -------------------------------------------------------------------------------- 1 | name: stylefacev 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=4.5=1_gnu 10 | - _tflow_select=2.3.0=mkl 11 | - aiohttp=3.8.1=py38h7f8727e_0 12 | - aiosignal=1.2.0=pyhd3eb1b0_0 13 | - astor=0.8.1=py38h06a4308_0 14 | - astunparse=1.6.3=py_0 15 | - async-timeout=4.0.1=pyhd3eb1b0_0 16 | - attrs=21.2.0=pyhd3eb1b0_0 17 | - av=8.0.3=py38h2c5b837_0 18 | - blas=1.0=mkl 19 | - blinker=1.4=py38h06a4308_0 20 | - brotli=1.0.9=he6710b0_2 21 | - brotlipy=0.7.0=py38h27cfd23_1003 22 | - bzip2=1.0.8=h7b6447c_0 23 | - c-ares=1.17.1=h27cfd23_0 24 | - ca-certificates=2022.3.29=h06a4308_0 25 | - certifi=2021.10.8=py38h06a4308_0 26 | - cffi=1.14.6=py38h400218f_0 27 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 28 | - click=8.0.3=pyhd3eb1b0_0 29 | - cryptography=35.0.0=py38hd23ed53_0 30 | - cudatoolkit=11.1.74=h6bb024c_0 31 | - cycler=0.10.0=py38_0 32 | - dataclasses=0.8=pyh6d0b6a4_7 33 | - dbus=1.13.18=hb2f20db_0 34 | - dill=0.3.4=pyhd3eb1b0_0 35 | - expat=2.4.1=h2531618_2 36 | - ffmpeg=4.3.2=hca11adc_0 37 | - fontconfig=2.13.1=h6c09931_0 38 | - fonttools=4.25.0=pyhd3eb1b0_0 39 | - freetype=2.11.0=h70c0345_0 40 | - frozenlist=1.2.0=py38h7f8727e_0 41 | - gast=0.4.0=pyhd3eb1b0_0 42 | - glib=2.69.1=h5202010_0 43 | - gmp=6.2.1=h2531618_2 44 | - gnutls=3.6.15=he1e5248_0 45 | - google-pasta=0.2.0=pyhd3eb1b0_0 46 | - grpcio=1.42.0=py38hce63b2e_0 47 | - gst-plugins-base=1.14.0=h8213a91_2 48 | - gstreamer=1.14.0=h28cd5cc_2 49 | - h5py=2.10.0=py38hd6299e0_1 50 | - hdf5=1.10.6=hb1b8bf9_0 51 | - icu=58.2=he6710b0_3 52 | - idna=3.2=pyhd3eb1b0_0 53 | - imageio=2.9.0=pyhd3eb1b0_0 54 | - importlib-metadata=4.8.2=py38h06a4308_0 55 | - intel-openmp=2021.4.0=h06a4308_3561 56 | - jpeg=9d=h7f8727e_0 57 | - keras-preprocessing=1.1.2=pyhd3eb1b0_0 58 | - kiwisolver=1.3.1=py38h2531618_0 59 | - lame=3.100=h7b6447c_0 60 | - lcms2=2.12=h3be6417_0 61 | - ld_impl_linux-64=2.35.1=h7274673_9 62 | - libffi=3.3=he6710b0_2 63 | - libgcc-ng=9.3.0=h5101ec6_17 64 | - libgfortran-ng=7.5.0=ha8ba4b0_17 65 | - libgfortran4=7.5.0=ha8ba4b0_17 66 | - libgomp=9.3.0=h5101ec6_17 67 | - libiconv=1.15=h63c8f33_5 68 | - libidn2=2.3.2=h7f8727e_0 69 | - libpng=1.6.37=hbc83047_0 70 | - libprotobuf=3.17.2=h4ff587b_1 71 | - libstdcxx-ng=9.3.0=hd4cf53a_17 72 | - libtasn1=4.16.0=h27cfd23_0 73 | - libtiff=4.2.0=h85742a9_0 74 | - libunistring=0.9.10=h27cfd23_0 75 | - libuuid=1.0.3=h7f8727e_2 76 | - libuv=1.40.0=h7b6447c_0 77 | - libwebp-base=1.2.0=h27cfd23_0 78 | - libxcb=1.14=h7b6447c_0 79 | - libxml2=2.9.12=h03d6c58_0 80 | - lz4-c=1.9.3=h295c915_1 81 | - matplotlib-base=3.4.2=py38hab158f2_0 82 | - mkl=2021.4.0=h06a4308_640 83 | - mkl-service=2.4.0=py38h7f8727e_0 84 | - mkl_fft=1.3.1=py38hd3c417c_0 85 | - mkl_random=1.2.2=py38h51133e4_0 86 | - multidict=5.1.0=py38h27cfd23_2 87 | - munch=2.5.0=pyhd3eb1b0_0 88 | - munkres=1.1.4=py_0 89 | - ncurses=6.3=h7f8727e_2 90 | - nettle=3.7.3=hbbd107a_1 91 | - ninja=1.10.2=hff7bd54_1 92 | - numpy=1.21.2=py38h20f2e39_0 93 | - numpy-base=1.21.2=py38h79a1101_0 94 | - olefile=0.46=pyhd3eb1b0_0 95 | - openh264=2.1.1=h780b84a_0 96 | - openjpeg=2.4.0=h3ad879b_0 97 | - openssl=1.1.1l=h7f8727e_0 98 | - opt_einsum=3.3.0=pyhd3eb1b0_1 99 | - pcre=8.45=h295c915_0 100 | - pillow=8.3.1=py38h2c7a002_0 101 | - pip=21.2.4=py38h06a4308_0 102 | - psutil=5.8.0=py38h27cfd23_1 103 | - pyasn1=0.4.8=pyhd3eb1b0_0 104 | - pycparser=2.21=pyhd3eb1b0_0 105 | - pyjwt=1.7.1=py38_0 106 | - pyopenssl=21.0.0=pyhd3eb1b0_1 107 | - pyparsing=3.0.4=pyhd3eb1b0_0 108 | - pyqt=5.9.2=py38h05f1152_4 109 | - pysocks=1.7.1=py38h06a4308_0 110 | - python=3.8.12=h12debd9_0 111 | - python-dateutil=2.8.2=pyhd3eb1b0_0 112 | - python-flatbuffers=2.0=pyhd3eb1b0_0 113 | - python_abi=3.8=2_cp38 114 | - pytorch=1.10.0=py3.8_cuda11.1_cudnn8.0.5_0 115 | - pytorch-mutex=1.0=cuda 116 | - qt=5.9.7=h5867ecd_1 117 | - readline=8.1=h27cfd23_0 118 | - requests=2.26.0=pyhd3eb1b0_0 119 | - requests-oauthlib=1.3.0=py_0 120 | - setuptools=58.0.4=py38h06a4308_0 121 | - sip=4.19.13=py38he6710b0_0 122 | - six=1.16.0=pyhd3eb1b0_0 123 | - sqlite=3.36.0=hc218d9a_0 124 | - termcolor=1.1.0=py38h06a4308_1 125 | - tk=8.6.11=h1ccaba5_0 126 | - torchaudio=0.10.0=py38_cu111 127 | - torchvision=0.11.1=py38_cu111 128 | - tornado=6.1=py38h27cfd23_0 129 | - tqdm=4.62.2=pyhd3eb1b0_1 130 | - typing-extensions=3.10.0.2=hd3eb1b0_0 131 | - typing_extensions=3.10.0.2=pyh06a4308_0 132 | - urllib3=1.26.7=pyhd3eb1b0_0 133 | - werkzeug=2.0.2=pyhd3eb1b0_0 134 | - wheel=0.37.0=pyhd3eb1b0_1 135 | - wrapt=1.13.3=py38h7f8727e_2 136 | - x264=1!161.3030=h7f98852_1 137 | - xz=5.2.5=h7b6447c_0 138 | - yaml=0.2.5=h7b6447c_0 139 | - yarl=1.6.3=py38h27cfd23_0 140 | - zipp=3.6.0=pyhd3eb1b0_0 141 | - zlib=1.2.11=h7b6447c_3 142 | - zstd=1.4.9=haebb681_0 143 | - pip: 144 | - absl-py==1.0.0 145 | - cachetools==4.2.4 146 | - dominate==2.6.0 147 | - easydict==1.9 148 | - ffmpeg-python==0.2.0 149 | - future==0.18.2 150 | - glfw==2.2.0 151 | - google-auth==2.3.3 152 | - google-auth-oauthlib==0.4.6 153 | - imageio-ffmpeg==0.4.3 154 | - imgui==1.3.0 155 | - joblib==1.1.0 156 | - jsonpatch==1.32 157 | - jsonpointer==2.2 158 | - lpips==0.1.4 159 | - markdown==3.3.6 160 | - matplotlib==3.5.0 161 | - networkx==2.6.3 162 | - oauthlib==3.1.1 163 | - opencv-python==4.5.4.58 164 | - packaging==21.3 165 | - protobuf==3.19.1 166 | - pyasn1-modules==0.2.8 167 | - pyopengl==3.1.5 168 | - pyspng==0.1.0 169 | - pytube==11.0.2 170 | - pywavelets==1.2.0 171 | - pyyaml==6.0 172 | - pyzmq==22.3.0 173 | - rsa==4.8 174 | - scikit-image==0.18.3 175 | - scikit-learn==1.0.1 176 | - scipy==1.7.3 177 | - setuptools-scm==6.3.2 178 | - tensorboard==2.7.0 179 | - tensorboard-data-server==0.6.1 180 | - tensorboard-plugin-wit==1.8.0 181 | - threadpoolctl==3.0.0 182 | - tifffile==2021.11.2 183 | - tomli==1.2.2 184 | - torch-fidelity==0.3.0 185 | - torchfile==0.1.0 186 | - visdom==0.1.8.9 187 | - websocket-client==1.2.3 188 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from models.base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | print(model_filename) 34 | modellib = importlib.import_module(model_filename) 35 | model = None 36 | target_model_name = model_name.replace('_', '') + 'model' 37 | for name, cls in modellib.__dict__.items(): 38 | if name.lower() == target_model_name.lower() \ 39 | and issubclass(cls, BaseModel): 40 | model = cls 41 | 42 | if model is None: 43 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 44 | exit(0) 45 | 46 | return model 47 | 48 | 49 | def get_option_setter(model_name): 50 | """Return the static method of the model class.""" 51 | model_class = find_model_using_name(model_name) 52 | return model_class.modify_commandline_options 53 | 54 | 55 | def create_model(opt): 56 | """Create a model given the option. 57 | 58 | This function warps the class CustomDatasetDataLoader. 59 | This is the main interface between this package and 'train.py'/'test.py' 60 | 61 | Example: 62 | >>> from models import create_model 63 | >>> model = create_model(opt) 64 | """ 65 | model = find_model_using_name(opt.model) 66 | instance = model(opt) 67 | print("model [%s] was created" % type(instance).__name__) 68 | return instance 69 | -------------------------------------------------------------------------------- /models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | from . import networks 6 | 7 | 8 | class BaseModel(ABC): 9 | """This class is an abstract base class (ABC) for models. 10 | To create a subclass, you need to implement the following five functions: 11 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 12 | -- : unpack data from dataset and apply preprocessing. 13 | -- : produce intermediate results. 14 | -- : calculate losses, gradients, and update network weights. 15 | -- : (optionally) add model-specific options and set default options. 16 | """ 17 | 18 | def __init__(self, opt): 19 | """Initialize the BaseModel class. 20 | 21 | Parameters: 22 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 23 | 24 | When creating your custom class, you need to implement your own initialization. 25 | In this function, you should first call 26 | Then, you need to define four lists: 27 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 28 | -- self.model_names (str list): define networks used in our training. 29 | -- self.visual_names (str list): specify the images that you want to display and save. 30 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 31 | """ 32 | self.opt = opt 33 | self.gpu_ids = opt.gpu_ids 34 | self.isTrain = opt.isTrain 35 | self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU 36 | self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir 37 | if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark. 38 | torch.backends.cudnn.benchmark = True 39 | self.loss_names = [] 40 | self.model_names = [] 41 | self.visual_names = [] 42 | self.optimizers = [] 43 | self.image_paths = [] 44 | self.metric = 0 # used for learning rate policy 'plateau' 45 | 46 | @staticmethod 47 | def modify_commandline_options(parser, is_train): 48 | """Add new model-specific options, and rewrite default values for existing options. 49 | 50 | Parameters: 51 | parser -- original option parser 52 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 53 | 54 | Returns: 55 | the modified parser. 56 | """ 57 | return parser 58 | 59 | @abstractmethod 60 | def set_input(self, input): 61 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 62 | 63 | Parameters: 64 | input (dict): includes the data itself and its metadata information. 65 | """ 66 | pass 67 | 68 | @abstractmethod 69 | def forward(self): 70 | """Run forward pass; called by both functions and .""" 71 | pass 72 | 73 | @abstractmethod 74 | def optimize_parameters(self): 75 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 76 | pass 77 | 78 | def setup(self, opt): 79 | """Load and print networks; create schedulers 80 | 81 | Parameters: 82 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 83 | """ 84 | if self.isTrain: 85 | self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers] 86 | if not self.isTrain or opt.continue_train: 87 | load_suffix = 'iter_%d' % opt.load_iter if opt.load_iter > 0 else opt.epoch 88 | self.load_networks(load_suffix) 89 | self.print_networks(opt.verbose) 90 | 91 | def eval(self): 92 | """Make models eval mode during test time""" 93 | for name in self.model_names: 94 | if isinstance(name, str): 95 | net = getattr(self, 'net' + name) 96 | net.eval() 97 | 98 | def test(self): 99 | """Forward function used in test time. 100 | 101 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 102 | It also calls to produce additional visualization results 103 | """ 104 | with torch.no_grad(): 105 | self.forward() 106 | self.compute_visuals() 107 | 108 | def compute_visuals(self): 109 | """Calculate additional output images for visdom and HTML visualization""" 110 | pass 111 | 112 | def get_image_paths(self): 113 | """ Return image paths that are used to load current data""" 114 | return self.image_paths 115 | 116 | def update_learning_rate(self): 117 | """Update learning rates for all the networks; called at the end of every epoch""" 118 | old_lr = self.optimizers[0].param_groups[0]['lr'] 119 | for scheduler in self.schedulers: 120 | if self.opt.lr_policy == 'plateau': 121 | scheduler.step(self.metric) 122 | else: 123 | scheduler.step() 124 | 125 | lr = self.optimizers[0].param_groups[0]['lr'] 126 | print('learning rate %.7f -> %.7f' % (old_lr, lr)) 127 | 128 | def get_current_visuals(self): 129 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 130 | visual_ret = OrderedDict() 131 | for name in self.visual_names: 132 | if isinstance(name, str): 133 | visual_ret[name] = getattr(self, name) 134 | return visual_ret 135 | 136 | def get_current_losses(self): 137 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 138 | errors_ret = OrderedDict() 139 | for name in self.loss_names: 140 | if isinstance(name, str): 141 | errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number 142 | return errors_ret 143 | 144 | def save_networks(self, epoch): 145 | """Save all the networks to the disk. 146 | 147 | Parameters: 148 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 149 | """ 150 | for name in self.model_names: 151 | if isinstance(name, str): 152 | save_filename = '%s_net_%s.pth' % (epoch, name) 153 | save_path = os.path.join(self.save_dir, save_filename) 154 | net = getattr(self, 'net' + name) 155 | 156 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 157 | if hasattr(net, 'module'): 158 | torch.save(net.module.cpu().state_dict(), save_path) 159 | net.cuda(self.gpu_ids[0]) 160 | else: 161 | torch.save(net.cpu().state_dict(), save_path) 162 | net.cuda(self.gpu_ids[0]) 163 | else: 164 | torch.save(net.cpu().state_dict(), save_path) 165 | 166 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 167 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 168 | key = keys[i] 169 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 170 | if module.__class__.__name__.startswith('InstanceNorm') and \ 171 | (key == 'running_mean' or key == 'running_var'): 172 | if getattr(module, key) is None: 173 | state_dict.pop('.'.join(keys)) 174 | if module.__class__.__name__.startswith('InstanceNorm') and \ 175 | (key == 'num_batches_tracked'): 176 | state_dict.pop('.'.join(keys)) 177 | else: 178 | self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1) 179 | 180 | def load_networks(self, epoch): 181 | """Load all the networks from the disk. 182 | 183 | Parameters: 184 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 185 | """ 186 | for name in self.model_names: 187 | if isinstance(name, str): 188 | load_filename = '%s_net_%s.pth' % (epoch, name) 189 | load_path = os.path.join(self.save_dir, load_filename) 190 | net = getattr(self, 'net' + name) 191 | if isinstance(net, torch.nn.DataParallel): 192 | net = net.module 193 | print('loading the model from %s' % load_path) 194 | # if you are using PyTorch newer than 0.4 (e.g., built from 195 | # GitHub source), you can remove str() on self.device 196 | state_dict = torch.load(load_path, map_location=str(self.device)) 197 | if hasattr(state_dict, '_metadata'): 198 | del state_dict._metadata 199 | 200 | # patch InstanceNorm checkpoints prior to 0.4 201 | for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop 202 | self.__patch_instance_norm_state_dict(state_dict, net, key.split('.')) 203 | net.load_state_dict(state_dict) 204 | 205 | def print_networks(self, verbose): 206 | """Print the total number of parameters in the network and (if verbose) network architecture 207 | 208 | Parameters: 209 | verbose (bool) -- if verbose: print the network architecture 210 | """ 211 | print('---------- Networks initialized -------------') 212 | for name in self.model_names: 213 | if isinstance(name, str): 214 | net = getattr(self, 'net' + name) 215 | num_params = 0 216 | for param in net.parameters(): 217 | num_params += param.numel() 218 | if verbose: 219 | print(net) 220 | print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6)) 221 | print('-----------------------------------------------') 222 | 223 | def set_requires_grad(self, nets, requires_grad=False): 224 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 225 | Parameters: 226 | nets (network list) -- a list of networks 227 | requires_grad (bool) -- whether the networks require gradients or not 228 | """ 229 | if not isinstance(nets, list): 230 | nets = [nets] 231 | for net in nets: 232 | if net is not None: 233 | for param in net.parameters(): 234 | param.requires_grad = requires_grad 235 | -------------------------------------------------------------------------------- /models/reenact_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | from . import lmcode_networks 5 | from . import diy_networks 6 | from . import resnet 7 | 8 | import dnnlib 9 | import legacy 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import random 13 | import os 14 | 15 | from . import rnn_net 16 | 17 | def make_transform(translate, angle): 18 | m = np.eye(3) 19 | s = np.sin(angle/360.0*np.pi*2) 20 | c = np.cos(angle/360.0*np.pi*2) 21 | m[0][0] = c 22 | m[0][1] = s 23 | m[0][2] = translate[0] 24 | m[1][0] = -s 25 | m[1][1] = c 26 | m[1][2] = translate[1] 27 | return m 28 | 29 | class ReenactModel(BaseModel): 30 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 31 | 32 | The model training requires '--dataset_mode aligned' dataset. 33 | By default, it uses a '--netG unet256' U-Net generator, 34 | a '--netD basic' discriminator (PatchGAN), 35 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 36 | 37 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 38 | """ 39 | @staticmethod 40 | def modify_commandline_options(parser, is_train=True): 41 | """Add new dataset-specific options, and rewrite default values for existing options. 42 | 43 | Parameters: 44 | parser -- original option parser 45 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 46 | 47 | Returns: 48 | the modified parser. 49 | 50 | For pix2pix, we do not use image buffer 51 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 52 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 53 | """ 54 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 55 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='noiseshufflevideo', num_test = 32) 56 | parser.add_argument('--pose_path', type=str, default='', help='path for pose net') 57 | parser.add_argument('--rnn_path', type=str, default='', help='path for rnn net') 58 | parser.add_argument('--n_frames_G', type=int, default=60) 59 | parser.add_argument('--w_residual', type=float, default=0.2) 60 | parser.add_argument('--num_point', type=int, default=14) 61 | parser.add_argument('--model_names', type=str, default='') 62 | if is_train: 63 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 64 | parser.add_argument('--lambda_L1', type=float, default=1.0, help='weight for L1 loss') 65 | 66 | return parser 67 | 68 | def __init__(self, opt): 69 | """Initialize the pix2pix class. 70 | 71 | Parameters: 72 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 73 | """ 74 | BaseModel.__init__(self, opt) 75 | # specify the training losses you want to print out. The training/test scripts will call 76 | self.loss_names = ['G_L1', 'G_VGG', 'G_W'] 77 | # specify the images you want to save/display. The training/test scripts will call 78 | self.visual_names = ['real_B', 'fake_B', 'fake_BA', 'fake_BR', 'real_A', 'fake_A', 'fake_AB', 'fake_AR', 'real_vid_B', 'fake_vid_B', 'fake_vid_AB', 'fake_vid_AR', 'fake_vid_BR', 'fake_vid'] 79 | # specify the models you want to save to the disk. The training/test scripts will call and 80 | if self.isTrain: 81 | self.model_names = ['FE'] 82 | else: # during test time, only load G 83 | self.model_names = ['FE'] 84 | if opt.model_names != '': 85 | str_models = opt.model_names.split(',') 86 | self.model_names = [] 87 | for str_model in str_models: 88 | self.model_names.append(str_model) 89 | # define networks (both generator and discriminator) 90 | with dnnlib.util.open_url(opt.network_pkl) as f: 91 | self.netG = legacy.load_network_pkl(f)['G_ema'].eval().to(self.gpu_ids[0]) # type: ignore 92 | 93 | lm_path = 'pretrained_models/wing.ckpt' 94 | self.netFE_lm = lmcode_networks.FAN(fname_pretrained=lm_path).eval().to(self.gpu_ids[0]) 95 | self.netFE_pose = diy_networks._resposenet(num_point=opt.num_point).eval().to(self.gpu_ids[0]) 96 | if opt.pose_path != '': 97 | self.netFE_pose.load_state_dict(torch.load(opt.pose_path)) 98 | 99 | self.netFE = resnet.wide_resdisnet50_2(num_classes=512 * 16).to(self.gpu_ids[0]) 100 | self.netFE = networks.init_net(self.netFE, opt.init_type, opt.init_gain, self.gpu_ids) 101 | 102 | self.netR = rnn_net.RNNModule(w_residual = opt.w_residual).to(self.gpu_ids[0]) 103 | if opt.rnn_path != '': 104 | self.netR.load_state_dict(torch.load(opt.rnn_path)) 105 | self.n_frames_G = opt.n_frames_G 106 | self.style_gan_size = 8 107 | 108 | self.m_zero = make_transform((0.0,0.0),(0.0)) 109 | self.count = 0 110 | 111 | 112 | def set_input(self, input): 113 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 114 | 115 | Parameters: 116 | input (dict): include the data itself and its metadata information. 117 | 118 | The option 'direction' can be used to swap images in domain A and domain B. 119 | """ 120 | self.real_Bs = input['A'].to(self.device) 121 | self.image_paths = input['A_paths'] 122 | self.count += 1 123 | self.image_paths[0] = os.path.split(self.image_paths[0])[0] + '/' + str(self.count) + '.png' 124 | 125 | real_v_list = [] 126 | with torch.no_grad(): 127 | for i in range(self.real_Bs.shape[1]): 128 | real_v_list.append(self.netFE_pose(self.netFE_lm.get_heatmap(self.real_Bs[:,i,...], b_preprocess=False), mode = 1).unsqueeze(1)) 129 | 130 | self.real_v = torch.cat(real_v_list, 1).detach() 131 | 132 | self.real_z = input['B'].to(self.device) 133 | 134 | def forward(self): 135 | """Run forward pass; called by both functions and .""" 136 | 137 | x_fake, self.rand_in, self.rand_rec = self.netR(self.real_v[:, 0].view(self.opt.batch_size, self.style_gan_size * self.style_gan_size), self.n_frames_G) 138 | x_fake = x_fake.view(self.opt.batch_size, self.n_frames_G, 1, self.style_gan_size, 139 | self.style_gan_size) 140 | 141 | if self.n_frames_G == 30: 142 | self.real_R_pose = torch.cat([x_fake, x_fake], 1) 143 | else: 144 | self.real_R_pose = x_fake 145 | 146 | if hasattr(self.netG.synthesis, 'input'): 147 | self.netG.synthesis.input.transform.copy_(torch.from_numpy(self.m_zero)) 148 | 149 | self.real_A_w = self.netG.mapping(self.real_z, None) 150 | self.real_A = self.netG.synthesis(self.real_A_w, noise_mode='const').detach().clamp(-1, 1) 151 | if self.real_A.shape[2] != 256: 152 | self.real_A = F.interpolate(self.real_A, size=(256, 256), mode='area') 153 | self.real_A_heat = self.netFE_lm.get_heatmap(self.real_A, b_preprocess=False) 154 | self.real_A_pose = self.netFE_pose(self.real_A_heat, mode=1).detach() 155 | self.real_A_app = self.netFE(self.real_A, mode=1).detach() 156 | self.fake_A_w = self.netFE(self.real_A_app, self.real_A_pose, mode=2).view(-1, 16, 512) 157 | self.fake_A = self.netG.synthesis(self.fake_A_w, noise_mode='const') # G(A) 158 | 159 | self.real_B_list = [] 160 | self.fake_B_list = [] 161 | self.real_A_list = [] 162 | self.fake_A_list = [] 163 | self.fake_AB_list = [] 164 | self.fake_BA_list = [] 165 | self.fake_AR_list = [] 166 | self.fake_BR_list = [] 167 | # for i in range(self.real_Bs.shape[1]): 168 | self.real_B_app = self.netFE(self.real_Bs[:,0,...], mode=1) 169 | for i in range(60): 170 | self.real_B = self.real_Bs[:,i,...] 171 | if self.real_B.shape[2] != 256: 172 | self.real_B = F.interpolate(self.real_B, size=(256, 256), mode='area') 173 | self.real_B_heat = self.netFE_lm.get_heatmap(self.real_B, b_preprocess=False) 174 | self.real_B_pose = self.netFE_pose(self.real_B_heat, mode=1).detach() 175 | 176 | self.fake_B_w = self.netFE(self.real_B_app, self.real_B_pose, mode=2).view(-1, 16, 512) 177 | self.fake_AB_w = self.netFE(self.real_A_app, self.real_B_pose, mode=2).view(-1, 16, 512) 178 | self.fake_BA_w = self.netFE(self.real_B_app, self.real_A_pose, mode=2).view(-1, 16, 512) 179 | 180 | self.fake_AR_w = self.netFE(self.real_A_app, self.real_R_pose[:,i,...], mode=2).view(-1, 16, 512) 181 | self.fake_BR_w = self.netFE(self.real_B_app, self.real_R_pose[:,i,...], mode=2).view(-1, 16, 512) 182 | 183 | self.fake_B = self.netG.synthesis(self.fake_B_w, noise_mode='const') # G(A) 184 | self.fake_AB = self.netG.synthesis(self.fake_AB_w, noise_mode='const') # G(A) 185 | self.fake_BA = self.netG.synthesis(self.fake_BA_w, noise_mode='const') # G(A) 186 | self.fake_AR = self.netG.synthesis(self.fake_AR_w, noise_mode='const') # G(A) 187 | self.fake_BR = self.netG.synthesis(self.fake_BR_w, noise_mode='const') # G(A) 188 | 189 | self.real_B_list.append(self.real_B.clamp(-1, 1)) 190 | self.fake_B_list.append(self.fake_B.clamp(-1, 1)) 191 | self.fake_AB_list.append(self.fake_AB.clamp(-1, 1)) 192 | self.fake_BA_list.append(self.fake_BA.clamp(-1, 1)) 193 | self.fake_AR_list.append(self.fake_AR.clamp(-1, 1)) 194 | self.fake_BR_list.append(self.fake_BR.clamp(-1, 1)) 195 | self.real_A_list.append(self.real_A.clamp(-1, 1)) 196 | self.fake_A_list.append(self.fake_A.clamp(-1, 1)) 197 | 198 | def optimize_parameters(self): 199 | self.forward() # compute fake images: G(A) 200 | # update G 201 | self.optimizer_FE.zero_grad() # set G's gradients to zero 202 | self.backward_G() # calculate graidents for G 203 | self.optimizer_FE.step() # udpate G's weights 204 | 205 | def compute_visuals(self): 206 | 207 | self.real_A = self.real_A.clamp(-1, 1) 208 | self.fake_A = self.fake_A.clamp(-1, 1) 209 | self.fake_AB = self.fake_AB.clamp(-1, 1) 210 | self.real_B = self.real_B.clamp(-1, 1) 211 | self.fake_B = self.fake_B.clamp(-1, 1) 212 | self.fake_BA = self.fake_BA.clamp(-1, 1) 213 | self.fake_AR = self.fake_AR.clamp(-1, 1) 214 | self.fake_BR = self.fake_BR.clamp(-1, 1) 215 | 216 | self.real_vid_B = torch.cat(self.real_B_list, 0) 217 | self.fake_vid_B = torch.cat(self.fake_B_list, 0) 218 | self.fake_vid_AB = torch.cat(self.fake_AB_list, 0) 219 | self.fake_vid_BA = torch.cat(self.fake_BA_list, 0) 220 | self.fake_vid_AR = torch.cat(self.fake_AR_list, 0) 221 | self.fake_vid_BR = torch.cat(self.fake_BR_list, 0) 222 | self.real_vid_A = torch.cat(self.real_A_list, 0) 223 | self.fake_vid_A = torch.cat(self.fake_A_list, 0) 224 | 225 | self.fake_vid = torch.cat([torch.cat([self.real_vid_B, self.fake_vid_BA, self.fake_vid_B, self.fake_vid_BR], dim = 3), torch.cat([self.real_vid_A, self.fake_vid_A, self.fake_vid_AB, self.fake_vid_AR], dim = 3)], dim = 2) 226 | 227 | -------------------------------------------------------------------------------- /models/rnn_losses.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 3 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 4 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 5 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 6 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 7 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 8 | """ 9 | import sys 10 | 11 | import torch 12 | import torch.nn.functional as F 13 | import torch.nn as nn 14 | 15 | 16 | def loss_hinge_dis(dis_fake, dis_real): 17 | loss_real = torch.mean(F.relu(1. - dis_real)) 18 | loss_fake = torch.mean(F.relu(1. + dis_fake)) 19 | return loss_real, loss_fake 20 | 21 | 22 | def loss_hinge_gen(dis_fake): 23 | loss = -torch.mean(dis_fake) 24 | return loss 25 | 26 | 27 | def compute_gradient_penalty_T(real_B, fake_B, modelD, num_D): 28 | alpha = torch.rand(list(real_B.size())[0], 1, 1, 1, 1) 29 | alpha = alpha.expand(real_B.size()).cuda(real_B.get_device()) 30 | 31 | interpolates = alpha * real_B.data + (1 - alpha) * fake_B.data 32 | interpolates = torch.tensor(interpolates, requires_grad=True) 33 | 34 | pred_interpolates = modelD(interpolates) 35 | 36 | gradient_penalty = 0 37 | if isinstance(pred_interpolates, list): 38 | for cur_pred in pred_interpolates: 39 | gradients = torch.autograd.grad(outputs=cur_pred[-1], 40 | inputs=interpolates, 41 | grad_outputs=torch.ones( 42 | cur_pred[-1].size()).cuda( 43 | real_B.get_device()), 44 | create_graph=True, 45 | retain_graph=True, 46 | only_inputs=True)[0] 47 | 48 | gradient_penalty += ((gradients.norm(2, dim=1) - 1)**2).mean() 49 | else: 50 | sys.exit('output is not list!') 51 | 52 | gradient_penalty = (gradient_penalty / num_D) * 10 53 | return gradient_penalty 54 | 55 | 56 | class GANLoss(nn.Module): 57 | def __init__(self, 58 | use_lsgan=True, 59 | target_real_label=1.0, 60 | target_fake_label=0.0, 61 | tensor=torch.FloatTensor): 62 | super(GANLoss, self).__init__() 63 | self.real_label = target_real_label 64 | self.fake_label = target_fake_label 65 | self.real_label_var = None 66 | self.fake_label_var = None 67 | self.Tensor = tensor 68 | if use_lsgan: 69 | self.loss = nn.MSELoss() 70 | else: 71 | self.loss = nn.BCELoss() 72 | 73 | def get_target_tensor(self, input, target_is_real): 74 | target_tensor = None 75 | if target_is_real: 76 | create_label = ((self.real_label_var is None) 77 | or (self.real_label_var.numel() != input.numel())) 78 | if create_label: 79 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 80 | self.real_label_var = torch.tensor(real_tensor, 81 | requires_grad=False) 82 | target_tensor = self.real_label_var 83 | else: 84 | create_label = ((self.fake_label_var is None) 85 | or (self.fake_label_var.numel() != input.numel())) 86 | if create_label: 87 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 88 | self.fake_label_var = torch.tensor(fake_tensor, 89 | requires_grad=False) 90 | target_tensor = self.fake_label_var 91 | 92 | if input.is_cuda: 93 | target_tensor = target_tensor.cuda() 94 | return target_tensor 95 | 96 | def __call__(self, input, target_is_real): 97 | if isinstance(input[0], list): 98 | loss = 0 99 | for input_i in input: 100 | pred = input_i[-1] 101 | target_tensor = self.get_target_tensor(pred, target_is_real) 102 | loss += self.loss(pred, target_tensor) 103 | return loss 104 | else: 105 | target_tensor = self.get_target_tensor(input[-1], target_is_real) 106 | return self.loss(input[-1], target_tensor) 107 | 108 | 109 | class Relativistic_Average_LSGAN(GANLoss): 110 | ''' 111 | Relativistic average LSGAN 112 | ''' 113 | def __call__(self, input_1, input_2, target_is_real): 114 | if isinstance(input_1[0], list): 115 | loss = 0 116 | for input_i, _input_i in zip(input_1, input_2): 117 | pred = input_i[-1] 118 | _pred = _input_i[-1] 119 | target_tensor = self.get_target_tensor(pred, target_is_real) 120 | loss += self.loss(pred - torch.mean(_pred), target_tensor) 121 | return loss 122 | else: 123 | target_tensor = self.get_target_tensor(input_1[-1], target_is_real) 124 | return self.loss(input_1[-1] - torch.mean(input_2[-1]), 125 | target_tensor) 126 | -------------------------------------------------------------------------------- /models/rnn_net.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright Snap Inc. 2021. This sample code is made available by Snap Inc. for informational purposes only. 3 | No license, whether implied or otherwise, is granted in or to such code (including any rights to copy, modify, 4 | publish, distribute and/or commercialize such code), unless you have entered into a separate agreement for such rights. 5 | Such code is provided as-is, without warranty of any kind, express or implied, including any warranties of merchantability, 6 | title, fitness for a particular purpose, non-infringement, or that such code is free of defects, errors or viruses. 7 | In no event will Snap Inc. be liable for any damages or losses of any kind arising from the sample code or your use thereof. 8 | """ 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.nn import init 13 | import torch.optim as optim 14 | 15 | 16 | class RNNModule(nn.Module): 17 | def __init__(self, 18 | z_dim=64, 19 | h_dim=64, 20 | w_residual=0.2): 21 | super(RNNModule, self).__init__() 22 | 23 | self.z_dim = z_dim 24 | self.h_dim = h_dim 25 | self.w_residual = w_residual 26 | 27 | self.enc_cell = nn.LSTMCell(z_dim, h_dim) 28 | self.cell = nn.LSTMCell(z_dim, h_dim) 29 | self.w = nn.Parameter(torch.FloatTensor(h_dim, h_dim)) 30 | self.b = nn.Parameter(torch.FloatTensor(h_dim)) 31 | self.fc1 = nn.Linear(h_dim * 2, z_dim) 32 | self.relu = nn.ReLU() 33 | self.fc2 = nn.Linear(z_dim, z_dim) 34 | 35 | self.init_weights() 36 | 37 | def init_optim(self, lr, beta1, beta2): 38 | self.optim = optim.Adam(params=self.parameters(), 39 | lr=lr, 40 | betas=(beta1, beta2), 41 | weight_decay=0, 42 | eps=1e-8) 43 | 44 | def init_weights(self): 45 | for module in self.modules(): 46 | if (isinstance(module, nn.LSTMCell)): 47 | for name, param in module.named_parameters(): 48 | if ('weight_ih' in name) or ('weight_hh' in name): 49 | mul = param.shape[0] // 4 50 | for idx in range(4): 51 | init.orthogonal_(param[idx * mul:(idx + 1) * mul]) 52 | elif 'bias' in name: 53 | param.data.fill_(0) 54 | if (isinstance(module, nn.Linear)): 55 | init.orthogonal_(module.weight) 56 | 57 | nn.init.normal_(self.w, std=0.02) 58 | self.b.data.fill_(0.0) 59 | 60 | def forward(self, z, n_frame): 61 | 62 | out = [z] 63 | h_, c_ = self.enc_cell(z) 64 | h = [h_] 65 | c = [c_] 66 | e = [] 67 | for i in range(n_frame - 1): 68 | e_ = self.get_initial_state_z(z.shape[0]) 69 | h_, c_ = self.cell(e_, (h[-1], c[-1])) 70 | mul = torch.matmul(h_, self.w) + self.b 71 | mul = torch.tanh(mul) 72 | e.append(e_) 73 | h.append(h_) 74 | c.append(c_) 75 | out_ = out[-1] + self.w_residual * mul 76 | out.append(out_) 77 | 78 | out = [item.unsqueeze(1) for item in out] 79 | 80 | out = torch.cat(out, dim=1).view(-1, self.z_dim) 81 | 82 | e = [item.unsqueeze(1) for item in e] 83 | e = torch.cat(e, dim=1).view(-1, self.z_dim) 84 | 85 | hh = h[1:] 86 | hh = [item.unsqueeze(1) for item in hh] 87 | hh = torch.cat(hh, dim=1).view(-1, self.h_dim) 88 | 89 | cc = c[1:] 90 | cc = [item.unsqueeze(1) for item in cc] 91 | cc = torch.cat(cc, dim=1).view(-1, self.h_dim) 92 | 93 | hc = torch.cat((hh, cc), dim=1) 94 | e_rec = self.fc2(self.relu(self.fc1(hc))) 95 | 96 | return out, e, e_rec 97 | 98 | def get_initial_state_z(self, batchSize): 99 | return torch.cuda.FloatTensor(batchSize, self.z_dim).normal_() 100 | -------------------------------------------------------------------------------- /models/stylepose_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | from . import lmcode_networks 5 | from . import diy_networks 6 | from . import resnet 7 | 8 | import dnnlib 9 | import legacy 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import random 13 | 14 | def draw_points(image, pose): 15 | new_image = image.clone() 16 | size = 256 17 | for i in range(pose.shape[0]): 18 | for j in range(pose.shape[1]): 19 | pose_w = int(pose[i, j, 0] * size) 20 | pose_h = int(pose[i, j, 1] * size) 21 | new_image[i, 0, pose_h-3:pose_h+3, pose_w-3:pose_w+3] = -1 22 | new_image[i, 1, pose_h - 3:pose_h + 3, pose_w - 3:pose_w + 3] = -1 23 | new_image[i, 2, pose_h - 3:pose_h + 3, pose_w - 3:pose_w + 3] = -1 24 | 25 | return new_image 26 | 27 | def make_transform(translate, angle): 28 | m = np.eye(3) 29 | s = np.sin(angle/360.0*np.pi*2) 30 | c = np.cos(angle/360.0*np.pi*2) 31 | m[0][0] = c 32 | m[0][1] = s 33 | m[0][2] = translate[0] 34 | m[1][0] = -s 35 | m[1][1] = c 36 | m[1][2] = translate[1] 37 | return m 38 | 39 | def sample_trans(): 40 | translate = (0.5 * (random.random() - 0.5), 0.5 * (random.random() - 0.5)) 41 | rotate = 90 * (random.random() - 0.5) 42 | m = make_transform(translate, rotate) 43 | m = np.linalg.inv(m) 44 | return m 45 | 46 | class StylePoseModel(BaseModel): 47 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 48 | 49 | The model training requires '--dataset_mode aligned' dataset. 50 | By default, it uses a '--netG unet256' U-Net generator, 51 | a '--netD basic' discriminator (PatchGAN), 52 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 53 | 54 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 55 | """ 56 | @staticmethod 57 | def modify_commandline_options(parser, is_train=True): 58 | """Add new dataset-specific options, and rewrite default values for existing options. 59 | 60 | Parameters: 61 | parser -- original option parser 62 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 63 | 64 | Returns: 65 | the modified parser. 66 | 67 | For pix2pix, we do not use image buffer 68 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 69 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 70 | """ 71 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 72 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='jointset') 73 | if is_train: 74 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 75 | parser.add_argument('--num_point', type=int, default=14) 76 | parser.add_argument('--dataroot2', type=str, default='../data/realign1024x1024_random-shift0.1') 77 | parser.add_argument('--lambda_L1', type=float, default=1.0, help='weight for L1 loss') 78 | 79 | return parser 80 | 81 | def __init__(self, opt): 82 | """Initialize the pix2pix class. 83 | 84 | Parameters: 85 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 86 | """ 87 | BaseModel.__init__(self, opt) 88 | # specify the training losses you want to print out. The training/test scripts will call 89 | self.loss_names = ['G_L2', 'G_L2B'] 90 | # specify the models you want to save to the disk. The training/test scripts will call and 91 | self.visual_names = ['real_A', 'real_B', 'fake_A', 'fake_B'] 92 | if self.isTrain: 93 | self.model_names = ['FE'] 94 | else: # during test time, only load G 95 | self.model_names = ['FE'] 96 | self.visual_names = ['real_A', 'real_B', 'fake_A', 'fake_B', 'real_A_map', 'real_B_map'] 97 | # define networks (both generator and discriminator) 98 | with dnnlib.util.open_url(opt.network_pkl) as f: 99 | self.netG = legacy.load_network_pkl(f)['G_ema'].to(self.gpu_ids[0]) # type: ignore 100 | 101 | lm_path = 'pretrain/wing.ckpt' 102 | self.netFE_lm = lmcode_networks.FAN(fname_pretrained=lm_path).eval().to(self.gpu_ids[0]) 103 | 104 | self.netFE = diy_networks._resposenet(num_point=opt.num_point).to(self.gpu_ids[0]) 105 | 106 | if self.isTrain: 107 | # define loss functions 108 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 109 | self.criterionL2 = torch.nn.MSELoss() 110 | # initialize optimizers; schedulers will be automatically created by function . 111 | self.optimizer_FE = torch.optim.Adam(self.netFE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 112 | self.optimizers.append(self.optimizer_FE) 113 | 114 | # Load VGG16 feature detector. 115 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 116 | with dnnlib.util.open_url(url) as f: 117 | self.vgg16 = torch.jit.load(f).eval().to(self.gpu_ids[0]) 118 | 119 | self.m_zero = make_transform((0.0,0.0),(0.0)) 120 | 121 | 122 | def set_input(self, input): 123 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 124 | 125 | Parameters: 126 | input (dict): include the data itself and its metadata information. 127 | 128 | The option 'direction' can be used to swap images in domain A and domain B. 129 | """ 130 | self.real_A = input['A'].to(self.device) 131 | self.real_B = input['B'].to(self.device) 132 | self.image_paths = input['B_paths'] 133 | 134 | def forward(self): 135 | """Run forward pass; called by both functions and .""" 136 | 137 | if self.real_A.shape[2] != 256: 138 | self.real_A = F.interpolate(self.real_A, size=(256, 256), mode='area') 139 | self.real_A_heat = self.netFE_lm.get_heatmap(self.real_A, b_preprocess=False) 140 | self.real_A_pose = self.netFE(self.real_A_heat).view(-1, 14, 2) 141 | self.real_A_lm = self.netFE_lm.get_landmark(self.real_A).detach().to(self.device) / 256 142 | self.real_A_key = torch.cat([ 143 | torch.sum(self.real_A_lm[:, 33:42, ...], dim=1, keepdim=True) / 9, 144 | torch.sum(self.real_A_lm[:, 42:51, ...], dim=1, keepdim=True) / 9, 145 | self.real_A_lm[:, [62], ...], 146 | self.real_A_lm[:, [96], ...], 147 | self.real_A_lm[:, [66], ...], 148 | self.real_A_lm[:, [70], ...], 149 | self.real_A_lm[:, [97], ...], 150 | self.real_A_lm[:, [74], ...], 151 | self.real_A_lm[:, [54], ...], 152 | self.real_A_lm[:, [79], ...], 153 | self.real_A_lm[:, [85], ...], 154 | self.real_A_lm[:, [88], ...], 155 | self.real_A_lm[:, [92], ...], 156 | torch.sum(self.real_A_lm[:, [90,94], ...], dim=1, keepdim=True) / 2, 157 | ], 1) 158 | 159 | self.fake_A = draw_points(self.real_A.clone(), self.real_A_pose) 160 | 161 | def forward_B(self): 162 | if self.real_B.shape[2] != 256: 163 | self.real_B = F.interpolate(self.real_B, size=(256, 256), mode='area') 164 | self.real_B_heat = self.netFE_lm.get_heatmap(self.real_B, b_preprocess=False) 165 | self.real_B_pose = self.netFE(self.real_B_heat).view(-1, 14, 2) 166 | self.real_B_lm = self.netFE_lm.get_landmark(self.real_B).detach().to(self.device) / 256 167 | self.real_B_key = torch.cat([ 168 | torch.sum(self.real_B_lm[:, 33:42, ...], dim=1, keepdim=True) / 9, 169 | torch.sum(self.real_B_lm[:, 42:51, ...], dim=1, keepdim=True) / 9, 170 | self.real_B_lm[:, [62], ...], 171 | self.real_B_lm[:, [96], ...], 172 | self.real_B_lm[:, [66], ...], 173 | self.real_B_lm[:, [70], ...], 174 | self.real_B_lm[:, [97], ...], 175 | self.real_B_lm[:, [74], ...], 176 | self.real_B_lm[:, [54], ...], 177 | self.real_B_lm[:, [79], ...], 178 | self.real_B_lm[:, [85], ...], 179 | self.real_B_lm[:, [88], ...], 180 | self.real_B_lm[:, [92], ...], 181 | torch.sum(self.real_B_lm[:, [90, 94], ...], dim=1, keepdim=True) / 2, 182 | ], 1) 183 | 184 | self.fake_B = draw_points(self.real_B.clone(), self.real_B_pose) 185 | 186 | def backward_G(self): 187 | """Calculate GAN and L1 loss for the generator""" 188 | 189 | # Second, G(A) = B 190 | self.loss_G_L2 = 10 * self.criterionL2(self.real_A_pose, self.real_A_key) 191 | # combine loss and calculate gradients 192 | self.loss_G = self.loss_G_L2 193 | self.loss_G.backward() 194 | 195 | def backward_G_B(self): 196 | """Calculate GAN and L1 loss for the generator""" 197 | 198 | # Second, G(A) = B 199 | self.loss_G_L2B = 10 * self.criterionL2(self.real_B_pose, self.real_B_key) 200 | # combine loss and calculate gradients 201 | self.loss_G_B = self.loss_G_L2B 202 | self.loss_G_B.backward() 203 | 204 | def optimize_parameters(self): 205 | self.forward() # compute fake images: G(A) 206 | # update G 207 | self.optimizer_FE.zero_grad() # set G's gradients to zero 208 | self.backward_G() # calculate graidents for G 209 | self.optimizer_FE.step() # udpate G's weights 210 | 211 | self.forward_B() 212 | self.optimizer_FE.zero_grad() # set G's gradients to zero 213 | self.backward_G_B() # calculate graidents for G 214 | self.optimizer_FE.step() # udpate G's weights 215 | 216 | def compute_visuals(self): 217 | """Calculate additional output images for visdom and HTML visualization""" 218 | self.forward_B() 219 | self.criterionL2 = torch.nn.MSELoss() 220 | self.loss_G_L2 = 10 * self.criterionL2(self.real_A_pose, self.real_A_key) 221 | self.loss_G_L2B = 10 * self.criterionL2(self.real_B_pose, self.real_B_key) 222 | self.fake_A = draw_points(self.real_A, self.real_A_pose) 223 | self.fake_B = draw_points(self.real_B, self.real_B_pose) 224 | self.real_A = draw_points(self.real_A, self.real_A_key) 225 | self.real_B = draw_points(self.real_B, self.real_B_key) 226 | 227 | self.real_A_map = self.netFE(self.real_A_heat, mode=1).detach() 228 | self.real_A_map = (self.real_A_map - torch.min(self.real_A_map)) / (torch.max(self.real_A_map) -torch.min(self.real_A_map)) * 2 -1 229 | self.real_B_map = self.netFE(self.real_B_heat, mode=1).detach() 230 | self.real_B_map = (self.real_B_map - torch.min(self.real_B_map)) / ( 231 | torch.max(self.real_B_map) - torch.min(self.real_B_map)) * 2 - 1 -------------------------------------------------------------------------------- /models/stylepre_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from .base_model import BaseModel 3 | from . import networks 4 | from . import lmcode_networks 5 | from . import diy_networks 6 | from . import resnet 7 | 8 | import dnnlib 9 | import legacy 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import random 13 | import os 14 | 15 | def make_transform(translate, angle): 16 | m = np.eye(3) 17 | s = np.sin(angle/360.0*np.pi*2) 18 | c = np.cos(angle/360.0*np.pi*2) 19 | m[0][0] = c 20 | m[0][1] = s 21 | m[0][2] = translate[0] 22 | m[1][0] = -s 23 | m[1][1] = c 24 | m[1][2] = translate[1] 25 | return m 26 | 27 | def sample_trans(): 28 | translate = (0.5 * (random.random() - 0.5), 0.5 * (random.random() - 0.5)) 29 | rotate = 90 * (random.random() - 0.5) 30 | m = make_transform(translate, rotate) 31 | m = np.linalg.inv(m) 32 | return m 33 | 34 | class StylePreModel(BaseModel): 35 | """ This class implements the pix2pix model, for learning a mapping from input images to output images given paired data. 36 | 37 | The model training requires '--dataset_mode aligned' dataset. 38 | By default, it uses a '--netG unet256' U-Net generator, 39 | a '--netD basic' discriminator (PatchGAN), 40 | and a '--gan_mode' vanilla GAN loss (the cross-entropy objective used in the orignal GAN paper). 41 | 42 | pix2pix paper: https://arxiv.org/pdf/1611.07004.pdf 43 | """ 44 | @staticmethod 45 | def modify_commandline_options(parser, is_train=True): 46 | """Add new dataset-specific options, and rewrite default values for existing options. 47 | 48 | Parameters: 49 | parser -- original option parser 50 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 51 | 52 | Returns: 53 | the modified parser. 54 | 55 | For pix2pix, we do not use image buffer 56 | The training objective is: GAN Loss + lambda_L1 * ||G(A)-B||_1 57 | By default, we use vanilla GAN loss, UNet with batchnorm, and aligned datasets. 58 | """ 59 | # changing the default values to match the pix2pix paper (https://phillipi.github.io/pix2pix/) 60 | parser.set_defaults(norm='batch', netG='unet_256', dataset_mode='noise') 61 | parser.add_argument('--pose_path', type=str, default='', help='path for pose net') 62 | parser.add_argument('--num_point', type=int, default=14) 63 | parser.add_argument('--pre_path', type=str, default='', help='path for pretrain') 64 | if is_train: 65 | parser.set_defaults(pool_size=0, gan_mode='vanilla') 66 | parser.add_argument('--lambda_L1', type=float, default=1.0, help='weight for L1 loss') 67 | 68 | return parser 69 | 70 | def __init__(self, opt): 71 | """Initialize the pix2pix class. 72 | 73 | Parameters: 74 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 75 | """ 76 | BaseModel.__init__(self, opt) 77 | # specify the training losses you want to print out. The training/test scripts will call 78 | self.loss_names = ['G_L1', 'G_VGG', 'G_W'] 79 | # specify the images you want to save/display. The training/test scripts will call 80 | self.visual_names = ['real_B', 'real_A', 'fake_B'] 81 | # specify the models you want to save to the disk. The training/test scripts will call and 82 | if self.isTrain: 83 | self.model_names = ['FE'] 84 | else: # during test time, only load G 85 | self.model_names = ['FE'] 86 | # define networks (both generator and discriminator) 87 | with dnnlib.util.open_url(opt.network_pkl) as f: 88 | self.netG = legacy.load_network_pkl(f)['G_ema'].eval().to(self.gpu_ids[0]) # type: ignore 89 | 90 | lm_path = 'pretrained_models/wing.ckpt' 91 | self.netFE_lm = lmcode_networks.FAN(fname_pretrained=lm_path).eval().to(self.gpu_ids[0]) 92 | self.netFE_pose = diy_networks._resposenet(num_point=opt.num_point).eval().to(self.gpu_ids[0]) 93 | if opt.pose_path != '': 94 | self.netFE_pose.load_state_dict(torch.load(opt.pose_path)) 95 | 96 | self.netFE = resnet.wide_resdisnet50_2(num_classes=512 * 16).to(self.gpu_ids[0]) 97 | if opt.pre_path != '': 98 | try: 99 | self.netFE.load_state_dict(torch.load(opt.pre_path), strict=True) 100 | except: 101 | import collections 102 | model_dic = torch.load(opt.pre_path) 103 | new_state_dict = collections.OrderedDict() 104 | for k, v in model_dic.items(): 105 | name = k.replace('module.', '') 106 | new_state_dict[name] = v 107 | self.netFE.load_state_dict(new_state_dict, strict=True) 108 | 109 | if self.isTrain: 110 | # define loss functions 111 | self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) 112 | self.criterionL1 = torch.nn.L1Loss() 113 | # initialize optimizers; schedulers will be automatically created by function . 114 | self.optimizer_FE = torch.optim.Adam(self.netFE.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999)) 115 | self.optimizers.append(self.optimizer_FE) 116 | 117 | # Load VGG16 feature detector. 118 | url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' 119 | with dnnlib.util.open_url(url) as f: 120 | self.vgg16 = torch.jit.load(f).eval().to(self.gpu_ids[0]) 121 | 122 | self.m_zero = make_transform((0.0,0.0),(0.0)) 123 | 124 | 125 | def set_input(self, input): 126 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 127 | 128 | Parameters: 129 | input (dict): include the data itself and its metadata information. 130 | 131 | The option 'direction' can be used to swap images in domain A and domain B. 132 | """ 133 | self.real_z = input['A'].to(self.device) 134 | self.image_paths = input['A_paths'] 135 | 136 | def forward(self): 137 | """Run forward pass; called by both functions and .""" 138 | if hasattr(self.netG.synthesis, 'input'): 139 | self.netG.synthesis.input.transform.copy_(torch.from_numpy(self.m_zero)) 140 | 141 | with torch.no_grad(): 142 | self.real_A_w = self.netG.mapping(self.real_z, None) 143 | self.real_A = self.netG.synthesis(self.real_A_w, noise_mode='const').detach().clamp(-1, 1) 144 | if self.real_A.shape[2] != 256: 145 | self.real_A = F.interpolate(self.real_A, size=(256, 256), mode='area') 146 | self.real_A_heat = self.netFE_lm.get_heatmap(self.real_A, b_preprocess=False) 147 | self.real_A_pose = self.netFE_pose(self.real_A_heat, mode=1).detach() 148 | 149 | m = sample_trans() 150 | self.netG.synthesis.input.transform.copy_(torch.from_numpy(m)) 151 | self.real_B = self.netG.synthesis(self.real_A_w, noise_mode='const').detach().clamp(-1, 1) 152 | if self.real_B.shape[2] != 256: 153 | self.real_B = F.interpolate(self.real_B, size=(256, 256), mode='area') 154 | 155 | self.real_B_app = self.netFE(self.real_B, mode=1) 156 | self.fake_B_w = self.netFE(self.real_B_app, self.real_A_pose, mode=2).view(-1, 16, 512) 157 | 158 | self.netG.synthesis.input.transform.copy_(torch.from_numpy(self.m_zero)) 159 | self.fake_B = self.netG.synthesis(self.fake_B_w, noise_mode='const') # G(A) 160 | if self.fake_B.shape[2] != 256: 161 | self.fake_B = F.interpolate(self.fake_B, size=(256, 256), mode='area') 162 | 163 | def backward_G(self): 164 | """Calculate GAN and L1 loss for the generator""" 165 | 166 | # Second, G(A) = B 167 | self.loss_G_L1 = 1 * self.opt.lambda_L1 * self.criterionL1(self.fake_B, self.real_A) 168 | 169 | self.loss_G_VGG = 100 * self.opt.lambda_L1 * self.criterionL1(self.vgg16(self.fake_B), self.vgg16(self.real_A)) 170 | 171 | self.loss_G_W = 100 * self.opt.lambda_L1 * self.criterionL1(self.fake_B_w[:,1:,:], self.real_A_w[:,1:,:]) 172 | 173 | # combine loss and calculate gradients 174 | self.loss_G = self.loss_G_L1 + self.loss_G_VGG + self.loss_G_W 175 | self.loss_G.backward() 176 | 177 | def optimize_parameters(self): 178 | self.forward() # compute fake images: G(A) 179 | # update G 180 | self.optimizer_FE.zero_grad() # set G's gradients to zero 181 | self.backward_G() # calculate graidents for G 182 | self.optimizer_FE.step() # udpate G's weights 183 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | """This package options includes option modules: training options, test options, and basic options (used in both training and test).""" 2 | -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from util import util 4 | import torch 5 | import models 6 | import data 7 | 8 | 9 | class BaseOptions(): 10 | """This class defines options used during both training and test time. 11 | 12 | It also implements several helper functions such as parsing, printing, and saving the options. 13 | It also gathers additional options defined in functions in both dataset class and model class. 14 | """ 15 | 16 | def __init__(self): 17 | """Reset the class; indicates the class hasn't been initailized""" 18 | self.initialized = False 19 | 20 | def initialize(self, parser): 21 | """Define the common options that are used in both training and test.""" 22 | # basic parameters 23 | parser.add_argument('--dataroot', required=True, help='path to images (should have subfolders trainA, trainB, valA, valB, etc)') 24 | parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment. It decides where to store samples and models') 25 | parser.add_argument('--network_pkl', type=str, help='Network pickle filename') 26 | parser.add_argument('--use_wandb', action='store_true', help='use wandb') 27 | parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU') 28 | parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') 29 | # model parameters 30 | parser.add_argument('--model', type=str, default='cycle_gan', help='chooses which model to use. [cycle_gan | pix2pix | test | colorization]') 31 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels: 3 for RGB and 1 for grayscale') 32 | parser.add_argument('--output_nc', type=int, default=3, help='# of output image channels: 3 for RGB and 1 for grayscale') 33 | parser.add_argument('--ngf', type=int, default=64, help='# of gen filters in the last conv layer') 34 | parser.add_argument('--ndf', type=int, default=64, help='# of discrim filters in the first conv layer') 35 | parser.add_argument('--netD', type=str, default='basic', help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator') 36 | parser.add_argument('--netG', type=str, default='resnet_9blocks', help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_256 | unet_128]') 37 | parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') 38 | parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [instance | batch | none]') 39 | parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal | xavier | kaiming | orthogonal]') 40 | parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') 41 | parser.add_argument('--no_dropout', action='store_true', help='no dropout for the generator') 42 | # dataset parameters 43 | parser.add_argument('--dataset_mode', type=str, default='unaligned', help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]') 44 | parser.add_argument('--direction', type=str, default='AtoB', help='AtoB or BtoA') 45 | parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly') 46 | parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data') 47 | parser.add_argument('--batch_size', type=int, default=1, help='input batch size') 48 | parser.add_argument('--load_size', type=int, default=256, help='scale images to this size') 49 | parser.add_argument('--crop_size', type=int, default=256, help='then crop to this size') 50 | parser.add_argument('--max_dataset_size', type=int, default=20000, help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.') 51 | parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='scaling and cropping of images at load time [resize_and_crop | crop | scale_width | scale_width_and_crop | none]') 52 | parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the images for data augmentation') 53 | parser.add_argument('--display_winsize', type=int, default=256, help='display window size for both visdom and HTML') 54 | # additional parameters 55 | parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model') 56 | parser.add_argument('--load_iter', type=int, default='0', help='which iteration to load? if load_iter > 0, the code will load models by iter_[load_iter]; otherwise, the code will load models by [epoch]') 57 | parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information') 58 | parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}') 59 | self.initialized = True 60 | return parser 61 | 62 | def gather_options(self): 63 | """Initialize our parser with basic options(only once). 64 | Add additional model-specific and dataset-specific options. 65 | These options are defined in the function 66 | in model and dataset classes. 67 | """ 68 | if not self.initialized: # check if it has been initialized 69 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 70 | parser = self.initialize(parser) 71 | 72 | # get the basic options 73 | opt, _ = parser.parse_known_args() 74 | 75 | # modify model-related parser options 76 | model_name = opt.model 77 | model_option_setter = models.get_option_setter(model_name) 78 | parser = model_option_setter(parser, self.isTrain) 79 | opt, _ = parser.parse_known_args() # parse again with new defaults 80 | 81 | # modify dataset-related parser options 82 | dataset_name = opt.dataset_mode 83 | dataset_option_setter = data.get_option_setter(dataset_name) 84 | parser = dataset_option_setter(parser, self.isTrain) 85 | 86 | # save and return the parser 87 | self.parser = parser 88 | return parser.parse_args() 89 | 90 | def print_options(self, opt): 91 | """Print and save options 92 | 93 | It will print both current options and default values(if different). 94 | It will save options into a text file / [checkpoints_dir] / opt.txt 95 | """ 96 | message = '' 97 | message += '----------------- Options ---------------\n' 98 | for k, v in sorted(vars(opt).items()): 99 | comment = '' 100 | default = self.parser.get_default(k) 101 | if v != default: 102 | comment = '\t[default: %s]' % str(default) 103 | message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) 104 | message += '----------------- End -------------------' 105 | print(message) 106 | 107 | # save to the disk 108 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 109 | util.mkdirs(expr_dir) 110 | file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase)) 111 | with open(file_name, 'wt') as opt_file: 112 | opt_file.write(message) 113 | opt_file.write('\n') 114 | 115 | def parse(self): 116 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 117 | opt = self.gather_options() 118 | opt.isTrain = self.isTrain # train or test 119 | 120 | # process opt.suffix 121 | if opt.suffix: 122 | suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else '' 123 | opt.name = opt.name + suffix 124 | 125 | self.print_options(opt) 126 | 127 | # set gpu ids 128 | str_ids = opt.gpu_ids.split(',') 129 | opt.gpu_ids = [] 130 | for str_id in str_ids: 131 | id = int(str_id) 132 | if id >= 0: 133 | opt.gpu_ids.append(id) 134 | if len(opt.gpu_ids) > 0: 135 | torch.cuda.set_device(opt.gpu_ids[0]) 136 | 137 | self.opt = opt 138 | return self.opt 139 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 13 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 14 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 15 | # Dropout and Batchnorm has different behavioir during training and test. 16 | parser.add_argument('--eval', action='store_true', help='use eval mode during test time.') 17 | parser.add_argument('--num_test', type=int, default=50, help='how many test images to run') 18 | # rewrite devalue values 19 | parser.set_defaults(model='test') 20 | # To avoid cropping, the load_size should be the same as crop_size 21 | parser.set_defaults(load_size=parser.get_default('crop_size')) 22 | self.isTrain = False 23 | return parser 24 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument('--display_freq', type=int, default=400, help='frequency of showing training results on screen') 14 | parser.add_argument('--display_ncols', type=int, default=4, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 15 | parser.add_argument('--display_id', type=int, default=0, help='window id of the web display') 16 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 17 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 18 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 19 | parser.add_argument('--update_html_freq', type=int, default=1000, help='frequency of saving training results to html') 20 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 21 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 22 | # network saving and loading parameters 23 | parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results') 24 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 25 | parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration') 26 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 27 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 28 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 29 | # training parameters 30 | parser.add_argument('--n_epochs', type=int, default=50, help='number of epochs with the initial learning rate') 31 | parser.add_argument('--n_epochs_decay', type=int, default=50, help='number of epochs to linearly decay learning rate to zero') 32 | parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') 33 | parser.add_argument('--lr', type=float, default=0.0002, help='initial learning rate for adam') 34 | parser.add_argument('--gan_mode', type=str, default='lsgan', help='the type of GAN objective. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.') 35 | parser.add_argument('--pool_size', type=int, default=50, help='the size of image buffer that stores previously generated images') 36 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 37 | parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') 38 | 39 | parser.add_argument('--epoch_gan', type=int, default=0, 40 | help='finetune the whole model with GAN loss finally') 41 | 42 | self.isTrain = True 43 | return parser 44 | -------------------------------------------------------------------------------- /scripts/vid2img.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import os 3 | import glob 4 | import numpy as np 5 | 6 | video_name = '../data/actor_align_512/Actor_*/*.mp4' 7 | 8 | files = sorted(glob.glob(video_name)) 9 | 10 | for index in range(len(files)): 11 | file = files[index] 12 | file1 = file.replace('.mp4', '/').replace('actor_align_512', 'actor_align_512_png') 13 | if not os.path.exists(file1): 14 | os.makedirs(file1) 15 | cap=cv.VideoCapture(file) 16 | isOpened=cap.isOpened() 17 | i=0 18 | while(isOpened) and i < 9000: 19 | i=i+1 20 | flag,frame=cap.read() 21 | # fileName = '%03d'%i+".jpg" 22 | name = ('00000' + str(i))[-5:] 23 | if flag == True : 24 | cv.imwrite(file1 + name + ".png", frame) 25 | cv.waitKey(1) 26 | else: 27 | break 28 | cap.release() 29 | print('end') -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """General-purpose test script for image-to-image translation. 2 | 3 | Once you have trained your model with train.py, you can use this script to test the model. 4 | It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'. 5 | 6 | It first creates model and dataset given the option. It will hard-code some parameters. 7 | It then runs inference for '--num_test' images and save results to an HTML file. 8 | 9 | """ 10 | import os 11 | from options.test_options import TestOptions 12 | from data import create_dataset 13 | from models import create_model 14 | from util.visualizer import save_videos 15 | from util import html 16 | 17 | try: 18 | import wandb 19 | except ImportError: 20 | print('Warning: wandb package cannot be found. The option "--use_wandb" will result in error.') 21 | 22 | 23 | if __name__ == '__main__': 24 | opt = TestOptions().parse() # get test options 25 | # hard-code some parameters for test 26 | opt.num_threads = 0 # test code only supports num_threads = 0 27 | opt.batch_size = 1 # test code only supports batch_size = 1 28 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 29 | opt.no_flip = True # no flip; comment this line if results on flipped images are needed. 30 | opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 31 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 32 | model = create_model(opt) # create a model given opt.model and other options 33 | model.setup(opt) # regular setup: load and print networks; create schedulers 34 | 35 | # initialize logger 36 | if opt.use_wandb: 37 | wandb_run = wandb.init(project='CycleGAN-and-pix2pix', name=opt.name, config=opt) if not wandb.run else wandb.run 38 | wandb_run._label(repo='CycleGAN-and-pix2pix') 39 | 40 | # create a website 41 | web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory 42 | if opt.load_iter > 0: # load_iter is 0 by default 43 | web_dir = '{:s}_iter{:d}'.format(web_dir, opt.load_iter) 44 | print('creating web directory', web_dir) 45 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 46 | # test with eval mode. This only affects layers like batchnorm and dropout. 47 | # For [pix2pix]: we use batchnorm and dropout in the original pix2pix. You can experiment it with and without eval() mode. 48 | # For [CycleGAN]: It should not affect CycleGAN as CycleGAN uses instancenorm without dropout. 49 | if opt.eval: 50 | model.eval() 51 | for i, data in enumerate(dataset): 52 | if i >= opt.num_test: # only apply our model to opt.num_test images. 53 | break 54 | model.set_input(data) # unpack data from data loader 55 | model.test() # run inference 56 | visuals = model.get_current_visuals() # get image results 57 | img_path = model.get_image_paths() # get image paths 58 | if i % 5 == 0: # save images to an HTML file 59 | print('processing (%04d)-th image... %s' % (i, img_path)) 60 | save_videos(webpage, visuals, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize, use_wandb=opt.use_wandb) 61 | webpage.save() # save the HTML 62 | -------------------------------------------------------------------------------- /torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/custom_ops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | import glob 10 | import hashlib 11 | import importlib 12 | import os 13 | import re 14 | import shutil 15 | import uuid 16 | 17 | import torch 18 | import torch.utils.cpp_extension 19 | from torch.utils.file_baton import FileBaton 20 | 21 | #---------------------------------------------------------------------------- 22 | # Global options. 23 | 24 | verbosity = 'brief' # Verbosity level: 'none', 'brief', 'full' 25 | 26 | #---------------------------------------------------------------------------- 27 | # Internal helper funcs. 28 | 29 | def _find_compiler_bindir(): 30 | patterns = [ 31 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Professional/VC/Tools/MSVC/*/bin/Hostx64/x64', 32 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/BuildTools/VC/Tools/MSVC/*/bin/Hostx64/x64', 33 | 'C:/Program Files (x86)/Microsoft Visual Studio/*/Community/VC/Tools/MSVC/*/bin/Hostx64/x64', 34 | 'C:/Program Files (x86)/Microsoft Visual Studio */vc/bin', 35 | ] 36 | for pattern in patterns: 37 | matches = sorted(glob.glob(pattern)) 38 | if len(matches): 39 | return matches[-1] 40 | return None 41 | 42 | #---------------------------------------------------------------------------- 43 | 44 | def _get_mangled_gpu_name(): 45 | name = torch.cuda.get_device_name().lower() 46 | out = [] 47 | for c in name: 48 | if re.match('[a-z0-9_-]+', c): 49 | out.append(c) 50 | else: 51 | out.append('-') 52 | return ''.join(out) 53 | 54 | #---------------------------------------------------------------------------- 55 | # Main entry point for compiling and loading C++/CUDA plugins. 56 | 57 | _cached_plugins = dict() 58 | 59 | def get_plugin(module_name, sources, headers=None, source_dir=None, **build_kwargs): 60 | assert verbosity in ['none', 'brief', 'full'] 61 | if headers is None: 62 | headers = [] 63 | if source_dir is not None: 64 | sources = [os.path.join(source_dir, fname) for fname in sources] 65 | headers = [os.path.join(source_dir, fname) for fname in headers] 66 | 67 | # Already cached? 68 | if module_name in _cached_plugins: 69 | return _cached_plugins[module_name] 70 | 71 | # Print status. 72 | if verbosity == 'full': 73 | print(f'Setting up PyTorch plugin "{module_name}"...') 74 | elif verbosity == 'brief': 75 | print(f'Setting up PyTorch plugin "{module_name}"... ', end='', flush=True) 76 | verbose_build = (verbosity == 'full') 77 | 78 | # Compile and load. 79 | try: # pylint: disable=too-many-nested-blocks 80 | # Make sure we can find the necessary compiler binaries. 81 | if os.name == 'nt' and os.system("where cl.exe >nul 2>nul") != 0: 82 | compiler_bindir = _find_compiler_bindir() 83 | if compiler_bindir is None: 84 | raise RuntimeError(f'Could not find MSVC/GCC/CLANG installation on this computer. Check _find_compiler_bindir() in "{__file__}".') 85 | os.environ['PATH'] += ';' + compiler_bindir 86 | 87 | # Some containers set TORCH_CUDA_ARCH_LIST to a list that can either 88 | # break the build or unnecessarily restrict what's available to nvcc. 89 | # Unset it to let nvcc decide based on what's available on the 90 | # machine. 91 | os.environ['TORCH_CUDA_ARCH_LIST'] = '' 92 | 93 | # Incremental build md5sum trickery. Copies all the input source files 94 | # into a cached build directory under a combined md5 digest of the input 95 | # source files. Copying is done only if the combined digest has changed. 96 | # This keeps input file timestamps and filenames the same as in previous 97 | # extension builds, allowing for fast incremental rebuilds. 98 | # 99 | # This optimization is done only in case all the source files reside in 100 | # a single directory (just for simplicity) and if the TORCH_EXTENSIONS_DIR 101 | # environment variable is set (we take this as a signal that the user 102 | # actually cares about this.) 103 | # 104 | # EDIT: We now do it regardless of TORCH_EXTENSIOS_DIR, in order to work 105 | # around the *.cu dependency bug in ninja config. 106 | # 107 | all_source_files = sorted(sources + headers) 108 | all_source_dirs = set(os.path.dirname(fname) for fname in all_source_files) 109 | if len(all_source_dirs) == 1: # and ('TORCH_EXTENSIONS_DIR' in os.environ): 110 | 111 | # Compute combined hash digest for all source files. 112 | hash_md5 = hashlib.md5() 113 | for src in all_source_files: 114 | with open(src, 'rb') as f: 115 | hash_md5.update(f.read()) 116 | 117 | # Select cached build directory name. 118 | source_digest = hash_md5.hexdigest() 119 | build_top_dir = torch.utils.cpp_extension._get_build_directory(module_name, verbose=verbose_build) # pylint: disable=protected-access 120 | cached_build_dir = os.path.join(build_top_dir, f'{source_digest}-{_get_mangled_gpu_name()}') 121 | 122 | if not os.path.isdir(cached_build_dir): 123 | tmpdir = f'{build_top_dir}/srctmp-{uuid.uuid4().hex}' 124 | os.makedirs(tmpdir) 125 | for src in all_source_files: 126 | shutil.copyfile(src, os.path.join(tmpdir, os.path.basename(src))) 127 | try: 128 | os.replace(tmpdir, cached_build_dir) # atomic 129 | except OSError: 130 | # source directory already exists, delete tmpdir and its contents. 131 | shutil.rmtree(tmpdir) 132 | if not os.path.isdir(cached_build_dir): raise 133 | 134 | # Compile. 135 | cached_sources = [os.path.join(cached_build_dir, os.path.basename(fname)) for fname in sources] 136 | torch.utils.cpp_extension.load(name=module_name, build_directory=cached_build_dir, 137 | verbose=verbose_build, sources=cached_sources, **build_kwargs) 138 | else: 139 | torch.utils.cpp_extension.load(name=module_name, verbose=verbose_build, sources=sources, **build_kwargs) 140 | 141 | # Load. 142 | module = importlib.import_module(module_name) 143 | 144 | except: 145 | if verbosity == 'brief': 146 | print('Failed!') 147 | raise 148 | 149 | # Print status and add to cache dict. 150 | if verbosity == 'full': 151 | print(f'Done setting up PyTorch plugin "{module_name}".') 152 | elif verbosity == 'brief': 153 | print('Done.') 154 | _cached_plugins[module_name] = module 155 | return module 156 | 157 | #---------------------------------------------------------------------------- 158 | -------------------------------------------------------------------------------- /torch_utils/ops/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | # empty 10 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "bias_act.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static bool has_same_layout(torch::Tensor x, torch::Tensor y) 17 | { 18 | if (x.dim() != y.dim()) 19 | return false; 20 | for (int64_t i = 0; i < x.dim(); i++) 21 | { 22 | if (x.size(i) != y.size(i)) 23 | return false; 24 | if (x.size(i) >= 2 && x.stride(i) != y.stride(i)) 25 | return false; 26 | } 27 | return true; 28 | } 29 | 30 | //------------------------------------------------------------------------ 31 | 32 | static torch::Tensor bias_act(torch::Tensor x, torch::Tensor b, torch::Tensor xref, torch::Tensor yref, torch::Tensor dy, int grad, int dim, int act, float alpha, float gain, float clamp) 33 | { 34 | // Validate arguments. 35 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 36 | TORCH_CHECK(b.numel() == 0 || (b.dtype() == x.dtype() && b.device() == x.device()), "b must have the same dtype and device as x"); 37 | TORCH_CHECK(xref.numel() == 0 || (xref.sizes() == x.sizes() && xref.dtype() == x.dtype() && xref.device() == x.device()), "xref must have the same shape, dtype, and device as x"); 38 | TORCH_CHECK(yref.numel() == 0 || (yref.sizes() == x.sizes() && yref.dtype() == x.dtype() && yref.device() == x.device()), "yref must have the same shape, dtype, and device as x"); 39 | TORCH_CHECK(dy.numel() == 0 || (dy.sizes() == x.sizes() && dy.dtype() == x.dtype() && dy.device() == x.device()), "dy must have the same dtype and device as x"); 40 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 41 | TORCH_CHECK(b.dim() == 1, "b must have rank 1"); 42 | TORCH_CHECK(b.numel() == 0 || (dim >= 0 && dim < x.dim()), "dim is out of bounds"); 43 | TORCH_CHECK(b.numel() == 0 || b.numel() == x.size(dim), "b has wrong number of elements"); 44 | TORCH_CHECK(grad >= 0, "grad must be non-negative"); 45 | 46 | // Validate layout. 47 | TORCH_CHECK(x.is_non_overlapping_and_dense(), "x must be non-overlapping and dense"); 48 | TORCH_CHECK(b.is_contiguous(), "b must be contiguous"); 49 | TORCH_CHECK(xref.numel() == 0 || has_same_layout(xref, x), "xref must have the same layout as x"); 50 | TORCH_CHECK(yref.numel() == 0 || has_same_layout(yref, x), "yref must have the same layout as x"); 51 | TORCH_CHECK(dy.numel() == 0 || has_same_layout(dy, x), "dy must have the same layout as x"); 52 | 53 | // Create output tensor. 54 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 55 | torch::Tensor y = torch::empty_like(x); 56 | TORCH_CHECK(has_same_layout(y, x), "y must have the same layout as x"); 57 | 58 | // Initialize CUDA kernel parameters. 59 | bias_act_kernel_params p; 60 | p.x = x.data_ptr(); 61 | p.b = (b.numel()) ? b.data_ptr() : NULL; 62 | p.xref = (xref.numel()) ? xref.data_ptr() : NULL; 63 | p.yref = (yref.numel()) ? yref.data_ptr() : NULL; 64 | p.dy = (dy.numel()) ? dy.data_ptr() : NULL; 65 | p.y = y.data_ptr(); 66 | p.grad = grad; 67 | p.act = act; 68 | p.alpha = alpha; 69 | p.gain = gain; 70 | p.clamp = clamp; 71 | p.sizeX = (int)x.numel(); 72 | p.sizeB = (int)b.numel(); 73 | p.stepB = (b.numel()) ? (int)x.stride(dim) : 1; 74 | 75 | // Choose CUDA kernel. 76 | void* kernel; 77 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 78 | { 79 | kernel = choose_bias_act_kernel(p); 80 | }); 81 | TORCH_CHECK(kernel, "no CUDA kernel found for the specified activation func"); 82 | 83 | // Launch CUDA kernel. 84 | p.loopX = 4; 85 | int blockSize = 4 * 32; 86 | int gridSize = (p.sizeX - 1) / (p.loopX * blockSize) + 1; 87 | void* args[] = {&p}; 88 | AT_CUDA_CHECK(cudaLaunchKernel(kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 89 | return y; 90 | } 91 | 92 | //------------------------------------------------------------------------ 93 | 94 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 95 | { 96 | m.def("bias_act", &bias_act); 97 | } 98 | 99 | //------------------------------------------------------------------------ 100 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include "bias_act.h" 11 | 12 | //------------------------------------------------------------------------ 13 | // Helpers. 14 | 15 | template struct InternalType; 16 | template <> struct InternalType { typedef double scalar_t; }; 17 | template <> struct InternalType { typedef float scalar_t; }; 18 | template <> struct InternalType { typedef float scalar_t; }; 19 | 20 | //------------------------------------------------------------------------ 21 | // CUDA kernel. 22 | 23 | template 24 | __global__ void bias_act_kernel(bias_act_kernel_params p) 25 | { 26 | typedef typename InternalType::scalar_t scalar_t; 27 | int G = p.grad; 28 | scalar_t alpha = (scalar_t)p.alpha; 29 | scalar_t gain = (scalar_t)p.gain; 30 | scalar_t clamp = (scalar_t)p.clamp; 31 | scalar_t one = (scalar_t)1; 32 | scalar_t two = (scalar_t)2; 33 | scalar_t expRange = (scalar_t)80; 34 | scalar_t halfExpRange = (scalar_t)40; 35 | scalar_t seluScale = (scalar_t)1.0507009873554804934193349852946; 36 | scalar_t seluAlpha = (scalar_t)1.6732632423543772848170429916717; 37 | 38 | // Loop over elements. 39 | int xi = blockIdx.x * p.loopX * blockDim.x + threadIdx.x; 40 | for (int loopIdx = 0; loopIdx < p.loopX && xi < p.sizeX; loopIdx++, xi += blockDim.x) 41 | { 42 | // Load. 43 | scalar_t x = (scalar_t)((const T*)p.x)[xi]; 44 | scalar_t b = (p.b) ? (scalar_t)((const T*)p.b)[(xi / p.stepB) % p.sizeB] : 0; 45 | scalar_t xref = (p.xref) ? (scalar_t)((const T*)p.xref)[xi] : 0; 46 | scalar_t yref = (p.yref) ? (scalar_t)((const T*)p.yref)[xi] : 0; 47 | scalar_t dy = (p.dy) ? (scalar_t)((const T*)p.dy)[xi] : one; 48 | scalar_t yy = (gain != 0) ? yref / gain : 0; 49 | scalar_t y = 0; 50 | 51 | // Apply bias. 52 | ((G == 0) ? x : xref) += b; 53 | 54 | // linear 55 | if (A == 1) 56 | { 57 | if (G == 0) y = x; 58 | if (G == 1) y = x; 59 | } 60 | 61 | // relu 62 | if (A == 2) 63 | { 64 | if (G == 0) y = (x > 0) ? x : 0; 65 | if (G == 1) y = (yy > 0) ? x : 0; 66 | } 67 | 68 | // lrelu 69 | if (A == 3) 70 | { 71 | if (G == 0) y = (x > 0) ? x : x * alpha; 72 | if (G == 1) y = (yy > 0) ? x : x * alpha; 73 | } 74 | 75 | // tanh 76 | if (A == 4) 77 | { 78 | if (G == 0) { scalar_t c = exp(x); scalar_t d = one / c; y = (x < -expRange) ? -one : (x > expRange) ? one : (c - d) / (c + d); } 79 | if (G == 1) y = x * (one - yy * yy); 80 | if (G == 2) y = x * (one - yy * yy) * (-two * yy); 81 | } 82 | 83 | // sigmoid 84 | if (A == 5) 85 | { 86 | if (G == 0) y = (x < -expRange) ? 0 : one / (exp(-x) + one); 87 | if (G == 1) y = x * yy * (one - yy); 88 | if (G == 2) y = x * yy * (one - yy) * (one - two * yy); 89 | } 90 | 91 | // elu 92 | if (A == 6) 93 | { 94 | if (G == 0) y = (x >= 0) ? x : exp(x) - one; 95 | if (G == 1) y = (yy >= 0) ? x : x * (yy + one); 96 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + one); 97 | } 98 | 99 | // selu 100 | if (A == 7) 101 | { 102 | if (G == 0) y = (x >= 0) ? seluScale * x : (seluScale * seluAlpha) * (exp(x) - one); 103 | if (G == 1) y = (yy >= 0) ? x * seluScale : x * (yy + seluScale * seluAlpha); 104 | if (G == 2) y = (yy >= 0) ? 0 : x * (yy + seluScale * seluAlpha); 105 | } 106 | 107 | // softplus 108 | if (A == 8) 109 | { 110 | if (G == 0) y = (x > expRange) ? x : log(exp(x) + one); 111 | if (G == 1) y = x * (one - exp(-yy)); 112 | if (G == 2) { scalar_t c = exp(-yy); y = x * c * (one - c); } 113 | } 114 | 115 | // swish 116 | if (A == 9) 117 | { 118 | if (G == 0) 119 | y = (x < -expRange) ? 0 : x / (exp(-x) + one); 120 | else 121 | { 122 | scalar_t c = exp(xref); 123 | scalar_t d = c + one; 124 | if (G == 1) 125 | y = (xref > halfExpRange) ? x : x * c * (xref + d) / (d * d); 126 | else 127 | y = (xref > halfExpRange) ? 0 : x * c * (xref * (two - d) + two * d) / (d * d * d); 128 | yref = (xref < -expRange) ? 0 : xref / (exp(-xref) + one) * gain; 129 | } 130 | } 131 | 132 | // Apply gain. 133 | y *= gain * dy; 134 | 135 | // Clamp. 136 | if (clamp >= 0) 137 | { 138 | if (G == 0) 139 | y = (y > -clamp & y < clamp) ? y : (y >= 0) ? clamp : -clamp; 140 | else 141 | y = (yref > -clamp & yref < clamp) ? y : 0; 142 | } 143 | 144 | // Store. 145 | ((T*)p.y)[xi] = (T)y; 146 | } 147 | } 148 | 149 | //------------------------------------------------------------------------ 150 | // CUDA kernel selection. 151 | 152 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p) 153 | { 154 | if (p.act == 1) return (void*)bias_act_kernel; 155 | if (p.act == 2) return (void*)bias_act_kernel; 156 | if (p.act == 3) return (void*)bias_act_kernel; 157 | if (p.act == 4) return (void*)bias_act_kernel; 158 | if (p.act == 5) return (void*)bias_act_kernel; 159 | if (p.act == 6) return (void*)bias_act_kernel; 160 | if (p.act == 7) return (void*)bias_act_kernel; 161 | if (p.act == 8) return (void*)bias_act_kernel; 162 | if (p.act == 9) return (void*)bias_act_kernel; 163 | return NULL; 164 | } 165 | 166 | //------------------------------------------------------------------------ 167 | // Template specializations. 168 | 169 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 170 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 171 | template void* choose_bias_act_kernel (const bias_act_kernel_params& p); 172 | 173 | //------------------------------------------------------------------------ 174 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | //------------------------------------------------------------------------ 10 | // CUDA kernel parameters. 11 | 12 | struct bias_act_kernel_params 13 | { 14 | const void* x; // [sizeX] 15 | const void* b; // [sizeB] or NULL 16 | const void* xref; // [sizeX] or NULL 17 | const void* yref; // [sizeX] or NULL 18 | const void* dy; // [sizeX] or NULL 19 | void* y; // [sizeX] 20 | 21 | int grad; 22 | int act; 23 | float alpha; 24 | float gain; 25 | float clamp; 26 | 27 | int sizeX; 28 | int sizeB; 29 | int stepB; 30 | int loopX; 31 | }; 32 | 33 | //------------------------------------------------------------------------ 34 | // CUDA kernel selection. 35 | 36 | template void* choose_bias_act_kernel(const bias_act_kernel_params& p); 37 | 38 | //------------------------------------------------------------------------ 39 | -------------------------------------------------------------------------------- /torch_utils/ops/bias_act.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom PyTorch ops for efficient bias and activation.""" 10 | 11 | import os 12 | import numpy as np 13 | import torch 14 | import dnnlib 15 | 16 | from .. import custom_ops 17 | from .. import misc 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | activation_funcs = { 22 | 'linear': dnnlib.EasyDict(func=lambda x, **_: x, def_alpha=0, def_gain=1, cuda_idx=1, ref='', has_2nd_grad=False), 23 | 'relu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.relu(x), def_alpha=0, def_gain=np.sqrt(2), cuda_idx=2, ref='y', has_2nd_grad=False), 24 | 'lrelu': dnnlib.EasyDict(func=lambda x, alpha, **_: torch.nn.functional.leaky_relu(x, alpha), def_alpha=0.2, def_gain=np.sqrt(2), cuda_idx=3, ref='y', has_2nd_grad=False), 25 | 'tanh': dnnlib.EasyDict(func=lambda x, **_: torch.tanh(x), def_alpha=0, def_gain=1, cuda_idx=4, ref='y', has_2nd_grad=True), 26 | 'sigmoid': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x), def_alpha=0, def_gain=1, cuda_idx=5, ref='y', has_2nd_grad=True), 27 | 'elu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.elu(x), def_alpha=0, def_gain=1, cuda_idx=6, ref='y', has_2nd_grad=True), 28 | 'selu': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.selu(x), def_alpha=0, def_gain=1, cuda_idx=7, ref='y', has_2nd_grad=True), 29 | 'softplus': dnnlib.EasyDict(func=lambda x, **_: torch.nn.functional.softplus(x), def_alpha=0, def_gain=1, cuda_idx=8, ref='y', has_2nd_grad=True), 30 | 'swish': dnnlib.EasyDict(func=lambda x, **_: torch.sigmoid(x) * x, def_alpha=0, def_gain=np.sqrt(2), cuda_idx=9, ref='x', has_2nd_grad=True), 31 | } 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | _plugin = None 36 | _null_tensor = torch.empty([0]) 37 | 38 | def _init(): 39 | global _plugin 40 | if _plugin is None: 41 | _plugin = custom_ops.get_plugin( 42 | module_name='bias_act_plugin', 43 | sources=['bias_act.cpp', 'bias_act.cu'], 44 | headers=['bias_act.h'], 45 | source_dir=os.path.dirname(__file__), 46 | extra_cuda_cflags=['--use_fast_math'], 47 | ) 48 | return True 49 | 50 | #---------------------------------------------------------------------------- 51 | 52 | def bias_act(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None, impl='cuda'): 53 | r"""Fused bias and activation function. 54 | 55 | Adds bias `b` to activation tensor `x`, evaluates activation function `act`, 56 | and scales the result by `gain`. Each of the steps is optional. In most cases, 57 | the fused op is considerably more efficient than performing the same calculation 58 | using standard PyTorch ops. It supports first and second order gradients, 59 | but not third order gradients. 60 | 61 | Args: 62 | x: Input activation tensor. Can be of any shape. 63 | b: Bias vector, or `None` to disable. Must be a 1D tensor of the same type 64 | as `x`. The shape must be known, and it must match the dimension of `x` 65 | corresponding to `dim`. 66 | dim: The dimension in `x` corresponding to the elements of `b`. 67 | The value of `dim` is ignored if `b` is not specified. 68 | act: Name of the activation function to evaluate, or `"linear"` to disable. 69 | Can be e.g. `"relu"`, `"lrelu"`, `"tanh"`, `"sigmoid"`, `"swish"`, etc. 70 | See `activation_funcs` for a full list. `None` is not allowed. 71 | alpha: Shape parameter for the activation function, or `None` to use the default. 72 | gain: Scaling factor for the output tensor, or `None` to use default. 73 | See `activation_funcs` for the default scaling of each activation function. 74 | If unsure, consider specifying 1. 75 | clamp: Clamp the output values to `[-clamp, +clamp]`, or `None` to disable 76 | the clamping (default). 77 | impl: Name of the implementation to use. Can be `"ref"` or `"cuda"` (default). 78 | 79 | Returns: 80 | Tensor of the same shape and datatype as `x`. 81 | """ 82 | assert isinstance(x, torch.Tensor) 83 | assert impl in ['ref', 'cuda'] 84 | if impl == 'cuda' and x.device.type == 'cuda' and _init(): 85 | return _bias_act_cuda(dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp).apply(x, b) 86 | return _bias_act_ref(x=x, b=b, dim=dim, act=act, alpha=alpha, gain=gain, clamp=clamp) 87 | 88 | #---------------------------------------------------------------------------- 89 | 90 | @misc.profiled_function 91 | def _bias_act_ref(x, b=None, dim=1, act='linear', alpha=None, gain=None, clamp=None): 92 | """Slow reference implementation of `bias_act()` using standard TensorFlow ops. 93 | """ 94 | assert isinstance(x, torch.Tensor) 95 | assert clamp is None or clamp >= 0 96 | spec = activation_funcs[act] 97 | alpha = float(alpha if alpha is not None else spec.def_alpha) 98 | gain = float(gain if gain is not None else spec.def_gain) 99 | clamp = float(clamp if clamp is not None else -1) 100 | 101 | # Add bias. 102 | if b is not None: 103 | assert isinstance(b, torch.Tensor) and b.ndim == 1 104 | assert 0 <= dim < x.ndim 105 | assert b.shape[0] == x.shape[dim] 106 | x = x + b.reshape([-1 if i == dim else 1 for i in range(x.ndim)]) 107 | 108 | # Evaluate activation function. 109 | alpha = float(alpha) 110 | x = spec.func(x, alpha=alpha) 111 | 112 | # Scale by gain. 113 | gain = float(gain) 114 | if gain != 1: 115 | x = x * gain 116 | 117 | # Clamp. 118 | if clamp >= 0: 119 | x = x.clamp(-clamp, clamp) # pylint: disable=invalid-unary-operand-type 120 | return x 121 | 122 | #---------------------------------------------------------------------------- 123 | 124 | _bias_act_cuda_cache = dict() 125 | 126 | def _bias_act_cuda(dim=1, act='linear', alpha=None, gain=None, clamp=None): 127 | """Fast CUDA implementation of `bias_act()` using custom ops. 128 | """ 129 | # Parse arguments. 130 | assert clamp is None or clamp >= 0 131 | spec = activation_funcs[act] 132 | alpha = float(alpha if alpha is not None else spec.def_alpha) 133 | gain = float(gain if gain is not None else spec.def_gain) 134 | clamp = float(clamp if clamp is not None else -1) 135 | 136 | # Lookup from cache. 137 | key = (dim, act, alpha, gain, clamp) 138 | if key in _bias_act_cuda_cache: 139 | return _bias_act_cuda_cache[key] 140 | 141 | # Forward op. 142 | class BiasActCuda(torch.autograd.Function): 143 | @staticmethod 144 | def forward(ctx, x, b): # pylint: disable=arguments-differ 145 | ctx.memory_format = torch.channels_last if x.ndim > 2 and x.stride(1) == 1 else torch.contiguous_format 146 | x = x.contiguous(memory_format=ctx.memory_format) 147 | b = b.contiguous() if b is not None else _null_tensor 148 | y = x 149 | if act != 'linear' or gain != 1 or clamp >= 0 or b is not _null_tensor: 150 | y = _plugin.bias_act(x, b, _null_tensor, _null_tensor, _null_tensor, 0, dim, spec.cuda_idx, alpha, gain, clamp) 151 | ctx.save_for_backward( 152 | x if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 153 | b if 'x' in spec.ref or spec.has_2nd_grad else _null_tensor, 154 | y if 'y' in spec.ref else _null_tensor) 155 | return y 156 | 157 | @staticmethod 158 | def backward(ctx, dy): # pylint: disable=arguments-differ 159 | dy = dy.contiguous(memory_format=ctx.memory_format) 160 | x, b, y = ctx.saved_tensors 161 | dx = None 162 | db = None 163 | 164 | if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]: 165 | dx = dy 166 | if act != 'linear' or gain != 1 or clamp >= 0: 167 | dx = BiasActCudaGrad.apply(dy, x, b, y) 168 | 169 | if ctx.needs_input_grad[1]: 170 | db = dx.sum([i for i in range(dx.ndim) if i != dim]) 171 | 172 | return dx, db 173 | 174 | # Backward op. 175 | class BiasActCudaGrad(torch.autograd.Function): 176 | @staticmethod 177 | def forward(ctx, dy, x, b, y): # pylint: disable=arguments-differ 178 | ctx.memory_format = torch.channels_last if dy.ndim > 2 and dy.stride(1) == 1 else torch.contiguous_format 179 | dx = _plugin.bias_act(dy, b, x, y, _null_tensor, 1, dim, spec.cuda_idx, alpha, gain, clamp) 180 | ctx.save_for_backward( 181 | dy if spec.has_2nd_grad else _null_tensor, 182 | x, b, y) 183 | return dx 184 | 185 | @staticmethod 186 | def backward(ctx, d_dx): # pylint: disable=arguments-differ 187 | d_dx = d_dx.contiguous(memory_format=ctx.memory_format) 188 | dy, x, b, y = ctx.saved_tensors 189 | d_dy = None 190 | d_x = None 191 | d_b = None 192 | d_y = None 193 | 194 | if ctx.needs_input_grad[0]: 195 | d_dy = BiasActCudaGrad.apply(d_dx, x, b, y) 196 | 197 | if spec.has_2nd_grad and (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]): 198 | d_x = _plugin.bias_act(d_dx, b, x, y, dy, 2, dim, spec.cuda_idx, alpha, gain, clamp) 199 | 200 | if spec.has_2nd_grad and ctx.needs_input_grad[2]: 201 | d_b = d_x.sum([i for i in range(d_x.ndim) if i != dim]) 202 | 203 | return d_dy, d_x, d_b, d_y 204 | 205 | # Add to cache. 206 | _bias_act_cuda_cache[key] = BiasActCuda 207 | return BiasActCuda 208 | 209 | #---------------------------------------------------------------------------- 210 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.conv2d` that supports 10 | arbitrarily high order gradients with zero performance penalty.""" 11 | 12 | import contextlib 13 | import torch 14 | 15 | # pylint: disable=redefined-builtin 16 | # pylint: disable=arguments-differ 17 | # pylint: disable=protected-access 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | enabled = False # Enable the custom op by setting this to true. 22 | weight_gradients_disabled = False # Forcefully disable computation of gradients with respect to the weights. 23 | 24 | @contextlib.contextmanager 25 | def no_weight_gradients(disable=True): 26 | global weight_gradients_disabled 27 | old = weight_gradients_disabled 28 | if disable: 29 | weight_gradients_disabled = True 30 | yield 31 | weight_gradients_disabled = old 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1): 36 | if _should_use_custom_op(input): 37 | return _conv2d_gradfix(transpose=False, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=0, dilation=dilation, groups=groups).apply(input, weight, bias) 38 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups) 39 | 40 | def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_padding=0, groups=1, dilation=1): 41 | if _should_use_custom_op(input): 42 | return _conv2d_gradfix(transpose=True, weight_shape=weight.shape, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation).apply(input, weight, bias) 43 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, stride=stride, padding=padding, output_padding=output_padding, groups=groups, dilation=dilation) 44 | 45 | #---------------------------------------------------------------------------- 46 | 47 | def _should_use_custom_op(input): 48 | assert isinstance(input, torch.Tensor) 49 | if (not enabled) or (not torch.backends.cudnn.enabled): 50 | return False 51 | if input.device.type != 'cuda': 52 | return False 53 | return True 54 | 55 | def _tuple_of_ints(xs, ndim): 56 | xs = tuple(xs) if isinstance(xs, (tuple, list)) else (xs,) * ndim 57 | assert len(xs) == ndim 58 | assert all(isinstance(x, int) for x in xs) 59 | return xs 60 | 61 | #---------------------------------------------------------------------------- 62 | 63 | _conv2d_gradfix_cache = dict() 64 | _null_tensor = torch.empty([0]) 65 | 66 | def _conv2d_gradfix(transpose, weight_shape, stride, padding, output_padding, dilation, groups): 67 | # Parse arguments. 68 | ndim = 2 69 | weight_shape = tuple(weight_shape) 70 | stride = _tuple_of_ints(stride, ndim) 71 | padding = _tuple_of_ints(padding, ndim) 72 | output_padding = _tuple_of_ints(output_padding, ndim) 73 | dilation = _tuple_of_ints(dilation, ndim) 74 | 75 | # Lookup from cache. 76 | key = (transpose, weight_shape, stride, padding, output_padding, dilation, groups) 77 | if key in _conv2d_gradfix_cache: 78 | return _conv2d_gradfix_cache[key] 79 | 80 | # Validate arguments. 81 | assert groups >= 1 82 | assert len(weight_shape) == ndim + 2 83 | assert all(stride[i] >= 1 for i in range(ndim)) 84 | assert all(padding[i] >= 0 for i in range(ndim)) 85 | assert all(dilation[i] >= 0 for i in range(ndim)) 86 | if not transpose: 87 | assert all(output_padding[i] == 0 for i in range(ndim)) 88 | else: # transpose 89 | assert all(0 <= output_padding[i] < max(stride[i], dilation[i]) for i in range(ndim)) 90 | 91 | # Helpers. 92 | common_kwargs = dict(stride=stride, padding=padding, dilation=dilation, groups=groups) 93 | def calc_output_padding(input_shape, output_shape): 94 | if transpose: 95 | return [0, 0] 96 | return [ 97 | input_shape[i + 2] 98 | - (output_shape[i + 2] - 1) * stride[i] 99 | - (1 - 2 * padding[i]) 100 | - dilation[i] * (weight_shape[i + 2] - 1) 101 | for i in range(ndim) 102 | ] 103 | 104 | # Forward & backward. 105 | class Conv2d(torch.autograd.Function): 106 | @staticmethod 107 | def forward(ctx, input, weight, bias): 108 | assert weight.shape == weight_shape 109 | ctx.save_for_backward( 110 | input if weight.requires_grad else _null_tensor, 111 | weight if input.requires_grad else _null_tensor, 112 | ) 113 | ctx.input_shape = input.shape 114 | 115 | # Simple 1x1 convolution => cuBLAS (only on Volta, not on Ampere). 116 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0) and torch.cuda.get_device_capability(input.device) < (8, 0): 117 | a = weight.reshape(groups, weight_shape[0] // groups, weight_shape[1]) 118 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1) 119 | c = (a.transpose(1, 2) if transpose else a) @ b.permute(1, 2, 0, 3).flatten(2) 120 | c = c.reshape(-1, input.shape[0], *input.shape[2:]).transpose(0, 1) 121 | c = c if bias is None else c + bias.unsqueeze(0).unsqueeze(2).unsqueeze(3) 122 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 123 | 124 | # General case => cuDNN. 125 | if transpose: 126 | return torch.nn.functional.conv_transpose2d(input=input, weight=weight, bias=bias, output_padding=output_padding, **common_kwargs) 127 | return torch.nn.functional.conv2d(input=input, weight=weight, bias=bias, **common_kwargs) 128 | 129 | @staticmethod 130 | def backward(ctx, grad_output): 131 | input, weight = ctx.saved_tensors 132 | input_shape = ctx.input_shape 133 | grad_input = None 134 | grad_weight = None 135 | grad_bias = None 136 | 137 | if ctx.needs_input_grad[0]: 138 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output.shape) 139 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 140 | grad_input = op.apply(grad_output, weight, None) 141 | assert grad_input.shape == input_shape 142 | 143 | if ctx.needs_input_grad[1] and not weight_gradients_disabled: 144 | grad_weight = Conv2dGradWeight.apply(grad_output, input) 145 | assert grad_weight.shape == weight_shape 146 | 147 | if ctx.needs_input_grad[2]: 148 | grad_bias = grad_output.sum([0, 2, 3]) 149 | 150 | return grad_input, grad_weight, grad_bias 151 | 152 | # Gradient with respect to the weights. 153 | class Conv2dGradWeight(torch.autograd.Function): 154 | @staticmethod 155 | def forward(ctx, grad_output, input): 156 | ctx.save_for_backward( 157 | grad_output if input.requires_grad else _null_tensor, 158 | input if grad_output.requires_grad else _null_tensor, 159 | ) 160 | ctx.grad_output_shape = grad_output.shape 161 | ctx.input_shape = input.shape 162 | 163 | # Simple 1x1 convolution => cuBLAS (on both Volta and Ampere). 164 | if weight_shape[2:] == stride == dilation == (1, 1) and padding == (0, 0): 165 | a = grad_output.reshape(grad_output.shape[0], groups, grad_output.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 166 | b = input.reshape(input.shape[0], groups, input.shape[1] // groups, -1).permute(1, 2, 0, 3).flatten(2) 167 | c = (b @ a.transpose(1, 2) if transpose else a @ b.transpose(1, 2)).reshape(weight_shape) 168 | return c.contiguous(memory_format=(torch.channels_last if input.stride(1) == 1 else torch.contiguous_format)) 169 | 170 | # General case => cuDNN. 171 | name = 'aten::cudnn_convolution_transpose_backward_weight' if transpose else 'aten::cudnn_convolution_backward_weight' 172 | flags = [torch.backends.cudnn.benchmark, torch.backends.cudnn.deterministic, torch.backends.cudnn.allow_tf32] 173 | return torch._C._jit_get_operation(name)(weight_shape, grad_output, input, padding, stride, dilation, groups, *flags) 174 | 175 | @staticmethod 176 | def backward(ctx, grad2_grad_weight): 177 | grad_output, input = ctx.saved_tensors 178 | grad_output_shape = ctx.grad_output_shape 179 | input_shape = ctx.input_shape 180 | grad2_grad_output = None 181 | grad2_input = None 182 | 183 | if ctx.needs_input_grad[0]: 184 | grad2_grad_output = Conv2d.apply(input, grad2_grad_weight, None) 185 | assert grad2_grad_output.shape == grad_output_shape 186 | 187 | if ctx.needs_input_grad[1]: 188 | p = calc_output_padding(input_shape=input_shape, output_shape=grad_output_shape) 189 | op = _conv2d_gradfix(transpose=(not transpose), weight_shape=weight_shape, output_padding=p, **common_kwargs) 190 | grad2_input = op.apply(grad_output, grad2_grad_weight, None) 191 | assert grad2_input.shape == input_shape 192 | 193 | return grad2_grad_output, grad2_input 194 | 195 | _conv2d_gradfix_cache[key] = Conv2d 196 | return Conv2d 197 | 198 | #---------------------------------------------------------------------------- 199 | -------------------------------------------------------------------------------- /torch_utils/ops/conv2d_resample.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """2D convolution with optional up/downsampling.""" 10 | 11 | import torch 12 | 13 | from .. import misc 14 | from . import conv2d_gradfix 15 | from . import upfirdn2d 16 | from .upfirdn2d import _parse_padding 17 | from .upfirdn2d import _get_filter_size 18 | 19 | #---------------------------------------------------------------------------- 20 | 21 | def _get_weight_shape(w): 22 | with misc.suppress_tracer_warnings(): # this value will be treated as a constant 23 | shape = [int(sz) for sz in w.shape] 24 | misc.assert_shape(w, shape) 25 | return shape 26 | 27 | #---------------------------------------------------------------------------- 28 | 29 | def _conv2d_wrapper(x, w, stride=1, padding=0, groups=1, transpose=False, flip_weight=True): 30 | """Wrapper for the underlying `conv2d()` and `conv_transpose2d()` implementations. 31 | """ 32 | _out_channels, _in_channels_per_group, kh, kw = _get_weight_shape(w) 33 | 34 | # Flip weight if requested. 35 | # Note: conv2d() actually performs correlation (flip_weight=True) not convolution (flip_weight=False). 36 | if not flip_weight and (kw > 1 or kh > 1): 37 | w = w.flip([2, 3]) 38 | 39 | # Execute using conv2d_gradfix. 40 | op = conv2d_gradfix.conv_transpose2d if transpose else conv2d_gradfix.conv2d 41 | return op(x, w, stride=stride, padding=padding, groups=groups) 42 | 43 | #---------------------------------------------------------------------------- 44 | 45 | @misc.profiled_function 46 | def conv2d_resample(x, w, f=None, up=1, down=1, padding=0, groups=1, flip_weight=True, flip_filter=False): 47 | r"""2D convolution with optional up/downsampling. 48 | 49 | Padding is performed only once at the beginning, not between the operations. 50 | 51 | Args: 52 | x: Input tensor of shape 53 | `[batch_size, in_channels, in_height, in_width]`. 54 | w: Weight tensor of shape 55 | `[out_channels, in_channels//groups, kernel_height, kernel_width]`. 56 | f: Low-pass filter for up/downsampling. Must be prepared beforehand by 57 | calling upfirdn2d.setup_filter(). None = identity (default). 58 | up: Integer upsampling factor (default: 1). 59 | down: Integer downsampling factor (default: 1). 60 | padding: Padding with respect to the upsampled image. Can be a single number 61 | or a list/tuple `[x, y]` or `[x_before, x_after, y_before, y_after]` 62 | (default: 0). 63 | groups: Split input channels into N groups (default: 1). 64 | flip_weight: False = convolution, True = correlation (default: True). 65 | flip_filter: False = convolution, True = correlation (default: False). 66 | 67 | Returns: 68 | Tensor of the shape `[batch_size, num_channels, out_height, out_width]`. 69 | """ 70 | # Validate arguments. 71 | assert isinstance(x, torch.Tensor) and (x.ndim == 4) 72 | assert isinstance(w, torch.Tensor) and (w.ndim == 4) and (w.dtype == x.dtype) 73 | assert f is None or (isinstance(f, torch.Tensor) and f.ndim in [1, 2] and f.dtype == torch.float32) 74 | assert isinstance(up, int) and (up >= 1) 75 | assert isinstance(down, int) and (down >= 1) 76 | assert isinstance(groups, int) and (groups >= 1) 77 | out_channels, in_channels_per_group, kh, kw = _get_weight_shape(w) 78 | fw, fh = _get_filter_size(f) 79 | px0, px1, py0, py1 = _parse_padding(padding) 80 | 81 | # Adjust padding to account for up/downsampling. 82 | if up > 1: 83 | px0 += (fw + up - 1) // 2 84 | px1 += (fw - up) // 2 85 | py0 += (fh + up - 1) // 2 86 | py1 += (fh - up) // 2 87 | if down > 1: 88 | px0 += (fw - down + 1) // 2 89 | px1 += (fw - down) // 2 90 | py0 += (fh - down + 1) // 2 91 | py1 += (fh - down) // 2 92 | 93 | # Fast path: 1x1 convolution with downsampling only => downsample first, then convolve. 94 | if kw == 1 and kh == 1 and (down > 1 and up == 1): 95 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 96 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 97 | return x 98 | 99 | # Fast path: 1x1 convolution with upsampling only => convolve first, then upsample. 100 | if kw == 1 and kh == 1 and (up > 1 and down == 1): 101 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 102 | x = upfirdn2d.upfirdn2d(x=x, f=f, up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 103 | return x 104 | 105 | # Fast path: downsampling only => use strided convolution. 106 | if down > 1 and up == 1: 107 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0,px1,py0,py1], flip_filter=flip_filter) 108 | x = _conv2d_wrapper(x=x, w=w, stride=down, groups=groups, flip_weight=flip_weight) 109 | return x 110 | 111 | # Fast path: upsampling with optional downsampling => use transpose strided convolution. 112 | if up > 1: 113 | if groups == 1: 114 | w = w.transpose(0, 1) 115 | else: 116 | w = w.reshape(groups, out_channels // groups, in_channels_per_group, kh, kw) 117 | w = w.transpose(1, 2) 118 | w = w.reshape(groups * in_channels_per_group, out_channels // groups, kh, kw) 119 | px0 -= kw - 1 120 | px1 -= kw - up 121 | py0 -= kh - 1 122 | py1 -= kh - up 123 | pxt = max(min(-px0, -px1), 0) 124 | pyt = max(min(-py0, -py1), 0) 125 | x = _conv2d_wrapper(x=x, w=w, stride=up, padding=[pyt,pxt], groups=groups, transpose=True, flip_weight=(not flip_weight)) 126 | x = upfirdn2d.upfirdn2d(x=x, f=f, padding=[px0+pxt,px1+pxt,py0+pyt,py1+pyt], gain=up**2, flip_filter=flip_filter) 127 | if down > 1: 128 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 129 | return x 130 | 131 | # Fast path: no up/downsampling, padding supported by the underlying implementation => use plain conv2d. 132 | if up == 1 and down == 1: 133 | if px0 == px1 and py0 == py1 and px0 >= 0 and py0 >= 0: 134 | return _conv2d_wrapper(x=x, w=w, padding=[py0,px0], groups=groups, flip_weight=flip_weight) 135 | 136 | # Fallback: Generic reference implementation. 137 | x = upfirdn2d.upfirdn2d(x=x, f=(f if up > 1 else None), up=up, padding=[px0,px1,py0,py1], gain=up**2, flip_filter=flip_filter) 138 | x = _conv2d_wrapper(x=x, w=w, groups=groups, flip_weight=flip_weight) 139 | if down > 1: 140 | x = upfirdn2d.upfirdn2d(x=x, f=f, down=down, flip_filter=flip_filter) 141 | return x 142 | 143 | #---------------------------------------------------------------------------- 144 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct filtered_lrelu_kernel_params 15 | { 16 | // These parameters decide which kernel to use. 17 | int up; // upsampling ratio (1, 2, 4) 18 | int down; // downsampling ratio (1, 2, 4) 19 | int2 fuShape; // [size, 1] | [size, size] 20 | int2 fdShape; // [size, 1] | [size, size] 21 | 22 | int _dummy; // Alignment. 23 | 24 | // Rest of the parameters. 25 | const void* x; // Input tensor. 26 | void* y; // Output tensor. 27 | const void* b; // Bias tensor. 28 | unsigned char* s; // Sign tensor in/out. NULL if unused. 29 | const float* fu; // Upsampling filter. 30 | const float* fd; // Downsampling filter. 31 | 32 | int2 pad0; // Left/top padding. 33 | float gain; // Additional gain factor. 34 | float slope; // Leaky ReLU slope on negative side. 35 | float clamp; // Clamp after nonlinearity. 36 | int flip; // Filter kernel flip for gradient computation. 37 | 38 | int tilesXdim; // Original number of horizontal output tiles. 39 | int tilesXrep; // Number of horizontal tiles per CTA. 40 | int blockZofs; // Block z offset to support large minibatch, channel dimensions. 41 | 42 | int4 xShape; // [width, height, channel, batch] 43 | int4 yShape; // [width, height, channel, batch] 44 | int2 sShape; // [width, height] - width is in bytes. Contiguous. Zeros if unused. 45 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 46 | int swLimit; // Active width of sign tensor in bytes. 47 | 48 | longlong4 xStride; // Strides of all tensors except signs, same component order as shapes. 49 | longlong4 yStride; // 50 | int64_t bStride; // 51 | longlong3 fuStride; // 52 | longlong3 fdStride; // 53 | }; 54 | 55 | struct filtered_lrelu_act_kernel_params 56 | { 57 | void* x; // Input/output, modified in-place. 58 | unsigned char* s; // Sign tensor in/out. NULL if unused. 59 | 60 | float gain; // Additional gain factor. 61 | float slope; // Leaky ReLU slope on negative side. 62 | float clamp; // Clamp after nonlinearity. 63 | 64 | int4 xShape; // [width, height, channel, batch] 65 | longlong4 xStride; // Input/output tensor strides, same order as in shape. 66 | int2 sShape; // [width, height] - width is in elements. Contiguous. Zeros if unused. 67 | int2 sOfs; // [ofs_x, ofs_y] - offset between upsampled data and sign tensor. 68 | }; 69 | 70 | //------------------------------------------------------------------------ 71 | // CUDA kernel specialization. 72 | 73 | struct filtered_lrelu_kernel_spec 74 | { 75 | void* setup; // Function for filter kernel setup. 76 | void* exec; // Function for main operation. 77 | int2 tileOut; // Width/height of launch tile. 78 | int numWarps; // Number of warps per thread block, determines launch block size. 79 | int xrep; // For processing multiple horizontal tiles per thread block. 80 | int dynamicSharedKB; // How much dynamic shared memory the exec kernel wants. 81 | }; 82 | 83 | //------------------------------------------------------------------------ 84 | // CUDA kernel selection. 85 | 86 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 87 | template void* choose_filtered_lrelu_act_kernel(void); 88 | template cudaError_t copy_filters(cudaStream_t stream); 89 | 90 | //------------------------------------------------------------------------ 91 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_ns.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for no signs mode (no gradients required). 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_rd.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign read mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/filtered_lrelu_wr.cu: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include "filtered_lrelu.cu" 10 | 11 | // Template/kernel specializations for sign write mode. 12 | 13 | // Full op, 32-bit indexing. 14 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 15 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 16 | 17 | // Full op, 64-bit indexing. 18 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 19 | template filtered_lrelu_kernel_spec choose_filtered_lrelu_kernel(const filtered_lrelu_kernel_params& p, int sharedKB); 20 | 21 | // Activation/signs only for generic variant. 64-bit indexing. 22 | template void* choose_filtered_lrelu_act_kernel(void); 23 | template void* choose_filtered_lrelu_act_kernel(void); 24 | template void* choose_filtered_lrelu_act_kernel(void); 25 | 26 | // Copy filters to constant memory. 27 | template cudaError_t copy_filters(cudaStream_t stream); 28 | -------------------------------------------------------------------------------- /torch_utils/ops/fma.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Fused multiply-add, with slightly faster gradients than `torch.addcmul()`.""" 10 | 11 | import torch 12 | 13 | #---------------------------------------------------------------------------- 14 | 15 | def fma(a, b, c): # => a * b + c 16 | return _FusedMultiplyAdd.apply(a, b, c) 17 | 18 | #---------------------------------------------------------------------------- 19 | 20 | class _FusedMultiplyAdd(torch.autograd.Function): # a * b + c 21 | @staticmethod 22 | def forward(ctx, a, b, c): # pylint: disable=arguments-differ 23 | out = torch.addcmul(c, a, b) 24 | ctx.save_for_backward(a, b) 25 | ctx.c_shape = c.shape 26 | return out 27 | 28 | @staticmethod 29 | def backward(ctx, dout): # pylint: disable=arguments-differ 30 | a, b = ctx.saved_tensors 31 | c_shape = ctx.c_shape 32 | da = None 33 | db = None 34 | dc = None 35 | 36 | if ctx.needs_input_grad[0]: 37 | da = _unbroadcast(dout * b, a.shape) 38 | 39 | if ctx.needs_input_grad[1]: 40 | db = _unbroadcast(dout * a, b.shape) 41 | 42 | if ctx.needs_input_grad[2]: 43 | dc = _unbroadcast(dout, c_shape) 44 | 45 | return da, db, dc 46 | 47 | #---------------------------------------------------------------------------- 48 | 49 | def _unbroadcast(x, shape): 50 | extra_dims = x.ndim - len(shape) 51 | assert extra_dims >= 0 52 | dim = [i for i in range(x.ndim) if x.shape[i] > 1 and (i < extra_dims or shape[i - extra_dims] == 1)] 53 | if len(dim): 54 | x = x.sum(dim=dim, keepdim=True) 55 | if extra_dims: 56 | x = x.reshape(-1, *x.shape[extra_dims+1:]) 57 | assert x.shape == shape 58 | return x 59 | 60 | #---------------------------------------------------------------------------- 61 | -------------------------------------------------------------------------------- /torch_utils/ops/grid_sample_gradfix.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Custom replacement for `torch.nn.functional.grid_sample` that 10 | supports arbitrarily high order gradients between the input and output. 11 | Only works on 2D images and assumes 12 | `mode='bilinear'`, `padding_mode='zeros'`, `align_corners=False`.""" 13 | 14 | import torch 15 | 16 | # pylint: disable=redefined-builtin 17 | # pylint: disable=arguments-differ 18 | # pylint: disable=protected-access 19 | 20 | #---------------------------------------------------------------------------- 21 | 22 | enabled = False # Enable the custom op by setting this to true. 23 | 24 | #---------------------------------------------------------------------------- 25 | 26 | def grid_sample(input, grid): 27 | if _should_use_custom_op(): 28 | return _GridSample2dForward.apply(input, grid) 29 | return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 30 | 31 | #---------------------------------------------------------------------------- 32 | 33 | def _should_use_custom_op(): 34 | return enabled 35 | 36 | #---------------------------------------------------------------------------- 37 | 38 | class _GridSample2dForward(torch.autograd.Function): 39 | @staticmethod 40 | def forward(ctx, input, grid): 41 | assert input.ndim == 4 42 | assert grid.ndim == 4 43 | output = torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=False) 44 | ctx.save_for_backward(input, grid) 45 | return output 46 | 47 | @staticmethod 48 | def backward(ctx, grad_output): 49 | input, grid = ctx.saved_tensors 50 | grad_input, grad_grid = _GridSample2dBackward.apply(grad_output, input, grid) 51 | return grad_input, grad_grid 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | class _GridSample2dBackward(torch.autograd.Function): 56 | @staticmethod 57 | def forward(ctx, grad_output, input, grid): 58 | op = torch._C._jit_get_operation('aten::grid_sampler_2d_backward') 59 | grad_input, grad_grid = op(grad_output, input, grid, 0, 0, False) 60 | ctx.save_for_backward(grid) 61 | return grad_input, grad_grid 62 | 63 | @staticmethod 64 | def backward(ctx, grad2_grad_input, grad2_grad_grid): 65 | _ = grad2_grad_grid # unused 66 | grid, = ctx.saved_tensors 67 | grad2_grad_output = None 68 | grad2_input = None 69 | grad2_grid = None 70 | 71 | if ctx.needs_input_grad[0]: 72 | grad2_grad_output = _GridSample2dForward.apply(grad2_grad_input, grid) 73 | 74 | assert not ctx.needs_input_grad[2] 75 | return grad2_grad_output, grad2_input, grad2_grid 76 | 77 | #---------------------------------------------------------------------------- 78 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.cpp: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | #include 11 | #include 12 | #include "upfirdn2d.h" 13 | 14 | //------------------------------------------------------------------------ 15 | 16 | static torch::Tensor upfirdn2d(torch::Tensor x, torch::Tensor f, int upx, int upy, int downx, int downy, int padx0, int padx1, int pady0, int pady1, bool flip, float gain) 17 | { 18 | // Validate arguments. 19 | TORCH_CHECK(x.is_cuda(), "x must reside on CUDA device"); 20 | TORCH_CHECK(f.device() == x.device(), "f must reside on the same device as x"); 21 | TORCH_CHECK(f.dtype() == torch::kFloat, "f must be float32"); 22 | TORCH_CHECK(x.numel() <= INT_MAX, "x is too large"); 23 | TORCH_CHECK(f.numel() <= INT_MAX, "f is too large"); 24 | TORCH_CHECK(x.numel() > 0, "x has zero size"); 25 | TORCH_CHECK(f.numel() > 0, "f has zero size"); 26 | TORCH_CHECK(x.dim() == 4, "x must be rank 4"); 27 | TORCH_CHECK(f.dim() == 2, "f must be rank 2"); 28 | TORCH_CHECK((x.size(0)-1)*x.stride(0) + (x.size(1)-1)*x.stride(1) + (x.size(2)-1)*x.stride(2) + (x.size(3)-1)*x.stride(3) <= INT_MAX, "x memory footprint is too large"); 29 | TORCH_CHECK(f.size(0) >= 1 && f.size(1) >= 1, "f must be at least 1x1"); 30 | TORCH_CHECK(upx >= 1 && upy >= 1, "upsampling factor must be at least 1"); 31 | TORCH_CHECK(downx >= 1 && downy >= 1, "downsampling factor must be at least 1"); 32 | 33 | // Create output tensor. 34 | const at::cuda::OptionalCUDAGuard device_guard(device_of(x)); 35 | int outW = ((int)x.size(3) * upx + padx0 + padx1 - (int)f.size(1) + downx) / downx; 36 | int outH = ((int)x.size(2) * upy + pady0 + pady1 - (int)f.size(0) + downy) / downy; 37 | TORCH_CHECK(outW >= 1 && outH >= 1, "output must be at least 1x1"); 38 | torch::Tensor y = torch::empty({x.size(0), x.size(1), outH, outW}, x.options(), x.suggest_memory_format()); 39 | TORCH_CHECK(y.numel() <= INT_MAX, "output is too large"); 40 | TORCH_CHECK((y.size(0)-1)*y.stride(0) + (y.size(1)-1)*y.stride(1) + (y.size(2)-1)*y.stride(2) + (y.size(3)-1)*y.stride(3) <= INT_MAX, "output memory footprint is too large"); 41 | 42 | // Initialize CUDA kernel parameters. 43 | upfirdn2d_kernel_params p; 44 | p.x = x.data_ptr(); 45 | p.f = f.data_ptr(); 46 | p.y = y.data_ptr(); 47 | p.up = make_int2(upx, upy); 48 | p.down = make_int2(downx, downy); 49 | p.pad0 = make_int2(padx0, pady0); 50 | p.flip = (flip) ? 1 : 0; 51 | p.gain = gain; 52 | p.inSize = make_int4((int)x.size(3), (int)x.size(2), (int)x.size(1), (int)x.size(0)); 53 | p.inStride = make_int4((int)x.stride(3), (int)x.stride(2), (int)x.stride(1), (int)x.stride(0)); 54 | p.filterSize = make_int2((int)f.size(1), (int)f.size(0)); 55 | p.filterStride = make_int2((int)f.stride(1), (int)f.stride(0)); 56 | p.outSize = make_int4((int)y.size(3), (int)y.size(2), (int)y.size(1), (int)y.size(0)); 57 | p.outStride = make_int4((int)y.stride(3), (int)y.stride(2), (int)y.stride(1), (int)y.stride(0)); 58 | p.sizeMajor = (p.inStride.z == 1) ? p.inSize.w : p.inSize.w * p.inSize.z; 59 | p.sizeMinor = (p.inStride.z == 1) ? p.inSize.z : 1; 60 | 61 | // Choose CUDA kernel. 62 | upfirdn2d_kernel_spec spec; 63 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(x.scalar_type(), "upfirdn2d_cuda", [&] 64 | { 65 | spec = choose_upfirdn2d_kernel(p); 66 | }); 67 | 68 | // Set looping options. 69 | p.loopMajor = (p.sizeMajor - 1) / 16384 + 1; 70 | p.loopMinor = spec.loopMinor; 71 | p.loopX = spec.loopX; 72 | p.launchMinor = (p.sizeMinor - 1) / p.loopMinor + 1; 73 | p.launchMajor = (p.sizeMajor - 1) / p.loopMajor + 1; 74 | 75 | // Compute grid size. 76 | dim3 blockSize, gridSize; 77 | if (spec.tileOutW < 0) // large 78 | { 79 | blockSize = dim3(4, 32, 1); 80 | gridSize = dim3( 81 | ((p.outSize.y - 1) / blockSize.x + 1) * p.launchMinor, 82 | (p.outSize.x - 1) / (blockSize.y * p.loopX) + 1, 83 | p.launchMajor); 84 | } 85 | else // small 86 | { 87 | blockSize = dim3(256, 1, 1); 88 | gridSize = dim3( 89 | ((p.outSize.y - 1) / spec.tileOutH + 1) * p.launchMinor, 90 | (p.outSize.x - 1) / (spec.tileOutW * p.loopX) + 1, 91 | p.launchMajor); 92 | } 93 | 94 | // Launch CUDA kernel. 95 | void* args[] = {&p}; 96 | AT_CUDA_CHECK(cudaLaunchKernel(spec.kernel, gridSize, blockSize, args, 0, at::cuda::getCurrentCUDAStream())); 97 | return y; 98 | } 99 | 100 | //------------------------------------------------------------------------ 101 | 102 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) 103 | { 104 | m.def("upfirdn2d", &upfirdn2d); 105 | } 106 | 107 | //------------------------------------------------------------------------ 108 | -------------------------------------------------------------------------------- /torch_utils/ops/upfirdn2d.h: -------------------------------------------------------------------------------- 1 | // Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | // 3 | // NVIDIA CORPORATION and its licensors retain all intellectual property 4 | // and proprietary rights in and to this software, related documentation 5 | // and any modifications thereto. Any use, reproduction, disclosure or 6 | // distribution of this software and related documentation without an express 7 | // license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | #include 10 | 11 | //------------------------------------------------------------------------ 12 | // CUDA kernel parameters. 13 | 14 | struct upfirdn2d_kernel_params 15 | { 16 | const void* x; 17 | const float* f; 18 | void* y; 19 | 20 | int2 up; 21 | int2 down; 22 | int2 pad0; 23 | int flip; 24 | float gain; 25 | 26 | int4 inSize; // [width, height, channel, batch] 27 | int4 inStride; 28 | int2 filterSize; // [width, height] 29 | int2 filterStride; 30 | int4 outSize; // [width, height, channel, batch] 31 | int4 outStride; 32 | int sizeMinor; 33 | int sizeMajor; 34 | 35 | int loopMinor; 36 | int loopMajor; 37 | int loopX; 38 | int launchMinor; 39 | int launchMajor; 40 | }; 41 | 42 | //------------------------------------------------------------------------ 43 | // CUDA kernel specialization. 44 | 45 | struct upfirdn2d_kernel_spec 46 | { 47 | void* kernel; 48 | int tileOutW; 49 | int tileOutH; 50 | int loopMinor; 51 | int loopX; 52 | }; 53 | 54 | //------------------------------------------------------------------------ 55 | // CUDA kernel selection. 56 | 57 | template upfirdn2d_kernel_spec choose_upfirdn2d_kernel(const upfirdn2d_kernel_params& p); 58 | 59 | //------------------------------------------------------------------------ 60 | -------------------------------------------------------------------------------- /torch_utils/persistence.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for pickling Python code alongside other data. 10 | 11 | The pickled code is automatically imported into a separate Python module 12 | during unpickling. This way, any previously exported pickles will remain 13 | usable even if the original code is no longer available, or if the current 14 | version of the code is not consistent with what was originally pickled.""" 15 | 16 | import sys 17 | import pickle 18 | import io 19 | import inspect 20 | import copy 21 | import uuid 22 | import types 23 | import dnnlib 24 | 25 | #---------------------------------------------------------------------------- 26 | 27 | _version = 6 # internal version number 28 | _decorators = set() # {decorator_class, ...} 29 | _import_hooks = [] # [hook_function, ...] 30 | _module_to_src_dict = dict() # {module: src, ...} 31 | _src_to_module_dict = dict() # {src: module, ...} 32 | 33 | #---------------------------------------------------------------------------- 34 | 35 | def persistent_class(orig_class): 36 | r"""Class decorator that extends a given class to save its source code 37 | when pickled. 38 | 39 | Example: 40 | 41 | from torch_utils import persistence 42 | 43 | @persistence.persistent_class 44 | class MyNetwork(torch.nn.Module): 45 | def __init__(self, num_inputs, num_outputs): 46 | super().__init__() 47 | self.fc = MyLayer(num_inputs, num_outputs) 48 | ... 49 | 50 | @persistence.persistent_class 51 | class MyLayer(torch.nn.Module): 52 | ... 53 | 54 | When pickled, any instance of `MyNetwork` and `MyLayer` will save its 55 | source code alongside other internal state (e.g., parameters, buffers, 56 | and submodules). This way, any previously exported pickle will remain 57 | usable even if the class definitions have been modified or are no 58 | longer available. 59 | 60 | The decorator saves the source code of the entire Python module 61 | containing the decorated class. It does *not* save the source code of 62 | any imported modules. Thus, the imported modules must be available 63 | during unpickling, also including `torch_utils.persistence` itself. 64 | 65 | It is ok to call functions defined in the same module from the 66 | decorated class. However, if the decorated class depends on other 67 | classes defined in the same module, they must be decorated as well. 68 | This is illustrated in the above example in the case of `MyLayer`. 69 | 70 | It is also possible to employ the decorator just-in-time before 71 | calling the constructor. For example: 72 | 73 | cls = MyLayer 74 | if want_to_make_it_persistent: 75 | cls = persistence.persistent_class(cls) 76 | layer = cls(num_inputs, num_outputs) 77 | 78 | As an additional feature, the decorator also keeps track of the 79 | arguments that were used to construct each instance of the decorated 80 | class. The arguments can be queried via `obj.init_args` and 81 | `obj.init_kwargs`, and they are automatically pickled alongside other 82 | object state. A typical use case is to first unpickle a previous 83 | instance of a persistent class, and then upgrade it to use the latest 84 | version of the source code: 85 | 86 | with open('old_pickle.pkl', 'rb') as f: 87 | old_net = pickle.load(f) 88 | new_net = MyNetwork(*old_obj.init_args, **old_obj.init_kwargs) 89 | misc.copy_params_and_buffers(old_net, new_net, require_all=True) 90 | """ 91 | assert isinstance(orig_class, type) 92 | if is_persistent(orig_class): 93 | return orig_class 94 | 95 | assert orig_class.__module__ in sys.modules 96 | orig_module = sys.modules[orig_class.__module__] 97 | orig_module_src = _module_to_src(orig_module) 98 | 99 | class Decorator(orig_class): 100 | _orig_module_src = orig_module_src 101 | _orig_class_name = orig_class.__name__ 102 | 103 | def __init__(self, *args, **kwargs): 104 | super().__init__(*args, **kwargs) 105 | self._init_args = copy.deepcopy(args) 106 | self._init_kwargs = copy.deepcopy(kwargs) 107 | assert orig_class.__name__ in orig_module.__dict__ 108 | _check_pickleable(self.__reduce__()) 109 | 110 | @property 111 | def init_args(self): 112 | return copy.deepcopy(self._init_args) 113 | 114 | @property 115 | def init_kwargs(self): 116 | return dnnlib.EasyDict(copy.deepcopy(self._init_kwargs)) 117 | 118 | def __reduce__(self): 119 | fields = list(super().__reduce__()) 120 | fields += [None] * max(3 - len(fields), 0) 121 | if fields[0] is not _reconstruct_persistent_obj: 122 | meta = dict(type='class', version=_version, module_src=self._orig_module_src, class_name=self._orig_class_name, state=fields[2]) 123 | fields[0] = _reconstruct_persistent_obj # reconstruct func 124 | fields[1] = (meta,) # reconstruct args 125 | fields[2] = None # state dict 126 | return tuple(fields) 127 | 128 | Decorator.__name__ = orig_class.__name__ 129 | _decorators.add(Decorator) 130 | return Decorator 131 | 132 | #---------------------------------------------------------------------------- 133 | 134 | def is_persistent(obj): 135 | r"""Test whether the given object or class is persistent, i.e., 136 | whether it will save its source code when pickled. 137 | """ 138 | try: 139 | if obj in _decorators: 140 | return True 141 | except TypeError: 142 | pass 143 | return type(obj) in _decorators # pylint: disable=unidiomatic-typecheck 144 | 145 | #---------------------------------------------------------------------------- 146 | 147 | def import_hook(hook): 148 | r"""Register an import hook that is called whenever a persistent object 149 | is being unpickled. A typical use case is to patch the pickled source 150 | code to avoid errors and inconsistencies when the API of some imported 151 | module has changed. 152 | 153 | The hook should have the following signature: 154 | 155 | hook(meta) -> modified meta 156 | 157 | `meta` is an instance of `dnnlib.EasyDict` with the following fields: 158 | 159 | type: Type of the persistent object, e.g. `'class'`. 160 | version: Internal version number of `torch_utils.persistence`. 161 | module_src Original source code of the Python module. 162 | class_name: Class name in the original Python module. 163 | state: Internal state of the object. 164 | 165 | Example: 166 | 167 | @persistence.import_hook 168 | def wreck_my_network(meta): 169 | if meta.class_name == 'MyNetwork': 170 | print('MyNetwork is being imported. I will wreck it!') 171 | meta.module_src = meta.module_src.replace("True", "False") 172 | return meta 173 | """ 174 | assert callable(hook) 175 | _import_hooks.append(hook) 176 | 177 | #---------------------------------------------------------------------------- 178 | 179 | def _reconstruct_persistent_obj(meta): 180 | r"""Hook that is called internally by the `pickle` module to unpickle 181 | a persistent object. 182 | """ 183 | meta = dnnlib.EasyDict(meta) 184 | meta.state = dnnlib.EasyDict(meta.state) 185 | for hook in _import_hooks: 186 | meta = hook(meta) 187 | assert meta is not None 188 | 189 | assert meta.version == _version 190 | module = _src_to_module(meta.module_src) 191 | 192 | assert meta.type == 'class' 193 | orig_class = module.__dict__[meta.class_name] 194 | decorator_class = persistent_class(orig_class) 195 | obj = decorator_class.__new__(decorator_class) 196 | 197 | setstate = getattr(obj, '__setstate__', None) 198 | if callable(setstate): 199 | setstate(meta.state) # pylint: disable=not-callable 200 | else: 201 | obj.__dict__.update(meta.state) 202 | return obj 203 | 204 | #---------------------------------------------------------------------------- 205 | 206 | def _module_to_src(module): 207 | r"""Query the source code of a given Python module. 208 | """ 209 | src = _module_to_src_dict.get(module, None) 210 | if src is None: 211 | src = inspect.getsource(module) 212 | _module_to_src_dict[module] = src 213 | _src_to_module_dict[src] = module 214 | return src 215 | 216 | def _src_to_module(src): 217 | r"""Get or create a Python module for the given source code. 218 | """ 219 | module = _src_to_module_dict.get(src, None) 220 | if module is None: 221 | module_name = "_imported_module_" + uuid.uuid4().hex 222 | module = types.ModuleType(module_name) 223 | sys.modules[module_name] = module 224 | _module_to_src_dict[module] = src 225 | _src_to_module_dict[src] = module 226 | exec(src, module.__dict__) # pylint: disable=exec-used 227 | return module 228 | 229 | #---------------------------------------------------------------------------- 230 | 231 | def _check_pickleable(obj): 232 | r"""Check that the given object is pickleable, raising an exception if 233 | it is not. This function is expected to be considerably more efficient 234 | than actually pickling the object. 235 | """ 236 | def recurse(obj): 237 | if isinstance(obj, (list, tuple, set)): 238 | return [recurse(x) for x in obj] 239 | if isinstance(obj, dict): 240 | return [[recurse(x), recurse(y)] for x, y in obj.items()] 241 | if isinstance(obj, (str, int, float, bool, bytes, bytearray)): 242 | return None # Python primitive types are pickleable. 243 | if f'{type(obj).__module__}.{type(obj).__name__}' in ['numpy.ndarray', 'torch.Tensor', 'torch.nn.parameter.Parameter']: 244 | return None # NumPy arrays and PyTorch tensors are pickleable. 245 | if is_persistent(obj): 246 | return None # Persistent objects are pickleable, by virtue of the constructor check. 247 | return obj 248 | with io.BytesIO() as f: 249 | pickle.dump(recurse(obj), f) 250 | 251 | #---------------------------------------------------------------------------- 252 | -------------------------------------------------------------------------------- /torch_utils/training_stats.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. 2 | # 3 | # NVIDIA CORPORATION and its licensors retain all intellectual property 4 | # and proprietary rights in and to this software, related documentation 5 | # and any modifications thereto. Any use, reproduction, disclosure or 6 | # distribution of this software and related documentation without an express 7 | # license agreement from NVIDIA CORPORATION is strictly prohibited. 8 | 9 | """Facilities for reporting and collecting training statistics across 10 | multiple processes and devices. The interface is designed to minimize 11 | synchronization overhead as well as the amount of boilerplate in user 12 | code.""" 13 | 14 | import re 15 | import numpy as np 16 | import torch 17 | import dnnlib 18 | 19 | from . import misc 20 | 21 | #---------------------------------------------------------------------------- 22 | 23 | _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] 24 | _reduce_dtype = torch.float32 # Data type to use for initial per-tensor reduction. 25 | _counter_dtype = torch.float64 # Data type to use for the internal counters. 26 | _rank = 0 # Rank of the current process. 27 | _sync_device = None # Device to use for multiprocess communication. None = single-process. 28 | _sync_called = False # Has _sync() been called yet? 29 | _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor 30 | _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor 31 | 32 | #---------------------------------------------------------------------------- 33 | 34 | def init_multiprocessing(rank, sync_device): 35 | r"""Initializes `torch_utils.training_stats` for collecting statistics 36 | across multiple processes. 37 | 38 | This function must be called after 39 | `torch.distributed.init_process_group()` and before `Collector.update()`. 40 | The call is not necessary if multi-process collection is not needed. 41 | 42 | Args: 43 | rank: Rank of the current process. 44 | sync_device: PyTorch device to use for inter-process 45 | communication, or None to disable multi-process 46 | collection. Typically `torch.device('cuda', rank)`. 47 | """ 48 | global _rank, _sync_device 49 | assert not _sync_called 50 | _rank = rank 51 | _sync_device = sync_device 52 | 53 | #---------------------------------------------------------------------------- 54 | 55 | @misc.profiled_function 56 | def report(name, value): 57 | r"""Broadcasts the given set of scalars to all interested instances of 58 | `Collector`, across device and process boundaries. 59 | 60 | This function is expected to be extremely cheap and can be safely 61 | called from anywhere in the training loop, loss function, or inside a 62 | `torch.nn.Module`. 63 | 64 | Warning: The current implementation expects the set of unique names to 65 | be consistent across processes. Please make sure that `report()` is 66 | called at least once for each unique name by each process, and in the 67 | same order. If a given process has no scalars to broadcast, it can do 68 | `report(name, [])` (empty list). 69 | 70 | Args: 71 | name: Arbitrary string specifying the name of the statistic. 72 | Averages are accumulated separately for each unique name. 73 | value: Arbitrary set of scalars. Can be a list, tuple, 74 | NumPy array, PyTorch tensor, or Python scalar. 75 | 76 | Returns: 77 | The same `value` that was passed in. 78 | """ 79 | if name not in _counters: 80 | _counters[name] = dict() 81 | 82 | elems = torch.as_tensor(value) 83 | if elems.numel() == 0: 84 | return value 85 | 86 | elems = elems.detach().flatten().to(_reduce_dtype) 87 | moments = torch.stack([ 88 | torch.ones_like(elems).sum(), 89 | elems.sum(), 90 | elems.square().sum(), 91 | ]) 92 | assert moments.ndim == 1 and moments.shape[0] == _num_moments 93 | moments = moments.to(_counter_dtype) 94 | 95 | device = moments.device 96 | if device not in _counters[name]: 97 | _counters[name][device] = torch.zeros_like(moments) 98 | _counters[name][device].add_(moments) 99 | return value 100 | 101 | #---------------------------------------------------------------------------- 102 | 103 | def report0(name, value): 104 | r"""Broadcasts the given set of scalars by the first process (`rank = 0`), 105 | but ignores any scalars provided by the other processes. 106 | See `report()` for further details. 107 | """ 108 | report(name, value if _rank == 0 else []) 109 | return value 110 | 111 | #---------------------------------------------------------------------------- 112 | 113 | class Collector: 114 | r"""Collects the scalars broadcasted by `report()` and `report0()` and 115 | computes their long-term averages (mean and standard deviation) over 116 | user-defined periods of time. 117 | 118 | The averages are first collected into internal counters that are not 119 | directly visible to the user. They are then copied to the user-visible 120 | state as a result of calling `update()` and can then be queried using 121 | `mean()`, `std()`, `as_dict()`, etc. Calling `update()` also resets the 122 | internal counters for the next round, so that the user-visible state 123 | effectively reflects averages collected between the last two calls to 124 | `update()`. 125 | 126 | Args: 127 | regex: Regular expression defining which statistics to 128 | collect. The default is to collect everything. 129 | keep_previous: Whether to retain the previous averages if no 130 | scalars were collected on a given round 131 | (default: True). 132 | """ 133 | def __init__(self, regex='.*', keep_previous=True): 134 | self._regex = re.compile(regex) 135 | self._keep_previous = keep_previous 136 | self._cumulative = dict() 137 | self._moments = dict() 138 | self.update() 139 | self._moments.clear() 140 | 141 | def names(self): 142 | r"""Returns the names of all statistics broadcasted so far that 143 | match the regular expression specified at construction time. 144 | """ 145 | return [name for name in _counters if self._regex.fullmatch(name)] 146 | 147 | def update(self): 148 | r"""Copies current values of the internal counters to the 149 | user-visible state and resets them for the next round. 150 | 151 | If `keep_previous=True` was specified at construction time, the 152 | operation is skipped for statistics that have received no scalars 153 | since the last update, retaining their previous averages. 154 | 155 | This method performs a number of GPU-to-CPU transfers and one 156 | `torch.distributed.all_reduce()`. It is intended to be called 157 | periodically in the main training loop, typically once every 158 | N training steps. 159 | """ 160 | if not self._keep_previous: 161 | self._moments.clear() 162 | for name, cumulative in _sync(self.names()): 163 | if name not in self._cumulative: 164 | self._cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 165 | delta = cumulative - self._cumulative[name] 166 | self._cumulative[name].copy_(cumulative) 167 | if float(delta[0]) != 0: 168 | self._moments[name] = delta 169 | 170 | def _get_delta(self, name): 171 | r"""Returns the raw moments that were accumulated for the given 172 | statistic between the last two calls to `update()`, or zero if 173 | no scalars were collected. 174 | """ 175 | assert self._regex.fullmatch(name) 176 | if name not in self._moments: 177 | self._moments[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 178 | return self._moments[name] 179 | 180 | def num(self, name): 181 | r"""Returns the number of scalars that were accumulated for the given 182 | statistic between the last two calls to `update()`, or zero if 183 | no scalars were collected. 184 | """ 185 | delta = self._get_delta(name) 186 | return int(delta[0]) 187 | 188 | def mean(self, name): 189 | r"""Returns the mean of the scalars that were accumulated for the 190 | given statistic between the last two calls to `update()`, or NaN if 191 | no scalars were collected. 192 | """ 193 | delta = self._get_delta(name) 194 | if int(delta[0]) == 0: 195 | return float('nan') 196 | return float(delta[1] / delta[0]) 197 | 198 | def std(self, name): 199 | r"""Returns the standard deviation of the scalars that were 200 | accumulated for the given statistic between the last two calls to 201 | `update()`, or NaN if no scalars were collected. 202 | """ 203 | delta = self._get_delta(name) 204 | if int(delta[0]) == 0 or not np.isfinite(float(delta[1])): 205 | return float('nan') 206 | if int(delta[0]) == 1: 207 | return float(0) 208 | mean = float(delta[1] / delta[0]) 209 | raw_var = float(delta[2] / delta[0]) 210 | return np.sqrt(max(raw_var - np.square(mean), 0)) 211 | 212 | def as_dict(self): 213 | r"""Returns the averages accumulated between the last two calls to 214 | `update()` as an `dnnlib.EasyDict`. The contents are as follows: 215 | 216 | dnnlib.EasyDict( 217 | NAME = dnnlib.EasyDict(num=FLOAT, mean=FLOAT, std=FLOAT), 218 | ... 219 | ) 220 | """ 221 | stats = dnnlib.EasyDict() 222 | for name in self.names(): 223 | stats[name] = dnnlib.EasyDict(num=self.num(name), mean=self.mean(name), std=self.std(name)) 224 | return stats 225 | 226 | def __getitem__(self, name): 227 | r"""Convenience getter. 228 | `collector[name]` is a synonym for `collector.mean(name)`. 229 | """ 230 | return self.mean(name) 231 | 232 | #---------------------------------------------------------------------------- 233 | 234 | def _sync(names): 235 | r"""Synchronize the global cumulative counters across devices and 236 | processes. Called internally by `Collector.update()`. 237 | """ 238 | if len(names) == 0: 239 | return [] 240 | global _sync_called 241 | _sync_called = True 242 | 243 | # Collect deltas within current rank. 244 | deltas = [] 245 | device = _sync_device if _sync_device is not None else torch.device('cpu') 246 | for name in names: 247 | delta = torch.zeros([_num_moments], dtype=_counter_dtype, device=device) 248 | for counter in _counters[name].values(): 249 | delta.add_(counter.to(device)) 250 | counter.copy_(torch.zeros_like(counter)) 251 | deltas.append(delta) 252 | deltas = torch.stack(deltas) 253 | 254 | # Sum deltas across ranks. 255 | if _sync_device is not None: 256 | torch.distributed.all_reduce(deltas) 257 | 258 | # Update cumulative values. 259 | deltas = deltas.cpu() 260 | for idx, name in enumerate(names): 261 | if name not in _cumulative: 262 | _cumulative[name] = torch.zeros([_num_moments], dtype=_counter_dtype) 263 | _cumulative[name].add_(deltas[idx]) 264 | 265 | # Return name-value pairs. 266 | return [(name, _cumulative[name]) for name in names] 267 | 268 | #---------------------------------------------------------------------------- 269 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """General-purpose training script for image-to-image translation. 2 | 3 | This script works for various models (with option '--model': e.g., pix2pix, cyclegan, colorization) and 4 | different datasets (with option '--dataset_mode': e.g., aligned, unaligned, single, colorization). 5 | You need to specify the dataset ('--dataroot'), experiment name ('--name'), and model ('--model'). 6 | 7 | It first creates model, dataset, and visualizer given the option. 8 | It then does standard network training. During the training, it also visualize/save the images, print/save the loss plot, and save models. 9 | The script supports continue/resume training. Use '--continue_train' to resume your previous training. 10 | 11 | """ 12 | import time 13 | from options.train_options import TrainOptions 14 | from data import create_dataset 15 | from models import create_model 16 | from util.visualizer import Visualizer 17 | 18 | if __name__ == '__main__': 19 | opt = TrainOptions().parse() # get training options 20 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 21 | dataset_size = len(dataset) # get the number of images in the dataset. 22 | print('The number of training images = %d' % dataset_size) 23 | 24 | model = create_model(opt) # create a model given opt.model and other options 25 | model.setup(opt) # regular setup: load and print networks; create schedulers 26 | visualizer = Visualizer(opt) # create a visualizer that display/save images and plots 27 | total_iters = 0 # the total number of training iterations 28 | 29 | for epoch in range(opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1): # outer loop for different epochs; we save the model by , + 30 | epoch_start_time = time.time() # timer for entire epoch 31 | iter_data_time = time.time() # timer for data loading per iteration 32 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 33 | visualizer.reset() # reset the visualizer: make sure it saves the results to HTML at least once every epoch 34 | model.update_learning_rate() # update learning rates in the beginning of every epoch. 35 | 36 | if opt.epoch_gan > 0 and epoch + opt.epoch_gan >= opt.n_epochs + opt.n_epochs_decay: 37 | model.Use_GAN_Loss() 38 | 39 | for i, data in enumerate(dataset): # inner loop within one epoch 40 | iter_start_time = time.time() # timer for computation per iteration 41 | if total_iters % opt.print_freq == 0: 42 | t_data = iter_start_time - iter_data_time 43 | 44 | total_iters += opt.batch_size 45 | epoch_iter += opt.batch_size 46 | model.set_input(data) # unpack data from dataset and apply preprocessing 47 | model.optimize_parameters() # calculate loss functions, get gradients, update network weights 48 | 49 | if total_iters % opt.display_freq == 0: # display images on visdom and save images to a HTML file 50 | save_result = total_iters % opt.update_html_freq == 0 51 | model.compute_visuals() 52 | visualizer.display_current_results(model.get_current_visuals(), epoch, save_result) 53 | 54 | if total_iters % opt.print_freq == 0: # print training losses and save logging information to the disk 55 | losses = model.get_current_losses() 56 | t_comp = (time.time() - iter_start_time) / opt.batch_size 57 | visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data) 58 | if opt.display_id > 0: 59 | visualizer.plot_current_losses(epoch, float(epoch_iter) / dataset_size, losses) 60 | 61 | if total_iters % opt.save_latest_freq == 0: # cache our latest model every iterations 62 | print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) 63 | save_suffix = 'iter_%d' % total_iters if opt.save_by_iter else 'latest' 64 | model.save_networks(save_suffix) 65 | 66 | iter_data_time = time.time() 67 | if epoch % opt.save_epoch_freq == 0: # cache our model every epochs 68 | print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) 69 | model.save_networks('latest') 70 | model.save_networks(epoch) 71 | 72 | print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)) 73 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /util/get_data.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os 3 | import tarfile 4 | import requests 5 | from warnings import warn 6 | from zipfile import ZipFile 7 | from bs4 import BeautifulSoup 8 | from os.path import abspath, isdir, join, basename 9 | 10 | 11 | class GetData(object): 12 | """A Python script for downloading CycleGAN or pix2pix datasets. 13 | 14 | Parameters: 15 | technique (str) -- One of: 'cyclegan' or 'pix2pix'. 16 | verbose (bool) -- If True, print additional information. 17 | 18 | Examples: 19 | >>> from util.get_data import GetData 20 | >>> gd = GetData(technique='cyclegan') 21 | >>> new_data_path = gd.get(save_path='./datasets') # options will be displayed. 22 | 23 | Alternatively, You can use bash scripts: 'scripts/download_pix2pix_model.sh' 24 | and 'scripts/download_cyclegan_model.sh'. 25 | """ 26 | 27 | def __init__(self, technique='cyclegan', verbose=True): 28 | url_dict = { 29 | 'pix2pix': 'http://efrosgans.eecs.berkeley.edu/pix2pix/datasets/', 30 | 'cyclegan': 'https://people.eecs.berkeley.edu/~taesung_park/CycleGAN/datasets' 31 | } 32 | self.url = url_dict.get(technique.lower()) 33 | self._verbose = verbose 34 | 35 | def _print(self, text): 36 | if self._verbose: 37 | print(text) 38 | 39 | @staticmethod 40 | def _get_options(r): 41 | soup = BeautifulSoup(r.text, 'lxml') 42 | options = [h.text for h in soup.find_all('a', href=True) 43 | if h.text.endswith(('.zip', 'tar.gz'))] 44 | return options 45 | 46 | def _present_options(self): 47 | r = requests.get(self.url) 48 | options = self._get_options(r) 49 | print('Options:\n') 50 | for i, o in enumerate(options): 51 | print("{0}: {1}".format(i, o)) 52 | choice = input("\nPlease enter the number of the " 53 | "dataset above you wish to download:") 54 | return options[int(choice)] 55 | 56 | def _download_data(self, dataset_url, save_path): 57 | if not isdir(save_path): 58 | os.makedirs(save_path) 59 | 60 | base = basename(dataset_url) 61 | temp_save_path = join(save_path, base) 62 | 63 | with open(temp_save_path, "wb") as f: 64 | r = requests.get(dataset_url) 65 | f.write(r.content) 66 | 67 | if base.endswith('.tar.gz'): 68 | obj = tarfile.open(temp_save_path) 69 | elif base.endswith('.zip'): 70 | obj = ZipFile(temp_save_path, 'r') 71 | else: 72 | raise ValueError("Unknown File Type: {0}.".format(base)) 73 | 74 | self._print("Unpacking Data...") 75 | obj.extractall(save_path) 76 | obj.close() 77 | os.remove(temp_save_path) 78 | 79 | def get(self, save_path, dataset=None): 80 | """ 81 | 82 | Download a dataset. 83 | 84 | Parameters: 85 | save_path (str) -- A directory to save the data to. 86 | dataset (str) -- (optional). A specific dataset to download. 87 | Note: this must include the file extension. 88 | If None, options will be presented for you 89 | to choose from. 90 | 91 | Returns: 92 | save_path_full (str) -- the absolute path to the downloaded data. 93 | 94 | """ 95 | if dataset is None: 96 | selected_dataset = self._present_options() 97 | else: 98 | selected_dataset = dataset 99 | 100 | save_path_full = join(save_path, selected_dataset.split('.')[0]) 101 | 102 | if isdir(save_path_full): 103 | warn("\n'{0}' already exists. Voiding Download.".format( 104 | save_path_full)) 105 | else: 106 | self._print('Downloading Data...') 107 | url = "{0}/{1}".format(self.url, selected_dataset) 108 | self._download_data(url, save_path=save_path) 109 | 110 | return abspath(save_path_full) 111 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 32 | with self.doc.head: 33 | meta(http_equiv="refresh", content=str(refresh)) 34 | 35 | def get_image_dir(self): 36 | """Return the directory that stores images""" 37 | return self.img_dir 38 | 39 | def add_header(self, text): 40 | """Insert a header to the HTML file 41 | 42 | Parameters: 43 | text (str) -- the header text 44 | """ 45 | with self.doc: 46 | h3(text) 47 | 48 | def add_images(self, ims, txts, links, width=400): 49 | """add images to the HTML file 50 | 51 | Parameters: 52 | ims (str list) -- a list of image paths 53 | txts (str list) -- a list of image names shown on the website 54 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 55 | """ 56 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 57 | self.doc.add(self.t) 58 | with self.t: 59 | with tr(): 60 | for im, txt, link in zip(ims, txts, links): 61 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 62 | with p(): 63 | with a(href=os.path.join('images', link)): 64 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 65 | br() 66 | p(txt) 67 | 68 | def save(self): 69 | """save the current content to the HMTL file""" 70 | html_file = '%s/index.html' % self.web_dir 71 | f = open(html_file, 'wt') 72 | f.write(self.doc.render()) 73 | f.close() 74 | 75 | 76 | if __name__ == '__main__': # we show an example usage here. 77 | html = HTML('web/', 'test_html') 78 | html.add_header('hello world') 79 | 80 | ims, txts, links = [], [], [] 81 | for n in range(4): 82 | ims.append('image_%d.png' % n) 83 | txts.append('text_%d' % n) 84 | links.append('image_%d.png' % n) 85 | html.add_images(ims, txts, links) 86 | html.save() 87 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | 4 | 5 | class ImagePool(): 6 | """This class implements an image buffer that stores previously generated images. 7 | 8 | This buffer enables us to update discriminators using a history of generated images 9 | rather than the ones produced by the latest generators. 10 | """ 11 | 12 | def __init__(self, pool_size): 13 | """Initialize the ImagePool class 14 | 15 | Parameters: 16 | pool_size (int) -- the size of image buffer, if pool_size=0, no buffer will be created 17 | """ 18 | self.pool_size = pool_size 19 | if self.pool_size > 0: # create an empty pool 20 | self.num_imgs = 0 21 | self.images = [] 22 | 23 | def query(self, images): 24 | """Return an image from the pool. 25 | 26 | Parameters: 27 | images: the latest generated images from the generator 28 | 29 | Returns images from the buffer. 30 | 31 | By 50/100, the buffer will return input images. 32 | By 50/100, the buffer will return images previously stored in the buffer, 33 | and insert the current images to the buffer. 34 | """ 35 | if self.pool_size == 0: # if the buffer size is 0, do nothing 36 | return images 37 | return_images = [] 38 | for image in images: 39 | image = torch.unsqueeze(image.data, 0) 40 | if self.num_imgs < self.pool_size: # if the buffer is not full; keep inserting current images to the buffer 41 | self.num_imgs = self.num_imgs + 1 42 | self.images.append(image) 43 | return_images.append(image) 44 | else: 45 | p = random.uniform(0, 1) 46 | if p > 0.5: # by 50% chance, the buffer will return a previously stored image, and insert the current image into the buffer 47 | random_id = random.randint(0, self.pool_size - 1) # randint is inclusive 48 | tmp = self.images[random_id].clone() 49 | self.images[random_id] = image 50 | return_images.append(tmp) 51 | else: # by another 50% chance, the buffer will return the current image 52 | return_images.append(image) 53 | return_images = torch.cat(return_images, 0) # collect all the images and return 54 | return return_images 55 | -------------------------------------------------------------------------------- /util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 24 | image_numpy = np.clip(image_numpy, -1, 1) 25 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 26 | else: # if it is a numpy array, do nothing 27 | image_numpy = input_image 28 | return image_numpy.astype(imtype) 29 | 30 | 31 | def diagnose_network(net, name='network'): 32 | """Calculate and print the mean of average absolute(gradients) 33 | 34 | Parameters: 35 | net (torch network) -- Torch network 36 | name (str) -- the name of the network 37 | """ 38 | mean = 0.0 39 | count = 0 40 | for param in net.parameters(): 41 | if param.grad is not None: 42 | mean += torch.mean(torch.abs(param.grad.data)) 43 | count += 1 44 | if count > 0: 45 | mean = mean / count 46 | print(name) 47 | print(mean) 48 | 49 | 50 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 51 | """Save a numpy image to the disk 52 | 53 | Parameters: 54 | image_numpy (numpy array) -- input numpy array 55 | image_path (str) -- the path of the image 56 | """ 57 | 58 | image_pil = Image.fromarray(image_numpy) 59 | h, w, _ = image_numpy.shape 60 | 61 | if aspect_ratio > 1.0: 62 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 63 | if aspect_ratio < 1.0: 64 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 65 | image_pil.save(image_path) 66 | 67 | 68 | def print_numpy(x, val=True, shp=False): 69 | """Print the mean, min, max, median, std, and size of a numpy array 70 | 71 | Parameters: 72 | val (bool) -- if print the values of the numpy array 73 | shp (bool) -- if print the shape of the numpy array 74 | """ 75 | x = x.astype(np.float64) 76 | if shp: 77 | print('shape,', x.shape) 78 | if val: 79 | x = x.flatten() 80 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 81 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 82 | 83 | 84 | def mkdirs(paths): 85 | """create empty directories if they don't exist 86 | 87 | Parameters: 88 | paths (str list) -- a list of directory paths 89 | """ 90 | if isinstance(paths, list) and not isinstance(paths, str): 91 | for path in paths: 92 | mkdir(path) 93 | else: 94 | mkdir(paths) 95 | 96 | 97 | def mkdir(path): 98 | """create a single empty directory if it didn't exist 99 | 100 | Parameters: 101 | path (str) -- a single directory path 102 | """ 103 | if not os.path.exists(path): 104 | os.makedirs(path) 105 | --------------------------------------------------------------------------------