├── Dockerfile ├── LICENSE ├── README.md ├── configs ├── __pycache__ │ └── resnet101_p2t.cpython-36.pyc ├── resnet101_rpcm_ytb_stage_1.py └── resnet101_rpcm_ytb_stage_2.py ├── dataloaders ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── custom_transforms.cpython-36.pyc │ └── datasets_m.cpython-36.pyc ├── custom_transforms.py └── datasets.py ├── networks ├── __init__.py ├── __pycache__ │ └── __init__.cpython-36.pyc ├── deeplab │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── aspp.cpython-36.pyc │ │ ├── decoder.cpython-36.pyc │ │ └── deeplab.cpython-36.pyc │ ├── aspp.py │ ├── backbone │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── mobilenet.cpython-36.pyc │ │ │ └── resnet.cpython-36.pyc │ │ ├── mobilenet.py │ │ └── resnet.py │ ├── decoder.py │ └── deeplab.py ├── engine │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-36.pyc.140354633045472 │ │ ├── eval_manager_mm.cpython-36.pyc │ │ └── eval_manager_mm_rpa.cpython-36.pyc │ ├── eval_manager.py │ ├── eval_manager_rpa.py │ └── train_manager.py ├── layers │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── aspp.cpython-36.pyc │ │ ├── attention.cpython-36.pyc │ │ ├── gct.cpython-36.pyc │ │ ├── loss.cpython-36.pyc │ │ ├── matching.cpython-36.pyc │ │ ├── normalization.cpython-36.pyc │ │ └── shannon_entropy.cpython-36.pyc │ ├── aspp.py │ ├── attention.py │ ├── gct.py │ ├── loss.py │ ├── matching.py │ ├── normalization.py │ └── shannon_entropy.py └── rpcm │ ├── __init__.py │ ├── __pycache__ │ ├── __init__.cpython-36.pyc │ ├── p2t_base.cpython-36.pyc │ └── prop_module.cpython-36.pyc │ ├── prop_module.py │ └── rpcm.py ├── requirements.txt ├── scripts ├── ytb_eval_with_RPA.sh ├── ytb_eval_without_RPA.sh └── ytb_train.sh ├── tools ├── eval.py ├── eval_rpa.py └── train.py └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-36.pyc ├── checkpoint.cpython-36.pyc ├── eval.cpython-36.pyc ├── image.cpython-36.pyc └── meters.cpython-36.pyc ├── checkpoint.py ├── eval.py ├── image.py ├── learning.py ├── meters.py └── metric.py /Dockerfile: -------------------------------------------------------------------------------- 1 | # main image 2 | FROM nvidia/cuda:10.1-cudnn7-devel-ubuntu18.04 3 | 4 | # tweaked azureml pytorch image 5 | # as of Aug 31, 2020 pt 1.6 doesn't seem to work with horovod on native mixed precision 6 | 7 | LABEL maintainer="Albert" 8 | LABEL maintainer_email="alsadovn@microsoft.com" 9 | LABEL version="0.1" 10 | 11 | USER root:root 12 | 13 | ENV com.nvidia.cuda.version $CUDA_VERSION 14 | ENV com.nvidia.volumes.needed nvidia_driver 15 | ENV LANG=C.UTF-8 LC_ALL=C.UTF-8 16 | ENV DEBIAN_FRONTEND noninteractive 17 | ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda/lib64:/usr/local/cuda/extras/CUPTI/lib64 18 | ENV NCCL_DEBUG=INFO 19 | ENV HOROVOD_GPU_ALLREDUCE=NCCL 20 | 21 | # Install Common Dependencies 22 | RUN apt-get update && \ 23 | apt-get install -y --no-install-recommends \ 24 | # SSH and RDMA 25 | libmlx4-1 \ 26 | libmlx5-1 \ 27 | librdmacm1 \ 28 | libibverbs1 \ 29 | libmthca1 \ 30 | libdapl2 \ 31 | dapl2-utils \ 32 | openssh-client \ 33 | openssh-server \ 34 | iproute2 && \ 35 | # Others 36 | apt-get install -y --no-install-recommends \ 37 | build-essential \ 38 | bzip2=1.0.6-8.1ubuntu0.2 \ 39 | libbz2-1.0=1.0.6-8.1ubuntu0.2 \ 40 | systemd \ 41 | git=1:2.17.1-1ubuntu0.7 \ 42 | wget \ 43 | cpio \ 44 | libsm6 \ 45 | libxext6 \ 46 | libxrender-dev \ 47 | fuse && \ 48 | apt-get clean -y && \ 49 | rm -rf /var/lib/apt/lists/* 50 | 51 | # Conda Environment 52 | ENV MINICONDA_VERSION 4.7.12.1 53 | ENV PATH /opt/miniconda/bin:$PATH 54 | RUN wget -qO /tmp/miniconda.sh https://repo.continuum.io/miniconda/Miniconda3-${MINICONDA_VERSION}-Linux-x86_64.sh && \ 55 | bash /tmp/miniconda.sh -bf -p /opt/miniconda && \ 56 | conda clean -ay && \ 57 | rm -rf /opt/miniconda/pkgs && \ 58 | rm /tmp/miniconda.sh && \ 59 | find / -type d -name __pycache__ | xargs rm -rf 60 | 61 | # To resolve horovod hangs due to a known NCCL issue in version 2.4. 62 | # Can remove it once we upgrade NCCL to 2.5+. 63 | # https://github.com/horovod/horovod/issues/893 64 | # ENV NCCL_TREE_THRESHOLD=0 65 | ENV PIP="pip install --no-cache-dir" 66 | 67 | RUN conda install -y conda=4.8.5 python=3.6.2 && conda clean -ay && \ 68 | conda install -y mkl=2020.1 && \ 69 | conda install -y numpy scipy scikit-learn scikit-image imageio protobuf && \ 70 | conda install -y ruamel.yaml==0.16.10 && \ 71 | # ruamel_yaml is a copy of ruamel.yaml package 72 | # conda installs version ruamel_yaml v0.15.87 which is vulnerable 73 | # force uninstall it leaving other packages intact 74 | conda remove --force -y ruamel_yaml && \ 75 | conda clean -ay && \ 76 | # Install AzureML SDK 77 | ${PIP} azureml-defaults && \ 78 | # Install PyTorch 79 | ${PIP} torch==1.4.0 && \ 80 | ${PIP} torchvision==0.2.1 && \ 81 | ${PIP} wandb && \ 82 | # # Install Horovod 83 | # HOROVOD_WITH_PYTORCH=1 ${PIP} horovod[pytorch]==0.19.5 && \ 84 | # ldconfig && \ 85 | ${PIP} tensorboard==1.15.0 && \ 86 | ${PIP} future==0.17.1 && \ 87 | ${PIP} onnxruntime==1.4.0 && \ 88 | ${PIP} pytorch-lightning && \ 89 | ${PIP} opencv-python-headless~=4.4.0 && \ 90 | ${PIP} imgaug==0.4.0 --no-deps && \ 91 | # hydra 92 | ${PIP} hydra-core --upgrade && \ 93 | ${PIP} lmdb pyarrow 94 | 95 | RUN pip3 install --upgrade pip 96 | RUN pip3 install pipreqs 97 | 98 | RUN apt-get update 99 | RUN apt-get install -y --no-install-recommends libglib2.0-dev 100 | RUN apt-get install -y --no-install-recommends vim 101 | 102 | WORKDIR / 103 | RUN apt-get install -y --no-install-recommends libunwind8 104 | RUN apt-get install -y --no-install-recommends libicu-dev 105 | RUN apt-get install -y --no-install-recommends htop 106 | RUN apt-get install -y --no-install-recommends net-tools 107 | RUN apt-get install -y --no-install-recommends rsync 108 | RUN apt-get install -y --no-install-recommends tree 109 | 110 | RUN wget -O azcopy.tar.gz https://aka.ms/downloadazcopylinux64 111 | RUN tar -xf azcopy.tar.gz 112 | RUN ./install.sh 113 | 114 | # put the requirements file for your own repo under /app for pip-based installation!!! 115 | WORKDIR /app 116 | RUN pip3 install -r requirements.txt 117 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Xiaohao Xu 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Reliable Propagation-Correction Modulation for Video Object Segmentation (AAAI22 Oral) 2 | 3 | ![Picture1](https://user-images.githubusercontent.com/65257938/145016835-3c4be820-c55d-4eb4-b7f5-b8a012ee0f8c.png) 4 | 5 | Preview version paper of this work is available at [Arxiv](https://arxiv.org/abs/2112.02853) 6 | 7 | AAAI [long paper presentation ppt](https://docs.google.com/presentation/d/1szmHc-2s1RpEfxr5cCbIRys85rpF85CR/edit?usp=sharing&ouid=113055027221294032292&rtpof=true&sd=true), [short one-minute paper presentation ppt](https://docs.google.com/presentation/d/1mv5xCWJQ0G5nVsrdQ5mY78RTb8HYlVMT/edit?usp=sharing&ouid=113055027221294032292&rtpof=true&sd=true), and the [poster](https://drive.google.com/file/d/1xKf0MvxxTqgbDGCBOkRSKdAuix0mVGOz/view?usp=sharing) are avavilable! 8 | 9 | Qualitative results and comparisons with previous SOTAs are available at both [YouTube](https://youtu.be/X6BsS3t3wnc) and [Bilibili](https://www.bilibili.com/video/BV1pr4y1D7TQ?spm_id_from=333.999.0.0). 10 | 11 | [Thanks to someone (I don't know) who transports the video to bilibili😀.] 12 | 13 | **This repo is a preview version. More details will be added later. Welcome to starts ⭐ & comments 💹 & collaboration 😀 !!** 14 | 15 | ```diff 16 | - 2023.4.1: The link for pretrained backbone ckpt is updated (as previous one has expired). 17 | - 2022.7.9: Our complete code is re-released! 18 | - 2022.3.9: Dockerfile is added for easy env setup and modification. 19 | - 2022.3.6: Our presentation PPT and Poster for AAAI22 are available now on GoogleDrive! 20 | - 2022.2.16 😀: Our paper has been selected as **Oral Presentation** in AAAI22! (Oral Acceptance Rate is about 4.5% this year (15% x 30%)) 21 | - 2021.12.25 🎅🎄: Precomputed Results on YouTube-VOS18/19 and DAVIS17 Val/Test-dev are available on both GoogleDrive and BaiduDisk! 22 | - 2021.12.14: Stay tuned for the code release! 23 | ``` 24 | --- 25 | 26 | 27 | 28 | ## Abstract 29 | **Error propagation** is a general but crucial problem in **online semi-supervised video object segmentation**. We aim to **suppress error propagation through a correction mechanism with high reliability**. 30 | 31 | The key insight is **to disentangle the correction from the conventional mask propagation process with reliable cues**. 32 | 33 | We **introduce two modulators, propagation and correction modulators,** to separately perform channel-wise re-calibration on the target frame embeddings according to local temporal correlations and reliable references respectively. Specifically, we assemble the modulators with a cascaded propagation-correction scheme. This avoids overriding the effects of the reliable correction modulator by the propagation modulator. 34 | 35 | Although the reference frame with the ground truth label provides reliable cues, it could be very different from the target frame and introduce uncertain or incomplete correlations. We **augment the reference cues by supplementing reliable feature patches to a maintained pool**, thus offering more comprehensive and expressive object representations to the modulators. In addition, a reliability filter is designed to retrieve reliable patches and pass them in subsequent frames. 36 | 37 | Our model achieves **state-of-the-art performance on YouTube-VOS18/19 and DAVIS17-Val/Test** benchmarks. Extensive experiments demonstrate that the correction mechanism provides considerable performance gain by fully utilizing reliable guidance. 38 | 39 | ## Requirements 40 | * Python3 41 | * pytorch >= 1.4.0 42 | * torchvision 43 | * opencv-python 44 | * Pillow 45 | 46 | You can also use the docker image below to set up your env directly. However, this docker image may contain some redundent packages. 47 | 48 | ```latex 49 | docker image: xxiaoh/vos:10.1-cudnn7-torch1.4_v3 50 | ``` 51 | 52 | A more light-weight version can be created by modified the [Dockerfile](https://github.com/JerryX1110/RPCMVOS/blob/main/Dockerfile) provided. 53 | 54 | ## Preparation 55 | * Datasets 56 | 57 | * **YouTube-VOS** 58 | 59 | A commonly-used large-scale VOS dataset. 60 | 61 | [datasets/YTB/2019](datasets/YTB/2019): version 2019, download [link](https://drive.google.com/drive/folders/1BWzrCWyPEmBEKm0lOHe5KLuBuQxUSwqz?usp=sharing). `train` is required for training. `valid` (6fps) and `valid_all_frames` (30fps, optional) are used for evaluation. 62 | 63 | [datasets/YTB/2018](datasets/YTB/2018): version 2018, download [link](https://drive.google.com/drive/folders/1bI5J1H3mxsIGo7Kp-pPZU8i6rnykOw7f?usp=sharing). Only `valid` (6fps) and `valid_all_frames` (30fps, optional) are required for this project and used for evaluation. 64 | 65 | * **DAVIS** 66 | 67 | A commonly-used small-scale VOS dataset. 68 | 69 | [datasets/DAVIS](datasets/DAVIS): [TrainVal](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-trainval-480p.zip) (480p) contains both the training and validation split. [Test-Dev](https://data.vision.ee.ethz.ch/csergi/share/davis/DAVIS-2017-test-dev-480p.zip) (480p) contains the Test-dev split. The [full-resolution version](https://davischallenge.org/davis2017/code.html) is also supported for training and evaluation but not required. 70 | 71 | * pretrained weights for the backbone 72 | 73 | [resnet101-deeplabv3p](https://drive.google.com/file/d/101jYpeGGG58Kywk03331PKKzv1wN5DL0/view?usp=sharing) 74 | 75 | 76 | ## Training 77 | Training for YouTube-VOS: 78 | 79 | sh ../scripts/ytb_train.sh 80 | 81 | * Notice that the some training parameters need to be changed according to your hardware environment, such as the interval to save a checkpoint. 82 | * More details will be added soon. 83 | 84 | ## Inference 85 | Using **r**eliable object **p**roxy **a**ugmentation (RPA) 86 | 87 | sh ../scripts/ytb_eval_with_RPA.sh 88 | 89 | Without using **r**eliable object **p**roxy **a**ugmentation (RPA): 90 | 91 | sh ../scripts/ytb_eval_without_RPA.sh 92 | 93 | * For evaluation, please use official YouTube-VOS servers ([2018 server](https://competitions.codalab.org/competitions/19544) and [2019 server](https://competitions.codalab.org/competitions/20127)), official [DAVIS toolkit](https://github.com/davisvideochallenge/davis-2017) (for Val), and official [DAVIS server](https://competitions.codalab.org/competitions/20516#learn_the_details) (for Test-dev). 94 | 95 | * More details will be added soon. 96 | 97 | ## Precomputed Results 98 | 99 | Precomputed results on both YouTube-VOS18/19 and DAVIS17 Val/Test-dev are available on [Google Drive](https://drive.google.com/drive/folders/1RaffnMvmQF4Nct30UBXqwrfOXTZ8rvQf?usp=sharing) and [Baidu Disk](https://pan.baidu.com/s/1WqB-SsbT7W-a6DbLIz8Lzw) (BaiduDisk password:6666). 100 | 101 | ## Limitation & Directions for further exploration in VOS! 102 | 103 | Although the numbers on some semi-VOS benchmarks are somehow extremely high, many problems still remain for further exploration. 104 | 105 | I think those who take a look at this repo are likely to be researching in the field related to segmentation or tracking. 106 | 107 | So I would like to share some directions to explore in VOS from my point of view here. Hopefully, I can see some nice solutions in the near future! 108 | 109 | * What about leveraging the propagation-then-correction mechanism in other tracking tasks such as MOT and pose tracking? 110 | * How about using a learning-based method to measure the prediction uncertainty? 111 | * How to tackle VOS in long-term videos? Maybe due to lack of a good dataset for long-term VOS evaluation, this problem is still a hard nut to crack. 112 | * How to update the memory pool containing historical infomation during propagation? 113 | * How to judge whether some information is useful for futher frames or not? 114 | * Will some data augmentations used in training lead to some bias in final prediction? 115 | 116 | (to be continued...) 117 | 118 | ## Citation 119 | If you find this work is useful for your research, please consider giving us a star 🌟 and citing it by the following BibTeX entry.: 120 | 121 | ```latex 122 | @inproceedings{xu2022reliable, 123 | title={Reliable propagation-correction modulation for video object segmentation}, 124 | author={Xu, Xiaohao and Wang, Jinglu and Li, Xiao and Lu, Yan}, 125 | booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, 126 | volume={36}, 127 | number={3}, 128 | pages={2946--2954}, 129 | year={2022} 130 | } 131 | ``` 132 | 133 | if you find the implementations helpful, please consider to cite: 134 | 135 | ```latex 136 | @misc{xu2022RPCMVOS, 137 | title={RPCMVOS-REPO}, 138 | author={Xiaohao, Xu}, 139 | publisher = {GitHub}, 140 | journal = {GitHub repository}, 141 | howpublished={\url{https://github.com/JerryX1110/RPCMVOS/}}, 142 | year={2022} 143 | } 144 | ``` 145 | 146 | 147 | 148 | ## Credit 149 | 150 | **CFBI**: 151 | 152 | **Deeplab**: 153 | 154 | **GCT**: 155 | 156 | ## Related Works in VOS 157 | **Semisupervised video object segmentation repo/paper link:** 158 | 159 | **PAOT [IJCAI 2023]**: 160 | 161 | **ARKitTrack [CVPR 2023]**: 162 | 163 | **MobileVOS [CVPR 2023]**: 164 | 165 | **Two-ShotVOS [CVPR 2023]**: 166 | 167 | **UNINEXT [CVPR 2023]**: 168 | 169 | **ISVOS [CVPR 2023]**: 170 | 171 | **TarVis [CVPR 2023]**: 172 | 173 | **LBLVOS [AAAI 2023]**: 174 | 175 | **DeAOT[NeurIPS 2022]**: 176 | 177 | **RobustVOS [ACM MM 2022]**: 178 | 179 | **BATMAN [ECCV 2022 Oral]**: 180 | 181 | **TBD [ECCV 2022]**: 182 | 183 | **XMEM [ECCV 2022]**: 184 | 185 | **QDMN [ECCV 2022]**: 186 | 187 | **GSFM [ECCV 2022]**: 188 | 189 | **SWEM [CVPR 2022]**: 190 | 191 | **RDE [CVPR 2022]**: 192 | 193 | **COVOS [CVPR 2022]** : 194 | 195 | **AOT [NeurIPS 2021]**: 196 | 197 | **STCN [NeurIPS 2021]**: 198 | 199 | **JOINT [ICCV 2021]**: 200 | 201 | **HMMN [ICCV 2021]**: 202 | 203 | **DMN-AOA [ICCV 2021]**: 204 | 205 | **MiVOS [CVPR 2021]**: 206 | 207 | **SSTVOS [CVPR 2021]**: 208 | 209 | **GraphMemVOS [ECCV 2020]**: 210 | 211 | **AFB-URR [NeurIPS 2020]**: 212 | 213 | **CFBI [ECCV 2020]**: 214 | 215 | **FRTM-VOS [CVPR 2020]**: 216 | 217 | **STM [ICCV 2019]**: 218 | 219 | **FEELVOS [CVPR 2019]**: 220 | 221 | (The list may be incomplete, feel free to contact me by pulling a issue and I'll add them on!) 222 | 223 | ## Useful websites for VOS 224 | **The 1st Large-scale Video Object Segmentation Challenge**: 225 | 226 | **The 2nd Large-scale Video Object Segmentation Challenge - Track 1: Video Object Segmentation**: 227 | 228 | **The Semi-Supervised DAVIS Challenge on Video Object Segmentation @ CVPR 2020**: 229 | 230 | **DAVIS**: 231 | 232 | **YouTube-VOS**: 233 | 234 | **Papers with code for Semi-VOS**: 235 | 236 | ## Q&A 237 | Some Q&As about the project from the readers are listed as follows. 238 | 239 | **Q1:I have noticed that the performance in youtubevos is very good, and I wonder what you think might be the reason?** 240 | 241 | **Error propagation** is a critical problem for most of the models in VOS as well as other tracking-related fileds. The main reason for the inprovement of our model is due to some designs to suppress error from propagation. Specificly, we propose an assembly of propagation and correction modulators to fully leverage the reference guidance during propagation. Apart from the reliable guidance from the reference, we also consider leveraging the reliable cues according to the historical predictions. To be specific, we use Shannon entropy as a measure of prediction uncertainty for further reliable object cues augmentation. 242 | 243 | **Q2:When you were training, did you randomly cut the images to 465x465, consistent with CFBI?** 244 | 245 | Yes. We mainly follow the training protocal used in CFBI. (Based on some observations, I think certain data augmentation methods may lead to some bias in training samples, which may futher lead to a gap between training and inference. However, I havn't verified this viewpoint concisely.) 246 | 247 | ## Acknowledgement ❤️ 248 | Firstly, the author would like to thank Rex for his insightful viewpoints about VOS during e-mail discussion! 249 | Also, this work is built upon CFBI. Thanks to the author of CFBI to release such a wonderful code repo for further work to build upon! 250 | 251 | ## Welcome to comments and discussions!! 252 | Xiaohao Xu: 253 | 254 | ## License 255 | This project is released under the Mit license. See [LICENSE](LICENSE) for additional details. 256 | -------------------------------------------------------------------------------- /configs/__pycache__/resnet101_p2t.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/configs/__pycache__/resnet101_p2t.cpython-36.pyc -------------------------------------------------------------------------------- /configs/resnet101_rpcm_ytb_stage_1.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import sys 5 | import cv2 6 | import time 7 | import socket 8 | import random 9 | 10 | class Configuration(): 11 | def __init__(self): 12 | self.EXP_NAME = 'resnet101_rpcm_ytb' 13 | 14 | self.EVAL_AUTO_RESUME = False 15 | self.UNC_RATIO = 1.0 16 | self.MEM_EVERY = 5 17 | self.USE_SF = False 18 | self.PAST_FRAME_NUM = 4 19 | self.COMPRESSION_RATE = 4 20 | self.BLOCK_NUM = 2 21 | 22 | self.DIR_ROOT = '../datasets/' ## your custom dataset path 23 | self.DIR_DAVIS = os.path.join(self.DIR_ROOT, 'DAVIS') ## davis dataset path 24 | self.DIR_YTB = os.path.join(self.DIR_ROOT, '/YouTube/train/') ## ytb19 dataset path (for training) 25 | self.DIR_YTB_EVAL18 = os.path.join(self.DIR_ROOT, '/YouTube2018/valid/') ## ytb18 dataset path (for evaluation) 26 | self.DIR_YTB_EVAL19 = os.path.join(self.DIR_ROOT,'/YouTube/valid/') ## ytb18 dataset path (for evaluation) 27 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME) ## result saving path 28 | 29 | 30 | 31 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') 32 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') 33 | self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') 34 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') 35 | self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') 36 | 37 | self.DATASETS = ['youtubevos'] 38 | self.DATA_WORKERS = 4 39 | self.DATA_RANDOMCROP = (465, 465) 40 | self.DATA_RANDOMFLIP = 0.5 41 | self.DATA_MAX_CROP_STEPS = 5 42 | self.DATA_MIN_SCALE_FACTOR = 1. 43 | self.DATA_MAX_SCALE_FACTOR = 1.3 44 | self.DATA_SHORT_EDGE_LEN = 480 45 | self.DATA_RANDOM_REVERSE_SEQ = True 46 | self.DATA_DAVIS_REPEAT = 30 47 | self.DATA_CURR_SEQ_LEN = 3 48 | self.DATA_RANDOM_GAP_DAVIS = 3 49 | self.DATA_RANDOM_GAP_YTB = 3 50 | 51 | 52 | self.PRETRAIN = True 53 | self.PRETRAIN_FULL = False 54 | self.PRETRAIN_MODEL = '' # path to resnet101-deeplabv3p.pth.tar 55 | 56 | self.MODEL_BACKBONE = 'resnet' 57 | self.MODEL_MODULE = 'networks.rpcm.rpcm' 58 | self.MODEL_OUTPUT_STRIDE = 16 59 | self.MODEL_ASPP_OUTDIM = 256 60 | self.MODEL_SHORTCUT_DIM = 48 61 | self.MODEL_SEMANTIC_EMBEDDING_DIM = 100 62 | self.MODEL_HEAD_EMBEDDING_DIM = 256 63 | self.MODEL_PRE_HEAD_EMBEDDING_DIM = 64 64 | self.MODEL_GN_GROUPS = 32 65 | self.MODEL_GN_EMB_GROUPS = 25 66 | self.MODEL_MULTI_LOCAL_DISTANCE = [2, 4, 6, 8, 10, 12] 67 | self.MODEL_LOCAL_DOWNSAMPLE = True 68 | self.MODEL_REFINE_CHANNELS = 64 # n * 32 69 | self.MODEL_LOW_LEVEL_INPLANES = 256 if self.MODEL_BACKBONE == 'resnet' else 24 70 | self.MODEL_RELATED_CHANNELS = 64 71 | self.MODEL_EPSILON = 1e-5 72 | self.MODEL_MATCHING_BACKGROUND = True 73 | self.MODEL_GCT_BETA_WD = True 74 | self.MODEL_FLOAT16_MATCHING = False 75 | self.MODEL_FREEZE_BN = True 76 | self.MODEL_FREEZE_BACKBONE = False 77 | 78 | self.TRAIN_TOTAL_STEPS = 200000 79 | self.TRAIN_START_STEP = 0 80 | self.TRAIN_LR = 0.02 81 | self.TRAIN_MOMENTUM = 0.9 82 | self.TRAIN_COSINE_DECAY = False 83 | self.TRAIN_WARM_UP_STEPS = 1000 84 | self.TRAIN_WEIGHT_DECAY = 15e-5 85 | self.TRAIN_POWER = 0.9 86 | self.TRAIN_GPUS = 8 87 | self.TRAIN_BATCH_SIZE = 8 88 | self.TRAIN_START_SEQ_TRAINING_STEPS = self.TRAIN_TOTAL_STEPS / 2 89 | self.TRAIN_TBLOG = False 90 | self.TRAIN_TBLOG_STEP = 60 91 | self.TRAIN_LOG_STEP = 20 92 | self.TRAIN_IMG_LOG = False 93 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 94 | self.TRAIN_HARD_MINING_STEP = self.TRAIN_TOTAL_STEPS / 2 95 | self.TRAIN_CLIP_GRAD_NORM = 5. 96 | self.TRAIN_SAVE_STEP = 20000 97 | self.TRAIN_MAX_KEEP_CKPT = 8 98 | self.TRAIN_RESUME = False 99 | self.TRAIN_RESUME_CKPT = None 100 | self.TRAIN_RESUME_STEP = 0 101 | self.TRAIN_AUTO_RESUME = True 102 | self.TRAIN_GLOBAL_ATROUS_RATE = 1 103 | self.TRAIN_LOCAL_ATROUS_RATE = 1 104 | self.TRAIN_LOCAL_PARALLEL = True 105 | self.TRAIN_GLOBAL_CHUNKS = 1 106 | self.TRAIN_DATASET_FULL_RESOLUTION = True 107 | 108 | 109 | self.TEST_GPU_ID = 0 110 | self.TEST_DATASET = 'youtubevos' 111 | self.TEST_DATASET_FULL_RESOLUTION = True 112 | self.TEST_DATASET_SPLIT = ['val'] 113 | self.TEST_CKPT_PATH = None 114 | self.TEST_CKPT_STEP = None # if "None", evaluate the latest checkpoint. 115 | self.TEST_FLIP = False 116 | self.TEST_MULTISCALE = [1] 117 | self.TEST_MIN_SIZE = None 118 | self.TEST_MAX_SIZE = 800 * 1.3 if self.TEST_MULTISCALE == [1.] else 800 119 | self.TEST_WORKERS = 4 120 | self.TEST_GLOBAL_CHUNKS = 4 121 | self.TEST_GLOBAL_ATROUS_RATE = 1 122 | self.TEST_LOCAL_ATROUS_RATE = 1 123 | self.TEST_LOCAL_PARALLEL = True 124 | 125 | # dist 126 | self.DIST_ENABLE = True 127 | self.DIST_BACKEND = "nccl" 128 | 129 | myname = socket.getfqdn(socket.gethostname( )) 130 | myaddr = socket.gethostbyname(myname) 131 | 132 | self.DIST_URL = "tcp://"+myaddr+":"+str(random.randint(30000,50000)) 133 | self.DIST_START_GPU = 0 134 | 135 | self.__check() 136 | 137 | def __check(self): 138 | if not torch.cuda.is_available(): 139 | raise ValueError('config.py: cuda is not avalable') 140 | if self.TRAIN_GPUS == 0: 141 | raise ValueError('config.py: the number of GPU is 0') 142 | for path in [self.DIR_RESULT, self.DIR_CKPT, self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, self.DIR_TB_LOG]: 143 | if not os.path.isdir(path): 144 | os.makedirs(path) 145 | 146 | 147 | 148 | cfg = Configuration() 149 | -------------------------------------------------------------------------------- /configs/resnet101_rpcm_ytb_stage_2.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import os 4 | import sys 5 | import cv2 6 | import time 7 | import socket 8 | import random 9 | 10 | class Configuration(): 11 | def __init__(self): 12 | self.EXP_NAME = 'resnet101_rpcm_ytb' 13 | 14 | self.EVAL_AUTO_RESUME = False 15 | self.UNC_RATIO = 1.0 16 | self.MEM_EVERY = 5 17 | self.USE_SF = False 18 | self.PAST_FRAME_NUM = 4 19 | self.COMPRESSION_RATE = 4 20 | self.BLOCK_NUM = 2 21 | 22 | self.DIR_ROOT = '../datasets/' ## your custom dataset path 23 | self.DIR_DAVIS = os.path.join(self.DIR_ROOT, 'DAVIS') ## davis dataset path 24 | self.DIR_YTB = os.path.join(self.DIR_ROOT, '/YouTube/train/') ## ytb19 dataset path (for training) 25 | self.DIR_YTB_EVAL18 = os.path.join(self.DIR_ROOT, '/YouTube2018/valid/') ## ytb18 dataset path (for evaluation) 26 | self.DIR_YTB_EVAL19 = os.path.join(self.DIR_ROOT,'/YouTube/valid/') ## ytb18 dataset path (for evaluation) 27 | self.DIR_RESULT = os.path.join(self.DIR_ROOT, 'result', self.EXP_NAME) ## result saving path 28 | 29 | self.DIR_CKPT = os.path.join(self.DIR_RESULT, 'ckpt') 30 | self.DIR_LOG = os.path.join(self.DIR_RESULT, 'log') 31 | self.DIR_IMG_LOG = os.path.join(self.DIR_RESULT, 'log', 'img') 32 | self.DIR_TB_LOG = os.path.join(self.DIR_RESULT, 'log', 'tensorboard') 33 | self.DIR_EVALUATION = os.path.join(self.DIR_RESULT, 'eval') 34 | 35 | self.DATASETS = ['youtubevos'] 36 | self.DATA_WORKERS = 4 37 | self.DATA_RANDOMCROP = (465, 465) 38 | self.DATA_RANDOMFLIP = 0.5 39 | self.DATA_MAX_CROP_STEPS = 5 40 | self.DATA_MIN_SCALE_FACTOR = 1. 41 | self.DATA_MAX_SCALE_FACTOR = 1.3 42 | self.DATA_SHORT_EDGE_LEN = 480 43 | self.DATA_RANDOM_REVERSE_SEQ = True 44 | self.DATA_DAVIS_REPEAT = 30 45 | self.DATA_CURR_SEQ_LEN = 5 46 | self.DATA_RANDOM_GAP_DAVIS = 3 47 | self.DATA_RANDOM_GAP_YTB = 3 48 | 49 | 50 | self.PRETRAIN = True 51 | self.PRETRAIN_FULL = False 52 | self.PRETRAIN_MODEL = '' # path to resnet101-deeplabv3p.pth.tar 53 | 54 | self.MODEL_BACKBONE = 'resnet' 55 | self.MODEL_MODULE = 'networks.p2t.p2t_base' 56 | self.MODEL_OUTPUT_STRIDE = 16 57 | self.MODEL_ASPP_OUTDIM = 256 58 | self.MODEL_SHORTCUT_DIM = 48 59 | self.MODEL_SEMANTIC_EMBEDDING_DIM = 100 60 | self.MODEL_HEAD_EMBEDDING_DIM = 256 61 | self.MODEL_PRE_HEAD_EMBEDDING_DIM = 64 62 | self.MODEL_GN_GROUPS = 32 63 | self.MODEL_GN_EMB_GROUPS = 25 64 | self.MODEL_MULTI_LOCAL_DISTANCE = [2, 4, 6, 8, 10, 12] 65 | self.MODEL_LOCAL_DOWNSAMPLE = True 66 | self.MODEL_REFINE_CHANNELS = 64 # n * 32 67 | self.MODEL_LOW_LEVEL_INPLANES = 256 if self.MODEL_BACKBONE == 'resnet' else 24 68 | self.MODEL_RELATED_CHANNELS = 64 69 | self.MODEL_EPSILON = 1e-5 70 | self.MODEL_MATCHING_BACKGROUND = True 71 | self.MODEL_GCT_BETA_WD = True 72 | self.MODEL_FLOAT16_MATCHING = False 73 | self.MODEL_FREEZE_BN = True 74 | self.MODEL_FREEZE_BACKBONE = False 75 | 76 | self.TRAIN_TOTAL_STEPS = 400000 77 | self.TRAIN_START_STEP = 0 78 | self.TRAIN_LR = 0.02 79 | self.TRAIN_MOMENTUM = 0.9 80 | self.TRAIN_COSINE_DECAY = False 81 | self.TRAIN_WARM_UP_STEPS = 1000 82 | self.TRAIN_WEIGHT_DECAY = 15e-5 83 | self.TRAIN_POWER = 0.9 84 | self.TRAIN_GPUS = 8 85 | self.TRAIN_BATCH_SIZE = 8 86 | self.TRAIN_START_SEQ_TRAINING_STEPS = self.TRAIN_TOTAL_STEPS / 2 87 | self.TRAIN_TBLOG = False 88 | self.TRAIN_TBLOG_STEP = 60 89 | self.TRAIN_LOG_STEP = 20 90 | self.TRAIN_IMG_LOG = False 91 | self.TRAIN_TOP_K_PERCENT_PIXELS = 0.15 92 | self.TRAIN_HARD_MINING_STEP = self.TRAIN_TOTAL_STEPS / 2 93 | self.TRAIN_CLIP_GRAD_NORM = 5. 94 | self.TRAIN_SAVE_STEP = 20000 95 | self.TRAIN_MAX_KEEP_CKPT = 8 96 | self.TRAIN_RESUME = False 97 | self.TRAIN_RESUME_CKPT = None 98 | self.TRAIN_RESUME_STEP = 0 99 | self.TRAIN_AUTO_RESUME = True 100 | self.TRAIN_GLOBAL_ATROUS_RATE = 1 101 | self.TRAIN_LOCAL_ATROUS_RATE = 1 102 | self.TRAIN_LOCAL_PARALLEL = True 103 | self.TRAIN_GLOBAL_CHUNKS = 20 104 | self.TRAIN_DATASET_FULL_RESOLUTION = True 105 | 106 | 107 | self.TEST_GPU_ID = 0 108 | self.TEST_DATASET = 'youtubevos' 109 | self.TEST_DATASET_FULL_RESOLUTION = False 110 | self.TEST_DATASET_SPLIT = ['val'] 111 | self.TEST_CKPT_PATH = None 112 | self.TEST_CKPT_STEP = None # if "None", evaluate the latest checkpoint. 113 | self.TEST_FLIP = False 114 | self.TEST_MULTISCALE = [1] 115 | self.TEST_MIN_SIZE = None 116 | self.TEST_MAX_SIZE = 800 * 1.3 if self.TEST_MULTISCALE == [1.] else 800 117 | self.TEST_WORKERS = 4 118 | self.TEST_GLOBAL_CHUNKS = 4 119 | self.TEST_GLOBAL_ATROUS_RATE = 1 120 | self.TEST_LOCAL_ATROUS_RATE = 1 121 | self.TEST_LOCAL_PARALLEL = True 122 | 123 | # dist 124 | self.DIST_ENABLE = True 125 | self.DIST_BACKEND = "nccl" 126 | 127 | myname = socket.getfqdn(socket.gethostname( )) 128 | myaddr = socket.gethostbyname(myname) 129 | 130 | self.DIST_URL = "tcp://"+myaddr+":"+str(random.randint(30000,50000)) 131 | self.DIST_START_GPU = 0 132 | 133 | self.__check() 134 | 135 | def __check(self): 136 | if not torch.cuda.is_available(): 137 | raise ValueError('config.py: cuda is not avalable') 138 | if self.TRAIN_GPUS == 0: 139 | raise ValueError('config.py: the number of GPU is 0') 140 | for path in [self.DIR_RESULT, self.DIR_CKPT, self.DIR_LOG, self.DIR_EVALUATION, self.DIR_IMG_LOG, self.DIR_TB_LOG]: 141 | if not os.path.isdir(path): 142 | os.makedirs(path) 143 | 144 | 145 | 146 | cfg = Configuration() 147 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/dataloaders/__init__.py -------------------------------------------------------------------------------- /dataloaders/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/dataloaders/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/custom_transforms.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/dataloaders/__pycache__/custom_transforms.cpython-36.pyc -------------------------------------------------------------------------------- /dataloaders/__pycache__/datasets_m.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/dataloaders/__pycache__/datasets_m.cpython-36.pyc -------------------------------------------------------------------------------- /dataloaders/custom_transforms.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import cv2 4 | import numpy as np 5 | import torch 6 | 7 | cv2.setNumThreads(0) 8 | 9 | class Resize(object): 10 | """Rescale the image in a sample to a given size. 11 | 12 | Args: 13 | output_size (tuple or int): Desired output size. If tuple, output is 14 | matched to output_size. If int, smaller of image edges is matched 15 | to output_size keeping aspect ratio the same. 16 | """ 17 | 18 | def __init__(self, output_size): 19 | assert isinstance(output_size, (int, tuple)) 20 | if isinstance(output_size, int): 21 | self.output_size = (output_size, output_size) 22 | else: 23 | self.output_size = output_size 24 | 25 | def __call__(self, sample): 26 | prev_img = sample['prev_img'] 27 | h, w = prev_img.shape[:2] 28 | if self.output_size == (h, w): 29 | return sample 30 | else: 31 | new_h, new_w = self.output_size 32 | 33 | for elem in sample.keys(): 34 | if 'meta' in elem: 35 | continue 36 | tmp = sample[elem] 37 | 38 | if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img': 39 | flagval = cv2.INTER_CUBIC 40 | else: 41 | flagval = cv2.INTER_NEAREST 42 | 43 | if elem == 'curr_img' or elem == 'curr_label': 44 | new_tmp = [] 45 | all_tmp = tmp 46 | for tmp in all_tmp: 47 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), 48 | interpolation=flagval) 49 | new_tmp.append(tmp) 50 | tmp = new_tmp 51 | else: 52 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), 53 | interpolation=flagval) 54 | 55 | sample[elem] = tmp 56 | 57 | return sample 58 | 59 | class BalancedRandomCrop(object): 60 | """Crop randomly the image in a sample. 61 | 62 | Args: 63 | output_size (tuple or int): Desired output size. If int, square crop 64 | is made. 65 | """ 66 | 67 | def __init__(self, output_size, max_step=5, max_obj_num=5, min_obj_pixel_num=100): 68 | assert isinstance(output_size, (int, tuple)) 69 | if isinstance(output_size, int): 70 | self.output_size = (output_size, output_size) 71 | else: 72 | assert len(output_size) == 2 73 | self.output_size = output_size 74 | self.max_step = max_step 75 | self.max_obj_num = max_obj_num 76 | self.min_obj_pixel_num = min_obj_pixel_num 77 | 78 | def __call__(self, sample): 79 | 80 | image = sample['prev_img'] 81 | h, w = image.shape[:2] 82 | new_h, new_w = self.output_size 83 | new_h = h if new_h >= h else new_h 84 | new_w = w if new_w >= w else new_w 85 | ref_label = sample["ref_label"] 86 | prev_label = sample["prev_label"] 87 | curr_label = sample["curr_label"] 88 | 89 | is_contain_obj = False 90 | step = 0 91 | while (not is_contain_obj) and (step < self.max_step): 92 | step += 1 93 | top = np.random.randint(0, h - new_h + 1) 94 | left = np.random.randint(0, w - new_w + 1) 95 | after_crop = [] 96 | contains = [] 97 | for elem in ([ref_label, prev_label] + curr_label): 98 | tmp = elem[top: top + new_h, left:left + new_w] 99 | contains.append(np.unique(tmp)) 100 | after_crop.append(tmp) 101 | 102 | 103 | all_obj = list(np.sort(contains[0])) 104 | 105 | if all_obj[-1] == 0: 106 | continue 107 | 108 | # remove background 109 | if all_obj[0] == 0: 110 | all_obj = all_obj[1:] 111 | # remove small obj 112 | new_all_obj = [] 113 | for obj_id in all_obj: 114 | after_crop_pixels = np.sum(after_crop[0] == obj_id) 115 | if after_crop_pixels > self.min_obj_pixel_num: 116 | new_all_obj.append(obj_id) 117 | 118 | if len(new_all_obj) == 0: 119 | is_contain_obj = False 120 | else: 121 | is_contain_obj = True 122 | 123 | if len(new_all_obj) > self.max_obj_num: 124 | random.shuffle(new_all_obj) 125 | new_all_obj = new_all_obj[:self.max_obj_num] 126 | 127 | all_obj = [0] + new_all_obj 128 | 129 | 130 | post_process = [] 131 | for elem in after_crop: 132 | new_elem = elem * 0 133 | for idx in range(len(all_obj)): 134 | obj_id = all_obj[idx] 135 | if obj_id == 0: 136 | continue 137 | mask = elem == obj_id 138 | 139 | new_elem += (mask * idx).astype(np.uint8) 140 | post_process.append(new_elem.astype(np.uint8)) 141 | 142 | sample["ref_label"] = post_process[0] 143 | sample["prev_label"] = post_process[1] 144 | curr_len = len(sample["curr_img"]) 145 | sample["curr_label"] = [] 146 | for idx in range(curr_len): 147 | sample["curr_label"].append(post_process[idx + 2]) 148 | 149 | for elem in sample.keys(): 150 | if 'meta' in elem or 'label' in elem: 151 | continue 152 | if elem == 'curr_img': 153 | new_tmp = [] 154 | for tmp_ in sample[elem]: 155 | tmp_ = tmp_[top: top + new_h, left:left + new_w] 156 | new_tmp.append(tmp_) 157 | sample[elem] = new_tmp 158 | else: 159 | tmp = sample[elem] 160 | tmp = tmp[top: top + new_h, left:left + new_w] 161 | sample[elem] = tmp 162 | 163 | obj_num = len(all_obj) - 1 164 | 165 | sample['meta']['obj_num'] = obj_num 166 | 167 | return sample 168 | 169 | class RandomScale(object): 170 | """Randomly resize the image and the ground truth to specified scales. 171 | Args: 172 | scales (list): the list of scales 173 | """ 174 | 175 | def __init__(self, min_scale=1., max_scale=1.3, short_edge=None): 176 | self.min_scale = min_scale 177 | self.max_scale = max_scale 178 | self.short_edge = short_edge 179 | 180 | def __call__(self, sample): 181 | # Fixed range of scales 182 | sc = np.random.uniform(self.min_scale, self.max_scale) 183 | # Align short edge 184 | if not (self.short_edge is None): 185 | image = sample['prev_img'] 186 | h, w = image.shape[:2] 187 | if h > w: 188 | sc *= float(self.short_edge) / w 189 | else: 190 | sc *= float(self.short_edge) / h 191 | 192 | 193 | for elem in sample.keys(): 194 | if 'meta' in elem: 195 | continue 196 | tmp = sample[elem] 197 | 198 | if elem == 'prev_img' or elem == 'curr_img' or elem == 'ref_img': 199 | flagval = cv2.INTER_CUBIC 200 | else: 201 | flagval = cv2.INTER_NEAREST 202 | 203 | if elem == 'curr_img' or elem == 'curr_label': 204 | new_tmp = [] 205 | for tmp_ in tmp: 206 | tmp_ = cv2.resize(tmp_, None, fx=sc, fy=sc, interpolation=flagval) 207 | new_tmp.append(tmp_) 208 | tmp = new_tmp 209 | else: 210 | tmp = cv2.resize(tmp, None, fx=sc, fy=sc, interpolation=flagval) 211 | 212 | sample[elem] = tmp 213 | 214 | return sample 215 | 216 | class RestrictSize(object): 217 | """Randomly resize the image and the ground truth to specified scales. 218 | Args: 219 | scales (list): the list of scales 220 | """ 221 | 222 | def __init__(self, min_size=None, max_size=800*1.3): 223 | self.min_size = min_size 224 | self.max_size = max_size 225 | assert ((min_size is None)) or ((max_size is None)) 226 | 227 | def __call__(self, sample): 228 | 229 | # Fixed range of scales 230 | sc = None 231 | image = sample['ref_img'] 232 | h, w = image.shape[:2] 233 | # Align short edge 234 | if not (self.min_size is None): 235 | if h > w: 236 | short_edge = w 237 | else: 238 | short_edge = h 239 | if short_edge < self.min_size: 240 | sc = float(self.min_size) / short_edge 241 | else: 242 | if h > w: 243 | long_edge = h 244 | else: 245 | long_edge = w 246 | if long_edge > self.max_size: 247 | sc = float(self.max_size) / long_edge 248 | 249 | if sc is None: 250 | new_h = h 251 | new_w = w 252 | else: 253 | new_h = int(sc * h) 254 | new_w = int(sc * w) 255 | new_h = new_h - (new_h - 1) % 4 256 | new_w = new_w - (new_w - 1) % 4 257 | if new_h == h and new_w == w: 258 | return sample 259 | 260 | 261 | for elem in sample.keys(): 262 | if 'meta' in elem: 263 | continue 264 | tmp = sample[elem] 265 | 266 | if 'label' in elem: 267 | flagval = cv2.INTER_NEAREST 268 | else: 269 | flagval = cv2.INTER_CUBIC 270 | 271 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), interpolation=flagval) 272 | 273 | sample[elem] = tmp 274 | 275 | return sample 276 | 277 | class RandomHorizontalFlip(object): 278 | """Horizontally flip the given image and ground truth randomly with a probability of 0.5.""" 279 | 280 | def __init__(self, prob): 281 | self.p = prob 282 | 283 | def __call__(self, sample): 284 | 285 | if random.random() < self.p: 286 | for elem in sample.keys(): 287 | if 'meta' in elem: 288 | continue 289 | if elem == 'curr_img' or elem == 'curr_label': 290 | new_tmp = [] 291 | for tmp_ in sample[elem]: 292 | tmp_ = cv2.flip(tmp_, flipCode=1) 293 | new_tmp.append(tmp_) 294 | sample[elem] = new_tmp 295 | else: 296 | tmp = sample[elem] 297 | tmp = cv2.flip(tmp, flipCode=1) 298 | sample[elem] = tmp 299 | 300 | return sample 301 | 302 | class RandomGaussianBlur(object): 303 | 304 | def __init__(self, prob=0.2): 305 | self.p = prob 306 | 307 | def __call__(self, sample): 308 | 309 | 310 | for elem in sample.keys(): 311 | if 'meta' in elem or 'label' in elem: 312 | continue 313 | 314 | if elem == 'curr_img': 315 | new_tmp = [] 316 | for tmp_ in sample[elem]: 317 | if random.random() < self.p: 318 | std = random.random() * 1.9 + 0.1 # [0.1, 2] 319 | tmp_ = cv2.GaussianBlur(tmp_, (9, 9), sigmaX=std, sigmaY=std) 320 | new_tmp.append(tmp_) 321 | sample[elem] = new_tmp 322 | else: 323 | tmp = sample[elem] 324 | if random.random() < self.p: 325 | std = random.random() * 1.9 + 0.1 # [0.1, 2] 326 | tmp = cv2.GaussianBlur(tmp, (9, 9), sigmaX=std, sigmaY=std) 327 | sample[elem] = tmp 328 | 329 | return sample 330 | 331 | class SubtractMeanImage(object): 332 | def __init__(self, mean, change_channels=False): 333 | self.mean = mean 334 | self.change_channels = change_channels 335 | 336 | def __call__(self, sample): 337 | for elem in sample.keys(): 338 | if 'image' in elem: 339 | if self.change_channels: 340 | sample[elem] = sample[elem][:, :, [2, 1, 0]] 341 | sample[elem] = np.subtract( 342 | sample[elem], np.array(self.mean, dtype=np.float32)) 343 | return sample 344 | 345 | def __str__(self): 346 | return 'SubtractMeanImage' + str(self.mean) 347 | 348 | class ToTensor(object): 349 | """Convert ndarrays in sample to Tensors.""" 350 | 351 | def __call__(self, sample): 352 | 353 | for elem in sample.keys(): 354 | if 'meta' in elem: 355 | continue 356 | tmp = sample[elem] 357 | 358 | if elem == 'curr_img' or elem == 'curr_label': 359 | new_tmp = [] 360 | for tmp_ in tmp: 361 | if tmp_.ndim == 2: 362 | tmp_ = tmp_[:, :, np.newaxis] 363 | else: 364 | tmp_ = tmp_ / 255. 365 | tmp_ -= (0.485, 0.456, 0.406) 366 | tmp_ /= (0.229, 0.224, 0.225) 367 | tmp_ = tmp_.transpose((2, 0, 1)) 368 | new_tmp.append(torch.from_numpy(tmp_)) 369 | tmp = new_tmp 370 | else: 371 | if tmp.ndim == 2: 372 | tmp = tmp[:, :, np.newaxis] 373 | else: 374 | tmp = tmp / 255. 375 | tmp -= (0.485, 0.456, 0.406) 376 | tmp /= (0.229, 0.224, 0.225) 377 | tmp = tmp.transpose((2, 0, 1)) 378 | tmp = torch.from_numpy(tmp) 379 | sample[elem] = tmp 380 | 381 | return sample 382 | 383 | class MultiRestrictSize(object): 384 | def __init__(self, min_size=None, max_size=800, flip=False, multi_scale=[1.3]): 385 | self.min_size = min_size 386 | self.max_size = max_size 387 | self.multi_scale = multi_scale 388 | self.flip = flip 389 | assert ((min_size is None)) or ((max_size is None)) 390 | 391 | def __call__(self, sample): 392 | samples = [] 393 | image = sample['current_img'] 394 | h, w = image.shape[:2] 395 | for scale in self.multi_scale: 396 | # Fixed range of scales 397 | sc = None 398 | # Align short edge 399 | if not (self.min_size is None): 400 | if h > w: 401 | short_edge = w 402 | else: 403 | short_edge = h 404 | if short_edge > self.min_size: 405 | sc = float(self.min_size) / short_edge 406 | else: 407 | if h > w: 408 | long_edge = h 409 | else: 410 | long_edge = w 411 | if long_edge > self.max_size: 412 | sc = float(self.max_size) / long_edge 413 | 414 | if sc is None: 415 | new_h = h 416 | new_w = w 417 | else: 418 | new_h = sc * h 419 | new_w = sc * w 420 | new_h = int(new_h * scale) 421 | new_w = int(new_w * scale) 422 | 423 | if (new_h - 1) % 16 != 0: 424 | new_h = int(np.around((new_h - 1) / 16.) * 16 + 1) 425 | if (new_w - 1) % 16 != 0: 426 | new_w = int(np.around((new_w - 1) / 16.) * 16 + 1) 427 | 428 | if new_h == h and new_w == w: 429 | samples.append(sample) 430 | else: 431 | new_sample = {} 432 | for elem in sample.keys(): 433 | if 'meta' in elem: 434 | new_sample[elem] = sample[elem] 435 | continue 436 | tmp = sample[elem] 437 | if 'label' in elem: 438 | new_sample[elem] = sample[elem] 439 | continue 440 | else: 441 | flagval = cv2.INTER_CUBIC 442 | tmp = cv2.resize(tmp, dsize=(new_w, new_h), interpolation=flagval) 443 | new_sample[elem] = tmp 444 | samples.append(new_sample) 445 | 446 | if self.flip: 447 | now_sample = samples[-1] 448 | new_sample = {} 449 | for elem in now_sample.keys(): 450 | if 'meta' in elem: 451 | new_sample[elem] = now_sample[elem].copy() 452 | new_sample[elem]['flip'] = True 453 | continue 454 | tmp = now_sample[elem] 455 | tmp = tmp[:, ::-1].copy() 456 | new_sample[elem] = tmp 457 | samples.append(new_sample) 458 | 459 | return samples 460 | 461 | class MultiToTensor(object): 462 | def __call__(self, samples): 463 | for idx in range(len(samples)): 464 | sample = samples[idx] 465 | for elem in sample.keys(): 466 | if 'meta' in elem: 467 | continue 468 | tmp = sample[elem] 469 | if tmp is None: 470 | continue 471 | 472 | if tmp.ndim == 2: 473 | tmp = tmp[:, :, np.newaxis] 474 | else: 475 | tmp = tmp / 255. 476 | tmp -= (0.485, 0.456, 0.406) 477 | tmp /= (0.229, 0.224, 0.225) 478 | 479 | tmp = tmp.transpose((2, 0, 1)) 480 | samples[idx][elem] = torch.from_numpy(tmp) 481 | 482 | return samples 483 | 484 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/__init__.py -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/deeplab/__init__.py -------------------------------------------------------------------------------- /networks/deeplab/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/deeplab/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/__pycache__/aspp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/deeplab/__pycache__/aspp.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/__pycache__/decoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/deeplab/__pycache__/decoder.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/__pycache__/deeplab.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/deeplab/__pycache__/deeplab.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/aspp.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class _ASPPModule(nn.Module): 7 | def __init__(self, inplanes, planes, kernel_size, padding, dilation, BatchNorm): 8 | super(_ASPPModule, self).__init__() 9 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 10 | stride=1, padding=padding, dilation=dilation, bias=False) 11 | self.bn = BatchNorm(planes) 12 | self.relu = nn.ReLU(inplace=True) 13 | 14 | self._init_weight() 15 | 16 | def forward(self, x): 17 | x = self.atrous_conv(x) 18 | x = self.bn(x) 19 | 20 | return self.relu(x) 21 | 22 | def _init_weight(self): 23 | for m in self.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | torch.nn.init.kaiming_normal_(m.weight) 26 | elif isinstance(m, nn.BatchNorm2d): 27 | m.weight.data.fill_(1) 28 | m.bias.data.zero_() 29 | 30 | class ASPP(nn.Module): 31 | def __init__(self, backbone, output_stride, BatchNorm): 32 | super(ASPP, self).__init__() 33 | if backbone == 'drn': 34 | inplanes = 512 35 | elif backbone == 'mobilenet': 36 | inplanes = 320 37 | else: 38 | inplanes = 2048 39 | if output_stride == 16: 40 | dilations = [1, 6, 12, 18] 41 | elif output_stride == 8: 42 | dilations = [1, 12, 24, 36] 43 | else: 44 | raise NotImplementedError 45 | 46 | self.aspp1 = _ASPPModule(inplanes, 256, 1, padding=0, dilation=dilations[0], BatchNorm=BatchNorm) 47 | self.aspp2 = _ASPPModule(inplanes, 256, 3, padding=dilations[1], dilation=dilations[1], BatchNorm=BatchNorm) 48 | self.aspp3 = _ASPPModule(inplanes, 256, 3, padding=dilations[2], dilation=dilations[2], BatchNorm=BatchNorm) 49 | self.aspp4 = _ASPPModule(inplanes, 256, 3, padding=dilations[3], dilation=dilations[3], BatchNorm=BatchNorm) 50 | 51 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 52 | nn.Conv2d(inplanes, 256, 1, stride=1, bias=False), 53 | BatchNorm(256), 54 | nn.ReLU(inplace=True)) 55 | self.conv1 = nn.Conv2d(1280, 256, 1, bias=False) 56 | self.bn1 = BatchNorm(256) 57 | self.relu = nn.ReLU(inplace=True) 58 | self.dropout = nn.Dropout(0.1) 59 | self._init_weight() 60 | 61 | def forward(self, x): 62 | x1 = self.aspp1(x) 63 | x2 = self.aspp2(x) 64 | x3 = self.aspp3(x) 65 | x4 = self.aspp4(x) 66 | x5 = self.global_avg_pool(x) 67 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 68 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 69 | 70 | x = self.conv1(x) 71 | x = self.bn1(x) 72 | x = self.relu(x) 73 | 74 | return self.dropout(x) 75 | 76 | def _init_weight(self): 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv2d): 79 | torch.nn.init.kaiming_normal_(m.weight) 80 | elif isinstance(m, nn.BatchNorm2d): 81 | m.weight.data.fill_(1) 82 | m.bias.data.zero_() 83 | 84 | 85 | def build_aspp(backbone, output_stride, BatchNorm): 86 | return ASPP(backbone, output_stride, BatchNorm) 87 | -------------------------------------------------------------------------------- /networks/deeplab/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from networks.deeplab.backbone import resnet, mobilenet 2 | 3 | def build_backbone(backbone, output_stride, BatchNorm): 4 | if backbone == 'resnet': 5 | return resnet.ResNet101(output_stride, BatchNorm) 6 | elif backbone == 'mobilenet': 7 | return mobilenet.MobileNetV2(output_stride, BatchNorm) 8 | else: 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /networks/deeplab/backbone/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/deeplab/backbone/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/deeplab/backbone/__pycache__/mobilenet.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/backbone/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/deeplab/backbone/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /networks/deeplab/backbone/mobilenet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | import math 5 | import torch.utils.model_zoo as model_zoo 6 | 7 | def conv_bn(inp, oup, stride, BatchNorm): 8 | return nn.Sequential( 9 | nn.Conv2d(inp, oup, 3, stride, 1, bias=False), 10 | BatchNorm(oup), 11 | nn.ReLU6(inplace=True) 12 | ) 13 | 14 | 15 | def fixed_padding(inputs, kernel_size, dilation): 16 | kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1) 17 | pad_total = kernel_size_effective - 1 18 | pad_beg = pad_total // 2 19 | pad_end = pad_total - pad_beg 20 | padded_inputs = F.pad(inputs, (pad_beg, pad_end, pad_beg, pad_end)) 21 | return padded_inputs 22 | 23 | 24 | class InvertedResidual(nn.Module): 25 | def __init__(self, inp, oup, stride, dilation, expand_ratio, BatchNorm): 26 | super(InvertedResidual, self).__init__() 27 | self.stride = stride 28 | assert stride in [1, 2] 29 | 30 | hidden_dim = round(inp * expand_ratio) 31 | self.use_res_connect = self.stride == 1 and inp == oup 32 | self.kernel_size = 3 33 | self.dilation = dilation 34 | 35 | if expand_ratio == 1: 36 | self.conv = nn.Sequential( 37 | # dw 38 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 39 | BatchNorm(hidden_dim), 40 | nn.ReLU6(inplace=True), 41 | # pw-linear 42 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, 1, bias=False), 43 | BatchNorm(oup), 44 | ) 45 | else: 46 | self.conv = nn.Sequential( 47 | # pw 48 | nn.Conv2d(inp, hidden_dim, 1, 1, 0, 1, bias=False), 49 | BatchNorm(hidden_dim), 50 | nn.ReLU6(inplace=True), 51 | # dw 52 | nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 0, dilation, groups=hidden_dim, bias=False), 53 | BatchNorm(hidden_dim), 54 | nn.ReLU6(inplace=True), 55 | # pw-linear 56 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, 1, bias=False), 57 | BatchNorm(oup), 58 | ) 59 | 60 | def forward(self, x): 61 | x_pad = fixed_padding(x, self.kernel_size, dilation=self.dilation) 62 | if self.use_res_connect: 63 | x = x + self.conv(x_pad) 64 | else: 65 | x = self.conv(x_pad) 66 | return x 67 | 68 | 69 | class MobileNetV2(nn.Module): 70 | def __init__(self, output_stride=8, BatchNorm=None, width_mult=1., pretrained=False): 71 | super(MobileNetV2, self).__init__() 72 | block = InvertedResidual 73 | input_channel = 32 74 | current_stride = 1 75 | rate = 1 76 | interverted_residual_setting = [ 77 | # t, c, n, s 78 | [1, 16, 1, 1], 79 | [6, 24, 2, 2], 80 | [6, 32, 3, 2], 81 | [6, 64, 4, 2], 82 | [6, 96, 3, 1], 83 | [6, 160, 3, 2], 84 | [6, 320, 1, 1], 85 | ] 86 | 87 | # building first layer 88 | input_channel = int(input_channel * width_mult) 89 | self.features = [conv_bn(3, input_channel, 2, BatchNorm)] 90 | current_stride *= 2 91 | # building inverted residual blocks 92 | for t, c, n, s in interverted_residual_setting: 93 | if current_stride == output_stride: 94 | stride = 1 95 | dilation = rate 96 | rate *= s 97 | else: 98 | stride = s 99 | dilation = 1 100 | current_stride *= s 101 | output_channel = int(c * width_mult) 102 | for i in range(n): 103 | if i == 0: 104 | self.features.append(block(input_channel, output_channel, stride, dilation, t, BatchNorm)) 105 | else: 106 | self.features.append(block(input_channel, output_channel, 1, rate, t, BatchNorm)) 107 | input_channel = output_channel 108 | self.features = nn.Sequential(*self.features) 109 | self._initialize_weights() 110 | 111 | if pretrained: 112 | self._load_pretrained_model() 113 | 114 | self.low_level_features = self.features[0:4] 115 | self.high_level_features = self.features[4:] 116 | 117 | self.feautre_8x = self.features[4:7] 118 | self.feature_16x = self.features[7:14] 119 | self.feature_32x = self.features[14:] 120 | 121 | def forward(self, x, return_mid_level=False): 122 | if return_mid_level: 123 | low_level_feat = self.low_level_features(x) 124 | mid_level_feat = self.feautre_8x(low_level_feat) 125 | x = self.feature_16x(mid_level_feat) 126 | x = self.feature_32x(x) 127 | return x, low_level_feat, mid_level_feat 128 | else: 129 | low_level_feat = self.low_level_features(x) 130 | x = self.high_level_features(low_level_feat) 131 | return x, low_level_feat 132 | 133 | def _load_pretrained_model(self): 134 | pretrain_dict = model_zoo.load_url('http://jeff95.me/models/mobilenet_v2-6a65762b.pth') 135 | model_dict = {} 136 | state_dict = self.state_dict() 137 | for k, v in pretrain_dict.items(): 138 | if k in state_dict: 139 | model_dict[k] = v 140 | state_dict.update(model_dict) 141 | self.load_state_dict(state_dict) 142 | 143 | def _initialize_weights(self): 144 | for m in self.modules(): 145 | if isinstance(m, nn.Conv2d): 146 | torch.nn.init.kaiming_normal_(m.weight) 147 | elif isinstance(m, nn.BatchNorm2d): 148 | m.weight.data.fill_(1) 149 | m.bias.data.zero_() 150 | 151 | if __name__ == "__main__": 152 | input = torch.rand(1, 3, 512, 512) 153 | model = MobileNetV2(output_stride=16, BatchNorm=nn.BatchNorm2d) 154 | output, low_level_feat = model(input) 155 | print(output.size()) 156 | print(low_level_feat.size()) 157 | -------------------------------------------------------------------------------- /networks/deeplab/backbone/resnet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | 5 | class Bottleneck(nn.Module): 6 | expansion = 4 7 | 8 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, BatchNorm=None): 9 | super(Bottleneck, self).__init__() 10 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 11 | self.bn1 = BatchNorm(planes) 12 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 13 | dilation=dilation, padding=dilation, bias=False) 14 | self.bn2 = BatchNorm(planes) 15 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 16 | self.bn3 = BatchNorm(planes * 4) 17 | self.relu = nn.ReLU(inplace=True) 18 | self.downsample = downsample 19 | self.stride = stride 20 | self.dilation = dilation 21 | 22 | def forward(self, x): 23 | residual = x 24 | 25 | out = self.conv1(x) 26 | out = self.bn1(out) 27 | out = self.relu(out) 28 | 29 | out = self.conv2(out) 30 | out = self.bn2(out) 31 | out = self.relu(out) 32 | 33 | out = self.conv3(out) 34 | out = self.bn3(out) 35 | 36 | if self.downsample is not None: 37 | residual = self.downsample(x) 38 | 39 | out += residual 40 | out = self.relu(out) 41 | 42 | return out 43 | 44 | class ResNet(nn.Module): 45 | 46 | def __init__(self, block, layers, output_stride, BatchNorm, pretrained=False): 47 | self.inplanes = 64 48 | super(ResNet, self).__init__() 49 | blocks = [1, 2, 4] 50 | if output_stride == 16: 51 | strides = [1, 2, 2, 1] 52 | dilations = [1, 1, 1, 2] 53 | elif output_stride == 8: 54 | strides = [1, 2, 1, 1] 55 | dilations = [1, 1, 2, 4] 56 | else: 57 | raise NotImplementedError 58 | 59 | # Modules 60 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 61 | bias=False) 62 | self.bn1 = BatchNorm(64) 63 | self.relu = nn.ReLU(inplace=True) 64 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 65 | 66 | self.layer1 = self._make_layer(block, 64, layers[0], stride=strides[0], dilation=dilations[0], BatchNorm=BatchNorm) 67 | self.layer2 = self._make_layer(block, 128, layers[1], stride=strides[1], dilation=dilations[1], BatchNorm=BatchNorm) 68 | self.layer3 = self._make_layer(block, 256, layers[2], stride=strides[2], dilation=dilations[2], BatchNorm=BatchNorm) 69 | self.layer4 = self._make_MG_unit(block, 512, blocks=blocks, stride=strides[3], dilation=dilations[3], BatchNorm=BatchNorm) 70 | self._init_weight() 71 | 72 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 73 | downsample = None 74 | if stride != 1 or self.inplanes != planes * block.expansion: 75 | downsample = nn.Sequential( 76 | nn.Conv2d(self.inplanes, planes * block.expansion, 77 | kernel_size=1, stride=stride, bias=False), 78 | BatchNorm(planes * block.expansion), 79 | ) 80 | 81 | layers = [] 82 | layers.append(block(self.inplanes, planes, stride, dilation, downsample, BatchNorm)) 83 | self.inplanes = planes * block.expansion 84 | for i in range(1, blocks): 85 | layers.append(block(self.inplanes, planes, dilation=dilation, BatchNorm=BatchNorm)) 86 | 87 | return nn.Sequential(*layers) 88 | 89 | def _make_MG_unit(self, block, planes, blocks, stride=1, dilation=1, BatchNorm=None): 90 | downsample = None 91 | if stride != 1 or self.inplanes != planes * block.expansion: 92 | downsample = nn.Sequential( 93 | nn.Conv2d(self.inplanes, planes * block.expansion, 94 | kernel_size=1, stride=stride, bias=False), 95 | BatchNorm(planes * block.expansion), 96 | ) 97 | 98 | layers = [] 99 | layers.append(block(self.inplanes, planes, stride, dilation=blocks[0]*dilation, 100 | downsample=downsample, BatchNorm=BatchNorm)) 101 | self.inplanes = planes * block.expansion 102 | for i in range(1, len(blocks)): 103 | layers.append(block(self.inplanes, planes, stride=1, 104 | dilation=blocks[i]*dilation, BatchNorm=BatchNorm)) 105 | 106 | return nn.Sequential(*layers) 107 | 108 | def forward(self, input, return_mid_level=False): 109 | x = self.conv1(input) 110 | x = self.bn1(x) 111 | x = self.relu(x) 112 | x = self.maxpool(x) 113 | 114 | x = self.layer1(x) 115 | low_level_feat = x 116 | x = self.layer2(x) 117 | mid_level_feat = x 118 | x = self.layer3(x) 119 | x = self.layer4(x) 120 | if return_mid_level: 121 | return x, low_level_feat, mid_level_feat 122 | else: 123 | return x, low_level_feat 124 | def _init_weight(self): 125 | for m in self.modules(): 126 | if isinstance(m, nn.Conv2d): 127 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 128 | m.weight.data.normal_(0, math.sqrt(2. / n)) 129 | elif isinstance(m, nn.BatchNorm2d): 130 | m.weight.data.fill_(1) 131 | m.bias.data.zero_() 132 | 133 | def _load_pretrained_model(self): 134 | pretrain_dict = model_zoo.load_url('https://download.pytorch.org/models/resnet101-5d3b4d8f.pth') 135 | model_dict = {} 136 | state_dict = self.state_dict() 137 | for k, v in pretrain_dict.items(): 138 | if k in state_dict: 139 | model_dict[k] = v 140 | state_dict.update(model_dict) 141 | self.load_state_dict(state_dict) 142 | 143 | def ResNet101(output_stride, BatchNorm, pretrained=True): 144 | """Constructs a ResNet-101 model. 145 | Args: 146 | pretrained (bool): If True, returns a model pre-trained on ImageNet 147 | """ 148 | model = ResNet(Bottleneck, [3, 4, 23, 3], output_stride, BatchNorm, pretrained=pretrained) 149 | return model 150 | 151 | if __name__ == "__main__": 152 | import torch 153 | model = ResNet101(BatchNorm=nn.BatchNorm2d, pretrained=True, output_stride=8) 154 | input = torch.rand(1, 3, 512, 512) 155 | output, low_level_feat = model(input) 156 | print(output.size()) 157 | print(low_level_feat.size()) 158 | -------------------------------------------------------------------------------- /networks/deeplab/decoder.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | class Decoder(nn.Module): 7 | def __init__(self, backbone, BatchNorm): 8 | super(Decoder, self).__init__() 9 | if backbone == 'resnet': 10 | low_level_inplanes = 256 11 | elif backbone == 'mobilenet': 12 | low_level_inplanes = 24 13 | else: 14 | raise NotImplementedError 15 | 16 | self.conv1 = nn.Conv2d(low_level_inplanes, 48, 1, bias=False) 17 | self.bn1 = BatchNorm(48) 18 | self.relu = nn.ReLU(inplace=True) 19 | 20 | self.last_conv = nn.Sequential(nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1, bias=False), 21 | BatchNorm(256), 22 | nn.ReLU(inplace=True), 23 | nn.Sequential(), 24 | nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False), 25 | BatchNorm(256), 26 | nn.ReLU(inplace=True), 27 | nn.Sequential()) 28 | 29 | self._init_weight() 30 | 31 | 32 | def forward(self, x, low_level_feat): 33 | low_level_feat = self.conv1(low_level_feat) 34 | low_level_feat = self.bn1(low_level_feat) 35 | low_level_feat = self.relu(low_level_feat) 36 | 37 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bilinear', align_corners=True) 38 | x = torch.cat((x, low_level_feat), dim=1) 39 | x = self.last_conv(x) 40 | 41 | return x 42 | 43 | def _init_weight(self): 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d): 46 | torch.nn.init.kaiming_normal_(m.weight) 47 | elif isinstance(m, nn.BatchNorm2d): 48 | m.weight.data.fill_(1) 49 | m.bias.data.zero_() 50 | 51 | def build_decoder(backbone, BatchNorm): 52 | return Decoder(backbone, BatchNorm) 53 | -------------------------------------------------------------------------------- /networks/deeplab/deeplab.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.deeplab.aspp import build_aspp 5 | from networks.deeplab.decoder import build_decoder 6 | from networks.deeplab.backbone import build_backbone 7 | from networks.layers.normalization import FrozenBatchNorm2d 8 | 9 | class DeepLab(nn.Module): 10 | def __init__(self, 11 | backbone='resnet', 12 | output_stride=16, 13 | freeze_bn=True): 14 | super(DeepLab, self).__init__() 15 | 16 | if freeze_bn == True: 17 | print("Use frozen BN in DeepLab!") 18 | BatchNorm = FrozenBatchNorm2d 19 | else: 20 | BatchNorm = nn.BatchNorm2d 21 | 22 | self.backbone = build_backbone(backbone, output_stride, BatchNorm) 23 | self.aspp = build_aspp(backbone, output_stride, BatchNorm) 24 | self.decoder = build_decoder(backbone, BatchNorm) 25 | 26 | 27 | def forward(self, input, return_aspp=False): 28 | if return_aspp: 29 | x, low_level_feat, mid_level_feat = self.backbone(input, True) 30 | else: 31 | x, low_level_feat = self.backbone(input) 32 | aspp_x = self.aspp(x) 33 | x = self.decoder(aspp_x, low_level_feat) 34 | 35 | if return_aspp: 36 | return x, aspp_x, low_level_feat, mid_level_feat 37 | else: 38 | return x, low_level_feat 39 | 40 | 41 | def get_1x_lr_params(self): 42 | modules = [self.backbone] 43 | for i in range(len(modules)): 44 | for m in modules[i].named_modules(): 45 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): 46 | for p in m[1].parameters(): 47 | if p.requires_grad: 48 | yield p 49 | 50 | def get_10x_lr_params(self): 51 | modules = [self.aspp, self.decoder] 52 | for i in range(len(modules)): 53 | for m in modules[i].named_modules(): 54 | if isinstance(m[1], nn.Conv2d) or isinstance(m[1], nn.BatchNorm2d): 55 | for p in m[1].parameters(): 56 | if p.requires_grad: 57 | yield p 58 | 59 | 60 | if __name__ == "__main__": 61 | model = DeepLab(backbone='resnet', output_stride=16) 62 | model.eval() 63 | input = torch.rand(2, 3, 513, 513) 64 | output = model(input) 65 | print(output.size()) 66 | 67 | 68 | -------------------------------------------------------------------------------- /networks/engine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/engine/__init__.py -------------------------------------------------------------------------------- /networks/engine/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/engine/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/engine/__pycache__/__init__.cpython-36.pyc.140354633045472: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/engine/__pycache__/__init__.cpython-36.pyc.140354633045472 -------------------------------------------------------------------------------- /networks/engine/__pycache__/eval_manager_mm.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/engine/__pycache__/eval_manager_mm.cpython-36.pyc -------------------------------------------------------------------------------- /networks/engine/__pycache__/eval_manager_mm_rpa.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/engine/__pycache__/eval_manager_mm_rpa.cpython-36.pyc -------------------------------------------------------------------------------- /networks/engine/eval_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import time 4 | import datetime as datetime 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | import numpy as np 11 | from dataloaders.datasets import YOUTUBE_VOS_Test,DAVIS_Test 12 | import dataloaders.custom_transforms as tr 13 | from networks.deeplab.deeplab import DeepLab 14 | from utils.meters import AverageMeter 15 | from utils.image import flip_tensor, save_mask 16 | from utils.checkpoint import load_network 17 | from utils.eval import zip_folder 18 | 19 | class Evaluator(object): 20 | def __init__(self, cfg): 21 | self.gpu = cfg.TEST_GPU_ID 22 | self.cfg = cfg 23 | self.print_log(cfg.__dict__) 24 | print("Use GPU {} for evaluating".format(self.gpu)) 25 | torch.cuda.set_device(self.gpu) 26 | 27 | self.print_log('Build backbone.') 28 | self.feature_extracter = DeepLab( 29 | backbone=cfg.MODEL_BACKBONE, 30 | freeze_bn=cfg.MODEL_FREEZE_BN).cuda(self.gpu) 31 | 32 | self.print_log('Build VOS model.') 33 | RPCM = importlib.import_module(cfg.MODEL_MODULE) 34 | self.model = RPCM.get_module()( 35 | cfg, 36 | self.feature_extracter).cuda(self.gpu) 37 | 38 | self.process_pretrained_model() 39 | 40 | self.prepare_dataset() 41 | 42 | def process_pretrained_model(self): 43 | cfg = self.cfg 44 | if cfg.TEST_CKPT_PATH == 'test': 45 | self.ckpt = 'test' 46 | self.print_log('Test evaluation.') 47 | return 48 | if cfg.TEST_CKPT_PATH is None: 49 | if cfg.TEST_CKPT_STEP is not None: 50 | ckpt = str(cfg.TEST_CKPT_STEP) 51 | else: 52 | ckpts = os.listdir(cfg.DIR_CKPT) 53 | if len(ckpts) > 0: 54 | ckpts = list(map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) 55 | ckpt = np.sort(ckpts)[-1] 56 | else: 57 | self.print_log('No checkpoint in {}.'.format(cfg.DIR_CKPT)) 58 | exit() 59 | self.ckpt = ckpt 60 | cfg.TEST_CKPT_PATH = os.path.join(cfg.DIR_CKPT, 'save_step_%s.pth' % ckpt) 61 | self.model, removed_dict = load_network(self.model, cfg.TEST_CKPT_PATH, self.gpu) 62 | if len(removed_dict) > 0: 63 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 64 | self.print_log('Load latest checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 65 | else: 66 | self.ckpt = 'unknown' 67 | self.model, removed_dict = load_network(self.model, cfg.TEST_CKPT_PATH, self.gpu) 68 | if len(removed_dict) > 0: 69 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 70 | self.print_log('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 71 | 72 | def prepare_dataset(self): 73 | cfg = self.cfg 74 | self.print_log('Process dataset...') 75 | eval_transforms = transforms.Compose([ 76 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, cfg.TEST_FLIP, cfg.TEST_MULTISCALE), 77 | tr.MultiToTensor()]) 78 | 79 | eval_name = '{}_{}_ckpt_{}'.format(cfg.TEST_DATASET, cfg.EXP_NAME, self.ckpt) 80 | 81 | if cfg.TEST_FLIP: 82 | eval_name += '_flip' 83 | if len(cfg.TEST_MULTISCALE) > 1: 84 | eval_name += '_ms' 85 | for scale in cfg.TEST_MULTISCALE: 86 | eval_name +="_" 87 | eval_name += str(scale) 88 | 89 | eval_name += "_ORG" 90 | 91 | if cfg.TEST_DATASET == 'youtubevos19': 92 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 93 | self.dataset = YOUTUBE_VOS_Test( 94 | root=cfg.DIR_YTB_EVAL19, 95 | transform=eval_transforms, 96 | result_root=self.result_root) 97 | 98 | elif cfg.TEST_DATASET == 'youtubevos18': 99 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 100 | self.dataset = YOUTUBE_VOS_Test( 101 | root=cfg.DIR_YTB_EVAL18, 102 | transform=eval_transforms, 103 | result_root=self.result_root) 104 | else: 105 | print('Unknown dataset!') 106 | exit() 107 | 108 | print('Eval {} on {}:'.format(cfg.EXP_NAME, cfg.TEST_DATASET)) 109 | self.source_folder = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 110 | self.zip_dir = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, '{}.zip'.format(eval_name)) 111 | if not os.path.exists(self.result_root): 112 | os.makedirs(self.result_root) 113 | self.print_log('Done!') 114 | 115 | def evaluating(self): 116 | cfg = self.cfg 117 | self.model.eval() 118 | video_num = 0 119 | total_time = 0 120 | total_frame = 0 121 | total_sfps = 0 122 | total_video_num = len(self.dataset) 123 | PlaceHolder=[] 124 | for i in range(cfg.BLOCK_NUM): 125 | PlaceHolder.append(None) 126 | 127 | for seq_idx, seq_dataset in enumerate(self.dataset): 128 | video_num += 1 129 | seq_name = seq_dataset.seq_name 130 | print('Prcessing Seq {} [{}/{}]:'.format(seq_name, video_num, total_video_num)) 131 | torch.cuda.empty_cache() 132 | 133 | seq_dataloader=DataLoader(seq_dataset, batch_size=1, shuffle=False, num_workers=cfg.TEST_WORKERS, pin_memory=True) 134 | 135 | seq_total_time = 0 136 | seq_total_frame = 0 137 | ref_embeddings = [] 138 | ref_masks = [] 139 | prev_embedding = [] 140 | prev_mask = [] 141 | memory_prev_all_list=[] 142 | memory_cur_all_list=[] 143 | memory_prev_list=[] 144 | memory_cur_list=[] 145 | with torch.no_grad(): 146 | for frame_idx, samples in enumerate(seq_dataloader): 147 | time_start = time.time() 148 | all_preds = [] 149 | join_label = None 150 | 151 | if frame_idx==0: 152 | for aug_idx in range(len(samples)): 153 | memory_prev_all_list.append([PlaceHolder]) 154 | 155 | else: 156 | memory_prev_all_list=memory_cur_all_list 157 | 158 | memory_cur_all_list=[] 159 | 160 | for aug_idx in range(len(samples)): 161 | if len(ref_embeddings) <= aug_idx: 162 | ref_embeddings.append([]) 163 | ref_masks.append([]) 164 | prev_embedding.append(None) 165 | prev_mask.append(None) 166 | 167 | sample = samples[aug_idx] 168 | ref_emb = ref_embeddings[aug_idx] 169 | ref_m = ref_masks[aug_idx] 170 | prev_emb = prev_embedding[aug_idx] 171 | prev_m = prev_mask[aug_idx] 172 | 173 | current_img = sample['current_img'] 174 | if 'current_label' in sample.keys(): 175 | current_label = sample['current_label'].cuda(self.gpu) 176 | else: 177 | current_label = None 178 | 179 | obj_num = sample['meta']['obj_num'] 180 | imgname = sample['meta']['current_name'] 181 | ori_height = sample['meta']['height'] 182 | ori_width = sample['meta']['width'] 183 | current_img = current_img.cuda(self.gpu) 184 | obj_num = obj_num.cuda(self.gpu) 185 | bs, _, h, w = current_img.size() 186 | 187 | all_pred, current_embedding,memory_cur_list = self.model.forward_for_eval(memory_prev_all_list[aug_idx], ref_emb, 188 | ref_m, prev_emb, prev_m, 189 | current_img, gt_ids=obj_num, 190 | pred_size=[ori_height,ori_width]) 191 | memory_cur_all_list.append(memory_cur_list) 192 | 193 | if frame_idx == 0: 194 | if current_label is None: 195 | print("No first frame label in Seq {}.".format(seq_name)) 196 | ref_embeddings[aug_idx].append(current_embedding) 197 | ref_masks[aug_idx].append(current_label) 198 | 199 | prev_embedding[aug_idx] = current_embedding 200 | prev_mask[aug_idx] = current_label 201 | else: 202 | if sample['meta']['flip']: 203 | all_pred = flip_tensor(all_pred, 3) 204 | # In YouTube-VOS, not all the objects appear in the first frame for the first time. Thus, we 205 | # have to introduce new labels for new objects, if necessary. 206 | if not sample['meta']['flip'] and not(current_label is None) and join_label is None: 207 | join_label = current_label 208 | all_preds.append(all_pred) 209 | if current_label is not None: 210 | ref_embeddings[aug_idx].append(current_embedding) 211 | prev_embedding[aug_idx] = current_embedding 212 | 213 | if frame_idx > 0: 214 | all_preds = torch.cat(all_preds, dim=0) 215 | all_preds = torch.mean(all_preds, dim=0) 216 | pred_label = torch.argmax(all_preds, dim=0) 217 | if join_label is not None: 218 | join_label = join_label.squeeze(0).squeeze(0) 219 | keep = (join_label == 0).long() 220 | pred_label = pred_label * keep + join_label * (1 - keep) 221 | pred_label = pred_label 222 | current_label = pred_label.view(1, 1, ori_height, ori_width) 223 | flip_pred_label = flip_tensor(pred_label, 1) 224 | flip_current_label = flip_pred_label.view(1, 1, ori_height, ori_width) 225 | 226 | for aug_idx in range(len(samples)): 227 | if join_label is not None: 228 | if samples[aug_idx]['meta']['flip']: 229 | ref_masks[aug_idx].append(flip_current_label) 230 | else: 231 | ref_masks[aug_idx].append(current_label) 232 | if samples[aug_idx]['meta']['flip']: 233 | prev_mask[aug_idx] = flip_current_label 234 | else: 235 | prev_mask[aug_idx] = current_label 236 | 237 | one_frametime = time.time() - time_start 238 | seq_total_time += one_frametime 239 | seq_total_frame += 1 240 | obj_num = obj_num[0].item() 241 | print('Frame: {}, Obj Num: {}, Time: {}'.format(imgname[0], obj_num, one_frametime)) 242 | # Save result 243 | save_mask(pred_label, os.path.join(self.result_root, seq_name, imgname[0].split('.')[0]+'.png')) 244 | else: 245 | one_frametime = time.time() - time_start 246 | seq_total_time += one_frametime 247 | print('Ref Frame: {}, Time: {}'.format(imgname[0], one_frametime)) 248 | 249 | del(ref_embeddings) 250 | del(ref_masks) 251 | del(prev_embedding) 252 | del(prev_mask) 253 | del(seq_dataset) 254 | del(seq_dataloader) 255 | 256 | seq_avg_time_per_frame = seq_total_time / seq_total_frame 257 | total_time += seq_total_time 258 | total_frame += seq_total_frame 259 | total_avg_time_per_frame = total_time / total_frame 260 | total_sfps += seq_avg_time_per_frame 261 | avg_sfps = total_sfps / (seq_idx + 1) 262 | print("Seq {} FPS: {}, Total FPS: {}, FPS per Seq: {}".format(seq_name, 1./seq_avg_time_per_frame, 1./total_avg_time_per_frame, 1./avg_sfps)) 263 | 264 | zip_folder(self.source_folder, self.zip_dir) 265 | self.print_log('Save result to {}.'.format(self.zip_dir)) 266 | 267 | 268 | def print_log(self, string): 269 | print(string) 270 | 271 | 272 | 273 | 274 | 275 | -------------------------------------------------------------------------------- /networks/engine/eval_manager_rpa.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import time 4 | import datetime as datetime 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torchvision import transforms 10 | import numpy as np 11 | from dataloaders.datasets import YOUTUBE_VOS_Test,DAVIS_Test 12 | import dataloaders.custom_transforms as tr 13 | from networks.deeplab.deeplab import DeepLab 14 | from utils.meters import AverageMeter 15 | from utils.image import flip_tensor, save_mask 16 | from utils.checkpoint import load_network 17 | from utils.eval import zip_folder 18 | from networks.layers.shannon_entropy import cal_shannon_entropy 19 | import math 20 | 21 | class Evaluator(object): 22 | def __init__(self, cfg): 23 | 24 | self.mem_every = cfg.MEM_EVERY 25 | self.unc_ratio = cfg.UNC_RATIO 26 | 27 | self.gpu = cfg.TEST_GPU_ID 28 | self.cfg = cfg 29 | self.print_log(cfg.__dict__) 30 | print("Use GPU {} for evaluating".format(self.gpu)) 31 | torch.cuda.set_device(self.gpu) 32 | 33 | self.print_log('Build backbone.') 34 | self.feature_extracter = DeepLab( 35 | backbone=cfg.MODEL_BACKBONE, 36 | freeze_bn=cfg.MODEL_FREEZE_BN).cuda(self.gpu) 37 | 38 | self.print_log('Build VOS model.') 39 | RPCM = importlib.import_module(cfg.MODEL_MODULE) 40 | self.model = RPCM.get_module()( 41 | cfg, 42 | self.feature_extracter).cuda(self.gpu) 43 | 44 | self.process_pretrained_model() 45 | 46 | self.prepare_dataset() 47 | 48 | def process_pretrained_model(self): 49 | cfg = self.cfg 50 | if cfg.TEST_CKPT_PATH == 'test': 51 | self.ckpt = 'test' 52 | self.print_log('Test evaluation.') 53 | return 54 | if cfg.TEST_CKPT_PATH is None: 55 | if cfg.TEST_CKPT_STEP is not None: 56 | ckpt = str(cfg.TEST_CKPT_STEP) 57 | else: 58 | ckpts = os.listdir(cfg.DIR_CKPT) 59 | if len(ckpts) > 0: 60 | ckpts = list(map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) 61 | ckpt = np.sort(ckpts)[-1] 62 | else: 63 | self.print_log('No checkpoint in {}.'.format(cfg.DIR_CKPT)) 64 | exit() 65 | self.ckpt = ckpt 66 | cfg.TEST_CKPT_PATH = os.path.join(cfg.DIR_CKPT, 'save_step_%s.pth' % ckpt) 67 | self.model, removed_dict = load_network(self.model, cfg.TEST_CKPT_PATH, self.gpu) 68 | if len(removed_dict) > 0: 69 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 70 | self.print_log('Load latest checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 71 | else: 72 | self.ckpt = 'unknown' 73 | self.model, removed_dict = load_network(self.model, cfg.TEST_CKPT_PATH, self.gpu) 74 | if len(removed_dict) > 0: 75 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 76 | self.print_log('Load checkpoint from {}'.format(cfg.TEST_CKPT_PATH)) 77 | 78 | def prepare_dataset(self): 79 | cfg = self.cfg 80 | self.print_log('Process dataset...') 81 | eval_transforms = transforms.Compose([ 82 | tr.MultiRestrictSize(cfg.TEST_MIN_SIZE, cfg.TEST_MAX_SIZE, cfg.TEST_FLIP, cfg.TEST_MULTISCALE), 83 | tr.MultiToTensor()]) 84 | 85 | eval_name = '{}_{}_ckpt_{}'.format(cfg.TEST_DATASET, cfg.EXP_NAME, self.ckpt) 86 | if cfg.TEST_FLIP: 87 | eval_name += '_flip' 88 | if len(cfg.TEST_MULTISCALE) > 1: 89 | eval_name += '_ms' 90 | for scale in cfg.TEST_MULTISCALE: 91 | eval_name +="_" 92 | eval_name +=str(scale) 93 | 94 | eval_name += "_mem_"+str(self.mem_every)+"_unc_"+str(self.unc_ratio)+"_res_"+str(cfg.TEST_MAX_SIZE)+"_wRPA" 95 | 96 | if cfg.TEST_DATASET == 'youtubevos19': 97 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 98 | self.dataset = YOUTUBE_VOS_Test( 99 | root=cfg.DIR_YTB_EVAL19, 100 | transform=eval_transforms, 101 | result_root=self.result_root) 102 | 103 | elif cfg.TEST_DATASET == 'youtubevos18': 104 | self.result_root = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 105 | self.dataset = YOUTUBE_VOS_Test( 106 | root=cfg.DIR_YTB_EVAL18, 107 | transform=eval_transforms, 108 | result_root=self.result_root) 109 | else: 110 | print('Unknown dataset!') 111 | exit() 112 | 113 | print('Eval {} on {}:'.format(cfg.EXP_NAME, cfg.TEST_DATASET)) 114 | self.source_folder = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, eval_name, 'Annotations') 115 | self.zip_dir = os.path.join(cfg.DIR_EVALUATION, cfg.TEST_DATASET, '{}.zip'.format(eval_name)) 116 | if not os.path.exists(self.result_root): 117 | os.makedirs(self.result_root) 118 | self.print_log('Done!') 119 | 120 | def evaluating(self): 121 | cfg = self.cfg 122 | self.model.eval() 123 | video_num = 0 124 | total_time = 0 125 | total_frame = 0 126 | total_sfps = 0 127 | total_video_num = len(self.dataset) 128 | PlaceHolder=[] 129 | for i in range(cfg.BLOCK_NUM): 130 | PlaceHolder.append(None) 131 | 132 | for seq_idx, seq_dataset in enumerate(self.dataset): 133 | video_num += 1 134 | seq_name = seq_dataset.seq_name 135 | 136 | print('Prcessing Seq {} [{}/{}]:'.format(seq_name, video_num, total_video_num)) 137 | 138 | torch.cuda.empty_cache() 139 | 140 | seq_dataloader=DataLoader(seq_dataset, batch_size=1, shuffle=False, num_workers=cfg.TEST_WORKERS, pin_memory=True) 141 | 142 | seq_total_time = 0 143 | seq_total_frame = 0 144 | ref_embeddings = [] 145 | ref_masks = [] 146 | prev_embedding = [] 147 | prev_mask = [] 148 | ref_mask_confident = [] 149 | memory_prev_all_list=[] 150 | memory_cur_all_list=[] 151 | memory_prev_list=[] 152 | memory_cur_list=[] 153 | label_all_list=[] 154 | 155 | with torch.no_grad(): 156 | for frame_idx, samples in enumerate(seq_dataloader): 157 | 158 | time_start = time.time() 159 | all_preds = [] 160 | 161 | join_label = None 162 | UPDATE=False 163 | 164 | 165 | if frame_idx==0: 166 | for aug_idx in range(len(samples)): 167 | memory_prev_all_list.append([PlaceHolder]) 168 | else: 169 | memory_prev_all_list=memory_cur_all_list 170 | 171 | memory_cur_all_list=[] 172 | for aug_idx in range(len(samples)): 173 | if len(ref_embeddings) <= aug_idx: 174 | ref_embeddings.append([]) 175 | ref_masks.append([]) 176 | prev_embedding.append(None) 177 | prev_mask.append(None) 178 | ref_mask_confident.append([]) 179 | 180 | sample = samples[aug_idx] 181 | ref_emb = ref_embeddings[aug_idx] 182 | 183 | ## use confident mask for correlation 184 | ref_m = ref_mask_confident[aug_idx] 185 | 186 | prev_emb = prev_embedding[aug_idx] 187 | prev_m = prev_mask[aug_idx] 188 | 189 | 190 | current_img = sample['current_img'] 191 | if 'current_label' in sample.keys(): 192 | current_label = sample['current_label'].cuda(self.gpu) 193 | else: 194 | current_label = None 195 | 196 | obj_list = sample['meta']['obj_list'] 197 | obj_num = sample['meta']['obj_num'] 198 | imgname = sample['meta']['current_name'] 199 | ori_height = sample['meta']['height'] 200 | ori_width = sample['meta']['width'] 201 | current_img = current_img.cuda(self.gpu) 202 | obj_num = obj_num.cuda(self.gpu) 203 | bs, _, h, w = current_img.size() 204 | 205 | all_pred, current_embedding,memory_cur_list = self.model.forward_for_eval(memory_prev_all_list[aug_idx], ref_emb, 206 | ref_m, prev_emb, prev_m, 207 | current_img, gt_ids=obj_num, 208 | pred_size=[ori_height,ori_width]) 209 | memory_cur_all_list.append(memory_cur_list) 210 | 211 | # delete the label that hasn't existed in the GT label for YTB-VOS 212 | all_pred_remake = [] 213 | all_pred_exist = [] 214 | if all_pred!=None: 215 | all_pred_split = all_pred.split(all_pred.size()[1],dim=1)[0] 216 | 217 | for i in range(all_pred.size()[1]): 218 | if i not in label_all_list: 219 | all_pred_remake.append(torch.zeros_like(all_pred_split[0][i]).unsqueeze(0)) 220 | else: 221 | all_pred_remake.append(all_pred_split[0][i].unsqueeze(0)) 222 | all_pred_exist.append(all_pred_split[0][i].unsqueeze(0)) 223 | all_pred = torch.cat(all_pred_remake,dim=0).unsqueeze(0) 224 | all_pred_exist = torch.cat(all_pred_exist,dim=0).unsqueeze(0) 225 | 226 | 227 | if 'current_label' in sample.keys(): 228 | label_cur_list = np.unique(sample['current_label'].cpu().detach().numpy()).tolist() 229 | for i in label_cur_list: 230 | if i not in label_all_list: 231 | label_all_list.append(i) 232 | 233 | if frame_idx == 0: 234 | if current_label is None: 235 | print("No first frame label in Seq {}.".format(seq_name)) 236 | ref_embeddings[aug_idx].append(current_embedding) 237 | ref_masks[aug_idx].append(current_label) 238 | ref_mask_confident[aug_idx].append(current_label) 239 | 240 | prev_embedding[aug_idx] = current_embedding 241 | prev_mask[aug_idx] = current_label 242 | 243 | else: 244 | if sample['meta']['flip']: 245 | all_pred = flip_tensor(all_pred, 3) 246 | 247 | # In YouTube-VOS, not all the objects appear in the first frame for the first time. Thus, we 248 | # have to introduce new labels for new objects, if necessary. 249 | if not sample['meta']['flip'] and not(current_label is None) and join_label is None: # gt exists here 250 | join_label = current_label 251 | all_preds.append(all_pred) 252 | 253 | all_pred_org = all_pred 254 | current_label_0 = None 255 | 256 | if current_label is not None: 257 | ref_embeddings[aug_idx].append(current_embedding) 258 | 259 | else: 260 | all_preds_0 = torch.cat(all_preds, dim=0) 261 | all_preds_0 = torch.mean(all_preds_0, dim=0) 262 | pred_label_0 = torch.argmax(all_preds_0, dim=0) 263 | current_label_0 = pred_label_0.view(1, 1, ori_height, ori_width) 264 | 265 | # uncertainty region filter 266 | uncertainty_org,uncertainty_norm = cal_shannon_entropy(all_pred_exist) 267 | 268 | # we set mem_every == -1 to indicate we don't use extra confident candidate pool 269 | if self.mem_every>-1 and frame_idx % self.mem_every==0 and frame_idx!=0 and current_embedding!=None and current_label_0!=None: 270 | ref_embeddings[aug_idx].append(current_embedding) 271 | ref_masks[aug_idx].append(current_label_0) 272 | UPDATE=True 273 | 274 | 275 | prev_embedding[aug_idx] = current_embedding 276 | 277 | if frame_idx > 0: 278 | all_preds = torch.cat(all_preds, dim=0) 279 | all_preds = torch.mean(all_preds, dim=0) 280 | pred_label = torch.argmax(all_preds, dim=0) 281 | if join_label is not None: 282 | join_label = join_label.squeeze(0).squeeze(0) 283 | keep = (join_label == 0).long() 284 | pred_label = pred_label * keep + join_label * (1 - keep) 285 | pred_label = pred_label 286 | current_label = pred_label.view(1, 1, ori_height, ori_width) 287 | if samples[aug_idx]['meta']['flip']: 288 | flip_pred_label = flip_tensor(pred_label, 1) 289 | flip_current_label = flip_pred_label.view(1, 1, ori_height, ori_width) 290 | 291 | for aug_idx in range(len(samples)): 292 | if join_label is not None: 293 | if samples[aug_idx]['meta']['flip']: 294 | ref_masks[aug_idx].append(flip_current_label) 295 | ref_mask_confident[aug_idx].append(flip_current_label) 296 | else: 297 | ref_masks[aug_idx].append(current_label) 298 | 299 | uncertainty_org,uncertainty_norm = cal_shannon_entropy(all_pred_exist) 300 | join_label = join_label.squeeze(0).squeeze(0) 301 | keep = (join_label == 0).long() 302 | join_uncertainty_map = (join_label <0).long() 303 | uncertainty_org = uncertainty_org * keep + join_uncertainty_map * (1 - keep) 304 | 305 | uncertainty_region = (uncertainty_org>self.unc_ratio ).long() 306 | 307 | # we use 125 to represent the filtered patches 308 | pred_label_c = pred_label* (1 - uncertainty_region) + (125) * uncertainty_region 309 | pred_label_c = pred_label_c.view(1, 1, ori_height, ori_width) 310 | 311 | ref_mask_confident[aug_idx].append(pred_label_c) 312 | 313 | if samples[aug_idx]['meta']['flip']: 314 | prev_mask[aug_idx] = flip_current_label 315 | else: 316 | prev_mask[aug_idx] = current_label 317 | 318 | if UPDATE: 319 | if self.mem_every>-1 and frame_idx%self.mem_every==0 and frame_idx!=0 and current_embedding!=None and current_label_0!=None : 320 | uncertainty_region = (uncertainty_org>self.unc_ratio ).long() 321 | pred_label_c = pred_label* (1 - uncertainty_region) + (125) * uncertainty_region 322 | pred_label_c = pred_label_c.view(1, 1, ori_height, ori_width) 323 | ref_mask_confident[aug_idx].append(pred_label_c) 324 | 325 | one_frametime = time.time() - time_start 326 | seq_total_time += one_frametime 327 | seq_total_frame += 1 328 | obj_num = obj_num[0].item() 329 | print('Frame: {}, Obj Num: {}, Time: {}'.format(imgname[0], obj_num, one_frametime)) 330 | # Save result 331 | save_mask(pred_label, os.path.join(self.result_root, seq_name, imgname[0].split('.')[0]+'.png')) 332 | 333 | else: 334 | one_frametime = time.time() - time_start 335 | seq_total_time += one_frametime 336 | print('Ref Frame: {}, Time: {}'.format(imgname[0], one_frametime)) 337 | 338 | del(ref_embeddings) 339 | del(ref_masks) 340 | del(prev_embedding) 341 | del(prev_mask) 342 | del(seq_dataset) 343 | del(seq_dataloader) 344 | del(memory_cur_all_list) 345 | 346 | 347 | seq_avg_time_per_frame = seq_total_time / seq_total_frame 348 | total_time += seq_total_time 349 | total_frame += seq_total_frame 350 | total_avg_time_per_frame = total_time / total_frame 351 | total_sfps += seq_avg_time_per_frame 352 | avg_sfps = total_sfps / (seq_idx + 1) 353 | print("Seq {} FPS: {}, Total FPS: {}, FPS per Seq: {}".format(seq_name, 1./seq_avg_time_per_frame, 1./total_avg_time_per_frame, 1./avg_sfps)) 354 | 355 | zip_folder(self.source_folder, self.zip_dir) 356 | self.print_log('Save result to {}.'.format(self.zip_dir)) 357 | 358 | 359 | def print_log(self, string): 360 | print(string) 361 | 362 | 363 | 364 | 365 | 366 | -------------------------------------------------------------------------------- /networks/engine/train_manager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import importlib 3 | import time 4 | import datetime as datetime 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | import torch.distributed as dist 9 | from torch.utils.data import DataLoader 10 | from torchvision import transforms 11 | import numpy as np 12 | from dataloaders.datasets import DAVIS2017_Train, YOUTUBE_VOS_Train, TEST 13 | import dataloaders.custom_transforms as tr 14 | from networks.deeplab.deeplab import DeepLab 15 | from utils.meters import AverageMeter 16 | from utils.image import label2colormap, masked_image, save_image 17 | from utils.checkpoint import load_network_and_optimizer, load_network, save_network 18 | from utils.learning import adjust_learning_rate, get_trainable_params 19 | from utils.metric import pytorch_iou 20 | #torch.backends.cudnn.enabled = True 21 | #torch.backends.cudnn.benchmark = True 22 | class Trainer(object): 23 | def __init__(self , rank, cfg): 24 | self.gpu = rank + cfg.DIST_START_GPU 25 | self.rank = rank 26 | self.cfg = cfg 27 | self.print_log(cfg.__dict__) 28 | print("Use GPU {} for training".format(self.gpu)) 29 | torch.cuda.set_device(self.gpu) 30 | 31 | self.print_log('Build backbone.') 32 | self.feature_extracter = DeepLab( 33 | backbone=cfg.MODEL_BACKBONE, 34 | freeze_bn=cfg.MODEL_FREEZE_BN).cuda(self.gpu) 35 | 36 | if cfg.MODEL_FREEZE_BACKBONE: 37 | for param in self.feature_extracter.parameters(): 38 | param.requires_grad = False 39 | 40 | self.print_log('Build VOS model.') 41 | RPCM = importlib.import_module(cfg.MODEL_MODULE) 42 | 43 | self.model = RPCM.get_module()( 44 | cfg, 45 | self.feature_extracter).cuda(self.gpu) 46 | 47 | if cfg.DIST_ENABLE: 48 | dist.init_process_group( 49 | backend=cfg.DIST_BACKEND, 50 | init_method=cfg.DIST_URL, 51 | world_size=cfg.TRAIN_GPUS, 52 | rank=rank, 53 | timeout=datetime.timedelta(seconds=300)) 54 | self.dist_model = torch.nn.parallel.DistributedDataParallel( 55 | self.model, 56 | device_ids=[self.gpu], 57 | find_unused_parameters=True) 58 | else: 59 | self.dist_model = self.model 60 | 61 | self.print_log('Build optimizer.') 62 | trainable_params = get_trainable_params( 63 | model=self.dist_model, 64 | base_lr=cfg.TRAIN_LR, 65 | weight_decay=cfg.TRAIN_WEIGHT_DECAY, 66 | beta_wd=cfg.MODEL_GCT_BETA_WD) 67 | 68 | self.optimizer = optim.SGD( 69 | trainable_params, 70 | lr=cfg.TRAIN_LR, 71 | momentum=cfg.TRAIN_MOMENTUM, 72 | nesterov=True) 73 | 74 | self.prepare_dataset() 75 | self.process_pretrained_model() 76 | 77 | if cfg.TRAIN_TBLOG and self.rank == 0: 78 | from tensorboardX import SummaryWriter 79 | self.tblogger = SummaryWriter(cfg.DIR_TB_LOG) 80 | 81 | def process_pretrained_model(self): 82 | cfg = self.cfg 83 | 84 | self.step = cfg.TRAIN_START_STEP 85 | self.epoch = 0 86 | 87 | if cfg.TRAIN_AUTO_RESUME: 88 | ckpts = os.listdir(cfg.DIR_CKPT) 89 | if len(ckpts) > 0: 90 | ckpts = list(map(lambda x: int(x.split('_')[-1].split('.')[0]), ckpts)) 91 | ckpt = np.sort(ckpts)[-1] 92 | cfg.TRAIN_RESUME = True 93 | cfg.TRAIN_RESUME_CKPT = ckpt 94 | cfg.TRAIN_RESUME_STEP = ckpt + 1 95 | else: 96 | cfg.TRAIN_RESUME = False 97 | 98 | if cfg.TRAIN_RESUME: 99 | resume_ckpt = os.path.join(cfg.DIR_CKPT, 'save_step_%s.pth' % (cfg.TRAIN_RESUME_CKPT)) 100 | 101 | self.model, self.optimizer, removed_dict = load_network_and_optimizer(self.model, self.optimizer, resume_ckpt, self.gpu) 102 | 103 | if len(removed_dict) > 0: 104 | self.print_log('Remove {} from checkpoint.'.format(removed_dict)) 105 | 106 | self.step = cfg.TRAIN_RESUME_STEP 107 | if cfg.TRAIN_TOTAL_STEPS <= self.step: 108 | self.print_log("Your training has finished!") 109 | exit() 110 | self.epoch = int(np.ceil(self.step / len(self.trainloader))) 111 | 112 | self.print_log('Resume from step {}'.format(self.step)) 113 | 114 | elif cfg.PRETRAIN: 115 | if cfg.PRETRAIN_FULL: 116 | self.model, removed_dict = load_network(self.model, cfg.PRETRAIN_MODEL, self.gpu) 117 | if len(removed_dict) > 0: 118 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 119 | self.print_log('Load pretrained VOS model from {}.'.format(cfg.PRETRAIN_MODEL)) 120 | else: 121 | feature_extracter, removed_dict = load_network(self.feature_extracter, cfg.PRETRAIN_MODEL, self.gpu) 122 | if len(removed_dict) > 0: 123 | self.print_log('Remove {} from pretrained model.'.format(removed_dict)) 124 | self.print_log('Load pretrained backbone model from {}.'.format(cfg.PRETRAIN_MODEL)) 125 | 126 | def prepare_dataset(self): 127 | cfg = self.cfg 128 | self.print_log('Process dataset...') 129 | composed_transforms = transforms.Compose([ 130 | tr.RandomScale(cfg.DATA_MIN_SCALE_FACTOR, cfg.DATA_MAX_SCALE_FACTOR, cfg.DATA_SHORT_EDGE_LEN), 131 | tr.BalancedRandomCrop(cfg.DATA_RANDOMCROP), 132 | tr.RandomHorizontalFlip(cfg.DATA_RANDOMFLIP), 133 | tr.Resize(cfg.DATA_RANDOMCROP), 134 | tr.ToTensor()]) 135 | 136 | train_datasets = [] 137 | if 'davis2017' in cfg.DATASETS: 138 | train_davis_dataset = DAVIS2017_Train( 139 | root=cfg.DIR_DAVIS, 140 | full_resolution=cfg.TRAIN_DATASET_FULL_RESOLUTION, 141 | transform=composed_transforms, 142 | repeat_time=cfg.DATA_DAVIS_REPEAT, 143 | curr_len=cfg.DATA_CURR_SEQ_LEN, 144 | rand_gap=cfg.DATA_RANDOM_GAP_DAVIS, 145 | rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ) 146 | train_datasets.append(train_davis_dataset) 147 | 148 | if 'youtubevos' in cfg.DATASETS: 149 | train_ytb_dataset = YOUTUBE_VOS_Train( 150 | root=cfg.DIR_YTB, 151 | transform=composed_transforms, 152 | curr_len=cfg.DATA_CURR_SEQ_LEN, 153 | rand_gap=cfg.DATA_RANDOM_GAP_YTB, 154 | rand_reverse=cfg.DATA_RANDOM_REVERSE_SEQ) 155 | train_datasets.append(train_ytb_dataset) 156 | 157 | if 'test' in cfg.DATASETS: 158 | test_dataset = TEST( 159 | transform=composed_transforms, 160 | curr_len=cfg.DATA_CURR_SEQ_LEN) 161 | train_datasets.append(test_dataset) 162 | 163 | if len(train_datasets) > 1: 164 | train_dataset = torch.utils.data.ConcatDataset(train_datasets) 165 | elif len(train_datasets) == 1: 166 | train_dataset = train_datasets[0] 167 | else: 168 | self.print_log('No dataset!') 169 | exit(0) 170 | 171 | self.train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 172 | self.trainloader = DataLoader( 173 | train_dataset, 174 | batch_size=int(cfg.TRAIN_BATCH_SIZE / cfg.TRAIN_GPUS), 175 | shuffle=False, 176 | num_workers=cfg.DATA_WORKERS, 177 | pin_memory=True, 178 | sampler=self.train_sampler) 179 | 180 | self.print_log('Done!') 181 | 182 | def sequential_training(self): 183 | 184 | cfg = self.cfg 185 | 186 | running_losses = [] 187 | running_ious = [] 188 | for _ in range(cfg.DATA_CURR_SEQ_LEN): 189 | running_losses.append(AverageMeter()) 190 | running_ious.append(AverageMeter()) 191 | batch_time = AverageMeter() 192 | avg_obj = AverageMeter() 193 | 194 | optimizer = self.optimizer 195 | model = self.dist_model 196 | train_sampler = self.train_sampler 197 | trainloader = self.trainloader 198 | step = self.step 199 | epoch = self.epoch 200 | max_itr = cfg.TRAIN_TOTAL_STEPS 201 | 202 | PlaceHolder=[] 203 | for i in range(cfg.BLOCK_NUM): 204 | PlaceHolder.append(None) 205 | 206 | self.print_log('Start training.') 207 | model.train() 208 | while step < cfg.TRAIN_TOTAL_STEPS: 209 | train_sampler.set_epoch(epoch) 210 | epoch += 1 211 | last_time = time.time() 212 | for frame_idx, sample in enumerate(trainloader): 213 | now_lr = adjust_learning_rate( 214 | optimizer=optimizer, 215 | base_lr=cfg.TRAIN_LR, 216 | p=cfg.TRAIN_POWER, 217 | itr=step, 218 | max_itr=max_itr, 219 | warm_up_steps=cfg.TRAIN_WARM_UP_STEPS, 220 | is_cosine_decay=cfg.TRAIN_COSINE_DECAY) 221 | 222 | ref_imgs = sample['ref_img'] # batch_size * 3 * h * w 223 | prev_imgs = sample['prev_img'] 224 | curr_imgs = sample['curr_img'][0] 225 | ref_labels = sample['ref_label'] # batch_size * 1 * h * w 226 | prev_labels = sample['prev_label'] 227 | curr_labels = sample['curr_label'][0] 228 | obj_nums = sample['meta']['obj_num'] 229 | bs, _, h, w = curr_imgs.size() 230 | 231 | ref_labels = ref_labels.cuda(self.gpu) 232 | prev_labels = prev_labels.cuda(self.gpu) 233 | curr_labels = curr_labels.cuda(self.gpu) 234 | obj_nums = obj_nums.cuda(self.gpu) 235 | 236 | if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0 and cfg.TRAIN_TBLOG: 237 | tf_board = True 238 | else: 239 | tf_board = False 240 | 241 | # Sequential training 242 | all_boards = [] 243 | curr_imgs = prev_imgs 244 | curr_labels = prev_labels 245 | all_pred = prev_labels.squeeze(1) 246 | optimizer.zero_grad() 247 | memory_cur_list=[] 248 | memory_prev_list=[] 249 | for iii in range(int(cfg.TRAIN_BATCH_SIZE//cfg.TRAIN_GPUS)): 250 | memory_cur_list.append(PlaceHolder) 251 | memory_prev_list.append(PlaceHolder) 252 | 253 | for idx in range(cfg.DATA_CURR_SEQ_LEN): 254 | prev_imgs = curr_imgs 255 | curr_imgs = sample['curr_img'][idx] 256 | inputs = torch.cat((ref_imgs, prev_imgs, curr_imgs), 0).cuda(self.gpu) 257 | if step > cfg.TRAIN_START_SEQ_TRAINING_STEPS: 258 | # Use previous prediction instead of ground-truth mask 259 | prev_labels = all_pred.unsqueeze(1) 260 | else: 261 | # Use previous ground-truth mask 262 | prev_labels = curr_labels 263 | curr_labels = sample['curr_label'][idx].cuda(self.gpu) 264 | 265 | loss, all_pred, boards,memory_cur_list = model( 266 | inputs, 267 | memory_prev_list, 268 | ref_labels, 269 | prev_labels, 270 | curr_labels, 271 | gt_ids=obj_nums, 272 | step=step, 273 | tf_board=tf_board) 274 | 275 | memory_prev_list = memory_cur_list 276 | 277 | iou = pytorch_iou(all_pred.unsqueeze(1), curr_labels, obj_nums) 278 | loss = torch.mean(loss) / cfg.DATA_CURR_SEQ_LEN 279 | loss.backward() 280 | all_boards.append(boards) 281 | running_losses[idx].update(loss.item() * cfg.DATA_CURR_SEQ_LEN) 282 | running_ious[idx].update(iou.item()) 283 | torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.TRAIN_CLIP_GRAD_NORM) 284 | optimizer.step() 285 | batch_time.update(time.time() - last_time) 286 | avg_obj.update(obj_nums.float().mean().item()) 287 | last_time = time.time() 288 | 289 | if step % cfg.TRAIN_TBLOG_STEP == 0 and self.rank == 0: 290 | self.process_log( 291 | ref_imgs, prev_imgs, curr_imgs, 292 | ref_labels, prev_labels, curr_labels, 293 | all_pred, all_boards, running_losses, running_ious, now_lr, step) 294 | 295 | if step % cfg.TRAIN_LOG_STEP == 0 and self.rank == 0: 296 | strs = 'Itr:{}, LR:{:.7f}, Time:{:.3f}, Obj:{:.1f}'.format(step, now_lr, batch_time.avg, avg_obj.avg) 297 | batch_time.reset() 298 | avg_obj.reset() 299 | for idx in range(cfg.DATA_CURR_SEQ_LEN): 300 | strs += ', S{}: L {:.3f}({:.3f}) IoU {:.3f}({:.3f})'.format(idx, running_losses[idx].val, running_losses[idx].avg, 301 | running_ious[idx].val, running_ious[idx].avg) 302 | running_losses[idx].reset() 303 | running_ious[idx].reset() 304 | 305 | self.print_log(strs) 306 | 307 | if step % cfg.TRAIN_SAVE_STEP == 0 and step != 0 and self.rank == 0: 308 | self.print_log('Save CKPT (Step {}).'.format(step)) 309 | save_network(self.model, optimizer, step, cfg.DIR_CKPT, cfg.TRAIN_MAX_KEEP_CKPT) 310 | 311 | step += 1 312 | if step > cfg.TRAIN_TOTAL_STEPS: 313 | break 314 | 315 | if self.rank == 0: 316 | self.print_log('Save final CKPT (Step {}).'.format(step - 1)) 317 | save_network(self.model, optimizer, step - 1, cfg.DIR_CKPT, cfg.TRAIN_MAX_KEEP_CKPT) 318 | 319 | def print_log(self, string): 320 | if self.rank == 0: 321 | print(string) 322 | 323 | 324 | def process_log(self, 325 | ref_imgs, prev_imgs, curr_imgs, 326 | ref_labels, prev_labels, curr_labels, 327 | curr_pred, all_boards, running_losses, running_ious, now_lr, step): 328 | cfg = self.cfg 329 | 330 | mean = np.array([[[0.485]], [[0.456]], [[0.406]]]) 331 | sigma = np.array([[[0.229]], [[0.224]], [[0.225]]]) 332 | 333 | show_ref_img, show_prev_img, show_curr_img = [img.cpu().numpy()[0] * sigma + mean for img in [ref_imgs, prev_imgs, curr_imgs]] 334 | 335 | show_gt, show_prev_gt, show_ref_gt, show_preds_s = [label.cpu()[0].squeeze(0).numpy() for label in [curr_labels, prev_labels, ref_labels, curr_pred]] 336 | 337 | show_gtf, show_prev_gtf, show_ref_gtf, show_preds_sf = [label2colormap(label).transpose((2,0,1)) for label in [show_gt, show_prev_gt, show_ref_gt, show_preds_s]] 338 | 339 | if cfg.TRAIN_IMG_LOG or cfg.TRAIN_TBLOG: 340 | 341 | show_ref_img = masked_image(show_ref_img, show_ref_gtf, show_ref_gt) 342 | if cfg.TRAIN_IMG_LOG: 343 | save_image(show_ref_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_ref_img.jpeg' % (step))) 344 | 345 | show_prev_img = masked_image(show_prev_img, show_prev_gtf, show_prev_gt) 346 | if cfg.TRAIN_IMG_LOG: 347 | save_image(show_prev_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_prev_img.jpeg' % (step))) 348 | 349 | show_img_pred = masked_image(show_curr_img, show_preds_sf, show_preds_s) 350 | if cfg.TRAIN_IMG_LOG: 351 | save_image(show_img_pred, os.path.join(cfg.DIR_IMG_LOG, '%06d_prediction.jpeg' % (step))) 352 | 353 | show_curr_img = masked_image(show_curr_img, show_gtf, show_gt) 354 | if cfg.TRAIN_IMG_LOG: 355 | save_image(show_curr_img, os.path.join(cfg.DIR_IMG_LOG, '%06d_groundtruth.jpeg' % (step))) 356 | 357 | if cfg.TRAIN_TBLOG: 358 | for seq_step, running_loss, running_iou in zip(range(len(running_losses)), running_losses, running_ious): 359 | self.tblogger.add_scalar('S{}/Loss'.format(seq_step), running_loss.avg, step) 360 | self.tblogger.add_scalar('S{}/IoU'.format(seq_step), running_iou.avg, step) 361 | 362 | self.tblogger.add_scalar('LR', now_lr, step) 363 | self.tblogger.add_image('Ref/Image', show_ref_img, step) 364 | self.tblogger.add_image('Ref/GT', show_ref_gtf, step) 365 | 366 | self.tblogger.add_image('Prev/Image', show_prev_img, step) 367 | self.tblogger.add_image('Prev/GT', show_prev_gtf, step) 368 | 369 | self.tblogger.add_image('Curr/Image_GT', show_curr_img, step) 370 | self.tblogger.add_image('Curr/Image_Pred', show_img_pred, step) 371 | 372 | self.tblogger.add_image('Curr/Mask_GT', show_gtf, step) 373 | self.tblogger.add_image('Curr/Mask_Pred', show_preds_sf, step) 374 | 375 | for seq_step, boards in enumerate(all_boards): 376 | for key in boards['image'].keys(): 377 | tmp = boards['image'][key].cpu().numpy() 378 | self.tblogger.add_image('S{}/' + key, tmp, step) 379 | for key in boards['scalar'].keys(): 380 | tmp = boards['scalar'][key].cpu().numpy() 381 | self.tblogger.add_scalar('S{}/' + key, tmp, step) 382 | 383 | self.tblogger.flush() 384 | 385 | del(all_boards) 386 | 387 | 388 | -------------------------------------------------------------------------------- /networks/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__init__.py -------------------------------------------------------------------------------- /networks/layers/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/layers/__pycache__/aspp.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__pycache__/aspp.cpython-36.pyc -------------------------------------------------------------------------------- /networks/layers/__pycache__/attention.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__pycache__/attention.cpython-36.pyc -------------------------------------------------------------------------------- /networks/layers/__pycache__/gct.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__pycache__/gct.cpython-36.pyc -------------------------------------------------------------------------------- /networks/layers/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /networks/layers/__pycache__/matching.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__pycache__/matching.cpython-36.pyc -------------------------------------------------------------------------------- /networks/layers/__pycache__/normalization.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__pycache__/normalization.cpython-36.pyc -------------------------------------------------------------------------------- /networks/layers/__pycache__/shannon_entropy.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/layers/__pycache__/shannon_entropy.cpython-36.pyc -------------------------------------------------------------------------------- /networks/layers/aspp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | from torch import nn 5 | from networks.layers.gct import GCT 6 | 7 | class _ASPPModule(nn.Module): 8 | def __init__(self, inplanes, planes, kernel_size, padding, dilation): 9 | super(_ASPPModule, self).__init__() 10 | self.GCT = GCT(inplanes) 11 | self.atrous_conv = nn.Conv2d(inplanes, planes, kernel_size=kernel_size, 12 | stride=1, padding=padding, dilation=dilation, bias=False) 13 | self.bn = nn.GroupNorm(int(planes / 4), planes) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | self._init_weight() 17 | 18 | def forward(self, x): 19 | x = self.GCT(x) 20 | x = self.atrous_conv(x) 21 | x = self.bn(x) 22 | 23 | return self.relu(x) 24 | 25 | def _init_weight(self): 26 | for m in self.modules(): 27 | if isinstance(m, nn.Conv2d): 28 | torch.nn.init.kaiming_normal_(m.weight) 29 | elif isinstance(m, nn.BatchNorm2d): 30 | m.weight.data.fill_(1) 31 | m.bias.data.zero_() 32 | 33 | class ASPP(nn.Module): 34 | def __init__(self): 35 | super(ASPP, self).__init__() 36 | 37 | inplanes = 512 38 | dilations = [1, 6, 12, 18] 39 | 40 | 41 | self.aspp1 = _ASPPModule(inplanes, 128, 1, padding=0, dilation=dilations[0]) 42 | self.aspp2 = _ASPPModule(inplanes, 128, 3, padding=dilations[1], dilation=dilations[1]) 43 | self.aspp3 = _ASPPModule(inplanes, 128, 3, padding=dilations[2], dilation=dilations[2]) 44 | self.aspp4 = _ASPPModule(inplanes, 128, 3, padding=dilations[3], dilation=dilations[3]) 45 | 46 | self.global_avg_pool = nn.Sequential(nn.AdaptiveAvgPool2d((1, 1)), 47 | nn.Conv2d(inplanes, 128, 1, stride=1, bias=False), 48 | nn.ReLU(inplace=True)) 49 | 50 | self.GCT = GCT(640) 51 | self.conv1 = nn.Conv2d(640, 256, 1, bias=False) 52 | self.bn1 = nn.GroupNorm(32, 256) 53 | self.relu = nn.ReLU(inplace=True) 54 | self._init_weight() 55 | 56 | def forward(self, x): 57 | x1 = self.aspp1(x) 58 | x2 = self.aspp2(x) 59 | x3 = self.aspp3(x) 60 | x4 = self.aspp4(x) 61 | x5 = self.global_avg_pool(x) 62 | x5 = F.interpolate(x5, size=x4.size()[2:], mode='bilinear', align_corners=True) 63 | x = torch.cat((x1, x2, x3, x4, x5), dim=1) 64 | 65 | x = self.GCT(x) 66 | x = self.conv1(x) 67 | x = self.bn1(x) 68 | x = self.relu(x) 69 | 70 | return x 71 | 72 | def _init_weight(self): 73 | for m in self.modules(): 74 | if isinstance(m, nn.Conv2d): 75 | torch.nn.init.kaiming_normal_(m.weight) 76 | elif isinstance(m, nn.BatchNorm2d): 77 | m.weight.data.fill_(1) 78 | m.bias.data.zero_() -------------------------------------------------------------------------------- /networks/layers/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | from torch import nn 5 | 6 | 7 | class IA_gate(nn.Module): 8 | def __init__(self, in_dim, out_dim): 9 | super(IA_gate, self).__init__() 10 | self.IA = nn.Linear(in_dim, out_dim) 11 | 12 | def forward(self, x, IA_head): 13 | a = self.IA(IA_head) 14 | a = 1. + torch.tanh(a) 15 | a = a.unsqueeze(-1).unsqueeze(-1) 16 | x = a * x 17 | return x 18 | 19 | def calculate_attention_head(ref_embedding, ref_label, prev_embedding, prev_label, epsilon=1e-5): 20 | 21 | ref_head = ref_embedding * ref_label 22 | ref_head_pos = torch.sum(ref_head, dim=(2,3)) 23 | ref_head_neg = torch.sum(ref_embedding, dim=(2,3)) - ref_head_pos 24 | ref_pos_num = torch.sum(ref_label, dim=(2,3)) 25 | ref_neg_num = torch.sum(1. - ref_label, dim=(2,3)) 26 | ref_head_pos = ref_head_pos / (ref_pos_num + epsilon) 27 | ref_head_neg = ref_head_neg / (ref_neg_num + epsilon) 28 | 29 | prev_head = prev_embedding * prev_label 30 | prev_head_pos = torch.sum(prev_head, dim=(2,3)) 31 | prev_head_neg = torch.sum(prev_embedding, dim=(2,3)) - prev_head_pos 32 | prev_pos_num = torch.sum(prev_label, dim=(2,3)) 33 | prev_neg_num = torch.sum(1. - prev_label, dim=(2,3)) 34 | prev_head_pos = prev_head_pos / (prev_pos_num + epsilon) 35 | prev_head_neg = prev_head_neg / (prev_neg_num + epsilon) 36 | 37 | total_head = torch.cat([ref_head_pos, ref_head_neg, prev_head_pos, prev_head_neg], dim=1) 38 | 39 | return total_head 40 | 41 | def calculate_attention_head_for_eval(ref_embeddings, ref_labels, prev_embedding, prev_label, epsilon=1e-5): 42 | total_ref_head_pos = 0. 43 | total_ref_head_neg = 0. 44 | total_ref_pos_num = 0. 45 | total_ref_neg_num = 0. 46 | 47 | for idx in range(len(ref_embeddings)): 48 | ref_embedding = ref_embeddings[idx] 49 | ref_label = ref_labels[idx] 50 | ref_head = ref_embedding * ref_label 51 | ref_head_pos = torch.sum(ref_head, dim=(2,3)) 52 | ref_head_neg = torch.sum(ref_embedding, dim=(2,3)) - ref_head_pos 53 | ref_pos_num = torch.sum(ref_label, dim=(2,3)) 54 | ref_neg_num = torch.sum(1. - ref_label, dim=(2,3)) 55 | total_ref_head_pos = total_ref_head_pos + ref_head_pos 56 | total_ref_head_neg = total_ref_head_neg + ref_head_neg 57 | total_ref_pos_num = total_ref_pos_num + ref_pos_num 58 | total_ref_neg_num = total_ref_neg_num + ref_neg_num 59 | ref_head_pos = total_ref_head_pos / (total_ref_pos_num + epsilon) 60 | ref_head_neg = total_ref_head_neg / (total_ref_neg_num + epsilon) 61 | 62 | prev_head = prev_embedding * prev_label 63 | prev_head_pos = torch.sum(prev_head, dim=(2,3)) 64 | prev_head_neg = torch.sum(prev_embedding, dim=(2,3)) - prev_head_pos 65 | prev_pos_num = torch.sum(prev_label, dim=(2,3)) 66 | prev_neg_num = torch.sum(1. - prev_label, dim=(2,3)) 67 | prev_head_pos = prev_head_pos / (prev_pos_num + epsilon) 68 | prev_head_neg = prev_head_neg / (prev_neg_num + epsilon) 69 | 70 | total_head = torch.cat([ref_head_pos, ref_head_neg, prev_head_pos, prev_head_neg], dim=1) 71 | return total_head 72 | 73 | -------------------------------------------------------------------------------- /networks/layers/gct.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import math 4 | from torch import nn 5 | 6 | 7 | class GCT(nn.Module): 8 | def __init__(self, num_channels, epsilon=1e-5, mode='l2', after_relu=False): 9 | super(GCT, self).__init__() 10 | self.alpha = nn.Parameter(torch.ones(1, num_channels, 1, 1)) 11 | self.gamma = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) 12 | self.beta = nn.Parameter(torch.zeros(1, num_channels, 1, 1)) 13 | self.epsilon = epsilon 14 | self.mode = mode 15 | self.after_relu = after_relu 16 | 17 | def forward(self, x): 18 | 19 | if self.mode == 'l2': 20 | embedding = (x.pow(2).sum((2,3), keepdim=True) + self.epsilon).pow(0.5) * self.alpha 21 | norm = self.gamma / (embedding.pow(2).mean(dim=1, keepdim=True) + self.epsilon).pow(0.5) 22 | 23 | elif self.mode == 'l1': 24 | if not self.after_relu: 25 | _x = torch.abs(x) 26 | else: 27 | _x = x 28 | embedding = _x.sum((2,3), keepdim=True) * self.alpha 29 | norm = self.gamma / (torch.abs(embedding).mean(dim=1, keepdim=True) + self.epsilon) 30 | else: 31 | print('Unknown mode!') 32 | exit() 33 | 34 | gate = 1. + torch.tanh(embedding * norm + self.beta) 35 | 36 | return x * gate 37 | 38 | class Bottleneck(nn.Module): 39 | def __init__(self, inplanes, outplanes, stride=1, dilation=1): 40 | super(Bottleneck, self).__init__() 41 | expansion = 4 42 | planes = int(outplanes / expansion) 43 | self.GCT1 = GCT(inplanes) 44 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 45 | self.bn1 = nn.GroupNorm(32, planes) 46 | 47 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 48 | dilation=dilation, padding=dilation, bias=False) 49 | self.bn2 = nn.GroupNorm(32, planes) 50 | 51 | self.conv3 = nn.Conv2d(planes, planes * expansion, kernel_size=1, bias=False) 52 | self.bn3 = nn.GroupNorm(32, planes * expansion) 53 | self.relu = nn.ReLU(inplace=True) 54 | if stride != 1 or inplanes != planes * expansion: 55 | downsample = nn.Sequential( 56 | nn.Conv2d(inplanes, planes * expansion, 57 | kernel_size=1, stride=stride, bias=False), 58 | nn.GroupNorm(32, planes * expansion), 59 | ) 60 | else: 61 | downsample = None 62 | self.downsample = downsample 63 | self.stride = stride 64 | self.dilation = dilation 65 | 66 | for m in self.modules(): 67 | if isinstance(m, nn.Conv2d): 68 | nn.init.kaiming_normal_(m.weight,mode='fan_out', nonlinearity='relu') 69 | 70 | def forward(self, x): 71 | residual = x 72 | 73 | out = self.GCT1(x) 74 | out = self.conv1(out) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out -------------------------------------------------------------------------------- /networks/layers/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | 5 | class Concat_BCEWithLogitsLoss(nn.Module): 6 | def __init__(self, top_k_percent_pixels=None, 7 | hard_example_mining_step=100000): 8 | super(Concat_BCEWithLogitsLoss, self).__init__() 9 | self.top_k_percent_pixels = top_k_percent_pixels 10 | if top_k_percent_pixels is not None: 11 | assert(top_k_percent_pixels > 0 and top_k_percent_pixels < 1) 12 | self.hard_example_mining_step = hard_example_mining_step 13 | if self.top_k_percent_pixels == None: 14 | self.bceloss = nn.BCEWithLogitsLoss(reduction='mean') 15 | else: 16 | self.bceloss = nn.BCEWithLogitsLoss(reduction='none') 17 | 18 | def forward(self, dic_tmp, y, step): 19 | total_loss = [] 20 | for i in range(len(dic_tmp)): 21 | pred_logits = dic_tmp[i] 22 | gts = y[i] 23 | if self.top_k_percent_pixels == None: 24 | final_loss = self.bceloss(pred_logits, gts) 25 | else: 26 | # Only compute the loss for top k percent pixels. 27 | # First, compute the loss for all pixels. Note we do not put the loss 28 | # to loss_collection and set reduction = None to keep the shape. 29 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) 30 | pred_logits = pred_logits.view(-1, pred_logits.size( 31 | 1), pred_logits.size(2) * pred_logits.size(3)) 32 | gts = gts.view(-1, gts.size(1), gts.size(2) * gts.size(3)) 33 | pixel_losses = self.bceloss(pred_logits, gts) 34 | if self.hard_example_mining_step == 0: 35 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels) 36 | else: 37 | ratio = min( 38 | 1.0, step / float(self.hard_example_mining_step)) 39 | top_k_pixels = int( 40 | (ratio * self.top_k_percent_pixels + (1.0 - ratio)) * num_pixels) 41 | _, top_k_indices = torch.topk( 42 | pixel_losses, k=top_k_pixels, dim=2) 43 | 44 | final_loss = nn.BCEWithLogitsLoss( 45 | weight=top_k_indices, reduction='mean')(pred_logits, gts) 46 | final_loss = final_loss.unsqueeze(0) 47 | total_loss.append(final_loss) 48 | total_loss = torch.cat(total_loss, dim=0) 49 | return total_loss 50 | 51 | 52 | class Concat_CrossEntropyLoss(nn.Module): 53 | def __init__(self, top_k_percent_pixels=None, 54 | hard_example_mining_step=100000): 55 | super(Concat_CrossEntropyLoss, self).__init__() 56 | self.top_k_percent_pixels = top_k_percent_pixels 57 | if top_k_percent_pixels is not None: 58 | assert(top_k_percent_pixels > 0 and top_k_percent_pixels < 1) 59 | self.hard_example_mining_step = hard_example_mining_step 60 | if self.top_k_percent_pixels == None: 61 | self.celoss = nn.CrossEntropyLoss( 62 | ignore_index=255, reduction='mean') 63 | else: 64 | self.celoss = nn.CrossEntropyLoss( 65 | ignore_index=255, reduction='none') 66 | 67 | def forward(self, dic_tmp, y, step): 68 | total_loss = [] 69 | for i in range(len(dic_tmp)): 70 | pred_logits = dic_tmp[i] 71 | gts = y[i] 72 | if self.top_k_percent_pixels == None: 73 | final_loss = self.celoss(pred_logits, gts) 74 | else: 75 | # Only compute the loss for top k percent pixels. 76 | # First, compute the loss for all pixels. Note we do not put the loss 77 | # to loss_collection and set reduction = None to keep the shape. 78 | num_pixels = float(pred_logits.size(2) * pred_logits.size(3)) 79 | pred_logits = pred_logits.view(-1, pred_logits.size( 80 | 1), pred_logits.size(2) * pred_logits.size(3)) 81 | gts = gts.view(-1, gts.size(1) * gts.size(2)) 82 | pixel_losses = self.celoss(pred_logits, gts) 83 | if self.hard_example_mining_step == 0: 84 | top_k_pixels = int(self.top_k_percent_pixels * num_pixels) 85 | else: 86 | ratio = min( 87 | 1.0, step / float(self.hard_example_mining_step)) 88 | top_k_pixels = int( 89 | (ratio * self.top_k_percent_pixels + (1.0 - ratio)) * num_pixels) 90 | top_k_loss, top_k_indices = torch.topk( 91 | pixel_losses, k=top_k_pixels, dim=1) 92 | 93 | final_loss = torch.mean(top_k_loss) 94 | final_loss = final_loss.unsqueeze(0) 95 | total_loss.append(final_loss) 96 | total_loss = torch.cat(total_loss, dim=0) 97 | return total_loss 98 | -------------------------------------------------------------------------------- /networks/layers/matching.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | def foreground2background(dis, obj_num): 7 | if obj_num == 1: 8 | return dis 9 | bg_dis = [] 10 | for i in range(obj_num): 11 | obj_back = [] 12 | for j in range(obj_num): 13 | if i == j: 14 | continue 15 | obj_back.append(dis[j].unsqueeze(0)) 16 | obj_back = torch.cat(obj_back, dim=1) 17 | obj_back, _ = torch.min(obj_back, dim=1, keepdim=True) 18 | bg_dis.append(obj_back) 19 | bg_dis = torch.cat(bg_dis, dim=0) 20 | return bg_dis 21 | 22 | WRONG_LABEL_PADDING_DISTANCE = 5e4 23 | #############################################################GLOBAL_DIST_MAP 24 | def _pairwise_distances(x, x2, y, y2): 25 | """ 26 | Computes pairwise squared l2 distances between tensors x and y. 27 | Args: 28 | x: [n, feature_dim]. 29 | y: [m, feature_dim]. 30 | Returns: 31 | d: [n, m]. 32 | """ 33 | xs = x2 34 | ys = y2 35 | 36 | xs = xs.unsqueeze(1) 37 | ys = ys.unsqueeze(0) 38 | d = xs + ys - 2. * torch.matmul(x, torch.t(y)) 39 | return d 40 | 41 | ################## 42 | def _flattened_pairwise_distances(reference_embeddings, ref_square, query_embeddings, query_square): 43 | """ 44 | Calculates flattened tensor of pairwise distances between ref and query. 45 | Args: 46 | reference_embeddings: [..., embedding_dim], 47 | the embedding vectors for the reference frame 48 | query_embeddings: [..., embedding_dim], 49 | the embedding vectors for the query frames. 50 | Returns: 51 | dists: [reference_embeddings.size / embedding_dim, query_embeddings.size / embedding_dim] 52 | """ 53 | dists = _pairwise_distances(query_embeddings, query_square, reference_embeddings, ref_square) 54 | return dists 55 | 56 | def _nn_features_per_object_for_chunk( 57 | reference_embeddings, ref_square, query_embeddings, query_square, wrong_label_mask): 58 | """Extracts features for each object using nearest neighbor attention. 59 | Args: 60 | reference_embeddings: [n_chunk, embedding_dim], 61 | the embedding vectors for the reference frame. 62 | query_embeddings: [m_chunk, embedding_dim], 63 | the embedding vectors for the query frames. 64 | wrong_label_mask: [n_objects, n_chunk], 65 | the mask for pixels not used for matching. 66 | Returns: 67 | nn_features: A float32 tensor of nearest neighbor features of shape 68 | [m_chunk, n_objects, n_chunk]. 69 | """ 70 | if reference_embeddings.dtype == torch.float16: 71 | wrong_label_mask = wrong_label_mask.half() 72 | else: 73 | wrong_label_mask = wrong_label_mask.float() 74 | 75 | reference_embeddings_key = reference_embeddings 76 | query_embeddings_key = query_embeddings 77 | dists = _flattened_pairwise_distances(reference_embeddings_key, ref_square, query_embeddings_key, query_square) 78 | 79 | dists = (torch.unsqueeze(dists, 1) + 80 | torch.unsqueeze(wrong_label_mask, 0) * 81 | WRONG_LABEL_PADDING_DISTANCE) 82 | 83 | features, _ = torch.min(dists, 2, keepdim=True) 84 | return features 85 | 86 | def _nearest_neighbor_features_per_object_in_chunks( 87 | reference_embeddings_flat, query_embeddings_flat, reference_labels_flat, n_chunks): 88 | """Calculates the nearest neighbor features per object in chunks to save mem. 89 | Uses chunking to bound the memory use. 90 | Args: 91 | reference_embeddings_flat: [n, embedding_dim], 92 | the embedding vectors for the reference frame. 93 | query_embeddings_flat: [m, embedding_dim], 94 | the embedding vectors for the query frames. 95 | reference_labels_flat: [n, n_objects], 96 | the class labels of the reference frame. 97 | n_chunks: Integer, the number of chunks to use to save memory 98 | (set to 1 for no chunking). 99 | Returns: 100 | nn_features: [m, n_objects, n]. 101 | """ 102 | 103 | feature_dim, embedding_dim = query_embeddings_flat.size() 104 | chunk_size = int(np.ceil(float(feature_dim) / n_chunks)) 105 | wrong_label_mask = reference_labels_flat < 0.1 106 | wrong_label_mask = wrong_label_mask.permute(1, 0) 107 | ref_square = reference_embeddings_flat.pow(2).sum(1) 108 | query_square = query_embeddings_flat.pow(2).sum(1) 109 | 110 | all_features = [] 111 | for n in range(n_chunks): 112 | if n_chunks == 1: 113 | query_embeddings_flat_chunk = query_embeddings_flat 114 | query_square_chunk = query_square 115 | chunk_start = 0 116 | else: 117 | chunk_start = n * chunk_size 118 | chunk_end = (n + 1) * chunk_size 119 | query_square_chunk = query_square[chunk_start:chunk_end] 120 | if query_square_chunk.size(0) == 0: 121 | continue 122 | query_embeddings_flat_chunk = query_embeddings_flat[chunk_start:chunk_end] 123 | features = _nn_features_per_object_for_chunk( 124 | reference_embeddings_flat, ref_square, query_embeddings_flat_chunk, query_square_chunk, 125 | wrong_label_mask) 126 | all_features.append(features) 127 | if n_chunks == 1: 128 | nn_features = all_features[0] 129 | else: 130 | nn_features = torch.cat(all_features, dim=0) 131 | 132 | 133 | return nn_features 134 | 135 | 136 | def global_matching( 137 | reference_embeddings, query_embeddings, reference_labels, 138 | n_chunks=100, dis_bias=0., ori_size=None, atrous_rate=1, use_float16=True, atrous_obj_pixel_num=0): 139 | """ 140 | Calculates the distance to the nearest neighbor per object. 141 | For every pixel of query_embeddings calculate the distance to the 142 | nearest neighbor in the (possibly subsampled) reference_embeddings per object. 143 | Args: 144 | reference_embeddings: [height, width, embedding_dim], 145 | the embedding vectors for the reference frame. 146 | query_embeddings: [height, width, 147 | embedding_dim], the embedding vectors for the query frames. 148 | reference_labels: [height, width, obj_nums], 149 | the class labels of the reference frame. 150 | n_chunks: Integer, the number of chunks to use to save memory 151 | (set to 1 for no chunking). 152 | dis_bias: [n_objects], foreground and background bias 153 | ori_size: (ori_height, ori_width), 154 | the original spatial size. If "None", (ori_height, ori_width) = (height, width). 155 | atrous_rate: Integer, the atrous rate of reference_embeddings. 156 | use_float16: Bool, if "True", use float16 type for matching. 157 | Returns: 158 | nn_features: [1, ori_height, ori_width, n_objects, feature_dim]. 159 | """ 160 | 161 | assert (reference_embeddings.size()[:2] == reference_labels.size()[:2]) 162 | if use_float16: 163 | query_embeddings = query_embeddings.half() 164 | reference_embeddings = reference_embeddings.half() 165 | h, w, embedding_dim = query_embeddings.size() 166 | obj_nums = reference_labels.size(2) 167 | 168 | if atrous_rate > 1: 169 | h_pad = (atrous_rate - h % atrous_rate) % atrous_rate 170 | w_pad = (atrous_rate - w % atrous_rate) % atrous_rate 171 | selected_points = torch.zeros(h + h_pad, w + w_pad, device=query_embeddings.device) 172 | selected_points = selected_points.view((h + h_pad) // atrous_rate, atrous_rate, 173 | (w + w_pad) // atrous_rate, atrous_rate) 174 | selected_points[:, 0, :, 0] = 1. 175 | selected_points = selected_points.view(h + h_pad, w + w_pad, 1)[:h, :w] 176 | is_big_obj = reference_labels.sum(dim=(0, 1)) > (atrous_obj_pixel_num * atrous_rate ** 2) 177 | reference_labels[:, :, is_big_obj] = reference_labels[:, :, is_big_obj] * selected_points 178 | 179 | reference_embeddings_flat = reference_embeddings.view(-1, embedding_dim) 180 | reference_labels_flat = reference_labels.view(-1, obj_nums) 181 | query_embeddings_flat = query_embeddings.view(-1, embedding_dim) 182 | 183 | all_ref_fg = torch.sum(reference_labels_flat, dim=1, keepdim=True) > 0.9 184 | reference_labels_flat = torch.masked_select(reference_labels_flat, 185 | all_ref_fg.expand(-1, obj_nums)).view(-1, obj_nums) 186 | if reference_labels_flat.size(0) == 0: 187 | return torch.ones(1, h, w, obj_nums, 1, device=all_ref_fg.device) 188 | reference_embeddings_flat = torch.masked_select(reference_embeddings_flat, 189 | all_ref_fg.expand(-1, embedding_dim)).view(-1, embedding_dim) 190 | 191 | nn_features = _nearest_neighbor_features_per_object_in_chunks( 192 | reference_embeddings_flat, query_embeddings_flat, reference_labels_flat, 193 | n_chunks) 194 | 195 | nn_features_reshape = nn_features.view(1, h, w, obj_nums, 1) 196 | nn_features_reshape = (torch.sigmoid(nn_features_reshape + dis_bias.view(1, 1, 1, -1, 1)) - 0.5) * 2 197 | 198 | if ori_size is not None: 199 | nn_features_reshape = nn_features_reshape.view(h, w, obj_nums, 1).permute(2, 3, 0, 1) 200 | nn_features_reshape = F.interpolate(nn_features_reshape, size=ori_size, 201 | mode='bilinear', align_corners=True).permute(2, 3, 0, 1).view(1, ori_size[0], ori_size[1], obj_nums, 1) 202 | 203 | if use_float16: 204 | nn_features_reshape = nn_features_reshape.float() 205 | return nn_features_reshape 206 | 207 | 208 | def global_matching_for_eval( 209 | all_reference_embeddings, query_embeddings, all_reference_labels, 210 | n_chunks=20, dis_bias=0., ori_size=None, atrous_rate=1, use_float16=True, atrous_obj_pixel_num=0): 211 | """ 212 | Calculates the distance to the nearest neighbor per object. 213 | For every pixel of query_embeddings calculate the distance to the 214 | nearest neighbor in the (possibly subsampled) reference_embeddings per object. 215 | Args: 216 | all_reference_embeddings: A list of reference_embeddings, 217 | each with size [height, width, embedding_dim], 218 | the embedding vectors for the reference frame. 219 | query_embeddings: [n_query_images, height, width, 220 | embedding_dim], the embedding vectors for the query frames. 221 | all_reference_labels: A list of reference_labels, 222 | each with size [height, width, obj_nums], 223 | the class labels of the reference frame. 224 | n_chunks: Integer, the number of chunks to use to save memory 225 | (set to 1 for no chunking). 226 | dis_bias: [n_objects], foreground and background bias 227 | ori_size: (ori_height, ori_width), 228 | the original spatial size. If "None", (ori_height, ori_width) = (height, width). 229 | atrous_rate: Integer, the atrous rate of reference_embeddings. 230 | use_float16: Bool, if "True", use float16 type for matching. 231 | Returns: 232 | nn_features: [n_query_images, ori_height, ori_width, n_objects, feature_dim]. 233 | """ 234 | 235 | h, w, embedding_dim = query_embeddings.size() 236 | obj_nums = all_reference_labels[0].size(2) 237 | all_reference_embeddings_flat = [] 238 | all_reference_labels_flat = [] 239 | ref_num = len(all_reference_labels) 240 | n_chunks *= ref_num 241 | if atrous_obj_pixel_num > 0: 242 | if atrous_rate > 1: 243 | h_pad = (atrous_rate - h % atrous_rate) % atrous_rate 244 | w_pad = (atrous_rate - w % atrous_rate) % atrous_rate 245 | selected_points = torch.zeros(h + h_pad, w + w_pad, device=query_embeddings.device) 246 | selected_points = selected_points.view((h + h_pad) // atrous_rate, atrous_rate, 247 | (w + w_pad) // atrous_rate, atrous_rate) 248 | selected_points[:, 0, :, 0] = 1. 249 | selected_points = selected_points.view(h + h_pad, w + w_pad, 1)[:h, :w] 250 | 251 | for reference_embeddings, reference_labels, idx in zip(all_reference_embeddings, all_reference_labels, range(ref_num)): 252 | if atrous_rate > 1: 253 | is_big_obj = reference_labels.sum(dim=(0, 1)) > (atrous_obj_pixel_num * atrous_rate ** 2) 254 | reference_labels[:, :, is_big_obj] = reference_labels[:, :, is_big_obj] * selected_points 255 | 256 | reference_embeddings_flat = reference_embeddings.view(-1, embedding_dim) 257 | reference_labels_flat = reference_labels.view(-1, obj_nums) 258 | 259 | all_reference_embeddings_flat.append(reference_embeddings_flat) 260 | all_reference_labels_flat.append(reference_labels_flat) 261 | 262 | reference_embeddings_flat = torch.cat(all_reference_embeddings_flat, dim=0) 263 | reference_labels_flat = torch.cat(all_reference_labels_flat, dim=0) 264 | else: 265 | if ref_num == 1: 266 | reference_embeddings, reference_labels = all_reference_embeddings[0], all_reference_labels[0] 267 | if atrous_rate > 1: 268 | h_pad = (atrous_rate - h % atrous_rate) % atrous_rate 269 | w_pad = (atrous_rate - w % atrous_rate) % atrous_rate 270 | if h_pad > 0 or w_pad > 0: 271 | reference_embeddings = F.pad(reference_embeddings, (0, 0, 0, w_pad, 0, h_pad)) 272 | reference_labels = F.pad(reference_labels, (0, 0, 0, w_pad, 0, h_pad)) 273 | 274 | reference_embeddings = reference_embeddings.view((h + h_pad) // atrous_rate, atrous_rate, 275 | (w + w_pad) // atrous_rate, atrous_rate, -1) 276 | reference_labels = reference_labels.view((h + h_pad) // atrous_rate, atrous_rate, 277 | (w + w_pad) // atrous_rate, atrous_rate, -1) 278 | reference_embeddings = reference_embeddings[:, 0, :, 0, :].contiguous() 279 | reference_labels = reference_labels[:, 0, :, 0, :].contiguous() 280 | reference_embeddings_flat = reference_embeddings.view(-1, embedding_dim) 281 | reference_labels_flat = reference_labels.view(-1, obj_nums) 282 | else: 283 | 284 | for reference_embeddings, reference_labels, idx in zip(all_reference_embeddings, all_reference_labels, range(ref_num)): 285 | if atrous_rate > 1: 286 | h_pad = (atrous_rate - h % atrous_rate) % atrous_rate 287 | w_pad = (atrous_rate - w % atrous_rate) % atrous_rate 288 | if h_pad > 0 or w_pad > 0: 289 | reference_embeddings = F.pad(reference_embeddings, (0, 0, 0, w_pad, 0, h_pad)) 290 | reference_labels = F.pad(reference_labels, (0, 0, 0, w_pad, 0, h_pad)) 291 | 292 | reference_embeddings = reference_embeddings.view((h + h_pad) // atrous_rate, atrous_rate, 293 | (w + w_pad) // atrous_rate, atrous_rate, -1) 294 | reference_labels = reference_labels.view((h + h_pad) // atrous_rate, atrous_rate, 295 | (w + w_pad) // atrous_rate, atrous_rate, -1) 296 | reference_embeddings = reference_embeddings[:, 0, :, 0, :].contiguous() 297 | reference_labels = reference_labels[:, 0, :, 0, :].contiguous() 298 | 299 | 300 | reference_embeddings_flat = reference_embeddings.view(-1, embedding_dim) 301 | reference_labels_flat = reference_labels.view(-1, obj_nums) 302 | 303 | all_reference_embeddings_flat.append(reference_embeddings_flat) 304 | all_reference_labels_flat.append(reference_labels_flat) 305 | 306 | reference_embeddings_flat = torch.cat(all_reference_embeddings_flat, dim=0) 307 | reference_labels_flat = torch.cat(all_reference_labels_flat, dim=0) 308 | 309 | 310 | 311 | query_embeddings_flat = query_embeddings.view(-1, embedding_dim) 312 | 313 | all_ref_fg = torch.sum(reference_labels_flat, dim=1, keepdim=True) > 0.9 314 | reference_labels_flat = torch.masked_select(reference_labels_flat, 315 | all_ref_fg.expand(-1, obj_nums)).view(-1, obj_nums) 316 | if reference_labels_flat.size(0) == 0: 317 | return torch.ones(1, h, w, obj_nums, 1, device=all_ref_fg.device) 318 | reference_embeddings_flat = torch.masked_select(reference_embeddings_flat, 319 | all_ref_fg.expand(-1, embedding_dim)).view(-1, embedding_dim) 320 | 321 | if use_float16: 322 | query_embeddings_flat = query_embeddings_flat.half() 323 | reference_embeddings_flat = reference_embeddings_flat.half() 324 | nn_features = _nearest_neighbor_features_per_object_in_chunks( 325 | reference_embeddings_flat, query_embeddings_flat, reference_labels_flat, n_chunks) 326 | 327 | nn_features_reshape = nn_features.view(1, h, w, obj_nums, 1) 328 | nn_features_reshape = (torch.sigmoid(nn_features_reshape + dis_bias.view(1, 1, 1, -1, 1)) - 0.5) * 2 329 | 330 | if ori_size is not None: 331 | nn_features_reshape = nn_features_reshape.view(h, w, obj_nums, 1).permute(2, 3, 0, 1) 332 | nn_features_reshape = F.interpolate(nn_features_reshape, size=ori_size, 333 | mode='bilinear', align_corners=True).permute(2, 3, 0, 1).view(1, ori_size[0], ori_size[1], obj_nums, 1) 334 | 335 | if use_float16: 336 | nn_features_reshape = nn_features_reshape.float() 337 | return nn_features_reshape 338 | 339 | ########################################################################LOCAL_DIST_MAP 340 | def local_pairwise_distances( 341 | x, y, max_distance=9, atrous_rate=1, allow_downsample=False): 342 | """Computes pairwise squared l2 distances using a local search window. 343 | Use for-loop for saving memory. 344 | Args: 345 | x: Float32 tensor of shape [height, width, feature_dim]. 346 | y: Float32 tensor of shape [height, width, feature_dim]. 347 | max_distance: Integer, the maximum distance in pixel coordinates 348 | per dimension which is considered to be in the search window. 349 | atrous_rate: Integer, the atrous rate of local matching. 350 | allow_downsample: Bool, if "True", downsample x and y 351 | with a stride of 2. 352 | Returns: 353 | Float32 distances tensor of shape [height, width, (2 * max_distance + 1) ** 2]. 354 | """ 355 | if allow_downsample: 356 | ori_height, ori_width, _ = x.size() 357 | x = x.permute(2, 0, 1).unsqueeze(0) 358 | y = y.permute(2, 0, 1).unsqueeze(0) 359 | down_size = (int(ori_height/2) + 1, int(ori_width/2) + 1) 360 | x = F.interpolate(x, size=down_size, mode='bilinear', align_corners=True) 361 | y = F.interpolate(y, size=down_size, mode='bilinear', align_corners=True) 362 | x = x.squeeze(0).permute(1, 2, 0) 363 | y = y.squeeze(0).permute(1, 2, 0) 364 | 365 | pad_max_distance = max_distance - max_distance % atrous_rate 366 | padded_y =nn.functional.pad(y, 367 | (0, 0, pad_max_distance, pad_max_distance, pad_max_distance, pad_max_distance), 368 | mode='constant', value=WRONG_LABEL_PADDING_DISTANCE) 369 | 370 | height, width, _ = x.size() 371 | dists = [] 372 | for y in range(2 * pad_max_distance // atrous_rate + 1): 373 | y_start = y * atrous_rate 374 | y_end = y_start + height 375 | y_slice = padded_y[y_start:y_end] 376 | for x in range(2 * max_distance + 1): 377 | x_start = x * atrous_rate 378 | x_end = x_start + width 379 | offset_y = y_slice[:, x_start:x_end] 380 | dist = torch.sum(torch.pow((x-offset_y),2), dim=2) 381 | dists.append(dist) 382 | dists = torch.stack(dists, dim=2) 383 | 384 | return dists 385 | 386 | def local_pairwise_distances_parallel( 387 | x, y, max_distance=9, atrous_rate=1, allow_downsample=True): 388 | """Computes pairwise squared l2 distances using a local search window. 389 | Args: 390 | x: Float32 tensor of shape [height, width, feature_dim]. 391 | y: Float32 tensor of shape [height, width, feature_dim]. 392 | max_distance: Integer, the maximum distance in pixel coordinates 393 | per dimension which is considered to be in the search window. 394 | atrous_rate: Integer, the atrous rate of local matching. 395 | allow_downsample: Bool, if "True", downsample x and y 396 | with a stride of 2. 397 | Returns: 398 | Float32 distances tensor of shape [height, width, (2 * max_distance + 1) ** 2]. 399 | """ 400 | ori_height, ori_width, _ = x.size() 401 | x = x.permute(2, 0, 1).unsqueeze(0) 402 | y = y.permute(2, 0, 1).unsqueeze(0) 403 | if allow_downsample: 404 | down_size = (int(ori_height/2) + 1, int(ori_width/2) + 1) 405 | x = F.interpolate(x, size=down_size, mode='bilinear', align_corners=True) 406 | y = F.interpolate(y, size=down_size, mode='bilinear', align_corners=True) 407 | 408 | _, channels, height, width = x.size() 409 | 410 | x2 = x.pow(2).sum(1).view(height, width, 1) 411 | 412 | y2 = y.pow(2).sum(1).view(1, 1, height, width) 413 | 414 | pad_max_distance = max_distance - max_distance % atrous_rate 415 | 416 | padded_y = F.pad(y, (pad_max_distance, pad_max_distance, pad_max_distance, pad_max_distance)) 417 | padded_y2 = F.pad(y2, (pad_max_distance, pad_max_distance, pad_max_distance, pad_max_distance), 418 | mode='constant', value=WRONG_LABEL_PADDING_DISTANCE) 419 | 420 | offset_y = F.unfold(padded_y, kernel_size=(height, width), 421 | stride=(atrous_rate, atrous_rate)).view(channels, height * width, -1).permute(1, 0, 2) 422 | offset_y2 = F.unfold(padded_y2, kernel_size=(height, width), 423 | stride=(atrous_rate, atrous_rate)).view(height, width, -1) 424 | x = x.view(channels, height * width, -1).permute(1, 2, 0) 425 | 426 | dists = x2 + offset_y2 - 2. * torch.matmul(x, offset_y).view(height, width, -1) 427 | 428 | return dists 429 | 430 | 431 | 432 | 433 | def local_matching( 434 | prev_frame_embedding, query_embedding, prev_frame_labels, 435 | dis_bias=0., multi_local_distance=[15], 436 | ori_size=None, atrous_rate=1, use_float16=True, allow_downsample=True, allow_parallel=True): 437 | """Computes nearest neighbor features while only allowing local matches. 438 | Args: 439 | prev_frame_embedding: [height, width, embedding_dim], 440 | the embedding vectors for the last frame. 441 | query_embedding: [height, width, embedding_dim], 442 | the embedding vectors for the query frames. 443 | prev_frame_labels: [height, width, n_objects], 444 | the class labels of the previous frame. 445 | multi_local_distance: A list of Integer, 446 | a list of maximum distance allowed for local matching. 447 | ori_size: (ori_height, ori_width), 448 | the original spatial size. If "None", (ori_height, ori_width) = (height, width). 449 | atrous_rate: Integer, the atrous rate of local matching. 450 | use_float16: Bool, if "True", use float16 type for matching. 451 | allow_downsample: Bool, if "True", downsample prev_frame_embedding and query_embedding 452 | with a stride of 2. 453 | allow_parallel: Bool, if "True", do matching in a parallel way. If "False", do matching in 454 | a for-loop way, which will save GPU memory. 455 | Returns: 456 | nn_features: A float32 np.array of nearest neighbor features of shape 457 | [1, height, width, n_objects, 1]. 458 | """ 459 | max_distance = multi_local_distance[-1] 460 | 461 | if ori_size is None: 462 | height, width = prev_frame_embedding.size()[:2] 463 | ori_size = (height, width) 464 | 465 | obj_num = prev_frame_labels.size(2) 466 | pad = torch.ones(1, device=prev_frame_embedding.device) * WRONG_LABEL_PADDING_DISTANCE 467 | if use_float16: 468 | query_embedding = query_embedding.half() 469 | prev_frame_embedding = prev_frame_embedding.half() 470 | pad = pad.half() 471 | 472 | if allow_parallel: 473 | d = local_pairwise_distances_parallel(query_embedding, prev_frame_embedding, 474 | max_distance=max_distance, atrous_rate=atrous_rate, allow_downsample=allow_downsample) 475 | else: 476 | d = local_pairwise_distances(query_embedding, prev_frame_embedding, 477 | max_distance=max_distance, atrous_rate=atrous_rate, allow_downsample=allow_downsample) 478 | 479 | height, width = d.size()[:2] 480 | 481 | labels = prev_frame_labels.permute(2, 0, 1).unsqueeze(1) 482 | if (height, width) != ori_size: 483 | labels = F.interpolate(labels, size=(height, width), mode='nearest') 484 | 485 | pad_max_distance = max_distance - max_distance % atrous_rate 486 | atrous_max_distance = pad_max_distance // atrous_rate 487 | 488 | padded_labels = F.pad(labels, 489 | (pad_max_distance, pad_max_distance, 490 | pad_max_distance, pad_max_distance, 491 | ), mode='constant', value=0) 492 | offset_masks = F.unfold(padded_labels, kernel_size=(height, width), 493 | stride=(atrous_rate, atrous_rate)).view(obj_num, height, width, -1).permute(1, 2, 3, 0) > 0.9 494 | 495 | d_tiled = d.unsqueeze(-1).expand((-1,-1,-1,obj_num)) # h, w, num_local_pos, obj_num 496 | 497 | d_masked = torch.where(offset_masks, d_tiled, pad) 498 | dists, pos = torch.min(d_masked, dim=2) 499 | multi_dists = [dists.permute(2, 0, 1).unsqueeze(1)] # n_objects, num_multi_local, h, w 500 | 501 | reshaped_d_masked = d_masked.view(height, width, 2 * atrous_max_distance + 1, 502 | 2 * atrous_max_distance + 1, obj_num) 503 | for local_dis in multi_local_distance[:-1]: 504 | local_dis = local_dis // atrous_rate 505 | start_idx = atrous_max_distance - local_dis 506 | end_idx = atrous_max_distance + local_dis + 1 507 | new_d_masked = reshaped_d_masked[:, :, start_idx:end_idx, start_idx:end_idx, :].contiguous() 508 | new_d_masked = new_d_masked.view(height, width, -1, obj_num) 509 | new_dists, _ = torch.min(new_d_masked, dim=2) 510 | new_dists = new_dists.permute(2, 0, 1).unsqueeze(1) 511 | multi_dists.append(new_dists) 512 | 513 | multi_dists = torch.cat(multi_dists, dim=1) 514 | multi_dists = (torch.sigmoid(multi_dists + dis_bias.view(-1, 1, 1, 1)) - 0.5) * 2 515 | 516 | if use_float16: 517 | multi_dists = multi_dists.float() 518 | 519 | if (height, width) != ori_size: 520 | multi_dists = F.interpolate(multi_dists, size=ori_size, 521 | mode='bilinear', align_corners=True) 522 | multi_dists = multi_dists.permute(2, 3, 0, 1) 523 | multi_dists = multi_dists.view(1, ori_size[0], ori_size[1], obj_num, -1) 524 | 525 | return multi_dists -------------------------------------------------------------------------------- /networks/layers/normalization.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class FrozenBatchNorm2d(nn.Module): 7 | """ 8 | BatchNorm2d where the batch statistics and the affine parameters 9 | are fixed 10 | """ 11 | def __init__(self, n, epsilon=1e-5): 12 | super(FrozenBatchNorm2d, self).__init__() 13 | self.register_buffer("weight", torch.ones(n)) 14 | self.register_buffer("bias", torch.zeros(n)) 15 | self.register_buffer("running_mean", torch.zeros(n)) 16 | self.register_buffer("running_var", torch.ones(n)) 17 | self.epsilon = epsilon 18 | 19 | def forward(self, x): 20 | scale = self.weight * (self.running_var + self.epsilon).rsqrt() 21 | bias = self.bias - self.running_mean * scale 22 | scale = scale.reshape(1, -1, 1, 1) 23 | bias = bias.reshape(1, -1, 1, 1) 24 | return x * scale + bias -------------------------------------------------------------------------------- /networks/layers/shannon_entropy.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def normalize(image, MIN_BOUND, MAX_BOUND): 5 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 6 | reverse_image = 1 - image 7 | return reverse_image 8 | 9 | def cal_shannon_entropy(preds): 10 | uncertainty = -1.0 * torch.sum(preds * torch.log(preds + 1e-6), dim=1, keepdim=True) 11 | uncertainty_norm = normalize(uncertainty, 0, np.log(2)) * 7 12 | return uncertainty,uncertainty_norm 13 | 14 | 15 | def normalize_train(image, MIN_BOUND, MAX_BOUND): 16 | image = (image - MIN_BOUND) / (MAX_BOUND - MIN_BOUND) 17 | return image 18 | -------------------------------------------------------------------------------- /networks/rpcm/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/rpcm/__init__.py -------------------------------------------------------------------------------- /networks/rpcm/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/rpcm/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /networks/rpcm/__pycache__/p2t_base.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/rpcm/__pycache__/p2t_base.cpython-36.pyc -------------------------------------------------------------------------------- /networks/rpcm/__pycache__/prop_module.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/networks/rpcm/__pycache__/prop_module.cpython-36.pyc -------------------------------------------------------------------------------- /networks/rpcm/prop_module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from networks.layers.attention import IA_gate 5 | from networks.layers.gct import Bottleneck, GCT 6 | from networks.layers.aspp import ASPP 7 | 8 | class P2C(nn.Module): 9 | def __init__(self, 10 | in_dim=256, 11 | attention_dim=400, 12 | embed_dim=100, 13 | refine_dim=48, 14 | low_level_dim=256): 15 | super(P2C,self).__init__() 16 | self.embed_dim = embed_dim 17 | IA_in_dim = attention_dim 18 | 19 | # Memory Encoder 20 | self.IA1 = IA_gate(IA_in_dim, in_dim) 21 | self.layer1 = Bottleneck(in_dim, embed_dim) 22 | 23 | self.IA2 = IA_gate(IA_in_dim, embed_dim) 24 | self.layer2 = Bottleneck(embed_dim, embed_dim, 1, 2) 25 | 26 | self.IA3 = IA_gate(IA_in_dim, embed_dim) 27 | self.layer3 = Bottleneck(embed_dim, embed_dim * 2, 2) 28 | 29 | self.IA4 = IA_gate(IA_in_dim, embed_dim * 2) 30 | self.layer4 = Bottleneck(embed_dim * 2, embed_dim * 2, 1, 2) 31 | 32 | self.IA5 = IA_gate(IA_in_dim, embed_dim * 2) 33 | self.layer5 = Bottleneck(embed_dim * 2, embed_dim * 2, 1, 4) 34 | 35 | self.IA9 = IA_gate(IA_in_dim, embed_dim * 2) 36 | self.ASPP = ASPP() 37 | 38 | # Propagation Modulator 39 | self.M1_Reweight_Layer_1 = IA_gate(IA_in_dim, embed_dim * 2) 40 | self.M1_Bottleneck_1 = Bottleneck(embed_dim*2, embed_dim * 2, 1) 41 | 42 | self.M1_Reweight_Layer_2 = IA_gate(IA_in_dim, embed_dim * 2) 43 | self.M1_Bottleneck_2 = Bottleneck(embed_dim*2, embed_dim * 1, 1) 44 | 45 | self.M1_Reweight_Layer_3 = IA_gate(IA_in_dim, embed_dim * 1) 46 | self.M1_Bottleneck_3 = Bottleneck(embed_dim*1, embed_dim * 1, 1) 47 | 48 | # Correction Modulator 49 | self.M2_Reweight_Layer_1 = IA_gate(IA_in_dim, embed_dim * 2) 50 | self.M2_Bottleneck_1 = Bottleneck(embed_dim*2, embed_dim * 2, 1) 51 | 52 | self.M2_Reweight_Layer_2 = IA_gate(IA_in_dim, embed_dim * 2) 53 | self.M2_Bottleneck_2 = Bottleneck(embed_dim*2, embed_dim * 1, 1) 54 | 55 | self.M2_Reweight_Layer_3 = IA_gate(IA_in_dim, embed_dim *1) 56 | self.M2_Bottleneck_3 = Bottleneck(embed_dim*1, embed_dim * 1, 1) 57 | 58 | # Decoder 59 | self.GCT_sc = GCT(low_level_dim + embed_dim) 60 | self.conv_sc = nn.Conv2d(low_level_dim + embed_dim, refine_dim, 1, bias=False) 61 | self.bn_sc = nn.GroupNorm(int(refine_dim / 4), refine_dim) 62 | self.relu = nn.ReLU(inplace=True) 63 | 64 | self.IA10 = IA_gate(IA_in_dim, embed_dim + refine_dim) 65 | self.conv1 = nn.Conv2d(embed_dim + refine_dim, int(embed_dim / 2), kernel_size=3, padding=1, bias=False) 66 | self.bn1 = nn.GroupNorm(32, int(embed_dim / 2)) 67 | 68 | 69 | self.IA11 = IA_gate(IA_in_dim, int(embed_dim / 2)) 70 | self.conv2 = nn.Conv2d(int(embed_dim / 2), int(embed_dim / 2), kernel_size=3, padding=1, bias=False) 71 | self.bn2 = nn.GroupNorm(32, int(embed_dim / 2)) 72 | 73 | # Output 74 | self.IA_final_fg = nn.Linear(IA_in_dim, int(embed_dim / 2) + 1) 75 | self.IA_final_bg = nn.Linear(IA_in_dim, int(embed_dim / 2) + 1) 76 | 77 | nn.init.kaiming_normal_(self.conv_sc.weight,mode='fan_out', nonlinearity='relu') 78 | nn.init.kaiming_normal_(self.conv1.weight,mode='fan_out', nonlinearity='relu') 79 | nn.init.kaiming_normal_(self.conv2.weight,mode='fan_out', nonlinearity='relu') 80 | 81 | 82 | def forward(self, x, IA_head=None,memory_list=None,low_level_feat=None,to_cat_previous_frame=None): 83 | # Memory Encoder 84 | x = self.IA1(x, IA_head) 85 | x = self.layer1(x) 86 | 87 | x = self.IA2(x, IA_head) 88 | x = self.layer2(x) 89 | 90 | low_level_feat = torch.cat([low_level_feat.expand(x.size()[0], -1, -1, -1), x], dim=1) 91 | 92 | x = self.IA3(x, IA_head) 93 | x = self.layer3(x) 94 | 95 | x = self.IA4(x, IA_head) 96 | x = self.layer4(x) 97 | 98 | x = self.IA5(x, IA_head) 99 | x = self.layer5(x) 100 | 101 | x = self.IA9(x, IA_head) 102 | x = self.ASPP(x) 103 | 104 | # Propagation Modulator 105 | x_emb_cur_1 = x.detach() 106 | if memory_list[0]==None or x_emb_cur_1.size()!=memory_list[0].size(): 107 | # reference changes or it's first frame 108 | memory_list[0] = x_emb_cur_1 109 | x = self.prop_modu(x, memory_list[0].cuda(x.device),IA_head) 110 | 111 | # Correction Modulator 112 | x_emb_cur_2 = x.detach() 113 | if memory_list[1]==None or x_emb_cur_2.size()!=memory_list[1].size(): 114 | # reference changes or it's first frame 115 | memory_list[1] = x_emb_cur_2 116 | x = self.corr_modu(x, memory_list[1].cuda(x.device),IA_head) 117 | 118 | # Decoder 119 | x = self.decoder(x, low_level_feat, IA_head) 120 | 121 | fg_logit = self.IA_logit(x, IA_head, self.IA_final_fg) 122 | bg_logit = self.IA_logit(x, IA_head, self.IA_final_bg) 123 | 124 | pred = self.augment_background_logit(fg_logit, bg_logit) 125 | 126 | memory_list =[x_emb_cur_1.cpu(),memory_list[1].cpu()] 127 | return pred,memory_list 128 | 129 | def IA_logit(self, x, IA_head, IA_final): 130 | n, c, h, w = x.size() 131 | x = x.view(1, n * c, h, w) 132 | IA_output = IA_final(IA_head) 133 | IA_weight = IA_output[:, :c] 134 | IA_bias = IA_output[:, -1] 135 | IA_weight = IA_weight.view(n, c, 1, 1) 136 | IA_bias = IA_bias.view(-1) 137 | logit = F.conv2d(x, weight=IA_weight, bias=IA_bias, groups=n).view(n, 1, h, w) 138 | return logit 139 | 140 | def decoder(self, x, low_level_feat, IA_head): 141 | x = F.interpolate(x, size=low_level_feat.size()[2:], mode='bicubic', align_corners=True) 142 | 143 | low_level_feat = self.GCT_sc(low_level_feat) 144 | low_level_feat = self.conv_sc(low_level_feat) 145 | low_level_feat = self.bn_sc(low_level_feat) 146 | low_level_feat = self.relu(low_level_feat) 147 | 148 | x = torch.cat([x, low_level_feat], dim=1) 149 | 150 | x = self.IA10(x, IA_head) 151 | x = self.conv1(x) 152 | x = self.bn1(x) 153 | x = self.relu(x) 154 | 155 | x = self.IA11(x, IA_head) 156 | x = self.conv2(x) 157 | x = self.bn2(x) 158 | x = self.relu(x) 159 | 160 | return x 161 | 162 | def prop_modu(self, x, x_memory,IA_head): 163 | x = torch.cat([x, x_memory], dim=1) 164 | x = self.M1_Reweight_Layer_1(x, IA_head) 165 | x = self.M1_Bottleneck_1(x) 166 | x = self.M1_Reweight_Layer_2(x, IA_head) 167 | x = self.M1_Bottleneck_2(x) 168 | x = self.M1_Reweight_Layer_3(x, IA_head) 169 | x = self.M1_Bottleneck_3(x) 170 | return x 171 | 172 | def corr_modu(self, x, x_memory,IA_head): 173 | x = torch.cat([x, x_memory], dim=1) 174 | x = self.M2_Reweight_Layer_1(x, IA_head) 175 | x = self.M2_Bottleneck_1(x) 176 | x = self.M2_Reweight_Layer_2(x, IA_head) 177 | x = self.M2_Bottleneck_2(x) 178 | x = self.M2_Reweight_Layer_3(x, IA_head) 179 | x = self.M2_Bottleneck_3(x) 180 | return x 181 | 182 | 183 | def augment_background_logit(self, fg_logit, bg_logit): 184 | # Augment the logit of absolute background by using the relative background logit of all the 185 | # foreground objects. 186 | obj_num = fg_logit.size(0) 187 | pred = fg_logit 188 | if obj_num > 1: 189 | bg_logit = bg_logit[1:obj_num, :, :, :] 190 | aug_bg_logit, _ = torch.min(bg_logit, dim=0, keepdim=True) 191 | pad = torch.zeros(aug_bg_logit.size(), device=aug_bg_logit.device).expand(obj_num - 1, -1, -1, -1) 192 | aug_bg_logit = torch.cat([aug_bg_logit, pad], dim=0) 193 | pred = pred + aug_bg_logit 194 | pred = pred.permute(1,0,2,3) 195 | return pred 196 | 197 | class DynamicPreHead(nn.Module): 198 | def __init__(self, in_dim=3, embed_dim=100, kernel_size=1): 199 | super(DynamicPreHead,self).__init__() 200 | self.conv=nn.Conv2d(in_dim,embed_dim,kernel_size=kernel_size,stride=1,padding=int((kernel_size-1)/2)) 201 | self.bn = nn.GroupNorm(int(embed_dim / 4), embed_dim) 202 | self.relu = nn.ReLU(True) 203 | nn.init.kaiming_normal_(self.conv.weight,mode='fan_out',nonlinearity='relu') 204 | 205 | def forward(self, x): 206 | x = self.conv(x) 207 | x = self.bn(x) 208 | x = self.relu(x) 209 | return x 210 | -------------------------------------------------------------------------------- /networks/rpcm/rpcm.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from networks.layers.loss import Concat_CrossEntropyLoss 6 | from networks.layers.matching import global_matching, global_matching_for_eval, local_matching, foreground2background 7 | from networks.layers.attention import calculate_attention_head, calculate_attention_head_for_eval 8 | from networks.rpcm.prop_module import P2C,DynamicPreHead 9 | 10 | class RPCM(nn.Module): 11 | def __init__(self, cfg, feature_extracter): 12 | super(RPCM, self).__init__() 13 | self.cfg = cfg 14 | self.epsilon = cfg.MODEL_EPSILON 15 | 16 | self.feature_extracter=feature_extracter 17 | 18 | self.seperate_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM, cfg.MODEL_ASPP_OUTDIM, kernel_size=3, stride=1, padding=1, groups=cfg.MODEL_ASPP_OUTDIM) 19 | self.bn1 = nn.GroupNorm(cfg.MODEL_GN_GROUPS, cfg.MODEL_ASPP_OUTDIM) 20 | self.relu1 = nn.ReLU(True) 21 | self.embedding_conv = nn.Conv2d(cfg.MODEL_ASPP_OUTDIM, cfg.MODEL_SEMANTIC_EMBEDDING_DIM, 1, 1) 22 | self.bn2 = nn.GroupNorm(cfg.MODEL_GN_EMB_GROUPS, cfg.MODEL_SEMANTIC_EMBEDDING_DIM) 23 | self.relu2 = nn.ReLU(True) 24 | self.semantic_embedding=nn.Sequential(*[self.seperate_conv, self.bn1, self.relu1, self.embedding_conv, self.bn2, self.relu2]) 25 | 26 | self.bg_bias = nn.Parameter(torch.zeros(1, 1, 1, 1)) 27 | self.fg_bias = nn.Parameter(torch.zeros(1, 1, 1, 1)) 28 | 29 | self.criterion = Concat_CrossEntropyLoss(cfg.TRAIN_TOP_K_PERCENT_PIXELS, cfg.TRAIN_HARD_MINING_STEP) 30 | 31 | for m in self.semantic_embedding: 32 | if isinstance(m, nn.Conv2d): 33 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 34 | 35 | self.dynamic_seghead = P2C( 36 | in_dim=cfg.MODEL_SEMANTIC_EMBEDDING_DIM + cfg.MODEL_PRE_HEAD_EMBEDDING_DIM, 37 | attention_dim=cfg.MODEL_SEMANTIC_EMBEDDING_DIM * 4, 38 | embed_dim=cfg.MODEL_HEAD_EMBEDDING_DIM, 39 | refine_dim=cfg.MODEL_REFINE_CHANNELS, 40 | low_level_dim=cfg.MODEL_LOW_LEVEL_INPLANES) 41 | 42 | in_dim = 2 + len(cfg.MODEL_MULTI_LOCAL_DISTANCE) 43 | if cfg.MODEL_MATCHING_BACKGROUND: 44 | in_dim += 1 +len(cfg.MODEL_MULTI_LOCAL_DISTANCE) 45 | self.dynamic_prehead = DynamicPreHead( 46 | in_dim=in_dim, 47 | embed_dim=cfg.MODEL_PRE_HEAD_EMBEDDING_DIM) 48 | 49 | def forward(self, input,memory_prev_list, ref_frame_label, previous_frame_mask, current_frame_mask, 50 | gt_ids, step=0, tf_board=False): 51 | x, low_level = self.extract_feature(input) 52 | ref_frame_embedding, previous_frame_embedding, current_frame_embedding = torch.split(x, split_size_or_sections=int(x.size(0)/3), dim=0) 53 | _, _, current_low_level = torch.split(low_level, split_size_or_sections=int(x.size(0)/3), dim=0) 54 | bs, c, h, w = current_frame_embedding.size() 55 | tmp_dic, boards, memory_cur_list= self.before_seghead_process( 56 | memory_prev_list, 57 | ref_frame_embedding, 58 | previous_frame_embedding, 59 | current_frame_embedding, 60 | ref_frame_label, 61 | previous_frame_mask, 62 | gt_ids, 63 | current_low_level=current_low_level,tf_board=tf_board) 64 | label_dic=[] 65 | all_pred = [] 66 | for i in range(bs): 67 | tmp_pred_logits = tmp_dic[i] 68 | tmp_pred_logits = nn.functional.interpolate(tmp_pred_logits, size=(input.shape[2],input.shape[3]), mode='bilinear', align_corners=True) 69 | tmp_dic[i] = tmp_pred_logits 70 | label_tmp, obj_num = current_frame_mask[i], gt_ids[i] 71 | label_dic.append(label_tmp.long()) 72 | pred = tmp_pred_logits 73 | preds_s = torch.argmax(pred,dim=1) 74 | all_pred.append(preds_s) 75 | all_pred = torch.cat(all_pred, dim=0) 76 | 77 | return self.criterion(tmp_dic, label_dic, step), all_pred, boards, memory_cur_list 78 | 79 | def forward_for_eval(self,memory_prev_list, ref_embeddings, ref_masks, prev_embedding, prev_mask, current_frame, pred_size, gt_ids): 80 | current_frame_embedding, current_low_level = self.extract_feature(current_frame) 81 | if prev_embedding is None: 82 | return None, current_frame_embedding,memory_prev_list 83 | else: 84 | bs,c,h,w = current_frame_embedding.size() 85 | tmp_dic, _ ,memory_cur_list= self.before_seghead_process( 86 | memory_prev_list, 87 | ref_embeddings, 88 | prev_embedding, 89 | current_frame_embedding, 90 | ref_masks, 91 | prev_mask, 92 | gt_ids, 93 | current_low_level=current_low_level, 94 | tf_board=False) 95 | all_pred = [] 96 | for i in range(bs): 97 | pred = tmp_dic[i] 98 | pred = nn.functional.interpolate(pred, size=(pred_size[0],pred_size[1]), mode='bilinear',align_corners=True) 99 | all_pred.append(pred) 100 | all_pred = torch.cat(all_pred, dim=0) 101 | all_pred = torch.softmax(all_pred, dim=1) 102 | return all_pred, current_frame_embedding, memory_cur_list 103 | 104 | def extract_feature(self, x): 105 | x, low_level=self.feature_extracter(x) 106 | x = self.semantic_embedding(x) 107 | return x, low_level 108 | 109 | def before_seghead_process(self, memory_prev_list=None, 110 | ref_frame_embedding=None, previous_frame_embedding=None, current_frame_embedding=None, 111 | ref_frame_label=None, previous_frame_mask=None, 112 | gt_ids=None, current_low_level=None, tf_board=False): 113 | 114 | cfg = self.cfg 115 | 116 | dic_tmp=[] 117 | bs,c,h,w = current_frame_embedding.size() 118 | 119 | if self.training: 120 | scale_ref_frame_label = torch.nn.functional.interpolate(ref_frame_label.float(),size=(h,w),mode='nearest') 121 | scale_ref_frame_label = scale_ref_frame_label.int() 122 | else: 123 | scale_ref_frame_labels = [] 124 | for each_ref_frame_label in ref_frame_label: 125 | each_scale_ref_frame_label = torch.nn.functional.interpolate(each_ref_frame_label.float(),size=(h,w),mode='nearest') 126 | each_scale_ref_frame_label = each_scale_ref_frame_label.int() 127 | scale_ref_frame_labels.append(each_scale_ref_frame_label) 128 | 129 | scale_previous_frame_label=torch.nn.functional.interpolate(previous_frame_mask.float(),size=(h,w),mode='nearest') 130 | scale_previous_frame_label=scale_previous_frame_label.int() 131 | 132 | boards = {'image': {}, 'scalar': {}} 133 | memory_cur_list = [] 134 | 135 | for n in range(bs): 136 | ref_obj_ids = torch.arange(0, gt_ids[n] + 1, device=current_frame_embedding.device).int().view(-1, 1, 1, 1) 137 | obj_num = ref_obj_ids.size(0) 138 | if gt_ids[n] > 0: 139 | dis_bias = torch.cat([self.bg_bias, self.fg_bias.expand(gt_ids[n], -1, -1, -1)], dim=0) 140 | else: 141 | dis_bias = self.bg_bias 142 | 143 | seq_current_frame_embedding = current_frame_embedding[n] 144 | seq_current_frame_embedding = seq_current_frame_embedding.permute(1,2,0) 145 | 146 | if self.training: 147 | seq_ref_frame_embedding = ref_frame_embedding[n] 148 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0) 149 | 150 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float() 151 | to_cat_ref_frame = seq_ref_frame_label 152 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0) 153 | 154 | global_matching_fg = global_matching( 155 | reference_embeddings=seq_ref_frame_embedding, 156 | query_embeddings=seq_current_frame_embedding, 157 | reference_labels=seq_ref_frame_label, 158 | n_chunks=cfg.TRAIN_GLOBAL_CHUNKS, 159 | dis_bias=dis_bias, 160 | atrous_rate=cfg.TRAIN_GLOBAL_ATROUS_RATE, 161 | use_float16=cfg.MODEL_FLOAT16_MATCHING) 162 | else: 163 | all_reference_embeddings = [] 164 | all_reference_labels = [] 165 | seq_ref_frame_labels = [] 166 | for idx in range(len(scale_ref_frame_labels)): 167 | each_ref_frame_embedding = ref_frame_embedding[idx] 168 | scale_ref_frame_label = scale_ref_frame_labels[idx] 169 | 170 | seq_ref_frame_embedding = each_ref_frame_embedding[n] 171 | seq_ref_frame_embedding = seq_ref_frame_embedding.permute(1,2,0) 172 | all_reference_embeddings.append(seq_ref_frame_embedding) 173 | 174 | seq_ref_frame_label = (scale_ref_frame_label[n].int() == ref_obj_ids).float() 175 | seq_ref_frame_labels.append(seq_ref_frame_label) 176 | seq_ref_frame_label = seq_ref_frame_label.squeeze(1).permute(1,2,0) 177 | all_reference_labels.append(seq_ref_frame_label) 178 | global_matching_fg = global_matching_for_eval( 179 | all_reference_embeddings=all_reference_embeddings, 180 | query_embeddings=seq_current_frame_embedding, 181 | all_reference_labels=all_reference_labels, 182 | n_chunks=cfg.TEST_GLOBAL_CHUNKS, 183 | dis_bias=dis_bias, 184 | atrous_rate=cfg.TEST_GLOBAL_ATROUS_RATE, 185 | use_float16=cfg.MODEL_FLOAT16_MATCHING) 186 | 187 | seq_prev_frame_embedding = previous_frame_embedding[n] 188 | seq_prev_frame_embedding = seq_prev_frame_embedding.permute(1,2,0) 189 | seq_previous_frame_label = (scale_previous_frame_label[n].int() == ref_obj_ids).float() 190 | to_cat_previous_frame = seq_previous_frame_label 191 | seq_previous_frame_label = seq_previous_frame_label.squeeze(1).permute(1,2,0) 192 | local_matching_fg = local_matching( 193 | prev_frame_embedding=seq_prev_frame_embedding, 194 | query_embedding=seq_current_frame_embedding, 195 | prev_frame_labels=seq_previous_frame_label, 196 | multi_local_distance=cfg.MODEL_MULTI_LOCAL_DISTANCE, 197 | dis_bias=dis_bias, 198 | use_float16=cfg.MODEL_FLOAT16_MATCHING, 199 | atrous_rate=cfg.TRAIN_LOCAL_ATROUS_RATE if self.training else cfg.TEST_LOCAL_ATROUS_RATE, 200 | allow_downsample=cfg.MODEL_LOCAL_DOWNSAMPLE, 201 | allow_parallel=cfg.TRAIN_LOCAL_PARALLEL if self.training else cfg.TEST_LOCAL_PARALLEL) 202 | 203 | to_cat_current_frame_embedding = current_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)) 204 | to_cat_prev_frame_embedding = previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)) 205 | 206 | to_cat_global_matching_fg = global_matching_fg.squeeze(0).permute(2,3,0,1) 207 | to_cat_local_matching_fg = local_matching_fg.squeeze(0).permute(2,3,0,1) 208 | 209 | if cfg.MODEL_MATCHING_BACKGROUND: 210 | to_cat_global_matching_bg = foreground2background(to_cat_global_matching_fg, gt_ids[n] + 1) 211 | reshaped_prev_nn_feature_n = to_cat_local_matching_fg.permute(0, 2, 3, 1).unsqueeze(1) 212 | to_cat_local_matching_bg = foreground2background(reshaped_prev_nn_feature_n, gt_ids[n] + 1) 213 | to_cat_local_matching_bg = to_cat_local_matching_bg.permute(0, 4, 2, 3, 1).squeeze(-1) 214 | 215 | pre_to_cat = torch.cat((to_cat_global_matching_fg, to_cat_local_matching_fg, to_cat_previous_frame), 1) 216 | 217 | if cfg.MODEL_MATCHING_BACKGROUND: 218 | pre_to_cat = torch.cat([pre_to_cat, to_cat_local_matching_bg, to_cat_global_matching_bg], 1) 219 | 220 | pre_to_cat = self.dynamic_prehead(pre_to_cat) 221 | 222 | to_cat = torch.cat((to_cat_current_frame_embedding, pre_to_cat),1) 223 | if self.training: 224 | attention_head = calculate_attention_head( 225 | ref_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 226 | to_cat_ref_frame, 227 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 228 | to_cat_previous_frame, 229 | epsilon=self.epsilon) 230 | else: 231 | attention_head = calculate_attention_head_for_eval( 232 | ref_frame_embedding, 233 | seq_ref_frame_labels, 234 | previous_frame_embedding[n].unsqueeze(0).expand((obj_num,-1,-1,-1)), 235 | to_cat_previous_frame, 236 | epsilon=self.epsilon) 237 | 238 | low_level_feat = current_low_level[n].unsqueeze(0) 239 | 240 | pred,memory_tmp_list = self.dynamic_seghead(to_cat, attention_head, memory_prev_list[n], low_level_feat,to_cat_previous_frame) 241 | memory_cur_list.append(memory_tmp_list) 242 | dic_tmp.append(pred) 243 | 244 | 245 | return dic_tmp, boards, memory_cur_list 246 | 247 | def get_module(): 248 | return RPCM 249 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pandas==0.23.4 2 | numpy==1.16.3 3 | torchvision==0.2.1 4 | torch==1.4.0 5 | opencv_python==3.4.2.17 6 | #blocks==0.1.1 7 | matplotlib==3.3.3 8 | Pillow==8.1.0 9 | seaborn==0.11.1 10 | tensorboardX==2.1 11 | -------------------------------------------------------------------------------- /scripts/ytb_eval_with_RPA.sh: -------------------------------------------------------------------------------- 1 | config="configs.resnet101_rpcm_ytb_stage_1" 2 | 3 | # using both propagation-correction modulator and reliable proxy augmentation 4 | 5 | # eval YTB19 6 | datasets="youtubevos19" 7 | python ../tools/eval_rpa.py --config ${config} --dataset ${datasets} --ckpt_step 400000 --global_chunks 4 --gpu_id 0 8 | 9 | # eval YTB19 10 | datasets="youtubevos18" 11 | python ../tools/eval_rpa.py --config ${config} --dataset ${datasets} --ckpt_step 400000 --global_chunks 4 --gpu_id 0 12 | -------------------------------------------------------------------------------- /scripts/ytb_eval_without_RPA.sh: -------------------------------------------------------------------------------- 1 | config="configs.resnet101_rpcm_ytb_stage_1" 2 | 3 | # only use propagation-correction modulator 4 | 5 | # eval on YTB18 6 | datasets="youtubevos18" 7 | python ../tools/eval.py --config ${config} --dataset ${datasets} --ckpt_step 400000 --global_chunks 4 --gpu_id 1 8 | 9 | # eval on YTB19 10 | datasets="youtubevos19" 11 | python ../tools/eval.py --config ${config} --dataset ${datasets} --ckpt_step 400000 --global_chunks 4 --gpu_id 1 12 | -------------------------------------------------------------------------------- /scripts/ytb_train.sh: -------------------------------------------------------------------------------- 1 | config="configs.resnet101_rpcm_ytb_stage_1" 2 | datasets="youtubevos" 3 | # first stage training 4 | python ../tools/train.py --config ${config} --datasets ${datasets} --global_chunks 1 5 | 6 | # second stage training 7 | config="configs.resnet101_rpcm_ytb_stage_2" 8 | python ../tools/train.py --config ${config} --datasets ${datasets} --global_chunks 1 9 | 10 | 11 | -------------------------------------------------------------------------------- /tools/eval.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | from networks.engine.eval_manager import Evaluator 5 | import importlib 6 | 7 | def main(): 8 | import argparse 9 | parser = argparse.ArgumentParser(description="Eval RPCM") 10 | parser.add_argument('--exp_name', type=str, default='') 11 | 12 | parser.add_argument('--config', type=str, default='') 13 | 14 | parser.add_argument('--gpu_id', type=int, default=0) 15 | 16 | parser.add_argument('--ckpt_path', type=str, default='') 17 | parser.add_argument('--ckpt_step', type=int, default=-1) 18 | 19 | parser.add_argument('--dataset', type=str, default='') 20 | 21 | parser.add_argument('--flip', action='store_true') 22 | parser.set_defaults(flip=False) 23 | parser.add_argument('--ms', nargs='+', type=float, default=[1.]) 24 | parser.add_argument('--max_long_edge', type=int, default=-1) 25 | 26 | parser.add_argument('--float16', action='store_true') 27 | parser.set_defaults(float16=False) 28 | parser.add_argument('--global_atrous_rate', type=int, default=1) 29 | parser.add_argument('--global_chunks', type=int, default=4) 30 | parser.add_argument('--min_matching_pixels', type=int, default=0) 31 | parser.add_argument('--no_local_parallel', dest='local_parallel', action='store_false') 32 | parser.set_defaults(local_parallel=True) 33 | args = parser.parse_args() 34 | 35 | config = importlib.import_module(args.config) 36 | cfg = config.cfg 37 | 38 | cfg.TEST_GPU_ID = args.gpu_id 39 | if args.exp_name != '': 40 | cfg.EXP_NAME = args.exp_name 41 | 42 | if args.ckpt_path != '': 43 | cfg.TEST_CKPT_PATH = args.ckpt_path 44 | if args.ckpt_step > 0: 45 | cfg.TEST_CKPT_STEP = args.ckpt_step 46 | 47 | if args.dataset != '': 48 | cfg.TEST_DATASET = args.dataset 49 | 50 | cfg.TEST_FLIP = args.flip 51 | cfg.TEST_MULTISCALE = args.ms 52 | if args.max_long_edge > 0: 53 | cfg.TEST_MAX_SIZE = args.max_long_edge 54 | else: 55 | cfg.TEST_MAX_SIZE = 800 * 1.3 if cfg.TEST_MULTISCALE == [1.] else 800 56 | 57 | cfg.MODEL_FLOAT16_MATCHING = args.float16 58 | if 'RPCMp' in cfg.MODEL_MODULE: 59 | cfg.TEST_GLOBAL_ATROUS_RATE = [args.global_atrous_rate, 1, 1] 60 | else: 61 | cfg.TEST_GLOBAL_ATROUS_RATE = args.global_atrous_rate 62 | cfg.TEST_GLOBAL_CHUNKS = args.global_chunks 63 | cfg.TEST_LOCAL_PARALLEL = args.local_parallel 64 | 65 | if args.min_matching_pixels > 0: 66 | cfg.TEST_GLOBAL_MATCHING_MIN_PIXEL = args.min_matching_pixels 67 | 68 | evaluator = Evaluator(cfg=cfg) 69 | evaluator.evaluating() 70 | 71 | if __name__ == '__main__': 72 | main() 73 | 74 | -------------------------------------------------------------------------------- /tools/eval_rpa.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | from networks.engine.eval_manager_rpa import Evaluator 5 | import importlib 6 | 7 | def main(): 8 | import argparse 9 | parser = argparse.ArgumentParser(description="Eval RPCM") 10 | parser.add_argument('--exp_name', type=str, default='') 11 | 12 | parser.add_argument('--config', type=str, default='') 13 | 14 | parser.add_argument('--gpu_id', type=int, default=0) 15 | 16 | parser.add_argument('--ckpt_path', type=str, default='') 17 | parser.add_argument('--ckpt_step', type=int, default=-1) 18 | 19 | parser.add_argument('--dataset', type=str, default='') 20 | 21 | parser.add_argument('--flip', action='store_true') 22 | parser.set_defaults(flip=False) 23 | parser.add_argument('--ms', nargs='+', type=float, default=[1.]) 24 | parser.add_argument('--max_long_edge', type=int, default=-1) 25 | parser.add_argument('--mem_every', type=int, default=5) 26 | parser.add_argument('--ucr', type=float, default=1.0) 27 | parser.add_argument('--float16', action='store_true') 28 | parser.add_argument('--vis', action='store_true') 29 | parser.set_defaults(float16=False) 30 | parser.add_argument('--global_atrous_rate', type=int, default=1) 31 | parser.add_argument('--global_chunks', type=int, default=4) 32 | parser.add_argument('--min_matching_pixels', type=int, default=0) 33 | parser.add_argument('--no_local_parallel', dest='local_parallel', action='store_false') 34 | parser.set_defaults(local_parallel=True) 35 | args = parser.parse_args() 36 | 37 | config = importlib.import_module(args.config) 38 | cfg = config.cfg 39 | 40 | cfg.TEST_GPU_ID = args.gpu_id 41 | if args.exp_name != '': 42 | cfg.EXP_NAME = args.exp_name 43 | if args.mem_every != '': 44 | cfg.MEM_EVERY = args.mem_every 45 | if args.ucr != '': 46 | cfg.UNC_RATIO = args.ucr 47 | if args.ckpt_path != '': 48 | cfg.TEST_CKPT_PATH = args.ckpt_path 49 | if args.ckpt_step > 0: 50 | cfg.TEST_CKPT_STEP = args.ckpt_step 51 | if args.dataset != '': 52 | cfg.TEST_DATASET = args.dataset 53 | 54 | cfg.UNC_VIS = args.vis 55 | 56 | cfg.TEST_FLIP = args.flip 57 | cfg.TEST_MULTISCALE = args.ms 58 | if args.max_long_edge > 0: 59 | cfg.TEST_MAX_SIZE = args.max_long_edge 60 | else: 61 | cfg.TEST_MAX_SIZE = 800 * 1.3 if cfg.TEST_MULTISCALE == [1.] else 800 62 | 63 | cfg.MODEL_FLOAT16_MATCHING = args.float16 64 | if 'RPCMp' in cfg.MODEL_MODULE: 65 | cfg.TEST_GLOBAL_ATROUS_RATE = [args.global_atrous_rate, 1, 1] 66 | else: 67 | cfg.TEST_GLOBAL_ATROUS_RATE = args.global_atrous_rate 68 | cfg.TEST_GLOBAL_CHUNKS = args.global_chunks 69 | cfg.TEST_LOCAL_PARALLEL = args.local_parallel 70 | 71 | if args.min_matching_pixels > 0: 72 | cfg.TEST_GLOBAL_MATCHING_MIN_PIXEL = args.min_matching_pixels 73 | 74 | evaluator = Evaluator(cfg=cfg) 75 | evaluator.evaluating() 76 | 77 | if __name__ == '__main__': 78 | main() 79 | 80 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.append('.') 3 | sys.path.append('..') 4 | from networks.engine.train_manager import Trainer 5 | import torch.multiprocessing as mp 6 | import importlib 7 | 8 | def main_worker(gpu, cfg): 9 | # Initiate a training manager 10 | trainer = Trainer(rank=gpu, cfg=cfg) 11 | # Start Training 12 | trainer.sequential_training() 13 | 14 | def main(): 15 | import argparse 16 | parser = argparse.ArgumentParser(description="Train RPCM") 17 | parser.add_argument('--exp_name', type=str, default='') 18 | parser.add_argument('--config', type=str, default='') 19 | 20 | parser.add_argument('--start_gpu', type=int, default=0) 21 | parser.add_argument('--gpu_num', type=int, default=-1) 22 | parser.add_argument('--batch_size', type=int, default=-1) 23 | 24 | parser.add_argument('--pretrained_path', type=str, default='') 25 | 26 | parser.add_argument('--datasets', nargs='+', type=str, default=['youtubevos']) 27 | parser.add_argument('--lr', type=float, default=-1.) 28 | parser.add_argument('--total_step', type=int, default=-1.) 29 | parser.add_argument('--start_step', type=int, default=-1.) 30 | 31 | parser.add_argument('--float16', action='store_true') 32 | parser.set_defaults(float16=False) 33 | parser.add_argument('--global_atrous_rate', type=int, default=1) 34 | parser.add_argument('--global_chunks', type=int, default=20) 35 | parser.add_argument('--no_local_parallel', dest='local_parallel', action='store_false') 36 | parser.set_defaults(local_parallel=True) 37 | args = parser.parse_args() 38 | 39 | config = importlib.import_module(args.config) 40 | cfg = config.cfg 41 | 42 | if args.exp_name != '': 43 | cfg.EXP_NAME = args.exp_name 44 | 45 | cfg.DIST_START_GPU = args.start_gpu 46 | if args.gpu_num > 0: 47 | cfg.TRAIN_GPUS = args.gpu_num 48 | if args.batch_size > 0: 49 | cfg.TRAIN_BATCH_SIZE = args.batch_size 50 | 51 | if args.pretrained_path != '': 52 | cfg.PRETRAIN_MODEL = args.pretrained_path 53 | 54 | if args.lr > 0: 55 | cfg.TRAIN_LR = args.lr 56 | if args.total_step > 0: 57 | cfg.TRAIN_TOTAL_STEPS = args.total_step 58 | cfg.TRAIN_START_SEQ_TRAINING_STEPS = int(args.total_step / 2) 59 | cfg.TRAIN_HARD_MINING_STEP = int(args.total_step / 2) 60 | if args.start_step > 0: 61 | cfg.TRAIN_START_STEP = args.start_step 62 | 63 | cfg.MODEL_FLOAT16_MATCHING = args.float16 64 | if 'RPCMp' in cfg.MODEL_MODULE: 65 | cfg.TRAIN_GLOBAL_ATROUS_RATE = [args.global_atrous_rate, 1, 1] 66 | else: 67 | cfg.TRAIN_GLOBAL_ATROUS_RATE = args.global_atrous_rate 68 | cfg.TRAIN_GLOBAL_CHUNKS = args.global_chunks 69 | cfg.TRAIN_LOCAL_PARALLEL = args.local_parallel 70 | 71 | # Use torch.multiprocessing.spawn to launch distributed processes 72 | mp.spawn(main_worker, nprocs=cfg.TRAIN_GPUS, args=(cfg,)) 73 | 74 | if __name__ == '__main__': 75 | main() 76 | 77 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/utils/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/checkpoint.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/utils/__pycache__/checkpoint.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/eval.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/utils/__pycache__/eval.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/utils/__pycache__/image.cpython-36.pyc -------------------------------------------------------------------------------- /utils/__pycache__/meters.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JerryX1110/RPCMVOS/8be9e7dbc0b8bab9b6aceb4f9ff862596b1da6c8/utils/__pycache__/meters.cpython-36.pyc -------------------------------------------------------------------------------- /utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import numpy as np 4 | 5 | def load_network_and_optimizer(net, opt, pretrained_dir, gpu): 6 | pretrained = torch.load( 7 | pretrained_dir, 8 | map_location=torch.device("cuda:"+str(gpu))) 9 | pretrained_dict = pretrained['state_dict'] 10 | model_dict = net.state_dict() 11 | pretrained_dict_update = {} 12 | pretrained_dict_remove = [] 13 | for k, v in pretrained_dict.items(): 14 | if k in model_dict: 15 | pretrained_dict_update[k] = v 16 | elif k[:7] == 'module.': 17 | if k[7:] in model_dict: 18 | pretrained_dict_update[k[7:]] = v 19 | else: 20 | pretrained_dict_remove.append(k) 21 | model_dict.update(pretrained_dict_update) 22 | net.load_state_dict(model_dict) 23 | opt.load_state_dict(pretrained['optimizer']) 24 | del(pretrained) 25 | return net.cuda(gpu), opt, pretrained_dict_remove 26 | 27 | def load_network_and_not_optimizer(net, opt, pretrained_dir, gpu): 28 | pretrained = torch.load( 29 | pretrained_dir, 30 | map_location=torch.device("cuda:"+str(gpu))) 31 | pretrained_dict = pretrained['state_dict'] 32 | model_dict = net.state_dict() 33 | pretrained_dict_update = {} 34 | pretrained_dict_remove = [] 35 | for k, v in pretrained_dict.items(): 36 | if k in model_dict: 37 | pretrained_dict_update[k] = v 38 | elif k[:7] == 'module.': 39 | if k[7:] in model_dict: 40 | pretrained_dict_update[k[7:]] = v 41 | else: 42 | pretrained_dict_remove.append(k) 43 | model_dict.update(pretrained_dict_update) 44 | net.load_state_dict(model_dict) 45 | del(pretrained) 46 | return net.cuda(gpu), opt, pretrained_dict_remove 47 | 48 | def load_network(net, pretrained_dir, gpu): 49 | pretrained = torch.load( 50 | pretrained_dir, 51 | map_location=torch.device("cuda:"+str(gpu))) 52 | pretrained_dict = pretrained['state_dict'] 53 | model_dict = net.state_dict() 54 | pretrained_dict_update = {} 55 | pretrained_dict_remove = [] 56 | for k, v in pretrained_dict.items(): 57 | if k in model_dict: 58 | pretrained_dict_update[k] = v 59 | elif k[:7] == 'module.': 60 | if k[7:] in model_dict: 61 | pretrained_dict_update[k[7:]] = v 62 | else: 63 | pretrained_dict_remove.append(k) 64 | model_dict.update(pretrained_dict_update) 65 | net.load_state_dict(model_dict) 66 | del(pretrained) 67 | return net.cuda(gpu), pretrained_dict_remove 68 | 69 | def save_network(net, opt, step, save_path, max_keep=8): 70 | try: 71 | if not os.path.exists(save_path): 72 | os.makedirs(save_path) 73 | save_file = 'save_step_%s.pth' % (step) 74 | save_dir = os.path.join(save_path, save_file) 75 | torch.save({'state_dict': net.state_dict(), 'optimizer': opt.state_dict()}, save_dir) 76 | except: 77 | save_path = './saved_models' 78 | if not os.path.exists(save_path): 79 | os.makedirs(save_path) 80 | save_file = 'save_step_%s.pth' % (step) 81 | save_dir = os.path.join(save_path, save_file) 82 | torch.save({'state_dict': net.state_dict(), 'optimizer': opt.state_dict()}, save_dir) 83 | 84 | all_ckpt = os.listdir(save_path) 85 | if len(all_ckpt) > max_keep: 86 | all_step = [] 87 | for ckpt_name in all_ckpt: 88 | step = int(ckpt_name.split('_')[-1].split('.')[0]) 89 | all_step.append(step) 90 | all_step = list(np.sort(all_step))[:-max_keep] 91 | for step in all_step: 92 | ckpt_path = os.path.join(save_path, 'save_step_%s.pth' % (step)) 93 | os.system('rm {}'.format(ckpt_path)) 94 | -------------------------------------------------------------------------------- /utils/eval.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import zipfile 3 | import os 4 | 5 | def zip_folder(source_folder, zip_dir): 6 | f = zipfile.ZipFile(zip_dir, 'w', zipfile.ZIP_DEFLATED) 7 | pre_len = len(os.path.dirname(source_folder)) 8 | for dirpath, dirnames, filenames in os.walk(source_folder): 9 | for filename in filenames: 10 | pathfile = os.path.join(dirpath, filename) 11 | arcname = pathfile[pre_len:].strip(os.path.sep) 12 | f.write(pathfile, arcname) 13 | f.close() -------------------------------------------------------------------------------- /utils/image.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import torch 4 | 5 | ## for visulization 6 | import cv2 7 | import time 8 | import gc 9 | import matplotlib.pyplot as plt 10 | import seaborn as sns 11 | 12 | plt.rcParams['font.sans-serif']=['SimHei'] 13 | plt.rcParams['axes.unicode_minus'] = False 14 | ## 15 | 16 | _palette = [0, 0, 0, 128, 0, 0, 0, 128, 0, 128, 128, 0, 0, 0, 128, 128, 0, 128, 0, 128, 128, 128, 128, 128, 64, 0, 0, 191, 0, 0, 64, 128, 0, 191, 128, 0, 64, 0, 128, 191, 0, 128, 64, 128, 128, 191, 128, 128, 0, 64, 0, 128, 64, 0, 0, 191, 0, 128, 191, 0, 0, 64, 128, 128, 64, 128, 22, 22, 22, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 29, 29, 29, 30, 30, 30, 31, 31, 31, 32, 32, 32, 33, 33, 33, 34, 34, 34, 35, 35, 35, 36, 36, 36, 37, 37, 37, 38, 38, 38, 39, 39, 39, 40, 40, 40, 41, 41, 41, 42, 42, 42, 43, 43, 43, 44, 44, 44, 45, 45, 45, 46, 46, 46, 47, 47, 47, 48, 48, 48, 49, 49, 49, 50, 50, 50, 51, 51, 51, 52, 52, 52, 53, 53, 53, 54, 54, 54, 55, 55, 55, 56, 56, 56, 57, 57, 57, 58, 58, 58, 59, 59, 59, 60, 60, 60, 61, 61, 61, 62, 62, 62, 63, 63, 63, 64, 64, 64, 65, 65, 65, 66, 66, 66, 67, 67, 67, 68, 68, 68, 69, 69, 69, 70, 70, 70, 71, 71, 71, 72, 72, 72, 73, 73, 73, 74, 74, 74, 75, 75, 75, 76, 76, 76, 77, 77, 77, 78, 78, 78, 79, 79, 79, 80, 80, 80, 81, 81, 81, 82, 82, 82, 83, 83, 83, 84, 84, 84, 85, 85, 85, 86, 86, 86, 87, 87, 87, 88, 88, 88, 89, 89, 89, 90, 90, 90, 91, 91, 91, 92, 92, 92, 93, 93, 93, 94, 94, 94, 95, 95, 95, 96, 96, 96, 97, 97, 97, 98, 98, 98, 99, 99, 99, 100, 100, 100, 101, 101, 101, 102, 102, 102, 103, 103, 103, 104, 104, 104, 105, 105, 105, 106, 106, 106, 107, 107, 107, 108, 108, 108, 109, 109, 109, 110, 110, 110, 111, 111, 111, 112, 112, 112, 113, 113, 113, 114, 114, 114, 115, 115, 115, 116, 116, 116, 117, 117, 117, 118, 118, 118, 119, 119, 119, 120, 120, 120, 121, 121, 121, 122, 122, 122, 123, 123, 123, 124, 124, 124, 125, 125, 125, 126, 126, 126, 127, 127, 127, 128, 128, 128, 129, 129, 129, 130, 130, 130, 131, 131, 131, 132, 132, 132, 133, 133, 133, 134, 134, 134, 135, 135, 135, 136, 136, 136, 137, 137, 137, 138, 138, 138, 139, 139, 139, 140, 140, 140, 141, 141, 141, 142, 142, 142, 143, 143, 143, 144, 144, 144, 145, 145, 145, 146, 146, 146, 147, 147, 147, 148, 148, 148, 149, 149, 149, 150, 150, 150, 151, 151, 151, 152, 152, 152, 153, 153, 153, 154, 154, 154, 155, 155, 155, 156, 156, 156, 157, 157, 157, 158, 158, 158, 159, 159, 159, 160, 160, 160, 161, 161, 161, 162, 162, 162, 163, 163, 163, 164, 164, 164, 165, 165, 165, 166, 166, 166, 167, 167, 167, 168, 168, 168, 169, 169, 169, 170, 170, 170, 171, 171, 171, 172, 172, 172, 173, 173, 173, 174, 174, 174, 175, 175, 175, 176, 176, 176, 177, 177, 177, 178, 178, 178, 179, 179, 179, 180, 180, 180, 181, 181, 181, 182, 182, 182, 183, 183, 183, 184, 184, 184, 185, 185, 185, 186, 186, 186, 187, 187, 187, 188, 188, 188, 189, 189, 189, 190, 190, 190, 191, 191, 191, 192, 192, 192, 193, 193, 193, 194, 194, 194, 195, 195, 195, 196, 196, 196, 197, 197, 197, 198, 198, 198, 199, 199, 199, 200, 200, 200, 201, 201, 201, 202, 202, 202, 203, 203, 203, 204, 204, 204, 205, 205, 205, 206, 206, 206, 207, 207, 207, 208, 208, 208, 209, 209, 209, 210, 210, 210, 211, 211, 211, 212, 212, 212, 213, 213, 213, 214, 214, 214, 215, 215, 215, 216, 216, 216, 217, 217, 217, 218, 218, 218, 219, 219, 219, 220, 220, 220, 221, 221, 221, 222, 222, 222, 223, 223, 223, 224, 224, 224, 225, 225, 225, 226, 226, 226, 227, 227, 227, 228, 228, 228, 229, 229, 229, 230, 230, 230, 231, 231, 231, 232, 232, 232, 233, 233, 233, 234, 234, 234, 235, 235, 235, 236, 236, 236, 237, 237, 237, 238, 238, 238, 239, 239, 239, 240, 240, 240, 241, 241, 241, 242, 242, 242, 243, 243, 243, 244, 244, 244, 245, 245, 245, 246, 246, 246, 247, 247, 247, 248, 248, 248, 249, 249, 249, 250, 250, 250, 251, 251, 251, 252, 252, 252, 253, 253, 253, 254, 254, 254, 255, 255, 255] 17 | 18 | 19 | def label2colormap(label): 20 | 21 | m = label.astype(np.uint8) 22 | r,c = m.shape 23 | cmap = np.zeros((r,c,3), dtype=np.uint8) 24 | cmap[:,:,0] = (m&1)<<7 | (m&8)<<3 | (m&64)>>1 25 | cmap[:,:,1] = (m&2)<<6 | (m&16)<<2 | (m&128)>>2 26 | cmap[:,:,2] = (m&4)<<5 | (m&32)<<1 27 | return cmap 28 | 29 | def masked_image(image, colored_mask, mask, alpha = 0.7): 30 | mask = np.expand_dims(mask > 0, axis=0) 31 | mask = np.repeat(mask, 3, axis=0) 32 | show_img = (image * alpha + colored_mask * (1 - alpha)) * mask + image * (1 - mask) 33 | return show_img 34 | 35 | def save_image(image, path): 36 | im = Image.fromarray(np.uint8(image * 255.).transpose((1, 2, 0))) 37 | im.save(path) 38 | 39 | def save_mask(mask_tensor, path): 40 | mask = mask_tensor.cpu().numpy().astype('uint8') 41 | mask = Image.fromarray(mask).convert('P') 42 | mask.putpalette(_palette) 43 | mask.save(path) 44 | 45 | def flip_tensor(tensor, dim=0): 46 | inv_idx = torch.arange(tensor.size(dim) - 1, -1, -1, device=tensor.device).long() 47 | tensor = tensor.index_select(dim, inv_idx) 48 | return tensor 49 | 50 | -------------------------------------------------------------------------------- /utils/learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | 4 | def adjust_learning_rate(optimizer, base_lr, p, itr, max_itr, warm_up_steps=1000, is_cosine_decay=False, min_lr=1e-5): 5 | 6 | if itr < warm_up_steps: 7 | now_lr = base_lr * itr / warm_up_steps 8 | else: 9 | itr = itr - warm_up_steps 10 | max_itr = max_itr - warm_up_steps 11 | if is_cosine_decay: 12 | now_lr = base_lr * (math.cos(math.pi * itr / (max_itr + 1)) + 1.) * 0.5 13 | else: 14 | now_lr = base_lr * (1 - itr / (max_itr + 1)) ** p 15 | 16 | if now_lr < min_lr: 17 | now_lr = min_lr 18 | 19 | for param_group in optimizer.param_groups: 20 | param_group['lr'] = now_lr 21 | return now_lr 22 | 23 | 24 | def get_trainable_params(model, base_lr, weight_decay, beta_wd=True): 25 | params = [] 26 | for key, value in model.named_parameters(): 27 | if not value.requires_grad: 28 | continue 29 | wd = weight_decay 30 | if 'beta' in key: 31 | if not beta_wd: 32 | wd = 0. 33 | params += [{"params": [value], "lr": base_lr, "weight_decay": wd}] 34 | return params -------------------------------------------------------------------------------- /utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /utils/metric.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | def pytorch_iou(pred, target, obj_num, epsilon=1e-6): 4 | ''' 5 | pred: [bs, h, w] 6 | target: [bs, h, w] 7 | obj_num: [bs] 8 | ''' 9 | bs = obj_num.size(0) 10 | all_iou = [] 11 | for idx in range(bs): 12 | now_pred = pred[idx].unsqueeze(0) 13 | now_target = target[idx].unsqueeze(0) 14 | now_obj_num = obj_num[idx] 15 | 16 | obj_ids = torch.arange(0, now_obj_num + 1, device=now_pred.device).int().view(-1, 1, 1) 17 | if obj_ids.size(0) == 1: # only contain background 18 | continue 19 | else: 20 | obj_ids = obj_ids[1:] 21 | now_pred = (now_pred == obj_ids).float() 22 | now_target = (now_target == obj_ids).float() 23 | 24 | intersection = (now_pred * now_target).sum((1, 2)) 25 | union = ((now_pred + now_target) > 0).float().sum((1, 2)) 26 | 27 | now_iou = (intersection + epsilon) / (union + epsilon) 28 | 29 | all_iou.append(now_iou.mean()) 30 | if len(all_iou) > 0: 31 | all_iou = torch.stack(all_iou).mean() 32 | else: 33 | all_iou = torch.ones((1), device=pred.device) 34 | return all_iou 35 | --------------------------------------------------------------------------------