├── .gitignore
├── License.txt
├── README.md
├── data
├── __init__.py
├── base_data_loader.py
├── base_dataset.py
├── custom_dataset_data_loader.py
├── data_loader.py
├── fewshot_face_dataset.py
├── fewshot_pose_dataset.py
├── fewshot_street_dataset.py
├── image_folder.py
├── keypoint2img.py
├── lmdb_dataset.py
└── preprocess
│ ├── download_youTube_playlist.py
│ ├── preprocess.py
│ ├── util
│ ├── check_valid.py
│ ├── get_poses.py
│ ├── track.py
│ └── util.py
│ └── youTube_playlist.txt
├── imgs
├── dance.gif
├── face.gif
├── illustration.gif
├── mona_lisa.gif
├── statue.gif
└── street.gif
├── models
├── __init__.py
├── base_model.py
├── face_refiner.py
├── flownet.py
├── input_process.py
├── loss_collector.py
├── models.py
├── networks
│ ├── __init__.py
│ ├── architecture.py
│ ├── base_network.py
│ ├── discriminator.py
│ ├── flownet2_pytorch
│ │ ├── Dockerfile
│ │ ├── README.md
│ │ ├── __init__.py
│ │ ├── convert.py
│ │ ├── datasets.py
│ │ ├── install.sh
│ │ ├── launch_docker.sh
│ │ ├── losses.py
│ │ ├── main.py
│ │ ├── models.py
│ │ ├── networks
│ │ │ ├── FlowNetC.py
│ │ │ ├── FlowNetFusion.py
│ │ │ ├── FlowNetS.py
│ │ │ ├── FlowNetSD.py
│ │ │ ├── __init__.py
│ │ │ ├── channelnorm_package
│ │ │ │ ├── __init__.py
│ │ │ │ ├── channelnorm.py
│ │ │ │ ├── channelnorm_cuda.cc
│ │ │ │ ├── channelnorm_kernel.cu
│ │ │ │ ├── channelnorm_kernel.cuh
│ │ │ │ └── setup.py
│ │ │ ├── correlation_package
│ │ │ │ ├── __init__.py
│ │ │ │ ├── correlation.py
│ │ │ │ ├── correlation_cuda.cc
│ │ │ │ ├── correlation_cuda_kernel.cu
│ │ │ │ ├── correlation_cuda_kernel.cuh
│ │ │ │ └── setup.py
│ │ │ ├── resample2d_package
│ │ │ │ ├── __init__.py
│ │ │ │ ├── resample2d.py
│ │ │ │ ├── resample2d_cuda.cc
│ │ │ │ ├── resample2d_kernel.cu
│ │ │ │ ├── resample2d_kernel.cuh
│ │ │ │ └── setup.py
│ │ │ └── submodules.py
│ │ ├── run-caffe2pytorch.sh
│ │ └── utils
│ │ │ ├── __init__.py
│ │ │ ├── flow_utils.py
│ │ │ ├── frame_utils.py
│ │ │ ├── param_utils.py
│ │ │ └── tools.py
│ ├── generator.py
│ ├── loss.py
│ ├── normalization.py
│ ├── sync_batchnorm
│ │ ├── __init__.py
│ │ ├── batchnorm.py
│ │ ├── comm.py
│ │ ├── replicate.py
│ │ └── unittest.py
│ └── vgg.py
├── trainer.py
└── vid2vid_model.py
├── options
├── __init__.py
├── base_options.py
├── test_options.py
└── train_options.py
├── scripts
├── download_datasets.py
├── download_flownet2.py
├── download_gdrive.py
├── face
│ ├── test_256.sh
│ ├── test_512.sh
│ ├── train_g1_256.sh
│ ├── train_g8_256.sh
│ └── train_g8_512.sh
├── pose
│ ├── test.sh
│ ├── train_g1.sh
│ └── train_g8.sh
└── street
│ ├── test.sh
│ ├── train_g1.sh
│ └── train_g8.sh
├── test.py
├── train.py
└── util
├── __init__.py
├── distributed.py
├── html.py
├── image_pool.py
├── util.py
└── visualizer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | debug*
2 | checkpoints/
3 | datasets/
4 | models/debug*
5 | models/networks/flownet2*/networks/*/*egg-info
6 | models/networks/flownet2*/networks/*/build
7 | models/networks/flownet2*/networks/*/__pycache__
8 | models/networks/flownet2*/networks/*/dist
9 | models/networks/flownet2*/*.pth.tar
10 | results/
11 | build/
12 | .idea/
13 | */Thumbs.db
14 | */**/__pycache__
15 | */*.pyc
16 | */**/*.pyc
17 | */**/**/*.pyc
18 | */**/**/**/*.pyc
19 | */**/**/**/**/*.pyc
20 | */*.so*
21 | */**/*.so*
22 | */**/*.dylib*
23 | *.DS_Store
24 | *~
25 |
--------------------------------------------------------------------------------
/License.txt:
--------------------------------------------------------------------------------
1 | Nvidia Source Code License (1-Way Commercial) – NVIDIA CONFIDENTIAL
2 |
3 | 1. Definitions
4 |
5 | “Licensor” means any person or entity that distributes its Work.
6 | “Software” means the original work of authorship made available under this License.
7 | “Work” means the Software and any additions to or derivative works of the Software that are made available under this License.
8 | “Nvidia Processors” means any central processing unit (CPU), graphics processing unit (GPU), field-programmable gate array (FPGA), application-specific integrated circuit (ASIC) or any combination thereof designed, made, sold, or provided by Nvidia or its affiliates.
9 | The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this License, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
10 | Works, including the Software, are “made available” under this License by including in or with the Work either (a) a copyright notice referencing the applicability of this License to the Work, or (b) a copy of this License.
11 |
12 | 2. License Grants
13 |
14 | 2.1 Copyright Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
15 |
16 | 2.2 Patent Grant. Subject to the terms and conditions of this License, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free patent license to make, have made, use, sell, offer for sale, import, and otherwise transfer its Work, in whole or in part. The foregoing license applies only to the patent claims licensable by Licensor that would be infringed by Licensor’s Work (or portion thereof) individually and excluding any combinations with any other materials or technology.
17 |
18 | 3. Limitations
19 |
20 | 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this License, (b) you include a complete copy of this License with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
21 |
22 | 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this License (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
23 |
24 | 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. The Work or derivative works thereof may be used or intended for use by Nvidia or it’s affiliates commercially or non-commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
25 |
26 | 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this License from such Licensor (including the grants in Sections 2.1 and 2.2) will terminate immediately.
27 |
28 | 3.5 Trademarks. This License does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this License.
29 |
30 | 3.6 Termination. If you violate any term of this License, then your rights under this License (including the grants in Sections 2.1 and 2.2) will terminate immediately.
31 |
32 | 4. Disclaimer of Warranty.
33 |
34 | THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF M ERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. SOME STATES’ CONSUMER LAWS DO NOT ALLOW EXCLUSION OF AN IMPLIED WARRANTY, SO THIS DISCLAIMER MAY NOT APPLY TO YOU.
35 |
36 | 5. Limitation of Liability.
37 |
38 | EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER COMM ERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
39 |
40 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import importlib
8 | from data.base_dataset import BaseDataset
9 | from util.distributed import master_only_print as print
10 |
11 | def find_dataset_using_name(dataset_name):
12 | # Given the option --dataset [datasetname],
13 | # the file "datasets/datasetname_dataset.py"
14 | # will be imported.
15 | dataset_filename = "data." + dataset_name + "_dataset"
16 | datasetlib = importlib.import_module(dataset_filename)
17 |
18 | # In the file, the class called DatasetNameDataset() will
19 | # be instantiated. It has to be a subclass of BaseDataset,
20 | # and it is case-insensitive.
21 | dataset = None
22 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
23 | for name, cls in datasetlib.__dict__.items():
24 | if name.lower() == target_dataset_name.lower() \
25 | and issubclass(cls, BaseDataset):
26 | dataset = cls
27 |
28 | if dataset is None:
29 | print("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
30 | exit(0)
31 |
32 | return dataset
33 |
34 |
35 | def get_option_setter(dataset_name):
36 | dataset_class = find_dataset_using_name(dataset_name)
37 | return dataset_class.modify_commandline_options
38 |
39 |
40 | def create_dataset(opt):
41 | dataset = find_dataset_using_name(opt.dataset_mode)
42 | instance = dataset()
43 | instance.initialize(opt)
44 | print("dataset [%s] was created" % (instance.name()))
45 | return instance
46 |
--------------------------------------------------------------------------------
/data/base_data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | class BaseDataLoader():
8 | def __init__(self):
9 | pass
10 |
11 | def initialize(self, opt):
12 | self.opt = opt
13 | pass
14 |
15 | def load_data():
16 | return None
17 |
18 |
19 |
20 |
--------------------------------------------------------------------------------
/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | ### Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
2 | ### Licensed under the Nvidia Source Code License.
3 |
4 | from PIL import Image
5 | import numpy as np
6 | import random
7 | import torch
8 | import torch.utils.data as data
9 | import torchvision.transforms as transforms
10 | from util.distributed import master_only_print as print
11 |
12 | class BaseDataset(data.Dataset):
13 | def __init__(self):
14 | super(BaseDataset, self).__init__()
15 | self.L = self.I = self.Lr = self.Ir = None
16 | self.n_frames_total = 1 # current number of frames to train in a single iteration
17 | self.use_lmdb = False
18 |
19 | def name(self):
20 | return 'BaseDataset'
21 |
22 | def update_training_batch(self, ratio):
23 | # update the training sequence length to be longer
24 | seq_len_max = 30
25 | if self.n_frames_total < seq_len_max:
26 | self.n_frames_total = min(seq_len_max, self.opt.n_frames_total * (2**ratio))
27 | print('--- Updating training sequence length to %d ---' % self.n_frames_total)
28 |
29 | def read_data(self, path, lmdb=None, data_type='img'):
30 | is_img = data_type == 'img'
31 | if self.use_lmdb and lmdb is not None:
32 | img, _ = lmdb.getitem_by_path(path.encode(), is_img)
33 | if is_img and len(img.mode) == 3:
34 | b, g, r = img.split()
35 | img = Image.merge("RGB", (r, g, b))
36 | elif data_type == 'np':
37 | img = img.decode()
38 | img = np.array([[int(j) for j in i.split(',')] for i in img.splitlines()])
39 | elif is_img:
40 | img = Image.open(path)
41 | elif data_type == 'np':
42 | img = np.loadtxt(path, delimiter=',')
43 | else:
44 | img = path
45 | return img
46 |
47 | def crop(self, img, coords):
48 | min_y, max_y, min_x, max_x = coords
49 | if isinstance(img, np.ndarray):
50 | return img[min_y:max_y, min_x:max_x]
51 | else:
52 | return img.crop((min_x, min_y, max_x, max_y))
53 |
54 | def concat_frame(self, A, Ai, n=100):
55 | if A is None or Ai.shape[0] >= n: return Ai[-n:]
56 | else: return torch.cat([A, Ai])[-n:]
57 |
58 | def concat(self, tensors, dim=0):
59 | tensors = [t for t in tensors if t is not None]
60 | return torch.cat(tensors, dim)
61 |
62 | def get_img_params(opt, size):
63 | w, h = size
64 | new_w, new_h = w, h
65 |
66 | # resize input image
67 | if 'resize' in opt.resize_or_crop:
68 | new_h = new_w = opt.loadSize
69 | else:
70 | if 'scale_width' in opt.resize_or_crop:
71 | new_w = opt.loadSize
72 | elif 'random_scale' in opt.resize_or_crop:
73 | new_w = random.randrange(int(opt.fineSize), int(1.2*opt.fineSize))
74 | new_h = int(new_w * h) // w
75 | if 'crop' not in opt.resize_or_crop:
76 | new_h = int(new_w // opt.aspect_ratio)
77 | new_w = new_w // 4 * 4
78 | new_h = new_h // 4 * 4
79 |
80 | # crop resized image
81 | size_x = min(opt.loadSize, opt.fineSize)
82 | size_y = size_x // opt.aspect_ratio
83 | if not opt.isTrain: # crop central region
84 | pos_x = (new_w - size_x) // 2
85 | pos_y = (new_h - size_y) // 2
86 | else: # crop random region
87 | pos_x = random.randrange(np.maximum(1, new_w - size_x))
88 | pos_y = random.randrange(np.maximum(1, new_h - size_y))
89 |
90 | # for color augmentation
91 | h_b = random.uniform(-30, 30)
92 | s_a = random.uniform(0.8, 1.2)
93 | s_b = random.uniform(-10, 10)
94 | v_a = random.uniform(0.8, 1.2)
95 | v_b = random.uniform(-10, 10)
96 |
97 | flip = random.random() > 0.5
98 | return {'new_size': (new_w, new_h), 'crop_pos': (pos_x, pos_y), 'crop_size': (size_x, size_y), 'flip': flip,
99 | 'color_aug': (h_b, s_a, s_b, v_a, v_b)}
100 |
101 | def get_video_params(opt, n_frames_total, cur_seq_len, index):
102 | if opt.isTrain:
103 | n_frames_total = min(cur_seq_len, n_frames_total) # total number of frames to load
104 | max_t_step = min(opt.max_t_step, (cur_seq_len-1) // max(1, (n_frames_total-1)))
105 | t_step = random.randrange(max_t_step) + 1 # spacing between neighboring sampled frames
106 |
107 | offset_max = max(1, cur_seq_len - (n_frames_total-1)*t_step) # maximum possible frame index for the first frame
108 | if 'pose' in opt.dataset_mode:
109 | start_idx = index % offset_max # offset for the first frame to load
110 | max_range, min_range = 60, 14 # range for possible reference frames
111 | else:
112 | start_idx = random.randrange(offset_max) # offset for the first frame to load
113 | max_range, min_range = 300, 14 # range for possible reference frames
114 |
115 | ref_range = list(range(max(0, start_idx - max_range), max(1, start_idx - min_range))) \
116 | + list(range(min(start_idx + min_range, cur_seq_len - 1), min(start_idx + max_range, cur_seq_len)))
117 | ref_indices = random.sample(ref_range, opt.n_shot) # indices for reference frames
118 |
119 | else:
120 | n_frames_total = 1
121 | start_idx = index
122 | t_step = 1
123 | ref_indices = opt.ref_img_id.split(',')
124 | ref_indices = [int(i) for i in ref_indices]
125 |
126 | return n_frames_total, start_idx, t_step, ref_indices
127 |
128 | def get_transform(opt, params, method=Image.BICUBIC, normalize=True, toTensor=True, color_aug=False):
129 | transform_list = []
130 | transform_list.append(transforms.Lambda(lambda img: __scale_image(img, params['new_size'], method)))
131 |
132 | if 'crop' in opt.resize_or_crop:
133 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], params['crop_size'])))
134 |
135 | if opt.isTrain and color_aug:
136 | transform_list.append(transforms.Lambda(lambda img: __color_aug(img, params['color_aug'])))
137 |
138 | if opt.isTrain and not opt.no_flip:
139 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
140 |
141 | if toTensor:
142 | transform_list += [transforms.ToTensor()]
143 |
144 | if normalize:
145 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5),
146 | (0.5, 0.5, 0.5))]
147 | return transforms.Compose(transform_list)
148 |
149 | def __scale_image(img, size, method=Image.BICUBIC):
150 | w, h = size
151 | return img.resize((w, h), method)
152 |
153 | def __crop(img, pos, size):
154 | ow, oh = img.size
155 | x1, y1 = pos
156 | tw, th = size
157 | return img.crop((x1, y1, x1 + tw, y1 + th))
158 |
159 | def __flip(img, flip):
160 | if flip:
161 | return img.transpose(Image.FLIP_LEFT_RIGHT)
162 | return img
163 |
164 | def __color_aug(img, params):
165 | h, s, v = img.convert('HSV').split()
166 | h = h.point(lambda i: (i + params[0]) % 256)
167 | s = s.point(lambda i: min(255, max(0, i * params[1] + params[2])))
168 | v = v.point(lambda i: min(255, max(0, i * params[3] + params[4])))
169 | img = Image.merge('HSV', (h, s, v)).convert('RGB')
170 | return img
--------------------------------------------------------------------------------
/data/custom_dataset_data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import torch.utils.data
8 | import torch.distributed as dist
9 | from data.base_data_loader import BaseDataLoader
10 | import data
11 |
12 | class CustomDatasetDataLoader(BaseDataLoader):
13 | def name(self):
14 | return 'CustomDatasetDataLoader'
15 |
16 | def initialize(self, opt):
17 | BaseDataLoader.initialize(self, opt)
18 | self.dataset = data.create_dataset(opt)
19 | if dist.is_initialized():
20 | sampler = torch.utils.data.distributed.DistributedSampler(self.dataset)
21 | else:
22 | sampler = None
23 |
24 | self.dataloader = torch.utils.data.DataLoader(
25 | self.dataset,
26 | batch_size=opt.batchSize,
27 | shuffle=(sampler is None) and not opt.serial_batches,
28 | sampler=sampler,
29 | pin_memory=True,
30 | num_workers=int(opt.nThreads),
31 | drop_last=True
32 | )
33 |
34 | def load_data(self):
35 | return self.dataloader
36 |
37 | def __len__(self):
38 | size = min(len(self.dataset), self.opt.max_dataset_size)
39 | ngpus = len(self.opt.gpu_ids)
40 | round_to_ngpus = (size // ngpus) * ngpus
41 | return round_to_ngpus
42 |
--------------------------------------------------------------------------------
/data/data_loader.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | from util.distributed import master_only_print as print
8 | def CreateDataLoader(opt):
9 | from data.custom_dataset_data_loader import CustomDatasetDataLoader
10 | data_loader = CustomDatasetDataLoader()
11 | print(data_loader.name())
12 | data_loader.initialize(opt)
13 | return data_loader
14 |
--------------------------------------------------------------------------------
/data/fewshot_street_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os.path as path
8 | import glob
9 | import torch
10 | from PIL import Image
11 | import numpy as np
12 |
13 | from data.base_dataset import BaseDataset, get_img_params, get_video_params, get_transform
14 | from data.image_folder import make_dataset, make_grouped_dataset, check_path_valid
15 |
16 | class FewshotStreetDataset(BaseDataset):
17 | @staticmethod
18 | def modify_commandline_options(parser, is_train):
19 | parser.set_defaults(dataroot='datasets/street/')
20 | parser.add_argument('--label_nc', type=int, default=20, help='# of input label channels')
21 | parser.add_argument('--input_nc', type=int, default=3, help='# of input image channels')
22 | parser.add_argument('--aspect_ratio', type=float, default=2)
23 | parser.set_defaults(resize_or_crop='random_scale_and_crop')
24 | parser.set_defaults(niter=20)
25 | parser.set_defaults(niter_single=10)
26 | parser.set_defaults(niter_step=2)
27 | parser.set_defaults(save_epoch_freq=1)
28 |
29 | ### for inference
30 | parser.add_argument('--seq_path', type=str, default='datasets/street/test_images/01/', help='path to the driving sequence')
31 | parser.add_argument('--ref_img_path', type=str, default='datasets/street/test_images/02/', help='path to the reference image')
32 | parser.add_argument('--ref_img_id', type=str, default='0', help='indices of reference frames')
33 | return parser
34 |
35 | def initialize(self, opt):
36 | self.opt = opt
37 | root = opt.dataroot
38 | self.L_is_label = self.opt.label_nc != 0
39 |
40 | if opt.isTrain:
41 | self.L_paths = sorted(make_grouped_dataset(path.join(root, 'train_labels')))
42 | self.I_paths = sorted(make_grouped_dataset(path.join(root, 'train_images')))
43 | check_path_valid(self.L_paths, self.I_paths)
44 |
45 | self.n_of_seqs = len(self.L_paths)
46 | print('%d sequences' % self.n_of_seqs)
47 | else:
48 | self.I_paths = sorted(make_dataset(opt.seq_path))
49 | self.L_paths = sorted(make_dataset(opt.seq_path.replace('images', 'labels')))
50 | self.ref_I_paths = sorted(make_dataset(opt.ref_img_path))
51 | self.ref_L_paths = sorted(make_dataset(opt.ref_img_path.replace('images', 'labels')))
52 |
53 | def __getitem__(self, index):
54 | opt = self.opt
55 | if opt.isTrain:
56 | L_paths = self.L_paths[index % self.n_of_seqs]
57 | I_paths = self.I_paths[index % self.n_of_seqs]
58 | ref_L_paths, ref_I_paths = L_paths, I_paths
59 | else:
60 | L_paths, I_paths = self.L_paths, self.I_paths
61 | ref_L_paths, ref_I_paths = self.ref_L_paths, self.ref_I_paths
62 |
63 |
64 | ### setting parameters
65 | n_frames_total, start_idx, t_step, ref_indices = get_video_params(opt, self.n_frames_total, len(I_paths), index)
66 | w, h = opt.fineSize, int(opt.fineSize / opt.aspect_ratio)
67 | img_params = get_img_params(opt, (w, h))
68 | is_first_frame = opt.isTrain or index == 0
69 |
70 | transform_I = get_transform(opt, img_params, color_aug=opt.isTrain)
71 | transform_L = get_transform(opt, img_params, method=Image.NEAREST, normalize=False) if self.L_is_label else transform_I
72 |
73 |
74 | ### read in reference image
75 | Lr, Ir = self.Lr, self.Ir
76 | if is_first_frame:
77 | for idx in ref_indices:
78 | Li = self.get_image(ref_L_paths[idx], transform_L, is_label=self.L_is_label)
79 | Ii = self.get_image(ref_I_paths[idx], transform_I)
80 | Lr = self.concat_frame(Lr, Li.unsqueeze(0))
81 | Ir = self.concat_frame(Ir, Ii.unsqueeze(0))
82 |
83 | if not opt.isTrain: # keep track of non-changing variables during inference
84 | self.Lr, self.Ir = Lr, Ir
85 |
86 |
87 | ### read in target images
88 | L, I = self.L, self.I
89 | for t in range(n_frames_total):
90 | idx = start_idx + t * t_step
91 | Lt = self.get_image(L_paths[idx], transform_L, is_label=self.L_is_label)
92 | It = self.get_image(I_paths[idx], transform_I)
93 | L = self.concat_frame(L, Lt.unsqueeze(0))
94 | I = self.concat_frame(I, It.unsqueeze(0))
95 |
96 | if not opt.isTrain:
97 | self.L, self.I = L, I
98 |
99 | seq = path.basename(path.dirname(opt.ref_img_path)) + '-' + opt.ref_img_id + '_' + path.basename(path.dirname(opt.seq_path))
100 |
101 | return_list = {'tgt_label': L, 'tgt_image': I, 'ref_label': Lr, 'ref_image': Ir,
102 | 'path': I_paths[idx], 'seq': seq}
103 | return return_list
104 |
105 | def get_image(self, A_path, transform_scaleA, is_label=False):
106 | if is_label: return self.get_label_tensor(A_path, transform_scaleA)
107 | A_img = self.read_data(A_path)
108 | A_scaled = transform_scaleA(A_img)
109 | return A_scaled
110 |
111 | def get_label_tensor(self, label_path, transform_label):
112 | label = self.read_data(label_path).convert('L')
113 |
114 | train2eval = self.opt.label_nc == 20
115 | if train2eval:
116 | ### 35 to 20
117 | A_label_np = np.array(label)
118 | label_mapping = np.array([19, 19, 19, 19, 19, 19, 19, 0, 1, 19, 19, 2, 3, 4, 19, 19, 19, 5, 19,
119 | 6, 7, 8, 9, 18, 10, 11, 12, 13, 14, 19, 19, 15, 16, 17, 19], dtype=np.uint8)
120 | A_label_np = label_mapping[A_label_np]
121 | label = Image.fromarray(A_label_np)
122 |
123 | label_tensor = transform_label(label) * 255.0
124 | return label_tensor
125 |
126 | def __len__(self):
127 | if not self.opt.isTrain: return len(self.L_paths)
128 | return max(10000, sum([len(L) for L in self.L_paths]))
129 |
130 | def name(self):
131 | return 'StreetDataset'
--------------------------------------------------------------------------------
/data/image_folder.py:
--------------------------------------------------------------------------------
1 | ###############################################################################
2 | # Code from
3 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
4 | # Modified the original code so that it also loads images from the current
5 | # directory as well as the subdirectories
6 | ###############################################################################
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import os
10 |
11 | IMG_EXTENSIONS = [
12 | '.jpg', '.JPG', '.jpeg', '.JPEG',
13 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', '.tiff', '.webp',
14 | '.txt', '.json',
15 | ]
16 |
17 |
18 | def is_image_file(filename):
19 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
20 |
21 | def make_dataset_rec(dir, images):
22 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
23 |
24 | for root, dnames, fnames in sorted(os.walk(dir, followlinks=True)):
25 | for fname in fnames:
26 | if is_image_file(fname):
27 | path = os.path.join(root, fname)
28 | images.append(path)
29 | # for dname in dnames:
30 | # make_dataset_rec(os.path.join(root, dname), images)
31 |
32 |
33 | def make_dataset(dir, recursive=False, read_cache=False, write_cache=False):
34 | images = []
35 |
36 | if read_cache:
37 | possible_filelist = os.path.join(dir, 'files.list')
38 | if os.path.isfile(possible_filelist):
39 | with open(possible_filelist, 'r') as f:
40 | images = f.read().splitlines()
41 | return images
42 |
43 | if recursive:
44 | make_dataset_rec(dir, images)
45 | else:
46 | #assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
47 |
48 | for root, dnames, fnames in sorted(os.walk(dir)):
49 | for fname in fnames:
50 | if is_image_file(fname):
51 | path = os.path.join(root, fname)
52 | images.append(path)
53 |
54 | if write_cache:
55 | filelist_cache = os.path.join(dir, 'files.list')
56 | with open(filelist_cache, 'w') as f:
57 | for path in images:
58 | f.write("%s\n" % path)
59 | print('wrote filelist cache at %s' % filelist_cache)
60 |
61 | return images
62 |
63 | def make_grouped_dataset(dir):
64 | images = []
65 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
66 | fnames = sorted(os.walk(dir))
67 | for fname in sorted(fnames):
68 | paths = []
69 | root = fname[0]
70 | for f in sorted(fname[2]):
71 | if is_image_file(f):
72 | paths.append(os.path.join(root, f))
73 | if len(paths) > 0:
74 | images.append(paths)
75 | return images
76 |
77 | def check_path_valid(A_paths, B_paths):
78 | if len(A_paths) != len(B_paths):
79 | print('%s not equal to %s' % (A_paths[0], B_paths[0]))
80 | assert(len(A_paths) == len(B_paths))
81 |
82 | if isinstance(A_paths[0], list):
83 | for a, b in zip(A_paths, B_paths):
84 | if len(a) != len(b):
85 | print('%s not equal to %s' % (a[0], b[0]))
86 | assert(len(a) == len(b))
87 |
88 | def default_loader(path):
89 | return Image.open(path).convert('RGB')
90 |
91 |
92 | class ImageFolder(data.Dataset):
93 |
94 | def __init__(self, root, transform=None, return_paths=False,
95 | loader=default_loader):
96 | imgs = make_dataset(root)
97 | if len(imgs) == 0:
98 | raise(RuntimeError("Found 0 images in: " + root + "\n"
99 | "Supported image extensions are: " +
100 | ",".join(IMG_EXTENSIONS)))
101 |
102 | self.root = root
103 | self.imgs = imgs
104 | self.transform = transform
105 | self.return_paths = return_paths
106 | self.loader = loader
107 |
108 | def __getitem__(self, index):
109 | path = self.imgs[index]
110 | img = self.loader(path)
111 | if self.transform is not None:
112 | img = self.transform(img)
113 | if self.return_paths:
114 | return img, path
115 | else:
116 | return img
117 |
118 | def __len__(self):
119 | return len(self.imgs)
120 |
--------------------------------------------------------------------------------
/data/lmdb_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import lmdb
9 | import pickle
10 | import numpy as np
11 | from PIL import Image
12 | import cv2
13 | import torch.utils.data as data
14 | from util.distributed import master_only_print as print
15 |
16 | class LMDBDataset(data.Dataset):
17 | def __init__(self, root, write_cache=False):
18 | self.root = os.path.expanduser(root)
19 | self.env = lmdb.open(root, max_readers=126, readonly=True, lock=False,
20 | readahead=False, meminit=False)
21 | with self.env.begin(write=False) as txn:
22 | self.length = txn.stat()['entries']
23 | print('LMDB file at %s opened.' % root)
24 | cache_file = os.path.join(root, '_cache_')
25 | if os.path.isfile(cache_file):
26 | self.keys = pickle.load(open(cache_file, "rb"))
27 | elif write_cache:
28 | print('generating keys')
29 | with self.env.begin(write=False) as txn:
30 | self.keys = [key for key, _ in txn.cursor()]
31 | pickle.dump(self.keys, open(cache_file, "wb"))
32 | print('cache file generated at %s' % cache_file)
33 | else:
34 | self.keys = []
35 |
36 | def getitem_by_path(self, path, is_img=True):
37 | env = self.env
38 | with env.begin(write=False) as txn:
39 | buf = txn.get(path)
40 | if is_img:
41 | img = cv2.imdecode(np.fromstring(buf, dtype=np.uint8), 1)
42 | img = Image.fromarray(img)
43 | return img, path
44 | return buf, path
45 |
46 | def __getitem__(self, index):
47 | path = self.keys[index]
48 | return self.getitem_by_path(path)
49 |
50 | def __len__(self):
51 | return self.length
52 |
--------------------------------------------------------------------------------
/data/preprocess/download_youTube_playlist.py:
--------------------------------------------------------------------------------
1 | import glob
2 | from pytube import YouTube, Playlist
3 |
4 |
5 | # Example script to download youtube scripts.
6 | class MyPlaylist(Playlist):
7 | def download_all(
8 | self,
9 | download_path=None,
10 | prefix_number=True,
11 | reverse_numbering=False,
12 | idx=0,
13 | ):
14 | self.populate_video_urls()
15 | prefix_gen = self._path_num_prefix_generator(reverse_numbering)
16 | for i, link in enumerate(self.video_urls):
17 | prefix = '%03d_%03d_' % (idx + 1, i + 1)
18 | p = glob.glob(prefix[:-1] + '*.mp4')
19 | print(prefix, link)
20 | if not p:
21 | try:
22 | yt = YouTube(link)
23 | dl_stream = yt.streams.filter(adaptive=True, subtype='mp4').first()
24 | dl_stream.download(download_path, filename_prefix=prefix)
25 | except:
26 | print('cannot download')
27 | pass
28 |
29 | playlist_path = 'youTube_playlist.txt'
30 | with open(playlist_path, 'r') as f:
31 | playlists = f.read().splitlines()
32 |
33 | for i, playlist in enumerate(playlists):
34 | pl = MyPlaylist(playlist)
35 | pl.download_all(idx=i)
36 |
--------------------------------------------------------------------------------
/data/preprocess/preprocess.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import glob
9 | import os.path as path
10 | import argparse
11 | import json
12 | from tqdm import tqdm
13 |
14 | from util.get_poses import extract_all_frames, run_densepose, run_openpose
15 | from util.check_valid import remove_invalid_frames, remove_static_frames, \
16 | remove_isolated_frames, check_densepose_exists
17 | from util.track import divide_sequences
18 | from util.util import remove_folder
19 |
20 |
21 | def parse_args():
22 | parser = argparse.ArgumentParser()
23 | parser.add_argument('--steps', default='all',
24 | help='all | extract_frames | openpose | densepose | clean | divide_sequences')
25 | parser.add_argument('--video_root', required=True,
26 | help='path for videos to process')
27 | parser.add_argument('--output_root', required=True,
28 | help='path for output images')
29 |
30 | parser.add_argument('--img_folder', default='images')
31 | parser.add_argument('--openpose_folder', default='openpose')
32 | parser.add_argument('--openpose_postfix', default='_keypoints.json')
33 | parser.add_argument('--densepose_folder', default='densepose')
34 | parser.add_argument('--densepose_postfix', default='_IUV.png')
35 | parser.add_argument('--densemask_folder', default='densemask')
36 | parser.add_argument('--densemask_postfix', default='_INDS.png')
37 | parser.add_argument('--track_folder', default='tracking')
38 |
39 | parser.add_argument('--openpose_root', default='/',
40 | help='root for the OpenPose library')
41 | parser.add_argument('--densepose_root', default='/',
42 | help='root for the DensePose library')
43 |
44 | parser.add_argument('--n_skip_frames', type=int, default='100',
45 | help='Number of frames between keyframes. A larger '
46 | 'number can expedite processing but may lose data')
47 | parser.add_argument('--min_n_of_frames', type=int, default='30',
48 | help='Minimum number of frames in the output sequence.')
49 |
50 | args = parser.parse_args()
51 | return args
52 |
53 |
54 | def rename_videos(video_root):
55 | video_paths = sorted(glob.glob(video_root + '/*.mp4'))
56 | for i, video_path in enumerate(video_paths):
57 | new_path = video_root + ('/%05d.mp4' % i)
58 | os.rename(video_path, new_path)
59 |
60 |
61 | # Remove frames that are not suitable for training.
62 | def remove_unusable_frames(args, video_idx):
63 | remove_invalid_frames(args, video_idx)
64 | check_densepose_exists(args, video_idx)
65 | remove_static_frames(args, video_idx)
66 | remove_isolated_frames(args, video_idx)
67 | video_path = path.join(args.output_root, args.img_folder, video_idx)
68 | if len(os.listdir(video_path)) == 0:
69 | remove_folder(args, video_idx)
70 |
71 |
72 | if __name__ == "__main__":
73 | args = parse_args()
74 | if args.steps == 'all':
75 | args.steps = 'openpose,densepose,clean,divide_sequences'
76 |
77 | if 'extract_frames' in args.steps or 'openpose' in args.steps or \
78 | 'densepose' in args.steps:
79 | rename_videos(args.video_root)
80 | video_paths = sorted(glob.glob(args.video_root + '/*.mp4'))
81 | for video_path in tqdm(video_paths):
82 | if 'extract_frames' in args.steps:
83 | # Only extract frames from all the videos.
84 | extract_all_frames(args, video_path)
85 | if 'openpose' in args.steps:
86 | # Include extracting frames and running OpenPose.
87 | run_openpose(args, video_path)
88 | if 'densepose' in args.steps:
89 | # Run DensePose.
90 | video_idx = path.basename(video_path).split('.')[0]
91 | run_densepose(args, video_idx)
92 |
93 | if 'clean' in args.steps:
94 | # Frames already extracted and openpose / densepose already run, only
95 | # remove the unusable frames in the dataset.
96 | # Note that the folder structure should be
97 | # [output_root] / [img_folder] / [sequences] / [frames], and the
98 | # names of the frames must be in format of
99 | # 'frame000000.jpg', 'frame000001.jpg', ...
100 | video_indices = sorted(glob.glob(path.join(args.output_root,
101 | args.img_folder, '*')))
102 | video_indices = [path.basename(p) for p in video_indices]
103 |
104 | # Remove all unusable frames in the sequences.
105 | for i, video_idx in enumerate(tqdm(video_indices)):
106 | remove_unusable_frames(args, video_idx)
107 |
108 | if 'divide_sequences' in args.steps:
109 | # Finally, divide the remaining sequences into sub-sequences, where
110 | # each seb-sequence only contains one person.
111 | video_indices = sorted(glob.glob(path.join(args.output_root,
112 | args.img_folder, '*')))
113 | video_indices = [path.basename(p) for p in video_indices]
114 | seq_indices = []
115 | start_frame_indices, end_frame_indices, ppl_indices = [], [], []
116 | for i, video_idx in enumerate(tqdm(video_indices)):
117 | start_frame_indices_i, end_frame_indices_i, ppl_indices_i = \
118 | divide_sequences(args, video_idx)
119 | seq_indices += [i] * len(start_frame_indices_i)
120 | start_frame_indices += start_frame_indices_i
121 | end_frame_indices += end_frame_indices_i
122 | ppl_indices += ppl_indices_i
123 |
124 | output = dict()
125 | output['seq_indices'] = seq_indices
126 | output['start_frame_indices'] = start_frame_indices
127 | output['end_frame_indices'] = end_frame_indices
128 | output['ppl_indices'] = ppl_indices
129 | output_path = path.join(args.output_root, 'all_subsequences.json')
130 | with open(output_path, 'w') as fp:
131 | json.dump(output, fp, indent=4)
132 |
--------------------------------------------------------------------------------
/data/preprocess/util/check_valid.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import glob
9 | import os.path as path
10 | import json
11 |
12 | from util.util import remove_frame, get_keypoint_array, get_frame_idx, \
13 | get_valid_openpose_keypoints
14 |
15 |
16 | # Remove invalid frames in the video.
17 | def remove_invalid_frames(args, video_idx):
18 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx)
19 | json_paths = sorted(glob.glob(op_dir + '/*.json'))
20 | for json_path in json_paths:
21 | if not is_valid_frame(args, json_path):
22 | remove_frame(args, start=json_path)
23 |
24 |
25 | # Remove static frames in the video if no motion is detected more than
26 | # max_static_frames.
27 | def remove_static_frames(args, video_idx):
28 | max_static_frames = 5 # maximum number of frames to be static
29 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx)
30 | json_paths = sorted(glob.glob(op_dir + '/*.json'))
31 | start_idx = end_idx = 0
32 | keypoint_dicts_prev = None
33 |
34 | for json_path in json_paths:
35 | with open(json_path, encoding='utf-8') as f:
36 | keypoint_dicts = json.loads(f.read())["people"]
37 | is_moving = detect_motion(keypoint_dicts_prev, keypoint_dicts)
38 | keypoint_dicts_prev = keypoint_dicts
39 |
40 | i = get_frame_idx(json_path)
41 | if not is_moving:
42 | end_idx = i
43 | else:
44 | # If static frames longer than max_static_frames, remove them.
45 | if (end_idx - start_idx) > max_static_frames:
46 | remove_frame(args, video_idx, start_idx, end_idx)
47 | start_idx = end_idx = i
48 |
49 |
50 | # Remove small batch frames if number of consecutive frames is smaller than
51 | # min_n_of_frames.
52 | def remove_isolated_frames(args, video_idx):
53 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx)
54 | json_paths = sorted(glob.glob(op_dir + '/*.json'))
55 |
56 | if len(json_paths):
57 | start_idx = end_idx = get_frame_idx(json_paths[0]) - 1
58 | for json_path in json_paths:
59 | i = get_frame_idx(json_path)
60 | # If the frames are not consecutive, there's a breakpoint.
61 | if i != end_idx + 1:
62 | # Check if this block of frames is longer than min_n_of_frames.
63 | if (end_idx - start_idx) < args.min_n_of_frames:
64 | remove_frame(args, video_idx, start_idx, end_idx)
65 | start_idx = i
66 | end_idx = i
67 | # Need to check again at the end of sequence.
68 | if (end_idx - start_idx) < args.min_n_of_frames:
69 | remove_frame(args, video_idx, start_idx, end_idx)
70 |
71 |
72 | # Detect if motion exists between consecutive frames.
73 | def detect_motion(keypoint_dicts_1, keypoint_dicts_2):
74 | motion_thre = 5 # minimum position difference to count as motion
75 | # If it's the first frame or the number of people in these two frames
76 | # are different, return true.
77 | if keypoint_dicts_1 is None:
78 | return True
79 | if len(keypoint_dicts_1) != len(keypoint_dicts_2):
80 | return True
81 |
82 | # If the pose difference between two frames are larger than threshold,
83 | # return true.
84 | for keypoint_dict_1, keypoint_dict_2 in zip(keypoint_dicts_1, keypoint_dicts_2):
85 | pose_pts1, pose_pts2 = get_keypoint_array([keypoint_dict_1, keypoint_dict_2])
86 | if ((abs(pose_pts1 - pose_pts2) > motion_thre) &
87 | (pose_pts1 != 0) & (pose_pts2 != 0)).any():
88 | return True
89 | return False
90 |
91 |
92 | # If densepose did not find any person in the frame (and thus outputs nothing),
93 | # remove the frame from the dataset.
94 | def check_densepose_exists(args, video_idx):
95 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx)
96 | json_paths = sorted(glob.glob(op_dir + '/*.json'))
97 | for json_path in json_paths:
98 | dp_path = json_path.replace(args.openpose_folder, args.densepose_folder)
99 | dp_path = dp_path.replace(args.openpose_postfix, args.densepose_postfix)
100 | if not os.path.exists(dp_path):
101 | remove_frame(args, start=json_path)
102 |
103 |
104 | # Check if the frame is valid to use.
105 | def is_valid_frame(args, img_path):
106 | if img_path.endswith('.jpg'):
107 | img_path = img_path.replace(args.img_folder, args.openpose_folder)
108 | img_path = img_path.replace('.jpg', args.openpose_postfix)
109 | with open(img_path, encoding='utf-8') as f:
110 | keypoint_dicts = json.loads(f.read())["people"]
111 | return len(keypoint_dicts) > 0 and is_full_body(keypoint_dicts) and \
112 | contains_non_overlapping_people(keypoint_dicts)
113 |
114 |
115 | # Check if the image contains a full body.
116 | def is_full_body(keypoint_dicts):
117 | if type(keypoint_dicts) != list:
118 | keypoint_dicts = [keypoint_dicts]
119 | for keypoint_dict in keypoint_dicts:
120 | pose_pts = get_keypoint_array(keypoint_dict)
121 | # Contains at least one joint of head and one joint of foot.
122 | full = pose_pts[[0, 15, 16, 17, 18], :].any() \
123 | and pose_pts[[11, 14, 19, 20, 21, 22, 23, 24], :].any()
124 | if full:
125 | return True
126 | return False
127 |
128 |
129 | # Check whether two people overlap with each other.
130 | def has_overlap(pose_pts_1, pose_pts_2):
131 | pose_pts_1 = get_valid_openpose_keypoints(pose_pts_1)[:, 0]
132 | pose_pts_2 = get_valid_openpose_keypoints(pose_pts_2)[:, 0]
133 | # Get the x_axis bbox of the person.
134 | x1_start, x1_end = int(pose_pts_1.min()), int(pose_pts_1.max())
135 | x2_start, x2_end = int(pose_pts_2.min()), int(pose_pts_2.max())
136 | if x1_end < x2_start or x2_end < x1_start:
137 | return False
138 | return True
139 |
140 |
141 | # Check if the image contains at least one person that does not overlap with others.
142 | def contains_non_overlapping_people(keypoint_dicts):
143 | if len(keypoint_dicts) < 2:
144 | return True
145 |
146 | all_pose_pts = [get_keypoint_array(k) for k in keypoint_dicts]
147 | for i, pose_pts in enumerate(all_pose_pts):
148 | overlap = False
149 | for j, pose_pts2 in enumerate(all_pose_pts):
150 | if i == j:
151 | continue
152 | overlap = overlap | has_overlap(pose_pts, pose_pts2)
153 | if overlap:
154 | break
155 | if not overlap:
156 | return True
157 | return False
158 |
--------------------------------------------------------------------------------
/data/preprocess/util/get_poses.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import glob
9 | import os.path as path
10 | import cv2
11 |
12 | from util.util import makedirs, remove_frame
13 | from util.check_valid import remove_invalid_frames, remove_static_frames, \
14 | remove_isolated_frames, is_valid_frame
15 |
16 |
17 | # Run OpenPose on the extracted frames, and remove invalid frames.
18 | # To expedite the process, we will first process only keyframes in the video.
19 | # If the keyframe looks promising, we will then process the whole block of
20 | # frames for the keyframe.
21 | def run_openpose(args, video_path):
22 | video_idx = path.basename(video_path).split('.')[0]
23 | try:
24 | img_dir = path.join(args.output_root, args.img_folder, video_idx)
25 | op_dir = path.join(args.output_root, args.openpose_folder, video_idx)
26 | img_names = sorted(glob.glob(img_dir + '/*.jpg'))
27 | op_names = sorted(glob.glob(op_dir + '/*.json'))
28 |
29 | # If the frames haven't been extracted or OpenPose hasn't been run or
30 | # finished processing.
31 | if (not os.path.isdir(img_dir) or not os.path.isdir(op_dir)
32 | or len(img_names) != len(op_names)):
33 | makedirs(img_dir)
34 |
35 | # First run OpenPose on key frames, then decide whether to run
36 | # the whole batch of frames.
37 | extract_key_frames(args, video_path, img_dir)
38 | run_openpose_cmd(args, video_idx)
39 |
40 | # If key frame looks good, extract all frames in the batch and
41 | # run OpenPose.
42 | if args.n_skip_frames > 1:
43 | extract_valid_frames(args, video_path, img_dir)
44 | run_openpose_cmd(args, video_idx)
45 |
46 | # Remove all unusable frames.
47 | remove_invalid_frames(args, video_idx)
48 | remove_static_frames(args, video_idx)
49 | remove_isolated_frames(args, video_idx)
50 | except:
51 | raise ValueError('video %s running openpose failed' % video_idx)
52 |
53 |
54 | # Run DensePose on the extracted frames, and remove invalid frames.
55 | def run_densepose(args, video_idx):
56 | try:
57 | img_dir = path.join(args.output_root, args.img_folder, video_idx)
58 | dp_dir = path.join(args.output_root, args.densepose_folder, video_idx)
59 | img_names = sorted(glob.glob(img_dir + '/*.jpg'))
60 | dp_names = sorted(glob.glob(dp_dir + '/*.png'))
61 |
62 | if not os.path.isdir(dp_dir) or len(img_names) != len(dp_names):
63 | makedirs(dp_dir)
64 |
65 | # Run densepose.
66 | run_densepose_cmd(args, video_idx)
67 | except:
68 | raise ValueError('video %s running densepose failed' % video_idx)
69 |
70 |
71 | # Extract only the keyframes in the video.
72 | def extract_key_frames(args, video_path, img_dir):
73 | print('Extracting keyframes.')
74 | vidcap = cv2.VideoCapture(video_path)
75 | success, image = vidcap.read()
76 | frame_count = 0
77 | while success:
78 | if frame_count % args.n_skip_frames == 0:
79 | write_name = path.join(img_dir, "frame%06d.jpg" % frame_count)
80 | cv2.imwrite(write_name, image)
81 | success, image = vidcap.read()
82 | frame_count += 1
83 |
84 |
85 | # Extract valid frames from the video based on the extracted keyframes.
86 | def extract_valid_frames(args, video_path, img_dir):
87 | vidcap = cv2.VideoCapture(video_path)
88 | success, image = vidcap.read()
89 | frame_count = 0
90 | do_write = True
91 | while success:
92 | is_key_frame = frame_count % args.n_skip_frames == 0
93 | write_name = path.join(img_dir, "frame%06d.jpg" % frame_count)
94 | # Each time it's keyframe, check whether the frame is valid. If it is,
95 | # all frames following it before the next keyframe will be extracted.
96 | # Otherwise, this block of frames will be skipped and the next keyframe
97 | # will be examined.
98 | if is_key_frame:
99 | do_write = is_valid_frame(args, write_name)
100 | if not do_write: # If not valid, remove this keyframe.
101 | remove_frame(args, start=write_name)
102 | if do_write:
103 | cv2.imwrite(write_name, image)
104 | success, image = vidcap.read()
105 | frame_count += 1
106 | print('Video contains %d frames.' % frame_count)
107 |
108 |
109 | # Extract all frames from the video.
110 | def extract_all_frames(args, video_path):
111 | video_idx = path.basename(video_path).split('.')[0]
112 | img_dir = path.join(args.output_root, args.img_folder, video_idx)
113 | makedirs(img_dir)
114 |
115 | vidcap = cv2.VideoCapture(video_path)
116 | success, image = vidcap.read()
117 | frame_count = 0
118 | while success:
119 | write_name = path.join(img_dir, "frame%06d.jpg" % frame_count)
120 | cv2.imwrite(write_name, image)
121 | success, image = vidcap.read()
122 | frame_count += 1
123 | print('Extracted %d frames' % frame_count)
124 |
125 |
126 | # Running the actual OpenPose command.
127 | def run_openpose_cmd(args, video_idx):
128 | pwd = os.getcwd()
129 | img_dir = path.join(pwd, args.output_root, args.img_folder, video_idx)
130 | op_dir = path.join(pwd, args.output_root, args.openpose_folder, video_idx)
131 | render_dir = path.join(pwd, args.output_root,
132 | args.openpose_folder + '_rendered', video_idx)
133 | makedirs(op_dir)
134 | makedirs(render_dir)
135 |
136 | cmd = 'cd %s; ./build/examples/openpose/openpose.bin --display 0 ' \
137 | '--disable_blending --image_dir %s --write_images %s --face --hand ' \
138 | '--face_render_threshold 0.1 --hand_render_threshold 0.02 ' \
139 | '--write_json %s; cd %s' \
140 | % (args.openpose_root, img_dir, render_dir, op_dir,
141 | path.join(pwd, args.output_root))
142 | os.system(cmd)
143 |
144 |
145 | # Running the actual DensePose command.
146 | def run_densepose_cmd(args, video_idx):
147 | pwd = os.getcwd()
148 | img_dir = path.join(pwd, args.output_root, args.img_folder, video_idx)
149 | dp_dir = path.join(pwd, args.output_root, args.densepose_folder, video_idx, 'frame.png')
150 | cmd = 'python2 tools/infer_simple.py ' \
151 | '--cfg configs/DensePose_ResNet101_FPN_s1x-e2e.yaml ' \
152 | '--wts https://dl.fbaipublicfiles.com/densepose/DensePose_ResNet101_FPN_s1x-e2e.pkl ' \
153 | '--output-dir %s %s' % (dp_dir, img_dir)
154 | # cmd = 'python apply_net.py show configs/densepose_rcnn_R_101_FPN_s1x.yaml ' \
155 | # 'densepose_rcnn_R_101_FPN_s1x.pkl %s dp_segm,dp_u,dp_v --output %s' \
156 | # % (img_dir, dp_dir)
157 | cmd = 'cd %s; ' % args.densepose_root \
158 | + cmd \
159 | + '; cd %s' % path.join(pwd, args.output_root)
160 | os.system(cmd)
161 |
--------------------------------------------------------------------------------
/data/preprocess/util/util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import os.path as path
9 | import numpy as np
10 |
11 |
12 | # Extract keypoint array given the openpose dict.
13 | def get_keypoint_array(keypoint_dict):
14 | if type(keypoint_dict) == list:
15 | return [get_keypoint_array(d) for d in keypoint_dict]
16 | if type(keypoint_dict) != np.ndarray:
17 | keypoint_dict = np.array(keypoint_dict["pose_keypoints_2d"]).reshape(25, 3)
18 | return keypoint_dict
19 |
20 |
21 | # Only extract openpose keypoints where the confidence is larger then conf_thre.
22 | def get_valid_openpose_keypoints(keypoint_array):
23 | if type(keypoint_array) == list:
24 | return [get_valid_openpose_keypoints(k) for k in keypoint_array]
25 | return keypoint_array[keypoint_array[:, 2] > 0.01, :]
26 |
27 |
28 | # Remove particular frame(s) for all folders.
29 | def remove_frame(args, video_idx='', start=0, end=None):
30 | if not isinstance(start, int):
31 | video_idx = path.basename(path.dirname(start))
32 | start = get_frame_idx(start)
33 | if end is None:
34 | end = start
35 | for i in range(start, end + 1):
36 | img_path = path.join(args.output_root, args.img_folder, video_idx,
37 | 'frame%06d.jpg' % i)
38 | op_path = path.join(args.output_root, args.openpose_folder, video_idx,
39 | 'frame%06d%s' % (i, args.openpose_postfix))
40 | dp_path = path.join(args.output_root, args.densepose_folder, video_idx,
41 | 'frame%06d%s' % (i, args.densepose_postfix))
42 | dm_path = path.join(args.output_root, args.densemask_folder, video_idx,
43 | 'frame%06d%s' % (i, args.densemask_postfix))
44 | print('removing %s' % img_path)
45 | remove(img_path)
46 | remove(op_path)
47 | remove(dp_path)
48 | remove(dm_path)
49 |
50 |
51 | def remove_folder(args, video_idx):
52 | os.rmdir(path.join(args.output_root, args.img_folder, video_idx))
53 | os.rmdir(path.join(args.output_root, args.openpose_folder, video_idx))
54 | os.rmdir(path.join(args.output_root, args.densepose_folder, video_idx))
55 | os.rmdir(path.join(args.output_root, args.densemask_folder, video_idx))
56 |
57 |
58 | def get_frame_idx(file_name):
59 | return int(path.basename(file_name)[5:11])
60 |
61 |
62 | def makedirs(folder):
63 | if not path.exists(folder):
64 | os.umask(0)
65 | os.makedirs(folder, mode=0o777)
66 |
67 |
68 | def remove(file_name):
69 | if path.exists(file_name):
70 | os.remove(file_name)
71 |
--------------------------------------------------------------------------------
/data/preprocess/youTube_playlist.txt:
--------------------------------------------------------------------------------
1 | https://www.youtube.com/playlist?list=PLZgHzuctidRIYzMpEcbEMIFp7vGcFK4x7
2 | https://www.youtube.com/playlist?list=PL0m7UHzPZEA9R8Y6xautFgqeWnorDj2Le
3 | https://www.youtube.com/playlist?list=PLsVSF-hJhvBKV9XDaqJAEx-otHG1_XT9q
--------------------------------------------------------------------------------
/imgs/dance.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/dance.gif
--------------------------------------------------------------------------------
/imgs/face.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/face.gif
--------------------------------------------------------------------------------
/imgs/illustration.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/illustration.gif
--------------------------------------------------------------------------------
/imgs/mona_lisa.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/mona_lisa.gif
--------------------------------------------------------------------------------
/imgs/statue.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/statue.gif
--------------------------------------------------------------------------------
/imgs/street.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/imgs/street.gif
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
--------------------------------------------------------------------------------
/models/face_refiner.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import torch
8 | import torch.nn.functional as F
9 |
10 | from .base_model import BaseModel
11 |
12 | class FaceRefineModel(BaseModel):
13 | def name(self):
14 | return 'FaceRefineModel'
15 |
16 | def initialize(self, opt, add_face_D, refine_face):
17 | BaseModel.initialize(self, opt)
18 | self.opt = opt
19 | self.add_face_D = add_face_D
20 | self.refine_face = refine_face
21 | self.face_size = int(opt.fineSize / opt.aspect_ratio) // 4
22 |
23 | ### refine the face region of the fake image
24 | def refine_face_region(self, netGf, label_valid, fake_image, label, ref_label_valid, ref_image, ref_label):
25 | label_face, fake_face_coarse = self.crop_face_region([label_valid, fake_image], label, crop_smaller=4)
26 | ref_label_face, ref_image_face = self.crop_face_region([ref_label_valid, ref_image], ref_label, crop_smaller=4)
27 | fake_face = netGf(label_face, ref_label_face.unsqueeze(1), ref_image_face.unsqueeze(1), img_coarse=fake_face_coarse.detach())
28 | fake_image = self.replace_face_region(fake_image, fake_face, label, fake_face_coarse.detach(), crop_smaller=4)
29 | return fake_image
30 |
31 | ### crop out the face region of the image (and resize if necessary to feed into generator/discriminator)
32 | def crop_face_region(self, image, input_label, crop_smaller=0):
33 | if type(image) == list:
34 | return [self.crop_face_region(im, input_label, crop_smaller) for im in image]
35 | for i in range(input_label.size(0)):
36 | ys, ye, xs, xe = self.get_face_region(input_label[i:i+1], crop_smaller=crop_smaller)
37 | output_i = F.interpolate(image[i:i+1,-3:,ys:ye,xs:xe], size=(self.face_size, self.face_size))
38 | output = torch.cat([output, output_i]) if i != 0 else output_i
39 | return output
40 |
41 | ### replace the face region in the fake image with the refined version
42 | def replace_face_region(self, fake_image, fake_face, input_label, fake_face_coarse=None, crop_smaller=0):
43 | fake_image = fake_image.clone()
44 | b, _, h, w = input_label.size()
45 | for i in range(b):
46 | ys, ye, xs, xe = self.get_face_region(input_label[i:i+1], crop_smaller)
47 | fake_face_i = fake_face[i:i+1] + (fake_face_coarse[i:i+1] if fake_face_coarse is not None else 0)
48 | fake_face_i = F.interpolate(fake_face_i, size=(ye-ys, xe-xs), mode='bilinear')
49 | fake_image[i:i+1,:,ys:ye,xs:xe] = torch.clamp(fake_face_i, -1, 1)
50 | return fake_image
51 |
52 | ### get coordinates of the face bounding box
53 | def get_face_region(self, pose, crop_smaller=0):
54 | if pose.dim() == 3: pose = pose.unsqueeze(0)
55 | elif pose.dim() == 5: pose = pose[-1,-1:]
56 | _, _, h, w = pose.size()
57 |
58 | use_openpose = not self.opt.basic_point_only and not self.opt.remove_face_labels
59 | if use_openpose: # use openpose face keypoints to identify face region
60 | face = ((pose[:,-3] > 0) & (pose[:,-2] > 0) & (pose[:,-1] > 0)).nonzero()
61 | else: # use densepose labels
62 | face = (pose[:,2] > 0.9).nonzero()
63 | if face.size(0):
64 | y, x = face[:,1], face[:,2]
65 | ys, ye, xs, xe = y.min().item(), y.max().item(), x.min().item(), x.max().item()
66 | if use_openpose:
67 | xc, yc = (xs + xe) // 2, (ys*3 + ye*2) // 5
68 | ylen = int((xe - xs) * 2.5)
69 | else:
70 | xc, yc = (xs + xe) // 2, (ys + ye) // 2
71 | ylen = int((ye - ys) * 1.25)
72 | ylen = xlen = min(w, max(32, ylen))
73 | yc = max(ylen//2, min(h-1 - ylen//2, yc))
74 | xc = max(xlen//2, min(w-1 - xlen//2, xc))
75 | else:
76 | yc = h//4
77 | xc = w//2
78 | ylen = xlen = h // 32 * 8
79 |
80 | ys, ye, xs, xe = yc - ylen//2, yc + ylen//2, xc - xlen//2, xc + xlen//2
81 | if crop_smaller != 0: # crop slightly smaller region inside face
82 | ys += crop_smaller; xs += crop_smaller
83 | ye -= crop_smaller; xe -= crop_smaller
84 | return ys, ye, xs, xe
--------------------------------------------------------------------------------
/models/flownet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | import numpy as np
9 | import torch
10 | import torch.nn.functional as F
11 | import sys
12 |
13 | from .base_model import BaseModel
14 |
15 | class FlowNet(BaseModel):
16 | def name(self):
17 | return 'FlowNet'
18 |
19 | def initialize(self, opt):
20 | BaseModel.initialize(self, opt)
21 |
22 | # flownet 2
23 | from .networks.flownet2_pytorch import models as flownet2_models
24 | from .networks.flownet2_pytorch.utils import tools as flownet2_tools
25 | from .networks.flownet2_pytorch.networks.resample2d_package.resample2d import Resample2d
26 |
27 | self.flowNet = flownet2_tools.module_to_dict(flownet2_models)['FlowNet2']().cuda()
28 | checkpoint = torch.load('models/networks/flownet2_pytorch/FlowNet2_checkpoint.pth.tar', map_location=torch.device('cpu'))
29 | self.flowNet.load_state_dict(checkpoint['state_dict'])
30 | self.flowNet.eval()
31 | self.resample = Resample2d()
32 | self.downsample = torch.nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
33 |
34 | def forward(self, data_list, epoch=0, dummy_bs=0):
35 | if data_list[0].get_device() == 0:
36 | data_list = self.remove_dummy_from_tensor(data_list, dummy_bs)
37 | image_now, image_ref = data_list
38 | image_now, image_ref = image_now[:,:,:3], image_ref[:,0:1,:3]
39 |
40 | flow_gt_prev = flow_gt_ref = conf_gt_prev = conf_gt_ref = None
41 | with torch.no_grad():
42 | if not self.opt.isTrain or epoch > self.opt.niter_single:
43 | image_prev = torch.cat([image_now[:,0:1], image_now[:,:-1]], dim=1)
44 | flow_gt_prev, conf_gt_prev = self.flowNet_forward(image_now, image_prev)
45 |
46 | if self.opt.warp_ref:
47 | flow_gt_ref, conf_gt_ref = self.flowNet_forward(image_now, image_ref.expand_as(image_now))
48 |
49 | flow_gt, conf_gt = [flow_gt_ref, flow_gt_prev], [conf_gt_ref, conf_gt_prev]
50 | return flow_gt, conf_gt
51 |
52 | def flowNet_forward(self, input_A, input_B):
53 | size = input_A.size()
54 | assert(len(size) == 4 or len(size) == 5)
55 | if len(size) == 5:
56 | b, n, c, h, w = size
57 | input_A = input_A.contiguous().view(-1, c, h, w)
58 | input_B = input_B.contiguous().view(-1, c, h, w)
59 | flow, conf = self.compute_flow_and_conf(input_A, input_B)
60 | return flow.view(b, n, 2, h, w), conf.view(b, n, 1, h, w)
61 | else:
62 | return self.compute_flow_and_conf(input_A, input_B)
63 |
64 | def compute_flow_and_conf(self, im1, im2):
65 | assert(im1.size()[1] == 3)
66 | assert(im1.size() == im2.size())
67 | old_h, old_w = im1.size()[2], im1.size()[3]
68 | new_h, new_w = old_h//64*64, old_w//64*64
69 | if old_h != new_h:
70 | im1 = F.interpolate(im1, size=(new_h, new_w), mode='bilinear')
71 | im2 = F.interpolate(im2, size=(new_h, new_w), mode='bilinear')
72 | self.flowNet.cuda(im1.get_device())
73 | data1 = torch.cat([im1.unsqueeze(2), im2.unsqueeze(2)], dim=2)
74 | flow1 = self.flowNet(data1)
75 | conf = (self.norm(im1 - self.resample(im2, flow1)) < 0.02).float()
76 |
77 | if old_h != new_h:
78 | flow1 = F.interpolate(flow1, size=(old_h, old_w), mode='bilinear') * old_h / new_h
79 | conf = F.interpolate(conf, size=(old_h, old_w), mode='bilinear')
80 | return flow1, conf
81 |
82 | def norm(self, t):
83 | return torch.sum(t*t, dim=1, keepdim=True)
84 |
--------------------------------------------------------------------------------
/models/input_process.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import torch
8 |
9 | ############################# input processing ###################################
10 | def encode_input(opt, data_list, dummy_bs):
11 | if opt.isTrain and data_list[0].get_device() == 0:
12 | data_list = remove_dummy_from_tensor(opt, data_list, dummy_bs)
13 | tgt_label, tgt_image, flow_gt, conf_gt, ref_label, ref_image, prev_label, prev_real_image, prev_fake_image = data_list
14 |
15 | # target label and image
16 | tgt_label = encode_label(opt, tgt_label)
17 | tgt_image = tgt_image.cuda()
18 |
19 | # reference label and image
20 | ref_label = encode_label(opt, ref_label)
21 | ref_image = ref_image.cuda()
22 |
23 | return tgt_label, tgt_image, flow_gt, conf_gt, ref_label, ref_image, [prev_label, prev_real_image, prev_fake_image]
24 |
25 | def encode_label(opt, label_map):
26 | size = label_map.size()
27 | if len(size) == 5:
28 | bs, t, c, h, w = size
29 | label_map = label_map.view(-1, c, h, w)
30 | else:
31 | bs, c, h, w = size
32 |
33 | label_nc = opt.label_nc
34 | if label_nc == 0:
35 | input_label = label_map.cuda()
36 | else:
37 | # create one-hot vector for label map
38 | label_map = label_map.cuda()
39 | oneHot_size = (label_map.shape[0], label_nc, h, w)
40 | input_label = torch.cuda.FloatTensor(torch.Size(oneHot_size)).zero_()
41 | input_label = input_label.scatter_(1, label_map.long().cuda(), 1.0)
42 |
43 | if len(size) == 5:
44 | return input_label.view(bs, t, -1, h, w)
45 | return input_label
46 |
47 | ### get the union of target and reference foreground masks
48 | def combine_fg_mask(fg_mask, ref_fg_mask, has_fg):
49 | return ((fg_mask > 0) | (ref_fg_mask > 0)).float() if has_fg else 1
50 |
51 | ### obtain the foreground mask for pose sequences, which only includes the human
52 | def get_fg_mask(opt, input_label, has_fg):
53 | if type(input_label) == list:
54 | return [get_fg_mask(opt, l, has_fg) for l in input_label]
55 | if not has_fg: return None
56 | if len(input_label.size()) == 5: input_label = input_label[:,0]
57 | mask = input_label[:,2:3] if opt.label_nc == 0 else -input_label[:,0:1]
58 |
59 | mask = torch.nn.MaxPool2d(15, padding=7, stride=1)(mask) # make the mask slightly larger
60 | mask = (mask > -1).float()
61 | return mask
62 |
63 | ### obtain mask of different body parts
64 | def get_part_mask(pose):
65 | part_groups = [[0], [1,2], [3,4], [5,6], [7,9,8,10], [11,13,12,14], [15,17,16,18], [19,21,20,22], [23,24]]
66 | n_parts = len(part_groups)
67 |
68 | need_reshape = pose.dim() == 4
69 | if need_reshape:
70 | bo, t, h, w = pose.size()
71 | pose = pose.view(-1, h, w)
72 | b, h, w = pose.size()
73 | part = (pose / 2 + 0.5) * 24
74 | mask = torch.cuda.ByteTensor(b, n_parts, h, w).fill_(0)
75 | for i in range(n_parts):
76 | for j in part_groups[i]:
77 | mask[:, i] = mask[:, i] | ((part > j-0.1) & (part < j+0.1)).byte()
78 | if need_reshape:
79 | mask = mask.view(bo, t, -1, h, w)
80 | return mask.float()
81 |
82 | ### obtain mask of faces
83 | def get_face_mask(pose):
84 | if pose.dim() == 3:
85 | pose = pose.unsqueeze(1)
86 | b, t, h, w = pose.size()
87 | part = (pose / 2 + 0.5) * 24
88 | if pose.is_cuda:
89 | mask = torch.cuda.ByteTensor(b, t, h, w).fill_(0)
90 | else:
91 | mask = torch.ByteTensor(b, t, h, w).fill_(0)
92 | for j in [23,24]:
93 | mask = mask | ((part > j-0.1) & (part < j+0.1)).byte()
94 | return mask.float()
95 |
96 | ### remove partial labels in the pose map if necessary
97 | def use_valid_labels(opt, pose):
98 | if 'pose' not in opt.dataset_mode: return pose
99 | if pose is None: return pose
100 | if type(pose) == list:
101 | return [use_valid_labels(opt, p) for p in pose]
102 | assert(pose.dim() == 4 or pose.dim() == 5)
103 | if opt.pose_type == 'open':
104 | if pose.dim() == 4: pose = pose[:,3:]
105 | elif pose.dim() == 5: pose = pose[:,:,3:]
106 | elif opt.remove_face_labels:
107 | if pose.dim() == 4:
108 | face_mask = get_face_mask(pose[:,2])
109 | pose = torch.cat([pose[:,:3] * (1 - face_mask) - face_mask, pose[:,3:]], dim=1)
110 | else:
111 | face_mask = get_face_mask(pose[:,:,2]).unsqueeze(2)
112 | pose = torch.cat([pose[:,:,:3] * (1 - face_mask) - face_mask, pose[:,:,3:]], dim=2)
113 | return pose
114 |
115 | def remove_dummy_from_tensor(opt, tensors, remove_size=0):
116 | if remove_size == 0: return tensors
117 | if tensors is None: return None
118 | if isinstance(tensors, list):
119 | return [remove_dummy_from_tensor(opt, tensor, remove_size) for tensor in tensors]
120 | if isinstance(tensors, torch.Tensor):
121 | tensors = tensors[remove_size:]
122 | return tensors
--------------------------------------------------------------------------------
/models/models.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import torch
8 | import torch.nn as nn
9 | import numpy as np
10 |
11 | from models.networks.sync_batchnorm import DataParallelWithCallback
12 | from models.vid2vid_model import Vid2VidModel
13 | from util.distributed import master_only
14 | from util.distributed import master_only_print as print
15 |
16 | def create_model(opt, epoch=0):
17 | model = Vid2VidModel()
18 | model.initialize(opt, epoch)
19 | print("model [%s] was created" % (model.name()))
20 |
21 | if opt.isTrain:
22 | if opt.amp != 'O0':
23 | from apex import amp
24 | print('using amp optimization')
25 | model, optimizers = amp.initialize(model, [model.optimizer_G, model.optimizer_D],
26 | opt_level=opt.amp, num_losses=2)
27 | else:
28 | optimizers = model.optimizer_G, model.optimizer_D
29 |
30 | model = WrapModel(opt, model)
31 | flowNet = None
32 | if not opt.no_flow_gt:
33 | from .flownet import FlowNet
34 | flowNet = FlowNet()
35 | flowNet.initialize(opt)
36 | flowNet = WrapModel(opt, flowNet)
37 | return model, flowNet, optimizers
38 | return model
39 |
40 | def WrapModel(opt, model):
41 | if opt.distributed:
42 | import apex
43 | model = apex.parallel.DistributedDataParallel(model.cuda(), delay_allreduce=True)
44 | else:
45 | model = MyModel(opt, model)
46 | return model
47 |
48 | @master_only
49 | def save_models(opt, epoch, epoch_iter, total_steps, visualizer, iter_path, model, end_of_epoch=False):
50 | if not end_of_epoch:
51 | if (total_steps % opt.save_latest_freq == 0):
52 | visualizer.vis_print(opt, 'saving the latest model (epoch %d, total_steps %d)' % (epoch, total_steps))
53 | model.module.save_networks('latest')
54 | np.savetxt(iter_path, (epoch, epoch_iter), delimiter=',', fmt='%d')
55 | model.cuda()
56 | else:
57 | if epoch % opt.save_epoch_freq == 0:
58 | visualizer.vis_print(opt, 'saving the model at the end of epoch %d, iters %d' % (epoch, total_steps))
59 | model.module.save_networks('latest')
60 | model.module.save_networks(epoch)
61 | np.savetxt(iter_path, (epoch+1, 0), delimiter=',', fmt='%d')
62 | model.cuda()
63 |
64 | def update_models(opt, epoch, model, data_loader):
65 | ### linearly decay learning rate after certain iterations
66 | if epoch > opt.niter:
67 | model.module.update_learning_rate(epoch)
68 |
69 | ### train single frame first then sequence of frames
70 | if epoch == opt.niter_single + 1 and not model.module.temporal:
71 | model.module.init_temporal_model()
72 |
73 | ### gradually grow training sequence length
74 | epoch_temp = epoch - opt.niter_single
75 | if epoch_temp > 0 and ((epoch_temp - 1) % opt.niter_step) == 0:
76 | data_loader.dataset.update_training_batch((epoch_temp - 1) // opt.niter_step)
77 |
78 |
79 | class MyModel(nn.Module):
80 | def __init__(self, opt, model):
81 | super(MyModel, self).__init__()
82 | self.opt = opt
83 | model = model.cuda(opt.gpu_ids[0])
84 | self.module = model
85 |
86 | self.model = DataParallelWithCallback(model, device_ids=opt.gpu_ids)
87 | if opt.batch_for_first_gpu != -1:
88 | self.bs_per_gpu = (opt.batchSize - opt.batch_for_first_gpu) // (len(opt.gpu_ids) - 1) # batch size for each GPU
89 | else:
90 | self.bs_per_gpu = int(np.ceil(float(opt.batchSize) / len(opt.gpu_ids))) # batch size for each GPU
91 | self.pad_bs = self.bs_per_gpu * len(opt.gpu_ids) - opt.batchSize
92 |
93 | def forward(self, *inputs, **kwargs):
94 | inputs = self.add_dummy_to_tensor(inputs, self.pad_bs)
95 | outputs = self.model(*inputs, **kwargs, dummy_bs=self.pad_bs)
96 | if self.pad_bs == self.bs_per_gpu: # gpu 0 does 0 batch but still returns 1 batch
97 | return self.remove_dummy_from_tensor(outputs, 1)
98 | return outputs
99 |
100 | def add_dummy_to_tensor(self, tensors, add_size=0):
101 | if add_size == 0 or tensors is None: return tensors
102 | if type(tensors) == list or type(tensors) == tuple:
103 | return [self.add_dummy_to_tensor(tensor, add_size) for tensor in tensors]
104 |
105 | if isinstance(tensors, torch.Tensor):
106 | dummy = torch.zeros_like(tensors)[:add_size]
107 | tensors = torch.cat([dummy, tensors])
108 | return tensors
109 |
110 | def remove_dummy_from_tensor(self, tensors, remove_size=0):
111 | if remove_size == 0 or tensors is None: return tensors
112 | if type(tensors) == list or type(tensors) == tuple:
113 | return [self.remove_dummy_from_tensor(tensor, remove_size) for tensor in tensors]
114 |
115 | if isinstance(tensors, torch.Tensor):
116 | tensors = tensors[remove_size:]
117 | return tensors
--------------------------------------------------------------------------------
/models/networks/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import torch
8 | import torch.nn as nn
9 | import functools
10 | import numpy as np
11 | import torch.nn.functional as F
12 | from models.networks.loss import *
13 | from models.networks.discriminator import *
14 | from models.networks.generator import *
15 |
16 |
17 | def modify_commandline_options(parser, is_train):
18 | opt, _ = parser.parse_known_args()
19 | if 'fewshot' in opt.netG:
20 | parser = FewShotGenerator.modify_commandline_options(parser, is_train)
21 |
22 | if is_train:
23 | if opt.which_model_netD == 'multiscale':
24 | parser = MultiscaleDiscriminator.modify_commandline_options(parser, is_train)
25 | elif opt.which_model_netD == 'n_layers':
26 | parser = NLayerDiscriminator.modify_commandline_options(parser, is_train)
27 | return parser
28 |
29 | def define_G(opt):
30 | if 'fewshot' in opt.netG:
31 | netG = FewShotGenerator(opt)
32 | else:
33 | raise('generator not implemented!')
34 | if opt.isTrain and opt.print_G: netG.print_network()
35 | if len(opt.gpu_ids) > 0:
36 | assert(torch.cuda.is_available())
37 | netG.cuda()
38 | netG.init_weights(opt.init_type, opt.init_variance)
39 | return netG
40 |
41 | def define_D(opt, input_nc, ndf, n_layers_D, norm='spectralinstance', subarch='n_layers', num_D=1, getIntermFeat=False, stride=2, gpu_ids=[]):
42 | norm_layer = get_nonspade_norm_layer(opt, norm_type=norm)
43 | if opt.which_model_netD == 'multiscale':
44 | netD = MultiscaleDiscriminator(opt, input_nc, ndf, n_layers_D, norm_layer, subarch, num_D, getIntermFeat, stride, gpu_ids)
45 | elif opt.which_model_netD == 'n_layers':
46 | netD = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer, getIntermFeat)
47 | else:
48 | raise('unknown type discriminator %s!' % opt.which_model_netD)
49 |
50 | if opt.isTrain and opt.print_D: netD.print_network()
51 | if len(gpu_ids) > 0:
52 | assert(torch.cuda.is_available())
53 | netD.cuda()
54 | netD.init_weights(opt.init_type, opt.init_variance)
55 | return netD
56 |
--------------------------------------------------------------------------------
/models/networks/architecture.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torch.nn.utils.spectral_norm as sn
11 |
12 | from models.networks.base_network import BaseNetwork, batch_conv
13 | from models.networks.normalization import SPADE, SynchronizedBatchNorm2d
14 |
15 | def actvn(x):
16 | out = F.leaky_relu(x, 2e-1)
17 | return out
18 |
19 | def generalConv(adaptive=False, transpose=False):
20 | class NormalConv2d(nn.Conv2d):
21 | def __init__(self, *args, **kwargs):
22 | super(NormalConv2d, self).__init__(*args, **kwargs)
23 | def forward(self, input, weight=None, bias=None, stride=1):
24 | return super(NormalConv2d, self).forward(input)
25 | class NormalConvTranspose2d(nn.ConvTranspose2d):
26 | def __init__(self, *args, output_padding=1, **kwargs):
27 | #kwargs['output_padding'] = 1
28 | super(NormalConvTranspose2d, self).__init__(*args, **kwargs)
29 | def forward(self, input, weight=None, bias=None, stride=1):
30 | return super(NormalConvTranspose2d, self).forward(input)
31 | class AdaptiveConv2d(nn.Module):
32 | def __init__(self, *args, **kwargs):
33 | super().__init__()
34 | def forward(self, input, weight=None, bias=None, stride=1):
35 | return batch_conv(input, weight, bias, stride)
36 |
37 | if adaptive: return AdaptiveConv2d
38 | return NormalConv2d if not transpose else NormalConvTranspose2d
39 |
40 | def generalNorm(norm):
41 | if 'spade' in norm: return SPADE
42 | def get_norm(norm):
43 | if 'instance' in norm:
44 | return nn.InstanceNorm2d
45 | elif 'syncbatch' in norm:
46 | return SynchronizedBatchNorm2d
47 | elif 'batch' in norm:
48 | return nn.BatchNorm2d
49 | norm = get_norm(norm)
50 | class NormalNorm(norm):
51 | def __init__(self, *args, hidden_nc=0, norm='', ks=1, params_free=False, **kwargs):
52 | super(NormalNorm, self).__init__(*args, **kwargs)
53 | def forward(self, input, label=None, weight=None):
54 | return super(NormalNorm, self).forward(input)
55 | return NormalNorm
56 |
57 | class SPADEConv2d(nn.Module):
58 | def __init__(self, fin, fout, norm='batch', hidden_nc=0, kernel_size=3, padding=1, stride=1):
59 | super().__init__()
60 | self.conv = sn(nn.Conv2d(fin, fout, kernel_size=kernel_size, stride=stride, padding=padding))
61 |
62 | Norm = generalNorm(norm)
63 | self.bn = Norm(fout, hidden_nc=hidden_nc, norm=norm, ks=3)
64 |
65 | def forward(self, x, label=None):
66 | x = self.conv(x)
67 | out = self.bn(x, label)
68 | out = actvn(out)
69 | return out
70 |
71 | class SPADEResnetBlock(nn.Module):
72 | def __init__(self, fin, fout, norm='batch', hidden_nc=0, conv_ks=3, spade_ks=1, stride=1, conv_params_free=False, norm_params_free=False):
73 | super().__init__()
74 | fhidden = min(fin, fout)
75 | self.learned_shortcut = (fin != fout)
76 | self.stride = stride
77 | Conv2d = generalConv(adaptive=conv_params_free)
78 | sn_ = sn if not conv_params_free else lambda x: x
79 |
80 | # Submodules
81 | self.conv_0 = sn_(Conv2d(fin, fhidden, conv_ks, stride=stride, padding=1))
82 | self.conv_1 = sn_(Conv2d(fhidden, fout, conv_ks, padding=1))
83 | if self.learned_shortcut:
84 | self.conv_s = sn_(Conv2d(fin, fout, 1, stride=stride, bias=False))
85 |
86 | Norm = generalNorm(norm)
87 | self.bn_0 = Norm(fin, hidden_nc=hidden_nc, norm=norm, ks=spade_ks, params_free=norm_params_free)
88 | self.bn_1 = Norm(fhidden, hidden_nc=hidden_nc, norm=norm, ks=spade_ks, params_free=norm_params_free)
89 | if self.learned_shortcut:
90 | self.bn_s = Norm(fin, hidden_nc=hidden_nc, norm=norm, ks=spade_ks, params_free=norm_params_free)
91 |
92 | def forward(self, x, label=None, conv_weights=[], norm_weights=[]):
93 | if not conv_weights: conv_weights = [None]*3
94 | if not norm_weights: norm_weights = [None]*3
95 | x_s = self._shortcut(x, label, conv_weights[2], norm_weights[2])
96 | dx = self.conv_0(actvn(self.bn_0(x, label, norm_weights[0])), conv_weights[0])
97 | dx = self.conv_1(actvn(self.bn_1(dx, label, norm_weights[1])), conv_weights[1])
98 | out = x_s + 1.0*dx
99 | return out
100 |
101 | def _shortcut(self, x, label, conv_weights, norm_weights):
102 | if self.learned_shortcut:
103 | x_s = self.conv_s(self.bn_s(x, label, norm_weights), conv_weights)
104 | elif self.stride != 1:
105 | x_s = nn.AvgPool2d(3, stride=2, padding=1)(x)
106 | else:
107 | x_s = x
108 | return x_s
109 |
--------------------------------------------------------------------------------
/models/networks/base_network.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import torch
8 | import torch.nn as nn
9 | from torch.nn import init
10 | import torch.nn.functional as F
11 | import numpy as np
12 |
13 | def get_grid(batchsize, rows, cols, gpu_id=0):
14 | hor = torch.linspace(-1.0, 1.0, cols)
15 | hor.requires_grad = False
16 | hor = hor.view(1, 1, 1, cols)
17 | hor = hor.expand(batchsize, 1, rows, cols)
18 | ver = torch.linspace(-1.0, 1.0, rows)
19 | ver.requires_grad = False
20 | ver = ver.view(1, 1, rows, 1)
21 | ver = ver.expand(batchsize, 1, rows, cols)
22 |
23 | t_grid = torch.cat([hor, ver], 1)
24 | t_grid.requires_grad = False
25 |
26 | return t_grid.cuda(gpu_id)
27 |
28 | def resample(image, flow):
29 | b, c, h, w = image.size()
30 | grid = get_grid(b, h, w, gpu_id=flow.get_device())
31 | flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1)
32 | final_grid = (grid + flow).permute(0, 2, 3, 1).cuda(image.get_device())
33 | try:
34 | output = F.grid_sample(image, final_grid, mode='bilinear', padding_mode='border', align_corners=True)
35 | except:
36 | output = F.grid_sample(image, final_grid, mode='bilinear', padding_mode='border')
37 | return output
38 |
39 | ### pick the reference image that is most similar to current frame
40 | def pick_ref(refs, ref_idx):
41 | if type(refs) == list:
42 | return [pick_ref(r, ref_idx) for r in refs]
43 | if ref_idx is None:
44 | return refs[:,0]
45 | ref_idx = ref_idx.long().view(-1, 1, 1, 1, 1)
46 | ref = refs.gather(1, ref_idx.expand_as(refs)[:,0:1])[:,0]
47 | return ref
48 |
49 | def concat(a, b, dim=0):
50 | if isinstance(a, list):
51 | return [concat(ai, bi, dim) for ai, bi in zip(a, b)]
52 | if a is None:
53 | return b
54 | return torch.cat([a, b], dim=dim)
55 |
56 | def batch_conv(x, weight, bias=None, stride=1, group_size=-1):
57 | if weight is None: return x
58 | if isinstance(weight, list) or isinstance(weight, tuple):
59 | weight, bias = weight
60 | padding = weight.size()[-1] // 2
61 | groups = group_size//weight.size()[2] if group_size != -1 else 1
62 | if bias is None: bias = [None] * x.size()[0]
63 | y = None
64 | for i in range(x.size(0)):
65 | if stride >= 1:
66 | yi = F.conv2d(x[i:i+1], weight=weight[i], bias=bias[i], padding=padding, stride=stride, groups=groups)
67 | else:
68 | yi = F.conv_transpose2d(x[i:i+1], weight=weight[i], bias=bias[i,:weight.size(2)], padding=padding, stride=int(1/stride),
69 | output_padding=1, groups=groups)
70 | y = concat(y, yi)
71 | return y
72 |
73 | class BaseNetwork(nn.Module):
74 | def __init__(self):
75 | super(BaseNetwork, self).__init__()
76 |
77 | def print_network(self):
78 | if isinstance(self, list):
79 | self = self[0]
80 | num_params = 0
81 | for param in self.parameters():
82 | num_params += param.numel()
83 | print(self)
84 | print('Total number of parameters: %d' % num_params)
85 |
86 | def init_weights(self, init_type='normal', gain=0.02):
87 | def init_func(m):
88 | classname = m.__class__.__name__
89 | if classname.find('BatchNorm2d') != -1:
90 | if hasattr(m, 'weight') and m.weight is not None:
91 | init.normal_(m.weight.data, 1.0, gain)
92 | if hasattr(m, 'bias') and m.bias is not None:
93 | init.constant_(m.bias.data, 0.0)
94 | elif hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
95 | if init_type == 'normal':
96 | init.normal_(m.weight.data, 0.0, gain)
97 | elif init_type == 'xavier':
98 | init.xavier_normal_(m.weight.data, gain=gain)
99 | elif init_type == 'xavier_uniform':
100 | init.xavier_uniform_(m.weight.data, gain=1.0)
101 | elif init_type == 'kaiming':
102 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
103 | elif init_type == 'orthogonal':
104 | init.orthogonal_(m.weight.data, gain=gain)
105 | elif init_type == 'none':
106 | m.reset_parameters()
107 | else:
108 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
109 | if hasattr(m, 'bias') and m.bias is not None:
110 | init.constant_(m.bias.data, 0.0)
111 |
112 | self.apply(init_func)
113 | for m in self.children():
114 | if hasattr(m, 'init_weights'):
115 | m.init_weights(init_type, gain)
116 |
117 | def load_pretrained_net(self, net_src, net_dst):
118 | source_weights = net_src.state_dict()
119 | target_weights = net_dst.state_dict()
120 |
121 | for k, v in source_weights.items():
122 | if k in target_weights and target_weights[k].size() == v.size():
123 | target_weights[k] = v
124 | net_dst.load_state_dict(target_weights)
125 |
126 | def reparameterize(self, mu, logvar):
127 | std = torch.exp(0.5 * logvar)
128 | eps = torch.randn_like(std)
129 | z = eps.mul(std) + mu
130 | return z
131 |
132 | def sum(self, x):
133 | if type(x) != list: return x
134 | return sum([self.sum(xi) for xi in x])
135 |
136 | def sum_mul(self, x):
137 | assert(type(x) == list)
138 | if type(x[0]) != list:
139 | return np.prod(x) + x[0]
140 | return [self.sum_mul(xi) for xi in x]
141 |
142 | def split_weights(self, weight, sizes):
143 | if isinstance(sizes, list):
144 | weights = []
145 | cur_size = 0
146 | for i in range(len(sizes)):
147 | next_size = cur_size + self.sum(sizes[i])
148 | weights.append(self.split_weights(weight[:,cur_size:next_size], sizes[i]))
149 | cur_size = next_size
150 | assert(next_size == weight.size()[1])
151 | return weights
152 | return weight
153 |
154 | def reshape_weight(self, x, weight_size):
155 | if type(weight_size[0]) == list and type(x) != list:
156 | x = self.split_weights(x, self.sum_mul(weight_size))
157 | if type(x) == list:
158 | return [self.reshape_weight(xi, wi) for xi, wi in zip(x, weight_size)]
159 | weight_size = [x.size()[0]] + weight_size
160 | bias_size = weight_size[1]
161 | try:
162 | weight = x[:, :-bias_size].view(weight_size)
163 | bias = x[:, -bias_size:]
164 | except:
165 | weight = x.view(weight_size)
166 | bias = None
167 | return [weight, bias]
168 |
169 | def reshape_embed_input(self, x):
170 | if isinstance(x, list):
171 | return [self.reshape_embed_input(xi) for xi in zip(x)]
172 | b, c, _, _ = x.size()
173 | x = x.view(b*c, -1)
174 | return x
175 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04
2 |
3 | RUN apt-get update && apt-get install -y rsync htop git openssh-server python-pip
4 |
5 | RUN pip install --upgrade pip
6 |
7 | RUN pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
8 | RUN pip install torchvision cffi tensorboardX
9 |
10 | RUN pip install tqdm scipy scikit-image colorama==0.3.7
11 | RUN pip install setproctitle pytz ipython
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/README.md:
--------------------------------------------------------------------------------
1 | # flownet2-pytorch
2 |
3 | Pytorch implementation of [FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks](https://arxiv.org/abs/1612.01925).
4 |
5 | Multiple GPU training is supported, and the code provides examples for training or inference on [MPI-Sintel](http://sintel.is.tue.mpg.de/) clean and final datasets. The same commands can be used for training or inference with other datasets. See below for more detail.
6 |
7 | Inference using fp16 (half-precision) is also supported.
8 |
9 | For more help, type
10 |
11 | python main.py --help
12 |
13 | ## Network architectures
14 | Below are the different flownet neural network architectures that are provided.
15 | A batchnorm version for each network is also available.
16 |
17 | - **FlowNet2S**
18 | - **FlowNet2C**
19 | - **FlowNet2CS**
20 | - **FlowNet2CSS**
21 | - **FlowNet2SD**
22 | - **FlowNet2**
23 |
24 | ## Custom layers
25 |
26 | `FlowNet2` or `FlowNet2C*` achitectures rely on custom layers `Resample2d` or `Correlation`.
27 | A pytorch implementation of these layers with cuda kernels are available at [./networks](./networks).
28 | Note : Currently, half precision kernels are not available for these layers.
29 |
30 | ## Data Loaders
31 |
32 | Dataloaders for FlyingChairs, FlyingThings, ChairsSDHom and ImagesFromFolder are available in [datasets.py](./datasets.py).
33 |
34 | ## Loss Functions
35 |
36 | L1 and L2 losses with multi-scale support are available in [losses.py](./losses.py).
37 |
38 | ## Installation
39 |
40 | # get flownet2-pytorch source
41 | git clone https://github.com/NVIDIA/flownet2-pytorch.git
42 | cd flownet2-pytorch
43 |
44 | # install custom layers
45 | bash install.sh
46 |
47 | ### Python requirements
48 | Currently, the code supports python 3
49 | * numpy
50 | * PyTorch ( == 0.4.1, for <= 0.4.0 see branch [python36-PyTorch0.4](https://github.com/NVIDIA/flownet2-pytorch/tree/python36-PyTorch0.4))
51 | * scipy
52 | * scikit-image
53 | * tensorboardX
54 | * colorama, tqdm, setproctitle
55 |
56 | ## Converted Caffe Pre-trained Models
57 | We've included caffe pre-trained models. Should you use these pre-trained weights, please adhere to the [license agreements](https://drive.google.com/file/d/1TVv0BnNFh3rpHZvD-easMb9jYrPE2Eqd/view?usp=sharing).
58 |
59 | * [FlowNet2](https://drive.google.com/file/d/1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da/view?usp=sharing)[620MB]
60 | * [FlowNet2-C](https://drive.google.com/file/d/1BFT6b7KgKJC8rA59RmOVAXRM_S7aSfKE/view?usp=sharing)[149MB]
61 | * [FlowNet2-CS](https://drive.google.com/file/d/1iBJ1_o7PloaINpa8m7u_7TsLCX0Dt_jS/view?usp=sharing)[297MB]
62 | * [FlowNet2-CSS](https://drive.google.com/file/d/157zuzVf4YMN6ABAQgZc8rRmR5cgWzSu8/view?usp=sharing)[445MB]
63 | * [FlowNet2-CSS-ft-sd](https://drive.google.com/file/d/1R5xafCIzJCXc8ia4TGfC65irmTNiMg6u/view?usp=sharing)[445MB]
64 | * [FlowNet2-S](https://drive.google.com/file/d/1V61dZjFomwlynwlYklJHC-TLfdFom3Lg/view?usp=sharing)[148MB]
65 | * [FlowNet2-SD](https://drive.google.com/file/d/1QW03eyYG_vD-dT-Mx4wopYvtPu_msTKn/view?usp=sharing)[173MB]
66 |
67 | ## Inference
68 | # Example on MPISintel Clean
69 | python main.py --inference --model FlowNet2 --save_flow --inference_dataset MpiSintelClean \
70 | --inference_dataset_root /path/to/mpi-sintel/clean/dataset \
71 | --resume /path/to/checkpoints
72 |
73 | ## Training and validation
74 |
75 | # Example on MPISintel Final and Clean, with L1Loss on FlowNet2 model
76 | python main.py --batch_size 8 --model FlowNet2 --loss=L1Loss --optimizer=Adam --optimizer_lr=1e-4 \
77 | --training_dataset MpiSintelFinal --training_dataset_root /path/to/mpi-sintel/final/dataset \
78 | --validation_dataset MpiSintelClean --validation_dataset_root /path/to/mpi-sintel/clean/dataset
79 |
80 | # Example on MPISintel Final and Clean, with MultiScale loss on FlowNet2C model
81 | python main.py --batch_size 8 --model FlowNet2C --optimizer=Adam --optimizer_lr=1e-4 --loss=MultiScale --loss_norm=L1 \
82 | --loss_numScales=5 --loss_startScale=4 --optimizer_lr=1e-4 --crop_size 384 512 \
83 | --training_dataset FlyingChairs --training_dataset_root /path/to/flying-chairs/dataset \
84 | --validation_dataset MpiSintelClean --validation_dataset_root /path/to/mpi-sintel/clean/dataset
85 |
86 | ## Results on MPI-Sintel
87 | [](https://www.youtube.com/watch?v=HtBmabY8aeU "Predicted flows on MPI-Sintel")
88 |
89 | ## Reference
90 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper:
91 | ````
92 | @InProceedings{IMKDB17,
93 | author = "E. Ilg and N. Mayer and T. Saikia and M. Keuper and A. Dosovitskiy and T. Brox",
94 | title = "FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks",
95 | booktitle = "IEEE Conference on Computer Vision and Pattern Recognition (CVPR)",
96 | month = "Jul",
97 | year = "2017",
98 | url = "http://lmb.informatik.uni-freiburg.de//Publications/2017/IMKDB17"
99 | }
100 | ````
101 | ```
102 | @misc{flownet2-pytorch,
103 | author = {Fitsum Reda and Robert Pottorff and Jon Barker and Bryan Catanzaro},
104 | title = {flownet2-pytorch: Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks},
105 | year = {2017},
106 | publisher = {GitHub},
107 | journal = {GitHub repository},
108 | howpublished = {\url{https://github.com/NVIDIA/flownet2-pytorch}}
109 | }
110 | ```
111 | ## Related Optical Flow Work from Nvidia
112 | Code (in Caffe and Pytorch): [PWC-Net](https://github.com/NVlabs/PWC-Net)
113 | Paper : [PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume](https://arxiv.org/abs/1709.02371).
114 |
115 | ## Acknowledgments
116 | Parts of this code were derived, as noted in the code, from [ClementPinard/FlowNetPytorch](https://github.com/ClementPinard/FlowNetPytorch).
117 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/__init__.py
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/convert.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2.7
2 |
3 | import caffe
4 | from caffe.proto import caffe_pb2
5 | import sys, os
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import argparse, tempfile
11 | import numpy as np
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('caffe_model', help='input model in hdf5 or caffemodel format')
15 | parser.add_argument('prototxt_template',help='prototxt template')
16 | parser.add_argument('flownet2_pytorch', help='path to flownet2-pytorch')
17 |
18 | args = parser.parse_args()
19 |
20 | args.rgb_max = 255
21 | args.fp16 = False
22 | args.grads = {}
23 |
24 | # load models
25 | sys.path.append(args.flownet2_pytorch)
26 |
27 | import models
28 | from utils.param_utils import *
29 |
30 | width = 256
31 | height = 256
32 | keys = {'TARGET_WIDTH': width,
33 | 'TARGET_HEIGHT': height,
34 | 'ADAPTED_WIDTH':width,
35 | 'ADAPTED_HEIGHT':height,
36 | 'SCALE_WIDTH':1.,
37 | 'SCALE_HEIGHT':1.,}
38 |
39 | template = '\n'.join(np.loadtxt(args.prototxt_template, dtype=str, delimiter='\n'))
40 | for k in keys:
41 | template = template.replace('$%s$'%(k),str(keys[k]))
42 |
43 | prototxt = tempfile.NamedTemporaryFile(mode='w', delete=True)
44 | prototxt.write(template)
45 | prototxt.flush()
46 |
47 | net = caffe.Net(prototxt.name, args.caffe_model, caffe.TEST)
48 |
49 | weights = {}
50 | biases = {}
51 |
52 | for k, v in list(net.params.items()):
53 | weights[k] = np.array(v[0].data).reshape(v[0].data.shape)
54 | biases[k] = np.array(v[1].data).reshape(v[1].data.shape)
55 | print((k, weights[k].shape, biases[k].shape))
56 |
57 | if 'FlowNet2/' in args.caffe_model:
58 | model = models.FlowNet2(args)
59 |
60 | parse_flownetc(model.flownetc.modules(), weights, biases)
61 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_')
62 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_')
63 | parse_flownetsd(model.flownets_d.modules(), weights, biases, param_prefix='netsd_')
64 | parse_flownetfusion(model.flownetfusion.modules(), weights, biases, param_prefix='fuse_')
65 |
66 | state = {'epoch': 0,
67 | 'state_dict': model.state_dict(),
68 | 'best_EPE': 1e10}
69 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2_checkpoint.pth.tar'))
70 |
71 | elif 'FlowNet2-C/' in args.caffe_model:
72 | model = models.FlowNet2C(args)
73 |
74 | parse_flownetc(model.modules(), weights, biases)
75 | state = {'epoch': 0,
76 | 'state_dict': model.state_dict(),
77 | 'best_EPE': 1e10}
78 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-C_checkpoint.pth.tar'))
79 |
80 | elif 'FlowNet2-CS/' in args.caffe_model:
81 | model = models.FlowNet2CS(args)
82 |
83 | parse_flownetc(model.flownetc.modules(), weights, biases)
84 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_')
85 |
86 | state = {'epoch': 0,
87 | 'state_dict': model.state_dict(),
88 | 'best_EPE': 1e10}
89 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CS_checkpoint.pth.tar'))
90 |
91 | elif 'FlowNet2-CSS/' in args.caffe_model:
92 | model = models.FlowNet2CSS(args)
93 |
94 | parse_flownetc(model.flownetc.modules(), weights, biases)
95 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_')
96 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_')
97 |
98 | state = {'epoch': 0,
99 | 'state_dict': model.state_dict(),
100 | 'best_EPE': 1e10}
101 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS_checkpoint.pth.tar'))
102 |
103 | elif 'FlowNet2-CSS-ft-sd/' in args.caffe_model:
104 | model = models.FlowNet2CSS(args)
105 |
106 | parse_flownetc(model.flownetc.modules(), weights, biases)
107 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_')
108 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_')
109 |
110 | state = {'epoch': 0,
111 | 'state_dict': model.state_dict(),
112 | 'best_EPE': 1e10}
113 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS-ft-sd_checkpoint.pth.tar'))
114 |
115 | elif 'FlowNet2-S/' in args.caffe_model:
116 | model = models.FlowNet2S(args)
117 |
118 | parse_flownetsonly(model.modules(), weights, biases, param_prefix='')
119 | state = {'epoch': 0,
120 | 'state_dict': model.state_dict(),
121 | 'best_EPE': 1e10}
122 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-S_checkpoint.pth.tar'))
123 |
124 | elif 'FlowNet2-SD/' in args.caffe_model:
125 | model = models.FlowNet2SD(args)
126 |
127 | parse_flownetsd(model.modules(), weights, biases, param_prefix='')
128 |
129 | state = {'epoch': 0,
130 | 'state_dict': model.state_dict(),
131 | 'best_EPE': 1e10}
132 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-SD_checkpoint.pth.tar'))
133 |
134 | else:
135 | print(('model type cound not be determined from input caffe model %s'%(args.caffe_model)))
136 | quit()
137 | print(("done converting ", args.caffe_model))
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ./networks/correlation_package
3 | python setup.py install --user
4 | cd ../resample2d_package
5 | python setup.py install --user
6 | cd ../channelnorm_package
7 | python setup.py install --user
8 | cd ..
9 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/launch_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 .
3 | sudo nvidia-docker run --rm -ti --volume=$(pwd):/flownet2-pytorch:rw --workdir=/flownet2-pytorch --ipc=host $USER/pytorch:CUDA8-py27 /bin/bash
4 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/losses.py:
--------------------------------------------------------------------------------
1 | '''
2 | Portions of this code copyright 2017, Clement Pinard
3 | '''
4 |
5 | # freda (todo) : adversarial loss
6 |
7 | import torch
8 | import torch.nn as nn
9 | import math
10 |
11 | def EPE(input_flow, target_flow):
12 | return torch.norm(target_flow-input_flow,p=2,dim=1).mean()
13 |
14 | class L1(nn.Module):
15 | def __init__(self):
16 | super(L1, self).__init__()
17 | def forward(self, output, target):
18 | lossvalue = torch.abs(output - target).mean()
19 | return lossvalue
20 |
21 | class L2(nn.Module):
22 | def __init__(self):
23 | super(L2, self).__init__()
24 | def forward(self, output, target):
25 | lossvalue = torch.norm(output-target,p=2,dim=1).mean()
26 | return lossvalue
27 |
28 | class L1Loss(nn.Module):
29 | def __init__(self, args):
30 | super(L1Loss, self).__init__()
31 | self.args = args
32 | self.loss = L1()
33 | self.loss_labels = ['L1', 'EPE']
34 |
35 | def forward(self, output, target):
36 | lossvalue = self.loss(output, target)
37 | epevalue = EPE(output, target)
38 | return [lossvalue, epevalue]
39 |
40 | class L2Loss(nn.Module):
41 | def __init__(self, args):
42 | super(L2Loss, self).__init__()
43 | self.args = args
44 | self.loss = L2()
45 | self.loss_labels = ['L2', 'EPE']
46 |
47 | def forward(self, output, target):
48 | lossvalue = self.loss(output, target)
49 | epevalue = EPE(output, target)
50 | return [lossvalue, epevalue]
51 |
52 | class MultiScale(nn.Module):
53 | def __init__(self, args, startScale = 4, numScales = 5, l_weight= 0.32, norm= 'L1'):
54 | super(MultiScale,self).__init__()
55 |
56 | self.startScale = startScale
57 | self.numScales = numScales
58 | self.loss_weights = torch.FloatTensor([(l_weight / 2 ** scale) for scale in range(self.numScales)])
59 | self.args = args
60 | self.l_type = norm
61 | self.div_flow = 0.05
62 | assert(len(self.loss_weights) == self.numScales)
63 |
64 | if self.l_type == 'L1':
65 | self.loss = L1()
66 | else:
67 | self.loss = L2()
68 |
69 | self.multiScales = [nn.AvgPool2d(self.startScale * (2**scale), self.startScale * (2**scale)) for scale in range(self.numScales)]
70 | self.loss_labels = ['MultiScale-'+self.l_type, 'EPE'],
71 |
72 | def forward(self, output, target):
73 | lossvalue = 0
74 | epevalue = 0
75 |
76 | if type(output) is tuple:
77 | target = self.div_flow * target
78 | for i, output_ in enumerate(output):
79 | target_ = self.multiScales[i](target)
80 | epevalue += self.loss_weights[i]*EPE(output_, target_)
81 | lossvalue += self.loss_weights[i]*self.loss(output_, target_)
82 | return [lossvalue, epevalue]
83 | else:
84 | epevalue += EPE(output, target)
85 | lossvalue += self.loss(output, target)
86 | return [lossvalue, epevalue]
87 |
88 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/FlowNetC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 |
5 | import math
6 | import numpy as np
7 |
8 | from .correlation_package.correlation import Correlation
9 |
10 | from .submodules import *
11 | 'Parameter count , 39,175,298 '
12 |
13 | class FlowNetC(nn.Module):
14 | def __init__(self,args, batchNorm=True, div_flow = 20):
15 | super(FlowNetC,self).__init__()
16 |
17 | self.batchNorm = batchNorm
18 | self.div_flow = div_flow
19 |
20 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2)
21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
23 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1)
24 |
25 | if args.fp16:
26 | self.corr = nn.Sequential(
27 | tofp32(),
28 | Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1),
29 | tofp16())
30 | else:
31 | self.corr = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1)
32 |
33 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True)
34 | self.conv3_1 = conv(self.batchNorm, 473, 256)
35 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
36 | self.conv4_1 = conv(self.batchNorm, 512, 512)
37 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
38 | self.conv5_1 = conv(self.batchNorm, 512, 512)
39 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
40 | self.conv6_1 = conv(self.batchNorm,1024, 1024)
41 |
42 | self.deconv5 = deconv(1024,512)
43 | self.deconv4 = deconv(1026,256)
44 | self.deconv3 = deconv(770,128)
45 | self.deconv2 = deconv(386,64)
46 |
47 | self.predict_flow6 = predict_flow(1024)
48 | self.predict_flow5 = predict_flow(1026)
49 | self.predict_flow4 = predict_flow(770)
50 | self.predict_flow3 = predict_flow(386)
51 | self.predict_flow2 = predict_flow(194)
52 |
53 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
54 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
55 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
56 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
57 |
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | if m.bias is not None:
61 | init.uniform_(m.bias)
62 | init.xavier_uniform_(m.weight)
63 |
64 | if isinstance(m, nn.ConvTranspose2d):
65 | if m.bias is not None:
66 | init.uniform_(m.bias)
67 | init.xavier_uniform_(m.weight)
68 | # init_deconv_bilinear(m.weight)
69 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
70 |
71 | def forward(self, x):
72 | x1 = x[:,0:3,:,:]
73 | x2 = x[:,3::,:,:]
74 |
75 | out_conv1a = self.conv1(x1)
76 | out_conv2a = self.conv2(out_conv1a)
77 | out_conv3a = self.conv3(out_conv2a)
78 |
79 | # FlownetC bottom input stream
80 | out_conv1b = self.conv1(x2)
81 |
82 | out_conv2b = self.conv2(out_conv1b)
83 | out_conv3b = self.conv3(out_conv2b)
84 |
85 | # Merge streams
86 | out_corr = self.corr(out_conv3a, out_conv3b) # False
87 | out_corr = self.corr_activation(out_corr)
88 |
89 | # Redirect top input stream and concatenate
90 | out_conv_redir = self.conv_redir(out_conv3a)
91 |
92 | in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1)
93 |
94 | # Merged conv layers
95 | out_conv3_1 = self.conv3_1(in_conv3_1)
96 |
97 | out_conv4 = self.conv4_1(self.conv4(out_conv3_1))
98 |
99 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
100 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
101 |
102 | flow6 = self.predict_flow6(out_conv6)
103 | flow6_up = self.upsampled_flow6_to_5(flow6)
104 | out_deconv5 = self.deconv5(out_conv6)
105 |
106 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
107 |
108 | flow5 = self.predict_flow5(concat5)
109 | flow5_up = self.upsampled_flow5_to_4(flow5)
110 | out_deconv4 = self.deconv4(concat5)
111 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
112 |
113 | flow4 = self.predict_flow4(concat4)
114 | flow4_up = self.upsampled_flow4_to_3(flow4)
115 | out_deconv3 = self.deconv3(concat4)
116 | concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1)
117 |
118 | flow3 = self.predict_flow3(concat3)
119 | flow3_up = self.upsampled_flow3_to_2(flow3)
120 | out_deconv2 = self.deconv2(concat3)
121 | concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1)
122 |
123 | flow2 = self.predict_flow2(concat2)
124 |
125 | if self.training:
126 | return flow2,flow3,flow4,flow5,flow6
127 | else:
128 | return flow2,
129 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/FlowNetFusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 |
5 | import math
6 | import numpy as np
7 |
8 | from .submodules import *
9 | 'Parameter count = 581,226'
10 |
11 | class FlowNetFusion(nn.Module):
12 | def __init__(self,args, batchNorm=True):
13 | super(FlowNetFusion,self).__init__()
14 |
15 | self.batchNorm = batchNorm
16 | self.conv0 = conv(self.batchNorm, 11, 64)
17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2)
18 | self.conv1_1 = conv(self.batchNorm, 64, 128)
19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2)
20 | self.conv2_1 = conv(self.batchNorm, 128, 128)
21 |
22 | self.deconv1 = deconv(128,32)
23 | self.deconv0 = deconv(162,16)
24 |
25 | self.inter_conv1 = i_conv(self.batchNorm, 162, 32)
26 | self.inter_conv0 = i_conv(self.batchNorm, 82, 16)
27 |
28 | self.predict_flow2 = predict_flow(128)
29 | self.predict_flow1 = predict_flow(32)
30 | self.predict_flow0 = predict_flow(16)
31 |
32 | self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
33 | self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
34 |
35 | for m in self.modules():
36 | if isinstance(m, nn.Conv2d):
37 | if m.bias is not None:
38 | init.uniform_(m.bias)
39 | init.xavier_uniform_(m.weight)
40 |
41 | if isinstance(m, nn.ConvTranspose2d):
42 | if m.bias is not None:
43 | init.uniform_(m.bias)
44 | init.xavier_uniform_(m.weight)
45 | # init_deconv_bilinear(m.weight)
46 |
47 | def forward(self, x):
48 | out_conv0 = self.conv0(x)
49 | out_conv1 = self.conv1_1(self.conv1(out_conv0))
50 | out_conv2 = self.conv2_1(self.conv2(out_conv1))
51 |
52 | flow2 = self.predict_flow2(out_conv2)
53 | flow2_up = self.upsampled_flow2_to_1(flow2)
54 | out_deconv1 = self.deconv1(out_conv2)
55 |
56 | concat1 = torch.cat((out_conv1,out_deconv1,flow2_up),1)
57 | out_interconv1 = self.inter_conv1(concat1)
58 | flow1 = self.predict_flow1(out_interconv1)
59 | flow1_up = self.upsampled_flow1_to_0(flow1)
60 | out_deconv0 = self.deconv0(concat1)
61 |
62 | concat0 = torch.cat((out_conv0,out_deconv0,flow1_up),1)
63 | out_interconv0 = self.inter_conv0(concat0)
64 | flow0 = self.predict_flow0(out_interconv0)
65 |
66 | return flow0
67 |
68 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/FlowNetS.py:
--------------------------------------------------------------------------------
1 | '''
2 | Portions of this code copyright 2017, Clement Pinard
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import init
8 |
9 | import math
10 | import numpy as np
11 |
12 | from .submodules import *
13 | 'Parameter count : 38,676,504 '
14 |
15 | class FlowNetS(nn.Module):
16 | def __init__(self, args, input_channels = 12, batchNorm=True):
17 | super(FlowNetS,self).__init__()
18 |
19 | self.batchNorm = batchNorm
20 | self.conv1 = conv(self.batchNorm, input_channels, 64, kernel_size=7, stride=2)
21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
23 | self.conv3_1 = conv(self.batchNorm, 256, 256)
24 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
25 | self.conv4_1 = conv(self.batchNorm, 512, 512)
26 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
27 | self.conv5_1 = conv(self.batchNorm, 512, 512)
28 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
29 | self.conv6_1 = conv(self.batchNorm,1024, 1024)
30 |
31 | self.deconv5 = deconv(1024,512)
32 | self.deconv4 = deconv(1026,256)
33 | self.deconv3 = deconv(770,128)
34 | self.deconv2 = deconv(386,64)
35 |
36 | self.predict_flow6 = predict_flow(1024)
37 | self.predict_flow5 = predict_flow(1026)
38 | self.predict_flow4 = predict_flow(770)
39 | self.predict_flow3 = predict_flow(386)
40 | self.predict_flow2 = predict_flow(194)
41 |
42 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
43 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
44 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
45 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
46 |
47 | for m in self.modules():
48 | if isinstance(m, nn.Conv2d):
49 | if m.bias is not None:
50 | init.uniform_(m.bias)
51 | init.xavier_uniform_(m.weight)
52 |
53 | if isinstance(m, nn.ConvTranspose2d):
54 | if m.bias is not None:
55 | init.uniform_(m.bias)
56 | init.xavier_uniform_(m.weight)
57 | # init_deconv_bilinear(m.weight)
58 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
59 |
60 | def forward(self, x):
61 | out_conv1 = self.conv1(x)
62 |
63 | out_conv2 = self.conv2(out_conv1)
64 | out_conv3 = self.conv3_1(self.conv3(out_conv2))
65 | out_conv4 = self.conv4_1(self.conv4(out_conv3))
66 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
67 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
68 |
69 | flow6 = self.predict_flow6(out_conv6)
70 | flow6_up = self.upsampled_flow6_to_5(flow6)
71 | out_deconv5 = self.deconv5(out_conv6)
72 |
73 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
74 | flow5 = self.predict_flow5(concat5)
75 | flow5_up = self.upsampled_flow5_to_4(flow5)
76 | out_deconv4 = self.deconv4(concat5)
77 |
78 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
79 | flow4 = self.predict_flow4(concat4)
80 | flow4_up = self.upsampled_flow4_to_3(flow4)
81 | out_deconv3 = self.deconv3(concat4)
82 |
83 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
84 | flow3 = self.predict_flow3(concat3)
85 | flow3_up = self.upsampled_flow3_to_2(flow3)
86 | out_deconv2 = self.deconv2(concat3)
87 |
88 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
89 | flow2 = self.predict_flow2(concat2)
90 |
91 | if self.training:
92 | return flow2,flow3,flow4,flow5,flow6
93 | else:
94 | return flow2,
95 |
96 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/FlowNetSD.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 |
5 | import math
6 | import numpy as np
7 |
8 | from .submodules import *
9 | 'Parameter count = 45,371,666'
10 |
11 | class FlowNetSD(nn.Module):
12 | def __init__(self, args, batchNorm=True):
13 | super(FlowNetSD,self).__init__()
14 |
15 | self.batchNorm = batchNorm
16 | self.conv0 = conv(self.batchNorm, 6, 64)
17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2)
18 | self.conv1_1 = conv(self.batchNorm, 64, 128)
19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2)
20 | self.conv2_1 = conv(self.batchNorm, 128, 128)
21 | self.conv3 = conv(self.batchNorm, 128, 256, stride=2)
22 | self.conv3_1 = conv(self.batchNorm, 256, 256)
23 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
24 | self.conv4_1 = conv(self.batchNorm, 512, 512)
25 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
26 | self.conv5_1 = conv(self.batchNorm, 512, 512)
27 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
28 | self.conv6_1 = conv(self.batchNorm,1024, 1024)
29 |
30 | self.deconv5 = deconv(1024,512)
31 | self.deconv4 = deconv(1026,256)
32 | self.deconv3 = deconv(770,128)
33 | self.deconv2 = deconv(386,64)
34 |
35 | self.inter_conv5 = i_conv(self.batchNorm, 1026, 512)
36 | self.inter_conv4 = i_conv(self.batchNorm, 770, 256)
37 | self.inter_conv3 = i_conv(self.batchNorm, 386, 128)
38 | self.inter_conv2 = i_conv(self.batchNorm, 194, 64)
39 |
40 | self.predict_flow6 = predict_flow(1024)
41 | self.predict_flow5 = predict_flow(512)
42 | self.predict_flow4 = predict_flow(256)
43 | self.predict_flow3 = predict_flow(128)
44 | self.predict_flow2 = predict_flow(64)
45 |
46 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
47 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
48 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
49 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
50 |
51 | for m in self.modules():
52 | if isinstance(m, nn.Conv2d):
53 | if m.bias is not None:
54 | init.uniform_(m.bias)
55 | init.xavier_uniform_(m.weight)
56 |
57 | if isinstance(m, nn.ConvTranspose2d):
58 | if m.bias is not None:
59 | init.uniform_(m.bias)
60 | init.xavier_uniform_(m.weight)
61 | # init_deconv_bilinear(m.weight)
62 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
63 |
64 |
65 |
66 | def forward(self, x):
67 | out_conv0 = self.conv0(x)
68 | out_conv1 = self.conv1_1(self.conv1(out_conv0))
69 | out_conv2 = self.conv2_1(self.conv2(out_conv1))
70 |
71 | out_conv3 = self.conv3_1(self.conv3(out_conv2))
72 | out_conv4 = self.conv4_1(self.conv4(out_conv3))
73 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
74 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
75 |
76 | flow6 = self.predict_flow6(out_conv6)
77 | flow6_up = self.upsampled_flow6_to_5(flow6)
78 | out_deconv5 = self.deconv5(out_conv6)
79 |
80 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
81 | out_interconv5 = self.inter_conv5(concat5)
82 | flow5 = self.predict_flow5(out_interconv5)
83 |
84 | flow5_up = self.upsampled_flow5_to_4(flow5)
85 | out_deconv4 = self.deconv4(concat5)
86 |
87 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
88 | out_interconv4 = self.inter_conv4(concat4)
89 | flow4 = self.predict_flow4(out_interconv4)
90 | flow4_up = self.upsampled_flow4_to_3(flow4)
91 | out_deconv3 = self.deconv3(concat4)
92 |
93 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
94 | out_interconv3 = self.inter_conv3(concat3)
95 | flow3 = self.predict_flow3(out_interconv3)
96 | flow3_up = self.upsampled_flow3_to_2(flow3)
97 | out_deconv2 = self.deconv2(concat3)
98 |
99 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
100 | out_interconv2 = self.inter_conv2(concat2)
101 | flow2 = self.predict_flow2(out_interconv2)
102 |
103 | if self.training:
104 | return flow2,flow3,flow4,flow5,flow6
105 | else:
106 | return flow2,
107 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/networks/__init__.py
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/channelnorm_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/networks/channelnorm_package/__init__.py
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/channelnorm_package/channelnorm.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Function, Variable
2 | from torch.nn.modules.module import Module
3 | import channelnorm_cuda
4 |
5 | class ChannelNormFunction(Function):
6 |
7 | @staticmethod
8 | def forward(ctx, input1, norm_deg=2):
9 | assert input1.is_contiguous()
10 | b, _, h, w = input1.size()
11 | output = input1.new(b, 1, h, w).zero_()
12 |
13 | channelnorm_cuda.forward(input1, output, norm_deg)
14 | ctx.save_for_backward(input1, output)
15 | ctx.norm_deg = norm_deg
16 |
17 | return output
18 |
19 | @staticmethod
20 | def backward(ctx, grad_output):
21 | input1, output = ctx.saved_tensors
22 |
23 | grad_input1 = Variable(input1.new(input1.size()).zero_())
24 |
25 | channelnorm.backward(input1, output, grad_output.data,
26 | grad_input1.data, ctx.norm_deg)
27 |
28 | return grad_input1, None
29 |
30 |
31 | class ChannelNorm(Module):
32 |
33 | def __init__(self, norm_deg=2):
34 | super(ChannelNorm, self).__init__()
35 | self.norm_deg = norm_deg
36 |
37 | def forward(self, input1):
38 | return ChannelNormFunction.apply(input1, self.norm_deg)
39 |
40 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/channelnorm_package/channelnorm_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "channelnorm_kernel.cuh"
5 |
6 | int channelnorm_cuda_forward(
7 | at::Tensor& input1,
8 | at::Tensor& output,
9 | int norm_deg) {
10 |
11 | channelnorm_kernel_forward(input1, output, norm_deg);
12 | return 1;
13 | }
14 |
15 |
16 | int channelnorm_cuda_backward(
17 | at::Tensor& input1,
18 | at::Tensor& output,
19 | at::Tensor& gradOutput,
20 | at::Tensor& gradInput1,
21 | int norm_deg) {
22 |
23 | channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
24 | return 1;
25 | }
26 |
27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
28 | m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
29 | m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/channelnorm_package/channelnorm_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 |
5 | #include "channelnorm_kernel.cuh"
6 |
7 | #define CUDA_NUM_THREADS 512
8 |
9 | #define DIM0(TENSOR) ((TENSOR).x)
10 | #define DIM1(TENSOR) ((TENSOR).y)
11 | #define DIM2(TENSOR) ((TENSOR).z)
12 | #define DIM3(TENSOR) ((TENSOR).w)
13 |
14 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
15 |
16 | using at::Half;
17 |
18 | template
19 | __global__ void kernel_channelnorm_update_output(
20 | const int n,
21 | const scalar_t* __restrict__ input1,
22 | const long4 input1_size,
23 | const long4 input1_stride,
24 | scalar_t* __restrict__ output,
25 | const long4 output_size,
26 | const long4 output_stride,
27 | int norm_deg) {
28 |
29 | int index = blockIdx.x * blockDim.x + threadIdx.x;
30 |
31 | if (index >= n) {
32 | return;
33 | }
34 |
35 | int dim_b = DIM0(output_size);
36 | int dim_c = DIM1(output_size);
37 | int dim_h = DIM2(output_size);
38 | int dim_w = DIM3(output_size);
39 | int dim_chw = dim_c * dim_h * dim_w;
40 |
41 | int b = ( index / dim_chw ) % dim_b;
42 | int y = ( index / dim_w ) % dim_h;
43 | int x = ( index ) % dim_w;
44 |
45 | int i1dim_c = DIM1(input1_size);
46 | int i1dim_h = DIM2(input1_size);
47 | int i1dim_w = DIM3(input1_size);
48 | int i1dim_chw = i1dim_c * i1dim_h * i1dim_w;
49 | int i1dim_hw = i1dim_h * i1dim_w;
50 |
51 | float result = 0.0;
52 |
53 | for (int c = 0; c < i1dim_c; ++c) {
54 | int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x;
55 | scalar_t val = input1[i1Index];
56 | result += static_cast(val * val);
57 | }
58 | result = sqrt(result);
59 | output[index] = static_cast(result);
60 | }
61 |
62 |
63 | template
64 | __global__ void kernel_channelnorm_backward_input1(
65 | const int n,
66 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
67 | const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride,
68 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
69 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride,
70 | int norm_deg) {
71 |
72 | int index = blockIdx.x * blockDim.x + threadIdx.x;
73 |
74 | if (index >= n) {
75 | return;
76 | }
77 |
78 | float val = 0.0;
79 |
80 | int dim_b = DIM0(gradInput_size);
81 | int dim_c = DIM1(gradInput_size);
82 | int dim_h = DIM2(gradInput_size);
83 | int dim_w = DIM3(gradInput_size);
84 | int dim_chw = dim_c * dim_h * dim_w;
85 | int dim_hw = dim_h * dim_w;
86 |
87 | int b = ( index / dim_chw ) % dim_b;
88 | int y = ( index / dim_w ) % dim_h;
89 | int x = ( index ) % dim_w;
90 |
91 |
92 | int outIndex = b * dim_hw + y * dim_w + x;
93 | val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9);
94 | gradInput[index] = static_cast(val);
95 |
96 | }
97 |
98 | void channelnorm_kernel_forward(
99 | at::Tensor& input1,
100 | at::Tensor& output,
101 | int norm_deg) {
102 |
103 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
104 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
105 |
106 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
107 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
108 |
109 | int n = output.numel();
110 |
111 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] {
112 |
113 | kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
114 | //at::globalContext().getCurrentCUDAStream() >>>(
115 | n,
116 | input1.data(),
117 | input1_size,
118 | input1_stride,
119 | output.data(),
120 | output_size,
121 | output_stride,
122 | norm_deg);
123 |
124 | }));
125 |
126 | // TODO: ATen-equivalent check
127 |
128 | // THCudaCheck(cudaGetLastError());
129 | }
130 |
131 | void channelnorm_kernel_backward(
132 | at::Tensor& input1,
133 | at::Tensor& output,
134 | at::Tensor& gradOutput,
135 | at::Tensor& gradInput1,
136 | int norm_deg) {
137 |
138 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
139 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
140 |
141 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
142 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
143 |
144 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
145 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
146 |
147 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
148 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
149 |
150 | int n = gradInput1.numel();
151 |
152 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] {
153 |
154 | kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::cuda::getCurrentCUDAStream() >>>(
155 | //at::globalContext().getCurrentCUDAStream() >>>(
156 | n,
157 | input1.data(),
158 | input1_size,
159 | input1_stride,
160 | output.data(),
161 | output_size,
162 | output_stride,
163 | gradOutput.data(),
164 | gradOutput_size,
165 | gradOutput_stride,
166 | gradInput1.data(),
167 | gradInput1_size,
168 | gradInput1_stride,
169 | norm_deg
170 | );
171 |
172 | }));
173 |
174 | // TODO: Add ATen-equivalent check
175 |
176 | // THCudaCheck(cudaGetLastError());
177 | }
178 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/channelnorm_package/channelnorm_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void channelnorm_kernel_forward(
6 | at::Tensor& input1,
7 | at::Tensor& output,
8 | int norm_deg);
9 |
10 |
11 | void channelnorm_kernel_backward(
12 | at::Tensor& input1,
13 | at::Tensor& output,
14 | at::Tensor& gradOutput,
15 | at::Tensor& gradInput1,
16 | int norm_deg);
17 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/channelnorm_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_52,code=sm_52',
12 | '-gencode', 'arch=compute_60,code=sm_60',
13 | '-gencode', 'arch=compute_61,code=sm_61',
14 | '-gencode', 'arch=compute_70,code=sm_70',
15 | '-gencode', 'arch=compute_70,code=compute_70'
16 | ]
17 |
18 | setup(
19 | name='channelnorm_cuda',
20 | ext_modules=[
21 | CUDAExtension('channelnorm_cuda', [
22 | 'channelnorm_cuda.cc',
23 | 'channelnorm_kernel.cu'
24 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
25 | ],
26 | cmdclass={
27 | 'build_ext': BuildExtension
28 | })
29 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/correlation_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/networks/correlation_package/__init__.py
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/correlation_package/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules.module import Module
3 | from torch.autograd import Function
4 | import correlation_cuda
5 |
6 | class CorrelationFunction(Function):
7 |
8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
9 | super(CorrelationFunction, self).__init__()
10 | self.pad_size = pad_size
11 | self.kernel_size = kernel_size
12 | self.max_displacement = max_displacement
13 | self.stride1 = stride1
14 | self.stride2 = stride2
15 | self.corr_multiply = corr_multiply
16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
17 |
18 | @staticmethod
19 | def forward(ctx, input1, input2, pad_size, kernel_size, max_displacement, stride1, stride2, corr_multiply):
20 | ctx.save_for_backward(input1, input2)
21 | ctx.pad_size = pad_size
22 | ctx.kernel_size = kernel_size
23 | ctx.max_displacement = max_displacement
24 | ctx.stride1 = stride1
25 | ctx.stride2 = stride2
26 | ctx.corr_multiply = corr_multiply
27 |
28 | with torch.cuda.device_of(input1):
29 | rbot1 = input1.new()
30 | rbot2 = input2.new()
31 | output = input1.new()
32 |
33 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
34 | ctx.pad_size, ctx.kernel_size, ctx.max_displacement,ctx.stride1, ctx.stride2, ctx.corr_multiply)
35 |
36 | return output
37 |
38 | @staticmethod
39 | def backward(ctx, grad_output):
40 | input1, input2 = ctx.saved_tensors
41 |
42 | with torch.cuda.device_of(input1):
43 | rbot1 = input1.new()
44 | rbot2 = input2.new()
45 |
46 | grad_input1 = input1.new()
47 | grad_input2 = input2.new()
48 |
49 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
50 | ctx.pad_size, ctx.kernel_size, ctx.max_displacement,ctx.stride1, ctx.stride2, ctx.corr_multiply)
51 |
52 | return grad_input1, grad_input2
53 |
54 |
55 | class Correlation(Module):
56 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
57 | super(Correlation, self).__init__()
58 | self.pad_size = pad_size
59 | self.kernel_size = kernel_size
60 | self.max_displacement = max_displacement
61 | self.stride1 = stride1
62 | self.stride2 = stride2
63 | self.corr_multiply = corr_multiply
64 |
65 | def forward(self, input1, input2):
66 | #result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,\
67 | # self.stride1, self.stride2, self.corr_multiply)(input1, input2)
68 | result = CorrelationFunction.apply(input1, input2, self.pad_size, self.kernel_size, self.max_displacement,\
69 | self.stride1, self.stride2, self.corr_multiply)
70 | return result
71 |
72 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/correlation_package/correlation_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 | #include
6 | #include
7 |
8 | #include "correlation_cuda_kernel.cuh"
9 |
10 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
11 | int pad_size,
12 | int kernel_size,
13 | int max_displacement,
14 | int stride1,
15 | int stride2,
16 | int corr_type_multiply)
17 | {
18 |
19 | int batchSize = input1.size(0);
20 |
21 | int nInputChannels = input1.size(1);
22 | int inputHeight = input1.size(2);
23 | int inputWidth = input1.size(3);
24 |
25 | int kernel_radius = (kernel_size - 1) / 2;
26 | int border_radius = kernel_radius + max_displacement;
27 |
28 | int paddedInputHeight = inputHeight + 2 * pad_size;
29 | int paddedInputWidth = inputWidth + 2 * pad_size;
30 |
31 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
32 |
33 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
34 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
35 |
36 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
37 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
38 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
39 |
40 | rInput1.fill_(0);
41 | rInput2.fill_(0);
42 | output.fill_(0);
43 |
44 | int success = correlation_forward_cuda_kernel(
45 | output,
46 | output.size(0),
47 | output.size(1),
48 | output.size(2),
49 | output.size(3),
50 | output.stride(0),
51 | output.stride(1),
52 | output.stride(2),
53 | output.stride(3),
54 | input1,
55 | input1.size(1),
56 | input1.size(2),
57 | input1.size(3),
58 | input1.stride(0),
59 | input1.stride(1),
60 | input1.stride(2),
61 | input1.stride(3),
62 | input2,
63 | input2.size(1),
64 | input2.stride(0),
65 | input2.stride(1),
66 | input2.stride(2),
67 | input2.stride(3),
68 | rInput1,
69 | rInput2,
70 | pad_size,
71 | kernel_size,
72 | max_displacement,
73 | stride1,
74 | stride2,
75 | corr_type_multiply,
76 | at::cuda::getCurrentCUDAStream()
77 | //at::globalContext().getCurrentCUDAStream()
78 | );
79 |
80 | //check for errors
81 | if (!success) {
82 | AT_ERROR("CUDA call failed");
83 | }
84 |
85 | return 1;
86 |
87 | }
88 |
89 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
90 | at::Tensor& gradInput1, at::Tensor& gradInput2,
91 | int pad_size,
92 | int kernel_size,
93 | int max_displacement,
94 | int stride1,
95 | int stride2,
96 | int corr_type_multiply)
97 | {
98 |
99 | int batchSize = input1.size(0);
100 | int nInputChannels = input1.size(1);
101 | int paddedInputHeight = input1.size(2)+ 2 * pad_size;
102 | int paddedInputWidth = input1.size(3)+ 2 * pad_size;
103 |
104 | int height = input1.size(2);
105 | int width = input1.size(3);
106 |
107 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
108 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
109 | gradInput1.resize_({batchSize, nInputChannels, height, width});
110 | gradInput2.resize_({batchSize, nInputChannels, height, width});
111 |
112 | rInput1.fill_(0);
113 | rInput2.fill_(0);
114 | gradInput1.fill_(0);
115 | gradInput2.fill_(0);
116 |
117 | int success = correlation_backward_cuda_kernel(gradOutput,
118 | gradOutput.size(0),
119 | gradOutput.size(1),
120 | gradOutput.size(2),
121 | gradOutput.size(3),
122 | gradOutput.stride(0),
123 | gradOutput.stride(1),
124 | gradOutput.stride(2),
125 | gradOutput.stride(3),
126 | input1,
127 | input1.size(1),
128 | input1.size(2),
129 | input1.size(3),
130 | input1.stride(0),
131 | input1.stride(1),
132 | input1.stride(2),
133 | input1.stride(3),
134 | input2,
135 | input2.stride(0),
136 | input2.stride(1),
137 | input2.stride(2),
138 | input2.stride(3),
139 | gradInput1,
140 | gradInput1.stride(0),
141 | gradInput1.stride(1),
142 | gradInput1.stride(2),
143 | gradInput1.stride(3),
144 | gradInput2,
145 | gradInput2.size(1),
146 | gradInput2.stride(0),
147 | gradInput2.stride(1),
148 | gradInput2.stride(2),
149 | gradInput2.stride(3),
150 | rInput1,
151 | rInput2,
152 | pad_size,
153 | kernel_size,
154 | max_displacement,
155 | stride1,
156 | stride2,
157 | corr_type_multiply,
158 | at::cuda::getCurrentCUDAStream()
159 | //at::globalContext().getCurrentCUDAStream()
160 | );
161 |
162 | if (!success) {
163 | AT_ERROR("CUDA call failed");
164 | }
165 |
166 | return 1;
167 | }
168 |
169 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
170 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
171 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
172 | }
173 |
174 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/correlation_package/correlation_cuda_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | int correlation_forward_cuda_kernel(at::Tensor& output,
8 | int ob,
9 | int oc,
10 | int oh,
11 | int ow,
12 | int osb,
13 | int osc,
14 | int osh,
15 | int osw,
16 |
17 | at::Tensor& input1,
18 | int ic,
19 | int ih,
20 | int iw,
21 | int isb,
22 | int isc,
23 | int ish,
24 | int isw,
25 |
26 | at::Tensor& input2,
27 | int gc,
28 | int gsb,
29 | int gsc,
30 | int gsh,
31 | int gsw,
32 |
33 | at::Tensor& rInput1,
34 | at::Tensor& rInput2,
35 | int pad_size,
36 | int kernel_size,
37 | int max_displacement,
38 | int stride1,
39 | int stride2,
40 | int corr_type_multiply,
41 | cudaStream_t stream);
42 |
43 |
44 | int correlation_backward_cuda_kernel(
45 | at::Tensor& gradOutput,
46 | int gob,
47 | int goc,
48 | int goh,
49 | int gow,
50 | int gosb,
51 | int gosc,
52 | int gosh,
53 | int gosw,
54 |
55 | at::Tensor& input1,
56 | int ic,
57 | int ih,
58 | int iw,
59 | int isb,
60 | int isc,
61 | int ish,
62 | int isw,
63 |
64 | at::Tensor& input2,
65 | int gsb,
66 | int gsc,
67 | int gsh,
68 | int gsw,
69 |
70 | at::Tensor& gradInput1,
71 | int gisb,
72 | int gisc,
73 | int gish,
74 | int gisw,
75 |
76 | at::Tensor& gradInput2,
77 | int ggc,
78 | int ggsb,
79 | int ggsc,
80 | int ggsh,
81 | int ggsw,
82 |
83 | at::Tensor& rInput1,
84 | at::Tensor& rInput2,
85 | int pad_size,
86 | int kernel_size,
87 | int max_displacement,
88 | int stride1,
89 | int stride2,
90 | int corr_type_multiply,
91 | cudaStream_t stream);
92 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/correlation_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup, find_packages
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_50,code=sm_50',
12 | '-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | '-gencode', 'arch=compute_70,code=sm_70',
16 | '-gencode', 'arch=compute_70,code=compute_70'
17 | ]
18 |
19 | setup(
20 | name='correlation_cuda',
21 | ext_modules=[
22 | CUDAExtension('correlation_cuda', [
23 | 'correlation_cuda.cc',
24 | 'correlation_cuda_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/resample2d_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/networks/resample2d_package/__init__.py
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/resample2d_package/resample2d.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from torch.autograd import Function, Variable
3 | import resample2d_cuda
4 |
5 | class Resample2dFunction(Function):
6 |
7 | @staticmethod
8 | def forward(ctx, input1, input2, kernel_size=1):
9 | assert input1.is_contiguous()
10 | assert input2.is_contiguous()
11 |
12 | ctx.save_for_backward(input1, input2)
13 | ctx.kernel_size = kernel_size
14 |
15 | _, d, _, _ = input1.size()
16 | b, _, h, w = input2.size()
17 | output = input1.new(b, d, h, w).zero_()
18 |
19 | resample2d_cuda.forward(input1, input2, output, kernel_size)
20 |
21 | return output
22 |
23 | @staticmethod
24 | def backward(ctx, grad_output):
25 | assert grad_output.is_contiguous()
26 |
27 | input1, input2 = ctx.saved_tensors
28 |
29 | grad_input1 = Variable(input1.new(input1.size()).zero_())
30 | grad_input2 = Variable(input1.new(input2.size()).zero_())
31 |
32 | resample2d_cuda.backward(input1, input2, grad_output.data,
33 | grad_input1.data, grad_input2.data,
34 | ctx.kernel_size)
35 |
36 | return grad_input1, grad_input2, None
37 |
38 | class Resample2d(Module):
39 |
40 | def __init__(self, kernel_size=1):
41 | super(Resample2d, self).__init__()
42 | self.kernel_size = kernel_size
43 |
44 | def forward(self, input1, input2):
45 | input1_c = input1.contiguous()
46 | return Resample2dFunction.apply(input1_c, input2, self.kernel_size)
47 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/resample2d_package/resample2d_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "resample2d_kernel.cuh"
5 |
6 | int resample2d_cuda_forward(
7 | at::Tensor& input1,
8 | at::Tensor& input2,
9 | at::Tensor& output,
10 | int kernel_size) {
11 | resample2d_kernel_forward(input1, input2, output, kernel_size);
12 | return 1;
13 | }
14 |
15 | int resample2d_cuda_backward(
16 | at::Tensor& input1,
17 | at::Tensor& input2,
18 | at::Tensor& gradOutput,
19 | at::Tensor& gradInput1,
20 | at::Tensor& gradInput2,
21 | int kernel_size) {
22 | resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size);
23 | return 1;
24 | }
25 |
26 |
27 |
28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
29 | m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
30 | m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
31 | }
32 |
33 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/resample2d_package/resample2d_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void resample2d_kernel_forward(
6 | at::Tensor& input1,
7 | at::Tensor& input2,
8 | at::Tensor& output,
9 | int kernel_size);
10 |
11 | void resample2d_kernel_backward(
12 | at::Tensor& input1,
13 | at::Tensor& input2,
14 | at::Tensor& gradOutput,
15 | at::Tensor& gradInput1,
16 | at::Tensor& gradInput2,
17 | int kernel_size);
18 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/resample2d_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_50,code=sm_50',
12 | '-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | '-gencode', 'arch=compute_70,code=sm_70',
16 | '-gencode', 'arch=compute_70,code=compute_70'
17 | ]
18 |
19 | setup(
20 | name='resample2d_cuda',
21 | ext_modules=[
22 | CUDAExtension('resample2d_cuda', [
23 | 'resample2d_cuda.cc',
24 | 'resample2d_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/networks/submodules.py:
--------------------------------------------------------------------------------
1 | # freda (todo) :
2 |
3 | import torch.nn as nn
4 | import torch
5 | import numpy as np
6 |
7 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
8 | if batchNorm:
9 | return nn.Sequential(
10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
11 | nn.BatchNorm2d(out_planes),
12 | nn.LeakyReLU(0.1,inplace=True)
13 | )
14 | else:
15 | return nn.Sequential(
16 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
17 | nn.LeakyReLU(0.1,inplace=True)
18 | )
19 |
20 | def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True):
21 | if batchNorm:
22 | return nn.Sequential(
23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias),
24 | nn.BatchNorm2d(out_planes),
25 | )
26 | else:
27 | return nn.Sequential(
28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias),
29 | )
30 |
31 | def predict_flow(in_planes):
32 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)
33 |
34 | def deconv(in_planes, out_planes):
35 | return nn.Sequential(
36 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
37 | nn.LeakyReLU(0.1,inplace=True)
38 | )
39 |
40 | class tofp16(nn.Module):
41 | def __init__(self):
42 | super(tofp16, self).__init__()
43 |
44 | def forward(self, input):
45 | return input.half()
46 |
47 |
48 | class tofp32(nn.Module):
49 | def __init__(self):
50 | super(tofp32, self).__init__()
51 |
52 | def forward(self, input):
53 | return input.float()
54 |
55 |
56 | def init_deconv_bilinear(weight):
57 | f_shape = weight.size()
58 | heigh, width = f_shape[-2], f_shape[-1]
59 | f = np.ceil(width/2.0)
60 | c = (2 * f - 1 - f % 2) / (2.0 * f)
61 | bilinear = np.zeros([heigh, width])
62 | for x in range(width):
63 | for y in range(heigh):
64 | value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
65 | bilinear[x, y] = value
66 | weight.data.fill_(0.)
67 | for i in range(f_shape[0]):
68 | for j in range(f_shape[1]):
69 | weight.data[i,j,:,:] = torch.from_numpy(bilinear)
70 |
71 |
72 | def save_grad(grads, name):
73 | def hook(grad):
74 | grads[name] = grad
75 | return hook
76 |
77 | '''
78 | def save_grad(grads, name):
79 | def hook(grad):
80 | grads[name] = grad
81 | return hook
82 | import torch
83 | from channelnorm_package.modules.channelnorm import ChannelNorm
84 | model = ChannelNorm().cuda()
85 | grads = {}
86 | a = 100*torch.autograd.Variable(torch.randn((1,3,5,5)).cuda(), requires_grad=True)
87 | a.register_hook(save_grad(grads, 'a'))
88 | b = model(a)
89 | y = torch.mean(b)
90 | y.backward()
91 |
92 | '''
93 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/run-caffe2pytorch.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | #!/bin/bash
9 |
10 | FN2PYTORCH=${1:-/}
11 |
12 | # install custom layers
13 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 .
14 | sudo nvidia-docker run --rm -ti --volume=${FN2PYTORCH}:/flownet2-pytorch:rw --workdir=/flownet2-pytorch $USER/pytorch:CUDA8-py27 /bin/bash -c "./install.sh"
15 |
16 | # convert FlowNet2-C, CS, CSS, CSS-ft-sd, SD, S and 2 to PyTorch
17 | sudo nvidia-docker run -ti --volume=${FN2PYTORCH}:/fn2pytorch:rw flownet2:latest /bin/bash -c "source /flownet2/flownet2/set-env.sh && cd /flownet2/flownet2/models && \
18 | python /fn2pytorch/convert.py ./FlowNet2-C/FlowNet2-C_weights.caffemodel ./FlowNet2-C/FlowNet2-C_deploy.prototxt.template /fn2pytorch &&
19 | python /fn2pytorch/convert.py ./FlowNet2-CS/FlowNet2-CS_weights.caffemodel ./FlowNet2-CS/FlowNet2-CS_deploy.prototxt.template /fn2pytorch && \
20 | python /fn2pytorch/convert.py ./FlowNet2-CSS/FlowNet2-CSS_weights.caffemodel.h5 ./FlowNet2-CSS/FlowNet2-CSS_deploy.prototxt.template /fn2pytorch && \
21 | python /fn2pytorch/convert.py ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_weights.caffemodel.h5 ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_deploy.prototxt.template /fn2pytorch && \
22 | python /fn2pytorch/convert.py ./FlowNet2-SD/FlowNet2-SD_weights.caffemodel.h5 ./FlowNet2-SD/FlowNet2-SD_deploy.prototxt.template /fn2pytorch && \
23 | python /fn2pytorch/convert.py ./FlowNet2-S/FlowNet2-S_weights.caffemodel.h5 ./FlowNet2-S/FlowNet2-S_deploy.prototxt.template /fn2pytorch && \
24 | python /fn2pytorch/convert.py ./FlowNet2/FlowNet2_weights.caffemodel.h5 ./FlowNet2/FlowNet2_deploy.prototxt.template /fn2pytorch"
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/few-shot-vid2vid/009e23f17f4ce1f227fdb4bbf50f81f706cd2c04/models/networks/flownet2_pytorch/utils/__init__.py
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/utils/flow_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | TAG_CHAR = np.array([202021.25], np.float32)
4 |
5 | def readFlow(fn):
6 | """ Read .flo file in Middlebury format"""
7 | # Code adapted from:
8 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
9 |
10 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
11 | # print 'fn = %s'%(fn)
12 | with open(fn, 'rb') as f:
13 | magic = np.fromfile(f, np.float32, count=1)
14 | if 202021.25 != magic:
15 | print('Magic number incorrect. Invalid .flo file')
16 | return None
17 | else:
18 | w = np.fromfile(f, np.int32, count=1)
19 | h = np.fromfile(f, np.int32, count=1)
20 | # print 'Reading %d x %d flo file\n' % (w, h)
21 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
22 | # Reshape data into 3D array (columns, rows, bands)
23 | # The reshape here is for visualization, the original code is (w,h,2)
24 | return np.resize(data, (int(h), int(w), 2))
25 |
26 | def writeFlow(filename,uv,v=None):
27 | """ Write optical flow to file.
28 |
29 | If v is None, uv is assumed to contain both u and v channels,
30 | stacked in depth.
31 | Original code by Deqing Sun, adapted from Daniel Scharstein.
32 | """
33 | nBands = 2
34 |
35 | if v is None:
36 | assert(uv.ndim == 3)
37 | assert(uv.shape[2] == 2)
38 | u = uv[:,:,0]
39 | v = uv[:,:,1]
40 | else:
41 | u = uv
42 |
43 | assert(u.shape == v.shape)
44 | height,width = u.shape
45 | f = open(filename,'wb')
46 | # write the header
47 | f.write(TAG_CHAR)
48 | np.array(width).astype(np.int32).tofile(f)
49 | np.array(height).astype(np.int32).tofile(f)
50 | # arrange into matrix form
51 | tmp = np.zeros((height, width*nBands))
52 | tmp[:,np.arange(width)*2] = u
53 | tmp[:,np.arange(width)*2 + 1] = v
54 | tmp.astype(np.float32).tofile(f)
55 | f.close()
56 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os.path import *
3 | from scipy.misc import imread
4 | from . import flow_utils
5 |
6 | def read_gen(file_name):
7 | ext = splitext(file_name)[-1]
8 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
9 | im = imread(file_name)
10 | if im.shape[2] > 3:
11 | return im[:,:,:3]
12 | else:
13 | return im
14 | elif ext == '.bin' or ext == '.raw':
15 | return np.load(file_name)
16 | elif ext == '.flo':
17 | return flow_utils.readFlow(file_name).astype(np.float32)
18 | return []
19 |
--------------------------------------------------------------------------------
/models/networks/flownet2_pytorch/utils/tools.py:
--------------------------------------------------------------------------------
1 | # freda (todo) :
2 |
3 | import os, time, sys, math
4 | import subprocess, shutil
5 | from os.path import *
6 | import numpy as np
7 | from inspect import isclass
8 | from pytz import timezone
9 | from datetime import datetime
10 | import inspect
11 | import torch
12 |
13 | def datestr():
14 | pacific = timezone('US/Pacific')
15 | now = datetime.now(pacific)
16 | return '{}{:02}{:02}_{:02}{:02}'.format(now.year, now.month, now.day, now.hour, now.minute)
17 |
18 | def module_to_dict(module, exclude=[]):
19 | return dict([(x, getattr(module, x)) for x in dir(module)
20 | if isclass(getattr(module, x))
21 | and x not in exclude
22 | and getattr(module, x) not in exclude])
23 |
24 | class TimerBlock:
25 | def __init__(self, title):
26 | print(("{}".format(title)))
27 |
28 | def __enter__(self):
29 | self.start = time.clock()
30 | return self
31 |
32 | def __exit__(self, exc_type, exc_value, traceback):
33 | self.end = time.clock()
34 | self.interval = self.end - self.start
35 |
36 | if exc_type is not None:
37 | self.log("Operation failed\n")
38 | else:
39 | self.log("Operation finished\n")
40 |
41 |
42 | def log(self, string):
43 | duration = time.clock() - self.start
44 | units = 's'
45 | if duration > 60:
46 | duration = duration / 60.
47 | units = 'm'
48 | print((" [{:.3f}{}] {}".format(duration, units, string)))
49 |
50 | def log2file(self, fid, string):
51 | fid = open(fid, 'a')
52 | fid.write("%s\n"%(string))
53 | fid.close()
54 |
55 | def add_arguments_for_module(parser, module, argument_for_class, default, skip_params=[], parameter_defaults={}):
56 | argument_group = parser.add_argument_group(argument_for_class.capitalize())
57 |
58 | module_dict = module_to_dict(module)
59 | argument_group.add_argument('--' + argument_for_class, type=str, default=default, choices=list(module_dict.keys()))
60 |
61 | args, unknown_args = parser.parse_known_args()
62 | class_obj = module_dict[vars(args)[argument_for_class]]
63 |
64 | argspec = inspect.getargspec(class_obj.__init__)
65 |
66 | defaults = argspec.defaults[::-1] if argspec.defaults else None
67 |
68 | args = argspec.args[::-1]
69 | for i, arg in enumerate(args):
70 | cmd_arg = '{}_{}'.format(argument_for_class, arg)
71 | if arg not in skip_params + ['self', 'args']:
72 | if arg in list(parameter_defaults.keys()):
73 | argument_group.add_argument('--{}'.format(cmd_arg), type=type(parameter_defaults[arg]), default=parameter_defaults[arg])
74 | elif (defaults is not None and i < len(defaults)):
75 | argument_group.add_argument('--{}'.format(cmd_arg), type=type(defaults[i]), default=defaults[i])
76 | else:
77 | print(("[Warning]: non-default argument '{}' detected on class '{}'. This argument cannot be modified via the command line"
78 | .format(arg, module.__class__.__name__)))
79 | # We don't have a good way of dealing with inferring the type of the argument
80 | # TODO: try creating a custom action and using ast's infer type?
81 | # else:
82 | # argument_group.add_argument('--{}'.format(cmd_arg), required=True)
83 |
84 | def kwargs_from_args(args, argument_for_class):
85 | argument_for_class = argument_for_class + '_'
86 | return {key[len(argument_for_class):]: value for key, value in list(vars(args).items()) if argument_for_class in key and key != argument_for_class + 'class'}
87 |
88 | def format_dictionary_of_losses(labels, values):
89 | try:
90 | string = ', '.join([('{}: {:' + ('.3f' if value >= 0.001 else '.1e') +'}').format(name, value) for name, value in zip(labels, values)])
91 | except (TypeError, ValueError) as e:
92 | print((list(zip(labels, values))))
93 | string = '[Log Error] ' + str(e)
94 |
95 | return string
96 |
97 |
98 | class IteratorTimer():
99 | def __init__(self, iterable):
100 | self.iterable = iterable
101 | self.iterator = self.iterable.__iter__()
102 |
103 | def __iter__(self):
104 | return self
105 |
106 | def __len__(self):
107 | return len(self.iterable)
108 |
109 | def __next__(self):
110 | start = time.time()
111 | n = next(self.iterator)
112 | self.last_duration = (time.time() - start)
113 | return n
114 |
115 | next = __next__
116 |
117 | def gpumemusage():
118 | gpu_mem = subprocess.check_output("nvidia-smi | grep MiB | cut -f 3 -d '|'", shell=True).replace(' ', '').replace('\n', '').replace('i', '')
119 | all_stat = [float(a) for a in gpu_mem.replace('/','').split('MB')[:-1]]
120 |
121 | gpu_mem = ''
122 | for i in range(len(all_stat)/2):
123 | curr, tot = all_stat[2*i], all_stat[2*i+1]
124 | util = "%1.2f"%(100*curr/tot)+'%'
125 | cmem = str(int(math.ceil(curr/1024.)))+'GB'
126 | gmem = str(int(math.ceil(tot/1024.)))+'GB'
127 | gpu_mem += util + '--' + join(cmem, gmem) + ' '
128 | return gpu_mem
129 |
130 |
131 | def update_hyperparameter_schedule(args, epoch, global_iteration, optimizer):
132 | if args.schedule_lr_frequency > 0:
133 | for param_group in optimizer.param_groups:
134 | if (global_iteration + 1) % args.schedule_lr_frequency == 0:
135 | param_group['lr'] /= float(args.schedule_lr_fraction)
136 | param_group['lr'] = float(np.maximum(param_group['lr'], 0.000001))
137 |
138 | def save_checkpoint(state, is_best, path, prefix, filename='checkpoint.pth.tar'):
139 | prefix_save = os.path.join(path, prefix)
140 | name = prefix_save + '_' + filename
141 | torch.save(state, name)
142 | if is_best:
143 | shutil.copyfile(name, prefix_save + '_model_best.pth.tar')
144 |
145 |
--------------------------------------------------------------------------------
/models/networks/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import numpy as np
11 | from models.networks.vgg import VGG_Activations, Vgg19
12 |
13 | # Defines the GAN loss which uses either LSGAN or the regular GAN.
14 | # When LSGAN is used, it is basically same as MSELoss,
15 | # but it abstracts away the need to create the target label tensor
16 | # that has the same size as the input
17 | class GANLoss(nn.Module):
18 | def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0,
19 | tensor=torch.FloatTensor, opt=None):
20 | super(GANLoss, self).__init__()
21 | self.real_label = target_real_label
22 | self.fake_label = target_fake_label
23 | self.real_label_tensor = None
24 | self.fake_label_tensor = None
25 | self.Tensor = tensor
26 | self.gan_mode = gan_mode
27 | self.opt = opt
28 | if gan_mode == 'ls':
29 | pass
30 | elif gan_mode == 'original':
31 | pass
32 | elif gan_mode == 'w':
33 | pass
34 | elif gan_mode == 'hinge':
35 | pass
36 | else:
37 | raise ValueError('Unexpected gan_mode {}'.format(gan_mode))
38 |
39 | def get_target_tensor(self, input, target_is_real):
40 | if target_is_real:
41 | if self.real_label_tensor is None:
42 | self.real_label_tensor = self.Tensor(1).fill_(self.real_label)
43 | return self.real_label_tensor.expand_as(input)
44 | else:
45 | if self.fake_label_tensor is None:
46 | self.fake_label_tensor = self.Tensor(1).fill_(self.fake_label)
47 | return self.fake_label_tensor.expand_as(input)
48 |
49 | def loss(self, input, target_is_real, weight=None, reduce_dim=True, for_discriminator=True):
50 | if self.gan_mode == 'original':
51 | target_tensor = self.get_target_tensor(input, target_is_real)
52 | batchsize = input.size(0)
53 | loss = F.binary_cross_entropy_with_logits(input, target_tensor, weight=weight)
54 | if not reduce_dim:
55 | loss = loss.view(batchsize, -1).mean(dim=1)
56 | return loss
57 | elif self.gan_mode == 'ls':
58 | #target_tensor = self.get_target_tensor(input, target_is_real)
59 | target_tensor = input * 0 + (self.real_label if target_is_real else self.fake_label)
60 | if weight is None and reduce_dim:
61 | return F.mse_loss(input, target_tensor)
62 | error = (input - target_tensor)**2
63 | if weight is not None:
64 | error *= weight
65 | if reduce_dim:
66 | return torch.mean(error)
67 | else:
68 | return error.view(input.size(0), -1).mean(dim=1)
69 | elif self.gan_mode == 'hinge':
70 | assert weight == None
71 | assert reduce_dim == True
72 | if for_discriminator:
73 | if target_is_real:
74 | minval = torch.min(input - 1, input * 0)
75 | loss = -torch.mean(minval)
76 | else:
77 | minval = torch.min(-input - 1, input * 0)
78 | loss = -torch.mean(minval)
79 | else:
80 | assert target_is_real, "The generator's hinge loss must be aiming for real"
81 | loss = -torch.mean(input)
82 |
83 | return loss
84 | else:
85 | # wgan
86 | assert weight is None and reduce_dim
87 | if target_is_real:
88 | return -input.mean()
89 | else:
90 | return input.mean()
91 |
92 | def __call__(self, input, target_is_real, weight=None, reduce_dim=True, for_discriminator=True):
93 | if isinstance(input, list):
94 | loss = 0
95 | for pred_i in input:
96 | if isinstance(pred_i, list):
97 | pred_i = pred_i[-1]
98 | loss_tensor = self.loss(pred_i, target_is_real, weight, reduce_dim, for_discriminator)
99 | bs = 1 if len(loss_tensor.size()) == 0 else loss_tensor.size(0)
100 | new_loss = torch.mean(loss_tensor.view(bs, -1), dim=1)
101 | loss += new_loss
102 | return loss / len(input)
103 | else:
104 | return self.loss(input, target_is_real, weight, reduce_dim, for_discriminator)
105 |
106 |
107 | class VGGLoss(nn.Module):
108 | def __init__(self, opt, gpu_ids):
109 | super(VGGLoss, self).__init__()
110 | self.vgg = VGG_Activations([1, 6, 11, 20, 29]).cuda()
111 | self.criterion = nn.L1Loss()
112 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
113 |
114 | def compute_loss(self, x_vgg, y_vgg):
115 | loss = 0
116 | for i in range(len(x_vgg)):
117 | loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
118 | return loss
119 |
120 | def forward(self, x, y):
121 | if len(x.size()) == 5:
122 | b, t, c, h, w = x.size()
123 | x, y = x.view(-1, c, h, w), y.view(-1, c, h, w)
124 |
125 | y_vgg = self.vgg(y)
126 | x_vgg = self.vgg(x)
127 | loss = self.compute_loss(x_vgg, y_vgg)
128 | return loss
129 |
130 | class MaskedL1Loss(nn.Module):
131 | def __init__(self):
132 | super(MaskedL1Loss, self).__init__()
133 | self.criterion = nn.L1Loss()
134 |
135 | def forward(self, input, target, mask):
136 | mask = mask.expand_as(input)
137 | loss = self.criterion(input * mask, target * mask)
138 | return loss
139 |
140 | class KLDLoss(nn.Module):
141 | def forward(self, mu, logvar):
142 | return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/models/networks/normalization.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import re
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | import torch.nn.utils.spectral_norm as sn
12 |
13 | from models.networks.base_network import batch_conv
14 | # from models.networks.sync_batchnorm import SynchronizedBatchNorm2d
15 | from apex.parallel import SyncBatchNorm as SynchronizedBatchNorm2d
16 |
17 |
18 | class SPADE(nn.Module):
19 | def __init__(self, norm_nc, hidden_nc=0, norm='batch', ks=3, params_free=False):
20 | super().__init__()
21 | pw = ks//2
22 | if not isinstance(hidden_nc, list): hidden_nc = [hidden_nc]
23 | for i, nhidden in enumerate(hidden_nc):
24 | mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
25 | mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
26 |
27 | if not params_free or (i != 0):
28 | s = str(i+1) if i > 0 else ''
29 | setattr(self, 'mlp_gamma%s' % s, mlp_gamma)
30 | setattr(self, 'mlp_beta%s' % s, mlp_beta)
31 |
32 | if 'batch' in norm:
33 | self.norm = SynchronizedBatchNorm2d(norm_nc, affine=False)
34 | else:
35 | self.norm = nn.InstanceNorm2d(norm_nc, affine=False, eps=0.1)
36 |
37 | def forward(self, x, maps, weights=None):
38 | if not isinstance(maps, list): maps = [maps]
39 | out = self.norm(x)
40 | for i in range(len(maps)):
41 | if maps[i] is None: continue
42 | m = F.interpolate(maps[i], size=x.size()[2:])
43 | if weights is None or (i != 0):
44 | s = str(i+1) if i > 0 else ''
45 | gamma = getattr(self, 'mlp_gamma%s' % s)(m)
46 | beta = getattr(self, 'mlp_beta%s' % s)(m)
47 | else:
48 | j = min(i, len(weights[0])-1)
49 | gamma = batch_conv(m, weights[0][j])
50 | beta = batch_conv(m, weights[1][j])
51 | out = out * (1 + gamma) + beta
52 | return out
53 |
54 | def get_nonspade_norm_layer(opt, norm_type='instance'):
55 | # helper function to get # output channels of the previous layer
56 | def get_out_channel(layer):
57 | if hasattr(layer, 'out_channels'):
58 | return getattr(layer, 'out_channels')
59 | return layer.weight.size(0)
60 |
61 | # this function will be returned
62 | def add_norm_layer(layer):
63 | nonlocal norm_type
64 | if norm_type.startswith('spectral'):
65 | layer = sn(layer)
66 | subnorm_type = norm_type[len('spectral'):]
67 |
68 | if subnorm_type == 'none' or len(subnorm_type) == 0:
69 | return layer
70 |
71 | # remove bias in the previous layer, which is meaningless
72 | # since it has no effect after normalization
73 | if getattr(layer, 'bias', None) is not None:
74 | delattr(layer, 'bias')
75 | layer.register_parameter('bias', None)
76 |
77 | if subnorm_type == 'batch':
78 | norm_layer = nn.BatchNorm2d(get_out_channel(layer), affine=True)
79 | elif subnorm_type == 'syncbatch':
80 | norm_layer = SynchronizedBatchNorm2d(get_out_channel(layer), affine=True)
81 | elif subnorm_type == 'instance':
82 | norm_layer = nn.InstanceNorm2d(get_out_channel(layer), affine=True, eps=0.1)
83 | else:
84 | raise ValueError('normalization layer %s is not recognized' % subnorm_type)
85 |
86 | return nn.Sequential(layer, norm_layer)
87 |
88 | return add_norm_layer
89 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : __init__.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
12 | from .replicate import DataParallelWithCallback, patch_replication_callback
13 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/comm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : comm.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import queue
12 | import collections
13 | import threading
14 |
15 | __all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
16 |
17 |
18 | class FutureResult(object):
19 | """A thread-safe future implementation. Used only as one-to-one pipe."""
20 |
21 | def __init__(self):
22 | self._result = None
23 | self._lock = threading.Lock()
24 | self._cond = threading.Condition(self._lock)
25 |
26 | def put(self, result):
27 | with self._lock:
28 | assert self._result is None, 'Previous result has\'t been fetched.'
29 | self._result = result
30 | self._cond.notify()
31 |
32 | def get(self):
33 | with self._lock:
34 | if self._result is None:
35 | self._cond.wait()
36 |
37 | res = self._result
38 | self._result = None
39 | return res
40 |
41 |
42 | _MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
43 | _SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
44 |
45 |
46 | class SlavePipe(_SlavePipeBase):
47 | """Pipe for master-slave communication."""
48 |
49 | def run_slave(self, msg):
50 | self.queue.put((self.identifier, msg))
51 | ret = self.result.get()
52 | self.queue.put(True)
53 | return ret
54 |
55 |
56 | class SyncMaster(object):
57 | """An abstract `SyncMaster` object.
58 |
59 | - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
60 | call `register(id)` and obtain an `SlavePipe` to communicate with the master.
61 | - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
62 | and passed to a registered callback.
63 | - After receiving the messages, the master device should gather the information and determine to message passed
64 | back to each slave devices.
65 | """
66 |
67 | def __init__(self, master_callback):
68 | """
69 |
70 | Args:
71 | master_callback: a callback to be invoked after having collected messages from slave devices.
72 | """
73 | self._master_callback = master_callback
74 | self._queue = queue.Queue()
75 | self._registry = collections.OrderedDict()
76 | self._activated = False
77 |
78 | def register_slave(self, identifier):
79 | """
80 | Register an slave device.
81 |
82 | Args:
83 | identifier: an identifier, usually is the device id.
84 |
85 | Returns: a `SlavePipe` object which can be used to communicate with the master device.
86 |
87 | """
88 | if self._activated:
89 | assert self._queue.empty(), 'Queue is not clean before next initialization.'
90 | self._activated = False
91 | self._registry.clear()
92 | future = FutureResult()
93 | self._registry[identifier] = _MasterRegistry(future)
94 | return SlavePipe(identifier, self._queue, future)
95 |
96 | def run_master(self, master_msg):
97 | """
98 | Main entry for the master device in each forward pass.
99 | The messages were first collected from each devices (including the master device), and then
100 | an callback will be invoked to compute the message to be sent back to each devices
101 | (including the master device).
102 |
103 | Args:
104 | master_msg: the message that the master want to send to itself. This will be placed as the first
105 | message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
106 |
107 | Returns: the message to be sent back to the master device.
108 |
109 | """
110 | self._activated = True
111 |
112 | intermediates = [(0, master_msg)]
113 | for i in range(self.nr_slaves):
114 | intermediates.append(self._queue.get())
115 |
116 | results = self._master_callback(intermediates)
117 | assert results[0][0] == 0, 'The first result should belongs to the master.'
118 |
119 | for i, res in results:
120 | if i == 0:
121 | continue
122 | self._registry[i].result.put(res)
123 |
124 | for i in range(self.nr_slaves):
125 | assert self._queue.get() is True
126 |
127 | return results[0][1]
128 |
129 | @property
130 | def nr_slaves(self):
131 | return len(self._registry)
132 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/replicate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : replicate.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import functools
12 |
13 | #from torch.nn.parallel.data_parallel import DataParallel
14 | import torch.nn as nn
15 |
16 | __all__ = [
17 | 'CallbackContext',
18 | 'execute_replication_callbacks',
19 | 'DataParallelWithCallback',
20 | 'patch_replication_callback'
21 | ]
22 |
23 |
24 | class DataParallel(nn.parallel.DataParallel):
25 | def replicate(self, module, device_ids):
26 | replicas = super(DataParallel, self).replicate(module, device_ids)
27 | replicas[0] = module
28 | return replicas
29 |
30 | class CallbackContext(object):
31 | pass
32 |
33 |
34 | def execute_replication_callbacks(modules):
35 | """
36 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
37 |
38 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
39 |
40 | Note that, as all modules are isomorphism, we assign each sub-module with a context
41 | (shared among multiple copies of this module on different devices).
42 | Through this context, different copies can share some information.
43 |
44 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
45 | of any slave copies.
46 | """
47 | master_copy = modules[0]
48 | nr_modules = len(list(master_copy.modules()))
49 | ctxs = [CallbackContext() for _ in range(nr_modules)]
50 |
51 | for i, module in enumerate(modules):
52 | for j, m in enumerate(module.modules()):
53 | if hasattr(m, '__data_parallel_replicate__'):
54 | m.__data_parallel_replicate__(ctxs[j], i)
55 |
56 |
57 | class DataParallelWithCallback(DataParallel):
58 | """
59 | Data Parallel with a replication callback.
60 |
61 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
62 | original `replicate` function.
63 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
64 |
65 | Examples:
66 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
67 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
68 | # sync_bn.__data_parallel_replicate__ will be invoked.
69 | """
70 |
71 | def replicate(self, module, device_ids):
72 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
73 | execute_replication_callbacks(modules)
74 | return modules
75 |
76 |
77 | def patch_replication_callback(data_parallel):
78 | """
79 | Monkey-patch an existing `DataParallel` object. Add the replication callback.
80 | Useful when you have customized `DataParallel` implementation.
81 |
82 | Examples:
83 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
84 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
85 | > patch_replication_callback(sync_bn)
86 | # this is equivalent to
87 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
88 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
89 | """
90 |
91 | assert isinstance(data_parallel, DataParallel)
92 |
93 | old_replicate = data_parallel.replicate
94 |
95 | @functools.wraps(old_replicate)
96 | def new_replicate(module, device_ids):
97 | modules = old_replicate(module, device_ids)
98 | execute_replication_callbacks(modules)
99 | return modules
100 |
101 | data_parallel.replicate = new_replicate
102 |
--------------------------------------------------------------------------------
/models/networks/sync_batchnorm/unittest.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # File : unittest.py
3 | # Author : Jiayuan Mao
4 | # Email : maojiayuan@gmail.com
5 | # Date : 27/01/2018
6 | #
7 | # This file is part of Synchronized-BatchNorm-PyTorch.
8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
9 | # Distributed under MIT License.
10 |
11 | import unittest
12 |
13 | import numpy as np
14 | from torch.autograd import Variable
15 |
16 |
17 | def as_numpy(v):
18 | if isinstance(v, Variable):
19 | v = v.data
20 | return v.cpu().numpy()
21 |
22 |
23 | class TorchTestCase(unittest.TestCase):
24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
25 | npa, npb = as_numpy(a), as_numpy(b)
26 | self.assertTrue(
27 | np.allclose(npa, npb, atol=atol),
28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
29 | )
30 |
--------------------------------------------------------------------------------
/models/networks/vgg.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torchvision
11 | from collections import OrderedDict
12 |
13 | class Vgg19(nn.Module):
14 | def __init__(self, requires_grad=False):
15 | super(Vgg19, self).__init__()
16 | vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
17 | self.slice1 = nn.Sequential()
18 | self.slice2 = nn.Sequential()
19 | self.slice3 = nn.Sequential()
20 | self.slice4 = nn.Sequential()
21 | self.slice5 = nn.Sequential()
22 | for x in range(2):
23 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
24 | for x in range(2, 7):
25 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
26 | for x in range(7, 12):
27 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
28 | for x in range(12, 21):
29 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
30 | for x in range(21, 30):
31 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
32 | if not requires_grad:
33 | for param in self.parameters():
34 | param.requires_grad = False
35 |
36 | def forward(self, X):
37 | h_relu1 = self.slice1(X)
38 | h_relu2 = self.slice2(h_relu1)
39 | h_relu3 = self.slice3(h_relu2)
40 | h_relu4 = self.slice4(h_relu3)
41 | h_relu5 = self.slice5(h_relu4)
42 | out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
43 | return out
44 |
45 | class VGG_Activations(nn.Module):
46 | def __init__(self, feature_idx):
47 | super(VGG_Activations, self).__init__()
48 | vgg_network = torchvision.models.vgg19(pretrained=True)
49 | features = list(vgg_network.features)
50 | self.features = nn.ModuleList(features).eval()
51 | self.idx_list = feature_idx
52 |
53 | def forward(self, x):
54 | results = []
55 | for ii, model in enumerate(self.features):
56 | x = model(x)
57 | if ii in self.idx_list:
58 | results.append(x)
59 | return results
60 |
--------------------------------------------------------------------------------
/models/trainer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import numpy as np
9 | import torch
10 | import time
11 | from collections import OrderedDict
12 | import fractions
13 | from subprocess import call
14 | def lcm(a,b): return abs(a * b)/fractions.gcd(a,b) if a and b else 0
15 |
16 | import util.util as util
17 | from util.visualizer import Visualizer
18 | from models.models import save_models, update_models
19 | from util.distributed import master_only, is_master, get_world_size
20 | from util.distributed import master_only_print as print
21 |
22 | class Trainer():
23 | def __init__(self, opt, data_loader):
24 | iter_path = os.path.join(opt.checkpoints_dir, opt.name, 'iter.txt')
25 | start_epoch, epoch_iter = 1, 0
26 | ### if continue training, recover previous states
27 | if opt.continue_train:
28 | if os.path.exists(iter_path):
29 | start_epoch, epoch_iter = np.loadtxt(iter_path , delimiter=',', dtype=int)
30 | print('Resuming from epoch %d at iteration %d' % (start_epoch, epoch_iter))
31 |
32 | print_freq = lcm(opt.print_freq, opt.batchSize)
33 | total_steps = (start_epoch-1) * len(data_loader) + epoch_iter
34 | total_steps = total_steps // print_freq * print_freq
35 |
36 | self.opt = opt
37 | self.epoch_iter, self.print_freq, self.total_steps, self.iter_path = epoch_iter, print_freq, total_steps, iter_path
38 | self.start_epoch, self.epoch_iter = start_epoch, epoch_iter
39 | self.dataset_size = len(data_loader)
40 | self.visualizer = Visualizer(opt)
41 |
42 | def start_of_iter(self, data):
43 | if self.total_steps % self.print_freq == 0:
44 | self.iter_start_time = time.time()
45 | self.total_steps += self.opt.batchSize
46 | self.epoch_iter += self.opt.batchSize
47 | self.save = self.total_steps % self.opt.display_freq == 0
48 | for k, v in data.items():
49 | if isinstance(v, torch.Tensor):
50 | data[k] = v.cuda()
51 | return data
52 |
53 | def end_of_iter(self, loss_dicts, output_list, model):
54 | opt = self.opt
55 | epoch, epoch_iter, print_freq, total_steps = self.epoch, self.epoch_iter, self.print_freq, self.total_steps
56 | ############## Display results and errors ##########
57 | ### print out errors
58 | if is_master() and total_steps % print_freq == 0:
59 | t = (time.time() - self.iter_start_time) / print_freq / get_world_size()
60 | errors = {k: v.data.item() if not isinstance(v, int) else v for k, v in loss_dicts.items()}
61 | self.visualizer.print_current_errors(epoch, epoch_iter, errors, t)
62 | self.visualizer.plot_current_errors(errors, total_steps)
63 |
64 | ### display output images
65 | if is_master() and self.save:
66 | visuals = save_all_tensors(opt, output_list, model)
67 | self.visualizer.display_current_results(visuals, epoch, total_steps)
68 |
69 | if is_master() and opt.print_mem:
70 | call(["nvidia-smi", "--format=csv", "--query-gpu=memory.used,memory.free"])
71 |
72 | ### save latest model
73 | save_models(opt, epoch, epoch_iter, total_steps, self.visualizer, self.iter_path, model)
74 | if epoch_iter > self.dataset_size - opt.batchSize:
75 | return True
76 | return False
77 |
78 | def start_of_epoch(self, epoch, model, data_loader):
79 | self.epoch = epoch
80 | self.epoch_start_time = time.time()
81 | if self.opt.distributed:
82 | data_loader.dataloader.sampler.set_epoch(epoch)
83 | # update model params
84 | update_models(self.opt, epoch, model, data_loader)
85 |
86 | def end_of_epoch(self, model):
87 | opt = self.opt
88 | iter_end_time = time.time()
89 | self.visualizer.vis_print(opt, 'End of epoch %d / %d \t Time Taken: %d sec' %
90 | (self.epoch, opt.niter + opt.niter_decay, time.time() - self.epoch_start_time))
91 |
92 | ### save model for this epoch
93 | save_models(opt, self.epoch, self.epoch_iter, self.total_steps, self.visualizer, self.iter_path, model, end_of_epoch=True)
94 | self.epoch_iter = 0
95 |
96 | def save_all_tensors(opt, output_list, model):
97 | fake_image, fake_raw_image, warped_image, flow, weight, atn_score, \
98 | target_label, target_image, flow_gt, conf_gt, ref_label, ref_image = output_list
99 |
100 | visual_list = [('target_label', util.visualize_label(opt, target_label, model)),
101 | ('synthesized_image', util.tensor2im(fake_image)),
102 | ('target_image', util.tensor2im(target_image)),
103 | ('ref_image', util.tensor2im(ref_image, tile=True)),
104 | ('raw_image', util.tensor2im(fake_raw_image)),
105 | ('warped_images', util.tensor2im(warped_image, tile=True)),
106 | ('flows', util.tensor2flow(flow, tile=True)),
107 | ('weights', util.tensor2im(weight, normalize=False, tile=True)),
108 | ('atn_score', util.tensor2im(atn_score, normalize=False)),
109 | ]
110 | visuals = OrderedDict(visual_list)
111 | return visuals
112 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | from .base_options import BaseOptions
8 |
9 | class TestOptions(BaseOptions):
10 | def initialize(self, parser):
11 | BaseOptions.initialize(self, parser)
12 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.')
13 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
14 | parser.add_argument('--how_many', type=int, default=300, help='how many test images to run')
15 | parser.set_defaults(serial_batches=True)
16 | parser.set_defaults(batchSize=1)
17 | parser.set_defaults(nThreads=1)
18 | parser.set_defaults(no_flip=True)
19 | self.isTrain = False
20 | return parser
21 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | from .base_options import BaseOptions
8 |
9 | class TrainOptions(BaseOptions):
10 | def initialize(self, parser):
11 | BaseOptions.initialize(self, parser)
12 | # for displays
13 | parser.add_argument('--display_freq', type=int, default=100, help='frequency of showing training results on screen')
14 | parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
15 | parser.add_argument('--save_latest_freq', type=int, default=1000, help='frequency of saving the latest results')
16 | parser.add_argument('--save_epoch_freq', type=int, default=5, help='frequency of saving checkpoints at the end of epochs')
17 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/')
18 | parser.add_argument('--debug', action='store_true', help='only do one epoch and displays at each iteration')
19 | parser.add_argument('--print_mem', action='store_true', help='print memory usage')
20 | parser.add_argument('--print_G', action='store_true', help='print network G')
21 | parser.add_argument('--print_D', action='store_true', help='print network D')
22 |
23 | # for training
24 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
25 | parser.add_argument('--load_pretrain', type=str, default='', help='load the pretrained model from the specified location')
26 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
27 | parser.add_argument('--niter', type=int, default=50, help='# of iter at starting learning rate')
28 | parser.add_argument('--niter_decay', type=int, default=50, help='# of iter to linearly decay learning rate to zero')
29 | parser.add_argument('--niter_single', type=int, default=50, help='# of iter for single frame training')
30 | parser.add_argument('--niter_step', type=int, default=10, help='# of iter to double the length of training sequence')
31 |
32 | # for temporal
33 | parser.add_argument('--n_frames_D', type=int, default=2, help='number of frames to feed into temporal discriminator')
34 | parser.add_argument('--n_frames_total', type=int, default=2, help='the overall number of frames in a sequence to train with')
35 | parser.add_argument('--max_t_step', type=int, default=4, help='max spacing between neighboring sampled frames. If greater than 1, the network may randomly skip frames during training.')
36 |
37 | self.isTrain = True
38 | return parser
39 |
--------------------------------------------------------------------------------
/scripts/download_datasets.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | from download_gdrive import *
9 |
10 | file_id = '1NyxrzJbHgDOpf-nRJhfxsHqWNDW_D1TV'
11 | chpt_path = './datasets/'
12 | if not os.path.isdir(chpt_path):
13 | os.makedirs(chpt_path)
14 | destination = os.path.join(chpt_path, 'datasets.zip')
15 | download_file_from_google_drive(file_id, destination)
16 | unzip_file(destination, chpt_path)
--------------------------------------------------------------------------------
/scripts/download_flownet2.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | from download_gdrive import *
9 | os.system('cd models/networks/flownet2_pytorch/; bash install.sh; cd ../../../')
10 |
11 | file_id = '1E8re-b6csNuo-abg1vJKCDjCzlIam50F'
12 | chpt_path = './models/networks/flownet2_pytorch/'
13 | destination = os.path.join(chpt_path, 'FlowNet2_checkpoint.pth.tar')
14 | download_file_from_google_drive(file_id, destination)
--------------------------------------------------------------------------------
/scripts/download_gdrive.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | # Download code taken from Code taken from https://stackoverflow.com questions/25010369/wget-curl-large-file-from-google-drive/39225039#39225039
9 | import requests, zipfile, os
10 | def download_file_from_google_drive(id, destination):
11 | URL = "https://docs.google.com/uc?export=download"
12 | session = requests.Session()
13 | response = session.get(URL, params = { 'id' : id }, stream = True)
14 | token = get_confirm_token(response)
15 | if token:
16 | params = { 'id' : id, 'confirm' : token }
17 | response = session.get(URL, params = params, stream = True)
18 | save_response_content(response, destination)
19 | def get_confirm_token(response):
20 | for key, value in response.cookies.items():
21 | if key.startswith('download_warning'):
22 | return value
23 | return None
24 | def save_response_content(response, destination):
25 | CHUNK_SIZE = 32768
26 | with open(destination, "wb") as f:
27 | for chunk in response.iter_content(CHUNK_SIZE):
28 | if chunk: # filter out keep-alive new chunks
29 | f.write(chunk)
30 |
31 | def unzip_file(file_name, unzip_path):
32 | zip_ref = zipfile.ZipFile(file_name, 'r')
33 | zip_ref.extractall(unzip_path)
34 | zip_ref.close()
35 | os.remove(file_name)
--------------------------------------------------------------------------------
/scripts/face/test_256.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | python test.py --name face_256 --dataset_mode fewshot_face --adaptive_spade --warp_ref --spade_combine
9 |
--------------------------------------------------------------------------------
/scripts/face/test_512.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | python test.py --name face_512 --dataset_mode fewshot_face --loadSize 512 --fineSize 512 --adaptive_spade --warp_ref --spade_combine
9 |
--------------------------------------------------------------------------------
/scripts/face/train_g1_256.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | python train.py --name face_256 --dataset_mode fewshot_face \
9 | --adaptive_spade --warp_ref --spade_combine --batchSize 4 --continue_train
--------------------------------------------------------------------------------
/scripts/face/train_g8_256.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | python train.py --name face_256 --dataset_mode fewshot_face \
9 | --adaptive_spade --warp_ref --spade_combine \
10 | --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 32 --nThreads 32 --continue_train
--------------------------------------------------------------------------------
/scripts/face/train_g8_512.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | python train.py --name face_512 --dataset_mode fewshot_face \
9 | --loadSize 512 --fineSize 512 --num_D 2 \
10 | --adaptive_spade --warp_ref --spade_combine \
11 | --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 8 --nThreads 32 --continue_train
--------------------------------------------------------------------------------
/scripts/pose/test.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | python test.py --name pose --dataset_mode fewshot_pose \
9 | --adaptive_spade --warp_ref --spade_combine --remove_face_labels --finetune
10 |
--------------------------------------------------------------------------------
/scripts/pose/train_g1.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | python train.py --name pose --dataset_mode fewshot_pose \
9 | --adaptive_spade --warp_ref --spade_combine --remove_face_labels --add_face_D \
10 | --batchSize 2 --niter 100 --niter_single 100 --continue_train
11 |
--------------------------------------------------------------------------------
/scripts/pose/train_g8.sh:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 |
8 | python train.py --name pose --dataset_mode fewshot_pose \
9 | --adaptive_spade --warp_ref --spade_combine --remove_face_labels --add_face_D \
10 | --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 30 --nThreads 32 --niter 100 --niter_single 100 --continue_train
--------------------------------------------------------------------------------
/scripts/street/test.sh:
--------------------------------------------------------------------------------
1 | python test.py --name street --dataset_mode fewshot_street --adaptive_spade --loadSize 512 --fineSize 512
2 |
--------------------------------------------------------------------------------
/scripts/street/train_g1.sh:
--------------------------------------------------------------------------------
1 | python train.py --name street --dataset_mode fewshot_street \
2 | --adaptive_spade --loadSize 512 --fineSize 512 --batchSize 6 --continue_train
3 |
--------------------------------------------------------------------------------
/scripts/street/train_g8.sh:
--------------------------------------------------------------------------------
1 | python train.py --name street --dataset_mode fewshot_street \
2 | --adaptive_spade --loadSize 512 --fineSize 512 \
3 | --gpu_ids 0,1,2,3,4,5,6,7 --batchSize 46 --nThreads 16 --continue_train
4 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import numpy as np
9 | import torch
10 | import cv2
11 | from collections import OrderedDict
12 |
13 | from options.test_options import TestOptions
14 | from data.data_loader import CreateDataLoader
15 | from models.models import create_model
16 | import util.util as util
17 | from util.visualizer import Visualizer
18 | from util import html
19 |
20 | opt = TestOptions().parse()
21 |
22 | ### setup dataset
23 | data_loader = CreateDataLoader(opt)
24 | dataset = data_loader.load_data()
25 |
26 | ### setup models
27 | model = create_model(opt)
28 | model.eval()
29 | visualizer = Visualizer(opt)
30 |
31 | # create website
32 | web_dir = os.path.join(opt.results_dir, opt.name, '%s_%s' % (opt.phase, opt.which_epoch))
33 | if opt.finetune: web_dir += '_finetune'
34 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.which_epoch), infer=True)
35 |
36 | # test
37 | for i, data in enumerate(dataset):
38 | if i >= opt.how_many or i >= len(dataset): break
39 | img_path = data['path']
40 | data_list = [data['tgt_label'], data['tgt_image'], None, None, data['ref_label'], data['ref_image'], None, None, None]
41 | synthesized_image, _, _, _, _, _ = model(data_list)
42 |
43 | synthesized_image = util.tensor2im(synthesized_image)
44 | tgt_image = util.tensor2im(data['tgt_image'])
45 | ref_image = util.tensor2im(data['ref_image'], tile=True)
46 | seq = data['seq'][0]
47 | visual_list = [ref_image, tgt_image, synthesized_image]
48 | visuals = OrderedDict([(seq, np.hstack(visual_list)),
49 | (seq + '/synthesized', synthesized_image),
50 | (seq + '/ref_image', ref_image if i == 0 else None),
51 | ])
52 | print('process image... %s' % img_path)
53 | visualizer.save_images(webpage, visuals, img_path)
54 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import numpy as np
9 | import torch
10 |
11 | from options.train_options import TrainOptions
12 | from data.data_loader import CreateDataLoader
13 | from models.models import create_model
14 | from models.loss_collector import loss_backward
15 | from models.trainer import Trainer
16 | from util.distributed import init_dist
17 | from util.distributed import master_only_print as print
18 |
19 | def train():
20 | opt = TrainOptions().parse()
21 |
22 | if opt.distributed:
23 | init_dist()
24 | print('batch size per GPU: %d' % opt.batchSize)
25 | torch.backends.cudnn.benchmark = True
26 |
27 | ### setup dataset
28 | data_loader = CreateDataLoader(opt)
29 | dataset = data_loader.load_data()
30 | pose = 'pose' in opt.dataset_mode
31 |
32 | ### setup trainer
33 | trainer = Trainer(opt, data_loader)
34 |
35 | ### setup models
36 | model, flowNet, [optimizer_G, optimizer_D] = create_model(opt, trainer.start_epoch)
37 | flow_gt = conf_gt = [None] * 2
38 |
39 | for epoch in range(trainer.start_epoch, opt.niter + opt.niter_decay + 1):
40 | if opt.distributed:
41 | dataset.sampler.set_epoch(epoch)
42 | trainer.start_of_epoch(epoch, model, data_loader)
43 | n_frames_total, n_frames_load = data_loader.dataset.n_frames_total, opt.n_frames_per_gpu
44 | for idx, data in enumerate(dataset, start=trainer.epoch_iter):
45 | data = trainer.start_of_iter(data)
46 |
47 | if not opt.no_flow_gt:
48 | data_list = [data['tgt_label'], data['ref_label']] if pose else [data['tgt_image'], data['ref_image']]
49 | flow_gt, conf_gt = flowNet(data_list, epoch)
50 | data_list = [data['tgt_label'], data['tgt_image'], flow_gt, conf_gt]
51 | data_ref_list = [data['ref_label'], data['ref_image']]
52 | data_prev = [None, None, None]
53 |
54 | ############## Forward Pass ######################
55 | for t in range(0, n_frames_total, n_frames_load):
56 | data_list_t = get_data_t(data_list, n_frames_load, t) + data_ref_list + data_prev
57 |
58 | d_losses = model(data_list_t, mode='discriminator')
59 | d_losses = loss_backward(opt, d_losses, optimizer_D, 1)
60 |
61 | g_losses, generated, data_prev = model(data_list_t, save_images=trainer.save, mode='generator')
62 | g_losses = loss_backward(opt, g_losses, optimizer_G, 0)
63 |
64 | loss_dict = dict(zip(model.module.lossCollector.loss_names, g_losses + d_losses))
65 |
66 | if trainer.end_of_iter(loss_dict, generated + data_list + data_ref_list, model):
67 | break
68 | trainer.end_of_epoch(model)
69 |
70 | def get_data_t(data, n_frames_load, t):
71 | if data is None: return None
72 | if type(data) == list:
73 | return [get_data_t(d, n_frames_load, t) for d in data]
74 | return data[:,t:t+n_frames_load]
75 |
76 | if __name__ == "__main__":
77 | train()
--------------------------------------------------------------------------------
/util/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
--------------------------------------------------------------------------------
/util/distributed.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import os
8 | import numpy as np
9 | import random
10 | import functools
11 | import torch
12 | import torch.distributed as dist
13 |
14 |
15 | def init_dist(launcher='pytorch', backend='nccl', **kwargs):
16 | raise ValueError('Distributed training is not fully tested yet and might be unstable. '
17 | 'If you are confident to run it, please comment out this line.')
18 | if dist.is_initialized():
19 | return torch.cuda.current_device()
20 | set_random_seed(get_rank())
21 | rank = int(os.environ['RANK'])
22 | num_gpus = torch.cuda.device_count()
23 | gpu_id = rank % num_gpus
24 | torch.cuda.set_device(gpu_id)
25 | dist.init_process_group(backend=backend, **kwargs)
26 | return gpu_id
27 |
28 |
29 | def get_rank():
30 | if dist.is_initialized():
31 | rank = dist.get_rank()
32 | else:
33 | rank = 0
34 | return rank
35 |
36 |
37 | def get_world_size():
38 | if dist.is_initialized():
39 | world_size = dist.get_world_size()
40 | else:
41 | world_size = 1
42 | return world_size
43 |
44 |
45 | def master_only(func):
46 | @functools.wraps(func)
47 | def wrapper(*args, **kwargs):
48 | if get_rank() == 0:
49 | return func(*args, **kwargs)
50 | else:
51 | return None
52 | return wrapper
53 |
54 |
55 | def is_master():
56 | """check if current process is the master"""
57 | return get_rank() == 0
58 |
59 |
60 | @master_only
61 | def master_only_print(*args):
62 | """master-only print"""
63 | print(*args)
64 |
65 |
66 | def dist_reduce_tensor(tensor):
67 | """ Reduce to rank 0 """
68 | world_size = get_world_size()
69 | if world_size < 2:
70 | return tensor
71 | with torch.no_grad():
72 | dist.reduce(tensor, dst=0)
73 | if get_rank() == 0:
74 | tensor /= world_size
75 | return tensor
76 |
77 |
78 | def dist_all_reduce_tensor(tensor):
79 | """ Reduce to all ranks """
80 | world_size = get_world_size()
81 | if world_size < 2:
82 | return tensor
83 | with torch.no_grad():
84 | dist.all_reduce(tensor)
85 | tensor.div_(world_size)
86 | return tensor
87 |
88 |
89 | def dist_all_gather_tensor(tensor):
90 | """ gather to all ranks """
91 | world_size = get_world_size()
92 | if world_size < 2:
93 | return [tensor]
94 | tensor_list = [
95 | torch.ones_like(tensor) for _ in range(dist.get_world_size())]
96 | with torch.no_grad():
97 | dist.all_gather(tensor_list, tensor)
98 | return tensor_list
99 |
100 |
101 | def set_random_seed(seed):
102 | """Set random seeds for everything.
103 | Inputs:
104 | seed (int): Random seed.
105 | """
106 | random.seed(seed)
107 | np.random.seed(seed)
108 | torch.manual_seed(seed)
109 | torch.cuda.manual_seed(seed)
110 | torch.cuda.manual_seed_all(seed)
111 |
--------------------------------------------------------------------------------
/util/html.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import datetime
8 | import dominate
9 | from dominate.tags import *
10 | import os
11 |
12 |
13 | class HTML:
14 | def __init__(self, web_dir, title, refresh=0, infer=False):
15 | self.title = title
16 | self.web_dir = web_dir
17 | if not infer:
18 | self.img_dir = os.path.join(self.web_dir, 'images')
19 | else:
20 | self.img_dir = self.web_dir
21 | if len(self.web_dir) > 0 and not os.path.exists(self.web_dir):
22 | os.makedirs(self.web_dir)
23 | if len(self.web_dir) > 0 and not os.path.exists(self.img_dir):
24 | os.makedirs(self.img_dir)
25 |
26 | self.doc = dominate.document(title=title)
27 | with self.doc:
28 | h1(datetime.datetime.now().strftime("%I:%M%p on %B %d, %Y"))
29 | if refresh > 0:
30 | with self.doc.head:
31 | meta(http_equiv="refresh", content=str(refresh))
32 |
33 | def get_image_dir(self):
34 | return self.img_dir
35 |
36 | def add_header(self, str):
37 | with self.doc:
38 | h3(str)
39 |
40 | def add_table(self, border=1):
41 | self.t = table(border=border, style="table-layout: fixed;")
42 | self.doc.add(self.t)
43 |
44 | def add_images(self, ims, txts, links, width=512, height=0):
45 | self.add_table()
46 | with self.t:
47 | with tr():
48 | for im, txt, link in zip(ims, txts, links):
49 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
50 | with p():
51 | with a(href=os.path.join('images', link)):
52 | if height != 0:
53 | img(style="width:%dpx;height:%dpx" % (width, height), src=os.path.join('images', im))
54 | else:
55 | img(style="width:%dpx" % (width), src=os.path.join('images', im))
56 | br()
57 | p(txt)
58 |
59 | def save(self):
60 | html_file = '%s/index.html' % self.web_dir
61 | f = open(html_file, 'wt')
62 | f.write(self.doc.render())
63 | f.close()
64 |
65 |
66 | if __name__ == '__main__':
67 | html = HTML('web/', 'test_html')
68 | html.add_header('hello world')
69 |
70 | ims = []
71 | txts = []
72 | links = []
73 | for n in range(4):
74 | ims.append('image_%d.jpg' % n)
75 | txts.append('text_%d' % n)
76 | links.append('image_%d.jpg' % n)
77 | html.add_images(ims, txts, links)
78 | html.save()
79 |
--------------------------------------------------------------------------------
/util/image_pool.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2019, NVIDIA Corporation. All rights reserved.
2 | #
3 | # This work is made available
4 | # under the Nvidia Source Code License (1-way Commercial).
5 | # To view a copy of this license, visit
6 | # https://nvlabs.github.io/few-shot-vid2vid/License.txt
7 | import random
8 | import torch
9 | from torch.autograd import Variable
10 | class ImagePool():
11 | def __init__(self, pool_size):
12 | self.pool_size = pool_size
13 | if self.pool_size > 0:
14 | self.num_imgs = 0
15 | self.images = []
16 |
17 | def query(self, images):
18 | if self.pool_size == 0:
19 | return images
20 | return_images = []
21 | for image in images.data:
22 | image = torch.unsqueeze(image, 0)
23 | if self.num_imgs < self.pool_size:
24 | self.num_imgs = self.num_imgs + 1
25 | self.images.append(image)
26 | return_images.append(image)
27 | else:
28 | p = random.uniform(0, 1)
29 | if p > 0.5:
30 | random_id = random.randint(0, self.pool_size-1)
31 | tmp = self.images[random_id].clone()
32 | self.images[random_id] = image
33 | return_images.append(tmp)
34 | else:
35 | return_images.append(image)
36 | return_images = Variable(torch.cat(return_images, 0))
37 | return return_images
38 |
--------------------------------------------------------------------------------