├── RAFT
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── corr.cpython-36.pyc
│ ├── corr.cpython-37.pyc
│ ├── corr.cpython-38.pyc
│ ├── corr.cpython-39.pyc
│ ├── extractor.cpython-36.pyc
│ ├── extractor.cpython-37.pyc
│ ├── extractor.cpython-38.pyc
│ ├── extractor.cpython-39.pyc
│ ├── raft.cpython-36.pyc
│ ├── raft.cpython-37.pyc
│ ├── raft.cpython-38.pyc
│ ├── raft.cpython-39.pyc
│ ├── update.cpython-36.pyc
│ ├── update.cpython-37.pyc
│ ├── update.cpython-38.pyc
│ └── update.cpython-39.pyc
├── corr.py
├── datasets.py
├── demo.py
├── extractor.py
├── raft.py
├── update.py
└── utils
│ ├── __init__.py
│ ├── __pycache__
│ ├── __init__.cpython-36.pyc
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── flow_viz.cpython-36.pyc
│ ├── flow_viz.cpython-37.pyc
│ ├── flow_viz.cpython-38.pyc
│ ├── flow_viz.cpython-39.pyc
│ ├── frame_utils.cpython-36.pyc
│ ├── frame_utils.cpython-37.pyc
│ ├── frame_utils.cpython-38.pyc
│ ├── frame_utils.cpython-39.pyc
│ ├── utils.cpython-36.pyc
│ ├── utils.cpython-37.pyc
│ ├── utils.cpython-38.pyc
│ └── utils.cpython-39.pyc
│ ├── augmentor.py
│ ├── flow_viz.py
│ ├── frame_utils.py
│ └── utils.py
├── README.md
├── causal
├── discriminator.py
└── gen.py
├── data
├── factormatte_GANCGANFlip148_dataset.py
├── gen_backgroundPosEx-Copy1.ipynb
├── gen_backgroundPosEx.ipynb
├── gen_foregroundPosEx.py
├── homographies.txt
├── keypoint_homo_short.ipynb
├── misc_data_process.py
└── noninteraction_ind.txt
├── datasets
├── confidence.py
└── homography.py
├── models
├── factormatte_GANFlip_model.py
└── networks.py
├── options
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── base_options.cpython-38.pyc
│ ├── base_options.cpython-39.pyc
│ ├── test_options.cpython-39.pyc
│ ├── train_options.cpython-38.pyc
│ └── train_options.cpython-39.pyc
├── base_options.py
├── test_options.py
└── train_options.py
├── prepare_data_stage1.sh
├── requirements.txt
├── test.py
├── third_party
├── __init__.py
├── __init__.pyc
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── __init__.cpython-39.pyc
├── data
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── base_dataset.cpython-38.pyc
│ │ ├── base_dataset.cpython-39.pyc
│ │ ├── image_folder.cpython-38.pyc
│ │ └── image_folder.cpython-39.pyc
│ ├── base_dataset.py
│ └── image_folder.py
├── models
│ ├── __init__.py
│ ├── __pycache__
│ │ ├── __init__.cpython-38.pyc
│ │ ├── __init__.cpython-39.pyc
│ │ ├── base_model.cpython-38.pyc
│ │ ├── base_model.cpython-39.pyc
│ │ ├── networks.cpython-38.pyc
│ │ ├── networks.cpython-39.pyc
│ │ ├── networks_lnr.cpython-38.pyc
│ │ └── networks_lnr.cpython-39.pyc
│ ├── base_model.py
│ ├── networks.py
│ └── networks_lnr.py
└── util
│ ├── __init__.py
│ ├── __init__.pyc
│ ├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── __init__.cpython-39.pyc
│ ├── html.cpython-38.pyc
│ ├── html.cpython-39.pyc
│ ├── util.cpython-38.pyc
│ ├── util.cpython-39.pyc
│ ├── visualizer.cpython-38.pyc
│ └── visualizer.cpython-39.pyc
│ ├── html.py
│ ├── util.py
│ ├── util.pyc
│ └── visualizer.py
├── train_GAN.py
├── utils.py
├── video_completion.py
└── weight
├── edge_completion.pth
├── imagenet_deepfill.pth
└── raft-things.pth
/RAFT/__init__.py:
--------------------------------------------------------------------------------
1 | # from .demo import RAFT_infer
2 | from .raft import RAFT
3 |
--------------------------------------------------------------------------------
/RAFT/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/corr.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/corr.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/corr.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/corr.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/corr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/corr.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/corr.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/corr.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/extractor.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/extractor.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/extractor.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/extractor.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/extractor.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/extractor.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/extractor.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/extractor.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/raft.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/raft.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/raft.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/raft.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/raft.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/raft.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/raft.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/raft.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/update.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/update.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/update.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/update.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/update.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/update.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/__pycache__/update.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/__pycache__/update.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/corr.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from .utils.utils import bilinear_sampler, coords_grid
4 |
5 | try:
6 | import alt_cuda_corr
7 | except:
8 | # alt_cuda_corr is not compiled
9 | pass
10 |
11 |
12 | class CorrBlock:
13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
14 | self.num_levels = num_levels
15 | self.radius = radius
16 | self.corr_pyramid = []
17 |
18 | # all pairs correlation
19 | corr = CorrBlock.corr(fmap1, fmap2)
20 |
21 | batch, h1, w1, dim, h2, w2 = corr.shape
22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2)
23 |
24 | self.corr_pyramid.append(corr)
25 | for i in range(self.num_levels-1):
26 | corr = F.avg_pool2d(corr, 2, stride=2)
27 | self.corr_pyramid.append(corr)
28 |
29 | def __call__(self, coords):
30 | r = self.radius
31 | coords = coords.permute(0, 2, 3, 1)
32 | batch, h1, w1, _ = coords.shape
33 |
34 | out_pyramid = []
35 | for i in range(self.num_levels):
36 | corr = self.corr_pyramid[i]
37 | dx = torch.linspace(-r, r, 2*r+1)
38 | dy = torch.linspace(-r, r, 2*r+1)
39 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
40 |
41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
43 | coords_lvl = centroid_lvl + delta_lvl
44 |
45 | corr = bilinear_sampler(corr, coords_lvl)
46 | corr = corr.view(batch, h1, w1, -1)
47 | out_pyramid.append(corr)
48 |
49 | out = torch.cat(out_pyramid, dim=-1)
50 | return out.permute(0, 3, 1, 2).contiguous().float()
51 |
52 | @staticmethod
53 | def corr(fmap1, fmap2):
54 | batch, dim, ht, wd = fmap1.shape
55 | fmap1 = fmap1.view(batch, dim, ht*wd)
56 | fmap2 = fmap2.view(batch, dim, ht*wd)
57 |
58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2)
59 | corr = corr.view(batch, ht, wd, 1, ht, wd)
60 | return corr / torch.sqrt(torch.tensor(dim).float())
61 |
62 |
63 | class CorrLayer(torch.autograd.Function):
64 | @staticmethod
65 | def forward(ctx, fmap1, fmap2, coords, r):
66 | fmap1 = fmap1.contiguous()
67 | fmap2 = fmap2.contiguous()
68 | coords = coords.contiguous()
69 | ctx.save_for_backward(fmap1, fmap2, coords)
70 | ctx.r = r
71 | corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
72 | return corr
73 |
74 | @staticmethod
75 | def backward(ctx, grad_corr):
76 | fmap1, fmap2, coords = ctx.saved_tensors
77 | grad_corr = grad_corr.contiguous()
78 | fmap1_grad, fmap2_grad, coords_grad = \
79 | correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
80 | return fmap1_grad, fmap2_grad, coords_grad, None
81 |
82 |
83 | class AlternateCorrBlock:
84 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
85 | self.num_levels = num_levels
86 | self.radius = radius
87 |
88 | self.pyramid = [(fmap1, fmap2)]
89 | for i in range(self.num_levels):
90 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
91 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
92 | self.pyramid.append((fmap1, fmap2))
93 |
94 | def __call__(self, coords):
95 |
96 | coords = coords.permute(0, 2, 3, 1)
97 | B, H, W, _ = coords.shape
98 |
99 | corr_list = []
100 | for i in range(self.num_levels):
101 | r = self.radius
102 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
103 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
104 |
105 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
106 | corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
107 | corr_list.append(corr.squeeze(1))
108 |
109 | corr = torch.stack(corr_list, dim=1)
110 | corr = corr.reshape(B, -1, H, W)
111 | return corr / 16.0
112 |
--------------------------------------------------------------------------------
/RAFT/datasets.py:
--------------------------------------------------------------------------------
1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch
2 |
3 | import numpy as np
4 | import torch
5 | import torch.utils.data as data
6 | import torch.nn.functional as F
7 |
8 | import os
9 | import math
10 | import random
11 | from glob import glob
12 | import os.path as osp
13 |
14 | from utils import frame_utils
15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
16 |
17 |
18 | class FlowDataset(data.Dataset):
19 | def __init__(self, aug_params=None, sparse=False):
20 | self.augmentor = None
21 | self.sparse = sparse
22 | if aug_params is not None:
23 | if sparse:
24 | self.augmentor = SparseFlowAugmentor(**aug_params)
25 | else:
26 | self.augmentor = FlowAugmentor(**aug_params)
27 |
28 | self.is_test = False
29 | self.init_seed = False
30 | self.flow_list = []
31 | self.image_list = []
32 | self.extra_info = []
33 |
34 | def __getitem__(self, index):
35 |
36 | if self.is_test:
37 | img1 = frame_utils.read_gen(self.image_list[index][0])
38 | img2 = frame_utils.read_gen(self.image_list[index][1])
39 | img1 = np.array(img1).astype(np.uint8)[..., :3]
40 | img2 = np.array(img2).astype(np.uint8)[..., :3]
41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
43 | return img1, img2, self.extra_info[index]
44 |
45 | if not self.init_seed:
46 | worker_info = torch.utils.data.get_worker_info()
47 | if worker_info is not None:
48 | torch.manual_seed(worker_info.id)
49 | np.random.seed(worker_info.id)
50 | random.seed(worker_info.id)
51 | self.init_seed = True
52 |
53 | index = index % len(self.image_list)
54 | valid = None
55 | if self.sparse:
56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
57 | else:
58 | flow = frame_utils.read_gen(self.flow_list[index])
59 |
60 | img1 = frame_utils.read_gen(self.image_list[index][0])
61 | img2 = frame_utils.read_gen(self.image_list[index][1])
62 |
63 | flow = np.array(flow).astype(np.float32)
64 | img1 = np.array(img1).astype(np.uint8)
65 | img2 = np.array(img2).astype(np.uint8)
66 |
67 | # grayscale images
68 | if len(img1.shape) == 2:
69 | img1 = np.tile(img1[...,None], (1, 1, 3))
70 | img2 = np.tile(img2[...,None], (1, 1, 3))
71 | else:
72 | img1 = img1[..., :3]
73 | img2 = img2[..., :3]
74 |
75 | if self.augmentor is not None:
76 | if self.sparse:
77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
78 | else:
79 | img1, img2, flow = self.augmentor(img1, img2, flow)
80 |
81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float()
84 |
85 | if valid is not None:
86 | valid = torch.from_numpy(valid)
87 | else:
88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
89 |
90 | return img1, img2, flow, valid.float()
91 |
92 |
93 | def __rmul__(self, v):
94 | self.flow_list = v * self.flow_list
95 | self.image_list = v * self.image_list
96 | return self
97 |
98 | def __len__(self):
99 | return len(self.image_list)
100 |
101 |
102 | class MpiSintel(FlowDataset):
103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
104 | super(MpiSintel, self).__init__(aug_params)
105 | flow_root = osp.join(root, split, 'flow')
106 | image_root = osp.join(root, split, dstype)
107 |
108 | if split == 'test':
109 | self.is_test = True
110 |
111 | for scene in os.listdir(image_root):
112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
113 | for i in range(len(image_list)-1):
114 | self.image_list += [ [image_list[i], image_list[i+1]] ]
115 | self.extra_info += [ (scene, i) ] # scene and frame_id
116 |
117 | if split != 'test':
118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
119 |
120 |
121 | class FlyingChairs(FlowDataset):
122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
123 | super(FlyingChairs, self).__init__(aug_params)
124 |
125 | images = sorted(glob(osp.join(root, '*.ppm')))
126 | flows = sorted(glob(osp.join(root, '*.flo')))
127 | assert (len(images)//2 == len(flows))
128 |
129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
130 | for i in range(len(flows)):
131 | xid = split_list[i]
132 | if (split=='training' and xid==1) or (split=='validation' and xid==2):
133 | self.flow_list += [ flows[i] ]
134 | self.image_list += [ [images[2*i], images[2*i+1]] ]
135 |
136 |
137 | class FlyingThings3D(FlowDataset):
138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
139 | super(FlyingThings3D, self).__init__(aug_params)
140 |
141 | for cam in ['left']:
142 | for direction in ['into_future', 'into_past']:
143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
145 |
146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
148 |
149 | for idir, fdir in zip(image_dirs, flow_dirs):
150 | images = sorted(glob(osp.join(idir, '*.png')) )
151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) )
152 | for i in range(len(flows)-1):
153 | if direction == 'into_future':
154 | self.image_list += [ [images[i], images[i+1]] ]
155 | self.flow_list += [ flows[i] ]
156 | elif direction == 'into_past':
157 | self.image_list += [ [images[i+1], images[i]] ]
158 | self.flow_list += [ flows[i+1] ]
159 |
160 |
161 | class KITTI(FlowDataset):
162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
163 | super(KITTI, self).__init__(aug_params, sparse=True)
164 | if split == 'testing':
165 | self.is_test = True
166 |
167 | root = osp.join(root, split)
168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
170 |
171 | for img1, img2 in zip(images1, images2):
172 | frame_id = img1.split('/')[-1]
173 | self.extra_info += [ [frame_id] ]
174 | self.image_list += [ [img1, img2] ]
175 |
176 | if split == 'training':
177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
178 |
179 |
180 | class HD1K(FlowDataset):
181 | def __init__(self, aug_params=None, root='datasets/HD1k'):
182 | super(HD1K, self).__init__(aug_params, sparse=True)
183 |
184 | seq_ix = 0
185 | while 1:
186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
188 |
189 | if len(flows) == 0:
190 | break
191 |
192 | for i in range(len(flows)-1):
193 | self.flow_list += [flows[i]]
194 | self.image_list += [ [images[i], images[i+1]] ]
195 |
196 | seq_ix += 1
197 |
198 |
199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
200 | """ Create the data loader for the corresponding trainign set """
201 |
202 | if args.stage == 'chairs':
203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
204 | train_dataset = FlyingChairs(aug_params, split='training')
205 |
206 | elif args.stage == 'things':
207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
210 | train_dataset = clean_dataset + final_dataset
211 |
212 | elif args.stage == 'sintel':
213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final')
217 |
218 | if TRAIN_DS == 'C+T+K+S+H':
219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
222 |
223 | elif TRAIN_DS == 'C+T+K/S':
224 | train_dataset = 100*sintel_clean + 100*sintel_final + things
225 |
226 | elif args.stage == 'kitti':
227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
228 | train_dataset = KITTI(aug_params, split='training')
229 |
230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
232 |
233 | print('Training with %d image pairs' % len(train_dataset))
234 | return train_loader
235 |
236 |
--------------------------------------------------------------------------------
/RAFT/demo.py:
--------------------------------------------------------------------------------
1 | import sys
2 | import argparse
3 | import os
4 | import cv2
5 | import glob
6 | import numpy as np
7 | import torch
8 | from PIL import Image
9 |
10 | from .raft import RAFT
11 | from .utils import flow_viz
12 | from .utils.utils import InputPadder
13 |
14 |
15 |
16 | DEVICE = 'cuda'
17 |
18 | def load_image(imfile):
19 | img = np.array(Image.open(imfile)).astype(np.uint8)
20 | img = torch.from_numpy(img).permute(2, 0, 1).float()
21 | return img
22 |
23 |
24 | def load_image_list(image_files):
25 | images = []
26 | for imfile in sorted(image_files):
27 | images.append(load_image(imfile))
28 |
29 | images = torch.stack(images, dim=0)
30 | images = images.to(DEVICE)
31 |
32 | padder = InputPadder(images.shape)
33 | return padder.pad(images)[0]
34 |
35 |
36 | def viz(img, flo):
37 | img = img[0].permute(1,2,0).cpu().numpy()
38 | flo = flo[0].permute(1,2,0).cpu().numpy()
39 |
40 | # map flow to rgb image
41 | flo = flow_viz.flow_to_image(flo)
42 | # img_flo = np.concatenate([img, flo], axis=0)
43 | img_flo = flo
44 |
45 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
46 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
47 | # cv2.waitKey()
48 |
49 |
50 | def demo(args):
51 | model = torch.nn.DataParallel(RAFT(args))
52 | model.load_state_dict(torch.load(args.model))
53 |
54 | model = model.module
55 | model.to(DEVICE)
56 | model.eval()
57 |
58 | with torch.no_grad():
59 | images = glob.glob(os.path.join(args.path, '*.png')) + \
60 | glob.glob(os.path.join(args.path, '*.jpg'))
61 |
62 | images = load_image_list(images)
63 | for i in range(images.shape[0]-1):
64 | image1 = images[i,None]
65 | image2 = images[i+1,None]
66 |
67 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
68 | viz(image1, flow_up)
69 |
70 |
71 | def RAFT_infer(args):
72 | model = torch.nn.DataParallel(RAFT(args))
73 | model.load_state_dict(torch.load(args.model))
74 |
75 | model = model.module
76 | model.to(DEVICE)
77 | model.eval()
78 |
79 | return model
80 |
--------------------------------------------------------------------------------
/RAFT/extractor.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ResidualBlock(nn.Module):
7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
8 | super(ResidualBlock, self).__init__()
9 |
10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
12 | self.relu = nn.ReLU(inplace=True)
13 |
14 | num_groups = planes // 8
15 |
16 | if norm_fn == 'group':
17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
19 | if not stride == 1:
20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
21 |
22 | elif norm_fn == 'batch':
23 | self.norm1 = nn.BatchNorm2d(planes)
24 | self.norm2 = nn.BatchNorm2d(planes)
25 | if not stride == 1:
26 | self.norm3 = nn.BatchNorm2d(planes)
27 |
28 | elif norm_fn == 'instance':
29 | self.norm1 = nn.InstanceNorm2d(planes)
30 | self.norm2 = nn.InstanceNorm2d(planes)
31 | if not stride == 1:
32 | self.norm3 = nn.InstanceNorm2d(planes)
33 |
34 | elif norm_fn == 'none':
35 | self.norm1 = nn.Sequential()
36 | self.norm2 = nn.Sequential()
37 | if not stride == 1:
38 | self.norm3 = nn.Sequential()
39 |
40 | if stride == 1:
41 | self.downsample = None
42 |
43 | else:
44 | self.downsample = nn.Sequential(
45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
46 |
47 |
48 | def forward(self, x):
49 | y = x
50 | y = self.relu(self.norm1(self.conv1(y)))
51 | y = self.relu(self.norm2(self.conv2(y)))
52 |
53 | if self.downsample is not None:
54 | x = self.downsample(x)
55 |
56 | return self.relu(x+y)
57 |
58 |
59 |
60 | class BottleneckBlock(nn.Module):
61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1):
62 | super(BottleneckBlock, self).__init__()
63 |
64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
67 | self.relu = nn.ReLU(inplace=True)
68 |
69 | num_groups = planes // 8
70 |
71 | if norm_fn == 'group':
72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
75 | if not stride == 1:
76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
77 |
78 | elif norm_fn == 'batch':
79 | self.norm1 = nn.BatchNorm2d(planes//4)
80 | self.norm2 = nn.BatchNorm2d(planes//4)
81 | self.norm3 = nn.BatchNorm2d(planes)
82 | if not stride == 1:
83 | self.norm4 = nn.BatchNorm2d(planes)
84 |
85 | elif norm_fn == 'instance':
86 | self.norm1 = nn.InstanceNorm2d(planes//4)
87 | self.norm2 = nn.InstanceNorm2d(planes//4)
88 | self.norm3 = nn.InstanceNorm2d(planes)
89 | if not stride == 1:
90 | self.norm4 = nn.InstanceNorm2d(planes)
91 |
92 | elif norm_fn == 'none':
93 | self.norm1 = nn.Sequential()
94 | self.norm2 = nn.Sequential()
95 | self.norm3 = nn.Sequential()
96 | if not stride == 1:
97 | self.norm4 = nn.Sequential()
98 |
99 | if stride == 1:
100 | self.downsample = None
101 |
102 | else:
103 | self.downsample = nn.Sequential(
104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
105 |
106 |
107 | def forward(self, x):
108 | y = x
109 | y = self.relu(self.norm1(self.conv1(y)))
110 | y = self.relu(self.norm2(self.conv2(y)))
111 | y = self.relu(self.norm3(self.conv3(y)))
112 |
113 | if self.downsample is not None:
114 | x = self.downsample(x)
115 |
116 | return self.relu(x+y)
117 |
118 | class BasicEncoder(nn.Module):
119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
120 | super(BasicEncoder, self).__init__()
121 | self.norm_fn = norm_fn
122 |
123 | if self.norm_fn == 'group':
124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
125 |
126 | elif self.norm_fn == 'batch':
127 | self.norm1 = nn.BatchNorm2d(64)
128 |
129 | elif self.norm_fn == 'instance':
130 | self.norm1 = nn.InstanceNorm2d(64)
131 |
132 | elif self.norm_fn == 'none':
133 | self.norm1 = nn.Sequential()
134 |
135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
136 | self.relu1 = nn.ReLU(inplace=True)
137 |
138 | self.in_planes = 64
139 | self.layer1 = self._make_layer(64, stride=1)
140 | self.layer2 = self._make_layer(96, stride=2)
141 | self.layer3 = self._make_layer(128, stride=2)
142 |
143 | # output convolution
144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
145 |
146 | self.dropout = None
147 | if dropout > 0:
148 | self.dropout = nn.Dropout2d(p=dropout)
149 |
150 | for m in self.modules():
151 | if isinstance(m, nn.Conv2d):
152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
154 | if m.weight is not None:
155 | nn.init.constant_(m.weight, 1)
156 | if m.bias is not None:
157 | nn.init.constant_(m.bias, 0)
158 |
159 | def _make_layer(self, dim, stride=1):
160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
162 | layers = (layer1, layer2)
163 |
164 | self.in_planes = dim
165 | return nn.Sequential(*layers)
166 |
167 |
168 | def forward(self, x):
169 |
170 | # if input is list, combine batch dimension
171 | is_list = isinstance(x, tuple) or isinstance(x, list)
172 | if is_list:
173 | batch_dim = x[0].shape[0]
174 | x = torch.cat(x, dim=0)
175 |
176 | x = self.conv1(x)
177 | x = self.norm1(x)
178 | x = self.relu1(x)
179 |
180 | x = self.layer1(x)
181 | x = self.layer2(x)
182 | x = self.layer3(x)
183 |
184 | x = self.conv2(x)
185 |
186 | if self.training and self.dropout is not None:
187 | x = self.dropout(x)
188 |
189 | if is_list:
190 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
191 |
192 | return x
193 |
194 |
195 | class SmallEncoder(nn.Module):
196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
197 | super(SmallEncoder, self).__init__()
198 | self.norm_fn = norm_fn
199 |
200 | if self.norm_fn == 'group':
201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
202 |
203 | elif self.norm_fn == 'batch':
204 | self.norm1 = nn.BatchNorm2d(32)
205 |
206 | elif self.norm_fn == 'instance':
207 | self.norm1 = nn.InstanceNorm2d(32)
208 |
209 | elif self.norm_fn == 'none':
210 | self.norm1 = nn.Sequential()
211 |
212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
213 | self.relu1 = nn.ReLU(inplace=True)
214 |
215 | self.in_planes = 32
216 | self.layer1 = self._make_layer(32, stride=1)
217 | self.layer2 = self._make_layer(64, stride=2)
218 | self.layer3 = self._make_layer(96, stride=2)
219 |
220 | self.dropout = None
221 | if dropout > 0:
222 | self.dropout = nn.Dropout2d(p=dropout)
223 |
224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
225 |
226 | for m in self.modules():
227 | if isinstance(m, nn.Conv2d):
228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
230 | if m.weight is not None:
231 | nn.init.constant_(m.weight, 1)
232 | if m.bias is not None:
233 | nn.init.constant_(m.bias, 0)
234 |
235 | def _make_layer(self, dim, stride=1):
236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
238 | layers = (layer1, layer2)
239 |
240 | self.in_planes = dim
241 | return nn.Sequential(*layers)
242 |
243 |
244 | def forward(self, x):
245 |
246 | # if input is list, combine batch dimension
247 | is_list = isinstance(x, tuple) or isinstance(x, list)
248 | if is_list:
249 | batch_dim = x[0].shape[0]
250 | x = torch.cat(x, dim=0)
251 |
252 | x = self.conv1(x)
253 | x = self.norm1(x)
254 | x = self.relu1(x)
255 |
256 | x = self.layer1(x)
257 | x = self.layer2(x)
258 | x = self.layer3(x)
259 | x = self.conv2(x)
260 |
261 | if self.training and self.dropout is not None:
262 | x = self.dropout(x)
263 |
264 | if is_list:
265 | x = torch.split(x, [batch_dim, batch_dim], dim=0)
266 |
267 | return x
268 |
--------------------------------------------------------------------------------
/RAFT/raft.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .update import BasicUpdateBlock, SmallUpdateBlock
7 | from .extractor import BasicEncoder, SmallEncoder
8 | from .corr import CorrBlock, AlternateCorrBlock
9 | from .utils.utils import bilinear_sampler, coords_grid, upflow8
10 |
11 | try:
12 | autocast = torch.cuda.amp.autocast
13 | except:
14 | # dummy autocast for PyTorch < 1.6
15 | class autocast:
16 | def __init__(self, enabled):
17 | pass
18 | def __enter__(self):
19 | pass
20 | def __exit__(self, *args):
21 | pass
22 |
23 |
24 | class RAFT(nn.Module):
25 | def __init__(self, args):
26 | super(RAFT, self).__init__()
27 | self.args = args
28 |
29 | if args.small:
30 | self.hidden_dim = hdim = 96
31 | self.context_dim = cdim = 64
32 | args.corr_levels = 4
33 | args.corr_radius = 3
34 |
35 | else:
36 | self.hidden_dim = hdim = 128
37 | self.context_dim = cdim = 128
38 | args.corr_levels = 4
39 | args.corr_radius = 4
40 |
41 | if 'dropout' not in args._get_kwargs():
42 | args.dropout = 0
43 |
44 | if 'alternate_corr' not in args._get_kwargs():
45 | args.alternate_corr = False
46 |
47 | # feature network, context network, and update block
48 | if args.small:
49 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
50 | self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
51 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
52 |
53 | else:
54 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
55 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
56 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
57 |
58 |
59 | def freeze_bn(self):
60 | for m in self.modules():
61 | if isinstance(m, nn.BatchNorm2d):
62 | m.eval()
63 |
64 | def initialize_flow(self, img):
65 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
66 | N, C, H, W = img.shape
67 | coords0 = coords_grid(N, H//8, W//8).to(img.device)
68 | coords1 = coords_grid(N, H//8, W//8).to(img.device)
69 |
70 | # optical flow computed as difference: flow = coords1 - coords0
71 | return coords0, coords1
72 |
73 | def upsample_flow(self, flow, mask):
74 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
75 | N, _, H, W = flow.shape
76 | mask = mask.view(N, 1, 9, 8, 8, H, W)
77 | mask = torch.softmax(mask, dim=2)
78 |
79 | up_flow = F.unfold(8 * flow, [3,3], padding=1)
80 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
81 |
82 | up_flow = torch.sum(mask * up_flow, dim=2)
83 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
84 | return up_flow.reshape(N, 2, 8*H, 8*W)
85 |
86 |
87 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False):
88 | """ Estimate optical flow between pair of frames """
89 |
90 | image1 = 2 * (image1 / 255.0) - 1.0
91 | image2 = 2 * (image2 / 255.0) - 1.0
92 |
93 | image1 = image1.contiguous()
94 | image2 = image2.contiguous()
95 |
96 | hdim = self.hidden_dim
97 | cdim = self.context_dim
98 |
99 | # run the feature network
100 | with autocast(enabled=self.args.mixed_precision):
101 | fmap1, fmap2 = self.fnet([image1, image2])
102 |
103 | fmap1 = fmap1.float()
104 | fmap2 = fmap2.float()
105 | if self.args.alternate_corr:
106 | corr_fn = CorrBlockAlternate(fmap1, fmap2, radius=self.args.corr_radius)
107 | else:
108 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
109 |
110 | # run the context network
111 | with autocast(enabled=self.args.mixed_precision):
112 | cnet = self.cnet(image1)
113 | net, inp = torch.split(cnet, [hdim, cdim], dim=1)
114 | net = torch.tanh(net)
115 | inp = torch.relu(inp)
116 |
117 | coords0, coords1 = self.initialize_flow(image1)
118 |
119 | if flow_init is not None:
120 | coords1 = coords1 + flow_init
121 |
122 | flow_predictions = []
123 | for itr in range(iters):
124 | coords1 = coords1.detach()
125 | corr = corr_fn(coords1) # index correlation volume
126 |
127 | flow = coords1 - coords0
128 | with autocast(enabled=self.args.mixed_precision):
129 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
130 |
131 | # F(t+1) = F(t) + \Delta(t)
132 | coords1 = coords1 + delta_flow
133 |
134 | # upsample predictions
135 | if up_mask is None:
136 | flow_up = upflow8(coords1 - coords0)
137 | else:
138 | flow_up = self.upsample_flow(coords1 - coords0, up_mask)
139 |
140 | flow_predictions.append(flow_up)
141 |
142 | if test_mode:
143 | return coords1 - coords0, flow_up
144 |
145 | return flow_predictions
146 |
--------------------------------------------------------------------------------
/RAFT/update.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class FlowHead(nn.Module):
7 | def __init__(self, input_dim=128, hidden_dim=256):
8 | super(FlowHead, self).__init__()
9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
11 | self.relu = nn.ReLU(inplace=True)
12 |
13 | def forward(self, x):
14 | return self.conv2(self.relu(self.conv1(x)))
15 |
16 | class ConvGRU(nn.Module):
17 | def __init__(self, hidden_dim=128, input_dim=192+128):
18 | super(ConvGRU, self).__init__()
19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
22 |
23 | def forward(self, h, x):
24 | hx = torch.cat([h, x], dim=1)
25 |
26 | z = torch.sigmoid(self.convz(hx))
27 | r = torch.sigmoid(self.convr(hx))
28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
29 |
30 | h = (1-z) * h + z * q
31 | return h
32 |
33 | class SepConvGRU(nn.Module):
34 | def __init__(self, hidden_dim=128, input_dim=192+128):
35 | super(SepConvGRU, self).__init__()
36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
39 |
40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
43 |
44 |
45 | def forward(self, h, x):
46 | # horizontal
47 | hx = torch.cat([h, x], dim=1)
48 | z = torch.sigmoid(self.convz1(hx))
49 | r = torch.sigmoid(self.convr1(hx))
50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
51 | h = (1-z) * h + z * q
52 |
53 | # vertical
54 | hx = torch.cat([h, x], dim=1)
55 | z = torch.sigmoid(self.convz2(hx))
56 | r = torch.sigmoid(self.convr2(hx))
57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
58 | h = (1-z) * h + z * q
59 |
60 | return h
61 |
62 | class SmallMotionEncoder(nn.Module):
63 | def __init__(self, args):
64 | super(SmallMotionEncoder, self).__init__()
65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
69 | self.conv = nn.Conv2d(128, 80, 3, padding=1)
70 |
71 | def forward(self, flow, corr):
72 | cor = F.relu(self.convc1(corr))
73 | flo = F.relu(self.convf1(flow))
74 | flo = F.relu(self.convf2(flo))
75 | cor_flo = torch.cat([cor, flo], dim=1)
76 | out = F.relu(self.conv(cor_flo))
77 | return torch.cat([out, flow], dim=1)
78 |
79 | class BasicMotionEncoder(nn.Module):
80 | def __init__(self, args):
81 | super(BasicMotionEncoder, self).__init__()
82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
88 |
89 | def forward(self, flow, corr):
90 | cor = F.relu(self.convc1(corr))
91 | cor = F.relu(self.convc2(cor))
92 | flo = F.relu(self.convf1(flow))
93 | flo = F.relu(self.convf2(flo))
94 |
95 | cor_flo = torch.cat([cor, flo], dim=1)
96 | out = F.relu(self.conv(cor_flo))
97 | return torch.cat([out, flow], dim=1)
98 |
99 | class SmallUpdateBlock(nn.Module):
100 | def __init__(self, args, hidden_dim=96):
101 | super(SmallUpdateBlock, self).__init__()
102 | self.encoder = SmallMotionEncoder(args)
103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
105 |
106 | def forward(self, net, inp, corr, flow):
107 | motion_features = self.encoder(flow, corr)
108 | inp = torch.cat([inp, motion_features], dim=1)
109 | net = self.gru(net, inp)
110 | delta_flow = self.flow_head(net)
111 |
112 | return net, None, delta_flow
113 |
114 | class BasicUpdateBlock(nn.Module):
115 | def __init__(self, args, hidden_dim=128, input_dim=128):
116 | super(BasicUpdateBlock, self).__init__()
117 | self.args = args
118 | self.encoder = BasicMotionEncoder(args)
119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
121 |
122 | self.mask = nn.Sequential(
123 | nn.Conv2d(128, 256, 3, padding=1),
124 | nn.ReLU(inplace=True),
125 | nn.Conv2d(256, 64*9, 1, padding=0))
126 |
127 | def forward(self, net, inp, corr, flow, upsample=True):
128 | motion_features = self.encoder(flow, corr)
129 | inp = torch.cat([inp, motion_features], dim=1)
130 |
131 | net = self.gru(net, inp)
132 | delta_flow = self.flow_head(net)
133 |
134 | # scale mask to balence gradients
135 | mask = .25 * self.mask(net)
136 | return net, mask, delta_flow
137 |
138 |
139 |
140 |
--------------------------------------------------------------------------------
/RAFT/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .flow_viz import flow_to_image
2 | from .frame_utils import writeFlow
3 |
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/flow_viz.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/flow_viz.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/flow_viz.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/flow_viz.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/flow_viz.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/flow_viz.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/flow_viz.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/flow_viz.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/frame_utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/frame_utils.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/frame_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/frame_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/frame_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/frame_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/frame_utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/frame_utils.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/utils.cpython-37.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/utils.cpython-38.pyc
--------------------------------------------------------------------------------
/RAFT/utils/__pycache__/utils.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/RAFT/utils/__pycache__/utils.cpython-39.pyc
--------------------------------------------------------------------------------
/RAFT/utils/augmentor.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import random
3 | import math
4 | from PIL import Image
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | import torch
11 | from torchvision.transforms import ColorJitter
12 | import torch.nn.functional as F
13 |
14 |
15 | class FlowAugmentor:
16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
17 |
18 | # spatial augmentation params
19 | self.crop_size = crop_size
20 | self.min_scale = min_scale
21 | self.max_scale = max_scale
22 | self.spatial_aug_prob = 0.8
23 | self.stretch_prob = 0.8
24 | self.max_stretch = 0.2
25 |
26 | # flip augmentation params
27 | self.do_flip = do_flip
28 | self.h_flip_prob = 0.5
29 | self.v_flip_prob = 0.1
30 |
31 | # photometric augmentation params
32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
33 | self.asymmetric_color_aug_prob = 0.2
34 | self.eraser_aug_prob = 0.5
35 |
36 | def color_transform(self, img1, img2):
37 | """ Photometric augmentation """
38 |
39 | # asymmetric
40 | if np.random.rand() < self.asymmetric_color_aug_prob:
41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
43 |
44 | # symmetric
45 | else:
46 | image_stack = np.concatenate([img1, img2], axis=0)
47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
48 | img1, img2 = np.split(image_stack, 2, axis=0)
49 |
50 | return img1, img2
51 |
52 | def eraser_transform(self, img1, img2, bounds=[50, 100]):
53 | """ Occlusion augmentation """
54 |
55 | ht, wd = img1.shape[:2]
56 | if np.random.rand() < self.eraser_aug_prob:
57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
58 | for _ in range(np.random.randint(1, 3)):
59 | x0 = np.random.randint(0, wd)
60 | y0 = np.random.randint(0, ht)
61 | dx = np.random.randint(bounds[0], bounds[1])
62 | dy = np.random.randint(bounds[0], bounds[1])
63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
64 |
65 | return img1, img2
66 |
67 | def spatial_transform(self, img1, img2, flow):
68 | # randomly sample scale
69 | ht, wd = img1.shape[:2]
70 | min_scale = np.maximum(
71 | (self.crop_size[0] + 8) / float(ht),
72 | (self.crop_size[1] + 8) / float(wd))
73 |
74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
75 | scale_x = scale
76 | scale_y = scale
77 | if np.random.rand() < self.stretch_prob:
78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
80 |
81 | scale_x = np.clip(scale_x, min_scale, None)
82 | scale_y = np.clip(scale_y, min_scale, None)
83 |
84 | if np.random.rand() < self.spatial_aug_prob:
85 | # rescale the images
86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
89 | flow = flow * [scale_x, scale_y]
90 |
91 | if self.do_flip:
92 | if np.random.rand() < self.h_flip_prob: # h-flip
93 | img1 = img1[:, ::-1]
94 | img2 = img2[:, ::-1]
95 | flow = flow[:, ::-1] * [-1.0, 1.0]
96 |
97 | if np.random.rand() < self.v_flip_prob: # v-flip
98 | img1 = img1[::-1, :]
99 | img2 = img2[::-1, :]
100 | flow = flow[::-1, :] * [1.0, -1.0]
101 |
102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
104 |
105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
108 |
109 | return img1, img2, flow
110 |
111 | def __call__(self, img1, img2, flow):
112 | img1, img2 = self.color_transform(img1, img2)
113 | img1, img2 = self.eraser_transform(img1, img2)
114 | img1, img2, flow = self.spatial_transform(img1, img2, flow)
115 |
116 | img1 = np.ascontiguousarray(img1)
117 | img2 = np.ascontiguousarray(img2)
118 | flow = np.ascontiguousarray(flow)
119 |
120 | return img1, img2, flow
121 |
122 | class SparseFlowAugmentor:
123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
124 | # spatial augmentation params
125 | self.crop_size = crop_size
126 | self.min_scale = min_scale
127 | self.max_scale = max_scale
128 | self.spatial_aug_prob = 0.8
129 | self.stretch_prob = 0.8
130 | self.max_stretch = 0.2
131 |
132 | # flip augmentation params
133 | self.do_flip = do_flip
134 | self.h_flip_prob = 0.5
135 | self.v_flip_prob = 0.1
136 |
137 | # photometric augmentation params
138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
139 | self.asymmetric_color_aug_prob = 0.2
140 | self.eraser_aug_prob = 0.5
141 |
142 | def color_transform(self, img1, img2):
143 | image_stack = np.concatenate([img1, img2], axis=0)
144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
145 | img1, img2 = np.split(image_stack, 2, axis=0)
146 | return img1, img2
147 |
148 | def eraser_transform(self, img1, img2):
149 | ht, wd = img1.shape[:2]
150 | if np.random.rand() < self.eraser_aug_prob:
151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0)
152 | for _ in range(np.random.randint(1, 3)):
153 | x0 = np.random.randint(0, wd)
154 | y0 = np.random.randint(0, ht)
155 | dx = np.random.randint(50, 100)
156 | dy = np.random.randint(50, 100)
157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color
158 |
159 | return img1, img2
160 |
161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
162 | ht, wd = flow.shape[:2]
163 | coords = np.meshgrid(np.arange(wd), np.arange(ht))
164 | coords = np.stack(coords, axis=-1)
165 |
166 | coords = coords.reshape(-1, 2).astype(np.float32)
167 | flow = flow.reshape(-1, 2).astype(np.float32)
168 | valid = valid.reshape(-1).astype(np.float32)
169 |
170 | coords0 = coords[valid>=1]
171 | flow0 = flow[valid>=1]
172 |
173 | ht1 = int(round(ht * fy))
174 | wd1 = int(round(wd * fx))
175 |
176 | coords1 = coords0 * [fx, fy]
177 | flow1 = flow0 * [fx, fy]
178 |
179 | xx = np.round(coords1[:,0]).astype(np.int32)
180 | yy = np.round(coords1[:,1]).astype(np.int32)
181 |
182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
183 | xx = xx[v]
184 | yy = yy[v]
185 | flow1 = flow1[v]
186 |
187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32)
189 |
190 | flow_img[yy, xx] = flow1
191 | valid_img[yy, xx] = 1
192 |
193 | return flow_img, valid_img
194 |
195 | def spatial_transform(self, img1, img2, flow, valid):
196 | # randomly sample scale
197 |
198 | ht, wd = img1.shape[:2]
199 | min_scale = np.maximum(
200 | (self.crop_size[0] + 1) / float(ht),
201 | (self.crop_size[1] + 1) / float(wd))
202 |
203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
204 | scale_x = np.clip(scale, min_scale, None)
205 | scale_y = np.clip(scale, min_scale, None)
206 |
207 | if np.random.rand() < self.spatial_aug_prob:
208 | # rescale the images
209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
212 |
213 | if self.do_flip:
214 | if np.random.rand() < 0.5: # h-flip
215 | img1 = img1[:, ::-1]
216 | img2 = img2[:, ::-1]
217 | flow = flow[:, ::-1] * [-1.0, 1.0]
218 | valid = valid[:, ::-1]
219 |
220 | margin_y = 20
221 | margin_x = 50
222 |
223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
225 |
226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
228 |
229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
233 | return img1, img2, flow, valid
234 |
235 |
236 | def __call__(self, img1, img2, flow, valid):
237 | img1, img2 = self.color_transform(img1, img2)
238 | img1, img2 = self.eraser_transform(img1, img2)
239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
240 |
241 | img1 = np.ascontiguousarray(img1)
242 | img2 = np.ascontiguousarray(img2)
243 | flow = np.ascontiguousarray(flow)
244 | valid = np.ascontiguousarray(valid)
245 |
246 | return img1, img2, flow, valid
247 |
--------------------------------------------------------------------------------
/RAFT/utils/flow_viz.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 |
4 | # MIT License
5 | #
6 | # Copyright (c) 2018 Tom Runia
7 | #
8 | # Permission is hereby granted, free of charge, to any person obtaining a copy
9 | # of this software and associated documentation files (the "Software"), to deal
10 | # in the Software without restriction, including without limitation the rights
11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
12 | # copies of the Software, and to permit persons to whom the Software is
13 | # furnished to do so, subject to conditions.
14 | #
15 | # Author: Tom Runia
16 | # Date Created: 2018-08-03
17 |
18 | import numpy as np
19 |
20 | def make_colorwheel():
21 | """
22 | Generates a color wheel for optical flow visualization as presented in:
23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
25 |
26 | Code follows the original C++ source code of Daniel Scharstein.
27 | Code follows the the Matlab source code of Deqing Sun.
28 |
29 | Returns:
30 | np.ndarray: Color wheel
31 | """
32 |
33 | RY = 15
34 | YG = 6
35 | GC = 4
36 | CB = 11
37 | BM = 13
38 | MR = 6
39 |
40 | ncols = RY + YG + GC + CB + BM + MR
41 | colorwheel = np.zeros((ncols, 3))
42 | col = 0
43 |
44 | # RY
45 | colorwheel[0:RY, 0] = 255
46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
47 | col = col+RY
48 | # YG
49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
50 | colorwheel[col:col+YG, 1] = 255
51 | col = col+YG
52 | # GC
53 | colorwheel[col:col+GC, 1] = 255
54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
55 | col = col+GC
56 | # CB
57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
58 | colorwheel[col:col+CB, 2] = 255
59 | col = col+CB
60 | # BM
61 | colorwheel[col:col+BM, 2] = 255
62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
63 | col = col+BM
64 | # MR
65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
66 | colorwheel[col:col+MR, 0] = 255
67 | return colorwheel
68 |
69 |
70 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
71 | """
72 | Applies the flow color wheel to (possibly clipped) flow components u and v.
73 |
74 | According to the C++ source code of Daniel Scharstein
75 | According to the Matlab source code of Deqing Sun
76 |
77 | Args:
78 | u (np.ndarray): Input horizontal flow of shape [H,W]
79 | v (np.ndarray): Input vertical flow of shape [H,W]
80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
81 |
82 | Returns:
83 | np.ndarray: Flow visualization image of shape [H,W,3]
84 | """
85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
86 | colorwheel = make_colorwheel() # shape [55x3]
87 | ncols = colorwheel.shape[0]
88 | rad = np.sqrt(np.square(u) + np.square(v))
89 | a = np.arctan2(-v, -u)/np.pi
90 | fk = (a+1) / 2*(ncols-1)
91 | k0 = np.floor(fk).astype(np.int32)
92 | k1 = k0 + 1
93 | k1[k1 == ncols] = 0
94 | f = fk - k0
95 | for i in range(colorwheel.shape[1]):
96 | tmp = colorwheel[:,i]
97 | col0 = tmp[k0] / 255.0
98 | col1 = tmp[k1] / 255.0
99 | col = (1-f)*col0 + f*col1
100 | idx = (rad <= 1)
101 | col[idx] = 1 - rad[idx] * (1-col[idx])
102 | col[~idx] = col[~idx] * 0.75 # out of range
103 | # Note the 2-i => BGR instead of RGB
104 | ch_idx = 2-i if convert_to_bgr else i
105 | flow_image[:,:,ch_idx] = np.floor(255 * col)
106 | return flow_image
107 |
108 |
109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, rad_max=None):
110 | """
111 | Expects a two dimensional flow image of shape.
112 |
113 | Args:
114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
117 |
118 | Returns:
119 | np.ndarray: Flow visualization image of shape [H,W,3]
120 | """
121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions'
122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
123 | if clip_flow is not None:
124 | flow_uv = np.clip(flow_uv, 0, clip_flow)
125 | u = flow_uv[:,:,0]
126 | v = flow_uv[:,:,1]
127 | if not rad_max:
128 | rad = np.sqrt(np.square(u) + np.square(v))
129 | rad_max = np.max(rad)
130 | epsilon = 1e-5
131 | u = u / (rad_max + epsilon)
132 | v = v / (rad_max + epsilon)
133 | return flow_uv_to_colors(u, v, convert_to_bgr)
--------------------------------------------------------------------------------
/RAFT/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from PIL import Image
3 | from os.path import *
4 | import re
5 |
6 | import cv2
7 | cv2.setNumThreads(0)
8 | cv2.ocl.setUseOpenCL(False)
9 |
10 | TAG_CHAR = np.array([202021.25], np.float32)
11 |
12 | def readFlow(fn):
13 | """ Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, 'rb') as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print('Magic number incorrect. Invalid .flo file')
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 | def readPFM(file):
34 | file = open(file, 'rb')
35 |
36 | color = None
37 | width = None
38 | height = None
39 | scale = None
40 | endian = None
41 |
42 | header = file.readline().rstrip()
43 | if header == b'PF':
44 | color = True
45 | elif header == b'Pf':
46 | color = False
47 | else:
48 | raise Exception('Not a PFM file.')
49 |
50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
51 | if dim_match:
52 | width, height = map(int, dim_match.groups())
53 | else:
54 | raise Exception('Malformed PFM header.')
55 |
56 | scale = float(file.readline().rstrip())
57 | if scale < 0: # little-endian
58 | endian = '<'
59 | scale = -scale
60 | else:
61 | endian = '>' # big-endian
62 |
63 | data = np.fromfile(file, endian + 'f')
64 | shape = (height, width, 3) if color else (height, width)
65 |
66 | data = np.reshape(data, shape)
67 | data = np.flipud(data)
68 | return data
69 |
70 | def writeFlow(filename,uv,v=None):
71 | """ Write optical flow to file.
72 |
73 | If v is None, uv is assumed to contain both u and v channels,
74 | stacked in depth.
75 | Original code by Deqing Sun, adapted from Daniel Scharstein.
76 | """
77 | nBands = 2
78 |
79 | if v is None:
80 | assert(uv.ndim == 3)
81 | assert(uv.shape[2] == 2)
82 | u = uv[:,:,0]
83 | v = uv[:,:,1]
84 | else:
85 | u = uv
86 |
87 | assert(u.shape == v.shape)
88 | height,width = u.shape
89 | f = open(filename,'wb')
90 | # write the header
91 | f.write(TAG_CHAR)
92 | np.array(width).astype(np.int32).tofile(f)
93 | np.array(height).astype(np.int32).tofile(f)
94 | # arrange into matrix form
95 | tmp = np.zeros((height, width*nBands))
96 | tmp[:,np.arange(width)*2] = u
97 | tmp[:,np.arange(width)*2 + 1] = v
98 | tmp.astype(np.float32).tofile(f)
99 | f.close()
100 |
101 |
102 | def readFlowKITTI(filename):
103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
104 | flow = flow[:,:,::-1].astype(np.float32)
105 | flow, valid = flow[:, :, :2], flow[:, :, 2]
106 | flow = (flow - 2**15) / 64.0
107 | return flow, valid
108 |
109 | def readDispKITTI(filename):
110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
111 | valid = disp > 0.0
112 | flow = np.stack([-disp, np.zeros_like(disp)], -1)
113 | return flow, valid
114 |
115 |
116 | def writeFlowKITTI(filename, uv):
117 | uv = 64.0 * uv + 2**15
118 | valid = np.ones([uv.shape[0], uv.shape[1], 1])
119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
120 | cv2.imwrite(filename, uv[..., ::-1])
121 |
122 |
123 | def read_gen(file_name, pil=False):
124 | ext = splitext(file_name)[-1]
125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
126 | return Image.open(file_name)
127 | elif ext == '.bin' or ext == '.raw':
128 | return np.load(file_name)
129 | elif ext == '.flo':
130 | return readFlow(file_name).astype(np.float32)
131 | elif ext == '.pfm':
132 | flow = readPFM(file_name).astype(np.float32)
133 | if len(flow.shape) == 2:
134 | return flow
135 | else:
136 | return flow[:, :, :-1]
137 | return []
--------------------------------------------------------------------------------
/RAFT/utils/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import numpy as np
4 | from scipy import interpolate
5 |
6 |
7 | class InputPadder:
8 | """ Pads images such that dimensions are divisible by 8 """
9 | def __init__(self, dims, mode='sintel'):
10 | self.ht, self.wd = dims[-2:]
11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
13 | if mode == 'sintel':
14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
15 | else:
16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
17 |
18 | def pad(self, *inputs):
19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs]
20 |
21 | def unpad(self,x):
22 | ht, wd = x.shape[-2:]
23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
24 | return x[..., c[0]:c[1], c[2]:c[3]]
25 |
26 | def forward_interpolate(flow):
27 | flow = flow.detach().cpu().numpy()
28 | dx, dy = flow[0], flow[1]
29 |
30 | ht, wd = dx.shape
31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
32 |
33 | x1 = x0 + dx
34 | y1 = y0 + dy
35 |
36 | x1 = x1.reshape(-1)
37 | y1 = y1.reshape(-1)
38 | dx = dx.reshape(-1)
39 | dy = dy.reshape(-1)
40 |
41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
42 | x1 = x1[valid]
43 | y1 = y1[valid]
44 | dx = dx[valid]
45 | dy = dy[valid]
46 |
47 | flow_x = interpolate.griddata(
48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
49 |
50 | flow_y = interpolate.griddata(
51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
52 |
53 | flow = np.stack([flow_x, flow_y], axis=0)
54 | return torch.from_numpy(flow).float()
55 |
56 |
57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False):
58 | """ Wrapper for grid_sample, uses pixel coordinates """
59 | H, W = img.shape[-2:]
60 | xgrid, ygrid = coords.split([1,1], dim=-1)
61 | xgrid = 2*xgrid/(W-1) - 1
62 | ygrid = 2*ygrid/(H-1) - 1
63 |
64 | grid = torch.cat([xgrid, ygrid], dim=-1)
65 | img = F.grid_sample(img, grid, align_corners=True)
66 |
67 | if mask:
68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
69 | return img, mask.float()
70 |
71 | return img
72 |
73 |
74 | def coords_grid(batch, ht, wd):
75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
76 | coords = torch.stack(coords[::-1], dim=0).float()
77 | return coords[None].repeat(batch, 1, 1, 1)
78 |
79 |
80 | def upflow8(flow, mode='bilinear'):
81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3])
82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
83 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FactorMatte
2 | ## Environment
3 | `conda create -n factormatte python=3.9 anaconda`
4 |
5 | `conda activate factormatte`
6 |
7 | Use conda or pip to install requirements.txt
8 |
9 | ## Example Video
10 | ### Download Dataset and put into the datasets/ folder
11 | https://drive.google.com/file/d/1-nZ9VA8bqRvll_4HEPGOxihIJ4o8y0kY/view?usp=sharing
12 |
13 | ### Stage 1
14 | `python train_GAN.py --name sand_car_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2 --stage 1 --dataset_mode omnimatte_GANCGANFlip148 --model omnimatte_GANFlip --dataroot ./datasets/sand_car --height 192 --width 288 --save_by_epoch --prob_masks --lambda_rgb_warp 1e-1 --lambda_alpha_warp 1e-1 --model_v 4 --residual_noise --strides 0,0,0 --num_Ds 0,0,0 --n_layers 0,0,0 --display_ind 63 --pos_ex_dirs , --batch_size 16 --n_epochs 1200 --bg_noise --gpu_ids 1,0 --lambda_recon_flow 1e-2`
15 |
16 | Copy the trained weights to the next stage's training folder: `cp 1110_checkpoints/sand_car_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2/*1200* 1110_checkpoints/sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_noninter_flowmask_flowrecon1e-2`
17 |
18 | Run test to generate the background image: `python test.py --name sand_car_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2 --dataset_mode omnimatte_GANCGANFlip148 --model omnimatte_GANFlip --dataroot ./datasets/DVM_manstatic --prob_masks --model_v 4 --residual_noise --strides 0,0,0 --num_Ds 0,0,0 --n_layers 0,0,0 --pos_ex_dirs , --epoch 1200 --stage 1 --gpu_ids 0 --start_ind 0 --width 512 --height 288`
19 |
20 | And put it in the data folder, it'll be used for the following stages. `cp results/sand_car_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2/test_1200_/panorama.png datasets/sand_car/bg_gt.png`
21 |
22 | ### Stage 2
23 | `python train_GAN.py --name sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_noninter_flowmask_flowrecon1e-2 --init_flowmask --lambda_recon_flow 1e-2 --dataset_mode factormatte_GANCGANFlip148 --model factormatte_GANFlip --dataroot ./datasets/sand_car --save_by_epoch --prob_masks --lambda_rgb_warp 1e-1 --lambda_alpha_warp 1e-1 --residual_noise --strides 0,2,2 --num_Ds 0,1,3 --n_layers 0,3,3 --start_ind 0 --noninter_only --width 288 --height 192 --discriminator_transform randomcrop --pos_ex_dirs 0uniform_0gaussian_dark_0flip_0elastic_0.25blursigma0.20.2k5_0.25gaussian_noise_std27mean0_rawframes,0rot_0flip_0.25blursigma0.20.2k5_0.25gaussian_noise_std27mean0_ --gpu_ids 0 --n_epochs 2400 --continue_train --epoch 1200 --stage 2 --display_ind 15`
24 |
25 | Copy the trained weights to the next stage's training folder: `cp 1110_checkpoints/sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_noninter_flowmask_flowrecon1e-2/*2400* 1110_checkpoints/sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_l2arecon1e-1dilate_recon2_148_othersretro_stage22000cont_flowrecon1e-1`
26 |
27 |
28 | ### Stage 3
29 | `python train_GAN.py --name sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_l2arecon1e-1dilate_recon2_148_othersretro_stage22000cont_flowrecon1e-1 --lambda_recon 2 --lambda_recon_flow 1e-1 --dataset_mode factormatte_GANCGANFlip148 --model factormatte_GANFlip --dataroot ./datasets/sand_car --save_by_epoch --prob_masks --lambda_rgb_warp 1e-1 --lambda_alpha_warp 1e-1 --residual_noise --strides 0,2,2 --num_Ds 0,1,3 --display_ind 63 --init_flowmask --lambda_recon_3 1e-1 --start_ind 0 --discriminator_transform randomcrop --steps 148 --pos_ex_dirs 0uniform_0gaussian_dark_0flip_0elastic_0.25blursigma0.20.2k5_0.25gaussian_noise_std27mean0_rawframes,0rot_0flip_0.25blursigma0.20.2k5_0.25gaussian_noise_std27mean0_ --stage 3` --height 192 --width 288 --n_epochs 3200 --gpu_ids 0 --continue_train --epoch 2400 --overwrite_lambdas
30 |
31 | ### Pretrained Weights
32 | For convenience, you can also download the weights of any stage for this dataset and start training from there on.
33 |
34 | Stage 1 weights: https://drive.google.com/drive/folders/1ERZQNM8nT2Xw9J2yzFzp3QoyxHFEZ7B4?usp=sharing
35 |
36 | Stage 2 weights: https://drive.google.com/drive/folders/1boJJ8DwPZxk9hzxUa-4nLW0vVPhXdhPL?usp=sharing
37 |
38 | Stage 3 weights: https://drive.google.com/drive/folders/1eDHuIsoON_ou_50sZ7nT4D3luiGxVvyx?usp=sharing
39 |
40 | ### Generate Results
41 | `python test.py --name sand_car_3layer_13GAN1e-3_strides22crop_D1_v4_rgbwarp1e-1_alphawarp1e-1_l2arecon1e-1dilate_recon2_148_othersretro_stage22000cont_flowrecon1e-1 --dataset_mode factormatte_GANCGANFlip148 --model factormatte_GANFlip --dataroot ./datasets/sand_car --gpu_ids 0 --prob_masks --residual_noise --pos_ex_dirs , --epoch 3200 --stage 3 --width 288 --height 192 --init_flowmask --test_suffix texts_to_put_after_fixed_folder_name`
42 |
43 |
44 | ## Custom Dataset
45 | To train on your custom video, please prepare it as follows: (Assume all file names are [xxxxx].png, e.g. 00001.png, 00100.png, 10001.png)
46 | 1. Extract all RGB frames and put them in "rgb" folder.
47 | 2. Arrange corresponding binary masks in the same order and put them in `mask/01` folder.
48 | 3. run `data/misc_data_process.py` to copy `mask/01` to `mask_nocushionmask/02`, and generate `mask_nocushionmask/01`. Please refer to the doc in data/misc_data_process.py for details. (Redundant, TODO: generate this on the fly.)
49 | 4. Estimate the homography between every two consecutive frames and flatten each matrix following the template of data/homographies.txt
50 | We provide a script in data/keypoint_homography_estimate.ipynb. It'll generate a file homographies_raw.txt. To get the final homographies.txt, run
51 | `python datasets/homography.py --homography_path ./datasets/[your_folder_name]/homographies_raw.txt --width [W] --height [H]`
52 |
53 | 5. Flow estimation by RAFT:
54 | `python video_completion.py --path datasets/[your_folder_name]/rgb --model weight/raft-things.pth --step 1`
55 |
56 |
57 | `python video_completion.py --path datasets/[your_folder_name]/rgb --model weight/raft-things.pth --step 4`
58 |
59 |
60 | `python video_completion.py --path datasets/[your_folder_name]/rgb --model weight/raft-things.pth --step 8`
61 |
62 | (As mentioned in section 7, we use multiple time scales (1, 4, 8) to reinforce consistency.)
63 |
64 | Move the generated flow matrices to your data folder:
65 |
66 | `mv RAFT_result/datasets[your_folder_name]rgb/*flow* datasets/[your_folder_name]`
67 |
68 | 6. Confidence estimate for flows:
69 | `python datasets/confidence.py --dataroot ./datasets/[your_folder_name] --step 1`
70 |
71 |
72 | `python datasets/confidence.py --dataroot ./datasets/[your_folder_name] --step 4`
73 |
74 |
75 | `python datasets/confidence.py --dataroot ./datasets/[your_folder_name] --step 8`
76 |
77 | 7. Find the simpler frames if you want to use the tricks in Section 7. Separate the frame indices as in `data/noninteraction_ind.txt`. If there's no such frames or you wish not to use such tricks, simply write "0, 1" in that file.
78 |
79 | 8. After Stage 1, run `python gen_foregroundPosEx.py` to generate positive examples for the foreground. Run `data/gen_backgroundPosEx.ipynb` to generate positive examples for the background.
80 |
81 | 9. In short, there should be these folders in data/[your_folder_name]:
82 |
83 | forward_flow_step1, forward_flow_step4, forward_flow_step8
84 |
85 | backward_flow_step1, backward_flow_step4, backward_flow_step8
86 |
87 | confidence_step1, confidence_step4, confidence_step8
88 |
89 | homographies.txt (if you use data/keypoint_homography_estimate.ipynb, there should also be a "homographies_raw.txt")
90 |
91 | mask_nocushionmask (2 subfolders: "01", "02")
92 |
93 | mask (1 subfolder containing the segmentaion mask of the foreground object: "01")
94 |
95 | noninteraction_ind.txt
96 |
97 | zbar.pth (Automatically generated to make sure the model starts with a fixed random noise.)
98 |
99 | dis_real_l1, dis_real_l2 (Generated after running Stage 1.)
100 |
101 |
--------------------------------------------------------------------------------
/causal/discriminator.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision.models as models
3 | import torch.nn as nn
4 | import functools
5 | import torch.nn.functional as F
6 | import numpy as np
7 |
8 |
9 | # Defines the PatchGAN discriminator with the specified arguments.
10 | class NLayerDiscriminator_4MultiscaleDiscriminator(nn.Module):
11 | def __init__(self, args, input_nc, ndf, n_layers, s, norm_layer, use_sigmoid):
12 | super(NLayerDiscriminator_4MultiscaleDiscriminator, self).__init__()
13 | self.conv = nn.Conv2d
14 | self.n_layers = n_layers
15 |
16 | kw = 4
17 | padw = int(np.ceil((kw-1.0)/2))
18 | sequence = [[self.conv(input_nc, ndf, kernel_size=kw, stride=s, padding=padw), \
19 | nn.LeakyReLU(0.2, True)]]
20 | nf = ndf
21 | # start from 1 because already 1 layer, minus 1 because another layer in the end
22 | for n in range(1, n_layers-1):
23 | nf_prev = nf
24 | nf = min(nf * 2, 512)
25 | sequence += [[
26 | self.conv(nf_prev, nf, kernel_size=kw, stride=1, padding=padw),
27 | norm_layer(nf),
28 | nn.LeakyReLU(0.2, True)
29 | ]]
30 |
31 | sequence += [[self.conv(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
32 | if use_sigmoid:
33 | sequence += [[nn.Sigmoid()]]
34 |
35 | sequence_stream = []
36 | for n in range(len(sequence)):
37 | sequence_stream += sequence[n]
38 | self.model = nn.Sequential(*sequence_stream)
39 |
40 |
41 | class MultiscaleDiscriminator(nn.Module):
42 | def __init__(self, args, stride, num_D, n_layers, norm_layer=nn.BatchNorm2d,
43 | use_sigmoid=False, ndf=64):
44 | super(MultiscaleDiscriminator, self).__init__()
45 | self.num_D = num_D
46 | self.n_layers = n_layers
47 | self.args = args
48 | if args.rgba_GAN == 'RGBA':
49 | input_nc = 4
50 | if args.rgba_GAN == 'RGB':
51 | input_nc = 3
52 | elif args.rgba_GAN == 'A':
53 | input_nc = 1 #a+mask
54 |
55 |
56 | for i in range(num_D):
57 | print('Initializing', i, 'th-scale discriminator. n_layers', n_layers, 'ndf', ndf, 'stride', stride,\
58 | norm_layer, use_sigmoid)
59 | netD = NLayerDiscriminator_4MultiscaleDiscriminator(args, input_nc, ndf, n_layers, stride, \
60 | norm_layer, use_sigmoid)
61 | setattr(self, 'layer'+str(i), netD.model)
62 |
63 | self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
64 |
65 | def singleD_forward(self, model, x):
66 | return model(x).flatten(1)
67 |
68 | def forward(self, x):
69 | num_D = self.num_D
70 | result = []
71 | result_valid = []
72 | input_downsampled = x
73 | for i in range(num_D):
74 | model = getattr(self, 'layer'+str(num_D-1-i))
75 | patches = self.singleD_forward(model, input_downsampled)
76 | result.append(patches)
77 | if i != (num_D-1):
78 | input_downsampled = self.downsample(input_downsampled)
79 | return torch.cat(result, 1)
80 |
81 |
--------------------------------------------------------------------------------
/causal/gen.py:
--------------------------------------------------------------------------------
1 | import cv2 as cv
2 | import torch
3 | import torchvision.models as models
4 | from torchvision import transforms as T
5 |
6 | import os
7 | from PIL import Image, ImageFilter
8 | import numpy as np
9 | import scipy.ndimage
10 |
11 | def prep_data(index):
12 | name = ('00'+str(index))[-4:]+'.png'
13 | gt = np.asarray(Image.open('../datasets/cushion_birdeye_texturecolor_suzanne_nocushionmask_3layer/rgb/'+name).convert('RGBA'))
14 | mask = np.asarray(Image.open('../datasets/cushion_birdeye_texturecolor_suzanne_nocushionmask_3layer/mask/01/seg'+name))
15 | mask[mask!=0] = 1
16 | mask = np.expand_dims(mask,-1)
17 | cube = np.where(mask!=0)
18 | up, down = cube[0].min(), cube[0].max()
19 | left, right = cube[1].min(), cube[1].max()
20 | fg = gt * mask
21 | return [left, right, up, down], fg
22 |
23 | def one_warp(fg, boundaries):
24 | left, right, up, down = boundaries
25 | src_pts = np.array([[left,up],[right,up], [right,down], [left,down]])
26 | h = down - up
27 | w = right - left
28 | left1 = left + np.random.uniform(-0.5*w, 0.5*w)
29 | left2 = left + np.random.uniform(-0.5*w, 0.5*w)
30 | right1 = right + np.random.uniform(-0.5*w, 0.5*w)
31 | right2 = right + np.random.uniform(-0.5*w, 0.5*w)
32 | up1 = up + np.random.uniform(-0.5*h, 0.5*h)
33 | up2 = up + np.random.uniform(-0.5*h, 0.5*h)
34 | down1 = down + np.random.uniform(-0.5*h, 0.5*h)
35 | down2 = down + np.random.uniform(-0.5*h, 0.5*h)
36 | leftstart = np.random.uniform(-0.4 * (448-w), 0.4 * (448-w))
37 | upstart = np.random.uniform(-0.4 * (256-h), 0.4 * (256-h))
38 | # print(leftstart,upstart)
39 | dst_pts = np.array([[left1+leftstart,up1+upstart],[right1+leftstart,up2+upstart],[right2+leftstart,down1+upstart], [left2+leftstart,down2+upstart]])
40 | M, _ = cv.findHomography(src_pts, dst_pts, cv.RANSAC,5.0)
41 | out = cv.warpPerspective(fg, M, (448, 256), flags=cv.INTER_LINEAR)
42 | return out
43 |
44 | def gen_fg():
45 | save_dir = 'classifier_dataset/cushion_birdeye_texturecolor_suzanne/test/fg'
46 | for n in range(1000):
47 | print(n)
48 | ind = np.random.randint(80, high=221)
49 | boundaries, fg = prep_data(ind)
50 | num = np.random.randint(0, high=6)
51 | grey = np.ones((256, 448,4))*255
52 | grey[:,:,:3]=np.random.randint(0, high=256)
53 | canvas = Image.fromarray(grey.astype('uint8'))
54 |
55 | for i in range(num):
56 | out = one_warp(fg, boundaries)
57 | alpha = np.random.rand()
58 | out[:,:,-1]=(out[:,:,-1]*alpha).astype('uint8')
59 | canvas = Image.alpha_composite(canvas, Image.fromarray(out))
60 | blur = np.random.rand()
61 | if blur >0.5:
62 | r = np.random.uniform(low=0, high=5.5)
63 | canvas = canvas.filter(ImageFilter.GaussianBlur(radius = r))
64 | canvas_np = np.asarray(canvas)
65 | for j in range(256):
66 | for k in range(448):
67 | if fg[j,k,-1]==0:
68 | fg[j,k]=canvas_np[j,k]
69 | fg_img = Image.fromarray(fg).convert('RGB')
70 | fg_img.save(os.path.join(save_dir, str(n)+'_from'+str(ind)+'_test.png'))
71 |
72 |
73 | def bg_warp(bg, boundaries):
74 | left, right, up, down = boundaries
75 | src_pts = np.array([[left,up],[right,up], [right,down], [left,down]])
76 | h = 256
77 | w = 448
78 | left1 = left + np.random.uniform(-0.15*w, 0.15*w)
79 | left2 = left + np.random.uniform(-0.15*w, 0.15*w)
80 | right1 = right + np.random.uniform(-0.15*w, 0.15*w)
81 | right2 = right + np.random.uniform(-0.15*w, 0.15*w)
82 | up1 = up + np.random.uniform(-0.15*h, 0.15*h)
83 | up2 = up + np.random.uniform(-0.15*h, 0.15*h)
84 | down1 = down + np.random.uniform(-0.15*h, 0.15*h)
85 | down2 = down + np.random.uniform(-0.15*h, 0.15*h)
86 |
87 | dst_pts = np.array([[left1,up1],[right1,up2],[right2,down1], [left2,down2]])
88 | M, _ = cv.findHomography(src_pts, dst_pts, cv.RANSAC, 5.0)
89 | out = cv.warpPerspective(bg, M, (448, 256), borderMode=cv.BORDER_WRAP, flags=cv.INTER_LINEAR) #[up:down, left:right]
90 | return out
91 |
92 |
93 | save_dir = 'classifier_dataset/cushion_birdeye_texturecolor_suzanne/train/bg/'
94 | gt = np.asarray(Image.open('../datasets/cushion_birdeye_texturecolor_suzanne_nocushionmask_3layer/bg_gt.png').convert('RGBA'))
95 | left = 115
96 | right = 302
97 | up = 33
98 | down = 200
99 | boundaries = [left, right, up, down]
100 | for n in range(5000):
101 | print(n)
102 | ind = np.random.randint(80, high=221)
103 | out = bg_warp(gt, boundaries)
104 | alpha = np.random.uniform(0.85, high=1)
105 | out[:,:,-1]=(out[:,:,-1]*alpha).astype('uint8')
106 | canvas = Image.fromarray(out).convert('RGB')
107 | blur = np.random.rand()
108 | if blur >0.5:
109 | r = np.random.uniform(low=0, high=5.5)
110 | canvas = canvas.filter(ImageFilter.GaussianBlur(radius = r))
111 | canvas.save(os.path.join(save_dir, str(n)+'_from'+str(ind)+'_bg_train.png'))
--------------------------------------------------------------------------------
/data/gen_foregroundPosEx.py:
--------------------------------------------------------------------------------
1 | import sys
2 | sys.path.insert(0, '/home/zg45/FactorMatte')
3 | import torch
4 | import torchvision.models as models
5 | from torchvision import transforms as T
6 |
7 | import os
8 | from PIL import Image, ImageFilter
9 | import numpy as np
10 | import scipy as sp
11 | import scipy.signal
12 | from shutil import copyfile
13 | import cv2 as cv
14 | from third_party.data.image_folder import make_dataset
15 |
16 |
17 | def prep_data(basedir, index):
18 | rgb_paths = sorted(make_dataset(os.path.join(basedir, 'rgb')))
19 | # mask_paths = sorted(make_dataset(os.path.join(basedir, 'l2_fake_real_comp_mask')))
20 | mask_paths = sorted(make_dataset(os.path.join(basedir, 'mask_nocushionmask/02/')))
21 | gt = np.asarray(Image.open(rgb_paths[index]).convert('RGBA')).astype('float')
22 | mask = np.asarray(Image.open(mask_paths[index]).convert('L')).astype('float')/255
23 | mask[mask != 1.] = 0
24 | if abs(mask).sum() == 0:
25 | return None, None
26 | # if 'composited' in basedir:
27 | # Optionally erode to be conservative
28 | # mask = cv.erode(mask, kernel=np.ones((12, 12)), iterations=1)
29 | mask = np.expand_dims(mask, -1)
30 | cube = np.where(mask != 0)
31 | up, down = cube[0].min(), cube[0].max()
32 | left, right = cube[1].min(), cube[1].max()
33 | fg = np.clip(gt * mask, 0, 255)
34 | return [left, right, up, down], fg.astype('uint8')
35 |
36 | def add_reflection(img, surface=140, alpha_range=[0, 0.75]):
37 | alpha = np.random.uniform(alpha_range[0], high=alpha_range[1])
38 | print(alpha)
39 | h, w, _ = img.shape
40 | start = abs(h - 2*surface)
41 | img[surface:] = alpha * img[start:surface].copy()[::-1]
42 | return img
43 |
44 | # def add_blur(img, sigma_range=[10, 20]):
45 | # sigma = 0
46 | # while sigma % 2 == 0:
47 | # # GaussianBlur only accepts odd kernel size
48 | # sigma = np.random.randint(sigma_range[0], high=sigma_range[1])
49 | # img = cv.GaussianBlur(img, (sigma, sigma), sigma/4 , borderType = cv.BORDER_REPLICATE)
50 | # return img
51 |
52 | def add_blur_1(img, sigma_range=[0.2, 1], kernel=5):
53 | img = img.astype("int16")
54 | std = np.random.uniform(sigma_range[0], high=sigma_range[1])
55 | blur_img = cv.GaussianBlur(img, (kernel, kernel), std, borderType = cv.BORDER_REPLICATE)
56 | blur_img = ceil_floor_image(blur_img)
57 | return blur_img
58 |
59 | def ceil_floor_image(image):
60 | """
61 | Args:
62 | image : numpy array of image in datatype int16
63 | Return :
64 | image : numpy array of image in datatype uint8 with ceilling(maximum 255) and flooring(minimum 0)
65 | """
66 | image[image > 255] = 255
67 | image[image < 0] = 0
68 | image = image.astype("uint8")
69 | return image
70 |
71 | def add_noise(img, std_range=[0, 20], mean=0):
72 | std = np.random.randint(std_range[0], high=std_range[1])
73 | print('std', std)
74 | gaussian_noise = np.random.normal(mean, std, img.shape)
75 | img = img.astype("int16")
76 | noise_img = img + gaussian_noise
77 | noise_img = ceil_floor_image(noise_img)
78 | return noise_img
79 |
80 | def flip(img):
81 | p = np.random.uniform()
82 | if p<0.5:
83 | img = img[:, ::-1]
84 | else:
85 | img = img[::-1, :]
86 | return img
87 |
88 | def rotate(img):
89 | deg = np.random.randint(0, 360)
90 | img = sp.ndimage.rotate(img, deg, reshape=False)
91 | return img
92 |
93 | def gen_pos_ex_fg(basedir, ind_low, ind_high, add_rot, add_flip, add_blurr_or_noise, add_gaussian_noise, \
94 | num=5000, blur_kwargs=None, noise_kwargs=None, folder_suffix=''):
95 | """
96 | ind high exclusive
97 | """
98 | save_dir = os.path.join(basedir, 'dis_real_l2', '_'.join([str(add_rot) + 'rot', str(add_flip) + 'flip', \
99 | str((1-add_gaussian_noise)*add_blurr_or_noise) + 'blursigma' + str(blur_kwargs['sigma_range'][0]) + str(blur_kwargs['sigma_range'][0])+'k'+str(blur_kwargs['kernel']),\
100 | str(add_gaussian_noise*add_blurr_or_noise)+'gaussian_noise_std'+str(noise_kwargs['std_range'][0])+str(noise_kwargs['std_range'][1])+'mean'+str(noise_kwargs['mean']), folder_suffix]))
101 | os.makedirs(save_dir)
102 | for n in range(0, num):
103 | print(n)
104 | fg = None
105 | while fg is None:
106 | ind = np.random.randint(0, high=ind_high-ind_low+1)
107 | boundaries, fg = prep_data(basedir, ind)
108 |
109 | decision = np.random.uniform(size=5)
110 | print(decision)
111 | # if decision[4] < 0.5:
112 | # scale = np.random.uniform(0.2, 1.2)
113 | # print('scale', scale)
114 | # scaled = scale * fg[:,:,:3].astype('int')
115 | # fg[:,:,:3] = np.clip(scaled, 0, 255).astype('uint8')
116 | if decision[0] < add_rot:
117 | print('rot')
118 | fg = rotate(fg)
119 | if decision[1] < add_flip:
120 | print('flip')
121 | fg = flip(fg)
122 |
123 | h, w, _ = fg.shape
124 | grey = np.ones((h, w, 4))*255
125 | grey[:,:,0]=0
126 | grey[:,:,1]=255
127 | grey[:,:,2]=0
128 | # grey[:,:,:3]=np.random.randint(0, high=80)
129 | canvas = Image.fromarray(grey.astype('uint8'))
130 | canvas_np = np.asarray(canvas)
131 | for j in range(h):
132 | for k in range(w):
133 | if fg[j, k, -1] == 0:
134 | fg[j, k] = canvas_np[j,k]
135 |
136 | if decision[2] < add_blurr_or_noise:
137 | if decision[3] < add_gaussian_noise:
138 | print('gaussian noise')
139 | fg = add_noise(fg, **noise_kwargs)
140 | # blur and noise are exclusive
141 | else:
142 | print('blurr, using add_blur_1')
143 | fg = add_blur_1(fg, **blur_kwargs)
144 | fg_img = Image.fromarray(fg).convert('RGB')
145 | fg_img.save(os.path.join(save_dir, '_'.join([str(n), 'from', str(ind+ind_low), 'fg', folder_suffix])+'.png'))
146 |
147 |
148 |
149 | if __name__ == '__main__':
150 | datadir = 'datasets/composited/cloth/cloth_grail_5152'
151 | video_start_ind = 0
152 | video_end_ind = 249
153 |
154 | # The probability of applying each augmentation during the generation of each positive example
155 | add_rot = 0 #0.5
156 | add_flip = 0 #0.5
157 | # reflec_kwargs= {
158 | # 'alpha_range': [0.1, 0.7],
159 | # 'surface': 140
160 | # }
161 |
162 | add_blurr_or_noise = 0.5
163 | blur_kwargs= {'sigma_range':[0.2, 1], 'kernel': 5}
164 | add_gaussian_noise = 0.5
165 | gaussian_noise_kwargs= {'std_range':[2, 7], 'mean':0}
166 |
167 | gen_pos_ex_fg(datadir, video_start_ind, video_end_ind, add_rot, add_flip, \
168 | add_blurr_or_noise, add_gaussian_noise, num=1200, blur_kwargs=blur_kwargs, \
169 | noise_kwargs=gaussian_noise_kwargs, folder_suffix='')
170 |
171 |
--------------------------------------------------------------------------------
/data/keypoint_homo_short.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "id": "74655d97",
7 | "metadata": {},
8 | "outputs": [],
9 | "source": [
10 | "import cv2 \n",
11 | "import numpy as np \n",
12 | "import torch\n",
13 | "import torchvision.models as models\n",
14 | "from torchvision import transforms as T\n",
15 | "\n",
16 | "import os\n",
17 | "from PIL import Image, ImageFilter \n",
18 | "import numpy as np\n",
19 | "import matplotlib.pyplot as plt\n",
20 | "from matplotlib.pyplot import imshow\n",
21 | "import scipy\n",
22 | "import scipy.ndimage\n",
23 | "from scipy import ndimage\n",
24 | "from shutil import copyfile"
25 | ]
26 | },
27 | {
28 | "cell_type": "code",
29 | "execution_count": null,
30 | "id": "fa93287f",
31 | "metadata": {},
32 | "outputs": [],
33 | "source": [
34 | "def feature_mask(mask, i):\n",
35 | " \"\"\"\n",
36 | " Because the edges of a segmentaation mask may be inacurate, we dilate it for \n",
37 | " the subsequent feature matching. \n",
38 | " The feature matcher will only look for correspondence points outside the mask.\n",
39 | " \"\"\"\n",
40 | " h, w = mask.shape\n",
41 | " obj = np.nonzero(mask)\n",
42 | " if len(obj[0]) == 0:\n",
43 | " # occlusion\n",
44 | " print('occlusion!')\n",
45 | " return 255 - mask\n",
46 | " mask_homo = np.ones_like(mask)*255\n",
47 | " up, down = obj[0].min(), obj[0].max() \n",
48 | " left, right = obj[1].min(), obj[1].max() \n",
49 | " kernel = np.ones((50, 50), np.uint8)\n",
50 | " mask = cv2.dilate(mask, kernel, iterations=1)\n",
51 | " mask_homo -= mask\n",
52 | " # You can adjust the edge erosion here to inlude more regions or less\n",
53 | "# mask_homo[max(0, up-50): min(h, down+60), max(0, left-20): min(w, right+20)]=0\n",
54 | " return mask_homo"
55 | ]
56 | },
57 | {
58 | "cell_type": "code",
59 | "execution_count": null,
60 | "id": "4550b7cc",
61 | "metadata": {
62 | "scrolled": true
63 | },
64 | "outputs": [],
65 | "source": [
66 | "# the dataset folder\n",
67 | "dataset = \"composited/cloth/cloth_grail_5152\"\n",
68 | "img_f = sorted(os.listdir(os.path.join(\"../datasets/\", dataset, \"rgb\")))[0]\n",
69 | "print(img_f)\n",
70 | "img1 = cv2.imread(os.path.join(\"../datasets/\", dataset, \"rgb\", img_f))\n",
71 | "old_gray = cv2.cvtColor(img1, cv2.COLOR_BGR2GRAY)\n",
72 | "\n",
73 | "h, w, _ = img1.shape\n",
74 | "img1_acc_mask = 0\n",
75 | "for mask_ind in os.listdir(os.path.join(\"../datasets/\", dataset, \"mask\")):\n",
76 | " print(mask_ind)\n",
77 | " mask_f = sorted(os.listdir(os.path.join(\"../datasets/\", dataset, \"mask/\", mask_ind)))[0]\n",
78 | " print(mask_f)\n",
79 | " mask_i = cv2.imread(os.path.join(\"../datasets/\", dataset, \"mask/\", mask_ind, mask_f))\n",
80 | " img1_acc_mask += mask_i\n",
81 | "img1_mask = feature_mask(cv2.cvtColor(img1_acc_mask, cv2.COLOR_BGR2GRAY),0)\n",
82 | "imshow(img1_mask)\n",
83 | "plt.show()\n",
84 | "sift = cv2.SIFT_create()\n",
85 | "# FLANN parameters\n",
86 | "FLANN_INDEX_KDTREE = 0\n",
87 | "index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)\n",
88 | "search_params = dict(checks=50) # or pass empty dictionary\n",
89 | "flann = cv2.FlannBasedMatcher(index_params,search_params)\n",
90 | "\n",
91 | "start_matrix = np.identity(3)\n",
92 | "with open(os.path.join(\"../datasets/\", dataset, 'homographies_raw.txt'), 'w') as f:\n",
93 | " for i in range(len(os.listdir(os.path.join(\"../datasets/\", dataset, \"rgb\")))):\n",
94 | " img_f = sorted(os.listdir(os.path.join(\"../datasets/\", dataset, \"rgb\")))[i]\n",
95 | " frame = cv2.imread(os.path.join(\"../datasets/\", dataset, \"rgb\", img_f))\n",
96 | " frame_gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)\n",
97 | "\n",
98 | "\n",
99 | " frame_acc_mask = 0\n",
100 | " # if there are multiple objects, collate their masks together, so that \n",
101 | " # feature masking will consider the points within none of them.\n",
102 | " # subfolders inside \"mask\" should be names as \"01\", \"02\", ...\n",
103 | " for mask_ind in os.listdir(os.path.join(\"../datasets/\", dataset, \"mask\")):\n",
104 | " mask_f = sorted(os.listdir(os.path.join(\"../datasets/\", dataset, \"mask/\", mask_ind)))[i]\n",
105 | " mask_i = cv2.imread(os.path.join(\"../datasets/\", dataset, \"mask/\", mask_ind, mask_f))\n",
106 | " frame_acc_mask += mask_i\n",
107 | " frame_mask = feature_mask(cv2.cvtColor(frame_acc_mask, cv2.COLOR_BGR2GRAY), i)\n",
108 | " imshow(frame_mask) \n",
109 | " plt.show()\n",
110 | " # find correspondence points between 2 frames by SIFT features.\n",
111 | " kp1, des1 = sift.detectAndCompute(old_gray, img1_mask)\n",
112 | " kp2, des2 = sift.detectAndCompute(frame_gray, frame_mask)\n",
113 | " matches = flann.knnMatch(des1,des2,k=2)\n",
114 | " tmp1 = cv2.drawKeypoints(old_gray, kp1, old_gray)\n",
115 | " tmp2 = cv2.drawKeypoints(frame_gray, kp2, frame_gray)\n",
116 | " plt.imshow(tmp1)\n",
117 | " plt.show()\n",
118 | " plt.imshow(tmp2)\n",
119 | " plt.show()\n",
120 | " good_points=[] \n",
121 | " for m, n in matches: \n",
122 | " good_points.append((m, m.distance/n.distance)) \n",
123 | " # sort the correspondence points by confidence, by default we only use the best 50.\n",
124 | " good_points.sort(key=lambda y: y[1])\n",
125 | " query_pts = np.float32([kp1[m.queryIdx] \n",
126 | " .pt for m,d in good_points[:50]]).reshape(-1, 1, 2) \n",
127 | "\n",
128 | " train_pts = np.float32([kp2[m.trainIdx] \n",
129 | " .pt for m,d in good_points[:50]]).reshape(-1, 1, 2) \n",
130 | " print('len(query_pts)',len(query_pts))\n",
131 | " # compute homography by the correspondence pairs\n",
132 | " matrix, matrix_mask = cv2.findHomography(query_pts, train_pts, cv2.RANSAC, 5.0) \n",
133 | " inliers = matrix_mask.sum()\n",
134 | " print(i, inliers, matrix)\n",
135 | " start_matrix = matrix @ start_matrix\n",
136 | " f.write(' '.join([str(i) for i in start_matrix.flatten()])+'\\n')\n",
137 | " imshow(frame_mask) \n",
138 | " plt.show()\n",
139 | " dst = cv2.warpPerspective(img1, start_matrix, (w, h), flags=cv2.INTER_LINEAR)\n",
140 | " imshow(dst) \n",
141 | " plt.show()\n",
142 | " dst = cv2.warpPerspective(old_gray, matrix, (w, h), flags=cv2.INTER_LINEAR)\n",
143 | " imshow(dst) \n",
144 | " plt.show()\n",
145 | " old_gray = frame_gray.copy()\n",
146 | " img1_mask = frame_mask.copy()\n",
147 | " imshow(frame_gray) \n",
148 | " plt.show()"
149 | ]
150 | }
151 | ],
152 | "metadata": {
153 | "kernelspec": {
154 | "display_name": "Python 3 (ipykernel)",
155 | "language": "python",
156 | "name": "python3"
157 | },
158 | "language_info": {
159 | "codemirror_mode": {
160 | "name": "ipython",
161 | "version": 3
162 | },
163 | "file_extension": ".py",
164 | "mimetype": "text/x-python",
165 | "name": "python",
166 | "nbconvert_exporter": "python",
167 | "pygments_lexer": "ipython3",
168 | "version": "3.9.12"
169 | }
170 | },
171 | "nbformat": 4,
172 | "nbformat_minor": 5
173 | }
174 |
--------------------------------------------------------------------------------
/data/misc_data_process.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from PIL import Image
4 | import numpy as np
5 | import shutil
6 |
7 |
8 | def gen_black_l1mask(p):
9 | """
10 | Assuming there's only 1 foreground object, copy its mask/01 folder
11 | to mask_nocushionmask/02.
12 | Then generate black images of the same size and same name as those in
13 | mask_nocushionmask/02 and put into mask_nocushionmask/01, which is used
14 | as the initialization of masks for the residual layer.
15 |
16 | The composition order is homography background, residual, then foreground.
17 | So the residual layer has index 1 and foreground layer's index changes to 2.
18 |
19 | TODO: the name "nocushionmask" is outdated and has no particular meaning now.
20 |
21 | Args:
22 | p (_type_): _description_
23 | """
24 | os.makedirs(os.path.join(p, 'mask_nocushionmask/01'), exist_ok=False)
25 | shutil.copytree(os.path.join(p, 'mask/01'), os.path.join(p, 'mask_nocushionmask/02'))
26 | for f in os.listdir(os.path.join(p, 'mask_nocushionmask/02')):
27 | if 'png' in f:
28 | print(f)
29 | img = Image.open(os.path.join(p, 'mask_nocushionmask/02', f))
30 | zeros = np.zeros_like(np.array(img)).astype('uint8')
31 | zeros_img = Image.fromarray(zeros)
32 | zeros_img.save(os.path.join(p, 'mask_nocushionmask/01', f))
33 |
34 | def real_video_rgba_a(source_dir, dest_dir):
35 | """
36 | Given RGBA images in source_dir, extract the Alpha channel and store in dest_dir.
37 | Used after Stage 1 if you want to manually clean up some predicted alphas.
38 | """
39 | os.makedirs(dest_dir, exist_ok=False)
40 | for f in os.listdir(source_dir):
41 | if '.png' in f:
42 | print(f)
43 | img_a = np.asarray(Image.open(os.path.join(source_dir, f)))[:,:,-1]
44 | Image.fromarray(img_a).save(os.path.join(dest_dir, f))
45 |
46 |
47 | if __name__ == '__main__':
48 | parser = argparse.ArgumentParser()
49 | # video completion
50 | parser.add_argument('--dataroot', type=str, help='dataroot')
51 | args = parser.parse_args()
52 |
53 | gen_black_l1mask(args.dataroot)
54 | # real_video_rgba_a('results/lucia_3layer_v4_rgbwarp1e-1_alphawarp1e-1_flowrecon1e-2/test_600_/images/rgba_l1/', 'datasets/lucia/dis_gt_alpha_stage2_res')
--------------------------------------------------------------------------------
/data/noninteraction_ind.txt:
--------------------------------------------------------------------------------
1 | 0, 40
--------------------------------------------------------------------------------
/datasets/confidence.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Erika Lu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | """Generate confidence maps from optical flow."""
17 | import os
18 | import sys
19 | sys.path.append('.')
20 | from utils import readFlow, numpy2im
21 | import glob
22 | from PIL import Image
23 | import numpy as np
24 | import torch
25 | import torch.nn.functional as F
26 |
27 |
28 | def compute_confidence(flo_f, flo_b, rgb, thresh=1, thresh_p=20):
29 | """Compute confidence map from optical flow."""
30 | im_height, im_width = flo_f.shape[:2]
31 | identity_grid = np.expand_dims(create_grid(im_height, im_width), 0)
32 | warp_b = flo_b[np.newaxis] + identity_grid
33 | warp_f = flo_f[np.newaxis] + identity_grid
34 | warp_b = map_coords(warp_b, im_height, im_width)
35 | warp_f = map_coords(warp_f, im_height, im_width)
36 | identity_grid = identity_grid.transpose(0, 3, 1, 2)
37 | warped_1 = F.grid_sample(torch.from_numpy(identity_grid), torch.from_numpy(warp_b), align_corners=True)
38 | warped_2 = F.grid_sample(warped_1, torch.from_numpy(warp_f), align_corners=True).numpy()
39 | err = np.linalg.norm(warped_2 - identity_grid, axis=1)
40 | err[err > thresh] = thresh
41 | err /= thresh
42 | confidence = 1 - err
43 |
44 | rgb = np.expand_dims(rgb.transpose(2, 0, 1), 0)
45 | rgb_warped_1 = F.grid_sample(torch.from_numpy(rgb).double(), torch.from_numpy(warp_b), align_corners=True)
46 | rgb_warped_2 = F.grid_sample(rgb_warped_1, torch.from_numpy(warp_f), align_corners=True).numpy()
47 | err = np.linalg.norm(rgb_warped_2 - rgb, axis=1)
48 | confidence_p = (err < thresh_p).astype(np.float32)
49 | confidence *= confidence_p
50 |
51 | return confidence[0]
52 |
53 |
54 | def map_coords(coords, height, width):
55 | """Map coordinates from pixel-space to [-1, 1] range for torch's grid_sample function."""
56 | coords_mapped = np.stack([coords[..., 0] / (width - 1), coords[..., 1] / (height - 1)], -1)
57 | return coords_mapped * 2 - 1
58 |
59 |
60 | def create_grid(height, width):
61 | ramp_u, ramp_v = np.meshgrid(np.linspace(0, width - 1, width), np.linspace(0, height - 1, height))
62 | return np.stack([ramp_u, ramp_v], -1)
63 |
64 |
65 | if __name__ == "__main__":
66 | import argparse
67 | arguments = argparse.ArgumentParser()
68 | arguments.add_argument('--dataroot', type=str)
69 | arguments.add_argument('--step', default=1, type=int)
70 | opt = arguments.parse_args()
71 |
72 | forward_flo = sorted(glob.glob(os.path.join(opt.dataroot, 'forward_flow_step'+str(opt.step), '*.flo')))
73 | backward_flo = sorted(glob.glob(os.path.join(opt.dataroot, 'backward_flow_step'+str(opt.step), '*.flo')))
74 | assert(len(forward_flo) == len(backward_flo))
75 | rgb_paths = sorted(glob.glob(os.path.join(opt.dataroot, 'rgb', '*')))
76 | print(f'generating {len(forward_flo)} confidence maps...from', '_flow_step'+str(opt.step))
77 | outdir = os.path.join(opt.dataroot, 'confidence_step'+str(opt.step))
78 | os.makedirs(outdir, exist_ok=True)
79 | for i in range(len(forward_flo)):
80 | flo_f = readFlow(forward_flo[i])
81 | flo_b = readFlow(backward_flo[i])
82 | rgb = np.array(Image.open(rgb_paths[i]))
83 | confidence = compute_confidence(flo_f, flo_b, rgb)
84 | fp = os.path.join(outdir, f'{i+1:04d}.png')
85 | im = numpy2im(confidence)
86 | im.save(fp)
87 | print(f'saved {len(forward_flo)} confidence maps to {outdir}')
88 |
--------------------------------------------------------------------------------
/datasets/homography.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Erika Lu
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 |
16 | """Helper tools for computing the world bounds from homographies."""
17 | import os
18 | import sys
19 | sys.path.append('.')
20 | from utils import readFlow, numpy2im
21 | # import glob
22 | from PIL import Image
23 | import numpy as np
24 | import torch
25 | import torch.nn.functional as F
26 |
27 |
28 | def transform2h(x, y, m):
29 | """Applies 2d homogeneous transformation."""
30 | A = np.dot(m, np.array([x, y, np.ones(len(x))]))
31 | xt = A[0, :] / A[2, :]
32 | yt = A[1, :] / A[2, :]
33 | return xt, yt
34 |
35 |
36 | def compute_world_bounds(homographies, height, width):
37 | """Compute minimum and maximum coordinates.
38 |
39 | homographies - list of 3x3 numpy arrays
40 | height, width - video dimensions
41 | """
42 | xbounds = [0, width - 1]
43 | ybounds = [0, height - 1]
44 |
45 | for h in homographies:
46 | # find transformed image bounding box
47 | x = np.array([0, width - 1, 0, width - 1])
48 | y = np.array([0, 0, height - 1, height - 1])
49 | [xt, yt] = transform2h(x, y, np.linalg.inv(h))
50 | xbounds[0] = min(xbounds[0], min(xt))
51 | xbounds[1] = max(xbounds[1], max(xt))
52 | ybounds[0] = min(ybounds[0], min(yt))
53 | ybounds[1] = max(ybounds[1], max(yt))
54 |
55 | return xbounds, ybounds
56 |
57 |
58 | if __name__ == "__main__":
59 | import argparse
60 | arguments = argparse.ArgumentParser()
61 | arguments.add_argument('--homography_path', type=str, help='path to text file containing homographies')
62 | arguments.add_argument('--width', type=int, help='video width')
63 | arguments.add_argument('--height', type=int, help='video height')
64 | opt = arguments.parse_args()
65 |
66 | with open(opt.homography_path) as f:
67 | lines = f.readlines()
68 | homographies = [l.rstrip().split(' ') for l in lines]
69 | homographies = [[float(h) for h in l] for l in homographies]
70 | homographies = [np.array(H).reshape(3, 3) for H in homographies]
71 | xbounds, ybounds = compute_world_bounds(homographies, opt.height, opt.width)
72 | out_path = f'{opt.homography_path[:-8]}.txt'
73 | with open(out_path, 'w') as f:
74 | f.write(f'size: {opt.width} {opt.height}\n')
75 | f.write(f'bounds: {xbounds[0]} {xbounds[1]} {ybounds[0]} {ybounds[1]}\n')
76 | f.writelines(lines)
77 | print(f'saved {out_path}')
78 |
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__init__.py
--------------------------------------------------------------------------------
/options/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/options/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/options/__pycache__/base_options.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/base_options.cpython-38.pyc
--------------------------------------------------------------------------------
/options/__pycache__/base_options.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/base_options.cpython-39.pyc
--------------------------------------------------------------------------------
/options/__pycache__/test_options.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/test_options.cpython-39.pyc
--------------------------------------------------------------------------------
/options/__pycache__/train_options.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/train_options.cpython-38.pyc
--------------------------------------------------------------------------------
/options/__pycache__/train_options.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/options/__pycache__/train_options.cpython-39.pyc
--------------------------------------------------------------------------------
/options/base_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | from third_party.util import util
4 | from third_party import models
5 | from third_party import data
6 | import torch
7 | import json
8 |
9 |
10 | class BaseOptions:
11 | """This class defines options used during both training and test time.
12 |
13 | It also implements several helper functions such as parsing, printing, and saving the options.
14 | It also gathers additional options defined in functions in both dataset class and model class.
15 | """
16 |
17 | def __init__(self):
18 | """Reset the class; indicates the class hasn't been initialized"""
19 | self.initialized = False
20 |
21 | def initialize(self, parser):
22 | """Define the common options that are used in both training and test."""
23 | # basic parameters
24 | parser.add_argument(
25 | "--dataroot",
26 | required=True,
27 | help="path to images (should have subfolders rgb_256, etc)",
28 | )
29 | parser.add_argument(
30 | "--name",
31 | type=str,
32 | default="experiment_name",
33 | help="name of the experiment. It decides where to store samples and models",
34 | )
35 | parser.add_argument(
36 | "--gpu_ids",
37 | type=str,
38 | default="0",
39 | help="gpu ids: e.g. 0 0,1,2, 0,2. use -1 for CPU",
40 | )
41 | parser.add_argument(
42 | "--checkpoints_dir",
43 | type=str,
44 | default="./1110_checkpoints",
45 | help="models are saved here",
46 | )
47 | parser.add_argument("--seed", type=int, default=35, help="initial random seed")
48 | # model parameters
49 | parser.add_argument(
50 | "--model",
51 | type=str,
52 | default="factormatte_GANFlip",
53 | help="chooses which model to use. [lnr | kp2uv]",
54 | )
55 | parser.add_argument(
56 | "--num_filters",
57 | type=int,
58 | default=64,
59 | help="# filters in the first and last conv layers",
60 | )
61 | # dataset parameters
62 | parser.add_argument(
63 | "--coarseness",
64 | type=int,
65 | default=10,
66 | help="Coarness of background offset interpolation",
67 | )
68 | parser.add_argument(
69 | "--max_frames",
70 | type=int,
71 | default=200,
72 | help="Similar meaning as max_dataset_size but cannot be infinite for background interpolation.",
73 | )
74 | parser.add_argument(
75 | "--dataset_mode",
76 | type=str,
77 | default="factormatte_GANCGANFlip148",
78 | help="chooses how datasets are loaded.",
79 | )
80 | parser.add_argument(
81 | "--serial_batches",
82 | action="store_true",
83 | help="if true, takes images in order to make batches, otherwise takes them randomly",
84 | )
85 | parser.add_argument(
86 | "--num_threads", default=4, type=int, help="# threads for loading data"
87 | )
88 | parser.add_argument(
89 | "--batch_size", type=int, default=8, help="input batch size"
90 | )
91 | parser.add_argument(
92 | "--max_dataset_size",
93 | type=int,
94 | default=float("inf"),
95 | help="Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.",
96 | )
97 | parser.add_argument(
98 | "--display_winsize",
99 | type=int,
100 | default=256,
101 | help="display window size for both visdom and HTML",
102 | )
103 | # additional parameters
104 | parser.add_argument(
105 | "--epoch",
106 | type=str,
107 | default="latest",
108 | help="which epoch to load? set to latest to use latest cached model",
109 | )
110 | parser.add_argument(
111 | "--verbose",
112 | action="store_true",
113 | help="if specified, print more debugging information",
114 | )
115 | parser.add_argument(
116 | "--suffix",
117 | default="",
118 | type=str,
119 | help="customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}",
120 | )
121 | parser.add_argument(
122 | "--prob_masks",
123 | action="store_true",
124 | help="if true, use 1 over #layer probability mask initialization, otherwise binary",
125 | )
126 | parser.add_argument(
127 | "--rgba",
128 | default="L",
129 | type=str,
130 | help="If true the input FG is RGBA, RGB, or L only.",
131 | )
132 | parser.add_argument(
133 | "--rgba_GAN",
134 | default="RGB",
135 | type=str,
136 | help="If true the input to the GAN discriminator is RGBA, RGB, or L only. Only used when there exists GAN, not CGAN.",
137 | )
138 | parser.add_argument(
139 | "--residual_noise",
140 | action="store_true",
141 | help="if true, use random noise for Z initialization.",
142 | )
143 | parser.add_argument(
144 | "--bg_noise",
145 | action="store_true",
146 | help="if true, use random noise for background Z initialization.",
147 | )
148 | parser.add_argument(
149 | "--no_bg",
150 | action="store_true",
151 | help="If true exclude the bg layer as defined in the original Omnimatte.",
152 | )
153 | parser.add_argument(
154 | "--orderscale",
155 | action="store_true",
156 | help="if true, keep the original Omnimatte's version of mask scaling.",
157 | )
158 | parser.add_argument(
159 | "--steps",
160 | type=str,
161 | default="1",
162 | help="X steps apart to consider. Specify without space.",
163 | )
164 | parser.add_argument(
165 | "--noninter_only",
166 | action="store_true",
167 | help="if true, only use nonteractive frames of the video.",
168 | )
169 | parser.add_argument(
170 | "--gradient_debug",
171 | action="store_true",
172 | help="whether to do the real gradient descent or just to record the gradients.",
173 | )
174 | parser.add_argument(
175 | "--num_Ds",
176 | default="0,3,3",
177 | type=str,
178 | help="Number of multiscale discriminators.",
179 | )
180 | parser.add_argument(
181 | "--strides",
182 | default="0,2,2",
183 | type=str,
184 | help="Number of stride in the convs of multiscale discriminators.",
185 | )
186 | parser.add_argument(
187 | "--n_layers",
188 | default="0,1,3",
189 | type=str,
190 | help="Number of stride in the convs of multiscale discriminators.",
191 | )
192 | parser.add_argument(
193 | "--fg_layer_ind",
194 | type=int,
195 | default=2,
196 | help="Which layer is the foreground, starting from 0.",
197 | )
198 | parser.add_argument(
199 | "--stage",
200 | type=int,
201 | help="Tells the dataset which dis_gt_alpha to use; index starting from 1. Stage 1: get bg, shouldn't have any dis_gt_alpha; stage 2: alpha from stage 1, for regularizing color, should run only on NFs; stage 3: alpha from stage 2 to constrain the alpha in IFs.",
202 | )
203 | parser.add_argument(
204 | "--get_bg",
205 | action="store_true",
206 | help="if specified, generate the bg panorama and quit",
207 | )
208 | self.initialized = True
209 | return parser
210 |
211 | def gather_options(self):
212 | """Initialize our parser with basic options(only once).
213 | Add additional model-specific and dataset-specific options.
214 | These options are defined in the function
215 | in model and dataset classes.
216 | """
217 | if not self.initialized: # check if it has been initialized
218 | parser = argparse.ArgumentParser(
219 | formatter_class=argparse.ArgumentDefaultsHelpFormatter
220 | )
221 | parser = self.initialize(parser)
222 |
223 | # get the basic options
224 | opt, _ = parser.parse_known_args()
225 |
226 | # modify model-related parser options
227 | model_name = opt.model
228 | model_option_setter = models.get_option_setter(model_name)
229 | parser = model_option_setter(parser, self.isTrain)
230 | opt, _ = parser.parse_known_args() # parse again with new defaults
231 |
232 | # modify dataset-related parser options
233 | dataset_name = opt.dataset_mode
234 | dataset_option_setter = data.get_option_setter(dataset_name)
235 | parser = dataset_option_setter(parser, self.isTrain)
236 |
237 | # save and return the parser
238 | self.parser = parser
239 | return parser.parse_args()
240 |
241 | def print_options(self, opt):
242 | """Print and save options
243 |
244 | It will print both current options and default values(if different).
245 | It will save options into a text file / [checkpoints_dir] / opt.txt
246 | """
247 | message = ""
248 | message += "----------------- Options ---------------\n"
249 | for k, v in sorted(vars(opt).items()):
250 | comment = ""
251 | default = self.parser.get_default(k)
252 | if v != default:
253 | comment = "\t[default: %s]" % str(default)
254 | message += "{:>25}: {:<30}{}\n".format(str(k), str(v), comment)
255 | message += "----------------- End -------------------"
256 | print(message)
257 |
258 | # save to the disk
259 | expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
260 | util.mkdirs(expr_dir)
261 | file_name = os.path.join(expr_dir, "{}_opt.txt".format(opt.phase))
262 | with open(file_name, "wt") as opt_file:
263 | opt_file.write(message)
264 | opt_file.write("\n")
265 |
266 | def parse(self):
267 | """Parse our options, create checkpoints directory suffix, and set up gpu device."""
268 | opt = self.gather_options()
269 | opt.isTrain = self.isTrain # train or test
270 |
271 | # process opt.suffix
272 | if opt.suffix:
273 | suffix = ("_" + opt.suffix.format(**vars(opt))) if opt.suffix != "" else ""
274 | opt.name = opt.name + suffix
275 |
276 | self.print_options(opt)
277 |
278 | # set gpu ids
279 | str_ids = opt.gpu_ids.split(",")
280 | opt.gpu_ids = []
281 | for str_id in str_ids:
282 | id = int(str_id)
283 | if id >= 0:
284 | opt.gpu_ids.append(id)
285 | if len(opt.gpu_ids) > 0:
286 | torch.cuda.set_device(opt.gpu_ids[0])
287 |
288 | self.opt = opt
289 | return self.opt
290 |
--------------------------------------------------------------------------------
/options/test_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TestOptions(BaseOptions):
5 | """This class includes test options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser) # define shared options
12 | parser.add_argument(
13 | "--results_dir", type=str, default="./results/", help="saves results here."
14 | )
15 | parser.add_argument(
16 | "--aspect_ratio",
17 | type=float,
18 | default=1.0,
19 | help="aspect ratio of result images",
20 | )
21 | parser.add_argument(
22 | "--phase", type=str, default="test", help="train, val, test, etc"
23 | )
24 | parser.add_argument(
25 | "--num_test",
26 | type=int,
27 | default=float("inf"),
28 | help="how many test images to run",
29 | )
30 | parser.add_argument(
31 | "--test_suffix", type=str, default="", help="suffix to folder name"
32 | )
33 | self.isTrain = False
34 | return parser
35 |
--------------------------------------------------------------------------------
/options/train_options.py:
--------------------------------------------------------------------------------
1 | from .base_options import BaseOptions
2 |
3 |
4 | class TrainOptions(BaseOptions):
5 | """This class includes training options.
6 |
7 | It also includes shared options defined in BaseOptions.
8 | """
9 |
10 | def initialize(self, parser):
11 | parser = BaseOptions.initialize(self, parser)
12 | # visdom and HTML visualization parameters
13 | parser.add_argument(
14 | "--display_ind",
15 | type=int,
16 | default=25,
17 | help="The index frame to visualize during training.",
18 | )
19 | parser.add_argument(
20 | "--display_freq",
21 | type=int,
22 | default=10,
23 | help="frequency of showing training results on screen (in epochs)",
24 | )
25 | parser.add_argument(
26 | "--display_ncols",
27 | type=int,
28 | default=0,
29 | help="if positive, display all images in a single visdom web panel with certain number of images per row.",
30 | )
31 | parser.add_argument(
32 | "--display_id", type=int, default=1, help="window id of the web display"
33 | )
34 | parser.add_argument(
35 | "--display_server",
36 | type=str,
37 | default="http://localhost",
38 | help="visdom server of the web display",
39 | )
40 | parser.add_argument(
41 | "--display_env",
42 | type=str,
43 | default="main",
44 | help='visdom display environment name (default is "main")',
45 | )
46 | parser.add_argument(
47 | "--display_port",
48 | type=int,
49 | default=8097,
50 | help="visdom port of the web display",
51 | )
52 | parser.add_argument(
53 | "--update_html_freq",
54 | type=int,
55 | default=10,
56 | help="frequency of saving training results to html",
57 | )
58 | parser.add_argument(
59 | "--print_freq",
60 | type=int,
61 | default=10,
62 | help="frequency of showing training results on console (in steps per epoch)",
63 | )
64 | parser.add_argument(
65 | "--no_html",
66 | action="store_true",
67 | help="do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/",
68 | )
69 | # network saving and loading parameters
70 | parser.add_argument(
71 | "--save_latest_freq",
72 | type=int,
73 | default=20,
74 | help="frequency of saving the latest results (in epochs)",
75 | )
76 | parser.add_argument(
77 | "--save_by_epoch",
78 | type=bool,
79 | default=True,
80 | help='whether saves model as "epoch" or "latest" (overwrites previous)',
81 | )
82 | parser.add_argument(
83 | "--continue_train",
84 | action="store_true",
85 | help="continue training: load the latest model",
86 | )
87 | parser.add_argument(
88 | "--overwrite_lambdas",
89 | action="store_true",
90 | help="continue training and overwrite lambdas and epochs hyperparams by history",
91 | )
92 | parser.add_argument(
93 | "--overwrite_lrs",
94 | action="store_true",
95 | help="continue training and overwrite lr by history",
96 | )
97 | parser.add_argument(
98 | "--epoch_count",
99 | type=int,
100 | default=1,
101 | help="the starting epoch count, we save the model by , +, ...",
102 | )
103 | parser.add_argument(
104 | "--phase", type=str, default="train", help="train, val, test, etc"
105 | )
106 | # training parameters
107 | parser.add_argument(
108 | "--n_epochs",
109 | type=int,
110 | default=None,
111 | help="number of training epochs with the initial learning rate.\
112 | You only need to specify one of this or n_steps",
113 | )
114 | parser.add_argument(
115 | "--n_steps",
116 | type=int,
117 | default=24000,
118 | help="number of training steps with the initial learning rate",
119 | )
120 | parser.add_argument(
121 | "--n_steps_decay",
122 | type=int,
123 | default=0,
124 | help="number of steps to linearly decay learning rate to zero",
125 | )
126 | parser.add_argument(
127 | "--lr", type=float, default=0.001, help="initial learning rate for adam"
128 | )
129 | parser.add_argument(
130 | "--lr_policy",
131 | type=str,
132 | default="linear",
133 | help="learning rate policy. [linear | step | plateau | cosine]",
134 | )
135 | parser.add_argument(
136 | "--pretrained",
137 | action="store_true",
138 | help="Whether use part of a pretrained resnet18 for the discriminator.",
139 | )
140 | parser.add_argument(
141 | "--discriminator_transform",
142 | type=str,
143 | default="randomcrop",
144 | help="What transform to apply to the generated rgb before feeding into the discriminator.",
145 | )
146 | parser.add_argument(
147 | "--jitter",
148 | action="store_true",
149 | help="Whether use the original jitter for training.",
150 | )
151 |
152 | self.isTrain = True
153 | return parser
154 |
--------------------------------------------------------------------------------
/prepare_data_stage1.sh:
--------------------------------------------------------------------------------
1 | # Get ready the homographies_raw.txt, mask/01, rgb folder and run this!
2 | python video_completion.py --path $1/rgb --step 1
3 | python video_completion.py --path $1/rgb --step 4
4 | python video_completion.py --path $1/rgb --step 8
5 |
6 | mv RAFT_result/$(echo $1 | sed 's/\///g')rgb/*flow* $1
7 |
8 | python datasets/confidence.py --dataroot $1 --step 1
9 | python datasets/confidence.py --dataroot $1 --step 4
10 | python datasets/confidence.py --dataroot $1 --step 8
11 |
12 | python datasets/homography.py --homography_path $1/homographies_raw.txt --width $2 --height $3
13 | python data/misc_data_process.py --dataroot $1
14 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch=1.11.0
2 | cudatoolkit=11.2
3 | torchvision
4 | scipy
5 | tensorboard
6 | Pillow
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | """Script to save the full outputs of an Omnimatte model.
2 |
3 | Once you have trained the Omnimatte model with train.py, you can use this script to save the model's final omnimattes.
4 | It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'.
5 |
6 | It first creates a model and dataset given the options. It will hard-code some parameters.
7 | It then runs inference for '--num_test' images and save results to an HTML file.
8 |
9 | Example (after training a model):
10 | python test.py --dataroot ./datasets/tennis --name tennis
11 |
12 | Use '--results_dir ' to specify the results directory.
13 |
14 | See options/base_options.py and options/test_options.py for more test options.
15 | """
16 | import os
17 | from options.test_options import TestOptions
18 | from third_party.data import create_dataset
19 | from third_party.models import create_model
20 | from third_party.util.visualizer import save_images, save_videos
21 | from third_party.util import html
22 | import torch
23 |
24 |
25 | if __name__ == "__main__":
26 | testopt = TestOptions()
27 | opt = testopt.parse()
28 | # hard-code some parameters for test
29 | opt.num_threads = 0 # test code only supports num_threads = 0
30 | opt.batch_size = 1 # test code only supports batch_size = 1
31 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed.
32 | opt.display_id = (
33 | -1
34 | ) # no visdom display; the test code saves the results to a HTML file.
35 | dataset = create_dataset(
36 | opt
37 | ) # create a dataset given opt.dataset_mode and other options
38 | model = create_model(opt) # create a model given opt.model and other options
39 | model.setup(opt) # regular setup: load and print networks; create schedulers
40 | if opt.gradient_debug:
41 | weight = torch.load(
42 | os.path.join(opt.checkpoints_dir, opt.name, str(opt.epoch) + "_others.pth")
43 | )
44 | for i in range(len(model.discriminators)):
45 | if model.discriminators[i] is not None:
46 | model.discriminators[i].load_state_dict(
47 | weight["discriminator_l" + str(i)], strict=False
48 | )
49 | print(i, "th discriminator weights loaded unstrictly")
50 | print(
51 | "the dict in the history is",
52 | weight["discriminator_l" + str(i)].keys(),
53 | )
54 | print(
55 | "the dict in current model is",
56 | model.discriminators[i].state_dict().keys(),
57 | )
58 |
59 | # create a website
60 | web_dir = os.path.join(
61 | opt.results_dir,
62 | opt.name,
63 | "{}_{}_{}".format(opt.phase, opt.epoch, opt.test_suffix),
64 | ) # define the website directory
65 | print("creating web directory", web_dir)
66 | webpage = html.HTML(
67 | web_dir,
68 | "Experiment = %s, Phase = %s, Epoch = %s" % (opt.name, opt.phase, opt.epoch),
69 | )
70 | video_visuals = None
71 | loss_recon = 0
72 | model.do_cam_adj = True #False
73 | for i, data in enumerate(dataset):
74 | # print(i)
75 | # if i < 130:
76 | # continue
77 | if i >= opt.num_test: # only apply our model to opt.num_test images.
78 | break
79 | model.set_input(data) # unpack data from data loader
80 | model.test(i) # run inference
81 | img_path = model.get_image_paths() # get image paths
82 | if i % 5 == 0: # save images to an HTML file
83 | print("processing (%04d)-th image... %s" % (i, img_path))
84 | with torch.no_grad():
85 | visuals = model.get_results() # rgba, reconstruction, original, mask
86 | if video_visuals is None:
87 | video_visuals = visuals
88 | else:
89 | for k in video_visuals:
90 | video_visuals[k] = torch.cat((video_visuals[k], visuals[k]))
91 | for k in video_visuals:
92 | rgba = {k: visuals[k]} # for k in visuals if "rgba" in k
93 | # save RGBA layers
94 | save_images(
95 | webpage,
96 | rgba,
97 | img_path,
98 | aspect_ratio=opt.aspect_ratio,
99 | width=opt.display_winsize,
100 | )
101 | # if os.path.isdir(os.path.join(opt.dataroot, "rgb_invis_gt")):
102 | # print(
103 | # model.criterionLoss(
104 | # model.reconstruction_rgb_no_cube, model.target_image
105 | # ),
106 | # )
107 | # loss_recon += model.criterionLoss(
108 | # model.reconstruction_rgb_no_cube, model.target_image
109 | # )
110 |
111 | save_videos(webpage, video_visuals, width=opt.display_winsize)
112 | webpage.save() # save the HTML of videos
113 | with open(os.path.join(web_dir, "invis_gt_eval.txt"), "w") as f:
114 | print("avg recon no cube L1Loss " + str(loss_recon / len(dataset)), file=f)
115 |
--------------------------------------------------------------------------------
/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/__init__.py
--------------------------------------------------------------------------------
/third_party/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/__init__.pyc
--------------------------------------------------------------------------------
/third_party/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/data/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes all the modules related to data loading and preprocessing
2 |
3 | To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
4 | You need to implement four functions:
5 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
6 | -- <__len__>: return the size of dataset.
7 | -- <__getitem__>: get a data point from data loader.
8 | -- : (optionally) add dataset-specific options and set default options.
9 |
10 | Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
11 | See our template dataset class 'template_dataset.py' for more details.
12 | """
13 | import importlib
14 | import torch.utils.data
15 | from .base_dataset import BaseDataset
16 |
17 |
18 | def find_dataset_using_name(dataset_name):
19 | """Import the module "data/[dataset_name]_dataset.py".
20 |
21 | In the file, the class called DatasetNameDataset() will
22 | be instantiated. It has to be a subclass of BaseDataset,
23 | and it is case-insensitive.
24 | """
25 | dataset_filename = "data." + dataset_name + "_dataset"
26 | datasetlib = importlib.import_module(dataset_filename)
27 |
28 | dataset = None
29 | target_dataset_name = dataset_name.replace('_', '') + 'dataset'
30 | for name, cls in datasetlib.__dict__.items():
31 | if name.lower() == target_dataset_name.lower() \
32 | and issubclass(cls, BaseDataset):
33 | dataset = cls
34 |
35 | if dataset is None:
36 | raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
37 |
38 | return dataset
39 |
40 |
41 | def get_option_setter(dataset_name):
42 | """Return the static method of the dataset class."""
43 | dataset_class = find_dataset_using_name(dataset_name)
44 | return dataset_class.modify_commandline_options
45 |
46 |
47 | def create_dataset(opt):
48 | """Create a dataset given the option.
49 |
50 | This function wraps the class CustomDatasetDataLoader.
51 | This is the main interface between this package and 'train.py'/'test.py'
52 |
53 | Example:
54 | >>> from data import create_dataset
55 | >>> dataset = create_dataset(opt)
56 | """
57 | data_loader = CustomDatasetDataLoader(opt)
58 | dataset = data_loader.load_data()
59 | return dataset
60 |
61 |
62 | class CustomDatasetDataLoader():
63 | """Wrapper class of Dataset class that performs multi-threaded data loading"""
64 |
65 | def __init__(self, opt):
66 | """Initialize this class
67 |
68 | Step 1: create a dataset instance given the name [dataset_mode]
69 | Step 2: create a multi-threaded data loader.
70 | """
71 | self.opt = opt
72 | dataset_class = find_dataset_using_name(opt.dataset_mode)
73 | self.dataset = dataset_class(opt)
74 | print("dataset [%s] was created" % type(self.dataset).__name__)
75 | loader = torch.utils.data.DataLoader
76 | self.dataloader = loader(
77 | self.dataset,
78 | batch_size=opt.batch_size,
79 | shuffle=not opt.serial_batches,
80 | num_workers=int(opt.num_threads),
81 | persistent_workers=int(opt.num_threads) > 0,
82 | drop_last = True)
83 |
84 | def load_data(self):
85 | return self
86 |
87 | def __len__(self):
88 | """Return the number of data in the dataset"""
89 | return min(len(self.dataset), self.opt.max_dataset_size)
90 |
91 | def __iter__(self):
92 | """Return a batch of data"""
93 | for i, data in enumerate(self.dataloader):
94 | if i * self.opt.batch_size >= self.opt.max_dataset_size:
95 | break
96 | yield data
97 |
--------------------------------------------------------------------------------
/third_party/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/data/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/data/__pycache__/base_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/base_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/data/__pycache__/base_dataset.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/base_dataset.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/data/__pycache__/image_folder.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/image_folder.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/data/__pycache__/image_folder.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/data/__pycache__/image_folder.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/data/base_dataset.py:
--------------------------------------------------------------------------------
1 | """This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
2 |
3 | It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
4 | """
5 | import random
6 | import numpy as np
7 | import torch.utils.data as data
8 | from PIL import Image
9 | import torchvision.transforms as transforms
10 | from abc import ABC, abstractmethod
11 |
12 |
13 | class BaseDataset(data.Dataset, ABC):
14 | """This class is an abstract base class (ABC) for datasets.
15 |
16 | To create a subclass, you need to implement the following four functions:
17 | -- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
18 | -- <__len__>: return the size of dataset.
19 | -- <__getitem__>: get a data point.
20 | -- : (optionally) add dataset-specific options and set default options.
21 | """
22 |
23 | def __init__(self, opt):
24 | """Initialize the class; save the options in the class
25 |
26 | Parameters:
27 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28 | """
29 | self.opt = opt
30 | self.root = opt.dataroot
31 |
32 | @staticmethod
33 | def modify_commandline_options(parser, is_train):
34 | """Add new dataset-specific options, and rewrite default values for existing options.
35 |
36 | Parameters:
37 | parser -- original option parser
38 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
39 |
40 | Returns:
41 | the modified parser.
42 | """
43 | return parser
44 |
45 | @abstractmethod
46 | def __len__(self):
47 | """Return the total number of images in the dataset."""
48 | return 0
49 |
50 | @abstractmethod
51 | def __getitem__(self, index):
52 | """Return a data point and its metadata information.
53 |
54 | Parameters:
55 | index - - a random integer for data indexing
56 |
57 | Returns:
58 | a dictionary of data with their names. It ususally contains the data itself and its metadata information.
59 | """
60 | pass
61 |
62 |
63 | def get_params(opt, size):
64 | w, h = size
65 | new_h = h
66 | new_w = w
67 | if opt.preprocess == 'resize_and_crop':
68 | new_h = new_w = opt.load_size
69 | elif opt.preprocess == 'scale_width_and_crop':
70 | new_w = opt.load_size
71 | new_h = opt.load_size * h // w
72 |
73 | x = random.randint(0, np.maximum(0, new_w - opt.crop_size))
74 | y = random.randint(0, np.maximum(0, new_h - opt.crop_size))
75 |
76 | flip = random.random() > 0.5
77 |
78 | return {'crop_pos': (x, y), 'flip': flip}
79 |
80 |
81 | def get_transform(opt, params=None, grayscale=False, method=Image.BICUBIC, convert=True):
82 | transform_list = []
83 | if grayscale:
84 | transform_list.append(transforms.Grayscale(1))
85 | if 'resize' in opt.preprocess:
86 | osize = [opt.load_size, opt.load_size]
87 | transform_list.append(transforms.Resize(osize, method))
88 | elif 'scale_width' in opt.preprocess:
89 | transform_list.append(transforms.Lambda(lambda img: __scale_width(img, opt.load_size, opt.crop_size, method)))
90 |
91 | if 'crop' in opt.preprocess:
92 | if params is None:
93 | transform_list.append(transforms.RandomCrop(opt.crop_size))
94 | else:
95 | transform_list.append(transforms.Lambda(lambda img: __crop(img, params['crop_pos'], opt.crop_size)))
96 |
97 | if opt.preprocess == 'none':
98 | transform_list.append(transforms.Lambda(lambda img: __make_power_2(img, base=4, method=method)))
99 |
100 | if not opt.no_flip:
101 | if params is None:
102 | transform_list.append(transforms.RandomHorizontalFlip())
103 | elif params['flip']:
104 | transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
105 |
106 | if convert:
107 | transform_list += [transforms.ToTensor()]
108 | if grayscale:
109 | transform_list += [transforms.Normalize((0.5,), (0.5,))]
110 | else:
111 | transform_list += [transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
112 | return transforms.Compose(transform_list)
113 |
114 |
115 | def __make_power_2(img, base, method=Image.BICUBIC):
116 | ow, oh = img.size
117 | h = int(round(oh / base) * base)
118 | w = int(round(ow / base) * base)
119 | if h == oh and w == ow:
120 | return img
121 |
122 | __print_size_warning(ow, oh, w, h)
123 | return img.resize((w, h), method)
124 |
125 |
126 | def __scale_width(img, target_size, crop_size, method=Image.BICUBIC):
127 | ow, oh = img.size
128 | if ow == target_size and oh >= crop_size:
129 | return img
130 | w = target_size
131 | h = int(max(target_size * oh / ow, crop_size))
132 | return img.resize((w, h), method)
133 |
134 |
135 | def __crop(img, pos, size):
136 | ow, oh = img.size
137 | x1, y1 = pos
138 | tw = th = size
139 | if (ow > tw or oh > th):
140 | return img.crop((x1, y1, x1 + tw, y1 + th))
141 | return img
142 |
143 |
144 | def __flip(img, flip):
145 | if flip:
146 | return img.transpose(Image.FLIP_LEFT_RIGHT)
147 | return img
148 |
149 |
150 | def __print_size_warning(ow, oh, w, h):
151 | """Print warning information about image size(only print once)"""
152 | if not hasattr(__print_size_warning, 'has_printed'):
153 | print("The image size needs to be a multiple of 4. "
154 | "The loaded image size was (%d, %d), so it was adjusted to "
155 | "(%d, %d). This adjustment will be done to all images "
156 | "whose sizes are not multiples of 4" % (ow, oh, w, h))
157 | __print_size_warning.has_printed = True
158 |
--------------------------------------------------------------------------------
/third_party/data/image_folder.py:
--------------------------------------------------------------------------------
1 | """A modified image folder class
2 |
3 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
4 | so that this class can load images from both current directory and its subdirectories.
5 | """
6 |
7 | import torch.utils.data as data
8 |
9 | from PIL import Image
10 | import os
11 |
12 | IMG_EXTENSIONS = [
13 | '.jpg', '.JPG', '.jpeg', '.JPEG',
14 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
15 | '.tif', '.TIF', '.tiff', '.TIFF',
16 | ]
17 |
18 |
19 | def is_image_file(filename):
20 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
21 |
22 |
23 | def make_dataset(dir, max_dataset_size=float("inf")):
24 | images = []
25 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
26 |
27 | for root, _, fnames in sorted(os.walk(dir)):
28 | for fname in fnames:
29 | if is_image_file(fname):
30 | path = os.path.join(root, fname)
31 | images.append(path)
32 | images = sorted(images)
33 | return images[:min(max_dataset_size, len(images))]
34 |
35 |
36 | def default_loader(path):
37 | return Image.open(path).convert('RGB')
38 |
39 |
40 | class ImageFolder(data.Dataset):
41 |
42 | def __init__(self, root, transform=None, return_paths=False,
43 | loader=default_loader):
44 | imgs = make_dataset(root)
45 | if len(imgs) == 0:
46 | raise(RuntimeError("Found 0 images in: " + root + "\n"
47 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
48 |
49 | self.root = root
50 | self.imgs = imgs
51 | self.transform = transform
52 | self.return_paths = return_paths
53 | self.loader = loader
54 |
55 | def __getitem__(self, index):
56 | path = self.imgs[index]
57 | img = self.loader(path)
58 | if self.transform is not None:
59 | img = self.transform(img)
60 | if self.return_paths:
61 | return img, path
62 | else:
63 | return img
64 |
65 | def __len__(self):
66 | return len(self.imgs)
67 |
--------------------------------------------------------------------------------
/third_party/models/__init__.py:
--------------------------------------------------------------------------------
1 | """This package contains modules related to objective functions, optimizations, and network architectures.
2 |
3 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
4 | You need to implement the following five functions:
5 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
6 | -- : unpack data from dataset and apply preprocessing.
7 | -- : produce intermediate results.
8 | -- : calculate loss, gradients, and update network weights.
9 | -- : (optionally) add model-specific options and set default options.
10 |
11 | In the function <__init__>, you need to define four lists:
12 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
13 | -- self.model_names (str list): define networks used in our training.
14 | -- self.visual_names (str list): specify the images that you want to display and save.
15 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
16 |
17 | Now you can use the model class by specifying flag '--model dummy'.
18 | See our template model class 'template_model.py' for more details.
19 | """
20 |
21 | import importlib
22 | from .base_model import BaseModel
23 |
24 |
25 | def find_model_using_name(model_name):
26 | """Import the module "models/[model_name]_model.py".
27 |
28 | In the file, the class called DatasetNameModel() will
29 | be instantiated. It has to be a subclass of BaseModel,
30 | and it is case-insensitive.
31 | """
32 | model_filename = "models." + model_name + "_model"
33 | modellib = importlib.import_module(model_filename)
34 | model = None
35 | target_model_name = model_name.replace('_', '') + 'model'
36 | for name, cls in modellib.__dict__.items():
37 | if name.lower() == target_model_name.lower() \
38 | and issubclass(cls, BaseModel):
39 | model = cls
40 |
41 | if model is None:
42 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
43 | exit(0)
44 |
45 | return model
46 |
47 |
48 | def get_option_setter(model_name):
49 | """Return the static method of the model class."""
50 | model_class = find_model_using_name(model_name)
51 | return model_class.modify_commandline_options
52 |
53 |
54 | def create_model(opt):
55 | """Create a model given the option.
56 |
57 | This function warps the class CustomDatasetDataLoader.
58 | This is the main interface between this package and 'train.py'/'test.py'
59 |
60 | Example:
61 | >>> from models import create_model
62 | >>> model = create_model(opt)
63 | """
64 | model = find_model_using_name(opt.model)
65 | instance = model(opt)
66 | print("model [%s] was created" % type(instance).__name__)
67 | return instance
68 |
--------------------------------------------------------------------------------
/third_party/models/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/models/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/models/__pycache__/base_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/base_model.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/models/__pycache__/base_model.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/base_model.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/models/__pycache__/networks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/networks.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/models/__pycache__/networks.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/networks.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/models/__pycache__/networks_lnr.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/networks_lnr.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/models/__pycache__/networks_lnr.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/models/__pycache__/networks_lnr.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/models/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from collections import OrderedDict
4 | from abc import ABC, abstractmethod
5 | import numpy as np
6 | from . import networks
7 |
8 |
9 | class BaseModel(ABC):
10 | """This class is an abstract base class (ABC) for models.
11 | To create a subclass, you need to implement the following five functions:
12 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
13 | -- : unpack data from dataset and apply preprocessing.
14 | -- : produce intermediate results.
15 | -- : calculate losses, gradients, and update network weights.
16 | -- : (optionally) add model-specific options and set default options.
17 | """
18 |
19 | def __init__(self, opt):
20 | """Initialize the BaseModel class.
21 |
22 | Parameters:
23 | opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
24 |
25 | When creating your custom class, you need to implement your own initialization.
26 | In this function, you should first call
27 | Then, you need to define four lists:
28 | -- self.loss_names (str list): specify the training losses that you want to plot and save.
29 | -- self.model_names (str list): define networks used in our training.
30 | -- self.visual_names (str list): specify the images that you want to display and save.
31 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
32 | """
33 | self.opt = opt
34 | self.gpu_ids = opt.gpu_ids
35 | self.isTrain = opt.isTrain
36 | self.device = (
37 | torch.device("cuda:{}".format(self.gpu_ids[0]))
38 | if self.gpu_ids
39 | else torch.device("cpu")
40 | ) # get device name: CPU or GPU
41 | self.save_dir = os.path.join(
42 | opt.checkpoints_dir, opt.name
43 | ) # save all the checkpoints to save_dir
44 | self.loss_names = []
45 | self.model_names = []
46 | self.visual_names = []
47 | self.optimizers = []
48 | self.image_paths = []
49 | self.metric = 0 # used for learning rate policy 'plateau'
50 |
51 | @staticmethod
52 | def modify_commandline_options(parser, is_train):
53 | """Add new model-specific options, and rewrite default values for existing options.
54 |
55 | Parameters:
56 | parser -- original option parser
57 | is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
58 |
59 | Returns:
60 | the modified parser.
61 | """
62 | return parser
63 |
64 | @abstractmethod
65 | def set_input(self, input):
66 | """Unpack input data from the dataloader and perform necessary pre-processing steps.
67 |
68 | Parameters:
69 | input (dict): includes the data itself and its metadata information.
70 | """
71 | pass
72 |
73 | @abstractmethod
74 | def forward(self):
75 | """Run forward pass; called by both functions and ."""
76 | pass
77 |
78 | @abstractmethod
79 | def optimize_parameters(self):
80 | """Calculate losses, gradients, and update network weights; called in every training iteration"""
81 | pass
82 |
83 | def setup(self, opt):
84 | """Load and print networks; create schedulers
85 |
86 | Parameters:
87 | opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
88 | """
89 | if self.isTrain:
90 | self.schedulers = [
91 | networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers
92 | ]
93 | if not self.isTrain or opt.continue_train:
94 | if opt.epoch != "latest":
95 | load_suffix = "epoch_" + opt.epoch
96 | else:
97 | load_suffix = opt.epoch
98 | self.load_networks(load_suffix)
99 | self.print_networks(opt.verbose)
100 |
101 | def eval(self):
102 | """Make models eval mode during test time"""
103 | for name in self.model_names:
104 | if isinstance(name, str):
105 | net = getattr(self, "net" + name)
106 | net.eval()
107 |
108 | def test(self, total_iters):
109 | """Forward function used in test time.
110 |
111 | This function wraps function in no_grad() so we don't save intermediate steps for backprop
112 | It also calls to produce additional visualization results
113 | """
114 | if self.opt.gradient_debug:
115 | self.set_requires_grad(self.netOmnimatteGAN, True)
116 | self.netOmnimatteGAN.train()
117 | web_dir = os.path.join(
118 | self.opt.results_dir,
119 | self.opt.name,
120 | "{}_{}_{}".format(self.opt.phase, self.opt.epoch, self.opt.test_suffix),
121 | )
122 | self.forward()
123 | grad_map = self.gradient_debug(total_iters)
124 | np.save(
125 | os.path.join(
126 | web_dir,
127 | str(total_iters) + "_grad_map.npy",
128 | ),
129 | grad_map,
130 | )
131 | else:
132 | with torch.no_grad():
133 | self.forward()
134 |
135 | def compute_visuals(self):
136 | """Calculate additional output images for visdom and HTML visualization"""
137 | pass
138 |
139 | def get_image_paths(self):
140 | """Return image paths that are used to load current data"""
141 | return self.image_paths
142 |
143 | def update_learning_rate(self):
144 | """Update learning rates for all the networks; called at the end of every epoch"""
145 | old_lr = self.optimizers[0].param_groups[0]["lr"]
146 | for scheduler in self.schedulers:
147 | if self.opt.lr_policy == "plateau":
148 | scheduler.step(self.metric)
149 | else:
150 | scheduler.step()
151 |
152 | lr = self.optimizers[0].param_groups[0]["lr"]
153 | if old_lr != lr:
154 | print("learning rate %.7f -> %.7f" % (old_lr, lr))
155 |
156 | def get_current_visuals(self):
157 | """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
158 | visual_ret = OrderedDict()
159 | for name in self.visual_names:
160 | if isinstance(name, str):
161 | visual_ret[name] = getattr(self, name)
162 | return visual_ret
163 |
164 | def get_current_losses(self):
165 | """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
166 | errors_ret = OrderedDict()
167 | for name in self.loss_names:
168 | if isinstance(name, str):
169 | errors_ret[name] = float(
170 | getattr(self, "loss_" + name)
171 | ) # float(...) works for both scalar tensor and float number
172 | return errors_ret
173 |
174 | def save_networks(self, epoch):
175 | """Save all the networks to the disk.
176 |
177 | Parameters:
178 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
179 | """
180 | for name in self.model_names:
181 | if isinstance(name, str):
182 | save_filename = "%s_net_%s.pth" % (epoch, name)
183 | save_path = os.path.join(self.save_dir, save_filename)
184 | net = getattr(self, "net" + name)
185 |
186 | if len(self.gpu_ids) > 0 and torch.cuda.is_available():
187 | torch.save(net.module.cpu().state_dict(), save_path)
188 | net.cuda(self.gpu_ids[0])
189 | else:
190 | torch.save(net.cpu().state_dict(), save_path)
191 |
192 | def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
193 | """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
194 | key = keys[i]
195 | if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
196 | if module.__class__.__name__.startswith("InstanceNorm") and (
197 | key == "running_mean" or key == "running_var"
198 | ):
199 | if getattr(module, key) is None:
200 | state_dict.pop(".".join(keys))
201 | if module.__class__.__name__.startswith("InstanceNorm") and (
202 | key == "num_batches_tracked"
203 | ):
204 | state_dict.pop(".".join(keys))
205 | else:
206 | self.__patch_instance_norm_state_dict(
207 | state_dict, getattr(module, key), keys, i + 1
208 | )
209 |
210 | def load_networks(self, epoch):
211 | """Load all the networks from the disk.
212 |
213 | Parameters:
214 | epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
215 | """
216 | for name in self.model_names:
217 | if isinstance(name, str):
218 | load_filename = "%s_net_%s.pth" % (epoch, name)
219 | load_path = os.path.join(self.save_dir, load_filename)
220 | net = getattr(self, "net" + name)
221 | if isinstance(net, torch.nn.DataParallel):
222 | net = net.module
223 | print("loading the model from %s" % load_path)
224 | # if you are using PyTorch newer than 0.4 (e.g., built from
225 | # GitHub source), you can remove str() on self.device
226 | state_dict = torch.load(load_path, map_location=str(self.device))
227 | if hasattr(state_dict, "_metadata"):
228 | del state_dict._metadata
229 |
230 | # patch InstanceNorm checkpoints prior to 0.4
231 | for key in list(
232 | state_dict.keys()
233 | ): # need to copy keys here because we mutate in loop
234 | self.__patch_instance_norm_state_dict(
235 | state_dict, net, key.split(".")
236 | )
237 | if net.state_dict()[key].shape != state_dict[key].shape:
238 | print(key, 'size mismatch! duplicating...')
239 | repeat_times = [2]+[1]*(len(net.state_dict()[key].shape)-1)
240 | state_dict[key] = state_dict[key].repeat(repeat_times)
241 | print(state_dict[key].size())
242 |
243 | net.load_state_dict(state_dict)
244 |
245 | def print_networks(self, verbose):
246 | """Print the total number of parameters in the network and (if verbose) network architecture
247 |
248 | Parameters:
249 | verbose (bool) -- if verbose: print the network architecture
250 | """
251 | print("---------- Networks initialized -------------")
252 | for name in self.model_names:
253 | if isinstance(name, str):
254 | net = getattr(self, "net" + name)
255 | num_params = 0
256 | num_trainable_params = 0
257 | for param in net.parameters():
258 | num_params += param.numel()
259 | if param.requires_grad:
260 | num_trainable_params += param.numel()
261 | if verbose:
262 | print(net)
263 | print(
264 | "[Network %s] Total number of parameters : %.3f M"
265 | % (name, num_params / 1e6)
266 | )
267 | print(
268 | "[Network %s] Total number of trainable parameters : %.3f M"
269 | % (name, num_trainable_params / 1e6)
270 | )
271 | print("-----------------------------------------------")
272 |
273 | def set_requires_grad(self, nets, requires_grad=False):
274 | """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
275 | Parameters:
276 | nets (network list) -- a list of networks
277 | requires_grad (bool) -- whether the networks require gradients or not
278 | """
279 | if not isinstance(nets, list):
280 | nets = [nets]
281 | for net in nets:
282 | if net is not None:
283 | for param in net.parameters():
284 | param.requires_grad = requires_grad
285 |
--------------------------------------------------------------------------------
/third_party/models/networks.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.optim import lr_scheduler
4 |
5 |
6 | ###############################################################################
7 | # Helper Functions
8 | ###############################################################################
9 | def get_scheduler(optimizer, opt):
10 | """Return a learning rate scheduler
11 |
12 | Parameters:
13 | optimizer -- the optimizer of the network
14 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
15 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
16 |
17 | For 'linear', we keep the same learning rate for the first epochs
18 | and linearly decay the rate to zero over the next epochs.
19 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
20 | See https://pytorch.org/docs/stable/optim.html for more details.
21 | """
22 | if opt.lr_policy == 'linear':
23 | def lambda_rule(epoch):
24 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
25 | return lr_l
26 |
27 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
28 | elif opt.lr_policy == 'step':
29 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
30 | elif opt.lr_policy == 'plateau':
31 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
32 | elif opt.lr_policy == 'cosine':
33 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
34 | else:
35 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
36 | return scheduler
37 |
38 |
39 | def init_net(net, gpu_ids=[]):
40 | """Initialize a network by registering CPU/GPU device (with multi-GPU support)
41 | Parameters:
42 | net (network) -- the network to be initialized
43 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2
44 |
45 | Return an initialized network.
46 | """
47 | if len(gpu_ids) > 0:
48 | assert (torch.cuda.is_available())
49 | net.to(gpu_ids[0])
50 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs
51 | return net
52 |
--------------------------------------------------------------------------------
/third_party/util/__init__.py:
--------------------------------------------------------------------------------
1 | """This package includes a miscellaneous collection of useful helper functions."""
2 |
--------------------------------------------------------------------------------
/third_party/util/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__init__.pyc
--------------------------------------------------------------------------------
/third_party/util/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/util/__pycache__/__init__.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/__init__.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/util/__pycache__/html.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/html.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/util/__pycache__/html.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/html.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/util/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/util/__pycache__/util.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/util.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/util/__pycache__/visualizer.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/visualizer.cpython-38.pyc
--------------------------------------------------------------------------------
/third_party/util/__pycache__/visualizer.cpython-39.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/__pycache__/visualizer.cpython-39.pyc
--------------------------------------------------------------------------------
/third_party/util/html.py:
--------------------------------------------------------------------------------
1 | import dominate
2 | from dominate.tags import meta, h3, table, tr, td, p, a, img, br, video, source
3 | import os
4 |
5 |
6 | class HTML:
7 | """This HTML class allows us to save images and write texts into a single HTML file.
8 |
9 | It consists of functions such as (add a text header to the HTML file),
10 | (add a row of images to the HTML file), and (save the HTML to the disk).
11 | It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
12 | """
13 |
14 | def __init__(self, web_dir, title, refresh=0):
15 | """Initialize the HTML classes
16 |
17 | Parameters:
18 | web_dir (str) -- a directory that stores the webpage. HTML file will be created at /index.html; images will be saved at 0:
35 | with self.doc.head:
36 | meta(http_equiv="refresh", content=str(refresh))
37 |
38 | def get_image_dir(self):
39 | """Return the directory that stores images"""
40 | return self.img_dir
41 |
42 | def get_video_dir(self):
43 | """Return the directory that stores videos"""
44 | return self.vid_dir
45 |
46 | def add_header(self, text):
47 | """Insert a header to the HTML file
48 |
49 | Parameters:
50 | text (str) -- the header text
51 | """
52 | with self.doc:
53 | h3(text)
54 |
55 | def add_images(self, ims, txts, links, width=400):
56 | """add images to the HTML file
57 |
58 | Parameters:
59 | ims (str list) -- a list of image paths
60 | txts (str list) -- a list of image names shown on the website
61 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
62 | """
63 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
64 | self.doc.add(self.t)
65 | with self.t:
66 | with tr():
67 | for im, txt, link in zip(ims, txts, links):
68 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
69 | with p():
70 | with a(href=os.path.join('images', link)):
71 | img(style="width:%dpx" % width, src=os.path.join('images', im))
72 | br()
73 | p(txt)
74 |
75 | def add_videos(self, vids, txts, links, width=400):
76 | """add images to the HTML file
77 |
78 | Parameters:
79 | ims (str list) -- a list of image paths
80 | txts (str list) -- a list of image names shown on the website
81 | links (str list) -- a list of hyperref links; when you click an image, it will redirect you to a new page
82 | """
83 | self.t = table(border=1, style="table-layout: fixed;") # Insert a table
84 | self.doc.add(self.t)
85 | with self.t:
86 | with tr():
87 | for vid, txt, link in zip(vids, txts, links):
88 | with td(style="word-wrap: break-word;", halign="center", valign="top"):
89 | with p():
90 | with a(href=os.path.join('videos', link)):
91 | with video(style="width:%dpx" % width, controls=True):
92 | source(src=os.path.join('videos', vid), type="video/mp4")
93 | br()
94 | p(txt)
95 |
96 | def save(self):
97 | """save the current content to the HMTL file"""
98 | html_file = '%s/index.html' % self.web_dir
99 | f = open(html_file, 'wt')
100 | f.write(self.doc.render())
101 | f.close()
102 |
103 |
104 | if __name__ == '__main__': # we show an example usage here.
105 | html = HTML('web/', 'test_html')
106 | html.add_header('hello world')
107 |
108 | ims, txts, links = [], [], []
109 | for n in range(4):
110 | ims.append('image_%d.png' % n)
111 | txts.append('text_%d' % n)
112 | links.append('image_%d.png' % n)
113 | html.add_images(ims, txts, links)
114 | html.save()
115 |
--------------------------------------------------------------------------------
/third_party/util/util.py:
--------------------------------------------------------------------------------
1 | """This module contains simple helper functions """
2 | from __future__ import print_function
3 | import torch
4 | import numpy as np
5 | from PIL import Image
6 | import os
7 |
8 |
9 | def tensor2im(input_image, imtype=np.uint8):
10 | """"Converts a Tensor array into a numpy image array.
11 |
12 | Parameters:
13 | input_image (tensor) -- the input image tensor array
14 | imtype (type) -- the desired type of the converted numpy array
15 | """
16 | if not isinstance(input_image, np.ndarray):
17 | if isinstance(input_image, torch.Tensor): # get the data from a variable
18 | image_tensor = input_image.data
19 | else:
20 | return input_image
21 | image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array
22 | if image_numpy.shape[0] == 1: # grayscale to RGB
23 | image_numpy = np.tile(image_numpy, (3, 1, 1))
24 | image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
25 | if image_numpy.shape[-1] == 4:
26 | image_numpy = render_png(image_numpy)
27 | else: # if it is a numpy array, do nothing
28 | image_numpy = input_image
29 | if image_numpy.shape[-1] == 4:
30 | image_numpy = render_png(image_numpy)
31 | return image_numpy.astype(imtype)
32 |
33 |
34 | def render_png(image, background='checker'):
35 | height, width = image.shape[:2]
36 | if background == 'checker':
37 | checkerboard = np.kron([[136, 120] * (width//128+1), [120, 136] * (width//128+1)] * (height//128+1), np.ones((16, 16)))
38 | checkerboard = np.expand_dims(np.tile(checkerboard, (4, 4)), -1)
39 | bg = checkerboard[:height, :width]
40 | elif background == 'black':
41 | bg = np.zeros([height, width, 1])
42 | else:
43 | bg = 255 * np.ones([height, width, 1])
44 | image = image.astype(np.float32)
45 | alpha = image[:, :, 3:] / 255
46 | rendered_image = alpha * image[:, :, :3] + (1 - alpha) * bg
47 | return rendered_image.astype(np.uint8)
48 |
49 |
50 | def add_title(image, title_text):
51 | Print('please put a dir for a font here')
52 | font_dir = ''
53 | from PIL import Image, ImageFont, ImageDraw
54 | import matplotlib.font_manager as fm
55 | image = Image.fromarray(image)
56 | title_font = ImageFont.truetype(font_dir, 35)
57 | image_editable = ImageDraw.Draw(image)
58 | image_editable.text((10,10), title_text, (255,255,255), stroke_fill=(0,0,0), font=title_font, stroke_width=3)
59 | return np.asarray(image)
60 |
61 |
62 | def diagnose_network(net, name='network'):
63 | """Calculate and print the mean of average absolute(gradients)
64 |
65 | Parameters:
66 | net (torch network) -- Torch network
67 | name (str) -- the name of the network
68 | """
69 | mean = 0.0
70 | count = 0
71 | for param in net.parameters():
72 | if param.grad is not None:
73 | mean += torch.mean(torch.abs(param.grad.data))
74 | count += 1
75 | if count > 0:
76 | mean = mean / count
77 | print(name)
78 | print(mean)
79 |
80 |
81 | def save_image(image_numpy, image_path, aspect_ratio=1.0):
82 | """Save a numpy image to the disk
83 |
84 | Parameters:
85 | image_numpy (numpy array) -- input numpy array
86 | image_path (str) -- the path of the image
87 | """
88 | image_pil = Image.fromarray(image_numpy)
89 | h, w, _ = image_numpy.shape
90 |
91 | if aspect_ratio > 1.0:
92 | image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
93 | if aspect_ratio < 1.0:
94 | image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
95 | image_pil.save(image_path)
96 |
97 |
98 | def print_numpy(x, val=True, shp=False):
99 | """Print the mean, min, max, median, std, and size of a numpy array
100 |
101 | Parameters:
102 | val (bool) -- if print the values of the numpy array
103 | shp (bool) -- if print the shape of the numpy array
104 | """
105 | x = x.astype(np.float64)
106 | if shp:
107 | print('shape,', x.shape)
108 | if val:
109 | x = x.flatten()
110 | print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
111 | np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
112 |
113 |
114 | def mkdirs(paths):
115 | """create empty directories if they don't exist
116 |
117 | Parameters:
118 | paths (str list) -- a list of directory paths
119 | """
120 | if isinstance(paths, list) and not isinstance(paths, str):
121 | for path in paths:
122 | mkdir(path)
123 | else:
124 | mkdir(paths)
125 |
126 |
127 | def mkdir(path):
128 | """create a single empty directory if it didn't exist
129 |
130 | Parameters:
131 | path (str) -- a single directory path
132 | """
133 | if not os.path.exists(path):
134 | os.makedirs(path)
135 |
--------------------------------------------------------------------------------
/third_party/util/util.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/third_party/util/util.pyc
--------------------------------------------------------------------------------
/train_GAN.py:
--------------------------------------------------------------------------------
1 | """Script for training an Omnimatte model on a video.
2 |
3 | You need to specify the dataset ('--dataroot') and experiment name ('--name').
4 |
5 | Example:
6 | python train.py --dataroot ./datasets/tennis --name tennis --gpu_ids 0,1
7 |
8 | The script first creates a model, dataset, and visualizer given the options.
9 | It then does standard network training. During training, it also visualizes/saves the images, prints/saves the loss
10 | plot, and saves the model.
11 | Use '--continue_train' to resume your previous training.
12 |
13 | See options/base_options.py and options/train_options.py for more training options.
14 | """
15 | import time
16 | from options.train_options import TrainOptions
17 | from third_party.data import create_dataset
18 | from third_party.models import create_model
19 | from third_party.util.visualizer import Visualizer
20 | import torch
21 | import numpy as np
22 | import random
23 | import os
24 |
25 |
26 | def main():
27 | trainopt = TrainOptions()
28 | opt = trainopt.parse()
29 |
30 | torch.manual_seed(opt.seed)
31 | np.random.seed(opt.seed)
32 | random.seed(opt.seed)
33 |
34 | dataset = create_dataset(opt)
35 | dataset_size = len(dataset)
36 | print("The number of training images = %d" % dataset_size)
37 | if opt.n_epochs is None:
38 | assert opt.n_steps, "You must specify one of n_epochs or n_steps."
39 | opt.n_epochs = int(
40 | opt.n_steps / np.ceil(dataset_size)
41 | ) # / opt.batch_size divide by bs seems wierd
42 | opt.n_epochs_decay = int(opt.n_steps_decay / np.ceil(dataset_size / opt.batch_size))
43 | total_iters = 0
44 | model = create_model(opt)
45 | model.setup(opt) # regular setup: load and print networks; create schedulers
46 | if opt.continue_train:
47 | opt.epoch_count = int(opt.epoch) + 1
48 | if opt.overwrite_lambdas:
49 | # Setting parameters here will overwrite the previous code
50 | history = torch.load(
51 | os.path.join(model.save_dir, opt.epoch + "_others.pth"), map_location='cuda:0'
52 | )
53 | for name in model.lambda_names:
54 | if isinstance(name, str):
55 | setattr(model, "lambda_" + name, history["lambda_" + name])
56 | print(
57 | "lambdas overwritten args",
58 | "lambda_" + name,
59 | getattr(model, "lambda_" + name, None),
60 | )
61 | total_iters = history["total_iters"]
62 | model.jitter_rgb = history["jitter_rgb"]
63 | model.do_cam_adj = history["do_cam_adj"]
64 | # Assume when continue by loading, there're already plenty of epochs passed
65 | # such that mask loss is no longer needed (set to 0)
66 | # model.mask_loss_rolloff_epoch = 0
67 | model.mask_loss_rolloff_epoch = history["mask_loss_rolloff_epoch"]
68 | print(
69 | "other params overwritten args",
70 | model.jitter_rgb,
71 | model.do_cam_adj,
72 | total_iters,
73 | opt.epoch_count,
74 | model.mask_loss_rolloff_epoch,
75 | )
76 |
77 | for i in range(len(model.discriminators)):
78 | if (model.discriminators[i] is not None) and (
79 | "discriminator_l" + str(i) in history
80 | ):
81 | model.discriminators[i].load_state_dict(
82 | history["discriminator_l" + str(i)], strict=False
83 | )
84 | print(i, "th discriminator weights loaded unstrictly")
85 | print(
86 | "the dict in the history is",
87 | history["discriminator_l" + str(i)].keys(),
88 | )
89 | print(
90 | "the dict in current model is",
91 | model.discriminators[i].state_dict().keys(),
92 | )
93 | model.discriminators[i].train()
94 |
95 | if opt.overwrite_lrs:
96 | print("lr overwritten args", history["lrs"])
97 | for i in range(len(model.optimizers)):
98 | optimizer = model.optimizers[i]
99 | for g in optimizer.param_groups:
100 | g["lr"] = history["lrs"][i]
101 |
102 | visualizer = Visualizer(opt)
103 | train(model, dataset, visualizer, opt, total_iters)
104 |
105 |
106 | def train(model, dataset, visualizer, opt, total_iters):
107 | dataset_size = len(dataset)
108 | for epoch in range(
109 | opt.epoch_count, opt.n_epochs + opt.n_epochs_decay + 1
110 | ): # outer loop for different epochs; we save the model by , +
111 | epoch_start_time = time.time() # timer for entire epoch
112 | iter_data_time = time.time() # timer for data loading per iteration
113 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
114 | model.update_lambdas(epoch)
115 | if epoch == opt.epoch_count:
116 | save_result = True
117 | dp = dataset.dataset[opt.display_ind]
118 | for k, v in dp.items():
119 | if torch.is_tensor(v):
120 | dp[k] = v.unsqueeze(0)
121 | else:
122 | dp[k] = [v]
123 | model.set_input(dp)
124 | model.compute_visuals(total_iters)
125 | visualizer.display_current_results(
126 | model.get_current_visuals(), 0, save_result
127 | )
128 |
129 | for i, data in enumerate(dataset): # inner loop within one epoch
130 | iter_start_time = time.time() # timer for computation per iteration
131 | if i % opt.print_freq == 0:
132 | t_data = iter_start_time - iter_data_time
133 | # #iters are not exact because the last batch might not suffice.
134 | total_iters += opt.batch_size
135 | epoch_iter += opt.batch_size
136 | model.set_input(data)
137 | model.optimize_parameters(total_iters, epoch)
138 |
139 | if (
140 | i % opt.print_freq == 0
141 | ): # print training losses and save logging information to the disk
142 | print(opt.name)
143 | losses = model.get_current_losses()
144 | t_comp = (time.time() - iter_start_time) / opt.batch_size
145 | visualizer.print_current_losses(
146 | epoch, epoch_iter, losses, t_comp, t_data
147 | )
148 | if opt.display_id > 0:
149 | visualizer.plot_current_losses(
150 | epoch, float(epoch_iter) / dataset_size, losses
151 | )
152 | iter_data_time = time.time()
153 |
154 | if (
155 | epoch % opt.display_freq == 1
156 | ): # display images on visdom and save images to a HTML file
157 | save_result = epoch % opt.update_html_freq == 1
158 | dp = dataset.dataset[opt.display_ind]
159 | for k, v in dp.items():
160 | if torch.is_tensor(v):
161 | dp[k] = v.unsqueeze(0)
162 | else:
163 | dp[k] = [v]
164 | model.set_input(dp)
165 | model.compute_visuals(total_iters)
166 | visualizer.display_current_results(
167 | model.get_current_visuals(), epoch, save_result
168 | )
169 |
170 | if (
171 | epoch % opt.save_latest_freq == 0 or epoch == opt.epoch_count
172 | ): # opt.n_epochs + opt.n_epochs_decay: # cache our latest model every epochs
173 | print(
174 | "saving the latest model (epoch %d, total_iters %d)"
175 | % (epoch, total_iters)
176 | )
177 | save_suffix = "epoch_%d" % epoch if opt.save_by_epoch else "latest"
178 | model.save_networks(save_suffix)
179 | others = {
180 | "lrs": [i.param_groups[0]["lr"] for i in model.optimizers],
181 | "jitter_rgb": model.jitter_rgb,
182 | "do_cam_adj": model.do_cam_adj,
183 | "total_iters": total_iters,
184 | }
185 | for i in range(len(model.discriminators)):
186 | if model.discriminators[i] is not None:
187 | others["discriminator_l" + str(i)] = model.discriminators[
188 | i
189 | ].state_dict()
190 | for name in model.lambda_names:
191 | if isinstance(name, str):
192 | others["lambda_" + name] = float(getattr(model, "lambda_" + name))
193 | others["lambda_Ds"] = torch.tensor(model.lambda_Ds)
194 | others["lambda_plausibles"] = torch.tensor(model.lambda_plausibles)
195 | others["mask_loss_rolloff_epoch"] = model.mask_loss_rolloff_epoch
196 | torch.save(
197 | others,
198 | os.path.join(opt.checkpoints_dir, opt.name, str(epoch) + "_others.pth"),
199 | )
200 |
201 | if ((epoch == 1) or (epoch % opt.update_D_epochs == 0)) and (
202 | model.optimizer_D is not None
203 | ):
204 | model.update_learning_rate([1])
205 | model.update_learning_rate(
206 | [0]
207 | ) # update learning rates at the end of every epoch.
208 | print(
209 | "End of epoch %d / %d \t Time Taken: %d sec"
210 | % (epoch, opt.n_epochs + opt.n_epochs_decay, time.time() - epoch_start_time)
211 | )
212 |
213 |
214 | def see_grad(model, dataset, visualizer, opt):
215 | total_iters = 0 # the total number of training iterations
216 | for f in os.listdir(opt.ckpt_dir):
217 | if "net_Omnimatte.pth" in f:
218 | weight = torch.load(os.path.join(opt.ckpt_dir, f))
219 | model.netOmnimatte.load_state_dict(weight)
220 | for epoch in range(
221 | 1
222 | ): # outer loop for different epochs; we save the model by , +
223 | epoch_start_time = time.time() # timer for entire epoch
224 | iter_data_time = time.time() # timer for data loading per iteration
225 | epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch
226 | model.update_lambdas(epoch)
227 | for i, data in enumerate(dataset): # inner loop within one epoch
228 | if i == 0:
229 | iter_start_time = time.time() # timer for computation per iteration
230 | if i % opt.print_freq == 0:
231 | t_data = iter_start_time - iter_data_time
232 |
233 | total_iters += opt.batch_size
234 | epoch_iter += opt.batch_size
235 | model.set_input(data)
236 | model.optimize_parameters(total_iters)
237 | else:
238 | break
239 |
240 |
241 | if __name__ == "__main__":
242 | main()
243 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn.functional as F
4 | from PIL import Image
5 |
6 |
7 | def numpy2im(np_array):
8 | """convert numpy float image to PIL Image"""
9 | return Image.fromarray((np_array * 255).astype(np.uint8))
10 |
11 |
12 | def readFlow(fn):
13 | """Read .flo file in Middlebury format"""
14 | # Code adapted from:
15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
16 |
17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
18 | # print 'fn = %s'%(fn)
19 | with open(fn, "rb") as f:
20 | magic = np.fromfile(f, np.float32, count=1)
21 | if 202021.25 != magic:
22 | print("Magic number incorrect. Invalid .flo file")
23 | return None
24 | else:
25 | w = np.fromfile(f, np.int32, count=1)
26 | h = np.fromfile(f, np.int32, count=1)
27 | # print 'Reading %d x %d flo file\n' % (w, h)
28 | data = np.fromfile(f, np.float32, count=2 * int(w) * int(h))
29 | # Reshape data into 3D array (columns, rows, bands)
30 | # The reshape here is for visualization, the original code is (w,h,2)
31 | return np.resize(data, (int(h), int(w), 2))
32 |
33 |
34 | def resize_flow(flow, width, height):
35 | orig_h, orig_w = flow.shape[1:]
36 | flow = F.interpolate(flow.unsqueeze(0), (height, width), mode="bilinear").squeeze(0)
37 | flow[0] *= width / orig_w
38 | flow[1] *= height / orig_h
39 | return flow
40 |
41 |
42 | def tensor_flow_to_image(
43 | flow_uv, clip_flow=None, convert_to_bgr=False, global_max=None
44 | ):
45 | flow_np = flow_uv.permute(1, 2, 0).cpu().numpy()
46 | image = flow_to_image(flow_np, clip_flow, convert_to_bgr, global_max)
47 | image = torch.from_numpy(image).permute(2, 0, 1)
48 | return image.float() / 255.0 * 2 - 1
49 |
50 |
51 | # The following flow visualization code is from https://github.com/tomrunia/OpticalFlow_Visualization
52 | def make_colorwheel():
53 | """
54 | Generates a color wheel for optical flow visualization as presented in:
55 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
56 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
57 | Code follows the original C++ source code of Daniel Scharstein.
58 | Code follows the the Matlab source code of Deqing Sun.
59 | Returns:
60 | np.ndarray: Color wheel
61 | """
62 |
63 | RY = 15
64 | YG = 6
65 | GC = 4
66 | CB = 11
67 | BM = 13
68 | MR = 6
69 |
70 | ncols = RY + YG + GC + CB + BM + MR
71 | colorwheel = np.zeros((ncols, 3))
72 | col = 0
73 |
74 | # RY
75 | colorwheel[0:RY, 0] = 255
76 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
77 | col = col + RY
78 | # YG
79 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
80 | colorwheel[col : col + YG, 1] = 255
81 | col = col + YG
82 | # GC
83 | colorwheel[col : col + GC, 1] = 255
84 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
85 | col = col + GC
86 | # CB
87 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
88 | colorwheel[col : col + CB, 2] = 255
89 | col = col + CB
90 | # BM
91 | colorwheel[col : col + BM, 2] = 255
92 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
93 | col = col + BM
94 | # MR
95 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
96 | colorwheel[col : col + MR, 0] = 255
97 | return colorwheel
98 |
99 |
100 | def flow_uv_to_colors(u, v, convert_to_bgr=False):
101 | """
102 | Applies the flow color wheel to (possibly clipped) flow components u and v.
103 | According to the C++ source code of Daniel Scharstein
104 | According to the Matlab source code of Deqing Sun
105 | Args:
106 | u (np.ndarray): Input horizontal flow of shape [H,W]
107 | v (np.ndarray): Input vertical flow of shape [H,W]
108 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
109 | Returns:
110 | np.ndarray: Flow visualization image of shape [H,W,3]
111 | """
112 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
113 | colorwheel = make_colorwheel() # shape [55x3]
114 | ncols = colorwheel.shape[0]
115 | rad = np.sqrt(np.square(u) + np.square(v))
116 | a = np.arctan2(-v, -u) / np.pi
117 | fk = (a + 1) / 2 * (ncols - 1)
118 | k0 = np.floor(fk).astype(np.int32)
119 | k1 = k0 + 1
120 | k1[k1 == ncols] = 0
121 | f = fk - k0
122 | for i in range(colorwheel.shape[1]):
123 | tmp = colorwheel[:, i]
124 | col0 = tmp[k0] / 255.0
125 | col1 = tmp[k1] / 255.0
126 | col = (1 - f) * col0 + f * col1
127 | idx = rad <= 1
128 | col[idx] = 1 - rad[idx] * (1 - col[idx])
129 | col[~idx] = col[~idx] * 0.75 # out of range
130 | # Note the 2-i => BGR instead of RGB
131 | ch_idx = 2 - i if convert_to_bgr else i
132 | flow_image[:, :, ch_idx] = np.floor(255 * col)
133 | return flow_image
134 |
135 |
136 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False, global_max=None):
137 | """
138 | Expects a two dimensional flow image of shape.
139 | Args:
140 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
141 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
142 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
143 | Returns:
144 | np.ndarray: Flow visualization image of shape [H,W,3]
145 | """
146 | assert flow_uv.ndim == 3, "input flow must have three dimensions"
147 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
148 | if clip_flow is not None:
149 | flow_uv = np.clip(flow_uv, 0, clip_flow)
150 | u = flow_uv[:, :, 0]
151 | v = flow_uv[:, :, 1]
152 | rad = np.sqrt(np.square(u) + np.square(v))
153 | rad_max = global_max if global_max else np.max(rad)
154 | # import pdb
155 |
156 | # pdb.set_trace()
157 | epsilon = 1e-5
158 | u = u / (rad_max + epsilon)
159 | v = v / (rad_max + epsilon)
160 | return flow_uv_to_colors(u, v, convert_to_bgr)
161 |
--------------------------------------------------------------------------------
/video_completion.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | sys.path.append(os.path.abspath(os.path.join(__file__, '..', '..')))
4 |
5 | import argparse
6 | import os
7 | import numpy as np
8 | import torch
9 | from PIL import Image
10 | import glob
11 | import torchvision.transforms.functional as F
12 |
13 | from RAFT import utils
14 | from RAFT import RAFT
15 |
16 |
17 | def create_dir(dir):
18 | """Creates a directory if not exist.
19 | """
20 | if not os.path.exists(dir):
21 | os.makedirs(dir)
22 |
23 |
24 | def initialize_RAFT(args):
25 | """Initializes the RAFT model.
26 | """
27 | model = torch.nn.DataParallel(RAFT(args))
28 | model.load_state_dict(torch.load(args.model))
29 |
30 | model = model.module
31 | model.to('cuda')
32 | model.eval()
33 |
34 | return model
35 |
36 |
37 | def calculate_flow(args, model, video, mode):
38 | """Calculates optical flow.
39 | """
40 | if mode not in ['forward', 'backward']:
41 | raise NotImplementedError
42 |
43 | nFrame, _, imgH, imgW = video.shape
44 | Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
45 |
46 | # if os.path.isdir(os.path.join(args.outroot, 'flow', mode + '_flo')):
47 | # for flow_name in sorted(glob.glob(os.path.join(args.outroot, 'flow', mode + '_flo', '*.flo'))):
48 | # print("Loading {0}".format(flow_name), '\r', end='')
49 | # flow = utils.frame_utils.readFlow(flow_name)
50 | # Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
51 | # return Flow
52 | flow_folder = 'flow' + args.path.replace("/","")
53 | create_dir(os.path.join(args.outroot, flow_folder, mode + '_flo'))
54 | create_dir(os.path.join(args.outroot, flow_folder, mode + '_png'))
55 |
56 | with torch.no_grad():
57 | for i in range(video.shape[0] - 1):
58 | print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + 1), '\r', end='')
59 | if mode == 'forward':
60 | # Flow i -> i + 1
61 | image1 = video[i, None]
62 | image2 = video[i + 1, None]
63 | elif mode == 'backward':
64 | # Flow i + 1 -> i
65 | image1 = video[i + 1, None]
66 | image2 = video[i, None]
67 | else:
68 | raise NotImplementedError
69 |
70 | _, flow = model(image1, image2, iters=20, test_mode=True)
71 | flow = flow[0].permute(1, 2, 0).cpu().numpy()
72 | Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
73 |
74 | # Flow visualization.
75 | flow_img = utils.flow_viz.flow_to_image(flow)
76 | flow_img = Image.fromarray(flow_img)
77 |
78 | # Saves the flow and flow_img.
79 | flow_img.save(os.path.join(args.outroot, flow_folder, mode + '_png', '%05d.png'%i))
80 | # np.save(os.path.join(args.outroot, 'flow', mode + '_flo', '%05d.npy'%i), flow)
81 | utils.frame_utils.writeFlow(os.path.join(args.outroot, flow_folder, mode + '_flo', '%05d.flo'%i), flow)
82 |
83 | return Flow
84 |
85 |
86 | def calculate_flow_global(args, model, video, mode, step=1):
87 | """Calculates optical flow.
88 | """
89 | if mode not in ['forward', 'backward']:
90 | raise NotImplementedError
91 |
92 | nFrame, _, imgH, imgW = video.shape
93 | Flow = np.empty(((imgH, imgW, 2, 0)), dtype=np.float32)
94 |
95 | # if os.path.isdir(os.path.join(args.outroot, 'flow', mode + '_flo')):
96 | # for flow_name in sorted(glob.glob(os.path.join(args.outroot, 'flow', mode + '_flo', '*.flo'))):
97 | # print("Loading {0}".format(flow_name), '\r', end='')
98 | # flow = utils.frame_utils.readFlow(flow_name)
99 | # Flow = np.concatenate((Flow, flow[..., None]), axis=-1)
100 | # return Flow
101 | flow_folder = args.path.replace("/","")
102 | create_dir(os.path.join(args.outroot, flow_folder, mode + '_flow_step' + str(step)))
103 | create_dir(os.path.join(args.outroot, flow_folder, mode + '_png_step' + str(step)))
104 | global_max = -10000000
105 | with torch.no_grad():
106 | # for i in range(10):
107 | for i in range(video.shape[0] - step):
108 | print("Calculating {0} flow {1:2d} <---> {2:2d}".format(mode, i, i + step), '\r', end='')
109 | if mode == 'forward':
110 | # Flow i -> i + 1
111 | image1 = video[i, None]
112 | image2 = video[i + step, None]
113 | elif mode == 'backward':
114 | # Flow i + 1 -> i
115 | image1 = video[i + step, None]
116 | image2 = video[i, None]
117 | else:
118 | raise NotImplementedError
119 |
120 | _, flow = model(image1, image2, iters = 20, test_mode = True)
121 | flow_max = torch.sqrt(flow[0,0,:,:] ** 2 + flow[0, 1, :, :] ** 2).max()
122 | global_max = max(global_max, flow_max.cpu().numpy())
123 | print(global_max)
124 | flow = flow[0].permute(1, 2, 0).cpu().numpy()
125 | Flow = np.concatenate((Flow, flow[..., None]), axis = -1)
126 |
127 | for j in range(Flow.shape[-1]):
128 | flow=Flow[:,:,:,j]
129 | print(j)
130 | # Flow visualization.
131 | flow_img = utils.flow_viz.flow_to_image(flow, rad_max=global_max)
132 | flow_img = Image.fromarray(flow_img)
133 |
134 | # Saves the flow and flow_img.
135 | flow_img.save(os.path.join(args.outroot, flow_folder, mode + '_png_step' + str(step), '%05d.png'%j))
136 | utils.frame_utils.writeFlow(os.path.join(args.outroot, flow_folder, mode + '_flow_step' + str(step), '%05d.flo'%j), flow)
137 |
138 | return Flow
139 |
140 | def video_completion(args):
141 |
142 | # Flow model.
143 | RAFT_model = initialize_RAFT(args)
144 |
145 | # Loads frames.
146 | filename_list = glob.glob(os.path.join(args.path, '*.png'))
147 | # glob.glob(os.path.join(args.path, '*.jpg'))
148 |
149 | # Obtains imgH, imgW and nFrame.
150 | imgH, imgW = np.array(Image.open(filename_list[0]).convert('RGB')).shape[:2]
151 | nFrame = len(filename_list)
152 |
153 | # Loads video.
154 | video = []
155 | for filename in sorted(filename_list):
156 | video.append(torch.from_numpy(np.array(Image.open(filename).convert('RGB')).astype(np.uint8)).permute(2, 0, 1).float())
157 |
158 | video = torch.stack(video, dim=0)
159 | video = video.to('cuda')
160 |
161 | # Calcutes the corrupted flow.
162 | print('STEP', str(args.step))
163 | corrFlowF = calculate_flow_global(args, RAFT_model, video, 'forward', step=args.step) #_interval
164 | corrFlowB = calculate_flow_global(args, RAFT_model, video, 'backward', step=args.step) #_interval
165 | print('\nFinish flow prediction.')
166 |
167 |
168 |
169 |
170 | if __name__ == '__main__':
171 | parser = argparse.ArgumentParser()
172 | # video completion
173 | parser.add_argument('--seamless', action='store_true', help='Whether operate in the gradient domain')
174 | parser.add_argument('--edge_guide', action='store_true', help='Whether use edge as guidance to complete flow')
175 | parser.add_argument('--mode', default='object_removal', help="modes: object_removal / video_extrapolation")
176 | parser.add_argument('--path', default='../data/tennis', help="dataset for evaluation")
177 | parser.add_argument('--outroot', default='RAFT_result/', help="output directory")
178 | parser.add_argument('--consistencyThres', dest='consistencyThres', default=np.inf, type=float, help='flow consistency error threshold')
179 | parser.add_argument('--alpha', dest='alpha', default=0.1, type=float)
180 | parser.add_argument('--Nonlocal', dest='Nonlocal', default=False, type=bool)
181 | parser.add_argument('--step', default=1, type=int)
182 |
183 | # RAFT
184 | parser.add_argument('--model', default='weight/raft-things.pth', help="restore checkpoint")
185 | parser.add_argument('--small', action='store_true', help='use small model')
186 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision')
187 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation')
188 |
189 | args = parser.parse_args()
190 |
191 | video_completion(args)
192 |
--------------------------------------------------------------------------------
/weight/edge_completion.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/weight/edge_completion.pth
--------------------------------------------------------------------------------
/weight/imagenet_deepfill.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/weight/imagenet_deepfill.pth
--------------------------------------------------------------------------------
/weight/raft-things.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/jaclyngu/FactorMatte/9a4820a61775d264be8cef375c3e8b7d22bcc31b/weight/raft-things.pth
--------------------------------------------------------------------------------