├── .gitignore ├── LICENSE ├── README.md ├── config └── datasets.yaml ├── data └── datasets │ └── .gitkeep ├── docs ├── _config.yml ├── bibtex.txt ├── comp_video_sota.png ├── index.md ├── motorcross-jump.gif ├── pseudo_label_example.png ├── pseudo_label_generator.png ├── static_model.png ├── training_instruction.md └── video_model.png ├── flownet2 ├── .gitignore ├── Dockerfile ├── LICENSE ├── README.md ├── __init__.py ├── convert.py ├── datasets.py ├── download_caffe_models.sh ├── image.png ├── install.sh ├── launch_docker.sh ├── losses.py ├── main.py ├── models.py ├── networks │ ├── FlowNetC.py │ ├── FlowNetFusion.py │ ├── FlowNetS.py │ ├── FlowNetSD.py │ ├── __init__.py │ ├── channelnorm_package │ │ ├── __init__.py │ │ ├── channelnorm.py │ │ ├── channelnorm_cuda.cc │ │ ├── channelnorm_kernel.cu │ │ ├── channelnorm_kernel.cuh │ │ └── setup.py │ ├── correlation_package │ │ ├── __init__.py │ │ ├── correlation.py │ │ ├── correlation_cuda.cc │ │ ├── correlation_cuda_kernel.cu │ │ ├── correlation_cuda_kernel.cuh │ │ └── setup.py │ ├── resample2d_package │ │ ├── __init__.py │ │ ├── resample2d.py │ │ ├── resample2d_cuda.cc │ │ ├── resample2d_kernel.cu │ │ ├── resample2d_kernel.cuh │ │ └── setup.py │ └── submodules.py ├── run-caffe2pytorch.sh └── utils │ ├── __init__.py │ ├── flow_utils.py │ ├── frame_utils.py │ ├── param_utils.py │ └── tools.py ├── generate_pseudo_labels.py ├── inference.py ├── libs ├── __init__.py ├── datasets │ ├── __init__.py │ ├── transforms.py │ └── video_datasets.py ├── modules │ ├── __init__.py │ ├── convgru.py │ └── non_local_dot_product.py ├── networks │ ├── __init__.py │ ├── models.py │ ├── pseudo_label_generator.py │ ├── rcrnet.py │ └── resnet_dilation.py └── utils │ ├── __init__.py │ ├── logger.py │ ├── metric.py │ └── pyt_utils.py ├── models ├── .gitkeep └── checkpoints │ └── .gitkeep ├── train.py └── train_fgplg.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Mac 2 | .DS_Store 3 | # Editor 4 | .vscode/ 5 | # Data 6 | *.pth 7 | *.pth.tar 8 | data/results 9 | data/pseudo-labels 10 | 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | .hypothesis/ 58 | .pytest_cache/ 59 | 60 | # Translations 61 | *.mo 62 | *.pot 63 | 64 | # Django stuff: 65 | *.log 66 | local_settings.py 67 | db.sqlite3 68 | 69 | # Flask stuff: 70 | instance/ 71 | .webassets-cache 72 | 73 | # Scrapy stuff: 74 | .scrapy 75 | 76 | # Sphinx documentation 77 | docs/_build/ 78 | 79 | # PyBuilder 80 | target/ 81 | 82 | # Jupyter Notebook 83 | .ipynb_checkpoints 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Kinpzz 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # RCRNet-Pytorch 2 | 3 | 4 | This repository contains the PyTorch implementation for 5 | 6 | **Semi-Supervised Video Salient Object Detection Using Pseudo-Labels**
7 | Pengxiang Yan, Guanbin Li, Yuan Xie, Zhen Li, Chuan Wang, Tianshui Chen, Liang Lin
8 | **ICCV 2019** | [[Project Page](https://kinpzz.com/publication/iccv19_semi_vsod/)] | [[Arxiv](https://arxiv.org/abs/1908.04051)] | [[CVF-Open-Access](http://openaccess.thecvf.com/content_ICCV_2019/html/Yan_Semi-Supervised_Video_Salient_Object_Detection_Using_Pseudo-Labels_ICCV_2019_paper.html)]
9 | 10 | ## Usage 11 | 12 | ### Requirements 13 | 14 | This code is tested on Ubuntu 16.04, Python=3.6 (via Anaconda3), PyTorch=0.4.1, CUDA=9.0. 15 | 16 | ``` 17 | # Install PyTorch=0.4.1 18 | $ conda install pytorch==0.4.1 torchvision==0.2.1 cuda90 -c pytorch 19 | 20 | # Install other packages 21 | $ pip install pyyaml==3.13 addict==2.2.0 tqdm==4.28.1 scipy==1.1.0 22 | ``` 23 | ### Datasets 24 | Our proposed RCRNet is evaluated on three public benchmark VSOD datsets including [VOS](http://cvteam.net/projects/TIP18-VOS/VOS.html), [DAVIS](https://davischallenge.org/) (version: 2016, 480p), and [FBMS](https://lmb.informatik.uni-freiburg.de/resources/datasets/). Please orginaize the datasets according to `config/datasets.yaml` and put them in `data/datasets`. Or you can set argument `--data` to the path of the dataset folder. 25 | 26 | ### Evaluation 27 | #### Comparison with State-of-the-Art 28 | ![comp_video_sota](docs/comp_video_sota.png) 29 | If you want to compare with our method: 30 | 31 | **Option 1:** you can download the saliency maps predicted by our model from [Google Drive](https://drive.google.com/open?id=1feY3GdNBS-LUBt0UDWwpA3fl9yHI4Vxr) / [Baidu Pan](https://pan.baidu.com/s/1oXBr9qxyF-8vvilvV5kcPg) (passwd: u079). 32 | 33 | **Option 2:** Or you can use our trained model for inference. The weights of trained model are available at [Google Drive](https://drive.google.com/open?id=1TSmi1DyKIvuzuXE1aw7t_ygmcUmjYnN_) / [Baidu Pan](https://pan.baidu.com/s/1PLoajL6X_s29I-4mreSuSQ) (passwd: 6pi3). Then run the following command for inference. 34 | ``` 35 | # VOS 36 | $ CUDA_VISIBLE_DEVICES=0 python inference.py --data data/datasets --dataset VOS --split test 37 | 38 | # DAVIS 39 | $ CUDA_VISIBLE_DEVICES=0 python inference.py --data data/datasets --dataset DAVIS --split val 40 | 41 | # FBMS 42 | $ CUDA_VISIBLE_DEVICES=0 python inference.py --data data/datasets --dataset FBMS --split test 43 | ``` 44 | 45 | Then, you can evaluate the saliency maps using your own evaluation code. 46 | 47 | ### Training 48 | If you want to train our proposed model from scratch (including using pseudo-labels), please refer to our paper and the [training instruction](docs/training_instruction.md) carefully. 49 | 50 | ## Citation 51 | If you find this work helpful, please consider citing 52 | ``` 53 | @inproceedings{yan2019semi, 54 | title={Semi-Supervised Video Salient Object Detection Using Pseudo-Labels}, 55 | author={Yan, Pengxiang and Li, Guanbin and Xie, Yuan and Li, Zhen and Wang, Chuan and Chen, Tianshui and Lin, Liang}, 56 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 57 | pages={7284--7293}, 58 | year={2019} 59 | } 60 | ``` 61 | 62 | ## Acknowledge 63 | Thanks to the third-party libraries: 64 | * [deeplab-pytorch](https://github.com/kazuto1011/deeplab-pytorch) by kazuto1011 65 | * [flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch) by NVIDIA 66 | * [pytorch-segmentation-toolbox](https://github.com/speedinghzl/pytorch-segmentation-toolbox) by speedinghzl 67 | * [Non-local_pytorch](https://github.com/AlexHex7/Non-local_pytorch) by AlexHex7 68 | -------------------------------------------------------------------------------- /config/datasets.yaml: -------------------------------------------------------------------------------- 1 | # Video Saliency Dataset Config files 2 | # 3 | # Author: Pengxiang Yan 4 | # Email: yanpx (at) mail2.sysu.edu.cn 5 | 6 | # MSRA-B Dataset 7 | # url: https://mmcheng.net/msra10k/ 8 | MSRA-B: 9 | image_dir: imgs 10 | label_dir: gt 11 | split_dir: ImageSets 12 | image_ext: .jpg 13 | label_ext: .png 14 | # HKU-IS Dataset 15 | # url: https://i.cs.hku.hk/~gbli/deep_saliency.html 16 | 17 | HKU-IS: 18 | image_dir: imgs 19 | label_dir: gt 20 | split_dir: ImageSets 21 | image_ext: .png 22 | label_ext: .png 23 | 24 | # DAVIS 2016 Dataset: Densely Annotated VIdeo Segmentation 25 | # url: https://davischallenge.org/ 26 | DAVIS2016: 27 | image_dir: JPEGImages/480p 28 | label_dir: Annotations/480p 29 | image_ext: .jpg 30 | label_ext: .png 31 | default_label_interval: 1 # every "default_label_interval" image gives a label 32 | video_split: 33 | train: ['bear', 'bmx-bumps', 'boat', 'breakdance-flare', 34 | 'bus', 'car-turn', 'dance-jump', 'dog-agility', 35 | 'drift-turn', 'elephant', 'flamingo', 'hike', 36 | 'hockey', 'horsejump-low', 'kite-walk', 'lucia', 37 | 'mallard-fly', 'motocross-bumps', 'motorbike', 38 | 'paragliding', 'rhino', 'rollerblade', 'scooter-gray', 39 | 'soccerball', 'stroller', 'surf', 'swing', 'tennis', 40 | 'train'] 41 | val: ['blackswan', 'bmx-trees', 'breakdance', 'camel', 42 | 'car-roundabout', 'car-shadow', 'cows', 'dance-twirl', 43 | 'dog', 'drift-chicane', 'drift-straight', 'goat', 44 | 'horsejump-high', 'kite-surf', 'libby', 'motocross-jump', 45 | 'paragliding-launch', 'parkour', 'scooter-black', 46 | 'soapbox'] 47 | 48 | # FBMS: Freiburg-Berkeley Motion Segmentation Dataset 49 | # url: https://lmb.informatik.uni-freiburg.de/resources/datasets/ 50 | FBMS: 51 | image_dir: JPEGImages 52 | label_dir: Annotations 53 | image_ext: .jpg 54 | label_ext: .png 55 | default_label_interval: 1 56 | video_split: 57 | train: ['bear01', 'bear02', 'cars2', 'cars3', 58 | 'cars6', 'cars7', 'cars8', 'cars9', 59 | 'cats02', 'cats04' ,'cats05', 'cats07', 60 | 'ducks01', 'horses01', 'horses03', 'horses06', 61 | 'lion02', 'marple1', 'marple10', 'marple11', 62 | 'marple13', 'marple3', 'marple5', 'marple8', 63 | 'meerkats01', 'people04', 'people05', 'rabbits01', 64 | 'rabbits05'] 65 | test: ['camel01', 'cars1', 'cars10', 'cars4', 66 | 'cars5', 'cats01', 'cats03', 'cats06', 67 | 'dogs01', 'dogs02', 'farm01', 'giraffes01', 68 | 'goats01', 'horses02', 'horses04', 'horses05', 69 | 'lion01', 'marple12', 'marple2', 'marple4', 70 | 'marple6', 'marple7', 'marple9', 'people03', 71 | 'people1', 'people2', 'rabbits02', 'rabbits03', 72 | 'rabbits04', 'tennis'] 73 | 74 | # SegTrack v2 is a video segmentation dataset with full pixel-level annotations 75 | # on multiple objects at each frame within each video. 76 | # url: http://web.engr.oregonstate.edu/~lif/SegTrack2/dataset.html 77 | SegTrackv2: 78 | image_dir: JPEGImages 79 | label_dir: .png 80 | label_dir: .png 81 | default_label_interval: 1 82 | video_split: 83 | trainval: ['bird_of_paradise', 'birdfall', 'frog', 84 | 'monkey', 'parachute', 'soldier', 'worm'] 85 | 86 | # TIP18: A Benchmark Dataset and Saliency-Guided Stacked Autoencoders for Video-Based Salient Object Detection 87 | # url: http://cvteam.net/projects/TIP18-VOS/VOS.html 88 | VOS: 89 | image_dir: JPEGImages 90 | label_dir: Annotations 91 | image_ext: .jpg 92 | label_ext: .png 93 | default_label_interval: 15 94 | video_split: 95 | train: ['1', '102', '103', '104', '107', '11', '111', '112', '114', '115', 96 | '117', '118', '119', '12', '123', '124', '125', '126', '127', '130', 97 | '131', '143', '145', '146', '147', '15', '156', '164', '17', '171', 98 | '176', '192', '197', '198', '199', '2', '20', '200', '201', '202', 99 | '205', '206', '207', '212', '215', '216', '217', '22', '220', '221', 100 | '222', '225', '226', '229', '23', '230', '231', '233', '235', '236', 101 | '25', '250', '251', '252', '255', '256', '257', '258', '259', '26', 102 | '261', '262', '263', '265', '267', '268', '27', '270', '271', '272', 103 | '273', '275', '276', '30', '32', '33', '34', '35', '38', '4', '40', 104 | '44', '45', '46', '48', '50', '51', '52', '53', '55', '6', '61', '64', 105 | '66', '67', '69', '70', '71', '73', '78', '80', '81', '83', '87', '88', 106 | '9', '90', '96', '99'] 107 | val: ['10', '101', '105', '109', '113', '120', '13', '133', '14', '148', '158', 108 | '18', '180', '196', '203', '204', '208', '209', '213', '219', '223', '228', 109 | '24','260', '269', '28', '31', '37', '39', '5', '57', '62', '7', '72', '77', 110 | '84', '92','94', '95', '97'] 111 | test: ['100', '106', '108', '110', '121', '132', '134', '16', '172', '189', '19', 112 | '194', '195', '21', '210', '211', '214', '224', '227', '232', '254', '264', 113 | '266', '274', '29', '3', '36', '42', '43', '47', '49', '58', '65', '68', '74', 114 | '76', '8', '85', '93', '98'] 115 | easy: ['1', '10', '101', '11', '12', '13', '130', '131', '132', '133', '134', 116 | '14', '143', '15', '16', '17', '18', '19', '192', '194', '195', '196', '197', 117 | '198', '199', '2', '20', '200', '201', '202', '203', '204', '205', '206', '207', 118 | '208', '209', '21', '210', '211', '22', '23', '233', '24', '25', '254', '255', 119 | '256', '257', '258', '259', '26', '260', '261', '262', '263', '264', '265', '266', 120 | '267', '268', '269', '27', '270', '271', '272', '273', '274', '275', '276', '28', 121 | '29', '3', '30', '31', '32', '33', '34', '4', '42', '5', '50', '51', '6', '68', 122 | '7', '76', '78', '8', '88', '9', '90', '92', '94', '96', '98'] 123 | normal: ['100', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', 124 | '112', '113', '114', '115', '117', '118', '119', '120', '121', '123', '124', '125', 125 | '126', '127', '145', '146', '147', '148', '156', '158', '164', '171', '172', '176', 126 | '180', '189', '212', '213', '214', '215', '216', '217', '219', '220', '221', '222', 127 | '223', '224', '225', '226', '227', '228', '229', '230', '231', '232', '235', '236', 128 | '250', '251', '252', '35', '36', '37', '38', '39', '40', '43', '44', '45', '46', '47', 129 | '48', '49', '52', '53', '55', '57', '58', '61', '62', '64', '65', '66', '67', '69', 130 | '70', '71', '72', '73', '74', '77', '80', '81', '83', '84', '85', '87', '93', '95', 131 | '97', '99'] 132 | easy-train: ['206', '90', '268', '30', '26', '201', '131', '271', '255', '276', '270', '15', '25', 133 | '1', '50', '20', '51', '88', '96', '130', '197', '27', '4', '205', '199', '78', '261', 134 | '2', '207', '198', '272', '192', '202', '258', '262', '32', '12', '6', '265', '263', 135 | '200', '11', '143', '267', '9', '22', '259', '275', '33', '17', '257', '34', '23', '233', 136 | '273', '256'] 137 | easy-val: ['31', '94', '204', '92', '133', '196', '203', '209', '208', '5', '101', '24', '260', 138 | '18', '13', '269', '28', '10', '7', '14'] 139 | easy-test: ['29', '42', '194', '211', '254', '21', '264', '274', '19', '134', '76', '8', '98', 140 | '210', '266', '16', '68', '132', '3', '195'] 141 | normal-train: ['115', '35', '127', '217', '87', '55', '147', '126', '125', '114', '123', '111', 142 | '235', '250', '103', '71', '81', '212', '231', '38', '252', '222', '69', '145', '176', 143 | '221', '225', '44', '45', '171', '124', '118', '229', '66', '112', '156', '220', '64', 144 | '216', '107', '80', '73', '48', '251', '236', '61', '102', '40', '83', '46', '67', '99', 145 | '164', '119', '53', '146', '226', '215', '230', '117', '52', '104', '70'] 146 | normal-val: ['148', '57', '77', '158', '228', '37', '39', '62', '105', '180', '97', '109', '219', 147 | '72', '84', '113', '120', '213', '95', '223'] 148 | normal-test: ['189', '85', '108', '110', '224', '121', '172', '43', '232', '36', '93', '47', '74', 149 | '227', '214', '49', '65', '106', '100', '58'] -------------------------------------------------------------------------------- /data/datasets/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/data/datasets/.gitkeep -------------------------------------------------------------------------------- /docs/_config.yml: -------------------------------------------------------------------------------- 1 | theme: jekyll-theme-cayman -------------------------------------------------------------------------------- /docs/bibtex.txt: -------------------------------------------------------------------------------- 1 | @inproceedings{yan2019semi, 2 | title={Semi-Supervised Video Salient Object Detection Using Pseudo-Labels}, 3 | author={Yan, Pengxiang and Li, Guanbin and Xie, Yuan and Li, Zhen and Wang, Chuan and Chen, Tianshui and Lin, Liang}, 4 | booktitle={Proceedings of the IEEE International Conference on Computer Vision}, 5 | pages={7284--7293}, 6 | year={2019} 7 | } 8 | -------------------------------------------------------------------------------- /docs/comp_video_sota.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/comp_video_sota.png -------------------------------------------------------------------------------- /docs/index.md: -------------------------------------------------------------------------------- 1 | ## Abstract 2 | 3 | Deep learning-based video salient object detection has recently achieved great success with its performance significantly outperforming any other unsupervised methods. However, existing data-driven approaches heavily rely on a large quantity of pixel-wise annotated video frames to deliver such promising results. In this paper, we address the semi-supervised video salient object detection task using pseudo-labels. Specifically, we present an effective video saliency detector that consists of a spatial refinement network and a spatiotemporal module. Based on the same refinement network and motion information in terms of optical flow, we further propose a novel method for generating pixel-level pseudo-labels from sparsely annotated frames. By utilizing the generated pseudo-labels together with a part of manual annotations, our video saliency detector learns spatial and temporal cues for both contrast inference and coherence enhancement, thus producing accurate saliency maps. Experimental results demonstrate that our proposed semi-supervised method even greatly outperforms all the state-of-the-art fully supervised methods across three public benchmarks of VOS, DAVIS, and FBMS. 4 | 5 |
6 | 7 | ## Paper 8 | 9 | Pengxiang Yan, Guanbin Li, Yuan Xie, Zhen Li, Chuan Wang, Tianshui Chen, Liang Lin, **Semi-Supervised Video Salient Object Detection Using Pseudo-Labels**, the IEEE International Conference on Computer Vision (ICCV), 2019, pp. 7284-7293. [[Arxiv](https://arxiv.org/abs/1908.04051)] [[Code](https://github.com/Kinpzz/RCRNet-Pytorch)] [[BibTex](https://github.com/Kinpzz/RCRNet-Pytorch/raw/master/docs/bibtex.txt)] 10 | 11 | ## Motivation 12 | 13 | * Existing data-driven approaches of video salient object detection heavily rely on a large quantity of densely annotated video frames to deliver promising results. 14 | 15 | * Consecutive video frames share small differences but will take a lot of efforts to densely annotate them and the labeling consistency is also hard to guarantee. 16 | 17 | ## Contributions 18 | 19 | - We propose a refinement network (RCRNet) equipped with a nonlocally enhanced recurrent (NER) module for spatiotemporal coherence modeling. 20 | 21 | - We propose a flow-guided pseudo-label generator (FGPLG) to generate pseudo-labels of intervals based on sparse annotations. 22 | 23 | - As shown in Figure. 1, our model can produce reasonable and consistent pseudo-labels, which can even improve the boundary details (Example a) and overcome the labeling ambiguity between frames (Example b). 24 | 25 |
26 | 27 | - Experimental results show that utilizing the joint supervision of pseudo-labels and sparse annotations can further improve the model performance. 28 | 29 | 30 | 31 | ## Architecture 32 | 33 | ### RCRNet 34 | 35 | 36 | 37 | ### RCRNet+NER 38 | 39 | ![video_model](video_model.png) 40 | 41 | ### Flow-Guided Pseudo-Label Generator 42 | 43 | ![pseudo_label_generator](pseudo_label_generator.png) 44 | 45 | ## Results 46 | 47 | ### Quantitative Comparison 48 | ![comp_video_sota](comp_video_sota.png) 49 | 50 | ## Downloads 51 | 52 | * Pre-computed saliency maps the validation set of DAVIS2016, the test set of FBMS, and the test set of VOS. [[Google Drive](https://drive.google.com/open?id=1feY3GdNBS-LUBt0UDWwpA3fl9yHI4Vxr)] [[Baidu Pan](https://pan.baidu.com/s/1oXBr9qxyF-8vvilvV5kcPg) (passwd: u079)] 53 | 54 | ## Q&A 55 | 56 | Q1: What is the difference between the semi-supervised strategies mentioned in semi-supervised video salient object detection (VSOD) and semi-supevised video object segmentation (VOS)? 57 | 58 | A1: VOS can be categories into **semi-supervised** and **unsupervised** methods when referring to different **testing** schemes. Semi-supervised VOS will provide the annotation of the first frame when testing. Video salient object detection (VSOD) is more similar to unsupervised VOS as both of them do not resort to labeled frames during testing. Here, our proposed method use only a part of labeled frames for **training** and that makes we call it a semi-supervised VSOD method. 59 | 60 | Q2: Are all the V+D methods in Table 1 fully-supervised? It might lack comparison with other semi-supervised video salient object detection methods (VSOD) and mask propagation based segmentation methods. 61 | 62 | A2: As far as we know, when referring to the training scheme, we are the first to adopt a semi-supervised strategy for VSOD. Thus, all the V+D methods in Table 1 (except ours) are trained under full supervision. Since most mask propagation based methods are designed for semi-supervised VOS, we only compare with unsupervised VOS methods in Section 5 of the paper. 63 | 64 | Q3: Is PDB fully-supervised or unsupervised? 65 | 66 | A3: PDB is fully-supervised during training but unsupervised during testing. Specifically, PDB is an algorithm for both video salient object detection and unsupervised video object segmentation and it is trained under fully supervision. 67 | 68 | ## Contact 69 | 70 | If you have any question, please send us an email at yanpx(AT)mail2.sysu.edu.cn -------------------------------------------------------------------------------- /docs/motorcross-jump.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/motorcross-jump.gif -------------------------------------------------------------------------------- /docs/pseudo_label_example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/pseudo_label_example.png -------------------------------------------------------------------------------- /docs/pseudo_label_generator.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/pseudo_label_generator.png -------------------------------------------------------------------------------- /docs/static_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/static_model.png -------------------------------------------------------------------------------- /docs/training_instruction.md: -------------------------------------------------------------------------------- 1 | ## Training Instruction 2 | 3 | ### Training RCRNet+NER 4 | If you want to train the proposed RCRNet from scratch, please refer to our paper and the following instruction carefully. 5 | 6 | The proposed RCRNet is built upon an ResNet-50 pretrained on ImageNet. 7 | 8 | 9 | 10 | **First**, we use two image saliency datasets, i.e., [MSRA-B](https://mmcheng.net/msra10k/) and [HKU-IS](https://i.cs.hku.hk/~gbli/deep_saliency.html), to pretrain the RCRNet (Figure 2), which contains a spatial feature extractor and a pixel-wise classifer. Here, we provide the weights of RCRNet pretrained on image saliency datasets at at [Google Drive](https://drive.google.com/open?id=1S7nao9WEhIiTmTC-E0nujMxm5Emypti9) or [Baidu Pan](https://pan.baidu.com/s/196cUbTInWJKd8FmiP9Jv_A) (passwd: j839). For simplicity, we do not provide the training code of this step. If you want to train this step you can implement your own training code. 11 | 12 | ![video_model](video_model.png) 13 | 14 | **Second**, we use the RCRNet pretrained on image saliency datasets as the backbone. Then we combine the training set of three video saliency datasets including VOS, DAVIS, and FBMS, to train the full video model, i.e., RCRNet equipped with NER module (Figure 3). You can run the following commands to train the RCRNet+NER. 15 | ``` 16 | $ CUDA_VISIBLE_DEVICES=0 python train.py \ 17 | --data data/datasets \ 18 | --checkpoint models/image_pretrained_model.pth 19 | ``` 20 | 21 | ### Using psedo-labels for training 22 | 23 | ![pseudo_label_generator](pseudo_label_generator.png) 24 | 25 | As for the second step, if you want train the RCRNet+NER using generated pseudo-labels for joint supervision. You can use our proposed flow-guied pseudo-label generator (FGPLG, Figure 4) to generate the pesdu-labels with a part of ground truth images. 26 | 27 | Note that the FGPLG requires flownet2.0 for flow estimation. Thus, please install the pytorch implementation of flownet2.0 using the following commands. 28 | ``` 29 | # Install FlowNet 2.0 (implemented by NVIDIA) 30 | $ cd flownet2 31 | $ bash install.sh 32 | ``` 33 | 34 | #### Generating pseudo-labels using FGPLG 35 | We provide the weights of FGPLG which is trained under the supervision of 20% ground truth images at [Baidu Pan](https://pan.baidu.com/s/1dw8O2Ua5pKmOKYVgKRyADQ) (passwd: hbsu). You can generate the pseduo-labels by 36 | ``` 37 | $ CUDA_VISIBLE_DEVICES=0 python generate_pseudo_labels.py \ 38 | --data data/datasets \ 39 | --checkpoint models/pseudo_label_generator_5.pth \ 40 | --pseudo-label-folder data/pseudo-labels \ 41 | --label_interval 5 \ 42 | --frame_between_label_num 1 43 | ``` 44 | 45 | Then you can train the video model under the joint supervision of pseudo-labels. 46 | ``` 47 | $ CUDA_VISIBLE_DEVICES=0 python train.py \ 48 | --data data/datasets \ 49 | --checkpoint models/image_pretrained_model.pth \ 50 | --pseudo-label-folder data/pseudo-labels/1_5 51 | ``` 52 | 53 | #### (Optional) Training FGPLG 54 | You can also train the FGPLG using other propotions of ground truth images by 55 | 56 | (Note that need to download the pretrained model of [Flownet2](https://github.com/NVIDIA/flownet2-pytorch#converted-caffe-pre-trained-models)[620MB]) 57 | ``` 58 | # set l 59 | $ CUDA_VISIBLE_DEVICES=0 python train_fgplg.py \ 60 | --data data/datasets \ 61 | --label_interval l \ 62 | --checkpoint models/image_pretrained_model.pth \ 63 | --flownet-checkpoint models/FlowNet2_checkpoint.pth.tar 64 | ``` 65 | 66 | Then you can use the trained FGPLG to generate pseudo labels based different numbers of GT images. 67 | ``` 68 | # set l and m 69 | $ CUDA_VISIBLE_DEVICES=0 python generate_pseudo_labels.py \ 70 | --data data/datasets \ 71 | --checkpoint models/pseudo_label_generator_m.pth \ 72 | --label_interval l \ 73 | --frame_between_label_num m \ 74 | --pseudo-label-folder data/pseudo-labels 75 | ``` 76 | Finally, you can train the video model under the joint supervision of pseudo-labels. 77 | ``` 78 | # set l and m 79 | $ CUDA_VISIBLE_DEVICES=0 python train.py \ 80 | --data data/datasets \ 81 | --checkpoint models/image_pretrained_model.pth \ 82 | --pseudo-label-folder data/pseudo-labels/m_l 83 | ``` 84 | -------------------------------------------------------------------------------- /docs/video_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/video_model.png -------------------------------------------------------------------------------- /flownet2/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | .torch 3 | _ext 4 | *.o 5 | work 6 | work/* 7 | _ext/ -------------------------------------------------------------------------------- /flownet2/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04 2 | 3 | RUN apt-get update && apt-get install -y rsync htop git openssh-server python-pip 4 | 5 | RUN pip install --upgrade pip 6 | 7 | RUN pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl 8 | RUN pip install torchvision cffi tensorboardX 9 | 10 | RUN pip install tqdm scipy scikit-image colorama==0.3.7 11 | RUN pip install setproctitle pytz ipython -------------------------------------------------------------------------------- /flownet2/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright 2017 NVIDIA CORPORATION 2 | 3 | Licensed under the Apache License, Version 2.0 (the "License"); 4 | you may not use this file except in compliance with the License. 5 | You may obtain a copy of the License at 6 | 7 | http://www.apache.org/licenses/LICENSE-2.0 8 | 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. -------------------------------------------------------------------------------- /flownet2/README.md: -------------------------------------------------------------------------------- 1 | # flownet2-pytorch 2 | 3 | Pytorch implementation of [FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks](https://arxiv.org/abs/1612.01925). 4 | 5 | Multiple GPU training is supported, and the code provides examples for training or inference on [MPI-Sintel](http://sintel.is.tue.mpg.de/) clean and final datasets. The same commands can be used for training or inference with other datasets. See below for more detail. 6 | 7 | Inference using fp16 (half-precision) is also supported. 8 | 9 | For more help, type
10 | 11 | python main.py --help 12 | 13 | ## Network architectures 14 | Below are the different flownet neural network architectures that are provided.
15 | A batchnorm version for each network is available. 16 | 17 | - **FlowNet2S** 18 | - **FlowNet2C** 19 | - **FlowNet2CS** 20 | - **FlowNet2CSS** 21 | - **FlowNet2SD** 22 | - **FlowNet2** 23 | 24 | ## Custom layers 25 | 26 | `FlowNet2` or `FlowNet2C*` achitectures rely on custom layers `Resample2d` or `Correlation`.
27 | A pytorch implementation of these layers with cuda kernels are available at [./networks](./networks).
28 | Note : Currently, half precision kernels are not available for these layers. 29 | 30 | ## Data Loaders 31 | 32 | Dataloaders for FlyingChairs, FlyingThings, ChairsSDHom and ImagesFromFolder are available in [datasets.py](./datasets.py).
33 | 34 | ## Loss Functions 35 | 36 | L1 and L2 losses with multi-scale support are available in [losses.py](./losses.py).
37 | 38 | ## Installation 39 | 40 | # get flownet2-pytorch source 41 | git clone https://github.com/NVIDIA/flownet2-pytorch.git 42 | cd flownet2-pytorch 43 | 44 | # install custom layers 45 | bash install.sh 46 | 47 | ### Python requirements 48 | Currently, the code supports python 3 49 | * numpy 50 | * PyTorch ( == 0.4.1, for <= 0.4.0 see branch [python36-PyTorch0.4](https://github.com/NVIDIA/flownet2-pytorch/tree/python36-PyTorch0.4)) 51 | * scipy 52 | * scikit-image 53 | * tensorboardX 54 | * colorama, tqdm, setproctitle 55 | 56 | ## Converted Caffe Pre-trained Models 57 | We've included caffe pre-trained models. Should you use these pre-trained weights, please adhere to the [license agreements](https://drive.google.com/file/d/1TVv0BnNFh3rpHZvD-easMb9jYrPE2Eqd/view?usp=sharing). 58 | 59 | * [FlowNet2](https://drive.google.com/file/d/1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da/view?usp=sharing)[620MB] 60 | * [FlowNet2-C](https://drive.google.com/file/d/1BFT6b7KgKJC8rA59RmOVAXRM_S7aSfKE/view?usp=sharing)[149MB] 61 | * [FlowNet2-CS](https://drive.google.com/file/d/1iBJ1_o7PloaINpa8m7u_7TsLCX0Dt_jS/view?usp=sharing)[297MB] 62 | * [FlowNet2-CSS](https://drive.google.com/file/d/157zuzVf4YMN6ABAQgZc8rRmR5cgWzSu8/view?usp=sharing)[445MB] 63 | * [FlowNet2-CSS-ft-sd](https://drive.google.com/file/d/1R5xafCIzJCXc8ia4TGfC65irmTNiMg6u/view?usp=sharing)[445MB] 64 | * [FlowNet2-S](https://drive.google.com/file/d/1V61dZjFomwlynwlYklJHC-TLfdFom3Lg/view?usp=sharing)[148MB] 65 | * [FlowNet2-SD](https://drive.google.com/file/d/1QW03eyYG_vD-dT-Mx4wopYvtPu_msTKn/view?usp=sharing)[173MB] 66 | 67 | ## Inference 68 | # Example on MPISintel Clean 69 | python main.py --inference --model FlowNet2 --save_flow --inference_dataset MpiSintelClean \ 70 | --inference_dataset_root /path/to/mpi-sintel/clean/dataset \ 71 | --resume /path/to/checkpoints 72 | 73 | ## Training and validation 74 | 75 | # Example on MPISintel Final and Clean, with L1Loss on FlowNet2 model 76 | python main.py --batch_size 8 --model FlowNet2 --loss=L1Loss --optimizer=Adam --optimizer_lr=1e-4 \ 77 | --training_dataset MpiSintelFinal --training_dataset_root /path/to/mpi-sintel/final/dataset \ 78 | --validation_dataset MpiSintelClean --validation_dataset_root /path/to/mpi-sintel/clean/dataset 79 | 80 | # Example on MPISintel Final and Clean, with MultiScale loss on FlowNet2C model 81 | python main.py --batch_size 8 --model FlowNet2C --optimizer=Adam --optimizer_lr=1e-4 --loss=MultiScale --loss_norm=L1 \ 82 | --loss_numScales=5 --loss_startScale=4 --optimizer_lr=1e-4 --crop_size 384 512 \ 83 | --training_dataset FlyingChairs --training_dataset_root /path/to/flying-chairs/dataset \ 84 | --validation_dataset MpiSintelClean --validation_dataset_root /path/to/mpi-sintel/clean/dataset 85 | 86 | ## Results on MPI-Sintel 87 | [![Predicted flows on MPI-Sintel](./image.png)](https://www.youtube.com/watch?v=HtBmabY8aeU "Predicted flows on MPI-Sintel") 88 | 89 | ## Reference 90 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper: 91 | ```` 92 | @InProceedings{IMKDB17, 93 | author = "E. Ilg and N. Mayer and T. Saikia and M. Keuper and A. Dosovitskiy and T. Brox", 94 | title = "FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks", 95 | booktitle = "IEEE Conference on Computer Vision and Pattern Recognition (CVPR)", 96 | month = "Jul", 97 | year = "2017", 98 | url = "http://lmb.informatik.uni-freiburg.de//Publications/2017/IMKDB17" 99 | } 100 | ```` 101 | ``` 102 | @misc{flownet2-pytorch, 103 | author = {Fitsum Reda and Robert Pottorff and Jon Barker and Bryan Catanzaro}, 104 | title = {flownet2-pytorch: Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks}, 105 | year = {2017}, 106 | publisher = {GitHub}, 107 | journal = {GitHub repository}, 108 | howpublished = {\url{https://github.com/NVIDIA/flownet2-pytorch}} 109 | } 110 | ``` 111 | ## Related Optical Flow Work from Nvidia 112 | Code (in Caffe and Pytorch): [PWC-Net](https://github.com/NVlabs/PWC-Net)
113 | Paper : [PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume](https://arxiv.org/abs/1709.02371). 114 | 115 | ## Acknowledgments 116 | Parts of this code were derived, as noted in the code, from [ClementPinard/FlowNetPytorch](https://github.com/ClementPinard/FlowNetPytorch). 117 | -------------------------------------------------------------------------------- /flownet2/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/__init__.py -------------------------------------------------------------------------------- /flownet2/convert.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python2.7 2 | 3 | import caffe 4 | from caffe.proto import caffe_pb2 5 | import sys, os 6 | 7 | import torch 8 | import torch.nn as nn 9 | 10 | import argparse, tempfile 11 | import numpy as np 12 | 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('caffe_model', help='input model in hdf5 or caffemodel format') 15 | parser.add_argument('prototxt_template',help='prototxt template') 16 | parser.add_argument('flownet2_pytorch', help='path to flownet2-pytorch') 17 | 18 | args = parser.parse_args() 19 | 20 | args.rgb_max = 255 21 | args.fp16 = False 22 | args.grads = {} 23 | 24 | # load models 25 | sys.path.append(args.flownet2_pytorch) 26 | 27 | import models 28 | from utils.param_utils import * 29 | 30 | width = 256 31 | height = 256 32 | keys = {'TARGET_WIDTH': width, 33 | 'TARGET_HEIGHT': height, 34 | 'ADAPTED_WIDTH':width, 35 | 'ADAPTED_HEIGHT':height, 36 | 'SCALE_WIDTH':1., 37 | 'SCALE_HEIGHT':1.,} 38 | 39 | template = '\n'.join(np.loadtxt(args.prototxt_template, dtype=str, delimiter='\n')) 40 | for k in keys: 41 | template = template.replace('$%s$'%(k),str(keys[k])) 42 | 43 | prototxt = tempfile.NamedTemporaryFile(mode='w', delete=True) 44 | prototxt.write(template) 45 | prototxt.flush() 46 | 47 | net = caffe.Net(prototxt.name, args.caffe_model, caffe.TEST) 48 | 49 | weights = {} 50 | biases = {} 51 | 52 | for k, v in list(net.params.items()): 53 | weights[k] = np.array(v[0].data).reshape(v[0].data.shape) 54 | biases[k] = np.array(v[1].data).reshape(v[1].data.shape) 55 | print((k, weights[k].shape, biases[k].shape)) 56 | 57 | if 'FlowNet2/' in args.caffe_model: 58 | model = models.FlowNet2(args) 59 | 60 | parse_flownetc(model.flownetc.modules(), weights, biases) 61 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 62 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 63 | parse_flownetsd(model.flownets_d.modules(), weights, biases, param_prefix='netsd_') 64 | parse_flownetfusion(model.flownetfusion.modules(), weights, biases, param_prefix='fuse_') 65 | 66 | state = {'epoch': 0, 67 | 'state_dict': model.state_dict(), 68 | 'best_EPE': 1e10} 69 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2_checkpoint.pth.tar')) 70 | 71 | elif 'FlowNet2-C/' in args.caffe_model: 72 | model = models.FlowNet2C(args) 73 | 74 | parse_flownetc(model.modules(), weights, biases) 75 | state = {'epoch': 0, 76 | 'state_dict': model.state_dict(), 77 | 'best_EPE': 1e10} 78 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-C_checkpoint.pth.tar')) 79 | 80 | elif 'FlowNet2-CS/' in args.caffe_model: 81 | model = models.FlowNet2CS(args) 82 | 83 | parse_flownetc(model.flownetc.modules(), weights, biases) 84 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 85 | 86 | state = {'epoch': 0, 87 | 'state_dict': model.state_dict(), 88 | 'best_EPE': 1e10} 89 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CS_checkpoint.pth.tar')) 90 | 91 | elif 'FlowNet2-CSS/' in args.caffe_model: 92 | model = models.FlowNet2CSS(args) 93 | 94 | parse_flownetc(model.flownetc.modules(), weights, biases) 95 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 96 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 97 | 98 | state = {'epoch': 0, 99 | 'state_dict': model.state_dict(), 100 | 'best_EPE': 1e10} 101 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS_checkpoint.pth.tar')) 102 | 103 | elif 'FlowNet2-CSS-ft-sd/' in args.caffe_model: 104 | model = models.FlowNet2CSS(args) 105 | 106 | parse_flownetc(model.flownetc.modules(), weights, biases) 107 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_') 108 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_') 109 | 110 | state = {'epoch': 0, 111 | 'state_dict': model.state_dict(), 112 | 'best_EPE': 1e10} 113 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS-ft-sd_checkpoint.pth.tar')) 114 | 115 | elif 'FlowNet2-S/' in args.caffe_model: 116 | model = models.FlowNet2S(args) 117 | 118 | parse_flownetsonly(model.modules(), weights, biases, param_prefix='') 119 | state = {'epoch': 0, 120 | 'state_dict': model.state_dict(), 121 | 'best_EPE': 1e10} 122 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-S_checkpoint.pth.tar')) 123 | 124 | elif 'FlowNet2-SD/' in args.caffe_model: 125 | model = models.FlowNet2SD(args) 126 | 127 | parse_flownetsd(model.modules(), weights, biases, param_prefix='') 128 | 129 | state = {'epoch': 0, 130 | 'state_dict': model.state_dict(), 131 | 'best_EPE': 1e10} 132 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-SD_checkpoint.pth.tar')) 133 | 134 | else: 135 | print(('model type cound not be determined from input caffe model %s'%(args.caffe_model))) 136 | quit() 137 | print(("done converting ", args.caffe_model)) -------------------------------------------------------------------------------- /flownet2/datasets.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.utils.data as data 3 | 4 | import os, math, random 5 | from os.path import * 6 | import numpy as np 7 | 8 | from glob import glob 9 | import utils.frame_utils as frame_utils 10 | 11 | from scipy.misc import imread, imresize 12 | 13 | class StaticRandomCrop(object): 14 | def __init__(self, image_size, crop_size): 15 | self.th, self.tw = crop_size 16 | h, w = image_size 17 | self.h1 = random.randint(0, h - self.th) 18 | self.w1 = random.randint(0, w - self.tw) 19 | 20 | def __call__(self, img): 21 | return img[self.h1:(self.h1+self.th), self.w1:(self.w1+self.tw),:] 22 | 23 | class StaticCenterCrop(object): 24 | def __init__(self, image_size, crop_size): 25 | self.th, self.tw = crop_size 26 | self.h, self.w = image_size 27 | def __call__(self, img): 28 | return img[(self.h-self.th)//2:(self.h+self.th)//2, (self.w-self.tw)//2:(self.w+self.tw)//2,:] 29 | 30 | class MpiSintel(data.Dataset): 31 | def __init__(self, args, is_cropped = False, root = '', dstype = 'clean', replicates = 1): 32 | self.args = args 33 | self.is_cropped = is_cropped 34 | self.crop_size = args.crop_size 35 | self.render_size = args.inference_size 36 | self.replicates = replicates 37 | 38 | flow_root = join(root, 'flow') 39 | image_root = join(root, dstype) 40 | 41 | file_list = sorted(glob(join(flow_root, '*/*.flo'))) 42 | 43 | self.flow_list = [] 44 | self.image_list = [] 45 | 46 | for file in file_list: 47 | if 'test' in file: 48 | # print file 49 | continue 50 | 51 | fbase = file[len(flow_root)+1:] 52 | fprefix = fbase[:-8] 53 | fnum = int(fbase[-8:-4]) 54 | 55 | img1 = join(image_root, fprefix + "%04d"%(fnum+0) + '.png') 56 | img2 = join(image_root, fprefix + "%04d"%(fnum+1) + '.png') 57 | 58 | if not isfile(img1) or not isfile(img2) or not isfile(file): 59 | continue 60 | 61 | self.image_list += [[img1, img2]] 62 | self.flow_list += [file] 63 | 64 | self.size = len(self.image_list) 65 | 66 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 67 | 68 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 69 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 70 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 71 | 72 | args.inference_size = self.render_size 73 | 74 | assert (len(self.image_list) == len(self.flow_list)) 75 | 76 | def __getitem__(self, index): 77 | 78 | index = index % self.size 79 | 80 | img1 = frame_utils.read_gen(self.image_list[index][0]) 81 | img2 = frame_utils.read_gen(self.image_list[index][1]) 82 | 83 | flow = frame_utils.read_gen(self.flow_list[index]) 84 | 85 | images = [img1, img2] 86 | image_size = img1.shape[:2] 87 | 88 | if self.is_cropped: 89 | cropper = StaticRandomCrop(image_size, self.crop_size) 90 | else: 91 | cropper = StaticCenterCrop(image_size, self.render_size) 92 | images = list(map(cropper, images)) 93 | flow = cropper(flow) 94 | 95 | images = np.array(images).transpose(3,0,1,2) 96 | flow = flow.transpose(2,0,1) 97 | 98 | images = torch.from_numpy(images.astype(np.float32)) 99 | flow = torch.from_numpy(flow.astype(np.float32)) 100 | 101 | return [images], [flow] 102 | 103 | def __len__(self): 104 | return self.size * self.replicates 105 | 106 | class MpiSintelClean(MpiSintel): 107 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 108 | super(MpiSintelClean, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'clean', replicates = replicates) 109 | 110 | class MpiSintelFinal(MpiSintel): 111 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 112 | super(MpiSintelFinal, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'final', replicates = replicates) 113 | 114 | class FlyingChairs(data.Dataset): 115 | def __init__(self, args, is_cropped, root = '/path/to/FlyingChairs_release/data', replicates = 1): 116 | self.args = args 117 | self.is_cropped = is_cropped 118 | self.crop_size = args.crop_size 119 | self.render_size = args.inference_size 120 | self.replicates = replicates 121 | 122 | images = sorted( glob( join(root, '*.ppm') ) ) 123 | 124 | self.flow_list = sorted( glob( join(root, '*.flo') ) ) 125 | 126 | assert (len(images)//2 == len(self.flow_list)) 127 | 128 | self.image_list = [] 129 | for i in range(len(self.flow_list)): 130 | im1 = images[2*i] 131 | im2 = images[2*i + 1] 132 | self.image_list += [ [ im1, im2 ] ] 133 | 134 | assert len(self.image_list) == len(self.flow_list) 135 | 136 | self.size = len(self.image_list) 137 | 138 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 139 | 140 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 141 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 142 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 143 | 144 | args.inference_size = self.render_size 145 | 146 | def __getitem__(self, index): 147 | index = index % self.size 148 | 149 | img1 = frame_utils.read_gen(self.image_list[index][0]) 150 | img2 = frame_utils.read_gen(self.image_list[index][1]) 151 | 152 | flow = frame_utils.read_gen(self.flow_list[index]) 153 | 154 | images = [img1, img2] 155 | image_size = img1.shape[:2] 156 | if self.is_cropped: 157 | cropper = StaticRandomCrop(image_size, self.crop_size) 158 | else: 159 | cropper = StaticCenterCrop(image_size, self.render_size) 160 | images = list(map(cropper, images)) 161 | flow = cropper(flow) 162 | 163 | 164 | images = np.array(images).transpose(3,0,1,2) 165 | flow = flow.transpose(2,0,1) 166 | 167 | images = torch.from_numpy(images.astype(np.float32)) 168 | flow = torch.from_numpy(flow.astype(np.float32)) 169 | 170 | return [images], [flow] 171 | 172 | def __len__(self): 173 | return self.size * self.replicates 174 | 175 | class FlyingThings(data.Dataset): 176 | def __init__(self, args, is_cropped, root = '/path/to/flyingthings3d', dstype = 'frames_cleanpass', replicates = 1): 177 | self.args = args 178 | self.is_cropped = is_cropped 179 | self.crop_size = args.crop_size 180 | self.render_size = args.inference_size 181 | self.replicates = replicates 182 | 183 | image_dirs = sorted(glob(join(root, dstype, 'TRAIN/*/*'))) 184 | image_dirs = sorted([join(f, 'left') for f in image_dirs] + [join(f, 'right') for f in image_dirs]) 185 | 186 | flow_dirs = sorted(glob(join(root, 'optical_flow_flo_format/TRAIN/*/*'))) 187 | flow_dirs = sorted([join(f, 'into_future/left') for f in flow_dirs] + [join(f, 'into_future/right') for f in flow_dirs]) 188 | 189 | assert (len(image_dirs) == len(flow_dirs)) 190 | 191 | self.image_list = [] 192 | self.flow_list = [] 193 | 194 | for idir, fdir in zip(image_dirs, flow_dirs): 195 | images = sorted( glob(join(idir, '*.png')) ) 196 | flows = sorted( glob(join(fdir, '*.flo')) ) 197 | for i in range(len(flows)): 198 | self.image_list += [ [ images[i], images[i+1] ] ] 199 | self.flow_list += [flows[i]] 200 | 201 | assert len(self.image_list) == len(self.flow_list) 202 | 203 | self.size = len(self.image_list) 204 | 205 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 206 | 207 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 208 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 209 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 210 | 211 | args.inference_size = self.render_size 212 | 213 | def __getitem__(self, index): 214 | index = index % self.size 215 | 216 | img1 = frame_utils.read_gen(self.image_list[index][0]) 217 | img2 = frame_utils.read_gen(self.image_list[index][1]) 218 | 219 | flow = frame_utils.read_gen(self.flow_list[index]) 220 | 221 | images = [img1, img2] 222 | image_size = img1.shape[:2] 223 | if self.is_cropped: 224 | cropper = StaticRandomCrop(image_size, self.crop_size) 225 | else: 226 | cropper = StaticCenterCrop(image_size, self.render_size) 227 | images = list(map(cropper, images)) 228 | flow = cropper(flow) 229 | 230 | 231 | images = np.array(images).transpose(3,0,1,2) 232 | flow = flow.transpose(2,0,1) 233 | 234 | images = torch.from_numpy(images.astype(np.float32)) 235 | flow = torch.from_numpy(flow.astype(np.float32)) 236 | 237 | return [images], [flow] 238 | 239 | def __len__(self): 240 | return self.size * self.replicates 241 | 242 | class FlyingThingsClean(FlyingThings): 243 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 244 | super(FlyingThingsClean, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'frames_cleanpass', replicates = replicates) 245 | 246 | class FlyingThingsFinal(FlyingThings): 247 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 248 | super(FlyingThingsFinal, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'frames_finalpass', replicates = replicates) 249 | 250 | class ChairsSDHom(data.Dataset): 251 | def __init__(self, args, is_cropped, root = '/path/to/chairssdhom/data', dstype = 'train', replicates = 1): 252 | self.args = args 253 | self.is_cropped = is_cropped 254 | self.crop_size = args.crop_size 255 | self.render_size = args.inference_size 256 | self.replicates = replicates 257 | 258 | image1 = sorted( glob( join(root, dstype, 't0/*.png') ) ) 259 | image2 = sorted( glob( join(root, dstype, 't1/*.png') ) ) 260 | self.flow_list = sorted( glob( join(root, dstype, 'flow/*.flo') ) ) 261 | 262 | assert (len(image1) == len(self.flow_list)) 263 | 264 | self.image_list = [] 265 | for i in range(len(self.flow_list)): 266 | im1 = image1[i] 267 | im2 = image2[i] 268 | self.image_list += [ [ im1, im2 ] ] 269 | 270 | assert len(self.image_list) == len(self.flow_list) 271 | 272 | self.size = len(self.image_list) 273 | 274 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 275 | 276 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 277 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 278 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 279 | 280 | args.inference_size = self.render_size 281 | 282 | def __getitem__(self, index): 283 | index = index % self.size 284 | 285 | img1 = frame_utils.read_gen(self.image_list[index][0]) 286 | img2 = frame_utils.read_gen(self.image_list[index][1]) 287 | 288 | flow = frame_utils.read_gen(self.flow_list[index]) 289 | flow = flow[::-1,:,:] 290 | 291 | images = [img1, img2] 292 | image_size = img1.shape[:2] 293 | if self.is_cropped: 294 | cropper = StaticRandomCrop(image_size, self.crop_size) 295 | else: 296 | cropper = StaticCenterCrop(image_size, self.render_size) 297 | images = list(map(cropper, images)) 298 | flow = cropper(flow) 299 | 300 | 301 | images = np.array(images).transpose(3,0,1,2) 302 | flow = flow.transpose(2,0,1) 303 | 304 | images = torch.from_numpy(images.astype(np.float32)) 305 | flow = torch.from_numpy(flow.astype(np.float32)) 306 | 307 | return [images], [flow] 308 | 309 | def __len__(self): 310 | return self.size * self.replicates 311 | 312 | class ChairsSDHomTrain(ChairsSDHom): 313 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 314 | super(ChairsSDHomTrain, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'train', replicates = replicates) 315 | 316 | class ChairsSDHomTest(ChairsSDHom): 317 | def __init__(self, args, is_cropped = False, root = '', replicates = 1): 318 | super(ChairsSDHomTest, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'test', replicates = replicates) 319 | 320 | class ImagesFromFolder(data.Dataset): 321 | def __init__(self, args, is_cropped, root = '/path/to/frames/only/folder', iext = 'png', replicates = 1): 322 | self.args = args 323 | self.is_cropped = is_cropped 324 | self.crop_size = args.crop_size 325 | self.render_size = args.inference_size 326 | self.replicates = replicates 327 | 328 | images = sorted( glob( join(root, '*.' + iext) ) ) 329 | self.image_list = [] 330 | for i in range(len(images)-1): 331 | im1 = images[i] 332 | im2 = images[i+1] 333 | self.image_list += [ [ im1, im2 ] ] 334 | 335 | self.size = len(self.image_list) 336 | 337 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape 338 | 339 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64): 340 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64 341 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64 342 | 343 | args.inference_size = self.render_size 344 | 345 | def __getitem__(self, index): 346 | index = index % self.size 347 | 348 | img1 = frame_utils.read_gen(self.image_list[index][0]) 349 | img2 = frame_utils.read_gen(self.image_list[index][1]) 350 | 351 | images = [img1, img2] 352 | image_size = img1.shape[:2] 353 | if self.is_cropped: 354 | cropper = StaticRandomCrop(image_size, self.crop_size) 355 | else: 356 | cropper = StaticCenterCrop(image_size, self.render_size) 357 | images = list(map(cropper, images)) 358 | 359 | images = np.array(images).transpose(3,0,1,2) 360 | images = torch.from_numpy(images.astype(np.float32)) 361 | 362 | return [images], [torch.zeros(images.size()[0:1] + (2,) + images.size()[-2:])] 363 | 364 | def __len__(self): 365 | return self.size * self.replicates 366 | 367 | ''' 368 | import argparse 369 | import sys, os 370 | import importlib 371 | from scipy.misc import imsave 372 | import numpy as np 373 | 374 | import datasets 375 | reload(datasets) 376 | 377 | parser = argparse.ArgumentParser() 378 | args = parser.parse_args() 379 | args.inference_size = [1080, 1920] 380 | args.crop_size = [384, 512] 381 | args.effective_batch_size = 1 382 | 383 | index = 500 384 | v_dataset = datasets.MpiSintelClean(args, True, root='../MPI-Sintel/flow/training') 385 | a, b = v_dataset[index] 386 | im1 = a[0].numpy()[:,0,:,:].transpose(1,2,0) 387 | im2 = a[0].numpy()[:,1,:,:].transpose(1,2,0) 388 | imsave('./img1.png', im1) 389 | imsave('./img2.png', im2) 390 | flow_utils.writeFlow('./flow.flo', b[0].numpy().transpose(1,2,0)) 391 | 392 | ''' 393 | -------------------------------------------------------------------------------- /flownet2/download_caffe_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sudo rm -rf flownet2-docker 3 | sudo git clone https://github.com/lmb-freiburg/flownet2-docker 4 | cd flownet2-docker 5 | 6 | sudo sed -i '$ a RUN apt-get update && apt-get install -y python-pip \ 7 | RUN pip install --upgrade pip \ 8 | RUN pip install numpy -I \ 9 | RUN pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl \ 10 | RUN pip install cffi ipython' Dockerfile 11 | 12 | sudo make 13 | 14 | -------------------------------------------------------------------------------- /flownet2/image.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/image.png -------------------------------------------------------------------------------- /flownet2/install.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | cd ./networks/correlation_package 3 | python setup.py install 4 | cd ../resample2d_package 5 | python setup.py install 6 | cd ../channelnorm_package 7 | python setup.py install 8 | cd .. 9 | -------------------------------------------------------------------------------- /flownet2/launch_docker.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 . 3 | sudo nvidia-docker run --rm -ti --volume=$(pwd):/flownet2-pytorch:rw --workdir=/flownet2-pytorch --ipc=host $USER/pytorch:CUDA8-py27 /bin/bash 4 | -------------------------------------------------------------------------------- /flownet2/losses.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Portions of this code copyright 2017, Clement Pinard 3 | ''' 4 | 5 | # freda (todo) : adversarial loss 6 | 7 | import torch 8 | import torch.nn as nn 9 | import math 10 | 11 | def EPE(input_flow, target_flow): 12 | return torch.norm(target_flow-input_flow,p=2,dim=1).mean() 13 | 14 | class L1(nn.Module): 15 | def __init__(self): 16 | super(L1, self).__init__() 17 | def forward(self, output, target): 18 | lossvalue = torch.abs(output - target).mean() 19 | return lossvalue 20 | 21 | class L2(nn.Module): 22 | def __init__(self): 23 | super(L2, self).__init__() 24 | def forward(self, output, target): 25 | lossvalue = torch.norm(output-target,p=2,dim=1).mean() 26 | return lossvalue 27 | 28 | class L1Loss(nn.Module): 29 | def __init__(self, args): 30 | super(L1Loss, self).__init__() 31 | self.args = args 32 | self.loss = L1() 33 | self.loss_labels = ['L1', 'EPE'] 34 | 35 | def forward(self, output, target): 36 | lossvalue = self.loss(output, target) 37 | epevalue = EPE(output, target) 38 | return [lossvalue, epevalue] 39 | 40 | class L2Loss(nn.Module): 41 | def __init__(self, args): 42 | super(L2Loss, self).__init__() 43 | self.args = args 44 | self.loss = L2() 45 | self.loss_labels = ['L2', 'EPE'] 46 | 47 | def forward(self, output, target): 48 | lossvalue = self.loss(output, target) 49 | epevalue = EPE(output, target) 50 | return [lossvalue, epevalue] 51 | 52 | class MultiScale(nn.Module): 53 | def __init__(self, args, startScale = 4, numScales = 5, l_weight= 0.32, norm= 'L1'): 54 | super(MultiScale,self).__init__() 55 | 56 | self.startScale = startScale 57 | self.numScales = numScales 58 | self.loss_weights = torch.FloatTensor([(l_weight / 2 ** scale) for scale in range(self.numScales)]) 59 | self.args = args 60 | self.l_type = norm 61 | self.div_flow = 0.05 62 | assert(len(self.loss_weights) == self.numScales) 63 | 64 | if self.l_type == 'L1': 65 | self.loss = L1() 66 | else: 67 | self.loss = L2() 68 | 69 | self.multiScales = [nn.AvgPool2d(self.startScale * (2**scale), self.startScale * (2**scale)) for scale in range(self.numScales)] 70 | self.loss_labels = ['MultiScale-'+self.l_type, 'EPE'], 71 | 72 | def forward(self, output, target): 73 | lossvalue = 0 74 | epevalue = 0 75 | 76 | if type(output) is tuple: 77 | target = self.div_flow * target 78 | for i, output_ in enumerate(output): 79 | target_ = self.multiScales[i](target) 80 | epevalue += self.loss_weights[i]*EPE(output_, target_) 81 | lossvalue += self.loss_weights[i]*self.loss(output_, target_) 82 | return [lossvalue, epevalue] 83 | else: 84 | epevalue += EPE(output, target) 85 | lossvalue += self.loss(output, target) 86 | return [lossvalue, epevalue] 87 | 88 | -------------------------------------------------------------------------------- /flownet2/networks/FlowNetC.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .correlation_package.correlation import Correlation 9 | 10 | from .submodules import * 11 | 'Parameter count , 39,175,298 ' 12 | 13 | class FlowNetC(nn.Module): 14 | def __init__(self,args, batchNorm=True, div_flow = 20): 15 | super(FlowNetC,self).__init__() 16 | 17 | self.batchNorm = batchNorm 18 | self.div_flow = div_flow 19 | 20 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2) 21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 23 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1) 24 | 25 | if args.fp16: 26 | self.corr = nn.Sequential( 27 | tofp32(), 28 | Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1), 29 | tofp16()) 30 | else: 31 | self.corr = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1) 32 | 33 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True) 34 | self.conv3_1 = conv(self.batchNorm, 473, 256) 35 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 36 | self.conv4_1 = conv(self.batchNorm, 512, 512) 37 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 38 | self.conv5_1 = conv(self.batchNorm, 512, 512) 39 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 40 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 41 | 42 | self.deconv5 = deconv(1024,512) 43 | self.deconv4 = deconv(1026,256) 44 | self.deconv3 = deconv(770,128) 45 | self.deconv2 = deconv(386,64) 46 | 47 | self.predict_flow6 = predict_flow(1024) 48 | self.predict_flow5 = predict_flow(1026) 49 | self.predict_flow4 = predict_flow(770) 50 | self.predict_flow3 = predict_flow(386) 51 | self.predict_flow2 = predict_flow(194) 52 | 53 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 54 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 55 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 56 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True) 57 | 58 | for m in self.modules(): 59 | if isinstance(m, nn.Conv2d): 60 | if m.bias is not None: 61 | init.uniform_(m.bias) 62 | init.xavier_uniform_(m.weight) 63 | 64 | if isinstance(m, nn.ConvTranspose2d): 65 | if m.bias is not None: 66 | init.uniform_(m.bias) 67 | init.xavier_uniform_(m.weight) 68 | # init_deconv_bilinear(m.weight) 69 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 70 | 71 | def forward(self, x): 72 | x1 = x[:,0:3,:,:] 73 | x2 = x[:,3::,:,:] 74 | 75 | out_conv1a = self.conv1(x1) 76 | out_conv2a = self.conv2(out_conv1a) 77 | out_conv3a = self.conv3(out_conv2a) 78 | 79 | # FlownetC bottom input stream 80 | out_conv1b = self.conv1(x2) 81 | 82 | out_conv2b = self.conv2(out_conv1b) 83 | out_conv3b = self.conv3(out_conv2b) 84 | 85 | # Merge streams 86 | out_corr = self.corr(out_conv3a, out_conv3b) # False 87 | out_corr = self.corr_activation(out_corr) 88 | 89 | # Redirect top input stream and concatenate 90 | out_conv_redir = self.conv_redir(out_conv3a) 91 | 92 | in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1) 93 | 94 | # Merged conv layers 95 | out_conv3_1 = self.conv3_1(in_conv3_1) 96 | 97 | out_conv4 = self.conv4_1(self.conv4(out_conv3_1)) 98 | 99 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 100 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 101 | 102 | flow6 = self.predict_flow6(out_conv6) 103 | flow6_up = self.upsampled_flow6_to_5(flow6) 104 | out_deconv5 = self.deconv5(out_conv6) 105 | 106 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 107 | 108 | flow5 = self.predict_flow5(concat5) 109 | flow5_up = self.upsampled_flow5_to_4(flow5) 110 | out_deconv4 = self.deconv4(concat5) 111 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 112 | 113 | flow4 = self.predict_flow4(concat4) 114 | flow4_up = self.upsampled_flow4_to_3(flow4) 115 | out_deconv3 = self.deconv3(concat4) 116 | concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1) 117 | 118 | flow3 = self.predict_flow3(concat3) 119 | flow3_up = self.upsampled_flow3_to_2(flow3) 120 | out_deconv2 = self.deconv2(concat3) 121 | concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1) 122 | 123 | flow2 = self.predict_flow2(concat2) 124 | 125 | if self.training: 126 | return flow2,flow3,flow4,flow5,flow6 127 | else: 128 | return flow2, 129 | -------------------------------------------------------------------------------- /flownet2/networks/FlowNetFusion.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .submodules import * 9 | 'Parameter count = 581,226' 10 | 11 | class FlowNetFusion(nn.Module): 12 | def __init__(self,args, batchNorm=True): 13 | super(FlowNetFusion,self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv0 = conv(self.batchNorm, 11, 64) 17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2) 18 | self.conv1_1 = conv(self.batchNorm, 64, 128) 19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2) 20 | self.conv2_1 = conv(self.batchNorm, 128, 128) 21 | 22 | self.deconv1 = deconv(128,32) 23 | self.deconv0 = deconv(162,16) 24 | 25 | self.inter_conv1 = i_conv(self.batchNorm, 162, 32) 26 | self.inter_conv0 = i_conv(self.batchNorm, 82, 16) 27 | 28 | self.predict_flow2 = predict_flow(128) 29 | self.predict_flow1 = predict_flow(32) 30 | self.predict_flow0 = predict_flow(16) 31 | 32 | self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 33 | self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 34 | 35 | for m in self.modules(): 36 | if isinstance(m, nn.Conv2d): 37 | if m.bias is not None: 38 | init.uniform_(m.bias) 39 | init.xavier_uniform_(m.weight) 40 | 41 | if isinstance(m, nn.ConvTranspose2d): 42 | if m.bias is not None: 43 | init.uniform_(m.bias) 44 | init.xavier_uniform_(m.weight) 45 | # init_deconv_bilinear(m.weight) 46 | 47 | def forward(self, x): 48 | out_conv0 = self.conv0(x) 49 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 50 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 51 | 52 | flow2 = self.predict_flow2(out_conv2) 53 | flow2_up = self.upsampled_flow2_to_1(flow2) 54 | out_deconv1 = self.deconv1(out_conv2) 55 | 56 | concat1 = torch.cat((out_conv1,out_deconv1,flow2_up),1) 57 | out_interconv1 = self.inter_conv1(concat1) 58 | flow1 = self.predict_flow1(out_interconv1) 59 | flow1_up = self.upsampled_flow1_to_0(flow1) 60 | out_deconv0 = self.deconv0(concat1) 61 | 62 | concat0 = torch.cat((out_conv0,out_deconv0,flow1_up),1) 63 | out_interconv0 = self.inter_conv0(concat0) 64 | flow0 = self.predict_flow0(out_interconv0) 65 | 66 | return flow0 67 | 68 | -------------------------------------------------------------------------------- /flownet2/networks/FlowNetS.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Portions of this code copyright 2017, Clement Pinard 3 | ''' 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.nn import init 8 | 9 | import math 10 | import numpy as np 11 | 12 | from .submodules import * 13 | 'Parameter count : 38,676,504 ' 14 | 15 | class FlowNetS(nn.Module): 16 | def __init__(self, args, input_channels = 12, batchNorm=True): 17 | super(FlowNetS,self).__init__() 18 | 19 | self.batchNorm = batchNorm 20 | self.conv1 = conv(self.batchNorm, input_channels, 64, kernel_size=7, stride=2) 21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2) 22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2) 23 | self.conv3_1 = conv(self.batchNorm, 256, 256) 24 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 25 | self.conv4_1 = conv(self.batchNorm, 512, 512) 26 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 27 | self.conv5_1 = conv(self.batchNorm, 512, 512) 28 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 29 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 30 | 31 | self.deconv5 = deconv(1024,512) 32 | self.deconv4 = deconv(1026,256) 33 | self.deconv3 = deconv(770,128) 34 | self.deconv2 = deconv(386,64) 35 | 36 | self.predict_flow6 = predict_flow(1024) 37 | self.predict_flow5 = predict_flow(1026) 38 | self.predict_flow4 = predict_flow(770) 39 | self.predict_flow3 = predict_flow(386) 40 | self.predict_flow2 = predict_flow(194) 41 | 42 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 43 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 44 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 45 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False) 46 | 47 | for m in self.modules(): 48 | if isinstance(m, nn.Conv2d): 49 | if m.bias is not None: 50 | init.uniform_(m.bias) 51 | init.xavier_uniform_(m.weight) 52 | 53 | if isinstance(m, nn.ConvTranspose2d): 54 | if m.bias is not None: 55 | init.uniform_(m.bias) 56 | init.xavier_uniform_(m.weight) 57 | # init_deconv_bilinear(m.weight) 58 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 59 | 60 | def forward(self, x): 61 | out_conv1 = self.conv1(x) 62 | 63 | out_conv2 = self.conv2(out_conv1) 64 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 65 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 66 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 67 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 68 | 69 | flow6 = self.predict_flow6(out_conv6) 70 | flow6_up = self.upsampled_flow6_to_5(flow6) 71 | out_deconv5 = self.deconv5(out_conv6) 72 | 73 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 74 | flow5 = self.predict_flow5(concat5) 75 | flow5_up = self.upsampled_flow5_to_4(flow5) 76 | out_deconv4 = self.deconv4(concat5) 77 | 78 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 79 | flow4 = self.predict_flow4(concat4) 80 | flow4_up = self.upsampled_flow4_to_3(flow4) 81 | out_deconv3 = self.deconv3(concat4) 82 | 83 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) 84 | flow3 = self.predict_flow3(concat3) 85 | flow3_up = self.upsampled_flow3_to_2(flow3) 86 | out_deconv2 = self.deconv2(concat3) 87 | 88 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) 89 | flow2 = self.predict_flow2(concat2) 90 | 91 | if self.training: 92 | return flow2,flow3,flow4,flow5,flow6 93 | else: 94 | return flow2, 95 | 96 | -------------------------------------------------------------------------------- /flownet2/networks/FlowNetSD.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import init 4 | 5 | import math 6 | import numpy as np 7 | 8 | from .submodules import * 9 | 'Parameter count = 45,371,666' 10 | 11 | class FlowNetSD(nn.Module): 12 | def __init__(self, args, batchNorm=True): 13 | super(FlowNetSD,self).__init__() 14 | 15 | self.batchNorm = batchNorm 16 | self.conv0 = conv(self.batchNorm, 6, 64) 17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2) 18 | self.conv1_1 = conv(self.batchNorm, 64, 128) 19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2) 20 | self.conv2_1 = conv(self.batchNorm, 128, 128) 21 | self.conv3 = conv(self.batchNorm, 128, 256, stride=2) 22 | self.conv3_1 = conv(self.batchNorm, 256, 256) 23 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2) 24 | self.conv4_1 = conv(self.batchNorm, 512, 512) 25 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2) 26 | self.conv5_1 = conv(self.batchNorm, 512, 512) 27 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2) 28 | self.conv6_1 = conv(self.batchNorm,1024, 1024) 29 | 30 | self.deconv5 = deconv(1024,512) 31 | self.deconv4 = deconv(1026,256) 32 | self.deconv3 = deconv(770,128) 33 | self.deconv2 = deconv(386,64) 34 | 35 | self.inter_conv5 = i_conv(self.batchNorm, 1026, 512) 36 | self.inter_conv4 = i_conv(self.batchNorm, 770, 256) 37 | self.inter_conv3 = i_conv(self.batchNorm, 386, 128) 38 | self.inter_conv2 = i_conv(self.batchNorm, 194, 64) 39 | 40 | self.predict_flow6 = predict_flow(1024) 41 | self.predict_flow5 = predict_flow(512) 42 | self.predict_flow4 = predict_flow(256) 43 | self.predict_flow3 = predict_flow(128) 44 | self.predict_flow2 = predict_flow(64) 45 | 46 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 47 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 48 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 49 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1) 50 | 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | if m.bias is not None: 54 | init.uniform_(m.bias) 55 | init.xavier_uniform_(m.weight) 56 | 57 | if isinstance(m, nn.ConvTranspose2d): 58 | if m.bias is not None: 59 | init.uniform_(m.bias) 60 | init.xavier_uniform_(m.weight) 61 | # init_deconv_bilinear(m.weight) 62 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear') 63 | 64 | 65 | 66 | def forward(self, x): 67 | out_conv0 = self.conv0(x) 68 | out_conv1 = self.conv1_1(self.conv1(out_conv0)) 69 | out_conv2 = self.conv2_1(self.conv2(out_conv1)) 70 | 71 | out_conv3 = self.conv3_1(self.conv3(out_conv2)) 72 | out_conv4 = self.conv4_1(self.conv4(out_conv3)) 73 | out_conv5 = self.conv5_1(self.conv5(out_conv4)) 74 | out_conv6 = self.conv6_1(self.conv6(out_conv5)) 75 | 76 | flow6 = self.predict_flow6(out_conv6) 77 | flow6_up = self.upsampled_flow6_to_5(flow6) 78 | out_deconv5 = self.deconv5(out_conv6) 79 | 80 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1) 81 | out_interconv5 = self.inter_conv5(concat5) 82 | flow5 = self.predict_flow5(out_interconv5) 83 | 84 | flow5_up = self.upsampled_flow5_to_4(flow5) 85 | out_deconv4 = self.deconv4(concat5) 86 | 87 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1) 88 | out_interconv4 = self.inter_conv4(concat4) 89 | flow4 = self.predict_flow4(out_interconv4) 90 | flow4_up = self.upsampled_flow4_to_3(flow4) 91 | out_deconv3 = self.deconv3(concat4) 92 | 93 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1) 94 | out_interconv3 = self.inter_conv3(concat3) 95 | flow3 = self.predict_flow3(out_interconv3) 96 | flow3_up = self.upsampled_flow3_to_2(flow3) 97 | out_deconv2 = self.deconv2(concat3) 98 | 99 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1) 100 | out_interconv2 = self.inter_conv2(concat2) 101 | flow2 = self.predict_flow2(out_interconv2) 102 | 103 | if self.training: 104 | return flow2,flow3,flow4,flow5,flow6 105 | else: 106 | return flow2, 107 | -------------------------------------------------------------------------------- /flownet2/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/networks/__init__.py -------------------------------------------------------------------------------- /flownet2/networks/channelnorm_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/networks/channelnorm_package/__init__.py -------------------------------------------------------------------------------- /flownet2/networks/channelnorm_package/channelnorm.py: -------------------------------------------------------------------------------- 1 | from torch.autograd import Function, Variable 2 | from torch.nn.modules.module import Module 3 | import channelnorm_cuda 4 | 5 | class ChannelNormFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, norm_deg=2): 9 | assert input1.is_contiguous() 10 | b, _, h, w = input1.size() 11 | output = input1.new(b, 1, h, w).zero_() 12 | 13 | channelnorm_cuda.forward(input1, output, norm_deg) 14 | ctx.save_for_backward(input1, output) 15 | ctx.norm_deg = norm_deg 16 | 17 | return output 18 | 19 | @staticmethod 20 | def backward(ctx, grad_output): 21 | input1, output = ctx.saved_tensors 22 | 23 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 24 | 25 | channelnorm.backward(input1, output, grad_output.data, 26 | grad_input1.data, ctx.norm_deg) 27 | 28 | return grad_input1, None 29 | 30 | 31 | class ChannelNorm(Module): 32 | 33 | def __init__(self, norm_deg=2): 34 | super(ChannelNorm, self).__init__() 35 | self.norm_deg = norm_deg 36 | 37 | def forward(self, input1): 38 | return ChannelNormFunction.apply(input1, self.norm_deg) 39 | 40 | -------------------------------------------------------------------------------- /flownet2/networks/channelnorm_package/channelnorm_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "channelnorm_kernel.cuh" 5 | 6 | int channelnorm_cuda_forward( 7 | at::Tensor& input1, 8 | at::Tensor& output, 9 | int norm_deg) { 10 | 11 | channelnorm_kernel_forward(input1, output, norm_deg); 12 | return 1; 13 | } 14 | 15 | 16 | int channelnorm_cuda_backward( 17 | at::Tensor& input1, 18 | at::Tensor& output, 19 | at::Tensor& gradOutput, 20 | at::Tensor& gradInput1, 21 | int norm_deg) { 22 | 23 | channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg); 24 | return 1; 25 | } 26 | 27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 28 | m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)"); 29 | m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)"); 30 | } 31 | 32 | -------------------------------------------------------------------------------- /flownet2/networks/channelnorm_package/channelnorm_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "channelnorm_kernel.cuh" 5 | 6 | #define CUDA_NUM_THREADS 512 7 | 8 | #define DIM0(TENSOR) ((TENSOR).x) 9 | #define DIM1(TENSOR) ((TENSOR).y) 10 | #define DIM2(TENSOR) ((TENSOR).z) 11 | #define DIM3(TENSOR) ((TENSOR).w) 12 | 13 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 14 | 15 | using at::Half; 16 | 17 | template 18 | __global__ void kernel_channelnorm_update_output( 19 | const int n, 20 | const scalar_t* __restrict__ input1, 21 | const long4 input1_size, 22 | const long4 input1_stride, 23 | scalar_t* __restrict__ output, 24 | const long4 output_size, 25 | const long4 output_stride, 26 | int norm_deg) { 27 | 28 | int index = blockIdx.x * blockDim.x + threadIdx.x; 29 | 30 | if (index >= n) { 31 | return; 32 | } 33 | 34 | int dim_b = DIM0(output_size); 35 | int dim_c = DIM1(output_size); 36 | int dim_h = DIM2(output_size); 37 | int dim_w = DIM3(output_size); 38 | int dim_chw = dim_c * dim_h * dim_w; 39 | 40 | int b = ( index / dim_chw ) % dim_b; 41 | int y = ( index / dim_w ) % dim_h; 42 | int x = ( index ) % dim_w; 43 | 44 | int i1dim_c = DIM1(input1_size); 45 | int i1dim_h = DIM2(input1_size); 46 | int i1dim_w = DIM3(input1_size); 47 | int i1dim_chw = i1dim_c * i1dim_h * i1dim_w; 48 | int i1dim_hw = i1dim_h * i1dim_w; 49 | 50 | float result = 0.0; 51 | 52 | for (int c = 0; c < i1dim_c; ++c) { 53 | int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x; 54 | scalar_t val = input1[i1Index]; 55 | result += static_cast(val * val); 56 | } 57 | result = sqrt(result); 58 | output[index] = static_cast(result); 59 | } 60 | 61 | 62 | template 63 | __global__ void kernel_channelnorm_backward_input1( 64 | const int n, 65 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 66 | const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, 67 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 68 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, 69 | int norm_deg) { 70 | 71 | int index = blockIdx.x * blockDim.x + threadIdx.x; 72 | 73 | if (index >= n) { 74 | return; 75 | } 76 | 77 | float val = 0.0; 78 | 79 | int dim_b = DIM0(gradInput_size); 80 | int dim_c = DIM1(gradInput_size); 81 | int dim_h = DIM2(gradInput_size); 82 | int dim_w = DIM3(gradInput_size); 83 | int dim_chw = dim_c * dim_h * dim_w; 84 | int dim_hw = dim_h * dim_w; 85 | 86 | int b = ( index / dim_chw ) % dim_b; 87 | int y = ( index / dim_w ) % dim_h; 88 | int x = ( index ) % dim_w; 89 | 90 | 91 | int outIndex = b * dim_hw + y * dim_w + x; 92 | val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9); 93 | gradInput[index] = static_cast(val); 94 | 95 | } 96 | 97 | void channelnorm_kernel_forward( 98 | at::Tensor& input1, 99 | at::Tensor& output, 100 | int norm_deg) { 101 | 102 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 103 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 104 | 105 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 106 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 107 | 108 | int n = output.numel(); 109 | 110 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] { 111 | 112 | kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( 113 | n, 114 | input1.data(), 115 | input1_size, 116 | input1_stride, 117 | output.data(), 118 | output_size, 119 | output_stride, 120 | norm_deg); 121 | 122 | })); 123 | 124 | // TODO: ATen-equivalent check 125 | 126 | // THCudaCheck(cudaGetLastError()); 127 | } 128 | 129 | void channelnorm_kernel_backward( 130 | at::Tensor& input1, 131 | at::Tensor& output, 132 | at::Tensor& gradOutput, 133 | at::Tensor& gradInput1, 134 | int norm_deg) { 135 | 136 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 137 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 138 | 139 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 140 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 141 | 142 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); 143 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); 144 | 145 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); 146 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); 147 | 148 | int n = gradInput1.numel(); 149 | 150 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] { 151 | 152 | kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( 153 | n, 154 | input1.data(), 155 | input1_size, 156 | input1_stride, 157 | output.data(), 158 | output_size, 159 | output_stride, 160 | gradOutput.data(), 161 | gradOutput_size, 162 | gradOutput_stride, 163 | gradInput1.data(), 164 | gradInput1_size, 165 | gradInput1_stride, 166 | norm_deg 167 | ); 168 | 169 | })); 170 | 171 | // TODO: Add ATen-equivalent check 172 | 173 | // THCudaCheck(cudaGetLastError()); 174 | } 175 | -------------------------------------------------------------------------------- /flownet2/networks/channelnorm_package/channelnorm_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void channelnorm_kernel_forward( 6 | at::Tensor& input1, 7 | at::Tensor& output, 8 | int norm_deg); 9 | 10 | 11 | void channelnorm_kernel_backward( 12 | at::Tensor& input1, 13 | at::Tensor& output, 14 | at::Tensor& gradOutput, 15 | at::Tensor& gradInput1, 16 | int norm_deg); 17 | -------------------------------------------------------------------------------- /flownet2/networks/channelnorm_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_52,code=sm_52', 12 | '-gencode', 'arch=compute_60,code=sm_60', 13 | '-gencode', 'arch=compute_61,code=sm_61', 14 | '-gencode', 'arch=compute_70,code=sm_70', 15 | '-gencode', 'arch=compute_70,code=compute_70' 16 | ] 17 | 18 | setup( 19 | name='channelnorm_cuda', 20 | ext_modules=[ 21 | CUDAExtension('channelnorm_cuda', [ 22 | 'channelnorm_cuda.cc', 23 | 'channelnorm_kernel.cu' 24 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 25 | ], 26 | cmdclass={ 27 | 'build_ext': BuildExtension 28 | }) 29 | -------------------------------------------------------------------------------- /flownet2/networks/correlation_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/networks/correlation_package/__init__.py -------------------------------------------------------------------------------- /flownet2/networks/correlation_package/correlation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn.modules.module import Module 3 | from torch.autograd import Function 4 | import correlation_cuda 5 | 6 | class CorrelationFunction(Function): 7 | 8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1): 9 | super(CorrelationFunction, self).__init__() 10 | self.pad_size = pad_size 11 | self.kernel_size = kernel_size 12 | self.max_displacement = max_displacement 13 | self.stride1 = stride1 14 | self.stride2 = stride2 15 | self.corr_multiply = corr_multiply 16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1) 17 | 18 | def forward(self, input1, input2): 19 | self.save_for_backward(input1, input2) 20 | 21 | with torch.cuda.device_of(input1): 22 | rbot1 = input1.new() 23 | rbot2 = input2.new() 24 | output = input1.new() 25 | 26 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output, 27 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 28 | 29 | return output 30 | 31 | def backward(self, grad_output): 32 | input1, input2 = self.saved_tensors 33 | 34 | with torch.cuda.device_of(input1): 35 | rbot1 = input1.new() 36 | rbot2 = input2.new() 37 | 38 | grad_input1 = input1.new() 39 | grad_input2 = input2.new() 40 | 41 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2, 42 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply) 43 | 44 | return grad_input1, grad_input2 45 | 46 | 47 | class Correlation(Module): 48 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1): 49 | super(Correlation, self).__init__() 50 | self.pad_size = pad_size 51 | self.kernel_size = kernel_size 52 | self.max_displacement = max_displacement 53 | self.stride1 = stride1 54 | self.stride2 = stride2 55 | self.corr_multiply = corr_multiply 56 | 57 | def forward(self, input1, input2): 58 | 59 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2) 60 | 61 | return result 62 | 63 | -------------------------------------------------------------------------------- /flownet2/networks/correlation_package/correlation_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include 5 | 6 | #include "correlation_cuda_kernel.cuh" 7 | 8 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output, 9 | int pad_size, 10 | int kernel_size, 11 | int max_displacement, 12 | int stride1, 13 | int stride2, 14 | int corr_type_multiply) 15 | { 16 | 17 | int batchSize = input1.size(0); 18 | 19 | int nInputChannels = input1.size(1); 20 | int inputHeight = input1.size(2); 21 | int inputWidth = input1.size(3); 22 | 23 | int kernel_radius = (kernel_size - 1) / 2; 24 | int border_radius = kernel_radius + max_displacement; 25 | 26 | int paddedInputHeight = inputHeight + 2 * pad_size; 27 | int paddedInputWidth = inputWidth + 2 * pad_size; 28 | 29 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1); 30 | 31 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1)); 32 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1)); 33 | 34 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 35 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 36 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth}); 37 | 38 | rInput1.fill_(0); 39 | rInput2.fill_(0); 40 | output.fill_(0); 41 | 42 | int success = correlation_forward_cuda_kernel( 43 | output, 44 | output.size(0), 45 | output.size(1), 46 | output.size(2), 47 | output.size(3), 48 | output.stride(0), 49 | output.stride(1), 50 | output.stride(2), 51 | output.stride(3), 52 | input1, 53 | input1.size(1), 54 | input1.size(2), 55 | input1.size(3), 56 | input1.stride(0), 57 | input1.stride(1), 58 | input1.stride(2), 59 | input1.stride(3), 60 | input2, 61 | input2.size(1), 62 | input2.stride(0), 63 | input2.stride(1), 64 | input2.stride(2), 65 | input2.stride(3), 66 | rInput1, 67 | rInput2, 68 | pad_size, 69 | kernel_size, 70 | max_displacement, 71 | stride1, 72 | stride2, 73 | corr_type_multiply, 74 | at::globalContext().getCurrentCUDAStream() 75 | ); 76 | 77 | //check for errors 78 | if (!success) { 79 | AT_ERROR("CUDA call failed"); 80 | } 81 | 82 | return 1; 83 | 84 | } 85 | 86 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput, 87 | at::Tensor& gradInput1, at::Tensor& gradInput2, 88 | int pad_size, 89 | int kernel_size, 90 | int max_displacement, 91 | int stride1, 92 | int stride2, 93 | int corr_type_multiply) 94 | { 95 | 96 | int batchSize = input1.size(0); 97 | int nInputChannels = input1.size(1); 98 | int paddedInputHeight = input1.size(2)+ 2 * pad_size; 99 | int paddedInputWidth = input1.size(3)+ 2 * pad_size; 100 | 101 | int height = input1.size(2); 102 | int width = input1.size(3); 103 | 104 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 105 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels}); 106 | gradInput1.resize_({batchSize, nInputChannels, height, width}); 107 | gradInput2.resize_({batchSize, nInputChannels, height, width}); 108 | 109 | rInput1.fill_(0); 110 | rInput2.fill_(0); 111 | gradInput1.fill_(0); 112 | gradInput2.fill_(0); 113 | 114 | int success = correlation_backward_cuda_kernel(gradOutput, 115 | gradOutput.size(0), 116 | gradOutput.size(1), 117 | gradOutput.size(2), 118 | gradOutput.size(3), 119 | gradOutput.stride(0), 120 | gradOutput.stride(1), 121 | gradOutput.stride(2), 122 | gradOutput.stride(3), 123 | input1, 124 | input1.size(1), 125 | input1.size(2), 126 | input1.size(3), 127 | input1.stride(0), 128 | input1.stride(1), 129 | input1.stride(2), 130 | input1.stride(3), 131 | input2, 132 | input2.stride(0), 133 | input2.stride(1), 134 | input2.stride(2), 135 | input2.stride(3), 136 | gradInput1, 137 | gradInput1.stride(0), 138 | gradInput1.stride(1), 139 | gradInput1.stride(2), 140 | gradInput1.stride(3), 141 | gradInput2, 142 | gradInput2.size(1), 143 | gradInput2.stride(0), 144 | gradInput2.stride(1), 145 | gradInput2.stride(2), 146 | gradInput2.stride(3), 147 | rInput1, 148 | rInput2, 149 | pad_size, 150 | kernel_size, 151 | max_displacement, 152 | stride1, 153 | stride2, 154 | corr_type_multiply, 155 | at::globalContext().getCurrentCUDAStream() 156 | ); 157 | 158 | if (!success) { 159 | AT_ERROR("CUDA call failed"); 160 | } 161 | 162 | return 1; 163 | } 164 | 165 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 166 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)"); 167 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)"); 168 | } 169 | 170 | -------------------------------------------------------------------------------- /flownet2/networks/correlation_package/correlation_cuda_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | #include 5 | #include 6 | 7 | int correlation_forward_cuda_kernel(at::Tensor& output, 8 | int ob, 9 | int oc, 10 | int oh, 11 | int ow, 12 | int osb, 13 | int osc, 14 | int osh, 15 | int osw, 16 | 17 | at::Tensor& input1, 18 | int ic, 19 | int ih, 20 | int iw, 21 | int isb, 22 | int isc, 23 | int ish, 24 | int isw, 25 | 26 | at::Tensor& input2, 27 | int gc, 28 | int gsb, 29 | int gsc, 30 | int gsh, 31 | int gsw, 32 | 33 | at::Tensor& rInput1, 34 | at::Tensor& rInput2, 35 | int pad_size, 36 | int kernel_size, 37 | int max_displacement, 38 | int stride1, 39 | int stride2, 40 | int corr_type_multiply, 41 | cudaStream_t stream); 42 | 43 | 44 | int correlation_backward_cuda_kernel( 45 | at::Tensor& gradOutput, 46 | int gob, 47 | int goc, 48 | int goh, 49 | int gow, 50 | int gosb, 51 | int gosc, 52 | int gosh, 53 | int gosw, 54 | 55 | at::Tensor& input1, 56 | int ic, 57 | int ih, 58 | int iw, 59 | int isb, 60 | int isc, 61 | int ish, 62 | int isw, 63 | 64 | at::Tensor& input2, 65 | int gsb, 66 | int gsc, 67 | int gsh, 68 | int gsw, 69 | 70 | at::Tensor& gradInput1, 71 | int gisb, 72 | int gisc, 73 | int gish, 74 | int gisw, 75 | 76 | at::Tensor& gradInput2, 77 | int ggc, 78 | int ggsb, 79 | int ggsc, 80 | int ggsh, 81 | int ggsw, 82 | 83 | at::Tensor& rInput1, 84 | at::Tensor& rInput2, 85 | int pad_size, 86 | int kernel_size, 87 | int max_displacement, 88 | int stride1, 89 | int stride2, 90 | int corr_type_multiply, 91 | cudaStream_t stream); 92 | -------------------------------------------------------------------------------- /flownet2/networks/correlation_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup, find_packages 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='correlation_cuda', 21 | ext_modules=[ 22 | CUDAExtension('correlation_cuda', [ 23 | 'correlation_cuda.cc', 24 | 'correlation_cuda_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /flownet2/networks/resample2d_package/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/networks/resample2d_package/__init__.py -------------------------------------------------------------------------------- /flownet2/networks/resample2d_package/resample2d.py: -------------------------------------------------------------------------------- 1 | from torch.nn.modules.module import Module 2 | from torch.autograd import Function, Variable 3 | import resample2d_cuda 4 | 5 | class Resample2dFunction(Function): 6 | 7 | @staticmethod 8 | def forward(ctx, input1, input2, kernel_size=1): 9 | assert input1.is_contiguous() 10 | assert input2.is_contiguous() 11 | 12 | ctx.save_for_backward(input1, input2) 13 | ctx.kernel_size = kernel_size 14 | 15 | _, d, _, _ = input1.size() 16 | b, _, h, w = input2.size() 17 | output = input1.new(b, d, h, w).zero_() 18 | 19 | resample2d_cuda.forward(input1, input2, output, kernel_size) 20 | 21 | return output 22 | 23 | @staticmethod 24 | def backward(ctx, grad_output): 25 | assert grad_output.is_contiguous() 26 | 27 | input1, input2 = ctx.saved_tensors 28 | 29 | grad_input1 = Variable(input1.new(input1.size()).zero_()) 30 | grad_input2 = Variable(input1.new(input2.size()).zero_()) 31 | 32 | resample2d_cuda.backward(input1, input2, grad_output.data, 33 | grad_input1.data, grad_input2.data, 34 | ctx.kernel_size) 35 | 36 | return grad_input1, grad_input2, None 37 | 38 | class Resample2d(Module): 39 | 40 | def __init__(self, kernel_size=1): 41 | super(Resample2d, self).__init__() 42 | self.kernel_size = kernel_size 43 | 44 | def forward(self, input1, input2): 45 | input1_c = input1.contiguous() 46 | return Resample2dFunction.apply(input1_c, input2, self.kernel_size) 47 | -------------------------------------------------------------------------------- /flownet2/networks/resample2d_package/resample2d_cuda.cc: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "resample2d_kernel.cuh" 5 | 6 | int resample2d_cuda_forward( 7 | at::Tensor& input1, 8 | at::Tensor& input2, 9 | at::Tensor& output, 10 | int kernel_size) { 11 | resample2d_kernel_forward(input1, input2, output, kernel_size); 12 | return 1; 13 | } 14 | 15 | int resample2d_cuda_backward( 16 | at::Tensor& input1, 17 | at::Tensor& input2, 18 | at::Tensor& gradOutput, 19 | at::Tensor& gradInput1, 20 | at::Tensor& gradInput2, 21 | int kernel_size) { 22 | resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size); 23 | return 1; 24 | } 25 | 26 | 27 | 28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 29 | m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)"); 30 | m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)"); 31 | } 32 | 33 | -------------------------------------------------------------------------------- /flownet2/networks/resample2d_package/resample2d_kernel.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #define CUDA_NUM_THREADS 512 5 | #define THREADS_PER_BLOCK 64 6 | 7 | #define DIM0(TENSOR) ((TENSOR).x) 8 | #define DIM1(TENSOR) ((TENSOR).y) 9 | #define DIM2(TENSOR) ((TENSOR).z) 10 | #define DIM3(TENSOR) ((TENSOR).w) 11 | 12 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))]) 13 | 14 | template 15 | __global__ void kernel_resample2d_update_output(const int n, 16 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 17 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 18 | scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, int kernel_size) { 19 | int index = blockIdx.x * blockDim.x + threadIdx.x; 20 | 21 | if (index >= n) { 22 | return; 23 | } 24 | 25 | scalar_t val = 0.0f; 26 | 27 | int dim_b = DIM0(output_size); 28 | int dim_c = DIM1(output_size); 29 | int dim_h = DIM2(output_size); 30 | int dim_w = DIM3(output_size); 31 | int dim_chw = dim_c * dim_h * dim_w; 32 | int dim_hw = dim_h * dim_w; 33 | 34 | int b = ( index / dim_chw ) % dim_b; 35 | int c = ( index / dim_hw ) % dim_c; 36 | int y = ( index / dim_w ) % dim_h; 37 | int x = ( index ) % dim_w; 38 | 39 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 40 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 41 | 42 | scalar_t xf = static_cast(x) + dx; 43 | scalar_t yf = static_cast(y) + dy; 44 | scalar_t alpha = xf - floor(xf); // alpha 45 | scalar_t beta = yf - floor(yf); // beta 46 | 47 | int xL = max(min( int (floor(xf)), dim_w-1), 0); 48 | int xR = max(min( int (floor(xf)+1), dim_w -1), 0); 49 | int yT = max(min( int (floor(yf)), dim_h-1), 0); 50 | int yB = max(min( int (floor(yf)+1), dim_h-1), 0); 51 | 52 | for (int fy = 0; fy < kernel_size; fy += 1) { 53 | for (int fx = 0; fx < kernel_size; fx += 1) { 54 | val += static_cast((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx)); 55 | val += static_cast((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx)); 56 | val += static_cast((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx)); 57 | val += static_cast((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx)); 58 | } 59 | } 60 | 61 | output[index] = val; 62 | 63 | } 64 | 65 | 66 | template 67 | __global__ void kernel_resample2d_backward_input1( 68 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 69 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 70 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 71 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) { 72 | 73 | int index = blockIdx.x * blockDim.x + threadIdx.x; 74 | 75 | if (index >= n) { 76 | return; 77 | } 78 | 79 | int dim_b = DIM0(gradOutput_size); 80 | int dim_c = DIM1(gradOutput_size); 81 | int dim_h = DIM2(gradOutput_size); 82 | int dim_w = DIM3(gradOutput_size); 83 | int dim_chw = dim_c * dim_h * dim_w; 84 | int dim_hw = dim_h * dim_w; 85 | 86 | int b = ( index / dim_chw ) % dim_b; 87 | int c = ( index / dim_hw ) % dim_c; 88 | int y = ( index / dim_w ) % dim_h; 89 | int x = ( index ) % dim_w; 90 | 91 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 92 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 93 | 94 | scalar_t xf = static_cast(x) + dx; 95 | scalar_t yf = static_cast(y) + dy; 96 | scalar_t alpha = xf - int(xf); // alpha 97 | scalar_t beta = yf - int(yf); // beta 98 | 99 | int idim_h = DIM2(input1_size); 100 | int idim_w = DIM3(input1_size); 101 | 102 | int xL = max(min( int (floor(xf)), idim_w-1), 0); 103 | int xR = max(min( int (floor(xf)+1), idim_w -1), 0); 104 | int yT = max(min( int (floor(yf)), idim_h-1), 0); 105 | int yB = max(min( int (floor(yf)+1), idim_h-1), 0); 106 | 107 | for (int fy = 0; fy < kernel_size; fy += 1) { 108 | for (int fx = 0; fx < kernel_size; fx += 1) { 109 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 110 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 111 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 112 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x)); 113 | } 114 | } 115 | 116 | } 117 | 118 | template 119 | __global__ void kernel_resample2d_backward_input2( 120 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride, 121 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride, 122 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride, 123 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) { 124 | 125 | int index = blockIdx.x * blockDim.x + threadIdx.x; 126 | 127 | if (index >= n) { 128 | return; 129 | } 130 | 131 | scalar_t output = 0.0; 132 | int kernel_rad = (kernel_size - 1)/2; 133 | 134 | int dim_b = DIM0(gradInput_size); 135 | int dim_c = DIM1(gradInput_size); 136 | int dim_h = DIM2(gradInput_size); 137 | int dim_w = DIM3(gradInput_size); 138 | int dim_chw = dim_c * dim_h * dim_w; 139 | int dim_hw = dim_h * dim_w; 140 | 141 | int b = ( index / dim_chw ) % dim_b; 142 | int c = ( index / dim_hw ) % dim_c; 143 | int y = ( index / dim_w ) % dim_h; 144 | int x = ( index ) % dim_w; 145 | 146 | int odim_c = DIM1(gradOutput_size); 147 | 148 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x); 149 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x); 150 | 151 | scalar_t xf = static_cast(x) + dx; 152 | scalar_t yf = static_cast(y) + dy; 153 | 154 | int xL = max(min( int (floor(xf)), dim_w-1), 0); 155 | int xR = max(min( int (floor(xf)+1), dim_w -1), 0); 156 | int yT = max(min( int (floor(yf)), dim_h-1), 0); 157 | int yB = max(min( int (floor(yf)+1), dim_h-1), 0); 158 | 159 | if (c % 2) { 160 | float gamma = 1 - (xf - floor(xf)); // alpha 161 | for (int i = 0; i <= 2*kernel_rad; ++i) { 162 | for (int j = 0; j <= 2*kernel_rad; ++j) { 163 | for (int ch = 0; ch < odim_c; ++ch) { 164 | output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); 165 | output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); 166 | output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); 167 | output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); 168 | } 169 | } 170 | } 171 | } 172 | else { 173 | float gamma = 1 - (yf - floor(yf)); // alpha 174 | for (int i = 0; i <= 2*kernel_rad; ++i) { 175 | for (int j = 0; j <= 2*kernel_rad; ++j) { 176 | for (int ch = 0; ch < odim_c; ++ch) { 177 | output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i)); 178 | output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i)); 179 | output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i)); 180 | output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i)); 181 | } 182 | } 183 | } 184 | 185 | } 186 | 187 | gradInput[index] = output; 188 | 189 | } 190 | 191 | void resample2d_kernel_forward( 192 | at::Tensor& input1, 193 | at::Tensor& input2, 194 | at::Tensor& output, 195 | int kernel_size) { 196 | 197 | int n = output.numel(); 198 | 199 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 200 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 201 | 202 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); 203 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); 204 | 205 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3)); 206 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3)); 207 | 208 | // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF 209 | // AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] { 210 | 211 | kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( 212 | n, 213 | input1.data(), 214 | input1_size, 215 | input1_stride, 216 | input2.data(), 217 | input2_size, 218 | input2_stride, 219 | output.data(), 220 | output_size, 221 | output_stride, 222 | kernel_size); 223 | 224 | // })); 225 | 226 | // TODO: ATen-equivalent check 227 | 228 | // THCudaCheck(cudaGetLastError()); 229 | 230 | } 231 | 232 | void resample2d_kernel_backward( 233 | at::Tensor& input1, 234 | at::Tensor& input2, 235 | at::Tensor& gradOutput, 236 | at::Tensor& gradInput1, 237 | at::Tensor& gradInput2, 238 | int kernel_size) { 239 | 240 | int n = gradOutput.numel(); 241 | 242 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3)); 243 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3)); 244 | 245 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3)); 246 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3)); 247 | 248 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3)); 249 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3)); 250 | 251 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3)); 252 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3)); 253 | 254 | // AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] { 255 | 256 | kernel_resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( 257 | n, 258 | input1.data(), 259 | input1_size, 260 | input1_stride, 261 | input2.data(), 262 | input2_size, 263 | input2_stride, 264 | gradOutput.data(), 265 | gradOutput_size, 266 | gradOutput_stride, 267 | gradInput1.data(), 268 | gradInput1_size, 269 | gradInput1_stride, 270 | kernel_size 271 | ); 272 | 273 | // })); 274 | 275 | const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3)); 276 | const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3)); 277 | 278 | n = gradInput2.numel(); 279 | 280 | // AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] { 281 | 282 | 283 | kernel_resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>( 284 | n, 285 | input1.data(), 286 | input1_size, 287 | input1_stride, 288 | input2.data(), 289 | input2_size, 290 | input2_stride, 291 | gradOutput.data(), 292 | gradOutput_size, 293 | gradOutput_stride, 294 | gradInput2.data(), 295 | gradInput2_size, 296 | gradInput2_stride, 297 | kernel_size 298 | ); 299 | 300 | // })); 301 | 302 | // TODO: Use the ATen equivalent to get last error 303 | 304 | // THCudaCheck(cudaGetLastError()); 305 | 306 | } 307 | -------------------------------------------------------------------------------- /flownet2/networks/resample2d_package/resample2d_kernel.cuh: -------------------------------------------------------------------------------- 1 | #pragma once 2 | 3 | #include 4 | 5 | void resample2d_kernel_forward( 6 | at::Tensor& input1, 7 | at::Tensor& input2, 8 | at::Tensor& output, 9 | int kernel_size); 10 | 11 | void resample2d_kernel_backward( 12 | at::Tensor& input1, 13 | at::Tensor& input2, 14 | at::Tensor& gradOutput, 15 | at::Tensor& gradInput1, 16 | at::Tensor& gradInput2, 17 | int kernel_size); 18 | -------------------------------------------------------------------------------- /flownet2/networks/resample2d_package/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | import os 3 | import torch 4 | 5 | from setuptools import setup 6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 7 | 8 | cxx_args = ['-std=c++11'] 9 | 10 | nvcc_args = [ 11 | '-gencode', 'arch=compute_50,code=sm_50', 12 | '-gencode', 'arch=compute_52,code=sm_52', 13 | '-gencode', 'arch=compute_60,code=sm_60', 14 | '-gencode', 'arch=compute_61,code=sm_61', 15 | '-gencode', 'arch=compute_70,code=sm_70', 16 | '-gencode', 'arch=compute_70,code=compute_70' 17 | ] 18 | 19 | setup( 20 | name='resample2d_cuda', 21 | ext_modules=[ 22 | CUDAExtension('resample2d_cuda', [ 23 | 'resample2d_cuda.cc', 24 | 'resample2d_kernel.cu' 25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args}) 26 | ], 27 | cmdclass={ 28 | 'build_ext': BuildExtension 29 | }) 30 | -------------------------------------------------------------------------------- /flownet2/networks/submodules.py: -------------------------------------------------------------------------------- 1 | # freda (todo) : 2 | 3 | import torch.nn as nn 4 | import torch 5 | import numpy as np 6 | 7 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1): 8 | if batchNorm: 9 | return nn.Sequential( 10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False), 11 | nn.BatchNorm2d(out_planes), 12 | nn.LeakyReLU(0.1,inplace=True) 13 | ) 14 | else: 15 | return nn.Sequential( 16 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True), 17 | nn.LeakyReLU(0.1,inplace=True) 18 | ) 19 | 20 | def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True): 21 | if batchNorm: 22 | return nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 24 | nn.BatchNorm2d(out_planes), 25 | ) 26 | else: 27 | return nn.Sequential( 28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias), 29 | ) 30 | 31 | def predict_flow(in_planes): 32 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True) 33 | 34 | def deconv(in_planes, out_planes): 35 | return nn.Sequential( 36 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True), 37 | nn.LeakyReLU(0.1,inplace=True) 38 | ) 39 | 40 | class tofp16(nn.Module): 41 | def __init__(self): 42 | super(tofp16, self).__init__() 43 | 44 | def forward(self, input): 45 | return input.half() 46 | 47 | 48 | class tofp32(nn.Module): 49 | def __init__(self): 50 | super(tofp32, self).__init__() 51 | 52 | def forward(self, input): 53 | return input.float() 54 | 55 | 56 | def init_deconv_bilinear(weight): 57 | f_shape = weight.size() 58 | heigh, width = f_shape[-2], f_shape[-1] 59 | f = np.ceil(width/2.0) 60 | c = (2 * f - 1 - f % 2) / (2.0 * f) 61 | bilinear = np.zeros([heigh, width]) 62 | for x in range(width): 63 | for y in range(heigh): 64 | value = (1 - abs(x / f - c)) * (1 - abs(y / f - c)) 65 | bilinear[x, y] = value 66 | weight.data.fill_(0.) 67 | for i in range(f_shape[0]): 68 | for j in range(f_shape[1]): 69 | weight.data[i,j,:,:] = torch.from_numpy(bilinear) 70 | 71 | 72 | def save_grad(grads, name): 73 | def hook(grad): 74 | grads[name] = grad 75 | return hook 76 | 77 | ''' 78 | def save_grad(grads, name): 79 | def hook(grad): 80 | grads[name] = grad 81 | return hook 82 | import torch 83 | from channelnorm_package.modules.channelnorm import ChannelNorm 84 | model = ChannelNorm().cuda() 85 | grads = {} 86 | a = 100*torch.autograd.Variable(torch.randn((1,3,5,5)).cuda(), requires_grad=True) 87 | a.register_hook(save_grad(grads, 'a')) 88 | b = model(a) 89 | y = torch.mean(b) 90 | y.backward() 91 | 92 | ''' 93 | -------------------------------------------------------------------------------- /flownet2/run-caffe2pytorch.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | FN2PYTORCH=${1:-/} 4 | 5 | # install custom layers 6 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 . 7 | sudo nvidia-docker run --rm -ti --volume=${FN2PYTORCH}:/flownet2-pytorch:rw --workdir=/flownet2-pytorch $USER/pytorch:CUDA8-py27 /bin/bash -c "./install.sh" 8 | 9 | # convert FlowNet2-C, CS, CSS, CSS-ft-sd, SD, S and 2 to PyTorch 10 | sudo nvidia-docker run -ti --volume=${FN2PYTORCH}:/fn2pytorch:rw flownet2:latest /bin/bash -c "source /flownet2/flownet2/set-env.sh && cd /flownet2/flownet2/models && \ 11 | python /fn2pytorch/convert.py ./FlowNet2-C/FlowNet2-C_weights.caffemodel ./FlowNet2-C/FlowNet2-C_deploy.prototxt.template /fn2pytorch && 12 | python /fn2pytorch/convert.py ./FlowNet2-CS/FlowNet2-CS_weights.caffemodel ./FlowNet2-CS/FlowNet2-CS_deploy.prototxt.template /fn2pytorch && \ 13 | python /fn2pytorch/convert.py ./FlowNet2-CSS/FlowNet2-CSS_weights.caffemodel.h5 ./FlowNet2-CSS/FlowNet2-CSS_deploy.prototxt.template /fn2pytorch && \ 14 | python /fn2pytorch/convert.py ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_weights.caffemodel.h5 ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_deploy.prototxt.template /fn2pytorch && \ 15 | python /fn2pytorch/convert.py ./FlowNet2-SD/FlowNet2-SD_weights.caffemodel.h5 ./FlowNet2-SD/FlowNet2-SD_deploy.prototxt.template /fn2pytorch && \ 16 | python /fn2pytorch/convert.py ./FlowNet2-S/FlowNet2-S_weights.caffemodel.h5 ./FlowNet2-S/FlowNet2-S_deploy.prototxt.template /fn2pytorch && \ 17 | python /fn2pytorch/convert.py ./FlowNet2/FlowNet2_weights.caffemodel.h5 ./FlowNet2/FlowNet2_deploy.prototxt.template /fn2pytorch" -------------------------------------------------------------------------------- /flownet2/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/utils/__init__.py -------------------------------------------------------------------------------- /flownet2/utils/flow_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | TAG_CHAR = np.array([202021.25], np.float32) 4 | 5 | def readFlow(fn): 6 | """ Read .flo file in Middlebury format""" 7 | # Code adapted from: 8 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy 9 | 10 | # WARNING: this will work on little-endian architectures (eg Intel x86) only! 11 | # print 'fn = %s'%(fn) 12 | with open(fn, 'rb') as f: 13 | magic = np.fromfile(f, np.float32, count=1) 14 | if 202021.25 != magic: 15 | print('Magic number incorrect. Invalid .flo file') 16 | return None 17 | else: 18 | w = np.fromfile(f, np.int32, count=1) 19 | h = np.fromfile(f, np.int32, count=1) 20 | # print 'Reading %d x %d flo file\n' % (w, h) 21 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) 22 | # Reshape data into 3D array (columns, rows, bands) 23 | # The reshape here is for visualization, the original code is (w,h,2) 24 | return np.resize(data, (int(h), int(w), 2)) 25 | 26 | def writeFlow(filename,uv,v=None): 27 | """ Write optical flow to file. 28 | 29 | If v is None, uv is assumed to contain both u and v channels, 30 | stacked in depth. 31 | Original code by Deqing Sun, adapted from Daniel Scharstein. 32 | """ 33 | nBands = 2 34 | 35 | if v is None: 36 | assert(uv.ndim == 3) 37 | assert(uv.shape[2] == 2) 38 | u = uv[:,:,0] 39 | v = uv[:,:,1] 40 | else: 41 | u = uv 42 | 43 | assert(u.shape == v.shape) 44 | height,width = u.shape 45 | f = open(filename,'wb') 46 | # write the header 47 | f.write(TAG_CHAR) 48 | np.array(width).astype(np.int32).tofile(f) 49 | np.array(height).astype(np.int32).tofile(f) 50 | # arrange into matrix form 51 | tmp = np.zeros((height, width*nBands)) 52 | tmp[:,np.arange(width)*2] = u 53 | tmp[:,np.arange(width)*2 + 1] = v 54 | tmp.astype(np.float32).tofile(f) 55 | f.close() 56 | -------------------------------------------------------------------------------- /flownet2/utils/frame_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from os.path import * 3 | from scipy.misc import imread 4 | from . import flow_utils 5 | 6 | def read_gen(file_name): 7 | ext = splitext(file_name)[-1] 8 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': 9 | im = imread(file_name) 10 | if im.shape[2] > 3: 11 | return im[:,:,:3] 12 | else: 13 | return im 14 | elif ext == '.bin' or ext == '.raw': 15 | return np.load(file_name) 16 | elif ext == '.flo': 17 | return flow_utils.readFlow(file_name).astype(np.float32) 18 | return [] 19 | -------------------------------------------------------------------------------- /flownet2/utils/param_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | 5 | def parse_flownetc(modules, weights, biases): 6 | keys = [ 7 | 'conv1', 8 | 'conv2', 9 | 'conv3', 10 | 'conv_redir', 11 | 'conv3_1', 12 | 'conv4', 13 | 'conv4_1', 14 | 'conv5', 15 | 'conv5_1', 16 | 'conv6', 17 | 'conv6_1', 18 | 19 | 'deconv5', 20 | 'deconv4', 21 | 'deconv3', 22 | 'deconv2', 23 | 24 | 'Convolution1', 25 | 'Convolution2', 26 | 'Convolution3', 27 | 'Convolution4', 28 | 'Convolution5', 29 | 30 | 'upsample_flow6to5', 31 | 'upsample_flow5to4', 32 | 'upsample_flow4to3', 33 | 'upsample_flow3to2', 34 | 35 | ] 36 | i = 0 37 | for m in modules: 38 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 39 | weight = weights[keys[i]].copy() 40 | bias = biases[keys[i]].copy() 41 | if keys[i] == 'conv1': 42 | m.weight.data[:,:,:,:] = torch.from_numpy(np.flip(weight, axis=1).copy()) 43 | m.bias.data[:] = torch.from_numpy(bias) 44 | else: 45 | m.weight.data[:,:,:,:] = torch.from_numpy(weight) 46 | m.bias.data[:] = torch.from_numpy(bias) 47 | 48 | i = i + 1 49 | return 50 | 51 | def parse_flownets(modules, weights, biases, param_prefix='net2_'): 52 | keys = [ 53 | 'conv1', 54 | 'conv2', 55 | 'conv3', 56 | 'conv3_1', 57 | 'conv4', 58 | 'conv4_1', 59 | 'conv5', 60 | 'conv5_1', 61 | 'conv6', 62 | 'conv6_1', 63 | 64 | 'deconv5', 65 | 'deconv4', 66 | 'deconv3', 67 | 'deconv2', 68 | 69 | 'predict_conv6', 70 | 'predict_conv5', 71 | 'predict_conv4', 72 | 'predict_conv3', 73 | 'predict_conv2', 74 | 75 | 'upsample_flow6to5', 76 | 'upsample_flow5to4', 77 | 'upsample_flow4to3', 78 | 'upsample_flow3to2', 79 | ] 80 | for i, k in enumerate(keys): 81 | if 'upsample' in k: 82 | keys[i] = param_prefix + param_prefix + k 83 | else: 84 | keys[i] = param_prefix + k 85 | i = 0 86 | for m in modules: 87 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 88 | weight = weights[keys[i]].copy() 89 | bias = biases[keys[i]].copy() 90 | if keys[i] == param_prefix+'conv1': 91 | m.weight.data[:,0:3,:,:] = torch.from_numpy(np.flip(weight[:,0:3,:,:], axis=1).copy()) 92 | m.weight.data[:,3:6,:,:] = torch.from_numpy(np.flip(weight[:,3:6,:,:], axis=1).copy()) 93 | m.weight.data[:,6:9,:,:] = torch.from_numpy(np.flip(weight[:,6:9,:,:], axis=1).copy()) 94 | m.weight.data[:,9::,:,:] = torch.from_numpy(weight[:,9:,:,:].copy()) 95 | if m.bias is not None: 96 | m.bias.data[:] = torch.from_numpy(bias) 97 | else: 98 | m.weight.data[:,:,:,:] = torch.from_numpy(weight) 99 | if m.bias is not None: 100 | m.bias.data[:] = torch.from_numpy(bias) 101 | i = i + 1 102 | return 103 | 104 | def parse_flownetsonly(modules, weights, biases, param_prefix=''): 105 | keys = [ 106 | 'conv1', 107 | 'conv2', 108 | 'conv3', 109 | 'conv3_1', 110 | 'conv4', 111 | 'conv4_1', 112 | 'conv5', 113 | 'conv5_1', 114 | 'conv6', 115 | 'conv6_1', 116 | 117 | 'deconv5', 118 | 'deconv4', 119 | 'deconv3', 120 | 'deconv2', 121 | 122 | 'Convolution1', 123 | 'Convolution2', 124 | 'Convolution3', 125 | 'Convolution4', 126 | 'Convolution5', 127 | 128 | 'upsample_flow6to5', 129 | 'upsample_flow5to4', 130 | 'upsample_flow4to3', 131 | 'upsample_flow3to2', 132 | ] 133 | for i, k in enumerate(keys): 134 | if 'upsample' in k: 135 | keys[i] = param_prefix + param_prefix + k 136 | else: 137 | keys[i] = param_prefix + k 138 | i = 0 139 | for m in modules: 140 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 141 | weight = weights[keys[i]].copy() 142 | bias = biases[keys[i]].copy() 143 | if keys[i] == param_prefix+'conv1': 144 | # print ("%s :"%(keys[i]), m.weight.size(), m.bias.size(), tf_w[keys[i]].shape[::-1]) 145 | m.weight.data[:,0:3,:,:] = torch.from_numpy(np.flip(weight[:,0:3,:,:], axis=1).copy()) 146 | m.weight.data[:,3:6,:,:] = torch.from_numpy(np.flip(weight[:,3:6,:,:], axis=1).copy()) 147 | if m.bias is not None: 148 | m.bias.data[:] = torch.from_numpy(bias) 149 | else: 150 | m.weight.data[:,:,:,:] = torch.from_numpy(weight) 151 | if m.bias is not None: 152 | m.bias.data[:] = torch.from_numpy(bias) 153 | i = i + 1 154 | return 155 | 156 | def parse_flownetsd(modules, weights, biases, param_prefix='netsd_'): 157 | keys = [ 158 | 'conv0', 159 | 'conv1', 160 | 'conv1_1', 161 | 'conv2', 162 | 'conv2_1', 163 | 'conv3', 164 | 'conv3_1', 165 | 'conv4', 166 | 'conv4_1', 167 | 'conv5', 168 | 'conv5_1', 169 | 'conv6', 170 | 'conv6_1', 171 | 172 | 'deconv5', 173 | 'deconv4', 174 | 'deconv3', 175 | 'deconv2', 176 | 177 | 'interconv5', 178 | 'interconv4', 179 | 'interconv3', 180 | 'interconv2', 181 | 182 | 'Convolution1', 183 | 'Convolution2', 184 | 'Convolution3', 185 | 'Convolution4', 186 | 'Convolution5', 187 | 188 | 'upsample_flow6to5', 189 | 'upsample_flow5to4', 190 | 'upsample_flow4to3', 191 | 'upsample_flow3to2', 192 | ] 193 | for i, k in enumerate(keys): 194 | keys[i] = param_prefix + k 195 | 196 | i = 0 197 | for m in modules: 198 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 199 | weight = weights[keys[i]].copy() 200 | bias = biases[keys[i]].copy() 201 | if keys[i] == param_prefix+'conv0': 202 | m.weight.data[:,0:3,:,:] = torch.from_numpy(np.flip(weight[:,0:3,:,:], axis=1).copy()) 203 | m.weight.data[:,3:6,:,:] = torch.from_numpy(np.flip(weight[:,3:6,:,:], axis=1).copy()) 204 | if m.bias is not None: 205 | m.bias.data[:] = torch.from_numpy(bias) 206 | else: 207 | m.weight.data[:,:,:,:] = torch.from_numpy(weight) 208 | if m.bias is not None: 209 | m.bias.data[:] = torch.from_numpy(bias) 210 | i = i + 1 211 | 212 | return 213 | 214 | def parse_flownetfusion(modules, weights, biases, param_prefix='fuse_'): 215 | keys = [ 216 | 'conv0', 217 | 'conv1', 218 | 'conv1_1', 219 | 'conv2', 220 | 'conv2_1', 221 | 222 | 'deconv1', 223 | 'deconv0', 224 | 225 | 'interconv1', 226 | 'interconv0', 227 | 228 | '_Convolution5', 229 | '_Convolution6', 230 | '_Convolution7', 231 | 232 | 'upsample_flow2to1', 233 | 'upsample_flow1to0', 234 | ] 235 | for i, k in enumerate(keys): 236 | keys[i] = param_prefix + k 237 | 238 | i = 0 239 | for m in modules: 240 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 241 | weight = weights[keys[i]].copy() 242 | bias = biases[keys[i]].copy() 243 | if keys[i] == param_prefix+'conv0': 244 | m.weight.data[:,0:3,:,:] = torch.from_numpy(np.flip(weight[:,0:3,:,:], axis=1).copy()) 245 | m.weight.data[:,3::,:,:] = torch.from_numpy(weight[:,3:,:,:].copy()) 246 | if m.bias is not None: 247 | m.bias.data[:] = torch.from_numpy(bias) 248 | else: 249 | m.weight.data[:,:,:,:] = torch.from_numpy(weight) 250 | if m.bias is not None: 251 | m.bias.data[:] = torch.from_numpy(bias) 252 | i = i + 1 253 | 254 | return 255 | -------------------------------------------------------------------------------- /flownet2/utils/tools.py: -------------------------------------------------------------------------------- 1 | # freda (todo) : 2 | 3 | import os, time, sys, math 4 | import subprocess, shutil 5 | from os.path import * 6 | import numpy as np 7 | from inspect import isclass 8 | from pytz import timezone 9 | from datetime import datetime 10 | import inspect 11 | import torch 12 | 13 | def datestr(): 14 | pacific = timezone('US/Pacific') 15 | now = datetime.now(pacific) 16 | return '{}{:02}{:02}_{:02}{:02}'.format(now.year, now.month, now.day, now.hour, now.minute) 17 | 18 | def module_to_dict(module, exclude=[]): 19 | return dict([(x, getattr(module, x)) for x in dir(module) 20 | if isclass(getattr(module, x)) 21 | and x not in exclude 22 | and getattr(module, x) not in exclude]) 23 | 24 | class TimerBlock: 25 | def __init__(self, title): 26 | print(("{}".format(title))) 27 | 28 | def __enter__(self): 29 | self.start = time.clock() 30 | return self 31 | 32 | def __exit__(self, exc_type, exc_value, traceback): 33 | self.end = time.clock() 34 | self.interval = self.end - self.start 35 | 36 | if exc_type is not None: 37 | self.log("Operation failed\n") 38 | else: 39 | self.log("Operation finished\n") 40 | 41 | 42 | def log(self, string): 43 | duration = time.clock() - self.start 44 | units = 's' 45 | if duration > 60: 46 | duration = duration / 60. 47 | units = 'm' 48 | print((" [{:.3f}{}] {}".format(duration, units, string))) 49 | 50 | def log2file(self, fid, string): 51 | fid = open(fid, 'a') 52 | fid.write("%s\n"%(string)) 53 | fid.close() 54 | 55 | def add_arguments_for_module(parser, module, argument_for_class, default, skip_params=[], parameter_defaults={}): 56 | argument_group = parser.add_argument_group(argument_for_class.capitalize()) 57 | 58 | module_dict = module_to_dict(module) 59 | argument_group.add_argument('--' + argument_for_class, type=str, default=default, choices=list(module_dict.keys())) 60 | 61 | args, unknown_args = parser.parse_known_args() 62 | class_obj = module_dict[vars(args)[argument_for_class]] 63 | 64 | argspec = inspect.getargspec(class_obj.__init__) 65 | 66 | defaults = argspec.defaults[::-1] if argspec.defaults else None 67 | 68 | args = argspec.args[::-1] 69 | for i, arg in enumerate(args): 70 | cmd_arg = '{}_{}'.format(argument_for_class, arg) 71 | if arg not in skip_params + ['self', 'args']: 72 | if arg in list(parameter_defaults.keys()): 73 | argument_group.add_argument('--{}'.format(cmd_arg), type=type(parameter_defaults[arg]), default=parameter_defaults[arg]) 74 | elif (defaults is not None and i < len(defaults)): 75 | argument_group.add_argument('--{}'.format(cmd_arg), type=type(defaults[i]), default=defaults[i]) 76 | else: 77 | print(("[Warning]: non-default argument '{}' detected on class '{}'. This argument cannot be modified via the command line" 78 | .format(arg, module.__class__.__name__))) 79 | # We don't have a good way of dealing with inferring the type of the argument 80 | # TODO: try creating a custom action and using ast's infer type? 81 | # else: 82 | # argument_group.add_argument('--{}'.format(cmd_arg), required=True) 83 | 84 | def kwargs_from_args(args, argument_for_class): 85 | argument_for_class = argument_for_class + '_' 86 | return {key[len(argument_for_class):]: value for key, value in list(vars(args).items()) if argument_for_class in key and key != argument_for_class + 'class'} 87 | 88 | def format_dictionary_of_losses(labels, values): 89 | try: 90 | string = ', '.join([('{}: {:' + ('.3f' if value >= 0.001 else '.1e') +'}').format(name, value) for name, value in zip(labels, values)]) 91 | except (TypeError, ValueError) as e: 92 | print((list(zip(labels, values)))) 93 | string = '[Log Error] ' + str(e) 94 | 95 | return string 96 | 97 | 98 | class IteratorTimer(): 99 | def __init__(self, iterable): 100 | self.iterable = iterable 101 | self.iterator = self.iterable.__iter__() 102 | 103 | def __iter__(self): 104 | return self 105 | 106 | def __len__(self): 107 | return len(self.iterable) 108 | 109 | def __next__(self): 110 | start = time.time() 111 | n = next(self.iterator) 112 | self.last_duration = (time.time() - start) 113 | return n 114 | 115 | next = __next__ 116 | 117 | def gpumemusage(): 118 | gpu_mem = subprocess.check_output("nvidia-smi | grep MiB | cut -f 3 -d '|'", shell=True).replace(' ', '').replace('\n', '').replace('i', '') 119 | all_stat = [float(a) for a in gpu_mem.replace('/','').split('MB')[:-1]] 120 | 121 | gpu_mem = '' 122 | for i in range(len(all_stat)/2): 123 | curr, tot = all_stat[2*i], all_stat[2*i+1] 124 | util = "%1.2f"%(100*curr/tot)+'%' 125 | cmem = str(int(math.ceil(curr/1024.)))+'GB' 126 | gmem = str(int(math.ceil(tot/1024.)))+'GB' 127 | gpu_mem += util + '--' + join(cmem, gmem) + ' ' 128 | return gpu_mem 129 | 130 | 131 | def update_hyperparameter_schedule(args, epoch, global_iteration, optimizer): 132 | if args.schedule_lr_frequency > 0: 133 | for param_group in optimizer.param_groups: 134 | if (global_iteration + 1) % args.schedule_lr_frequency == 0: 135 | param_group['lr'] /= float(args.schedule_lr_fraction) 136 | param_group['lr'] = float(np.maximum(param_group['lr'], 0.000001)) 137 | 138 | def save_checkpoint(state, is_best, path, prefix, filename='checkpoint.pth.tar'): 139 | prefix_save = os.path.join(path, prefix) 140 | name = prefix_save + '_' + filename 141 | torch.save(state, name) 142 | if is_best: 143 | shutil.copyfile(name, prefix_save + '_model_best.pth.tar') 144 | 145 | -------------------------------------------------------------------------------- /generate_pseudo_labels.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import os 9 | import sys 10 | sys.path.append('flownet2') 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils import data 15 | from torchvision.transforms import functional as TF 16 | 17 | import argparse 18 | from tqdm import tqdm 19 | 20 | from libs.datasets import get_transforms, get_datasets 21 | from libs.networks.pseudo_label_generator import FGPLG 22 | from libs.utils.pyt_utils import load_model 23 | 24 | parser = argparse.ArgumentParser() 25 | 26 | # Dataloading-related settings 27 | parser.add_argument('--data', type=str, default='data/datasets/', 28 | help='path to datasets folder') 29 | parser.add_argument('--checkpoint', default='models/pseudo_label_generator_5.pth', 30 | help='path to the pretrained checkpoint') 31 | parser.add_argument('--dataset-config', default='config/datasets.yaml', 32 | help='dataset config file') 33 | parser.add_argument('--pseudo-label-folder', default='data/pseudo-labels', 34 | help='location to save generated pseudo-labels') 35 | parser.add_argument("--label_interval", default=5, type=int, 36 | help="the interval of ground truth labels") 37 | parser.add_argument("--frame_between_label_num", default=1, type=int, 38 | help="the number of generated pseudo-labels in each interval") 39 | parser.add_argument('-j', '--num_workers', default=1, type=int, metavar='N', 40 | help='number of data loading workers.') 41 | 42 | # Model settings 43 | parser.add_argument('--size', default=448, type=int, 44 | help='image size') 45 | parser.add_argument('--os', default=16, type=int, 46 | help='output stride.') 47 | 48 | # FlowNet setting 49 | parser.add_argument("--fp16", action="store_true", 50 | help="Run model in pseudo-fp16 mode (fp16 storage fp32 math).") 51 | parser.add_argument("--rgb_max", type=float, default=1.) 52 | 53 | args = parser.parse_args() 54 | 55 | cuda = torch.cuda.is_available() 56 | device = torch.device("cuda" if cuda else "cpu") 57 | 58 | if cuda: 59 | torch.backends.cudnn.benchmark = True 60 | current_device = torch.cuda.current_device() 61 | print("Running on", torch.cuda.get_device_name(current_device)) 62 | else: 63 | print("Running on CPU") 64 | 65 | data_transforms = get_transforms( 66 | input_size=(args.size, args.size), 67 | image_mode=False 68 | ) 69 | dataset = get_datasets( 70 | name_list=["DAVIS2016", "FBMS", "VOS"], 71 | split_list=["train", "train", "train"], 72 | config_path=args.dataset_config, 73 | root=args.data, 74 | training=True, # provide labels 75 | transforms=data_transforms['test'], 76 | read_clip=True, 77 | random_reverse_clip=False, 78 | label_interval=args.label_interval, 79 | frame_between_label_num=args.frame_between_label_num, 80 | clip_len=args.frame_between_label_num+2 81 | ) 82 | 83 | dataloader = data.DataLoader( 84 | dataset=dataset, 85 | batch_size=1, # only support 1 video clip 86 | num_workers=args.num_workers, 87 | shuffle=False, 88 | drop_last=True 89 | ) 90 | 91 | pseudo_label_generator = FGPLG(args=args, output_stride=args.os) 92 | 93 | # load pretrained models 94 | if os.path.exists(args.checkpoint): 95 | print('Loading state dict from: {0}'.format(args.checkpoint)) 96 | pseudo_label_generator = load_model(model=pseudo_label_generator, model_file=args.checkpoint, is_restore=True) 97 | else: 98 | raise ValueError("Cannot find model file at {}".format(args.checkpoint)) 99 | 100 | pseudo_label_generator.to(device) 101 | 102 | pseudo_label_folder = os.path.join(args.pseudo_label_folder, "{}_{}".format(args.frame_between_label_num, args.label_interval)) 103 | if not os.path.exists(pseudo_label_folder): 104 | os.makedirs(pseudo_label_folder) 105 | 106 | def generate_pseudo_label(): 107 | pseudo_label_generator.eval() 108 | 109 | for data in tqdm(dataloader): 110 | images = [] 111 | labels = [] 112 | for frame in data: 113 | images.append(frame['image'].to(device)) 114 | labels.append(frame['label'].to(device) if 'label' in frame else None) 115 | with torch.no_grad(): 116 | for i in range(1, args.frame_between_label_num+1): 117 | pseudo_label = pseudo_label_generator.generate_pseudo_label(images[i], images[0], images[-1], labels[0], labels[-1]) 118 | labels[i] = torch.sigmoid(pseudo_label).detach() 119 | # save pseudo-labels 120 | for i, label_ in enumerate(labels): 121 | for j, label in enumerate(label_.detach().cpu()): 122 | dataset = data[i]['dataset'][j] 123 | image_id = data[i]['image_id'][j] 124 | pseudo_label_path = os.path.join(pseudo_label_folder, "{}/{}.png".format(dataset, image_id)) 125 | 126 | height = data[i]['height'].item() 127 | width = data[i]['width'].item() 128 | result = TF.to_pil_image(label) 129 | result = result.resize((height, width)) 130 | dirname = os.path.dirname(pseudo_label_path) 131 | if not os.path.exists(dirname): 132 | os.makedirs(dirname) 133 | result.save(pseudo_label_path) 134 | 135 | if __name__ == "__main__": 136 | print("Generating pseudo-labels at {}".format(args.pseudo_label_folder)) 137 | print("label interval: {}".format(args.label_interval)) 138 | print("frame between label num: {}".format(args.frame_between_label_num)) 139 | generate_pseudo_label() 140 | -------------------------------------------------------------------------------- /inference.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import os 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils import data 13 | from torchvision.transforms import functional as TF 14 | 15 | import argparse 16 | from tqdm import tqdm 17 | 18 | from libs.datasets import get_transforms, get_datasets 19 | from libs.networks import VideoModel 20 | from libs.utils.pyt_utils import load_model 21 | 22 | parser = argparse.ArgumentParser() 23 | 24 | # Dataloading-related settings 25 | parser.add_argument('--data', type=str, default='data/datasets/', 26 | help='path to datasets folder') 27 | parser.add_argument('--dataset', default='VOS', type=str, 28 | help='dataset name for inference') 29 | parser.add_argument('--split', default='test', type=str, 30 | help='dataset split for inference') 31 | parser.add_argument('--checkpoint', default='models/video_best_model.pth', 32 | help='path to the pretrained checkpoint') 33 | parser.add_argument('--dataset-config', default='config/datasets.yaml', 34 | help='dataset config file') 35 | parser.add_argument('--results-folder', default='data/results/', 36 | help='location to save predicted saliency maps') 37 | parser.add_argument('-j', '--num_workers', default=1, type=int, metavar='N', 38 | help='number of data loading workers.') 39 | 40 | # Model settings 41 | parser.add_argument('--size', default=448, type=int, 42 | help='image size') 43 | parser.add_argument('--os', default=16, type=int, 44 | help='output stride.') 45 | parser.add_argument("--clip_len", type=int, default=4, 46 | help="the number of frames in a video clip.") 47 | 48 | args = parser.parse_args() 49 | 50 | cuda = torch.cuda.is_available() 51 | device = torch.device("cuda" if cuda else "cpu") 52 | 53 | if cuda: 54 | torch.backends.cudnn.benchmark = True 55 | current_device = torch.cuda.current_device() 56 | print("Running on", torch.cuda.get_device_name(current_device)) 57 | else: 58 | print("Running on CPU") 59 | 60 | data_transforms = get_transforms( 61 | input_size=(args.size, args.size), 62 | image_mode=False 63 | ) 64 | dataset = get_datasets( 65 | name_list=args.dataset, 66 | split_list=args.split, 67 | config_path=args.dataset_config, 68 | root=args.data, 69 | training=False, 70 | transforms=data_transforms['test'], 71 | read_clip=True, 72 | random_reverse_clip=False, 73 | label_interval=1, 74 | frame_between_label_num=0, 75 | clip_len=args.clip_len 76 | ) 77 | 78 | dataloader = data.DataLoader( 79 | dataset=dataset, 80 | batch_size=1, # only support 1 video clip 81 | num_workers=args.num_workers, 82 | shuffle=False 83 | ) 84 | 85 | model = VideoModel(output_stride=args.os) 86 | 87 | # load pretrained models 88 | if os.path.exists(args.checkpoint): 89 | print('Loading state dict from: {0}'.format(args.checkpoint)) 90 | model = load_model(model=model, model_file=args.checkpoint, is_restore=True) 91 | else: 92 | raise ValueError("Cannot find model file at {}".format(args.checkpoint)) 93 | 94 | model.to(device) 95 | 96 | def inference(): 97 | model.eval() 98 | print("Begin inference on {} {}.".format(args.dataset, args.split)) 99 | for data in tqdm(dataloader): 100 | images = [frame['image'].to(device) for frame in data] 101 | with torch.no_grad(): 102 | preds = model(images) 103 | preds = [torch.sigmoid(pred) for pred in preds] 104 | # save predicted saliency maps 105 | for i, pred_ in enumerate(preds): 106 | for j, pred in enumerate(pred_.detach().cpu()): 107 | dataset = data[i]['dataset'][j] 108 | image_id = data[i]['image_id'][j] 109 | height = data[i]['height'].item() 110 | width = data[i]['width'].item() 111 | result_path = os.path.join(args.results_folder, "{}/{}.png".format(dataset, image_id)) 112 | 113 | result = TF.to_pil_image(pred) 114 | result = result.resize((height, width)) 115 | dirname = os.path.dirname(result_path) 116 | if not os.path.exists(dirname): 117 | os.makedirs(dirname) 118 | result.save(result_path) 119 | 120 | if __name__ == "__main__": 121 | inference() 122 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/libs/__init__.py -------------------------------------------------------------------------------- /libs/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .transforms import get_transforms 2 | from .video_datasets import get_datasets -------------------------------------------------------------------------------- /libs/datasets/transforms.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | import random 8 | 9 | from PIL import Image 10 | import torch 11 | from torchvision import transforms 12 | from torchvision.transforms import functional as F 13 | from torch.utils import data 14 | import numpy as np 15 | 16 | def get_transforms(image_mode, input_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): 17 | data_transforms = { 18 | 'train': transforms.Compose([ 19 | ColorJitter(brightness=.3, contrast=.3, saturation=.3, hue=.3, image_mode=image_mode), 20 | RandomResizedCrop(input_size, image_mode), 21 | RandomFlip(image_mode), 22 | ToTensor(), 23 | Normalize(mean=mean, 24 | std=std) 25 | ]) if image_mode else transforms.Compose([ 26 | Resize(input_size), 27 | ToTensor(), 28 | Normalize(mean=mean, 29 | std=std) 30 | ]), 31 | 'val': transforms.Compose([ 32 | Resize(input_size), 33 | ToTensor(), 34 | Normalize(mean=mean, 35 | std=std) 36 | ]), 37 | 'test': transforms.Compose([ 38 | Resize(input_size), 39 | ToTensor(), 40 | Normalize(mean=mean, 41 | std=std) 42 | ]), 43 | } 44 | return data_transforms 45 | 46 | class ColorJitter(transforms.ColorJitter): 47 | def __init__(self, image_mode, **kwargs): 48 | super(ColorJitter, self).__init__(**kwargs) 49 | self.transform = None 50 | self.image_mode = image_mode 51 | def __call__(self, sample): 52 | if self.transform is None or self.image_mode: 53 | self.transform = self.get_params(self.brightness, self.contrast, 54 | self.saturation, self.hue) 55 | sample['image'] = self.transform(sample['image']) 56 | return sample 57 | 58 | class RandomResizedCrop(object): 59 | """ 60 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random 61 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop 62 | is finally resized to given size. 63 | This is popularly used to train the Inception networks. 64 | Args: 65 | size: expected output size of each edge 66 | scale: range of size of the origin size cropped 67 | ratio: range of aspect ratio of the origin aspect ratio cropped 68 | """ 69 | 70 | def __init__(self, size, image_mode, scale=(0.7, 1.0), ratio=(3. / 4., 4. / 3.)): 71 | self.size = size 72 | self.scale = scale 73 | self.ratio = ratio 74 | self.i, self.j, self.h, self.w = None, None, None, None 75 | self.image_mode = image_mode 76 | def __call__(self, sample): 77 | image, label = sample['image'], sample['label'] 78 | if self.i is None or self.image_mode: 79 | self.i, self.j, self.h, self.w = transforms.RandomResizedCrop.get_params(image, self.scale, self.ratio) 80 | image = F.resized_crop(image, self.i, self.j, self.h, self.w, self.size, Image.BILINEAR) 81 | label = F.resized_crop(label, self.i, self.j, self.h, self.w, self.size, Image.BILINEAR) 82 | sample['image'], sample['label'] = image, label 83 | return sample 84 | 85 | class RandomFlip(object): 86 | """Horizontally flip the given PIL Image randomly with a given probability. 87 | """ 88 | def __init__(self, image_mode): 89 | self.rand_flip_index = None 90 | self.image_mode = image_mode 91 | def __call__(self, sample): 92 | image, label = sample['image'], sample['label'] 93 | if self.rand_flip_index is None or self.image_mode: 94 | self.rand_flip_index = random.randint(-1,2) 95 | # 0: horizontal flip, 1: vertical flip, -1: horizontal and vertical flip 96 | if self.rand_flip_index == 0: 97 | image = F.hflip(image) 98 | label = F.hflip(label) 99 | elif self.rand_flip_index == 1: 100 | image = F.vflip(image) 101 | label = F.vflip(label) 102 | elif self.rand_flip_index == 2: 103 | image = F.vflip(F.hflip(image)) 104 | label = F.vflip(F.hflip(label)) 105 | sample['image'], sample['label'] = image, label 106 | return sample 107 | 108 | class Resize(object): 109 | """ Resize PIL image use both for training and inference""" 110 | def __init__(self, size): 111 | self.size = size 112 | 113 | def __call__(self, sample): 114 | image, label = sample['image'], sample['label'] 115 | image = F.resize(image, self.size, Image.BILINEAR) 116 | if label is not None: 117 | label = F.resize(label, self.size, Image.BILINEAR) 118 | sample['image'], sample['label'] = image, label 119 | return sample 120 | 121 | class ToTensor(object): 122 | """Convert ndarrays in sample to Tensors.""" 123 | 124 | def __call__(self, sample): 125 | image, label = sample['image'], sample['label'] 126 | 127 | # swap color axis because 128 | # numpy image: H x W x C 129 | # torch image: C X H X W 130 | # Image range from [0~255] to [0.0 ~ 1.0] 131 | image = F.to_tensor(image) 132 | if label is not None: 133 | label = torch.from_numpy(np.array(label)).unsqueeze(0).float() 134 | return {'image': image, 'label': label} 135 | 136 | class Normalize(object): 137 | """ Normalize a tensor image with mean and standard deviation. 138 | args: tensor (Tensor) – Tensor image of size (C, H, W) to be normalized. 139 | Returns: Normalized Tensor image. 140 | """ 141 | # default caffe mode 142 | def __init__(self, mean, std): 143 | self.mean = mean 144 | self.std = std 145 | 146 | def __call__(self, sample): 147 | image, label = sample['image'], sample['label'] 148 | image = F.normalize(image, self.mean, self.std) 149 | return {'image': image, 'label': label} 150 | 151 | -------------------------------------------------------------------------------- /libs/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .convgru import ConvGRUCell 2 | from .non_local_dot_product import NONLocalBlock3D -------------------------------------------------------------------------------- /libs/modules/convgru.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | import torch 8 | from torch import nn 9 | 10 | class ConvGRUCell(nn.Module): 11 | """ 12 | ICLR2016: Delving Deeper into Convolutional Networks for Learning Video Representations 13 | url: https://arxiv.org/abs/1511.06432 14 | """ 15 | def __init__(self, input_channels, hidden_channels, kernel_size, cuda_flag=True): 16 | super(ConvGRUCell, self).__init__() 17 | self.input_channels = input_channels 18 | self.cuda_flag = cuda_flag 19 | self.hidden_channels = hidden_channels 20 | self.kernel_size = kernel_size 21 | 22 | padding = self.kernel_size // 2 23 | self.reset_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding) 24 | self.update_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding) 25 | self.output_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding) 26 | # init 27 | for m in self.state_dict(): 28 | if isinstance(m, nn.Conv2d): 29 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 30 | nn.init.constant_(m.bias, 0) 31 | 32 | def forward(self, x, hidden): 33 | if hidden is None: 34 | size_h = [x.data.size()[0], self.hidden_channels] + list(x.data.size()[2:]) 35 | if self.cuda_flag: 36 | hidden = torch.zeros(size_h).cuda() 37 | else: 38 | hidden = torch.zeros(size_h) 39 | 40 | inputs = torch.cat((x, hidden), dim=1) 41 | reset_gate = torch.sigmoid(self.reset_gate(inputs)) 42 | update_gate = torch.sigmoid(self.update_gate(inputs)) 43 | 44 | reset_hidden = reset_gate * hidden 45 | reset_inputs = torch.tanh(self.output_gate(torch.cat((x, reset_hidden), dim=1))) 46 | new_hidden = (1 - update_gate)*reset_inputs + update_gate*hidden 47 | 48 | return new_hidden -------------------------------------------------------------------------------- /libs/modules/non_local_dot_product.py: -------------------------------------------------------------------------------- 1 | 2 | #!/usr/bin/env python 3 | # coding: utf-8 4 | # 5 | # Author: AlexHex7 6 | # URL: https://github.com/AlexHex7/Non-local_pytorch 7 | 8 | import torch 9 | from torch import nn 10 | from torch.nn import functional as F 11 | 12 | 13 | class _NonLocalBlockND(nn.Module): 14 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True): 15 | super(_NonLocalBlockND, self).__init__() 16 | 17 | assert dimension in [1, 2, 3] 18 | 19 | self.dimension = dimension 20 | self.sub_sample = sub_sample 21 | 22 | self.in_channels = in_channels 23 | self.inter_channels = inter_channels 24 | 25 | if self.inter_channels is None: 26 | self.inter_channels = in_channels // 2 27 | if self.inter_channels == 0: 28 | self.inter_channels = 1 29 | 30 | if dimension == 3: 31 | conv_nd = nn.Conv3d 32 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2)) 33 | bn = nn.BatchNorm3d 34 | elif dimension == 2: 35 | conv_nd = nn.Conv2d 36 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2)) 37 | bn = nn.BatchNorm2d 38 | else: 39 | conv_nd = nn.Conv1d 40 | max_pool_layer = nn.MaxPool1d(kernel_size=(2)) 41 | bn = nn.BatchNorm1d 42 | 43 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 44 | kernel_size=1, stride=1, padding=0) 45 | 46 | if bn_layer: 47 | self.W = nn.Sequential( 48 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 49 | kernel_size=1, stride=1, padding=0), 50 | bn(self.in_channels) 51 | ) 52 | nn.init.constant_(self.W[1].weight, 0) 53 | nn.init.constant_(self.W[1].bias, 0) 54 | else: 55 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels, 56 | kernel_size=1, stride=1, padding=0) 57 | nn.init.constant_(self.W.weight, 0) 58 | nn.init.constant_(self.W.bias, 0) 59 | 60 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 61 | kernel_size=1, stride=1, padding=0) 62 | 63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels, 64 | kernel_size=1, stride=1, padding=0) 65 | 66 | if sub_sample: 67 | self.g = nn.Sequential(self.g, max_pool_layer) 68 | self.phi = nn.Sequential(self.phi, max_pool_layer) 69 | 70 | def forward(self, x): 71 | ''' 72 | :param x: (b, c, t, h, w) 73 | :return: 74 | ''' 75 | 76 | batch_size = x.size(0) 77 | 78 | g_x = self.g(x).view(batch_size, self.inter_channels, -1) 79 | g_x = g_x.permute(0, 2, 1) 80 | 81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1) 82 | theta_x = theta_x.permute(0, 2, 1) 83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1) 84 | f = torch.matmul(theta_x, phi_x) 85 | N = f.size(-1) 86 | f_div_C = f / N 87 | 88 | y = torch.matmul(f_div_C, g_x) 89 | y = y.permute(0, 2, 1).contiguous() 90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:]) 91 | W_y = self.W(y) 92 | z = W_y + x 93 | 94 | return z 95 | 96 | 97 | class NONLocalBlock1D(_NonLocalBlockND): 98 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 99 | super(NONLocalBlock1D, self).__init__(in_channels, 100 | inter_channels=inter_channels, 101 | dimension=1, sub_sample=sub_sample, 102 | bn_layer=bn_layer) 103 | 104 | 105 | class NONLocalBlock2D(_NonLocalBlockND): 106 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 107 | super(NONLocalBlock2D, self).__init__(in_channels, 108 | inter_channels=inter_channels, 109 | dimension=2, sub_sample=sub_sample, 110 | bn_layer=bn_layer) 111 | 112 | 113 | class NONLocalBlock3D(_NonLocalBlockND): 114 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True): 115 | super(NONLocalBlock3D, self).__init__(in_channels, 116 | inter_channels=inter_channels, 117 | dimension=3, sub_sample=sub_sample, 118 | bn_layer=bn_layer) 119 | 120 | 121 | if __name__ == '__main__': 122 | import torch 123 | 124 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]: 125 | img = torch.zeros(2, 3, 20) 126 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer) 127 | out = net(img) 128 | print(out.size()) 129 | 130 | img = torch.zeros(2, 3, 20, 20) 131 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer) 132 | out = net(img) 133 | print(out.size()) 134 | 135 | img = torch.randn(2, 3, 8, 20, 20) 136 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer) 137 | out = net(img) 138 | print(out.size()) -------------------------------------------------------------------------------- /libs/networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .models import ImageModel, VideoModel -------------------------------------------------------------------------------- /libs/networks/models.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | from libs.networks.rcrnet import RCRNet 8 | from libs.modules.convgru import ConvGRUCell 9 | from libs.modules.non_local_dot_product import NONLocalBlock3D 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | class ImageModel(nn.Module): 16 | ''' 17 | RCRNet 18 | ''' 19 | def __init__(self, pretrained=False): 20 | super(ImageModel, self).__init__() 21 | self.backbone = RCRNet( 22 | n_classes=1, 23 | output_stride=16, 24 | pretrained=pretrained 25 | ) 26 | def forward(self, frame): 27 | seg = self.backbone(frame) 28 | return seg 29 | 30 | class VideoModel(nn.Module): 31 | ''' 32 | RCRNet+NER 33 | ''' 34 | def __init__(self, output_stride=16): 35 | super(VideoModel, self).__init__() 36 | # video mode + video dataset 37 | self.backbone = RCRNet( 38 | n_classes=1, 39 | output_stride=output_stride, 40 | pretrained=False, 41 | input_channels=3 42 | ) 43 | self.convgru_forward = ConvGRUCell(256, 256, 3) 44 | self.convgru_backward = ConvGRUCell(256, 256, 3) 45 | self.bidirection_conv = nn.Conv2d(512, 256, 3, 1, 1) 46 | 47 | self.non_local_block = NONLocalBlock3D(256, sub_sample=False, bn_layer=False) 48 | self.non_local_block2 = NONLocalBlock3D(256, sub_sample=False, bn_layer=False) 49 | 50 | self.freeze_bn() 51 | 52 | def freeze_bn(self): 53 | for m in self.backbone.named_modules(): 54 | if isinstance(m[1], nn.BatchNorm2d): 55 | m[1].eval() 56 | 57 | def forward(self, clip): 58 | clip_feats = [self.backbone.feat_conv(frame) for frame in clip] 59 | feats_time = [feats[-1] for feats in clip_feats] 60 | feats_time = torch.stack(feats_time, dim=2) 61 | feats_time = self.non_local_block(feats_time) 62 | 63 | # Deep Bidirectional ConvGRU 64 | frame = clip[0] 65 | feat = feats_time[:,:,0,:,:] 66 | feats_forward = [] 67 | # forward 68 | for i in range(len(clip)): 69 | feat = self.convgru_forward(feats_time[:,:,i,:,:], feat) 70 | feats_forward.append(feat) 71 | # backward 72 | feat = feats_forward[-1] 73 | feats_backward = [] 74 | for i in range(len(clip)): 75 | feat = self.convgru_backward(feats_forward[len(clip)-1-i], feat) 76 | feats_backward.append(feat) 77 | 78 | feats_backward = feats_backward[::-1] 79 | feats = [] 80 | for i in range(len(clip)): 81 | feat = torch.tanh(self.bidirection_conv(torch.cat((feats_forward[i], feats_backward[i]), dim=1))) 82 | feats.append(feat) 83 | feats = torch.stack(feats, dim=2) 84 | 85 | feats = self.non_local_block2(feats) 86 | preds = [] 87 | for i, frame in enumerate(clip): 88 | seg = self.backbone.seg_conv(clip_feats[i][0], clip_feats[i][1], clip_feats[i][2], feats[:,:,i,:,:], frame.shape[2:]) 89 | preds.append(seg) 90 | return preds -------------------------------------------------------------------------------- /libs/networks/pseudo_label_generator.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | from libs.networks.rcrnet import RCRNet 8 | 9 | from flownet2.models import FlowNet2 10 | from flownet2.networks.resample2d_package.resample2d import Resample2d 11 | 12 | import torch 13 | import torch.nn as nn 14 | 15 | _mean = [0.485, 0.456, 0.406] 16 | _std = [0.229, 0.224, 0.225] 17 | 18 | def normalize_flow(flow): 19 | origin_size = flow.shape[2:] 20 | flow[:, 0, :, :] /= origin_size[1] # dx 21 | flow[:, 1, :, :] /= origin_size[0] # dy 22 | norm_flow = (flow[:, 0, :, :] ** 2 + flow[:, 1, :, :] ** 2) ** 0.5 23 | return norm_flow.unsqueeze(1) 24 | 25 | def compute_flow(flownet, data, data_ref): 26 | # flow from data_ref to data 27 | images = [data[0].clone(), data_ref[0].clone()] 28 | for image in images: 29 | for i, (mean, std) in enumerate(zip(_mean, _std)): 30 | image[i].mul_(std).add_(mean) 31 | images = torch.stack(images) 32 | images = images.permute((1, 0, 2, 3)) # to channel, 2, h, w 33 | im = images.unsqueeze(0).float() # add batch_size = 1 [batch_size, channel, 2, h, w] 34 | return flownet(im) 35 | 36 | def resize_flow(flow, size): 37 | origin_size = flow.shape[2:] 38 | flow = F.interpolate(flow.clone(), size=size, mode="near") 39 | flow[:, 0, :, :] /= origin_size[1] / size[1] # dx 40 | flow[:, 1, :, :] /= origin_size[0] / size[0] # dy 41 | return flow 42 | 43 | class FGPLG(nn.Module): 44 | def __init__(self, args, output_stride=16): 45 | super(FGPLG, self).__init__() 46 | self.flownet = FlowNet2(args) 47 | self.warp = Resample2d() 48 | channels = 7 49 | 50 | self.backbone = RCRNet( 51 | n_classes=1, 52 | output_stride=output_stride, 53 | pretrained=False, 54 | input_channels=channels 55 | ) 56 | 57 | self.freeze_bn() 58 | self.freeze_layer() 59 | 60 | def freeze_bn(self): 61 | for m in self.backbone.named_modules(): 62 | if isinstance(m[1], nn.BatchNorm2d): 63 | m[1].eval() 64 | def freeze_layer(self): 65 | if hasattr(self, 'flownet'): 66 | for p in self.flownet.parameters(): 67 | p.requires_grad = False 68 | 69 | def generate_pseudo_label(self, frame, frame_l, frame_r, label_l, label_r): 70 | flow_forward = compute_flow(self.flownet, frame, frame_l) 71 | flow_backward = compute_flow(self.flownet, frame, frame_r) 72 | warp_label_l = self.warp(label_l, flow_forward) 73 | warp_label_r = self.warp(label_r, flow_backward) 74 | inputs = torch.cat(( 75 | frame, 76 | warp_label_l, 77 | warp_label_r, 78 | normalize_flow(flow_forward), 79 | normalize_flow(flow_backward) 80 | ), 1) 81 | pseudo_label = self.backbone(inputs) 82 | return pseudo_label 83 | 84 | def forward(self, clip, clip_label): 85 | pseudo_label = self.generate_pseudo_label(clip[1], clip[0], clip[2], clip_label[0], clip_label[2]) 86 | return pseudo_label, clip_label[1] 87 | -------------------------------------------------------------------------------- /libs/networks/rcrnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | from collections import OrderedDict 8 | 9 | import torch 10 | import torch.nn as nn 11 | import torch.nn.functional as F 12 | 13 | from .resnet_dilation import resnet50, Bottleneck, conv1x1 14 | 15 | class _ConvBatchNormReLU(nn.Sequential): 16 | def __init__(self, 17 | in_channels, 18 | out_channels, 19 | kernel_size, 20 | stride, 21 | padding, 22 | dilation, 23 | relu=True, 24 | ): 25 | super(_ConvBatchNormReLU, self).__init__() 26 | self.add_module( 27 | "conv", 28 | nn.Conv2d( 29 | in_channels=in_channels, 30 | out_channels=out_channels, 31 | kernel_size=kernel_size, 32 | stride=stride, 33 | padding=padding, 34 | dilation=dilation, 35 | bias=False, 36 | ), 37 | ) 38 | self.add_module( 39 | "bn", 40 | nn.BatchNorm2d(out_channels), 41 | ) 42 | 43 | if relu: 44 | self.add_module("relu", nn.ReLU()) 45 | 46 | def forward(self, x): 47 | return super(_ConvBatchNormReLU, self).forward(x) 48 | 49 | class _ASPPModule(nn.Module): 50 | """Atrous Spatial Pyramid Pooling with image pool""" 51 | 52 | def __init__(self, in_channels, out_channels, output_stride): 53 | super(_ASPPModule, self).__init__() 54 | if output_stride == 8: 55 | pyramids = [12, 24, 36] 56 | elif output_stride == 16: 57 | pyramids = [6, 12, 18] 58 | self.stages = nn.Module() 59 | self.stages.add_module( 60 | "c0", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1) 61 | ) 62 | for i, (dilation, padding) in enumerate(zip(pyramids, pyramids)): 63 | self.stages.add_module( 64 | "c{}".format(i + 1), 65 | _ConvBatchNormReLU(in_channels, out_channels, 3, 1, padding, dilation), 66 | ) 67 | self.imagepool = nn.Sequential( 68 | OrderedDict( 69 | [ 70 | ("pool", nn.AdaptiveAvgPool2d((1,1))), 71 | ("conv", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1)), 72 | ] 73 | ) 74 | ) 75 | self.fire = nn.Sequential( 76 | OrderedDict( 77 | [ 78 | ("conv", _ConvBatchNormReLU(out_channels * 5, out_channels, 3, 1, 1, 1)), 79 | ("dropout", nn.Dropout2d(0.1)) 80 | ] 81 | ) 82 | ) 83 | 84 | def forward(self, x): 85 | h = self.imagepool(x) 86 | h = [F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)] 87 | for stage in self.stages.children(): 88 | h += [stage(x)] 89 | h = torch.cat(h, dim=1) 90 | h = self.fire(h) 91 | return h 92 | 93 | class _RefinementModule(nn.Module): 94 | """ Reduce channels and refinment module""" 95 | 96 | def __init__(self, 97 | bottom_up_channels, 98 | reduce_channels, 99 | top_down_channels, 100 | refinement_channels, 101 | expansion=2 102 | ): 103 | super(_RefinementModule, self).__init__() 104 | downsample = None 105 | if bottom_up_channels != reduce_channels: 106 | downsample = nn.Sequential( 107 | conv1x1(bottom_up_channels, reduce_channels), 108 | nn.BatchNorm2d(reduce_channels), 109 | ) 110 | self.skip = Bottleneck(bottom_up_channels, reduce_channels // expansion, 1, 1, downsample, expansion) 111 | self.refine = _ConvBatchNormReLU(reduce_channels + top_down_channels, refinement_channels, 3, 1, 1, 1) 112 | def forward(self, td, bu): 113 | td = self.skip(td) 114 | x = torch.cat((bu, td), dim=1) 115 | x = self.refine(x) 116 | return x 117 | 118 | class RCRNet(nn.Module): 119 | 120 | def __init__(self, n_classes, output_stride, input_channels=3, pretrained=False): 121 | super(RCRNet, self).__init__() 122 | self.resnet = resnet50(pretrained=pretrained, output_stride=output_stride, input_channels=input_channels) 123 | self.aspp = _ASPPModule(2048, 256, output_stride) 124 | # Decoder 125 | self.decoder = nn.Sequential( 126 | OrderedDict( 127 | [ 128 | ("conv1", _ConvBatchNormReLU(128, 256, 3, 1, 1, 1)), 129 | ("conv2", nn.Conv2d(256, n_classes, kernel_size=1)), 130 | ] 131 | ) 132 | ) 133 | self.add_module("refinement1", _RefinementModule(1024, 96, 256, 128, 2)) 134 | self.add_module("refinement2", _RefinementModule(512, 96, 128, 128, 2)) 135 | self.add_module("refinement3", _RefinementModule(256, 96, 128, 128, 2)) 136 | 137 | if pretrained: 138 | for key in self.state_dict(): 139 | if 'resnet' not in key: 140 | self.init_layer(key) 141 | 142 | def init_layer(self, key): 143 | if key.split('.')[-1] == 'weight': 144 | if 'conv' in key: 145 | if self.state_dict()[key].ndimension() >= 2: 146 | nn.init.kaiming_normal_(self.state_dict()[key], mode='fan_out', nonlinearity='relu') 147 | elif 'bn' in key: 148 | self.state_dict()[key][...] = 1 149 | elif key.split('.')[-1] == 'bias': 150 | self.state_dict()[key][...] = 0.001 151 | 152 | def feat_conv(self, x): 153 | ''' 154 | Spatial feature extractor 155 | ''' 156 | block0 = self.resnet.conv1(x) 157 | block0 = self.resnet.bn1(block0) 158 | block0 = self.resnet.relu(block0) 159 | block0 = self.resnet.maxpool(block0) 160 | 161 | block1 = self.resnet.layer1(block0) 162 | block2 = self.resnet.layer2(block1) 163 | block3 = self.resnet.layer3(block2) 164 | block4 = self.resnet.layer4(block3) 165 | block4 = self.aspp(block4) 166 | return block1, block2, block3, block4 167 | 168 | def seg_conv(self, block1, block2, block3, block4, shape): 169 | ''' 170 | Pixel-wise classifer 171 | ''' 172 | bu1 = self.refinement1(block3, block4) 173 | bu1 = F.interpolate(bu1, size=block2.shape[2:], mode="bilinear", align_corners=False) 174 | bu2 = self.refinement2(block2, bu1) 175 | bu2 = F.interpolate(bu2, size=block1.shape[2:], mode="bilinear", align_corners=False) 176 | bu3 = self.refinement3(block1, bu2) 177 | bu3 = F.interpolate(bu3, size=shape, mode="bilinear", align_corners=False) 178 | seg = self.decoder(bu3) 179 | return seg 180 | 181 | def forward(self, x): 182 | block1, block2, block3, block4 = self.feat_conv(x) 183 | seg = self.seg_conv(block1, block2, block3, block4, x.shape[2:]) 184 | return seg -------------------------------------------------------------------------------- /libs/networks/resnet_dilation.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # This code is based on torchvison resnet 5 | # URL: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 6 | 7 | import torch.nn as nn 8 | import torch.utils.model_zoo as model_zoo 9 | 10 | 11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 12 | 'resnet152'] 13 | 14 | 15 | model_urls = { 16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 21 | } 22 | 23 | 24 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1): 25 | """3x3 convolution with padding""" 26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 27 | padding=padding, dilation=dilation, bias=False) 28 | 29 | 30 | def conv1x1(in_planes, out_planes, stride=1): 31 | """1x1 convolution""" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 33 | 34 | 35 | class BasicBlock(nn.Module): 36 | expansion = 1 37 | 38 | def __init__(self, inplanes, planes, stride, dilation, downsample=None): 39 | super(BasicBlock, self).__init__() 40 | self.conv1 = conv3x3(inplanes, planes, stride) 41 | self.bn1 = nn.BatchNorm2d(planes) 42 | self.relu = nn.ReLU(inplace=True) 43 | self.conv2 = conv3x3(planes, planes, 1, dilation, dilation) 44 | self.bn2 = nn.BatchNorm2d(planes) 45 | self.downsample = downsample 46 | self.stride = stride 47 | 48 | def forward(self, x): 49 | residual = x 50 | 51 | out = self.conv1(x) 52 | out = self.bn1(out) 53 | out = self.relu(out) 54 | 55 | out = self.conv2(out) 56 | out = self.bn2(out) 57 | 58 | if self.downsample is not None: 59 | residual = self.downsample(x) 60 | 61 | out += residual 62 | out = self.relu(out) 63 | 64 | return out 65 | 66 | 67 | class Bottleneck(nn.Module): 68 | expansion = 4 69 | 70 | def __init__(self, inplanes, planes, stride, dilation, downsample=None, expansion=4): 71 | super(Bottleneck, self).__init__() 72 | self.expansion = expansion 73 | self.conv1 = conv1x1(inplanes, planes) 74 | self.bn1 = nn.BatchNorm2d(planes) 75 | self.conv2 = conv3x3(planes, planes, stride, dilation, dilation) 76 | self.bn2 = nn.BatchNorm2d(planes) 77 | self.conv3 = conv1x1(planes, planes * self.expansion) 78 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 79 | self.relu = nn.ReLU(inplace=True) 80 | self.downsample = downsample 81 | self.stride = stride 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | 108 | def __init__(self, block, layers, output_stride, num_classes=1000, input_channels=3): 109 | super(ResNet, self).__init__() 110 | if output_stride == 8: 111 | stride = [1, 2, 1, 1] 112 | dilation = [1, 1, 2, 2] 113 | elif output_stride == 16: 114 | stride = [1, 2, 2, 1] 115 | dilation = [1, 1, 1, 2] 116 | 117 | self.inplanes = 64 118 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3, 119 | bias=False) 120 | self.bn1 = nn.BatchNorm2d(64) 121 | self.relu = nn.ReLU(inplace=True) 122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 123 | self.layer1 = self._make_layer(block, 64, layers[0], stride=stride[0], dilation=dilation[0]) 124 | self.layer2 = self._make_layer(block, 128, layers[1], stride=stride[1], dilation=dilation[1]) 125 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride[2], dilation=dilation[2]) 126 | self.layer4 = self._make_layer(block, 512, layers[3], stride=stride[3], dilation=dilation[3]) 127 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 128 | self.fc = nn.Linear(512 * block.expansion, num_classes) 129 | 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv2d): 132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 133 | elif isinstance(m, nn.BatchNorm2d): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | def _make_layer(self, block, planes, blocks, stride, dilation): 138 | downsample = None 139 | if stride != 1 or self.inplanes != planes * block.expansion: 140 | downsample = nn.Sequential( 141 | conv1x1(self.inplanes, planes * block.expansion, stride), 142 | nn.BatchNorm2d(planes * block.expansion), 143 | ) 144 | 145 | layers = [] 146 | layers.append(block(self.inplanes, planes, stride, dilation, downsample)) 147 | self.inplanes = planes * block.expansion 148 | for _ in range(1, blocks): 149 | layers.append(block(self.inplanes, planes, 1, dilation)) 150 | 151 | return nn.Sequential(*layers) 152 | 153 | def forward(self, x): 154 | x = self.conv1(x) 155 | x = self.bn1(x) 156 | x = self.relu(x) 157 | x = self.maxpool(x) 158 | 159 | x = self.layer1(x) 160 | x = self.layer2(x) 161 | x = self.layer3(x) 162 | x = self.layer4(x) 163 | 164 | x = self.avgpool(x) 165 | x = x.view(x.size(0), -1) 166 | x = self.fc(x) 167 | 168 | return x 169 | 170 | 171 | def resnet18(pretrained=False, **kwargs): 172 | """Constructs a ResNet-18 model. 173 | Args: 174 | pretrained (bool): If True, returns a model pre-trained on ImageNet 175 | """ 176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 177 | if pretrained: 178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18'])) 179 | return model 180 | 181 | 182 | def resnet34(pretrained=False, **kwargs): 183 | """Constructs a ResNet-34 model. 184 | Args: 185 | pretrained (bool): If True, returns a model pre-trained on ImageNet 186 | """ 187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 188 | if pretrained: 189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34'])) 190 | return model 191 | 192 | 193 | def resnet50(pretrained=False, **kwargs): 194 | """Constructs a ResNet-50 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 199 | if pretrained: 200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 201 | return model 202 | 203 | 204 | def resnet101(pretrained=False, **kwargs): 205 | """Constructs a ResNet-101 model. 206 | Args: 207 | pretrained (bool): If True, returns a model pre-trained on ImageNet 208 | """ 209 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 210 | if pretrained: 211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101'])) 212 | return model 213 | 214 | 215 | def resnet152(pretrained=False, **kwargs): 216 | """Constructs a ResNet-152 model. 217 | Args: 218 | pretrained (bool): If True, returns a model pre-trained on ImageNet 219 | """ 220 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 221 | if pretrained: 222 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152'])) 223 | return model 224 | -------------------------------------------------------------------------------- /libs/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/libs/utils/__init__.py -------------------------------------------------------------------------------- /libs/utils/logger.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: speedinghzl 5 | # URL: https://github.com/speedinghzl/pytorch-segmentation-toolbox 6 | 7 | import os 8 | import sys 9 | import logging 10 | 11 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO') 12 | _default_level = logging.getLevelName(_default_level_name.upper()) 13 | 14 | 15 | class LogFormatter(logging.Formatter): 16 | log_fout = None 17 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] ' 18 | date = '%(asctime)s ' 19 | msg = '%(message)s' 20 | 21 | def format(self, record): 22 | if record.levelno == logging.DEBUG: 23 | mcl, mtxt = self._color_dbg, 'DBG' 24 | elif record.levelno == logging.WARNING: 25 | mcl, mtxt = self._color_warn, 'WRN' 26 | elif record.levelno == logging.ERROR: 27 | mcl, mtxt = self._color_err, 'ERR' 28 | else: 29 | mcl, mtxt = self._color_normal, '' 30 | 31 | if mtxt: 32 | mtxt += ' ' 33 | 34 | if self.log_fout: 35 | self.__set_fmt(self.date_full + mtxt + self.msg) 36 | formatted = super(LogFormatter, self).format(record) 37 | # self.log_fout.write(formatted) 38 | # self.log_fout.write('\n') 39 | # self.log_fout.flush() 40 | return formatted 41 | 42 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg)) 43 | formatted = super(LogFormatter, self).format(record) 44 | 45 | return formatted 46 | 47 | if sys.version_info.major < 3: 48 | def __set_fmt(self, fmt): 49 | self._fmt = fmt 50 | else: 51 | def __set_fmt(self, fmt): 52 | self._style._fmt = fmt 53 | 54 | @staticmethod 55 | def _color_dbg(msg): 56 | return '\x1b[36m{}\x1b[0m'.format(msg) 57 | 58 | @staticmethod 59 | def _color_warn(msg): 60 | return '\x1b[1;31m{}\x1b[0m'.format(msg) 61 | 62 | @staticmethod 63 | def _color_err(msg): 64 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg) 65 | 66 | @staticmethod 67 | def _color_omitted(msg): 68 | return '\x1b[35m{}\x1b[0m'.format(msg) 69 | 70 | @staticmethod 71 | def _color_normal(msg): 72 | return msg 73 | 74 | @staticmethod 75 | def _color_date(msg): 76 | return '\x1b[32m{}\x1b[0m'.format(msg) 77 | 78 | 79 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter): 80 | logger = logging.getLogger() 81 | logger.setLevel(_default_level) 82 | del logger.handlers[:] 83 | 84 | if log_dir and log_file: 85 | if not os.path.isdir(log_dir): 86 | os.makedirs(log_dir) 87 | LogFormatter.log_fout = True 88 | file_handler = logging.FileHandler(log_file, mode='a') 89 | file_handler.setLevel(logging.INFO) 90 | file_handler.setFormatter(formatter) 91 | logger.addHandler(file_handler) 92 | 93 | stream_handler = logging.StreamHandler() 94 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S')) 95 | stream_handler.setLevel(0) 96 | logger.addHandler(stream_handler) 97 | return logger -------------------------------------------------------------------------------- /libs/utils/metric.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | # This is the python implementation of S-measure 7 | 8 | import numpy as np 9 | 10 | eps = np.finfo(np.float32).eps 11 | def StructureMeasure(prediction, GT): 12 | """ 13 | StructureMeasure computes the similarity between the foreground map and 14 | ground truth(as proposed in "Structure-measure: A new way to evaluate 15 | foreground maps" [Deng-Ping Fan et. al - ICCV 2017]) 16 | Usage: 17 | Q = StructureMeasure(prediction,GT) 18 | Input: 19 | prediction - Binary/Non binary foreground map with values in the range 20 | [0 1]. Type: np.float32 21 | GT - Binary ground truth. Type: np.bool 22 | Output: 23 | Q - The computed similarity score 24 | """ 25 | # check input 26 | if prediction.dtype != np.float32: 27 | raise ValueError("prediction should be of type: np.float32") 28 | if np.amax(prediction) > 1 or np.amin(prediction) < 0: 29 | raise ValueError("prediction should be in the range of [0 1]") 30 | if GT.dtype != np.bool: 31 | raise ValueError("prediction should be of type: np.bool") 32 | 33 | y = np.mean(GT) 34 | 35 | if y == 0: # if the GT is completely black 36 | x = np.mean(prediction) 37 | Q = 1.0 - x 38 | elif y == 1: # if the GT is completely white 39 | x = np.mean(prediction) 40 | Q = x 41 | else: 42 | alpha = 0.5 43 | Q = alpha * S_object(prediction, GT) + (1 - alpha) * S_region(prediction, GT) 44 | if Q < 0: 45 | Q = 0 46 | 47 | return Q 48 | 49 | def S_object(prediction, GT): 50 | """ 51 | S_object Computes the object similarity between foreground maps and ground 52 | truth(as proposed in "Structure-measure:A new way to evaluate foreground 53 | maps" [Deng-Ping Fan et. al - ICCV 2017]) 54 | Usage: 55 | Q = S_object(prediction,GT) 56 | Input: 57 | prediction - Binary/Non binary foreground map with values in the range 58 | [0 1]. Type: np.float32 59 | GT - Binary ground truth. Type: np.bool 60 | Output: 61 | Q - The object similarity score 62 | """ 63 | # compute the similarity of the foreground in the object level 64 | # Notice: inplace operation need deep copy 65 | prediction_fg = prediction.copy() 66 | prediction_fg[~GT] = 0 67 | O_FG = Object(prediction_fg, GT) 68 | 69 | # compute the similarity of the background 70 | prediction_bg = 1.0 - prediction; 71 | prediction_bg[GT] = 0 72 | O_BG = Object(prediction_bg, ~GT) 73 | 74 | # combine the foreground measure and background measure together 75 | u = np.mean(GT) 76 | Q = u * O_FG + (1 - u) * O_BG 77 | 78 | return Q 79 | 80 | def Object(prediction, GT): 81 | # compute the mean of the foreground or background in prediction 82 | x = np.mean(prediction[GT]) 83 | # compute the standard deviations of the foreground or background in prediction 84 | sigma_x = np.std(prediction[GT]) 85 | 86 | score = 2.0 * x / (x * x + 1.0 + sigma_x + eps) 87 | return score 88 | 89 | def S_region(prediction, GT): 90 | """ 91 | S_region computes the region similarity between the foreground map and 92 | ground truth(as proposed in "Structure-measure:A new way to evaluate 93 | foreground maps" [Deng-Ping Fan et. al - ICCV 2017]) 94 | Usage: 95 | Q = S_region(prediction,GT) 96 | Input: 97 | prediction - Binary/Non binary foreground map with values in the range 98 | [0 1]. Type: np.float32 99 | GT - Binary ground truth. Type: np.bool 100 | Output: 101 | Q - The region similarity score 102 | """ 103 | # find the centroid of the GT 104 | X, Y = centroid(GT) 105 | # divide GT into 4 regions 106 | GT_1, GT_2, GT_3, GT_4, w1, w2, w3, w4 = divideGT(GT, X, Y) 107 | # Divede prediction into 4 regions 108 | prediction_1, prediction_2, prediction_3, prediction_4 = Divideprediction(prediction, X, Y) 109 | # Compute the ssim score for each regions 110 | Q1 = ssim(prediction_1, GT_1) 111 | Q2 = ssim(prediction_2, GT_2) 112 | Q3 = ssim(prediction_3, GT_3) 113 | Q4 = ssim(prediction_4, GT_4) 114 | #Sum the 4 scores 115 | Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4 116 | return Q 117 | 118 | def centroid(GT): 119 | """ 120 | Centroid Compute the centroid of the GT 121 | Usage: 122 | X,Y = Centroid(GT) 123 | Input: 124 | GT - Binary ground truth. Type: logical. 125 | Output: 126 | X,Y - The coordinates of centroid. 127 | """ 128 | rows, cols = GT.shape 129 | 130 | total = np.sum(GT) 131 | if total == 0: 132 | X = round(float(cols) / 2) 133 | Y = round(float(rows) / 2) 134 | else: 135 | i = np.arange(1, cols + 1).astype(np.float) 136 | j = (np.arange(1, rows + 1)[np.newaxis].T)[:,0].astype(np.float) 137 | X = round(np.sum(np.sum(GT, axis=0) * i) / total) 138 | Y = round(np.sum(np.sum(GT, axis=1) * j) / total) 139 | return int(X), int(Y) 140 | 141 | def divideGT(GT, X, Y): 142 | """ 143 | LT - left top; 144 | RT - right top; 145 | LB - left bottom; 146 | RB - right bottom; 147 | """ 148 | # width and height of the GT 149 | hei, wid = GT.shape 150 | area = float(wid * hei) 151 | 152 | # copy 4 regions 153 | LT = GT[0:Y, 0:X] 154 | RT = GT[0:Y, X:wid] 155 | LB = GT[Y:hei, 0:X] 156 | RB = GT[Y:hei, X:wid] 157 | 158 | # The different weight (each block proportional to the GT foreground region). 159 | w1 = (X * Y) / area 160 | w2 = ((wid - X) * Y) / area 161 | w3 = (X * (hei-Y)) / area 162 | w4 = 1.0 - w1 - w2 - w3 163 | return LT, RT, LB, RB, w1, w2, w3, w4 164 | 165 | def Divideprediction(prediction, X, Y): 166 | """ 167 | Divide the prediction into 4 regions according to the centroid of the GT 168 | """ 169 | hei, wid = prediction.shape 170 | # copy 4 regions 171 | LT = prediction[0:Y, 0:X] 172 | RT = prediction[0:Y, X:wid] 173 | LB = prediction[Y:hei, 0:X] 174 | RB = prediction[Y:hei, X:wid] 175 | 176 | return LT, RT, LB, RB 177 | 178 | def ssim(prediction, GT): 179 | """ 180 | ssim computes the region similarity between foreground maps and ground 181 | truth(as proposed in "Structure-measure: A new way to evaluate foreground 182 | maps" [Deng-Ping Fan et. al - ICCV 2017]) 183 | Usage: 184 | Q = ssim(prediction,GT) 185 | Input: 186 | prediction - Binary/Non binary foreground map with values in the range 187 | [0 1]. Type: np.float32 188 | GT - Binary ground truth. Type: np.bool 189 | Output: 190 | Q - The region similarity score 191 | """ 192 | dGT = GT.astype(np.float32) 193 | 194 | hei, wid = prediction.shape 195 | N = float(wid * hei) 196 | 197 | # Compute the mean of SM,GT 198 | x = np.mean(prediction) 199 | y = np.mean(dGT) 200 | 201 | # Compute the variance of SM,GT 202 | dx = prediction - x 203 | dy = dGT - y 204 | total = N - 1 + eps 205 | sigma_x2 = np.sum(dx * dx) / total 206 | sigma_y2 = np.sum(dy * dy) / total 207 | 208 | # Compute the covariance between SM and GT 209 | sigma_xy = np.sum(dx * dy) / total 210 | 211 | alpha = 4 * x * y * sigma_xy 212 | beta = (x * x + y * y) * (sigma_x2 + sigma_y2) 213 | 214 | if alpha != 0: 215 | Q = alpha / (beta + eps) 216 | elif beta == 0: 217 | Q = 1.0 218 | else: 219 | Q = 0 220 | return Q 221 | 222 | -------------------------------------------------------------------------------- /libs/utils/pyt_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: speedinghzl 5 | # URL: https://github.com/speedinghzl/pytorch-segmentation-toolbox 6 | 7 | import time 8 | import logging 9 | 10 | import torch 11 | 12 | from .logger import get_logger 13 | 14 | logger = get_logger() 15 | 16 | def load_model(model, model_file, is_restore=False): 17 | t_start = time.time() 18 | if isinstance(model_file, str): 19 | device = torch.device('cpu') 20 | state_dict = torch.load(model_file, map_location=device) 21 | if 'state_dict' in state_dict.keys(): 22 | state_dict = state_dict['state_dict'] 23 | else: 24 | state_dict = model_file 25 | t_ioend = time.time() 26 | 27 | if not is_restore: 28 | # extend the input channels of FGPLG from 3 to 7 29 | v2 = model.backbone.resnet.conv1.weight 30 | if v2.size(1) > 3: 31 | v = state_dict['backbone.resnet.conv1.weight'] 32 | v = torch.cat((v,v2[:,3:,:,:]), dim=1) 33 | state_dict['backbone.resnet.conv1.weight'] = v 34 | 35 | model.load_state_dict(state_dict, strict=False) 36 | ckpt_keys = set(state_dict.keys()) 37 | own_keys = set(model.state_dict().keys()) 38 | missing_keys = own_keys - ckpt_keys 39 | unexpected_keys = ckpt_keys - own_keys 40 | 41 | if len(missing_keys) > 0: 42 | logger.warning('Missing key(s) in state_dict: {}'.format( 43 | ', '.join('{}'.format(k) for k in missing_keys))) 44 | 45 | if len(unexpected_keys) > 0: 46 | logger.warning('Unexpected key(s) in state_dict: {}'.format( 47 | ', '.join('{}'.format(k) for k in unexpected_keys))) 48 | 49 | del state_dict 50 | t_end = time.time() 51 | logger.info( 52 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format( 53 | t_ioend - t_start, t_end - t_ioend)) 54 | 55 | return model -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/models/.gitkeep -------------------------------------------------------------------------------- /models/checkpoints/.gitkeep: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/models/checkpoints/.gitkeep -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import os 9 | 10 | import torch 11 | import torch.nn as nn 12 | from torch.utils import data 13 | 14 | import argparse 15 | from tqdm import tqdm 16 | import numpy as np 17 | 18 | from libs.datasets import get_transforms, get_datasets 19 | from libs.networks import VideoModel 20 | from libs.utils.metric import StructureMeasure 21 | from libs.utils.pyt_utils import load_model 22 | 23 | parser = argparse.ArgumentParser() 24 | 25 | # Dataloading-related settings 26 | parser.add_argument('--data', type=str, default='data/datasets/', 27 | help='path to datasets folder') 28 | parser.add_argument('--checkpoint', default='models/image_pretrained_model.pth', 29 | help='path to the pretrained checkpoint') 30 | parser.add_argument('--dataset-config', default='config/datasets.yaml', 31 | help='dataset config file') 32 | parser.add_argument('--save-folder', default='models/checkpoints', 33 | help='location to save checkpoint models') 34 | parser.add_argument('--pseudo-label-folder', default='', 35 | help='location to load pseudo-labels') 36 | parser.add_argument('-j', '--num_workers', default=1, type=int, metavar='N', 37 | help='number of data loading workers.') 38 | 39 | # Training procedure settings 40 | parser.add_argument('--batch-size', default=1, type=int, 41 | help='batch size for each gpu. Only support 1 for video clips.') 42 | parser.add_argument('--backup-epochs', type=int, default=1, 43 | help='iteration epoch to perform state backups') 44 | parser.add_argument('--epochs', type=int, default=50, 45 | help='upper epoch limit') 46 | parser.add_argument('--start-epoch', type=int, default=0, 47 | help='epoch number to resume') 48 | parser.add_argument('--eval-first', default=False, action='store_true', 49 | help='evaluate model weights before training') 50 | parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float, 51 | help='initial learning rate') 52 | 53 | # Model settings 54 | parser.add_argument('--size', default=448, type=int, 55 | help='image size') 56 | parser.add_argument('--os', default=16, type=int, 57 | help='output stride.') 58 | parser.add_argument("--clip_len", type=int, default=4, 59 | help="the number of frames in a video clip.") 60 | 61 | args = parser.parse_args() 62 | 63 | cuda = torch.cuda.is_available() 64 | device = torch.device("cuda" if cuda else "cpu") 65 | 66 | if cuda: 67 | torch.backends.cudnn.benchmark = True 68 | current_device = torch.cuda.current_device() 69 | print("Running on", torch.cuda.get_device_name(current_device)) 70 | else: 71 | print("Running on CPU") 72 | 73 | data_transforms = get_transforms( 74 | input_size=(args.size, args.size), 75 | image_mode=False 76 | ) 77 | 78 | train_dataset = get_datasets( 79 | name_list=["DAVIS2016", "FBMS", "VOS"], 80 | split_list=["train", "train", "train"], 81 | config_path=args.dataset_config, 82 | root=args.data, 83 | training=True, 84 | transforms=data_transforms['train'], 85 | read_clip=True, 86 | random_reverse_clip=True, 87 | clip_len=args.clip_len 88 | ) 89 | val_dataset = get_datasets( 90 | name_list="VOS", 91 | split_list="val", 92 | config_path=args.dataset_config, 93 | root=args.data, 94 | training=True, 95 | transforms=data_transforms['val'], 96 | read_clip=True, 97 | random_reverse_clip=False, 98 | clip_len=args.clip_len 99 | ) 100 | 101 | train_dataloader = data.DataLoader( 102 | dataset=train_dataset, 103 | batch_size=args.batch_size, 104 | num_workers=args.num_workers, 105 | shuffle=True, 106 | drop_last=True 107 | ) 108 | val_dataloader = data.DataLoader( 109 | dataset=val_dataset, 110 | batch_size=args.batch_size, 111 | num_workers=args.num_workers, 112 | shuffle=False 113 | ) 114 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 115 | 116 | # loading pseudo-labels for training 117 | if os.path.exists(args.pseudo_label_folder): 118 | print("Loading pseudo-labels from {}".format(args.pseudo_label_folder)) 119 | datasets = dataloaders['train'].dataset 120 | for dataset in datasets.datasets: 121 | dataset._reset_files(clip_len=args.clip_len, label_dir=os.path.join(args.pseudo_label_folder, dataset.name)) 122 | if isinstance(datasets, data.ConcatDataset): 123 | datasets.cumulative_sizes = datasets.cumsum(datasets.datasets) 124 | 125 | model = VideoModel(output_stride=args.os) 126 | # load pretrained models 127 | if os.path.exists(args.checkpoint): 128 | print('Loading state dict from: {0}'.format(args.checkpoint)) 129 | if args.start_epoch == 0: 130 | model = load_model(model=model, model_file=args.checkpoint, is_restore=False) 131 | else: 132 | model = load_model(model=model, model_file=args.checkpoint, is_restore=True) 133 | else: 134 | raise ValueError("Cannot find model file at {}".format(args.checkpoint)) 135 | 136 | model = nn.DataParallel(model) 137 | model.to(device) 138 | 139 | criterion = nn.BCEWithLogitsLoss() 140 | 141 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.module.parameters()), lr=args.lr) 142 | 143 | if not os.path.exists(args.save_folder): 144 | os.makedirs(args.save_folder) 145 | 146 | def train(): 147 | 148 | best_smeasure = 0.0 149 | best_epoch = 0 150 | for epoch in range(args.start_epoch, args.epochs+1): 151 | # Each epoch has a training and validation phase 152 | if args.eval_first: 153 | phases = ['val'] 154 | else: 155 | phases = ['train', 'val'] 156 | 157 | for phase in phases: 158 | if phase == 'train': 159 | model.train() # Set model to training mode 160 | model.module.freeze_bn() 161 | else: 162 | model.eval() # Set model to evaluate mode 163 | 164 | running_loss = 0.0 165 | running_mae = 0.0 166 | running_smean = 0.0 167 | print("{} epoch {}...".format(phase, epoch)) 168 | # Iterate over data. 169 | for data in tqdm(dataloaders[phase]): 170 | images, labels = [], [] 171 | for frame in data: 172 | images.append(frame['image'].to(device)) 173 | labels.append(frame['label'].to(device)) 174 | # zero the parameter gradients 175 | optimizer.zero_grad() 176 | # track history if only in train 177 | with torch.set_grad_enabled(phase == 'train'): 178 | # read clips 179 | preds = model(images) 180 | loss = [] 181 | for pred, label in zip(preds, labels): 182 | loss.append(criterion(pred, label)) 183 | # backward + optimize only if in training phase 184 | if phase == 'train': 185 | torch.autograd.backward(loss) 186 | optimizer.step() 187 | # statistics 188 | for _loss in loss: 189 | running_loss += _loss.item() 190 | preds = [torch.sigmoid(pred) for pred in preds] # activation 191 | 192 | # iterate list 193 | for i, (label_, pred_) in enumerate(zip(labels, preds)): 194 | # interate batch 195 | for j, (label, pred) in enumerate(zip(label_.detach().cpu(), pred_.detach().cpu())): 196 | pred_idx = pred[0,:,:].numpy() 197 | label_idx = label[0,:,:].numpy() 198 | if phase == 'val': 199 | running_smean += StructureMeasure(pred_idx.astype(np.float32), (label_idx>=0.5).astype(np.bool)) 200 | running_mae += np.abs(pred_idx - label_idx).mean() 201 | 202 | samples_num = len(dataloaders[phase].dataset) 203 | samples_num *= args.clip_len 204 | epoch_loss = running_loss / samples_num 205 | epoch_mae = running_mae / samples_num 206 | print('{} Loss: {:.4f}'.format(phase, epoch_loss)) 207 | print('{} MAE: {:.4f}'.format(phase, epoch_mae)) 208 | 209 | # save current best epoch 210 | if phase == 'val': 211 | epoch_smeasure = running_smean / samples_num 212 | print('{} S-measure: {:.4f}'.format(phase, epoch_smeasure)) 213 | if epoch_smeasure > best_smeasure: 214 | best_smeasure = epoch_smeasure 215 | best_epoch = epoch 216 | model_path = os.path.join(args.save_folder, "video_current_best_model.pth") 217 | print("Saving current best model at: {}".format(model_path) ) 218 | torch.save( 219 | model.module.state_dict(), 220 | model_path, 221 | ) 222 | if epoch > 0 and epoch % args.backup_epochs == 0: 223 | # save model 224 | model_path = os.path.join(args.save_folder, "video_epoch-{}.pth".format(epoch)) 225 | print("Backup model at: {}".format(model_path)) 226 | torch.save( 227 | model.module.state_dict(), 228 | model_path, 229 | ) 230 | 231 | print('Best S-measure: {} at epoch {}'.format(best_smeasure, best_epoch)) 232 | 233 | if __name__ == "__main__": 234 | train() 235 | -------------------------------------------------------------------------------- /train_fgplg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | # 4 | # Author: Pengxiang Yan 5 | # Email: yanpx (at) mail2.sysu.edu.cn 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import os 9 | import sys 10 | sys.path.append('flownet2') 11 | 12 | import torch 13 | import torch.nn as nn 14 | from torch.utils import data 15 | 16 | import argparse 17 | from tqdm import tqdm 18 | import numpy as np 19 | 20 | from libs.datasets import get_transforms, get_datasets 21 | from libs.networks.pseudo_label_generator import FGPLG 22 | from libs.utils.metric import StructureMeasure 23 | from libs.utils.pyt_utils import load_model 24 | 25 | parser = argparse.ArgumentParser() 26 | 27 | # Dataloading-related settings 28 | parser.add_argument('--data', type=str, default='data/datasets/', 29 | help='path to datasets folder') 30 | parser.add_argument('--checkpoint', default='models/image_pretrained_model.pth', 31 | help='path to the pretrained checkpoint') 32 | parser.add_argument('--flownet-checkpoint', default='models/FlowNet2_checkpoint.pth.tar', 33 | help='path to the checkpoint of pretrained flownet2') 34 | parser.add_argument('--dataset-config', default='config/datasets.yaml', 35 | help='dataset config file') 36 | parser.add_argument('--save-folder', default='models/checkpoints', 37 | help='location to save checkpoint models') 38 | parser.add_argument("--label_interval", default=5, type=int, 39 | help="the interval of ground truth labels") 40 | parser.add_argument('-j', '--num_workers', default=1, type=int, metavar='N', 41 | help='number of data loading workers.') 42 | 43 | # Training procedure settings 44 | parser.add_argument('--batch-size', default=1, type=int, 45 | help='batch size for each gpu. Only support 1 for video clips.') 46 | parser.add_argument('--backup-epochs', type=int, default=1, 47 | help='iteration epoch to perform state backups') 48 | parser.add_argument('--epochs', type=int, default=50, 49 | help='upper epoch limit') 50 | parser.add_argument('--start-epoch', type=int, default=0, 51 | help='epoch number to resume') 52 | parser.add_argument('--eval-first', default=False, action='store_true', 53 | help='evaluate model weights before training') 54 | parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float, 55 | help='initial learning rate') 56 | 57 | # Model settings 58 | parser.add_argument('--size', default=448, type=int, 59 | help='image size') 60 | parser.add_argument('--os', default=16, type=int, 61 | help='output stride.') 62 | parser.add_argument("--clip_len", type=int, default=3, 63 | help="the number of frames in a video clip.") 64 | 65 | # FlowNet setting 66 | parser.add_argument("--fp16", action="store_true", 67 | help="Run model in pseudo-fp16 mode (fp16 storage fp32 math).") 68 | parser.add_argument("--rgb_max", type=float, default=1.) 69 | 70 | 71 | args = parser.parse_args() 72 | 73 | cuda = torch.cuda.is_available() 74 | device = torch.device("cuda" if cuda else "cpu") 75 | 76 | if cuda: 77 | torch.backends.cudnn.benchmark = True 78 | current_device = torch.cuda.current_device() 79 | print("Running on", torch.cuda.get_device_name(current_device)) 80 | else: 81 | print("Running on CPU") 82 | 83 | data_transforms = get_transforms( 84 | input_size=(args.size, args.size), 85 | image_mode=False 86 | ) 87 | 88 | train_dataset = get_datasets( 89 | name_list=["DAVIS2016", "FBMS", "VOS"], 90 | split_list=["train", "train", "train"], 91 | config_path=args.dataset_config, 92 | root=args.data, 93 | training=True, 94 | transforms=data_transforms['train'], 95 | read_clip=True, 96 | random_reverse_clip=True, 97 | label_interval=args.label_interval, 98 | clip_len=args.clip_len 99 | ) 100 | val_dataset = get_datasets( 101 | name_list="VOS", 102 | split_list="val", 103 | config_path=args.dataset_config, 104 | root=args.data, 105 | training=True, 106 | transforms=data_transforms['val'], 107 | read_clip=True, 108 | random_reverse_clip=False, 109 | label_interval=args.label_interval, 110 | clip_len=args.clip_len 111 | ) 112 | 113 | train_dataloader = data.DataLoader( 114 | dataset=train_dataset, 115 | batch_size=args.batch_size, 116 | num_workers=args.num_workers, 117 | shuffle=True, 118 | drop_last=True 119 | ) 120 | val_dataloader = data.DataLoader( 121 | dataset=val_dataset, 122 | batch_size=args.batch_size, 123 | num_workers=args.num_workers, 124 | shuffle=False 125 | ) 126 | dataloaders = {'train': train_dataloader, 'val': val_dataloader} 127 | 128 | pseudo_label_generator = FGPLG(args=args, output_stride=args.os) 129 | 130 | if os.path.exists(args.checkpoint): 131 | print('Loading state dict from: {0}'.format(args.checkpoint)) 132 | if args.start_epoch == 0: 133 | pseudo_label_generator = load_model(model=pseudo_label_generator, model_file=args.checkpoint, is_restore=False) 134 | if os.path.exists(args.flownet_checkpoint): 135 | pseudo_label_generator.flownet = load_model(model=pseudo_label_generator.flownet, model_file=args.flownet_checkpoint, is_restore=True) 136 | else: 137 | raise ValueError("Cannot pretrained flownet model file at {}".format(args.flownet_checkpoint)) 138 | else: 139 | pseudo_label_generator = load_model(model=pseudo_label_generator, model_file=args.checkpoint, is_restore=True) 140 | else: 141 | raise ValueError("Cannot find model file at {}".format(args.checkpoint)) 142 | 143 | pseudo_label_generator = nn.DataParallel(pseudo_label_generator) 144 | pseudo_label_generator.to(device) 145 | 146 | criterion = nn.BCEWithLogitsLoss() 147 | 148 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, pseudo_label_generator.module.parameters()), lr=args.lr) 149 | 150 | if not os.path.exists(args.save_folder): 151 | os.makedirs(args.save_folder) 152 | 153 | def train(): 154 | best_smeasure = 0.0 155 | best_epoch = 0 156 | for epoch in range(args.start_epoch, args.epochs+1): 157 | # Each epoch has a training and validation phase 158 | if args.eval_first: 159 | phases = ['val'] 160 | else: 161 | phases = ['train', 'val'] 162 | 163 | for phase in phases: 164 | if phase == 'train': 165 | pseudo_label_generator.train() # Set model to training mode 166 | pseudo_label_generator.module.freeze_bn() 167 | else: 168 | pseudo_label_generator.eval() # Set model to evaluate mode 169 | 170 | running_loss = 0.0 171 | running_acc = 0.0 172 | running_iou = 0.0 173 | running_mae = 0.0 174 | running_smean = 0.0 175 | print("{} epoch {}...".format(phase, epoch)) 176 | # Iterate over data. 177 | for data in tqdm(dataloaders[phase]): 178 | images, labels = [], [] 179 | for frame in data: 180 | images.append(frame['image'].to(device)) 181 | labels.append(frame['label'].to(device)) 182 | # zero the parameter gradients 183 | optimizer.zero_grad() 184 | # track history if only in train 185 | with torch.set_grad_enabled(phase == 'train'): 186 | # read clips 187 | preds, labels = pseudo_label_generator(images, labels) 188 | loss = criterion(preds, labels) 189 | # backward + optimize only if in training phase 190 | if phase == 'train': 191 | torch.autograd.backward(loss) 192 | optimizer.step() 193 | # statistics 194 | running_loss += loss.item() 195 | preds = torch.sigmoid(preds) # activation 196 | 197 | pred_idx = preds.squeeze().detach().cpu().numpy() 198 | label_idx = labels.squeeze().detach().cpu().numpy() 199 | if phase == 'val': 200 | running_smean += StructureMeasure(pred_idx.astype(np.float32), (label_idx>=0.5).astype(np.bool)) 201 | running_mae += np.abs(pred_idx - label_idx).mean() 202 | 203 | samples_num = len(dataloaders[phase].dataset) 204 | epoch_loss = running_loss / samples_num 205 | epoch_mae = running_mae / samples_num 206 | print('{} Loss: {:.4f}'.format(phase, epoch_loss)) 207 | print('{} MAE: {:.4f}'.format(phase, epoch_mae)) 208 | 209 | # save current best epoch 210 | if phase == 'val': 211 | epoch_smeasure = running_smean / samples_num 212 | print('{} S-measure: {:.4f}'.format(phase, epoch_smeasure)) 213 | if epoch_smeasure > best_smeasure: 214 | best_smeasure = epoch_smeasure 215 | best_epoch = epoch 216 | model_path = os.path.join(args.save_folder, "fgplg_current_best_model.pth") 217 | print("Saving current best model at: {}".format(model_path) ) 218 | torch.save( 219 | pseudo_label_generator.module.state_dict(), 220 | model_path, 221 | ) 222 | if epoch > 0 and epoch % args.backup_epochs == 0: 223 | # save model 224 | model_path = os.path.join(args.save_folder, "fgplg_epoch-{}.pth".format(epoch)) 225 | print("Backup model at: {}".format(model_path)) 226 | torch.save( 227 | pseudo_label_generator.module.state_dict(), 228 | model_path, 229 | ) 230 | 231 | print('Best S-measure: {} at epoch {}'.format(best_smeasure, best_epoch)) 232 | 233 | if __name__ == "__main__": 234 | train() 235 | --------------------------------------------------------------------------------