├── RAFT_core ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── corr.cpython-37.pyc │ ├── extractor.cpython-37.pyc │ ├── raft.cpython-37.pyc │ └── update.cpython-37.pyc ├── corr.py ├── datasets.py ├── extractor.py ├── raft-things.pth-no-zip ├── raft.py ├── update.py └── utils │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-37.pyc │ └── utils.cpython-37.pyc │ ├── augmentor.py │ ├── flow_viz.py │ ├── frame_utils.py │ └── utils.py ├── README.md ├── TC_cal.py ├── TC_cal_total.py ├── VC_perclip.py ├── VC_perclip_total.py ├── evaluation.sh ├── evaluator_test.py └── utils.py /RAFT_core/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/.DS_Store -------------------------------------------------------------------------------- /RAFT_core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/__init__.py -------------------------------------------------------------------------------- /RAFT_core/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/__pycache__/corr.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/__pycache__/corr.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/__pycache__/extractor.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/__pycache__/extractor.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/__pycache__/raft.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/__pycache__/raft.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/__pycache__/update.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/__pycache__/update.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from RAFT_core.utils.utils import bilinear_sampler, coords_grid 4 | 5 | try: 6 | import alt_cuda_corr 7 | except: 8 | # alt_cuda_corr is not compiled 9 | pass 10 | 11 | 12 | class CorrBlock: 13 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 14 | self.num_levels = num_levels 15 | self.radius = radius 16 | self.corr_pyramid = [] 17 | 18 | # all pairs correlation 19 | corr = CorrBlock.corr(fmap1, fmap2) 20 | 21 | batch, h1, w1, dim, h2, w2 = corr.shape 22 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 23 | 24 | self.corr_pyramid.append(corr) 25 | for i in range(self.num_levels-1): 26 | corr = F.avg_pool2d(corr, 2, stride=2) 27 | self.corr_pyramid.append(corr) 28 | 29 | def __call__(self, coords): 30 | r = self.radius 31 | coords = coords.permute(0, 2, 3, 1) 32 | batch, h1, w1, _ = coords.shape 33 | 34 | out_pyramid = [] 35 | for i in range(self.num_levels): 36 | corr = self.corr_pyramid[i] 37 | dx = torch.linspace(-r, r, 2*r+1) 38 | dy = torch.linspace(-r, r, 2*r+1) 39 | delta = torch.stack(torch.meshgrid(dy, dx), dim=-1).to(coords.device) 40 | 41 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 42 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 43 | coords_lvl = centroid_lvl + delta_lvl 44 | 45 | corr = bilinear_sampler(corr, coords_lvl) 46 | corr = corr.view(batch, h1, w1, -1) 47 | out_pyramid.append(corr) 48 | 49 | out = torch.cat(out_pyramid, dim=-1) 50 | return out.permute(0, 3, 1, 2).contiguous().float() 51 | 52 | @staticmethod 53 | def corr(fmap1, fmap2): 54 | batch, dim, ht, wd = fmap1.shape 55 | fmap1 = fmap1.view(batch, dim, ht*wd) 56 | fmap2 = fmap2.view(batch, dim, ht*wd) 57 | 58 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 59 | corr = corr.view(batch, ht, wd, 1, ht, wd) 60 | return corr / torch.sqrt(torch.tensor(dim).float()) 61 | 62 | 63 | class AlternateCorrBlock: 64 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 65 | self.num_levels = num_levels 66 | self.radius = radius 67 | 68 | self.pyramid = [(fmap1, fmap2)] 69 | for i in range(self.num_levels): 70 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 71 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 72 | self.pyramid.append((fmap1, fmap2)) 73 | 74 | def __call__(self, coords): 75 | coords = coords.permute(0, 2, 3, 1) 76 | B, H, W, _ = coords.shape 77 | dim = self.pyramid[0][0].shape[1] 78 | 79 | corr_list = [] 80 | for i in range(self.num_levels): 81 | r = self.radius 82 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 83 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 84 | 85 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 86 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 87 | corr_list.append(corr.squeeze(1)) 88 | 89 | corr = torch.stack(corr_list, dim=1) 90 | corr = corr.reshape(B, -1, H, W) 91 | return corr / torch.sqrt(torch.tensor(dim).float()) 92 | -------------------------------------------------------------------------------- /RAFT_core/datasets.py: -------------------------------------------------------------------------------- 1 | # Data loading based on https://github.com/NVIDIA/flownet2-pytorch 2 | 3 | import numpy as np 4 | import torch 5 | import torch.utils.data as data 6 | import torch.nn.functional as F 7 | 8 | import os 9 | import math 10 | import random 11 | from glob import glob 12 | import os.path as osp 13 | 14 | from utils import frame_utils 15 | from utils.augmentor import FlowAugmentor, SparseFlowAugmentor 16 | 17 | 18 | class FlowDataset(data.Dataset): 19 | def __init__(self, aug_params=None, sparse=False): 20 | self.augmentor = None 21 | self.sparse = sparse 22 | if aug_params is not None: 23 | if sparse: 24 | self.augmentor = SparseFlowAugmentor(**aug_params) 25 | else: 26 | self.augmentor = FlowAugmentor(**aug_params) 27 | 28 | self.is_test = False 29 | self.init_seed = False 30 | self.flow_list = [] 31 | self.image_list = [] 32 | self.extra_info = [] 33 | 34 | def __getitem__(self, index): 35 | 36 | if self.is_test: 37 | img1 = frame_utils.read_gen(self.image_list[index][0]) 38 | img2 = frame_utils.read_gen(self.image_list[index][1]) 39 | img1 = np.array(img1).astype(np.uint8)[..., :3] 40 | img2 = np.array(img2).astype(np.uint8)[..., :3] 41 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 42 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 43 | return img1, img2, self.extra_info[index] 44 | 45 | if not self.init_seed: 46 | worker_info = torch.utils.data.get_worker_info() 47 | if worker_info is not None: 48 | torch.manual_seed(worker_info.id) 49 | np.random.seed(worker_info.id) 50 | random.seed(worker_info.id) 51 | self.init_seed = True 52 | 53 | index = index % len(self.image_list) 54 | valid = None 55 | if self.sparse: 56 | flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) 57 | else: 58 | flow = frame_utils.read_gen(self.flow_list[index]) 59 | 60 | img1 = frame_utils.read_gen(self.image_list[index][0]) 61 | img2 = frame_utils.read_gen(self.image_list[index][1]) 62 | 63 | flow = np.array(flow).astype(np.float32) 64 | img1 = np.array(img1).astype(np.uint8) 65 | img2 = np.array(img2).astype(np.uint8) 66 | 67 | # grayscale images 68 | if len(img1.shape) == 2: 69 | img1 = np.tile(img1[...,None], (1, 1, 3)) 70 | img2 = np.tile(img2[...,None], (1, 1, 3)) 71 | else: 72 | img1 = img1[..., :3] 73 | img2 = img2[..., :3] 74 | 75 | if self.augmentor is not None: 76 | if self.sparse: 77 | img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) 78 | else: 79 | img1, img2, flow = self.augmentor(img1, img2, flow) 80 | 81 | img1 = torch.from_numpy(img1).permute(2, 0, 1).float() 82 | img2 = torch.from_numpy(img2).permute(2, 0, 1).float() 83 | flow = torch.from_numpy(flow).permute(2, 0, 1).float() 84 | 85 | if valid is not None: 86 | valid = torch.from_numpy(valid) 87 | else: 88 | valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) 89 | 90 | return img1, img2, flow, valid.float() 91 | 92 | 93 | def __rmul__(self, v): 94 | self.flow_list = v * self.flow_list 95 | self.image_list = v * self.image_list 96 | return self 97 | 98 | def __len__(self): 99 | return len(self.image_list) 100 | 101 | 102 | class MpiSintel(FlowDataset): 103 | def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): 104 | super(MpiSintel, self).__init__(aug_params) 105 | flow_root = osp.join(root, split, 'flow') 106 | image_root = osp.join(root, split, dstype) 107 | 108 | if split == 'test': 109 | self.is_test = True 110 | 111 | for scene in os.listdir(image_root): 112 | image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) 113 | for i in range(len(image_list)-1): 114 | self.image_list += [ [image_list[i], image_list[i+1]] ] 115 | self.extra_info += [ (scene, i) ] # scene and frame_id 116 | 117 | if split != 'test': 118 | self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) 119 | 120 | 121 | class FlyingChairs(FlowDataset): 122 | def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): 123 | super(FlyingChairs, self).__init__(aug_params) 124 | 125 | images = sorted(glob(osp.join(root, '*.ppm'))) 126 | flows = sorted(glob(osp.join(root, '*.flo'))) 127 | assert (len(images)//2 == len(flows)) 128 | 129 | split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) 130 | for i in range(len(flows)): 131 | xid = split_list[i] 132 | if (split=='training' and xid==1) or (split=='validation' and xid==2): 133 | self.flow_list += [ flows[i] ] 134 | self.image_list += [ [images[2*i], images[2*i+1]] ] 135 | 136 | 137 | class FlyingThings3D(FlowDataset): 138 | def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): 139 | super(FlyingThings3D, self).__init__(aug_params) 140 | 141 | for cam in ['left']: 142 | for direction in ['into_future', 'into_past']: 143 | image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) 144 | image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) 145 | 146 | flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) 147 | flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) 148 | 149 | for idir, fdir in zip(image_dirs, flow_dirs): 150 | images = sorted(glob(osp.join(idir, '*.png')) ) 151 | flows = sorted(glob(osp.join(fdir, '*.pfm')) ) 152 | for i in range(len(flows)-1): 153 | if direction == 'into_future': 154 | self.image_list += [ [images[i], images[i+1]] ] 155 | self.flow_list += [ flows[i] ] 156 | elif direction == 'into_past': 157 | self.image_list += [ [images[i+1], images[i]] ] 158 | self.flow_list += [ flows[i+1] ] 159 | 160 | 161 | class KITTI(FlowDataset): 162 | def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): 163 | super(KITTI, self).__init__(aug_params, sparse=True) 164 | if split == 'testing': 165 | self.is_test = True 166 | 167 | root = osp.join(root, split) 168 | images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) 169 | images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) 170 | 171 | for img1, img2 in zip(images1, images2): 172 | frame_id = img1.split('/')[-1] 173 | self.extra_info += [ [frame_id] ] 174 | self.image_list += [ [img1, img2] ] 175 | 176 | if split == 'training': 177 | self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) 178 | 179 | 180 | class HD1K(FlowDataset): 181 | def __init__(self, aug_params=None, root='datasets/HD1k'): 182 | super(HD1K, self).__init__(aug_params, sparse=True) 183 | 184 | seq_ix = 0 185 | while 1: 186 | flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) 187 | images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) 188 | 189 | if len(flows) == 0: 190 | break 191 | 192 | for i in range(len(flows)-1): 193 | self.flow_list += [flows[i]] 194 | self.image_list += [ [images[i], images[i+1]] ] 195 | 196 | seq_ix += 1 197 | 198 | 199 | def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): 200 | """ Create the data loader for the corresponding trainign set """ 201 | 202 | if args.stage == 'chairs': 203 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} 204 | train_dataset = FlyingChairs(aug_params, split='training') 205 | 206 | elif args.stage == 'things': 207 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} 208 | clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') 209 | final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') 210 | train_dataset = clean_dataset + final_dataset 211 | 212 | elif args.stage == 'sintel': 213 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} 214 | things = FlyingThings3D(aug_params, dstype='frames_cleanpass') 215 | sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') 216 | sintel_final = MpiSintel(aug_params, split='training', dstype='final') 217 | 218 | if TRAIN_DS == 'C+T+K+S+H': 219 | kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) 220 | hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) 221 | train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things 222 | 223 | elif TRAIN_DS == 'C+T+K/S': 224 | train_dataset = 100*sintel_clean + 100*sintel_final + things 225 | 226 | elif args.stage == 'kitti': 227 | aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} 228 | train_dataset = KITTI(aug_params, split='training') 229 | 230 | train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, 231 | pin_memory=False, shuffle=True, num_workers=4, drop_last=True) 232 | 233 | print('Training with %d image pairs' % len(train_dataset)) 234 | return train_loader 235 | 236 | -------------------------------------------------------------------------------- /RAFT_core/extractor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ResidualBlock(nn.Module): 7 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 8 | super(ResidualBlock, self).__init__() 9 | 10 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) 11 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | num_groups = planes // 8 15 | 16 | if norm_fn == 'group': 17 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 18 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 19 | if not stride == 1: 20 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 21 | 22 | elif norm_fn == 'batch': 23 | self.norm1 = nn.BatchNorm2d(planes) 24 | self.norm2 = nn.BatchNorm2d(planes) 25 | if not stride == 1: 26 | self.norm3 = nn.BatchNorm2d(planes) 27 | 28 | elif norm_fn == 'instance': 29 | self.norm1 = nn.InstanceNorm2d(planes) 30 | self.norm2 = nn.InstanceNorm2d(planes) 31 | if not stride == 1: 32 | self.norm3 = nn.InstanceNorm2d(planes) 33 | 34 | elif norm_fn == 'none': 35 | self.norm1 = nn.Sequential() 36 | self.norm2 = nn.Sequential() 37 | if not stride == 1: 38 | self.norm3 = nn.Sequential() 39 | 40 | if stride == 1: 41 | self.downsample = None 42 | 43 | else: 44 | self.downsample = nn.Sequential( 45 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) 46 | 47 | 48 | def forward(self, x): 49 | y = x 50 | y = self.relu(self.norm1(self.conv1(y))) 51 | y = self.relu(self.norm2(self.conv2(y))) 52 | 53 | if self.downsample is not None: 54 | x = self.downsample(x) 55 | 56 | return self.relu(x+y) 57 | 58 | 59 | 60 | class BottleneckBlock(nn.Module): 61 | def __init__(self, in_planes, planes, norm_fn='group', stride=1): 62 | super(BottleneckBlock, self).__init__() 63 | 64 | self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) 65 | self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) 66 | self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) 67 | self.relu = nn.ReLU(inplace=True) 68 | 69 | num_groups = planes // 8 70 | 71 | if norm_fn == 'group': 72 | self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 73 | self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) 74 | self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 75 | if not stride == 1: 76 | self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) 77 | 78 | elif norm_fn == 'batch': 79 | self.norm1 = nn.BatchNorm2d(planes//4) 80 | self.norm2 = nn.BatchNorm2d(planes//4) 81 | self.norm3 = nn.BatchNorm2d(planes) 82 | if not stride == 1: 83 | self.norm4 = nn.BatchNorm2d(planes) 84 | 85 | elif norm_fn == 'instance': 86 | self.norm1 = nn.InstanceNorm2d(planes//4) 87 | self.norm2 = nn.InstanceNorm2d(planes//4) 88 | self.norm3 = nn.InstanceNorm2d(planes) 89 | if not stride == 1: 90 | self.norm4 = nn.InstanceNorm2d(planes) 91 | 92 | elif norm_fn == 'none': 93 | self.norm1 = nn.Sequential() 94 | self.norm2 = nn.Sequential() 95 | self.norm3 = nn.Sequential() 96 | if not stride == 1: 97 | self.norm4 = nn.Sequential() 98 | 99 | if stride == 1: 100 | self.downsample = None 101 | 102 | else: 103 | self.downsample = nn.Sequential( 104 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) 105 | 106 | 107 | def forward(self, x): 108 | y = x 109 | y = self.relu(self.norm1(self.conv1(y))) 110 | y = self.relu(self.norm2(self.conv2(y))) 111 | y = self.relu(self.norm3(self.conv3(y))) 112 | 113 | if self.downsample is not None: 114 | x = self.downsample(x) 115 | 116 | return self.relu(x+y) 117 | 118 | class BasicEncoder(nn.Module): 119 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 120 | super(BasicEncoder, self).__init__() 121 | self.norm_fn = norm_fn 122 | 123 | if self.norm_fn == 'group': 124 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) 125 | 126 | elif self.norm_fn == 'batch': 127 | self.norm1 = nn.BatchNorm2d(64) 128 | 129 | elif self.norm_fn == 'instance': 130 | self.norm1 = nn.InstanceNorm2d(64) 131 | 132 | elif self.norm_fn == 'none': 133 | self.norm1 = nn.Sequential() 134 | 135 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) 136 | self.relu1 = nn.ReLU(inplace=True) 137 | 138 | self.in_planes = 64 139 | self.layer1 = self._make_layer(64, stride=1) 140 | self.layer2 = self._make_layer(96, stride=2) 141 | self.layer3 = self._make_layer(128, stride=2) 142 | 143 | # output convolution 144 | self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) 145 | 146 | self.dropout = None 147 | if dropout > 0: 148 | self.dropout = nn.Dropout2d(p=dropout) 149 | 150 | for m in self.modules(): 151 | if isinstance(m, nn.Conv2d): 152 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 153 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 154 | if m.weight is not None: 155 | nn.init.constant_(m.weight, 1) 156 | if m.bias is not None: 157 | nn.init.constant_(m.bias, 0) 158 | 159 | def _make_layer(self, dim, stride=1): 160 | layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) 161 | layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) 162 | layers = (layer1, layer2) 163 | 164 | self.in_planes = dim 165 | return nn.Sequential(*layers) 166 | 167 | 168 | def forward(self, x): 169 | 170 | # if input is list, combine batch dimension 171 | is_list = isinstance(x, tuple) or isinstance(x, list) 172 | if is_list: 173 | batch_dim = x[0].shape[0] 174 | x = torch.cat(x, dim=0) 175 | 176 | x = self.conv1(x) 177 | x = self.norm1(x) 178 | x = self.relu1(x) 179 | 180 | x = self.layer1(x) 181 | x = self.layer2(x) 182 | x = self.layer3(x) 183 | 184 | x = self.conv2(x) 185 | 186 | if self.training and self.dropout is not None: 187 | x = self.dropout(x) 188 | 189 | if is_list: 190 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 191 | 192 | return x 193 | 194 | 195 | class SmallEncoder(nn.Module): 196 | def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): 197 | super(SmallEncoder, self).__init__() 198 | self.norm_fn = norm_fn 199 | 200 | if self.norm_fn == 'group': 201 | self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) 202 | 203 | elif self.norm_fn == 'batch': 204 | self.norm1 = nn.BatchNorm2d(32) 205 | 206 | elif self.norm_fn == 'instance': 207 | self.norm1 = nn.InstanceNorm2d(32) 208 | 209 | elif self.norm_fn == 'none': 210 | self.norm1 = nn.Sequential() 211 | 212 | self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) 213 | self.relu1 = nn.ReLU(inplace=True) 214 | 215 | self.in_planes = 32 216 | self.layer1 = self._make_layer(32, stride=1) 217 | self.layer2 = self._make_layer(64, stride=2) 218 | self.layer3 = self._make_layer(96, stride=2) 219 | 220 | self.dropout = None 221 | if dropout > 0: 222 | self.dropout = nn.Dropout2d(p=dropout) 223 | 224 | self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) 225 | 226 | for m in self.modules(): 227 | if isinstance(m, nn.Conv2d): 228 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 229 | elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): 230 | if m.weight is not None: 231 | nn.init.constant_(m.weight, 1) 232 | if m.bias is not None: 233 | nn.init.constant_(m.bias, 0) 234 | 235 | def _make_layer(self, dim, stride=1): 236 | layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) 237 | layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) 238 | layers = (layer1, layer2) 239 | 240 | self.in_planes = dim 241 | return nn.Sequential(*layers) 242 | 243 | 244 | def forward(self, x): 245 | 246 | # if input is list, combine batch dimension 247 | is_list = isinstance(x, tuple) or isinstance(x, list) 248 | if is_list: 249 | batch_dim = x[0].shape[0] 250 | x = torch.cat(x, dim=0) 251 | 252 | x = self.conv1(x) 253 | x = self.norm1(x) 254 | x = self.relu1(x) 255 | 256 | x = self.layer1(x) 257 | x = self.layer2(x) 258 | x = self.layer3(x) 259 | x = self.conv2(x) 260 | 261 | if self.training and self.dropout is not None: 262 | x = self.dropout(x) 263 | 264 | if is_list: 265 | x = torch.split(x, [batch_dim, batch_dim], dim=0) 266 | 267 | return x 268 | -------------------------------------------------------------------------------- /RAFT_core/raft-things.pth-no-zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/raft-things.pth-no-zip -------------------------------------------------------------------------------- /RAFT_core/raft.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import sys 6 | sys.path.append('RAFT_core') 7 | 8 | from update import BasicUpdateBlock, SmallUpdateBlock 9 | from extractor import BasicEncoder, SmallEncoder 10 | from corr import CorrBlock, AlternateCorrBlock 11 | from RAFT_core.utils.utils import bilinear_sampler, coords_grid, upflow8 12 | 13 | try: 14 | autocast = torch.cuda.amp.autocast 15 | except: 16 | # dummy autocast for PyTorch < 1.6 17 | class autocast: 18 | def __init__(self, enabled): 19 | pass 20 | def __enter__(self): 21 | pass 22 | def __exit__(self, *args): 23 | pass 24 | 25 | 26 | class RAFT(nn.Module): 27 | def __init__(self,requires_grad=False): 28 | super(RAFT, self).__init__() 29 | # self.args = args 30 | 31 | self.hidden_dim = hdim = 128 32 | self.context_dim = cdim = 128 33 | corr_levels = 4 34 | corr_radius = 4 35 | self.corr_radius = corr_radius 36 | 37 | 38 | 39 | # feature network, context network, and update block 40 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=0) 41 | self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=0) 42 | self.update_block = BasicUpdateBlock(corr_levels,corr_radius, hidden_dim=hdim) 43 | if not requires_grad: 44 | for param in self.parameters(): 45 | param.requires_grad = False 46 | 47 | def freeze_bn(self): 48 | for m in self.modules(): 49 | if isinstance(m, nn.BatchNorm2d): 50 | m.eval() 51 | 52 | def initialize_flow(self, img): 53 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 54 | N, C, H, W = img.shape 55 | coords0 = coords_grid(N, H//8, W//8).to(img.device) 56 | coords1 = coords_grid(N, H//8, W//8).to(img.device) 57 | 58 | # optical flow computed as difference: flow = coords1 - coords0 59 | return coords0, coords1 60 | 61 | def upsample_flow(self, flow, mask): 62 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 63 | N, _, H, W = flow.shape 64 | mask = mask.view(N, 1, 9, 8, 8, H, W) 65 | mask = torch.softmax(mask, dim=2) 66 | 67 | up_flow = F.unfold(8 * flow, [3,3], padding=1) 68 | up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) 69 | 70 | up_flow = torch.sum(mask * up_flow, dim=2) 71 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 72 | return up_flow.reshape(N, 2, 8*H, 8*W) 73 | 74 | 75 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 76 | """ Estimate optical flow between pair of frames """ 77 | 78 | image1 = 2 * (image1 / 255.0) - 1.0 79 | image2 = 2 * (image2 / 255.0) - 1.0 80 | 81 | image1 = image1.contiguous() 82 | image2 = image2.contiguous() 83 | 84 | hdim = self.hidden_dim 85 | cdim = self.context_dim 86 | 87 | # run the feature network 88 | fmap1, fmap2 = self.fnet([image1, image2]) 89 | 90 | fmap1 = fmap1.float() 91 | fmap2 = fmap2.float() 92 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.corr_radius) 93 | 94 | # run the context network 95 | cnet = self.cnet(image1) 96 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 97 | net = torch.tanh(net) 98 | inp = torch.relu(inp) 99 | 100 | coords0, coords1 = self.initialize_flow(image1) 101 | 102 | if flow_init is not None: 103 | coords1 = coords1 + flow_init 104 | 105 | flow_predictions = [] 106 | for itr in range(iters): 107 | coords1 = coords1.detach() 108 | corr = corr_fn(coords1) # index correlation volume 109 | 110 | flow = coords1 - coords0 111 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 112 | 113 | # F(t+1) = F(t) + \Delta(t) 114 | coords1 = coords1 + delta_flow 115 | 116 | # upsample predictions 117 | if up_mask is None: 118 | flow_up = upflow8(coords1 - coords0) 119 | else: 120 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 121 | 122 | flow_predictions.append(flow_up) 123 | 124 | if test_mode: 125 | return coords1 - coords0, flow_up 126 | 127 | return flow_predictions 128 | -------------------------------------------------------------------------------- /RAFT_core/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self,corr_levels,corr_radius): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = corr_levels * (2*corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self,corr_levels,corr_radius ): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = corr_levels * (2*corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, corr_levels,corr_radius, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(corr_levels,corr_radius) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, corr_levels,corr_radius, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.encoder = BasicMotionEncoder(corr_levels,corr_radius) 118 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 119 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 120 | 121 | self.mask = nn.Sequential( 122 | nn.Conv2d(128, 256, 3, padding=1), 123 | nn.ReLU(inplace=True), 124 | nn.Conv2d(256, 64*9, 1, padding=0)) 125 | 126 | def forward(self, net, inp, corr, flow, upsample=True): 127 | motion_features = self.encoder(flow, corr) 128 | inp = torch.cat([inp, motion_features], dim=1) 129 | 130 | net = self.gru(net, inp) 131 | delta_flow = self.flow_head(net) 132 | 133 | # scale mask to balence gradients 134 | mask = .25 * self.mask(net) 135 | return net, mask, delta_flow 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /RAFT_core/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/utils/__init__.py -------------------------------------------------------------------------------- /RAFT_core/utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/utils/__pycache__/utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/VSPW-dataset/VSPW_code/73b98ceb92cd9639da210da14f1bb0e02946f846/RAFT_core/utils/__pycache__/utils.cpython-37.pyc -------------------------------------------------------------------------------- /RAFT_core/utils/augmentor.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import random 3 | import math 4 | from PIL import Image 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | import torch 11 | from torchvision.transforms import ColorJitter 12 | import torch.nn.functional as F 13 | 14 | 15 | class FlowAugmentor: 16 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): 17 | 18 | # spatial augmentation params 19 | self.crop_size = crop_size 20 | self.min_scale = min_scale 21 | self.max_scale = max_scale 22 | self.spatial_aug_prob = 0.8 23 | self.stretch_prob = 0.8 24 | self.max_stretch = 0.2 25 | 26 | # flip augmentation params 27 | self.do_flip = do_flip 28 | self.h_flip_prob = 0.5 29 | self.v_flip_prob = 0.1 30 | 31 | # photometric augmentation params 32 | self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) 33 | self.asymmetric_color_aug_prob = 0.2 34 | self.eraser_aug_prob = 0.5 35 | 36 | def color_transform(self, img1, img2): 37 | """ Photometric augmentation """ 38 | 39 | # asymmetric 40 | if np.random.rand() < self.asymmetric_color_aug_prob: 41 | img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) 42 | img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) 43 | 44 | # symmetric 45 | else: 46 | image_stack = np.concatenate([img1, img2], axis=0) 47 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 48 | img1, img2 = np.split(image_stack, 2, axis=0) 49 | 50 | return img1, img2 51 | 52 | def eraser_transform(self, img1, img2, bounds=[50, 100]): 53 | """ Occlusion augmentation """ 54 | 55 | ht, wd = img1.shape[:2] 56 | if np.random.rand() < self.eraser_aug_prob: 57 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 58 | for _ in range(np.random.randint(1, 3)): 59 | x0 = np.random.randint(0, wd) 60 | y0 = np.random.randint(0, ht) 61 | dx = np.random.randint(bounds[0], bounds[1]) 62 | dy = np.random.randint(bounds[0], bounds[1]) 63 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 64 | 65 | return img1, img2 66 | 67 | def spatial_transform(self, img1, img2, flow): 68 | # randomly sample scale 69 | ht, wd = img1.shape[:2] 70 | min_scale = np.maximum( 71 | (self.crop_size[0] + 8) / float(ht), 72 | (self.crop_size[1] + 8) / float(wd)) 73 | 74 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 75 | scale_x = scale 76 | scale_y = scale 77 | if np.random.rand() < self.stretch_prob: 78 | scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 79 | scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) 80 | 81 | scale_x = np.clip(scale_x, min_scale, None) 82 | scale_y = np.clip(scale_y, min_scale, None) 83 | 84 | if np.random.rand() < self.spatial_aug_prob: 85 | # rescale the images 86 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 87 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 88 | flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 89 | flow = flow * [scale_x, scale_y] 90 | 91 | if self.do_flip: 92 | if np.random.rand() < self.h_flip_prob: # h-flip 93 | img1 = img1[:, ::-1] 94 | img2 = img2[:, ::-1] 95 | flow = flow[:, ::-1] * [-1.0, 1.0] 96 | 97 | if np.random.rand() < self.v_flip_prob: # v-flip 98 | img1 = img1[::-1, :] 99 | img2 = img2[::-1, :] 100 | flow = flow[::-1, :] * [1.0, -1.0] 101 | 102 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) 103 | x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) 104 | 105 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 106 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 107 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 108 | 109 | return img1, img2, flow 110 | 111 | def __call__(self, img1, img2, flow): 112 | img1, img2 = self.color_transform(img1, img2) 113 | img1, img2 = self.eraser_transform(img1, img2) 114 | img1, img2, flow = self.spatial_transform(img1, img2, flow) 115 | 116 | img1 = np.ascontiguousarray(img1) 117 | img2 = np.ascontiguousarray(img2) 118 | flow = np.ascontiguousarray(flow) 119 | 120 | return img1, img2, flow 121 | 122 | class SparseFlowAugmentor: 123 | def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): 124 | # spatial augmentation params 125 | self.crop_size = crop_size 126 | self.min_scale = min_scale 127 | self.max_scale = max_scale 128 | self.spatial_aug_prob = 0.8 129 | self.stretch_prob = 0.8 130 | self.max_stretch = 0.2 131 | 132 | # flip augmentation params 133 | self.do_flip = do_flip 134 | self.h_flip_prob = 0.5 135 | self.v_flip_prob = 0.1 136 | 137 | # photometric augmentation params 138 | self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) 139 | self.asymmetric_color_aug_prob = 0.2 140 | self.eraser_aug_prob = 0.5 141 | 142 | def color_transform(self, img1, img2): 143 | image_stack = np.concatenate([img1, img2], axis=0) 144 | image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) 145 | img1, img2 = np.split(image_stack, 2, axis=0) 146 | return img1, img2 147 | 148 | def eraser_transform(self, img1, img2): 149 | ht, wd = img1.shape[:2] 150 | if np.random.rand() < self.eraser_aug_prob: 151 | mean_color = np.mean(img2.reshape(-1, 3), axis=0) 152 | for _ in range(np.random.randint(1, 3)): 153 | x0 = np.random.randint(0, wd) 154 | y0 = np.random.randint(0, ht) 155 | dx = np.random.randint(50, 100) 156 | dy = np.random.randint(50, 100) 157 | img2[y0:y0+dy, x0:x0+dx, :] = mean_color 158 | 159 | return img1, img2 160 | 161 | def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): 162 | ht, wd = flow.shape[:2] 163 | coords = np.meshgrid(np.arange(wd), np.arange(ht)) 164 | coords = np.stack(coords, axis=-1) 165 | 166 | coords = coords.reshape(-1, 2).astype(np.float32) 167 | flow = flow.reshape(-1, 2).astype(np.float32) 168 | valid = valid.reshape(-1).astype(np.float32) 169 | 170 | coords0 = coords[valid>=1] 171 | flow0 = flow[valid>=1] 172 | 173 | ht1 = int(round(ht * fy)) 174 | wd1 = int(round(wd * fx)) 175 | 176 | coords1 = coords0 * [fx, fy] 177 | flow1 = flow0 * [fx, fy] 178 | 179 | xx = np.round(coords1[:,0]).astype(np.int32) 180 | yy = np.round(coords1[:,1]).astype(np.int32) 181 | 182 | v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) 183 | xx = xx[v] 184 | yy = yy[v] 185 | flow1 = flow1[v] 186 | 187 | flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) 188 | valid_img = np.zeros([ht1, wd1], dtype=np.int32) 189 | 190 | flow_img[yy, xx] = flow1 191 | valid_img[yy, xx] = 1 192 | 193 | return flow_img, valid_img 194 | 195 | def spatial_transform(self, img1, img2, flow, valid): 196 | # randomly sample scale 197 | 198 | ht, wd = img1.shape[:2] 199 | min_scale = np.maximum( 200 | (self.crop_size[0] + 1) / float(ht), 201 | (self.crop_size[1] + 1) / float(wd)) 202 | 203 | scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) 204 | scale_x = np.clip(scale, min_scale, None) 205 | scale_y = np.clip(scale, min_scale, None) 206 | 207 | if np.random.rand() < self.spatial_aug_prob: 208 | # rescale the images 209 | img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 210 | img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) 211 | flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) 212 | 213 | if self.do_flip: 214 | if np.random.rand() < 0.5: # h-flip 215 | img1 = img1[:, ::-1] 216 | img2 = img2[:, ::-1] 217 | flow = flow[:, ::-1] * [-1.0, 1.0] 218 | valid = valid[:, ::-1] 219 | 220 | margin_y = 20 221 | margin_x = 50 222 | 223 | y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) 224 | x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) 225 | 226 | y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) 227 | x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) 228 | 229 | img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 230 | img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 231 | flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 232 | valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] 233 | return img1, img2, flow, valid 234 | 235 | 236 | def __call__(self, img1, img2, flow, valid): 237 | img1, img2 = self.color_transform(img1, img2) 238 | img1, img2 = self.eraser_transform(img1, img2) 239 | img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) 240 | 241 | img1 = np.ascontiguousarray(img1) 242 | img2 = np.ascontiguousarray(img2) 243 | flow = np.ascontiguousarray(flow) 244 | valid = np.ascontiguousarray(valid) 245 | 246 | return img1, img2, flow, valid 247 | -------------------------------------------------------------------------------- /RAFT_core/utils/flow_viz.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | 4 | # MIT License 5 | # 6 | # Copyright (c) 2018 Tom Runia 7 | # 8 | # Permission is hereby granted, free of charge, to any person obtaining a copy 9 | # of this software and associated documentation files (the "Software"), to deal 10 | # in the Software without restriction, including without limitation the rights 11 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 12 | # copies of the Software, and to permit persons to whom the Software is 13 | # furnished to do so, subject to conditions. 14 | # 15 | # Author: Tom Runia 16 | # Date Created: 2018-08-03 17 | 18 | import numpy as np 19 | 20 | def make_colorwheel(): 21 | """ 22 | Generates a color wheel for optical flow visualization as presented in: 23 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 24 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 25 | 26 | Code follows the original C++ source code of Daniel Scharstein. 27 | Code follows the the Matlab source code of Deqing Sun. 28 | 29 | Returns: 30 | np.ndarray: Color wheel 31 | """ 32 | 33 | RY = 15 34 | YG = 6 35 | GC = 4 36 | CB = 11 37 | BM = 13 38 | MR = 6 39 | 40 | ncols = RY + YG + GC + CB + BM + MR 41 | colorwheel = np.zeros((ncols, 3)) 42 | col = 0 43 | 44 | # RY 45 | colorwheel[0:RY, 0] = 255 46 | colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) 47 | col = col+RY 48 | # YG 49 | colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) 50 | colorwheel[col:col+YG, 1] = 255 51 | col = col+YG 52 | # GC 53 | colorwheel[col:col+GC, 1] = 255 54 | colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) 55 | col = col+GC 56 | # CB 57 | colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) 58 | colorwheel[col:col+CB, 2] = 255 59 | col = col+CB 60 | # BM 61 | colorwheel[col:col+BM, 2] = 255 62 | colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) 63 | col = col+BM 64 | # MR 65 | colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) 66 | colorwheel[col:col+MR, 0] = 255 67 | return colorwheel 68 | 69 | 70 | def flow_uv_to_colors(u, v, convert_to_bgr=False): 71 | """ 72 | Applies the flow color wheel to (possibly clipped) flow components u and v. 73 | 74 | According to the C++ source code of Daniel Scharstein 75 | According to the Matlab source code of Deqing Sun 76 | 77 | Args: 78 | u (np.ndarray): Input horizontal flow of shape [H,W] 79 | v (np.ndarray): Input vertical flow of shape [H,W] 80 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 81 | 82 | Returns: 83 | np.ndarray: Flow visualization image of shape [H,W,3] 84 | """ 85 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 86 | colorwheel = make_colorwheel() # shape [55x3] 87 | ncols = colorwheel.shape[0] 88 | rad = np.sqrt(np.square(u) + np.square(v)) 89 | a = np.arctan2(-v, -u)/np.pi 90 | fk = (a+1) / 2*(ncols-1) 91 | k0 = np.floor(fk).astype(np.int32) 92 | k1 = k0 + 1 93 | k1[k1 == ncols] = 0 94 | f = fk - k0 95 | for i in range(colorwheel.shape[1]): 96 | tmp = colorwheel[:,i] 97 | col0 = tmp[k0] / 255.0 98 | col1 = tmp[k1] / 255.0 99 | col = (1-f)*col0 + f*col1 100 | idx = (rad <= 1) 101 | col[idx] = 1 - rad[idx] * (1-col[idx]) 102 | col[~idx] = col[~idx] * 0.75 # out of range 103 | # Note the 2-i => BGR instead of RGB 104 | ch_idx = 2-i if convert_to_bgr else i 105 | flow_image[:,:,ch_idx] = np.floor(255 * col) 106 | return flow_image 107 | 108 | 109 | def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): 110 | """ 111 | Expects a two dimensional flow image of shape. 112 | 113 | Args: 114 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 115 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 116 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 117 | 118 | Returns: 119 | np.ndarray: Flow visualization image of shape [H,W,3] 120 | """ 121 | assert flow_uv.ndim == 3, 'input flow must have three dimensions' 122 | assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' 123 | if clip_flow is not None: 124 | flow_uv = np.clip(flow_uv, 0, clip_flow) 125 | u = flow_uv[:,:,0] 126 | v = flow_uv[:,:,1] 127 | rad = np.sqrt(np.square(u) + np.square(v)) 128 | rad_max = np.max(rad) 129 | epsilon = 1e-5 130 | u = u / (rad_max + epsilon) 131 | v = v / (rad_max + epsilon) 132 | return flow_uv_to_colors(u, v, convert_to_bgr) -------------------------------------------------------------------------------- /RAFT_core/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | from os.path import * 4 | import re 5 | 6 | import cv2 7 | cv2.setNumThreads(0) 8 | cv2.ocl.setUseOpenCL(False) 9 | 10 | TAG_CHAR = np.array([202021.25], np.float32) 11 | 12 | def readFlow(fn): 13 | """ Read .flo file in Middlebury format""" 14 | # Code adapted from: 15 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 16 | 17 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 18 | # print 'fn = %s'%(fn) 19 | with open(fn, 'rb') as f: 20 | magic = np.fromfile(f, np.float32, count=1) 21 | if 202021.25 != magic: 22 | print('Magic number incorrect. Invalid .flo file') 23 | return None 24 | else: 25 | w = np.fromfile(f, np.int32, count=1) 26 | h = np.fromfile(f, np.int32, count=1) 27 | # print 'Reading %d x %d flo file\n' % (w, h) 28 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 29 | # Reshape data into 3D array (columns, rows, bands) 30 | # The reshape here is for visualization, the original code is (w,h,2) 31 | return np.resize(data, (int(h), int(w), 2)) 32 | 33 | def readPFM(file): 34 | file = open(file, 'rb') 35 | 36 | color = None 37 | width = None 38 | height = None 39 | scale = None 40 | endian = None 41 | 42 | header = file.readline().rstrip() 43 | if header == b'PF': 44 | color = True 45 | elif header == b'Pf': 46 | color = False 47 | else: 48 | raise Exception('Not a PFM file.') 49 | 50 | dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) 51 | if dim_match: 52 | width, height = map(int, dim_match.groups()) 53 | else: 54 | raise Exception('Malformed PFM header.') 55 | 56 | scale = float(file.readline().rstrip()) 57 | if scale < 0: # little-endian 58 | endian = '<' 59 | scale = -scale 60 | else: 61 | endian = '>' # big-endian 62 | 63 | data = np.fromfile(file, endian + 'f') 64 | shape = (height, width, 3) if color else (height, width) 65 | 66 | data = np.reshape(data, shape) 67 | data = np.flipud(data) 68 | return data 69 | 70 | def writeFlow(filename,uv,v=None): 71 | """ Write optical flow to file. 72 | 73 | If v is None, uv is assumed to contain both u and v channels, 74 | stacked in depth. 75 | Original code by Deqing Sun, adapted from Daniel Scharstein. 76 | """ 77 | nBands = 2 78 | 79 | if v is None: 80 | assert(uv.ndim == 3) 81 | assert(uv.shape[2] == 2) 82 | u = uv[:,:,0] 83 | v = uv[:,:,1] 84 | else: 85 | u = uv 86 | 87 | assert(u.shape == v.shape) 88 | height,width = u.shape 89 | f = open(filename,'wb') 90 | # write the header 91 | f.write(TAG_CHAR) 92 | np.array(width).astype(np.int32).tofile(f) 93 | np.array(height).astype(np.int32).tofile(f) 94 | # arrange into matrix form 95 | tmp = np.zeros((height, width*nBands)) 96 | tmp[:,np.arange(width)*2] = u 97 | tmp[:,np.arange(width)*2 + 1] = v 98 | tmp.astype(np.float32).tofile(f) 99 | f.close() 100 | 101 | 102 | def readFlowKITTI(filename): 103 | flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) 104 | flow = flow[:,:,::-1].astype(np.float32) 105 | flow, valid = flow[:, :, :2], flow[:, :, 2] 106 | flow = (flow - 2**15) / 64.0 107 | return flow, valid 108 | 109 | def readDispKITTI(filename): 110 | disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 111 | valid = disp > 0.0 112 | flow = np.stack([-disp, np.zeros_like(disp)], -1) 113 | return flow, valid 114 | 115 | 116 | def writeFlowKITTI(filename, uv): 117 | uv = 64.0 * uv + 2**15 118 | valid = np.ones([uv.shape[0], uv.shape[1], 1]) 119 | uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) 120 | cv2.imwrite(filename, uv[..., ::-1]) 121 | 122 | 123 | def read_gen(file_name, pil=False): 124 | ext = splitext(file_name)[-1] 125 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 126 | return Image.open(file_name) 127 | elif ext == '.bin' or ext == '.raw': 128 | return np.load(file_name) 129 | elif ext == '.flo': 130 | return readFlow(file_name).astype(np.float32) 131 | elif ext == '.pfm': 132 | flow = readPFM(file_name).astype(np.float32) 133 | if len(flow.shape) == 2: 134 | return flow 135 | else: 136 | return flow[:, :, :-1] 137 | return [] -------------------------------------------------------------------------------- /RAFT_core/utils/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, x): 19 | #return F.pad(x, self._pad, mode='replicate') 20 | return F.pad(x, self._pad, mode='constant') 21 | 22 | def unpad(self,x): 23 | ht, wd = x.shape[-2:] 24 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 25 | return x[:,:, c[0]:c[1], c[2]:c[3]] 26 | 27 | def forward_interpolate(flow): 28 | flow = flow.detach().cpu().numpy() 29 | dx, dy = flow[0], flow[1] 30 | 31 | ht, wd = dx.shape 32 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 33 | 34 | x1 = x0 + dx 35 | y1 = y0 + dy 36 | 37 | x1 = x1.reshape(-1) 38 | y1 = y1.reshape(-1) 39 | dx = dx.reshape(-1) 40 | dy = dy.reshape(-1) 41 | 42 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 43 | x1 = x1[valid] 44 | y1 = y1[valid] 45 | dx = dx[valid] 46 | dy = dy[valid] 47 | 48 | flow_x = interpolate.griddata( 49 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 50 | 51 | flow_y = interpolate.griddata( 52 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 53 | 54 | flow = np.stack([flow_x, flow_y], axis=0) 55 | return torch.from_numpy(flow).float() 56 | 57 | 58 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 59 | """ Wrapper for grid_sample, uses pixel coordinates """ 60 | H, W = img.shape[-2:] 61 | xgrid, ygrid = coords.split([1,1], dim=-1) 62 | xgrid = 2*xgrid/(W-1) - 1 63 | ygrid = 2*ygrid/(H-1) - 1 64 | 65 | grid = torch.cat([xgrid, ygrid], dim=-1) 66 | img = F.grid_sample(img, grid, align_corners=True) 67 | 68 | if mask: 69 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 70 | return img, mask.float() 71 | 72 | return img 73 | 74 | 75 | def coords_grid(batch, ht, wd): 76 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 77 | coords = torch.stack(coords[::-1], dim=0).float() 78 | return coords[None].repeat(batch, 1, 1, 1) 79 | 80 | 81 | def upflow8(flow, mode='bilinear'): 82 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 83 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 84 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VSPW_code 2 | CVPR 2021 VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild 3 | 4 | 5 | ### If you need the training and testing code, please contact me (jiaxumiao@zju.edu.cn). 6 | ### You can also refer to a good re-implementation code at [this link](https://github.com/sssdddwww2/CVPR2021_VSPW_Implement). 7 | 8 | 9 | # Evaluation 10 | 11 | ``` 12 | sh evaluation.sh 13 | ``` 14 | 15 | 16 | 17 | # Citation 18 | 19 | ``` 20 | @inproceedings{miao2021vspw, 21 | 22 | title={VSPW: A Large-scale Dataset for Video Scene Parsing in the Wild}, 23 | 24 | author={Miao, Jiaxu and Wei, Yunchao and Wu, Yu and Liang, Chen and Li, Guangrui and Yang, Yi}, 25 | 26 | booktitle={Proceedings of the {IEEE} Conference on Computer Vision and Pattern Recognition}, 27 | 28 | year={2021} 29 | 30 | } 31 | ``` 32 | -------------------------------------------------------------------------------- /TC_cal.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from RAFT_core.raft import RAFT 4 | from RAFT_core.utils.utils import InputPadder 5 | from collections import OrderedDict 6 | from utils import Evaluator 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | import sys 11 | 12 | def flowwarp(x, flo): 13 | """ 14 | warp an image/tensor (im2) back to im1, according to the optical flow 15 | x: [B, C, H, W] (im2) 16 | flo: [B, 2, H, W] flow 17 | """ 18 | B, C, H, W = x.size() 19 | # mesh grid 20 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 21 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 22 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 23 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 24 | grid = torch.cat((xx,yy),1).float() 25 | 26 | if x.is_cuda: 27 | grid = grid.to(x.device) 28 | vgrid = grid + flo 29 | 30 | # scale grid to [-1,1] 31 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 32 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 33 | 34 | vgrid = vgrid.permute(0,2,3,1) 35 | output = nn.functional.grid_sample(x, vgrid,mode='nearest',align_corners=False) 36 | 37 | return output 38 | 39 | 40 | 41 | num_class=124 42 | 43 | DIR_=sys.argv[1] 44 | 45 | data_dir=DIR_+'/data' 46 | result_dir=sys.argv[2] 47 | #list_=['1001_5z_ijQjUf_0','1002_QXQ_QoswLOs'] 48 | 49 | split='val.txt' 50 | with open(os.path.join(DIR_,split),'r') as f: 51 | 52 | list_ = f.readlines() 53 | list_ = [v[:-1] for v in list_] 54 | 55 | ### 56 | gpu=0 57 | model_raft = RAFT() 58 | to_load = torch.load('./RAFT_core/raft-things.pth-no-zip') 59 | new_state_dict = OrderedDict() 60 | for k, v in to_load.items(): 61 | name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module. 62 | new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 63 | model_raft.load_state_dict(new_state_dict) 64 | model_raft = model_raft.cuda(gpu) 65 | ### 66 | total_TC=0. 67 | evaluator = Evaluator(num_class) 68 | for video in list_[:100]: 69 | imglist_ = sorted(os.listdir(os.path.join(data_dir,video,'origin'))) 70 | for i,img in enumerate(imglist_[:-1]): 71 | #print('processing video : {} image: {}'.format(video,img)) 72 | next_img = imglist_[i+1] 73 | imgname = img 74 | next_imgname = next_img 75 | img = Image.open(os.path.join(data_dir,video,'origin',img)) 76 | next_img =Image.open(os.path.join(data_dir,video,'origin',next_img)) 77 | image1 = torch.from_numpy(np.array(img)) 78 | image2 = torch.from_numpy(np.array(next_img)) 79 | padder = InputPadder(image1.size()[:2]) 80 | image1 = image1.unsqueeze(0).permute(0,3,1,2) 81 | image2 = image2.unsqueeze(0).permute(0,3,1,2) 82 | image1 = padder.pad(image1) 83 | image2 = padder.pad(image2) 84 | image1 = image1.cuda(gpu) 85 | image2 = image2.cuda(gpu) 86 | with torch.no_grad(): 87 | model_raft.eval() 88 | _,flow = model_raft(image1,image2,iters=20, test_mode=True) 89 | flow = padder.unpad(flow) 90 | 91 | flow = flow.data.cpu() 92 | pred = Image.open(os.path.join(result_dir,video,imgname.split('.')[0]+'.png')) 93 | next_pred = Image.open(os.path.join(result_dir,video,next_imgname.split('.')[0]+'.png')) 94 | pred =torch.from_numpy(np.array(pred)) 95 | next_pred = torch.from_numpy(np.array(next_pred)) 96 | next_pred = next_pred.unsqueeze(0).unsqueeze(0).float() 97 | # print(next_pred) 98 | 99 | warp_pred = flowwarp(next_pred,flow) 100 | # print(warp_pred) 101 | warp_pred = warp_pred.int().squeeze(1).numpy() 102 | pred = pred.unsqueeze(0).numpy() 103 | evaluator.add_batch(pred, warp_pred) 104 | # v_mIoU = evaluator.Mean_Intersection_over_Union() 105 | # total_TC+=v_mIoU 106 | # print('processed video : {} score:{}'.format(video,v_mIoU)) 107 | 108 | #TC = total_TC/len(list_) 109 | TC = evaluator.Mean_Intersection_over_Union() 110 | 111 | print("TC score is {}".format(TC)) 112 | 113 | print(split) 114 | print(result_dir) 115 | 116 | 117 | 118 | 119 | 120 | 121 | -------------------------------------------------------------------------------- /TC_cal_total.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | from RAFT_core.raft import RAFT 4 | from RAFT_core.utils.utils import InputPadder 5 | from collections import OrderedDict 6 | from utils import Evaluator 7 | import numpy as np 8 | import torch 9 | import torch.nn as nn 10 | 11 | def flowwarp(x, flo): 12 | """ 13 | warp an image/tensor (im2) back to im1, according to the optical flow 14 | x: [B, C, H, W] (im2) 15 | flo: [B, 2, H, W] flow 16 | """ 17 | B, C, H, W = x.size() 18 | # mesh grid 19 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 20 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 21 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 22 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 23 | grid = torch.cat((xx,yy),1).float() 24 | 25 | if x.is_cuda: 26 | grid = grid.to(x.device) 27 | vgrid = grid + flo 28 | 29 | # scale grid to [-1,1] 30 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 31 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 32 | 33 | vgrid = vgrid.permute(0,2,3,1) 34 | output = nn.functional.grid_sample(x, vgrid,mode='nearest',align_corners=False) 35 | 36 | return output 37 | 38 | 39 | 40 | num_class=124 41 | data_dir='/home/miaojiaxu/jiaxu_2/semantic_seg/LVSP_plus_data_label124_480p/data' 42 | #tar_dir ='/home/miaojiaxu/jiaxu_2/semantic_seg/newsaveimages' 43 | tar_dir ='/home/miaojiaxu/jiaxu_2/semantic_seg/newsaveimages_ab/ab' 44 | 45 | gpu=5 46 | ### 47 | model_raft = RAFT() 48 | to_load = torch.load('./RAFT/models/raft-things.pth-no-zip') 49 | new_state_dict = OrderedDict() 50 | for k, v in to_load.items(): 51 | name = k[7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,正好去掉了module. 52 | new_state_dict[name] = v #新字典的key值对应的value为一一对应的值。 53 | model_raft.load_state_dict(new_state_dict) 54 | model_raft = model_raft.cuda(gpu) 55 | ### 56 | 57 | #for split in ['val.txt','test.txt']: 58 | for split in ['val.txt']: 59 | 60 | #split='val.txt' 61 | with open('/home/miaojiaxu/jiaxu_2/semantic_seg/LVSP_plus_data_label124_480p/'+split,'r') as f: 62 | list_ = f.readlines() 63 | list_ = [v[:-1] for v in list_] 64 | 65 | 66 | 67 | for fold in os.listdir(tar_dir): 68 | 69 | if fold =='__pycache__': 70 | continue 71 | if os.path.isdir(os.path.join(tar_dir,fold)): 72 | 73 | result_dir=os.path.join(tar_dir,fold) 74 | else: 75 | continue 76 | #list_=['1001_5z_ijQjUf_0','1002_QXQ_QoswLOs'] 77 | 78 | 79 | total_TC=0. 80 | evaluator = Evaluator(num_class) 81 | for video in list_[:100]: 82 | imglist_ = sorted(os.listdir(os.path.join(data_dir,video,'origin'))) 83 | for i,img in enumerate(imglist_[:-1]): 84 | #print('processing video : {} image: {}'.format(video,img)) 85 | next_img = imglist_[i+1] 86 | imgname = img 87 | next_imgname = next_img 88 | img = Image.open(os.path.join(data_dir,video,'origin',img)) 89 | next_img =Image.open(os.path.join(data_dir,video,'origin',next_img)) 90 | image1 = torch.from_numpy(np.array(img)) 91 | image2 = torch.from_numpy(np.array(next_img)) 92 | padder = InputPadder(image1.size()[:2]) 93 | image1 = image1.unsqueeze(0).permute(0,3,1,2) 94 | image2 = image2.unsqueeze(0).permute(0,3,1,2) 95 | image1 = padder.pad(image1) 96 | image2 = padder.pad(image2) 97 | image1 = image1.cuda(gpu) 98 | image2 = image2.cuda(gpu) 99 | with torch.no_grad(): 100 | model_raft.eval() 101 | _,flow = model_raft(image1,image2,iters=20, test_mode=True) 102 | flow = padder.unpad(flow) 103 | 104 | flow = flow.data.cpu() 105 | pred = Image.open(os.path.join(result_dir,video,imgname.split('.')[0]+'.png')) 106 | next_pred = Image.open(os.path.join(result_dir,video,next_imgname.split('.')[0]+'.png')) 107 | pred =torch.from_numpy(np.array(pred)) 108 | next_pred = torch.from_numpy(np.array(next_pred)) 109 | next_pred = next_pred.unsqueeze(0).unsqueeze(0).float() 110 | # print(next_pred) 111 | 112 | warp_pred = flowwarp(next_pred,flow) 113 | # print(warp_pred) 114 | warp_pred = warp_pred.int().squeeze(1).numpy() 115 | pred = pred.unsqueeze(0).numpy() 116 | evaluator.add_batch(pred, warp_pred) 117 | v_mIoU = evaluator.Mean_Intersection_over_Union() 118 | # total_TC+=v_mIoU 119 | # print('processed video : {} score:{}'.format(video,v_mIoU)) 120 | 121 | # TC = total_TC/len(list_) 122 | TC = v_mIoU 123 | print('*'*100) 124 | print("TC score is {}".format(TC)) 125 | 126 | print(split) 127 | print(result_dir) 128 | print('*'*100) 129 | 130 | 131 | 132 | 133 | 134 | 135 | -------------------------------------------------------------------------------- /VC_perclip.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | #from utils import Evaluator 5 | import sys 6 | 7 | def get_common(list_,predlist,clip_num,h,w): 8 | accs = [] 9 | for i in range(len(list_)-clip_num): 10 | global_common = np.ones((h,w)) 11 | predglobal_common = np.ones((h,w)) 12 | 13 | 14 | for j in range(1,clip_num): 15 | common = (list_[i] == list_[i+j]) 16 | global_common = np.logical_and(global_common,common) 17 | pred_common = (predlist[i]==predlist[i+j]) 18 | predglobal_common = np.logical_and(predglobal_common,pred_common) 19 | pred = (predglobal_common*global_common) 20 | 21 | acc = pred.sum()/global_common.sum() 22 | accs.append(acc) 23 | return accs 24 | 25 | 26 | 27 | DIR=sys.argv[1] 28 | 29 | Pred=sys.argv[2] 30 | split = 'val.txt' 31 | 32 | with open(os.path.join(DIR,split),'r') as f: 33 | lines = f.readlines() 34 | for line in lines: 35 | videolist = [line[:-1] for line in lines] 36 | total_acc=[] 37 | 38 | clip_num=16 39 | 40 | 41 | for video in videolist: 42 | imglist = [] 43 | predlist = [] 44 | 45 | images = sorted(os.listdir(os.path.join(DIR,'data',video,'mask'))) 46 | 47 | if len(images)<=clip_num: 48 | continue 49 | for imgname in images: 50 | img = Image.open(os.path.join(DIR,'data',video,'mask',imgname)) 51 | w,h = img.size 52 | img = np.array(img) 53 | imglist.append(img) 54 | pred = Image.open(os.path.join(Pred,video,imgname)) 55 | pred = np.array(pred) 56 | predlist.append(pred) 57 | 58 | accs = get_common(imglist,predlist,clip_num,h,w) 59 | print(sum(accs)/len(accs)) 60 | total_acc.extend(accs) 61 | Acc = np.array(total_acc) 62 | Acc = np.nanmean(Acc) 63 | print(Pred) 64 | print('*'*10) 65 | print('VC{} score: {} on {} set'.format(clip_num,Acc,split)) 66 | print('*'*10) 67 | 68 | -------------------------------------------------------------------------------- /VC_perclip_total.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | #from utils import Evaluator 5 | 6 | 7 | def get_common(list_,predlist,clip_num,h,w): 8 | accs = [] 9 | for i in range(len(list_)-clip_num): 10 | global_common = np.ones((h,w)) 11 | predglobal_common = np.ones((h,w)) 12 | 13 | 14 | for j in range(1,clip_num): 15 | common = (list_[i] == list_[i+j]) 16 | global_common = np.logical_and(global_common,common) 17 | pred_common = (predlist[i]==predlist[i+j]) 18 | predglobal_common = np.logical_and(predglobal_common,pred_common) 19 | pred = (predglobal_common*global_common) 20 | 21 | acc = pred.sum()/global_common.sum() 22 | accs.append(acc) 23 | return accs 24 | 25 | 26 | 27 | DIR='/home/miaojiaxu/jiaxu_2/semantic_seg/LVSP_plus_data_label124_480p' 28 | #tar_dir = '/home/miaojiaxu/jiaxu_2/semantic_seg/newsaveimages' 29 | tar_dir = '/home/miaojiaxu/jiaxu_2/semantic_seg/newsaveimages_ab/abpsp' 30 | #for clip_num in [8,16]: 31 | for clip_num in [16]: 32 | # for split in ['val.txt','test.txt']: 33 | for split in ['val.txt']: 34 | with open(os.path.join(DIR,split),'r') as f: 35 | lines = f.readlines() 36 | for line in lines: 37 | videolist = [line[:-1] for line in lines] 38 | for fold in os.listdir(tar_dir): 39 | if fold =='__pycache__': 40 | continue 41 | if os.path.isdir(os.path.join(tar_dir,fold)): 42 | Pred = os.path.join(tar_dir,fold) 43 | else: 44 | continue 45 | #Pred='/home/miaojiaxu/jiaxu_2/semantic_seg/VSP_124_saveimg/clip_ocr_369_result1' 46 | 47 | 48 | 49 | 50 | total_acc=[] 51 | for video in videolist: 52 | imglist = [] 53 | predlist = [] 54 | 55 | images = sorted(os.listdir(os.path.join(DIR,'data',video,'mask'))) 56 | 57 | if len(images)<=clip_num: 58 | continue 59 | for imgname in images: 60 | img = Image.open(os.path.join(DIR,'data',video,'mask',imgname)) 61 | w,h = img.size 62 | img = np.array(img) 63 | imglist.append(img) 64 | pred = Image.open(os.path.join(Pred,video,imgname)) 65 | pred = np.array(pred) 66 | predlist.append(pred) 67 | 68 | accs = get_common(imglist,predlist,clip_num,h,w) 69 | # print(sum(accs)/len(accs)) 70 | total_acc.extend(accs) 71 | Acc = np.array(total_acc) 72 | Acc = np.nanmean(Acc) 73 | print('*'*10) 74 | print(Acc) 75 | print(clip_num) 76 | print(Pred) 77 | print(split) 78 | 79 | -------------------------------------------------------------------------------- /evaluation.sh: -------------------------------------------------------------------------------- 1 | DATA_DIR='/your/path/to/VSPW' 2 | PRED_DIR='/your/path/to/predictions' 3 | 4 | ###mIoU 5 | 6 | python evaluator_test.py $DATA_DIR $PRED_DIR 7 | 8 | ###TC score 9 | python TC_cal.py $DATA_DIR $PRED_DIR 10 | 11 | 12 | ##VC score 13 | 14 | python VC_perclip.py $DATA_DIR $PRED_DIR 15 | -------------------------------------------------------------------------------- /evaluator_test.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from PIL import Image 4 | from utils import Evaluator 5 | import sys 6 | eval_ = Evaluator(124) 7 | eval_.reset() 8 | 9 | DIR=sys.argv[1] 10 | split = 'val.txt' 11 | 12 | with open(os.path.join(DIR,split),'r') as f: 13 | lines = f.readlines() 14 | for line in lines: 15 | videolist = [line[:-1] for line in lines] 16 | PRED=sys.argv[2] 17 | for video in videolist: 18 | for tar in os.listdir(os.path.join(DIR,'data',video,'mask')): 19 | pred = os.path.join(PRED,video,tar) 20 | tar_ = Image.open(os.path.join(DIR,'data',video,'mask',tar)) 21 | tar_ = np.array(tar_) 22 | tar_ = tar_[np.newaxis,:] 23 | pred_ = Image.open(pred) 24 | pred_ = np.array(pred_) 25 | pred_ = pred_[np.newaxis,:] 26 | eval_.add_batch(tar_,pred_) 27 | 28 | Acc = eval_.Pixel_Accuracy() 29 | Acc_class = eval_.Pixel_Accuracy_Class() 30 | mIoU = eval_.Mean_Intersection_over_Union() 31 | FWIoU = eval_.Frequency_Weighted_Intersection_over_Union() 32 | print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format(Acc, Acc_class, mIoU, FWIoU)) 33 | 34 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import torch 3 | import torch.nn as nn 4 | import os 5 | import logging 6 | import re 7 | import functools 8 | import fnmatch 9 | import numpy as np 10 | def flowwarp(x, flo): 11 | """ 12 | warp an image/tensor (im2) back to im1, according to the optical flow 13 | x: [B, C, H, W] (im2) 14 | flo: [B, 2, H, W] flow 15 | """ 16 | B, C, H, W = x.size() 17 | # mesh grid 18 | xx = torch.arange(0, W).view(1,-1).repeat(H,1) 19 | yy = torch.arange(0, H).view(-1,1).repeat(1,W) 20 | xx = xx.view(1,1,H,W).repeat(B,1,1,1) 21 | yy = yy.view(1,1,H,W).repeat(B,1,1,1) 22 | grid = torch.cat((xx,yy),1).float() 23 | 24 | if x.is_cuda: 25 | grid = grid.to(x.device) 26 | vgrid = grid + flo 27 | 28 | # scale grid to [-1,1] 29 | vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1)-1.0 30 | vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1)-1.0 31 | 32 | vgrid = vgrid.permute(0,2,3,1) 33 | output = nn.functional.grid_sample(x, vgrid,align_corners=False) 34 | 35 | return output 36 | 37 | def get_common(list_,predlist,clip_num,h,w): 38 | accs = [] 39 | for i in range(len(list_)-clip_num): 40 | global_common = np.ones((h,w)) 41 | predglobal_common = np.ones((h,w)) 42 | 43 | 44 | for j in range(1,clip_num): 45 | common = (list_[i] == list_[i+j]) 46 | global_common = np.logical_and(global_common,common) 47 | pred_common = (predlist[i]==predlist[i+j]) 48 | predglobal_common = np.logical_and(predglobal_common,pred_common) 49 | pred = (predglobal_common*global_common) 50 | 51 | acc = pred.sum()/global_common.sum() 52 | accs.append(acc) 53 | return accs 54 | 55 | class Evaluator(object): 56 | def __init__(self, num_class): 57 | self.num_class = num_class 58 | self.confusion_matrix = np.zeros((self.num_class,)*2) 59 | 60 | def beforeval(self): 61 | isval = np.sum(self.confusion_matrix,axis=1)>0 62 | self.confusion_matrix = self.confusion_matrix*isval 63 | 64 | def Pixel_Accuracy(self): 65 | Acc = np.diag(self.confusion_matrix).sum() / self.confusion_matrix.sum() 66 | return Acc 67 | 68 | def Pixel_Accuracy_Class(self): 69 | Acc = np.diag(self.confusion_matrix) / self.confusion_matrix.sum(axis=1) 70 | Acc = np.nanmean(Acc) 71 | return Acc 72 | 73 | 74 | def Mean_Intersection_over_Union(self): 75 | MIoU = np.diag(self.confusion_matrix) / ( 76 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 77 | np.diag(self.confusion_matrix)) 78 | isval = np.sum(self.confusion_matrix,axis=1)>0 79 | MIoU = np.nansum(MIoU*isval)/isval.sum() 80 | return MIoU 81 | 82 | def Frequency_Weighted_Intersection_over_Union(self): 83 | freq = np.sum(self.confusion_matrix, axis=1) / np.sum(self.confusion_matrix) 84 | iu = np.diag(self.confusion_matrix) / ( 85 | np.sum(self.confusion_matrix, axis=1) + np.sum(self.confusion_matrix, axis=0) - 86 | np.diag(self.confusion_matrix)) 87 | 88 | FWIoU = (freq[freq > 0] * iu[freq > 0]).sum() 89 | return FWIoU 90 | 91 | def _generate_matrix(self, gt_image, pre_image): 92 | mask = (gt_image >= 0) & (gt_image < self.num_class) 93 | #print(mask) 94 | #print(gt_image.shape) 95 | #print(gt_image[mask]) 96 | label = self.num_class * gt_image[mask].astype('int') + pre_image[mask] 97 | # print(label.shape) 98 | count = np.bincount(label, minlength=self.num_class**2) 99 | confusion_matrix = count.reshape(self.num_class, self.num_class) 100 | return confusion_matrix 101 | 102 | def add_batch(self, gt_image, pre_image): 103 | assert gt_image.shape == pre_image.shape 104 | self.confusion_matrix += self._generate_matrix(gt_image, pre_image) 105 | 106 | def reset(self): 107 | self.confusion_matrix = np.zeros((self.num_class,) * 2) 108 | 109 | 110 | def setup_logger(distributed_rank=0, filename="log.txt"): 111 | logger = logging.getLogger("Logger") 112 | logger.setLevel(logging.DEBUG) 113 | # don't log results for the non-master process 114 | if distributed_rank > 0: 115 | return logger 116 | ch = logging.StreamHandler(stream=sys.stdout) 117 | ch.setLevel(logging.DEBUG) 118 | fmt = "[%(asctime)s %(levelname)s %(filename)s line %(lineno)d %(process)d] %(message)s" 119 | ch.setFormatter(logging.Formatter(fmt)) 120 | logger.addHandler(ch) 121 | 122 | return logger 123 | 124 | 125 | def find_recursive(root_dir, ext='.jpg'): 126 | files = [] 127 | for root, dirnames, filenames in os.walk(root_dir): 128 | for filename in fnmatch.filter(filenames, '*' + ext): 129 | if filename[0]=='.': 130 | continue 131 | files.append(os.path.join(root, filename)) 132 | return files 133 | 134 | 135 | class AverageMeter(object): 136 | """Computes and stores the average and current value""" 137 | def __init__(self): 138 | self.initialized = False 139 | self.val = None 140 | self.avg = None 141 | self.sum = None 142 | self.count = None 143 | 144 | def initialize(self, val, weight): 145 | self.val = val 146 | self.avg = val 147 | self.sum = val * weight 148 | self.count = weight 149 | self.initialized = True 150 | 151 | def update(self, val, weight=1): 152 | if not self.initialized: 153 | self.initialize(val, weight) 154 | else: 155 | self.add(val, weight) 156 | 157 | def add(self, val, weight): 158 | self.val = val 159 | self.sum += val * weight 160 | self.count += weight 161 | self.avg = self.sum / self.count 162 | 163 | def value(self): 164 | return self.val 165 | 166 | def average(self): 167 | return self.avg 168 | 169 | 170 | def unique(ar, return_index=False, return_inverse=False, return_counts=False): 171 | ar = np.asanyarray(ar).flatten() 172 | 173 | optional_indices = return_index or return_inverse 174 | optional_returns = optional_indices or return_counts 175 | 176 | if ar.size == 0: 177 | if not optional_returns: 178 | ret = ar 179 | else: 180 | ret = (ar,) 181 | if return_index: 182 | ret += (np.empty(0, np.bool),) 183 | if return_inverse: 184 | ret += (np.empty(0, np.bool),) 185 | if return_counts: 186 | ret += (np.empty(0, np.intp),) 187 | return ret 188 | if optional_indices: 189 | perm = ar.argsort(kind='mergesort' if return_index else 'quicksort') 190 | aux = ar[perm] 191 | else: 192 | ar.sort() 193 | aux = ar 194 | flag = np.concatenate(([True], aux[1:] != aux[:-1])) 195 | 196 | if not optional_returns: 197 | ret = aux[flag] 198 | else: 199 | ret = (aux[flag],) 200 | if return_index: 201 | ret += (perm[flag],) 202 | if return_inverse: 203 | iflag = np.cumsum(flag) - 1 204 | inv_idx = np.empty(ar.shape, dtype=np.intp) 205 | inv_idx[perm] = iflag 206 | ret += (inv_idx,) 207 | if return_counts: 208 | idx = np.concatenate(np.nonzero(flag) + ([ar.size],)) 209 | ret += (np.diff(idx),) 210 | return ret 211 | 212 | 213 | def colorEncode(labelmap, colors, mode='RGB'): 214 | labelmap = labelmap.astype('int') 215 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 216 | dtype=np.uint8) 217 | for label in unique(labelmap): 218 | if label < 0: 219 | continue 220 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 221 | np.tile(colors[label], 222 | (labelmap.shape[0], labelmap.shape[1], 1)) 223 | 224 | if mode == 'BGR': 225 | return labelmap_rgb[:, :, ::-1] 226 | else: 227 | return labelmap_rgb 228 | 229 | 230 | def accuracy(preds, label): 231 | valid = (label >= 0) 232 | acc_sum = (valid * (preds == label)).sum() 233 | valid_sum = valid.sum() 234 | acc = float(acc_sum) / (valid_sum + 1e-10) 235 | return acc, valid_sum 236 | 237 | 238 | def intersectionAndUnion(imPred, imLab, numClass): 239 | imPred = np.asarray(imPred).copy() 240 | imLab = np.asarray(imLab).copy() 241 | 242 | imPred += 1 243 | imLab += 1 244 | # Remove classes from unlabeled pixels in gt image. 245 | # We should not penalize detections in unlabeled portions of the image. 246 | imPred = imPred * (imLab > 0) 247 | 248 | # Compute area intersection: 249 | intersection = imPred * (imPred == imLab) 250 | (area_intersection, _) = np.histogram( 251 | intersection, bins=numClass, range=(1, numClass)) 252 | 253 | # Compute area union: 254 | (area_pred, _) = np.histogram(imPred, bins=numClass, range=(1, numClass)) 255 | (area_lab, _) = np.histogram(imLab, bins=numClass, range=(1, numClass)) 256 | area_union = area_pred + area_lab - area_intersection 257 | 258 | return (area_intersection, area_union) 259 | 260 | 261 | class NotSupportedCliException(Exception): 262 | pass 263 | 264 | 265 | def process_range(xpu, inp): 266 | start, end = map(int, inp) 267 | if start > end: 268 | end, start = start, end 269 | return map(lambda x: '{}{}'.format(xpu, x), range(start, end+1)) 270 | 271 | 272 | REGEX = [ 273 | (re.compile(r'^gpu(\d+)$'), lambda x: ['gpu%s' % x[0]]), 274 | (re.compile(r'^(\d+)$'), lambda x: ['gpu%s' % x[0]]), 275 | (re.compile(r'^gpu(\d+)-(?:gpu)?(\d+)$'), 276 | functools.partial(process_range, 'gpu')), 277 | (re.compile(r'^(\d+)-(\d+)$'), 278 | functools.partial(process_range, 'gpu')), 279 | ] 280 | 281 | 282 | def parse_devices(input_devices): 283 | 284 | """Parse user's devices input str to standard format. 285 | e.g. [gpu0, gpu1, ...] 286 | 287 | """ 288 | ret = [] 289 | for d in input_devices.split(','): 290 | for regex, func in REGEX: 291 | m = regex.match(d.lower().strip()) 292 | if m: 293 | tmp = func(m.groups()) 294 | # prevent duplicate 295 | for x in tmp: 296 | if x not in ret: 297 | ret.append(x) 298 | break 299 | else: 300 | raise NotSupportedCliException( 301 | 'Can not recognize device: "{}"'.format(d)) 302 | return ret 303 | --------------------------------------------------------------------------------