├── RAFT ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── corr.cpython-36.pyc │ ├── corr.cpython-37.pyc │ ├── corr.cpython-38.pyc │ ├── corr.cpython-39.pyc │ ├── extractor.cpython-36.pyc │ ├── extractor.cpython-37.pyc │ ├── extractor.cpython-38.pyc │ ├── extractor.cpython-39.pyc │ ├── raft.cpython-36.pyc │ ├── raft.cpython-37.pyc │ ├── raft.cpython-38.pyc │ ├── raft.cpython-39.pyc │ ├── update.cpython-36.pyc │ ├── update.cpython-37.pyc │ ├── update.cpython-38.pyc │ └── update.cpython-39.pyc ├── corr.py ├── datasets.py ├── demo.py ├── extractor.py ├── raft.py ├── update.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── flow_viz.cpython-36.pyc │ ├── flow_viz.cpython-37.pyc │ ├── flow_viz.cpython-38.pyc │ ├── flow_viz.cpython-39.pyc │ ├── frame_utils.cpython-36.pyc │ ├── frame_utils.cpython-37.pyc │ ├── frame_utils.cpython-38.pyc │ ├── frame_utils.cpython-39.pyc │ ├── utils.cpython-36.pyc │ ├── utils.cpython-37.pyc │ ├── utils.cpython-38.pyc │ └── utils.cpython-39.pyc │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── README.md ├── causal ├── discriminator.py └── gen.py ├── data ├── factormatte_GANCGANFlip148_dataset.py ├── gen_backgroundPosEx-Copy1.ipynb ├── gen_backgroundPosEx.ipynb ├── gen_foregroundPosEx.py ├── homographies.txt ├── keypoint_homo_short.ipynb ├── misc_data_process.py └── noninteraction_ind.txt ├── datasets ├── confidence.py └── homography.py ├── models ├── factormatte_GANFlip_model.py └── networks.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── base_options.cpython-38.pyc │ ├── base_options.cpython-39.pyc │ ├── test_options.cpython-39.pyc │ ├── train_options.cpython-38.pyc │ └── train_options.cpython-39.pyc ├── base_options.py ├── test_options.py └── train_options.py ├── prepare_data_stage1.sh ├── requirements.txt ├── test.py ├── third_party ├── __init__.py ├── __init__.pyc ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── __init__.cpython-39.pyc ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── base_dataset.cpython-38.pyc │ │ ├── base_dataset.cpython-39.pyc │ │ ├── image_folder.cpython-38.pyc │ │ └── image_folder.cpython-39.pyc │ ├── base_dataset.py │ └── image_folder.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── __init__.cpython-39.pyc │ │ ├── base_model.cpython-38.pyc │ │ ├── base_model.cpython-39.pyc │ │ ├── networks.cpython-38.pyc │ │ ├── networks.cpython-39.pyc │ │ ├── networks_lnr.cpython-38.pyc │ │ └── networks_lnr.cpython-39.pyc │ ├── base_model.py │ ├── networks.py │ └── networks_lnr.py └── util │ ├── __init__.py │ ├── __init__.pyc │ ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── __init__.cpython-39.pyc │ ├── html.cpython-38.pyc │ ├── html.cpython-39.pyc │ ├── util.cpython-38.pyc │ ├── util.cpython-39.pyc │ ├── visualizer.cpython-38.pyc │ └── visualizer.cpython-39.pyc │ ├── html.py │ ├── util.py │ ├── util.pyc │ └── visualizer.py ├── train_GAN.py ├── utils.py ├── video_completion.py └── weight ├── edge_completion.pth ├── imagenet_deepfill.pth └── raft-things.pth /RAFT/__init__.py: -------------------------------------------------------------------------------- 1 | # from .demo import RAFT_infer 2 | from .raft import RAFT 3 | -------------------------------------------------------------------------------- /RAFT/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/corr.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/corr.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/corr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/corr.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/corr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/corr.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/corr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/corr.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/extractor.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/extractor.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/extractor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/extractor.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/extractor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/extractor.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/extractor.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/extractor.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/raft.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/raft.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/raft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/raft.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/raft.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/raft.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/raft.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/raft.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/update.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/update.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/update.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/update.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/update.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/__pycache__/update.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/update.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from .utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class CorrLayer(torch.autograd.Function): 64 | @staticmethod 65 | def forward(ctx, fmap1, fmap2, coords, r): 66 | fmap1 = fmap1.contiguous() 67 | fmap2 = fmap2.contiguous() 68 | coords = coords.contiguous() 69 | ctx.save_for_backward(fmap1, fmap2, coords) 70 | ctx.r = r 71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) 72 | return corr 73 | 74 | @staticmethod 75 | def backward(ctx, grad_corr): 76 | fmap1, fmap2, coords = ctx.saved_tensors 77 | grad_corr = grad_corr.contiguous() 78 | fmap1_grad, fmap2_grad, coords_grad = \ 79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) 80 | return fmap1_grad, fmap2_grad, coords_grad, None 81 | 82 | 83 | class AlternateCorrBlock: 84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 85 | self.num_levels = num_levels 86 | self.radius = radius 87 | 88 | self.pyramid = [(fmap1, fmap2)] 89 | for i in range(self.num_levels): 90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 92 | self.pyramid.append((fmap1, fmap2)) 93 | 94 | def __call__(self, coords): 95 | 96 | coords = coords.permute(0, 2, 3, 1) 97 | B, H, W, _ = coords.shape 98 | 99 | corr_list = [] 100 | for i in range(self.num_levels): 101 | r = self.radius 102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) 103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) 104 | 105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) 107 | corr_list.append(corr.squeeze(1)) 108 | 109 | corr = torch.stack(corr_list, dim=1) 110 | corr = corr.reshape(B, -1, H, W) 111 | return corr / 16.0 112 | -------------------------------------------------------------------------------- /RAFT/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /RAFT/demo.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import argparse 3 | import os 4 | import cv2 5 | import glob 6 | import numpy as np 7 | import torch 8 | from PIL import Image 9 | 10 | from .raft import RAFT 11 | from .utils import flow_viz 12 | from .utils.utils import InputPadder 13 | 14 | 15 | 16 | DEVICE = 'cuda' 17 | 18 | def load_image(imfile): 19 | img = np.array(Image.open(imfile)).astype(np.uint8) 20 | img = torch.from_numpy(img).permute(2, 0, 1).float() 21 | return img 22 | 23 | 24 | def load_image_list(image_files): 25 | images = [] 26 | for imfile in sorted(image_files): 27 | images.append(load_image(imfile)) 28 | 29 | images = torch.stack(images, dim=0) 30 | images = images.to(DEVICE) 31 | 32 | padder = InputPadder(images.shape) 33 | return padder.pad(images)[0] 34 | 35 | 36 | def viz(img, flo): 37 | img = img[0].permute(1,2,0).cpu().numpy() 38 | flo = flo[0].permute(1,2,0).cpu().numpy() 39 | 40 | # map flow to rgb image 41 | flo = flow_viz.flow_to_image(flo) 42 | # img_flo = np.concatenate([img, flo], axis=0) 43 | img_flo = flo 44 | 45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) 46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 47 | # cv2.waitKey() 48 | 49 | 50 | def demo(args): 51 | model = torch.nn.DataParallel(RAFT(args)) 52 | model.load_state_dict(torch.load(args.model)) 53 | 54 | model = model.module 55 | model.to(DEVICE) 56 | model.eval() 57 | 58 | with torch.no_grad(): 59 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 60 | glob.glob(os.path.join(args.path, '*.jpg')) 61 | 62 | images = load_image_list(images) 63 | for i in range(images.shape[0]-1): 64 | image1 = images[i,None] 65 | image2 = images[i+1,None] 66 | 67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 68 | viz(image1, flow_up) 69 | 70 | 71 | def RAFT_infer(args): 72 | model = torch.nn.DataParallel(RAFT(args)) 73 | model.load_state_dict(torch.load(args.model)) 74 | 75 | model = model.module 76 | model.to(DEVICE) 77 | model.eval() 78 | 79 | return model 80 | -------------------------------------------------------------------------------- /RAFT/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /RAFT/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .update import BasicUpdateBlock, SmallUpdateBlock 7 | from .extractor import BasicEncoder, SmallEncoder 8 | from .corr import CorrBlock, AlternateCorrBlock 9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8 10 | 11 | try: 12 | autocast = torch.cuda.amp.autocast 13 | except: 14 | # dummy autocast for PyTorch < 1.6 15 | class autocast: 16 | def __init__(self, enabled): 17 | pass 18 | def __enter__(self): 19 | pass 20 | def __exit__(self, *args): 21 | pass 22 | 23 | 24 | class RAFT(nn.Module): 25 | def __init__(self, args): 26 | super(RAFT, self).__init__() 27 | self.args = args 28 | 29 | if args.small: 30 | self.hidden_dim = hdim = 96 31 | self.context_dim = cdim = 64 32 | args.corr_levels = 4 33 | args.corr_radius = 3 34 | 35 | else: 36 | self.hidden_dim = hdim = 128 37 | self.context_dim = cdim = 128 38 | args.corr_levels = 4 39 | args.corr_radius = 4 40 | 41 | if 'dropout' not in args._get_kwargs(): 42 | args.dropout = 0 43 | 44 | if 'alternate_corr' not in args._get_kwargs(): 45 | args.alternate_corr = False 46 | 47 | # feature network, context network, and update block 48 | if args.small: 49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) 50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) 51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 52 | 53 | else: 54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) 55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) 56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 57 | 58 | 59 | def freeze_bn(self): 60 | for m in self.modules(): 61 | if isinstance(m, nn.BatchNorm2d): 62 | m.eval() 63 | 64 | def initialize_flow(self, img): 65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 66 | N, C, H, W = img.shape 67 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 68 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 69 | 70 | # optical flow computed as difference: flow = coords1 - coords0 71 | return coords0, coords1 72 | 73 | def upsample_flow(self, flow, mask): 74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 75 | N, _, H, W = flow.shape 76 | mask = mask.view(N, 1, 9, 8, 8, H, W) 77 | mask = torch.softmax(mask, dim=2) 78 | 79 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 81 | 82 | up_flow = torch.sum(mask * up_flow, dim=2) 83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 84 | return up_flow.reshape(N, 2, 8*H, 8*W) 85 | 86 | 87 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 88 | """ Estimate optical flow between pair of frames """ 89 | 90 | image1 = 2 * (image1 / 255.0) - 1.0 91 | image2 = 2 * (image2 / 255.0) - 1.0 92 | 93 | image1 = image1.contiguous() 94 | image2 = image2.contiguous() 95 | 96 | hdim = self.hidden_dim 97 | cdim = self.context_dim 98 | 99 | # run the feature network 100 | with autocast(enabled=self.args.mixed_precision): 101 | fmap1, fmap2 = self.fnet([image1, image2]) 102 | 103 | fmap1 = fmap1.float() 104 | fmap2 = fmap2.float() 105 | if self.args.alternate_corr: 106 | corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius) 107 | else: 108 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 109 | 110 | # run the context network 111 | with autocast(enabled=self.args.mixed_precision): 112 | cnet = self.cnet(image1) 113 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 114 | net = torch.tanh(net) 115 | inp = torch.relu(inp) 116 | 117 | coords0, coords1 = self.initialize_flow(image1) 118 | 119 | if flow_init is not None: 120 | coords1 = coords1 + flow_init 121 | 122 | flow_predictions = [] 123 | for itr in range(iters): 124 | coords1 = coords1.detach() 125 | corr = corr_fn(coords1) # index correlation volume 126 | 127 | flow = coords1 - coords0 128 | with autocast(enabled=self.args.mixed_precision): 129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 130 | 131 | # F(t+1) = F(t) + \Delta(t) 132 | coords1 = coords1 + delta_flow 133 | 134 | # upsample predictions 135 | if up_mask is None: 136 | flow_up = upflow8(coords1 - coords0) 137 | else: 138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 139 | 140 | flow_predictions.append(flow_up) 141 | 142 | if test_mode: 143 | return coords1 - coords0, flow_up 144 | 145 | return flow_predictions 146 | -------------------------------------------------------------------------------- /RAFT/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /RAFT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .flow_viz import flow_to_image 2 | from .frame_utils import writeFlow 3 | -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/flow_viz.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/flow_viz.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/flow_viz.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/flow_viz.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/flow_viz.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/flow_viz.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/flow_viz.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/flow_viz.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/frame_utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/frame_utils.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/frame_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/frame_utils.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/frame_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/frame_utils.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/frame_utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/frame_utils.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /RAFT/utils/__pycache__/utils.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/utils.cpython-39.pyc -------------------------------------------------------------------------------- /RAFT/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /RAFT/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, rad_max=None): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | if not rad_max: 128 | rad = np.sqrt(np.square(u) + np.square(v)) 129 | rad_max = np.max(rad) 130 | epsilon = 1e-5 131 | u = u / (rad_max + epsilon) 132 | v = v / (rad_max + epsilon) 133 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /RAFT/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /RAFT/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FactorMatte 2 | ## Environment 3 | `conda create -n factormatte python=3.9 anaconda` 4 |
5 | `conda activate factormatte` 6 |
7 | Use conda or pip to install requirements.txt 8 | 9 | ## Example Video 10 | ### Download Dataset and put into the datasets/ folder 11 | https://drive.google.com/file/d/1-nZ9VA8bqRvll_4HEPGOxihIJ4o8y0kY/view?usp=sharing 12 | 13 | ### Stage 1 14 | `python train_GAN.py --name sand_car_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2 --stage 1 --dataset_mode omnimatte_GANCGANFlip148 --model omnimatte_GANFlip --dataroot ./datasets/sand_car --height 192 --width 288 --save_by_epoch --prob_masks --lambda_rgb_warp 1e-1 --lambda_alpha_warp 1e-1 --model_v 4 --residual_noise --strides 0,0,0 --num_Ds 0,0,0 --n_layers 0,0,0 --display_ind 63 --pos_ex_dirs , --batch_size 16 --n_epochs 1200 --bg_noise --gpu_ids 1,0 --lambda_recon_flow 1e-2` 15 | 16 | Copy the trained weights to the next stage's training folder: `cp 1110_checkpoints/sand_car_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2/*1200* 1110_checkpoints/sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_noninter_flowmask_flowrecon1e-2` 17 | 18 | Run test to generate the background image: `python test.py --name sand_car_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2 --dataset_mode omnimatte_GANCGANFlip148 --model omnimatte_GANFlip --dataroot ./datasets/DVM_manstatic --prob_masks --model_v 4 --residual_noise --strides 0,0,0 --num_Ds 0,0,0 --n_layers 0,0,0 --pos_ex_dirs , --epoch 1200 --stage 1 --gpu_ids 0 --start_ind 0 --width 512 --height 288` 19 | 20 | And put it in the data folder, it'll be used for the following stages. `cp results/sand_car_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2/test_1200_/panorama.png datasets/sand_car/bg_gt.png` 21 | 22 | ### Stage 2 23 | `python train_GAN.py --name sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_noninter_flowmask_flowrecon1e-2 --init_flowmask --lambda_recon_flow 1e-2 --dataset_mode factormatte_GANCGANFlip148 --model factormatte_GANFlip --dataroot ./datasets/sand_car --save_by_epoch --prob_masks --lambda_rgb_warp 1e-1 --lambda_alpha_warp 1e-1 --residual_noise --strides 0,2,2 --num_Ds 0,1,3 --n_layers 0,3,3 --start_ind 0 --noninter_only --width 288 --height 192 --discriminator_transform randomcrop --pos_ex_dirs 0uniform_0gaussian_dark_0flip_0elastic_0.25blursigma0.20.2k5_0.25gaussian_noise_std27mean0_rawframes,0rot_0flip_0.25blursigma0.20.2k5_0.25gaussian_noise_std27mean0_ --gpu_ids 0 --n_epochs 2400 --continue_train --epoch 1200 --stage 2 --display_ind 15` 24 | 25 | Copy the trained weights to the next stage's training folder: `cp 1110_checkpoints/sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_noninter_flowmask_flowrecon1e-2/*2400* 1110_checkpoints/sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_l2arecon1e-1dilate_recon2_148_othersretro_stage22000cont_flowrecon1e-1` 26 | 27 | 28 | ### Stage 3 29 | `python train_GAN.py --name sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_l2arecon1e-1dilate_recon2_148_othersretro_stage22000cont_flowrecon1e-1 --lambda_recon 2 --lambda_recon_flow 1e-1 --dataset_mode factormatte_GANCGANFlip148 --model factormatte_GANFlip --dataroot ./datasets/sand_car --save_by_epoch --prob_masks --lambda_rgb_warp 1e-1 --lambda_alpha_warp 1e-1 --residual_noise --strides 0,2,2 --num_Ds 0,1,3 --display_ind 63 --init_flowmask --lambda_recon_3 1e-1 --start_ind 0 --discriminator_transform randomcrop --steps 148 --pos_ex_dirs 0uniform_0gaussian_dark_0flip_0elastic_0.25blursigma0.20.2k5_0.25gaussian_noise_std27mean0_rawframes,0rot_0flip_0.25blursigma0.20.2k5_0.25gaussian_noise_std27mean0_ --stage 3` --height 192 --width 288 --n_epochs 3200 --gpu_ids 0 --continue_train --epoch 2400 --overwrite_lambdas 30 | 31 | ### Pretrained Weights 32 | For convenience, you can also download the weights of any stage for this dataset and start training from there on. 33 | 34 | Stage 1 weights: https://drive.google.com/drive/folders/1ERZQNM8nT2Xw9J2yzFzp3QoyxHFEZ7B4?usp=sharing 35 | 36 | Stage 2 weights: https://drive.google.com/drive/folders/1boJJ8DwPZxk9hzxUa-4nLW0vVPhXdhPL?usp=sharing 37 | 38 | Stage 3 weights: https://drive.google.com/drive/folders/1eDHuIsoON_ou_50sZ7nT4D3luiGxVvyx?usp=sharing 39 | 40 | ### Generate Results 41 | `python test.py --name sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_l2arecon1e-1dilate_recon2_148_othersretro_stage22000cont_flowrecon1e-1 --dataset_mode factormatte_GANCGANFlip148 --model factormatte_GANFlip --dataroot ./datasets/sand_car --gpu_ids 0 --prob_masks --residual_noise --pos_ex_dirs , --epoch 3200 --stage 3 --width 288 --height 192 --init_flowmask --test_suffix texts_to_put_after_fixed_folder_name` 42 | 43 | 44 | ## Custom Dataset 45 | To train on your custom video, please prepare it as follows: (Assume all file names are [xxxxx].png, e.g. 00001.png, 00100.png, 10001.png) 46 | 1. Extract all RGB frames and put them in "rgb" folder. 47 | 2. Arrange corresponding binary masks in the same order and put them in `mask/01` folder. 48 | 3. run `data/misc_data_process.py` to copy `mask/01` to `mask_nocushionmask/02`, and generate `mask_nocushionmask/01`. Please refer to the doc in data/misc_data_process.py for details. (Redundant, TODO: generate this on the fly.) 49 | 4. Estimate the homography between every two consecutive frames and flatten each matrix following the template of data/homographies.txt 50 | We provide a script in data/keypoint_homography_estimate.ipynb. It'll generate a file homographies_raw.txt. To get the final homographies.txt, run 51 | `python datasets/homography.py --homography_path ./datasets/[your_folder_name]/homographies_raw.txt --width [W] --height [H]` 52 | 53 | 5. Flow estimation by RAFT: 54 | `python video_completion.py --path datasets/[your_folder_name]/rgb --model weight/raft-things.pth --step 1` 55 |
56 | 57 | `python video_completion.py --path datasets/[your_folder_name]/rgb --model weight/raft-things.pth --step 4` 58 |
59 | 60 | `python video_completion.py --path datasets/[your_folder_name]/rgb --model weight/raft-things.pth --step 8` 61 |
62 | (As mentioned in section 7, we use multiple time scales (1, 4, 8) to reinforce consistency.) 63 |
64 | Move the generated flow matrices to your data folder: 65 | 66 | `mv RAFT_result/datasets[your_folder_name]rgb/*flow* datasets/[your_folder_name]` 67 | 68 | 6. Confidence estimate for flows: 69 | `python datasets/confidence.py --dataroot ./datasets/[your_folder_name] --step 1` 70 |
71 | 72 | `python datasets/confidence.py --dataroot ./datasets/[your_folder_name] --step 4` 73 |
74 | 75 | `python datasets/confidence.py --dataroot ./datasets/[your_folder_name] --step 8` 76 | 77 | 7. Find the simpler frames if you want to use the tricks in Section 7. Separate the frame indices as in `data/noninteraction_ind.txt`. If there's no such frames or you wish not to use such tricks, simply write "0, 1" in that file. 78 | 79 | 8. After Stage 1, run `python gen_foregroundPosEx.py` to generate positive examples for the foreground. Run `data/gen_backgroundPosEx.ipynb` to generate positive examples for the background. 80 | 81 | 9. In short, there should be these folders in data/[your_folder_name]: 82 |
83 | forward_flow_step1, forward_flow_step4, forward_flow_step8 84 |
85 | backward_flow_step1, backward_flow_step4, backward_flow_step8 86 |
87 | confidence_step1, confidence_step4, confidence_step8 88 |
89 | homographies.txt (if you use data/keypoint_homography_estimate.ipynb, there should also be a "homographies_raw.txt") 90 |
91 | mask_nocushionmask (2 subfolders: "01", "02") 92 |
93 | mask (1 subfolder containing the segmentaion mask of the foreground object: "01") 94 |
95 | noninteraction_ind.txt 96 |
97 | zbar.pth (Automatically generated to make sure the model starts with a fixed random noise.) 98 |
99 | dis_real_l1, dis_real_l2 (Generated after running Stage 1.) 100 | 101 | -------------------------------------------------------------------------------- /causal/discriminator.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision.models as models 3 | import torch.nn as nn 4 | import functools 5 | import torch.nn.functional as F 6 | import numpy as np 7 | 8 | 9 | # Defines the PatchGAN discriminator with the specified arguments. 10 | class NLayerDiscriminator_4MultiscaleDiscriminator(nn.Module): 11 | def __init__(self, args, input_nc, ndf, n_layers, s, norm_layer, use_sigmoid): 12 | super(NLayerDiscriminator_4MultiscaleDiscriminator, self).__init__() 13 | self.conv = nn.Conv2d 14 | self.n_layers = n_layers 15 | 16 | kw = 4 17 | padw = int(np.ceil((kw-1.0)/2)) 18 | sequence = [[self.conv(input_nc, ndf, kernel_size=kw, stride=s, padding=padw), \ 19 | nn.LeakyReLU(0.2, True)]] 20 | nf = ndf 21 | # start from 1 because already 1 layer, minus 1 because another layer in the end 22 | for n in range(1, n_layers-1): 23 | nf_prev = nf 24 | nf = min(nf * 2, 512) 25 | sequence += [[ 26 | self.conv(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), 27 | norm_layer(nf), 28 | nn.LeakyReLU(0.2, True) 29 | ]] 30 | 31 | sequence += [[self.conv(nf, 1, kernel_size=kw, stride=1, padding=padw)]] 32 | if use_sigmoid: 33 | sequence += [[nn.Sigmoid()]] 34 | 35 | sequence_stream = [] 36 | for n in range(len(sequence)): 37 | sequence_stream += sequence[n] 38 | self.model = nn.Sequential(*sequence_stream) 39 | 40 | 41 | class MultiscaleDiscriminator(nn.Module): 42 | def __init__(self, args, stride, num_D, n_layers, norm_layer=nn.BatchNorm2d, 43 | use_sigmoid=False, ndf=64): 44 | super(MultiscaleDiscriminator, self).__init__() 45 | self.num_D = num_D 46 | self.n_layers = n_layers 47 | self.args = args 48 | if args.rgba_GAN == 'RGBA': 49 | input_nc = 4 50 | if args.rgba_GAN == 'RGB': 51 | input_nc = 3 52 | elif args.rgba_GAN == 'A': 53 | input_nc = 1 #a+mask 54 | 55 | 56 | for i in range(num_D): 57 | print('Initializing', i, 'th-scale discriminator. n_layers', n_layers, 'ndf', ndf, 'stride', stride,\ 58 | norm_layer, use_sigmoid) 59 | netD = NLayerDiscriminator_4MultiscaleDiscriminator(args, input_nc, ndf, n_layers, stride, \ 60 | norm_layer, use_sigmoid) 61 | setattr(self, 'layer'+str(i), netD.model) 62 | 63 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False) 64 | 65 | def singleD_forward(self, model, x): 66 | return model(x).flatten(1) 67 | 68 | def forward(self, x): 69 | num_D = self.num_D 70 | result = [] 71 | result_valid = [] 72 | input_downsampled = x 73 | for i in range(num_D): 74 | model = getattr(self, 'layer'+str(num_D-1-i)) 75 | patches = self.singleD_forward(model, input_downsampled) 76 | result.append(patches) 77 | if i != (num_D-1): 78 | input_downsampled = self.downsample(input_downsampled) 79 | return torch.cat(result, 1) 80 | 81 | -------------------------------------------------------------------------------- /causal/gen.py: -------------------------------------------------------------------------------- 1 | import cv2 as cv 2 | import torch 3 | import torchvision.models as models 4 | from torchvision import transforms as T 5 | 6 | import os 7 | from PIL import Image, ImageFilter 8 | import numpy as np 9 | import scipy.ndimage 10 | 11 | def prep_data(index): 12 | name = ('00'+str(index))[-4:]+'.png' 13 | gt = np.asarray(Image.open('../datasets/cushion_birdeye_texturecolor_suzanne_nocushionmask_3layer/rgb/'+name).convert('RGBA')) 14 | mask = np.asarray(Image.open('../datasets/cushion_birdeye_texturecolor_suzanne_nocushionmask_3layer/mask/01/seg'+name)) 15 | mask[mask!=0] = 1 16 | mask = np.expand_dims(mask,-1) 17 | cube = np.where(mask!=0) 18 | up, down = cube[0].min(), cube[0].max() 19 | left, right = cube[1].min(), cube[1].max() 20 | fg = gt * mask 21 | return [left, right, up, down], fg 22 | 23 | def one_warp(fg, boundaries): 24 | left, right, up, down = boundaries 25 | src_pts = np.array([[left,up],[right,up], [right,down], [left,down]]) 26 | h = down - up 27 | w = right - left 28 | left1 = left + np.random.uniform(-0.5*w, 0.5*w) 29 | left2 = left + np.random.uniform(-0.5*w, 0.5*w) 30 | right1 = right + np.random.uniform(-0.5*w, 0.5*w) 31 | right2 = right + np.random.uniform(-0.5*w, 0.5*w) 32 | up1 = up + np.random.uniform(-0.5*h, 0.5*h) 33 | up2 = up + np.random.uniform(-0.5*h, 0.5*h) 34 | down1 = down + np.random.uniform(-0.5*h, 0.5*h) 35 | down2 = down + np.random.uniform(-0.5*h, 0.5*h) 36 | leftstart = np.random.uniform(-0.4 * (448-w), 0.4 * (448-w)) 37 | upstart = np.random.uniform(-0.4 * (256-h), 0.4 * (256-h)) 38 | # print(leftstart,upstart) 39 | dst_pts = np.array([[left1+leftstart,up1+upstart],[right1+leftstart,up2+upstart],[right2+leftstart,down1+upstart], [left2+leftstart,down2+upstart]]) 40 | M, _ = cv.findHomography(src_pts, dst_pts, cv.RANSAC,5.0) 41 | out = cv.warpPerspective(fg, M, (448, 256), flags=cv.INTER_LINEAR) 42 | return out 43 | 44 | def gen_fg(): 45 | save_dir = 'classifier_dataset/cushion_birdeye_texturecolor_suzanne/test/fg' 46 | for n in range(1000): 47 | print(n) 48 | ind = np.random.randint(80, high=221) 49 | boundaries, fg = prep_data(ind) 50 | num = np.random.randint(0, high=6) 51 | grey = np.ones((256, 448,4))*255 52 | grey[:,:,:3]=np.random.randint(0, high=256) 53 | canvas = Image.fromarray(grey.astype('uint8')) 54 | 55 | for i in range(num): 56 | out = one_warp(fg, boundaries) 57 | alpha = np.random.rand() 58 | out[:,:,-1]=(out[:,:,-1]*alpha).astype('uint8') 59 | canvas = Image.alpha_composite(canvas, Image.fromarray(out)) 60 | blur = np.random.rand() 61 | if blur >0.5: 62 | r = np.random.uniform(low=0, high=5.5) 63 | canvas = canvas.filter(ImageFilter.GaussianBlur(radius = r)) 64 | canvas_np = np.asarray(canvas) 65 | for j in range(256): 66 | for k in range(448): 67 | if fg[j,k,-1]==0: 68 | fg[j,k]=canvas_np[j,k] 69 | fg_img = Image.fromarray(fg).convert('RGB') 70 | fg_img.save(os.path.join(save_dir, str(n)+'_from'+str(ind)+'_test.png')) 71 | 72 | 73 | def bg_warp(bg, boundaries): 74 | left, right, up, down = boundaries 75 | src_pts = np.array([[left,up],[right,up], [right,down], [left,down]]) 76 | h = 256 77 | w = 448 78 | left1 = left + np.random.uniform(-0.15*w, 0.15*w) 79 | left2 = left + np.random.uniform(-0.15*w, 0.15*w) 80 | right1 = right + np.random.uniform(-0.15*w, 0.15*w) 81 | right2 = right + np.random.uniform(-0.15*w, 0.15*w) 82 | up1 = up + np.random.uniform(-0.15*h, 0.15*h) 83 | up2 = up + np.random.uniform(-0.15*h, 0.15*h) 84 | down1 = down + np.random.uniform(-0.15*h, 0.15*h) 85 | down2 = down + np.random.uniform(-0.15*h, 0.15*h) 86 | 87 | dst_pts = np.array([[left1,up1],[right1,up2],[right2,down1], [left2,down2]]) 88 | M, _ = cv.findHomography(src_pts, dst_pts, cv.RANSAC, 5.0) 89 | out = cv.warpPerspective(bg, M, (448, 256), borderMode=cv.BORDER_WRAP, flags=cv.INTER_LINEAR) #[up:down, left:right] 90 | return out 91 | 92 | 93 | save_dir = 'classifier_dataset/cushion_birdeye_texturecolor_suzanne/train/bg/' 94 | gt = np.asarray(Image.open('../datasets/cushion_birdeye_texturecolor_suzanne_nocushionmask_3layer/bg_gt.png').convert('RGBA')) 95 | left = 115 96 | right = 302 97 | up = 33 98 | down = 200 99 | boundaries = [left, right, up, down] 100 | for n in range(5000): 101 | print(n) 102 | ind = np.random.randint(80, high=221) 103 | out = bg_warp(gt, boundaries) 104 | alpha = np.random.uniform(0.85, high=1) 105 | out[:,:,-1]=(out[:,:,-1]*alpha).astype('uint8') 106 | canvas = Image.fromarray(out).convert('RGB') 107 | blur = np.random.rand() 108 | if blur >0.5: 109 | r = np.random.uniform(low=0, high=5.5) 110 | canvas = canvas.filter(ImageFilter.GaussianBlur(radius = r)) 111 | canvas.save(os.path.join(save_dir, str(n)+'_from'+str(ind)+'_bg_train.png')) -------------------------------------------------------------------------------- /data/gen_foregroundPosEx.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0, '/home/zg45/FactorMatte') 3 | import torch 4 | import torchvision.models as models 5 | from torchvision import transforms as T 6 | 7 | import os 8 | from PIL import Image, ImageFilter 9 | import numpy as np 10 | import scipy as sp 11 | import scipy.signal 12 | from shutil import copyfile 13 | import cv2 as cv 14 | from third_party.data.image_folder import make_dataset 15 | 16 | 17 | def prep_data(basedir, index): 18 | rgb_paths = sorted(make_dataset(os.path.join(basedir, 'rgb'))) 19 | # mask_paths = sorted(make_dataset(os.path.join(basedir, 'l2_fake_real_comp_mask'))) 20 | mask_paths = sorted(make_dataset(os.path.join(basedir, 'mask_nocushionmask/02/'))) 21 | gt = np.asarray(Image.open(rgb_paths[index]).convert('RGBA')).astype('float') 22 | mask = np.asarray(Image.open(mask_paths[index]).convert('L')).astype('float')/255 23 | mask[mask != 1.] = 0 24 | if abs(mask).sum() == 0: 25 | return None, None 26 | # if 'composited' in basedir: 27 | # Optionally erode to be conservative 28 | # mask = cv.erode(mask, kernel=np.ones((12, 12)), iterations=1) 29 | mask = np.expand_dims(mask, -1) 30 | cube = np.where(mask != 0) 31 | up, down = cube[0].min(), cube[0].max() 32 | left, right = cube[1].min(), cube[1].max() 33 | fg = np.clip(gt * mask, 0, 255) 34 | return [left, right, up, down], fg.astype('uint8') 35 | 36 | def add_reflection(img, surface=140, alpha_range=[0, 0.75]): 37 | alpha = np.random.uniform(alpha_range[0], high=alpha_range[1]) 38 | print(alpha) 39 | h, w, _ = img.shape 40 | start = abs(h - 2*surface) 41 | img[surface:] = alpha * img[start:surface].copy()[::-1] 42 | return img 43 | 44 | # def add_blur(img, sigma_range=[10, 20]): 45 | # sigma = 0 46 | # while sigma % 2 == 0: 47 | # # GaussianBlur only accepts odd kernel size 48 | # sigma = np.random.randint(sigma_range[0], high=sigma_range[1]) 49 | # img = cv.GaussianBlur(img, (sigma, sigma), sigma/4 , borderType = cv.BORDER_REPLICATE) 50 | # return img 51 | 52 | def add_blur_1(img, sigma_range=[0.2, 1], kernel=5): 53 | img = img.astype("int16") 54 | std = np.random.uniform(sigma_range[0], high=sigma_range[1]) 55 | blur_img = cv.GaussianBlur(img, (kernel, kernel), std, borderType = cv.BORDER_REPLICATE) 56 | blur_img = ceil_floor_image(blur_img) 57 | return blur_img 58 | 59 | def ceil_floor_image(image): 60 | """ 61 | Args: 62 | image : numpy array of image in datatype int16 63 | Return : 64 | image : numpy array of image in datatype uint8 with ceilling(maximum 255) and flooring(minimum 0) 65 | """ 66 | image[image > 255] = 255 67 | image[image < 0] = 0 68 | image = image.astype("uint8") 69 | return image 70 | 71 | def add_noise(img, std_range=[0, 20], mean=0): 72 | std = np.random.randint(std_range[0], high=std_range[1]) 73 | print('std', std) 74 | gaussian_noise = np.random.normal(mean, std, img.shape) 75 | img = img.astype("int16") 76 | noise_img = img + gaussian_noise 77 | noise_img = ceil_floor_image(noise_img) 78 | return noise_img 79 | 80 | def flip(img): 81 | p = np.random.uniform() 82 | if p<0.5: 83 | img = img[:, ::-1] 84 | else: 85 | img = img[::-1, :] 86 | return img 87 | 88 | def rotate(img): 89 | deg = np.random.randint(0, 360) 90 | img = sp.ndimage.rotate(img, deg, reshape=False) 91 | return img 92 | 93 | def gen_pos_ex_fg(basedir, ind_low, ind_high, add_rot, add_flip, add_blurr_or_noise, add_gaussian_noise, \ 94 | num=5000, blur_kwargs=None, noise_kwargs=None, folder_suffix=''): 95 | """ 96 | ind high exclusive 97 | """ 98 | save_dir = os.path.join(basedir, 'dis_real_l2', '_'.join([str(add_rot) + 'rot', str(add_flip) + 'flip', \ 99 | str((1-add_gaussian_noise)*add_blurr_or_noise) + 'blursigma' + str(blur_kwargs['sigma_range'][0]) + str(blur_kwargs['sigma_range'][0])+'k'+str(blur_kwargs['kernel']),\ 100 | str(add_gaussian_noise*add_blurr_or_noise)+'gaussian_noise_std'+str(noise_kwargs['std_range'][0])+str(noise_kwargs['std_range'][1])+'mean'+str(noise_kwargs['mean']), folder_suffix])) 101 | os.makedirs(save_dir) 102 | for n in range(0, num): 103 | print(n) 104 | fg = None 105 | while fg is None: 106 | ind = np.random.randint(0, high=ind_high-ind_low+1) 107 | boundaries, fg = prep_data(basedir, ind) 108 | 109 | decision = np.random.uniform(size=5) 110 | print(decision) 111 | # if decision[4] < 0.5: 112 | # scale = np.random.uniform(0.2, 1.2) 113 | # print('scale', scale) 114 | # scaled = scale * fg[:,:,:3].astype('int') 115 | # fg[:,:,:3] = np.clip(scaled, 0, 255).astype('uint8') 116 | if decision[0] < add_rot: 117 | print('rot') 118 | fg = rotate(fg) 119 | if decision[1] < add_flip: 120 | print('flip') 121 | fg = flip(fg) 122 | 123 | h, w, _ = fg.shape 124 | grey = np.ones((h, w, 4))*255 125 | grey[:,:,0]=0 126 | grey[:,:,1]=255 127 | grey[:,:,2]=0 128 | # grey[:,:,:3]=np.random.randint(0, high=80) 129 | canvas = Image.fromarray(grey.astype('uint8')) 130 | canvas_np = np.asarray(canvas) 131 | for j in range(h): 132 | for k in range(w): 133 | if fg[j, k, -1] == 0: 134 | fg[j, k] = canvas_np[j,k] 135 | 136 | if decision[2] < add_blurr_or_noise: 137 | if decision[3] < add_gaussian_noise: 138 | print('gaussian noise') 139 | fg = add_noise(fg, **noise_kwargs) 140 | # blur and noise are exclusive 141 | else: 142 | print('blurr, using add_blur_1') 143 | fg = add_blur_1(fg, **blur_kwargs) 144 | fg_img = Image.fromarray(fg).convert('RGB') 145 | fg_img.save(os.path.join(save_dir, '_'.join([str(n), 'from', str(ind+ind_low), 'fg', folder_suffix])+'.png')) 146 | 147 | 148 | 149 | if __name__ == '__main__': 150 | datadir = 'datasets/composited/cloth/cloth_grail_5152' 151 | video_start_ind = 0 152 | video_end_ind = 249 153 | 154 | # The probability of applying each augmentation during the generation of each positive example 155 | add_rot = 0 #0.5 156 | add_flip = 0 #0.5 157 | # reflec_kwargs= { 158 | # 'alpha_range': [0.1, 0.7], 159 | # 'surface': 140 160 | # } 161 | 162 | add_blurr_or_noise = 0.5 163 | blur_kwargs= {'sigma_range':[0.2, 1], 'kernel': 5} 164 | add_gaussian_noise = 0.5 165 | gaussian_noise_kwargs= {'std_range':[2, 7], 'mean':0} 166 | 167 | gen_pos_ex_fg(datadir, video_start_ind, video_end_ind, add_rot, add_flip, \ 168 | add_blurr_or_noise, add_gaussian_noise, num=1200, blur_kwargs=blur_kwargs, \ 169 | noise_kwargs=gaussian_noise_kwargs, folder_suffix='') 170 | 171 | -------------------------------------------------------------------------------- /data/keypoint_homo_short.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "74655d97", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "import cv2 \n", 11 | "import numpy as np \n", 12 | "import torch\n", 13 | "import torchvision.models as models\n", 14 | "from torchvision import transforms as T\n", 15 | "\n", 16 | "import os\n", 17 | "from PIL import Image, ImageFilter \n", 18 | "import numpy as np\n", 19 | "import matplotlib.pyplot as plt\n", 20 | "from matplotlib.pyplot import imshow\n", 21 | "import scipy\n", 22 | "import scipy.ndimage\n", 23 | "from scipy import ndimage\n", 24 | "from shutil import copyfile" 25 | ] 26 | }, 27 | { 28 | "cell_type": "code", 29 | "execution_count": null, 30 | "id": "fa93287f", 31 | "metadata": {}, 32 | "outputs": [], 33 | "source": [ 34 | "def feature_mask(mask, i):\n", 35 | " \"\"\"\n", 36 | " Because the edges of a segmentaation mask may be inacurate, we dilate it for \n", 37 | " the subsequent feature matching. \n", 38 | " The feature matcher will only look for correspondence points outside the mask.\n", 39 | " \"\"\"\n", 40 | " h, w = mask.shape\n", 41 | " obj = np.nonzero(mask)\n", 42 | " if len(obj[0]) == 0:\n", 43 | " # occlusion\n", 44 | " print('occlusion!')\n", 45 | " return 255 - mask\n", 46 | " mask_homo = np.ones_like(mask)*255\n", 47 | " up, down = obj[0].min(), obj[0].max() \n", 48 | " left, right = obj[1].min(), obj[1].max() \n", 49 | " kernel = np.ones((50, 50), np.uint8)\n", 50 | " mask = cv2.dilate(mask, kernel, iterations=1)\n", 51 | " mask_homo -= mask\n", 52 | " # You can adjust the edge erosion here to inlude more regions or less\n", 53 | "# mask_homo[max(0, up-50): min(h, down+60), max(0, left-20): min(w, right+20)]=0\n", 54 | " return mask_homo" 55 | ] 56 | }, 57 | { 58 | "cell_type": "code", 59 | "execution_count": null, 60 | "id": "4550b7cc", 61 | "metadata": { 62 | "scrolled": true 63 | }, 64 | "outputs": [], 65 | "source": [ 66 | "# the dataset folder\n", 67 | "dataset = \"composited/cloth/cloth_grail_5152\"\n", 68 | "img_f = sorted(os.listdir(os.path.join(\"../datasets/\", dataset, \"rgb\")))[0]\n", 69 | "print(img_f)\n", 70 | "img1 = cv2.imread(os.path.join(\"../datasets/\", dataset, \"rgb\", img_f))\n", 71 | "old_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)\n", 72 | "\n", 73 | "h, w, _ = img1.shape\n", 74 | "img1_acc_mask = 0\n", 75 | "for mask_ind in os.listdir(os.path.join(\"../datasets/\", dataset, \"mask\")):\n", 76 | " print(mask_ind)\n", 77 | " mask_f = sorted(os.listdir(os.path.join(\"../datasets/\", dataset, \"mask/\", mask_ind)))[0]\n", 78 | " print(mask_f)\n", 79 | " mask_i = cv2.imread(os.path.join(\"../datasets/\", dataset, \"mask/\", mask_ind, mask_f))\n", 80 | " img1_acc_mask += mask_i\n", 81 | "img1_mask = feature_mask(cv2.cvtColor(img1_acc_mask, cv2.COLOR_BGR2GRAY),0)\n", 82 | "imshow(img1_mask)\n", 83 | "plt.show()\n", 84 | "sift = cv2.SIFT_create()\n", 85 | "# FLANN parameters\n", 86 | "FLANN_INDEX_KDTREE = 0\n", 87 | "index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)\n", 88 | "search_params = dict(checks=50) # or pass empty dictionary\n", 89 | "flann = cv2.FlannBasedMatcher(index_params,search_params)\n", 90 | "\n", 91 | "start_matrix = np.identity(3)\n", 92 | "with open(os.path.join(\"../datasets/\", dataset, 'homographies_raw.txt'), 'w') as f:\n", 93 | " for i in range(len(os.listdir(os.path.join(\"../datasets/\", dataset, \"rgb\")))):\n", 94 | " img_f = sorted(os.listdir(os.path.join(\"../datasets/\", dataset, \"rgb\")))[i]\n", 95 | " frame = cv2.imread(os.path.join(\"../datasets/\", dataset, \"rgb\", img_f))\n", 96 | " frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)\n", 97 | "\n", 98 | "\n", 99 | " frame_acc_mask = 0\n", 100 | " # if there are multiple objects, collate their masks together, so that \n", 101 | " # feature masking will consider the points within none of them.\n", 102 | " # subfolders inside \"mask\" should be names as \"01\", \"02\", ...\n", 103 | " for mask_ind in os.listdir(os.path.join(\"../datasets/\", dataset, \"mask\")):\n", 104 | " mask_f = sorted(os.listdir(os.path.join(\"../datasets/\", dataset, \"mask/\", mask_ind)))[i]\n", 105 | " mask_i = cv2.imread(os.path.join(\"../datasets/\", dataset, \"mask/\", mask_ind, mask_f))\n", 106 | " frame_acc_mask += mask_i\n", 107 | " frame_mask = feature_mask(cv2.cvtColor(frame_acc_mask, cv2.COLOR_BGR2GRAY), i)\n", 108 | " imshow(frame_mask) \n", 109 | " plt.show()\n", 110 | " # find correspondence points between 2 frames by SIFT features.\n", 111 | " kp1, des1 = sift.detectAndCompute(old_gray, img1_mask)\n", 112 | " kp2, des2 = sift.detectAndCompute(frame_gray, frame_mask)\n", 113 | " matches = flann.knnMatch(des1,des2,k=2)\n", 114 | " tmp1 = cv2.drawKeypoints(old_gray, kp1, old_gray)\n", 115 | " tmp2 = cv2.drawKeypoints(frame_gray, kp2, frame_gray)\n", 116 | " plt.imshow(tmp1)\n", 117 | " plt.show()\n", 118 | " plt.imshow(tmp2)\n", 119 | " plt.show()\n", 120 | " good_points=[] \n", 121 | " for m, n in matches: \n", 122 | " good_points.append((m, m.distance/n.distance)) \n", 123 | " # sort the correspondence points by confidence, by default we only use the best 50.\n", 124 | " good_points.sort(key=lambda y: y[1])\n", 125 | " query_pts = np.float32([kp1[m.queryIdx] \n", 126 | " .pt for m,d in good_points[:50]]).reshape(-1, 1, 2) \n", 127 | "\n", 128 | " train_pts = np.float32([kp2[m.trainIdx] \n", 129 | " .pt for m,d in good_points[:50]]).reshape(-1, 1, 2) \n", 130 | " print('len(query_pts)',len(query_pts))\n", 131 | " # compute homography by the correspondence pairs\n", 132 | " matrix, matrix_mask = cv2.findHomography(query_pts, train_pts, cv2.RANSAC, 5.0) \n", 133 | " inliers = matrix_mask.sum()\n", 134 | " print(i, inliers, matrix)\n", 135 | " start_matrix = matrix @ start_matrix\n", 136 | " f.write(' '.join([str(i) for i in start_matrix.flatten()])+'\\n')\n", 137 | " imshow(frame_mask) \n", 138 | " plt.show()\n", 139 | " dst = cv2.warpPerspective(img1, start_matrix, (w, h), flags=cv2.INTER_LINEAR)\n", 140 | " imshow(dst) \n", 141 | " plt.show()\n", 142 | " dst = cv2.warpPerspective(old_gray, matrix, (w, h), flags=cv2.INTER_LINEAR)\n", 143 | " imshow(dst) \n", 144 | " plt.show()\n", 145 | " old_gray = frame_gray.copy()\n", 146 | " img1_mask = frame_mask.copy()\n", 147 | " imshow(frame_gray) \n", 148 | " plt.show()" 149 | ] 150 | } 151 | ], 152 | "metadata": { 153 | "kernelspec": { 154 | "display_name": "Python 3 (ipykernel)", 155 | "language": "python", 156 | "name": "python3" 157 | }, 158 | "language_info": { 159 | "codemirror_mode": { 160 | "name": "ipython", 161 | "version": 3 162 | }, 163 | "file_extension": ".py", 164 | "mimetype": "text/x-python", 165 | "name": "python", 166 | "nbconvert_exporter": "python", 167 | "pygments_lexer": "ipython3", 168 | "version": "3.9.12" 169 | } 170 | }, 171 | "nbformat": 4, 172 | "nbformat_minor": 5 173 | } 174 | -------------------------------------------------------------------------------- /data/misc_data_process.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from PIL import Image 4 | import numpy as np 5 | import shutil 6 | 7 | 8 | def gen_black_l1mask(p): 9 | """ 10 | Assuming there's only 1 foreground object, copy its mask/01 folder 11 | to mask_nocushionmask/02. 12 | Then generate black images of the same size and same name as those in 13 | mask_nocushionmask/02 and put into mask_nocushionmask/01, which is used 14 | as the initialization of masks for the residual layer. 15 | 16 | The composition order is homography background, residual, then foreground. 17 | So the residual layer has index 1 and foreground layer's index changes to 2. 18 | 19 | TODO: the name "nocushionmask" is outdated and has no particular meaning now. 20 | 21 | Args: 22 | p (_type_): _description_ 23 | """ 24 | os.makedirs(os.path.join(p, 'mask_nocushionmask/01'), exist_ok=False) 25 | shutil.copytree(os.path.join(p, 'mask/01'), os.path.join(p, 'mask_nocushionmask/02')) 26 | for f in os.listdir(os.path.join(p, 'mask_nocushionmask/02')): 27 | if 'png' in f: 28 | print(f) 29 | img = Image.open(os.path.join(p, 'mask_nocushionmask/02', f)) 30 | zeros = np.zeros_like(np.array(img)).astype('uint8') 31 | zeros_img = Image.fromarray(zeros) 32 | zeros_img.save(os.path.join(p, 'mask_nocushionmask/01', f)) 33 | 34 | def real_video_rgba_a(source_dir, dest_dir): 35 | """ 36 | Given RGBA images in source_dir, extract the Alpha channel and store in dest_dir. 37 | Used after Stage 1 if you want to manually clean up some predicted alphas. 38 | """ 39 | os.makedirs(dest_dir, exist_ok=False) 40 | for f in os.listdir(source_dir): 41 | if '.png' in f: 42 | print(f) 43 | img_a = np.asarray(Image.open(os.path.join(source_dir, f)))[:,:,-1] 44 | Image.fromarray(img_a).save(os.path.join(dest_dir, f)) 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | # video completion 50 | parser.add_argument('--dataroot', type=str, help='dataroot') 51 | args = parser.parse_args() 52 | 53 | gen_black_l1mask(args.dataroot) 54 | # real_video_rgba_a('results/lucia_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2/test_600_/images/rgba_l1/', 'datasets/lucia/dis_gt_alpha_stage2_res') -------------------------------------------------------------------------------- /data/noninteraction_ind.txt: -------------------------------------------------------------------------------- 1 | 0, 40 -------------------------------------------------------------------------------- /datasets/confidence.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Erika Lu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Generate confidence maps from optical flow.""" 17 | import os 18 | import sys 19 | sys.path.append('.') 20 | from utils import readFlow, numpy2im 21 | import glob 22 | from PIL import Image 23 | import numpy as np 24 | import torch 25 | import torch.nn.functional as F 26 | 27 | 28 | def compute_confidence(flo_f, flo_b, rgb, thresh=1, thresh_p=20): 29 | """Compute confidence map from optical flow.""" 30 | im_height, im_width = flo_f.shape[:2] 31 | identity_grid = np.expand_dims(create_grid(im_height, im_width), 0) 32 | warp_b = flo_b[np.newaxis] + identity_grid 33 | warp_f = flo_f[np.newaxis] + identity_grid 34 | warp_b = map_coords(warp_b, im_height, im_width) 35 | warp_f = map_coords(warp_f, im_height, im_width) 36 | identity_grid = identity_grid.transpose(0, 3, 1, 2) 37 | warped_1 = F.grid_sample(torch.from_numpy(identity_grid), torch.from_numpy(warp_b), align_corners=True) 38 | warped_2 = F.grid_sample(warped_1, torch.from_numpy(warp_f), align_corners=True).numpy() 39 | err = np.linalg.norm(warped_2 - identity_grid, axis=1) 40 | err[err > thresh] = thresh 41 | err /= thresh 42 | confidence = 1 - err 43 | 44 | rgb = np.expand_dims(rgb.transpose(2, 0, 1), 0) 45 | rgb_warped_1 = F.grid_sample(torch.from_numpy(rgb).double(), torch.from_numpy(warp_b), align_corners=True) 46 | rgb_warped_2 = F.grid_sample(rgb_warped_1, torch.from_numpy(warp_f), align_corners=True).numpy() 47 | err = np.linalg.norm(rgb_warped_2 - rgb, axis=1) 48 | confidence_p = (err < thresh_p).astype(np.float32) 49 | confidence *= confidence_p 50 | 51 | return confidence[0] 52 | 53 | 54 | def map_coords(coords, height, width): 55 | """Map coordinates from pixel-space to [-1, 1] range for torch's grid_sample function.""" 56 | coords_mapped = np.stack([coords[..., 0] / (width - 1), coords[..., 1] / (height - 1)], -1) 57 | return coords_mapped * 2 - 1 58 | 59 | 60 | def create_grid(height, width): 61 | ramp_u, ramp_v = np.meshgrid(np.linspace(0, width - 1, width), np.linspace(0, height - 1, height)) 62 | return np.stack([ramp_u, ramp_v], -1) 63 | 64 | 65 | if __name__ == "__main__": 66 | import argparse 67 | arguments = argparse.ArgumentParser() 68 | arguments.add_argument('--dataroot', type=str) 69 | arguments.add_argument('--step', default=1, type=int) 70 | opt = arguments.parse_args() 71 | 72 | forward_flo = sorted(glob.glob(os.path.join(opt.dataroot, 'forward_flow_step'+str(opt.step), '*.flo'))) 73 | backward_flo = sorted(glob.glob(os.path.join(opt.dataroot, 'backward_flow_step'+str(opt.step), '*.flo'))) 74 | assert(len(forward_flo) == len(backward_flo)) 75 | rgb_paths = sorted(glob.glob(os.path.join(opt.dataroot, 'rgb', '*'))) 76 | print(f'generating {len(forward_flo)} confidence maps...from', '_flow_step'+str(opt.step)) 77 | outdir = os.path.join(opt.dataroot, 'confidence_step'+str(opt.step)) 78 | os.makedirs(outdir, exist_ok=True) 79 | for i in range(len(forward_flo)): 80 | flo_f = readFlow(forward_flo[i]) 81 | flo_b = readFlow(backward_flo[i]) 82 | rgb = np.array(Image.open(rgb_paths[i])) 83 | confidence = compute_confidence(flo_f, flo_b, rgb) 84 | fp = os.path.join(outdir, f'{i+1:04d}.png') 85 | im = numpy2im(confidence) 86 | im.save(fp) 87 | print(f'saved {len(forward_flo)} confidence maps to {outdir}') 88 | -------------------------------------------------------------------------------- /datasets/homography.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Erika Lu 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | 16 | """Helper tools for computing the world bounds from homographies.""" 17 | import os 18 | import sys 19 | sys.path.append('.') 20 | from utils import readFlow, numpy2im 21 | # import glob 22 | from PIL import Image 23 | import numpy as np 24 | import torch 25 | import torch.nn.functional as F 26 | 27 | 28 | def transform2h(x, y, m): 29 | """Applies 2d homogeneous transformation.""" 30 | A = np.dot(m, np.array([x, y, np.ones(len(x))])) 31 | xt = A[0, :] / A[2, :] 32 | yt = A[1, :] / A[2, :] 33 | return xt, yt 34 | 35 | 36 | def compute_world_bounds(homographies, height, width): 37 | """Compute minimum and maximum coordinates. 38 | 39 | homographies - list of 3x3 numpy arrays 40 | height, width - video dimensions 41 | """ 42 | xbounds = [0, width - 1] 43 | ybounds = [0, height - 1] 44 | 45 | for h in homographies: 46 | # find transformed image bounding box 47 | x = np.array([0, width - 1, 0, width - 1]) 48 | y = np.array([0, 0, height - 1, height - 1]) 49 | [xt, yt] = transform2h(x, y, np.linalg.inv(h)) 50 | xbounds[0] = min(xbounds[0], min(xt)) 51 | xbounds[1] = max(xbounds[1], max(xt)) 52 | ybounds[0] = min(ybounds[0], min(yt)) 53 | ybounds[1] = max(ybounds[1], max(yt)) 54 | 55 | return xbounds, ybounds 56 | 57 | 58 | if __name__ == "__main__": 59 | import argparse 60 | arguments = argparse.ArgumentParser() 61 | arguments.add_argument('--homography_path', type=str, help='path to text file containing homographies') 62 | arguments.add_argument('--width', type=int, help='video width') 63 | arguments.add_argument('--height', type=int, help='video height') 64 | opt = arguments.parse_args() 65 | 66 | with open(opt.homography_path) as f: 67 | lines = f.readlines() 68 | homographies = [l.rstrip().split(' ') for l in lines] 69 | homographies = [[float(h) for h in l] for l in homographies] 70 | homographies = [np.array(H).reshape(3, 3) for H in homographies] 71 | xbounds, ybounds = compute_world_bounds(homographies, opt.height, opt.width) 72 | out_path = f'{opt.homography_path[:-8]}.txt' 73 | with open(out_path, 'w') as f: 74 | f.write(f'size: {opt.width} {opt.height}\n') 75 | f.write(f'bounds: {xbounds[0]} {xbounds[1]} {ybounds[0]} {ybounds[1]}\n') 76 | f.writelines(lines) 77 | print(f'saved {out_path}') 78 | -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__init__.py -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/base_options.cpython-38.pyc -------------------------------------------------------------------------------- /options/__pycache__/base_options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/base_options.cpython-39.pyc -------------------------------------------------------------------------------- /options/__pycache__/test_options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/test_options.cpython-39.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/train_options.cpython-38.pyc -------------------------------------------------------------------------------- /options/__pycache__/train_options.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/train_options.cpython-39.pyc -------------------------------------------------------------------------------- /options/base_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from third_party.util import util 4 | from third_party import models 5 | from third_party import data 6 | import torch 7 | import json 8 | 9 | 10 | class BaseOptions: 11 | """This class defines options used during both training and test time. 12 | 13 | It also implements several helper functions such as parsing, printing, and saving the options. 14 | It also gathers additional options defined in functions in both dataset class and model class. 15 | """ 16 | 17 | def __init__(self): 18 | """Reset the class; indicates the class hasn't been initialized""" 19 | self.initialized = False 20 | 21 | def initialize(self, parser): 22 | """Define the common options that are used in both training and test.""" 23 | # basic parameters 24 | parser.add_argument( 25 | "--dataroot", 26 | required=True, 27 | help="path to images (should have subfolders rgb_256, etc)", 28 | ) 29 | parser.add_argument( 30 | "--name", 31 | type=str, 32 | default="experiment_name", 33 | help="name of the experiment. It decides where to store samples and models", 34 | ) 35 | parser.add_argument( 36 | "--gpu_ids", 37 | type=str, 38 | default="0", 39 | help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU", 40 | ) 41 | parser.add_argument( 42 | "--checkpoints_dir", 43 | type=str, 44 | default="./1110_checkpoints", 45 | help="models are saved here", 46 | ) 47 | parser.add_argument("--seed", type=int, default=35, help="initial random seed") 48 | # model parameters 49 | parser.add_argument( 50 | "--model", 51 | type=str, 52 | default="factormatte_GANFlip", 53 | help="chooses which model to use. [lnr | kp2uv]", 54 | ) 55 | parser.add_argument( 56 | "--num_filters", 57 | type=int, 58 | default=64, 59 | help="# filters in the first and last conv layers", 60 | ) 61 | # dataset parameters 62 | parser.add_argument( 63 | "--coarseness", 64 | type=int, 65 | default=10, 66 | help="Coarness of background offset interpolation", 67 | ) 68 | parser.add_argument( 69 | "--max_frames", 70 | type=int, 71 | default=200, 72 | help="Similar meaning as max_dataset_size but cannot be infinite for background interpolation.", 73 | ) 74 | parser.add_argument( 75 | "--dataset_mode", 76 | type=str, 77 | default="factormatte_GANCGANFlip148", 78 | help="chooses how datasets are loaded.", 79 | ) 80 | parser.add_argument( 81 | "--serial_batches", 82 | action="store_true", 83 | help="if true, takes images in order to make batches, otherwise takes them randomly", 84 | ) 85 | parser.add_argument( 86 | "--num_threads", default=4, type=int, help="# threads for loading data" 87 | ) 88 | parser.add_argument( 89 | "--batch_size", type=int, default=8, help="input batch size" 90 | ) 91 | parser.add_argument( 92 | "--max_dataset_size", 93 | type=int, 94 | default=float("inf"), 95 | help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.", 96 | ) 97 | parser.add_argument( 98 | "--display_winsize", 99 | type=int, 100 | default=256, 101 | help="display window size for both visdom and HTML", 102 | ) 103 | # additional parameters 104 | parser.add_argument( 105 | "--epoch", 106 | type=str, 107 | default="latest", 108 | help="which epoch to load? set to latest to use latest cached model", 109 | ) 110 | parser.add_argument( 111 | "--verbose", 112 | action="store_true", 113 | help="if specified, print more debugging information", 114 | ) 115 | parser.add_argument( 116 | "--suffix", 117 | default="", 118 | type=str, 119 | help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}", 120 | ) 121 | parser.add_argument( 122 | "--prob_masks", 123 | action="store_true", 124 | help="if true, use 1 over #layer probability mask initialization, otherwise binary", 125 | ) 126 | parser.add_argument( 127 | "--rgba", 128 | default="L", 129 | type=str, 130 | help="If true the input FG is RGBA, RGB, or L only.", 131 | ) 132 | parser.add_argument( 133 | "--rgba_GAN", 134 | default="RGB", 135 | type=str, 136 | help="If true the input to the GAN discriminator is RGBA, RGB, or L only. Only used when there exists GAN, not CGAN.", 137 | ) 138 | parser.add_argument( 139 | "--residual_noise", 140 | action="store_true", 141 | help="if true, use random noise for Z initialization.", 142 | ) 143 | parser.add_argument( 144 | "--bg_noise", 145 | action="store_true", 146 | help="if true, use random noise for background Z initialization.", 147 | ) 148 | parser.add_argument( 149 | "--no_bg", 150 | action="store_true", 151 | help="If true exclude the bg layer as defined in the original Omnimatte.", 152 | ) 153 | parser.add_argument( 154 | "--orderscale", 155 | action="store_true", 156 | help="if true, keep the original Omnimatte's version of mask scaling.", 157 | ) 158 | parser.add_argument( 159 | "--steps", 160 | type=str, 161 | default="1", 162 | help="X steps apart to consider. Specify without space.", 163 | ) 164 | parser.add_argument( 165 | "--noninter_only", 166 | action="store_true", 167 | help="if true, only use nonteractive frames of the video.", 168 | ) 169 | parser.add_argument( 170 | "--gradient_debug", 171 | action="store_true", 172 | help="whether to do the real gradient descent or just to record the gradients.", 173 | ) 174 | parser.add_argument( 175 | "--num_Ds", 176 | default="0,3,3", 177 | type=str, 178 | help="Number of multiscale discriminators.", 179 | ) 180 | parser.add_argument( 181 | "--strides", 182 | default="0,2,2", 183 | type=str, 184 | help="Number of stride in the convs of multiscale discriminators.", 185 | ) 186 | parser.add_argument( 187 | "--n_layers", 188 | default="0,1,3", 189 | type=str, 190 | help="Number of stride in the convs of multiscale discriminators.", 191 | ) 192 | parser.add_argument( 193 | "--fg_layer_ind", 194 | type=int, 195 | default=2, 196 | help="Which layer is the foreground, starting from 0.", 197 | ) 198 | parser.add_argument( 199 | "--stage", 200 | type=int, 201 | help="Tells the dataset which dis_gt_alpha to use; index starting from 1. Stage 1: get bg, shouldn't have any dis_gt_alpha; stage 2: alpha from stage 1, for regularizing color, should run only on NFs; stage 3: alpha from stage 2 to constrain the alpha in IFs.", 202 | ) 203 | parser.add_argument( 204 | "--get_bg", 205 | action="store_true", 206 | help="if specified, generate the bg panorama and quit", 207 | ) 208 | self.initialized = True 209 | return parser 210 | 211 | def gather_options(self): 212 | """Initialize our parser with basic options(only once). 213 | Add additional model-specific and dataset-specific options. 214 | These options are defined in the function 215 | in model and dataset classes. 216 | """ 217 | if not self.initialized: # check if it has been initialized 218 | parser = argparse.ArgumentParser( 219 | formatter_class=argparse.ArgumentDefaultsHelpFormatter 220 | ) 221 | parser = self.initialize(parser) 222 | 223 | # get the basic options 224 | opt, _ = parser.parse_known_args() 225 | 226 | # modify model-related parser options 227 | model_name = opt.model 228 | model_option_setter = models.get_option_setter(model_name) 229 | parser = model_option_setter(parser, self.isTrain) 230 | opt, _ = parser.parse_known_args() # parse again with new defaults 231 | 232 | # modify dataset-related parser options 233 | dataset_name = opt.dataset_mode 234 | dataset_option_setter = data.get_option_setter(dataset_name) 235 | parser = dataset_option_setter(parser, self.isTrain) 236 | 237 | # save and return the parser 238 | self.parser = parser 239 | return parser.parse_args() 240 | 241 | def print_options(self, opt): 242 | """Print and save options 243 | 244 | It will print both current options and default values(if different). 245 | It will save options into a text file / [checkpoints_dir] / opt.txt 246 | """ 247 | message = "" 248 | message += "----------------- Options ---------------\n" 249 | for k, v in sorted(vars(opt).items()): 250 | comment = "" 251 | default = self.parser.get_default(k) 252 | if v != default: 253 | comment = "\t[default: %s]" % str(default) 254 | message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment) 255 | message += "----------------- End -------------------" 256 | print(message) 257 | 258 | # save to the disk 259 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name) 260 | util.mkdirs(expr_dir) 261 | file_name = os.path.join(expr_dir, "{}_opt.txt".format(opt.phase)) 262 | with open(file_name, "wt") as opt_file: 263 | opt_file.write(message) 264 | opt_file.write("\n") 265 | 266 | def parse(self): 267 | """Parse our options, create checkpoints directory suffix, and set up gpu device.""" 268 | opt = self.gather_options() 269 | opt.isTrain = self.isTrain # train or test 270 | 271 | # process opt.suffix 272 | if opt.suffix: 273 | suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else "" 274 | opt.name = opt.name + suffix 275 | 276 | self.print_options(opt) 277 | 278 | # set gpu ids 279 | str_ids = opt.gpu_ids.split(",") 280 | opt.gpu_ids = [] 281 | for str_id in str_ids: 282 | id = int(str_id) 283 | if id >= 0: 284 | opt.gpu_ids.append(id) 285 | if len(opt.gpu_ids) > 0: 286 | torch.cuda.set_device(opt.gpu_ids[0]) 287 | 288 | self.opt = opt 289 | return self.opt 290 | -------------------------------------------------------------------------------- /options/test_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TestOptions(BaseOptions): 5 | """This class includes test options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) # define shared options 12 | parser.add_argument( 13 | "--results_dir", type=str, default="./results/", help="saves results here." 14 | ) 15 | parser.add_argument( 16 | "--aspect_ratio", 17 | type=float, 18 | default=1.0, 19 | help="aspect ratio of result images", 20 | ) 21 | parser.add_argument( 22 | "--phase", type=str, default="test", help="train, val, test, etc" 23 | ) 24 | parser.add_argument( 25 | "--num_test", 26 | type=int, 27 | default=float("inf"), 28 | help="how many test images to run", 29 | ) 30 | parser.add_argument( 31 | "--test_suffix", type=str, default="", help="suffix to folder name" 32 | ) 33 | self.isTrain = False 34 | return parser 35 | -------------------------------------------------------------------------------- /options/train_options.py: -------------------------------------------------------------------------------- 1 | from .base_options import BaseOptions 2 | 3 | 4 | class TrainOptions(BaseOptions): 5 | """This class includes training options. 6 | 7 | It also includes shared options defined in BaseOptions. 8 | """ 9 | 10 | def initialize(self, parser): 11 | parser = BaseOptions.initialize(self, parser) 12 | # visdom and HTML visualization parameters 13 | parser.add_argument( 14 | "--display_ind", 15 | type=int, 16 | default=25, 17 | help="The index frame to visualize during training.", 18 | ) 19 | parser.add_argument( 20 | "--display_freq", 21 | type=int, 22 | default=10, 23 | help="frequency of showing training results on screen (in epochs)", 24 | ) 25 | parser.add_argument( 26 | "--display_ncols", 27 | type=int, 28 | default=0, 29 | help="if positive, display all images in a single visdom web panel with certain number of images per row.", 30 | ) 31 | parser.add_argument( 32 | "--display_id", type=int, default=1, help="window id of the web display" 33 | ) 34 | parser.add_argument( 35 | "--display_server", 36 | type=str, 37 | default="http://localhost", 38 | help="visdom server of the web display", 39 | ) 40 | parser.add_argument( 41 | "--display_env", 42 | type=str, 43 | default="main", 44 | help='visdom display environment name (default is "main")', 45 | ) 46 | parser.add_argument( 47 | "--display_port", 48 | type=int, 49 | default=8097, 50 | help="visdom port of the web display", 51 | ) 52 | parser.add_argument( 53 | "--update_html_freq", 54 | type=int, 55 | default=10, 56 | help="frequency of saving training results to html", 57 | ) 58 | parser.add_argument( 59 | "--print_freq", 60 | type=int, 61 | default=10, 62 | help="frequency of showing training results on console (in steps per epoch)", 63 | ) 64 | parser.add_argument( 65 | "--no_html", 66 | action="store_true", 67 | help="do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/", 68 | ) 69 | # network saving and loading parameters 70 | parser.add_argument( 71 | "--save_latest_freq", 72 | type=int, 73 | default=20, 74 | help="frequency of saving the latest results (in epochs)", 75 | ) 76 | parser.add_argument( 77 | "--save_by_epoch", 78 | type=bool, 79 | default=True, 80 | help='whether saves model as "epoch" or "latest" (overwrites previous)', 81 | ) 82 | parser.add_argument( 83 | "--continue_train", 84 | action="store_true", 85 | help="continue training: load the latest model", 86 | ) 87 | parser.add_argument( 88 | "--overwrite_lambdas", 89 | action="store_true", 90 | help="continue training and overwrite lambdas and epochs hyperparams by history", 91 | ) 92 | parser.add_argument( 93 | "--overwrite_lrs", 94 | action="store_true", 95 | help="continue training and overwrite lr by history", 96 | ) 97 | parser.add_argument( 98 | "--epoch_count", 99 | type=int, 100 | default=1, 101 | help="the starting epoch count, we save the model by , +, ...", 102 | ) 103 | parser.add_argument( 104 | "--phase", type=str, default="train", help="train, val, test, etc" 105 | ) 106 | # training parameters 107 | parser.add_argument( 108 | "--n_epochs", 109 | type=int, 110 | default=None, 111 | help="number of training epochs with the initial learning rate.\ 112 | You only need to specify one of this or n_steps", 113 | ) 114 | parser.add_argument( 115 | "--n_steps", 116 | type=int, 117 | default=24000, 118 | help="number of training steps with the initial learning rate", 119 | ) 120 | parser.add_argument( 121 | "--n_steps_decay", 122 | type=int, 123 | default=0, 124 | help="number of steps to linearly decay learning rate to zero", 125 | ) 126 | parser.add_argument( 127 | "--lr", type=float, default=0.001, help="initial learning rate for adam" 128 | ) 129 | parser.add_argument( 130 | "--lr_policy", 131 | type=str, 132 | default="linear", 133 | help="learning rate policy. [linear | step | plateau | cosine]", 134 | ) 135 | parser.add_argument( 136 | "--pretrained", 137 | action="store_true", 138 | help="Whether use part of a pretrained resnet18 for the discriminator.", 139 | ) 140 | parser.add_argument( 141 | "--discriminator_transform", 142 | type=str, 143 | default="randomcrop", 144 | help="What transform to apply to the generated rgb before feeding into the discriminator.", 145 | ) 146 | parser.add_argument( 147 | "--jitter", 148 | action="store_true", 149 | help="Whether use the original jitter for training.", 150 | ) 151 | 152 | self.isTrain = True 153 | return parser 154 | -------------------------------------------------------------------------------- /prepare_data_stage1.sh: -------------------------------------------------------------------------------- 1 | # Get ready the homographies_raw.txt, mask/01, rgb folder and run this! 2 | python video_completion.py --path $1/rgb --step 1 3 | python video_completion.py --path $1/rgb --step 4 4 | python video_completion.py --path $1/rgb --step 8 5 | 6 | mv RAFT_result/$(echo $1 | sed 's/\///g')rgb/*flow* $1 7 | 8 | python datasets/confidence.py --dataroot $1 --step 1 9 | python datasets/confidence.py --dataroot $1 --step 4 10 | python datasets/confidence.py --dataroot $1 --step 8 11 | 12 | python datasets/homography.py --homography_path $1/homographies_raw.txt --width $2 --height $3 13 | python data/misc_data_process.py --dataroot $1 14 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch=1.11.0 2 | cudatoolkit=11.2 3 | torchvision 4 | scipy 5 | tensorboard 6 | Pillow -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """Script to save the full outputs of an Omnimatte model. 2 | 3 | Once you have trained the Omnimatte model with train.py, you can use this script to save the model's final omnimattes. 4 | It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'. 5 | 6 | It first creates a model and dataset given the options. It will hard-code some parameters. 7 | It then runs inference for '--num_test' images and save results to an HTML file. 8 | 9 | Example (after training a model): 10 | python test.py --dataroot ./datasets/tennis --name tennis 11 | 12 | Use '--results_dir ' to specify the results directory. 13 | 14 | See options/base_options.py and options/test_options.py for more test options. 15 | """ 16 | import os 17 | from options.test_options import TestOptions 18 | from third_party.data import create_dataset 19 | from third_party.models import create_model 20 | from third_party.util.visualizer import save_images, save_videos 21 | from third_party.util import html 22 | import torch 23 | 24 | 25 | if __name__ == "__main__": 26 | testopt = TestOptions() 27 | opt = testopt.parse() 28 | # hard-code some parameters for test 29 | opt.num_threads = 0 # test code only supports num_threads = 0 30 | opt.batch_size = 1 # test code only supports batch_size = 1 31 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 32 | opt.display_id = ( 33 | -1 34 | ) # no visdom display; the test code saves the results to a HTML file. 35 | dataset = create_dataset( 36 | opt 37 | ) # create a dataset given opt.dataset_mode and other options 38 | model = create_model(opt) # create a model given opt.model and other options 39 | model.setup(opt) # regular setup: load and print networks; create schedulers 40 | if opt.gradient_debug: 41 | weight = torch.load( 42 | os.path.join(opt.checkpoints_dir, opt.name, str(opt.epoch) + "_others.pth") 43 | ) 44 | for i in range(len(model.discriminators)): 45 | if model.discriminators[i] is not None: 46 | model.discriminators[i].load_state_dict( 47 | weight["discriminator_l" + str(i)], strict=False 48 | ) 49 | print(i, "th discriminator weights loaded unstrictly") 50 | print( 51 | "the dict in the history is", 52 | weight["discriminator_l" + str(i)].keys(), 53 | ) 54 | print( 55 | "the dict in current model is", 56 | model.discriminators[i].state_dict().keys(), 57 | ) 58 | 59 | # create a website 60 | web_dir = os.path.join( 61 | opt.results_dir, 62 | opt.name, 63 | "{}_{}_{}".format(opt.phase, opt.epoch, opt.test_suffix), 64 | ) # define the website directory 65 | print("creating web directory", web_dir) 66 | webpage = html.HTML( 67 | web_dir, 68 | "Experiment = %s, Phase = %s, Epoch = %s" % (opt.name, opt.phase, opt.epoch), 69 | ) 70 | video_visuals = None 71 | loss_recon = 0 72 | model.do_cam_adj = True #False 73 | for i, data in enumerate(dataset): 74 | # print(i) 75 | # if i < 130: 76 | # continue 77 | if i >= opt.num_test: # only apply our model to opt.num_test images. 78 | break 79 | model.set_input(data) # unpack data from data loader 80 | model.test(i) # run inference 81 | img_path = model.get_image_paths() # get image paths 82 | if i % 5 == 0: # save images to an HTML file 83 | print("processing (%04d)-th image... %s" % (i, img_path)) 84 | with torch.no_grad(): 85 | visuals = model.get_results() # rgba, reconstruction, original, mask 86 | if video_visuals is None: 87 | video_visuals = visuals 88 | else: 89 | for k in video_visuals: 90 | video_visuals[k] = torch.cat((video_visuals[k], visuals[k])) 91 | for k in video_visuals: 92 | rgba = {k: visuals[k]} # for k in visuals if "rgba" in k 93 | # save RGBA layers 94 | save_images( 95 | webpage, 96 | rgba, 97 | img_path, 98 | aspect_ratio=opt.aspect_ratio, 99 | width=opt.display_winsize, 100 | ) 101 | # if os.path.isdir(os.path.join(opt.dataroot, "rgb_invis_gt")): 102 | # print( 103 | # model.criterionLoss( 104 | # model.reconstruction_rgb_no_cube, model.target_image 105 | # ), 106 | # ) 107 | # loss_recon += model.criterionLoss( 108 | # model.reconstruction_rgb_no_cube, model.target_image 109 | # ) 110 | 111 | save_videos(webpage, video_visuals, width=opt.display_winsize) 112 | webpage.save() # save the HTML of videos 113 | with open(os.path.join(web_dir, "invis_gt_eval.txt"), "w") as f: 114 | print("avg recon no cube L1Loss " + str(loss_recon / len(dataset)), file=f) 115 | -------------------------------------------------------------------------------- /third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/__init__.py -------------------------------------------------------------------------------- /third_party/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/__init__.pyc -------------------------------------------------------------------------------- /third_party/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/data/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes all the modules related to data loading and preprocessing 2 | 3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset. 4 | You need to implement four functions: 5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 6 | -- <__len__>: return the size of dataset. 7 | -- <__getitem__>: get a data point from data loader. 8 | -- : (optionally) add dataset-specific options and set default options. 9 | 10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'. 11 | See our template dataset class 'template_dataset.py' for more details. 12 | """ 13 | import importlib 14 | import torch.utils.data 15 | from .base_dataset import BaseDataset 16 | 17 | 18 | def find_dataset_using_name(dataset_name): 19 | """Import the module "data/[dataset_name]_dataset.py". 20 | 21 | In the file, the class called DatasetNameDataset() will 22 | be instantiated. It has to be a subclass of BaseDataset, 23 | and it is case-insensitive. 24 | """ 25 | dataset_filename = "data." + dataset_name + "_dataset" 26 | datasetlib = importlib.import_module(dataset_filename) 27 | 28 | dataset = None 29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset' 30 | for name, cls in datasetlib.__dict__.items(): 31 | if name.lower() == target_dataset_name.lower() \ 32 | and issubclass(cls, BaseDataset): 33 | dataset = cls 34 | 35 | if dataset is None: 36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name)) 37 | 38 | return dataset 39 | 40 | 41 | def get_option_setter(dataset_name): 42 | """Return the static method of the dataset class.""" 43 | dataset_class = find_dataset_using_name(dataset_name) 44 | return dataset_class.modify_commandline_options 45 | 46 | 47 | def create_dataset(opt): 48 | """Create a dataset given the option. 49 | 50 | This function wraps the class CustomDatasetDataLoader. 51 | This is the main interface between this package and 'train.py'/'test.py' 52 | 53 | Example: 54 | >>> from data import create_dataset 55 | >>> dataset = create_dataset(opt) 56 | """ 57 | data_loader = CustomDatasetDataLoader(opt) 58 | dataset = data_loader.load_data() 59 | return dataset 60 | 61 | 62 | class CustomDatasetDataLoader(): 63 | """Wrapper class of Dataset class that performs multi-threaded data loading""" 64 | 65 | def __init__(self, opt): 66 | """Initialize this class 67 | 68 | Step 1: create a dataset instance given the name [dataset_mode] 69 | Step 2: create a multi-threaded data loader. 70 | """ 71 | self.opt = opt 72 | dataset_class = find_dataset_using_name(opt.dataset_mode) 73 | self.dataset = dataset_class(opt) 74 | print("dataset [%s] was created" % type(self.dataset).__name__) 75 | loader = torch.utils.data.DataLoader 76 | self.dataloader = loader( 77 | self.dataset, 78 | batch_size=opt.batch_size, 79 | shuffle=not opt.serial_batches, 80 | num_workers=int(opt.num_threads), 81 | persistent_workers=int(opt.num_threads) > 0, 82 | drop_last = True) 83 | 84 | def load_data(self): 85 | return self 86 | 87 | def __len__(self): 88 | """Return the number of data in the dataset""" 89 | return min(len(self.dataset), self.opt.max_dataset_size) 90 | 91 | def __iter__(self): 92 | """Return a batch of data""" 93 | for i, data in enumerate(self.dataloader): 94 | if i * self.opt.batch_size >= self.opt.max_dataset_size: 95 | break 96 | yield data 97 | -------------------------------------------------------------------------------- /third_party/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/data/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/data/__pycache__/base_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/base_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/data/__pycache__/base_dataset.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/base_dataset.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/data/__pycache__/image_folder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/image_folder.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/data/__pycache__/image_folder.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/image_folder.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/data/base_dataset.py: -------------------------------------------------------------------------------- 1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets. 2 | 3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses. 4 | """ 5 | import random 6 | import numpy as np 7 | import torch.utils.data as data 8 | from PIL import Image 9 | import torchvision.transforms as transforms 10 | from abc import ABC, abstractmethod 11 | 12 | 13 | class BaseDataset(data.Dataset, ABC): 14 | """This class is an abstract base class (ABC) for datasets. 15 | 16 | To create a subclass, you need to implement the following four functions: 17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt). 18 | -- <__len__>: return the size of dataset. 19 | -- <__getitem__>: get a data point. 20 | -- : (optionally) add dataset-specific options and set default options. 21 | """ 22 | 23 | def __init__(self, opt): 24 | """Initialize the class; save the options in the class 25 | 26 | Parameters: 27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 28 | """ 29 | self.opt = opt 30 | self.root = opt.dataroot 31 | 32 | @staticmethod 33 | def modify_commandline_options(parser, is_train): 34 | """Add new dataset-specific options, and rewrite default values for existing options. 35 | 36 | Parameters: 37 | parser -- original option parser 38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 39 | 40 | Returns: 41 | the modified parser. 42 | """ 43 | return parser 44 | 45 | @abstractmethod 46 | def __len__(self): 47 | """Return the total number of images in the dataset.""" 48 | return 0 49 | 50 | @abstractmethod 51 | def __getitem__(self, index): 52 | """Return a data point and its metadata information. 53 | 54 | Parameters: 55 | index - - a random integer for data indexing 56 | 57 | Returns: 58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information. 59 | """ 60 | pass 61 | 62 | 63 | def get_params(opt, size): 64 | w, h = size 65 | new_h = h 66 | new_w = w 67 | if opt.preprocess == 'resize_and_crop': 68 | new_h = new_w = opt.load_size 69 | elif opt.preprocess == 'scale_width_and_crop': 70 | new_w = opt.load_size 71 | new_h = opt.load_size * h // w 72 | 73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size)) 74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size)) 75 | 76 | flip = random.random() > 0.5 77 | 78 | return {'crop_pos': (x, y), 'flip': flip} 79 | 80 | 81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True): 82 | transform_list = [] 83 | if grayscale: 84 | transform_list.append(transforms.Grayscale(1)) 85 | if 'resize' in opt.preprocess: 86 | osize = [opt.load_size, opt.load_size] 87 | transform_list.append(transforms.Resize(osize, method)) 88 | elif 'scale_width' in opt.preprocess: 89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method))) 90 | 91 | if 'crop' in opt.preprocess: 92 | if params is None: 93 | transform_list.append(transforms.RandomCrop(opt.crop_size)) 94 | else: 95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size))) 96 | 97 | if opt.preprocess == 'none': 98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method))) 99 | 100 | if not opt.no_flip: 101 | if params is None: 102 | transform_list.append(transforms.RandomHorizontalFlip()) 103 | elif params['flip']: 104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip']))) 105 | 106 | if convert: 107 | transform_list += [transforms.ToTensor()] 108 | if grayscale: 109 | transform_list += [transforms.Normalize((0.5,), (0.5,))] 110 | else: 111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] 112 | return transforms.Compose(transform_list) 113 | 114 | 115 | def __make_power_2(img, base, method=Image.BICUBIC): 116 | ow, oh = img.size 117 | h = int(round(oh / base) * base) 118 | w = int(round(ow / base) * base) 119 | if h == oh and w == ow: 120 | return img 121 | 122 | __print_size_warning(ow, oh, w, h) 123 | return img.resize((w, h), method) 124 | 125 | 126 | def __scale_width(img, target_size, crop_size, method=Image.BICUBIC): 127 | ow, oh = img.size 128 | if ow == target_size and oh >= crop_size: 129 | return img 130 | w = target_size 131 | h = int(max(target_size * oh / ow, crop_size)) 132 | return img.resize((w, h), method) 133 | 134 | 135 | def __crop(img, pos, size): 136 | ow, oh = img.size 137 | x1, y1 = pos 138 | tw = th = size 139 | if (ow > tw or oh > th): 140 | return img.crop((x1, y1, x1 + tw, y1 + th)) 141 | return img 142 | 143 | 144 | def __flip(img, flip): 145 | if flip: 146 | return img.transpose(Image.FLIP_LEFT_RIGHT) 147 | return img 148 | 149 | 150 | def __print_size_warning(ow, oh, w, h): 151 | """Print warning information about image size(only print once)""" 152 | if not hasattr(__print_size_warning, 'has_printed'): 153 | print("The image size needs to be a multiple of 4. " 154 | "The loaded image size was (%d, %d), so it was adjusted to " 155 | "(%d, %d). This adjustment will be done to all images " 156 | "whose sizes are not multiples of 4" % (ow, oh, w, h)) 157 | __print_size_warning.has_printed = True 158 | -------------------------------------------------------------------------------- /third_party/data/image_folder.py: -------------------------------------------------------------------------------- 1 | """A modified image folder class 2 | 3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 4 | so that this class can load images from both current directory and its subdirectories. 5 | """ 6 | 7 | import torch.utils.data as data 8 | 9 | from PIL import Image 10 | import os 11 | 12 | IMG_EXTENSIONS = [ 13 | '.jpg', '.JPG', '.jpeg', '.JPEG', 14 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 15 | '.tif', '.TIF', '.tiff', '.TIFF', 16 | ] 17 | 18 | 19 | def is_image_file(filename): 20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 21 | 22 | 23 | def make_dataset(dir, max_dataset_size=float("inf")): 24 | images = [] 25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 26 | 27 | for root, _, fnames in sorted(os.walk(dir)): 28 | for fname in fnames: 29 | if is_image_file(fname): 30 | path = os.path.join(root, fname) 31 | images.append(path) 32 | images = sorted(images) 33 | return images[:min(max_dataset_size, len(images))] 34 | 35 | 36 | def default_loader(path): 37 | return Image.open(path).convert('RGB') 38 | 39 | 40 | class ImageFolder(data.Dataset): 41 | 42 | def __init__(self, root, transform=None, return_paths=False, 43 | loader=default_loader): 44 | imgs = make_dataset(root) 45 | if len(imgs) == 0: 46 | raise(RuntimeError("Found 0 images in: " + root + "\n" 47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 48 | 49 | self.root = root 50 | self.imgs = imgs 51 | self.transform = transform 52 | self.return_paths = return_paths 53 | self.loader = loader 54 | 55 | def __getitem__(self, index): 56 | path = self.imgs[index] 57 | img = self.loader(path) 58 | if self.transform is not None: 59 | img = self.transform(img) 60 | if self.return_paths: 61 | return img, path 62 | else: 63 | return img 64 | 65 | def __len__(self): 66 | return len(self.imgs) 67 | -------------------------------------------------------------------------------- /third_party/models/__init__.py: -------------------------------------------------------------------------------- 1 | """This package contains modules related to objective functions, optimizations, and network architectures. 2 | 3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 4 | You need to implement the following five functions: 5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 6 | -- : unpack data from dataset and apply preprocessing. 7 | -- : produce intermediate results. 8 | -- : calculate loss, gradients, and update network weights. 9 | -- : (optionally) add model-specific options and set default options. 10 | 11 | In the function <__init__>, you need to define four lists: 12 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 13 | -- self.model_names (str list): define networks used in our training. 14 | -- self.visual_names (str list): specify the images that you want to display and save. 15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 16 | 17 | Now you can use the model class by specifying flag '--model dummy'. 18 | See our template model class 'template_model.py' for more details. 19 | """ 20 | 21 | import importlib 22 | from .base_model import BaseModel 23 | 24 | 25 | def find_model_using_name(model_name): 26 | """Import the module "models/[model_name]_model.py". 27 | 28 | In the file, the class called DatasetNameModel() will 29 | be instantiated. It has to be a subclass of BaseModel, 30 | and it is case-insensitive. 31 | """ 32 | model_filename = "models." + model_name + "_model" 33 | modellib = importlib.import_module(model_filename) 34 | model = None 35 | target_model_name = model_name.replace('_', '') + 'model' 36 | for name, cls in modellib.__dict__.items(): 37 | if name.lower() == target_model_name.lower() \ 38 | and issubclass(cls, BaseModel): 39 | model = cls 40 | 41 | if model is None: 42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 43 | exit(0) 44 | 45 | return model 46 | 47 | 48 | def get_option_setter(model_name): 49 | """Return the static method of the model class.""" 50 | model_class = find_model_using_name(model_name) 51 | return model_class.modify_commandline_options 52 | 53 | 54 | def create_model(opt): 55 | """Create a model given the option. 56 | 57 | This function warps the class CustomDatasetDataLoader. 58 | This is the main interface between this package and 'train.py'/'test.py' 59 | 60 | Example: 61 | >>> from models import create_model 62 | >>> model = create_model(opt) 63 | """ 64 | model = find_model_using_name(opt.model) 65 | instance = model(opt) 66 | print("model [%s] was created" % type(instance).__name__) 67 | return instance 68 | -------------------------------------------------------------------------------- /third_party/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/models/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/models/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/models/__pycache__/base_model.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/base_model.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/models/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/models/__pycache__/networks.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/networks.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/models/__pycache__/networks_lnr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/networks_lnr.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/models/__pycache__/networks_lnr.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/networks_lnr.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from collections import OrderedDict 4 | from abc import ABC, abstractmethod 5 | import numpy as np 6 | from . import networks 7 | 8 | 9 | class BaseModel(ABC): 10 | """This class is an abstract base class (ABC) for models. 11 | To create a subclass, you need to implement the following five functions: 12 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 13 | -- : unpack data from dataset and apply preprocessing. 14 | -- : produce intermediate results. 15 | -- : calculate losses, gradients, and update network weights. 16 | -- : (optionally) add model-specific options and set default options. 17 | """ 18 | 19 | def __init__(self, opt): 20 | """Initialize the BaseModel class. 21 | 22 | Parameters: 23 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions 24 | 25 | When creating your custom class, you need to implement your own initialization. 26 | In this function, you should first call 27 | Then, you need to define four lists: 28 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 29 | -- self.model_names (str list): define networks used in our training. 30 | -- self.visual_names (str list): specify the images that you want to display and save. 31 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example. 32 | """ 33 | self.opt = opt 34 | self.gpu_ids = opt.gpu_ids 35 | self.isTrain = opt.isTrain 36 | self.device = ( 37 | torch.device("cuda:{}".format(self.gpu_ids[0])) 38 | if self.gpu_ids 39 | else torch.device("cpu") 40 | ) # get device name: CPU or GPU 41 | self.save_dir = os.path.join( 42 | opt.checkpoints_dir, opt.name 43 | ) # save all the checkpoints to save_dir 44 | self.loss_names = [] 45 | self.model_names = [] 46 | self.visual_names = [] 47 | self.optimizers = [] 48 | self.image_paths = [] 49 | self.metric = 0 # used for learning rate policy 'plateau' 50 | 51 | @staticmethod 52 | def modify_commandline_options(parser, is_train): 53 | """Add new model-specific options, and rewrite default values for existing options. 54 | 55 | Parameters: 56 | parser -- original option parser 57 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options. 58 | 59 | Returns: 60 | the modified parser. 61 | """ 62 | return parser 63 | 64 | @abstractmethod 65 | def set_input(self, input): 66 | """Unpack input data from the dataloader and perform necessary pre-processing steps. 67 | 68 | Parameters: 69 | input (dict): includes the data itself and its metadata information. 70 | """ 71 | pass 72 | 73 | @abstractmethod 74 | def forward(self): 75 | """Run forward pass; called by both functions and .""" 76 | pass 77 | 78 | @abstractmethod 79 | def optimize_parameters(self): 80 | """Calculate losses, gradients, and update network weights; called in every training iteration""" 81 | pass 82 | 83 | def setup(self, opt): 84 | """Load and print networks; create schedulers 85 | 86 | Parameters: 87 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions 88 | """ 89 | if self.isTrain: 90 | self.schedulers = [ 91 | networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers 92 | ] 93 | if not self.isTrain or opt.continue_train: 94 | if opt.epoch != "latest": 95 | load_suffix = "epoch_" + opt.epoch 96 | else: 97 | load_suffix = opt.epoch 98 | self.load_networks(load_suffix) 99 | self.print_networks(opt.verbose) 100 | 101 | def eval(self): 102 | """Make models eval mode during test time""" 103 | for name in self.model_names: 104 | if isinstance(name, str): 105 | net = getattr(self, "net" + name) 106 | net.eval() 107 | 108 | def test(self, total_iters): 109 | """Forward function used in test time. 110 | 111 | This function wraps function in no_grad() so we don't save intermediate steps for backprop 112 | It also calls to produce additional visualization results 113 | """ 114 | if self.opt.gradient_debug: 115 | self.set_requires_grad(self.netOmnimatteGAN, True) 116 | self.netOmnimatteGAN.train() 117 | web_dir = os.path.join( 118 | self.opt.results_dir, 119 | self.opt.name, 120 | "{}_{}_{}".format(self.opt.phase, self.opt.epoch, self.opt.test_suffix), 121 | ) 122 | self.forward() 123 | grad_map = self.gradient_debug(total_iters) 124 | np.save( 125 | os.path.join( 126 | web_dir, 127 | str(total_iters) + "_grad_map.npy", 128 | ), 129 | grad_map, 130 | ) 131 | else: 132 | with torch.no_grad(): 133 | self.forward() 134 | 135 | def compute_visuals(self): 136 | """Calculate additional output images for visdom and HTML visualization""" 137 | pass 138 | 139 | def get_image_paths(self): 140 | """Return image paths that are used to load current data""" 141 | return self.image_paths 142 | 143 | def update_learning_rate(self): 144 | """Update learning rates for all the networks; called at the end of every epoch""" 145 | old_lr = self.optimizers[0].param_groups[0]["lr"] 146 | for scheduler in self.schedulers: 147 | if self.opt.lr_policy == "plateau": 148 | scheduler.step(self.metric) 149 | else: 150 | scheduler.step() 151 | 152 | lr = self.optimizers[0].param_groups[0]["lr"] 153 | if old_lr != lr: 154 | print("learning rate %.7f -> %.7f" % (old_lr, lr)) 155 | 156 | def get_current_visuals(self): 157 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML""" 158 | visual_ret = OrderedDict() 159 | for name in self.visual_names: 160 | if isinstance(name, str): 161 | visual_ret[name] = getattr(self, name) 162 | return visual_ret 163 | 164 | def get_current_losses(self): 165 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file""" 166 | errors_ret = OrderedDict() 167 | for name in self.loss_names: 168 | if isinstance(name, str): 169 | errors_ret[name] = float( 170 | getattr(self, "loss_" + name) 171 | ) # float(...) works for both scalar tensor and float number 172 | return errors_ret 173 | 174 | def save_networks(self, epoch): 175 | """Save all the networks to the disk. 176 | 177 | Parameters: 178 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 179 | """ 180 | for name in self.model_names: 181 | if isinstance(name, str): 182 | save_filename = "%s_net_%s.pth" % (epoch, name) 183 | save_path = os.path.join(self.save_dir, save_filename) 184 | net = getattr(self, "net" + name) 185 | 186 | if len(self.gpu_ids) > 0 and torch.cuda.is_available(): 187 | torch.save(net.module.cpu().state_dict(), save_path) 188 | net.cuda(self.gpu_ids[0]) 189 | else: 190 | torch.save(net.cpu().state_dict(), save_path) 191 | 192 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0): 193 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)""" 194 | key = keys[i] 195 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer 196 | if module.__class__.__name__.startswith("InstanceNorm") and ( 197 | key == "running_mean" or key == "running_var" 198 | ): 199 | if getattr(module, key) is None: 200 | state_dict.pop(".".join(keys)) 201 | if module.__class__.__name__.startswith("InstanceNorm") and ( 202 | key == "num_batches_tracked" 203 | ): 204 | state_dict.pop(".".join(keys)) 205 | else: 206 | self.__patch_instance_norm_state_dict( 207 | state_dict, getattr(module, key), keys, i + 1 208 | ) 209 | 210 | def load_networks(self, epoch): 211 | """Load all the networks from the disk. 212 | 213 | Parameters: 214 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name) 215 | """ 216 | for name in self.model_names: 217 | if isinstance(name, str): 218 | load_filename = "%s_net_%s.pth" % (epoch, name) 219 | load_path = os.path.join(self.save_dir, load_filename) 220 | net = getattr(self, "net" + name) 221 | if isinstance(net, torch.nn.DataParallel): 222 | net = net.module 223 | print("loading the model from %s" % load_path) 224 | # if you are using PyTorch newer than 0.4 (e.g., built from 225 | # GitHub source), you can remove str() on self.device 226 | state_dict = torch.load(load_path, map_location=str(self.device)) 227 | if hasattr(state_dict, "_metadata"): 228 | del state_dict._metadata 229 | 230 | # patch InstanceNorm checkpoints prior to 0.4 231 | for key in list( 232 | state_dict.keys() 233 | ): # need to copy keys here because we mutate in loop 234 | self.__patch_instance_norm_state_dict( 235 | state_dict, net, key.split(".") 236 | ) 237 | if net.state_dict()[key].shape != state_dict[key].shape: 238 | print(key, 'size mismatch! duplicating...') 239 | repeat_times = [2]+[1]*(len(net.state_dict()[key].shape)-1) 240 | state_dict[key] = state_dict[key].repeat(repeat_times) 241 | print(state_dict[key].size()) 242 | 243 | net.load_state_dict(state_dict) 244 | 245 | def print_networks(self, verbose): 246 | """Print the total number of parameters in the network and (if verbose) network architecture 247 | 248 | Parameters: 249 | verbose (bool) -- if verbose: print the network architecture 250 | """ 251 | print("---------- Networks initialized -------------") 252 | for name in self.model_names: 253 | if isinstance(name, str): 254 | net = getattr(self, "net" + name) 255 | num_params = 0 256 | num_trainable_params = 0 257 | for param in net.parameters(): 258 | num_params += param.numel() 259 | if param.requires_grad: 260 | num_trainable_params += param.numel() 261 | if verbose: 262 | print(net) 263 | print( 264 | "[Network %s] Total number of parameters : %.3f M" 265 | % (name, num_params / 1e6) 266 | ) 267 | print( 268 | "[Network %s] Total number of trainable parameters : %.3f M" 269 | % (name, num_trainable_params / 1e6) 270 | ) 271 | print("-----------------------------------------------") 272 | 273 | def set_requires_grad(self, nets, requires_grad=False): 274 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations 275 | Parameters: 276 | nets (network list) -- a list of networks 277 | requires_grad (bool) -- whether the networks require gradients or not 278 | """ 279 | if not isinstance(nets, list): 280 | nets = [nets] 281 | for net in nets: 282 | if net is not None: 283 | for param in net.parameters(): 284 | param.requires_grad = requires_grad 285 | -------------------------------------------------------------------------------- /third_party/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import lr_scheduler 4 | 5 | 6 | ############################################################################### 7 | # Helper Functions 8 | ############################################################################### 9 | def get_scheduler(optimizer, opt): 10 | """Return a learning rate scheduler 11 | 12 | Parameters: 13 | optimizer -- the optimizer of the network 14 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  15 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 16 | 17 | For 'linear', we keep the same learning rate for the first epochs 18 | and linearly decay the rate to zero over the next epochs. 19 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 20 | See https://pytorch.org/docs/stable/optim.html for more details. 21 | """ 22 | if opt.lr_policy == 'linear': 23 | def lambda_rule(epoch): 24 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 25 | return lr_l 26 | 27 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 28 | elif opt.lr_policy == 'step': 29 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1) 30 | elif opt.lr_policy == 'plateau': 31 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 32 | elif opt.lr_policy == 'cosine': 33 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 34 | else: 35 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 36 | return scheduler 37 | 38 | 39 | def init_net(net, gpu_ids=[]): 40 | """Initialize a network by registering CPU/GPU device (with multi-GPU support) 41 | Parameters: 42 | net (network) -- the network to be initialized 43 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 44 | 45 | Return an initialized network. 46 | """ 47 | if len(gpu_ids) > 0: 48 | assert (torch.cuda.is_available()) 49 | net.to(gpu_ids[0]) 50 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 51 | return net 52 | -------------------------------------------------------------------------------- /third_party/util/__init__.py: -------------------------------------------------------------------------------- 1 | """This package includes a miscellaneous collection of useful helper functions.""" 2 | -------------------------------------------------------------------------------- /third_party/util/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__init__.pyc -------------------------------------------------------------------------------- /third_party/util/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/util/__pycache__/__init__.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/__init__.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/util/__pycache__/html.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/html.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/util/__pycache__/html.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/html.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/util/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/util/__pycache__/util.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/util.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/util/__pycache__/visualizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/visualizer.cpython-38.pyc -------------------------------------------------------------------------------- /third_party/util/__pycache__/visualizer.cpython-39.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/visualizer.cpython-39.pyc -------------------------------------------------------------------------------- /third_party/util/html.py: -------------------------------------------------------------------------------- 1 | import dominate 2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source 3 | import os 4 | 5 | 6 | class HTML: 7 | """This HTML class allows us to save images and write texts into a single HTML file. 8 | 9 | It consists of functions such as (add a text header to the HTML file), 10 | (add a row of images to the HTML file), and (save the HTML to the disk). 11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API. 12 | """ 13 | 14 | def __init__(self, web_dir, title, refresh=0): 15 | """Initialize the HTML classes 16 | 17 | Parameters: 18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0: 35 | with self.doc.head: 36 | meta(http_equiv="refresh", content=str(refresh)) 37 | 38 | def get_image_dir(self): 39 | """Return the directory that stores images""" 40 | return self.img_dir 41 | 42 | def get_video_dir(self): 43 | """Return the directory that stores videos""" 44 | return self.vid_dir 45 | 46 | def add_header(self, text): 47 | """Insert a header to the HTML file 48 | 49 | Parameters: 50 | text (str) -- the header text 51 | """ 52 | with self.doc: 53 | h3(text) 54 | 55 | def add_images(self, ims, txts, links, width=400): 56 | """add images to the HTML file 57 | 58 | Parameters: 59 | ims (str list) -- a list of image paths 60 | txts (str list) -- a list of image names shown on the website 61 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 62 | """ 63 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 64 | self.doc.add(self.t) 65 | with self.t: 66 | with tr(): 67 | for im, txt, link in zip(ims, txts, links): 68 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 69 | with p(): 70 | with a(href=os.path.join('images', link)): 71 | img(style="width:%dpx" % width, src=os.path.join('images', im)) 72 | br() 73 | p(txt) 74 | 75 | def add_videos(self, vids, txts, links, width=400): 76 | """add images to the HTML file 77 | 78 | Parameters: 79 | ims (str list) -- a list of image paths 80 | txts (str list) -- a list of image names shown on the website 81 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page 82 | """ 83 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table 84 | self.doc.add(self.t) 85 | with self.t: 86 | with tr(): 87 | for vid, txt, link in zip(vids, txts, links): 88 | with td(style="word-wrap: break-word;", halign="center", valign="top"): 89 | with p(): 90 | with a(href=os.path.join('videos', link)): 91 | with video(style="width:%dpx" % width, controls=True): 92 | source(src=os.path.join('videos', vid), type="video/mp4") 93 | br() 94 | p(txt) 95 | 96 | def save(self): 97 | """save the current content to the HMTL file""" 98 | html_file = '%s/index.html' % self.web_dir 99 | f = open(html_file, 'wt') 100 | f.write(self.doc.render()) 101 | f.close() 102 | 103 | 104 | if __name__ == '__main__': # we show an example usage here. 105 | html = HTML('web/', 'test_html') 106 | html.add_header('hello world') 107 | 108 | ims, txts, links = [], [], [] 109 | for n in range(4): 110 | ims.append('image_%d.png' % n) 111 | txts.append('text_%d' % n) 112 | links.append('image_%d.png' % n) 113 | html.add_images(ims, txts, links) 114 | html.save() 115 | -------------------------------------------------------------------------------- /third_party/util/util.py: -------------------------------------------------------------------------------- 1 | """This module contains simple helper functions """ 2 | from __future__ import print_function 3 | import torch 4 | import numpy as np 5 | from PIL import Image 6 | import os 7 | 8 | 9 | def tensor2im(input_image, imtype=np.uint8): 10 | """"Converts a Tensor array into a numpy image array. 11 | 12 | Parameters: 13 | input_image (tensor) -- the input image tensor array 14 | imtype (type) -- the desired type of the converted numpy array 15 | """ 16 | if not isinstance(input_image, np.ndarray): 17 | if isinstance(input_image, torch.Tensor): # get the data from a variable 18 | image_tensor = input_image.data 19 | else: 20 | return input_image 21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array 22 | if image_numpy.shape[0] == 1: # grayscale to RGB 23 | image_numpy = np.tile(image_numpy, (3, 1, 1)) 24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling 25 | if image_numpy.shape[-1] == 4: 26 | image_numpy = render_png(image_numpy) 27 | else: # if it is a numpy array, do nothing 28 | image_numpy = input_image 29 | if image_numpy.shape[-1] == 4: 30 | image_numpy = render_png(image_numpy) 31 | return image_numpy.astype(imtype) 32 | 33 | 34 | def render_png(image, background='checker'): 35 | height, width = image.shape[:2] 36 | if background == 'checker': 37 | checkerboard = np.kron([[136, 120] * (width//128+1), [120, 136] * (width//128+1)] * (height//128+1), np.ones((16, 16))) 38 | checkerboard = np.expand_dims(np.tile(checkerboard, (4, 4)), -1) 39 | bg = checkerboard[:height, :width] 40 | elif background == 'black': 41 | bg = np.zeros([height, width, 1]) 42 | else: 43 | bg = 255 * np.ones([height, width, 1]) 44 | image = image.astype(np.float32) 45 | alpha = image[:, :, 3:] / 255 46 | rendered_image = alpha * image[:, :, :3] + (1 - alpha) * bg 47 | return rendered_image.astype(np.uint8) 48 | 49 | 50 | def add_title(image, title_text): 51 | Print('please put a dir for a font here') 52 | font_dir = '' 53 | from PIL import Image, ImageFont, ImageDraw 54 | import matplotlib.font_manager as fm 55 | image = Image.fromarray(image) 56 | title_font = ImageFont.truetype(font_dir, 35) 57 | image_editable = ImageDraw.Draw(image) 58 | image_editable.text((10,10), title_text, (255,255,255), stroke_fill=(0,0,0), font=title_font, stroke_width=3) 59 | return np.asarray(image) 60 | 61 | 62 | def diagnose_network(net, name='network'): 63 | """Calculate and print the mean of average absolute(gradients) 64 | 65 | Parameters: 66 | net (torch network) -- Torch network 67 | name (str) -- the name of the network 68 | """ 69 | mean = 0.0 70 | count = 0 71 | for param in net.parameters(): 72 | if param.grad is not None: 73 | mean += torch.mean(torch.abs(param.grad.data)) 74 | count += 1 75 | if count > 0: 76 | mean = mean / count 77 | print(name) 78 | print(mean) 79 | 80 | 81 | def save_image(image_numpy, image_path, aspect_ratio=1.0): 82 | """Save a numpy image to the disk 83 | 84 | Parameters: 85 | image_numpy (numpy array) -- input numpy array 86 | image_path (str) -- the path of the image 87 | """ 88 | image_pil = Image.fromarray(image_numpy) 89 | h, w, _ = image_numpy.shape 90 | 91 | if aspect_ratio > 1.0: 92 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC) 93 | if aspect_ratio < 1.0: 94 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC) 95 | image_pil.save(image_path) 96 | 97 | 98 | def print_numpy(x, val=True, shp=False): 99 | """Print the mean, min, max, median, std, and size of a numpy array 100 | 101 | Parameters: 102 | val (bool) -- if print the values of the numpy array 103 | shp (bool) -- if print the shape of the numpy array 104 | """ 105 | x = x.astype(np.float64) 106 | if shp: 107 | print('shape,', x.shape) 108 | if val: 109 | x = x.flatten() 110 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % ( 111 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x))) 112 | 113 | 114 | def mkdirs(paths): 115 | """create empty directories if they don't exist 116 | 117 | Parameters: 118 | paths (str list) -- a list of directory paths 119 | """ 120 | if isinstance(paths, list) and not isinstance(paths, str): 121 | for path in paths: 122 | mkdir(path) 123 | else: 124 | mkdir(paths) 125 | 126 | 127 | def mkdir(path): 128 | """create a single empty directory if it didn't exist 129 | 130 | Parameters: 131 | path (str) -- a single directory path 132 | """ 133 | if not os.path.exists(path): 134 | os.makedirs(path) 135 | -------------------------------------------------------------------------------- /third_party/util/util.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/util.pyc -------------------------------------------------------------------------------- /train_GAN.py: -------------------------------------------------------------------------------- 1 | """Script for training an Omnimatte model on a video. 2 | 3 | You need to specify the dataset ('--dataroot') and experiment name ('--name'). 4 | 5 | Example: 6 | python train.py --dataroot ./datasets/tennis --name tennis --gpu_ids 0,1 7 | 8 | The script first creates a model, dataset, and visualizer given the options. 9 | It then does standard network training. During training, it also visualizes/saves the images, prints/saves the loss 10 | plot, and saves the model. 11 | Use '--continue_train' to resume your previous training. 12 | 13 | See options/base_options.py and options/train_options.py for more training options. 14 | """ 15 | import time 16 | from options.train_options import TrainOptions 17 | from third_party.data import create_dataset 18 | from third_party.models import create_model 19 | from third_party.util.visualizer import Visualizer 20 | import torch 21 | import numpy as np 22 | import random 23 | import os 24 | 25 | 26 | def main(): 27 | trainopt = TrainOptions() 28 | opt = trainopt.parse() 29 | 30 | torch.manual_seed(opt.seed) 31 | np.random.seed(opt.seed) 32 | random.seed(opt.seed) 33 | 34 | dataset = create_dataset(opt) 35 | dataset_size = len(dataset) 36 | print("The number of training images = %d" % dataset_size) 37 | if opt.n_epochs is None: 38 | assert opt.n_steps, "You must specify one of n_epochs or n_steps." 39 | opt.n_epochs = int( 40 | opt.n_steps / np.ceil(dataset_size) 41 | ) # / opt.batch_size divide by bs seems wierd 42 | opt.n_epochs_decay = int(opt.n_steps_decay / np.ceil(dataset_size / opt.batch_size)) 43 | total_iters = 0 44 | model = create_model(opt) 45 | model.setup(opt) # regular setup: load and print networks; create schedulers 46 | if opt.continue_train: 47 | opt.epoch_count = int(opt.epoch) + 1 48 | if opt.overwrite_lambdas: 49 | # Setting parameters here will overwrite the previous code 50 | history = torch.load( 51 | os.path.join(model.save_dir, opt.epoch + "_others.pth"), map_location='cuda:0' 52 | ) 53 | for name in model.lambda_names: 54 | if isinstance(name, str): 55 | setattr(model, "lambda_" + name, history["lambda_" + name]) 56 | print( 57 | "lambdas overwritten args", 58 | "lambda_" + name, 59 | getattr(model, "lambda_" + name, None), 60 | ) 61 | total_iters = history["total_iters"] 62 | model.jitter_rgb = history["jitter_rgb"] 63 | model.do_cam_adj = history["do_cam_adj"] 64 | # Assume when continue by loading, there're already plenty of epochs passed 65 | # such that mask loss is no longer needed (set to 0) 66 | # model.mask_loss_rolloff_epoch = 0 67 | model.mask_loss_rolloff_epoch = history["mask_loss_rolloff_epoch"] 68 | print( 69 | "other params overwritten args", 70 | model.jitter_rgb, 71 | model.do_cam_adj, 72 | total_iters, 73 | opt.epoch_count, 74 | model.mask_loss_rolloff_epoch, 75 | ) 76 | 77 | for i in range(len(model.discriminators)): 78 | if (model.discriminators[i] is not None) and ( 79 | "discriminator_l" + str(i) in history 80 | ): 81 | model.discriminators[i].load_state_dict( 82 | history["discriminator_l" + str(i)], strict=False 83 | ) 84 | print(i, "th discriminator weights loaded unstrictly") 85 | print( 86 | "the dict in the history is", 87 | history["discriminator_l" + str(i)].keys(), 88 | ) 89 | print( 90 | "the dict in current model is", 91 | model.discriminators[i].state_dict().keys(), 92 | ) 93 | model.discriminators[i].train() 94 | 95 | if opt.overwrite_lrs: 96 | print("lr overwritten args", history["lrs"]) 97 | for i in range(len(model.optimizers)): 98 | optimizer = model.optimizers[i] 99 | for g in optimizer.param_groups: 100 | g["lr"] = history["lrs"][i] 101 | 102 | visualizer = Visualizer(opt) 103 | train(model, dataset, visualizer, opt, total_iters) 104 | 105 | 106 | def train(model, dataset, visualizer, opt, total_iters): 107 | dataset_size = len(dataset) 108 | for epoch in range( 109 | opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1 110 | ): # outer loop for different epochs; we save the model by , + 111 | epoch_start_time = time.time() # timer for entire epoch 112 | iter_data_time = time.time() # timer for data loading per iteration 113 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 114 | model.update_lambdas(epoch) 115 | if epoch == opt.epoch_count: 116 | save_result = True 117 | dp = dataset.dataset[opt.display_ind] 118 | for k, v in dp.items(): 119 | if torch.is_tensor(v): 120 | dp[k] = v.unsqueeze(0) 121 | else: 122 | dp[k] = [v] 123 | model.set_input(dp) 124 | model.compute_visuals(total_iters) 125 | visualizer.display_current_results( 126 | model.get_current_visuals(), 0, save_result 127 | ) 128 | 129 | for i, data in enumerate(dataset): # inner loop within one epoch 130 | iter_start_time = time.time() # timer for computation per iteration 131 | if i % opt.print_freq == 0: 132 | t_data = iter_start_time - iter_data_time 133 | # #iters are not exact because the last batch might not suffice. 134 | total_iters += opt.batch_size 135 | epoch_iter += opt.batch_size 136 | model.set_input(data) 137 | model.optimize_parameters(total_iters, epoch) 138 | 139 | if ( 140 | i % opt.print_freq == 0 141 | ): # print training losses and save logging information to the disk 142 | print(opt.name) 143 | losses = model.get_current_losses() 144 | t_comp = (time.time() - iter_start_time) / opt.batch_size 145 | visualizer.print_current_losses( 146 | epoch, epoch_iter, losses, t_comp, t_data 147 | ) 148 | if opt.display_id > 0: 149 | visualizer.plot_current_losses( 150 | epoch, float(epoch_iter) / dataset_size, losses 151 | ) 152 | iter_data_time = time.time() 153 | 154 | if ( 155 | epoch % opt.display_freq == 1 156 | ): # display images on visdom and save images to a HTML file 157 | save_result = epoch % opt.update_html_freq == 1 158 | dp = dataset.dataset[opt.display_ind] 159 | for k, v in dp.items(): 160 | if torch.is_tensor(v): 161 | dp[k] = v.unsqueeze(0) 162 | else: 163 | dp[k] = [v] 164 | model.set_input(dp) 165 | model.compute_visuals(total_iters) 166 | visualizer.display_current_results( 167 | model.get_current_visuals(), epoch, save_result 168 | ) 169 | 170 | if ( 171 | epoch % opt.save_latest_freq == 0 or epoch == opt.epoch_count 172 | ): # opt.n_epochs + opt.n_epochs_decay: # cache our latest model every epochs 173 | print( 174 | "saving the latest model (epoch %d, total_iters %d)" 175 | % (epoch, total_iters) 176 | ) 177 | save_suffix = "epoch_%d" % epoch if opt.save_by_epoch else "latest" 178 | model.save_networks(save_suffix) 179 | others = { 180 | "lrs": [i.param_groups[0]["lr"] for i in model.optimizers], 181 | "jitter_rgb": model.jitter_rgb, 182 | "do_cam_adj": model.do_cam_adj, 183 | "total_iters": total_iters, 184 | } 185 | for i in range(len(model.discriminators)): 186 | if model.discriminators[i] is not None: 187 | others["discriminator_l" + str(i)] = model.discriminators[ 188 | i 189 | ].state_dict() 190 | for name in model.lambda_names: 191 | if isinstance(name, str): 192 | others["lambda_" + name] = float(getattr(model, "lambda_" + name)) 193 | others["lambda_Ds"] = torch.tensor(model.lambda_Ds) 194 | others["lambda_plausibles"] = torch.tensor(model.lambda_plausibles) 195 | others["mask_loss_rolloff_epoch"] = model.mask_loss_rolloff_epoch 196 | torch.save( 197 | others, 198 | os.path.join(opt.checkpoints_dir, opt.name, str(epoch) + "_others.pth"), 199 | ) 200 | 201 | if ((epoch == 1) or (epoch % opt.update_D_epochs == 0)) and ( 202 | model.optimizer_D is not None 203 | ): 204 | model.update_learning_rate([1]) 205 | model.update_learning_rate( 206 | [0] 207 | ) # update learning rates at the end of every epoch. 208 | print( 209 | "End of epoch %d / %d \t Time Taken: %d sec" 210 | % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time) 211 | ) 212 | 213 | 214 | def see_grad(model, dataset, visualizer, opt): 215 | total_iters = 0 # the total number of training iterations 216 | for f in os.listdir(opt.ckpt_dir): 217 | if "net_Omnimatte.pth" in f: 218 | weight = torch.load(os.path.join(opt.ckpt_dir, f)) 219 | model.netOmnimatte.load_state_dict(weight) 220 | for epoch in range( 221 | 1 222 | ): # outer loop for different epochs; we save the model by , + 223 | epoch_start_time = time.time() # timer for entire epoch 224 | iter_data_time = time.time() # timer for data loading per iteration 225 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch 226 | model.update_lambdas(epoch) 227 | for i, data in enumerate(dataset): # inner loop within one epoch 228 | if i == 0: 229 | iter_start_time = time.time() # timer for computation per iteration 230 | if i % opt.print_freq == 0: 231 | t_data = iter_start_time - iter_data_time 232 | 233 | total_iters += opt.batch_size 234 | epoch_iter += opt.batch_size 235 | model.set_input(data) 236 | model.optimize_parameters(total_iters) 237 | else: 238 | break 239 | 240 | 241 | if __name__ == "__main__": 242 | main() 243 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from PIL import Image 5 | 6 | 7 | def numpy2im(np_array): 8 | """convert numpy float image to PIL Image""" 9 | return Image.fromarray((np_array * 255).astype(np.uint8)) 10 | 11 | 12 | def readFlow(fn): 13 | """Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, "rb") as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print("Magic number incorrect. Invalid .flo file") 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | 34 | def resize_flow(flow, width, height): 35 | orig_h, orig_w = flow.shape[1:] 36 | flow = F.interpolate(flow.unsqueeze(0), (height, width), mode="bilinear").squeeze(0) 37 | flow[0] *= width / orig_w 38 | flow[1] *= height / orig_h 39 | return flow 40 | 41 | 42 | def tensor_flow_to_image( 43 | flow_uv, clip_flow=None, convert_to_bgr=False, global_max=None 44 | ): 45 | flow_np = flow_uv.permute(1, 2, 0).cpu().numpy() 46 | image = flow_to_image(flow_np, clip_flow, convert_to_bgr, global_max) 47 | image = torch.from_numpy(image).permute(2, 0, 1) 48 | return image.float() / 255.0 * 2 - 1 49 | 50 | 51 | # The following flow visualization code is from https://github.com/tomrunia/OpticalFlow_Visualization 52 | def make_colorwheel(): 53 | """ 54 | Generates a color wheel for optical flow visualization as presented in: 55 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 56 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 57 | Code follows the original C++ source code of Daniel Scharstein. 58 | Code follows the the Matlab source code of Deqing Sun. 59 | Returns: 60 | np.ndarray: Color wheel 61 | """ 62 | 63 | RY = 15 64 | YG = 6 65 | GC = 4 66 | CB = 11 67 | BM = 13 68 | MR = 6 69 | 70 | ncols = RY + YG + GC + CB + BM + MR 71 | colorwheel = np.zeros((ncols, 3)) 72 | col = 0 73 | 74 | # RY 75 | colorwheel[0:RY, 0] = 255 76 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 77 | col = col + RY 78 | # YG 79 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 80 | colorwheel[col : col + YG, 1] = 255 81 | col = col + YG 82 | # GC 83 | colorwheel[col : col + GC, 1] = 255 84 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 85 | col = col + GC 86 | # CB 87 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 88 | colorwheel[col : col + CB, 2] = 255 89 | col = col + CB 90 | # BM 91 | colorwheel[col : col + BM, 2] = 255 92 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 93 | col = col + BM 94 | # MR 95 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 96 | colorwheel[col : col + MR, 0] = 255 97 | return colorwheel 98 | 99 | 100 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 101 | """ 102 | Applies the flow color wheel to (possibly clipped) flow components u and v. 103 | According to the C++ source code of Daniel Scharstein 104 | According to the Matlab source code of Deqing Sun 105 | Args: 106 | u (np.ndarray): Input horizontal flow of shape [H,W] 107 | v (np.ndarray): Input vertical flow of shape [H,W] 108 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 109 | Returns: 110 | np.ndarray: Flow visualization image of shape [H,W,3] 111 | """ 112 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 113 | colorwheel = make_colorwheel() # shape [55x3] 114 | ncols = colorwheel.shape[0] 115 | rad = np.sqrt(np.square(u) + np.square(v)) 116 | a = np.arctan2(-v, -u) / np.pi 117 | fk = (a + 1) / 2 * (ncols - 1) 118 | k0 = np.floor(fk).astype(np.int32) 119 | k1 = k0 + 1 120 | k1[k1 == ncols] = 0 121 | f = fk - k0 122 | for i in range(colorwheel.shape[1]): 123 | tmp = colorwheel[:, i] 124 | col0 = tmp[k0] / 255.0 125 | col1 = tmp[k1] / 255.0 126 | col = (1 - f) * col0 + f * col1 127 | idx = rad <= 1 128 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 129 | col[~idx] = col[~idx] * 0.75 # out of range 130 | # Note the 2-i => BGR instead of RGB 131 | ch_idx = 2 - i if convert_to_bgr else i 132 | flow_image[:, :, ch_idx] = np.floor(255 * col) 133 | return flow_image 134 | 135 | 136 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, global_max=None): 137 | """ 138 | Expects a two dimensional flow image of shape. 139 | Args: 140 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 141 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 142 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 143 | Returns: 144 | np.ndarray: Flow visualization image of shape [H,W,3] 145 | """ 146 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 147 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 148 | if clip_flow is not None: 149 | flow_uv = np.clip(flow_uv, 0, clip_flow) 150 | u = flow_uv[:, :, 0] 151 | v = flow_uv[:, :, 1] 152 | rad = np.sqrt(np.square(u) + np.square(v)) 153 | rad_max = global_max if global_max else np.max(rad) 154 | # import pdb 155 | 156 | # pdb.set_trace() 157 | epsilon = 1e-5 158 | u = u / (rad_max + epsilon) 159 | v = v / (rad_max + epsilon) 160 | return flow_uv_to_colors(u, v, convert_to_bgr) 161 | -------------------------------------------------------------------------------- /video_completion.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..'))) 4 | 5 | import argparse 6 | import os 7 | import numpy as np 8 | import torch 9 | from PIL import Image 10 | import glob 11 | import torchvision.transforms.functional as F 12 | 13 | from RAFT import utils 14 | from RAFT import RAFT 15 | 16 | 17 | def create_dir(dir): 18 | """Creates a directory if not exist. 19 | """ 20 | if not os.path.exists(dir): 21 | os.makedirs(dir) 22 | 23 | 24 | def initialize_RAFT(args): 25 | """Initializes the RAFT model. 26 | """ 27 | model = torch.nn.DataParallel(RAFT(args)) 28 | model.load_state_dict(torch.load(args.model)) 29 | 30 | model = model.module 31 | model.to('cuda') 32 | model.eval() 33 | 34 | return model 35 | 36 | 37 | def calculate_flow(args, model, video, mode): 38 | """Calculates optical flow. 39 | """ 40 | if mode not in ['forward', 'backward']: 41 | raise NotImplementedError 42 | 43 | nFrame, _, imgH, imgW = video.shape 44 | Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) 45 | 46 | # if os.path.isdir(os.path.join(args.outroot, 'flow', mode + '_flo')): 47 | # for flow_name in sorted(glob.glob(os.path.join(args.outroot, 'flow', mode + '_flo', '*.flo'))): 48 | # print("Loading {0}".format(flow_name), '\r', end='') 49 | # flow = utils.frame_utils.readFlow(flow_name) 50 | # Flow = np.concatenate((Flow, flow[..., None]), axis=-1) 51 | # return Flow 52 | flow_folder = 'flow' + args.path.replace("/","") 53 | create_dir(os.path.join(args.outroot, flow_folder, mode + '_flo')) 54 | create_dir(os.path.join(args.outroot, flow_folder, mode + '_png')) 55 | 56 | with torch.no_grad(): 57 | for i in range(video.shape[0] - 1): 58 | print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='') 59 | if mode == 'forward': 60 | # Flow i -> i + 1 61 | image1 = video[i, None] 62 | image2 = video[i + 1, None] 63 | elif mode == 'backward': 64 | # Flow i + 1 -> i 65 | image1 = video[i + 1, None] 66 | image2 = video[i, None] 67 | else: 68 | raise NotImplementedError 69 | 70 | _, flow = model(image1, image2, iters=20, test_mode=True) 71 | flow = flow[0].permute(1, 2, 0).cpu().numpy() 72 | Flow = np.concatenate((Flow, flow[..., None]), axis=-1) 73 | 74 | # Flow visualization. 75 | flow_img = utils.flow_viz.flow_to_image(flow) 76 | flow_img = Image.fromarray(flow_img) 77 | 78 | # Saves the flow and flow_img. 79 | flow_img.save(os.path.join(args.outroot, flow_folder, mode + '_png', '%05d.png'%i)) 80 | # np.save(os.path.join(args.outroot, 'flow', mode + '_flo', '%05d.npy'%i), flow) 81 | utils.frame_utils.writeFlow(os.path.join(args.outroot, flow_folder, mode + '_flo', '%05d.flo'%i), flow) 82 | 83 | return Flow 84 | 85 | 86 | def calculate_flow_global(args, model, video, mode, step=1): 87 | """Calculates optical flow. 88 | """ 89 | if mode not in ['forward', 'backward']: 90 | raise NotImplementedError 91 | 92 | nFrame, _, imgH, imgW = video.shape 93 | Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32) 94 | 95 | # if os.path.isdir(os.path.join(args.outroot, 'flow', mode + '_flo')): 96 | # for flow_name in sorted(glob.glob(os.path.join(args.outroot, 'flow', mode + '_flo', '*.flo'))): 97 | # print("Loading {0}".format(flow_name), '\r', end='') 98 | # flow = utils.frame_utils.readFlow(flow_name) 99 | # Flow = np.concatenate((Flow, flow[..., None]), axis=-1) 100 | # return Flow 101 | flow_folder = args.path.replace("/","") 102 | create_dir(os.path.join(args.outroot, flow_folder, mode + '_flow_step' + str(step))) 103 | create_dir(os.path.join(args.outroot, flow_folder, mode + '_png_step' + str(step))) 104 | global_max = -10000000 105 | with torch.no_grad(): 106 | # for i in range(10): 107 | for i in range(video.shape[0] - step): 108 | print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + step), '\r', end='') 109 | if mode == 'forward': 110 | # Flow i -> i + 1 111 | image1 = video[i, None] 112 | image2 = video[i + step, None] 113 | elif mode == 'backward': 114 | # Flow i + 1 -> i 115 | image1 = video[i + step, None] 116 | image2 = video[i, None] 117 | else: 118 | raise NotImplementedError 119 | 120 | _, flow = model(image1, image2, iters = 20, test_mode = True) 121 | flow_max = torch.sqrt(flow[0,0,:,:] ** 2 + flow[0, 1, :, :] ** 2).max() 122 | global_max = max(global_max, flow_max.cpu().numpy()) 123 | print(global_max) 124 | flow = flow[0].permute(1, 2, 0).cpu().numpy() 125 | Flow = np.concatenate((Flow, flow[..., None]), axis = -1) 126 | 127 | for j in range(Flow.shape[-1]): 128 | flow=Flow[:,:,:,j] 129 | print(j) 130 | # Flow visualization. 131 | flow_img = utils.flow_viz.flow_to_image(flow, rad_max=global_max) 132 | flow_img = Image.fromarray(flow_img) 133 | 134 | # Saves the flow and flow_img. 135 | flow_img.save(os.path.join(args.outroot, flow_folder, mode + '_png_step' + str(step), '%05d.png'%j)) 136 | utils.frame_utils.writeFlow(os.path.join(args.outroot, flow_folder, mode + '_flow_step' + str(step), '%05d.flo'%j), flow) 137 | 138 | return Flow 139 | 140 | def video_completion(args): 141 | 142 | # Flow model. 143 | RAFT_model = initialize_RAFT(args) 144 | 145 | # Loads frames. 146 | filename_list = glob.glob(os.path.join(args.path, '*.png')) 147 | # glob.glob(os.path.join(args.path, '*.jpg')) 148 | 149 | # Obtains imgH, imgW and nFrame. 150 | imgH, imgW = np.array(Image.open(filename_list[0]).convert('RGB')).shape[:2] 151 | nFrame = len(filename_list) 152 | 153 | # Loads video. 154 | video = [] 155 | for filename in sorted(filename_list): 156 | video.append(torch.from_numpy(np.array(Image.open(filename).convert('RGB')).astype(np.uint8)).permute(2, 0, 1).float()) 157 | 158 | video = torch.stack(video, dim=0) 159 | video = video.to('cuda') 160 | 161 | # Calcutes the corrupted flow. 162 | print('STEP', str(args.step)) 163 | corrFlowF = calculate_flow_global(args, RAFT_model, video, 'forward', step=args.step) #_interval 164 | corrFlowB = calculate_flow_global(args, RAFT_model, video, 'backward', step=args.step) #_interval 165 | print('\nFinish flow prediction.') 166 | 167 | 168 | 169 | 170 | if __name__ == '__main__': 171 | parser = argparse.ArgumentParser() 172 | # video completion 173 | parser.add_argument('--seamless', action='store_true', help='Whether operate in the gradient domain') 174 | parser.add_argument('--edge_guide', action='store_true', help='Whether use edge as guidance to complete flow') 175 | parser.add_argument('--mode', default='object_removal', help="modes: object_removal / video_extrapolation") 176 | parser.add_argument('--path', default='../data/tennis', help="dataset for evaluation") 177 | parser.add_argument('--outroot', default='RAFT_result/', help="output directory") 178 | parser.add_argument('--consistencyThres', dest='consistencyThres', default=np.inf, type=float, help='flow consistency error threshold') 179 | parser.add_argument('--alpha', dest='alpha', default=0.1, type=float) 180 | parser.add_argument('--Nonlocal', dest='Nonlocal', default=False, type=bool) 181 | parser.add_argument('--step', default=1, type=int) 182 | 183 | # RAFT 184 | parser.add_argument('--model', default='weight/raft-things.pth', help="restore checkpoint") 185 | parser.add_argument('--small', action='store_true', help='use small model') 186 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 187 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 188 | 189 | args = parser.parse_args() 190 | 191 | video_completion(args) 192 | -------------------------------------------------------------------------------- /weight/edge_completion.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/weight/edge_completion.pth -------------------------------------------------------------------------------- /weight/imagenet_deepfill.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/weight/imagenet_deepfill.pth -------------------------------------------------------------------------------- /weight/raft-things.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/weight/raft-things.pth --------------------------------------------------------------------------------