├── LICENSE ├── README.md ├── dataset ├── __init__.py ├── davis.py ├── dutsv2.py ├── fbms.py ├── lvid.py ├── transforms.py ├── ytobj.py └── ytvos.py ├── evaluation ├── __init__.py ├── evaluator.py └── metrics.py ├── fakeflow.py ├── requirements.txt ├── run.py ├── trainer.py ├── utils.py └── weights └── empty.txt /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Suhwan Cho 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FakeFlow 2 | 3 | This is the official PyTorch implementation of our paper: 4 | 5 | > **Improving Unsupervised Video Object Segmentation via Fake Flow Generation**, *arXiv 2024*\ 6 | > Suhwan Cho, Minhyeok Lee, Jungho Lee, DongHyeong Kim, Seunghoon Lee, Sungmin Woo, Sangyoun Lee\ 7 | > Link: [[arXiv]](https://arxiv.org/pdf/2407.11714) 8 | 9 | 10 | 11 | 12 | 13 | You can also find other related papers at [awesome-video-object-segmentation](https://github.com/suhwan-cho/awesome-video-object-segmentation). 14 | 15 | 16 | ## Abstract 17 | In unsupervised VOS, the scarcity of training data has been a significant bottleneck in achieving high segmentation accuracy. Inspired by 18 | observations on two-stream approaches, we introduce a novel data generation method based on the **depth-to-flow conversion** process. With our fake flow generation protocol, 19 | large-scale image-flow pairs can be leveraged during network training. To facilitate future research, we also prepare the **DUTSv2** dataset, which is an extended version of DUTS, 20 | comprising pairs of the original images and the simulated flow maps. 21 | 22 | 23 | ## Preparation 24 | 1\. Download 25 | [DUTS](http://saliencydetection.net/duts/#org3aad434), 26 | [DAVIS](https://davischallenge.org/davis2017/code.html), 27 | [FBMS](https://lmb.informatik.uni-freiburg.de/resources/datasets), 28 | [YouTube-Objects](https://data.vision.ee.ethz.ch/cvl/youtube-objects), 29 | and [Long-Videos](https://www.kaggle.com/datasets/gvclsu/long-videos) 30 | from the official websites. 31 | 32 | 2\. Estimate and save optical flow maps from the videos using [RAFT](https://github.com/princeton-vl/RAFT). 33 | 34 | 3\. For DUTS, simulate optical flow maps using [DPT](https://github.com/isl-org/DPT). 35 | 36 | 4\. For convenience, I also provide the pre-processed 37 | [DUTSv2](https://drive.google.com/file/d/1P8_USG8CWlpWm5UEcfXgXdr3IYQnhAvi/view?usp=drive_link), 38 | [DAVIS](https://drive.google.com/file/d/1kx-Cs5qQU99dszJQJOGKNb-wD_090q6c/view?usp=drive_link), 39 | [FBMS](https://drive.google.com/file/d/1Zgt5ouwFeTpMTemfNeEFz7uEUo77e2ml/view?usp=drive_link), 40 | [YouTube-Objects](https://drive.google.com/file/d/1t_eeHXJ30TWBNmMzE7vfS0izEafiBfgn/view?usp=drive_link), 41 | and [Long-Videos](https://drive.google.com/file/d/1gZm1QBT_6JmHhphNrxuSztcqkm_eI6Sq/view?usp=drive_link). 42 | 43 | 5\. Replace dataset paths in "run.py" file with your dataset paths. 44 | 45 | 46 | ## Training 47 | 1\. Open the "run.py" file. 48 | 49 | 2\. Specify the model version. 50 | 51 | 3\. Verify the training settings. 52 | 53 | 4\. Start **FakeFlow** training! 54 | ``` 55 | python run.py --train 56 | ``` 57 | 58 | 59 | ## Testing 60 | 1\. Open the "run.py" file. 61 | 62 | 2\. Specify the model version. 63 | 64 | 3\. Choose a pre-trained model. 65 | 66 | 4\. Start **FakeFlow** testing! 67 | ``` 68 | python run.py --test 69 | ``` 70 | 71 | 72 | ## Attachments 73 | [pre-trained model (mitb0)](https://drive.google.com/file/d/1FFz9buCu5XCl1LUpwwIUZJNvcAH4V_QG/view?usp=drive_link)\ 74 | [pre-trained model (mitb1)](https://drive.google.com/file/d/1DhNsNoF2borozWU5JbrJHIixUoW4hs-Y/view?usp=drive_link)\ 75 | [pre-trained model (mitb2)](https://drive.google.com/file/d/1GYAlCt97kcNjtcXoZUGkhIQXgovwV0fz/view?usp=drive_link)\ 76 | [pre-computed results](https://drive.google.com/file/d/1OiIaVPf51kqAzGYqFtFl8lLsw-yuY5fi/view?usp=sharing) 77 | 78 | 79 | ## Note 80 | Code and models are only available for non-commercial research purposes.\ 81 | If you have any questions, please feel free to contact me :) 82 | ``` 83 | E-mail: suhwanx@gmail.com 84 | ``` 85 | -------------------------------------------------------------------------------- /dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .ytvos import * 2 | from .dutsv2 import * 3 | from .davis import * 4 | from .fbms import * 5 | from .ytobj import * 6 | from .lvid import * 7 | from .transforms import * 8 | -------------------------------------------------------------------------------- /dataset/davis.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | import os 3 | import random 4 | from glob import glob 5 | from PIL import Image 6 | import torchvision as tv 7 | import torchvision.transforms.functional as TF 8 | 9 | 10 | class TrainDAVIS(torch.utils.data.Dataset): 11 | def __init__(self, root, year, split, clip_n): 12 | self.root = root 13 | with open(os.path.join(root, 'ImageSets', '{}/{}.txt'.format(year, split)), 'r') as f: 14 | self.video_list = f.read().splitlines() 15 | self.clip_n = clip_n 16 | self.to_tensor = tv.transforms.ToTensor() 17 | self.to_mask = LabelToLongTensor() 18 | 19 | def __len__(self): 20 | return self.clip_n 21 | 22 | def __getitem__(self, idx): 23 | video_name = random.choice(self.video_list) 24 | img_dir = os.path.join(self.root, 'JPEGImages', '480p', video_name) 25 | flow_dir = os.path.join(self.root, 'JPEGFlows', '480p', video_name) 26 | mask_dir = os.path.join(self.root, 'Annotations', '480p', video_name) 27 | img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) 28 | flow_list = sorted(glob(os.path.join(flow_dir, '*.jpg'))) 29 | mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) 30 | 31 | # select training frame 32 | all_frames = list(range(len(img_list))) 33 | frame_id = random.choice(all_frames) 34 | img = Image.open(img_list[frame_id]).convert('RGB') 35 | flow = Image.open(flow_list[frame_id]).convert('RGB') 36 | mask = Image.open(mask_list[frame_id]).convert('P') 37 | 38 | # resize to 512p 39 | img = img.resize((512, 512), Image.BICUBIC) 40 | flow = flow.resize((512, 512), Image.BICUBIC) 41 | mask = mask.resize((512, 512), Image.NEAREST) 42 | 43 | # joint flip 44 | if random.random() > 0.5: 45 | img = TF.hflip(img) 46 | flow = TF.hflip(flow) 47 | mask = TF.hflip(mask) 48 | if random.random() > 0.5: 49 | img = TF.vflip(img) 50 | flow = TF.vflip(flow) 51 | mask = TF.vflip(mask) 52 | 53 | # convert formats 54 | imgs = self.to_tensor(img).unsqueeze(0) 55 | flows = self.to_tensor(flow).unsqueeze(0) 56 | masks = self.to_mask(mask).unsqueeze(0) 57 | masks = (masks != 0).long() 58 | return {'imgs': imgs, 'flows': flows, 'masks': masks} 59 | 60 | 61 | class TestDAVIS(torch.utils.data.Dataset): 62 | def __init__(self, root, year, split): 63 | self.root = root 64 | self.year = year 65 | self.split = split 66 | self.init_data() 67 | 68 | def read_img(self, path): 69 | pic = Image.open(path).convert('RGB') 70 | transform = tv.transforms.ToTensor() 71 | return transform(pic) 72 | 73 | def read_mask(self, path): 74 | pic = Image.open(path).convert('P') 75 | transform = LabelToLongTensor() 76 | return transform(pic) 77 | 78 | def init_data(self): 79 | with open(os.path.join(self.root, 'ImageSets', self.year, self.split + '.txt'), 'r') as f: 80 | self.video_list = sorted(f.read().splitlines()) 81 | print('--- DAVIS {} {} loaded for testing ---'.format(self.year, self.split)) 82 | 83 | def get_snippet(self, video_name, frame_ids): 84 | img_path = os.path.join(self.root, 'JPEGImages', '480p', video_name) 85 | flow_path = os.path.join(self.root, 'JPEGFlows', '480p', video_name) 86 | mask_path = os.path.join(self.root, 'Annotations', '480p', video_name) 87 | imgs = torch.stack([self.read_img(os.path.join(img_path, '{:05d}.jpg'.format(i))) for i in frame_ids]).unsqueeze(0) 88 | flows = torch.stack([self.read_img(os.path.join(flow_path, '{:05d}.jpg'.format(i))) for i in frame_ids]).unsqueeze(0) 89 | masks = torch.stack([self.read_mask(os.path.join(mask_path, '{:05d}.png'.format(i))) for i in frame_ids]).unsqueeze(0) 90 | if self.year == '2016': 91 | masks = (masks != 0).long() 92 | files = ['{:05d}.png'.format(i) for i in frame_ids] 93 | return {'imgs': imgs, 'flows': flows, 'masks': masks, 'files': files} 94 | 95 | def get_video(self, video_name): 96 | frame_ids = sorted([int(file[:5]) for file in os.listdir(os.path.join(self.root, 'JPEGImages', '480p', video_name))]) 97 | yield self.get_snippet(video_name, frame_ids) 98 | 99 | def get_videos(self): 100 | for video_name in self.video_list: 101 | yield video_name, self.get_video(video_name) 102 | -------------------------------------------------------------------------------- /dataset/dutsv2.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | import os 3 | import random 4 | from glob import glob 5 | from PIL import Image 6 | import torchvision as tv 7 | import torchvision.transforms.functional as TF 8 | 9 | 10 | class TrainDUTSv2(torch.utils.data.Dataset): 11 | def __init__(self, root, clip_n): 12 | self.root = root 13 | img_dir = os.path.join(root, 'JPEGImages') 14 | flow_dir = os.path.join(root, 'JPEGFlows') 15 | mask_dir = os.path.join(root, 'Annotations') 16 | self.img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) 17 | self.flow_list = sorted(glob(os.path.join(flow_dir, '*.jpg'))) 18 | self.mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) 19 | self.clip_n = clip_n 20 | self.to_tensor = tv.transforms.ToTensor() 21 | 22 | def __len__(self): 23 | return self.clip_n 24 | 25 | def __getitem__(self, idx): 26 | all_frames = list(range(len(self.img_list))) 27 | frame_id = random.choice(all_frames) 28 | img = Image.open(self.img_list[frame_id]).convert('RGB') 29 | flow = Image.open(self.flow_list[frame_id]).convert('RGB') 30 | mask = Image.open(self.mask_list[frame_id]).convert('L') 31 | 32 | # resize to 512p 33 | img = img.resize((512, 512), Image.BICUBIC) 34 | flow = flow.resize((512, 512), Image.BICUBIC) 35 | mask = mask.resize((512, 512), Image.BICUBIC) 36 | 37 | # joint flip 38 | if random.random() > 0.5: 39 | img = TF.hflip(img) 40 | flow = TF.hflip(flow) 41 | mask = TF.hflip(mask) 42 | if random.random() > 0.5: 43 | img = TF.vflip(img) 44 | flow = TF.vflip(flow) 45 | mask = TF.vflip(mask) 46 | 47 | # convert formats 48 | imgs = self.to_tensor(img).unsqueeze(0) 49 | flows = self.to_tensor(flow).unsqueeze(0) 50 | masks = self.to_tensor(mask).unsqueeze(0) 51 | masks = (masks > 0.5).long() 52 | return {'imgs': imgs, 'flows': flows, 'masks': masks} 53 | -------------------------------------------------------------------------------- /dataset/fbms.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | import os 3 | from glob import glob 4 | from PIL import Image 5 | import torchvision as tv 6 | 7 | 8 | class TestFBMS(torch.utils.data.Dataset): 9 | def __init__(self, root): 10 | self.root = root 11 | self.video_list = sorted(os.listdir(os.path.join(root, 'JPEGImages'))) 12 | self.to_tensor = tv.transforms.ToTensor() 13 | 14 | def __len__(self): 15 | return len(self.video_list) 16 | 17 | def __getitem__(self, idx): 18 | video_name = self.video_list[idx] 19 | img_dir = os.path.join(self.root, 'JPEGImages', video_name) 20 | flow_dir = os.path.join(self.root, 'JPEGFlows', video_name) 21 | mask_dir = os.path.join(self.root, 'Annotations', video_name) 22 | img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) 23 | flow_list = sorted(glob(os.path.join(flow_dir, '*.jpg'))) 24 | mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) 25 | 26 | # generate testing snippets 27 | imgs = [] 28 | flows = [] 29 | masks = [] 30 | for i in range(len(img_list)): 31 | img = Image.open(img_list[i]).convert('RGB') 32 | imgs.append(self.to_tensor(img)) 33 | for i in range(len(flow_list)): 34 | flow = Image.open(flow_list[i]).convert('RGB') 35 | flows.append(self.to_tensor(flow)) 36 | for i in range(len(mask_list)): 37 | mask = Image.open(mask_list[i]).convert('L') 38 | masks.append(self.to_tensor(mask)) 39 | 40 | # gather all frames 41 | imgs = torch.stack(imgs, dim=0) 42 | flows = torch.stack(flows, dim=0) 43 | masks = torch.stack(masks, dim=0) 44 | masks = (masks > 0.5).long() 45 | return {'imgs': imgs, 'flows': flows, 'masks': masks, 'video_name': video_name, 'files': mask_list} 46 | -------------------------------------------------------------------------------- /dataset/lvid.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | import os 3 | from glob import glob 4 | from PIL import Image 5 | import torchvision as tv 6 | 7 | 8 | class TestLVID(torch.utils.data.Dataset): 9 | def __init__(self, root): 10 | self.root = root 11 | self.video_list = sorted(os.listdir(os.path.join(root, 'JPEGImages'))) 12 | self.to_tensor = tv.transforms.ToTensor() 13 | self.to_mask = LabelToLongTensor() 14 | 15 | def __len__(self): 16 | return len(self.video_list) 17 | 18 | def __getitem__(self, idx): 19 | video_name = self.video_list[idx] 20 | img_dir = os.path.join(self.root, 'JPEGImages', video_name) 21 | flow_dir = os.path.join(self.root, 'JPEGFlows', video_name) 22 | mask_dir = os.path.join(self.root, 'Annotations', video_name) 23 | img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) 24 | flow_list = sorted(glob(os.path.join(flow_dir, '*.jpg'))) 25 | mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) 26 | 27 | # generate testing snippets 28 | imgs = [] 29 | flows = [] 30 | masks = [] 31 | for i in range(len(img_list)): 32 | img = Image.open(img_list[i]).convert('RGB') 33 | imgs.append(self.to_tensor(img)) 34 | for i in range(len(flow_list)): 35 | flow = Image.open(flow_list[i]).convert('RGB') 36 | flows.append(self.to_tensor(flow)) 37 | for i in range(len(mask_list)): 38 | mask = Image.open(mask_list[i]).convert('P') 39 | masks.append(self.to_mask(mask)) 40 | 41 | # gather all frames 42 | imgs = torch.stack(imgs, dim=0) 43 | flows = torch.stack(flows, dim=0) 44 | masks = torch.stack(masks, dim=0) 45 | masks = (masks != 0).long() 46 | return {'imgs': imgs, 'flows': flows, 'masks': masks, 'video_name': video_name, 'files': mask_list} 47 | -------------------------------------------------------------------------------- /dataset/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | class LabelToLongTensor(object): 6 | def __call__(self, pic): 7 | if isinstance(pic, np.ndarray): 8 | label = torch.from_numpy(pic).long() 9 | elif pic.mode == '1': 10 | label = torch.from_numpy(np.array(pic, np.uint8, copy=False)).long().view(1, pic.size[1], pic.size[0]) 11 | else: 12 | label = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) 13 | if pic.mode == 'LA': 14 | label = label.view(pic.size[1], pic.size[0], 2) 15 | label = label.transpose(0, 1).transpose(0, 2).contiguous().long()[0] 16 | label = label.view(1, label.size(0), label.size(1)) 17 | else: 18 | label = label.view(pic.size[1], pic.size[0], -1) 19 | label = label.transpose(0, 1).transpose(0, 2).contiguous().long() 20 | label[label == 255] = 0 21 | return label 22 | -------------------------------------------------------------------------------- /dataset/ytobj.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | import os 3 | from glob import glob 4 | from PIL import Image 5 | import torchvision as tv 6 | 7 | 8 | class TestYTOBJ(torch.utils.data.Dataset): 9 | def __init__(self, root): 10 | self.root = root 11 | self.video_list = [] 12 | class_list = sorted(os.listdir(os.path.join(root, 'JPEGImages'))) 13 | for class_name in class_list: 14 | video_list = sorted(os.listdir(os.path.join(root, 'JPEGImages', class_name))) 15 | for video_name in video_list: 16 | self.video_list.append(class_name + '_' + video_name) 17 | self.to_tensor = tv.transforms.ToTensor() 18 | 19 | def __len__(self): 20 | return len(self.video_list) 21 | 22 | def __getitem__(self, idx): 23 | class_name = self.video_list[idx].split('_')[0] 24 | video_name = self.video_list[idx].split('_')[1] 25 | img_dir = os.path.join(self.root, 'JPEGImages', class_name, video_name) 26 | flow_dir = os.path.join(self.root, 'JPEGFlows', class_name, video_name) 27 | mask_dir = os.path.join(self.root, 'Annotations', class_name, video_name) 28 | img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) 29 | flow_list = sorted(glob(os.path.join(flow_dir, '*.jpg'))) 30 | mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) 31 | 32 | # generate testing snippets 33 | imgs = [] 34 | flows = [] 35 | masks = [] 36 | for i in range(len(img_list)): 37 | img = Image.open(img_list[i]).convert('RGB') 38 | imgs.append(self.to_tensor(img)) 39 | for i in range(len(flow_list)): 40 | flow = Image.open(flow_list[i]).convert('RGB') 41 | flows.append(self.to_tensor(flow)) 42 | for i in range(len(mask_list)): 43 | mask = Image.open(mask_list[i]).convert('L') 44 | masks.append(self.to_tensor(mask)) 45 | 46 | # gather all frames 47 | imgs = torch.stack(imgs, dim=0) 48 | flows = torch.stack(flows, dim=0) 49 | masks = torch.stack(masks, dim=0) 50 | masks = (masks > 0.5).long() 51 | return {'imgs': imgs, 'flows': flows, 'masks': masks, 'class_name': class_name, 'video_name': video_name, 'files': mask_list} 52 | -------------------------------------------------------------------------------- /dataset/ytvos.py: -------------------------------------------------------------------------------- 1 | from .transforms import * 2 | import os 3 | import random 4 | from glob import glob 5 | from PIL import Image 6 | import torchvision as tv 7 | import torchvision.transforms.functional as TF 8 | 9 | 10 | class TrainYTVOS(torch.utils.data.Dataset): 11 | def __init__(self, root, split, clip_n): 12 | self.root = root 13 | self.split = split 14 | with open(os.path.join(root, 'ImageSets', '{}.txt'.format(split)), 'r') as f: 15 | self.video_list = f.read().splitlines() 16 | self.clip_n = clip_n 17 | self.to_tensor = tv.transforms.ToTensor() 18 | self.to_mask = LabelToLongTensor() 19 | 20 | def __len__(self): 21 | return self.clip_n 22 | 23 | def __getitem__(self, idx): 24 | video_name = random.choice(self.video_list) 25 | img_dir = os.path.join(self.root, self.split, 'JPEGImages', video_name) 26 | flow_dir = os.path.join(self.root, self.split, 'JPEGFlows', video_name) 27 | mask_dir = os.path.join(self.root, self.split, 'Annotations', video_name) 28 | img_list = sorted(glob(os.path.join(img_dir, '*.jpg'))) 29 | flow_list = sorted(glob(os.path.join(flow_dir, '*.jpg'))) 30 | mask_list = sorted(glob(os.path.join(mask_dir, '*.png'))) 31 | 32 | # select training frame 33 | all_frames = list(range(len(img_list))) 34 | frame_id = random.choice(all_frames) 35 | img = Image.open(img_list[frame_id]).convert('RGB') 36 | flow = Image.open(flow_list[frame_id]).convert('RGB') 37 | mask = Image.open(mask_list[frame_id]).convert('P') 38 | 39 | # resize to 512p 40 | img = img.resize((512, 512), Image.BICUBIC) 41 | flow = flow.resize((512, 512), Image.BICUBIC) 42 | mask = mask.resize((512, 512), Image.NEAREST) 43 | 44 | # joint flip 45 | if random.random() > 0.5: 46 | img = TF.hflip(img) 47 | flow = TF.hflip(flow) 48 | mask = TF.hflip(mask) 49 | if random.random() > 0.5: 50 | img = TF.vflip(img) 51 | flow = TF.vflip(flow) 52 | mask = TF.vflip(mask) 53 | 54 | # convert formats 55 | imgs = self.to_tensor(img).unsqueeze(0) 56 | flows = self.to_tensor(flow).unsqueeze(0) 57 | masks = self.to_mask(mask).unsqueeze(0) 58 | masks = (masks != 0).long() 59 | return {'imgs': imgs, 'flows': flows, 'masks': masks} 60 | -------------------------------------------------------------------------------- /evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | from .evaluator import Evaluator 2 | -------------------------------------------------------------------------------- /evaluation/evaluator.py: -------------------------------------------------------------------------------- 1 | import utils 2 | import os 3 | import time 4 | 5 | 6 | class Evaluator(object): 7 | def __init__(self, dataset): 8 | self.dataset = dataset 9 | self.img_saver = utils.ImageSaver() 10 | self.sdm = utils.DAVISLabels() 11 | 12 | def evaluate_video(self, model, video_name, video_parts, output_path): 13 | for vos_data in video_parts: 14 | imgs = vos_data['imgs'].cuda() 15 | flows = vos_data['flows'].cuda() 16 | files = vos_data['files'] 17 | 18 | # inference 19 | t0 = time.time() 20 | vos_out = model(imgs, flows) 21 | t1 = time.time() 22 | 23 | # save output 24 | for i in range(len(files)): 25 | fpath = os.path.join(output_path, video_name, files[i]) 26 | data = ((vos_out['masks'][0, i, 0, :, :].cpu().byte().numpy(), fpath), self.sdm) 27 | self.img_saver.enqueue(data) 28 | return t1 - t0, imgs.size(1) 29 | 30 | def evaluate(self, model, output_path): 31 | model.cuda() 32 | total_seconds, total_frames = 0, 0 33 | for video_name, video_parts in self.dataset.get_videos(): 34 | os.makedirs(os.path.join(output_path, video_name), exist_ok=True) 35 | seconds, frames = self.evaluate_video(model, video_name, video_parts, output_path) 36 | total_seconds = total_seconds + seconds 37 | total_frames = total_frames + frames 38 | print('{} done, {:.1f} fps'.format(video_name, frames / seconds)) 39 | print('total fps: {:.1f}\n'.format(total_frames / total_seconds)) 40 | self.img_saver.kill() 41 | -------------------------------------------------------------------------------- /evaluation/metrics.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import warnings 5 | 6 | 7 | def db_eval_iou(annotation, segmentation): 8 | annotation = annotation.astype(np.bool) 9 | segmentation = segmentation.astype(np.bool) 10 | void_pixels = np.zeros_like(segmentation) 11 | inters = np.sum((segmentation & annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 12 | union = np.sum((segmentation | annotation) & np.logical_not(void_pixels), axis=(-2, -1)) 13 | j = inters / union 14 | if j.ndim == 0: 15 | j = 1 if np.isclose(union, 0) else j 16 | else: 17 | j[np.isclose(union, 0)] = 1 18 | return j 19 | 20 | 21 | def db_eval_boundary(annotation, segmentation, bound_th=0.008): 22 | if annotation.ndim == 3: 23 | n_frames = annotation.shape[0] 24 | f_res = np.zeros(n_frames) 25 | for frame_id in range(n_frames): 26 | f_res[frame_id] = f_measure(segmentation[frame_id, :, :, ], annotation[frame_id, :, :], bound_th=bound_th) 27 | elif annotation.ndim == 2: 28 | f_res = f_measure(segmentation, annotation, bound_th=bound_th) 29 | return f_res 30 | 31 | 32 | def f_measure(foreground_mask, gt_mask, bound_th=0.008): 33 | void_pixels = np.zeros_like(foreground_mask).astype(np.bool) 34 | bound_pix = bound_th if bound_th >= 1 else np.ceil(bound_th * np.linalg.norm(foreground_mask.shape)) 35 | fg_boundary = _seg2bmap(foreground_mask * np.logical_not(void_pixels)) 36 | gt_boundary = _seg2bmap(gt_mask * np.logical_not(void_pixels)) 37 | from skimage.morphology import disk 38 | fg_dil = cv2.dilate(fg_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 39 | gt_dil = cv2.dilate(gt_boundary.astype(np.uint8), disk(bound_pix).astype(np.uint8)) 40 | gt_match = gt_boundary * fg_dil 41 | fg_match = fg_boundary * gt_dil 42 | n_fg = np.sum(fg_boundary) 43 | n_gt = np.sum(gt_boundary) 44 | if n_fg == 0 and n_gt > 0: 45 | precision = 1 46 | recall = 0 47 | elif n_fg > 0 and n_gt == 0: 48 | precision = 0 49 | recall = 1 50 | elif n_fg == 0 and n_gt == 0: 51 | precision = 1 52 | recall = 1 53 | else: 54 | precision = np.sum(fg_match) / float(n_fg) 55 | recall = np.sum(gt_match) / float(n_gt) 56 | if precision + recall == 0: 57 | F = 0 58 | else: 59 | F = 2 * precision * recall / (precision + recall) 60 | return F 61 | 62 | 63 | def _seg2bmap(seg, width=None, height=None): 64 | seg = seg.astype(np.bool) 65 | seg[seg > 0] = 1 66 | width = seg.shape[1] if width is None else width 67 | height = seg.shape[0] if height is None else height 68 | h, w = seg.shape[:2] 69 | e = np.zeros_like(seg) 70 | s = np.zeros_like(seg) 71 | se = np.zeros_like(seg) 72 | e[:, :-1] = seg[:, 1:] 73 | s[:-1, :] = seg[1:, :] 74 | se[:-1, :-1] = seg[1:, 1:] 75 | b = seg ^ e | seg ^ s | seg ^ se 76 | b[-1, :] = seg[-1, :] ^ e[-1, :] 77 | b[:, -1] = seg[:, -1] ^ s[:, -1] 78 | b[-1, -1] = 0 79 | 80 | if w == width and h == height: 81 | bmap = b 82 | else: 83 | bmap = np.zeros((height, width)) 84 | for x in range(w): 85 | for y in range(h): 86 | if b[y, x]: 87 | j = 1 + math.floor((y - 1) + height / h) 88 | i = 1 + math.floor((x - 1) + width / h) 89 | bmap[j, i] = 1 90 | return bmap 91 | 92 | 93 | def db_statistics(per_frame_values): 94 | with warnings.catch_warnings(): 95 | warnings.simplefilter('ignore', category=RuntimeWarning) 96 | M = np.nanmean(per_frame_values) 97 | O = np.nanmean(per_frame_values > 0.5) 98 | N_bins = 4 99 | ids = np.round(np.linspace(1, len(per_frame_values), N_bins + 1) + 1e-10) - 1 100 | ids = ids.astype(np.uint8) 101 | D_bins = [per_frame_values[ids[i]:ids[i + 1] + 1] for i in range(0, 4)] 102 | with warnings.catch_warnings(): 103 | warnings.simplefilter('ignore', category=RuntimeWarning) 104 | D = np.nanmean(D_bins[0]) - np.nanmean(D_bins[3]) 105 | return M, O, D -------------------------------------------------------------------------------- /fakeflow.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from transformers import SegformerModel 5 | 6 | 7 | # basic modules 8 | class Conv(nn.Sequential): 9 | def __init__(self, *conv_args): 10 | super().__init__() 11 | self.add_module('conv', nn.Conv2d(*conv_args)) 12 | for m in self.children(): 13 | if isinstance(m, nn.Conv2d): 14 | nn.init.kaiming_uniform_(m.weight) 15 | if m.bias is not None: 16 | nn.init.constant_(m.bias, 0) 17 | 18 | 19 | class ConvRelu(nn.Sequential): 20 | def __init__(self, *conv_args): 21 | super().__init__() 22 | self.add_module('conv', nn.Conv2d(*conv_args)) 23 | self.add_module('relu', nn.ReLU()) 24 | for m in self.children(): 25 | if isinstance(m, nn.Conv2d): 26 | nn.init.kaiming_uniform_(m.weight) 27 | if m.bias is not None: 28 | nn.init.constant_(m.bias, 0) 29 | 30 | 31 | class CBAM(nn.Module): 32 | def __init__(self, c): 33 | super().__init__() 34 | self.conv1 = Conv(c, c, 3, 1, 1) 35 | self.conv2 = nn.Sequential(ConvRelu(c, c, 1, 1, 0), Conv(c, c, 1, 1, 0)) 36 | self.conv3 = nn.Sequential(ConvRelu(2, 16, 3, 1, 1), Conv(16, 1, 3, 1, 1)) 37 | 38 | def forward(self, x): 39 | x = self.conv1(x) 40 | c = torch.sigmoid(self.conv2(F.adaptive_avg_pool2d(x, output_size=(1, 1))) + self.conv2(F.adaptive_max_pool2d(x, output_size=(1, 1)))) 41 | x = x * c 42 | s = torch.sigmoid(self.conv3(torch.cat([torch.mean(x, dim=1, keepdim=True), torch.max(x, dim=1, keepdim=True)[0]], dim=1))) 43 | x = x * s 44 | return x 45 | 46 | 47 | # encoding module 48 | class Encoder(nn.Module): 49 | def __init__(self, ver): 50 | super().__init__() 51 | 52 | # MiT-b0 backbone 53 | if ver == 'mitb0': 54 | self.backbone = SegformerModel.from_pretrained('nvidia/mit-b0') 55 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 56 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 57 | 58 | # MiT-b1 backbone 59 | if ver == 'mitb1': 60 | self.backbone = SegformerModel.from_pretrained('nvidia/mit-b1') 61 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 62 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 63 | 64 | # MiT-b2 backbone 65 | if ver == 'mitb2': 66 | self.backbone = SegformerModel.from_pretrained('nvidia/mit-b2') 67 | self.register_buffer('mean', torch.FloatTensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) 68 | self.register_buffer('std', torch.FloatTensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) 69 | 70 | def forward(self, img): 71 | x = (img - self.mean) / self.std 72 | x = self.backbone(x, output_hidden_states=True).hidden_states 73 | s4 = x[0] 74 | s8 = x[1] 75 | s16 = x[2] 76 | s32 = x[3] 77 | return {'s4': s4, 's8': s8, 's16': s16, 's32': s32} 78 | 79 | 80 | # decoding module 81 | class Decoder(nn.Module): 82 | def __init__(self, ver): 83 | super().__init__() 84 | 85 | # MiT-b0 backbone 86 | if ver == 'mitb0': 87 | self.conv1 = ConvRelu(256, 256, 1, 1, 0) 88 | self.blend1 = ConvRelu(256, 256, 3, 1, 1) 89 | self.cbam1 = CBAM(256) 90 | self.conv2 = ConvRelu(160, 256, 1, 1, 0) 91 | self.blend2 = ConvRelu(256 + 256, 256, 3, 1, 1) 92 | self.cbam2 = CBAM(256) 93 | self.conv3 = ConvRelu(64, 256, 1, 1, 0) 94 | self.blend3 = ConvRelu(256 + 256, 256, 3, 1, 1) 95 | self.cbam3 = CBAM(256) 96 | self.conv4 = ConvRelu(32, 256, 1, 1, 0) 97 | self.blend4 = ConvRelu(256 + 256, 256, 3, 1, 1) 98 | self.cbam4 = CBAM(256) 99 | self.predictor = Conv(256, 2, 3, 1, 1) 100 | 101 | # MiT-b1 and MiT-b2 backbones 102 | else: 103 | self.conv1 = ConvRelu(512, 256, 1, 1, 0) 104 | self.blend1 = ConvRelu(256, 256, 3, 1, 1) 105 | self.cbam1 = CBAM(256) 106 | self.conv2 = ConvRelu(320, 256, 1, 1, 0) 107 | self.blend2 = ConvRelu(256 + 256, 256, 3, 1, 1) 108 | self.cbam2 = CBAM(256) 109 | self.conv3 = ConvRelu(128, 256, 1, 1, 0) 110 | self.blend3 = ConvRelu(256 + 256, 256, 3, 1, 1) 111 | self.cbam3 = CBAM(256) 112 | self.conv4 = ConvRelu(64, 256, 1, 1, 0) 113 | self.blend4 = ConvRelu(256 + 256, 256, 3, 1, 1) 114 | self.cbam4 = CBAM(256) 115 | self.predictor = Conv(256, 2, 3, 1, 1) 116 | 117 | def forward(self, app_feats, mo_feats): 118 | x = self.conv1(app_feats['s32'] + mo_feats['s32']) 119 | x = self.cbam1(self.blend1(x)) 120 | s16 = F.interpolate(x, scale_factor=2, mode='bicubic') 121 | x = torch.cat([self.conv2(app_feats['s16'] + mo_feats['s16']), s16], dim=1) 122 | x = self.cbam2(self.blend2(x)) 123 | s8 = F.interpolate(x, scale_factor=2, mode='bicubic') 124 | x = torch.cat([self.conv3(app_feats['s8'] + mo_feats['s8']), s8], dim=1) 125 | x = self.cbam3(self.blend3(x)) 126 | s4 = F.interpolate(x, scale_factor=2, mode='bicubic') 127 | x = torch.cat([self.conv4(app_feats['s4'] + mo_feats['s4']), s4], dim=1) 128 | x = self.predictor(self.cbam4(self.blend4(x))) 129 | score = F.interpolate(x, scale_factor=4, mode='bicubic') 130 | return score 131 | 132 | 133 | # VOS model 134 | class VOS(nn.Module): 135 | def __init__(self, ver): 136 | super().__init__() 137 | self.app_encoder = Encoder(ver) 138 | self.mo_encoder = Encoder(ver) 139 | self.decoder = Decoder(ver) 140 | 141 | 142 | # FakeFlow model 143 | class FakeFlow(nn.Module): 144 | def __init__(self, ver): 145 | super().__init__() 146 | self.vos = VOS(ver) 147 | 148 | def forward(self, imgs, flows): 149 | B, L, _, H1, W1 = imgs.size() 150 | _, _, _, H2, W2 = flows.size() 151 | 152 | # resize to 512p 153 | s = 512 154 | imgs = F.interpolate(imgs.view(B * L, -1, H1, W1), size=(s, s), mode='bicubic').view(B, L, -1, s, s) 155 | flows = F.interpolate(flows.view(B * L, -1, H2, W2), size=(s, s), mode='bicubic').view(B, L, -1, s, s) 156 | 157 | # for each frame 158 | score_lst = [] 159 | mask_lst = [] 160 | for i in range(L): 161 | 162 | # query frame prediction 163 | app_feats = self.vos.app_encoder(imgs[:, i]) 164 | mo_feats = self.vos.mo_encoder(flows[:, i]) 165 | score = self.vos.decoder(app_feats, mo_feats) 166 | score = F.interpolate(score, size=(H1, W1), mode='bicubic') 167 | 168 | # store soft scores 169 | if B != 1: 170 | score_lst.append(score) 171 | 172 | # store hard masks 173 | if B == 1: 174 | pred_seg = torch.softmax(score, dim=1) 175 | pred_mask = torch.max(pred_seg, dim=1, keepdim=True)[1] 176 | mask_lst.append(pred_mask) 177 | 178 | # generate output 179 | output = {} 180 | if B != 1: 181 | output['scores'] = torch.stack(score_lst, dim=1) 182 | if B == 1: 183 | output['masks'] = torch.stack(mask_lst, dim=1) 184 | return output 185 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.11.0 2 | torchvision==0.12.0 3 | pypng==0.0.21 4 | transformers==4.30.2 5 | opencv-python==4.7.0.72 6 | numpy==1.22.0 7 | scikit-image==0.19.3 8 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from dataset import * 2 | import evaluation 3 | from fakeflow import FakeFlow 4 | from trainer import Trainer 5 | from optparse import OptionParser 6 | import warnings 7 | warnings.filterwarnings('ignore') 8 | 9 | 10 | parser = OptionParser() 11 | parser.add_option('--train', action='store_true', default=None) 12 | parser.add_option('--test', action='store_true', default=None) 13 | options = parser.parse_args()[0] 14 | 15 | 16 | def train_ytvos(model): 17 | train_set = TrainYTVOS('../DB/YTVOS18', 'train', clip_n=512) 18 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True, num_workers=4, pin_memory=True) 19 | val_set = TestDAVIS('../DB/DAVIS', '2016', 'val') 20 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) 21 | trainer = Trainer(model, optimizer, train_loader, val_set, save_name='ytvos', save_step=1000, val_step=100) 22 | trainer.train(4000) 23 | 24 | 25 | def train_dutsv2_davis(model): 26 | dutsv2_set = TrainDUTSv2('../DB/DUTSv2', clip_n=384) 27 | davis_set = TrainDAVIS('../DB/DAVIS', '2016', 'train', clip_n=128) 28 | train_set = torch.utils.data.ConcatDataset([dutsv2_set, davis_set]) 29 | train_loader = torch.utils.data.DataLoader(train_set, batch_size=16, shuffle=True, num_workers=4, pin_memory=True) 30 | val_set = TestDAVIS('../DB/DAVIS', '2016', 'val') 31 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) 32 | trainer = Trainer(model, optimizer, train_loader, val_set, save_name='ytvos_dutsv2_davis', save_step=1000, val_step=100) 33 | trainer.train(2000) 34 | 35 | 36 | def test_davis(model): 37 | evaluator = evaluation.Evaluator(TestDAVIS('../DB/DAVIS', '2016', 'val')) 38 | evaluator.evaluate(model, os.path.join('outputs', 'DAVIS16_val')) 39 | 40 | 41 | def test_fbms(model): 42 | test_set = TestFBMS('../DB/FBMS/TestSet') 43 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=4) 44 | model.cuda() 45 | ious = [] 46 | 47 | # inference 48 | for vos_data in test_loader: 49 | imgs = vos_data['imgs'].cuda() 50 | flows = vos_data['flows'].cuda() 51 | masks = vos_data['masks'].cuda() 52 | video_name = vos_data['video_name'][0] 53 | files = vos_data['files'] 54 | os.makedirs('outputs/FBMS_test/{}'.format(video_name), exist_ok=True) 55 | vos_out = model(imgs, flows) 56 | 57 | # get iou of each sequence 58 | iou = 0 59 | count = 0 60 | for i in range(masks.size(1)): 61 | tv.utils.save_image(vos_out['masks'][0, i].float(), 'outputs/FBMS_test/{}/{}'.format(video_name, files[i][0].split('/')[-1])) 62 | if torch.sum(masks[0, i]) == 0: 63 | continue 64 | iou = iou + torch.sum(masks[0, i] * vos_out['masks'][0, i]) / torch.sum((masks[0, i] + vos_out['masks'][0, i]).clamp(0, 1)) 65 | count = count + 1 66 | print('{} iou: {:.5f}'.format(video_name, iou / count)) 67 | ious.append(iou / count) 68 | 69 | # calculate overall iou 70 | print('total seqs\' iou: {:.5f}\n'.format(sum(ious) / len(ious))) 71 | 72 | 73 | def test_ytobj(model): 74 | test_set = TestYTOBJ('../DB/YTOBJ') 75 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=4) 76 | model.cuda() 77 | ious = {'aeroplane': [], 'bird': [], 'boat': [], 'car': [], 'cat': [], 'cow': [], 'dog': [], 'horse': [], 'motorbike': [], 'train': []} 78 | total_iou = 0 79 | total_count = 0 80 | 81 | # inference 82 | for vos_data in test_loader: 83 | imgs = vos_data['imgs'].cuda() 84 | flows = vos_data['flows'].cuda() 85 | masks = vos_data['masks'].cuda() 86 | class_name = vos_data['class_name'][0] 87 | video_name = vos_data['video_name'][0] 88 | files = vos_data['files'] 89 | os.makedirs('outputs/YTOBJ/{}/{}'.format(class_name, video_name), exist_ok=True) 90 | vos_out = model(imgs, flows) 91 | 92 | # get iou of each sequence 93 | iou = 0 94 | count = 0 95 | for i in range(masks.size(1)): 96 | tv.utils.save_image(vos_out['masks'][0, i].float(), 'outputs/YTOBJ/{}/{}/{}'.format(class_name, video_name, files[i][0].split('/')[-1])) 97 | if torch.sum(masks[0, i]) == 0: 98 | continue 99 | iou = iou + torch.sum(masks[0, i] * vos_out['masks'][0, i]) / torch.sum((masks[0, i] + vos_out['masks'][0, i]).clamp(0, 1)) 100 | count = count + 1 101 | if count == 0: 102 | continue 103 | print('{}_{} iou: {:.5f}'.format(class_name, video_name, iou / count)) 104 | ious[class_name].append(iou / count) 105 | total_iou = total_iou + iou / count 106 | total_count = total_count + 1 107 | 108 | # calculate overall iou 109 | for class_name in ious.keys(): 110 | print('class: {} seqs\' iou: {:.5f}'.format(class_name, sum(ious[class_name]) / len(ious[class_name]))) 111 | print('total seqs\' iou: {:.5f}\n'.format(total_iou / total_count)) 112 | 113 | 114 | def test_lvid(model): 115 | test_set = TestLVID('../DB/LVID') 116 | test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, num_workers=4) 117 | model.cuda() 118 | ious = [] 119 | 120 | # inference 121 | for vos_data in test_loader: 122 | imgs = vos_data['imgs'].cuda() 123 | flows = vos_data['flows'].cuda() 124 | masks = vos_data['masks'].cuda() 125 | video_name = vos_data['video_name'][0] 126 | files = vos_data['files'] 127 | os.makedirs('outputs/LVID/{}'.format(video_name), exist_ok=True) 128 | vos_out = model(imgs, flows) 129 | 130 | # get iou of each sequence 131 | iou = 0 132 | count = 0 133 | for i in range(masks.size(1)): 134 | tv.utils.save_image(vos_out['masks'][0, i].float(), 'outputs/LVID/{}/{}'.format(video_name, files[i][0].split('/')[-1])) 135 | if torch.sum(masks[0, i]) == 0: 136 | continue 137 | iou = iou + torch.sum(masks[0, i] * vos_out['masks'][0, i]) / torch.sum((masks[0, i] + vos_out['masks'][0, i]).clamp(0, 1)) 138 | count = count + 1 139 | print('{} iou: {:.5f}'.format(video_name, iou / count)) 140 | ious.append(iou / count) 141 | 142 | # calculate overall iou 143 | print('total seqs\' iou: {:.5f}\n'.format(sum(ious) / len(ious))) 144 | 145 | 146 | if __name__ == '__main__': 147 | 148 | # set device 149 | torch.cuda.set_device(0) 150 | 151 | # define model 152 | ver = 'mitb2' 153 | model = FakeFlow(ver).eval() 154 | 155 | # training stage 156 | if options.train: 157 | model = torch.nn.DataParallel(model) 158 | train_ytvos(model) 159 | train_dutsv2_davis(model) 160 | 161 | # testing stage 162 | if options.test: 163 | model.load_state_dict(torch.load('weights/FakeFlow_{}.pth'.format(ver), map_location='cpu')) 164 | with torch.no_grad(): 165 | test_davis(model) 166 | test_fbms(model) 167 | test_ytobj(model) 168 | test_lvid(model) 169 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | from evaluation import metrics 2 | from utils import AverageMeter, get_iou 3 | import copy 4 | import numpy 5 | import torch 6 | 7 | 8 | class Trainer(object): 9 | def __init__(self, model, optimizer, train_loader, val_set, save_name, save_step, val_step): 10 | self.model = model.cuda() 11 | self.optimizer = optimizer 12 | self.train_loader = train_loader 13 | self.val_set = val_set 14 | self.save_name = save_name 15 | self.save_step = save_step 16 | self.val_step = val_step 17 | self.epoch = 1 18 | self.best_score = 0 19 | self.score = 0 20 | self.stats = {'loss': AverageMeter(), 'iou': AverageMeter()} 21 | 22 | def train(self, max_epochs): 23 | for epoch in range(self.epoch, max_epochs + 1): 24 | self.epoch = epoch 25 | self.train_epoch() 26 | if self.epoch % self.save_step == 0: 27 | print('saving checkpoint\n') 28 | self.save_checkpoint() 29 | if self.score > self.best_score: 30 | print('new best checkpoint, after epoch {}\n'.format(self.epoch)) 31 | self.save_checkpoint(alt_name='best') 32 | self.best_score = self.score 33 | print('finished training!\n', flush=True) 34 | 35 | def train_epoch(self): 36 | 37 | # train 38 | self.model.train() 39 | self.cycle_dataset(mode='train') 40 | 41 | # val 42 | self.model.eval() 43 | if self.epoch % self.val_step == 0: 44 | if self.val_set is not None: 45 | with torch.no_grad(): 46 | self.score = self.cycle_dataset(mode='val') 47 | 48 | # update stats 49 | for stat_value in self.stats.values(): 50 | stat_value.new_epoch() 51 | 52 | def cycle_dataset(self, mode): 53 | if mode == 'train': 54 | for vos_data in self.train_loader: 55 | imgs = vos_data['imgs'].cuda() 56 | flows = vos_data['flows'].cuda() 57 | masks = vos_data['masks'].cuda() 58 | B, L, _, H, W = imgs.size() 59 | 60 | # model run 61 | vos_out = self.model(imgs, flows) 62 | loss = torch.nn.CrossEntropyLoss()(vos_out['scores'].view(B * L, 2, H, W), masks.reshape(B * L, H, W)) 63 | 64 | # backward 65 | self.optimizer.zero_grad() 66 | loss.backward() 67 | self.optimizer.step() 68 | 69 | # loss, iou 70 | self.stats['loss'].update(loss.detach().cpu().item(), B) 71 | iou = torch.mean(get_iou(vos_out['scores'].view(B * L, 2, H, W), masks.reshape(B * L, H, W))[:, 1:]) 72 | self.stats['iou'].update(iou.detach().cpu().item(), B) 73 | 74 | print('[ep{:04d}] loss: {:.5f}, iou: {:.5f}'.format(self.epoch, self.stats['loss'].avg, self.stats['iou'].avg)) 75 | 76 | if mode == 'val': 77 | metrics_res = {} 78 | metrics_res['J'] = [] 79 | metrics_res['F'] = [] 80 | for video_name, video_parts in self.val_set.get_videos(): 81 | for vos_data in video_parts: 82 | imgs = vos_data['imgs'].cuda() 83 | flows = vos_data['flows'].cuda() 84 | masks = vos_data['masks'].cuda() 85 | 86 | # inference 87 | vos_out = self.model(imgs, flows) 88 | res_masks = vos_out['masks'][:, 1:-1].squeeze(2) 89 | gt_masks = masks[:, 1:-1].squeeze(2) 90 | B, L, H, W = res_masks.shape 91 | object_ids = numpy.unique(gt_masks.cpu()).tolist() 92 | object_ids.remove(0) 93 | 94 | # evaluate output 95 | all_res_masks = numpy.zeros((len(object_ids), L, H, W)) 96 | all_gt_masks = numpy.zeros((len(object_ids), L, H, W)) 97 | for k in object_ids: 98 | res_masks_k = copy.deepcopy(res_masks).cpu().numpy() 99 | res_masks_k[res_masks_k != k] = 0 100 | res_masks_k[res_masks_k != 0] = 1 101 | all_res_masks[k - 1] = res_masks_k[0] 102 | gt_masks_k = copy.deepcopy(gt_masks).cpu().numpy() 103 | gt_masks_k[gt_masks_k != k] = 0 104 | gt_masks_k[gt_masks_k != 0] = 1 105 | all_gt_masks[k - 1] = gt_masks_k[0] 106 | 107 | # calculate scores 108 | j_metrics_res = numpy.zeros(all_gt_masks.shape[:2]) 109 | f_metrics_res = numpy.zeros(all_gt_masks.shape[:2]) 110 | for i in range(all_gt_masks.shape[0]): 111 | j_metrics_res[i] = metrics.db_eval_iou(all_gt_masks[i], all_res_masks[i]) 112 | f_metrics_res[i] = metrics.db_eval_boundary(all_gt_masks[i], all_res_masks[i]) 113 | [JM, _, _] = metrics.db_statistics(j_metrics_res[i]) 114 | metrics_res['J'].append(JM) 115 | [FM, _, _] = metrics.db_statistics(f_metrics_res[i]) 116 | metrics_res['F'].append(FM) 117 | 118 | # gather scores 119 | J, F = metrics_res['J'], metrics_res['F'] 120 | final_mean = (numpy.mean(J) + numpy.mean(F)) / 2. 121 | print('[ep{:04d}] J&F score: {:.5f}\n'.format(self.epoch, final_mean)) 122 | return final_mean 123 | 124 | def save_checkpoint(self, alt_name=None): 125 | if alt_name is not None: 126 | file_path = 'weights/{}_{}.pth'.format(self.save_name, alt_name) 127 | else: 128 | file_path = 'weights/{}_{:04d}.pth'.format(self.save_name, self.epoch) 129 | torch.save(self.model.module.state_dict(), file_path) 130 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import png 3 | import time 4 | import math 5 | import numpy 6 | import queue 7 | import threading 8 | 9 | 10 | DAVIS_PALETTE_4BIT = [[ 0, 0, 0], 11 | [128, 0, 0], 12 | [ 0, 128, 0], 13 | [128, 128, 0], 14 | [ 0, 0, 128], 15 | [128, 0, 128], 16 | [ 0, 128, 128], 17 | [128, 128, 128], 18 | [ 64, 0, 0], 19 | [191, 0, 0], 20 | [ 64, 128, 0], 21 | [191, 128, 0], 22 | [ 64, 0, 128], 23 | [191, 0, 128], 24 | [ 64, 128, 128], 25 | [191, 128, 128]] 26 | 27 | 28 | class ReadSaveImage(object): 29 | def __init__(self): 30 | super().__init__() 31 | 32 | def check_path(self, fullpath): 33 | path, filename = os.path.split(fullpath) 34 | if not os.path.exists(path): 35 | os.makedirs(path) 36 | 37 | 38 | class DAVISLabels(ReadSaveImage): 39 | def __init__(self): 40 | super().__init__() 41 | self._width = 0 42 | self._height = 0 43 | 44 | def save(self, image, path): 45 | self.check_path(path) 46 | bitdepth = int(math.log(len(DAVIS_PALETTE_4BIT)) / math.log(2)) 47 | height, width = image.shape 48 | file = open(path, 'wb') 49 | writer = png.Writer(width, height, palette=DAVIS_PALETTE_4BIT, bitdepth=bitdepth) 50 | writer.write(file, image) 51 | 52 | def read(self, path): 53 | try: 54 | reader = png.Reader(path) 55 | width, height, data, meta = reader.read() 56 | image = numpy.vstack(data) 57 | self._height, self._width = image.shape 58 | except png.FormatError: 59 | image = numpy.zeros((self._height, self._width)) 60 | self.save(image, path) 61 | return image 62 | 63 | 64 | class ImageSaver(threading.Thread): 65 | def __init__(self): 66 | super().__init__() 67 | self._alive = True 68 | self._queue = queue.Queue(2 ** 20) 69 | self.start() 70 | 71 | @property 72 | def alive(self): 73 | return self._alive 74 | 75 | @alive.setter 76 | def alive(self, alive): 77 | self._alive = alive 78 | 79 | @property 80 | def queue(self): 81 | return self._queue 82 | 83 | def kill(self): 84 | self._alive = False 85 | 86 | def enqueue(self, datatuple): 87 | ret = True 88 | try: 89 | self._queue.put(datatuple, block=False) 90 | except queue.Full: 91 | print('enqueue full') 92 | ret = False 93 | return ret 94 | 95 | def run(self): 96 | while True: 97 | while not self._queue.empty(): 98 | args, method = self._queue.get(block=False, timeout=2) 99 | method.save(*args) 100 | self._queue.task_done() 101 | if not self._alive and self._queue.empty(): 102 | break 103 | time.sleep(0.0001) 104 | 105 | 106 | class AverageMeter(object): 107 | def __init__(self): 108 | self.clear() 109 | 110 | def reset(self): 111 | self.avg = 0 112 | self.val = 0 113 | self.sum = 0 114 | self.count = 0 115 | 116 | def clear(self): 117 | self.reset() 118 | self.history = [] 119 | 120 | def update(self, val, n=1): 121 | self.val = val 122 | self.sum += val * n 123 | self.count += n 124 | if self.count > 0: 125 | self.avg = self.sum / self.count 126 | else: 127 | self.avg = 'nan' 128 | 129 | def new_epoch(self): 130 | self.history.append(self.avg) 131 | self.reset() 132 | 133 | 134 | def get_iou(predictions, gt): 135 | nsamples, nclasses, height, width = predictions.size() 136 | prediction_max, prediction_argmax = predictions.max(-3) 137 | prediction_argmax = prediction_argmax.long() 138 | classes = gt.new_tensor([c for c in range(nclasses)]).view(1, nclasses, 1, 1) 139 | pred_bin = (prediction_argmax.view(nsamples, 1, height, width) == classes) 140 | gt_bin = (gt.view(nsamples, 1, height, width) == classes) 141 | intersection = (pred_bin * gt_bin).float().sum(dim=-2).sum(dim=-1) 142 | union = ((pred_bin + gt_bin) > 0).float().sum(dim=-2).sum(dim=-1) 143 | return (intersection + 1e-7) / (union + 1e-7) 144 | -------------------------------------------------------------------------------- /weights/empty.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/suhwan-cho/FakeFlow/abfb2a864a83524f2b173cf52826f92a1edf5b0c/weights/empty.txt --------------------------------------------------------------------------------