├── DBVI_2x ├── README.md ├── configs │ ├── configTest.py │ └── configTrain.py ├── dataloader │ ├── Vimeo.py │ ├── cvtransform.py │ ├── dataUtils.py │ ├── dataloaderBase.py │ ├── davis.py │ ├── snufilm.py │ └── ucf101.py ├── datasets │ └── dataset ├── lib │ ├── RAFT │ │ ├── alt_cuda_corr │ │ │ ├── correlation.cpp │ │ │ ├── correlation_kernel.cu │ │ │ └── setup.py │ │ ├── corr.py │ │ ├── extractor.py │ │ ├── raft-things.pth │ │ ├── raft.py │ │ ├── update.py │ │ └── utils.py │ ├── __init__.py │ ├── checkTool.py │ ├── distTool.py │ ├── dlTool.py │ ├── fileTool.py │ ├── lossTool.py │ ├── metrics.py │ ├── softsplat.py │ ├── videoTool.py │ ├── visualTool.py │ └── warp.py ├── model │ ├── RRDBNet.py │ ├── block.py │ ├── invBlock │ │ ├── __init__.py │ │ └── permute.py │ └── module_util.py ├── output │ └── output ├── runTest.py ├── runTrain.py ├── test.py └── train.py ├── DBVI_8x ├── configs │ ├── configTest.py │ └── configTrain.py ├── dataloader │ ├── Adobe.py │ ├── GoPro.py │ ├── XVFI.py │ ├── cvtransform.py │ ├── dataUtils.py │ └── dataloaderBase.py ├── datasets │ └── dataset ├── lib │ ├── RAFT │ │ ├── alt_cuda_corr │ │ │ ├── build │ │ │ │ └── temp.linux-x86_64-3.6 │ │ │ │ │ ├── build.ninja │ │ │ │ │ └── correlation.o │ │ │ ├── correlation.cpp │ │ │ ├── correlation.egg-info │ │ │ │ ├── PKG-INFO │ │ │ │ ├── SOURCES.txt │ │ │ │ ├── dependency_links.txt │ │ │ │ └── top_level.txt │ │ │ ├── correlation_kernel.cu │ │ │ └── setup.py │ │ ├── corr.py │ │ ├── extractor.py │ │ ├── raft-things.pth │ │ ├── raft.py │ │ ├── update.py │ │ └── utils.py │ ├── __init__.py │ ├── checkTool.py │ ├── distTool.py │ ├── dlTool.py │ ├── fileTool.py │ ├── lossTool.py │ ├── metrics.py │ ├── softsplat.py │ ├── videoTool.py │ ├── visualTool.py │ └── warp.py ├── model │ ├── RRDBNet.py │ ├── block.py │ ├── invBlock │ │ ├── __init__.py │ │ └── permute.py │ └── module_util.py ├── output │ └── output ├── readme.md ├── runTest.py ├── runTrain.py ├── test.py └── train.py ├── LICENSE.md ├── README.md └── mkDataset ├── forAdobe ├── getTestList.py └── png2lmdb.py ├── forDavis ├── getTestList.py └── png2lmdb.py ├── forGoPro ├── getTestList.py ├── getTrainList.py └── png2lmdb.py ├── forSnufilm ├── png2lmdb.py └── txt2sample.py ├── forUCF101 ├── getTestList.py └── png2lmdb.py ├── forVimeo ├── getTrainTestList.py ├── png2lmdb.py └── splitTrainTest.py ├── forXVFI ├── getTestListXVFi.py ├── getTrainListXVFi.py ├── mp4_decoding.py └── png2lmdbXVFI.py └── lib ├── __init__.py ├── dataUtils.py └── fileTool.py /DBVI_2x/README.md: -------------------------------------------------------------------------------- 1 | # 2x Interpolation 2 | 3 | ## 1. Preparing Dataset 4 | The training/testing datasets we used can be either downloaded from following links and processed with the codes in ../mkDataset/forXXX/ by changing each 'Path/to/' accordingly, or directly downloaded from [here](https://pan.baidu.com/s/1meK6lCXrwrBQ3KFgos1aDw?pwd=2022)(password:2022)(the ready-to-use lmdb files) 5 | #### Links: 6 | [Vimeo_Septuplet](http://data.csail.mit.edu/tofu/dataset/vimeo_septuplet.zip), 7 | [ucf101](https://sites.google.com/view/xiangyuxu/qvi_nips19), 8 | [DAVIS](https://sites.google.com/view/xiangyuxu/qvi_nips19), 9 | [SNU-FILM](https://myungsub.github.io/CAIN/) 10 | 11 | The processed files should be put in ./datasets and originized as: 12 | ``` 13 | datasets/ 14 | vimeo/ 15 | test_lmdb/ 16 | data.mdb 17 | lock.mdb 18 | sample.pkl 19 | train_lmdb/ 20 | data.mdb 21 | lock.mdb 22 | sample.pkl 23 | ``` 24 | 25 | ## 2. Training 26 | ### Training with single gpu: 27 | (1) Set whether resume or not, dir of checkpoints and the name of pretrained weights(only needed if resume is true) in configs/configTrain.py(line50~54). 28 | 29 | (2) Open a terminal and run ifconfig to get your ip address: XXX.XXX.XXX.XXX 30 | 31 | (3) python train.py --initNode=XXX.XXX.XXX.XXX 32 | 33 | ### Distributed training with muli-gpus(16GPU,2Nodes) on cluser managed by [slurm](https://slurm.schedmd.com/quickstart_admin.html): 34 | (1) Set the name of train set(GoPro/X4K1000FPS), whether resume or not, dir of checkpoints and the name of pretrained weights(only needed if resume is true) in configs/configTrain.py(line50~54). 35 | 36 | (2) Set the name of part and nodes in cluser, number and index of gpus/cpus per-node and so on in runTrain.py(line3~14). 37 | 38 | The example in runTrain.py is running on one part named Pixel, two nodes named 'SH-IDC1-10-5-39-55' and 'SH-IDC1-10-5-31-54', with 8 gpus per-node. 39 | 40 | (3) python runTrain.py 41 | 42 | ## 3. Testing with Pretrained Models 43 | [Model](https://pan.baidu.com/s/1TOtVA8f7my5vzB0n_kOEnA) pretrained on Vimeo-septulets (password:2022) 44 | Download model and put it under ./output/ 45 | 46 | ### Testing with single gpu: 47 | (1) Set the name of test set, dir of checkpoints and the name of pretrained weights in configs/configTest.py(line50~55). 48 | 49 | (2) Open a terminal and run ifconfig to get your ip address: XXX.XXX.XXX.XXX 50 | 51 | (3) python test.py --initNode=XXX.XXX.XXX.XXX 52 | 53 | ### Distributed testing with muli-gpus(10) on cluser managed by [slurm](https://slurm.schedmd.com/quickstart_admin.html): 54 | (1) Set the name of test set, dir of checkpoints and the name of pretrained weights in configs/configTest.py(line50~55). 55 | 56 | (2) Set the name of part and nodes in cluser, number/index of gpus/cpus per-node and so on in runTest.py(line3~14). 57 | 58 | The example in runTest.py is running on one part named Pixel with two nodes named 'SH-IDC1-10-5-39-55' and 'SH-IDC1-10-5-31-54' and 5 gpus per-node. 59 | 60 | (3) python runTest.py 61 | 62 | -------------------------------------------------------------------------------- /DBVI_2x/dataloader/Vimeo.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | from dataloader import cvtransform 6 | import torch 7 | import torch.utils.data as data 8 | import torch.nn.functional as F 9 | from torch.utils.data.distributed import DistributedSampler 10 | from dataloader.dataloaderBase import DistributedSamplerVali 11 | from configs.configTrain import configMain 12 | import pickle 13 | import lmdb 14 | import numpy as np 15 | from lib.visualTool import visImg 16 | 17 | 18 | # train------------------------------------------------------------------- 19 | class Train(data.Dataset): 20 | def __init__(self, cfg: configMain): 21 | self.Sample = './datasets/vimeo/train_lmdb/sample.pkl' 22 | self.LMDB = './datasets/vimeo/train_lmdb/data.mdb' 23 | self.numIter = cfg.numIter 24 | 25 | with open(self.Sample, 'rb') as fs: 26 | self.Sample = pickle.load(fs) 27 | 28 | self.length = len(self.Sample) 29 | # self.length = 10 30 | self.transforms = cvtransform.Compose([ 31 | cvtransform.RandomCrop(cfg.train.size), 32 | cvtransform.RandomHorizontalFlip(0.5), 33 | cvtransform.RandomHVerticalFlip(0.5), 34 | cvtransform.ColorJitter(0.05, 0.05, 0.05, 0.05), 35 | cvtransform.ToTensor() 36 | ]) 37 | self.env = None 38 | self.txn = None 39 | self.outKeys = ['In', 'I0', 'It', 'I1', 'I2'] 40 | 41 | def __getitem__(self, idx): 42 | cv2.setNumThreads(0) 43 | cv2.ocl.setUseOpenCL(False) 44 | if any([self.txn is None, self.env is None]): 45 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 46 | meminit=False) 47 | self.txn = self.env.begin(write=False) 48 | 49 | sampleKeys = self.Sample[idx] 50 | 51 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 52 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 53 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 54 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 55 | 56 | It_1 = pickle.loads(self.txn.get(sampleKeys['It'].encode('ascii'))) 57 | 58 | t = torch.tensor(0.5, dtype=torch.float32) 59 | 60 | if np.random.rand() > 0.5: 61 | valueList = [In, I0, It_1, I1, I2] 62 | else: 63 | valueList = [I2, I1, It_1, I0, In] 64 | t = 1 - t 65 | 66 | valueList = self.transforms(valueList) 67 | valueList = valueList 68 | 69 | outDict = {'t': t} 70 | for key, value in zip(self.outKeys, valueList): 71 | outDict[key] = value 72 | return outDict 73 | 74 | def __len__(self): 75 | return self.length 76 | 77 | def close(self): 78 | if self.env is not None: 79 | self.env.close() 80 | self.txn = None 81 | self.env = None 82 | 83 | 84 | def creatTrainLoader(cfg: configMain): 85 | dataset = Train(cfg) 86 | 87 | sampler = DistributedSampler(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 88 | 89 | loader = data.DataLoader(dataset=dataset, batch_size=cfg.train.batchPerGPU, 90 | shuffle=False, num_workers=4, pin_memory=False, # False if memory is not enough 91 | drop_last=True, sampler=sampler) 92 | return sampler, loader 93 | 94 | 95 | class Test(data.Dataset): 96 | def __init__(self, cfg: configMain): 97 | super(Test, self).__init__() 98 | self.Sample = './datasets/vimeo/test_lmdb/sample.pkl' 99 | self.LMDB = './datasets/vimeo/test_lmdb/data.mdb' 100 | self.numIter = cfg.numIter 101 | 102 | with open(self.Sample, 'rb') as fs: 103 | self.Sample = pickle.load(fs) 104 | 105 | self.length = len(self.Sample) # 7824 106 | # self.length = 10 107 | self.transforms = cvtransform.Compose([ 108 | cvtransform.ToTensor() 109 | ]) 110 | self.env = None 111 | self.txn = None 112 | self.outKeys = ['In', 'I0', 'I1', 'I2', 'It'] 113 | 114 | def __getitem__(self, idx): 115 | cv2.setNumThreads(0) 116 | cv2.ocl.setUseOpenCL(False) 117 | if any([self.txn is None, self.env is None]): 118 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 119 | meminit=False) 120 | self.txn = self.env.begin(write=False) 121 | 122 | sampleKeys = self.Sample[idx] 123 | 124 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 125 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 126 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 127 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 128 | 129 | valueList = [In, I0, I1, I2] 130 | watchList = [sampleKeys['In'], sampleKeys['I0'], sampleKeys['I1'], sampleKeys['I2']] 131 | gtList = [] 132 | 133 | It = pickle.loads(self.txn.get(sampleKeys['It'].encode('ascii'))) 134 | ItName = sampleKeys['It'] 135 | valueList.append(It) 136 | gtList.append(ItName) 137 | valueList = self.transforms(valueList) 138 | 139 | outDict = {} 140 | for key, value in zip(self.outKeys, valueList): 141 | outDict[key] = value 142 | 143 | return outDict, watchList, gtList 144 | 145 | def __len__(self): 146 | return self.length 147 | 148 | 149 | def creatValiLoader(cfg: configMain): 150 | dataset = Test(cfg) 151 | 152 | # if cfg.dist.isDist: 153 | sampler = DistributedSamplerVali(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 154 | 155 | loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, 156 | num_workers=4, pin_memory=True, # False if memory is not enough 157 | drop_last=False, sampler=sampler) 158 | return sampler, loader 159 | 160 | 161 | if __name__ == "__main__": 162 | cfg = configMain() 163 | testSampler, testLoader = creatValiLoader(cfg) 164 | for valdict in testLoader: 165 | visImg(valdict['In'], wait=100) 166 | visImg(valdict['I0'], wait=100) 167 | visImg(valdict['I1'], wait=100) 168 | visImg(valdict['I2'], wait=100) 169 | 170 | visImg(valdict['I0'], wait=100) 171 | visImg(valdict['It1'], wait=100) 172 | visImg(valdict['It2'], wait=100) 173 | visImg(valdict['It3'], wait=100) 174 | visImg(valdict['It4'], wait=100) 175 | visImg(valdict['It5'], wait=100) 176 | visImg(valdict['It6'], wait=100) 177 | visImg(valdict['It7'], wait=100) 178 | # visImg(valdict['I1'], wait=100) 179 | pass 180 | -------------------------------------------------------------------------------- /DBVI_2x/dataloader/dataUtils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pixel_shuffle(input: torch.Tensor, scale_factor): 5 | batch_size, channels, in_height, in_width = input.size() 6 | 7 | out_channels = int(int(channels / scale_factor) / scale_factor) 8 | out_height = int(in_height * scale_factor) 9 | out_width = int(in_width * scale_factor) 10 | 11 | if scale_factor >= 1: 12 | input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width) 13 | shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() 14 | else: 15 | block_size = int(1 / scale_factor) 16 | input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size) 17 | shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 18 | 19 | return shuffle_out.view(batch_size, out_channels, out_height, out_width) 20 | -------------------------------------------------------------------------------- /DBVI_2x/dataloader/dataloaderBase.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class DistributedSamplerVali(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | 14 | .. note:: 15 | Dataset is assumed to be of constant size. 16 | 17 | Arguments: 18 | dataset: Dataset used for sampling. 19 | num_replicas (optional): Number of processes participating in 20 | distributed training. 21 | rank (optional): Rank of the current process within num_replicas. 22 | """ 23 | 24 | def __init__(self, dataset, num_replicas=None, rank=None): 25 | # super(DistributedSamplerVali, self).__init__() 26 | if num_replicas is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | num_replicas = dist.get_world_size() 30 | if rank is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available") 33 | rank = dist.get_rank() 34 | self.dataset = dataset 35 | self.num_replicas = num_replicas 36 | self.rank = rank 37 | self.epoch = 0 38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 39 | self.total_size = self.num_samples * self.num_replicas 40 | 41 | def __iter__(self): 42 | # deterministically shuffle based on epoch 43 | 44 | indices = list(range(len(self.dataset))) 45 | 46 | # add extra samples to make it evenly divisible 47 | indices += indices[:(self.total_size - len(indices))] 48 | assert len(indices) == self.total_size 49 | 50 | # subsample 51 | indices = indices[self.rank:self.total_size:self.num_replicas] 52 | assert len(indices) == self.num_samples 53 | 54 | return iter(indices) 55 | 56 | def __len__(self): 57 | return self.num_samples 58 | 59 | def set_epoch(self, epoch): 60 | self.epoch = epoch 61 | -------------------------------------------------------------------------------- /DBVI_2x/dataloader/davis.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | from dataloader import cvtransform 6 | import torch.utils.data as data 7 | from dataloader.dataloaderBase import DistributedSamplerVali 8 | from configs.configTrain import configMain 9 | import pickle 10 | import lmdb 11 | from lib.visualTool import visImg 12 | 13 | 14 | class Test(data.Dataset): 15 | def __init__(self, cfg: configMain): 16 | super(Test, self).__init__() 17 | self.Sample = './datasets/davis/davis_lmdb/sample.pkl' 18 | self.LMDB = './datasets/davis/davis_lmdb/data.mdb' 19 | self.numIter = cfg.numIter 20 | 21 | with open(self.Sample, 'rb') as fs: 22 | self.Sample = pickle.load(fs) 23 | 24 | self.length = len(self.Sample) # 2849 25 | 26 | self.transforms = cvtransform.Compose([ 27 | cvtransform.CenterCrop((480, 840)), 28 | cvtransform.ToTensor() 29 | ]) 30 | self.env = None 31 | self.txn = None 32 | self.outKeys = ['In', 'I0', 'I1', 'I2', 'It'] 33 | 34 | def __getitem__(self, idx): 35 | cv2.setNumThreads(0) 36 | cv2.ocl.setUseOpenCL(False) 37 | if any([self.txn is None, self.env is None]): 38 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 39 | meminit=False) 40 | self.txn = self.env.begin(write=False) 41 | 42 | sampleKeys = self.Sample[idx] 43 | 44 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 45 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 46 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 47 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 48 | 49 | valueList = [In, I0, I1, I2] 50 | watchList = [sampleKeys['In'], sampleKeys['I0'], sampleKeys['I1'], sampleKeys['I2']] 51 | gtList = [] 52 | 53 | It = pickle.loads(self.txn.get(sampleKeys['It'].encode('ascii'))) 54 | ItName = sampleKeys['It'] 55 | valueList.append(It) 56 | gtList.append(ItName) 57 | valueList = self.transforms(valueList) 58 | 59 | outDict = {} 60 | for key, value in zip(self.outKeys, valueList): 61 | outDict[key] = value 62 | 63 | return outDict, watchList, gtList 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | 69 | def creatValiLoader(cfg: configMain): 70 | dataset = Test(cfg) 71 | 72 | # if cfg.dist.isDist: 73 | sampler = DistributedSamplerVali(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 74 | 75 | loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, 76 | num_workers=4, pin_memory=False, # False if memory is not enough 77 | drop_last=False, sampler=sampler) 78 | return sampler, loader 79 | 80 | 81 | if __name__ == "__main__": 82 | cfg = configMain() 83 | testSampler, testLoader = creatValiLoader(cfg) 84 | for valdict in testLoader: 85 | visImg(valdict['In'], wait=100) 86 | visImg(valdict['I0'], wait=100) 87 | visImg(valdict['I1'], wait=100) 88 | visImg(valdict['I2'], wait=100) 89 | 90 | visImg(valdict['I0'], wait=100) 91 | visImg(valdict['It1'], wait=100) 92 | visImg(valdict['It2'], wait=100) 93 | visImg(valdict['It3'], wait=100) 94 | visImg(valdict['It4'], wait=100) 95 | visImg(valdict['It5'], wait=100) 96 | visImg(valdict['It6'], wait=100) 97 | visImg(valdict['It7'], wait=100) 98 | # visImg(valdict['I1'], wait=100) 99 | pass 100 | -------------------------------------------------------------------------------- /DBVI_2x/dataloader/snufilm.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | from dataloader import cvtransform 6 | import torch.utils.data as data 7 | from dataloader.dataloaderBase import DistributedSamplerVali 8 | from configs.configTest import configMain 9 | import pickle 10 | import lmdb 11 | from lib.visualTool import visImg 12 | 13 | 14 | class Test(data.Dataset): 15 | def __init__(self, cfg: configMain): 16 | super(Test, self).__init__() 17 | if 'easy' in cfg.dataset: 18 | self.Sample = './datasets/snufilm/snufilm_lmdb/test-easy.pkl' #310 19 | elif 'medium' in cfg.dataset: 20 | self.Sample = './datasets/snufilm/snufilm_lmdb/test-medium.pkl' #310 21 | elif 'snufilm-hard' in cfg.dataset: 22 | self.Sample = './datasets/snufilm/snufilm_lmdb/test-hard.pkl' #310 23 | else: 24 | self.Sample = './datasets/snufilm/snufilm_lmdb/test-extreme.pkl' #234 25 | 26 | self.LMDB = './datasets/snufilm/snufilm_lmdb/data.mdb' 27 | self.numIter = cfg.numIter 28 | 29 | with open(self.Sample, 'rb') as fs: 30 | self.Sample = pickle.load(fs) 31 | 32 | self.length = len(self.Sample) 33 | 34 | self.transforms = cvtransform.Compose([ 35 | cvtransform.ToTensor() 36 | ]) 37 | self.env = None 38 | self.txn = None 39 | self.outKeys = ['In', 'I0', 'I1', 'I2', 'It'] 40 | 41 | def __getitem__(self, idx): 42 | cv2.setNumThreads(0) 43 | cv2.ocl.setUseOpenCL(False) 44 | if any([self.txn is None, self.env is None]): 45 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 46 | meminit=False) 47 | self.txn = self.env.begin(write=False) 48 | 49 | sampleKeys = self.Sample[idx] 50 | 51 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 52 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 53 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 54 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 55 | 56 | valueList = [In, I0, I1, I2] 57 | watchList = [sampleKeys['In'], sampleKeys['I0'], sampleKeys['I1'], sampleKeys['I2']] 58 | gtList = [] 59 | 60 | It = pickle.loads(self.txn.get(sampleKeys['It'].encode('ascii'))) 61 | ItName = sampleKeys['It'] 62 | valueList.append(It) 63 | gtList.append(ItName) 64 | valueList = self.transforms(valueList) 65 | 66 | outDict = {} 67 | for key, value in zip(self.outKeys, valueList): 68 | outDict[key] = value 69 | 70 | return outDict, watchList, gtList 71 | 72 | def __len__(self): 73 | return self.length 74 | 75 | 76 | def creatValiLoader(cfg: configMain): 77 | dataset = Test(cfg) 78 | 79 | # if cfg.dist.isDist: 80 | sampler = DistributedSamplerVali(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 81 | 82 | loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, 83 | num_workers=4, pin_memory=False, # False if memory is not enough 84 | drop_last=False, sampler=sampler) 85 | return sampler, loader 86 | 87 | 88 | if __name__ == "__main__": 89 | cfg = configMain() 90 | testSampler, testLoader = creatValiLoader(cfg) 91 | for valdict, watchList, gtList in testLoader: 92 | visImg(valdict['In'], wait=100) 93 | visImg(valdict['I0'], wait=100) 94 | visImg(valdict['I1'], wait=100) 95 | visImg(valdict['I2'], wait=100) 96 | 97 | visImg(valdict['It'], wait=100) 98 | # visImg(valdict['I1'], wait=100) 99 | pass 100 | -------------------------------------------------------------------------------- /DBVI_2x/dataloader/ucf101.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | from dataloader import cvtransform 6 | import torch.utils.data as data 7 | from dataloader.dataloaderBase import DistributedSamplerVali 8 | from configs.configTrain import configMain 9 | import pickle 10 | import lmdb 11 | from lib.visualTool import visImg 12 | 13 | 14 | class Test(data.Dataset): 15 | def __init__(self, cfg: configMain): 16 | super(Test, self).__init__() 17 | self.Sample = './datasets/ucf101/ucf101_lmdb/sample.pkl' 18 | self.LMDB = './datasets/ucf101/ucf101_lmdb/data.mdb' 19 | self.numIter = cfg.numIter 20 | 21 | with open(self.Sample, 'rb') as fs: 22 | self.Sample = pickle.load(fs) 23 | 24 | self.length = len(self.Sample) # 100 25 | # self.length = 1 26 | self.transforms = cvtransform.Compose([ 27 | # cvtransform.CenterCrop((224, 224)), 28 | cvtransform.ToTensor() 29 | ]) 30 | self.env = None 31 | self.txn = None 32 | self.outKeys = ['In', 'I0', 'I1', 'I2', 'It'] 33 | 34 | def __getitem__(self, idx): 35 | cv2.setNumThreads(0) 36 | cv2.ocl.setUseOpenCL(False) 37 | if any([self.txn is None, self.env is None]): 38 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 39 | meminit=False) 40 | self.txn = self.env.begin(write=False) 41 | 42 | sampleKeys = self.Sample[idx] 43 | 44 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 45 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 46 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 47 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 48 | 49 | valueList = [In, I0, I1, I2] 50 | watchList = [sampleKeys['In'], sampleKeys['I0'], sampleKeys['I1'], sampleKeys['I2']] 51 | gtList = [] 52 | 53 | It = pickle.loads(self.txn.get(sampleKeys['It'].encode('ascii'))) 54 | ItName = sampleKeys['It'] 55 | valueList.append(It) 56 | gtList.append(ItName) 57 | valueList = self.transforms(valueList) 58 | 59 | outDict = {} 60 | for key, value in zip(self.outKeys, valueList): 61 | outDict[key] = value 62 | 63 | return outDict, watchList, gtList 64 | 65 | def __len__(self): 66 | return self.length 67 | 68 | 69 | def creatValiLoader(cfg: configMain): 70 | dataset = Test(cfg) 71 | 72 | # if cfg.dist.isDist: 73 | sampler = DistributedSamplerVali(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 74 | 75 | loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, 76 | num_workers=4, pin_memory=False, # False if memory is not enough 77 | drop_last=False, sampler=sampler) 78 | return sampler, loader 79 | 80 | 81 | if __name__ == "__main__": 82 | cfg = configMain() 83 | testSampler, testLoader = creatValiLoader(cfg) 84 | for valdict in testLoader: 85 | visImg(valdict['In'], wait=100) 86 | visImg(valdict['I0'], wait=100) 87 | visImg(valdict['I1'], wait=100) 88 | visImg(valdict['I2'], wait=100) 89 | 90 | visImg(valdict['I0'], wait=100) 91 | visImg(valdict['It1'], wait=100) 92 | visImg(valdict['It2'], wait=100) 93 | visImg(valdict['It3'], wait=100) 94 | visImg(valdict['It4'], wait=100) 95 | visImg(valdict['It5'], wait=100) 96 | visImg(valdict['It6'], wait=100) 97 | visImg(valdict['It7'], wait=100) 98 | # visImg(valdict['I1'], wait=100) 99 | pass 100 | -------------------------------------------------------------------------------- /DBVI_2x/datasets/dataset: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_2x/datasets/dataset -------------------------------------------------------------------------------- /DBVI_2x/lib/RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /DBVI_2x/lib/RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /DBVI_2x/lib/RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from lib.RAFT.utils import bilinear_sampler, coords_grid 4 | # from lib.RAFT import alt_cuda_corr 5 | 6 | 7 | class CorrBlock: 8 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 9 | self.num_levels = num_levels 10 | self.radius = radius 11 | self.corr_pyramid = [] 12 | 13 | # all pairs correlation 14 | corr = CorrBlock.corr(fmap1, fmap2) 15 | 16 | batch, h1, w1, dim, h2, w2 = corr.shape 17 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 18 | 19 | self.corr_pyramid.append(corr) 20 | for i in range(self.num_levels - 1): 21 | corr = F.avg_pool2d(corr, 2, stride=2) 22 | self.corr_pyramid.append(corr) 23 | 24 | def __call__(self, coords): 25 | r = self.radius 26 | coords = coords.permute(0, 2, 3, 1) 27 | batch, h1, w1, _ = coords.shape 28 | 29 | out_pyramid = [] 30 | for i in range(self.num_levels): 31 | corr = self.corr_pyramid[i] 32 | dx = torch.linspace(-r, r, 2 * r + 1) 33 | dy = torch.linspace(-r, r, 2 * r + 1) 34 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 35 | 36 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 37 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 38 | coords_lvl = centroid_lvl + delta_lvl 39 | 40 | corr = bilinear_sampler(corr, coords_lvl) 41 | corr = corr.view(batch, h1, w1, -1) 42 | out_pyramid.append(corr) 43 | 44 | out = torch.cat(out_pyramid, dim=-1) 45 | return out.permute(0, 3, 1, 2).contiguous().float() 46 | 47 | @staticmethod 48 | def corr(fmap1, fmap2): 49 | batch, dim, ht, wd = fmap1.shape 50 | fmap1 = fmap1.view(batch, dim, ht * wd) 51 | fmap2 = fmap2.view(batch, dim, ht * wd) 52 | 53 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 54 | corr = corr.view(batch, ht, wd, 1, ht, wd) 55 | return corr / torch.sqrt(torch.tensor(dim).float()) 56 | 57 | # 58 | # class AlternateCorrBlock: 59 | # def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 60 | # self.num_levels = num_levels 61 | # self.radius = radius 62 | # 63 | # self.pyramid = [(fmap1, fmap2)] 64 | # for i in range(self.num_levels): 65 | # fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 66 | # fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 67 | # self.pyramid.append((fmap1, fmap2)) 68 | # 69 | # def __call__(self, coords): 70 | # coords = coords.permute(0, 2, 3, 1) 71 | # B, H, W, _ = coords.shape 72 | # dim = self.pyramid[0][0].shape[1] 73 | # 74 | # corr_list = [] 75 | # for i in range(self.num_levels): 76 | # r = self.radius 77 | # fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 78 | # fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 79 | # 80 | # coords_i = (coords / 2 ** i).reshape(B, 1, H, W, 2).contiguous() 81 | # corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 82 | # corr_list.append(corr.squeeze(1)) 83 | # 84 | # corr = torch.stack(corr_list, dim=1) 85 | # corr = corr.reshape(B, -1, H, W) 86 | # return corr / torch.sqrt(torch.tensor(dim).float()) 87 | -------------------------------------------------------------------------------- /DBVI_2x/lib/RAFT/raft-things.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_2x/lib/RAFT/raft-things.pth -------------------------------------------------------------------------------- /DBVI_2x/lib/RAFT/raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from lib.RAFT.update import BasicUpdateBlock, SmallUpdateBlock 6 | from lib.RAFT.extractor import BasicEncoder, SmallEncoder 7 | from lib.RAFT.corr import CorrBlock 8 | from lib.RAFT.utils import coords_grid, upflow8 9 | from collections import OrderedDict 10 | import cv2 11 | cv2.setNumThreads(0) 12 | cv2.ocl.setUseOpenCL(False) 13 | from lib.visualTool import visFlow 14 | import argparse 15 | from pathlib import Path 16 | 17 | 18 | class RAFParam(object): 19 | def __init__(self): 20 | super(RAFParam, self).__init__() 21 | self.alternate_corr=False 22 | self.mixed_precision=False 23 | self.small=False 24 | 25 | 26 | class RAFT(nn.Module): 27 | def __init__(self): 28 | super(RAFT, self).__init__() 29 | import warnings 30 | warnings.filterwarnings("ignore") 31 | self.args = RAFParam() 32 | 33 | if self.args.small: 34 | self.hidden_dim = hdim = 96 35 | self.context_dim = cdim = 64 36 | self.args.corr_levels = 4 37 | self.args.corr_radius = 3 38 | 39 | else: 40 | self.hidden_dim = hdim = 128 41 | self.context_dim = cdim = 128 42 | self.args.corr_levels = 4 43 | self.args.corr_radius = 4 44 | 45 | self.args.dropout = 0 46 | self.args.alternate_corr = False 47 | 48 | # feature network, context network, and update block 49 | if self.args.small: 50 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=self.args.dropout) 51 | self.cnet = SmallEncoder(output_dim=hdim + cdim, norm_fn='none', dropout=self.args.dropout) 52 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 53 | 54 | else: 55 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.args.dropout) 56 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=self.args.dropout) 57 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 58 | 59 | pathPreWeight = str(Path(__file__).parent.absolute() / Path('raft-things.pth')) 60 | self.initPreweight(pathPreWeight) 61 | 62 | def freeze_bn(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.BatchNorm2d): 65 | m.eval() 66 | 67 | def initialize_flow(self, img): 68 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 69 | N, C, H, W = img.shape 70 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 71 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 72 | 73 | # optical flow computed as difference: flow = coords1 - coords0 74 | return coords0, coords1 75 | 76 | def upsample_flow(self, flow, mask, scale=8, ksize=3): 77 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 78 | N, C, H, W = flow.shape 79 | mask = mask.view(N, 1, ksize ** 2, scale, scale, H, W) 80 | mask = torch.softmax(mask, dim=2) 81 | 82 | up_flow = F.unfold(8 * flow, [ksize, ksize], padding=1) 83 | up_flow = up_flow.view(N, C, ksize ** 2, 1, 1, H, W) 84 | 85 | up_flow = torch.sum(mask * up_flow, dim=2) 86 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 87 | return up_flow.reshape(N, 2, scale * H, scale * W) 88 | 89 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 90 | """ Estimate optical flow between pair of frames """ 91 | image1 = image1.contiguous() 92 | image2 = image2.contiguous() 93 | 94 | hdim = self.hidden_dim 95 | cdim = self.context_dim 96 | 97 | fmap1, fmap2 = self.fnet([image1, image2]) 98 | 99 | fmap1 = fmap1.float() 100 | fmap2 = fmap2.float() 101 | 102 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 103 | 104 | cnet = self.cnet(image1) 105 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 106 | net = torch.tanh(net) 107 | inp = torch.relu(inp) 108 | 109 | coords0, coords1 = self.initialize_flow(image1) 110 | 111 | if flow_init is not None: 112 | coords1 = coords1 + flow_init 113 | 114 | # flow_predictions = [] 115 | for itr in range(iters): 116 | coords1 = coords1.detach() 117 | corr = corr_fn(coords1) # index correlation volume 118 | 119 | flow = coords1 - coords0 120 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 121 | 122 | # F(t+1) = F(t) + \Delta(t) 123 | coords1 = coords1 + delta_flow 124 | 125 | # upsample predictions 126 | # if up_mask is None: 127 | # flow_up = upflow8(coords1 - coords0) 128 | # else: 129 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 130 | 131 | return flow_up 132 | 133 | def initPreweight(self, pathPreWeight: str = None, rmModule=True): 134 | preW = self.getWeight(pathPreWeight) 135 | assert preW is not None, 'weighth in {} is empty'.format(pathPreWeight) 136 | modelW = self.state_dict() 137 | preWDict = OrderedDict() 138 | # modelWDict = OrderedDict() 139 | 140 | for k, v in preW.items(): 141 | if rmModule: 142 | preWDict[k.replace('module.', "")] = v 143 | else: 144 | preWDict[k] = v 145 | 146 | shareW = {k: v for k, v in preWDict.items() if str(k) in modelW} 147 | assert shareW, 'shareW is empty' 148 | self.load_state_dict(preWDict, strict=False) 149 | 150 | @staticmethod 151 | def getWeight(pathPreWeight: str = None): 152 | if pathPreWeight is not None: 153 | return torch.load(pathPreWeight, map_location=torch.device('cpu')) 154 | else: 155 | return None 156 | 157 | def main(): 158 | model = RAFT() 159 | model.cuda() 160 | model.eval() 161 | 162 | with torch.no_grad(): 163 | image1 = cv2.imread( 164 | '/home/sensetime/data/VideoInterpolation/highfps/goPro/240fps/GoPro_public/test/GOPR0384_11_00/000009.png') 165 | image2 = cv2.imread( 166 | '/home/sensetime/data/VideoInterpolation/highfps/goPro/240fps/GoPro_public/test/GOPR0384_11_00/000017.png') 167 | 168 | image1 = torch.from_numpy(image1).float().permute([2, 0, 1]).unsqueeze(0)[:, [2, 1, 0], ...].cuda() 169 | image2 = torch.from_numpy(image2).float().permute([2, 0, 1]).unsqueeze(0)[:, [2, 1, 0], ...].cuda() 170 | 171 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 172 | visFlow(flow_up) 173 | 174 | 175 | if __name__ == '__main__': 176 | main() -------------------------------------------------------------------------------- /DBVI_2x/lib/RAFT/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /DBVI_2x/lib/RAFT/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /DBVI_2x/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_2x/lib/__init__.py -------------------------------------------------------------------------------- /DBVI_2x/lib/checkTool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | import torch 6 | import torch.nn.functional as F 7 | from lib import fileTool as FT 8 | import torch 9 | 10 | 11 | def checkGrad(net): 12 | for parem in list(net.named_parameters()): 13 | if parem[1].grad is not None: 14 | print(parem[0] + ' \t shape={}, \t mean={}, \t std={}\n'.format(parem[1].shape, 15 | parem[1].grad.abs().mean().cpu().item(), 16 | parem[1].grad.abs().std().cpu().item())) 17 | 18 | 19 | def write_video_cv2(allFrames, video_name, fps, sizes): 20 | out = cv2.VideoWriter(video_name, cv2.CAP_OPENCV_MJPEG, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, sizes) 21 | 22 | for outF in allFrames: 23 | # frameIn = cv2.imread(inF, cv2.IMREAD_COLOR) 24 | frameOut = cv2.imread(outF, cv2.IMREAD_COLOR) 25 | # frame = np.concatenate([frameIn, frameOut], axis=1) 26 | out.write(frameOut) 27 | out.release() 28 | 29 | 30 | if __name__ == '__main__': 31 | videoPath = '/home/sensetime/data/ICCV2021/OurResults/slomoDVS34_16/Ours_S2/slomoDVS-2021_02_24_11_48_40' 32 | 33 | allFrames = FT.getAllFiles(videoPath, 'png') 34 | inFrames = [a for a in allFrames if 'EVI' not in a] 35 | inFrames = [[a, a, a, a] for a in inFrames] 36 | inFrames = [item for sublist in inFrames for item in sublist] 37 | 38 | videoName = '/home/sensetime/data/ICCV2021/OurResults/slomoDVS34_16/Ours_S2/slomoDVS-2021_02_24_11_48_40/couple.avi' 39 | write_video_cv2(inFrames, allFrames, videoName, 24, (480, 176)) 40 | -------------------------------------------------------------------------------- /DBVI_2x/lib/distTool.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import os 3 | from configs.configTrain import configMain 4 | 5 | 6 | def synchronize(): 7 | """ 8 | Helper function to synchronize (barrier) among all processes when 9 | using distributed training 10 | """ 11 | if not dist.is_available(): 12 | return 13 | if not dist.is_initialized(): 14 | return 15 | world_size = dist.get_world_size() 16 | if world_size == 1: 17 | return 18 | dist.barrier() 19 | 20 | 21 | def get_world_size(): 22 | if not dist.is_available(): 23 | return 1 24 | if not dist.is_initialized(): 25 | return 1 26 | return dist.get_world_size() 27 | 28 | 29 | def reduceTensorMean(tensor): 30 | wordSize = os.environ['SLURM_NTASKS'] if 'SLURM_NTASKS' in os.environ else 1 31 | rt = tensor.clone() 32 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 33 | rt /= float(wordSize) 34 | return rt 35 | 36 | def reduceTensorSum(tensor): 37 | rt = tensor.clone() 38 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 39 | return rt 40 | -------------------------------------------------------------------------------- /DBVI_2x/lib/dlTool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.optim import lr_scheduler 4 | from configs.configTrain import configMain 5 | import math 6 | import torch 7 | from sklearn.mixture import GaussianMixture 8 | import numpy as np 9 | from torch.optim import Adam 10 | 11 | 12 | def getOptimizer(gNet: nn.Module, cfg: configMain): 13 | core_network_params = [] 14 | map_network_params = [] 15 | for k, v in gNet.named_parameters(): 16 | if v.requires_grad: 17 | if 'mapping' in k: 18 | map_network_params.append(v) 19 | else: 20 | core_network_params.append(v) 21 | optim = Adam([{'params': core_network_params, 'lr': cfg.optim.lrInit}, 22 | {'params': map_network_params, 'lr': 1e-2 * cfg.optim.lrInit}], 23 | weight_decay=0, betas=(0.9, 0.999)) 24 | for group in optim.param_groups: 25 | group.setdefault('initial_lr', group['lr']) 26 | maxPSNR = -1 27 | return optim, maxPSNR 28 | 29 | 30 | def getScheduler(optimizer, epoch=-1): 31 | return lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[110, 135], gamma=0.4, last_epoch=epoch) 32 | 33 | 34 | def compute_same_pad(kernel_size, stride): 35 | if isinstance(kernel_size, int): 36 | kernel_size = [kernel_size] 37 | 38 | if isinstance(stride, int): 39 | stride = [stride] 40 | 41 | assert len(stride) == len( 42 | kernel_size 43 | ), "Pass kernel size and stride both as int, or both as equal length iterable" 44 | 45 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] 46 | 47 | 48 | def fitFullConvGaussian(zAll: torch.Tensor): 49 | zAllNpy = zAll.detach().cpu().numpy() 50 | selectIdx = np.random.choice(zAllNpy.shape[0], size=100, replace=False) 51 | zAllNpySellect = zAllNpy[selectIdx, ...] 52 | gmm = GaussianMixture(n_components=10, covariance_type='full').fit(zAllNpySellect) 53 | return gmm 54 | -------------------------------------------------------------------------------- /DBVI_2x/lib/fileTool.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | from queue import Queue 4 | 5 | 6 | def delPath(root): 7 | """ 8 | remove dir trees 9 | :param path: root dir 10 | :return: True or False 11 | """ 12 | root = Path(root) 13 | if root.is_file(): 14 | try: 15 | root.unlink() 16 | except Exception as e: 17 | print(e) 18 | elif root.is_dir(): 19 | for item in root.iterdir(): 20 | delPath(item) 21 | try: 22 | root.rmdir() 23 | # print('Files in {} is removed'.format(root)) 24 | except Exception as e: 25 | print(e) 26 | 27 | 28 | def mkPath(path): 29 | p = Path(path) 30 | try: 31 | p.mkdir(parents=True, exist_ok=False) 32 | return True 33 | except Exception as e: 34 | return False 35 | 36 | 37 | def getAllFiles(root, ext=None): 38 | p = Path(root) 39 | if ext is not None: 40 | pathnames = p.glob("**/*.{}".format(ext)) 41 | 42 | else: 43 | pathnames = p.glob("**/*") 44 | filenames = sorted([x.as_posix() for x in pathnames]) 45 | return filenames 46 | 47 | 48 | def copyFile(src, dst): 49 | parent = Path(dst).parent 50 | mkPath(parent) 51 | try: 52 | shutil.copytree(str(src), str(dst)) 53 | except: 54 | shutil.copy(str(src), str(dst)) 55 | 56 | 57 | def movFile(src, dst): 58 | parent = Path(dst).parent 59 | mkPath(parent) 60 | shutil.move(str(src), str(dst)) 61 | 62 | 63 | def getSubDirs(root): 64 | p = Path(root) 65 | dirs = [x for x in p.iterdir() if x.is_dir()] 66 | return sorted(dirs) 67 | 68 | 69 | class fileBuffer(Queue): 70 | def __init__(self, capacity): 71 | super(fileBuffer, self).__init__() 72 | assert capacity > 0 73 | self.capacity = capacity 74 | 75 | def __call__(self, x): 76 | while self.qsize() >= self.capacity: 77 | delPath(self.get()) 78 | self.put(x) 79 | 80 | 81 | if __name__ == '__main__': 82 | path = '/home/sensetime/project/lib/a' 83 | # mkPath(path) 84 | filenames = delPath(path) 85 | pass 86 | -------------------------------------------------------------------------------- /DBVI_2x/lib/lossTool.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class l1Loss(nn.Module): 7 | def __init__(self, reduction='mean'): 8 | super(l1Loss, self).__init__() 9 | self.reduction = reduction 10 | 11 | def forward(self, pred, gt, mask=None): 12 | if mask is None: 13 | return F.l1_loss(pred, gt, reduction=self.reduction) 14 | else: 15 | return F.l1_loss(pred * mask, gt * mask, reduction=self.reduction) 16 | 17 | def __repr__(self): 18 | return 'l1Loss' 19 | 20 | 21 | class l2Loss(nn.Module): 22 | def __init__(self, reduction='mean'): 23 | super(l2Loss, self).__init__() 24 | self.reduction = reduction 25 | 26 | def forward(self, pred, gt, mask=None): 27 | if mask is None: 28 | return F.mse_loss(pred, gt, reduction=self.reduction) 29 | else: 30 | return F.mse_loss(pred * mask, gt * mask, reduction=self.reduction) 31 | 32 | def __repr__(self): 33 | return 'l2Loss' 34 | 35 | 36 | class totalLoss(nn.Module): 37 | def __init__(self): 38 | super(totalLoss, self).__init__() 39 | self.l1loss = l1Loss() 40 | self.scale = [1, 1, 1, 1, 1, 1] 41 | self.lossNames = ['l1'] 42 | self.lossDict = {} 43 | 44 | def forward(self, rgbs, gts): 45 | self.lossDict.clear() 46 | l1loss = 0 47 | if len(rgbs) > 1: 48 | assert all([len(self.scale) == len(rgbs), len(rgbs) == len(gts)]) 49 | scales = self.scale 50 | else: 51 | scales = self.scale[-len(rgbs)::] 52 | for scale, rgb, gt in zip(scales, rgbs, gts): 53 | l1loss += self.l1loss(rgb, gt.detach()).mean() * scale * 1 54 | 55 | lossSum = l1loss 56 | 57 | self.lossDict.setdefault('l1', l1loss.detach()) 58 | self.lossDict.setdefault('Total', lossSum.detach()) 59 | 60 | return lossSum 61 | 62 | 63 | if __name__ == '__main__': 64 | a = torch.randn([1, 3, 128, 128]) 65 | b = torch.randn([1, 3, 128, 128]) 66 | mask = torch.ones_like(a) 67 | 68 | loss = totalLoss() 69 | 70 | lossOut = loss(a, b, mask) 71 | lossOut2 = loss(a, b, mask) 72 | pass 73 | -------------------------------------------------------------------------------- /DBVI_2x/lib/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from skimage.measure import compare_psnr, compare_ssim 4 | import torch.nn.functional as F 5 | from math import exp 6 | from torch import nn 7 | 8 | 9 | class AverageMeter(object): 10 | """Computes and stores the average and current value""" 11 | 12 | def __init__(self): 13 | self.reset() 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def reset(self): 20 | self.val = 0 21 | self.avg = 0 22 | self.sum = 0 23 | self.count = 0 24 | 25 | def update(self, val, n=1): 26 | self.val = val 27 | self.sum += val * n 28 | self.count += n 29 | self.avg = self.sum / self.count 30 | 31 | 32 | class skiMetric(object): 33 | def __init__(self, drange=1.0): 34 | super(skiMetric, self).__init__() 35 | self.range = drange 36 | 37 | def psnr(self, img: np.ndarray, gt: np.ndarray): 38 | """ 39 | :param img:H, W, C 40 | :param gt: H, W, C 41 | :return: 42 | """ 43 | assert img.shape == gt.shape, 'img shape != gt shape' 44 | assert 3 == img.ndim 45 | return compare_psnr(img, gt, data_range=self.range) 46 | 47 | def ssim(self, img: np.ndarray, gt: np.ndarray): 48 | """ 49 | :param img:H, W, C 50 | :param gt: H, W, C 51 | :return: 52 | """ 53 | assert img.shape == gt.shape, 'img shape != gt shape' 54 | assert 3 == img.ndim 55 | return compare_ssim(img, gt, data_range=self.range, multichannel=True, gaussian_weights=True) 56 | 57 | 58 | class torchMetric(nn.Module): 59 | def __init__(self, reduce='sum', device=torch.device('cpu')): 60 | super(torchMetric, self).__init__() 61 | self.range = 1.0 62 | self.reduce = reduce 63 | self.window_size = 11 64 | self.device = device 65 | self.register_buffer('window', self.initWindow()) 66 | 67 | def forward(self, pred: torch.Tensor, gt: torch.Tensor, ifssim=True): 68 | N, C, H, W = pred.shape 69 | pred = self.reRange(pred) 70 | gt = self.reRange(gt) 71 | 72 | psnr = self.psnr(pred.detach(), gt.detach()) 73 | if ifssim: 74 | ssim = self.ssim(pred.detach(), gt.detach()) 75 | else: 76 | ssim=psnr 77 | 78 | return psnr, ssim, N 79 | 80 | # @torch.jit.script 81 | def psnr(self, img: torch.Tensor, gt: torch.Tensor): 82 | assert all([4 == img.ndimension(), 4 == gt.ndimension(), img.shape == gt.shape]) 83 | 84 | mse = ((img - gt.detach()) ** 2).mean(dim=[1, 2, 3]) # N, 1 85 | psnrBatch = 10 * torch.log10(self.range ** 2 * mse.reciprocal()) # N, 1 86 | 87 | if self.reduce == 'sum': 88 | psnrBatch = psnrBatch.sum(dim=0) 89 | elif self.reduce == 'mean': 90 | psnrBatch = psnrBatch.mean() 91 | 92 | return psnrBatch 93 | 94 | # @torch.jit.script 95 | def ssim(self, img: torch.Tensor, gt: torch.Tensor): 96 | padd = 0 97 | img = img.unsqueeze(1) 98 | gt = gt.unsqueeze(1) 99 | 100 | mu1 = F.conv3d(F.pad(img, (5, 5, 5, 5, 5, 5), mode='replicate'), self.window, padding=padd, groups=1) 101 | mu2 = F.conv3d(F.pad(gt, (5, 5, 5, 5, 5, 5), mode='replicate'), self.window, padding=padd, groups=1) 102 | 103 | mu1_sq = mu1.pow(2) 104 | mu2_sq = mu2.pow(2) 105 | mu1_mu2 = mu1 * mu2 106 | 107 | sigma1_sq = F.conv3d(F.pad(img * img, (5, 5, 5, 5, 5, 5), 'replicate'), self.window, padding=padd, 108 | groups=1) - mu1_sq 109 | sigma2_sq = F.conv3d(F.pad(gt * gt, (5, 5, 5, 5, 5, 5), 'replicate'), self.window, padding=padd, 110 | groups=1) - mu2_sq 111 | sigma12 = F.conv3d(F.pad(img * gt, (5, 5, 5, 5, 5, 5), 'replicate'), self.window, padding=padd, 112 | groups=1) - mu1_mu2 113 | 114 | # C1 = (0.01 * self.range) ** 2 115 | C1 = 0.01 ** 2 116 | C2 = 0.03 ** 2 117 | 118 | v1 = 2.0 * sigma12 + C2 119 | v2 = sigma1_sq + sigma2_sq + C2 120 | cs = torch.mean(v1 / v2) 121 | 122 | ssim_map: torch.Tensor = (((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)).mean(dim=[1, 2, 3, 4]) 123 | 124 | if self.reduce == 'sum': 125 | ssim_map = ssim_map.sum(dim=0) 126 | elif self.reduce == 'mean': 127 | ssim_map = ssim_map.mean() 128 | return ssim_map 129 | 130 | def initWindow(self): 131 | wSize = self.window_size 132 | sigma = 1.5 133 | gaussList = [exp(-(x - wSize // 2) ** 2 / float(2 * sigma ** 2)) for x in range(wSize)] 134 | gauss = torch.tensor(gaussList) 135 | _1D_window = (gauss / gauss.sum()).unsqueeze(1) 136 | _2D_window = _1D_window.mm(_1D_window.t()) 137 | _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t()) 138 | return _3D_window.expand(1, 1, wSize, wSize, wSize).contiguous() 139 | 140 | def reRange(self, img: torch.Tensor): 141 | if img.max() > 128: 142 | img = img.float() / 255.0 143 | if img.min() < -0.5: 144 | img = (img.float() + 1.0) / 2.0 145 | return img 146 | -------------------------------------------------------------------------------- /DBVI_2x/lib/videoTool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lib.fileTool as FT 3 | 4 | ffmpegPath = '/usr/bin/ffmpeg' 5 | videoName = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/bbb.MP4' 6 | 7 | 8 | def video2Frame(vPath: str, fdir: str, H: int = None, W: int = None): 9 | FT.mkPath(fdir) 10 | if H is None or W is None: 11 | os.system('{} -y -i {} -vsync 0 -qscale:v 2 {}/%04d.png'.format(ffmpegPath, vPath, fdir)) 12 | else: 13 | os.system('{} -y -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.jpg'.format(ffmpegPath, vPath, W, H, fdir)) 14 | 15 | 16 | def frame2Video(fdir: str, vPath: str, fps: int, H: int = None, W: int = None, ): 17 | if H is None or W is None: 18 | # os.system('{} -y -r {} -f image2 -i {}/%*.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 19 | # .format(ffmpegPath, fps, fdir, vPath)) 20 | 21 | os.system('{} -y -r {} -f image2 -i {}/%4d.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 22 | .format(ffmpegPath, fps, fdir, vPath)) 23 | else: 24 | os.system('{} -y -r {} -f image2 -s {}x{} -i {}/%*.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {}' 25 | .format(ffmpegPath, fps, W, H, fdir, vPath)) 26 | 27 | 28 | def slomo(vPath: str, dstPath: str, fps): 29 | os.system( 30 | '{} -y -r {} -i {} -strict -2 -vcodec libx264 -c:a aac -crf 18 {}'.format(ffmpegPath, fps, vPath, dstPath)) 31 | 32 | 33 | def downFPS(vPath: str, dstPath: str, fps): 34 | os.system( 35 | '{} -i {} -strict -2 -r {} {}'.format(ffmpegPath, vPath, fps, dstPath)) 36 | 37 | 38 | def downSample(vPath: str, dstPath: str, H, W): 39 | os.system( 40 | '{} -i {} -strict -2 -s {}x{} {}'.format(ffmpegPath, vPath, H, W, dstPath)) 41 | 42 | 43 | if __name__ == '__main__': 44 | framePath = '/data/2021_12_23/dstframes' 45 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/output' 46 | video = '/data/2021_12_23/02.mp4' 47 | # video2Frame(video, framePath) 48 | 49 | # video = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00/out.mp4' 50 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00' 51 | frame2Video(framePath, video, 30) 52 | 53 | # vPath = '/media/sensetime/Elements/0721 /0716_video/1.avi' 54 | # dstPath = '/media/sensetime/Elements/0721/0716_video/1_.mp4' 55 | # downFPS(vPath, dstPath, 8) 56 | -------------------------------------------------------------------------------- /DBVI_2x/lib/visualTool.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | import torch 6 | import numpy as np 7 | from pathlib import Path 8 | from lib import fileTool as FT 9 | 10 | 11 | class visFlow(): 12 | def __init__(self, flow_uv): 13 | super(visFlow, self).__init__() 14 | self.colorWheel = self.make_colorwheel() 15 | self.run(flow_uv) 16 | 17 | def run(self, flow_uv: torch.Tensor, clip_flow=None, convert_to_bgr=False): 18 | flow_uv = flow_uv[0].permute([1, 2, 0]).detach().cpu().numpy() 19 | if clip_flow is not None: 20 | flow_uv = np.clip(flow_uv, 0, clip_flow) 21 | u = flow_uv[:, :, 0] 22 | v = flow_uv[:, :, 1] 23 | rad = np.sqrt(np.square(u) + np.square(v)) 24 | rad_max = np.max(rad) 25 | epsilon = 1e-5 26 | u = u / (rad_max + epsilon) 27 | v = v / (rad_max + epsilon) 28 | rgb = self.flow_uv_to_colors(u, v, convert_to_bgr) 29 | cv2.namedWindow('flow', 0) 30 | cv2.imshow('flow', rgb[:, :, [2, 1, 0]]) 31 | cv2.waitKey(0) 32 | 33 | def make_colorwheel(self): 34 | RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6) 35 | ncols = RY + YG + GC + CB + BM + MR 36 | colorwheel = np.zeros((ncols, 3)) 37 | col = 0 38 | 39 | # RY 40 | colorwheel[0:RY, 0] = 255 41 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 42 | col = col + RY 43 | # YG 44 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 45 | colorwheel[col:col + YG, 1] = 255 46 | col = col + YG 47 | # GC 48 | colorwheel[col:col + GC, 1] = 255 49 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 50 | col = col + GC 51 | # CB 52 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 53 | colorwheel[col:col + CB, 2] = 255 54 | col = col + CB 55 | # BM 56 | colorwheel[col:col + BM, 2] = 255 57 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 58 | col = col + BM 59 | # MR 60 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 61 | colorwheel[col:col + MR, 0] = 255 62 | return colorwheel 63 | 64 | def flow_uv_to_colors(self, u, v, convert_to_bgr=False): 65 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 66 | ncols = self.colorWheel.shape[0] 67 | rad = np.sqrt(np.square(u) + np.square(v)) 68 | a = np.arctan2(-v, -u) / np.pi 69 | fk = (a + 1) / 2 * (ncols - 1) 70 | k0 = np.floor(fk).astype(np.int32) 71 | k1 = k0 + 1 72 | k1[k1 == ncols] = 0 73 | f = fk - k0 74 | for i in range(self.colorWheel.shape[1]): 75 | tmp = self.colorWheel[:, i] 76 | col0 = tmp[k0] / 255.0 77 | col1 = tmp[k1] / 255.0 78 | col = (1 - f) * col0 + f * col1 79 | idx = (rad <= 1) 80 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 81 | col[~idx] = col[~idx] * 0.75 # out of range 82 | # Note the 2-i => BGR instead of RGB 83 | ch_idx = 2 - i if convert_to_bgr else i 84 | flow_image[:, :, ch_idx] = np.floor(255 * col) 85 | return flow_image 86 | 87 | 88 | class tensor2Video(object): 89 | def __init__(self, outPath, h, w, fps=24): 90 | super(tensor2Video, self).__init__() 91 | fourcc = cv2.VideoWriter.fourcc('I', '4', '2', '0') 92 | self.out = cv2.VideoWriter(outPath, fourcc, fps, (w, h)) 93 | 94 | def add(self, frame: torch.Tensor): 95 | frame = (frame + 1.0) / 2.0 96 | # frame = (frame - frame.min()) / (frame.max() - frame.min()) 97 | frame = (frame[0].permute([1, 2, 0]).cpu().numpy() * 255).astype(np.uint8) 98 | self.out.write(frame) 99 | 100 | def release(self): 101 | self.out.release() 102 | 103 | 104 | def makeGrid(batchImg: torch.Tensor, shape=(2, 1)): 105 | N, C, H, W = batchImg.shape 106 | batchImg = batchImg.permute([0, 2, 3, 1]) 107 | batchImg = batchImg.detach().cpu().numpy() 108 | nh = shape[0] 109 | nw = shape[1] 110 | batchImg = batchImg.reshape((nh, nw, H, W, C)).swapaxes(1, 2).reshape(nh * H, nw * W, C) 111 | return batchImg 112 | 113 | 114 | def visImg(batchImg: torch.Tensor, shape=(1, 1), wait=0, name='visImg'): 115 | """ 116 | :param img: tensor(N,3,H,W) or None 117 | :return: None 118 | """ 119 | N, C, H, W = batchImg.shape 120 | assert all([C == 3, N == shape[0] * shape[1]]) 121 | # batchImg = ((batchImg - batchImg.min()) / (batchImg.max() - batchImg.min()) * 255.0).byte() 122 | batchImg = ((batchImg.float() + 1) / 2.0 * 255).clamp(0, 255).byte() 123 | batchImgViz = makeGrid(batchImg, shape) 124 | cv2.namedWindow(name, 0) 125 | cv2.imshow(name, batchImgViz[:, :, ::-1]) 126 | cv2.waitKey(wait) 127 | 128 | 129 | def saveImg(x: torch.Tensor, outpath, isrgb=True): 130 | if Path(outpath).is_file(): 131 | return None 132 | rmax, rmin = x.max(), x.min() 133 | if rmin < -0.5: 134 | x = (x + 1.0) / 2.0 135 | x = x.clamp(0, 1) 136 | if isrgb: 137 | x = x[:, [2, 1, 0], :, :] 138 | xNpy = (x.squeeze(0).permute([1, 2, 0]) * 255).byte().detach().cpu().numpy() 139 | FT.mkPath(Path(outpath).parent) 140 | 141 | cv2.imwrite(str(outpath), xNpy) 142 | 143 | 144 | def saveTensor(x: torch.Tensor, srcName: str, dstDir: str, isInter=False): 145 | if isInter: 146 | dstName = str(Path(dstDir) / srcName.replace('.png', '_inter.pth')) 147 | else: 148 | dstName = str(Path(dstDir) / srcName.replace('.png', '.pth')) 149 | 150 | if Path(dstName).is_file(): 151 | return False 152 | if not Path(Path(dstName).parent).is_dir(): 153 | FT.mkPath(Path(dstName).parent) 154 | 155 | # xRGB = (x + 1.0) / 2.0 156 | xRGB = x.detach().cpu() 157 | torch.save(xRGB, dstName) 158 | return True 159 | -------------------------------------------------------------------------------- /DBVI_2x/lib/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.autograd import grad 5 | from lib.softsplat import FunctionSoftsplat 6 | 7 | 8 | def backWarp(img: torch.Tensor, flow: torch.Tensor): 9 | device = img.device 10 | N, C, H, W = img.size() 11 | 12 | u = flow[:, 0, :, :] 13 | v = flow[:, 1, :, :] 14 | 15 | gridY, gridX = torch.meshgrid([torch.arange(start=0, end=H, device=device, requires_grad=False), 16 | torch.arange(start=0, end=W, device=device, requires_grad=False)]) 17 | 18 | x = gridX.unsqueeze(0).expand_as(u).float().detach() + u 19 | y = gridY.unsqueeze(0).expand_as(v).float().detach() + v 20 | 21 | # range -1 to 1 22 | x = 2 * x / (W - 1.0) - 1.0 23 | y = 2 * y / (H - 1.0) - 1.0 24 | # stacking X and Y 25 | grid = torch.stack((x, y), dim=3) 26 | 27 | imgOut = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros', align_corners=True) 28 | 29 | mask = torch.ones_like(img, requires_grad=False) 30 | 31 | mask = F.grid_sample(mask, grid, mode='bilinear', padding_mode='zeros', align_corners=True) 32 | 33 | mask[mask < 0.9999] = 0 34 | mask[mask > 0] = 1 35 | 36 | return imgOut * (mask.detach()), mask.detach() 37 | # return imgOut 38 | 39 | 40 | class ModuleSoftsplat(torch.nn.Module): 41 | def __init__(self, strType='average'): 42 | super().__init__() 43 | 44 | self.strType = strType 45 | 46 | def forward(self, tenInput, tenFlow, tenMetric): 47 | return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) 48 | 49 | 50 | class fidelityGradCuda(nn.Module): 51 | def __init__(self, res=True): 52 | super(fidelityGradCuda, self).__init__() 53 | self.fWarp = ModuleSoftsplat() 54 | 55 | def forward(self, It: torch.Tensor, I0: torch.Tensor, F0t: torch.Tensor): 56 | self.device = I0.device 57 | It0, mask = backWarp(It, F0t) 58 | grad_ll0 = (I0 - It0) # grad(y=0.5*(I0 - It0)^2, x=It0) 59 | 60 | totalGrad = grad_ll0 * mask 61 | warpGrad = self.fWarp(tenInput=totalGrad, tenFlow=F0t, tenMetric=None) 62 | 63 | return warpGrad 64 | 65 | 66 | class fidelityGradTorch(nn.Module): 67 | def __init__(self): 68 | super(fidelityGradTorch, self).__init__() 69 | 70 | def forward(self, It: torch.Tensor, I0: torch.Tensor, F0t: torch.Tensor): 71 | with torch.enable_grad(): 72 | It.requires_grad_() 73 | It0, mask = backWarp(It, F0t) 74 | loss = -(0.5 * (I0 * mask - It0 * mask) ** 2).sum() 75 | warpGrad = grad(loss, It, create_graph=True)[0] 76 | 77 | return warpGrad -------------------------------------------------------------------------------- /DBVI_2x/model/invBlock/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_2x/model/invBlock/__init__.py -------------------------------------------------------------------------------- /DBVI_2x/model/invBlock/permute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class TransConv1x1(nn.Module): 8 | def __init__(self, inCh): 9 | super(TransConv1x1, self).__init__() 10 | self.w_shape = [inCh, inCh] 11 | w_init = np.linalg.qr(np.random.randn(*self.w_shape))[0].astype(np.float32) 12 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) 13 | 14 | def get_weight(self, x, rev): 15 | b, c, h, w = x.shape 16 | dlogdet = torch.slogdet(self.weight)[1] * h * w # slogdet(A) = torch.log(torch.abs(torch.det(A))) 17 | 18 | if not rev: 19 | weight = self.weight 20 | else: 21 | weight = self.weight.t() 22 | 23 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 24 | 25 | def forward(self, x, logdet=None, rev=False): 26 | """ 27 | log-det = log|abs(|W|)| * pixels 28 | """ 29 | logdet = 0.0 if logdet is None else logdet 30 | weight, dlogdet = self.get_weight(x, rev) 31 | z = F.conv2d(x, weight) 32 | if not rev: 33 | logdet = logdet + dlogdet 34 | else: 35 | logdet = logdet - dlogdet 36 | 37 | return z, logdet 38 | 39 | 40 | class InvConv1x1(nn.Module): 41 | def __init__(self, inCh): 42 | super(InvConv1x1, self).__init__() 43 | self.w_shape = [inCh, inCh] 44 | w_init = np.linalg.qr(np.random.randn(*self.w_shape))[0].astype(np.float32) 45 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) 46 | 47 | def get_weight(self, x, rev): 48 | b, c, h, w = x.shape 49 | dlogdet = torch.slogdet(self.weight)[1] * h * w # slogdet(A) = torch.log(torch.abs(torch.det(A))) 50 | 51 | if not rev: 52 | weight = self.weight 53 | else: 54 | weight = torch.inverse(self.weight) 55 | 56 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 57 | 58 | def forward(self, x, logdet=None, rev=False): 59 | """ 60 | log-det = log|abs(|W|)| * pixels 61 | """ 62 | logdet = 0.0 if logdet is None else logdet 63 | weight, dlogdet = self.get_weight(x, rev) 64 | z = F.conv2d(x, weight) 65 | if not rev: 66 | logdet = logdet + dlogdet 67 | else: 68 | logdet = logdet - dlogdet 69 | 70 | return z, logdet 71 | 72 | 73 | class Permute2d(nn.Module): 74 | def __init__(self, inCh, shuffle=True): 75 | super().__init__() 76 | self.inCh = inCh 77 | # self.indices = torch.arange(self.inCh - 1, -1, -1, dtype=torch.long) 78 | # self.indices_inverse = torch.zeros(self.inCh, dtype=torch.long) 79 | self.register_buffer('indices', torch.arange(self.inCh - 1, -1, -1, dtype=torch.long)) 80 | self.register_buffer('indices_inverse', torch.zeros(self.inCh, dtype=torch.long)) 81 | 82 | for i in range(self.inCh): 83 | self.indices_inverse[self.indices[i]] = i 84 | 85 | if shuffle: 86 | self.reset_indices() 87 | 88 | def reset_indices(self): 89 | shuffle_idx = torch.randperm(self.indices.shape[0]) 90 | self.indices = self.indices[shuffle_idx] 91 | 92 | for i in range(self.inCh): 93 | self.indices_inverse[self.indices[i]] = i 94 | 95 | def forward(self, x, logdet, rev=False): 96 | assert len(x.size()) == 4 97 | 98 | if not rev: 99 | x = x[:, self.indices, :, :] 100 | else: 101 | x = x[:, self.indices_inverse, :, :] 102 | 103 | return x, logdet 104 | 105 | 106 | class InvConvLU1x1(nn.Module): # is not recommended 107 | def __init__(self, inCh): 108 | super(InvConvLU1x1, self).__init__() 109 | w_shape = [inCh, inCh] 110 | w_init = torch.qr(torch.randn(*w_shape))[0] 111 | 112 | p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) 113 | s = torch.diag(upper) 114 | sign_s = torch.sign(s) 115 | log_s = torch.log(torch.abs(s)) 116 | upper = torch.triu(upper, 1) 117 | l_mask = torch.tril(torch.ones(w_shape), -1) 118 | eye = torch.eye(*w_shape) 119 | 120 | self.register_buffer('p', p) # .cuda() will work only on register_buffer 121 | self.register_buffer('sign_s', sign_s) 122 | self.register_buffer('l_mask', l_mask) 123 | self.register_buffer('eye', eye) 124 | 125 | self.lower = nn.Parameter(lower) 126 | self.log_s = nn.Parameter(log_s) 127 | self.upper = nn.Parameter(upper) 128 | 129 | self.w_shape = w_shape 130 | 131 | def get_weight(self, x, rev): 132 | b, c, h, w = x.shape 133 | 134 | lower = self.lower * self.l_mask + self.eye 135 | 136 | u = self.upper * self.l_mask.transpose(0, 1).contiguous() 137 | u += torch.diag(self.sign_s * torch.exp(self.log_s)) 138 | 139 | dlogdet = torch.sum(self.log_s) * h * w 140 | if not rev: 141 | weight = torch.matmul(self.p, torch.matmul(lower, u)) 142 | else: 143 | u_inv = torch.inverse(u) 144 | l_inv = torch.inverse(lower) 145 | p_inv = torch.inverse(self.p) 146 | 147 | weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) 148 | 149 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 150 | 151 | def forward(self, x, logdet=None, rev=False): 152 | """ 153 | log-det = log|abs(|W|)| * pixels 154 | """ 155 | weight, dlogdet = self.get_weight(x, rev) 156 | z = F.conv2d(x, weight) 157 | 158 | if not rev: 159 | logdet = logdet + dlogdet 160 | else: 161 | logdet = logdet - dlogdet 162 | return z, logdet 163 | 164 | 165 | if __name__ == '__main__': 166 | x = torch.randn([1, 3, 16, 16], requires_grad=True) 167 | conv = InvConv1x1(inCh=3) 168 | y, _ = conv(x, logdet=0.0, rev=False) 169 | y.retain_grad() 170 | loss = ((y - 1) ** 2).sum() 171 | loss.backward() 172 | 173 | grady = y.grad 174 | gradx = x.grad 175 | 176 | gradx2, _ = conv(grady, logdet=0.0, rev=True) 177 | 178 | print((gradx - gradx2).mean()) 179 | 180 | pass 181 | -------------------------------------------------------------------------------- /DBVI_2x/model/module_util.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import init 3 | 4 | 5 | def randomInitNet(net_l, iniType='kaiming', scale=1.0): 6 | if not isinstance(net_l, list): 7 | net_l = [net_l] 8 | for net in net_l: 9 | for m in net.modules(): 10 | if any([isinstance(m, nn.Conv2d), isinstance(m, nn.ConvTranspose2d), isinstance(m, nn.Linear)]): 11 | if iniType == 'normal': 12 | init.normal_(m.weight, 0.0, 0.2) 13 | elif iniType == 'xavier': 14 | init.xavier_normal_(m.weight, gain=0.2) 15 | elif iniType == 'kaiming': 16 | init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') 17 | elif iniType == 'orthogonal': 18 | init.orthogonal_(m.weight, gain=0.2) 19 | elif iniType == 'default': 20 | pass 21 | 22 | if m.bias is not None: 23 | init.constant_(m.bias, 0.0) 24 | m.weight.data *= scale 25 | elif any([isinstance(m, nn.InstanceNorm2d), isinstance(m, nn.LocalResponseNorm), 26 | isinstance(m, nn.BatchNorm2d), isinstance(m, nn.GroupNorm)]): 27 | try: 28 | init.constant_(m.weight, 1.0) 29 | init.constant_(m.bias, 0.0) 30 | except Exception as e: 31 | pass 32 | -------------------------------------------------------------------------------- /DBVI_2x/output/output: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_2x/output/output -------------------------------------------------------------------------------- /DBVI_2x/runTest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | jobName = 'Test' 4 | part = 'Pixel' 5 | 6 | freeNodes = ['SH-IDC1-10-5-39-55','SH-IDC1-10-5-31-54'] 7 | gpuDict = "\"{\'SH-IDC1-10-5-39-55\': \'0,1,2,3,4\', \'SH-IDC1-10-5-31-54\': \'0,1,2,3,4\'}\"" 8 | 9 | 10 | ntaskPerNode = 5 # number of GPUs per nodes 11 | cpus_per_task = 4 12 | reuseGPU = 1 13 | envDistributed = 1 14 | 15 | nodeNum = len(freeNodes) 16 | nTasks = ntaskPerNode * nodeNum if envDistributed else 1 17 | nodeList = ','.join(freeNodes) 18 | initNode = freeNodes[0] 19 | 20 | scrip = 'test' 21 | config = 'configTest' 22 | 23 | 24 | def runDist(): 25 | pyCode = [] 26 | pyCode.append('python') 27 | pyCode.append('-m') 28 | pyCode.append(scrip) 29 | pyCode.append('--initNode {}'.format(initNode)) 30 | pyCode.append('--config {}'.format(config)) 31 | pyCode.append('--gpuList {}'.format(gpuDict)) 32 | pyCode.append('--reuseGPU {}'.format(reuseGPU)) 33 | pyCode.append('--expName {}'.format(jobName)) 34 | pyCode = ' '.join(pyCode) 35 | 36 | srunCode = [] 37 | srunCode.append('srun') 38 | srunCode.append('--gres=gpu:{}'.format(ntaskPerNode)) 39 | srunCode.append('--job-name={}'.format(jobName)) 40 | srunCode.append('--partition={}'.format(part)) 41 | srunCode.append('--nodelist={}'.format(nodeList)) if freeNodes is not None else print('Get node by slurm') 42 | srunCode.append('--ntasks={}'.format(nTasks)) 43 | srunCode.append('--nodes={}'.format(nodeNum)) 44 | srunCode.append(f'--ntasks-per-node={ntaskPerNode}') if envDistributed else print( 45 | 'ntasks-per-node is 1') 46 | srunCode.append(f'--cpus-per-task={cpus_per_task}') 47 | srunCode.append('--kill-on-bad-exit=1') 48 | srunCode.append('--mpi=pmi2') 49 | srunCode.append(pyCode) 50 | 51 | srunCode = ' '.join(srunCode) 52 | print(srunCode) 53 | os.system(srunCode) 54 | 55 | if __name__ == '__main__': 56 | runDist() 57 | -------------------------------------------------------------------------------- /DBVI_2x/runTrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | jobName = '2x_Vimeo' 4 | part = 'Pixel' 5 | 6 | freeNodes = ['SH-IDC1-10-5-39-55','SH-IDC1-10-5-31-54'] 7 | 8 | gpuDict = "\"{\'SH-IDC1-10-5-39-55\': \'0,1,2,3,4,5,6,7\', \'SH-IDC1-10-5-31-54\': \'0,1,2,3,4,5,6,7\'}\"" 9 | 10 | 11 | ntaskPerNode = 8 # number of GPUs per nodes 12 | cpus_per_task = 4 13 | reuseGPU = 1 14 | envDistributed = 1 15 | 16 | nodeNum = len(freeNodes) 17 | nTasks = ntaskPerNode * nodeNum if envDistributed else 1 18 | nodeList = ','.join(freeNodes) 19 | initNode = freeNodes[0] 20 | 21 | scrip = 'train' 22 | config = 'configTrain' 23 | 24 | 25 | def runDist(): 26 | pyCode = [] 27 | pyCode.append('python') 28 | pyCode.append('-m') 29 | pyCode.append(scrip) 30 | pyCode.append('--initNode {}'.format(initNode)) 31 | pyCode.append('--config {}'.format(config)) 32 | pyCode.append('--gpuList {}'.format(gpuDict)) 33 | pyCode.append('--reuseGPU {}'.format(reuseGPU)) 34 | pyCode.append('--expName {}'.format(jobName)) 35 | pyCode = ' '.join(pyCode) 36 | 37 | srunCode = [] 38 | srunCode.append('srun') 39 | srunCode.append('--gres=gpu:{}'.format(ntaskPerNode)) 40 | srunCode.append('--job-name={}'.format(jobName)) 41 | srunCode.append('--partition={}'.format(part)) 42 | srunCode.append('--nodelist={}'.format(nodeList)) if freeNodes is not None else print('Get node by slurm') 43 | srunCode.append('--ntasks={}'.format(nTasks)) 44 | srunCode.append('--nodes={}'.format(nodeNum)) 45 | srunCode.append(f'--ntasks-per-node={ntaskPerNode}') if envDistributed else print( 46 | 'ntasks-per-node is 1') 47 | srunCode.append(f'--cpus-per-task={cpus_per_task}') 48 | srunCode.append('--kill-on-bad-exit=1') 49 | srunCode.append('--mpi=pmi2') 50 | srunCode.append(pyCode) 51 | 52 | srunCode = ' '.join(srunCode) 53 | print(srunCode) 54 | 55 | os.system(srunCode) 56 | 57 | 58 | if __name__ == '__main__': 59 | runDist() 60 | -------------------------------------------------------------------------------- /DBVI_8x/dataloader/Adobe.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | from dataloader import cvtransform 6 | import torch.utils.data as data 7 | 8 | from dataloader.dataloaderBase import DistributedSamplerVali 9 | from configs.configTrain import configMain 10 | import pickle 11 | import lmdb 12 | from lib.visualTool import visImg 13 | 14 | 15 | class Test(data.Dataset): 16 | def __init__(self, cfg: configMain): 17 | super(Test, self).__init__() 18 | self.Sample = './datasets/Adobe/Adobe_lmdb/sample.pkl' 19 | self.LMDB = './datasets/Adobe/Adobe_lmdb/data.mdb' 20 | self.numIter = cfg.numIter 21 | 22 | with open(self.Sample, 'rb') as fs: 23 | self.Sample = pickle.load(fs) 24 | 25 | self.length = len(self.Sample) # 630 26 | # self.length = 1 27 | self.transforms = cvtransform.Compose([ 28 | # cvtransform.CenterCrop(cfg.test.size), 29 | cvtransform.ToTensor() 30 | ]) 31 | self.env = None 32 | self.txn = None 33 | self.outKeys = ['In', 'I0', 'I1', 'I2', 'It1', 'It2', 'It3', 'It4', 'It5', 'It6', 'It7'] 34 | 35 | def __getitem__(self, idx): 36 | cv2.setNumThreads(0) 37 | cv2.ocl.setUseOpenCL(False) 38 | if any([self.txn is None, self.env is None]): 39 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 40 | meminit=False) 41 | self.txn = self.env.begin(write=False) 42 | 43 | sampleKeys = self.Sample[idx] 44 | 45 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 46 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 47 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 48 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 49 | 50 | valueList = [In, I0, I1, I2] 51 | watchList = [sampleKeys['In'], sampleKeys['I0'], sampleKeys['I1'], sampleKeys['I2']] 52 | 53 | gtList = [] 54 | for i in range(self.numIter): 55 | It = pickle.loads(self.txn.get(sampleKeys[f'It{i + 1}'].encode('ascii'))) 56 | valueList.append(It) 57 | gtList.append(sampleKeys[f'It{i + 1}']) 58 | valueList = self.transforms(valueList) 59 | 60 | outDict = {} 61 | for key, value in zip(self.outKeys, valueList): 62 | outDict[key] = value 63 | 64 | return outDict, watchList, gtList 65 | 66 | def __len__(self): 67 | return self.length 68 | 69 | 70 | def creatValiLoader(cfg: configMain): 71 | dataset = Test(cfg) 72 | 73 | # if cfg.dist.isDist: 74 | sampler = DistributedSamplerVali(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 75 | 76 | loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, 77 | num_workers=2, pin_memory=True, # False if memory is not enough 78 | drop_last=False, sampler=sampler) 79 | return sampler, loader 80 | 81 | 82 | if __name__ == "__main__": 83 | cfg = configMain() 84 | testSampler, testLoader = creatValiLoader(cfg) 85 | for valdict, watchList, gtList in testLoader: 86 | # visImg(valdict['In'], wait=100) 87 | # visImg(valdict['I0'], wait=100) 88 | # visImg(valdict['I1'], wait=100) 89 | # visImg(valdict['I2'], wait=100) 90 | 91 | visImg(valdict['I0'], wait=100) 92 | visImg(valdict['It1'], wait=100) 93 | visImg(valdict['It2'], wait=100) 94 | visImg(valdict['It3'], wait=100) 95 | visImg(valdict['It4'], wait=100) 96 | visImg(valdict['It5'], wait=100) 97 | visImg(valdict['It6'], wait=100) 98 | visImg(valdict['It7'], wait=100) 99 | # visImg(valdict['I1'], wait=100) 100 | pass 101 | -------------------------------------------------------------------------------- /DBVI_8x/dataloader/GoPro.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | from dataloader import cvtransform 6 | import torch 7 | import torch.utils.data as data 8 | import torch.nn.functional as F 9 | from torch.utils.data.distributed import DistributedSampler 10 | from dataloader.dataloaderBase import DistributedSamplerVali 11 | from configs.configTrain import configMain 12 | import pickle 13 | import lmdb 14 | import numpy as np 15 | 16 | 17 | # train------------------------------------------------------------------- 18 | class Train(data.Dataset): 19 | def __init__(self, cfg: configMain): 20 | self.Sample = './datasets/GoPro/gopro_train_lmdb/sample.pkl' 21 | self.LMDB = './datasets/GoPro/gopro_train_lmdb/data.mdb' 22 | 23 | self.numIter = cfg.numIter 24 | 25 | with open(self.Sample, 'rb') as fs: 26 | self.Sample = pickle.load(fs) 27 | 28 | self.length = len(self.Sample) #1500 29 | 30 | self.transforms = cvtransform.Compose([ 31 | cvtransform.RandomCrop(cfg.train.size), 32 | cvtransform.RandomHorizontalFlip(0.5), 33 | cvtransform.RandomHVerticalFlip(0.5), 34 | cvtransform.ColorJitter(0.05, 0.05, 0.05, 0.05), 35 | cvtransform.ToTensor() 36 | ]) 37 | self.env = None 38 | self.txn = None 39 | self.outKeys = ['In', 'I0', 'It', 'I1', 'I2'] 40 | 41 | def __getitem__(self, idx): 42 | cv2.setNumThreads(0) 43 | cv2.ocl.setUseOpenCL(False) 44 | if any([self.txn is None, self.env is None]): 45 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 46 | meminit=False) 47 | self.txn = self.env.begin(write=False) 48 | 49 | sampleKeys = self.Sample[idx] 50 | 51 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 52 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 53 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 54 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 55 | 56 | t = np.random.randint(1, self.numIter + 1) 57 | It = pickle.loads(self.txn.get(sampleKeys[f'It{t}'].encode('ascii'))) 58 | t = torch.tensor(t / float(self.numIter + 1.0), dtype=torch.float32) 59 | 60 | if np.random.rand() > 0.5: 61 | valueList = [In, I0, It, I1, I2] 62 | else: 63 | valueList = [I2, I1, It, I0, In] 64 | t = 1 - t 65 | 66 | valueList = self.transforms(valueList) 67 | 68 | outDict = {'t': t} 69 | for key, value in zip(self.outKeys, valueList): 70 | outDict[key] = value 71 | return outDict 72 | 73 | def __len__(self): 74 | return self.length 75 | 76 | def close(self): 77 | if self.env is not None: 78 | self.env.close() 79 | self.txn = None 80 | self.env = None 81 | 82 | 83 | def creatTrainLoader(cfg: configMain): 84 | dataset = Train(cfg) 85 | 86 | sampler = DistributedSampler(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 87 | 88 | loader = data.DataLoader(dataset=dataset, batch_size=cfg.train.batchPerGPU, 89 | shuffle=False, num_workers=4, pin_memory=True, # False if memory is not enough 90 | drop_last=True, sampler=sampler) 91 | return sampler, loader 92 | 93 | 94 | class celebAValidation(data.Dataset): 95 | def __init__(self, cfg: configMain): 96 | super(celebAValidation, self).__init__() 97 | self.Sample = './datasets/GoPro/gopro_test_lmdb/sample.pkl' 98 | # self.Sample = './datasets/GoPro/gopro_test_lmdb/sample4x.pkl' # for testing 4x interpolation 99 | self.LMDB = './datasets/GoPro/gopro_test_lmdb/data.mdb' 100 | self.numIter = cfg.numIter 101 | 102 | with open(self.Sample, 'rb') as fs: 103 | self.Sample = pickle.load(fs) 104 | self.length = len(self.Sample) 105 | self.transforms = cvtransform.Compose([ 106 | # cvtransform.CenterCrop(cfg.test.size), 107 | cvtransform.ToTensor() 108 | ]) 109 | self.env = None 110 | self.txn = None 111 | self.outKeys = ['In', 'I0', 'I1', 'I2', 'It1', 'It2', 'It3', 'It4', 'It5', 'It6', 'It7'] 112 | 113 | def __getitem__(self, idx): 114 | cv2.setNumThreads(0) 115 | cv2.ocl.setUseOpenCL(False) 116 | if any([self.txn is None, self.env is None]): 117 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 118 | meminit=False) 119 | self.txn = self.env.begin(write=False) 120 | 121 | sampleKeys = self.Sample[idx] 122 | 123 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 124 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 125 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 126 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 127 | 128 | valueList = [In, I0, I1, I2] 129 | watchList = [sampleKeys['In'], sampleKeys['I0'], sampleKeys['I1'], sampleKeys['I2']] 130 | 131 | gtList = [] 132 | for i in range(self.numIter): 133 | It = pickle.loads(self.txn.get(sampleKeys[f'It{i + 1}'].encode('ascii'))) 134 | valueList.append(It) 135 | gtList.append(sampleKeys[f'It{i + 1}']) 136 | valueList = self.transforms(valueList) 137 | 138 | outDict = {} 139 | for key, value in zip(self.outKeys, valueList): 140 | outDict[key] = value 141 | 142 | return outDict, watchList, gtList 143 | 144 | def __len__(self): 145 | return self.length 146 | 147 | 148 | def creatValiLoader(cfg: configMain): 149 | dataset = celebAValidation(cfg) 150 | 151 | # if cfg.dist.isDist: 152 | sampler = DistributedSamplerVali(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 153 | 154 | loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, 155 | num_workers=2, pin_memory=True, # False if memory is not enough 156 | drop_last=False, sampler=sampler) 157 | return sampler, loader 158 | 159 | 160 | if __name__ == "__main__": 161 | cfg = configMain() 162 | testSampler, testLoader = creatTrainLoader(cfg) 163 | for valdict in testLoader: 164 | pass 165 | # visImg(valdict['In'], wait=100) 166 | # visImg(valdict['I0'], wait=100) 167 | # visImg(valdict['I1'], wait=100) 168 | # visImg(valdict['I2'], wait=100) 169 | # 170 | # visImg(valdict['I0'], wait=100) 171 | # visImg(valdict['It1'], wait=100) 172 | # visImg(valdict['It2'], wait=100) 173 | # visImg(valdict['It3'], wait=100) 174 | # visImg(valdict['It4'], wait=100) 175 | # visImg(valdict['It5'], wait=100) 176 | # visImg(valdict['It6'], wait=100) 177 | # visImg(valdict['It7'], wait=100) 178 | # visImg(valdict['I1'], wait=100) 179 | -------------------------------------------------------------------------------- /DBVI_8x/dataloader/XVFI.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch.nn.functional as F 3 | 4 | cv2.setNumThreads(0) 5 | cv2.ocl.setUseOpenCL(False) 6 | from dataloader import cvtransform 7 | import torch.utils.data as data 8 | from torch.utils.data.distributed import DistributedSampler 9 | from dataloader.dataloaderBase import DistributedSamplerVali 10 | from configs.configTrain import configMain 11 | import pickle 12 | import lmdb 13 | from lib.visualTool import visImg 14 | import torch 15 | import numpy as np 16 | 17 | 18 | class Train(data.Dataset): 19 | def __init__(self, cfg: configMain): 20 | self.LMDB = './dataset/X4K1000FPS/train_lmdb/data.mdb' 21 | self.Sample = './dataset/X4K1000FPS/train_lmdb/sample.pkl' 22 | 23 | self.numIter = cfg.numIter 24 | 25 | with open(self.Sample, 'rb') as fs: 26 | self.Sample = pickle.load(fs) 27 | 28 | self.length = len(self.Sample) #1500 29 | 30 | self.transforms = cvtransform.Compose([ 31 | cvtransform.RandomCrop(cfg.train.size), 32 | cvtransform.RandomHorizontalFlip(0.5), 33 | cvtransform.RandomHVerticalFlip(0.5), 34 | cvtransform.ColorJitter(0.05, 0.05, 0.05, 0.05), 35 | cvtransform.ToTensor() 36 | ]) 37 | self.env = None 38 | self.txn = None 39 | self.outKeys = ['In', 'I0', 'It_1', 'I1', 'I2'] 40 | 41 | def __getitem__(self, idx): 42 | cv2.setNumThreads(0) 43 | cv2.ocl.setUseOpenCL(False) 44 | if any([self.txn is None, self.env is None]): 45 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 46 | meminit=False) 47 | self.txn = self.env.begin(write=False) 48 | 49 | sampleKeys = self.Sample[idx] 50 | 51 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 52 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 53 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 54 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 55 | 56 | t = np.random.randint(1, self.numIter + 1) 57 | It = pickle.loads(self.txn.get(sampleKeys[f'It{t}'].encode('ascii'))) 58 | t = torch.tensor(t / float(self.numIter + 1.0), dtype=torch.float32) 59 | 60 | if np.random.rand() > 0.5: 61 | valueList = [In, I0, It, I1, I2] 62 | else: 63 | valueList = [I2, I1, It, I0, In] 64 | t = 1 - t 65 | 66 | valueList = self.transforms(valueList) 67 | 68 | 69 | valueList = valueList 70 | 71 | outDict = {'t': t} 72 | for key, value in zip(self.outKeys, valueList): 73 | outDict[key] = value 74 | return outDict 75 | 76 | def __len__(self): 77 | return self.length 78 | 79 | def close(self): 80 | if self.env is not None: 81 | self.env.close() 82 | self.txn = None 83 | self.env = None 84 | 85 | 86 | def creatTrainLoader(cfg: configMain): 87 | dataset = Train(cfg) 88 | 89 | sampler = DistributedSampler(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 90 | 91 | loader = data.DataLoader(dataset=dataset, batch_size=cfg.train.batchPerGPU, 92 | shuffle=False, num_workers=4, pin_memory=True, # False if memory is not enough 93 | drop_last=True, sampler=sampler) 94 | return sampler, loader 95 | 96 | 97 | class Test(data.Dataset): 98 | def __init__(self, cfg: configMain): 99 | super(Test, self).__init__() 100 | self.Sample = './datasets/X4K1000FPS/test_lmdb/sample.pkl' 101 | self.LMDB = './datasets/X4K1000FPS/test_lmdb/data.mdb' 102 | self.numIter = cfg.numIter 103 | 104 | with open(self.Sample, 'rb') as fs: 105 | self.Sample = pickle.load(fs) 106 | 107 | self.length = len(self.Sample) 108 | 109 | self.transforms = cvtransform.Compose([ 110 | cvtransform.ToTensor() 111 | ]) 112 | self.env = None 113 | self.txn = None 114 | self.outKeys = ['In', 'I0', 'I1', 'I2', 'It1', 'It2', 'It3', 'It4', 'It5', 'It6', 'It7'] 115 | 116 | def __getitem__(self, idx): 117 | cv2.setNumThreads(0) 118 | cv2.ocl.setUseOpenCL(False) 119 | if any([self.txn is None, self.env is None]): 120 | self.env = lmdb.open(self.LMDB, subdir=False, readonly=True, lock=False, readahead=False, 121 | meminit=False) 122 | self.txn = self.env.begin(write=False) 123 | 124 | sampleKeys = self.Sample[idx] 125 | 126 | In = pickle.loads(self.txn.get(sampleKeys['In'].encode('ascii'))) 127 | I0 = pickle.loads(self.txn.get(sampleKeys['I0'].encode('ascii'))) 128 | I1 = pickle.loads(self.txn.get(sampleKeys['I1'].encode('ascii'))) 129 | I2 = pickle.loads(self.txn.get(sampleKeys['I2'].encode('ascii'))) 130 | 131 | valueList = [In, I0, I1, I2] 132 | watchList = [sampleKeys['In'], sampleKeys['I0'], sampleKeys['I1'], sampleKeys['I2']] 133 | gtList = [] 134 | 135 | for i in range(self.numIter): 136 | It = pickle.loads(self.txn.get(sampleKeys[f'It{i + 1}'].encode('ascii'))) 137 | valueList.append(It) 138 | gtList.append(sampleKeys[f'It{i + 1}']) 139 | valueList = self.transforms(valueList) 140 | 141 | outDict = {} 142 | for key, value in zip(self.outKeys, valueList): 143 | outDict[key] = value 144 | 145 | return outDict, watchList, gtList 146 | 147 | def __len__(self): 148 | return self.length 149 | 150 | 151 | def creatValiLoader(cfg: configMain): 152 | dataset = Test(cfg) 153 | 154 | # if cfg.dist.isDist: 155 | sampler = DistributedSamplerVali(dataset, num_replicas=cfg.dist.wordSize, rank=cfg.dist.gloRank) 156 | 157 | loader = data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, 158 | num_workers=2, pin_memory=False, # False if memory is not enough 159 | drop_last=False, sampler=sampler) 160 | return sampler, loader 161 | -------------------------------------------------------------------------------- /DBVI_8x/dataloader/dataUtils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def pixel_shuffle(input: torch.Tensor, scale_factor): 5 | batch_size, channels, in_height, in_width = input.size() 6 | 7 | out_channels = int(int(channels / scale_factor) / scale_factor) 8 | out_height = int(in_height * scale_factor) 9 | out_width = int(in_width * scale_factor) 10 | 11 | if scale_factor >= 1: 12 | input_view = input.contiguous().view(batch_size, out_channels, scale_factor, scale_factor, in_height, in_width) 13 | shuffle_out = input_view.permute(0, 1, 4, 2, 5, 3).contiguous() 14 | else: 15 | block_size = int(1 / scale_factor) 16 | input_view = input.contiguous().view(batch_size, channels, out_height, block_size, out_width, block_size) 17 | shuffle_out = input_view.permute(0, 1, 3, 5, 2, 4).contiguous() 18 | 19 | return shuffle_out.view(batch_size, out_channels, out_height, out_width) 20 | -------------------------------------------------------------------------------- /DBVI_8x/dataloader/dataloaderBase.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Sampler 2 | import torch.distributed as dist 3 | import math 4 | 5 | 6 | class DistributedSamplerVali(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | It is especially useful in conjunction with 10 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 11 | process can pass a DistributedSampler instance as a DataLoader sampler, 12 | and load a subset of the original dataset that is exclusive to it. 13 | 14 | .. note:: 15 | Dataset is assumed to be of constant size. 16 | 17 | Arguments: 18 | dataset: Dataset used for sampling. 19 | num_replicas (optional): Number of processes participating in 20 | distributed training. 21 | rank (optional): Rank of the current process within num_replicas. 22 | """ 23 | 24 | def __init__(self, dataset, num_replicas=None, rank=None): 25 | # super(DistributedSamplerVali, self).__init__() 26 | if num_replicas is None: 27 | if not dist.is_available(): 28 | raise RuntimeError("Requires distributed package to be available") 29 | num_replicas = dist.get_world_size() 30 | if rank is None: 31 | if not dist.is_available(): 32 | raise RuntimeError("Requires distributed package to be available") 33 | rank = dist.get_rank() 34 | self.dataset = dataset 35 | self.num_replicas = num_replicas 36 | self.rank = rank 37 | self.epoch = 0 38 | self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.num_replicas)) 39 | self.total_size = self.num_samples * self.num_replicas 40 | 41 | def __iter__(self): 42 | # deterministically shuffle based on epoch 43 | 44 | indices = list(range(len(self.dataset))) 45 | 46 | # add extra samples to make it evenly divisible 47 | indices += indices[:(self.total_size - len(indices))] 48 | assert len(indices) == self.total_size 49 | 50 | # subsample 51 | indices = indices[self.rank:self.total_size:self.num_replicas] 52 | assert len(indices) == self.num_samples 53 | 54 | return iter(indices) 55 | 56 | def __len__(self): 57 | return self.num_samples 58 | 59 | def set_epoch(self, epoch): 60 | self.epoch = epoch 61 | -------------------------------------------------------------------------------- /DBVI_8x/datasets/dataset: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_8x/datasets/dataset -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/alt_cuda_corr/build/temp.linux-x86_64-3.6/build.ninja: -------------------------------------------------------------------------------- 1 | ninja_required_version = 1.3 2 | cxx = c++ 3 | nvcc = /mnt/lustre/yuzhiyang/cuda/cuda90_cudnn7501_pytorch11/bin/nvcc 4 | 5 | cflags = -pthread -B /mnt/lustre/yuzhiyang/anaconda3/envs/torch18/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/include -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/include/TH -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/include/THC -I/mnt/lustre/yuzhiyang/cuda/cuda90_cudnn7501_pytorch11/include -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/include/python3.6m -c 6 | post_cflags = -std=c++14 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1009"' -DTORCH_EXTENSION_NAME=alt_cuda_corr -D_GLIBCXX_USE_CXX11_ABI=0 7 | cuda_cflags = -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/include -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/include/torch/csrc/api/include -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/include/TH -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/lib/python3.6/site-packages/torch/include/THC -I/mnt/lustre/yuzhiyang/cuda/cuda90_cudnn7501_pytorch11/include -I/mnt/lustre/yuzhiyang/anaconda3/envs/torch18/include/python3.6m -c 8 | cuda_post_cflags = -D__CUDA_NO_HALF_OPERATORS__ -D__CUDA_NO_HALF_CONVERSIONS__ -D__CUDA_NO_BFLOAT16_CONVERSIONS__ -D__CUDA_NO_HALF2_OPERATORS__ --expt-relaxed-constexpr --compiler-options ''"'"'-fPIC'"'"'' -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -DTORCH_API_INCLUDE_EXTENSION_H '-DPYBIND11_COMPILER_TYPE="_gcc"' '-DPYBIND11_STDLIB="_libstdcpp"' '-DPYBIND11_BUILD_ABI="_cxxabi1009"' -DTORCH_EXTENSION_NAME=alt_cuda_corr -D_GLIBCXX_USE_CXX11_ABI=0 -std=c++14 9 | ldflags = 10 | 11 | rule compile 12 | command = $cxx -MMD -MF $out.d $cflags -c $in -o $out $post_cflags 13 | depfile = $out.d 14 | deps = gcc 15 | 16 | rule cuda_compile 17 | command = $nvcc $cuda_cflags -c $in -o $out $cuda_post_cflags 18 | 19 | 20 | 21 | build /mnt/lustre/yuzhiyang/research/IMLE/IMLE_VI/lib/RAFT/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation.o: compile /mnt/lustre/yuzhiyang/research/IMLE/IMLE_VI/lib/RAFT/alt_cuda_corr/correlation.cpp 22 | build /mnt/lustre/yuzhiyang/research/IMLE/IMLE_VI/lib/RAFT/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation_kernel.o: cuda_compile /mnt/lustre/yuzhiyang/research/IMLE/IMLE_VI/lib/RAFT/alt_cuda_corr/correlation_kernel.cu 23 | 24 | 25 | 26 | 27 | 28 | -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_8x/lib/RAFT/alt_cuda_corr/build/temp.linux-x86_64-3.6/correlation.o -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | // CUDA forward declarations 5 | std::vector corr_cuda_forward( 6 | torch::Tensor fmap1, 7 | torch::Tensor fmap2, 8 | torch::Tensor coords, 9 | int radius); 10 | 11 | std::vector corr_cuda_backward( 12 | torch::Tensor fmap1, 13 | torch::Tensor fmap2, 14 | torch::Tensor coords, 15 | torch::Tensor corr_grad, 16 | int radius); 17 | 18 | // C++ interface 19 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 20 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 21 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 22 | 23 | std::vector corr_forward( 24 | torch::Tensor fmap1, 25 | torch::Tensor fmap2, 26 | torch::Tensor coords, 27 | int radius) { 28 | CHECK_INPUT(fmap1); 29 | CHECK_INPUT(fmap2); 30 | CHECK_INPUT(coords); 31 | 32 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 33 | } 34 | 35 | 36 | std::vector corr_backward( 37 | torch::Tensor fmap1, 38 | torch::Tensor fmap2, 39 | torch::Tensor coords, 40 | torch::Tensor corr_grad, 41 | int radius) { 42 | CHECK_INPUT(fmap1); 43 | CHECK_INPUT(fmap2); 44 | CHECK_INPUT(coords); 45 | CHECK_INPUT(corr_grad); 46 | 47 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 48 | } 49 | 50 | 51 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 52 | m.def("forward", &corr_forward, "CORR forward"); 53 | m.def("backward", &corr_backward, "CORR backward"); 54 | } -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/alt_cuda_corr/correlation.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: correlation 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/alt_cuda_corr/correlation.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | correlation.cpp 2 | correlation_kernel.cu 3 | setup.py 4 | correlation.egg-info/PKG-INFO 5 | correlation.egg-info/SOURCES.txt 6 | correlation.egg-info/dependency_links.txt 7 | correlation.egg-info/top_level.txt -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/alt_cuda_corr/correlation.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/alt_cuda_corr/correlation.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | alt_cuda_corr 2 | -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | 5 | setup( 6 | name='correlation', 7 | ext_modules=[ 8 | CUDAExtension('alt_cuda_corr', 9 | sources=['correlation.cpp', 'correlation_kernel.cu'], 10 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) 15 | 16 | -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/corr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from lib.RAFT.utils import bilinear_sampler, coords_grid 4 | # from lib.RAFT import alt_cuda_corr 5 | 6 | 7 | class CorrBlock: 8 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 9 | self.num_levels = num_levels 10 | self.radius = radius 11 | self.corr_pyramid = [] 12 | 13 | # all pairs correlation 14 | corr = CorrBlock.corr(fmap1, fmap2) 15 | 16 | batch, h1, w1, dim, h2, w2 = corr.shape 17 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 18 | 19 | self.corr_pyramid.append(corr) 20 | for i in range(self.num_levels - 1): 21 | corr = F.avg_pool2d(corr, 2, stride=2) 22 | self.corr_pyramid.append(corr) 23 | 24 | def __call__(self, coords): 25 | r = self.radius 26 | coords = coords.permute(0, 2, 3, 1) 27 | batch, h1, w1, _ = coords.shape 28 | 29 | out_pyramid = [] 30 | for i in range(self.num_levels): 31 | corr = self.corr_pyramid[i] 32 | dx = torch.linspace(-r, r, 2 * r + 1) 33 | dy = torch.linspace(-r, r, 2 * r + 1) 34 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) 35 | 36 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2 ** i 37 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 38 | coords_lvl = centroid_lvl + delta_lvl 39 | 40 | corr = bilinear_sampler(corr, coords_lvl) 41 | corr = corr.view(batch, h1, w1, -1) 42 | out_pyramid.append(corr) 43 | 44 | out = torch.cat(out_pyramid, dim=-1) 45 | return out.permute(0, 3, 1, 2).contiguous().float() 46 | 47 | @staticmethod 48 | def corr(fmap1, fmap2): 49 | batch, dim, ht, wd = fmap1.shape 50 | fmap1 = fmap1.view(batch, dim, ht * wd) 51 | fmap2 = fmap2.view(batch, dim, ht * wd) 52 | 53 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 54 | corr = corr.view(batch, ht, wd, 1, ht, wd) 55 | return corr / torch.sqrt(torch.tensor(dim).float()) 56 | 57 | # 58 | # class AlternateCorrBlock: 59 | # def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 60 | # self.num_levels = num_levels 61 | # self.radius = radius 62 | # 63 | # self.pyramid = [(fmap1, fmap2)] 64 | # for i in range(self.num_levels): 65 | # fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 66 | # fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 67 | # self.pyramid.append((fmap1, fmap2)) 68 | # 69 | # def __call__(self, coords): 70 | # coords = coords.permute(0, 2, 3, 1) 71 | # B, H, W, _ = coords.shape 72 | # dim = self.pyramid[0][0].shape[1] 73 | # 74 | # corr_list = [] 75 | # for i in range(self.num_levels): 76 | # r = self.radius 77 | # fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 78 | # fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 79 | # 80 | # coords_i = (coords / 2 ** i).reshape(B, 1, H, W, 2).contiguous() 81 | # corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 82 | # corr_list.append(corr.squeeze(1)) 83 | # 84 | # corr = torch.stack(corr_list, dim=1) 85 | # corr = corr.reshape(B, -1, H, W) 86 | # return corr / torch.sqrt(torch.tensor(dim).float()) 87 | -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/raft-things.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_8x/lib/RAFT/raft-things.pth -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from lib.RAFT.update import BasicUpdateBlock, SmallUpdateBlock 6 | from lib.RAFT.extractor import BasicEncoder, SmallEncoder 7 | from lib.RAFT.corr import CorrBlock 8 | from lib.RAFT.utils import coords_grid, upflow8 9 | from collections import OrderedDict 10 | import cv2 11 | cv2.setNumThreads(0) 12 | cv2.ocl.setUseOpenCL(False) 13 | from lib.visualTool import visFlow 14 | import argparse 15 | from pathlib import Path 16 | 17 | 18 | class RAFParam(object): 19 | def __init__(self): 20 | super(RAFParam, self).__init__() 21 | self.alternate_corr=False 22 | self.mixed_precision=False 23 | self.small=False 24 | 25 | 26 | class RAFT(nn.Module): 27 | def __init__(self): 28 | super(RAFT, self).__init__() 29 | import warnings 30 | warnings.filterwarnings("ignore") 31 | self.args = RAFParam() 32 | 33 | if self.args.small: 34 | self.hidden_dim = hdim = 96 35 | self.context_dim = cdim = 64 36 | self.args.corr_levels = 4 37 | self.args.corr_radius = 3 38 | 39 | else: 40 | self.hidden_dim = hdim = 128 41 | self.context_dim = cdim = 128 42 | self.args.corr_levels = 4 43 | self.args.corr_radius = 4 44 | 45 | self.args.dropout = 0 46 | self.args.alternate_corr = False 47 | 48 | # feature network, context network, and update block 49 | if self.args.small: 50 | self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=self.args.dropout) 51 | self.cnet = SmallEncoder(output_dim=hdim + cdim, norm_fn='none', dropout=self.args.dropout) 52 | self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) 53 | 54 | else: 55 | self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=self.args.dropout) 56 | self.cnet = BasicEncoder(output_dim=hdim + cdim, norm_fn='batch', dropout=self.args.dropout) 57 | self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) 58 | 59 | pathPreWeight = str(Path(__file__).parent.absolute() / Path('raft-things.pth')) 60 | self.initPreweight(pathPreWeight) 61 | 62 | def freeze_bn(self): 63 | for m in self.modules(): 64 | if isinstance(m, nn.BatchNorm2d): 65 | m.eval() 66 | 67 | def initialize_flow(self, img): 68 | """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" 69 | N, C, H, W = img.shape 70 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 71 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 72 | 73 | # optical flow computed as difference: flow = coords1 - coords0 74 | return coords0, coords1 75 | 76 | def upsample_flow(self, flow, mask, scale=8, ksize=3): 77 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 78 | N, C, H, W = flow.shape 79 | mask = mask.view(N, 1, ksize ** 2, scale, scale, H, W) 80 | mask = torch.softmax(mask, dim=2) 81 | 82 | up_flow = F.unfold(8 * flow, [ksize, ksize], padding=1) 83 | up_flow = up_flow.view(N, C, ksize ** 2, 1, 1, H, W) 84 | 85 | up_flow = torch.sum(mask * up_flow, dim=2) 86 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 87 | return up_flow.reshape(N, 2, scale * H, scale * W) 88 | 89 | def forward(self, image1, image2, iters=12, flow_init=None, upsample=True, test_mode=False): 90 | """ Estimate optical flow between pair of frames """ 91 | image1 = image1.contiguous() 92 | image2 = image2.contiguous() 93 | 94 | hdim = self.hidden_dim 95 | cdim = self.context_dim 96 | 97 | fmap1, fmap2 = self.fnet([image1, image2]) 98 | 99 | fmap1 = fmap1.float() 100 | fmap2 = fmap2.float() 101 | 102 | corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) 103 | 104 | cnet = self.cnet(image1) 105 | net, inp = torch.split(cnet, [hdim, cdim], dim=1) 106 | net = torch.tanh(net) 107 | inp = torch.relu(inp) 108 | 109 | coords0, coords1 = self.initialize_flow(image1) 110 | 111 | if flow_init is not None: 112 | coords1 = coords1 + flow_init 113 | 114 | # flow_predictions = [] 115 | for itr in range(iters): 116 | coords1 = coords1.detach() 117 | corr = corr_fn(coords1) # index correlation volume 118 | 119 | flow = coords1 - coords0 120 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 121 | 122 | # F(t+1) = F(t) + \Delta(t) 123 | coords1 = coords1 + delta_flow 124 | 125 | # upsample predictions 126 | # if up_mask is None: 127 | # flow_up = upflow8(coords1 - coords0) 128 | # else: 129 | flow_up = self.upsample_flow(coords1 - coords0, up_mask) 130 | 131 | return flow_up 132 | 133 | def initPreweight(self, pathPreWeight: str = None, rmModule=True): 134 | preW = self.getWeight(pathPreWeight) 135 | assert preW is not None, 'weighth in {} is empty'.format(pathPreWeight) 136 | modelW = self.state_dict() 137 | preWDict = OrderedDict() 138 | # modelWDict = OrderedDict() 139 | 140 | for k, v in preW.items(): 141 | if rmModule: 142 | preWDict[k.replace('module.', "")] = v 143 | else: 144 | preWDict[k] = v 145 | 146 | shareW = {k: v for k, v in preWDict.items() if str(k) in modelW} 147 | assert shareW, 'shareW is empty' 148 | self.load_state_dict(preWDict, strict=False) 149 | 150 | @staticmethod 151 | def getWeight(pathPreWeight: str = None): 152 | if pathPreWeight is not None: 153 | return torch.load(pathPreWeight, map_location=torch.device('cpu')) 154 | else: 155 | return None 156 | 157 | def main(): 158 | model = RAFT() 159 | model.cuda() 160 | model.eval() 161 | 162 | with torch.no_grad(): 163 | image1 = cv2.imread( 164 | '/home/sensetime/data/VideoInterpolation/highfps/goPro/240fps/GoPro_public/test/GOPR0384_11_00/000009.png') 165 | image2 = cv2.imread( 166 | '/home/sensetime/data/VideoInterpolation/highfps/goPro/240fps/GoPro_public/test/GOPR0384_11_00/000017.png') 167 | 168 | image1 = torch.from_numpy(image1).float().permute([2, 0, 1]).unsqueeze(0)[:, [2, 1, 0], ...].cuda() 169 | image2 = torch.from_numpy(image2).float().permute([2, 0, 1]).unsqueeze(0)[:, [2, 1, 0], ...].cuda() 170 | 171 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 172 | visFlow(flow_up) 173 | 174 | 175 | if __name__ == '__main__': 176 | main() -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/update.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FlowHead(nn.Module): 7 | def __init__(self, input_dim=128, hidden_dim=256): 8 | super(FlowHead, self).__init__() 9 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 10 | self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) 11 | self.relu = nn.ReLU(inplace=True) 12 | 13 | def forward(self, x): 14 | return self.conv2(self.relu(self.conv1(x))) 15 | 16 | class ConvGRU(nn.Module): 17 | def __init__(self, hidden_dim=128, input_dim=192+128): 18 | super(ConvGRU, self).__init__() 19 | self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 20 | self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 21 | self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) 22 | 23 | def forward(self, h, x): 24 | hx = torch.cat([h, x], dim=1) 25 | 26 | z = torch.sigmoid(self.convz(hx)) 27 | r = torch.sigmoid(self.convr(hx)) 28 | q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) 29 | 30 | h = (1-z) * h + z * q 31 | return h 32 | 33 | class SepConvGRU(nn.Module): 34 | def __init__(self, hidden_dim=128, input_dim=192+128): 35 | super(SepConvGRU, self).__init__() 36 | self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 37 | self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 38 | self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) 39 | 40 | self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 41 | self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 42 | self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) 43 | 44 | 45 | def forward(self, h, x): 46 | # horizontal 47 | hx = torch.cat([h, x], dim=1) 48 | z = torch.sigmoid(self.convz1(hx)) 49 | r = torch.sigmoid(self.convr1(hx)) 50 | q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) 51 | h = (1-z) * h + z * q 52 | 53 | # vertical 54 | hx = torch.cat([h, x], dim=1) 55 | z = torch.sigmoid(self.convz2(hx)) 56 | r = torch.sigmoid(self.convr2(hx)) 57 | q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) 58 | h = (1-z) * h + z * q 59 | 60 | return h 61 | 62 | class SmallMotionEncoder(nn.Module): 63 | def __init__(self, args): 64 | super(SmallMotionEncoder, self).__init__() 65 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 66 | self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) 67 | self.convf1 = nn.Conv2d(2, 64, 7, padding=3) 68 | self.convf2 = nn.Conv2d(64, 32, 3, padding=1) 69 | self.conv = nn.Conv2d(128, 80, 3, padding=1) 70 | 71 | def forward(self, flow, corr): 72 | cor = F.relu(self.convc1(corr)) 73 | flo = F.relu(self.convf1(flow)) 74 | flo = F.relu(self.convf2(flo)) 75 | cor_flo = torch.cat([cor, flo], dim=1) 76 | out = F.relu(self.conv(cor_flo)) 77 | return torch.cat([out, flow], dim=1) 78 | 79 | class BasicMotionEncoder(nn.Module): 80 | def __init__(self, args): 81 | super(BasicMotionEncoder, self).__init__() 82 | cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 83 | self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) 84 | self.convc2 = nn.Conv2d(256, 192, 3, padding=1) 85 | self.convf1 = nn.Conv2d(2, 128, 7, padding=3) 86 | self.convf2 = nn.Conv2d(128, 64, 3, padding=1) 87 | self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) 88 | 89 | def forward(self, flow, corr): 90 | cor = F.relu(self.convc1(corr)) 91 | cor = F.relu(self.convc2(cor)) 92 | flo = F.relu(self.convf1(flow)) 93 | flo = F.relu(self.convf2(flo)) 94 | 95 | cor_flo = torch.cat([cor, flo], dim=1) 96 | out = F.relu(self.conv(cor_flo)) 97 | return torch.cat([out, flow], dim=1) 98 | 99 | class SmallUpdateBlock(nn.Module): 100 | def __init__(self, args, hidden_dim=96): 101 | super(SmallUpdateBlock, self).__init__() 102 | self.encoder = SmallMotionEncoder(args) 103 | self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) 104 | self.flow_head = FlowHead(hidden_dim, hidden_dim=128) 105 | 106 | def forward(self, net, inp, corr, flow): 107 | motion_features = self.encoder(flow, corr) 108 | inp = torch.cat([inp, motion_features], dim=1) 109 | net = self.gru(net, inp) 110 | delta_flow = self.flow_head(net) 111 | 112 | return net, None, delta_flow 113 | 114 | class BasicUpdateBlock(nn.Module): 115 | def __init__(self, args, hidden_dim=128, input_dim=128): 116 | super(BasicUpdateBlock, self).__init__() 117 | self.args = args 118 | self.encoder = BasicMotionEncoder(args) 119 | self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) 120 | self.flow_head = FlowHead(hidden_dim, hidden_dim=256) 121 | 122 | self.mask = nn.Sequential( 123 | nn.Conv2d(128, 256, 3, padding=1), 124 | nn.ReLU(inplace=True), 125 | nn.Conv2d(256, 64*9, 1, padding=0)) 126 | 127 | def forward(self, net, inp, corr, flow, upsample=True): 128 | motion_features = self.encoder(flow, corr) 129 | inp = torch.cat([inp, motion_features], dim=1) 130 | 131 | net = self.gru(net, inp) 132 | delta_flow = self.flow_head(net) 133 | 134 | # scale mask to balence gradients 135 | mask = .25 * self.mask(net) 136 | return net, mask, delta_flow 137 | 138 | 139 | 140 | -------------------------------------------------------------------------------- /DBVI_8x/lib/RAFT/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | from scipy import interpolate 5 | 6 | 7 | class InputPadder: 8 | """ Pads images such that dimensions are divisible by 8 """ 9 | def __init__(self, dims, mode='sintel'): 10 | self.ht, self.wd = dims[-2:] 11 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 12 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 13 | if mode == 'sintel': 14 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 15 | else: 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 17 | 18 | def pad(self, *inputs): 19 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 20 | 21 | def unpad(self,x): 22 | ht, wd = x.shape[-2:] 23 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 24 | return x[..., c[0]:c[1], c[2]:c[3]] 25 | 26 | def forward_interpolate(flow): 27 | flow = flow.detach().cpu().numpy() 28 | dx, dy = flow[0], flow[1] 29 | 30 | ht, wd = dx.shape 31 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 32 | 33 | x1 = x0 + dx 34 | y1 = y0 + dy 35 | 36 | x1 = x1.reshape(-1) 37 | y1 = y1.reshape(-1) 38 | dx = dx.reshape(-1) 39 | dy = dy.reshape(-1) 40 | 41 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 42 | x1 = x1[valid] 43 | y1 = y1[valid] 44 | dx = dx[valid] 45 | dy = dy[valid] 46 | 47 | flow_x = interpolate.griddata( 48 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 49 | 50 | flow_y = interpolate.griddata( 51 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 52 | 53 | flow = np.stack([flow_x, flow_y], axis=0) 54 | return torch.from_numpy(flow).float() 55 | 56 | 57 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 58 | """ Wrapper for grid_sample, uses pixel coordinates """ 59 | H, W = img.shape[-2:] 60 | xgrid, ygrid = coords.split([1,1], dim=-1) 61 | xgrid = 2*xgrid/(W-1) - 1 62 | ygrid = 2*ygrid/(H-1) - 1 63 | 64 | grid = torch.cat([xgrid, ygrid], dim=-1) 65 | img = F.grid_sample(img, grid, align_corners=True) 66 | 67 | if mask: 68 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 69 | return img, mask.float() 70 | 71 | return img 72 | 73 | 74 | def coords_grid(batch, ht, wd): 75 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 76 | coords = torch.stack(coords[::-1], dim=0).float() 77 | return coords[None].repeat(batch, 1, 1, 1) 78 | 79 | 80 | def upflow8(flow, mode='bilinear'): 81 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 82 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 83 | -------------------------------------------------------------------------------- /DBVI_8x/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_8x/lib/__init__.py -------------------------------------------------------------------------------- /DBVI_8x/lib/checkTool.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | cv2.setNumThreads(0) 4 | cv2.ocl.setUseOpenCL(False) 5 | import torch 6 | import torch.nn.functional as F 7 | from lib import fileTool as FT 8 | import torch 9 | 10 | 11 | def checkGrad(net): 12 | for parem in list(net.named_parameters()): 13 | if parem[1].grad is not None: 14 | print(parem[0] + ' \t shape={}, \t mean={}, \t std={}\n'.format(parem[1].shape, 15 | parem[1].grad.abs().mean().cpu().item(), 16 | parem[1].grad.abs().std().cpu().item())) 17 | 18 | 19 | def write_video_cv2(allFrames, video_name, fps, sizes): 20 | out = cv2.VideoWriter(video_name, cv2.CAP_OPENCV_MJPEG, cv2.VideoWriter_fourcc('M', 'J', 'P', 'G'), fps, sizes) 21 | 22 | for outF in allFrames: 23 | # frameIn = cv2.imread(inF, cv2.IMREAD_COLOR) 24 | frameOut = cv2.imread(outF, cv2.IMREAD_COLOR) 25 | # frame = np.concatenate([frameIn, frameOut], axis=1) 26 | out.write(frameOut) 27 | out.release() 28 | 29 | 30 | if __name__ == '__main__': 31 | videoPath = '/home/sensetime/data/ICCV2021/OurResults/slomoDVS34_16/Ours_S2/slomoDVS-2021_02_24_11_48_40' 32 | 33 | allFrames = FT.getAllFiles(videoPath, 'png') 34 | inFrames = [a for a in allFrames if 'EVI' not in a] 35 | inFrames = [[a, a, a, a] for a in inFrames] 36 | inFrames = [item for sublist in inFrames for item in sublist] 37 | 38 | videoName = '/home/sensetime/data/ICCV2021/OurResults/slomoDVS34_16/Ours_S2/slomoDVS-2021_02_24_11_48_40/couple.avi' 39 | write_video_cv2(inFrames, allFrames, videoName, 24, (480, 176)) 40 | -------------------------------------------------------------------------------- /DBVI_8x/lib/distTool.py: -------------------------------------------------------------------------------- 1 | import torch.distributed as dist 2 | import os 3 | from configs.configTrain import configMain 4 | 5 | 6 | def synchronize(): 7 | """ 8 | Helper function to synchronize (barrier) among all processes when 9 | using distributed training 10 | """ 11 | if not dist.is_available(): 12 | return 13 | if not dist.is_initialized(): 14 | return 15 | world_size = dist.get_world_size() 16 | if world_size == 1: 17 | return 18 | dist.barrier() 19 | 20 | 21 | def get_world_size(): 22 | if not dist.is_available(): 23 | return 1 24 | if not dist.is_initialized(): 25 | return 1 26 | return dist.get_world_size() 27 | 28 | 29 | def reduceTensorMean(tensor): 30 | wordSize = os.environ['SLURM_NTASKS'] if 'SLURM_NTASKS' in os.environ else 1 31 | rt = tensor.clone() 32 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 33 | rt /= float(wordSize) 34 | return rt 35 | 36 | def reduceTensorSum(tensor): 37 | rt = tensor.clone() 38 | dist.all_reduce(rt, op=dist.ReduceOp.SUM) 39 | return rt 40 | -------------------------------------------------------------------------------- /DBVI_8x/lib/dlTool.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torch.optim import lr_scheduler 3 | from configs.configTrain import configMain 4 | import math 5 | import torch 6 | from sklearn.mixture import GaussianMixture 7 | import numpy as np 8 | from torch.optim import Adam 9 | 10 | 11 | def getOptimizer(gNet: nn.Module, cfg: configMain): 12 | core_network_params = [] 13 | map_network_params = [] 14 | for k, v in gNet.named_parameters(): 15 | if v.requires_grad: 16 | if 'mapping' in k: 17 | map_network_params.append(v) 18 | else: 19 | core_network_params.append(v) 20 | optim = Adam([{'params': core_network_params, 'lr': cfg.optim.lrInit}, 21 | {'params': map_network_params, 'lr': 1e-2 * cfg.optim.lrInit}], 22 | weight_decay=0, betas=(0.9, 0.999)) 23 | for group in optim.param_groups: 24 | group.setdefault('initial_lr', group['lr']) 25 | maxPSNR = -1 26 | return optim, maxPSNR 27 | 28 | 29 | def getScheduler(optimizer, epoch=-1): 30 | return lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=[80, 120, 160], gamma=0.4, last_epoch=epoch) 31 | 32 | 33 | def compute_same_pad(kernel_size, stride): 34 | if isinstance(kernel_size, int): 35 | kernel_size = [kernel_size] 36 | 37 | if isinstance(stride, int): 38 | stride = [stride] 39 | 40 | assert len(stride) == len( 41 | kernel_size 42 | ), "Pass kernel size and stride both as int, or both as equal length iterable" 43 | 44 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel_size, stride)] 45 | 46 | 47 | def fitFullConvGaussian(zAll: torch.Tensor): 48 | zAllNpy = zAll.detach().cpu().numpy() 49 | selectIdx = np.random.choice(zAllNpy.shape[0], size=100, replace=False) 50 | zAllNpySellect = zAllNpy[selectIdx, ...] 51 | gmm = GaussianMixture(n_components=10, covariance_type='full').fit(zAllNpySellect) 52 | return gmm 53 | -------------------------------------------------------------------------------- /DBVI_8x/lib/fileTool.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | from queue import Queue 4 | 5 | 6 | def delPath(root): 7 | """ 8 | remove dir trees 9 | :param path: root dir 10 | :return: True or False 11 | """ 12 | root = Path(root) 13 | if root.is_file(): 14 | try: 15 | root.unlink() 16 | except Exception as e: 17 | print(e) 18 | elif root.is_dir(): 19 | for item in root.iterdir(): 20 | delPath(item) 21 | try: 22 | root.rmdir() 23 | # print('Files in {} is removed'.format(root)) 24 | except Exception as e: 25 | print(e) 26 | 27 | 28 | def mkPath(path): 29 | p = Path(path) 30 | try: 31 | p.mkdir(parents=True, exist_ok=False) 32 | return True 33 | except Exception as e: 34 | return False 35 | 36 | 37 | def getAllFiles(root, ext=None): 38 | p = Path(root) 39 | if ext is not None: 40 | pathnames = p.glob("**/*.{}".format(ext)) 41 | 42 | else: 43 | pathnames = p.glob("**/*") 44 | filenames = sorted([x.as_posix() for x in pathnames]) 45 | return filenames 46 | 47 | 48 | def copyFile(src, dst): 49 | parent = Path(dst).parent 50 | mkPath(parent) 51 | try: 52 | shutil.copytree(str(src), str(dst)) 53 | except: 54 | shutil.copy(str(src), str(dst)) 55 | 56 | 57 | def movFile(src, dst): 58 | parent = Path(dst).parent 59 | mkPath(parent) 60 | shutil.move(str(src), str(dst)) 61 | 62 | 63 | def getSubDirs(root): 64 | p = Path(root) 65 | dirs = [x for x in p.iterdir() if x.is_dir()] 66 | return sorted(dirs) 67 | 68 | 69 | class fileBuffer(Queue): 70 | def __init__(self, capacity): 71 | super(fileBuffer, self).__init__() 72 | assert capacity > 0 73 | self.capacity = capacity 74 | 75 | def __call__(self, x): 76 | while self.qsize() >= self.capacity: 77 | delPath(self.get()) 78 | self.put(x) 79 | 80 | 81 | if __name__ == '__main__': 82 | path = '/home/sensetime/project/lib/a' 83 | # mkPath(path) 84 | filenames = delPath(path) 85 | pass 86 | -------------------------------------------------------------------------------- /DBVI_8x/lib/lossTool.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | import numpy as np 6 | import scipy 7 | import torch 8 | from torch.nn.functional import avg_pool2d, pad 9 | from torch import pow, sqrt, clamp, mean 10 | from torch.nn import Module, L1Loss 11 | from torch import nn 12 | 13 | 14 | class l1Loss(nn.Module): 15 | def __init__(self, reduction='mean'): 16 | super(l1Loss, self).__init__() 17 | self.reduction = reduction 18 | 19 | def forward(self, pred, gt, mask=None): 20 | if mask is None: 21 | return F.l1_loss(pred, gt, reduction=self.reduction) 22 | else: 23 | return F.l1_loss(pred * mask, gt * mask, reduction=self.reduction) 24 | 25 | def __repr__(self): 26 | return 'l1Loss' 27 | 28 | 29 | class l2Loss(nn.Module): 30 | def __init__(self, reduction='mean'): 31 | super(l2Loss, self).__init__() 32 | self.reduction = reduction 33 | 34 | def forward(self, pred, gt, mask=None): 35 | if mask is None: 36 | return F.mse_loss(pred, gt, reduction=self.reduction) 37 | else: 38 | return F.mse_loss(pred * mask, gt * mask, reduction=self.reduction) 39 | 40 | def __repr__(self): 41 | return 'l2Loss' 42 | 43 | 44 | class totalLoss(nn.Module): 45 | def __init__(self): 46 | super(totalLoss, self).__init__() 47 | self.l1loss = l1Loss() 48 | self.scale = [0.2, 0.4, 1, 1] 49 | self.lossNames = ['l1'] 50 | self.lossDict = {} 51 | 52 | def forward(self, rgbs, gts): 53 | self.lossDict.clear() 54 | l1loss = 0 55 | 56 | if len(rgbs) > 1: 57 | assert all([len(self.scale) == len(rgbs), len(rgbs) == len(gts)]) 58 | scales = self.scale 59 | else: 60 | scales = self.scale[-len(rgbs)::] 61 | for scale, rgb, gt in zip(scales, rgbs, gts): 62 | l1loss += self.l1loss(rgb, gt.detach()).mean() * scale * 1 63 | 64 | lossSum = l1loss 65 | 66 | self.lossDict.setdefault('l1', l1loss.detach()) 67 | self.lossDict.setdefault('Total', lossSum.detach()) 68 | 69 | return lossSum 70 | 71 | 72 | if __name__ == '__main__': 73 | a = torch.randn([1, 3, 128, 128]) 74 | b = torch.randn([1, 3, 128, 128]) 75 | mask = torch.ones_like(a) 76 | 77 | loss = totalLoss() 78 | 79 | lossOut = loss(a, b, mask) 80 | lossOut2 = loss(a, b, mask) 81 | pass 82 | -------------------------------------------------------------------------------- /DBVI_8x/lib/videoTool.py: -------------------------------------------------------------------------------- 1 | import os 2 | import lib.fileTool as FT 3 | 4 | ffmpegPath = '/usr/bin/ffmpeg' 5 | videoName = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/bbb.MP4' 6 | 7 | 8 | def video2Frame(vPath: str, fdir: str, H: int = None, W: int = None): 9 | FT.mkPath(fdir) 10 | if H is None or W is None: 11 | os.system('{} -y -i {} -vsync 0 -qscale:v 2 {}/%04d.png'.format(ffmpegPath, vPath, fdir)) 12 | else: 13 | os.system('{} -y -i {} -vf scale={}:{} -vsync 0 -qscale:v 2 {}/%04d.jpg'.format(ffmpegPath, vPath, W, H, fdir)) 14 | 15 | 16 | def frame2Video(fdir: str, vPath: str, fps: int, H: int = None, W: int = None, ): 17 | if H is None or W is None: 18 | # os.system('{} -y -r {} -f image2 -i {}/%*.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 19 | # .format(ffmpegPath, fps, fdir, vPath)) 20 | 21 | os.system('{} -y -r {} -f image2 -i {}/%4d.png -vcodec libx264 -crf 18 -pix_fmt yuv420p {}' 22 | .format(ffmpegPath, fps, fdir, vPath)) 23 | else: 24 | os.system('{} -y -r {} -f image2 -s {}x{} -i {}/%*.png -vcodec libx264 -crf 25 -pix_fmt yuv420p {}' 25 | .format(ffmpegPath, fps, W, H, fdir, vPath)) 26 | 27 | 28 | def slomo(vPath: str, dstPath: str, fps): 29 | os.system( 30 | '{} -y -r {} -i {} -strict -2 -vcodec libx264 -c:a aac -crf 18 {}'.format(ffmpegPath, fps, vPath, dstPath)) 31 | 32 | 33 | def downFPS(vPath: str, dstPath: str, fps): 34 | os.system( 35 | '{} -i {} -strict -2 -r {} {}'.format(ffmpegPath, vPath, fps, dstPath)) 36 | 37 | 38 | def downSample(vPath: str, dstPath: str, H, W): 39 | os.system( 40 | '{} -i {} -strict -2 -s {}x{} {}'.format(ffmpegPath, vPath, H, W, dstPath)) 41 | 42 | 43 | if __name__ == '__main__': 44 | framePath = '/data/2021_12_23/dstframes' 45 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/gopro_yzy/output' 46 | video = '/data/2021_12_23/02.mp4' 47 | # video2Frame(video, framePath) 48 | 49 | # video = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00/out.mp4' 50 | # framePath = '/home/sensetime/data/VideoInterpolation/highfps/goPro_240fps/train/GOPR0372_07_00' 51 | frame2Video(framePath, video, 30) 52 | 53 | # vPath = '/media/sensetime/Elements/0721 /0716_video/1.avi' 54 | # dstPath = '/media/sensetime/Elements/0721/0716_video/1_.mp4' 55 | # downFPS(vPath, dstPath, 8) 56 | -------------------------------------------------------------------------------- /DBVI_8x/lib/visualTool.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | cv2.setNumThreads(0) 3 | cv2.ocl.setUseOpenCL(False) 4 | import torch 5 | import numpy as np 6 | from pathlib import Path 7 | from lib import fileTool as FT 8 | 9 | 10 | class visFlow(): 11 | def __init__(self, flow_uv): 12 | super(visFlow, self).__init__() 13 | self.colorWheel = self.make_colorwheel() 14 | self.run(flow_uv) 15 | 16 | def run(self, flow_uv: torch.Tensor, clip_flow=None, convert_to_bgr=False): 17 | flow_uv = flow_uv[0].permute([1, 2, 0]).detach().cpu().numpy() 18 | if clip_flow is not None: 19 | flow_uv = np.clip(flow_uv, 0, clip_flow) 20 | u = flow_uv[:, :, 0] 21 | v = flow_uv[:, :, 1] 22 | rad = np.sqrt(np.square(u) + np.square(v)) 23 | rad_max = np.max(rad) 24 | epsilon = 1e-5 25 | u = u / (rad_max + epsilon) 26 | v = v / (rad_max + epsilon) 27 | rgb = self.flow_uv_to_colors(u, v, convert_to_bgr) 28 | cv2.namedWindow('flow', 0) 29 | cv2.imshow('flow', rgb[:, :, [2, 1, 0]]) 30 | cv2.waitKey(0) 31 | 32 | def make_colorwheel(self): 33 | RY, YG, GC, CB, BM, MR = (15, 6, 4, 11, 13, 6) 34 | ncols = RY + YG + GC + CB + BM + MR 35 | colorwheel = np.zeros((ncols, 3)) 36 | col = 0 37 | 38 | # RY 39 | colorwheel[0:RY, 0] = 255 40 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 41 | col = col + RY 42 | # YG 43 | colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 44 | colorwheel[col:col + YG, 1] = 255 45 | col = col + YG 46 | # GC 47 | colorwheel[col:col + GC, 1] = 255 48 | colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 49 | col = col + GC 50 | # CB 51 | colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 52 | colorwheel[col:col + CB, 2] = 255 53 | col = col + CB 54 | # BM 55 | colorwheel[col:col + BM, 2] = 255 56 | colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 57 | col = col + BM 58 | # MR 59 | colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 60 | colorwheel[col:col + MR, 0] = 255 61 | return colorwheel 62 | 63 | def flow_uv_to_colors(self, u, v, convert_to_bgr=False): 64 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 65 | ncols = self.colorWheel.shape[0] 66 | rad = np.sqrt(np.square(u) + np.square(v)) 67 | a = np.arctan2(-v, -u) / np.pi 68 | fk = (a + 1) / 2 * (ncols - 1) 69 | k0 = np.floor(fk).astype(np.int32) 70 | k1 = k0 + 1 71 | k1[k1 == ncols] = 0 72 | f = fk - k0 73 | for i in range(self.colorWheel.shape[1]): 74 | tmp = self.colorWheel[:, i] 75 | col0 = tmp[k0] / 255.0 76 | col1 = tmp[k1] / 255.0 77 | col = (1 - f) * col0 + f * col1 78 | idx = (rad <= 1) 79 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 80 | col[~idx] = col[~idx] * 0.75 # out of range 81 | # Note the 2-i => BGR instead of RGB 82 | ch_idx = 2 - i if convert_to_bgr else i 83 | flow_image[:, :, ch_idx] = np.floor(255 * col) 84 | return flow_image 85 | 86 | 87 | class tensor2Video(object): 88 | def __init__(self, outPath, h, w, fps=24): 89 | super(tensor2Video, self).__init__() 90 | fourcc = cv2.VideoWriter.fourcc('I', '4', '2', '0') 91 | self.out = cv2.VideoWriter(outPath, fourcc, fps, (w, h)) 92 | 93 | def add(self, frame: torch.Tensor): 94 | frame = (frame + 1.0) / 2.0 95 | # frame = (frame - frame.min()) / (frame.max() - frame.min()) 96 | frame = (frame[0].permute([1, 2, 0]).cpu().numpy() * 255).astype(np.uint8) 97 | self.out.write(frame) 98 | 99 | def release(self): 100 | self.out.release() 101 | 102 | 103 | def makeGrid(batchImg: torch.Tensor, shape=(2, 1)): 104 | N, C, H, W = batchImg.shape 105 | batchImg = batchImg.permute([0, 2, 3, 1]) 106 | batchImg = batchImg.detach().cpu().numpy() 107 | nh = shape[0] 108 | nw = shape[1] 109 | batchImg = batchImg.reshape((nh, nw, H, W, C)).swapaxes(1, 2).reshape(nh * H, nw * W, C) 110 | return batchImg 111 | 112 | 113 | def visImg(batchImg: torch.Tensor, shape=(1, 1), wait=0, name='visImg'): 114 | """ 115 | :param img: tensor(N,3,H,W) or None 116 | :return: None 117 | """ 118 | N, C, H, W = batchImg.shape 119 | assert all([C == 3, N == shape[0] * shape[1]]) 120 | # batchImg = ((batchImg - batchImg.min()) / (batchImg.max() - batchImg.min()) * 255.0).byte() 121 | batchImg = ((batchImg.float() + 1) / 2.0 * 255).clamp(0, 255).byte() 122 | batchImgViz = makeGrid(batchImg, shape) 123 | cv2.namedWindow(name, 0) 124 | cv2.imshow(name, batchImgViz[:, :, ::-1]) 125 | cv2.waitKey(wait) 126 | 127 | 128 | def saveImg(x: torch.Tensor, srcName: str, dstDir: str, isInter=False): 129 | if isInter: 130 | dstName = str(Path(dstDir) / srcName.replace('.png', '_inter.png')) 131 | else: 132 | dstName = str(Path(dstDir) / srcName) 133 | 134 | if Path(dstName).is_file(): 135 | return False 136 | if not Path(Path(dstName).parent).is_dir(): 137 | FT.mkPath(Path(dstName).parent) 138 | 139 | xRGB = (x.clamp(-1, 1) + 1.0) / 2.0 140 | xRGB = (xRGB[0].permute([1, 2, 0]).detach().cpu() * 255).byte().numpy() 141 | xBGR = cv2.cvtColor(xRGB, cv2.COLOR_RGB2BGR) 142 | cv2.imwrite(dstName, xBGR) 143 | return True 144 | 145 | 146 | def saveTensor(x: torch.Tensor, srcName: str, dstDir: str, isInter=False): 147 | if isInter: 148 | dstName = str(Path(dstDir) / srcName.replace('.png', '_inter.pth')) 149 | else: 150 | dstName = str(Path(dstDir) / srcName.replace('.png', '.pth')) 151 | 152 | if Path(dstName).is_file(): 153 | return False 154 | if not Path(Path(dstName).parent).is_dir(): 155 | FT.mkPath(Path(dstName).parent) 156 | 157 | # xRGB = (x + 1.0) / 2.0 158 | xRGB = x.detach().cpu() 159 | torch.save(xRGB, dstName) 160 | return True 161 | 162 | 163 | 164 | def visImg(batchImg: torch.Tensor, shape=(1, 1), wait=0, name='visImg'): 165 | """ 166 | :param img: tensor(N,3,H,W) or None 167 | :return: None 168 | """ 169 | N, C, H, W = batchImg.shape 170 | assert all([C == 3, N == shape[0] * shape[1]]) 171 | # batchImg = ((batchImg - batchImg.min()) / (batchImg.max() - batchImg.min()) * 255.0).byte() 172 | batchImg = ((batchImg.float() + 1) / 2.0 * 255).clamp(0, 255).byte()[:, [2, 1, 0], :, :] # rgb2bgr 173 | batchImgViz = makeGrid(batchImg, shape) 174 | cv2.namedWindow(name, 0) 175 | cv2.imshow(name, batchImgViz) 176 | cv2.waitKey(wait) 177 | 178 | 179 | def saveImg(x: torch.Tensor, outpath, isrgb=True): 180 | if Path(outpath).is_file(): 181 | return None 182 | rmax, rmin = x.max(), x.min() 183 | if rmin < -0.5: 184 | x = (x + 1.0) / 2.0 185 | x = x.clamp(0, 1) 186 | if isrgb: 187 | x = x[:, [2, 1, 0], :, :] 188 | xNpy = (x.squeeze(0).permute([1, 2, 0]) * 255).byte().detach().cpu().numpy() 189 | FT.mkPath(Path(outpath).parent) 190 | 191 | cv2.imwrite(str(outpath), xNpy) -------------------------------------------------------------------------------- /DBVI_8x/lib/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from lib.softsplat import FunctionSoftsplat 5 | from torch.autograd import grad 6 | 7 | 8 | def backWarp(img: torch.Tensor, flow: torch.Tensor): 9 | device = img.device 10 | N, C, H, W = img.size() 11 | 12 | u = flow[:, 0, :, :] 13 | v = flow[:, 1, :, :] 14 | 15 | gridY, gridX = torch.meshgrid([torch.arange(start=0, end=H, device=device, requires_grad=False), 16 | torch.arange(start=0, end=W, device=device, requires_grad=False)]) 17 | 18 | x = gridX.unsqueeze(0).expand_as(u).float().detach() + u 19 | y = gridY.unsqueeze(0).expand_as(v).float().detach() + v 20 | 21 | # range -1 to 1 22 | x = 2 * x / (W - 1.0) - 1.0 23 | y = 2 * y / (H - 1.0) - 1.0 24 | # stacking X and Y 25 | grid = torch.stack((x, y), dim=3) 26 | # Sample pixels using bilinear interpolation. 27 | imgOut = F.grid_sample(img, grid, mode='bilinear', padding_mode='zeros', align_corners=True) 28 | 29 | mask = torch.ones_like(img, requires_grad=False) 30 | mask = F.grid_sample(mask, grid, mode='bilinear', padding_mode='zeros', align_corners=True) 31 | 32 | mask[mask < 0.9999] = 0 33 | mask[mask > 0] = 1 34 | 35 | return imgOut * (mask.detach()), mask.detach() 36 | 37 | 38 | class ModuleSoftsplat(torch.nn.Module): 39 | def __init__(self, strType='average'): 40 | super().__init__() 41 | 42 | self.strType = strType 43 | 44 | def forward(self, tenInput, tenFlow, tenMetric): 45 | return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) 46 | 47 | 48 | class fidelityGradCuda(nn.Module): 49 | def __init__(self): 50 | super(fidelityGradCuda, self).__init__() 51 | self.fWarp = ModuleSoftsplat() 52 | 53 | def forward(self, It: torch.Tensor, I0: torch.Tensor, F0t: torch.Tensor): 54 | self.device = I0.device 55 | It0, mask = backWarp(It, F0t) 56 | grad_ll = (I0 - It0) # grad(y=-0.5*(I0 - It0)^2, x=It0) 57 | 58 | totalGrad = grad_ll * mask 59 | warpGrad = self.fWarp(tenInput=totalGrad, tenFlow=F0t, tenMetric=None) 60 | 61 | return warpGrad 62 | 63 | 64 | class fidelityGradTorch(nn.Module): 65 | def __init__(self): 66 | super(fidelityGradTorch, self).__init__() 67 | 68 | def forward(self, It: torch.Tensor, I0: torch.Tensor, F0t: torch.Tensor): 69 | with torch.enable_grad(): 70 | It.requires_grad_() 71 | It0, mask = backWarp(It, F0t) 72 | loss = -(0.5 * (I0 * mask - It0 * mask) ** 2).sum() 73 | warpGrad = grad(loss, It, create_graph=True)[0] 74 | 75 | return warpGrad 76 | -------------------------------------------------------------------------------- /DBVI_8x/model/RRDBNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from model import block as B 4 | from configs.configTrain import configMain 5 | from lib.warp import fidelityGradCuda as fidelity 6 | # from lib.warp import fidelityGradTorch as fidelity 7 | 8 | from torch.nn import functional as F 9 | from functools import partial 10 | import math 11 | from model.invBlock.permute import TransConv1x1 12 | 13 | 14 | class subNet(B.BaseNet): 15 | def __init__(self, inCh=3, outCh=16, gradCh=8, gradFlowIn=4): 16 | super(subNet, self).__init__() 17 | 18 | midMap = 512 19 | mapping = [nn.Linear(128, midMap), nn.LeakyReLU(0.2, True)] 20 | for i in range(7): 21 | mapping.append(nn.Linear(midMap, midMap)) 22 | mapping.append(nn.LeakyReLU(0.2, True)) 23 | self.mapping = nn.Sequential(*mapping) 24 | 25 | self.convUV = nn.Sequential(B.conv(outCh, outCh, kernel_size=3, order='nac'), 26 | B.conv(outCh, gradCh * 2, kernel_size=1, order='nac', zeroInit=True)) 27 | self.invConv = TransConv1x1(gradCh) 28 | 29 | self.compConv = B.conv(gradCh, gradFlowIn, kernel_size=1, order='nac') 30 | 31 | self.inConv = B.conv(inCh, outCh, kernel_size=3, norm=False, act=False, order='cna') 32 | 33 | self.style_block = B.StyleBlock(ndBlock=2, outCh=outCh, midMap=midMap) 34 | 35 | self.convNext = nn.Sequential(B.conv(outCh, outCh, kernel_size=3, norm=True, act=True, order='nac'), 36 | B.conv(outCh, outCh, kernel_size=3, norm=True, act=True, order='nac') 37 | ) 38 | self.randomInitNet() 39 | 40 | self.toRGB = B.conv(outCh, 3, kernel_size=3, order='nac', zeroInit=True) 41 | 42 | self.toFlow = nn.Sequential(B.conv(outCh, outCh, kernel_size=3, order='nac'), 43 | B.conv(outCh, gradCh, kernel_size=1, order='nac', zeroInit=True)) 44 | 45 | def forward(self, I0, I1, It, fid0, fid1, F0t, F1t, fea, d01): 46 | n, c, h, w = It.shape 47 | 48 | styleCode = torch.zeros([n, 128], dtype=torch.float32, device=It.device).detach() 49 | 50 | mu_invVar = self.convUV(fea) 51 | mu, invVar = torch.chunk(mu_invVar, chunks=2, dim=1) 52 | invVar = torch.sigmoid(invVar) 53 | 54 | d01Fea, _ = self.invConv(d01) 55 | gradFlow = self.negGradFlow(Fea=d01Fea, mean=mu, invVar=invVar) 56 | gradFlow = self.compConv(gradFlow) 57 | 58 | xCode = self.inConv(torch.cat([I0, I1, It, fid0, fid1, F0t, F1t, fea, gradFlow], dim=1)) 59 | 60 | affineCode = self.mapping(styleCode) 61 | 62 | backbone = self.style_block(xCode, affineCode) 63 | 64 | feaNext = self.convNext(backbone) 65 | 66 | ItOut = self.toRGB(feaNext) # [-1, 1] 67 | 68 | d01New = self.toFlow(feaNext) 69 | ItNew = It + ItOut 70 | 71 | return feaNext, ItNew, d01New 72 | 73 | def negGradFlow(self, Fea, mean=0.0, invVar=1.0): 74 | ngrad = (Fea - mean) * invVar 75 | gradFlow, _ = self.invConv(ngrad, rev=True) 76 | 77 | return gradFlow 78 | 79 | 80 | class IMRRDBNet(B.BaseNet): 81 | def __init__(self, cfg: configMain): 82 | super(IMRRDBNet, self).__init__() 83 | self.netScale = 8 84 | 85 | I01Ch = 3 + 3 86 | ItCh = 3 87 | fidICh = 3 + 3 # grad(I0 ,WIt), grad(I0 ,WIt) 88 | flowCh = 2 + 2 89 | gradFlowCh = 16 + 16 90 | gradFlowIn = 4 91 | FeaCh = [16, 16, 16, 16, 16] # Fea for next level 92 | 93 | self.scale = cfg.model.scale 94 | for idx, s in enumerate(self.scale): 95 | allInCh = I01Ch + ItCh + fidICh + flowCh + FeaCh[idx] + gradFlowIn 96 | allOutCh = FeaCh[idx + 1] 97 | 98 | self.add_module(f'fidelityGrad_{idx}', fidelity()) 99 | self.add_module(f'subNet_{idx}', subNet(inCh=allInCh, outCh=allOutCh, 100 | gradCh=gradFlowCh, gradFlowIn=gradFlowIn)) 101 | 102 | if cfg.resume: 103 | self.initPreweight(cfg.path.ckpt) 104 | 105 | def forward(self, batchDict, flowNet, t: float = None): 106 | In, I0, I1, I2 = batchDict['In'], batchDict['I0'], batchDict['I1'], batchDict['I2'] 107 | channelMean = sum([Ik.mean(dim=[2, 3], keepdim=True) for Ik in [In, I0, I1, I2]]) / 4.0 108 | In, I0, I1, I2 = [Ik - channelMean for Ik in [In, I0, I1, I2]] 109 | 110 | IB, IC, IH, IW = In.shape 111 | In, I0, I1, I2 = [self.padToScale(I, self.netScale) for I in [In, I0, I1, I2]] 112 | 113 | if t is None: 114 | t = batchDict['t'].view([In.shape[0], 1, 1, 1]) 115 | else: 116 | t = torch.tensor(data=t, device=In.device, dtype=torch.float32).view([In.shape[0], 1, 1, 1]) 117 | 118 | N, C, H, W = I0.shape 119 | 120 | F0t, F1t = self.getFt(In, I0, I1, I2, t, flowNet) 121 | 122 | output = [] 123 | level = len(self.scale) 124 | 125 | fea = torch.zeros([N, 16, H, W], dtype=torch.float32, device=I0.device).detach() 126 | It = torch.zeros([N, C, H, W], dtype=torch.float32, device=I0.device).detach() 127 | df = torch.zeros([N, 32, H, W], dtype=torch.float32, device=I0.device).detach() 128 | 129 | for idx, l in enumerate(range(level)): 130 | 131 | net = getattr(self, f'subNet_{idx}') 132 | getFidGrad = getattr(self, f'fidelityGrad_{idx}') 133 | 134 | fid0 = getFidGrad(It=It, I0=I0, F0t=F0t) 135 | fid1 = getFidGrad(It=It, I0=I1, F0t=F1t) 136 | 137 | fea, It, d01t = net(I0, I1, It, fid0 * (1.0 - t), fid1 * t, F0t, F1t, fea, df) 138 | df = df + d01t 139 | 140 | d0t, d1t = torch.chunk(d01t[:, 0:4, :, :], chunks=2, dim=1) 141 | F0t = F0t + d0t 142 | F1t = F1t + d1t 143 | output.append(It[:, :, 0:IH, 0:IW] + channelMean) 144 | 145 | return output 146 | 147 | def adap2Net(self, tensor: torch.Tensor): 148 | Height, Width = tensor.size(2), tensor.size(3) 149 | 150 | Height_ = int(math.floor(math.ceil(Height / self.netScale) * self.netScale)) 151 | Width_ = int(math.floor(math.ceil(Width / self.netScale) * self.netScale)) 152 | 153 | if any([Height_ != Height, Width_ != Width]): 154 | tensor = F.pad(tensor, [0, Width_ - Width, 0, Height_ - Height]) 155 | 156 | return tensor 157 | 158 | def getWeight(self, pathPreWeight: str = None): 159 | keyName = 'gNet' 160 | checkpoints = torch.load(pathPreWeight, map_location=torch.device('cpu')) 161 | try: 162 | weightDict = checkpoints[keyName] 163 | except Exception as e: 164 | weightDict = checkpoints['model_state_dict'] 165 | return weightDict 166 | 167 | def setRequiresGrad(self, nets, requires_grad=False): 168 | if not isinstance(nets, list): 169 | nets = [nets] 170 | for net in nets: 171 | if net is not None: 172 | for param in net.parameters(): 173 | param.requires_grad = requires_grad 174 | 175 | @torch.no_grad() 176 | def getFt(self, In, I0, I1, I2, t, flowNet): 177 | 178 | device0 = In.device 179 | N, C, H, W = I0.shape 180 | flowNet.eval() 181 | 182 | if (H * W) >= (1080 * 2048): 183 | down2x = partial(F.interpolate, scale_factor=0.5, mode='bilinear', align_corners=True) 184 | [In, I0, I1, I2] = [down2x(I) for I in [In, I0, I1, I2]] 185 | 186 | F0n = flowNet(I0, In) 187 | F01 = flowNet(I0, I1) 188 | a0 = (F01 + F0n) / 2.0 189 | b0 = (F01 - F0n) / 2.0 190 | F0t = a0 * (t ** 2) + b0 * t 191 | 192 | F12 = flowNet(I1, I2) 193 | F10 = flowNet(I1, I0) 194 | a1 = (F10 + F12) / 2.0 195 | b1 = (F10 - F12) / 2.0 196 | F1t = a1 * ((1 - t) ** 2) + b1 * (1 - t) 197 | 198 | if (H * W) >= (1080 * 2048): 199 | up2x = partial(F.interpolate, scale_factor=2.0, mode='bilinear', align_corners=True) 200 | [F0t, F1t] = [up2x(Fx * 2.0) for Fx in [F0t, F1t]] 201 | 202 | return F0t.to(device0).detach(), F1t.to(device0).detach() -------------------------------------------------------------------------------- /DBVI_8x/model/invBlock/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_8x/model/invBlock/__init__.py -------------------------------------------------------------------------------- /DBVI_8x/model/invBlock/permute.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | class TransConv1x1(nn.Module): 8 | def __init__(self, inCh): 9 | super(TransConv1x1, self).__init__() 10 | self.w_shape = [inCh, inCh] 11 | w_init = np.linalg.qr(np.random.randn(*self.w_shape))[0].astype(np.float32) 12 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) 13 | 14 | def get_weight(self, x, rev): 15 | b, c, h, w = x.shape 16 | dlogdet = torch.slogdet(self.weight)[1] * h * w # slogdet(A) = torch.log(torch.abs(torch.det(A))) 17 | 18 | if not rev: 19 | weight = self.weight 20 | else: 21 | weight = self.weight.t() 22 | 23 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 24 | 25 | def forward(self, x, logdet=None, rev=False): 26 | """ 27 | log-det = log|abs(|W|)| * pixels 28 | """ 29 | logdet = 0.0 if logdet is None else logdet 30 | weight, dlogdet = self.get_weight(x, rev) 31 | z = F.conv2d(x, weight) 32 | if not rev: 33 | logdet = logdet + dlogdet 34 | else: 35 | logdet = logdet - dlogdet 36 | 37 | return z, logdet 38 | 39 | 40 | class InvConv1x1(nn.Module): 41 | def __init__(self, inCh): 42 | super(InvConv1x1, self).__init__() 43 | self.w_shape = [inCh, inCh] 44 | w_init = np.linalg.qr(np.random.randn(*self.w_shape))[0].astype(np.float32) 45 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) 46 | 47 | def get_weight(self, x, rev): 48 | b, c, h, w = x.shape 49 | dlogdet = torch.slogdet(self.weight)[1] * h * w # slogdet(A) = torch.log(torch.abs(torch.det(A))) 50 | 51 | if not rev: 52 | weight = self.weight 53 | else: 54 | weight = torch.inverse(self.weight) 55 | 56 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 57 | 58 | def forward(self, x, logdet=None, rev=False): 59 | """ 60 | log-det = log|abs(|W|)| * pixels 61 | """ 62 | logdet = 0.0 if logdet is None else logdet 63 | weight, dlogdet = self.get_weight(x, rev) 64 | z = F.conv2d(x, weight) 65 | if not rev: 66 | logdet = logdet + dlogdet 67 | else: 68 | logdet = logdet - dlogdet 69 | 70 | return z, logdet 71 | 72 | 73 | class Permute2d(nn.Module): 74 | def __init__(self, inCh, shuffle=True): 75 | super().__init__() 76 | self.inCh = inCh 77 | # self.indices = torch.arange(self.inCh - 1, -1, -1, dtype=torch.long) 78 | # self.indices_inverse = torch.zeros(self.inCh, dtype=torch.long) 79 | self.register_buffer('indices', torch.arange(self.inCh - 1, -1, -1, dtype=torch.long)) 80 | self.register_buffer('indices_inverse', torch.zeros(self.inCh, dtype=torch.long)) 81 | 82 | for i in range(self.inCh): 83 | self.indices_inverse[self.indices[i]] = i 84 | 85 | if shuffle: 86 | self.reset_indices() 87 | 88 | def reset_indices(self): 89 | shuffle_idx = torch.randperm(self.indices.shape[0]) 90 | self.indices = self.indices[shuffle_idx] 91 | 92 | for i in range(self.inCh): 93 | self.indices_inverse[self.indices[i]] = i 94 | 95 | def forward(self, x, logdet, rev=False): 96 | assert len(x.size()) == 4 97 | 98 | if not rev: 99 | x = x[:, self.indices, :, :] 100 | else: 101 | x = x[:, self.indices_inverse, :, :] 102 | 103 | return x, logdet 104 | 105 | 106 | class InvConvLU1x1(nn.Module): # is not recommended 107 | def __init__(self, inCh): 108 | super(InvConvLU1x1, self).__init__() 109 | w_shape = [inCh, inCh] 110 | w_init = torch.qr(torch.randn(*w_shape))[0] 111 | 112 | p, lower, upper = torch.lu_unpack(*torch.lu(w_init)) 113 | s = torch.diag(upper) 114 | sign_s = torch.sign(s) 115 | log_s = torch.log(torch.abs(s)) 116 | upper = torch.triu(upper, 1) 117 | l_mask = torch.tril(torch.ones(w_shape), -1) 118 | eye = torch.eye(*w_shape) 119 | 120 | self.register_buffer('p', p) # .cuda() will work only on register_buffer 121 | self.register_buffer('sign_s', sign_s) 122 | self.register_buffer('l_mask', l_mask) 123 | self.register_buffer('eye', eye) 124 | 125 | self.lower = nn.Parameter(lower) 126 | self.log_s = nn.Parameter(log_s) 127 | self.upper = nn.Parameter(upper) 128 | 129 | self.w_shape = w_shape 130 | 131 | def get_weight(self, x, rev): 132 | b, c, h, w = x.shape 133 | 134 | lower = self.lower * self.l_mask + self.eye 135 | 136 | u = self.upper * self.l_mask.transpose(0, 1).contiguous() 137 | u += torch.diag(self.sign_s * torch.exp(self.log_s)) 138 | 139 | dlogdet = torch.sum(self.log_s) * h * w 140 | if not rev: 141 | weight = torch.matmul(self.p, torch.matmul(lower, u)) 142 | else: 143 | u_inv = torch.inverse(u) 144 | l_inv = torch.inverse(lower) 145 | p_inv = torch.inverse(self.p) 146 | 147 | weight = torch.matmul(u_inv, torch.matmul(l_inv, p_inv)) 148 | 149 | return weight.view(self.w_shape[0], self.w_shape[1], 1, 1), dlogdet 150 | 151 | def forward(self, x, logdet=None, rev=False): 152 | """ 153 | log-det = log|abs(|W|)| * pixels 154 | """ 155 | weight, dlogdet = self.get_weight(x, rev) 156 | z = F.conv2d(x, weight) 157 | 158 | if not rev: 159 | logdet = logdet + dlogdet 160 | else: 161 | logdet = logdet - dlogdet 162 | return z, logdet -------------------------------------------------------------------------------- /DBVI_8x/model/module_util.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.nn import init 3 | 4 | 5 | def randomInitNet(net_l, iniType='kaiming', scale=1.0): 6 | if not isinstance(net_l, list): 7 | net_l = [net_l] 8 | for net in net_l: 9 | for m in net.modules(): 10 | if any([isinstance(m, nn.Conv2d), isinstance(m, nn.ConvTranspose2d), isinstance(m, nn.Linear)]): 11 | if iniType == 'normal': 12 | init.normal_(m.weight, 0.0, 0.2) 13 | elif iniType == 'xavier': 14 | init.xavier_normal_(m.weight, gain=0.2) 15 | elif iniType == 'kaiming': 16 | init.kaiming_normal_(m.weight, a=0.2, mode='fan_in', nonlinearity='leaky_relu') 17 | elif iniType == 'orthogonal': 18 | init.orthogonal_(m.weight, gain=0.2) 19 | elif iniType == 'default': 20 | pass 21 | 22 | if m.bias is not None: 23 | init.constant_(m.bias, 0.0) 24 | m.weight.data *= scale 25 | elif any([isinstance(m, nn.InstanceNorm2d), isinstance(m, nn.LocalResponseNorm), 26 | isinstance(m, nn.BatchNorm2d), isinstance(m, nn.GroupNorm)]): 27 | try: 28 | init.constant_(m.weight, 1.0) 29 | init.constant_(m.bias, 0.0) 30 | except Exception as e: 31 | pass 32 | -------------------------------------------------------------------------------- /DBVI_8x/output/output: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/DBVI_8x/output/output -------------------------------------------------------------------------------- /DBVI_8x/readme.md: -------------------------------------------------------------------------------- 1 | # 8x Interpolation 2 | ## 1. Preparing Dataset 3 | The training/testing datasets we used can be either downloaded from following links and processed with the codes in ../mkDataset/forXXX/ by changing each 'Path/to/' accordingly, or directly downloaded from [here](https://pan.baidu.com/s/1meK6lCXrwrBQ3KFgos1aDw?pwd=2022)(password:2022)(the ready-to-use lmdb files) 4 | #### Links: 5 | [GoPro](https://drive.google.com/file/d/1rJTmM9_mLCNzBUUhYIGldBYgup279E_f/view), 6 | [X4K1000FPS](https://github.com/JihyongOh/XVFI#X4K1000FPS), 7 | [Adobe240(official)](http://www.cs.ubc.ca/labs/imager/tr/2017/DeepVideoDeblurring/DeepVideoDeblurring_Dataset_Original_High_FPS_Videos.zip), 8 | [Adobe240_lmdb(selected and used in paper)](https://pan.baidu.com/s/1E5TAUAks_AzWEcmgwuR8oA?pwd=2022)(password:2022) 9 | 10 | The processed files should be put in ./datasets and originized as: 11 | ``` 12 | datasets/ 13 | GoPro/ 14 | gopro_test_lmdb/ 15 | data.mdb 16 | lock.mdb 17 | sample.pkl 18 | gopro_train_lmdb/ 19 | data.mdb 20 | lock.mdb 21 | sample.pkl 22 | ``` 23 | 24 | ## 2. Training 25 | ### Training with single gpu: 26 | (1) Set the name of train set(GoPro/X4K1000FPS), whether resume or not, dir of checkpoints and the name of pretrained weights(only needed if resume is true) in configs/configTrain.py(line50~54). 27 | 28 | (2) Open a terminal and run ifconfig to get your ip address: XXX.XXX.XXX.XXX 29 | 30 | (3) python train.py --initNode=XXX.XXX.XXX.XXX 31 | 32 | ### Distributed training with muli-gpus(16GPU,2Nodes) on cluser managed by [slurm](https://slurm.schedmd.com/quickstart_admin.html): 33 | (1) Set the name of train set(GoPro/X4K1000FPS), whether resume or not, dir of checkpoints and the name of pretrained weights(only needed if resume is true) in configs/configTrain.py(line50~54). 34 | 35 | (2) Set the name of part and nodes in cluser, number and index of gpus/cpus per-node and so on in runTrain.py(line3~14). 36 | 37 | The example in runTrain.py is running on one part named Pixel, two nodes named 'SH-IDC1-10-5-39-55' and 'SH-IDC1-10-5-31-54', with 8 gpus per-node. 38 | 39 | (3) python runTrain.py 40 | 41 | ## 4. Testing with Pretrained Models 42 | 43 | [Models](https://pan.baidu.com/s/1pxRFu29r56nDLgIHqFzHBA) pretrained on GoPro (password:2022) 44 | [Models](https://pan.baidu.com/s/1bXUaHN_n1F2YL8N9V5oMqw) pretrained on X4K1000FPS (password:2022) 45 | Download models and put them under ./output/ 46 | 47 | ### Testing with single gpu: 48 | (1) Set the name of test set, dir of checkpoints and the name of pretrained weights in configs/configTest.py(line50~54). 49 | 50 | (2) Open a terminal and run ifconfig to get your ip address: XXX.XXX.XXX.XXX 51 | 52 | (3) python test.py --initNode=XXX.XXX.XXX.XXX 53 | 54 | (4) PNG results on X4K1000FPS is provided [here](https://pan.baidu.com/s/1Quw5ToZ2itVmE-B0v6PKBA)(password:2022) for users with limited GPU memory (22GB at least). The results evaluated on quantized PNGs may be a little different from those reported in paper (less than 0.1db) 55 | 56 | ### Distributed testing with muli-gpus(10) on cluser managed by [slurm](https://slurm.schedmd.com/quickstart_admin.html): 57 | (1) Set the name of test set, dir of checkpoints and the name of pretrained weights in configs/configTest.py(line50~54). 58 | 59 | (2) Set the name of part and nodes in cluser, number and index of gpus/cpus per-node in runTest.py(line3~14). 60 | 61 | The example in runTest.py is running on one part named Pixel with two nodes named 'SH-IDC1-10-5-39-55' and 'SH-IDC1-10-5-31-38' and 5 gpus per-node. 62 | 63 | (3) python runTest.py 64 | 65 | -------------------------------------------------------------------------------- /DBVI_8x/runTest.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | jobName = 'Test' 4 | part = 'Pixel' 5 | 6 | freeNodes = ['SH-IDC1-10-5-39-55','SH-IDC1-10-5-31-38'] 7 | # gpuDict = "\"{\'SH-IDC1-10-5-31-37\': \'0,1,2,3,4,5,6,7\', \'SH-IDC1-10-5-31-38\': \'0,1,2,3,4,5,6,7\'}\"" 8 | gpuDict = "\"{\'SH-IDC1-10-5-39-55\': \'0,1,2,3,4\', \'SH-IDC1-10-5-31-38\': \'0,1,2,3,4\'}\"" 9 | 10 | 11 | ntaskPerNode = 5 # number of GPUs per nodes 12 | cpus_per_task = 4 13 | reuseGPU = 1 14 | envDistributed = 1 15 | 16 | nodeNum = len(freeNodes) 17 | nTasks = ntaskPerNode * nodeNum if envDistributed else 1 18 | nodeList = ','.join(freeNodes) 19 | initNode = freeNodes[0] 20 | 21 | 22 | scrip = 'test' 23 | config = 'configTest' 24 | 25 | 26 | def runDist(): 27 | pyCode = [] 28 | pyCode.append('python') 29 | pyCode.append('-m') 30 | pyCode.append(scrip) 31 | pyCode.append('--initNode {}'.format(initNode)) 32 | pyCode.append('--config {}'.format(config)) 33 | pyCode.append('--gpuList {}'.format(gpuDict)) 34 | pyCode.append('--reuseGPU {}'.format(reuseGPU)) 35 | pyCode.append('--expName {}'.format(jobName)) 36 | pyCode = ' '.join(pyCode) 37 | 38 | srunCode = [] 39 | srunCode.append('srun') 40 | srunCode.append('--gres=gpu:{}'.format(ntaskPerNode)) if not (reuseGPU and envDistributed) else print( 41 | 'Reuse GPUS of {}'.format(gpuDict)) 42 | srunCode.append('--job-name={}'.format(jobName)) 43 | srunCode.append('--partition={}'.format(part)) 44 | srunCode.append('--nodelist={}'.format(nodeList)) if freeNodes is not None else print('Get node by slurm') 45 | srunCode.append('--ntasks={}'.format(nTasks)) 46 | srunCode.append('--nodes={}'.format(nodeNum)) 47 | srunCode.append(f'--ntasks-per-node={ntaskPerNode}') if envDistributed else print( 48 | 'ntasks-per-node is 1') 49 | srunCode.append(f'--cpus-per-task={cpus_per_task}') 50 | srunCode.append('--kill-on-bad-exit=1') 51 | srunCode.append('--mpi=pmi2') 52 | srunCode.append(pyCode) 53 | srunCode = ' '.join(srunCode) 54 | print(srunCode) 55 | os.system(srunCode) 56 | 57 | 58 | if __name__ == '__main__': 59 | runDist() 60 | -------------------------------------------------------------------------------- /DBVI_8x/runTrain.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | jobName = '8x_GoPro' 4 | # jobName = '8x_XVFI' 5 | 6 | part = 'Pixel' 7 | 8 | 9 | freeNodes = ['SH-IDC1-10-5-39-55', 'SH-IDC1-10-5-31-54'] 10 | gpuDict = "\"{\'SH-IDC1-10-5-39-55\': \'0,1,2,3,4,5,6,7\', \'SH-IDC1-10-5-31-54\': \'0,1,2,3,4,5,6,7\'}\"" 11 | 12 | 13 | ntaskPerNode = 8 # number of GPUs per nodes 14 | cpus_per_task = 7 # number of CPUs per task 15 | reuseGPU = 1 16 | envDistributed = 1 17 | 18 | nodeNum = len(freeNodes) 19 | nTasks = ntaskPerNode * nodeNum if envDistributed else 1 20 | nodeList = ','.join(freeNodes) 21 | initNode = freeNodes[0] 22 | 23 | scrip = 'train' 24 | config = 'configTrain' 25 | 26 | 27 | def runDist(): 28 | pyCode = [] 29 | pyCode.append('python') 30 | pyCode.append('-m') 31 | pyCode.append(scrip) 32 | pyCode.append('--initNode {}'.format(initNode)) 33 | pyCode.append('--config {}'.format(config)) 34 | pyCode.append('--gpuList {}'.format(gpuDict)) 35 | pyCode.append('--reuseGPU {}'.format(reuseGPU)) 36 | pyCode.append('--expName {}'.format(jobName)) 37 | pyCode = ' '.join(pyCode) 38 | 39 | srunCode = [] 40 | srunCode.append('srun') 41 | srunCode.append('--gres=gpu:{}'.format(ntaskPerNode)) 42 | srunCode.append('--job-name={}'.format(jobName)) 43 | srunCode.append('--partition={}'.format(part)) 44 | srunCode.append('--nodelist={}'.format(nodeList)) if freeNodes is not None else print('Get node by slurm') 45 | srunCode.append('--ntasks={}'.format(nTasks)) 46 | srunCode.append('--nodes={}'.format(nodeNum)) 47 | srunCode.append(f'--ntasks-per-node={ntaskPerNode}') if envDistributed else print( 48 | 'ntasks-per-node is 1') 49 | srunCode.append(f'--cpus-per-task={cpus_per_task}') 50 | srunCode.append('--kill-on-bad-exit=1') 51 | srunCode.append('--mpi=pmi2') 52 | srunCode.append(pyCode) 53 | 54 | srunCode = ' '.join(srunCode) 55 | print(srunCode) 56 | 57 | os.system(srunCode) 58 | 59 | 60 | if __name__ == '__main__': 61 | runDist() 62 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Ocean2022 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Bayesian Video Frame Interpolation (ECCV2022) 2 | [[Paper&supp](https://www.ecva.net/papers/eccv_2022/papers_ECCV/html/1287_ECCV_2022_paper.php)], [[Demo](https://youtu.be/8KvFwN1_3DY)], [[Presentation](https://youtu.be/2quo-k0PcQ4)] 3 | 4 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-bayesian-video-frame-interpolation/video-frame-interpolation-on-gopro)](https://paperswithcode.com/sota/video-frame-interpolation-on-gopro?p=deep-bayesian-video-frame-interpolation) 5 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-bayesian-video-frame-interpolation/video-frame-interpolation-on-x4k1000fps)](https://paperswithcode.com/sota/video-frame-interpolation-on-x4k1000fps?p=deep-bayesian-video-frame-interpolation) 6 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-bayesian-video-frame-interpolation/video-frame-interpolation-on-davis)](https://paperswithcode.com/sota/video-frame-interpolation-on-davis?p=deep-bayesian-video-frame-interpolation) 7 | 8 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-bayesian-video-frame-interpolation/video-frame-interpolation-on-snu-film-easy)](https://paperswithcode.com/sota/video-frame-interpolation-on-snu-film-easy?p=deep-bayesian-video-frame-interpolation) 9 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-bayesian-video-frame-interpolation/video-frame-interpolation-on-snu-film-medium)](https://paperswithcode.com/sota/video-frame-interpolation-on-snu-film-medium?p=deep-bayesian-video-frame-interpolation) 10 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-bayesian-video-frame-interpolation/video-frame-interpolation-on-snu-film-hard)](https://paperswithcode.com/sota/video-frame-interpolation-on-snu-film-hard?p=deep-bayesian-video-frame-interpolation) 11 | [![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/deep-bayesian-video-frame-interpolation/video-frame-interpolation-on-snu-film-extreme)](https://paperswithcode.com/sota/video-frame-interpolation-on-snu-film-extreme?p=deep-bayesian-video-frame-interpolation) 12 | 13 | ## 1. Requirements 14 | 15 | 1) cuda 9.0, cudnn7.6.5 16 | 17 | 2) python 3.6.9 18 | 19 | 3) pytorch 1.8.1 20 | 21 | 4) numpy 1.17.2 22 | 23 | 5) cupy-90 24 | 25 | 6) tqdm 26 | 27 | 7) gcc 5.4.0 28 | 29 | 8) cmake 3.16.0 30 | 31 | 9) opencv_contrib_python 32 | 33 | 10) [Apex](https://github.com/NVIDIA/apex) 34 | 35 | 11) For distributed training with multi-gpus on cluster: slurm 15.08.11 36 | 37 | 38 | ## 2. How to use 39 | [For 8x interpolation](https://github.com/Oceanlib/DBVI/tree/main/DBVI_8x) 40 | 41 | [For 2x interpolation](https://github.com/Oceanlib/DBVI/tree/main/DBVI_2x) 42 | 43 | ## 3. Citation 44 | ``` 45 | @inproceedings{DBVI2022, 46 | title={Deep Bayesian Video Frame Interpolation}, 47 | author={Yu, Zhiyang and Zhang, Yu and Xiang, Xujie and Zou, Dongqing and Chen, Xijun and Ren, Jimmy S}, 48 | booktitle={European Conference on Computer Vision}, 49 | pages={144--160}, 50 | year={2022}, 51 | organization={Springer} 52 | } 53 | ``` 54 | 55 | ### 4. Reference code base 56 | [[opencv_torchvision](https://github.com/hityzy1122/opencv_transforms_torchvision)], 57 | [[ESRGAN](https://github.com/xinntao/ESRGAN)], 58 | [[CAM-Net](https://github.com/niopeng/CAM-Net/tree/main/code)], 59 | [[SoftSplat](https://github.com/sniklaus/softmax-splatting)], 60 | [[DeepView](https://github.com/Findeton/deepview)], 61 | [[FLAVR](https://github.com/tarun005/FLAVR)], 62 | [[superSlomo](https://github.com/avinashpaliwal/Super-SloMo)], 63 | [[QVI](https://sites.google.com/view/xiangyuxu/qvi_nips19)] 64 | 65 | -------------------------------------------------------------------------------- /mkDataset/forAdobe/getTestList.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import re 4 | import pickle 5 | from lib import fileTool as FT 6 | from lib.dataUtils import sortFunc 7 | 8 | 9 | def extractKey(name: str): 10 | parts = Path(name).parts 11 | key = f'{parts[-2]}/{parts[-1]}' 12 | return key 13 | 14 | 15 | def sortFunc(name: str): 16 | idx = re.search(r'(\d*).png', str(name)).group(1) 17 | return int(idx) 18 | 19 | 20 | def genSampleTest(srcDir, samplePath, numIter): 21 | idxV = [0, numIter + 1, 2 * numIter + 2, 3 * numIter + 3] 22 | idxT = [numIter + 1 + i + 1 for i in range(numIter)] 23 | idxVT = idxV + idxT 24 | 25 | keyV = ['In', 'I0', 'I1', 'I2'] 26 | keyT = [f'It{i + 1}' for i in range(numIter)] 27 | keyVT = keyV + keyT 28 | 29 | allSubDirs = FT.getSubDirs(srcDir) 30 | # allSamplesNames = [] 31 | allSamples = [] 32 | for subDir in allSubDirs: 33 | allFrames = FT.getAllFiles(subDir, 'png') 34 | allFrames.sort(key=sortFunc) 35 | 36 | for startIdx in range(0, len(allFrames), numIter + 1): 37 | sampleDict = {} 38 | endIdx = startIdx + numIter * 3 + 4 39 | if endIdx <= len(allFrames): 40 | asample = allFrames[startIdx:endIdx] 41 | asample = [extractKey(i) for i in asample] 42 | for idx, key in zip(idxVT, keyVT): 43 | sampleDict[key] = asample[idx] 44 | allSamples.append(sampleDict) 45 | 46 | # elif endIdx != len(allFrames) + 1: 47 | # asample = allFrames[-(numIter * 3 + 4)::] 48 | # asample = [extractKey(i) for i in asample] 49 | # for idx, key in zip(idxVT, keyVT): 50 | # sampleDict[key] = asample[idx] 51 | # allSamples.append(sampleDict) 52 | # break 53 | # else: 54 | # break 55 | with open(samplePath, 'wb') as f: 56 | pickle.dump(allSamples, f) 57 | 58 | 59 | if __name__ == '__main__': 60 | numIter = 7 61 | srcDir = 'Path/to/frames' 62 | samplePath = 'Path/to/Adobe/Adobe_lmdb/sample.pkl' 63 | 64 | genSampleTest(srcDir=srcDir, samplePath=samplePath, numIter=numIter) 65 | -------------------------------------------------------------------------------- /mkDataset/forAdobe/png2lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | from lib import fileTool as FT 5 | from tqdm import tqdm 6 | import os 7 | import lmdb 8 | import pickle 9 | import cv2 10 | from pathlib import Path 11 | 12 | 13 | def extractKey(name: str): 14 | parts = Path(name).parts 15 | key = f'{parts[-2]}/{parts[-1]}' 16 | return key 17 | 18 | 19 | def png2LMDB(srcDir, lmdb_path): 20 | allPNGs = FT.getAllFiles(srcDir, 'png') 21 | 22 | pbar = tqdm(total=len(allPNGs)) 23 | 24 | isdir = os.path.isdir(lmdb_path) 25 | write_frequency = 1000 26 | db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, 27 | meminit=False, map_async=True) 28 | txn = db.begin(write=True) 29 | for idx, aPNG in enumerate(allPNGs): 30 | key = extractKey(aPNG) 31 | 32 | img = cv2.imread(aPNG) 33 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 34 | 35 | sample = pickle.dumps(img, protocol=4) 36 | txn.put(u'{}'.format(key).encode('ascii'), sample) 37 | if idx % write_frequency == 0: 38 | txn.commit() 39 | txn = db.begin(write=True) 40 | pbar.update(1) 41 | 42 | txn.commit() 43 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 44 | with db.begin(write=True) as txn: 45 | txn.put(b'__keys__', pickle.dumps(keys, protocol=4)) 46 | txn.put(b'__len__', pickle.dumps(len(keys), protocol=4)) 47 | 48 | print("Flushing database ...") 49 | db.sync() 50 | db.close() 51 | 52 | 53 | if __name__ == '__main__': 54 | srcDir = 'Path/to/frames' 55 | lmdb_path = 'Path/to/Adobe/Adobe_lmdb' 56 | png2LMDB(srcDir, lmdb_path) 57 | -------------------------------------------------------------------------------- /mkDataset/forDavis/getTestList.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | import pickle 4 | from lib import fileTool as FT 5 | 6 | 7 | def sortFunc(name: str): 8 | idx = re.search(r'(\d*).jpg', str(name)).group(1) 9 | return int(idx) 10 | 11 | 12 | def extractKey(name: str): 13 | parent = Path(name).parent.name 14 | fileName = Path(name).name 15 | key = f'{parent}/{fileName}' 16 | return key 17 | 18 | 19 | def genSampleTest(srcDir, samplePath, numIter): 20 | allSubDirs = FT.getSubDirs(srcDir) 21 | allSamples = [] 22 | 23 | for subDir in allSubDirs: 24 | allFrames = FT.getAllFiles(subDir, 'jpg') 25 | allFrames.sort(key=sortFunc) 26 | 27 | for startIdx in range(0, len(allFrames) - 6, 2): 28 | endIdx = startIdx + numIter * 3 + 4 29 | if endIdx <= len(allFrames): 30 | sampleDict={} 31 | asample = allFrames[startIdx:endIdx] 32 | sampleDict['In'] = extractKey(asample[0]) 33 | sampleDict['I0'] = extractKey(asample[2]) 34 | sampleDict['I1'] = extractKey(asample[4]) 35 | sampleDict['I2'] = extractKey(asample[6]) 36 | sampleDict['It'] = extractKey(asample[3]) 37 | 38 | allSamples.append(sampleDict) 39 | 40 | with open(samplePath, 'wb') as f: 41 | pickle.dump(allSamples, f) 42 | 43 | 44 | if __name__ == '__main__': 45 | numIter = 1 46 | srcDir = 'Path/to/davis_frame' 47 | samplePath = 'Path/to/davis/davis_lmdb/sample.pkl' 48 | 49 | genSampleTest(srcDir=srcDir, samplePath=samplePath, numIter=numIter) 50 | -------------------------------------------------------------------------------- /mkDataset/forDavis/png2lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | from lib import fileTool as FT 5 | from tqdm import tqdm 6 | import os 7 | import lmdb 8 | import pickle 9 | import cv2 10 | from pathlib import Path 11 | 12 | 13 | def extractKey(name: str): 14 | parent = Path(name).parent.name 15 | fileName = Path(name).name 16 | key = f'{parent}/{fileName}' 17 | return key 18 | 19 | 20 | def png2LMDB(srcDir, lmdb_path): 21 | allPNGs = FT.getAllFiles(srcDir, 'jpg') 22 | pbar = tqdm(total=len(allPNGs)) 23 | 24 | isdir = os.path.isdir(lmdb_path) 25 | write_frequency = 1000 26 | db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, 27 | meminit=False, map_async=True) 28 | txn = db.begin(write=True) 29 | for idx, aPNG in enumerate(allPNGs): 30 | key = extractKey(aPNG) 31 | 32 | img = cv2.imread(aPNG) 33 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 34 | 35 | sample = pickle.dumps(img, protocol=4) 36 | txn.put(u'{}'.format(key).encode('ascii'), sample) 37 | if idx % write_frequency == 0: 38 | txn.commit() 39 | txn = db.begin(write=True) 40 | pbar.update(1) 41 | 42 | txn.commit() 43 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 44 | with db.begin(write=True) as txn: 45 | txn.put(b'__keys__', pickle.dumps(keys, protocol=4)) 46 | txn.put(b'__len__', pickle.dumps(len(keys), protocol=4)) 47 | 48 | print("Flushing database ...") 49 | db.sync() 50 | db.close() 51 | 52 | 53 | if __name__ == '__main__': 54 | srcDir = 'Path/to/davis_frame' 55 | lmdb_path = 'Path/to/davis/davis_lmdb' 56 | png2LMDB(srcDir, lmdb_path) 57 | -------------------------------------------------------------------------------- /mkDataset/forGoPro/getTestList.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import re 4 | import pickle 5 | from lib import fileTool as FT 6 | from lib.dataUtils import sortFunc 7 | 8 | 9 | def extractKey(name: str): 10 | parts = Path(name).parts 11 | key = f'{parts[-3]}/{parts[-2]}/{parts[-1]}' 12 | return key 13 | 14 | 15 | def sortFunc(name: str): 16 | idx = re.search(r'(\d*).png', str(name)).group(1) 17 | return int(idx) 18 | 19 | 20 | def genSampleTest(srcDir, samplePath, numIter): 21 | idxV = [0, numIter + 1, 2 * numIter + 2, 3 * numIter + 3] 22 | idxT = [numIter + 1 + i + 1 for i in range(numIter)] 23 | idxVT = idxV + idxT 24 | 25 | keyV = ['In', 'I0', 'I1', 'I2'] 26 | keyT = [f'It{i + 1}' for i in range(numIter)] 27 | keyVT = keyV + keyT 28 | 29 | allSubDirs = FT.getSubDirs(srcDir) 30 | # allSamplesNames = [] 31 | allSamples = [] 32 | for subDir in allSubDirs: 33 | allFrames = FT.getAllFiles(subDir, 'png') 34 | allFrames.sort(key=sortFunc) 35 | 36 | for startIdx in range(0, len(allFrames), numIter+1): 37 | sampleDict = {} 38 | endIdx = startIdx + numIter * 3 + 4 39 | if endIdx <= len(allFrames): 40 | asample = allFrames[startIdx:endIdx] 41 | asample = [extractKey(i) for i in asample] 42 | for idx, key in zip(idxVT, keyVT): 43 | sampleDict[key] = asample[idx] 44 | allSamples.append(sampleDict) 45 | 46 | # elif endIdx != len(allFrames) + 1: 47 | # asample = allFrames[-(numIter * 3 + 4)::] 48 | # asample = [extractKey(i) for i in asample] 49 | # for idx, key in zip(idxVT, keyVT): 50 | # sampleDict[key] = asample[idx] 51 | # allSamples.append(sampleDict) 52 | # break 53 | # else: 54 | # break 55 | with open(samplePath, 'wb') as f: 56 | pickle.dump(allSamples, f) 57 | 58 | 59 | if __name__ == '__main__': 60 | numIter = 7 61 | srcDir = 'Path/to/GoPro/test' 62 | samplePath = 'Path/to/GoPro/gopro_test_lmdb/sample.pkl' 63 | 64 | genSampleTest(srcDir=srcDir, samplePath=samplePath, numIter=numIter) 65 | -------------------------------------------------------------------------------- /mkDataset/forGoPro/getTrainList.py: -------------------------------------------------------------------------------- 1 | import re 2 | import pickle 3 | import lib.fileTool as FT 4 | from pathlib import Path 5 | 6 | 7 | def extractKey(name: str): 8 | parts = Path(name).parts 9 | key = f'{parts[-3]}/{parts[-2]}/{parts[-1]}' 10 | return key 11 | 12 | 13 | def sortFunc(name: str): 14 | idx = re.search(r'(\d*).png', str(name)).group(1) 15 | return int(idx) 16 | 17 | 18 | def genSampleTrain(srcDir, samplePath, numIter=7): 19 | idxV = [0, numIter + 1, 2 * numIter + 2, 3 * numIter + 3] 20 | idxT = [numIter + 1 + i + 1 for i in range(numIter)] 21 | idxVT = idxV + idxT 22 | 23 | keyV = ['In', 'I0', 'I1', 'I2'] 24 | keyT = [f'It{i + 1}' for i in range(numIter)] 25 | keyVT = keyV + keyT 26 | 27 | allSubDirs = FT.getSubDirs(srcDir) 28 | allSamples = [] 29 | for subDir in allSubDirs: 30 | allFrames = FT.getAllFiles(subDir, 'png') 31 | allFrames.sort(key=sortFunc) 32 | 33 | for startIdx in range(len(allFrames)): 34 | sampleDict = {} 35 | endIdx = startIdx + numIter * 3 + 4 36 | if endIdx <= len(allFrames): 37 | asample = allFrames[startIdx:endIdx] 38 | asample = [extractKey(i) for i in asample] 39 | for idx, key in zip(idxVT, keyVT): 40 | sampleDict[key] = asample[idx] 41 | allSamples.append(sampleDict) 42 | 43 | elif endIdx != len(allFrames) + 1: 44 | asample = allFrames[-(numIter * 3 + 4)::] 45 | asample = [extractKey(i) for i in asample] 46 | for idx, key in zip(idxVT, keyVT): 47 | sampleDict[key] = asample[idx] 48 | allSamples.append(sampleDict) 49 | break 50 | else: 51 | break 52 | with open(samplePath, 'wb') as f: 53 | pickle.dump(allSamples, f) 54 | 55 | 56 | if __name__ == '__main__': 57 | srcDir = 'Path/to/GoPro/train' 58 | samplePath = 'Path/to/GoPro/gopro_train_lmdb/sample.pkl' 59 | numIter = 7 60 | genSampleTrain(srcDir, samplePath, numIter=numIter) 61 | -------------------------------------------------------------------------------- /mkDataset/forGoPro/png2lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | from lib import fileTool as FT 5 | from tqdm import tqdm 6 | import os 7 | import lmdb 8 | import pickle 9 | import cv2 10 | from pathlib import Path 11 | 12 | 13 | def extractKey(name: str): 14 | parts = Path(name).parts 15 | key = f'{parts[-2]}/{parts[-1]}' 16 | return key 17 | 18 | 19 | def png2LMDB(srcDir, samplePath, lmdb_path): 20 | with open(samplePath, 'rb') as f: 21 | samples = pickle.load(f) 22 | values = [] 23 | for asample in samples: 24 | values += list(asample.values()) 25 | allPNGs = list(set(values)) 26 | allPNGs = [str(Path(srcDir)/i) for i in allPNGs] 27 | 28 | pbar = tqdm(total=len(allPNGs)) 29 | 30 | isdir = os.path.isdir(lmdb_path) 31 | write_frequency = 1000 32 | db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, 33 | meminit=False, map_async=True) 34 | txn = db.begin(write=True) 35 | for idx, aPNG in enumerate(allPNGs): 36 | key = extractKey(aPNG) 37 | 38 | img = cv2.imread(aPNG) 39 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 40 | 41 | sample = pickle.dumps(img, protocol=4) 42 | txn.put(u'{}'.format(key).encode('ascii'), sample) 43 | if idx % write_frequency == 0: 44 | txn.commit() 45 | txn = db.begin(write=True) 46 | pbar.update(1) 47 | 48 | txn.commit() 49 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 50 | with db.begin(write=True) as txn: 51 | txn.put(b'__keys__', pickle.dumps(keys, protocol=4)) 52 | txn.put(b'__len__', pickle.dumps(len(keys), protocol=4)) 53 | 54 | print("Flushing database ...") 55 | db.sync() 56 | db.close() 57 | 58 | 59 | if __name__ == '__main__': 60 | srcDir = 'Path/to/GoPro/train' 61 | samplePath = 'Path/to/GoPro/gopro_train_lmdb/sample.pkl' #'Path/to/GoPro/gopro_test_lmdb/sample.pkl' 62 | lmdb_path = 'Path/to/GoPro/gopro_train_lmdb' #'Path/to/GoPro/gopro_test_lmdb' 63 | png2LMDB(srcDir, samplePath, lmdb_path) 64 | -------------------------------------------------------------------------------- /mkDataset/forSnufilm/png2lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | from lib import fileTool as FT 5 | from tqdm import tqdm 6 | import os 7 | import lmdb 8 | import pickle 9 | import cv2 10 | from pathlib import Path 11 | 12 | 13 | def extractKey(name: str): 14 | parts = Path(name).parts 15 | key = f'{parts[-3]}/{parts[-2]}/{parts[-1]}' 16 | return key 17 | 18 | 19 | def png2LMDB(srcDir, lmdb_path): 20 | allPNGs = FT.getAllFiles(srcDir, 'png') 21 | allKeys = [] 22 | with open('Path/to/snufilm-test/snufilm_lmdb/test-easy.pkl', 'rb') as fs: 23 | samples = pickle.load(fs) 24 | for asample in samples: 25 | allKeys.extend(list(asample.values())) 26 | 27 | with open('Path/to/snufilm-test/snufilm_lmdb/test-extreme.pkl', 'rb') as fs: 28 | samples = pickle.load(fs) 29 | for asample in samples: 30 | allKeys.extend(list(asample.values())) 31 | 32 | with open('Path/to/snufilm-test/snufilm_lmdb/test-hard.pkl', 'rb') as fs: 33 | samples = pickle.load(fs) 34 | for asample in samples: 35 | allKeys.extend(list(asample.values())) 36 | 37 | with open('Path/to/snufilm-test/snufilm_lmdb/test-medium.pkl', 'rb') as fs: 38 | samples = pickle.load(fs) 39 | for asample in samples: 40 | allKeys.extend(list(asample.values())) 41 | allKeys = set(allKeys) 42 | 43 | pbar = tqdm(total=len(allPNGs)) 44 | 45 | isdir = os.path.isdir(lmdb_path) 46 | write_frequency = 1000 47 | db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, 48 | meminit=False, map_async=True) 49 | txn = db.begin(write=True) 50 | for idx, aPNG in enumerate(allPNGs): 51 | key = extractKey(aPNG) 52 | if key in allKeys: 53 | img = cv2.imread(aPNG) 54 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 55 | 56 | sample = pickle.dumps(img, protocol=4) 57 | txn.put(u'{}'.format(key).encode('ascii'), sample) 58 | if idx % write_frequency == 0: 59 | txn.commit() 60 | txn = db.begin(write=True) 61 | pbar.update(1) 62 | 63 | txn.commit() 64 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 65 | with db.begin(write=True) as txn: 66 | txn.put(b'__keys__', pickle.dumps(keys, protocol=4)) 67 | txn.put(b'__len__', pickle.dumps(len(keys), protocol=4)) 68 | 69 | print("Flushing database ...") 70 | db.sync() 71 | db.close() 72 | 73 | 74 | if __name__ == '__main__': 75 | srcDir = 'Path/to/snufilm-test' 76 | lmdb_path = 'Path/to/snufilm/snufilm_lmdb' 77 | png2LMDB(srcDir, lmdb_path) 78 | -------------------------------------------------------------------------------- /mkDataset/forSnufilm/txt2sample.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import re 3 | from pathlib import Path 4 | 5 | 6 | def extractKey(name: str): 7 | parts = Path(name).parts 8 | key = f'{parts[-3]}/{parts[-2]}/{parts[-1]}' 9 | return key 10 | 11 | 12 | def extendframe(I0: str, It: str, I1: str): 13 | parent = Path(I0).parent 14 | I0Idx = Path(I0).stem 15 | ItIdx = Path(It).stem 16 | I1Idx = Path(I1).stem 17 | len0 = len(I0Idx) 18 | gap = int(I1Idx) - int(I0Idx) 19 | In = str(parent / f'{int(I0Idx) - gap:0{len0}d}.png') 20 | I2 = str(parent / f'{int(I1Idx) + gap:0{len0}d}.png') 21 | return In, I0, I1, I2, It 22 | 23 | 24 | def main(dataPath, txtPath, picklePath): 25 | allSamples=[] 26 | with open(txtPath, mode='r') as f: 27 | lines = f.readlines() 28 | for line in lines: 29 | sampleDict = {} 30 | I0, It, I1 = line.strip('\n').split(' ') 31 | I0, It, I1 = [extractKey(i) for i in [I0, It, I1]] 32 | In, I0, I1, I2, It = extendframe(I0, It, I1) 33 | 34 | sampleDict['In'] = In 35 | sampleDict['I0'] = I0 36 | sampleDict['It'] = It 37 | sampleDict['I1'] = I1 38 | sampleDict['I2'] = I2 39 | if all([(Path(dataPath) / i).is_file() for i in [In, I0, I1, I2, It]]): 40 | allSamples.append(sampleDict) 41 | 42 | with open(picklePath, 'wb') as f: 43 | pickle.dump(allSamples, f) 44 | 45 | 46 | if __name__ == '__main__': 47 | dataPath = 'Path/to/snufilm-test' 48 | 49 | txtPath = 'Path/to/snufilm-test/eval_modes/test-extreme.txt' # easy, hard, medium extreme 50 | picklePath = 'Path/to/snufilm/snufilm_lmdb/test-extreme.pkl' # easy, hard, medium extreme 51 | main(dataPath, txtPath, picklePath) 52 | -------------------------------------------------------------------------------- /mkDataset/forUCF101/getTestList.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pickle 4 | from lib import fileTool as FT 5 | 6 | 7 | def extractKey(name: str): 8 | parent = Path(name).parent.name 9 | fileName = Path(name).name 10 | key = f'{parent}/{fileName}' 11 | return key 12 | 13 | 14 | def genSampleTest(srcDir, samplePath, numIter): 15 | allSubDirs = FT.getSubDirs(srcDir) 16 | allSamples = [] 17 | 18 | for subDir in allSubDirs: 19 | sampleDict={} 20 | allFrames = FT.getAllFiles(subDir, 'png') 21 | for fname in allFrames: 22 | if 'frame0.png' in fname: 23 | sampleDict['In'] = extractKey(fname) 24 | if 'frame1.png' in fname: 25 | sampleDict['I0'] = extractKey(fname) 26 | if 'frame2.png' in fname: 27 | sampleDict['I1'] = extractKey(fname) 28 | if 'frame3.png' in fname: 29 | sampleDict['I2'] = extractKey(fname) 30 | if 'framet.png' in fname: 31 | sampleDict['It'] = extractKey(fname) 32 | allSamples.append(sampleDict) 33 | 34 | with open(samplePath, 'wb') as f: 35 | pickle.dump(allSamples, f) 36 | 37 | 38 | if __name__ == '__main__': 39 | numIter = 1 40 | srcDir = '/Path/to/ucf101/' 41 | samplePath = 'Path/to/ucf101/ucf101_lmdb/sample.pkl' 42 | 43 | genSampleTest(srcDir=srcDir, samplePath=samplePath, numIter=numIter) 44 | -------------------------------------------------------------------------------- /mkDataset/forUCF101/png2lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | from lib import fileTool as FT 5 | from tqdm import tqdm 6 | import os 7 | import lmdb 8 | import pickle 9 | import cv2 10 | from pathlib import Path 11 | 12 | 13 | def extractKey(name: str): 14 | parent = Path(name).parent.name 15 | fileName = Path(name).name 16 | key = f'{parent}/{fileName}' 17 | return key 18 | 19 | 20 | def png2LMDB(srcDir, lmdb_path): 21 | allPNGs = FT.getAllFiles(srcDir, 'png') 22 | pbar = tqdm(total=len(allPNGs)) 23 | 24 | isdir = os.path.isdir(lmdb_path) 25 | write_frequency = 1000 26 | db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, 27 | meminit=False, map_async=True) 28 | txn = db.begin(write=True) 29 | for idx, aPNG in enumerate(allPNGs): 30 | key = extractKey(aPNG) 31 | 32 | img = cv2.imread(aPNG) 33 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 34 | 35 | sample = pickle.dumps(img, protocol=4) 36 | txn.put(u'{}'.format(key).encode('ascii'), sample) 37 | if idx % write_frequency == 0: 38 | txn.commit() 39 | txn = db.begin(write=True) 40 | pbar.update(1) 41 | 42 | txn.commit() 43 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 44 | with db.begin(write=True) as txn: 45 | txn.put(b'__keys__', pickle.dumps(keys, protocol=4)) 46 | txn.put(b'__len__', pickle.dumps(len(keys), protocol=4)) 47 | 48 | print("Flushing database ...") 49 | db.sync() 50 | db.close() 51 | 52 | 53 | if __name__ == '__main__': 54 | srcDir = 'Path/to/ucf101/' 55 | lmdb_path = 'Path/to/ucf101/ucf101_lmdb' 56 | png2LMDB(srcDir, lmdb_path) 57 | -------------------------------------------------------------------------------- /mkDataset/forVimeo/getTrainTestList.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import re 3 | import pickle 4 | from lib import fileTool as FT 5 | from tqdm import tqdm 6 | 7 | 8 | def extractKey(name: str): 9 | parts = Path(name).parts 10 | key = f'{parts[-3]}/{parts[-2]}/{parts[-1]}' 11 | return key 12 | 13 | 14 | def genSampleTest(srcDir, samplePath, numIter): 15 | allSubDirs1 = FT.getSubDirs(srcDir) 16 | allSamples = [] 17 | 18 | for subDir1 in tqdm(allSubDirs1): 19 | allsubDirs2 = FT.getSubDirs(subDir1) 20 | for subDir2 in allsubDirs2: 21 | allFrames = FT.getAllFiles(subDir2, 'png') 22 | sampleDict = {} 23 | 24 | for aPNG in allFrames: 25 | name = Path(aPNG).stem 26 | if name == 'im1': 27 | sampleDict['In'] = extractKey(aPNG) 28 | if name == 'im3': 29 | sampleDict['I0'] = extractKey(aPNG) 30 | if name == 'im5': 31 | sampleDict['I1'] = extractKey(aPNG) 32 | if name == 'im7': 33 | sampleDict['I2'] = extractKey(aPNG) 34 | if name == 'im4': 35 | sampleDict['It'] = extractKey(aPNG) 36 | 37 | allSamples.append(sampleDict) 38 | 39 | with open(samplePath, 'wb') as f: 40 | pickle.dump(allSamples, f) 41 | 42 | 43 | if __name__ == '__main__': 44 | numIter = 1 45 | srcDir = 'Path/to/vimeo_septuplet/train' # 'Path/to/vimeo_septuplet/test' 46 | samplePath = 'Path/to/vimeo/train_lmdb/sample.pkl' # 'Path/to/vimeo/test_lmdb/sample.pkl' 47 | 48 | genSampleTest(srcDir=srcDir, samplePath=samplePath, numIter=numIter) 49 | -------------------------------------------------------------------------------- /mkDataset/forVimeo/png2lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | os.environ['OMP_NUM_THREADS'] = '1' 4 | from lib import fileTool as FT 5 | from tqdm import tqdm 6 | import os 7 | import lmdb 8 | import pickle 9 | import cv2 10 | from pathlib import Path 11 | 12 | 13 | def extractKey(name: str): 14 | parts = Path(name).parts 15 | key = f'{parts[-3]}/{parts[-2]}/{parts[-1]}' 16 | return key 17 | 18 | 19 | def png2LMDB(srcDir, lmdb_path): 20 | allPNGs = FT.getAllFiles(srcDir, 'png') 21 | 22 | pbar = tqdm(total=len(allPNGs)) 23 | 24 | isdir = os.path.isdir(lmdb_path) 25 | write_frequency = 1000 26 | db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, 27 | meminit=False, map_async=True) 28 | txn = db.begin(write=True) 29 | for idx, aPNG in enumerate(allPNGs): 30 | if Path(aPNG).stem in ['im1', 'im3', 'im5', 'im7', 'im4']: 31 | key = extractKey(aPNG) 32 | 33 | img = cv2.imread(aPNG) 34 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 35 | 36 | sample = pickle.dumps(img, protocol=4) 37 | txn.put(u'{}'.format(key).encode('ascii'), sample) 38 | if idx % write_frequency == 0: 39 | txn.commit() 40 | txn = db.begin(write=True) 41 | pbar.update(1) 42 | 43 | txn.commit() 44 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 45 | with db.begin(write=True) as txn: 46 | txn.put(b'__keys__', pickle.dumps(keys, protocol=4)) 47 | txn.put(b'__len__', pickle.dumps(len(keys), protocol=4)) 48 | 49 | print("Flushing database ...") 50 | db.sync() 51 | db.close() 52 | 53 | 54 | if __name__ == '__main__': 55 | srcDir = 'Path/to/vimeo_septuplet/train' 56 | lmdb_path = 'Path/to/vimeo/train_lmdb' 57 | png2LMDB(srcDir, lmdb_path) 58 | -------------------------------------------------------------------------------- /mkDataset/forVimeo/splitTrainTest.py: -------------------------------------------------------------------------------- 1 | import lib.fileTool as FT 2 | from pathlib import Path 3 | from tqdm import tqdm 4 | 5 | def split(srcDir, trainDir, testDir, sepTrain, sepTest): 6 | with open(sepTrain, 'r') as f: 7 | trainlist = f.read().splitlines() 8 | with open(sepTest, 'r') as f: 9 | testlist = f.read().splitlines() 10 | for atrain in tqdm(trainlist, leave=False): 11 | src = str(Path(srcDir) / atrain) 12 | assert Path(src).is_dir() 13 | dst = src.replace(srcDir, trainDir) 14 | FT.movFile(src, dst) 15 | assert Path(dst).is_dir() 16 | for atest in tqdm(testlist): 17 | src = str(Path(srcDir) / atest) 18 | assert Path(src).is_dir() 19 | dst = src.replace(srcDir, testDir) 20 | FT.movFile(src, dst) 21 | assert Path(dst).is_dir() 22 | pass 23 | 24 | 25 | if __name__ == '__main__': 26 | srcDir = 'Path/to/vimeo_septuplet/sequences' 27 | trainDir = 'Path/to/vimeo_septuplet/train' 28 | testDir = 'Path/to/vimeo_septuplet/test' 29 | sepTrain = 'Path/to/vimeo_septuplet/sep_trainlist.txt' 30 | sepTest = 'Path/to/vimeo_septuplet/sep_testlist.txt' 31 | 32 | split(srcDir, trainDir, testDir, sepTrain, sepTest) 33 | -------------------------------------------------------------------------------- /mkDataset/forXVFI/getTestListXVFi.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import pickle 4 | from lib import fileTool as FT 5 | 6 | 7 | def extractKey(name: str): 8 | parts = Path(name).parts 9 | key = f'{parts[-2]}/{parts[-1]}' 10 | return key 11 | 12 | 13 | def genSampleTest(srcDir, samplePath, numIter): 14 | allSubDirs = FT.getSubDirs(srcDir) 15 | allSamples = [] 16 | keyV = ['In', 'I0', 'I1', 'I2'] 17 | keyT = [f'It{i + 1}' for i in range(numIter)] 18 | keyVT = keyV + keyT 19 | 20 | for subDir in allSubDirs: 21 | sampleDict={} 22 | allFrames = FT.getAllFiles(subDir, 'png') 23 | 24 | for fname in allFrames: 25 | for akey in keyVT: 26 | if f'{akey}.png' in fname: 27 | sampleDict[akey] = extractKey(fname) 28 | break 29 | allSamples.append(sampleDict) 30 | 31 | with open(samplePath, 'wb') as f: 32 | pickle.dump(allSamples, f) 33 | 34 | 35 | if __name__ == '__main__': 36 | numIter = 7 37 | srcDir = './datasets/X4K1000FPS/test_frames/' 38 | samplePath = './datasets/X4K1000FPS/X4k_lmdb//sample.pkl' 39 | 40 | genSampleTest(srcDir=srcDir, samplePath=samplePath, numIter=numIter) 41 | -------------------------------------------------------------------------------- /mkDataset/forXVFI/getTrainListXVFi.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import re 4 | import pickle 5 | from lib import fileTool as FT 6 | from lib.dataUtils import sortFunc 7 | 8 | 9 | def extractKey(name: str): 10 | parts = Path(name).parts 11 | key = f'{parts[-2]}/{parts[-1]}' 12 | return key 13 | 14 | 15 | def sortFunc(name: str): 16 | idx = re.search(r'(\d*).png', str(name)).group(1) 17 | return int(idx) 18 | 19 | 20 | def genSampleTrain(srcDir, samplePath, numIter=7): 21 | idxV = [0, (numIter + 1) * 2, (2 * numIter + 2) * 2, (3 * numIter + 3) * 2] 22 | idxT = [numIter * 2 + 2 + i * 2 + 2 for i in range(numIter)] 23 | idxVT = idxV + idxT 24 | 25 | keyV = ['In', 'I0', 'I1', 'I2'] 26 | keyT = [f'It{i + 1}' for i in range(numIter)] 27 | keyVT = keyV + keyT 28 | 29 | allSubDirs = FT.getSubDirs(srcDir) 30 | allSamples = [] 31 | for subDir in allSubDirs: 32 | allFrames = FT.getAllFiles(subDir, 'png') 33 | allFrames.sort(key=sortFunc) 34 | 35 | # for startIdx in range(len(allFrames)): 36 | # for startIdx in [0, 3, 8, 11, 14, 16]: 37 | for startIdx in [0, 3, 8, 11, 16]: 38 | sampleDict = {} 39 | endIdx = startIdx + 49 40 | asample = allFrames[startIdx:endIdx] 41 | asample = [extractKey(i) for i in asample] 42 | for idx, key in zip(idxVT, keyVT): 43 | sampleDict[key] = asample[idx] 44 | allSamples.append(sampleDict) 45 | 46 | with open(samplePath, 'wb') as f: 47 | pickle.dump(allSamples, f) 48 | 49 | 50 | if __name__ == '__main__': 51 | numIter = 7 52 | srcDir = '/data/dataset/X4K1000FPS/train' 53 | samplePath = '/data/dataset/X4K1000FPS/train_lmdb/sample.pkl' 54 | 55 | genSampleTrain(srcDir=srcDir, samplePath=samplePath, numIter=numIter) 56 | -------------------------------------------------------------------------------- /mkDataset/forXVFI/mp4_decoding.py: -------------------------------------------------------------------------------- 1 | ## You need ffmpeg version 4 to support the option '-pred mixed' which is new in version 4. 2 | ## The option '-pred mixed' gives smaller .png file size (lossless compression). 3 | ## The older version of ffmpeg also can be used without the option '-pred mixed' 4 | ## To install the ffmpeg version 4 in ubuntu, please run the below lines through terminal. 5 | 6 | # conda install -c conda-forge ffmpeg 7 | 8 | ## Please modify the below code lines if needed. 9 | 10 | import os, glob, sys 11 | 12 | def check_folder(log_dir): 13 | if not os.path.exists(log_dir): 14 | os.makedirs(log_dir) 15 | # print(log_dir, " created") 16 | return log_dir 17 | 18 | try: 19 | ################################################################# 20 | ## Decode test set. About 6 GB with the option '-pred mixed' 21 | ################################################################# 22 | # test_types = sorted(glob.glob('./encoded_test/*/')) 23 | # for test_type in test_types: 24 | # samples = sorted(glob.glob(test_type + '*.mp4')) 25 | # for sample in samples: 26 | # new_dir = sample.replace('encoded_test','test').replace('.mp4','') 27 | # check_folder(new_dir) 28 | # cmd = "ffmpeg -i {} -pred mixed -start_number 0 {}/%04d.png".format(sample, new_dir) # if ffmpeg version >= 4 29 | # # cmd = "ffmpeg -i {} -start_number 0 {}/%04d.png".format(sample, new_dir) # if ffmpeg version < 4 30 | # print(cmd) 31 | # if os.system(cmd): 32 | # raise KeyboardInterrupt 33 | # 34 | 35 | ################################################################# 36 | ## Decode training set. About 240 GB with the option '-pred mixed' 37 | ################################################################# 38 | scenes = sorted(glob.glob('/data/dataset/X4K1000FPS/encoded_train/*/')) 39 | for scene in scenes: 40 | samples = sorted(glob.glob(os.path.join(scene, '*.mp4'))) 41 | for sample in samples: 42 | new_dir = sample.replace('encoded_train','train').replace('.mp4','') 43 | check_folder(new_dir) 44 | cmd = "ffmpeg -i {} -pred mixed -start_number 0 {}/%04d.png".format(sample, new_dir) # if ffmpeg version >= 4 45 | # cmd = "ffmpeg -i {} -start_number 0 {}/%04d.png".format(sample, new_dir) # if ffmpeg version < 4 46 | print(cmd) 47 | if os.system(cmd): 48 | raise KeyboardInterrupt 49 | 50 | ################################################################# 51 | except KeyboardInterrupt: 52 | print("KeyboardInterrupt") 53 | sys.exit(0) 54 | -------------------------------------------------------------------------------- /mkDataset/forXVFI/png2lmdbXVFI.py: -------------------------------------------------------------------------------- 1 | from lib import fileTool as FT 2 | from tqdm import tqdm 3 | import os 4 | import lmdb 5 | import pickle 6 | import cv2 7 | import re 8 | from pathlib import Path 9 | from collections import OrderedDict 10 | 11 | 12 | def extractKey(name: str): 13 | parts = Path(name).parts 14 | key = f'{parts[-2]}/{parts[-1]}' 15 | return key 16 | 17 | 18 | def png2LMDB(srcDir, lmdb_path): 19 | allPNGs = FT.getAllFiles(srcDir, 'png') 20 | 21 | pbar = tqdm(total=len(allPNGs)) 22 | 23 | isdir = os.path.isdir(lmdb_path) 24 | write_frequency = 20 25 | db = lmdb.open(lmdb_path, subdir=isdir, map_size=1099511627776 * 2, readonly=False, 26 | meminit=False, map_async=True) 27 | txn = db.begin(write=True) 28 | for idx, aPNG in enumerate(allPNGs): 29 | key = extractKey(aPNG) 30 | 31 | img = cv2.imread(aPNG) 32 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 33 | 34 | sample = pickle.dumps(img, protocol=4) 35 | txn.put(u'{}'.format(key).encode('ascii'), sample) 36 | if idx % write_frequency == 0: 37 | txn.commit() 38 | txn = db.begin(write=True) 39 | pbar.update(1) 40 | 41 | txn.commit() 42 | keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)] 43 | with db.begin(write=True) as txn: 44 | txn.put(b'__keys__', pickle.dumps(keys, protocol=4)) 45 | txn.put(b'__len__', pickle.dumps(len(keys), protocol=4)) 46 | 47 | print("Flushing database ...") 48 | db.sync() 49 | db.close() 50 | 51 | 52 | def mergeData(): 53 | srcPath = '/home/sensetime/data/VideoInterpolation/highfps/X4K1000FPS/test' 54 | dstPath = '/home/sensetime/data/VideoInterpolation/highfps/X4K1000FPS/mytest' 55 | allTypes = FT.getSubDirs(srcPath) 56 | total = 0 57 | for aType in allTypes: 58 | allseqs = FT.getSubDirs(aType) 59 | allseqs = [str(i) for i in allseqs] 60 | INames = OrderedDict({'0968.png': 'In.png', '0000.png': 'I0.png', '0032.png': 'I1.png', '1064.png': 'I2.png', 61 | '0004.png': 'It1.png', '0008.png': 'It2.png', '0012.png': 'It3.png', 62 | '0016.png': 'It4.png', '0020.png': 'It5.png', '0024.png': 'It6.png', 63 | '0028.png': 'It7.png'}) 64 | for aseq in allseqs: 65 | Is = [str(Path(aseq) / i) for i in INames.keys()] 66 | newDirName = str(Path(dstPath) / f'{Path(aType).stem}_{Path(aseq).stem}') 67 | FT.mkPath(newDirName) 68 | for I in Is: 69 | key = re.search('(\d){4}.png', I).group(0) 70 | value = INames[key] 71 | dstName = str(Path(newDirName) / value) 72 | FT.copyFile(I, dstName) 73 | pass 74 | 75 | pass 76 | 77 | 78 | if __name__ == '__main__': 79 | # mergeData() 80 | srcDir = 'Path/to/X4K1000FPS/train' 81 | lmdb_path = 'Path/to/train_lmdb' 82 | samplePath = 'Path/to/train_lmdb/samples.pkl' 83 | png2LMDB(srcDir, lmdb_path) 84 | # checkTest() 85 | -------------------------------------------------------------------------------- /mkDataset/lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Oceanlib/DBVI/7db961500656c8b37706d5c70547c82a547bb838/mkDataset/lib/__init__.py -------------------------------------------------------------------------------- /mkDataset/lib/dataUtils.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | def sortFunc(name: str): 4 | idx = re.search(r'(\d*).png', str(name)).group(1) 5 | return int(idx) 6 | 7 | 8 | def sample2idx(mapdict, samplelist): 9 | idxlist = [] 10 | for sample in samplelist: 11 | idx = mapdict[sample] 12 | idxlist.append(idx) 13 | return tuple(idxlist) -------------------------------------------------------------------------------- /mkDataset/lib/fileTool.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | from pathlib import Path 3 | from queue import Queue 4 | 5 | 6 | def delPath(root): 7 | """ 8 | remove dir trees 9 | :param path: root dir 10 | :return: True or False 11 | """ 12 | root = Path(root) 13 | if root.is_file(): 14 | try: 15 | root.unlink() 16 | except Exception as e: 17 | print(e) 18 | elif root.is_dir(): 19 | for item in root.iterdir(): 20 | delPath(item) 21 | try: 22 | root.rmdir() 23 | # print('Files in {} is removed'.format(root)) 24 | except Exception as e: 25 | print(e) 26 | 27 | 28 | def mkPath(path): 29 | p = Path(path) 30 | try: 31 | p.mkdir(parents=True, exist_ok=False) 32 | return True 33 | except Exception as e: 34 | return False 35 | 36 | 37 | def getAllFiles(root, ext=None): 38 | p = Path(root) 39 | if ext is not None: 40 | pathnames = p.glob("**/*.{}".format(ext)) 41 | 42 | else: 43 | pathnames = p.glob("**/*") 44 | filenames = sorted([x.as_posix() for x in pathnames]) 45 | return filenames 46 | 47 | 48 | def copyFile(src, dst): 49 | parent = Path(dst).parent 50 | mkPath(parent) 51 | try: 52 | shutil.copytree(str(src), str(dst)) 53 | except: 54 | shutil.copy(str(src), str(dst)) 55 | 56 | 57 | def movFile(src, dst): 58 | parent = Path(dst).parent 59 | mkPath(parent) 60 | shutil.move(str(src), str(dst)) 61 | 62 | 63 | def getSubDirs(root): 64 | p = Path(root) 65 | dirs = [x for x in p.iterdir() if x.is_dir()] 66 | return sorted(dirs) 67 | 68 | 69 | class fileBuffer(Queue): 70 | def __init__(self, capacity): 71 | super(fileBuffer, self).__init__() 72 | assert capacity > 0 73 | self.capacity = capacity 74 | 75 | def __call__(self, x): 76 | while self.qsize() >= self.capacity: 77 | delPath(self.get()) 78 | self.put(x) 79 | --------------------------------------------------------------------------------