├── .gitignore ├── License.txt ├── README.md ├── data ├── __init__.py ├── base_data_loader.py ├── base_dataset.py ├── custom_dataset_data_loader.py ├── data_loader.py ├── fewshot_face_dataset.py ├── fewshot_pose_dataset.py ├── fewshot_street_dataset.py ├── image_folder.py ├── keypoint2img.py ├── lmdb_dataset.py └── preprocess │ ├── download_youTube_playlist.py │ ├── preprocess.py │ ├── util │ ├── check_valid.py │ ├── get_poses.py │ ├── track.py │ └── util.py │ └── youTube_playlist.txt ├── imgs ├── dance.gif ├── face.gif ├── illustration.gif ├── mona_lisa.gif ├── statue.gif └── street.gif ├── models ├── __init__.py ├── base_model.py ├── face_refiner.py ├── flownet.py ├── input_process.py ├── loss_collector.py ├── models.py ├── networks │ ├── __init__.py │ ├── architecture.py │ ├── base_network.py │ ├── discriminator.py │ ├── flownet2_pytorch │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── __init__.py │ │ ├── convert.py │ │ ├── datasets.py │ │ ├── install.sh │ │ ├── launch_docker.sh │ │ ├── losses.py │ │ ├── main.py │ │ ├── models.py │ │ ├── networks │ │ │ ├── FlowNetC.py │ │ │ ├── FlowNetFusion.py │ │ │ ├── FlowNetS.py │ │ │ ├── FlowNetSD.py │ │ │ ├── __init__.py │ │ │ ├── channelnorm_package │ │ │ │ ├── __init__.py │ │ │ │ ├── channelnorm.py │ │ │ │ ├── channelnorm_cuda.cc │ │ │ │ ├── channelnorm_kernel.cu │ │ │ │ ├── channelnorm_kernel.cuh │ │ │ │ └── setup.py │ │ │ ├── correlation_package │ │ │ │ ├── __init__.py │ │ │ │ ├── correlation.py │ │ │ │ ├── correlation_cuda.cc │ │ │ │ ├── correlation_cuda_kernel.cu │ │ │ │ ├── correlation_cuda_kernel.cuh │ │ │ │ └── setup.py │ │ │ ├── resample2d_package │ │ │ │ ├── __init__.py │ │ │ │ ├── resample2d.py │ │ │ │ ├── resample2d_cuda.cc │ │ │ │ ├── resample2d_kernel.cu │ │ │ │ ├── resample2d_kernel.cuh │ │ │ │ └── setup.py │ │ │ └── submodules.py │ │ ├── run-caffe2pytorch.sh │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── flow_utils.py │ │ │ ├── frame_utils.py │ │ │ ├── param_utils.py │ │ │ └── tools.py │ ├── generator.py │ ├── loss.py │ ├── normalization.py │ ├── sync_batchnorm │ │ ├── __init__.py │ │ ├── batchnorm.py │ │ ├── comm.py │ │ ├── replicate.py │ │ └── unittest.py │ └── vgg.py ├── trainer.py └── vid2vid_model.py ├── options ├── __init__.py ├── base_options.py ├── test_options.py └── train_options.py ├── scripts ├── download_datasets.py ├── download_flownet2.py ├── download_gdrive.py ├── face │ ├── test_256.sh │ ├── test_512.sh │ ├── train_g1_256.sh │ ├── train_g8_256.sh │ └── train_g8_512.sh ├── pose │ ├── test.sh │ ├── train_g1.sh │ └── train_g8.sh └── street │ ├── test.sh │ ├── train_g1.sh │ └── train_g8.sh ├── test.py ├── train.py └── util ├── __init__.py ├── distributed.py ├── html.py ├── image_pool.py ├── util.py └── visualizer.py /.gitignore: -------------------------------------------------------------------------------- 1 | debug* 2 | checkpoints/ 3 | datasets/ 4 | models/debug* 5 | models/networks/flownet2*/networks/*/*egg-info 6 | models/networks/flownet2*/networks/*/build 7 | models/networks/flownet2*/networks/*/__pycache__ 8 | models/networks/flownet2*/networks/*/dist 9 | models/networks/flownet2*/*.pth.tar 10 | results/ 11 | build/ 12 | .idea/ 13 | */Thumbs.db 14 | */**/__pycache__ 15 | */*.pyc 16 | */**/*.pyc 17 | */**/**/*.pyc 18 | */**/**/**/*.pyc 19 | */**/**/**/**/*.pyc 20 | */*.so* 21 | */**/*.so* 22 | */**/*.dylib* 23 | *.DS_Store 24 | *~ 25 | -------------------------------------------------------------------------------- /License.txt: -------------------------------------------------------------------------------- 1 | Nvidia Source Code License (1-Way Commercial) – NVIDIA CONFIDENTIAL 2 | 3 | 1. Definitions 4 | 5 | “Licensor” means any person or entity that distributes its Work. 6 | “Software” means the original work of authorship made available under this License. 7 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License. 8 | “Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates. 9 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. 10 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License. 11 | 12 | 2. License Grants 13 | 14 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. 15 | 16 | 2.2 Patent Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free patent license to make, have made, use, sell, offer for sale, import, and otherwise transfer its Work, in whole or in part. The foregoing license applies only to the patent claims licensable by Licensor that would be infringed by Licensor’s Work (or portion thereof) individually and excluding any combinations with any other materials or technology. 17 | 18 | 3. Limitations 19 | 20 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. 21 | 22 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. 23 | 24 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or it’s affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only. 25 | 26 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately. 27 | 28 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License. 29 | 30 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately. 31 | 32 | 4. Disclaimer of Warranty. 33 | 34 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. SOME STATES’ CONSUMER LAWS DO NOT ALLOW EXCLUSION OF AN IMPLIED WARRANTY, SO THIS DISCLAIMER MAY NOT APPLY TO YOU. 35 | 36 | 5. Limitation of Liability. 37 | 38 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. 39 | 40 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import importlib 8 | from data.base_dataset import BaseDataset 9 | from util.distributed import master_only_print as print 10 | 11 | def find_dataset_using_name(dataset_name): 12 | # Given the option --dataset [datasetname], 13 | # the file "datasets/datasetname_dataset.py" 14 | # will be imported. 15 | dataset_filename = "data." + dataset_name + "_dataset" 16 | datasetlib = importlib.import_module(dataset_filename) 17 | 18 | # In the file, the class called DatasetNameDataset() will 19 | # be instantiated. It has to be a subclass of BaseDataset, 20 | # and it is case-insensitive. 21 | dataset = None 22 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 23 | for name, cls in datasetlib.__dict__.items(): 24 | if name.lower() == target_dataset_name.lower() \ 25 | and issubclass(cls, BaseDataset): 26 | dataset = cls 27 | 28 | if dataset is None: 29 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 30 | exit(0) 31 | 32 | return dataset 33 | 34 | 35 | def get_option_setter(dataset_name): 36 | dataset_class = find_dataset_using_name(dataset_name) 37 | return dataset_class.modify_commandline_options 38 | 39 | 40 | def create_dataset(opt): 41 | dataset = find_dataset_using_name(opt.dataset_mode) 42 | instance = dataset() 43 | instance.initialize(opt) 44 | print("dataset [%s] was created" % (instance.name())) 45 | return instance 46 | -------------------------------------------------------------------------------- /data/base_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | class BaseDataLoader(): 8 | def __init__(self): 9 | pass 10 | 11 | def initialize(self, opt): 12 | self.opt = opt 13 | pass 14 | 15 | def load_data(): 16 | return None 17 | 18 | 19 | 20 | -------------------------------------------------------------------------------- /data/base_dataset.py: -------------------------------------------------------------------------------- 1 | ### Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 2 | ### Licensed under the Nvidia Source Code License. 3 | 4 | from PIL import Image 5 | import numpy as np 6 | import random 7 | import torch 8 | import torch.utils.data as data 9 | import torchvision.transforms as transforms 10 | from util.distributed import master_only_print as print 11 | 12 | class BaseDataset(data.Dataset): 13 | def __init__(self): 14 | super(BaseDataset, self).__init__() 15 | self.L = self.I = self.Lr = self.Ir = None 16 | self.n_frames_total = 1 # current number of frames to train in a single iteration 17 | self.use_lmdb = False 18 | 19 | def name(self): 20 | return 'BaseDataset' 21 | 22 | def update_training_batch(self, ratio): 23 | # update the training sequence length to be longer 24 | seq_len_max = 30 25 | if self.n_frames_total < seq_len_max: 26 | self.n_frames_total = min(seq_len_max, self.opt.n_frames_total * (2**ratio)) 27 | print('--- Updating training sequence length to %d ---' % self.n_frames_total) 28 | 29 | def read_data(self, path, lmdb=None, data_type='img'): 30 | is_img = data_type == 'img' 31 | if self.use_lmdb and lmdb is not None: 32 | img, _ = lmdb.getitem_by_path(path.encode(), is_img) 33 | if is_img and len(img.mode) == 3: 34 | b, g, r = img.split() 35 | img = Image.merge("RGB", (r, g, b)) 36 | elif data_type == 'np': 37 | img = img.decode() 38 | img = np.array([[int(j) for j in i.split(',')] for i in img.splitlines()]) 39 | elif is_img: 40 | img = Image.open(path) 41 | elif data_type == 'np': 42 | img = np.loadtxt(path, delimiter=',') 43 | else: 44 | img = path 45 | return img 46 | 47 | def crop(self, img, coords): 48 | min_y, max_y, min_x, max_x = coords 49 | if isinstance(img, np.ndarray): 50 | return img[min_y:max_y, min_x:max_x] 51 | else: 52 | return img.crop((min_x, min_y, max_x, max_y)) 53 | 54 | def concat_frame(self, A, Ai, n=100): 55 | if A is None or Ai.shape[0] >= n: return Ai[-n:] 56 | else: return torch.cat([A, Ai])[-n:] 57 | 58 | def concat(self, tensors, dim=0): 59 | tensors = [t for t in tensors if t is not None] 60 | return torch.cat(tensors, dim) 61 | 62 | def get_img_params(opt, size): 63 | w, h = size 64 | new_w, new_h = w, h 65 | 66 | # resize input image 67 | if 'resize' in opt.resize_or_crop: 68 | new_h = new_w = opt.loadSize 69 | else: 70 | if 'scale_width' in opt.resize_or_crop: 71 | new_w = opt.loadSize 72 | elif 'random_scale' in opt.resize_or_crop: 73 | new_w = random.randrange(int(opt.fineSize), int(1.2*opt.fineSize)) 74 | new_h = int(new_w * h) // w 75 | if 'crop' not in opt.resize_or_crop: 76 | new_h = int(new_w // opt.aspect_ratio) 77 | new_w = new_w // 4 * 4 78 | new_h = new_h // 4 * 4 79 | 80 | # crop resized image 81 | size_x = min(opt.loadSize, opt.fineSize) 82 | size_y = size_x // opt.aspect_ratio 83 | if not opt.isTrain: # crop central region 84 | pos_x = (new_w - size_x) // 2 85 | pos_y = (new_h - size_y) // 2 86 | else: # crop random region 87 | pos_x = random.randrange(np.maximum(1, new_w - size_x)) 88 | pos_y = random.randrange(np.maximum(1, new_h - size_y)) 89 | 90 | # for color augmentation 91 | h_b = random.uniform(-30, 30) 92 | s_a = random.uniform(0.8, 1.2) 93 | s_b = random.uniform(-10, 10) 94 | v_a = random.uniform(0.8, 1.2) 95 | v_b = random.uniform(-10, 10) 96 | 97 | flip = random.random() > 0.5 98 | return {'new_size': (new_w, new_h), 'crop_pos': (pos_x, pos_y), 'crop_size': (size_x, size_y), 'flip': flip, 99 | 'color_aug': (h_b, s_a, s_b, v_a, v_b)} 100 | 101 | def get_video_params(opt, n_frames_total, cur_seq_len, index): 102 | if opt.isTrain: 103 | n_frames_total = min(cur_seq_len, n_frames_total) # total number of frames to load 104 | max_t_step = min(opt.max_t_step, (cur_seq_len-1) // max(1, (n_frames_total-1))) 105 | t_step = random.randrange(max_t_step) + 1 # spacing between neighboring sampled frames 106 | 107 | offset_max = max(1, cur_seq_len - (n_frames_total-1)*t_step) # maximum possible frame index for the first frame 108 | if 'pose' in opt.dataset_mode: 109 | start_idx = index % offset_max # offset for the first frame to load 110 | max_range, min_range = 60, 14 # range for possible reference frames 111 | else: 112 | start_idx = random.randrange(offset_max) # offset for the first frame to load 113 | max_range, min_range = 300, 14 # range for possible reference frames 114 | 115 | ref_range = list(range(max(0, start_idx - max_range), max(1, start_idx - min_range))) \ 116 | + list(range(min(start_idx + min_range, cur_seq_len - 1), min(start_idx + max_range, cur_seq_len))) 117 | ref_indices = random.sample(ref_range, opt.n_shot) # indices for reference frames 118 | 119 | else: 120 | n_frames_total = 1 121 | start_idx = index 122 | t_step = 1 123 | ref_indices = opt.ref_img_id.split(',') 124 | ref_indices = [int(i) for i in ref_indices] 125 | 126 | return n_frames_total, start_idx, t_step, ref_indices 127 | 128 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True, color_aug=False): 129 | transform_list = [] 130 | transform_list.append(transforms.Lambda(lambda img: __scale_image(img, params['new_size'], method))) 131 | 132 | if 'crop' in opt.resize_or_crop: 133 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], params['crop_size']))) 134 | 135 | if opt.isTrain and color_aug: 136 | transform_list.append(transforms.Lambda(lambda img: __color_aug(img, params['color_aug']))) 137 | 138 | if opt.isTrain and not opt.no_flip: 139 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 140 | 141 | if toTensor: 142 | transform_list += [transforms.ToTensor()] 143 | 144 | if normalize: 145 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), 146 | (0.5, 0.5, 0.5))] 147 | return transforms.Compose(transform_list) 148 | 149 | def __scale_image(img, size, method=Image.BICUBIC): 150 | w, h = size 151 | return img.resize((w, h), method) 152 | 153 | def __crop(img, pos, size): 154 | ow, oh = img.size 155 | x1, y1 = pos 156 | tw, th = size 157 | return img.crop((x1, y1, x1 + tw, y1 + th)) 158 | 159 | def __flip(img, flip): 160 | if flip: 161 | return img.transpose(Image.FLIP_LEFT_RIGHT) 162 | return img 163 | 164 | def __color_aug(img, params): 165 | h, s, v = img.convert('HSV').split() 166 | h = h.point(lambda i: (i + params[0]) % 256) 167 | s = s.point(lambda i: min(255, max(0, i * params[1] + params[2]))) 168 | v = v.point(lambda i: min(255, max(0, i * params[3] + params[4]))) 169 | img = Image.merge('HSV', (h, s, v)).convert('RGB') 170 | return img -------------------------------------------------------------------------------- /data/custom_dataset_data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import torch.utils.data 8 | import torch.distributed as dist 9 | from data.base_data_loader import BaseDataLoader 10 | import data 11 | 12 | class CustomDatasetDataLoader(BaseDataLoader): 13 | def name(self): 14 | return 'CustomDatasetDataLoader' 15 | 16 | def initialize(self, opt): 17 | BaseDataLoader.initialize(self, opt) 18 | self.dataset = data.create_dataset(opt) 19 | if dist.is_initialized(): 20 | sampler = torch.utils.data.distributed.DistributedSampler(self.dataset) 21 | else: 22 | sampler = None 23 | 24 | self.dataloader = torch.utils.data.DataLoader( 25 | self.dataset, 26 | batch_size=opt.batchSize, 27 | shuffle=(sampler is None) and not opt.serial_batches, 28 | sampler=sampler, 29 | pin_memory=True, 30 | num_workers=int(opt.nThreads), 31 | drop_last=True 32 | ) 33 | 34 | def load_data(self): 35 | return self.dataloader 36 | 37 | def __len__(self): 38 | size = min(len(self.dataset), self.opt.max_dataset_size) 39 | ngpus = len(self.opt.gpu_ids) 40 | round_to_ngpus = (size // ngpus) * ngpus 41 | return round_to_ngpus 42 | -------------------------------------------------------------------------------- /data/data_loader.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | from util.distributed import master_only_print as print 8 | def CreateDataLoader(opt): 9 | from data.custom_dataset_data_loader import CustomDatasetDataLoader 10 | data_loader = CustomDatasetDataLoader() 11 | print(data_loader.name()) 12 | data_loader.initialize(opt) 13 | return data_loader 14 | -------------------------------------------------------------------------------- /data/fewshot_street_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os.path as path 8 | import glob 9 | import torch 10 | from PIL import Image 11 | import numpy as np 12 | 13 | from data.base_dataset import BaseDataset, get_img_params, get_video_params, get_transform 14 | from data.image_folder import make_dataset, make_grouped_dataset, check_path_valid 15 | 16 | class FewshotStreetDataset(BaseDataset): 17 | @staticmethod 18 | def modify_commandline_options(parser, is_train): 19 | parser.set_defaults(dataroot='datasets/street/') 20 | parser.add_argument('--label_nc', type=int, default=20, help='# of input label channels') 21 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels') 22 | parser.add_argument('--aspect_ratio', type=float, default=2) 23 | parser.set_defaults(resize_or_crop='random_scale_and_crop') 24 | parser.set_defaults(niter=20) 25 | parser.set_defaults(niter_single=10) 26 | parser.set_defaults(niter_step=2) 27 | parser.set_defaults(save_epoch_freq=1) 28 | 29 | ### for inference 30 | parser.add_argument('--seq_path', type=str, default='datasets/street/test_images/01/', help='path to the driving sequence') 31 | parser.add_argument('--ref_img_path', type=str, default='datasets/street/test_images/02/', help='path to the reference image') 32 | parser.add_argument('--ref_img_id', type=str, default='0', help='indices of reference frames') 33 | return parser 34 | 35 | def initialize(self, opt): 36 | self.opt = opt 37 | root = opt.dataroot 38 | self.L_is_label = self.opt.label_nc != 0 39 | 40 | if opt.isTrain: 41 | self.L_paths = sorted(make_grouped_dataset(path.join(root, 'train_labels'))) 42 | self.I_paths = sorted(make_grouped_dataset(path.join(root, 'train_images'))) 43 | check_path_valid(self.L_paths, self.I_paths) 44 | 45 | self.n_of_seqs = len(self.L_paths) 46 | print('%d sequences' % self.n_of_seqs) 47 | else: 48 | self.I_paths = sorted(make_dataset(opt.seq_path)) 49 | self.L_paths = sorted(make_dataset(opt.seq_path.replace('images', 'labels'))) 50 | self.ref_I_paths = sorted(make_dataset(opt.ref_img_path)) 51 | self.ref_L_paths = sorted(make_dataset(opt.ref_img_path.replace('images', 'labels'))) 52 | 53 | def __getitem__(self, index): 54 | opt = self.opt 55 | if opt.isTrain: 56 | L_paths = self.L_paths[index % self.n_of_seqs] 57 | I_paths = self.I_paths[index % self.n_of_seqs] 58 | ref_L_paths, ref_I_paths = L_paths, I_paths 59 | else: 60 | L_paths, I_paths = self.L_paths, self.I_paths 61 | ref_L_paths, ref_I_paths = self.ref_L_paths, self.ref_I_paths 62 | 63 | 64 | ### setting parameters 65 | n_frames_total, start_idx, t_step, ref_indices = get_video_params(opt, self.n_frames_total, len(I_paths), index) 66 | w, h = opt.fineSize, int(opt.fineSize / opt.aspect_ratio) 67 | img_params = get_img_params(opt, (w, h)) 68 | is_first_frame = opt.isTrain or index == 0 69 | 70 | transform_I = get_transform(opt, img_params, color_aug=opt.isTrain) 71 | transform_L = get_transform(opt, img_params, method=Image.NEAREST, normalize=False) if self.L_is_label else transform_I 72 | 73 | 74 | ### read in reference image 75 | Lr, Ir = self.Lr, self.Ir 76 | if is_first_frame: 77 | for idx in ref_indices: 78 | Li = self.get_image(ref_L_paths[idx], transform_L, is_label=self.L_is_label) 79 | Ii = self.get_image(ref_I_paths[idx], transform_I) 80 | Lr = self.concat_frame(Lr, Li.unsqueeze(0)) 81 | Ir = self.concat_frame(Ir, Ii.unsqueeze(0)) 82 | 83 | if not opt.isTrain: # keep track of non-changing variables during inference 84 | self.Lr, self.Ir = Lr, Ir 85 | 86 | 87 | ### read in target images 88 | L, I = self.L, self.I 89 | for t in range(n_frames_total): 90 | idx = start_idx + t * t_step 91 | Lt = self.get_image(L_paths[idx], transform_L, is_label=self.L_is_label) 92 | It = self.get_image(I_paths[idx], transform_I) 93 | L = self.concat_frame(L, Lt.unsqueeze(0)) 94 | I = self.concat_frame(I, It.unsqueeze(0)) 95 | 96 | if not opt.isTrain: 97 | self.L, self.I = L, I 98 | 99 | seq = path.basename(path.dirname(opt.ref_img_path)) + '-' + opt.ref_img_id + '_' + path.basename(path.dirname(opt.seq_path)) 100 | 101 | return_list = {'tgt_label': L, 'tgt_image': I, 'ref_label': Lr, 'ref_image': Ir, 102 | 'path': I_paths[idx], 'seq': seq} 103 | return return_list 104 | 105 | def get_image(self, A_path, transform_scaleA, is_label=False): 106 | if is_label: return self.get_label_tensor(A_path, transform_scaleA) 107 | A_img = self.read_data(A_path) 108 | A_scaled = transform_scaleA(A_img) 109 | return A_scaled 110 | 111 | def get_label_tensor(self, label_path, transform_label): 112 | label = self.read_data(label_path).convert('L') 113 | 114 | train2eval = self.opt.label_nc == 20 115 | if train2eval: 116 | ### 35 to 20 117 | A_label_np = np.array(label) 118 | label_mapping = np.array([19, 19, 19, 19, 19, 19, 19, 0, 1, 19, 19, 2, 3, 4, 19, 19, 19, 5, 19, 119 | 6, 7, 8, 9, 18, 10, 11, 12, 13, 14, 19, 19, 15, 16, 17, 19], dtype=np.uint8) 120 | A_label_np = label_mapping[A_label_np] 121 | label = Image.fromarray(A_label_np) 122 | 123 | label_tensor = transform_label(label) * 255.0 124 | return label_tensor 125 | 126 | def __len__(self): 127 | if not self.opt.isTrain: return len(self.L_paths) 128 | return max(10000, sum([len(L) for L in self.L_paths])) 129 | 130 | def name(self): 131 | return 'StreetDataset' -------------------------------------------------------------------------------- /data/image_folder.py: -------------------------------------------------------------------------------- 1 | ############################################################################### 2 | # Code from 3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 4 | # Modified the original code so that it also loads images from the current 5 | # directory as well as the subdirectories 6 | ############################################################################### 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import os 10 | 11 | IMG_EXTENSIONS = [ 12 | '.jpg', '.JPG', '.jpeg', '.JPEG', 13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp', 14 | '.txt', '.json', 15 | ] 16 | 17 | 18 | def is_image_file(filename): 19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 20 | 21 | def make_dataset_rec(dir, images): 22 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 23 | 24 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)): 25 | for fname in fnames: 26 | if is_image_file(fname): 27 | path = os.path.join(root, fname) 28 | images.append(path) 29 | # for dname in dnames: 30 | # make_dataset_rec(os.path.join(root, dname), images) 31 | 32 | 33 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False): 34 | images = [] 35 | 36 | if read_cache: 37 | possible_filelist = os.path.join(dir, 'files.list') 38 | if os.path.isfile(possible_filelist): 39 | with open(possible_filelist, 'r') as f: 40 | images = f.read().splitlines() 41 | return images 42 | 43 | if recursive: 44 | make_dataset_rec(dir, images) 45 | else: 46 | #assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir 47 | 48 | for root, dnames, fnames in sorted(os.walk(dir)): 49 | for fname in fnames: 50 | if is_image_file(fname): 51 | path = os.path.join(root, fname) 52 | images.append(path) 53 | 54 | if write_cache: 55 | filelist_cache = os.path.join(dir, 'files.list') 56 | with open(filelist_cache, 'w') as f: 57 | for path in images: 58 | f.write("%s\n" % path) 59 | print('wrote filelist cache at %s' % filelist_cache) 60 | 61 | return images 62 | 63 | def make_grouped_dataset(dir): 64 | images = [] 65 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 66 | fnames = sorted(os.walk(dir)) 67 | for fname in sorted(fnames): 68 | paths = [] 69 | root = fname[0] 70 | for f in sorted(fname[2]): 71 | if is_image_file(f): 72 | paths.append(os.path.join(root, f)) 73 | if len(paths) > 0: 74 | images.append(paths) 75 | return images 76 | 77 | def check_path_valid(A_paths, B_paths): 78 | if len(A_paths) != len(B_paths): 79 | print('%s not equal to %s' % (A_paths[0], B_paths[0])) 80 | assert(len(A_paths) == len(B_paths)) 81 | 82 | if isinstance(A_paths[0], list): 83 | for a, b in zip(A_paths, B_paths): 84 | if len(a) != len(b): 85 | print('%s not equal to %s' % (a[0], b[0])) 86 | assert(len(a) == len(b)) 87 | 88 | def default_loader(path): 89 | return Image.open(path).convert('RGB') 90 | 91 | 92 | class ImageFolder(data.Dataset): 93 | 94 | def __init__(self, root, transform=None, return_paths=False, 95 | loader=default_loader): 96 | imgs = make_dataset(root) 97 | if len(imgs) == 0: 98 | raise(RuntimeError("Found 0 images in: " + root + "\n" 99 | "Supported image extensions are: " + 100 | ",".join(IMG_EXTENSIONS))) 101 | 102 | self.root = root 103 | self.imgs = imgs 104 | self.transform = transform 105 | self.return_paths = return_paths 106 | self.loader = loader 107 | 108 | def __getitem__(self, index): 109 | path = self.imgs[index] 110 | img = self.loader(path) 111 | if self.transform is not None: 112 | img = self.transform(img) 113 | if self.return_paths: 114 | return img, path 115 | else: 116 | return img 117 | 118 | def __len__(self): 119 | return len(self.imgs) 120 | -------------------------------------------------------------------------------- /data/lmdb_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import lmdb 9 | import pickle 10 | import numpy as np 11 | from PIL import Image 12 | import cv2 13 | import torch.utils.data as data 14 | from util.distributed import master_only_print as print 15 | 16 | class LMDBDataset(data.Dataset): 17 | def __init__(self, root, write_cache=False): 18 | self.root = os.path.expanduser(root) 19 | self.env = lmdb.open(root, max_readers=126, readonly=True, lock=False, 20 | readahead=False, meminit=False) 21 | with self.env.begin(write=False) as txn: 22 | self.length = txn.stat()['entries'] 23 | print('LMDB file at %s opened.' % root) 24 | cache_file = os.path.join(root, '_cache_') 25 | if os.path.isfile(cache_file): 26 | self.keys = pickle.load(open(cache_file, "rb")) 27 | elif write_cache: 28 | print('generating keys') 29 | with self.env.begin(write=False) as txn: 30 | self.keys = [key for key, _ in txn.cursor()] 31 | pickle.dump(self.keys, open(cache_file, "wb")) 32 | print('cache file generated at %s' % cache_file) 33 | else: 34 | self.keys = [] 35 | 36 | def getitem_by_path(self, path, is_img=True): 37 | env = self.env 38 | with env.begin(write=False) as txn: 39 | buf = txn.get(path) 40 | if is_img: 41 | img = cv2.imdecode(np.fromstring(buf, dtype=np.uint8), 1) 42 | img = Image.fromarray(img) 43 | return img, path 44 | return buf, path 45 | 46 | def __getitem__(self, index): 47 | path = self.keys[index] 48 | return self.getitem_by_path(path) 49 | 50 | def __len__(self): 51 | return self.length 52 | -------------------------------------------------------------------------------- /data/preprocess/download_youTube_playlist.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from pytube import YouTube, Playlist 3 | 4 | 5 | # Example script to download youtube scripts. 6 | class MyPlaylist(Playlist): 7 | def download_all( 8 | self, 9 | download_path=None, 10 | prefix_number=True, 11 | reverse_numbering=False, 12 | idx=0, 13 | ): 14 | self.populate_video_urls() 15 | prefix_gen = self._path_num_prefix_generator(reverse_numbering) 16 | for i, link in enumerate(self.video_urls): 17 | prefix = '%03d_%03d_' % (idx + 1, i + 1) 18 | p = glob.glob(prefix[:-1] + '*.mp4') 19 | print(prefix, link) 20 | if not p: 21 | try: 22 | yt = YouTube(link) 23 | dl_stream = yt.streams.filter(adaptive=True, subtype='mp4').first() 24 | dl_stream.download(download_path, filename_prefix=prefix) 25 | except: 26 | print('cannot download') 27 | pass 28 | 29 | playlist_path = 'youTube_playlist.txt' 30 | with open(playlist_path, 'r') as f: 31 | playlists = f.read().splitlines() 32 | 33 | for i, playlist in enumerate(playlists): 34 | pl = MyPlaylist(playlist) 35 | pl.download_all(idx=i) 36 | -------------------------------------------------------------------------------- /data/preprocess/preprocess.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import glob 9 | import os.path as path 10 | import argparse 11 | import json 12 | from tqdm import tqdm 13 | 14 | from util.get_poses import extract_all_frames, run_densepose, run_openpose 15 | from util.check_valid import remove_invalid_frames, remove_static_frames, \ 16 | remove_isolated_frames, check_densepose_exists 17 | from util.track import divide_sequences 18 | from util.util import remove_folder 19 | 20 | 21 | def parse_args(): 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--steps', default='all', 24 | help='all | extract_frames | openpose | densepose | clean | divide_sequences') 25 | parser.add_argument('--video_root', required=True, 26 | help='path for videos to process') 27 | parser.add_argument('--output_root', required=True, 28 | help='path for output images') 29 | 30 | parser.add_argument('--img_folder', default='images') 31 | parser.add_argument('--openpose_folder', default='openpose') 32 | parser.add_argument('--openpose_postfix', default='_keypoints.json') 33 | parser.add_argument('--densepose_folder', default='densepose') 34 | parser.add_argument('--densepose_postfix', default='_IUV.png') 35 | parser.add_argument('--densemask_folder', default='densemask') 36 | parser.add_argument('--densemask_postfix', default='_INDS.png') 37 | parser.add_argument('--track_folder', default='tracking') 38 | 39 | parser.add_argument('--openpose_root', default='/', 40 | help='root for the OpenPose library') 41 | parser.add_argument('--densepose_root', default='/', 42 | help='root for the DensePose library') 43 | 44 | parser.add_argument('--n_skip_frames', type=int, default='100', 45 | help='Number of frames between keyframes. A larger ' 46 | 'number can expedite processing but may lose data') 47 | parser.add_argument('--min_n_of_frames', type=int, default='30', 48 | help='Minimum number of frames in the output sequence.') 49 | 50 | args = parser.parse_args() 51 | return args 52 | 53 | 54 | def rename_videos(video_root): 55 | video_paths = sorted(glob.glob(video_root + '/*.mp4')) 56 | for i, video_path in enumerate(video_paths): 57 | new_path = video_root + ('/%05d.mp4' % i) 58 | os.rename(video_path, new_path) 59 | 60 | 61 | # Remove frames that are not suitable for training. 62 | def remove_unusable_frames(args, video_idx): 63 | remove_invalid_frames(args, video_idx) 64 | check_densepose_exists(args, video_idx) 65 | remove_static_frames(args, video_idx) 66 | remove_isolated_frames(args, video_idx) 67 | video_path = path.join(args.output_root, args.img_folder, video_idx) 68 | if len(os.listdir(video_path)) == 0: 69 | remove_folder(args, video_idx) 70 | 71 | 72 | if __name__ == "__main__": 73 | args = parse_args() 74 | if args.steps == 'all': 75 | args.steps = 'openpose,densepose,clean,divide_sequences' 76 | 77 | if 'extract_frames' in args.steps or 'openpose' in args.steps or \ 78 | 'densepose' in args.steps: 79 | rename_videos(args.video_root) 80 | video_paths = sorted(glob.glob(args.video_root + '/*.mp4')) 81 | for video_path in tqdm(video_paths): 82 | if 'extract_frames' in args.steps: 83 | # Only extract frames from all the videos. 84 | extract_all_frames(args, video_path) 85 | if 'openpose' in args.steps: 86 | # Include extracting frames and running OpenPose. 87 | run_openpose(args, video_path) 88 | if 'densepose' in args.steps: 89 | # Run DensePose. 90 | video_idx = path.basename(video_path).split('.')[0] 91 | run_densepose(args, video_idx) 92 | 93 | if 'clean' in args.steps: 94 | # Frames already extracted and openpose / densepose already run, only 95 | # remove the unusable frames in the dataset. 96 | # Note that the folder structure should be 97 | # [output_root] / [img_folder] / [sequences] / [frames], and the 98 | # names of the frames must be in format of 99 | # 'frame000000.jpg', 'frame000001.jpg', ... 100 | video_indices = sorted(glob.glob(path.join(args.output_root, 101 | args.img_folder, '*'))) 102 | video_indices = [path.basename(p) for p in video_indices] 103 | 104 | # Remove all unusable frames in the sequences. 105 | for i, video_idx in enumerate(tqdm(video_indices)): 106 | remove_unusable_frames(args, video_idx) 107 | 108 | if 'divide_sequences' in args.steps: 109 | # Finally, divide the remaining sequences into sub-sequences, where 110 | # each seb-sequence only contains one person. 111 | video_indices = sorted(glob.glob(path.join(args.output_root, 112 | args.img_folder, '*'))) 113 | video_indices = [path.basename(p) for p in video_indices] 114 | seq_indices = [] 115 | start_frame_indices, end_frame_indices, ppl_indices = [], [], [] 116 | for i, video_idx in enumerate(tqdm(video_indices)): 117 | start_frame_indices_i, end_frame_indices_i, ppl_indices_i = \ 118 | divide_sequences(args, video_idx) 119 | seq_indices += [i] * len(start_frame_indices_i) 120 | start_frame_indices += start_frame_indices_i 121 | end_frame_indices += end_frame_indices_i 122 | ppl_indices += ppl_indices_i 123 | 124 | output = dict() 125 | output['seq_indices'] = seq_indices 126 | output['start_frame_indices'] = start_frame_indices 127 | output['end_frame_indices'] = end_frame_indices 128 | output['ppl_indices'] = ppl_indices 129 | output_path = path.join(args.output_root, 'all_subsequences.json') 130 | with open(output_path, 'w') as fp: 131 | json.dump(output, fp, indent=4) 132 | -------------------------------------------------------------------------------- /data/preprocess/util/check_valid.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import glob 9 | import os.path as path 10 | import json 11 | 12 | from util.util import remove_frame, get_keypoint_array, get_frame_idx, \ 13 | get_valid_openpose_keypoints 14 | 15 | 16 | # Remove invalid frames in the video. 17 | def remove_invalid_frames(args, video_idx): 18 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx) 19 | json_paths = sorted(glob.glob(op_dir + '/*.json')) 20 | for json_path in json_paths: 21 | if not is_valid_frame(args, json_path): 22 | remove_frame(args, start=json_path) 23 | 24 | 25 | # Remove static frames in the video if no motion is detected more than 26 | # max_static_frames. 27 | def remove_static_frames(args, video_idx): 28 | max_static_frames = 5 # maximum number of frames to be static 29 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx) 30 | json_paths = sorted(glob.glob(op_dir + '/*.json')) 31 | start_idx = end_idx = 0 32 | keypoint_dicts_prev = None 33 | 34 | for json_path in json_paths: 35 | with open(json_path, encoding='utf-8') as f: 36 | keypoint_dicts = json.loads(f.read())["people"] 37 | is_moving = detect_motion(keypoint_dicts_prev, keypoint_dicts) 38 | keypoint_dicts_prev = keypoint_dicts 39 | 40 | i = get_frame_idx(json_path) 41 | if not is_moving: 42 | end_idx = i 43 | else: 44 | # If static frames longer than max_static_frames, remove them. 45 | if (end_idx - start_idx) > max_static_frames: 46 | remove_frame(args, video_idx, start_idx, end_idx) 47 | start_idx = end_idx = i 48 | 49 | 50 | # Remove small batch frames if number of consecutive frames is smaller than 51 | # min_n_of_frames. 52 | def remove_isolated_frames(args, video_idx): 53 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx) 54 | json_paths = sorted(glob.glob(op_dir + '/*.json')) 55 | 56 | if len(json_paths): 57 | start_idx = end_idx = get_frame_idx(json_paths[0]) - 1 58 | for json_path in json_paths: 59 | i = get_frame_idx(json_path) 60 | # If the frames are not consecutive, there's a breakpoint. 61 | if i != end_idx + 1: 62 | # Check if this block of frames is longer than min_n_of_frames. 63 | if (end_idx - start_idx) < args.min_n_of_frames: 64 | remove_frame(args, video_idx, start_idx, end_idx) 65 | start_idx = i 66 | end_idx = i 67 | # Need to check again at the end of sequence. 68 | if (end_idx - start_idx) < args.min_n_of_frames: 69 | remove_frame(args, video_idx, start_idx, end_idx) 70 | 71 | 72 | # Detect if motion exists between consecutive frames. 73 | def detect_motion(keypoint_dicts_1, keypoint_dicts_2): 74 | motion_thre = 5 # minimum position difference to count as motion 75 | # If it's the first frame or the number of people in these two frames 76 | # are different, return true. 77 | if keypoint_dicts_1 is None: 78 | return True 79 | if len(keypoint_dicts_1) != len(keypoint_dicts_2): 80 | return True 81 | 82 | # If the pose difference between two frames are larger than threshold, 83 | # return true. 84 | for keypoint_dict_1, keypoint_dict_2 in zip(keypoint_dicts_1, keypoint_dicts_2): 85 | pose_pts1, pose_pts2 = get_keypoint_array([keypoint_dict_1, keypoint_dict_2]) 86 | if ((abs(pose_pts1 - pose_pts2) > motion_thre) & 87 | (pose_pts1 != 0) & (pose_pts2 != 0)).any(): 88 | return True 89 | return False 90 | 91 | 92 | # If densepose did not find any person in the frame (and thus outputs nothing), 93 | # remove the frame from the dataset. 94 | def check_densepose_exists(args, video_idx): 95 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx) 96 | json_paths = sorted(glob.glob(op_dir + '/*.json')) 97 | for json_path in json_paths: 98 | dp_path = json_path.replace(args.openpose_folder, args.densepose_folder) 99 | dp_path = dp_path.replace(args.openpose_postfix, args.densepose_postfix) 100 | if not os.path.exists(dp_path): 101 | remove_frame(args, start=json_path) 102 | 103 | 104 | # Check if the frame is valid to use. 105 | def is_valid_frame(args, img_path): 106 | if img_path.endswith('.jpg'): 107 | img_path = img_path.replace(args.img_folder, args.openpose_folder) 108 | img_path = img_path.replace('.jpg', args.openpose_postfix) 109 | with open(img_path, encoding='utf-8') as f: 110 | keypoint_dicts = json.loads(f.read())["people"] 111 | return len(keypoint_dicts) > 0 and is_full_body(keypoint_dicts) and \ 112 | contains_non_overlapping_people(keypoint_dicts) 113 | 114 | 115 | # Check if the image contains a full body. 116 | def is_full_body(keypoint_dicts): 117 | if type(keypoint_dicts) != list: 118 | keypoint_dicts = [keypoint_dicts] 119 | for keypoint_dict in keypoint_dicts: 120 | pose_pts = get_keypoint_array(keypoint_dict) 121 | # Contains at least one joint of head and one joint of foot. 122 | full = pose_pts[[0, 15, 16, 17, 18], :].any() \ 123 | and pose_pts[[11, 14, 19, 20, 21, 22, 23, 24], :].any() 124 | if full: 125 | return True 126 | return False 127 | 128 | 129 | # Check whether two people overlap with each other. 130 | def has_overlap(pose_pts_1, pose_pts_2): 131 | pose_pts_1 = get_valid_openpose_keypoints(pose_pts_1)[:, 0] 132 | pose_pts_2 = get_valid_openpose_keypoints(pose_pts_2)[:, 0] 133 | # Get the x_axis bbox of the person. 134 | x1_start, x1_end = int(pose_pts_1.min()), int(pose_pts_1.max()) 135 | x2_start, x2_end = int(pose_pts_2.min()), int(pose_pts_2.max()) 136 | if x1_end < x2_start or x2_end < x1_start: 137 | return False 138 | return True 139 | 140 | 141 | # Check if the image contains at least one person that does not overlap with others. 142 | def contains_non_overlapping_people(keypoint_dicts): 143 | if len(keypoint_dicts) < 2: 144 | return True 145 | 146 | all_pose_pts = [get_keypoint_array(k) for k in keypoint_dicts] 147 | for i, pose_pts in enumerate(all_pose_pts): 148 | overlap = False 149 | for j, pose_pts2 in enumerate(all_pose_pts): 150 | if i == j: 151 | continue 152 | overlap = overlap | has_overlap(pose_pts, pose_pts2) 153 | if overlap: 154 | break 155 | if not overlap: 156 | return True 157 | return False 158 | -------------------------------------------------------------------------------- /data/preprocess/util/get_poses.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import glob 9 | import os.path as path 10 | import cv2 11 | 12 | from util.util import makedirs, remove_frame 13 | from util.check_valid import remove_invalid_frames, remove_static_frames, \ 14 | remove_isolated_frames, is_valid_frame 15 | 16 | 17 | # Run OpenPose on the extracted frames, and remove invalid frames. 18 | # To expedite the process, we will first process only keyframes in the video. 19 | # If the keyframe looks promising, we will then process the whole block of 20 | # frames for the keyframe. 21 | def run_openpose(args, video_path): 22 | video_idx = path.basename(video_path).split('.')[0] 23 | try: 24 | img_dir = path.join(args.output_root, args.img_folder, video_idx) 25 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx) 26 | img_names = sorted(glob.glob(img_dir + '/*.jpg')) 27 | op_names = sorted(glob.glob(op_dir + '/*.json')) 28 | 29 | # If the frames haven't been extracted or OpenPose hasn't been run or 30 | # finished processing. 31 | if (not os.path.isdir(img_dir) or not os.path.isdir(op_dir) 32 | or len(img_names) != len(op_names)): 33 | makedirs(img_dir) 34 | 35 | # First run OpenPose on key frames, then decide whether to run 36 | # the whole batch of frames. 37 | extract_key_frames(args, video_path, img_dir) 38 | run_openpose_cmd(args, video_idx) 39 | 40 | # If key frame looks good, extract all frames in the batch and 41 | # run OpenPose. 42 | if args.n_skip_frames > 1: 43 | extract_valid_frames(args, video_path, img_dir) 44 | run_openpose_cmd(args, video_idx) 45 | 46 | # Remove all unusable frames. 47 | remove_invalid_frames(args, video_idx) 48 | remove_static_frames(args, video_idx) 49 | remove_isolated_frames(args, video_idx) 50 | except: 51 | raise ValueError('video %s running openpose failed' % video_idx) 52 | 53 | 54 | # Run DensePose on the extracted frames, and remove invalid frames. 55 | def run_densepose(args, video_idx): 56 | try: 57 | img_dir = path.join(args.output_root, args.img_folder, video_idx) 58 | dp_dir = path.join(args.output_root, args.densepose_folder, video_idx) 59 | img_names = sorted(glob.glob(img_dir + '/*.jpg')) 60 | dp_names = sorted(glob.glob(dp_dir + '/*.png')) 61 | 62 | if not os.path.isdir(dp_dir) or len(img_names) != len(dp_names): 63 | makedirs(dp_dir) 64 | 65 | # Run densepose. 66 | run_densepose_cmd(args, video_idx) 67 | except: 68 | raise ValueError('video %s running densepose failed' % video_idx) 69 | 70 | 71 | # Extract only the keyframes in the video. 72 | def extract_key_frames(args, video_path, img_dir): 73 | print('Extracting keyframes.') 74 | vidcap = cv2.VideoCapture(video_path) 75 | success, image = vidcap.read() 76 | frame_count = 0 77 | while success: 78 | if frame_count % args.n_skip_frames == 0: 79 | write_name = path.join(img_dir, "frame%06d.jpg" % frame_count) 80 | cv2.imwrite(write_name, image) 81 | success, image = vidcap.read() 82 | frame_count += 1 83 | 84 | 85 | # Extract valid frames from the video based on the extracted keyframes. 86 | def extract_valid_frames(args, video_path, img_dir): 87 | vidcap = cv2.VideoCapture(video_path) 88 | success, image = vidcap.read() 89 | frame_count = 0 90 | do_write = True 91 | while success: 92 | is_key_frame = frame_count % args.n_skip_frames == 0 93 | write_name = path.join(img_dir, "frame%06d.jpg" % frame_count) 94 | # Each time it's keyframe, check whether the frame is valid. If it is, 95 | # all frames following it before the next keyframe will be extracted. 96 | # Otherwise, this block of frames will be skipped and the next keyframe 97 | # will be examined. 98 | if is_key_frame: 99 | do_write = is_valid_frame(args, write_name) 100 | if not do_write: # If not valid, remove this keyframe. 101 | remove_frame(args, start=write_name) 102 | if do_write: 103 | cv2.imwrite(write_name, image) 104 | success, image = vidcap.read() 105 | frame_count += 1 106 | print('Video contains %d frames.' % frame_count) 107 | 108 | 109 | # Extract all frames from the video. 110 | def extract_all_frames(args, video_path): 111 | video_idx = path.basename(video_path).split('.')[0] 112 | img_dir = path.join(args.output_root, args.img_folder, video_idx) 113 | makedirs(img_dir) 114 | 115 | vidcap = cv2.VideoCapture(video_path) 116 | success, image = vidcap.read() 117 | frame_count = 0 118 | while success: 119 | write_name = path.join(img_dir, "frame%06d.jpg" % frame_count) 120 | cv2.imwrite(write_name, image) 121 | success, image = vidcap.read() 122 | frame_count += 1 123 | print('Extracted %d frames' % frame_count) 124 | 125 | 126 | # Running the actual OpenPose command. 127 | def run_openpose_cmd(args, video_idx): 128 | pwd = os.getcwd() 129 | img_dir = path.join(pwd, args.output_root, args.img_folder, video_idx) 130 | op_dir = path.join(pwd, args.output_root, args.openpose_folder, video_idx) 131 | render_dir = path.join(pwd, args.output_root, 132 | args.openpose_folder + '_rendered', video_idx) 133 | makedirs(op_dir) 134 | makedirs(render_dir) 135 | 136 | cmd = 'cd %s; ./build/examples/openpose/openpose.bin --display 0 ' \ 137 | '--disable_blending --image_dir %s --write_images %s --face --hand ' \ 138 | '--face_render_threshold 0.1 --hand_render_threshold 0.02 ' \ 139 | '--write_json %s; cd %s' \ 140 | % (args.openpose_root, img_dir, render_dir, op_dir, 141 | path.join(pwd, args.output_root)) 142 | os.system(cmd) 143 | 144 | 145 | # Running the actual DensePose command. 146 | def run_densepose_cmd(args, video_idx): 147 | pwd = os.getcwd() 148 | img_dir = path.join(pwd, args.output_root, args.img_folder, video_idx) 149 | dp_dir = path.join(pwd, args.output_root, args.densepose_folder, video_idx, 'frame.png') 150 | cmd = 'python2 tools/infer_simple.py ' \ 151 | '--cfg configs/DensePose_ResNet101_FPN_s1x-e2e.yaml ' \ 152 | '--wts https://dl.fbaipublicfiles.com/densepose/DensePose_ResNet101_FPN_s1x-e2e.pkl ' \ 153 | '--output-dir %s %s' % (dp_dir, img_dir) 154 | # cmd = 'python apply_net.py show configs/densepose_rcnn_R_101_FPN_s1x.yaml ' \ 155 | # 'densepose_rcnn_R_101_FPN_s1x.pkl %s dp_segm,dp_u,dp_v --output %s' \ 156 | # % (img_dir, dp_dir) 157 | cmd = 'cd %s; ' % args.densepose_root \ 158 | + cmd \ 159 | + '; cd %s' % path.join(pwd, args.output_root) 160 | os.system(cmd) 161 | -------------------------------------------------------------------------------- /data/preprocess/util/util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import os.path as path 9 | import numpy as np 10 | 11 | 12 | # Extract keypoint array given the openpose dict. 13 | def get_keypoint_array(keypoint_dict): 14 | if type(keypoint_dict) == list: 15 | return [get_keypoint_array(d) for d in keypoint_dict] 16 | if type(keypoint_dict) != np.ndarray: 17 | keypoint_dict = np.array(keypoint_dict["pose_keypoints_2d"]).reshape(25, 3) 18 | return keypoint_dict 19 | 20 | 21 | # Only extract openpose keypoints where the confidence is larger then conf_thre. 22 | def get_valid_openpose_keypoints(keypoint_array): 23 | if type(keypoint_array) == list: 24 | return [get_valid_openpose_keypoints(k) for k in keypoint_array] 25 | return keypoint_array[keypoint_array[:, 2] > 0.01, :] 26 | 27 | 28 | # Remove particular frame(s) for all folders. 29 | def remove_frame(args, video_idx='', start=0, end=None): 30 | if not isinstance(start, int): 31 | video_idx = path.basename(path.dirname(start)) 32 | start = get_frame_idx(start) 33 | if end is None: 34 | end = start 35 | for i in range(start, end + 1): 36 | img_path = path.join(args.output_root, args.img_folder, video_idx, 37 | 'frame%06d.jpg' % i) 38 | op_path = path.join(args.output_root, args.openpose_folder, video_idx, 39 | 'frame%06d%s' % (i, args.openpose_postfix)) 40 | dp_path = path.join(args.output_root, args.densepose_folder, video_idx, 41 | 'frame%06d%s' % (i, args.densepose_postfix)) 42 | dm_path = path.join(args.output_root, args.densemask_folder, video_idx, 43 | 'frame%06d%s' % (i, args.densemask_postfix)) 44 | print('removing %s' % img_path) 45 | remove(img_path) 46 | remove(op_path) 47 | remove(dp_path) 48 | remove(dm_path) 49 | 50 | 51 | def remove_folder(args, video_idx): 52 | os.rmdir(path.join(args.output_root, args.img_folder, video_idx)) 53 | os.rmdir(path.join(args.output_root, args.openpose_folder, video_idx)) 54 | os.rmdir(path.join(args.output_root, args.densepose_folder, video_idx)) 55 | os.rmdir(path.join(args.output_root, args.densemask_folder, video_idx)) 56 | 57 | 58 | def get_frame_idx(file_name): 59 | return int(path.basename(file_name)[5:11]) 60 | 61 | 62 | def makedirs(folder): 63 | if not path.exists(folder): 64 | os.umask(0) 65 | os.makedirs(folder, mode=0o777) 66 | 67 | 68 | def remove(file_name): 69 | if path.exists(file_name): 70 | os.remove(file_name) 71 | -------------------------------------------------------------------------------- /data/preprocess/youTube_playlist.txt: -------------------------------------------------------------------------------- 1 | https://www.youtube.com/playlist?list=PLZgHzuctidRIYzMpEcbEMIFp7vGcFK4x7 2 | https://www.youtube.com/playlist?list=PL0m7UHzPZEA9R8Y6xautFgqeWnorDj2Le 3 | https://www.youtube.com/playlist?list=PLsVSF-hJhvBKV9XDaqJAEx-otHG1_XT9q -------------------------------------------------------------------------------- /imgs/dance.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/dance.gif -------------------------------------------------------------------------------- /imgs/face.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/face.gif -------------------------------------------------------------------------------- /imgs/illustration.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/illustration.gif -------------------------------------------------------------------------------- /imgs/mona_lisa.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/mona_lisa.gif -------------------------------------------------------------------------------- /imgs/statue.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/statue.gif -------------------------------------------------------------------------------- /imgs/street.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/street.gif -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt -------------------------------------------------------------------------------- /models/face_refiner.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import torch 8 | import torch.nn.functional as F 9 | 10 | from .base_model import BaseModel 11 | 12 | class FaceRefineModel(BaseModel): 13 | def name(self): 14 | return 'FaceRefineModel' 15 | 16 | def initialize(self, opt, add_face_D, refine_face): 17 | BaseModel.initialize(self, opt) 18 | self.opt = opt 19 | self.add_face_D = add_face_D 20 | self.refine_face = refine_face 21 | self.face_size = int(opt.fineSize / opt.aspect_ratio) // 4 22 | 23 | ### refine the face region of the fake image 24 | def refine_face_region(self, netGf, label_valid, fake_image, label, ref_label_valid, ref_image, ref_label): 25 | label_face, fake_face_coarse = self.crop_face_region([label_valid, fake_image], label, crop_smaller=4) 26 | ref_label_face, ref_image_face = self.crop_face_region([ref_label_valid, ref_image], ref_label, crop_smaller=4) 27 | fake_face = netGf(label_face, ref_label_face.unsqueeze(1), ref_image_face.unsqueeze(1), img_coarse=fake_face_coarse.detach()) 28 | fake_image = self.replace_face_region(fake_image, fake_face, label, fake_face_coarse.detach(), crop_smaller=4) 29 | return fake_image 30 | 31 | ### crop out the face region of the image (and resize if necessary to feed into generator/discriminator) 32 | def crop_face_region(self, image, input_label, crop_smaller=0): 33 | if type(image) == list: 34 | return [self.crop_face_region(im, input_label, crop_smaller) for im in image] 35 | for i in range(input_label.size(0)): 36 | ys, ye, xs, xe = self.get_face_region(input_label[i:i+1], crop_smaller=crop_smaller) 37 | output_i = F.interpolate(image[i:i+1,-3:,ys:ye,xs:xe], size=(self.face_size, self.face_size)) 38 | output = torch.cat([output, output_i]) if i != 0 else output_i 39 | return output 40 | 41 | ### replace the face region in the fake image with the refined version 42 | def replace_face_region(self, fake_image, fake_face, input_label, fake_face_coarse=None, crop_smaller=0): 43 | fake_image = fake_image.clone() 44 | b, _, h, w = input_label.size() 45 | for i in range(b): 46 | ys, ye, xs, xe = self.get_face_region(input_label[i:i+1], crop_smaller) 47 | fake_face_i = fake_face[i:i+1] + (fake_face_coarse[i:i+1] if fake_face_coarse is not None else 0) 48 | fake_face_i = F.interpolate(fake_face_i, size=(ye-ys, xe-xs), mode='bilinear') 49 | fake_image[i:i+1,:,ys:ye,xs:xe] = torch.clamp(fake_face_i, -1, 1) 50 | return fake_image 51 | 52 | ### get coordinates of the face bounding box 53 | def get_face_region(self, pose, crop_smaller=0): 54 | if pose.dim() == 3: pose = pose.unsqueeze(0) 55 | elif pose.dim() == 5: pose = pose[-1,-1:] 56 | _, _, h, w = pose.size() 57 | 58 | use_openpose = not self.opt.basic_point_only and not self.opt.remove_face_labels 59 | if use_openpose: # use openpose face keypoints to identify face region 60 | face = ((pose[:,-3] > 0) & (pose[:,-2] > 0) & (pose[:,-1] > 0)).nonzero() 61 | else: # use densepose labels 62 | face = (pose[:,2] > 0.9).nonzero() 63 | if face.size(0): 64 | y, x = face[:,1], face[:,2] 65 | ys, ye, xs, xe = y.min().item(), y.max().item(), x.min().item(), x.max().item() 66 | if use_openpose: 67 | xc, yc = (xs + xe) // 2, (ys*3 + ye*2) // 5 68 | ylen = int((xe - xs) * 2.5) 69 | else: 70 | xc, yc = (xs + xe) // 2, (ys + ye) // 2 71 | ylen = int((ye - ys) * 1.25) 72 | ylen = xlen = min(w, max(32, ylen)) 73 | yc = max(ylen//2, min(h-1 - ylen//2, yc)) 74 | xc = max(xlen//2, min(w-1 - xlen//2, xc)) 75 | else: 76 | yc = h//4 77 | xc = w//2 78 | ylen = xlen = h // 32 * 8 79 | 80 | ys, ye, xs, xe = yc - ylen//2, yc + ylen//2, xc - xlen//2, xc + xlen//2 81 | if crop_smaller != 0: # crop slightly smaller region inside face 82 | ys += crop_smaller; xs += crop_smaller 83 | ye -= crop_smaller; xe -= crop_smaller 84 | return ys, ye, xs, xe -------------------------------------------------------------------------------- /models/flownet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | import numpy as np 9 | import torch 10 | import torch.nn.functional as F 11 | import sys 12 | 13 | from .base_model import BaseModel 14 | 15 | class FlowNet(BaseModel): 16 | def name(self): 17 | return 'FlowNet' 18 | 19 | def initialize(self, opt): 20 | BaseModel.initialize(self, opt) 21 | 22 | # flownet 2 23 | from .networks.flownet2_pytorch import models as flownet2_models 24 | from .networks.flownet2_pytorch.utils import tools as flownet2_tools 25 | from .networks.flownet2_pytorch.networks.resample2d_package.resample2d import Resample2d 26 | 27 | self.flowNet = flownet2_tools.module_to_dict(flownet2_models)['FlowNet2']().cuda() 28 | checkpoint = torch.load('models/networks/flownet2_pytorch/FlowNet2_checkpoint.pth.tar', map_location=torch.device('cpu')) 29 | self.flowNet.load_state_dict(checkpoint['state_dict']) 30 | self.flowNet.eval() 31 | self.resample = Resample2d() 32 | self.downsample = torch.nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 33 | 34 | def forward(self, data_list, epoch=0, dummy_bs=0): 35 | if data_list[0].get_device() == 0: 36 | data_list = self.remove_dummy_from_tensor(data_list, dummy_bs) 37 | image_now, image_ref = data_list 38 | image_now, image_ref = image_now[:,:,:3], image_ref[:,0:1,:3] 39 | 40 | flow_gt_prev = flow_gt_ref = conf_gt_prev = conf_gt_ref = None 41 | with torch.no_grad(): 42 | if not self.opt.isTrain or epoch > self.opt.niter_single: 43 | image_prev = torch.cat([image_now[:,0:1], image_now[:,:-1]], dim=1) 44 | flow_gt_prev, conf_gt_prev = self.flowNet_forward(image_now, image_prev) 45 | 46 | if self.opt.warp_ref: 47 | flow_gt_ref, conf_gt_ref = self.flowNet_forward(image_now, image_ref.expand_as(image_now)) 48 | 49 | flow_gt, conf_gt = [flow_gt_ref, flow_gt_prev], [conf_gt_ref, conf_gt_prev] 50 | return flow_gt, conf_gt 51 | 52 | def flowNet_forward(self, input_A, input_B): 53 | size = input_A.size() 54 | assert(len(size) == 4 or len(size) == 5) 55 | if len(size) == 5: 56 | b, n, c, h, w = size 57 | input_A = input_A.contiguous().view(-1, c, h, w) 58 | input_B = input_B.contiguous().view(-1, c, h, w) 59 | flow, conf = self.compute_flow_and_conf(input_A, input_B) 60 | return flow.view(b, n, 2, h, w), conf.view(b, n, 1, h, w) 61 | else: 62 | return self.compute_flow_and_conf(input_A, input_B) 63 | 64 | def compute_flow_and_conf(self, im1, im2): 65 | assert(im1.size()[1] == 3) 66 | assert(im1.size() == im2.size()) 67 | old_h, old_w = im1.size()[2], im1.size()[3] 68 | new_h, new_w = old_h//64*64, old_w//64*64 69 | if old_h != new_h: 70 | im1 = F.interpolate(im1, size=(new_h, new_w), mode='bilinear') 71 | im2 = F.interpolate(im2, size=(new_h, new_w), mode='bilinear') 72 | self.flowNet.cuda(im1.get_device()) 73 | data1 = torch.cat([im1.unsqueeze(2), im2.unsqueeze(2)], dim=2) 74 | flow1 = self.flowNet(data1) 75 | conf = (self.norm(im1 - self.resample(im2, flow1)) < 0.02).float() 76 | 77 | if old_h != new_h: 78 | flow1 = F.interpolate(flow1, size=(old_h, old_w), mode='bilinear') * old_h / new_h 79 | conf = F.interpolate(conf, size=(old_h, old_w), mode='bilinear') 80 | return flow1, conf 81 | 82 | def norm(self, t): 83 | return torch.sum(t*t, dim=1, keepdim=True) 84 | -------------------------------------------------------------------------------- /models/input_process.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import torch 8 | 9 | ############################# input processing ################################### 10 | def encode_input(opt, data_list, dummy_bs): 11 | if opt.isTrain and data_list[0].get_device() == 0: 12 | data_list = remove_dummy_from_tensor(opt, data_list, dummy_bs) 13 | tgt_label, tgt_image, flow_gt, conf_gt, ref_label, ref_image, prev_label, prev_real_image, prev_fake_image = data_list 14 | 15 | # target label and image 16 | tgt_label = encode_label(opt, tgt_label) 17 | tgt_image = tgt_image.cuda() 18 | 19 | # reference label and image 20 | ref_label = encode_label(opt, ref_label) 21 | ref_image = ref_image.cuda() 22 | 23 | return tgt_label, tgt_image, flow_gt, conf_gt, ref_label, ref_image, [prev_label, prev_real_image, prev_fake_image] 24 | 25 | def encode_label(opt, label_map): 26 | size = label_map.size() 27 | if len(size) == 5: 28 | bs, t, c, h, w = size 29 | label_map = label_map.view(-1, c, h, w) 30 | else: 31 | bs, c, h, w = size 32 | 33 | label_nc = opt.label_nc 34 | if label_nc == 0: 35 | input_label = label_map.cuda() 36 | else: 37 | # create one-hot vector for label map 38 | label_map = label_map.cuda() 39 | oneHot_size = (label_map.shape[0], label_nc, h, w) 40 | input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_() 41 | input_label = input_label.scatter_(1, label_map.long().cuda(), 1.0) 42 | 43 | if len(size) == 5: 44 | return input_label.view(bs, t, -1, h, w) 45 | return input_label 46 | 47 | ### get the union of target and reference foreground masks 48 | def combine_fg_mask(fg_mask, ref_fg_mask, has_fg): 49 | return ((fg_mask > 0) | (ref_fg_mask > 0)).float() if has_fg else 1 50 | 51 | ### obtain the foreground mask for pose sequences, which only includes the human 52 | def get_fg_mask(opt, input_label, has_fg): 53 | if type(input_label) == list: 54 | return [get_fg_mask(opt, l, has_fg) for l in input_label] 55 | if not has_fg: return None 56 | if len(input_label.size()) == 5: input_label = input_label[:,0] 57 | mask = input_label[:,2:3] if opt.label_nc == 0 else -input_label[:,0:1] 58 | 59 | mask = torch.nn.MaxPool2d(15, padding=7, stride=1)(mask) # make the mask slightly larger 60 | mask = (mask > -1).float() 61 | return mask 62 | 63 | ### obtain mask of different body parts 64 | def get_part_mask(pose): 65 | part_groups = [[0], [1,2], [3,4], [5,6], [7,9,8,10], [11,13,12,14], [15,17,16,18], [19,21,20,22], [23,24]] 66 | n_parts = len(part_groups) 67 | 68 | need_reshape = pose.dim() == 4 69 | if need_reshape: 70 | bo, t, h, w = pose.size() 71 | pose = pose.view(-1, h, w) 72 | b, h, w = pose.size() 73 | part = (pose / 2 + 0.5) * 24 74 | mask = torch.cuda.ByteTensor(b, n_parts, h, w).fill_(0) 75 | for i in range(n_parts): 76 | for j in part_groups[i]: 77 | mask[:, i] = mask[:, i] | ((part > j-0.1) & (part < j+0.1)).byte() 78 | if need_reshape: 79 | mask = mask.view(bo, t, -1, h, w) 80 | return mask.float() 81 | 82 | ### obtain mask of faces 83 | def get_face_mask(pose): 84 | if pose.dim() == 3: 85 | pose = pose.unsqueeze(1) 86 | b, t, h, w = pose.size() 87 | part = (pose / 2 + 0.5) * 24 88 | if pose.is_cuda: 89 | mask = torch.cuda.ByteTensor(b, t, h, w).fill_(0) 90 | else: 91 | mask = torch.ByteTensor(b, t, h, w).fill_(0) 92 | for j in [23,24]: 93 | mask = mask | ((part > j-0.1) & (part < j+0.1)).byte() 94 | return mask.float() 95 | 96 | ### remove partial labels in the pose map if necessary 97 | def use_valid_labels(opt, pose): 98 | if 'pose' not in opt.dataset_mode: return pose 99 | if pose is None: return pose 100 | if type(pose) == list: 101 | return [use_valid_labels(opt, p) for p in pose] 102 | assert(pose.dim() == 4 or pose.dim() == 5) 103 | if opt.pose_type == 'open': 104 | if pose.dim() == 4: pose = pose[:,3:] 105 | elif pose.dim() == 5: pose = pose[:,:,3:] 106 | elif opt.remove_face_labels: 107 | if pose.dim() == 4: 108 | face_mask = get_face_mask(pose[:,2]) 109 | pose = torch.cat([pose[:,:3] * (1 - face_mask) - face_mask, pose[:,3:]], dim=1) 110 | else: 111 | face_mask = get_face_mask(pose[:,:,2]).unsqueeze(2) 112 | pose = torch.cat([pose[:,:,:3] * (1 - face_mask) - face_mask, pose[:,:,3:]], dim=2) 113 | return pose 114 | 115 | def remove_dummy_from_tensor(opt, tensors, remove_size=0): 116 | if remove_size == 0: return tensors 117 | if tensors is None: return None 118 | if isinstance(tensors, list): 119 | return [remove_dummy_from_tensor(opt, tensor, remove_size) for tensor in tensors] 120 | if isinstance(tensors, torch.Tensor): 121 | tensors = tensors[remove_size:] 122 | return tensors -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | 11 | from models.networks.sync_batchnorm import DataParallelWithCallback 12 | from models.vid2vid_model import Vid2VidModel 13 | from util.distributed import master_only 14 | from util.distributed import master_only_print as print 15 | 16 | def create_model(opt, epoch=0): 17 | model = Vid2VidModel() 18 | model.initialize(opt, epoch) 19 | print("model [%s] was created" % (model.name())) 20 | 21 | if opt.isTrain: 22 | if opt.amp != 'O0': 23 | from apex import amp 24 | print('using amp optimization') 25 | model, optimizers = amp.initialize(model, [model.optimizer_G, model.optimizer_D], 26 | opt_level=opt.amp, num_losses=2) 27 | else: 28 | optimizers = model.optimizer_G, model.optimizer_D 29 | 30 | model = WrapModel(opt, model) 31 | flowNet = None 32 | if not opt.no_flow_gt: 33 | from .flownet import FlowNet 34 | flowNet = FlowNet() 35 | flowNet.initialize(opt) 36 | flowNet = WrapModel(opt, flowNet) 37 | return model, flowNet, optimizers 38 | return model 39 | 40 | def WrapModel(opt, model): 41 | if opt.distributed: 42 | import apex 43 | model = apex.parallel.DistributedDataParallel(model.cuda(), delay_allreduce=True) 44 | else: 45 | model = MyModel(opt, model) 46 | return model 47 | 48 | @master_only 49 | def save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, model, end_of_epoch=False): 50 | if not end_of_epoch: 51 | if (total_steps % opt.save_latest_freq == 0): 52 | visualizer.vis_print(opt, 'saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps)) 53 | model.module.save_networks('latest') 54 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d') 55 | model.cuda() 56 | else: 57 | if epoch % opt.save_epoch_freq == 0: 58 | visualizer.vis_print(opt, 'saving the model at the end of epoch %d, iters %d' % (epoch, total_steps)) 59 | model.module.save_networks('latest') 60 | model.module.save_networks(epoch) 61 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d') 62 | model.cuda() 63 | 64 | def update_models(opt, epoch, model, data_loader): 65 | ### linearly decay learning rate after certain iterations 66 | if epoch > opt.niter: 67 | model.module.update_learning_rate(epoch) 68 | 69 | ### train single frame first then sequence of frames 70 | if epoch == opt.niter_single + 1 and not model.module.temporal: 71 | model.module.init_temporal_model() 72 | 73 | ### gradually grow training sequence length 74 | epoch_temp = epoch - opt.niter_single 75 | if epoch_temp > 0 and ((epoch_temp - 1) % opt.niter_step) == 0: 76 | data_loader.dataset.update_training_batch((epoch_temp - 1) // opt.niter_step) 77 | 78 | 79 | class MyModel(nn.Module): 80 | def __init__(self, opt, model): 81 | super(MyModel, self).__init__() 82 | self.opt = opt 83 | model = model.cuda(opt.gpu_ids[0]) 84 | self.module = model 85 | 86 | self.model = DataParallelWithCallback(model, device_ids=opt.gpu_ids) 87 | if opt.batch_for_first_gpu != -1: 88 | self.bs_per_gpu = (opt.batchSize - opt.batch_for_first_gpu) // (len(opt.gpu_ids) - 1) # batch size for each GPU 89 | else: 90 | self.bs_per_gpu = int(np.ceil(float(opt.batchSize) / len(opt.gpu_ids))) # batch size for each GPU 91 | self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batchSize 92 | 93 | def forward(self, *inputs, **kwargs): 94 | inputs = self.add_dummy_to_tensor(inputs, self.pad_bs) 95 | outputs = self.model(*inputs, **kwargs, dummy_bs=self.pad_bs) 96 | if self.pad_bs == self.bs_per_gpu: # gpu 0 does 0 batch but still returns 1 batch 97 | return self.remove_dummy_from_tensor(outputs, 1) 98 | return outputs 99 | 100 | def add_dummy_to_tensor(self, tensors, add_size=0): 101 | if add_size == 0 or tensors is None: return tensors 102 | if type(tensors) == list or type(tensors) == tuple: 103 | return [self.add_dummy_to_tensor(tensor, add_size) for tensor in tensors] 104 | 105 | if isinstance(tensors, torch.Tensor): 106 | dummy = torch.zeros_like(tensors)[:add_size] 107 | tensors = torch.cat([dummy, tensors]) 108 | return tensors 109 | 110 | def remove_dummy_from_tensor(self, tensors, remove_size=0): 111 | if remove_size == 0 or tensors is None: return tensors 112 | if type(tensors) == list or type(tensors) == tuple: 113 | return [self.remove_dummy_from_tensor(tensor, remove_size) for tensor in tensors] 114 | 115 | if isinstance(tensors, torch.Tensor): 116 | tensors = tensors[remove_size:] 117 | return tensors -------------------------------------------------------------------------------- /models/networks/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import torch 8 | import torch.nn as nn 9 | import functools 10 | import numpy as np 11 | import torch.nn.functional as F 12 | from models.networks.loss import * 13 | from models.networks.discriminator import * 14 | from models.networks.generator import * 15 | 16 | 17 | def modify_commandline_options(parser, is_train): 18 | opt, _ = parser.parse_known_args() 19 | if 'fewshot' in opt.netG: 20 | parser = FewShotGenerator.modify_commandline_options(parser, is_train) 21 | 22 | if is_train: 23 | if opt.which_model_netD == 'multiscale': 24 | parser = MultiscaleDiscriminator.modify_commandline_options(parser, is_train) 25 | elif opt.which_model_netD == 'n_layers': 26 | parser = NLayerDiscriminator.modify_commandline_options(parser, is_train) 27 | return parser 28 | 29 | def define_G(opt): 30 | if 'fewshot' in opt.netG: 31 | netG = FewShotGenerator(opt) 32 | else: 33 | raise('generator not implemented!') 34 | if opt.isTrain and opt.print_G: netG.print_network() 35 | if len(opt.gpu_ids) > 0: 36 | assert(torch.cuda.is_available()) 37 | netG.cuda() 38 | netG.init_weights(opt.init_type, opt.init_variance) 39 | return netG 40 | 41 | def define_D(opt, input_nc, ndf, n_layers_D, norm='spectralinstance', subarch='n_layers', num_D=1, getIntermFeat=False, stride=2, gpu_ids=[]): 42 | norm_layer = get_nonspade_norm_layer(opt, norm_type=norm) 43 | if opt.which_model_netD == 'multiscale': 44 | netD = MultiscaleDiscriminator(opt, input_nc, ndf, n_layers_D, norm_layer, subarch, num_D, getIntermFeat, stride, gpu_ids) 45 | elif opt.which_model_netD == 'n_layers': 46 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer, getIntermFeat) 47 | else: 48 | raise('unknown type discriminator %s!' % opt.which_model_netD) 49 | 50 | if opt.isTrain and opt.print_D: netD.print_network() 51 | if len(gpu_ids) > 0: 52 | assert(torch.cuda.is_available()) 53 | netD.cuda() 54 | netD.init_weights(opt.init_type, opt.init_variance) 55 | return netD 56 | -------------------------------------------------------------------------------- /models/networks/architecture.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torch.nn.utils.spectral_norm as sn 11 | 12 | from models.networks.base_network import BaseNetwork, batch_conv 13 | from models.networks.normalization import SPADE, SynchronizedBatchNorm2d 14 | 15 | def actvn(x): 16 | out = F.leaky_relu(x, 2e-1) 17 | return out 18 | 19 | def generalConv(adaptive=False, transpose=False): 20 | class NormalConv2d(nn.Conv2d): 21 | def __init__(self, *args, **kwargs): 22 | super(NormalConv2d, self).__init__(*args, **kwargs) 23 | def forward(self, input, weight=None, bias=None, stride=1): 24 | return super(NormalConv2d, self).forward(input) 25 | class NormalConvTranspose2d(nn.ConvTranspose2d): 26 | def __init__(self, *args, output_padding=1, **kwargs): 27 | #kwargs['output_padding'] = 1 28 | super(NormalConvTranspose2d, self).__init__(*args, **kwargs) 29 | def forward(self, input, weight=None, bias=None, stride=1): 30 | return super(NormalConvTranspose2d, self).forward(input) 31 | class AdaptiveConv2d(nn.Module): 32 | def __init__(self, *args, **kwargs): 33 | super().__init__() 34 | def forward(self, input, weight=None, bias=None, stride=1): 35 | return batch_conv(input, weight, bias, stride) 36 | 37 | if adaptive: return AdaptiveConv2d 38 | return NormalConv2d if not transpose else NormalConvTranspose2d 39 | 40 | def generalNorm(norm): 41 | if 'spade' in norm: return SPADE 42 | def get_norm(norm): 43 | if 'instance' in norm: 44 | return nn.InstanceNorm2d 45 | elif 'syncbatch' in norm: 46 | return SynchronizedBatchNorm2d 47 | elif 'batch' in norm: 48 | return nn.BatchNorm2d 49 | norm = get_norm(norm) 50 | class NormalNorm(norm): 51 | def __init__(self, *args, hidden_nc=0, norm='', ks=1, params_free=False, **kwargs): 52 | super(NormalNorm, self).__init__(*args, **kwargs) 53 | def forward(self, input, label=None, weight=None): 54 | return super(NormalNorm, self).forward(input) 55 | return NormalNorm 56 | 57 | class SPADEConv2d(nn.Module): 58 | def __init__(self, fin, fout, norm='batch', hidden_nc=0, kernel_size=3, padding=1, stride=1): 59 | super().__init__() 60 | self.conv = sn(nn.Conv2d(fin, fout, kernel_size=kernel_size, stride=stride, padding=padding)) 61 | 62 | Norm = generalNorm(norm) 63 | self.bn = Norm(fout, hidden_nc=hidden_nc, norm=norm, ks=3) 64 | 65 | def forward(self, x, label=None): 66 | x = self.conv(x) 67 | out = self.bn(x, label) 68 | out = actvn(out) 69 | return out 70 | 71 | class SPADEResnetBlock(nn.Module): 72 | def __init__(self, fin, fout, norm='batch', hidden_nc=0, conv_ks=3, spade_ks=1, stride=1, conv_params_free=False, norm_params_free=False): 73 | super().__init__() 74 | fhidden = min(fin, fout) 75 | self.learned_shortcut = (fin != fout) 76 | self.stride = stride 77 | Conv2d = generalConv(adaptive=conv_params_free) 78 | sn_ = sn if not conv_params_free else lambda x: x 79 | 80 | # Submodules 81 | self.conv_0 = sn_(Conv2d(fin, fhidden, conv_ks, stride=stride, padding=1)) 82 | self.conv_1 = sn_(Conv2d(fhidden, fout, conv_ks, padding=1)) 83 | if self.learned_shortcut: 84 | self.conv_s = sn_(Conv2d(fin, fout, 1, stride=stride, bias=False)) 85 | 86 | Norm = generalNorm(norm) 87 | self.bn_0 = Norm(fin, hidden_nc=hidden_nc, norm=norm, ks=spade_ks, params_free=norm_params_free) 88 | self.bn_1 = Norm(fhidden, hidden_nc=hidden_nc, norm=norm, ks=spade_ks, params_free=norm_params_free) 89 | if self.learned_shortcut: 90 | self.bn_s = Norm(fin, hidden_nc=hidden_nc, norm=norm, ks=spade_ks, params_free=norm_params_free) 91 | 92 | def forward(self, x, label=None, conv_weights=[], norm_weights=[]): 93 | if not conv_weights: conv_weights = [None]*3 94 | if not norm_weights: norm_weights = [None]*3 95 | x_s = self._shortcut(x, label, conv_weights[2], norm_weights[2]) 96 | dx = self.conv_0(actvn(self.bn_0(x, label, norm_weights[0])), conv_weights[0]) 97 | dx = self.conv_1(actvn(self.bn_1(dx, label, norm_weights[1])), conv_weights[1]) 98 | out = x_s + 1.0*dx 99 | return out 100 | 101 | def _shortcut(self, x, label, conv_weights, norm_weights): 102 | if self.learned_shortcut: 103 | x_s = self.conv_s(self.bn_s(x, label, norm_weights), conv_weights) 104 | elif self.stride != 1: 105 | x_s = nn.AvgPool2d(3, stride=2, padding=1)(x) 106 | else: 107 | x_s = x 108 | return x_s 109 | -------------------------------------------------------------------------------- /models/networks/base_network.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import torch 8 | import torch.nn as nn 9 | from torch.nn import init 10 | import torch.nn.functional as F 11 | import numpy as np 12 | 13 | def get_grid(batchsize, rows, cols, gpu_id=0): 14 | hor = torch.linspace(-1.0, 1.0, cols) 15 | hor.requires_grad = False 16 | hor = hor.view(1, 1, 1, cols) 17 | hor = hor.expand(batchsize, 1, rows, cols) 18 | ver = torch.linspace(-1.0, 1.0, rows) 19 | ver.requires_grad = False 20 | ver = ver.view(1, 1, rows, 1) 21 | ver = ver.expand(batchsize, 1, rows, cols) 22 | 23 | t_grid = torch.cat([hor, ver], 1) 24 | t_grid.requires_grad = False 25 | 26 | return t_grid.cuda(gpu_id) 27 | 28 | def resample(image, flow): 29 | b, c, h, w = image.size() 30 | grid = get_grid(b, h, w, gpu_id=flow.get_device()) 31 | flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) 32 | final_grid = (grid + flow).permute(0, 2, 3, 1).cuda(image.get_device()) 33 | try: 34 | output = F.grid_sample(image, final_grid, mode='bilinear', padding_mode='border', align_corners=True) 35 | except: 36 | output = F.grid_sample(image, final_grid, mode='bilinear', padding_mode='border') 37 | return output 38 | 39 | ### pick the reference image that is most similar to current frame 40 | def pick_ref(refs, ref_idx): 41 | if type(refs) == list: 42 | return [pick_ref(r, ref_idx) for r in refs] 43 | if ref_idx is None: 44 | return refs[:,0] 45 | ref_idx = ref_idx.long().view(-1, 1, 1, 1, 1) 46 | ref = refs.gather(1, ref_idx.expand_as(refs)[:,0:1])[:,0] 47 | return ref 48 | 49 | def concat(a, b, dim=0): 50 | if isinstance(a, list): 51 | return [concat(ai, bi, dim) for ai, bi in zip(a, b)] 52 | if a is None: 53 | return b 54 | return torch.cat([a, b], dim=dim) 55 | 56 | def batch_conv(x, weight, bias=None, stride=1, group_size=-1): 57 | if weight is None: return x 58 | if isinstance(weight, list) or isinstance(weight, tuple): 59 | weight, bias = weight 60 | padding = weight.size()[-1] // 2 61 | groups = group_size//weight.size()[2] if group_size != -1 else 1 62 | if bias is None: bias = [None] * x.size()[0] 63 | y = None 64 | for i in range(x.size(0)): 65 | if stride >= 1: 66 | yi = F.conv2d(x[i:i+1], weight=weight[i], bias=bias[i], padding=padding, stride=stride, groups=groups) 67 | else: 68 | yi = F.conv_transpose2d(x[i:i+1], weight=weight[i], bias=bias[i,:weight.size(2)], padding=padding, stride=int(1/stride), 69 | output_padding=1, groups=groups) 70 | y = concat(y, yi) 71 | return y 72 | 73 | class BaseNetwork(nn.Module): 74 | def __init__(self): 75 | super(BaseNetwork, self).__init__() 76 | 77 | def print_network(self): 78 | if isinstance(self, list): 79 | self = self[0] 80 | num_params = 0 81 | for param in self.parameters(): 82 | num_params += param.numel() 83 | print(self) 84 | print('Total number of parameters: %d' % num_params) 85 | 86 | def init_weights(self, init_type='normal', gain=0.02): 87 | def init_func(m): 88 | classname = m.__class__.__name__ 89 | if classname.find('BatchNorm2d') != -1: 90 | if hasattr(m, 'weight') and m.weight is not None: 91 | init.normal_(m.weight.data, 1.0, gain) 92 | if hasattr(m, 'bias') and m.bias is not None: 93 | init.constant_(m.bias.data, 0.0) 94 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1): 95 | if init_type == 'normal': 96 | init.normal_(m.weight.data, 0.0, gain) 97 | elif init_type == 'xavier': 98 | init.xavier_normal_(m.weight.data, gain=gain) 99 | elif init_type == 'xavier_uniform': 100 | init.xavier_uniform_(m.weight.data, gain=1.0) 101 | elif init_type == 'kaiming': 102 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 103 | elif init_type == 'orthogonal': 104 | init.orthogonal_(m.weight.data, gain=gain) 105 | elif init_type == 'none': 106 | m.reset_parameters() 107 | else: 108 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 109 | if hasattr(m, 'bias') and m.bias is not None: 110 | init.constant_(m.bias.data, 0.0) 111 | 112 | self.apply(init_func) 113 | for m in self.children(): 114 | if hasattr(m, 'init_weights'): 115 | m.init_weights(init_type, gain) 116 | 117 | def load_pretrained_net(self, net_src, net_dst): 118 | source_weights = net_src.state_dict() 119 | target_weights = net_dst.state_dict() 120 | 121 | for k, v in source_weights.items(): 122 | if k in target_weights and target_weights[k].size() == v.size(): 123 | target_weights[k] = v 124 | net_dst.load_state_dict(target_weights) 125 | 126 | def reparameterize(self, mu, logvar): 127 | std = torch.exp(0.5 * logvar) 128 | eps = torch.randn_like(std) 129 | z = eps.mul(std) + mu 130 | return z 131 | 132 | def sum(self, x): 133 | if type(x) != list: return x 134 | return sum([self.sum(xi) for xi in x]) 135 | 136 | def sum_mul(self, x): 137 | assert(type(x) == list) 138 | if type(x[0]) != list: 139 | return np.prod(x) + x[0] 140 | return [self.sum_mul(xi) for xi in x] 141 | 142 | def split_weights(self, weight, sizes): 143 | if isinstance(sizes, list): 144 | weights = [] 145 | cur_size = 0 146 | for i in range(len(sizes)): 147 | next_size = cur_size + self.sum(sizes[i]) 148 | weights.append(self.split_weights(weight[:,cur_size:next_size], sizes[i])) 149 | cur_size = next_size 150 | assert(next_size == weight.size()[1]) 151 | return weights 152 | return weight 153 | 154 | def reshape_weight(self, x, weight_size): 155 | if type(weight_size[0]) == list and type(x) != list: 156 | x = self.split_weights(x, self.sum_mul(weight_size)) 157 | if type(x) == list: 158 | return [self.reshape_weight(xi, wi) for xi, wi in zip(x, weight_size)] 159 | weight_size = [x.size()[0]] + weight_size 160 | bias_size = weight_size[1] 161 | try: 162 | weight = x[:, :-bias_size].view(weight_size) 163 | bias = x[:, -bias_size:] 164 | except: 165 | weight = x.view(weight_size) 166 | bias = None 167 | return [weight, bias] 168 | 169 | def reshape_embed_input(self, x): 170 | if isinstance(x, list): 171 | return [self.reshape_embed_input(xi) for xi in zip(x)] 172 | b, c, _, _ = x.size() 173 | x = x.view(b*c, -1) 174 | return x 175 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 2 | 3 | RUN apt-get update && apt-get install -y rsync htop git openssh-server python-pip 4 | 5 | RUN pip install --upgrade pip 6 | 7 | RUN pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl 8 | RUN pip install torchvision cffi tensorboardX 9 | 10 | RUN pip install tqdm scipy scikit-image colorama==0.3.7 11 | RUN pip install setproctitle pytz ipython -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/README.md: -------------------------------------------------------------------------------- 1 | # flownet2-pytorch 2 | 3 | Pytorch implementation of [FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks](https://arxiv.org/abs/1612.01925). 4 | 5 | Multiple GPU training is supported, and the code provides examples for training or inference on [MPI-Sintel](http://sintel.is.tue.mpg.de/) clean and final datasets. The same commands can be used for training or inference with other datasets. See below for more detail. 6 | 7 | Inference using fp16 (half-precision) is also supported. 8 | 9 | For more help, type
10 | 11 | python main.py --help 12 | 13 | ## Network architectures 14 | Below are the different flownet neural network architectures that are provided.
15 | A batchnorm version for each network is also available. 16 | 17 | - **FlowNet2S** 18 | - **FlowNet2C** 19 | - **FlowNet2CS** 20 | - **FlowNet2CSS** 21 | - **FlowNet2SD** 22 | - **FlowNet2** 23 | 24 | ## Custom layers 25 | 26 | `FlowNet2` or `FlowNet2C*` achitectures rely on custom layers `Resample2d` or `Correlation`.
27 | A pytorch implementation of these layers with cuda kernels are available at [./networks](./networks).
28 | Note : Currently, half precision kernels are not available for these layers. 29 | 30 | ## Data Loaders 31 | 32 | Dataloaders for FlyingChairs, FlyingThings, ChairsSDHom and ImagesFromFolder are available in [datasets.py](./datasets.py).
33 | 34 | ## Loss Functions 35 | 36 | L1 and L2 losses with multi-scale support are available in [losses.py](./losses.py).
37 | 38 | ## Installation 39 | 40 | # get flownet2-pytorch source 41 | git clone https://github.com/NVIDIA/flownet2-pytorch.git 42 | cd flownet2-pytorch 43 | 44 | # install custom layers 45 | bash install.sh 46 | 47 | ### Python requirements 48 | Currently, the code supports python 3 49 | * numpy 50 | * PyTorch ( == 0.4.1, for <= 0.4.0 see branch [python36-PyTorch0.4](https://github.com/NVIDIA/flownet2-pytorch/tree/python36-PyTorch0.4)) 51 | * scipy 52 | * scikit-image 53 | * tensorboardX 54 | * colorama, tqdm, setproctitle 55 | 56 | ## Converted Caffe Pre-trained Models 57 | We've included caffe pre-trained models. Should you use these pre-trained weights, please adhere to the [license agreements](https://drive.google.com/file/d/1TVv0BnNFh3rpHZvD-easMb9jYrPE2Eqd/view?usp=sharing). 58 | 59 | * [FlowNet2](https://drive.google.com/file/d/1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da/view?usp=sharing)[620MB] 60 | * [FlowNet2-C](https://drive.google.com/file/d/1BFT6b7KgKJC8rA59RmOVAXRM_S7aSfKE/view?usp=sharing)[149MB] 61 | * [FlowNet2-CS](https://drive.google.com/file/d/1iBJ1_o7PloaINpa8m7u_7TsLCX0Dt_jS/view?usp=sharing)[297MB] 62 | * [FlowNet2-CSS](https://drive.google.com/file/d/157zuzVf4YMN6ABAQgZc8rRmR5cgWzSu8/view?usp=sharing)[445MB] 63 | * [FlowNet2-CSS-ft-sd](https://drive.google.com/file/d/1R5xafCIzJCXc8ia4TGfC65irmTNiMg6u/view?usp=sharing)[445MB] 64 | * [FlowNet2-S](https://drive.google.com/file/d/1V61dZjFomwlynwlYklJHC-TLfdFom3Lg/view?usp=sharing)[148MB] 65 | * [FlowNet2-SD](https://drive.google.com/file/d/1QW03eyYG_vD-dT-Mx4wopYvtPu_msTKn/view?usp=sharing)[173MB] 66 | 67 | ## Inference 68 | # Example on MPISintel Clean 69 | python main.py --inference --model FlowNet2 --save_flow --inference_dataset MpiSintelClean \ 70 | --inference_dataset_root /path/to/mpi-sintel/clean/dataset \ 71 | --resume /path/to/checkpoints 72 | 73 | ## Training and validation 74 | 75 | # Example on MPISintel Final and Clean, with L1Loss on FlowNet2 model 76 | python main.py --batch_size 8 --model FlowNet2 --loss=L1Loss --optimizer=Adam --optimizer_lr=1e-4 \ 77 | --training_dataset MpiSintelFinal --training_dataset_root /path/to/mpi-sintel/final/dataset \ 78 | --validation_dataset MpiSintelClean --validation_dataset_root /path/to/mpi-sintel/clean/dataset 79 | 80 | # Example on MPISintel Final and Clean, with MultiScale loss on FlowNet2C model 81 | python main.py --batch_size 8 --model FlowNet2C --optimizer=Adam --optimizer_lr=1e-4 --loss=MultiScale --loss_norm=L1 \ 82 | --loss_numScales=5 --loss_startScale=4 --optimizer_lr=1e-4 --crop_size 384 512 \ 83 | --training_dataset FlyingChairs --training_dataset_root /path/to/flying-chairs/dataset \ 84 | --validation_dataset MpiSintelClean --validation_dataset_root /path/to/mpi-sintel/clean/dataset 85 | 86 | ## Results on MPI-Sintel 87 | [![Predicted flows on MPI-Sintel](./image.png)](https://www.youtube.com/watch?v=HtBmabY8aeU "Predicted flows on MPI-Sintel") 88 | 89 | ## Reference 90 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper: 91 | ```` 92 | @InProceedings{IMKDB17, 93 | author = "E. Ilg and N. Mayer and T. Saikia and M. Keuper and A. Dosovitskiy and T. Brox", 94 | title = "FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks", 95 | booktitle = "IEEE Conference on Computer Vision and Pattern Recognition (CVPR)", 96 | month = "Jul", 97 | year = "2017", 98 | url = "http://lmb.informatik.uni-freiburg.de//Publications/2017/IMKDB17" 99 | } 100 | ```` 101 | ``` 102 | @misc{flownet2-pytorch, 103 | author = {Fitsum Reda and Robert Pottorff and Jon Barker and Bryan Catanzaro}, 104 | title = {flownet2-pytorch: Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks}, 105 | year = {2017}, 106 | publisher = {GitHub}, 107 | journal = {GitHub repository}, 108 | howpublished = {\url{https://github.com/NVIDIA/flownet2-pytorch}} 109 | } 110 | ``` 111 | ## Related Optical Flow Work from Nvidia 112 | Code (in Caffe and Pytorch): [PWC-Net](https://github.com/NVlabs/PWC-Net)
113 | Paper : [PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume](https://arxiv.org/abs/1709.02371). 114 | 115 | ## Acknowledgments 116 | Parts of this code were derived, as noted in the code, from [ClementPinard/FlowNetPytorch](https://github.com/ClementPinard/FlowNetPytorch). 117 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/__init__.py -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/convert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2.7 2 | 3 | import caffe 4 | from caffe.proto import caffe_pb2 5 | import sys, os 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import argparse, tempfile 11 | import numpy as np 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('caffe_model', help='input model in hdf5 or caffemodel format') 15 | parser.add_argument('prototxt_template',help='prototxt template') 16 | parser.add_argument('flownet2_pytorch', help='path to flownet2-pytorch') 17 | 18 | args = parser.parse_args() 19 | 20 | args.rgb_max = 255 21 | args.fp16 = False 22 | args.grads = {} 23 | 24 | # load models 25 | sys.path.append(args.flownet2_pytorch) 26 | 27 | import models 28 | from utils.param_utils import * 29 | 30 | width = 256 31 | height = 256 32 | keys = {'TARGET_WIDTH': width, 33 | 'TARGET_HEIGHT': height, 34 | 'ADAPTED_WIDTH':width, 35 | 'ADAPTED_HEIGHT':height, 36 | 'SCALE_WIDTH':1., 37 | 'SCALE_HEIGHT':1.,} 38 | 39 | template = '\n'.join(np.loadtxt(args.prototxt_template, dtype=str, delimiter='\n')) 40 | for k in keys: 41 | template = template.replace('$%s$'%(k),str(keys[k])) 42 | 43 | prototxt = tempfile.NamedTemporaryFile(mode='w', delete=True) 44 | prototxt.write(template) 45 | prototxt.flush() 46 | 47 | net = caffe.Net(prototxt.name, args.caffe_model, caffe.TEST) 48 | 49 | weights = {} 50 | biases = {} 51 | 52 | for k, v in list(net.params.items()): 53 | weights[k] = np.array(v[0].data).reshape(v[0].data.shape) 54 | biases[k] = np.array(v[1].data).reshape(v[1].data.shape) 55 | print((k, weights[k].shape, biases[k].shape)) 56 | 57 | if 'FlowNet2/' in args.caffe_model: 58 | model = models.FlowNet2(args) 59 | 60 | parse_flownetc(model.flownetc.modules(), weights, biases) 61 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 62 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 63 | parse_flownetsd(model.flownets_d.modules(), weights, biases, param_prefix='netsd_') 64 | parse_flownetfusion(model.flownetfusion.modules(), weights, biases, param_prefix='fuse_') 65 | 66 | state = {'epoch': 0, 67 | 'state_dict': model.state_dict(), 68 | 'best_EPE': 1e10} 69 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2_checkpoint.pth.tar')) 70 | 71 | elif 'FlowNet2-C/' in args.caffe_model: 72 | model = models.FlowNet2C(args) 73 | 74 | parse_flownetc(model.modules(), weights, biases) 75 | state = {'epoch': 0, 76 | 'state_dict': model.state_dict(), 77 | 'best_EPE': 1e10} 78 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-C_checkpoint.pth.tar')) 79 | 80 | elif 'FlowNet2-CS/' in args.caffe_model: 81 | model = models.FlowNet2CS(args) 82 | 83 | parse_flownetc(model.flownetc.modules(), weights, biases) 84 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 85 | 86 | state = {'epoch': 0, 87 | 'state_dict': model.state_dict(), 88 | 'best_EPE': 1e10} 89 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CS_checkpoint.pth.tar')) 90 | 91 | elif 'FlowNet2-CSS/' in args.caffe_model: 92 | model = models.FlowNet2CSS(args) 93 | 94 | parse_flownetc(model.flownetc.modules(), weights, biases) 95 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 96 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 97 | 98 | state = {'epoch': 0, 99 | 'state_dict': model.state_dict(), 100 | 'best_EPE': 1e10} 101 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS_checkpoint.pth.tar')) 102 | 103 | elif 'FlowNet2-CSS-ft-sd/' in args.caffe_model: 104 | model = models.FlowNet2CSS(args) 105 | 106 | parse_flownetc(model.flownetc.modules(), weights, biases) 107 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 108 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 109 | 110 | state = {'epoch': 0, 111 | 'state_dict': model.state_dict(), 112 | 'best_EPE': 1e10} 113 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS-ft-sd_checkpoint.pth.tar')) 114 | 115 | elif 'FlowNet2-S/' in args.caffe_model: 116 | model = models.FlowNet2S(args) 117 | 118 | parse_flownetsonly(model.modules(), weights, biases, param_prefix='') 119 | state = {'epoch': 0, 120 | 'state_dict': model.state_dict(), 121 | 'best_EPE': 1e10} 122 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-S_checkpoint.pth.tar')) 123 | 124 | elif 'FlowNet2-SD/' in args.caffe_model: 125 | model = models.FlowNet2SD(args) 126 | 127 | parse_flownetsd(model.modules(), weights, biases, param_prefix='') 128 | 129 | state = {'epoch': 0, 130 | 'state_dict': model.state_dict(), 131 | 'best_EPE': 1e10} 132 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-SD_checkpoint.pth.tar')) 133 | 134 | else: 135 | print(('model type cound not be determined from input caffe model %s'%(args.caffe_model))) 136 | quit() 137 | print(("done converting ", args.caffe_model)) -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ./networks/correlation_package 3 | python setup.py install --user 4 | cd ../resample2d_package 5 | python setup.py install --user 6 | cd ../channelnorm_package 7 | python setup.py install --user 8 | cd .. 9 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/launch_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 . 3 | sudo nvidia-docker run --rm -ti --volume=$(pwd):/flownet2-pytorch:rw --workdir=/flownet2-pytorch --ipc=host $USER/pytorch:CUDA8-py27 /bin/bash 4 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/losses.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Portions of this code copyright 2017, Clement Pinard 3 | ''' 4 | 5 | # freda (todo) : adversarial loss 6 | 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | 11 | def EPE(input_flow, target_flow): 12 | return torch.norm(target_flow-input_flow,p=2,dim=1).mean() 13 | 14 | class L1(nn.Module): 15 | def __init__(self): 16 | super(L1, self).__init__() 17 | def forward(self, output, target): 18 | lossvalue = torch.abs(output - target).mean() 19 | return lossvalue 20 | 21 | class L2(nn.Module): 22 | def __init__(self): 23 | super(L2, self).__init__() 24 | def forward(self, output, target): 25 | lossvalue = torch.norm(output-target,p=2,dim=1).mean() 26 | return lossvalue 27 | 28 | class L1Loss(nn.Module): 29 | def __init__(self, args): 30 | super(L1Loss, self).__init__() 31 | self.args = args 32 | self.loss = L1() 33 | self.loss_labels = ['L1', 'EPE'] 34 | 35 | def forward(self, output, target): 36 | lossvalue = self.loss(output, target) 37 | epevalue = EPE(output, target) 38 | return [lossvalue, epevalue] 39 | 40 | class L2Loss(nn.Module): 41 | def __init__(self, args): 42 | super(L2Loss, self).__init__() 43 | self.args = args 44 | self.loss = L2() 45 | self.loss_labels = ['L2', 'EPE'] 46 | 47 | def forward(self, output, target): 48 | lossvalue = self.loss(output, target) 49 | epevalue = EPE(output, target) 50 | return [lossvalue, epevalue] 51 | 52 | class MultiScale(nn.Module): 53 | def __init__(self, args, startScale = 4, numScales = 5, l_weight= 0.32, norm= 'L1'): 54 | super(MultiScale,self).__init__() 55 | 56 | self.startScale = startScale 57 | self.numScales = numScales 58 | self.loss_weights = torch.FloatTensor([(l_weight / 2 ** scale) for scale in range(self.numScales)]) 59 | self.args = args 60 | self.l_type = norm 61 | self.div_flow = 0.05 62 | assert(len(self.loss_weights) == self.numScales) 63 | 64 | if self.l_type == 'L1': 65 | self.loss = L1() 66 | else: 67 | self.loss = L2() 68 | 69 | self.multiScales = [nn.AvgPool2d(self.startScale * (2**scale), self.startScale * (2**scale)) for scale in range(self.numScales)] 70 | self.loss_labels = ['MultiScale-'+self.l_type, 'EPE'], 71 | 72 | def forward(self, output, target): 73 | lossvalue = 0 74 | epevalue = 0 75 | 76 | if type(output) is tuple: 77 | target = self.div_flow * target 78 | for i, output_ in enumerate(output): 79 | target_ = self.multiScales[i](target) 80 | epevalue += self.loss_weights[i]*EPE(output_, target_) 81 | lossvalue += self.loss_weights[i]*self.loss(output_, target_) 82 | return [lossvalue, epevalue] 83 | else: 84 | epevalue += EPE(output, target) 85 | lossvalue += self.loss(output, target) 86 | return [lossvalue, epevalue] 87 | 88 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/FlowNetC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .correlation_package.correlation import Correlation 9 | 10 | from .submodules import * 11 | 'Parameter count , 39,175,298 ' 12 | 13 | class FlowNetC(nn.Module): 14 | def __init__(self,args, batchNorm=True, div_flow = 20): 15 | super(FlowNetC,self).__init__() 16 | 17 | self.batchNorm = batchNorm 18 | self.div_flow = div_flow 19 | 20 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2) 21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 23 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1) 24 | 25 | if args.fp16: 26 | self.corr = nn.Sequential( 27 | tofp32(), 28 | Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1), 29 | tofp16()) 30 | else: 31 | self.corr = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1) 32 | 33 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True) 34 | self.conv3_1 = conv(self.batchNorm, 473, 256) 35 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 36 | self.conv4_1 = conv(self.batchNorm, 512, 512) 37 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 38 | self.conv5_1 = conv(self.batchNorm, 512, 512) 39 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 40 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 41 | 42 | self.deconv5 = deconv(1024,512) 43 | self.deconv4 = deconv(1026,256) 44 | self.deconv3 = deconv(770,128) 45 | self.deconv2 = deconv(386,64) 46 | 47 | self.predict_flow6 = predict_flow(1024) 48 | self.predict_flow5 = predict_flow(1026) 49 | self.predict_flow4 = predict_flow(770) 50 | self.predict_flow3 = predict_flow(386) 51 | self.predict_flow2 = predict_flow(194) 52 | 53 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 54 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 55 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 56 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 57 | 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | if m.bias is not None: 61 | init.uniform_(m.bias) 62 | init.xavier_uniform_(m.weight) 63 | 64 | if isinstance(m, nn.ConvTranspose2d): 65 | if m.bias is not None: 66 | init.uniform_(m.bias) 67 | init.xavier_uniform_(m.weight) 68 | # init_deconv_bilinear(m.weight) 69 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 70 | 71 | def forward(self, x): 72 | x1 = x[:,0:3,:,:] 73 | x2 = x[:,3::,:,:] 74 | 75 | out_conv1a = self.conv1(x1) 76 | out_conv2a = self.conv2(out_conv1a) 77 | out_conv3a = self.conv3(out_conv2a) 78 | 79 | # FlownetC bottom input stream 80 | out_conv1b = self.conv1(x2) 81 | 82 | out_conv2b = self.conv2(out_conv1b) 83 | out_conv3b = self.conv3(out_conv2b) 84 | 85 | # Merge streams 86 | out_corr = self.corr(out_conv3a, out_conv3b) # False 87 | out_corr = self.corr_activation(out_corr) 88 | 89 | # Redirect top input stream and concatenate 90 | out_conv_redir = self.conv_redir(out_conv3a) 91 | 92 | in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) 93 | 94 | # Merged conv layers 95 | out_conv3_1 = self.conv3_1(in_conv3_1) 96 | 97 | out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) 98 | 99 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 100 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 101 | 102 | flow6 = self.predict_flow6(out_conv6) 103 | flow6_up = self.upsampled_flow6_to_5(flow6) 104 | out_deconv5 = self.deconv5(out_conv6) 105 | 106 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 107 | 108 | flow5 = self.predict_flow5(concat5) 109 | flow5_up = self.upsampled_flow5_to_4(flow5) 110 | out_deconv4 = self.deconv4(concat5) 111 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 112 | 113 | flow4 = self.predict_flow4(concat4) 114 | flow4_up = self.upsampled_flow4_to_3(flow4) 115 | out_deconv3 = self.deconv3(concat4) 116 | concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1) 117 | 118 | flow3 = self.predict_flow3(concat3) 119 | flow3_up = self.upsampled_flow3_to_2(flow3) 120 | out_deconv2 = self.deconv2(concat3) 121 | concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1) 122 | 123 | flow2 = self.predict_flow2(concat2) 124 | 125 | if self.training: 126 | return flow2,flow3,flow4,flow5,flow6 127 | else: 128 | return flow2, 129 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/FlowNetFusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .submodules import * 9 | 'Parameter count = 581,226' 10 | 11 | class FlowNetFusion(nn.Module): 12 | def __init__(self,args, batchNorm=True): 13 | super(FlowNetFusion,self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv0 = conv(self.batchNorm, 11, 64) 17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2) 18 | self.conv1_1 = conv(self.batchNorm, 64, 128) 19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2) 20 | self.conv2_1 = conv(self.batchNorm, 128, 128) 21 | 22 | self.deconv1 = deconv(128,32) 23 | self.deconv0 = deconv(162,16) 24 | 25 | self.inter_conv1 = i_conv(self.batchNorm, 162, 32) 26 | self.inter_conv0 = i_conv(self.batchNorm, 82, 16) 27 | 28 | self.predict_flow2 = predict_flow(128) 29 | self.predict_flow1 = predict_flow(32) 30 | self.predict_flow0 = predict_flow(16) 31 | 32 | self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 33 | self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 34 | 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | if m.bias is not None: 38 | init.uniform_(m.bias) 39 | init.xavier_uniform_(m.weight) 40 | 41 | if isinstance(m, nn.ConvTranspose2d): 42 | if m.bias is not None: 43 | init.uniform_(m.bias) 44 | init.xavier_uniform_(m.weight) 45 | # init_deconv_bilinear(m.weight) 46 | 47 | def forward(self, x): 48 | out_conv0 = self.conv0(x) 49 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 50 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 51 | 52 | flow2 = self.predict_flow2(out_conv2) 53 | flow2_up = self.upsampled_flow2_to_1(flow2) 54 | out_deconv1 = self.deconv1(out_conv2) 55 | 56 | concat1 = torch.cat((out_conv1,out_deconv1,flow2_up),1) 57 | out_interconv1 = self.inter_conv1(concat1) 58 | flow1 = self.predict_flow1(out_interconv1) 59 | flow1_up = self.upsampled_flow1_to_0(flow1) 60 | out_deconv0 = self.deconv0(concat1) 61 | 62 | concat0 = torch.cat((out_conv0,out_deconv0,flow1_up),1) 63 | out_interconv0 = self.inter_conv0(concat0) 64 | flow0 = self.predict_flow0(out_interconv0) 65 | 66 | return flow0 67 | 68 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/FlowNetS.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Portions of this code copyright 2017, Clement Pinard 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | import math 10 | import numpy as np 11 | 12 | from .submodules import * 13 | 'Parameter count : 38,676,504 ' 14 | 15 | class FlowNetS(nn.Module): 16 | def __init__(self, args, input_channels = 12, batchNorm=True): 17 | super(FlowNetS,self).__init__() 18 | 19 | self.batchNorm = batchNorm 20 | self.conv1 = conv(self.batchNorm, input_channels, 64, kernel_size=7, stride=2) 21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 23 | self.conv3_1 = conv(self.batchNorm, 256, 256) 24 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 25 | self.conv4_1 = conv(self.batchNorm, 512, 512) 26 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 27 | self.conv5_1 = conv(self.batchNorm, 512, 512) 28 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 29 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 30 | 31 | self.deconv5 = deconv(1024,512) 32 | self.deconv4 = deconv(1026,256) 33 | self.deconv3 = deconv(770,128) 34 | self.deconv2 = deconv(386,64) 35 | 36 | self.predict_flow6 = predict_flow(1024) 37 | self.predict_flow5 = predict_flow(1026) 38 | self.predict_flow4 = predict_flow(770) 39 | self.predict_flow3 = predict_flow(386) 40 | self.predict_flow2 = predict_flow(194) 41 | 42 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 43 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 44 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 45 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | if m.bias is not None: 50 | init.uniform_(m.bias) 51 | init.xavier_uniform_(m.weight) 52 | 53 | if isinstance(m, nn.ConvTranspose2d): 54 | if m.bias is not None: 55 | init.uniform_(m.bias) 56 | init.xavier_uniform_(m.weight) 57 | # init_deconv_bilinear(m.weight) 58 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 59 | 60 | def forward(self, x): 61 | out_conv1 = self.conv1(x) 62 | 63 | out_conv2 = self.conv2(out_conv1) 64 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 65 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 66 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 67 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 68 | 69 | flow6 = self.predict_flow6(out_conv6) 70 | flow6_up = self.upsampled_flow6_to_5(flow6) 71 | out_deconv5 = self.deconv5(out_conv6) 72 | 73 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 74 | flow5 = self.predict_flow5(concat5) 75 | flow5_up = self.upsampled_flow5_to_4(flow5) 76 | out_deconv4 = self.deconv4(concat5) 77 | 78 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 79 | flow4 = self.predict_flow4(concat4) 80 | flow4_up = self.upsampled_flow4_to_3(flow4) 81 | out_deconv3 = self.deconv3(concat4) 82 | 83 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) 84 | flow3 = self.predict_flow3(concat3) 85 | flow3_up = self.upsampled_flow3_to_2(flow3) 86 | out_deconv2 = self.deconv2(concat3) 87 | 88 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) 89 | flow2 = self.predict_flow2(concat2) 90 | 91 | if self.training: 92 | return flow2,flow3,flow4,flow5,flow6 93 | else: 94 | return flow2, 95 | 96 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/FlowNetSD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .submodules import * 9 | 'Parameter count = 45,371,666' 10 | 11 | class FlowNetSD(nn.Module): 12 | def __init__(self, args, batchNorm=True): 13 | super(FlowNetSD,self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv0 = conv(self.batchNorm, 6, 64) 17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2) 18 | self.conv1_1 = conv(self.batchNorm, 64, 128) 19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2) 20 | self.conv2_1 = conv(self.batchNorm, 128, 128) 21 | self.conv3 = conv(self.batchNorm, 128, 256, stride=2) 22 | self.conv3_1 = conv(self.batchNorm, 256, 256) 23 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 24 | self.conv4_1 = conv(self.batchNorm, 512, 512) 25 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 26 | self.conv5_1 = conv(self.batchNorm, 512, 512) 27 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 28 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 29 | 30 | self.deconv5 = deconv(1024,512) 31 | self.deconv4 = deconv(1026,256) 32 | self.deconv3 = deconv(770,128) 33 | self.deconv2 = deconv(386,64) 34 | 35 | self.inter_conv5 = i_conv(self.batchNorm, 1026, 512) 36 | self.inter_conv4 = i_conv(self.batchNorm, 770, 256) 37 | self.inter_conv3 = i_conv(self.batchNorm, 386, 128) 38 | self.inter_conv2 = i_conv(self.batchNorm, 194, 64) 39 | 40 | self.predict_flow6 = predict_flow(1024) 41 | self.predict_flow5 = predict_flow(512) 42 | self.predict_flow4 = predict_flow(256) 43 | self.predict_flow3 = predict_flow(128) 44 | self.predict_flow2 = predict_flow(64) 45 | 46 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 47 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 48 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 49 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 50 | 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | if m.bias is not None: 54 | init.uniform_(m.bias) 55 | init.xavier_uniform_(m.weight) 56 | 57 | if isinstance(m, nn.ConvTranspose2d): 58 | if m.bias is not None: 59 | init.uniform_(m.bias) 60 | init.xavier_uniform_(m.weight) 61 | # init_deconv_bilinear(m.weight) 62 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 63 | 64 | 65 | 66 | def forward(self, x): 67 | out_conv0 = self.conv0(x) 68 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 69 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 70 | 71 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 72 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 73 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 74 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 75 | 76 | flow6 = self.predict_flow6(out_conv6) 77 | flow6_up = self.upsampled_flow6_to_5(flow6) 78 | out_deconv5 = self.deconv5(out_conv6) 79 | 80 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 81 | out_interconv5 = self.inter_conv5(concat5) 82 | flow5 = self.predict_flow5(out_interconv5) 83 | 84 | flow5_up = self.upsampled_flow5_to_4(flow5) 85 | out_deconv4 = self.deconv4(concat5) 86 | 87 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 88 | out_interconv4 = self.inter_conv4(concat4) 89 | flow4 = self.predict_flow4(out_interconv4) 90 | flow4_up = self.upsampled_flow4_to_3(flow4) 91 | out_deconv3 = self.deconv3(concat4) 92 | 93 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) 94 | out_interconv3 = self.inter_conv3(concat3) 95 | flow3 = self.predict_flow3(out_interconv3) 96 | flow3_up = self.upsampled_flow3_to_2(flow3) 97 | out_deconv2 = self.deconv2(concat3) 98 | 99 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) 100 | out_interconv2 = self.inter_conv2(concat2) 101 | flow2 = self.predict_flow2(out_interconv2) 102 | 103 | if self.training: 104 | return flow2,flow3,flow4,flow5,flow6 105 | else: 106 | return flow2, 107 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/networks/__init__.py -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/channelnorm_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/networks/channelnorm_package/__init__.py -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/channelnorm_package/channelnorm.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function, Variable 2 | from torch.nn.modules.module import Module 3 | import channelnorm_cuda 4 | 5 | class ChannelNormFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, norm_deg=2): 9 | assert input1.is_contiguous() 10 | b, _, h, w = input1.size() 11 | output = input1.new(b, 1, h, w).zero_() 12 | 13 | channelnorm_cuda.forward(input1, output, norm_deg) 14 | ctx.save_for_backward(input1, output) 15 | ctx.norm_deg = norm_deg 16 | 17 | return output 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | input1, output = ctx.saved_tensors 22 | 23 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 24 | 25 | channelnorm.backward(input1, output, grad_output.data, 26 | grad_input1.data, ctx.norm_deg) 27 | 28 | return grad_input1, None 29 | 30 | 31 | class ChannelNorm(Module): 32 | 33 | def __init__(self, norm_deg=2): 34 | super(ChannelNorm, self).__init__() 35 | self.norm_deg = norm_deg 36 | 37 | def forward(self, input1): 38 | return ChannelNormFunction.apply(input1, self.norm_deg) 39 | 40 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/channelnorm_package/channelnorm_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "channelnorm_kernel.cuh" 5 | 6 | int channelnorm_cuda_forward( 7 | at::Tensor& input1, 8 | at::Tensor& output, 9 | int norm_deg) { 10 | 11 | channelnorm_kernel_forward(input1, output, norm_deg); 12 | return 1; 13 | } 14 | 15 | 16 | int channelnorm_cuda_backward( 17 | at::Tensor& input1, 18 | at::Tensor& output, 19 | at::Tensor& gradOutput, 20 | at::Tensor& gradInput1, 21 | int norm_deg) { 22 | 23 | channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); 24 | return 1; 25 | } 26 | 27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 28 | m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); 29 | m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); 30 | } 31 | 32 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/channelnorm_package/channelnorm_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | 5 | #include "channelnorm_kernel.cuh" 6 | 7 | #define CUDA_NUM_THREADS 512 8 | 9 | #define DIM0(TENSOR) ((TENSOR).x) 10 | #define DIM1(TENSOR) ((TENSOR).y) 11 | #define DIM2(TENSOR) ((TENSOR).z) 12 | #define DIM3(TENSOR) ((TENSOR).w) 13 | 14 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 15 | 16 | using at::Half; 17 | 18 | template 19 | __global__ void kernel_channelnorm_update_output( 20 | const int n, 21 | const scalar_t* __restrict__ input1, 22 | const long4 input1_size, 23 | const long4 input1_stride, 24 | scalar_t* __restrict__ output, 25 | const long4 output_size, 26 | const long4 output_stride, 27 | int norm_deg) { 28 | 29 | int index = blockIdx.x * blockDim.x + threadIdx.x; 30 | 31 | if (index >= n) { 32 | return; 33 | } 34 | 35 | int dim_b = DIM0(output_size); 36 | int dim_c = DIM1(output_size); 37 | int dim_h = DIM2(output_size); 38 | int dim_w = DIM3(output_size); 39 | int dim_chw = dim_c * dim_h * dim_w; 40 | 41 | int b = ( index / dim_chw ) % dim_b; 42 | int y = ( index / dim_w ) % dim_h; 43 | int x = ( index ) % dim_w; 44 | 45 | int i1dim_c = DIM1(input1_size); 46 | int i1dim_h = DIM2(input1_size); 47 | int i1dim_w = DIM3(input1_size); 48 | int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; 49 | int i1dim_hw = i1dim_h * i1dim_w; 50 | 51 | float result = 0.0; 52 | 53 | for (int c = 0; c < i1dim_c; ++c) { 54 | int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; 55 | scalar_t val = input1[i1Index]; 56 | result += static_cast(val * val); 57 | } 58 | result = sqrt(result); 59 | output[index] = static_cast(result); 60 | } 61 | 62 | 63 | template 64 | __global__ void kernel_channelnorm_backward_input1( 65 | const int n, 66 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 67 | const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, 68 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 69 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, 70 | int norm_deg) { 71 | 72 | int index = blockIdx.x * blockDim.x + threadIdx.x; 73 | 74 | if (index >= n) { 75 | return; 76 | } 77 | 78 | float val = 0.0; 79 | 80 | int dim_b = DIM0(gradInput_size); 81 | int dim_c = DIM1(gradInput_size); 82 | int dim_h = DIM2(gradInput_size); 83 | int dim_w = DIM3(gradInput_size); 84 | int dim_chw = dim_c * dim_h * dim_w; 85 | int dim_hw = dim_h * dim_w; 86 | 87 | int b = ( index / dim_chw ) % dim_b; 88 | int y = ( index / dim_w ) % dim_h; 89 | int x = ( index ) % dim_w; 90 | 91 | 92 | int outIndex = b * dim_hw + y * dim_w + x; 93 | val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9); 94 | gradInput[index] = static_cast(val); 95 | 96 | } 97 | 98 | void channelnorm_kernel_forward( 99 | at::Tensor& input1, 100 | at::Tensor& output, 101 | int norm_deg) { 102 | 103 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 104 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 105 | 106 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 107 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 108 | 109 | int n = output.numel(); 110 | 111 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] { 112 | 113 | kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 114 | //at::globalContext().getCurrentCUDAStream() >>>( 115 | n, 116 | input1.data(), 117 | input1_size, 118 | input1_stride, 119 | output.data(), 120 | output_size, 121 | output_stride, 122 | norm_deg); 123 | 124 | })); 125 | 126 | // TODO: ATen-equivalent check 127 | 128 | // THCudaCheck(cudaGetLastError()); 129 | } 130 | 131 | void channelnorm_kernel_backward( 132 | at::Tensor& input1, 133 | at::Tensor& output, 134 | at::Tensor& gradOutput, 135 | at::Tensor& gradInput1, 136 | int norm_deg) { 137 | 138 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 139 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 140 | 141 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 142 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 143 | 144 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); 145 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); 146 | 147 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); 148 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); 149 | 150 | int n = gradInput1.numel(); 151 | 152 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] { 153 | 154 | kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>( 155 | //at::globalContext().getCurrentCUDAStream() >>>( 156 | n, 157 | input1.data(), 158 | input1_size, 159 | input1_stride, 160 | output.data(), 161 | output_size, 162 | output_stride, 163 | gradOutput.data(), 164 | gradOutput_size, 165 | gradOutput_stride, 166 | gradInput1.data(), 167 | gradInput1_size, 168 | gradInput1_stride, 169 | norm_deg 170 | ); 171 | 172 | })); 173 | 174 | // TODO: Add ATen-equivalent check 175 | 176 | // THCudaCheck(cudaGetLastError()); 177 | } 178 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/channelnorm_package/channelnorm_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void channelnorm_kernel_forward( 6 | at::Tensor& input1, 7 | at::Tensor& output, 8 | int norm_deg); 9 | 10 | 11 | void channelnorm_kernel_backward( 12 | at::Tensor& input1, 13 | at::Tensor& output, 14 | at::Tensor& gradOutput, 15 | at::Tensor& gradInput1, 16 | int norm_deg); 17 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/channelnorm_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_52,code=sm_52', 12 | '-gencode', 'arch=compute_60,code=sm_60', 13 | '-gencode', 'arch=compute_61,code=sm_61', 14 | '-gencode', 'arch=compute_70,code=sm_70', 15 | '-gencode', 'arch=compute_70,code=compute_70' 16 | ] 17 | 18 | setup( 19 | name='channelnorm_cuda', 20 | ext_modules=[ 21 | CUDAExtension('channelnorm_cuda', [ 22 | 'channelnorm_cuda.cc', 23 | 'channelnorm_kernel.cu' 24 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 25 | ], 26 | cmdclass={ 27 | 'build_ext': BuildExtension 28 | }) 29 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/correlation_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/networks/correlation_package/__init__.py -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/correlation_package/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.autograd import Function 4 | import correlation_cuda 5 | 6 | class CorrelationFunction(Function): 7 | 8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): 9 | super(CorrelationFunction, self).__init__() 10 | self.pad_size = pad_size 11 | self.kernel_size = kernel_size 12 | self.max_displacement = max_displacement 13 | self.stride1 = stride1 14 | self.stride2 = stride2 15 | self.corr_multiply = corr_multiply 16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) 17 | 18 | @staticmethod 19 | def forward(ctx, input1, input2, pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply): 20 | ctx.save_for_backward(input1, input2) 21 | ctx.pad_size = pad_size 22 | ctx.kernel_size = kernel_size 23 | ctx.max_displacement = max_displacement 24 | ctx.stride1 = stride1 25 | ctx.stride2 = stride2 26 | ctx.corr_multiply = corr_multiply 27 | 28 | with torch.cuda.device_of(input1): 29 | rbot1 = input1.new() 30 | rbot2 = input2.new() 31 | output = input1.new() 32 | 33 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 34 | ctx.pad_size, ctx.kernel_size, ctx.max_displacement,ctx.stride1, ctx.stride2, ctx.corr_multiply) 35 | 36 | return output 37 | 38 | @staticmethod 39 | def backward(ctx, grad_output): 40 | input1, input2 = ctx.saved_tensors 41 | 42 | with torch.cuda.device_of(input1): 43 | rbot1 = input1.new() 44 | rbot2 = input2.new() 45 | 46 | grad_input1 = input1.new() 47 | grad_input2 = input2.new() 48 | 49 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, 50 | ctx.pad_size, ctx.kernel_size, ctx.max_displacement,ctx.stride1, ctx.stride2, ctx.corr_multiply) 51 | 52 | return grad_input1, grad_input2 53 | 54 | 55 | class Correlation(Module): 56 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): 57 | super(Correlation, self).__init__() 58 | self.pad_size = pad_size 59 | self.kernel_size = kernel_size 60 | self.max_displacement = max_displacement 61 | self.stride1 = stride1 62 | self.stride2 = stride2 63 | self.corr_multiply = corr_multiply 64 | 65 | def forward(self, input1, input2): 66 | #result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,\ 67 | # self.stride1, self.stride2, self.corr_multiply)(input1, input2) 68 | result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,\ 69 | self.stride1, self.stride2, self.corr_multiply) 70 | return result 71 | 72 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/correlation_package/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | #include 6 | #include 7 | 8 | #include "correlation_cuda_kernel.cuh" 9 | 10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 11 | int pad_size, 12 | int kernel_size, 13 | int max_displacement, 14 | int stride1, 15 | int stride2, 16 | int corr_type_multiply) 17 | { 18 | 19 | int batchSize = input1.size(0); 20 | 21 | int nInputChannels = input1.size(1); 22 | int inputHeight = input1.size(2); 23 | int inputWidth = input1.size(3); 24 | 25 | int kernel_radius = (kernel_size - 1) / 2; 26 | int border_radius = kernel_radius + max_displacement; 27 | 28 | int paddedInputHeight = inputHeight + 2 * pad_size; 29 | int paddedInputWidth = inputWidth + 2 * pad_size; 30 | 31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 32 | 33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 35 | 36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 39 | 40 | rInput1.fill_(0); 41 | rInput2.fill_(0); 42 | output.fill_(0); 43 | 44 | int success = correlation_forward_cuda_kernel( 45 | output, 46 | output.size(0), 47 | output.size(1), 48 | output.size(2), 49 | output.size(3), 50 | output.stride(0), 51 | output.stride(1), 52 | output.stride(2), 53 | output.stride(3), 54 | input1, 55 | input1.size(1), 56 | input1.size(2), 57 | input1.size(3), 58 | input1.stride(0), 59 | input1.stride(1), 60 | input1.stride(2), 61 | input1.stride(3), 62 | input2, 63 | input2.size(1), 64 | input2.stride(0), 65 | input2.stride(1), 66 | input2.stride(2), 67 | input2.stride(3), 68 | rInput1, 69 | rInput2, 70 | pad_size, 71 | kernel_size, 72 | max_displacement, 73 | stride1, 74 | stride2, 75 | corr_type_multiply, 76 | at::cuda::getCurrentCUDAStream() 77 | //at::globalContext().getCurrentCUDAStream() 78 | ); 79 | 80 | //check for errors 81 | if (!success) { 82 | AT_ERROR("CUDA call failed"); 83 | } 84 | 85 | return 1; 86 | 87 | } 88 | 89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 90 | at::Tensor& gradInput1, at::Tensor& gradInput2, 91 | int pad_size, 92 | int kernel_size, 93 | int max_displacement, 94 | int stride1, 95 | int stride2, 96 | int corr_type_multiply) 97 | { 98 | 99 | int batchSize = input1.size(0); 100 | int nInputChannels = input1.size(1); 101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 103 | 104 | int height = input1.size(2); 105 | int width = input1.size(3); 106 | 107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 109 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 110 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 111 | 112 | rInput1.fill_(0); 113 | rInput2.fill_(0); 114 | gradInput1.fill_(0); 115 | gradInput2.fill_(0); 116 | 117 | int success = correlation_backward_cuda_kernel(gradOutput, 118 | gradOutput.size(0), 119 | gradOutput.size(1), 120 | gradOutput.size(2), 121 | gradOutput.size(3), 122 | gradOutput.stride(0), 123 | gradOutput.stride(1), 124 | gradOutput.stride(2), 125 | gradOutput.stride(3), 126 | input1, 127 | input1.size(1), 128 | input1.size(2), 129 | input1.size(3), 130 | input1.stride(0), 131 | input1.stride(1), 132 | input1.stride(2), 133 | input1.stride(3), 134 | input2, 135 | input2.stride(0), 136 | input2.stride(1), 137 | input2.stride(2), 138 | input2.stride(3), 139 | gradInput1, 140 | gradInput1.stride(0), 141 | gradInput1.stride(1), 142 | gradInput1.stride(2), 143 | gradInput1.stride(3), 144 | gradInput2, 145 | gradInput2.size(1), 146 | gradInput2.stride(0), 147 | gradInput2.stride(1), 148 | gradInput2.stride(2), 149 | gradInput2.stride(3), 150 | rInput1, 151 | rInput2, 152 | pad_size, 153 | kernel_size, 154 | max_displacement, 155 | stride1, 156 | stride2, 157 | corr_type_multiply, 158 | at::cuda::getCurrentCUDAStream() 159 | //at::globalContext().getCurrentCUDAStream() 160 | ); 161 | 162 | if (!success) { 163 | AT_ERROR("CUDA call failed"); 164 | } 165 | 166 | return 1; 167 | } 168 | 169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 170 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 171 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 172 | } 173 | 174 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/correlation_package/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/correlation_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/resample2d_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/networks/resample2d_package/__init__.py -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/resample2d_package/resample2d.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.autograd import Function, Variable 3 | import resample2d_cuda 4 | 5 | class Resample2dFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, input2, kernel_size=1): 9 | assert input1.is_contiguous() 10 | assert input2.is_contiguous() 11 | 12 | ctx.save_for_backward(input1, input2) 13 | ctx.kernel_size = kernel_size 14 | 15 | _, d, _, _ = input1.size() 16 | b, _, h, w = input2.size() 17 | output = input1.new(b, d, h, w).zero_() 18 | 19 | resample2d_cuda.forward(input1, input2, output, kernel_size) 20 | 21 | return output 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | assert grad_output.is_contiguous() 26 | 27 | input1, input2 = ctx.saved_tensors 28 | 29 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 30 | grad_input2 = Variable(input1.new(input2.size()).zero_()) 31 | 32 | resample2d_cuda.backward(input1, input2, grad_output.data, 33 | grad_input1.data, grad_input2.data, 34 | ctx.kernel_size) 35 | 36 | return grad_input1, grad_input2, None 37 | 38 | class Resample2d(Module): 39 | 40 | def __init__(self, kernel_size=1): 41 | super(Resample2d, self).__init__() 42 | self.kernel_size = kernel_size 43 | 44 | def forward(self, input1, input2): 45 | input1_c = input1.contiguous() 46 | return Resample2dFunction.apply(input1_c, input2, self.kernel_size) 47 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/resample2d_package/resample2d_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "resample2d_kernel.cuh" 5 | 6 | int resample2d_cuda_forward( 7 | at::Tensor& input1, 8 | at::Tensor& input2, 9 | at::Tensor& output, 10 | int kernel_size) { 11 | resample2d_kernel_forward(input1, input2, output, kernel_size); 12 | return 1; 13 | } 14 | 15 | int resample2d_cuda_backward( 16 | at::Tensor& input1, 17 | at::Tensor& input2, 18 | at::Tensor& gradOutput, 19 | at::Tensor& gradInput1, 20 | at::Tensor& gradInput2, 21 | int kernel_size) { 22 | resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size); 23 | return 1; 24 | } 25 | 26 | 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); 30 | m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); 31 | } 32 | 33 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/resample2d_package/resample2d_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void resample2d_kernel_forward( 6 | at::Tensor& input1, 7 | at::Tensor& input2, 8 | at::Tensor& output, 9 | int kernel_size); 10 | 11 | void resample2d_kernel_backward( 12 | at::Tensor& input1, 13 | at::Tensor& input2, 14 | at::Tensor& gradOutput, 15 | at::Tensor& gradInput1, 16 | at::Tensor& gradInput2, 17 | int kernel_size); 18 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/resample2d_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='resample2d_cuda', 21 | ext_modules=[ 22 | CUDAExtension('resample2d_cuda', [ 23 | 'resample2d_cuda.cc', 24 | 'resample2d_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/networks/submodules.py: -------------------------------------------------------------------------------- 1 | # freda (todo) : 2 | 3 | import torch.nn as nn 4 | import torch 5 | import numpy as np 6 | 7 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1): 8 | if batchNorm: 9 | return nn.Sequential( 10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False), 11 | nn.BatchNorm2d(out_planes), 12 | nn.LeakyReLU(0.1,inplace=True) 13 | ) 14 | else: 15 | return nn.Sequential( 16 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 17 | nn.LeakyReLU(0.1,inplace=True) 18 | ) 19 | 20 | def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True): 21 | if batchNorm: 22 | return nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 24 | nn.BatchNorm2d(out_planes), 25 | ) 26 | else: 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 29 | ) 30 | 31 | def predict_flow(in_planes): 32 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) 33 | 34 | def deconv(in_planes, out_planes): 35 | return nn.Sequential( 36 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 37 | nn.LeakyReLU(0.1,inplace=True) 38 | ) 39 | 40 | class tofp16(nn.Module): 41 | def __init__(self): 42 | super(tofp16, self).__init__() 43 | 44 | def forward(self, input): 45 | return input.half() 46 | 47 | 48 | class tofp32(nn.Module): 49 | def __init__(self): 50 | super(tofp32, self).__init__() 51 | 52 | def forward(self, input): 53 | return input.float() 54 | 55 | 56 | def init_deconv_bilinear(weight): 57 | f_shape = weight.size() 58 | heigh, width = f_shape[-2], f_shape[-1] 59 | f = np.ceil(width/2.0) 60 | c = (2 * f - 1 - f % 2) / (2.0 * f) 61 | bilinear = np.zeros([heigh, width]) 62 | for x in range(width): 63 | for y in range(heigh): 64 | value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) 65 | bilinear[x, y] = value 66 | weight.data.fill_(0.) 67 | for i in range(f_shape[0]): 68 | for j in range(f_shape[1]): 69 | weight.data[i,j,:,:] = torch.from_numpy(bilinear) 70 | 71 | 72 | def save_grad(grads, name): 73 | def hook(grad): 74 | grads[name] = grad 75 | return hook 76 | 77 | ''' 78 | def save_grad(grads, name): 79 | def hook(grad): 80 | grads[name] = grad 81 | return hook 82 | import torch 83 | from channelnorm_package.modules.channelnorm import ChannelNorm 84 | model = ChannelNorm().cuda() 85 | grads = {} 86 | a = 100*torch.autograd.Variable(torch.randn((1,3,5,5)).cuda(), requires_grad=True) 87 | a.register_hook(save_grad(grads, 'a')) 88 | b = model(a) 89 | y = torch.mean(b) 90 | y.backward() 91 | 92 | ''' 93 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/run-caffe2pytorch.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | #!/bin/bash 9 | 10 | FN2PYTORCH=${1:-/} 11 | 12 | # install custom layers 13 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 . 14 | sudo nvidia-docker run --rm -ti --volume=${FN2PYTORCH}:/flownet2-pytorch:rw --workdir=/flownet2-pytorch $USER/pytorch:CUDA8-py27 /bin/bash -c "./install.sh" 15 | 16 | # convert FlowNet2-C, CS, CSS, CSS-ft-sd, SD, S and 2 to PyTorch 17 | sudo nvidia-docker run -ti --volume=${FN2PYTORCH}:/fn2pytorch:rw flownet2:latest /bin/bash -c "source /flownet2/flownet2/set-env.sh && cd /flownet2/flownet2/models && \ 18 | python /fn2pytorch/convert.py ./FlowNet2-C/FlowNet2-C_weights.caffemodel ./FlowNet2-C/FlowNet2-C_deploy.prototxt.template /fn2pytorch && 19 | python /fn2pytorch/convert.py ./FlowNet2-CS/FlowNet2-CS_weights.caffemodel ./FlowNet2-CS/FlowNet2-CS_deploy.prototxt.template /fn2pytorch && \ 20 | python /fn2pytorch/convert.py ./FlowNet2-CSS/FlowNet2-CSS_weights.caffemodel.h5 ./FlowNet2-CSS/FlowNet2-CSS_deploy.prototxt.template /fn2pytorch && \ 21 | python /fn2pytorch/convert.py ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_weights.caffemodel.h5 ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_deploy.prototxt.template /fn2pytorch && \ 22 | python /fn2pytorch/convert.py ./FlowNet2-SD/FlowNet2-SD_weights.caffemodel.h5 ./FlowNet2-SD/FlowNet2-SD_deploy.prototxt.template /fn2pytorch && \ 23 | python /fn2pytorch/convert.py ./FlowNet2-S/FlowNet2-S_weights.caffemodel.h5 ./FlowNet2-S/FlowNet2-S_deploy.prototxt.template /fn2pytorch && \ 24 | python /fn2pytorch/convert.py ./FlowNet2/FlowNet2_weights.caffemodel.h5 ./FlowNet2/FlowNet2_deploy.prototxt.template /fn2pytorch" -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/utils/__init__.py -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | TAG_CHAR = np.array([202021.25], np.float32) 4 | 5 | def readFlow(fn): 6 | """ Read .flo file in Middlebury format""" 7 | # Code adapted from: 8 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 9 | 10 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 11 | # print 'fn = %s'%(fn) 12 | with open(fn, 'rb') as f: 13 | magic = np.fromfile(f, np.float32, count=1) 14 | if 202021.25 != magic: 15 | print('Magic number incorrect. Invalid .flo file') 16 | return None 17 | else: 18 | w = np.fromfile(f, np.int32, count=1) 19 | h = np.fromfile(f, np.int32, count=1) 20 | # print 'Reading %d x %d flo file\n' % (w, h) 21 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 22 | # Reshape data into 3D array (columns, rows, bands) 23 | # The reshape here is for visualization, the original code is (w,h,2) 24 | return np.resize(data, (int(h), int(w), 2)) 25 | 26 | def writeFlow(filename,uv,v=None): 27 | """ Write optical flow to file. 28 | 29 | If v is None, uv is assumed to contain both u and v channels, 30 | stacked in depth. 31 | Original code by Deqing Sun, adapted from Daniel Scharstein. 32 | """ 33 | nBands = 2 34 | 35 | if v is None: 36 | assert(uv.ndim == 3) 37 | assert(uv.shape[2] == 2) 38 | u = uv[:,:,0] 39 | v = uv[:,:,1] 40 | else: 41 | u = uv 42 | 43 | assert(u.shape == v.shape) 44 | height,width = u.shape 45 | f = open(filename,'wb') 46 | # write the header 47 | f.write(TAG_CHAR) 48 | np.array(width).astype(np.int32).tofile(f) 49 | np.array(height).astype(np.int32).tofile(f) 50 | # arrange into matrix form 51 | tmp = np.zeros((height, width*nBands)) 52 | tmp[:,np.arange(width)*2] = u 53 | tmp[:,np.arange(width)*2 + 1] = v 54 | tmp.astype(np.float32).tofile(f) 55 | f.close() 56 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os.path import * 3 | from scipy.misc import imread 4 | from . import flow_utils 5 | 6 | def read_gen(file_name): 7 | ext = splitext(file_name)[-1] 8 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 9 | im = imread(file_name) 10 | if im.shape[2] > 3: 11 | return im[:,:,:3] 12 | else: 13 | return im 14 | elif ext == '.bin' or ext == '.raw': 15 | return np.load(file_name) 16 | elif ext == '.flo': 17 | return flow_utils.readFlow(file_name).astype(np.float32) 18 | return [] 19 | -------------------------------------------------------------------------------- /models/networks/flownet2_pytorch/utils/tools.py: -------------------------------------------------------------------------------- 1 | # freda (todo) : 2 | 3 | import os, time, sys, math 4 | import subprocess, shutil 5 | from os.path import * 6 | import numpy as np 7 | from inspect import isclass 8 | from pytz import timezone 9 | from datetime import datetime 10 | import inspect 11 | import torch 12 | 13 | def datestr(): 14 | pacific = timezone('US/Pacific') 15 | now = datetime.now(pacific) 16 | return '{}{:02}{:02}_{:02}{:02}'.format(now.year, now.month, now.day, now.hour, now.minute) 17 | 18 | def module_to_dict(module, exclude=[]): 19 | return dict([(x, getattr(module, x)) for x in dir(module) 20 | if isclass(getattr(module, x)) 21 | and x not in exclude 22 | and getattr(module, x) not in exclude]) 23 | 24 | class TimerBlock: 25 | def __init__(self, title): 26 | print(("{}".format(title))) 27 | 28 | def __enter__(self): 29 | self.start = time.clock() 30 | return self 31 | 32 | def __exit__(self, exc_type, exc_value, traceback): 33 | self.end = time.clock() 34 | self.interval = self.end - self.start 35 | 36 | if exc_type is not None: 37 | self.log("Operation failed\n") 38 | else: 39 | self.log("Operation finished\n") 40 | 41 | 42 | def log(self, string): 43 | duration = time.clock() - self.start 44 | units = 's' 45 | if duration > 60: 46 | duration = duration / 60. 47 | units = 'm' 48 | print((" [{:.3f}{}] {}".format(duration, units, string))) 49 | 50 | def log2file(self, fid, string): 51 | fid = open(fid, 'a') 52 | fid.write("%s\n"%(string)) 53 | fid.close() 54 | 55 | def add_arguments_for_module(parser, module, argument_for_class, default, skip_params=[], parameter_defaults={}): 56 | argument_group = parser.add_argument_group(argument_for_class.capitalize()) 57 | 58 | module_dict = module_to_dict(module) 59 | argument_group.add_argument('--' + argument_for_class, type=str, default=default, choices=list(module_dict.keys())) 60 | 61 | args, unknown_args = parser.parse_known_args() 62 | class_obj = module_dict[vars(args)[argument_for_class]] 63 | 64 | argspec = inspect.getargspec(class_obj.__init__) 65 | 66 | defaults = argspec.defaults[::-1] if argspec.defaults else None 67 | 68 | args = argspec.args[::-1] 69 | for i, arg in enumerate(args): 70 | cmd_arg = '{}_{}'.format(argument_for_class, arg) 71 | if arg not in skip_params + ['self', 'args']: 72 | if arg in list(parameter_defaults.keys()): 73 | argument_group.add_argument('--{}'.format(cmd_arg), type=type(parameter_defaults[arg]), default=parameter_defaults[arg]) 74 | elif (defaults is not None and i < len(defaults)): 75 | argument_group.add_argument('--{}'.format(cmd_arg), type=type(defaults[i]), default=defaults[i]) 76 | else: 77 | print(("[Warning]: non-default argument '{}' detected on class '{}'. This argument cannot be modified via the command line" 78 | .format(arg, module.__class__.__name__))) 79 | # We don't have a good way of dealing with inferring the type of the argument 80 | # TODO: try creating a custom action and using ast's infer type? 81 | # else: 82 | # argument_group.add_argument('--{}'.format(cmd_arg), required=True) 83 | 84 | def kwargs_from_args(args, argument_for_class): 85 | argument_for_class = argument_for_class + '_' 86 | return {key[len(argument_for_class):]: value for key, value in list(vars(args).items()) if argument_for_class in key and key != argument_for_class + 'class'} 87 | 88 | def format_dictionary_of_losses(labels, values): 89 | try: 90 | string = ', '.join([('{}: {:' + ('.3f' if value >= 0.001 else '.1e') +'}').format(name, value) for name, value in zip(labels, values)]) 91 | except (TypeError, ValueError) as e: 92 | print((list(zip(labels, values)))) 93 | string = '[Log Error] ' + str(e) 94 | 95 | return string 96 | 97 | 98 | class IteratorTimer(): 99 | def __init__(self, iterable): 100 | self.iterable = iterable 101 | self.iterator = self.iterable.__iter__() 102 | 103 | def __iter__(self): 104 | return self 105 | 106 | def __len__(self): 107 | return len(self.iterable) 108 | 109 | def __next__(self): 110 | start = time.time() 111 | n = next(self.iterator) 112 | self.last_duration = (time.time() - start) 113 | return n 114 | 115 | next = __next__ 116 | 117 | def gpumemusage(): 118 | gpu_mem = subprocess.check_output("nvidia-smi | grep MiB | cut -f 3 -d '|'", shell=True).replace(' ', '').replace('\n', '').replace('i', '') 119 | all_stat = [float(a) for a in gpu_mem.replace('/','').split('MB')[:-1]] 120 | 121 | gpu_mem = '' 122 | for i in range(len(all_stat)/2): 123 | curr, tot = all_stat[2*i], all_stat[2*i+1] 124 | util = "%1.2f"%(100*curr/tot)+'%' 125 | cmem = str(int(math.ceil(curr/1024.)))+'GB' 126 | gmem = str(int(math.ceil(tot/1024.)))+'GB' 127 | gpu_mem += util + '--' + join(cmem, gmem) + ' ' 128 | return gpu_mem 129 | 130 | 131 | def update_hyperparameter_schedule(args, epoch, global_iteration, optimizer): 132 | if args.schedule_lr_frequency > 0: 133 | for param_group in optimizer.param_groups: 134 | if (global_iteration + 1) % args.schedule_lr_frequency == 0: 135 | param_group['lr'] /= float(args.schedule_lr_fraction) 136 | param_group['lr'] = float(np.maximum(param_group['lr'], 0.000001)) 137 | 138 | def save_checkpoint(state, is_best, path, prefix, filename='checkpoint.pth.tar'): 139 | prefix_save = os.path.join(path, prefix) 140 | name = prefix_save + '_' + filename 141 | torch.save(state, name) 142 | if is_best: 143 | shutil.copyfile(name, prefix_save + '_model_best.pth.tar') 144 | 145 | -------------------------------------------------------------------------------- /models/networks/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import numpy as np 11 | from models.networks.vgg import VGG_Activations, Vgg19 12 | 13 | # Defines the GAN loss which uses either LSGAN or the regular GAN. 14 | # When LSGAN is used, it is basically same as MSELoss, 15 | # but it abstracts away the need to create the target label tensor 16 | # that has the same size as the input 17 | class GANLoss(nn.Module): 18 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, 19 | tensor=torch.FloatTensor, opt=None): 20 | super(GANLoss, self).__init__() 21 | self.real_label = target_real_label 22 | self.fake_label = target_fake_label 23 | self.real_label_tensor = None 24 | self.fake_label_tensor = None 25 | self.Tensor = tensor 26 | self.gan_mode = gan_mode 27 | self.opt = opt 28 | if gan_mode == 'ls': 29 | pass 30 | elif gan_mode == 'original': 31 | pass 32 | elif gan_mode == 'w': 33 | pass 34 | elif gan_mode == 'hinge': 35 | pass 36 | else: 37 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode)) 38 | 39 | def get_target_tensor(self, input, target_is_real): 40 | if target_is_real: 41 | if self.real_label_tensor is None: 42 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label) 43 | return self.real_label_tensor.expand_as(input) 44 | else: 45 | if self.fake_label_tensor is None: 46 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label) 47 | return self.fake_label_tensor.expand_as(input) 48 | 49 | def loss(self, input, target_is_real, weight=None, reduce_dim=True, for_discriminator=True): 50 | if self.gan_mode == 'original': 51 | target_tensor = self.get_target_tensor(input, target_is_real) 52 | batchsize = input.size(0) 53 | loss = F.binary_cross_entropy_with_logits(input, target_tensor, weight=weight) 54 | if not reduce_dim: 55 | loss = loss.view(batchsize, -1).mean(dim=1) 56 | return loss 57 | elif self.gan_mode == 'ls': 58 | #target_tensor = self.get_target_tensor(input, target_is_real) 59 | target_tensor = input * 0 + (self.real_label if target_is_real else self.fake_label) 60 | if weight is None and reduce_dim: 61 | return F.mse_loss(input, target_tensor) 62 | error = (input - target_tensor)**2 63 | if weight is not None: 64 | error *= weight 65 | if reduce_dim: 66 | return torch.mean(error) 67 | else: 68 | return error.view(input.size(0), -1).mean(dim=1) 69 | elif self.gan_mode == 'hinge': 70 | assert weight == None 71 | assert reduce_dim == True 72 | if for_discriminator: 73 | if target_is_real: 74 | minval = torch.min(input - 1, input * 0) 75 | loss = -torch.mean(minval) 76 | else: 77 | minval = torch.min(-input - 1, input * 0) 78 | loss = -torch.mean(minval) 79 | else: 80 | assert target_is_real, "The generator's hinge loss must be aiming for real" 81 | loss = -torch.mean(input) 82 | 83 | return loss 84 | else: 85 | # wgan 86 | assert weight is None and reduce_dim 87 | if target_is_real: 88 | return -input.mean() 89 | else: 90 | return input.mean() 91 | 92 | def __call__(self, input, target_is_real, weight=None, reduce_dim=True, for_discriminator=True): 93 | if isinstance(input, list): 94 | loss = 0 95 | for pred_i in input: 96 | if isinstance(pred_i, list): 97 | pred_i = pred_i[-1] 98 | loss_tensor = self.loss(pred_i, target_is_real, weight, reduce_dim, for_discriminator) 99 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0) 100 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1) 101 | loss += new_loss 102 | return loss / len(input) 103 | else: 104 | return self.loss(input, target_is_real, weight, reduce_dim, for_discriminator) 105 | 106 | 107 | class VGGLoss(nn.Module): 108 | def __init__(self, opt, gpu_ids): 109 | super(VGGLoss, self).__init__() 110 | self.vgg = VGG_Activations([1, 6, 11, 20, 29]).cuda() 111 | self.criterion = nn.L1Loss() 112 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 113 | 114 | def compute_loss(self, x_vgg, y_vgg): 115 | loss = 0 116 | for i in range(len(x_vgg)): 117 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach()) 118 | return loss 119 | 120 | def forward(self, x, y): 121 | if len(x.size()) == 5: 122 | b, t, c, h, w = x.size() 123 | x, y = x.view(-1, c, h, w), y.view(-1, c, h, w) 124 | 125 | y_vgg = self.vgg(y) 126 | x_vgg = self.vgg(x) 127 | loss = self.compute_loss(x_vgg, y_vgg) 128 | return loss 129 | 130 | class MaskedL1Loss(nn.Module): 131 | def __init__(self): 132 | super(MaskedL1Loss, self).__init__() 133 | self.criterion = nn.L1Loss() 134 | 135 | def forward(self, input, target, mask): 136 | mask = mask.expand_as(input) 137 | loss = self.criterion(input * mask, target * mask) 138 | return loss 139 | 140 | class KLDLoss(nn.Module): 141 | def forward(self, mu, logvar): 142 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) 143 | 144 | 145 | 146 | 147 | -------------------------------------------------------------------------------- /models/networks/normalization.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import re 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.functional as F 11 | import torch.nn.utils.spectral_norm as sn 12 | 13 | from models.networks.base_network import batch_conv 14 | # from models.networks.sync_batchnorm import SynchronizedBatchNorm2d 15 | from apex.parallel import SyncBatchNorm as SynchronizedBatchNorm2d 16 | 17 | 18 | class SPADE(nn.Module): 19 | def __init__(self, norm_nc, hidden_nc=0, norm='batch', ks=3, params_free=False): 20 | super().__init__() 21 | pw = ks//2 22 | if not isinstance(hidden_nc, list): hidden_nc = [hidden_nc] 23 | for i, nhidden in enumerate(hidden_nc): 24 | mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 25 | mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) 26 | 27 | if not params_free or (i != 0): 28 | s = str(i+1) if i > 0 else '' 29 | setattr(self, 'mlp_gamma%s' % s, mlp_gamma) 30 | setattr(self, 'mlp_beta%s' % s, mlp_beta) 31 | 32 | if 'batch' in norm: 33 | self.norm = SynchronizedBatchNorm2d(norm_nc, affine=False) 34 | else: 35 | self.norm = nn.InstanceNorm2d(norm_nc, affine=False, eps=0.1) 36 | 37 | def forward(self, x, maps, weights=None): 38 | if not isinstance(maps, list): maps = [maps] 39 | out = self.norm(x) 40 | for i in range(len(maps)): 41 | if maps[i] is None: continue 42 | m = F.interpolate(maps[i], size=x.size()[2:]) 43 | if weights is None or (i != 0): 44 | s = str(i+1) if i > 0 else '' 45 | gamma = getattr(self, 'mlp_gamma%s' % s)(m) 46 | beta = getattr(self, 'mlp_beta%s' % s)(m) 47 | else: 48 | j = min(i, len(weights[0])-1) 49 | gamma = batch_conv(m, weights[0][j]) 50 | beta = batch_conv(m, weights[1][j]) 51 | out = out * (1 + gamma) + beta 52 | return out 53 | 54 | def get_nonspade_norm_layer(opt, norm_type='instance'): 55 | # helper function to get # output channels of the previous layer 56 | def get_out_channel(layer): 57 | if hasattr(layer, 'out_channels'): 58 | return getattr(layer, 'out_channels') 59 | return layer.weight.size(0) 60 | 61 | # this function will be returned 62 | def add_norm_layer(layer): 63 | nonlocal norm_type 64 | if norm_type.startswith('spectral'): 65 | layer = sn(layer) 66 | subnorm_type = norm_type[len('spectral'):] 67 | 68 | if subnorm_type == 'none' or len(subnorm_type) == 0: 69 | return layer 70 | 71 | # remove bias in the previous layer, which is meaningless 72 | # since it has no effect after normalization 73 | if getattr(layer, 'bias', None) is not None: 74 | delattr(layer, 'bias') 75 | layer.register_parameter('bias', None) 76 | 77 | if subnorm_type == 'batch': 78 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True) 79 | elif subnorm_type == 'syncbatch': 80 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True) 81 | elif subnorm_type == 'instance': 82 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=True, eps=0.1) 83 | else: 84 | raise ValueError('normalization layer %s is not recognized' % subnorm_type) 85 | 86 | return nn.Sequential(layer, norm_layer) 87 | 88 | return add_norm_layer 89 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/comm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : comm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import queue 12 | import collections 13 | import threading 14 | 15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster'] 16 | 17 | 18 | class FutureResult(object): 19 | """A thread-safe future implementation. Used only as one-to-one pipe.""" 20 | 21 | def __init__(self): 22 | self._result = None 23 | self._lock = threading.Lock() 24 | self._cond = threading.Condition(self._lock) 25 | 26 | def put(self, result): 27 | with self._lock: 28 | assert self._result is None, 'Previous result has\'t been fetched.' 29 | self._result = result 30 | self._cond.notify() 31 | 32 | def get(self): 33 | with self._lock: 34 | if self._result is None: 35 | self._cond.wait() 36 | 37 | res = self._result 38 | self._result = None 39 | return res 40 | 41 | 42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result']) 43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result']) 44 | 45 | 46 | class SlavePipe(_SlavePipeBase): 47 | """Pipe for master-slave communication.""" 48 | 49 | def run_slave(self, msg): 50 | self.queue.put((self.identifier, msg)) 51 | ret = self.result.get() 52 | self.queue.put(True) 53 | return ret 54 | 55 | 56 | class SyncMaster(object): 57 | """An abstract `SyncMaster` object. 58 | 59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should 60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master. 61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected, 62 | and passed to a registered callback. 63 | - After receiving the messages, the master device should gather the information and determine to message passed 64 | back to each slave devices. 65 | """ 66 | 67 | def __init__(self, master_callback): 68 | """ 69 | 70 | Args: 71 | master_callback: a callback to be invoked after having collected messages from slave devices. 72 | """ 73 | self._master_callback = master_callback 74 | self._queue = queue.Queue() 75 | self._registry = collections.OrderedDict() 76 | self._activated = False 77 | 78 | def register_slave(self, identifier): 79 | """ 80 | Register an slave device. 81 | 82 | Args: 83 | identifier: an identifier, usually is the device id. 84 | 85 | Returns: a `SlavePipe` object which can be used to communicate with the master device. 86 | 87 | """ 88 | if self._activated: 89 | assert self._queue.empty(), 'Queue is not clean before next initialization.' 90 | self._activated = False 91 | self._registry.clear() 92 | future = FutureResult() 93 | self._registry[identifier] = _MasterRegistry(future) 94 | return SlavePipe(identifier, self._queue, future) 95 | 96 | def run_master(self, master_msg): 97 | """ 98 | Main entry for the master device in each forward pass. 99 | The messages were first collected from each devices (including the master device), and then 100 | an callback will be invoked to compute the message to be sent back to each devices 101 | (including the master device). 102 | 103 | Args: 104 | master_msg: the message that the master want to send to itself. This will be placed as the first 105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example. 106 | 107 | Returns: the message to be sent back to the master device. 108 | 109 | """ 110 | self._activated = True 111 | 112 | intermediates = [(0, master_msg)] 113 | for i in range(self.nr_slaves): 114 | intermediates.append(self._queue.get()) 115 | 116 | results = self._master_callback(intermediates) 117 | assert results[0][0] == 0, 'The first result should belongs to the master.' 118 | 119 | for i, res in results: 120 | if i == 0: 121 | continue 122 | self._registry[i].result.put(res) 123 | 124 | for i in range(self.nr_slaves): 125 | assert self._queue.get() is True 126 | 127 | return results[0][1] 128 | 129 | @property 130 | def nr_slaves(self): 131 | return len(self._registry) 132 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | #from torch.nn.parallel.data_parallel import DataParallel 14 | import torch.nn as nn 15 | 16 | __all__ = [ 17 | 'CallbackContext', 18 | 'execute_replication_callbacks', 19 | 'DataParallelWithCallback', 20 | 'patch_replication_callback' 21 | ] 22 | 23 | 24 | class DataParallel(nn.parallel.DataParallel): 25 | def replicate(self, module, device_ids): 26 | replicas = super(DataParallel, self).replicate(module, device_ids) 27 | replicas[0] = module 28 | return replicas 29 | 30 | class CallbackContext(object): 31 | pass 32 | 33 | 34 | def execute_replication_callbacks(modules): 35 | """ 36 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 37 | 38 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 39 | 40 | Note that, as all modules are isomorphism, we assign each sub-module with a context 41 | (shared among multiple copies of this module on different devices). 42 | Through this context, different copies can share some information. 43 | 44 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 45 | of any slave copies. 46 | """ 47 | master_copy = modules[0] 48 | nr_modules = len(list(master_copy.modules())) 49 | ctxs = [CallbackContext() for _ in range(nr_modules)] 50 | 51 | for i, module in enumerate(modules): 52 | for j, m in enumerate(module.modules()): 53 | if hasattr(m, '__data_parallel_replicate__'): 54 | m.__data_parallel_replicate__(ctxs[j], i) 55 | 56 | 57 | class DataParallelWithCallback(DataParallel): 58 | """ 59 | Data Parallel with a replication callback. 60 | 61 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 62 | original `replicate` function. 63 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 64 | 65 | Examples: 66 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 67 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 68 | # sync_bn.__data_parallel_replicate__ will be invoked. 69 | """ 70 | 71 | def replicate(self, module, device_ids): 72 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 73 | execute_replication_callbacks(modules) 74 | return modules 75 | 76 | 77 | def patch_replication_callback(data_parallel): 78 | """ 79 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 80 | Useful when you have customized `DataParallel` implementation. 81 | 82 | Examples: 83 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 84 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 85 | > patch_replication_callback(sync_bn) 86 | # this is equivalent to 87 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 88 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 89 | """ 90 | 91 | assert isinstance(data_parallel, DataParallel) 92 | 93 | old_replicate = data_parallel.replicate 94 | 95 | @functools.wraps(old_replicate) 96 | def new_replicate(module, device_ids): 97 | modules = old_replicate(module, device_ids) 98 | execute_replication_callbacks(modules) 99 | return modules 100 | 101 | data_parallel.replicate = new_replicate 102 | -------------------------------------------------------------------------------- /models/networks/sync_batchnorm/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /models/networks/vgg.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | import torch 9 | import torch.nn as nn 10 | import torchvision 11 | from collections import OrderedDict 12 | 13 | class Vgg19(nn.Module): 14 | def __init__(self, requires_grad=False): 15 | super(Vgg19, self).__init__() 16 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features 17 | self.slice1 = nn.Sequential() 18 | self.slice2 = nn.Sequential() 19 | self.slice3 = nn.Sequential() 20 | self.slice4 = nn.Sequential() 21 | self.slice5 = nn.Sequential() 22 | for x in range(2): 23 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(2, 7): 25 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 26 | for x in range(7, 12): 27 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 28 | for x in range(12, 21): 29 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 30 | for x in range(21, 30): 31 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 32 | if not requires_grad: 33 | for param in self.parameters(): 34 | param.requires_grad = False 35 | 36 | def forward(self, X): 37 | h_relu1 = self.slice1(X) 38 | h_relu2 = self.slice2(h_relu1) 39 | h_relu3 = self.slice3(h_relu2) 40 | h_relu4 = self.slice4(h_relu3) 41 | h_relu5 = self.slice5(h_relu4) 42 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 43 | return out 44 | 45 | class VGG_Activations(nn.Module): 46 | def __init__(self, feature_idx): 47 | super(VGG_Activations, self).__init__() 48 | vgg_network = torchvision.models.vgg19(pretrained=True) 49 | features = list(vgg_network.features) 50 | self.features = nn.ModuleList(features).eval() 51 | self.idx_list = feature_idx 52 | 53 | def forward(self, x): 54 | results = [] 55 | for ii, model in enumerate(self.features): 56 | x = model(x) 57 | if ii in self.idx_list: 58 | results.append(x) 59 | return results 60 | -------------------------------------------------------------------------------- /models/trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import numpy as np 9 | import torch 10 | import time 11 | from collections import OrderedDict 12 | import fractions 13 | from subprocess import call 14 | def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0 15 | 16 | import util.util as util 17 | from util.visualizer import Visualizer 18 | from models.models import save_models, update_models 19 | from util.distributed import master_only, is_master, get_world_size 20 | from util.distributed import master_only_print as print 21 | 22 | class Trainer(): 23 | def __init__(self, opt, data_loader): 24 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt') 25 | start_epoch, epoch_iter = 1, 0 26 | ### if continue training, recover previous states 27 | if opt.continue_train: 28 | if os.path.exists(iter_path): 29 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int) 30 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter)) 31 | 32 | print_freq = lcm(opt.print_freq, opt.batchSize) 33 | total_steps = (start_epoch-1) * len(data_loader) + epoch_iter 34 | total_steps = total_steps // print_freq * print_freq 35 | 36 | self.opt = opt 37 | self.epoch_iter, self.print_freq, self.total_steps, self.iter_path = epoch_iter, print_freq, total_steps, iter_path 38 | self.start_epoch, self.epoch_iter = start_epoch, epoch_iter 39 | self.dataset_size = len(data_loader) 40 | self.visualizer = Visualizer(opt) 41 | 42 | def start_of_iter(self, data): 43 | if self.total_steps % self.print_freq == 0: 44 | self.iter_start_time = time.time() 45 | self.total_steps += self.opt.batchSize 46 | self.epoch_iter += self.opt.batchSize 47 | self.save = self.total_steps % self.opt.display_freq == 0 48 | for k, v in data.items(): 49 | if isinstance(v, torch.Tensor): 50 | data[k] = v.cuda() 51 | return data 52 | 53 | def end_of_iter(self, loss_dicts, output_list, model): 54 | opt = self.opt 55 | epoch, epoch_iter, print_freq, total_steps = self.epoch, self.epoch_iter, self.print_freq, self.total_steps 56 | ############## Display results and errors ########## 57 | ### print out errors 58 | if is_master() and total_steps % print_freq == 0: 59 | t = (time.time() - self.iter_start_time) / print_freq / get_world_size() 60 | errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dicts.items()} 61 | self.visualizer.print_current_errors(epoch, epoch_iter, errors, t) 62 | self.visualizer.plot_current_errors(errors, total_steps) 63 | 64 | ### display output images 65 | if is_master() and self.save: 66 | visuals = save_all_tensors(opt, output_list, model) 67 | self.visualizer.display_current_results(visuals, epoch, total_steps) 68 | 69 | if is_master() and opt.print_mem: 70 | call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"]) 71 | 72 | ### save latest model 73 | save_models(opt, epoch, epoch_iter, total_steps, self.visualizer, self.iter_path, model) 74 | if epoch_iter > self.dataset_size - opt.batchSize: 75 | return True 76 | return False 77 | 78 | def start_of_epoch(self, epoch, model, data_loader): 79 | self.epoch = epoch 80 | self.epoch_start_time = time.time() 81 | if self.opt.distributed: 82 | data_loader.dataloader.sampler.set_epoch(epoch) 83 | # update model params 84 | update_models(self.opt, epoch, model, data_loader) 85 | 86 | def end_of_epoch(self, model): 87 | opt = self.opt 88 | iter_end_time = time.time() 89 | self.visualizer.vis_print(opt, 'End of epoch %d / %d \t Time Taken: %d sec' % 90 | (self.epoch, opt.niter + opt.niter_decay, time.time() - self.epoch_start_time)) 91 | 92 | ### save model for this epoch 93 | save_models(opt, self.epoch, self.epoch_iter, self.total_steps, self.visualizer, self.iter_path, model, end_of_epoch=True) 94 | self.epoch_iter = 0 95 | 96 | def save_all_tensors(opt, output_list, model): 97 | fake_image, fake_raw_image, warped_image, flow, weight, atn_score, \ 98 | target_label, target_image, flow_gt, conf_gt, ref_label, ref_image = output_list 99 | 100 | visual_list = [('target_label', util.visualize_label(opt, target_label, model)), 101 | ('synthesized_image', util.tensor2im(fake_image)), 102 | ('target_image', util.tensor2im(target_image)), 103 | ('ref_image', util.tensor2im(ref_image, tile=True)), 104 | ('raw_image', util.tensor2im(fake_raw_image)), 105 | ('warped_images', util.tensor2im(warped_image, tile=True)), 106 | ('flows', util.tensor2flow(flow, tile=True)), 107 | ('weights', util.tensor2im(weight, normalize=False, tile=True)), 108 | ('atn_score', util.tensor2im(atn_score, normalize=False)), 109 | ] 110 | visuals = OrderedDict(visual_list) 111 | return visuals 112 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | from .base_options import BaseOptions 8 | 9 | class TestOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 13 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 14 | parser.add_argument('--how_many', type=int, default=300, help='how many test images to run') 15 | parser.set_defaults(serial_batches=True) 16 | parser.set_defaults(batchSize=1) 17 | parser.set_defaults(nThreads=1) 18 | parser.set_defaults(no_flip=True) 19 | self.isTrain = False 20 | return parser 21 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | from .base_options import BaseOptions 8 | 9 | class TrainOptions(BaseOptions): 10 | def initialize(self, parser): 11 | BaseOptions.initialize(self, parser) 12 | # for displays 13 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen') 14 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console') 15 | parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results') 16 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs') 17 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 18 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration') 19 | parser.add_argument('--print_mem', action='store_true', help='print memory usage') 20 | parser.add_argument('--print_G', action='store_true', help='print network G') 21 | parser.add_argument('--print_D', action='store_true', help='print network D') 22 | 23 | # for training 24 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 25 | parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location') 26 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 27 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate') 28 | parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero') 29 | parser.add_argument('--niter_single', type=int, default=50, help='# of iter for single frame training') 30 | parser.add_argument('--niter_step', type=int, default=10, help='# of iter to double the length of training sequence') 31 | 32 | # for temporal 33 | parser.add_argument('--n_frames_D', type=int, default=2, help='number of frames to feed into temporal discriminator') 34 | parser.add_argument('--n_frames_total', type=int, default=2, help='the overall number of frames in a sequence to train with') 35 | parser.add_argument('--max_t_step', type=int, default=4, help='max spacing between neighboring sampled frames. If greater than 1, the network may randomly skip frames during training.') 36 | 37 | self.isTrain = True 38 | return parser 39 | -------------------------------------------------------------------------------- /scripts/download_datasets.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | from download_gdrive import * 9 | 10 | file_id = '1NyxrzJbHgDOpf-nRJhfxsHqWNDW_D1TV' 11 | chpt_path = './datasets/' 12 | if not os.path.isdir(chpt_path): 13 | os.makedirs(chpt_path) 14 | destination = os.path.join(chpt_path, 'datasets.zip') 15 | download_file_from_google_drive(file_id, destination) 16 | unzip_file(destination, chpt_path) -------------------------------------------------------------------------------- /scripts/download_flownet2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | from download_gdrive import * 9 | os.system('cd models/networks/flownet2_pytorch/; bash install.sh; cd ../../../') 10 | 11 | file_id = '1E8re-b6csNuo-abg1vJKCDjCzlIam50F' 12 | chpt_path = './models/networks/flownet2_pytorch/' 13 | destination = os.path.join(chpt_path, 'FlowNet2_checkpoint.pth.tar') 14 | download_file_from_google_drive(file_id, destination) -------------------------------------------------------------------------------- /scripts/download_gdrive.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | # Download code taken from Code taken from https://stackoverflow.com questions/25010369/wget-curl-large-file-from-google-drive/39225039#39225039 9 | import requests, zipfile, os 10 | def download_file_from_google_drive(id, destination): 11 | URL = "https://docs.google.com/uc?export=download" 12 | session = requests.Session() 13 | response = session.get(URL, params = { 'id' : id }, stream = True) 14 | token = get_confirm_token(response) 15 | if token: 16 | params = { 'id' : id, 'confirm' : token } 17 | response = session.get(URL, params = params, stream = True) 18 | save_response_content(response, destination) 19 | def get_confirm_token(response): 20 | for key, value in response.cookies.items(): 21 | if key.startswith('download_warning'): 22 | return value 23 | return None 24 | def save_response_content(response, destination): 25 | CHUNK_SIZE = 32768 26 | with open(destination, "wb") as f: 27 | for chunk in response.iter_content(CHUNK_SIZE): 28 | if chunk: # filter out keep-alive new chunks 29 | f.write(chunk) 30 | 31 | def unzip_file(file_name, unzip_path): 32 | zip_ref = zipfile.ZipFile(file_name, 'r') 33 | zip_ref.extractall(unzip_path) 34 | zip_ref.close() 35 | os.remove(file_name) -------------------------------------------------------------------------------- /scripts/face/test_256.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | python test.py --name face_256 --dataset_mode fewshot_face --adaptive_spade --warp_ref --spade_combine 9 | -------------------------------------------------------------------------------- /scripts/face/test_512.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | python test.py --name face_512 --dataset_mode fewshot_face --loadSize 512 --fineSize 512 --adaptive_spade --warp_ref --spade_combine 9 | -------------------------------------------------------------------------------- /scripts/face/train_g1_256.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | python train.py --name face_256 --dataset_mode fewshot_face \ 9 | --adaptive_spade --warp_ref --spade_combine --batchSize 4 --continue_train -------------------------------------------------------------------------------- /scripts/face/train_g8_256.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | python train.py --name face_256 --dataset_mode fewshot_face \ 9 | --adaptive_spade --warp_ref --spade_combine \ 10 | --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 32 --nThreads 32 --continue_train -------------------------------------------------------------------------------- /scripts/face/train_g8_512.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | python train.py --name face_512 --dataset_mode fewshot_face \ 9 | --loadSize 512 --fineSize 512 --num_D 2 \ 10 | --adaptive_spade --warp_ref --spade_combine \ 11 | --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 8 --nThreads 32 --continue_train -------------------------------------------------------------------------------- /scripts/pose/test.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | python test.py --name pose --dataset_mode fewshot_pose \ 9 | --adaptive_spade --warp_ref --spade_combine --remove_face_labels --finetune 10 | -------------------------------------------------------------------------------- /scripts/pose/train_g1.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | python train.py --name pose --dataset_mode fewshot_pose \ 9 | --adaptive_spade --warp_ref --spade_combine --remove_face_labels --add_face_D \ 10 | --batchSize 2 --niter 100 --niter_single 100 --continue_train 11 | -------------------------------------------------------------------------------- /scripts/pose/train_g8.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | 8 | python train.py --name pose --dataset_mode fewshot_pose \ 9 | --adaptive_spade --warp_ref --spade_combine --remove_face_labels --add_face_D \ 10 | --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 30 --nThreads 32 --niter 100 --niter_single 100 --continue_train -------------------------------------------------------------------------------- /scripts/street/test.sh: -------------------------------------------------------------------------------- 1 | python test.py --name street --dataset_mode fewshot_street --adaptive_spade --loadSize 512 --fineSize 512 2 | -------------------------------------------------------------------------------- /scripts/street/train_g1.sh: -------------------------------------------------------------------------------- 1 | python train.py --name street --dataset_mode fewshot_street \ 2 | --adaptive_spade --loadSize 512 --fineSize 512 --batchSize 6 --continue_train 3 | -------------------------------------------------------------------------------- /scripts/street/train_g8.sh: -------------------------------------------------------------------------------- 1 | python train.py --name street --dataset_mode fewshot_street \ 2 | --adaptive_spade --loadSize 512 --fineSize 512 \ 3 | --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 46 --nThreads 16 --continue_train 4 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import numpy as np 9 | import torch 10 | import cv2 11 | from collections import OrderedDict 12 | 13 | from options.test_options import TestOptions 14 | from data.data_loader import CreateDataLoader 15 | from models.models import create_model 16 | import util.util as util 17 | from util.visualizer import Visualizer 18 | from util import html 19 | 20 | opt = TestOptions().parse() 21 | 22 | ### setup dataset 23 | data_loader = CreateDataLoader(opt) 24 | dataset = data_loader.load_data() 25 | 26 | ### setup models 27 | model = create_model(opt) 28 | model.eval() 29 | visualizer = Visualizer(opt) 30 | 31 | # create website 32 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch)) 33 | if opt.finetune: web_dir += '_finetune' 34 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch), infer=True) 35 | 36 | # test 37 | for i, data in enumerate(dataset): 38 | if i >= opt.how_many or i >= len(dataset): break 39 | img_path = data['path'] 40 | data_list = [data['tgt_label'], data['tgt_image'], None, None, data['ref_label'], data['ref_image'], None, None, None] 41 | synthesized_image, _, _, _, _, _ = model(data_list) 42 | 43 | synthesized_image = util.tensor2im(synthesized_image) 44 | tgt_image = util.tensor2im(data['tgt_image']) 45 | ref_image = util.tensor2im(data['ref_image'], tile=True) 46 | seq = data['seq'][0] 47 | visual_list = [ref_image, tgt_image, synthesized_image] 48 | visuals = OrderedDict([(seq, np.hstack(visual_list)), 49 | (seq + '/synthesized', synthesized_image), 50 | (seq + '/ref_image', ref_image if i == 0 else None), 51 | ]) 52 | print('process image... %s' % img_path) 53 | visualizer.save_images(webpage, visuals, img_path) 54 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import numpy as np 9 | import torch 10 | 11 | from options.train_options import TrainOptions 12 | from data.data_loader import CreateDataLoader 13 | from models.models import create_model 14 | from models.loss_collector import loss_backward 15 | from models.trainer import Trainer 16 | from util.distributed import init_dist 17 | from util.distributed import master_only_print as print 18 | 19 | def train(): 20 | opt = TrainOptions().parse() 21 | 22 | if opt.distributed: 23 | init_dist() 24 | print('batch size per GPU: %d' % opt.batchSize) 25 | torch.backends.cudnn.benchmark = True 26 | 27 | ### setup dataset 28 | data_loader = CreateDataLoader(opt) 29 | dataset = data_loader.load_data() 30 | pose = 'pose' in opt.dataset_mode 31 | 32 | ### setup trainer 33 | trainer = Trainer(opt, data_loader) 34 | 35 | ### setup models 36 | model, flowNet, [optimizer_G, optimizer_D] = create_model(opt, trainer.start_epoch) 37 | flow_gt = conf_gt = [None] * 2 38 | 39 | for epoch in range(trainer.start_epoch, opt.niter + opt.niter_decay + 1): 40 | if opt.distributed: 41 | dataset.sampler.set_epoch(epoch) 42 | trainer.start_of_epoch(epoch, model, data_loader) 43 | n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu 44 | for idx, data in enumerate(dataset, start=trainer.epoch_iter): 45 | data = trainer.start_of_iter(data) 46 | 47 | if not opt.no_flow_gt: 48 | data_list = [data['tgt_label'], data['ref_label']] if pose else [data['tgt_image'], data['ref_image']] 49 | flow_gt, conf_gt = flowNet(data_list, epoch) 50 | data_list = [data['tgt_label'], data['tgt_image'], flow_gt, conf_gt] 51 | data_ref_list = [data['ref_label'], data['ref_image']] 52 | data_prev = [None, None, None] 53 | 54 | ############## Forward Pass ###################### 55 | for t in range(0, n_frames_total, n_frames_load): 56 | data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + data_prev 57 | 58 | d_losses = model(data_list_t, mode='discriminator') 59 | d_losses = loss_backward(opt, d_losses, optimizer_D, 1) 60 | 61 | g_losses, generated, data_prev = model(data_list_t, save_images=trainer.save, mode='generator') 62 | g_losses = loss_backward(opt, g_losses, optimizer_G, 0) 63 | 64 | loss_dict = dict(zip(model.module.lossCollector.loss_names, g_losses + d_losses)) 65 | 66 | if trainer.end_of_iter(loss_dict, generated + data_list + data_ref_list, model): 67 | break 68 | trainer.end_of_epoch(model) 69 | 70 | def get_data_t(data, n_frames_load, t): 71 | if data is None: return None 72 | if type(data) == list: 73 | return [get_data_t(d, n_frames_load, t) for d in data] 74 | return data[:,t:t+n_frames_load] 75 | 76 | if __name__ == "__main__": 77 | train() -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt -------------------------------------------------------------------------------- /util/distributed.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import os 8 | import numpy as np 9 | import random 10 | import functools 11 | import torch 12 | import torch.distributed as dist 13 | 14 | 15 | def init_dist(launcher='pytorch', backend='nccl', **kwargs): 16 | raise ValueError('Distributed training is not fully tested yet and might be unstable. ' 17 | 'If you are confident to run it, please comment out this line.') 18 | if dist.is_initialized(): 19 | return torch.cuda.current_device() 20 | set_random_seed(get_rank()) 21 | rank = int(os.environ['RANK']) 22 | num_gpus = torch.cuda.device_count() 23 | gpu_id = rank % num_gpus 24 | torch.cuda.set_device(gpu_id) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | return gpu_id 27 | 28 | 29 | def get_rank(): 30 | if dist.is_initialized(): 31 | rank = dist.get_rank() 32 | else: 33 | rank = 0 34 | return rank 35 | 36 | 37 | def get_world_size(): 38 | if dist.is_initialized(): 39 | world_size = dist.get_world_size() 40 | else: 41 | world_size = 1 42 | return world_size 43 | 44 | 45 | def master_only(func): 46 | @functools.wraps(func) 47 | def wrapper(*args, **kwargs): 48 | if get_rank() == 0: 49 | return func(*args, **kwargs) 50 | else: 51 | return None 52 | return wrapper 53 | 54 | 55 | def is_master(): 56 | """check if current process is the master""" 57 | return get_rank() == 0 58 | 59 | 60 | @master_only 61 | def master_only_print(*args): 62 | """master-only print""" 63 | print(*args) 64 | 65 | 66 | def dist_reduce_tensor(tensor): 67 | """ Reduce to rank 0 """ 68 | world_size = get_world_size() 69 | if world_size < 2: 70 | return tensor 71 | with torch.no_grad(): 72 | dist.reduce(tensor, dst=0) 73 | if get_rank() == 0: 74 | tensor /= world_size 75 | return tensor 76 | 77 | 78 | def dist_all_reduce_tensor(tensor): 79 | """ Reduce to all ranks """ 80 | world_size = get_world_size() 81 | if world_size < 2: 82 | return tensor 83 | with torch.no_grad(): 84 | dist.all_reduce(tensor) 85 | tensor.div_(world_size) 86 | return tensor 87 | 88 | 89 | def dist_all_gather_tensor(tensor): 90 | """ gather to all ranks """ 91 | world_size = get_world_size() 92 | if world_size < 2: 93 | return [tensor] 94 | tensor_list = [ 95 | torch.ones_like(tensor) for _ in range(dist.get_world_size())] 96 | with torch.no_grad(): 97 | dist.all_gather(tensor_list, tensor) 98 | return tensor_list 99 | 100 | 101 | def set_random_seed(seed): 102 | """Set random seeds for everything. 103 | Inputs: 104 | seed (int): Random seed. 105 | """ 106 | random.seed(seed) 107 | np.random.seed(seed) 108 | torch.manual_seed(seed) 109 | torch.cuda.manual_seed(seed) 110 | torch.cuda.manual_seed_all(seed) 111 | -------------------------------------------------------------------------------- /util/html.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import datetime 8 | import dominate 9 | from dominate.tags import * 10 | import os 11 | 12 | 13 | class HTML: 14 | def __init__(self, web_dir, title, refresh=0, infer=False): 15 | self.title = title 16 | self.web_dir = web_dir 17 | if not infer: 18 | self.img_dir = os.path.join(self.web_dir, 'images') 19 | else: 20 | self.img_dir = self.web_dir 21 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir): 22 | os.makedirs(self.web_dir) 23 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir): 24 | os.makedirs(self.img_dir) 25 | 26 | self.doc = dominate.document(title=title) 27 | with self.doc: 28 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y")) 29 | if refresh > 0: 30 | with self.doc.head: 31 | meta(http_equiv="refresh", content=str(refresh)) 32 | 33 | def get_image_dir(self): 34 | return self.img_dir 35 | 36 | def add_header(self, str): 37 | with self.doc: 38 | h3(str) 39 | 40 | def add_table(self, border=1): 41 | self.t = table(border=border, style="table-layout: fixed;") 42 | self.doc.add(self.t) 43 | 44 | def add_images(self, ims, txts, links, width=512, height=0): 45 | self.add_table() 46 | with self.t: 47 | with tr(): 48 | for im, txt, link in zip(ims, txts, links): 49 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 50 | with p(): 51 | with a(href=os.path.join('images', link)): 52 | if height != 0: 53 | img(style="width:%dpx;height:%dpx" % (width, height), src=os.path.join('images', im)) 54 | else: 55 | img(style="width:%dpx" % (width), src=os.path.join('images', im)) 56 | br() 57 | p(txt) 58 | 59 | def save(self): 60 | html_file = '%s/index.html' % self.web_dir 61 | f = open(html_file, 'wt') 62 | f.write(self.doc.render()) 63 | f.close() 64 | 65 | 66 | if __name__ == '__main__': 67 | html = HTML('web/', 'test_html') 68 | html.add_header('hello world') 69 | 70 | ims = [] 71 | txts = [] 72 | links = [] 73 | for n in range(4): 74 | ims.append('image_%d.jpg' % n) 75 | txts.append('text_%d' % n) 76 | links.append('image_%d.jpg' % n) 77 | html.add_images(ims, txts, links) 78 | html.save() 79 | -------------------------------------------------------------------------------- /util/image_pool.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved. 2 | # 3 | # This work is made available 4 | # under the Nvidia Source Code License (1-way Commercial). 5 | # To view a copy of this license, visit 6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt 7 | import random 8 | import torch 9 | from torch.autograd import Variable 10 | class ImagePool(): 11 | def __init__(self, pool_size): 12 | self.pool_size = pool_size 13 | if self.pool_size > 0: 14 | self.num_imgs = 0 15 | self.images = [] 16 | 17 | def query(self, images): 18 | if self.pool_size == 0: 19 | return images 20 | return_images = [] 21 | for image in images.data: 22 | image = torch.unsqueeze(image, 0) 23 | if self.num_imgs < self.pool_size: 24 | self.num_imgs = self.num_imgs + 1 25 | self.images.append(image) 26 | return_images.append(image) 27 | else: 28 | p = random.uniform(0, 1) 29 | if p > 0.5: 30 | random_id = random.randint(0, self.pool_size-1) 31 | tmp = self.images[random_id].clone() 32 | self.images[random_id] = image 33 | return_images.append(tmp) 34 | else: 35 | return_images.append(image) 36 | return_images = Variable(torch.cat(return_images, 0)) 37 | return return_images 38 | --------------------------------------------------------------------------------