├── .gitignore
├── LICENSE
├── README.md
├── config
└── datasets.yaml
├── data
└── datasets
│ └── .gitkeep
├── docs
├── _config.yml
├── bibtex.txt
├── comp_video_sota.png
├── index.md
├── motorcross-jump.gif
├── pseudo_label_example.png
├── pseudo_label_generator.png
├── static_model.png
├── training_instruction.md
└── video_model.png
├── flownet2
├── .gitignore
├── Dockerfile
├── LICENSE
├── README.md
├── __init__.py
├── convert.py
├── datasets.py
├── download_caffe_models.sh
├── image.png
├── install.sh
├── launch_docker.sh
├── losses.py
├── main.py
├── models.py
├── networks
│ ├── FlowNetC.py
│ ├── FlowNetFusion.py
│ ├── FlowNetS.py
│ ├── FlowNetSD.py
│ ├── __init__.py
│ ├── channelnorm_package
│ │ ├── __init__.py
│ │ ├── channelnorm.py
│ │ ├── channelnorm_cuda.cc
│ │ ├── channelnorm_kernel.cu
│ │ ├── channelnorm_kernel.cuh
│ │ └── setup.py
│ ├── correlation_package
│ │ ├── __init__.py
│ │ ├── correlation.py
│ │ ├── correlation_cuda.cc
│ │ ├── correlation_cuda_kernel.cu
│ │ ├── correlation_cuda_kernel.cuh
│ │ └── setup.py
│ ├── resample2d_package
│ │ ├── __init__.py
│ │ ├── resample2d.py
│ │ ├── resample2d_cuda.cc
│ │ ├── resample2d_kernel.cu
│ │ ├── resample2d_kernel.cuh
│ │ └── setup.py
│ └── submodules.py
├── run-caffe2pytorch.sh
└── utils
│ ├── __init__.py
│ ├── flow_utils.py
│ ├── frame_utils.py
│ ├── param_utils.py
│ └── tools.py
├── generate_pseudo_labels.py
├── inference.py
├── libs
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── transforms.py
│ └── video_datasets.py
├── modules
│ ├── __init__.py
│ ├── convgru.py
│ └── non_local_dot_product.py
├── networks
│ ├── __init__.py
│ ├── models.py
│ ├── pseudo_label_generator.py
│ ├── rcrnet.py
│ └── resnet_dilation.py
└── utils
│ ├── __init__.py
│ ├── logger.py
│ ├── metric.py
│ └── pyt_utils.py
├── models
├── .gitkeep
└── checkpoints
│ └── .gitkeep
├── train.py
└── train_fgplg.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Mac
2 | .DS_Store
3 | # Editor
4 | .vscode/
5 | # Data
6 | *.pth
7 | *.pth.tar
8 | data/results
9 | data/pseudo-labels
10 |
11 | # Byte-compiled / optimized / DLL files
12 | __pycache__/
13 | *.py[cod]
14 | *$py.class
15 |
16 | # C extensions
17 | *.so
18 |
19 | # Distribution / packaging
20 | .Python
21 | build/
22 | develop-eggs/
23 | dist/
24 | downloads/
25 | eggs/
26 | .eggs/
27 | lib/
28 | lib64/
29 | parts/
30 | sdist/
31 | var/
32 | wheels/
33 | *.egg-info/
34 | .installed.cfg
35 | *.egg
36 | MANIFEST
37 |
38 | # PyInstaller
39 | # Usually these files are written by a python script from a template
40 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
41 | *.manifest
42 | *.spec
43 |
44 | # Installer logs
45 | pip-log.txt
46 | pip-delete-this-directory.txt
47 |
48 | # Unit test / coverage reports
49 | htmlcov/
50 | .tox/
51 | .coverage
52 | .coverage.*
53 | .cache
54 | nosetests.xml
55 | coverage.xml
56 | *.cover
57 | .hypothesis/
58 | .pytest_cache/
59 |
60 | # Translations
61 | *.mo
62 | *.pot
63 |
64 | # Django stuff:
65 | *.log
66 | local_settings.py
67 | db.sqlite3
68 |
69 | # Flask stuff:
70 | instance/
71 | .webassets-cache
72 |
73 | # Scrapy stuff:
74 | .scrapy
75 |
76 | # Sphinx documentation
77 | docs/_build/
78 |
79 | # PyBuilder
80 | target/
81 |
82 | # Jupyter Notebook
83 | .ipynb_checkpoints
84 |
85 | # pyenv
86 | .python-version
87 |
88 | # celery beat schedule file
89 | celerybeat-schedule
90 |
91 | # SageMath parsed files
92 | *.sage.py
93 |
94 | # Environments
95 | .env
96 | .venv
97 | env/
98 | venv/
99 | ENV/
100 | env.bak/
101 | venv.bak/
102 |
103 | # Spyder project settings
104 | .spyderproject
105 | .spyproject
106 |
107 | # Rope project settings
108 | .ropeproject
109 |
110 | # mkdocs documentation
111 | /site
112 |
113 | # mypy
114 | .mypy_cache/
115 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Kinpzz
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # RCRNet-Pytorch
2 |
3 |
4 | This repository contains the PyTorch implementation for
5 |
6 | **Semi-Supervised Video Salient Object Detection Using Pseudo-Labels**
7 | Pengxiang Yan, Guanbin Li, Yuan Xie, Zhen Li, Chuan Wang, Tianshui Chen, Liang Lin
8 | **ICCV 2019** | [[Project Page](https://kinpzz.com/publication/iccv19_semi_vsod/)] | [[Arxiv](https://arxiv.org/abs/1908.04051)] | [[CVF-Open-Access](http://openaccess.thecvf.com/content_ICCV_2019/html/Yan_Semi-Supervised_Video_Salient_Object_Detection_Using_Pseudo-Labels_ICCV_2019_paper.html)]
9 |
10 | ## Usage
11 |
12 | ### Requirements
13 |
14 | This code is tested on Ubuntu 16.04, Python=3.6 (via Anaconda3), PyTorch=0.4.1, CUDA=9.0.
15 |
16 | ```
17 | # Install PyTorch=0.4.1
18 | $ conda install pytorch==0.4.1 torchvision==0.2.1 cuda90 -c pytorch
19 |
20 | # Install other packages
21 | $ pip install pyyaml==3.13 addict==2.2.0 tqdm==4.28.1 scipy==1.1.0
22 | ```
23 | ### Datasets
24 | Our proposed RCRNet is evaluated on three public benchmark VSOD datsets including [VOS](http://cvteam.net/projects/TIP18-VOS/VOS.html), [DAVIS](https://davischallenge.org/) (version: 2016, 480p), and [FBMS](https://lmb.informatik.uni-freiburg.de/resources/datasets/). Please orginaize the datasets according to `config/datasets.yaml` and put them in `data/datasets`. Or you can set argument `--data` to the path of the dataset folder.
25 |
26 | ### Evaluation
27 | #### Comparison with State-of-the-Art
28 | 
29 | If you want to compare with our method:
30 |
31 | **Option 1:** you can download the saliency maps predicted by our model from [Google Drive](https://drive.google.com/open?id=1feY3GdNBS-LUBt0UDWwpA3fl9yHI4Vxr) / [Baidu Pan](https://pan.baidu.com/s/1oXBr9qxyF-8vvilvV5kcPg) (passwd: u079).
32 |
33 | **Option 2:** Or you can use our trained model for inference. The weights of trained model are available at [Google Drive](https://drive.google.com/open?id=1TSmi1DyKIvuzuXE1aw7t_ygmcUmjYnN_) / [Baidu Pan](https://pan.baidu.com/s/1PLoajL6X_s29I-4mreSuSQ) (passwd: 6pi3). Then run the following command for inference.
34 | ```
35 | # VOS
36 | $ CUDA_VISIBLE_DEVICES=0 python inference.py --data data/datasets --dataset VOS --split test
37 |
38 | # DAVIS
39 | $ CUDA_VISIBLE_DEVICES=0 python inference.py --data data/datasets --dataset DAVIS --split val
40 |
41 | # FBMS
42 | $ CUDA_VISIBLE_DEVICES=0 python inference.py --data data/datasets --dataset FBMS --split test
43 | ```
44 |
45 | Then, you can evaluate the saliency maps using your own evaluation code.
46 |
47 | ### Training
48 | If you want to train our proposed model from scratch (including using pseudo-labels), please refer to our paper and the [training instruction](docs/training_instruction.md) carefully.
49 |
50 | ## Citation
51 | If you find this work helpful, please consider citing
52 | ```
53 | @inproceedings{yan2019semi,
54 | title={Semi-Supervised Video Salient Object Detection Using Pseudo-Labels},
55 | author={Yan, Pengxiang and Li, Guanbin and Xie, Yuan and Li, Zhen and Wang, Chuan and Chen, Tianshui and Lin, Liang},
56 | booktitle={Proceedings of the IEEE International Conference on Computer Vision},
57 | pages={7284--7293},
58 | year={2019}
59 | }
60 | ```
61 |
62 | ## Acknowledge
63 | Thanks to the third-party libraries:
64 | * [deeplab-pytorch](https://github.com/kazuto1011/deeplab-pytorch) by kazuto1011
65 | * [flownet2-pytorch](https://github.com/NVIDIA/flownet2-pytorch) by NVIDIA
66 | * [pytorch-segmentation-toolbox](https://github.com/speedinghzl/pytorch-segmentation-toolbox) by speedinghzl
67 | * [Non-local_pytorch](https://github.com/AlexHex7/Non-local_pytorch) by AlexHex7
68 |
--------------------------------------------------------------------------------
/config/datasets.yaml:
--------------------------------------------------------------------------------
1 | # Video Saliency Dataset Config files
2 | #
3 | # Author: Pengxiang Yan
4 | # Email: yanpx (at) mail2.sysu.edu.cn
5 |
6 | # MSRA-B Dataset
7 | # url: https://mmcheng.net/msra10k/
8 | MSRA-B:
9 | image_dir: imgs
10 | label_dir: gt
11 | split_dir: ImageSets
12 | image_ext: .jpg
13 | label_ext: .png
14 | # HKU-IS Dataset
15 | # url: https://i.cs.hku.hk/~gbli/deep_saliency.html
16 |
17 | HKU-IS:
18 | image_dir: imgs
19 | label_dir: gt
20 | split_dir: ImageSets
21 | image_ext: .png
22 | label_ext: .png
23 |
24 | # DAVIS 2016 Dataset: Densely Annotated VIdeo Segmentation
25 | # url: https://davischallenge.org/
26 | DAVIS2016:
27 | image_dir: JPEGImages/480p
28 | label_dir: Annotations/480p
29 | image_ext: .jpg
30 | label_ext: .png
31 | default_label_interval: 1 # every "default_label_interval" image gives a label
32 | video_split:
33 | train: ['bear', 'bmx-bumps', 'boat', 'breakdance-flare',
34 | 'bus', 'car-turn', 'dance-jump', 'dog-agility',
35 | 'drift-turn', 'elephant', 'flamingo', 'hike',
36 | 'hockey', 'horsejump-low', 'kite-walk', 'lucia',
37 | 'mallard-fly', 'motocross-bumps', 'motorbike',
38 | 'paragliding', 'rhino', 'rollerblade', 'scooter-gray',
39 | 'soccerball', 'stroller', 'surf', 'swing', 'tennis',
40 | 'train']
41 | val: ['blackswan', 'bmx-trees', 'breakdance', 'camel',
42 | 'car-roundabout', 'car-shadow', 'cows', 'dance-twirl',
43 | 'dog', 'drift-chicane', 'drift-straight', 'goat',
44 | 'horsejump-high', 'kite-surf', 'libby', 'motocross-jump',
45 | 'paragliding-launch', 'parkour', 'scooter-black',
46 | 'soapbox']
47 |
48 | # FBMS: Freiburg-Berkeley Motion Segmentation Dataset
49 | # url: https://lmb.informatik.uni-freiburg.de/resources/datasets/
50 | FBMS:
51 | image_dir: JPEGImages
52 | label_dir: Annotations
53 | image_ext: .jpg
54 | label_ext: .png
55 | default_label_interval: 1
56 | video_split:
57 | train: ['bear01', 'bear02', 'cars2', 'cars3',
58 | 'cars6', 'cars7', 'cars8', 'cars9',
59 | 'cats02', 'cats04' ,'cats05', 'cats07',
60 | 'ducks01', 'horses01', 'horses03', 'horses06',
61 | 'lion02', 'marple1', 'marple10', 'marple11',
62 | 'marple13', 'marple3', 'marple5', 'marple8',
63 | 'meerkats01', 'people04', 'people05', 'rabbits01',
64 | 'rabbits05']
65 | test: ['camel01', 'cars1', 'cars10', 'cars4',
66 | 'cars5', 'cats01', 'cats03', 'cats06',
67 | 'dogs01', 'dogs02', 'farm01', 'giraffes01',
68 | 'goats01', 'horses02', 'horses04', 'horses05',
69 | 'lion01', 'marple12', 'marple2', 'marple4',
70 | 'marple6', 'marple7', 'marple9', 'people03',
71 | 'people1', 'people2', 'rabbits02', 'rabbits03',
72 | 'rabbits04', 'tennis']
73 |
74 | # SegTrack v2 is a video segmentation dataset with full pixel-level annotations
75 | # on multiple objects at each frame within each video.
76 | # url: http://web.engr.oregonstate.edu/~lif/SegTrack2/dataset.html
77 | SegTrackv2:
78 | image_dir: JPEGImages
79 | label_dir: .png
80 | label_dir: .png
81 | default_label_interval: 1
82 | video_split:
83 | trainval: ['bird_of_paradise', 'birdfall', 'frog',
84 | 'monkey', 'parachute', 'soldier', 'worm']
85 |
86 | # TIP18: A Benchmark Dataset and Saliency-Guided Stacked Autoencoders for Video-Based Salient Object Detection
87 | # url: http://cvteam.net/projects/TIP18-VOS/VOS.html
88 | VOS:
89 | image_dir: JPEGImages
90 | label_dir: Annotations
91 | image_ext: .jpg
92 | label_ext: .png
93 | default_label_interval: 15
94 | video_split:
95 | train: ['1', '102', '103', '104', '107', '11', '111', '112', '114', '115',
96 | '117', '118', '119', '12', '123', '124', '125', '126', '127', '130',
97 | '131', '143', '145', '146', '147', '15', '156', '164', '17', '171',
98 | '176', '192', '197', '198', '199', '2', '20', '200', '201', '202',
99 | '205', '206', '207', '212', '215', '216', '217', '22', '220', '221',
100 | '222', '225', '226', '229', '23', '230', '231', '233', '235', '236',
101 | '25', '250', '251', '252', '255', '256', '257', '258', '259', '26',
102 | '261', '262', '263', '265', '267', '268', '27', '270', '271', '272',
103 | '273', '275', '276', '30', '32', '33', '34', '35', '38', '4', '40',
104 | '44', '45', '46', '48', '50', '51', '52', '53', '55', '6', '61', '64',
105 | '66', '67', '69', '70', '71', '73', '78', '80', '81', '83', '87', '88',
106 | '9', '90', '96', '99']
107 | val: ['10', '101', '105', '109', '113', '120', '13', '133', '14', '148', '158',
108 | '18', '180', '196', '203', '204', '208', '209', '213', '219', '223', '228',
109 | '24','260', '269', '28', '31', '37', '39', '5', '57', '62', '7', '72', '77',
110 | '84', '92','94', '95', '97']
111 | test: ['100', '106', '108', '110', '121', '132', '134', '16', '172', '189', '19',
112 | '194', '195', '21', '210', '211', '214', '224', '227', '232', '254', '264',
113 | '266', '274', '29', '3', '36', '42', '43', '47', '49', '58', '65', '68', '74',
114 | '76', '8', '85', '93', '98']
115 | easy: ['1', '10', '101', '11', '12', '13', '130', '131', '132', '133', '134',
116 | '14', '143', '15', '16', '17', '18', '19', '192', '194', '195', '196', '197',
117 | '198', '199', '2', '20', '200', '201', '202', '203', '204', '205', '206', '207',
118 | '208', '209', '21', '210', '211', '22', '23', '233', '24', '25', '254', '255',
119 | '256', '257', '258', '259', '26', '260', '261', '262', '263', '264', '265', '266',
120 | '267', '268', '269', '27', '270', '271', '272', '273', '274', '275', '276', '28',
121 | '29', '3', '30', '31', '32', '33', '34', '4', '42', '5', '50', '51', '6', '68',
122 | '7', '76', '78', '8', '88', '9', '90', '92', '94', '96', '98']
123 | normal: ['100', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111',
124 | '112', '113', '114', '115', '117', '118', '119', '120', '121', '123', '124', '125',
125 | '126', '127', '145', '146', '147', '148', '156', '158', '164', '171', '172', '176',
126 | '180', '189', '212', '213', '214', '215', '216', '217', '219', '220', '221', '222',
127 | '223', '224', '225', '226', '227', '228', '229', '230', '231', '232', '235', '236',
128 | '250', '251', '252', '35', '36', '37', '38', '39', '40', '43', '44', '45', '46', '47',
129 | '48', '49', '52', '53', '55', '57', '58', '61', '62', '64', '65', '66', '67', '69',
130 | '70', '71', '72', '73', '74', '77', '80', '81', '83', '84', '85', '87', '93', '95',
131 | '97', '99']
132 | easy-train: ['206', '90', '268', '30', '26', '201', '131', '271', '255', '276', '270', '15', '25',
133 | '1', '50', '20', '51', '88', '96', '130', '197', '27', '4', '205', '199', '78', '261',
134 | '2', '207', '198', '272', '192', '202', '258', '262', '32', '12', '6', '265', '263',
135 | '200', '11', '143', '267', '9', '22', '259', '275', '33', '17', '257', '34', '23', '233',
136 | '273', '256']
137 | easy-val: ['31', '94', '204', '92', '133', '196', '203', '209', '208', '5', '101', '24', '260',
138 | '18', '13', '269', '28', '10', '7', '14']
139 | easy-test: ['29', '42', '194', '211', '254', '21', '264', '274', '19', '134', '76', '8', '98',
140 | '210', '266', '16', '68', '132', '3', '195']
141 | normal-train: ['115', '35', '127', '217', '87', '55', '147', '126', '125', '114', '123', '111',
142 | '235', '250', '103', '71', '81', '212', '231', '38', '252', '222', '69', '145', '176',
143 | '221', '225', '44', '45', '171', '124', '118', '229', '66', '112', '156', '220', '64',
144 | '216', '107', '80', '73', '48', '251', '236', '61', '102', '40', '83', '46', '67', '99',
145 | '164', '119', '53', '146', '226', '215', '230', '117', '52', '104', '70']
146 | normal-val: ['148', '57', '77', '158', '228', '37', '39', '62', '105', '180', '97', '109', '219',
147 | '72', '84', '113', '120', '213', '95', '223']
148 | normal-test: ['189', '85', '108', '110', '224', '121', '172', '43', '232', '36', '93', '47', '74',
149 | '227', '214', '49', '65', '106', '100', '58']
--------------------------------------------------------------------------------
/data/datasets/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/data/datasets/.gitkeep
--------------------------------------------------------------------------------
/docs/_config.yml:
--------------------------------------------------------------------------------
1 | theme: jekyll-theme-cayman
--------------------------------------------------------------------------------
/docs/bibtex.txt:
--------------------------------------------------------------------------------
1 | @inproceedings{yan2019semi,
2 | title={Semi-Supervised Video Salient Object Detection Using Pseudo-Labels},
3 | author={Yan, Pengxiang and Li, Guanbin and Xie, Yuan and Li, Zhen and Wang, Chuan and Chen, Tianshui and Lin, Liang},
4 | booktitle={Proceedings of the IEEE International Conference on Computer Vision},
5 | pages={7284--7293},
6 | year={2019}
7 | }
8 |
--------------------------------------------------------------------------------
/docs/comp_video_sota.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/comp_video_sota.png
--------------------------------------------------------------------------------
/docs/index.md:
--------------------------------------------------------------------------------
1 | ## Abstract
2 |
3 | Deep learning-based video salient object detection has recently achieved great success with its performance significantly outperforming any other unsupervised methods. However, existing data-driven approaches heavily rely on a large quantity of pixel-wise annotated video frames to deliver such promising results. In this paper, we address the semi-supervised video salient object detection task using pseudo-labels. Specifically, we present an effective video saliency detector that consists of a spatial refinement network and a spatiotemporal module. Based on the same refinement network and motion information in terms of optical flow, we further propose a novel method for generating pixel-level pseudo-labels from sparsely annotated frames. By utilizing the generated pseudo-labels together with a part of manual annotations, our video saliency detector learns spatial and temporal cues for both contrast inference and coherence enhancement, thus producing accurate saliency maps. Experimental results demonstrate that our proposed semi-supervised method even greatly outperforms all the state-of-the-art fully supervised methods across three public benchmarks of VOS, DAVIS, and FBMS.
4 |
5 |
6 |
7 | ## Paper
8 |
9 | Pengxiang Yan, Guanbin Li, Yuan Xie, Zhen Li, Chuan Wang, Tianshui Chen, Liang Lin, **Semi-Supervised Video Salient Object Detection Using Pseudo-Labels**, the IEEE International Conference on Computer Vision (ICCV), 2019, pp. 7284-7293. [[Arxiv](https://arxiv.org/abs/1908.04051)] [[Code](https://github.com/Kinpzz/RCRNet-Pytorch)] [[BibTex](https://github.com/Kinpzz/RCRNet-Pytorch/raw/master/docs/bibtex.txt)]
10 |
11 | ## Motivation
12 |
13 | * Existing data-driven approaches of video salient object detection heavily rely on a large quantity of densely annotated video frames to deliver promising results.
14 |
15 | * Consecutive video frames share small differences but will take a lot of efforts to densely annotate them and the labeling consistency is also hard to guarantee.
16 |
17 | ## Contributions
18 |
19 | - We propose a refinement network (RCRNet) equipped with a nonlocally enhanced recurrent (NER) module for spatiotemporal coherence modeling.
20 |
21 | - We propose a flow-guided pseudo-label generator (FGPLG) to generate pseudo-labels of intervals based on sparse annotations.
22 |
23 | - As shown in Figure. 1, our model can produce reasonable and consistent pseudo-labels, which can even improve the boundary details (Example a) and overcome the labeling ambiguity between frames (Example b).
24 |
25 |
26 |
27 | - Experimental results show that utilizing the joint supervision of pseudo-labels and sparse annotations can further improve the model performance.
28 |
29 |
30 |
31 | ## Architecture
32 |
33 | ### RCRNet
34 |
35 |
36 |
37 | ### RCRNet+NER
38 |
39 | 
40 |
41 | ### Flow-Guided Pseudo-Label Generator
42 |
43 | 
44 |
45 | ## Results
46 |
47 | ### Quantitative Comparison
48 | 
49 |
50 | ## Downloads
51 |
52 | * Pre-computed saliency maps the validation set of DAVIS2016, the test set of FBMS, and the test set of VOS. [[Google Drive](https://drive.google.com/open?id=1feY3GdNBS-LUBt0UDWwpA3fl9yHI4Vxr)] [[Baidu Pan](https://pan.baidu.com/s/1oXBr9qxyF-8vvilvV5kcPg) (passwd: u079)]
53 |
54 | ## Q&A
55 |
56 | Q1: What is the difference between the semi-supervised strategies mentioned in semi-supervised video salient object detection (VSOD) and semi-supevised video object segmentation (VOS)?
57 |
58 | A1: VOS can be categories into **semi-supervised** and **unsupervised** methods when referring to different **testing** schemes. Semi-supervised VOS will provide the annotation of the first frame when testing. Video salient object detection (VSOD) is more similar to unsupervised VOS as both of them do not resort to labeled frames during testing. Here, our proposed method use only a part of labeled frames for **training** and that makes we call it a semi-supervised VSOD method.
59 |
60 | Q2: Are all the V+D methods in Table 1 fully-supervised? It might lack comparison with other semi-supervised video salient object detection methods (VSOD) and mask propagation based segmentation methods.
61 |
62 | A2: As far as we know, when referring to the training scheme, we are the first to adopt a semi-supervised strategy for VSOD. Thus, all the V+D methods in Table 1 (except ours) are trained under full supervision. Since most mask propagation based methods are designed for semi-supervised VOS, we only compare with unsupervised VOS methods in Section 5 of the paper.
63 |
64 | Q3: Is PDB fully-supervised or unsupervised?
65 |
66 | A3: PDB is fully-supervised during training but unsupervised during testing. Specifically, PDB is an algorithm for both video salient object detection and unsupervised video object segmentation and it is trained under fully supervision.
67 |
68 | ## Contact
69 |
70 | If you have any question, please send us an email at yanpx(AT)mail2.sysu.edu.cn
--------------------------------------------------------------------------------
/docs/motorcross-jump.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/motorcross-jump.gif
--------------------------------------------------------------------------------
/docs/pseudo_label_example.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/pseudo_label_example.png
--------------------------------------------------------------------------------
/docs/pseudo_label_generator.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/pseudo_label_generator.png
--------------------------------------------------------------------------------
/docs/static_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/static_model.png
--------------------------------------------------------------------------------
/docs/training_instruction.md:
--------------------------------------------------------------------------------
1 | ## Training Instruction
2 |
3 | ### Training RCRNet+NER
4 | If you want to train the proposed RCRNet from scratch, please refer to our paper and the following instruction carefully.
5 |
6 | The proposed RCRNet is built upon an ResNet-50 pretrained on ImageNet.
7 |
8 |
9 |
10 | **First**, we use two image saliency datasets, i.e., [MSRA-B](https://mmcheng.net/msra10k/) and [HKU-IS](https://i.cs.hku.hk/~gbli/deep_saliency.html), to pretrain the RCRNet (Figure 2), which contains a spatial feature extractor and a pixel-wise classifer. Here, we provide the weights of RCRNet pretrained on image saliency datasets at at [Google Drive](https://drive.google.com/open?id=1S7nao9WEhIiTmTC-E0nujMxm5Emypti9) or [Baidu Pan](https://pan.baidu.com/s/196cUbTInWJKd8FmiP9Jv_A) (passwd: j839). For simplicity, we do not provide the training code of this step. If you want to train this step you can implement your own training code.
11 |
12 | 
13 |
14 | **Second**, we use the RCRNet pretrained on image saliency datasets as the backbone. Then we combine the training set of three video saliency datasets including VOS, DAVIS, and FBMS, to train the full video model, i.e., RCRNet equipped with NER module (Figure 3). You can run the following commands to train the RCRNet+NER.
15 | ```
16 | $ CUDA_VISIBLE_DEVICES=0 python train.py \
17 | --data data/datasets \
18 | --checkpoint models/image_pretrained_model.pth
19 | ```
20 |
21 | ### Using psedo-labels for training
22 |
23 | 
24 |
25 | As for the second step, if you want train the RCRNet+NER using generated pseudo-labels for joint supervision. You can use our proposed flow-guied pseudo-label generator (FGPLG, Figure 4) to generate the pesdu-labels with a part of ground truth images.
26 |
27 | Note that the FGPLG requires flownet2.0 for flow estimation. Thus, please install the pytorch implementation of flownet2.0 using the following commands.
28 | ```
29 | # Install FlowNet 2.0 (implemented by NVIDIA)
30 | $ cd flownet2
31 | $ bash install.sh
32 | ```
33 |
34 | #### Generating pseudo-labels using FGPLG
35 | We provide the weights of FGPLG which is trained under the supervision of 20% ground truth images at [Baidu Pan](https://pan.baidu.com/s/1dw8O2Ua5pKmOKYVgKRyADQ) (passwd: hbsu). You can generate the pseduo-labels by
36 | ```
37 | $ CUDA_VISIBLE_DEVICES=0 python generate_pseudo_labels.py \
38 | --data data/datasets \
39 | --checkpoint models/pseudo_label_generator_5.pth \
40 | --pseudo-label-folder data/pseudo-labels \
41 | --label_interval 5 \
42 | --frame_between_label_num 1
43 | ```
44 |
45 | Then you can train the video model under the joint supervision of pseudo-labels.
46 | ```
47 | $ CUDA_VISIBLE_DEVICES=0 python train.py \
48 | --data data/datasets \
49 | --checkpoint models/image_pretrained_model.pth \
50 | --pseudo-label-folder data/pseudo-labels/1_5
51 | ```
52 |
53 | #### (Optional) Training FGPLG
54 | You can also train the FGPLG using other propotions of ground truth images by
55 |
56 | (Note that need to download the pretrained model of [Flownet2](https://github.com/NVIDIA/flownet2-pytorch#converted-caffe-pre-trained-models)[620MB])
57 | ```
58 | # set l
59 | $ CUDA_VISIBLE_DEVICES=0 python train_fgplg.py \
60 | --data data/datasets \
61 | --label_interval l \
62 | --checkpoint models/image_pretrained_model.pth \
63 | --flownet-checkpoint models/FlowNet2_checkpoint.pth.tar
64 | ```
65 |
66 | Then you can use the trained FGPLG to generate pseudo labels based different numbers of GT images.
67 | ```
68 | # set l and m
69 | $ CUDA_VISIBLE_DEVICES=0 python generate_pseudo_labels.py \
70 | --data data/datasets \
71 | --checkpoint models/pseudo_label_generator_m.pth \
72 | --label_interval l \
73 | --frame_between_label_num m \
74 | --pseudo-label-folder data/pseudo-labels
75 | ```
76 | Finally, you can train the video model under the joint supervision of pseudo-labels.
77 | ```
78 | # set l and m
79 | $ CUDA_VISIBLE_DEVICES=0 python train.py \
80 | --data data/datasets \
81 | --checkpoint models/image_pretrained_model.pth \
82 | --pseudo-label-folder data/pseudo-labels/m_l
83 | ```
84 |
--------------------------------------------------------------------------------
/docs/video_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/docs/video_model.png
--------------------------------------------------------------------------------
/flownet2/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | .torch
3 | _ext
4 | *.o
5 | work
6 | work/*
7 | _ext/
--------------------------------------------------------------------------------
/flownet2/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:8.0-cudnn6-devel-ubuntu16.04
2 |
3 | RUN apt-get update && apt-get install -y rsync htop git openssh-server python-pip
4 |
5 | RUN pip install --upgrade pip
6 |
7 | RUN pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl
8 | RUN pip install torchvision cffi tensorboardX
9 |
10 | RUN pip install tqdm scipy scikit-image colorama==0.3.7
11 | RUN pip install setproctitle pytz ipython
--------------------------------------------------------------------------------
/flownet2/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright 2017 NVIDIA CORPORATION
2 |
3 | Licensed under the Apache License, Version 2.0 (the "License");
4 | you may not use this file except in compliance with the License.
5 | You may obtain a copy of the License at
6 |
7 | http://www.apache.org/licenses/LICENSE-2.0
8 |
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
--------------------------------------------------------------------------------
/flownet2/README.md:
--------------------------------------------------------------------------------
1 | # flownet2-pytorch
2 |
3 | Pytorch implementation of [FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks](https://arxiv.org/abs/1612.01925).
4 |
5 | Multiple GPU training is supported, and the code provides examples for training or inference on [MPI-Sintel](http://sintel.is.tue.mpg.de/) clean and final datasets. The same commands can be used for training or inference with other datasets. See below for more detail.
6 |
7 | Inference using fp16 (half-precision) is also supported.
8 |
9 | For more help, type
10 |
11 | python main.py --help
12 |
13 | ## Network architectures
14 | Below are the different flownet neural network architectures that are provided.
15 | A batchnorm version for each network is available.
16 |
17 | - **FlowNet2S**
18 | - **FlowNet2C**
19 | - **FlowNet2CS**
20 | - **FlowNet2CSS**
21 | - **FlowNet2SD**
22 | - **FlowNet2**
23 |
24 | ## Custom layers
25 |
26 | `FlowNet2` or `FlowNet2C*` achitectures rely on custom layers `Resample2d` or `Correlation`.
27 | A pytorch implementation of these layers with cuda kernels are available at [./networks](./networks).
28 | Note : Currently, half precision kernels are not available for these layers.
29 |
30 | ## Data Loaders
31 |
32 | Dataloaders for FlyingChairs, FlyingThings, ChairsSDHom and ImagesFromFolder are available in [datasets.py](./datasets.py).
33 |
34 | ## Loss Functions
35 |
36 | L1 and L2 losses with multi-scale support are available in [losses.py](./losses.py).
37 |
38 | ## Installation
39 |
40 | # get flownet2-pytorch source
41 | git clone https://github.com/NVIDIA/flownet2-pytorch.git
42 | cd flownet2-pytorch
43 |
44 | # install custom layers
45 | bash install.sh
46 |
47 | ### Python requirements
48 | Currently, the code supports python 3
49 | * numpy
50 | * PyTorch ( == 0.4.1, for <= 0.4.0 see branch [python36-PyTorch0.4](https://github.com/NVIDIA/flownet2-pytorch/tree/python36-PyTorch0.4))
51 | * scipy
52 | * scikit-image
53 | * tensorboardX
54 | * colorama, tqdm, setproctitle
55 |
56 | ## Converted Caffe Pre-trained Models
57 | We've included caffe pre-trained models. Should you use these pre-trained weights, please adhere to the [license agreements](https://drive.google.com/file/d/1TVv0BnNFh3rpHZvD-easMb9jYrPE2Eqd/view?usp=sharing).
58 |
59 | * [FlowNet2](https://drive.google.com/file/d/1hF8vS6YeHkx3j2pfCeQqqZGwA_PJq_Da/view?usp=sharing)[620MB]
60 | * [FlowNet2-C](https://drive.google.com/file/d/1BFT6b7KgKJC8rA59RmOVAXRM_S7aSfKE/view?usp=sharing)[149MB]
61 | * [FlowNet2-CS](https://drive.google.com/file/d/1iBJ1_o7PloaINpa8m7u_7TsLCX0Dt_jS/view?usp=sharing)[297MB]
62 | * [FlowNet2-CSS](https://drive.google.com/file/d/157zuzVf4YMN6ABAQgZc8rRmR5cgWzSu8/view?usp=sharing)[445MB]
63 | * [FlowNet2-CSS-ft-sd](https://drive.google.com/file/d/1R5xafCIzJCXc8ia4TGfC65irmTNiMg6u/view?usp=sharing)[445MB]
64 | * [FlowNet2-S](https://drive.google.com/file/d/1V61dZjFomwlynwlYklJHC-TLfdFom3Lg/view?usp=sharing)[148MB]
65 | * [FlowNet2-SD](https://drive.google.com/file/d/1QW03eyYG_vD-dT-Mx4wopYvtPu_msTKn/view?usp=sharing)[173MB]
66 |
67 | ## Inference
68 | # Example on MPISintel Clean
69 | python main.py --inference --model FlowNet2 --save_flow --inference_dataset MpiSintelClean \
70 | --inference_dataset_root /path/to/mpi-sintel/clean/dataset \
71 | --resume /path/to/checkpoints
72 |
73 | ## Training and validation
74 |
75 | # Example on MPISintel Final and Clean, with L1Loss on FlowNet2 model
76 | python main.py --batch_size 8 --model FlowNet2 --loss=L1Loss --optimizer=Adam --optimizer_lr=1e-4 \
77 | --training_dataset MpiSintelFinal --training_dataset_root /path/to/mpi-sintel/final/dataset \
78 | --validation_dataset MpiSintelClean --validation_dataset_root /path/to/mpi-sintel/clean/dataset
79 |
80 | # Example on MPISintel Final and Clean, with MultiScale loss on FlowNet2C model
81 | python main.py --batch_size 8 --model FlowNet2C --optimizer=Adam --optimizer_lr=1e-4 --loss=MultiScale --loss_norm=L1 \
82 | --loss_numScales=5 --loss_startScale=4 --optimizer_lr=1e-4 --crop_size 384 512 \
83 | --training_dataset FlyingChairs --training_dataset_root /path/to/flying-chairs/dataset \
84 | --validation_dataset MpiSintelClean --validation_dataset_root /path/to/mpi-sintel/clean/dataset
85 |
86 | ## Results on MPI-Sintel
87 | [](https://www.youtube.com/watch?v=HtBmabY8aeU "Predicted flows on MPI-Sintel")
88 |
89 | ## Reference
90 | If you find this implementation useful in your work, please acknowledge it appropriately and cite the paper:
91 | ````
92 | @InProceedings{IMKDB17,
93 | author = "E. Ilg and N. Mayer and T. Saikia and M. Keuper and A. Dosovitskiy and T. Brox",
94 | title = "FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks",
95 | booktitle = "IEEE Conference on Computer Vision and Pattern Recognition (CVPR)",
96 | month = "Jul",
97 | year = "2017",
98 | url = "http://lmb.informatik.uni-freiburg.de//Publications/2017/IMKDB17"
99 | }
100 | ````
101 | ```
102 | @misc{flownet2-pytorch,
103 | author = {Fitsum Reda and Robert Pottorff and Jon Barker and Bryan Catanzaro},
104 | title = {flownet2-pytorch: Pytorch implementation of FlowNet 2.0: Evolution of Optical Flow Estimation with Deep Networks},
105 | year = {2017},
106 | publisher = {GitHub},
107 | journal = {GitHub repository},
108 | howpublished = {\url{https://github.com/NVIDIA/flownet2-pytorch}}
109 | }
110 | ```
111 | ## Related Optical Flow Work from Nvidia
112 | Code (in Caffe and Pytorch): [PWC-Net](https://github.com/NVlabs/PWC-Net)
113 | Paper : [PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume](https://arxiv.org/abs/1709.02371).
114 |
115 | ## Acknowledgments
116 | Parts of this code were derived, as noted in the code, from [ClementPinard/FlowNetPytorch](https://github.com/ClementPinard/FlowNetPytorch).
117 |
--------------------------------------------------------------------------------
/flownet2/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/__init__.py
--------------------------------------------------------------------------------
/flownet2/convert.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python2.7
2 |
3 | import caffe
4 | from caffe.proto import caffe_pb2
5 | import sys, os
6 |
7 | import torch
8 | import torch.nn as nn
9 |
10 | import argparse, tempfile
11 | import numpy as np
12 |
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('caffe_model', help='input model in hdf5 or caffemodel format')
15 | parser.add_argument('prototxt_template',help='prototxt template')
16 | parser.add_argument('flownet2_pytorch', help='path to flownet2-pytorch')
17 |
18 | args = parser.parse_args()
19 |
20 | args.rgb_max = 255
21 | args.fp16 = False
22 | args.grads = {}
23 |
24 | # load models
25 | sys.path.append(args.flownet2_pytorch)
26 |
27 | import models
28 | from utils.param_utils import *
29 |
30 | width = 256
31 | height = 256
32 | keys = {'TARGET_WIDTH': width,
33 | 'TARGET_HEIGHT': height,
34 | 'ADAPTED_WIDTH':width,
35 | 'ADAPTED_HEIGHT':height,
36 | 'SCALE_WIDTH':1.,
37 | 'SCALE_HEIGHT':1.,}
38 |
39 | template = '\n'.join(np.loadtxt(args.prototxt_template, dtype=str, delimiter='\n'))
40 | for k in keys:
41 | template = template.replace('$%s$'%(k),str(keys[k]))
42 |
43 | prototxt = tempfile.NamedTemporaryFile(mode='w', delete=True)
44 | prototxt.write(template)
45 | prototxt.flush()
46 |
47 | net = caffe.Net(prototxt.name, args.caffe_model, caffe.TEST)
48 |
49 | weights = {}
50 | biases = {}
51 |
52 | for k, v in list(net.params.items()):
53 | weights[k] = np.array(v[0].data).reshape(v[0].data.shape)
54 | biases[k] = np.array(v[1].data).reshape(v[1].data.shape)
55 | print((k, weights[k].shape, biases[k].shape))
56 |
57 | if 'FlowNet2/' in args.caffe_model:
58 | model = models.FlowNet2(args)
59 |
60 | parse_flownetc(model.flownetc.modules(), weights, biases)
61 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_')
62 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_')
63 | parse_flownetsd(model.flownets_d.modules(), weights, biases, param_prefix='netsd_')
64 | parse_flownetfusion(model.flownetfusion.modules(), weights, biases, param_prefix='fuse_')
65 |
66 | state = {'epoch': 0,
67 | 'state_dict': model.state_dict(),
68 | 'best_EPE': 1e10}
69 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2_checkpoint.pth.tar'))
70 |
71 | elif 'FlowNet2-C/' in args.caffe_model:
72 | model = models.FlowNet2C(args)
73 |
74 | parse_flownetc(model.modules(), weights, biases)
75 | state = {'epoch': 0,
76 | 'state_dict': model.state_dict(),
77 | 'best_EPE': 1e10}
78 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-C_checkpoint.pth.tar'))
79 |
80 | elif 'FlowNet2-CS/' in args.caffe_model:
81 | model = models.FlowNet2CS(args)
82 |
83 | parse_flownetc(model.flownetc.modules(), weights, biases)
84 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_')
85 |
86 | state = {'epoch': 0,
87 | 'state_dict': model.state_dict(),
88 | 'best_EPE': 1e10}
89 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CS_checkpoint.pth.tar'))
90 |
91 | elif 'FlowNet2-CSS/' in args.caffe_model:
92 | model = models.FlowNet2CSS(args)
93 |
94 | parse_flownetc(model.flownetc.modules(), weights, biases)
95 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_')
96 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_')
97 |
98 | state = {'epoch': 0,
99 | 'state_dict': model.state_dict(),
100 | 'best_EPE': 1e10}
101 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS_checkpoint.pth.tar'))
102 |
103 | elif 'FlowNet2-CSS-ft-sd/' in args.caffe_model:
104 | model = models.FlowNet2CSS(args)
105 |
106 | parse_flownetc(model.flownetc.modules(), weights, biases)
107 | parse_flownets(model.flownets_1.modules(), weights, biases, param_prefix='net2_')
108 | parse_flownets(model.flownets_2.modules(), weights, biases, param_prefix='net3_')
109 |
110 | state = {'epoch': 0,
111 | 'state_dict': model.state_dict(),
112 | 'best_EPE': 1e10}
113 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-CSS-ft-sd_checkpoint.pth.tar'))
114 |
115 | elif 'FlowNet2-S/' in args.caffe_model:
116 | model = models.FlowNet2S(args)
117 |
118 | parse_flownetsonly(model.modules(), weights, biases, param_prefix='')
119 | state = {'epoch': 0,
120 | 'state_dict': model.state_dict(),
121 | 'best_EPE': 1e10}
122 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-S_checkpoint.pth.tar'))
123 |
124 | elif 'FlowNet2-SD/' in args.caffe_model:
125 | model = models.FlowNet2SD(args)
126 |
127 | parse_flownetsd(model.modules(), weights, biases, param_prefix='')
128 |
129 | state = {'epoch': 0,
130 | 'state_dict': model.state_dict(),
131 | 'best_EPE': 1e10}
132 | torch.save(state, os.path.join(args.flownet2_pytorch, 'FlowNet2-SD_checkpoint.pth.tar'))
133 |
134 | else:
135 | print(('model type cound not be determined from input caffe model %s'%(args.caffe_model)))
136 | quit()
137 | print(("done converting ", args.caffe_model))
--------------------------------------------------------------------------------
/flownet2/datasets.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.utils.data as data
3 |
4 | import os, math, random
5 | from os.path import *
6 | import numpy as np
7 |
8 | from glob import glob
9 | import utils.frame_utils as frame_utils
10 |
11 | from scipy.misc import imread, imresize
12 |
13 | class StaticRandomCrop(object):
14 | def __init__(self, image_size, crop_size):
15 | self.th, self.tw = crop_size
16 | h, w = image_size
17 | self.h1 = random.randint(0, h - self.th)
18 | self.w1 = random.randint(0, w - self.tw)
19 |
20 | def __call__(self, img):
21 | return img[self.h1:(self.h1+self.th), self.w1:(self.w1+self.tw),:]
22 |
23 | class StaticCenterCrop(object):
24 | def __init__(self, image_size, crop_size):
25 | self.th, self.tw = crop_size
26 | self.h, self.w = image_size
27 | def __call__(self, img):
28 | return img[(self.h-self.th)//2:(self.h+self.th)//2, (self.w-self.tw)//2:(self.w+self.tw)//2,:]
29 |
30 | class MpiSintel(data.Dataset):
31 | def __init__(self, args, is_cropped = False, root = '', dstype = 'clean', replicates = 1):
32 | self.args = args
33 | self.is_cropped = is_cropped
34 | self.crop_size = args.crop_size
35 | self.render_size = args.inference_size
36 | self.replicates = replicates
37 |
38 | flow_root = join(root, 'flow')
39 | image_root = join(root, dstype)
40 |
41 | file_list = sorted(glob(join(flow_root, '*/*.flo')))
42 |
43 | self.flow_list = []
44 | self.image_list = []
45 |
46 | for file in file_list:
47 | if 'test' in file:
48 | # print file
49 | continue
50 |
51 | fbase = file[len(flow_root)+1:]
52 | fprefix = fbase[:-8]
53 | fnum = int(fbase[-8:-4])
54 |
55 | img1 = join(image_root, fprefix + "%04d"%(fnum+0) + '.png')
56 | img2 = join(image_root, fprefix + "%04d"%(fnum+1) + '.png')
57 |
58 | if not isfile(img1) or not isfile(img2) or not isfile(file):
59 | continue
60 |
61 | self.image_list += [[img1, img2]]
62 | self.flow_list += [file]
63 |
64 | self.size = len(self.image_list)
65 |
66 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
67 |
68 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64):
69 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64
70 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64
71 |
72 | args.inference_size = self.render_size
73 |
74 | assert (len(self.image_list) == len(self.flow_list))
75 |
76 | def __getitem__(self, index):
77 |
78 | index = index % self.size
79 |
80 | img1 = frame_utils.read_gen(self.image_list[index][0])
81 | img2 = frame_utils.read_gen(self.image_list[index][1])
82 |
83 | flow = frame_utils.read_gen(self.flow_list[index])
84 |
85 | images = [img1, img2]
86 | image_size = img1.shape[:2]
87 |
88 | if self.is_cropped:
89 | cropper = StaticRandomCrop(image_size, self.crop_size)
90 | else:
91 | cropper = StaticCenterCrop(image_size, self.render_size)
92 | images = list(map(cropper, images))
93 | flow = cropper(flow)
94 |
95 | images = np.array(images).transpose(3,0,1,2)
96 | flow = flow.transpose(2,0,1)
97 |
98 | images = torch.from_numpy(images.astype(np.float32))
99 | flow = torch.from_numpy(flow.astype(np.float32))
100 |
101 | return [images], [flow]
102 |
103 | def __len__(self):
104 | return self.size * self.replicates
105 |
106 | class MpiSintelClean(MpiSintel):
107 | def __init__(self, args, is_cropped = False, root = '', replicates = 1):
108 | super(MpiSintelClean, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'clean', replicates = replicates)
109 |
110 | class MpiSintelFinal(MpiSintel):
111 | def __init__(self, args, is_cropped = False, root = '', replicates = 1):
112 | super(MpiSintelFinal, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'final', replicates = replicates)
113 |
114 | class FlyingChairs(data.Dataset):
115 | def __init__(self, args, is_cropped, root = '/path/to/FlyingChairs_release/data', replicates = 1):
116 | self.args = args
117 | self.is_cropped = is_cropped
118 | self.crop_size = args.crop_size
119 | self.render_size = args.inference_size
120 | self.replicates = replicates
121 |
122 | images = sorted( glob( join(root, '*.ppm') ) )
123 |
124 | self.flow_list = sorted( glob( join(root, '*.flo') ) )
125 |
126 | assert (len(images)//2 == len(self.flow_list))
127 |
128 | self.image_list = []
129 | for i in range(len(self.flow_list)):
130 | im1 = images[2*i]
131 | im2 = images[2*i + 1]
132 | self.image_list += [ [ im1, im2 ] ]
133 |
134 | assert len(self.image_list) == len(self.flow_list)
135 |
136 | self.size = len(self.image_list)
137 |
138 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
139 |
140 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64):
141 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64
142 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64
143 |
144 | args.inference_size = self.render_size
145 |
146 | def __getitem__(self, index):
147 | index = index % self.size
148 |
149 | img1 = frame_utils.read_gen(self.image_list[index][0])
150 | img2 = frame_utils.read_gen(self.image_list[index][1])
151 |
152 | flow = frame_utils.read_gen(self.flow_list[index])
153 |
154 | images = [img1, img2]
155 | image_size = img1.shape[:2]
156 | if self.is_cropped:
157 | cropper = StaticRandomCrop(image_size, self.crop_size)
158 | else:
159 | cropper = StaticCenterCrop(image_size, self.render_size)
160 | images = list(map(cropper, images))
161 | flow = cropper(flow)
162 |
163 |
164 | images = np.array(images).transpose(3,0,1,2)
165 | flow = flow.transpose(2,0,1)
166 |
167 | images = torch.from_numpy(images.astype(np.float32))
168 | flow = torch.from_numpy(flow.astype(np.float32))
169 |
170 | return [images], [flow]
171 |
172 | def __len__(self):
173 | return self.size * self.replicates
174 |
175 | class FlyingThings(data.Dataset):
176 | def __init__(self, args, is_cropped, root = '/path/to/flyingthings3d', dstype = 'frames_cleanpass', replicates = 1):
177 | self.args = args
178 | self.is_cropped = is_cropped
179 | self.crop_size = args.crop_size
180 | self.render_size = args.inference_size
181 | self.replicates = replicates
182 |
183 | image_dirs = sorted(glob(join(root, dstype, 'TRAIN/*/*')))
184 | image_dirs = sorted([join(f, 'left') for f in image_dirs] + [join(f, 'right') for f in image_dirs])
185 |
186 | flow_dirs = sorted(glob(join(root, 'optical_flow_flo_format/TRAIN/*/*')))
187 | flow_dirs = sorted([join(f, 'into_future/left') for f in flow_dirs] + [join(f, 'into_future/right') for f in flow_dirs])
188 |
189 | assert (len(image_dirs) == len(flow_dirs))
190 |
191 | self.image_list = []
192 | self.flow_list = []
193 |
194 | for idir, fdir in zip(image_dirs, flow_dirs):
195 | images = sorted( glob(join(idir, '*.png')) )
196 | flows = sorted( glob(join(fdir, '*.flo')) )
197 | for i in range(len(flows)):
198 | self.image_list += [ [ images[i], images[i+1] ] ]
199 | self.flow_list += [flows[i]]
200 |
201 | assert len(self.image_list) == len(self.flow_list)
202 |
203 | self.size = len(self.image_list)
204 |
205 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
206 |
207 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64):
208 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64
209 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64
210 |
211 | args.inference_size = self.render_size
212 |
213 | def __getitem__(self, index):
214 | index = index % self.size
215 |
216 | img1 = frame_utils.read_gen(self.image_list[index][0])
217 | img2 = frame_utils.read_gen(self.image_list[index][1])
218 |
219 | flow = frame_utils.read_gen(self.flow_list[index])
220 |
221 | images = [img1, img2]
222 | image_size = img1.shape[:2]
223 | if self.is_cropped:
224 | cropper = StaticRandomCrop(image_size, self.crop_size)
225 | else:
226 | cropper = StaticCenterCrop(image_size, self.render_size)
227 | images = list(map(cropper, images))
228 | flow = cropper(flow)
229 |
230 |
231 | images = np.array(images).transpose(3,0,1,2)
232 | flow = flow.transpose(2,0,1)
233 |
234 | images = torch.from_numpy(images.astype(np.float32))
235 | flow = torch.from_numpy(flow.astype(np.float32))
236 |
237 | return [images], [flow]
238 |
239 | def __len__(self):
240 | return self.size * self.replicates
241 |
242 | class FlyingThingsClean(FlyingThings):
243 | def __init__(self, args, is_cropped = False, root = '', replicates = 1):
244 | super(FlyingThingsClean, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'frames_cleanpass', replicates = replicates)
245 |
246 | class FlyingThingsFinal(FlyingThings):
247 | def __init__(self, args, is_cropped = False, root = '', replicates = 1):
248 | super(FlyingThingsFinal, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'frames_finalpass', replicates = replicates)
249 |
250 | class ChairsSDHom(data.Dataset):
251 | def __init__(self, args, is_cropped, root = '/path/to/chairssdhom/data', dstype = 'train', replicates = 1):
252 | self.args = args
253 | self.is_cropped = is_cropped
254 | self.crop_size = args.crop_size
255 | self.render_size = args.inference_size
256 | self.replicates = replicates
257 |
258 | image1 = sorted( glob( join(root, dstype, 't0/*.png') ) )
259 | image2 = sorted( glob( join(root, dstype, 't1/*.png') ) )
260 | self.flow_list = sorted( glob( join(root, dstype, 'flow/*.flo') ) )
261 |
262 | assert (len(image1) == len(self.flow_list))
263 |
264 | self.image_list = []
265 | for i in range(len(self.flow_list)):
266 | im1 = image1[i]
267 | im2 = image2[i]
268 | self.image_list += [ [ im1, im2 ] ]
269 |
270 | assert len(self.image_list) == len(self.flow_list)
271 |
272 | self.size = len(self.image_list)
273 |
274 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
275 |
276 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64):
277 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64
278 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64
279 |
280 | args.inference_size = self.render_size
281 |
282 | def __getitem__(self, index):
283 | index = index % self.size
284 |
285 | img1 = frame_utils.read_gen(self.image_list[index][0])
286 | img2 = frame_utils.read_gen(self.image_list[index][1])
287 |
288 | flow = frame_utils.read_gen(self.flow_list[index])
289 | flow = flow[::-1,:,:]
290 |
291 | images = [img1, img2]
292 | image_size = img1.shape[:2]
293 | if self.is_cropped:
294 | cropper = StaticRandomCrop(image_size, self.crop_size)
295 | else:
296 | cropper = StaticCenterCrop(image_size, self.render_size)
297 | images = list(map(cropper, images))
298 | flow = cropper(flow)
299 |
300 |
301 | images = np.array(images).transpose(3,0,1,2)
302 | flow = flow.transpose(2,0,1)
303 |
304 | images = torch.from_numpy(images.astype(np.float32))
305 | flow = torch.from_numpy(flow.astype(np.float32))
306 |
307 | return [images], [flow]
308 |
309 | def __len__(self):
310 | return self.size * self.replicates
311 |
312 | class ChairsSDHomTrain(ChairsSDHom):
313 | def __init__(self, args, is_cropped = False, root = '', replicates = 1):
314 | super(ChairsSDHomTrain, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'train', replicates = replicates)
315 |
316 | class ChairsSDHomTest(ChairsSDHom):
317 | def __init__(self, args, is_cropped = False, root = '', replicates = 1):
318 | super(ChairsSDHomTest, self).__init__(args, is_cropped = is_cropped, root = root, dstype = 'test', replicates = replicates)
319 |
320 | class ImagesFromFolder(data.Dataset):
321 | def __init__(self, args, is_cropped, root = '/path/to/frames/only/folder', iext = 'png', replicates = 1):
322 | self.args = args
323 | self.is_cropped = is_cropped
324 | self.crop_size = args.crop_size
325 | self.render_size = args.inference_size
326 | self.replicates = replicates
327 |
328 | images = sorted( glob( join(root, '*.' + iext) ) )
329 | self.image_list = []
330 | for i in range(len(images)-1):
331 | im1 = images[i]
332 | im2 = images[i+1]
333 | self.image_list += [ [ im1, im2 ] ]
334 |
335 | self.size = len(self.image_list)
336 |
337 | self.frame_size = frame_utils.read_gen(self.image_list[0][0]).shape
338 |
339 | if (self.render_size[0] < 0) or (self.render_size[1] < 0) or (self.frame_size[0]%64) or (self.frame_size[1]%64):
340 | self.render_size[0] = ( (self.frame_size[0])//64 ) * 64
341 | self.render_size[1] = ( (self.frame_size[1])//64 ) * 64
342 |
343 | args.inference_size = self.render_size
344 |
345 | def __getitem__(self, index):
346 | index = index % self.size
347 |
348 | img1 = frame_utils.read_gen(self.image_list[index][0])
349 | img2 = frame_utils.read_gen(self.image_list[index][1])
350 |
351 | images = [img1, img2]
352 | image_size = img1.shape[:2]
353 | if self.is_cropped:
354 | cropper = StaticRandomCrop(image_size, self.crop_size)
355 | else:
356 | cropper = StaticCenterCrop(image_size, self.render_size)
357 | images = list(map(cropper, images))
358 |
359 | images = np.array(images).transpose(3,0,1,2)
360 | images = torch.from_numpy(images.astype(np.float32))
361 |
362 | return [images], [torch.zeros(images.size()[0:1] + (2,) + images.size()[-2:])]
363 |
364 | def __len__(self):
365 | return self.size * self.replicates
366 |
367 | '''
368 | import argparse
369 | import sys, os
370 | import importlib
371 | from scipy.misc import imsave
372 | import numpy as np
373 |
374 | import datasets
375 | reload(datasets)
376 |
377 | parser = argparse.ArgumentParser()
378 | args = parser.parse_args()
379 | args.inference_size = [1080, 1920]
380 | args.crop_size = [384, 512]
381 | args.effective_batch_size = 1
382 |
383 | index = 500
384 | v_dataset = datasets.MpiSintelClean(args, True, root='../MPI-Sintel/flow/training')
385 | a, b = v_dataset[index]
386 | im1 = a[0].numpy()[:,0,:,:].transpose(1,2,0)
387 | im2 = a[0].numpy()[:,1,:,:].transpose(1,2,0)
388 | imsave('./img1.png', im1)
389 | imsave('./img2.png', im2)
390 | flow_utils.writeFlow('./flow.flo', b[0].numpy().transpose(1,2,0))
391 |
392 | '''
393 |
--------------------------------------------------------------------------------
/flownet2/download_caffe_models.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | sudo rm -rf flownet2-docker
3 | sudo git clone https://github.com/lmb-freiburg/flownet2-docker
4 | cd flownet2-docker
5 |
6 | sudo sed -i '$ a RUN apt-get update && apt-get install -y python-pip \
7 | RUN pip install --upgrade pip \
8 | RUN pip install numpy -I \
9 | RUN pip install http://download.pytorch.org/whl/cu80/torch-0.2.0.post3-cp27-cp27mu-manylinux1_x86_64.whl \
10 | RUN pip install cffi ipython' Dockerfile
11 |
12 | sudo make
13 |
14 |
--------------------------------------------------------------------------------
/flownet2/image.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/image.png
--------------------------------------------------------------------------------
/flownet2/install.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | cd ./networks/correlation_package
3 | python setup.py install
4 | cd ../resample2d_package
5 | python setup.py install
6 | cd ../channelnorm_package
7 | python setup.py install
8 | cd ..
9 |
--------------------------------------------------------------------------------
/flownet2/launch_docker.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 .
3 | sudo nvidia-docker run --rm -ti --volume=$(pwd):/flownet2-pytorch:rw --workdir=/flownet2-pytorch --ipc=host $USER/pytorch:CUDA8-py27 /bin/bash
4 |
--------------------------------------------------------------------------------
/flownet2/losses.py:
--------------------------------------------------------------------------------
1 | '''
2 | Portions of this code copyright 2017, Clement Pinard
3 | '''
4 |
5 | # freda (todo) : adversarial loss
6 |
7 | import torch
8 | import torch.nn as nn
9 | import math
10 |
11 | def EPE(input_flow, target_flow):
12 | return torch.norm(target_flow-input_flow,p=2,dim=1).mean()
13 |
14 | class L1(nn.Module):
15 | def __init__(self):
16 | super(L1, self).__init__()
17 | def forward(self, output, target):
18 | lossvalue = torch.abs(output - target).mean()
19 | return lossvalue
20 |
21 | class L2(nn.Module):
22 | def __init__(self):
23 | super(L2, self).__init__()
24 | def forward(self, output, target):
25 | lossvalue = torch.norm(output-target,p=2,dim=1).mean()
26 | return lossvalue
27 |
28 | class L1Loss(nn.Module):
29 | def __init__(self, args):
30 | super(L1Loss, self).__init__()
31 | self.args = args
32 | self.loss = L1()
33 | self.loss_labels = ['L1', 'EPE']
34 |
35 | def forward(self, output, target):
36 | lossvalue = self.loss(output, target)
37 | epevalue = EPE(output, target)
38 | return [lossvalue, epevalue]
39 |
40 | class L2Loss(nn.Module):
41 | def __init__(self, args):
42 | super(L2Loss, self).__init__()
43 | self.args = args
44 | self.loss = L2()
45 | self.loss_labels = ['L2', 'EPE']
46 |
47 | def forward(self, output, target):
48 | lossvalue = self.loss(output, target)
49 | epevalue = EPE(output, target)
50 | return [lossvalue, epevalue]
51 |
52 | class MultiScale(nn.Module):
53 | def __init__(self, args, startScale = 4, numScales = 5, l_weight= 0.32, norm= 'L1'):
54 | super(MultiScale,self).__init__()
55 |
56 | self.startScale = startScale
57 | self.numScales = numScales
58 | self.loss_weights = torch.FloatTensor([(l_weight / 2 ** scale) for scale in range(self.numScales)])
59 | self.args = args
60 | self.l_type = norm
61 | self.div_flow = 0.05
62 | assert(len(self.loss_weights) == self.numScales)
63 |
64 | if self.l_type == 'L1':
65 | self.loss = L1()
66 | else:
67 | self.loss = L2()
68 |
69 | self.multiScales = [nn.AvgPool2d(self.startScale * (2**scale), self.startScale * (2**scale)) for scale in range(self.numScales)]
70 | self.loss_labels = ['MultiScale-'+self.l_type, 'EPE'],
71 |
72 | def forward(self, output, target):
73 | lossvalue = 0
74 | epevalue = 0
75 |
76 | if type(output) is tuple:
77 | target = self.div_flow * target
78 | for i, output_ in enumerate(output):
79 | target_ = self.multiScales[i](target)
80 | epevalue += self.loss_weights[i]*EPE(output_, target_)
81 | lossvalue += self.loss_weights[i]*self.loss(output_, target_)
82 | return [lossvalue, epevalue]
83 | else:
84 | epevalue += EPE(output, target)
85 | lossvalue += self.loss(output, target)
86 | return [lossvalue, epevalue]
87 |
88 |
--------------------------------------------------------------------------------
/flownet2/networks/FlowNetC.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 |
5 | import math
6 | import numpy as np
7 |
8 | from .correlation_package.correlation import Correlation
9 |
10 | from .submodules import *
11 | 'Parameter count , 39,175,298 '
12 |
13 | class FlowNetC(nn.Module):
14 | def __init__(self,args, batchNorm=True, div_flow = 20):
15 | super(FlowNetC,self).__init__()
16 |
17 | self.batchNorm = batchNorm
18 | self.div_flow = div_flow
19 |
20 | self.conv1 = conv(self.batchNorm, 3, 64, kernel_size=7, stride=2)
21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
23 | self.conv_redir = conv(self.batchNorm, 256, 32, kernel_size=1, stride=1)
24 |
25 | if args.fp16:
26 | self.corr = nn.Sequential(
27 | tofp32(),
28 | Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1),
29 | tofp16())
30 | else:
31 | self.corr = Correlation(pad_size=20, kernel_size=1, max_displacement=20, stride1=1, stride2=2, corr_multiply=1)
32 |
33 | self.corr_activation = nn.LeakyReLU(0.1,inplace=True)
34 | self.conv3_1 = conv(self.batchNorm, 473, 256)
35 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
36 | self.conv4_1 = conv(self.batchNorm, 512, 512)
37 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
38 | self.conv5_1 = conv(self.batchNorm, 512, 512)
39 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
40 | self.conv6_1 = conv(self.batchNorm,1024, 1024)
41 |
42 | self.deconv5 = deconv(1024,512)
43 | self.deconv4 = deconv(1026,256)
44 | self.deconv3 = deconv(770,128)
45 | self.deconv2 = deconv(386,64)
46 |
47 | self.predict_flow6 = predict_flow(1024)
48 | self.predict_flow5 = predict_flow(1026)
49 | self.predict_flow4 = predict_flow(770)
50 | self.predict_flow3 = predict_flow(386)
51 | self.predict_flow2 = predict_flow(194)
52 |
53 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
54 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
55 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
56 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=True)
57 |
58 | for m in self.modules():
59 | if isinstance(m, nn.Conv2d):
60 | if m.bias is not None:
61 | init.uniform_(m.bias)
62 | init.xavier_uniform_(m.weight)
63 |
64 | if isinstance(m, nn.ConvTranspose2d):
65 | if m.bias is not None:
66 | init.uniform_(m.bias)
67 | init.xavier_uniform_(m.weight)
68 | # init_deconv_bilinear(m.weight)
69 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
70 |
71 | def forward(self, x):
72 | x1 = x[:,0:3,:,:]
73 | x2 = x[:,3::,:,:]
74 |
75 | out_conv1a = self.conv1(x1)
76 | out_conv2a = self.conv2(out_conv1a)
77 | out_conv3a = self.conv3(out_conv2a)
78 |
79 | # FlownetC bottom input stream
80 | out_conv1b = self.conv1(x2)
81 |
82 | out_conv2b = self.conv2(out_conv1b)
83 | out_conv3b = self.conv3(out_conv2b)
84 |
85 | # Merge streams
86 | out_corr = self.corr(out_conv3a, out_conv3b) # False
87 | out_corr = self.corr_activation(out_corr)
88 |
89 | # Redirect top input stream and concatenate
90 | out_conv_redir = self.conv_redir(out_conv3a)
91 |
92 | in_conv3_1 = torch.cat((out_conv_redir, out_corr), 1)
93 |
94 | # Merged conv layers
95 | out_conv3_1 = self.conv3_1(in_conv3_1)
96 |
97 | out_conv4 = self.conv4_1(self.conv4(out_conv3_1))
98 |
99 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
100 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
101 |
102 | flow6 = self.predict_flow6(out_conv6)
103 | flow6_up = self.upsampled_flow6_to_5(flow6)
104 | out_deconv5 = self.deconv5(out_conv6)
105 |
106 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
107 |
108 | flow5 = self.predict_flow5(concat5)
109 | flow5_up = self.upsampled_flow5_to_4(flow5)
110 | out_deconv4 = self.deconv4(concat5)
111 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
112 |
113 | flow4 = self.predict_flow4(concat4)
114 | flow4_up = self.upsampled_flow4_to_3(flow4)
115 | out_deconv3 = self.deconv3(concat4)
116 | concat3 = torch.cat((out_conv3_1,out_deconv3,flow4_up),1)
117 |
118 | flow3 = self.predict_flow3(concat3)
119 | flow3_up = self.upsampled_flow3_to_2(flow3)
120 | out_deconv2 = self.deconv2(concat3)
121 | concat2 = torch.cat((out_conv2a,out_deconv2,flow3_up),1)
122 |
123 | flow2 = self.predict_flow2(concat2)
124 |
125 | if self.training:
126 | return flow2,flow3,flow4,flow5,flow6
127 | else:
128 | return flow2,
129 |
--------------------------------------------------------------------------------
/flownet2/networks/FlowNetFusion.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 |
5 | import math
6 | import numpy as np
7 |
8 | from .submodules import *
9 | 'Parameter count = 581,226'
10 |
11 | class FlowNetFusion(nn.Module):
12 | def __init__(self,args, batchNorm=True):
13 | super(FlowNetFusion,self).__init__()
14 |
15 | self.batchNorm = batchNorm
16 | self.conv0 = conv(self.batchNorm, 11, 64)
17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2)
18 | self.conv1_1 = conv(self.batchNorm, 64, 128)
19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2)
20 | self.conv2_1 = conv(self.batchNorm, 128, 128)
21 |
22 | self.deconv1 = deconv(128,32)
23 | self.deconv0 = deconv(162,16)
24 |
25 | self.inter_conv1 = i_conv(self.batchNorm, 162, 32)
26 | self.inter_conv0 = i_conv(self.batchNorm, 82, 16)
27 |
28 | self.predict_flow2 = predict_flow(128)
29 | self.predict_flow1 = predict_flow(32)
30 | self.predict_flow0 = predict_flow(16)
31 |
32 | self.upsampled_flow2_to_1 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
33 | self.upsampled_flow1_to_0 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
34 |
35 | for m in self.modules():
36 | if isinstance(m, nn.Conv2d):
37 | if m.bias is not None:
38 | init.uniform_(m.bias)
39 | init.xavier_uniform_(m.weight)
40 |
41 | if isinstance(m, nn.ConvTranspose2d):
42 | if m.bias is not None:
43 | init.uniform_(m.bias)
44 | init.xavier_uniform_(m.weight)
45 | # init_deconv_bilinear(m.weight)
46 |
47 | def forward(self, x):
48 | out_conv0 = self.conv0(x)
49 | out_conv1 = self.conv1_1(self.conv1(out_conv0))
50 | out_conv2 = self.conv2_1(self.conv2(out_conv1))
51 |
52 | flow2 = self.predict_flow2(out_conv2)
53 | flow2_up = self.upsampled_flow2_to_1(flow2)
54 | out_deconv1 = self.deconv1(out_conv2)
55 |
56 | concat1 = torch.cat((out_conv1,out_deconv1,flow2_up),1)
57 | out_interconv1 = self.inter_conv1(concat1)
58 | flow1 = self.predict_flow1(out_interconv1)
59 | flow1_up = self.upsampled_flow1_to_0(flow1)
60 | out_deconv0 = self.deconv0(concat1)
61 |
62 | concat0 = torch.cat((out_conv0,out_deconv0,flow1_up),1)
63 | out_interconv0 = self.inter_conv0(concat0)
64 | flow0 = self.predict_flow0(out_interconv0)
65 |
66 | return flow0
67 |
68 |
--------------------------------------------------------------------------------
/flownet2/networks/FlowNetS.py:
--------------------------------------------------------------------------------
1 | '''
2 | Portions of this code copyright 2017, Clement Pinard
3 | '''
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch.nn import init
8 |
9 | import math
10 | import numpy as np
11 |
12 | from .submodules import *
13 | 'Parameter count : 38,676,504 '
14 |
15 | class FlowNetS(nn.Module):
16 | def __init__(self, args, input_channels = 12, batchNorm=True):
17 | super(FlowNetS,self).__init__()
18 |
19 | self.batchNorm = batchNorm
20 | self.conv1 = conv(self.batchNorm, input_channels, 64, kernel_size=7, stride=2)
21 | self.conv2 = conv(self.batchNorm, 64, 128, kernel_size=5, stride=2)
22 | self.conv3 = conv(self.batchNorm, 128, 256, kernel_size=5, stride=2)
23 | self.conv3_1 = conv(self.batchNorm, 256, 256)
24 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
25 | self.conv4_1 = conv(self.batchNorm, 512, 512)
26 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
27 | self.conv5_1 = conv(self.batchNorm, 512, 512)
28 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
29 | self.conv6_1 = conv(self.batchNorm,1024, 1024)
30 |
31 | self.deconv5 = deconv(1024,512)
32 | self.deconv4 = deconv(1026,256)
33 | self.deconv3 = deconv(770,128)
34 | self.deconv2 = deconv(386,64)
35 |
36 | self.predict_flow6 = predict_flow(1024)
37 | self.predict_flow5 = predict_flow(1026)
38 | self.predict_flow4 = predict_flow(770)
39 | self.predict_flow3 = predict_flow(386)
40 | self.predict_flow2 = predict_flow(194)
41 |
42 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
43 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
44 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
45 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1, bias=False)
46 |
47 | for m in self.modules():
48 | if isinstance(m, nn.Conv2d):
49 | if m.bias is not None:
50 | init.uniform_(m.bias)
51 | init.xavier_uniform_(m.weight)
52 |
53 | if isinstance(m, nn.ConvTranspose2d):
54 | if m.bias is not None:
55 | init.uniform_(m.bias)
56 | init.xavier_uniform_(m.weight)
57 | # init_deconv_bilinear(m.weight)
58 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
59 |
60 | def forward(self, x):
61 | out_conv1 = self.conv1(x)
62 |
63 | out_conv2 = self.conv2(out_conv1)
64 | out_conv3 = self.conv3_1(self.conv3(out_conv2))
65 | out_conv4 = self.conv4_1(self.conv4(out_conv3))
66 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
67 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
68 |
69 | flow6 = self.predict_flow6(out_conv6)
70 | flow6_up = self.upsampled_flow6_to_5(flow6)
71 | out_deconv5 = self.deconv5(out_conv6)
72 |
73 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
74 | flow5 = self.predict_flow5(concat5)
75 | flow5_up = self.upsampled_flow5_to_4(flow5)
76 | out_deconv4 = self.deconv4(concat5)
77 |
78 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
79 | flow4 = self.predict_flow4(concat4)
80 | flow4_up = self.upsampled_flow4_to_3(flow4)
81 | out_deconv3 = self.deconv3(concat4)
82 |
83 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
84 | flow3 = self.predict_flow3(concat3)
85 | flow3_up = self.upsampled_flow3_to_2(flow3)
86 | out_deconv2 = self.deconv2(concat3)
87 |
88 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
89 | flow2 = self.predict_flow2(concat2)
90 |
91 | if self.training:
92 | return flow2,flow3,flow4,flow5,flow6
93 | else:
94 | return flow2,
95 |
96 |
--------------------------------------------------------------------------------
/flownet2/networks/FlowNetSD.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn import init
4 |
5 | import math
6 | import numpy as np
7 |
8 | from .submodules import *
9 | 'Parameter count = 45,371,666'
10 |
11 | class FlowNetSD(nn.Module):
12 | def __init__(self, args, batchNorm=True):
13 | super(FlowNetSD,self).__init__()
14 |
15 | self.batchNorm = batchNorm
16 | self.conv0 = conv(self.batchNorm, 6, 64)
17 | self.conv1 = conv(self.batchNorm, 64, 64, stride=2)
18 | self.conv1_1 = conv(self.batchNorm, 64, 128)
19 | self.conv2 = conv(self.batchNorm, 128, 128, stride=2)
20 | self.conv2_1 = conv(self.batchNorm, 128, 128)
21 | self.conv3 = conv(self.batchNorm, 128, 256, stride=2)
22 | self.conv3_1 = conv(self.batchNorm, 256, 256)
23 | self.conv4 = conv(self.batchNorm, 256, 512, stride=2)
24 | self.conv4_1 = conv(self.batchNorm, 512, 512)
25 | self.conv5 = conv(self.batchNorm, 512, 512, stride=2)
26 | self.conv5_1 = conv(self.batchNorm, 512, 512)
27 | self.conv6 = conv(self.batchNorm, 512, 1024, stride=2)
28 | self.conv6_1 = conv(self.batchNorm,1024, 1024)
29 |
30 | self.deconv5 = deconv(1024,512)
31 | self.deconv4 = deconv(1026,256)
32 | self.deconv3 = deconv(770,128)
33 | self.deconv2 = deconv(386,64)
34 |
35 | self.inter_conv5 = i_conv(self.batchNorm, 1026, 512)
36 | self.inter_conv4 = i_conv(self.batchNorm, 770, 256)
37 | self.inter_conv3 = i_conv(self.batchNorm, 386, 128)
38 | self.inter_conv2 = i_conv(self.batchNorm, 194, 64)
39 |
40 | self.predict_flow6 = predict_flow(1024)
41 | self.predict_flow5 = predict_flow(512)
42 | self.predict_flow4 = predict_flow(256)
43 | self.predict_flow3 = predict_flow(128)
44 | self.predict_flow2 = predict_flow(64)
45 |
46 | self.upsampled_flow6_to_5 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
47 | self.upsampled_flow5_to_4 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
48 | self.upsampled_flow4_to_3 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
49 | self.upsampled_flow3_to_2 = nn.ConvTranspose2d(2, 2, 4, 2, 1)
50 |
51 | for m in self.modules():
52 | if isinstance(m, nn.Conv2d):
53 | if m.bias is not None:
54 | init.uniform_(m.bias)
55 | init.xavier_uniform_(m.weight)
56 |
57 | if isinstance(m, nn.ConvTranspose2d):
58 | if m.bias is not None:
59 | init.uniform_(m.bias)
60 | init.xavier_uniform_(m.weight)
61 | # init_deconv_bilinear(m.weight)
62 | self.upsample1 = nn.Upsample(scale_factor=4, mode='bilinear')
63 |
64 |
65 |
66 | def forward(self, x):
67 | out_conv0 = self.conv0(x)
68 | out_conv1 = self.conv1_1(self.conv1(out_conv0))
69 | out_conv2 = self.conv2_1(self.conv2(out_conv1))
70 |
71 | out_conv3 = self.conv3_1(self.conv3(out_conv2))
72 | out_conv4 = self.conv4_1(self.conv4(out_conv3))
73 | out_conv5 = self.conv5_1(self.conv5(out_conv4))
74 | out_conv6 = self.conv6_1(self.conv6(out_conv5))
75 |
76 | flow6 = self.predict_flow6(out_conv6)
77 | flow6_up = self.upsampled_flow6_to_5(flow6)
78 | out_deconv5 = self.deconv5(out_conv6)
79 |
80 | concat5 = torch.cat((out_conv5,out_deconv5,flow6_up),1)
81 | out_interconv5 = self.inter_conv5(concat5)
82 | flow5 = self.predict_flow5(out_interconv5)
83 |
84 | flow5_up = self.upsampled_flow5_to_4(flow5)
85 | out_deconv4 = self.deconv4(concat5)
86 |
87 | concat4 = torch.cat((out_conv4,out_deconv4,flow5_up),1)
88 | out_interconv4 = self.inter_conv4(concat4)
89 | flow4 = self.predict_flow4(out_interconv4)
90 | flow4_up = self.upsampled_flow4_to_3(flow4)
91 | out_deconv3 = self.deconv3(concat4)
92 |
93 | concat3 = torch.cat((out_conv3,out_deconv3,flow4_up),1)
94 | out_interconv3 = self.inter_conv3(concat3)
95 | flow3 = self.predict_flow3(out_interconv3)
96 | flow3_up = self.upsampled_flow3_to_2(flow3)
97 | out_deconv2 = self.deconv2(concat3)
98 |
99 | concat2 = torch.cat((out_conv2,out_deconv2,flow3_up),1)
100 | out_interconv2 = self.inter_conv2(concat2)
101 | flow2 = self.predict_flow2(out_interconv2)
102 |
103 | if self.training:
104 | return flow2,flow3,flow4,flow5,flow6
105 | else:
106 | return flow2,
107 |
--------------------------------------------------------------------------------
/flownet2/networks/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/networks/__init__.py
--------------------------------------------------------------------------------
/flownet2/networks/channelnorm_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/networks/channelnorm_package/__init__.py
--------------------------------------------------------------------------------
/flownet2/networks/channelnorm_package/channelnorm.py:
--------------------------------------------------------------------------------
1 | from torch.autograd import Function, Variable
2 | from torch.nn.modules.module import Module
3 | import channelnorm_cuda
4 |
5 | class ChannelNormFunction(Function):
6 |
7 | @staticmethod
8 | def forward(ctx, input1, norm_deg=2):
9 | assert input1.is_contiguous()
10 | b, _, h, w = input1.size()
11 | output = input1.new(b, 1, h, w).zero_()
12 |
13 | channelnorm_cuda.forward(input1, output, norm_deg)
14 | ctx.save_for_backward(input1, output)
15 | ctx.norm_deg = norm_deg
16 |
17 | return output
18 |
19 | @staticmethod
20 | def backward(ctx, grad_output):
21 | input1, output = ctx.saved_tensors
22 |
23 | grad_input1 = Variable(input1.new(input1.size()).zero_())
24 |
25 | channelnorm.backward(input1, output, grad_output.data,
26 | grad_input1.data, ctx.norm_deg)
27 |
28 | return grad_input1, None
29 |
30 |
31 | class ChannelNorm(Module):
32 |
33 | def __init__(self, norm_deg=2):
34 | super(ChannelNorm, self).__init__()
35 | self.norm_deg = norm_deg
36 |
37 | def forward(self, input1):
38 | return ChannelNormFunction.apply(input1, self.norm_deg)
39 |
40 |
--------------------------------------------------------------------------------
/flownet2/networks/channelnorm_package/channelnorm_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "channelnorm_kernel.cuh"
5 |
6 | int channelnorm_cuda_forward(
7 | at::Tensor& input1,
8 | at::Tensor& output,
9 | int norm_deg) {
10 |
11 | channelnorm_kernel_forward(input1, output, norm_deg);
12 | return 1;
13 | }
14 |
15 |
16 | int channelnorm_cuda_backward(
17 | at::Tensor& input1,
18 | at::Tensor& output,
19 | at::Tensor& gradOutput,
20 | at::Tensor& gradInput1,
21 | int norm_deg) {
22 |
23 | channelnorm_kernel_backward(input1, output, gradOutput, gradInput1, norm_deg);
24 | return 1;
25 | }
26 |
27 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
28 | m.def("forward", &channelnorm_cuda_forward, "Channel norm forward (CUDA)");
29 | m.def("backward", &channelnorm_cuda_backward, "Channel norm backward (CUDA)");
30 | }
31 |
32 |
--------------------------------------------------------------------------------
/flownet2/networks/channelnorm_package/channelnorm_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "channelnorm_kernel.cuh"
5 |
6 | #define CUDA_NUM_THREADS 512
7 |
8 | #define DIM0(TENSOR) ((TENSOR).x)
9 | #define DIM1(TENSOR) ((TENSOR).y)
10 | #define DIM2(TENSOR) ((TENSOR).z)
11 | #define DIM3(TENSOR) ((TENSOR).w)
12 |
13 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
14 |
15 | using at::Half;
16 |
17 | template
18 | __global__ void kernel_channelnorm_update_output(
19 | const int n,
20 | const scalar_t* __restrict__ input1,
21 | const long4 input1_size,
22 | const long4 input1_stride,
23 | scalar_t* __restrict__ output,
24 | const long4 output_size,
25 | const long4 output_stride,
26 | int norm_deg) {
27 |
28 | int index = blockIdx.x * blockDim.x + threadIdx.x;
29 |
30 | if (index >= n) {
31 | return;
32 | }
33 |
34 | int dim_b = DIM0(output_size);
35 | int dim_c = DIM1(output_size);
36 | int dim_h = DIM2(output_size);
37 | int dim_w = DIM3(output_size);
38 | int dim_chw = dim_c * dim_h * dim_w;
39 |
40 | int b = ( index / dim_chw ) % dim_b;
41 | int y = ( index / dim_w ) % dim_h;
42 | int x = ( index ) % dim_w;
43 |
44 | int i1dim_c = DIM1(input1_size);
45 | int i1dim_h = DIM2(input1_size);
46 | int i1dim_w = DIM3(input1_size);
47 | int i1dim_chw = i1dim_c * i1dim_h * i1dim_w;
48 | int i1dim_hw = i1dim_h * i1dim_w;
49 |
50 | float result = 0.0;
51 |
52 | for (int c = 0; c < i1dim_c; ++c) {
53 | int i1Index = b * i1dim_chw + c * i1dim_hw + y * i1dim_w + x;
54 | scalar_t val = input1[i1Index];
55 | result += static_cast(val * val);
56 | }
57 | result = sqrt(result);
58 | output[index] = static_cast(result);
59 | }
60 |
61 |
62 | template
63 | __global__ void kernel_channelnorm_backward_input1(
64 | const int n,
65 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
66 | const scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride,
67 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
68 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride,
69 | int norm_deg) {
70 |
71 | int index = blockIdx.x * blockDim.x + threadIdx.x;
72 |
73 | if (index >= n) {
74 | return;
75 | }
76 |
77 | float val = 0.0;
78 |
79 | int dim_b = DIM0(gradInput_size);
80 | int dim_c = DIM1(gradInput_size);
81 | int dim_h = DIM2(gradInput_size);
82 | int dim_w = DIM3(gradInput_size);
83 | int dim_chw = dim_c * dim_h * dim_w;
84 | int dim_hw = dim_h * dim_w;
85 |
86 | int b = ( index / dim_chw ) % dim_b;
87 | int y = ( index / dim_w ) % dim_h;
88 | int x = ( index ) % dim_w;
89 |
90 |
91 | int outIndex = b * dim_hw + y * dim_w + x;
92 | val = static_cast(gradOutput[outIndex]) * static_cast(input1[index]) / (static_cast(output[outIndex])+1e-9);
93 | gradInput[index] = static_cast(val);
94 |
95 | }
96 |
97 | void channelnorm_kernel_forward(
98 | at::Tensor& input1,
99 | at::Tensor& output,
100 | int norm_deg) {
101 |
102 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
103 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
104 |
105 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
106 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
107 |
108 | int n = output.numel();
109 |
110 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_forward", ([&] {
111 |
112 | kernel_channelnorm_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
113 | n,
114 | input1.data(),
115 | input1_size,
116 | input1_stride,
117 | output.data(),
118 | output_size,
119 | output_stride,
120 | norm_deg);
121 |
122 | }));
123 |
124 | // TODO: ATen-equivalent check
125 |
126 | // THCudaCheck(cudaGetLastError());
127 | }
128 |
129 | void channelnorm_kernel_backward(
130 | at::Tensor& input1,
131 | at::Tensor& output,
132 | at::Tensor& gradOutput,
133 | at::Tensor& gradInput1,
134 | int norm_deg) {
135 |
136 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
137 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
138 |
139 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
140 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
141 |
142 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
143 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
144 |
145 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
146 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
147 |
148 | int n = gradInput1.numel();
149 |
150 | AT_DISPATCH_FLOATING_TYPES_AND_HALF(input1.type(), "channelnorm_backward_input1", ([&] {
151 |
152 | kernel_channelnorm_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
153 | n,
154 | input1.data(),
155 | input1_size,
156 | input1_stride,
157 | output.data(),
158 | output_size,
159 | output_stride,
160 | gradOutput.data(),
161 | gradOutput_size,
162 | gradOutput_stride,
163 | gradInput1.data(),
164 | gradInput1_size,
165 | gradInput1_stride,
166 | norm_deg
167 | );
168 |
169 | }));
170 |
171 | // TODO: Add ATen-equivalent check
172 |
173 | // THCudaCheck(cudaGetLastError());
174 | }
175 |
--------------------------------------------------------------------------------
/flownet2/networks/channelnorm_package/channelnorm_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void channelnorm_kernel_forward(
6 | at::Tensor& input1,
7 | at::Tensor& output,
8 | int norm_deg);
9 |
10 |
11 | void channelnorm_kernel_backward(
12 | at::Tensor& input1,
13 | at::Tensor& output,
14 | at::Tensor& gradOutput,
15 | at::Tensor& gradInput1,
16 | int norm_deg);
17 |
--------------------------------------------------------------------------------
/flownet2/networks/channelnorm_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_52,code=sm_52',
12 | '-gencode', 'arch=compute_60,code=sm_60',
13 | '-gencode', 'arch=compute_61,code=sm_61',
14 | '-gencode', 'arch=compute_70,code=sm_70',
15 | '-gencode', 'arch=compute_70,code=compute_70'
16 | ]
17 |
18 | setup(
19 | name='channelnorm_cuda',
20 | ext_modules=[
21 | CUDAExtension('channelnorm_cuda', [
22 | 'channelnorm_cuda.cc',
23 | 'channelnorm_kernel.cu'
24 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
25 | ],
26 | cmdclass={
27 | 'build_ext': BuildExtension
28 | })
29 |
--------------------------------------------------------------------------------
/flownet2/networks/correlation_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/networks/correlation_package/__init__.py
--------------------------------------------------------------------------------
/flownet2/networks/correlation_package/correlation.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn.modules.module import Module
3 | from torch.autograd import Function
4 | import correlation_cuda
5 |
6 | class CorrelationFunction(Function):
7 |
8 | def __init__(self, pad_size=3, kernel_size=3, max_displacement=20, stride1=1, stride2=2, corr_multiply=1):
9 | super(CorrelationFunction, self).__init__()
10 | self.pad_size = pad_size
11 | self.kernel_size = kernel_size
12 | self.max_displacement = max_displacement
13 | self.stride1 = stride1
14 | self.stride2 = stride2
15 | self.corr_multiply = corr_multiply
16 | # self.out_channel = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1)
17 |
18 | def forward(self, input1, input2):
19 | self.save_for_backward(input1, input2)
20 |
21 | with torch.cuda.device_of(input1):
22 | rbot1 = input1.new()
23 | rbot2 = input2.new()
24 | output = input1.new()
25 |
26 | correlation_cuda.forward(input1, input2, rbot1, rbot2, output,
27 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
28 |
29 | return output
30 |
31 | def backward(self, grad_output):
32 | input1, input2 = self.saved_tensors
33 |
34 | with torch.cuda.device_of(input1):
35 | rbot1 = input1.new()
36 | rbot2 = input2.new()
37 |
38 | grad_input1 = input1.new()
39 | grad_input2 = input2.new()
40 |
41 | correlation_cuda.backward(input1, input2, rbot1, rbot2, grad_output, grad_input1, grad_input2,
42 | self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)
43 |
44 | return grad_input1, grad_input2
45 |
46 |
47 | class Correlation(Module):
48 | def __init__(self, pad_size=0, kernel_size=0, max_displacement=0, stride1=1, stride2=2, corr_multiply=1):
49 | super(Correlation, self).__init__()
50 | self.pad_size = pad_size
51 | self.kernel_size = kernel_size
52 | self.max_displacement = max_displacement
53 | self.stride1 = stride1
54 | self.stride2 = stride2
55 | self.corr_multiply = corr_multiply
56 |
57 | def forward(self, input1, input2):
58 |
59 | result = CorrelationFunction(self.pad_size, self.kernel_size, self.max_displacement,self.stride1, self.stride2, self.corr_multiply)(input1, input2)
60 |
61 | return result
62 |
63 |
--------------------------------------------------------------------------------
/flownet2/networks/correlation_package/correlation_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 | #include
4 | #include
5 |
6 | #include "correlation_cuda_kernel.cuh"
7 |
8 | int correlation_forward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& output,
9 | int pad_size,
10 | int kernel_size,
11 | int max_displacement,
12 | int stride1,
13 | int stride2,
14 | int corr_type_multiply)
15 | {
16 |
17 | int batchSize = input1.size(0);
18 |
19 | int nInputChannels = input1.size(1);
20 | int inputHeight = input1.size(2);
21 | int inputWidth = input1.size(3);
22 |
23 | int kernel_radius = (kernel_size - 1) / 2;
24 | int border_radius = kernel_radius + max_displacement;
25 |
26 | int paddedInputHeight = inputHeight + 2 * pad_size;
27 | int paddedInputWidth = inputWidth + 2 * pad_size;
28 |
29 | int nOutputChannels = ((max_displacement/stride2)*2 + 1) * ((max_displacement/stride2)*2 + 1);
30 |
31 | int outputHeight = ceil(static_cast(paddedInputHeight - 2 * border_radius) / static_cast(stride1));
32 | int outputwidth = ceil(static_cast(paddedInputWidth - 2 * border_radius) / static_cast(stride1));
33 |
34 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
35 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
36 | output.resize_({batchSize, nOutputChannels, outputHeight, outputwidth});
37 |
38 | rInput1.fill_(0);
39 | rInput2.fill_(0);
40 | output.fill_(0);
41 |
42 | int success = correlation_forward_cuda_kernel(
43 | output,
44 | output.size(0),
45 | output.size(1),
46 | output.size(2),
47 | output.size(3),
48 | output.stride(0),
49 | output.stride(1),
50 | output.stride(2),
51 | output.stride(3),
52 | input1,
53 | input1.size(1),
54 | input1.size(2),
55 | input1.size(3),
56 | input1.stride(0),
57 | input1.stride(1),
58 | input1.stride(2),
59 | input1.stride(3),
60 | input2,
61 | input2.size(1),
62 | input2.stride(0),
63 | input2.stride(1),
64 | input2.stride(2),
65 | input2.stride(3),
66 | rInput1,
67 | rInput2,
68 | pad_size,
69 | kernel_size,
70 | max_displacement,
71 | stride1,
72 | stride2,
73 | corr_type_multiply,
74 | at::globalContext().getCurrentCUDAStream()
75 | );
76 |
77 | //check for errors
78 | if (!success) {
79 | AT_ERROR("CUDA call failed");
80 | }
81 |
82 | return 1;
83 |
84 | }
85 |
86 | int correlation_backward_cuda(at::Tensor& input1, at::Tensor& input2, at::Tensor& rInput1, at::Tensor& rInput2, at::Tensor& gradOutput,
87 | at::Tensor& gradInput1, at::Tensor& gradInput2,
88 | int pad_size,
89 | int kernel_size,
90 | int max_displacement,
91 | int stride1,
92 | int stride2,
93 | int corr_type_multiply)
94 | {
95 |
96 | int batchSize = input1.size(0);
97 | int nInputChannels = input1.size(1);
98 | int paddedInputHeight = input1.size(2)+ 2 * pad_size;
99 | int paddedInputWidth = input1.size(3)+ 2 * pad_size;
100 |
101 | int height = input1.size(2);
102 | int width = input1.size(3);
103 |
104 | rInput1.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
105 | rInput2.resize_({batchSize, paddedInputHeight, paddedInputWidth, nInputChannels});
106 | gradInput1.resize_({batchSize, nInputChannels, height, width});
107 | gradInput2.resize_({batchSize, nInputChannels, height, width});
108 |
109 | rInput1.fill_(0);
110 | rInput2.fill_(0);
111 | gradInput1.fill_(0);
112 | gradInput2.fill_(0);
113 |
114 | int success = correlation_backward_cuda_kernel(gradOutput,
115 | gradOutput.size(0),
116 | gradOutput.size(1),
117 | gradOutput.size(2),
118 | gradOutput.size(3),
119 | gradOutput.stride(0),
120 | gradOutput.stride(1),
121 | gradOutput.stride(2),
122 | gradOutput.stride(3),
123 | input1,
124 | input1.size(1),
125 | input1.size(2),
126 | input1.size(3),
127 | input1.stride(0),
128 | input1.stride(1),
129 | input1.stride(2),
130 | input1.stride(3),
131 | input2,
132 | input2.stride(0),
133 | input2.stride(1),
134 | input2.stride(2),
135 | input2.stride(3),
136 | gradInput1,
137 | gradInput1.stride(0),
138 | gradInput1.stride(1),
139 | gradInput1.stride(2),
140 | gradInput1.stride(3),
141 | gradInput2,
142 | gradInput2.size(1),
143 | gradInput2.stride(0),
144 | gradInput2.stride(1),
145 | gradInput2.stride(2),
146 | gradInput2.stride(3),
147 | rInput1,
148 | rInput2,
149 | pad_size,
150 | kernel_size,
151 | max_displacement,
152 | stride1,
153 | stride2,
154 | corr_type_multiply,
155 | at::globalContext().getCurrentCUDAStream()
156 | );
157 |
158 | if (!success) {
159 | AT_ERROR("CUDA call failed");
160 | }
161 |
162 | return 1;
163 | }
164 |
165 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
166 | m.def("forward", &correlation_forward_cuda, "Correlation forward (CUDA)");
167 | m.def("backward", &correlation_backward_cuda, "Correlation backward (CUDA)");
168 | }
169 |
170 |
--------------------------------------------------------------------------------
/flownet2/networks/correlation_package/correlation_cuda_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 | #include
5 | #include
6 |
7 | int correlation_forward_cuda_kernel(at::Tensor& output,
8 | int ob,
9 | int oc,
10 | int oh,
11 | int ow,
12 | int osb,
13 | int osc,
14 | int osh,
15 | int osw,
16 |
17 | at::Tensor& input1,
18 | int ic,
19 | int ih,
20 | int iw,
21 | int isb,
22 | int isc,
23 | int ish,
24 | int isw,
25 |
26 | at::Tensor& input2,
27 | int gc,
28 | int gsb,
29 | int gsc,
30 | int gsh,
31 | int gsw,
32 |
33 | at::Tensor& rInput1,
34 | at::Tensor& rInput2,
35 | int pad_size,
36 | int kernel_size,
37 | int max_displacement,
38 | int stride1,
39 | int stride2,
40 | int corr_type_multiply,
41 | cudaStream_t stream);
42 |
43 |
44 | int correlation_backward_cuda_kernel(
45 | at::Tensor& gradOutput,
46 | int gob,
47 | int goc,
48 | int goh,
49 | int gow,
50 | int gosb,
51 | int gosc,
52 | int gosh,
53 | int gosw,
54 |
55 | at::Tensor& input1,
56 | int ic,
57 | int ih,
58 | int iw,
59 | int isb,
60 | int isc,
61 | int ish,
62 | int isw,
63 |
64 | at::Tensor& input2,
65 | int gsb,
66 | int gsc,
67 | int gsh,
68 | int gsw,
69 |
70 | at::Tensor& gradInput1,
71 | int gisb,
72 | int gisc,
73 | int gish,
74 | int gisw,
75 |
76 | at::Tensor& gradInput2,
77 | int ggc,
78 | int ggsb,
79 | int ggsc,
80 | int ggsh,
81 | int ggsw,
82 |
83 | at::Tensor& rInput1,
84 | at::Tensor& rInput2,
85 | int pad_size,
86 | int kernel_size,
87 | int max_displacement,
88 | int stride1,
89 | int stride2,
90 | int corr_type_multiply,
91 | cudaStream_t stream);
92 |
--------------------------------------------------------------------------------
/flownet2/networks/correlation_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup, find_packages
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_50,code=sm_50',
12 | '-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | '-gencode', 'arch=compute_70,code=sm_70',
16 | '-gencode', 'arch=compute_70,code=compute_70'
17 | ]
18 |
19 | setup(
20 | name='correlation_cuda',
21 | ext_modules=[
22 | CUDAExtension('correlation_cuda', [
23 | 'correlation_cuda.cc',
24 | 'correlation_cuda_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
--------------------------------------------------------------------------------
/flownet2/networks/resample2d_package/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/networks/resample2d_package/__init__.py
--------------------------------------------------------------------------------
/flownet2/networks/resample2d_package/resample2d.py:
--------------------------------------------------------------------------------
1 | from torch.nn.modules.module import Module
2 | from torch.autograd import Function, Variable
3 | import resample2d_cuda
4 |
5 | class Resample2dFunction(Function):
6 |
7 | @staticmethod
8 | def forward(ctx, input1, input2, kernel_size=1):
9 | assert input1.is_contiguous()
10 | assert input2.is_contiguous()
11 |
12 | ctx.save_for_backward(input1, input2)
13 | ctx.kernel_size = kernel_size
14 |
15 | _, d, _, _ = input1.size()
16 | b, _, h, w = input2.size()
17 | output = input1.new(b, d, h, w).zero_()
18 |
19 | resample2d_cuda.forward(input1, input2, output, kernel_size)
20 |
21 | return output
22 |
23 | @staticmethod
24 | def backward(ctx, grad_output):
25 | assert grad_output.is_contiguous()
26 |
27 | input1, input2 = ctx.saved_tensors
28 |
29 | grad_input1 = Variable(input1.new(input1.size()).zero_())
30 | grad_input2 = Variable(input1.new(input2.size()).zero_())
31 |
32 | resample2d_cuda.backward(input1, input2, grad_output.data,
33 | grad_input1.data, grad_input2.data,
34 | ctx.kernel_size)
35 |
36 | return grad_input1, grad_input2, None
37 |
38 | class Resample2d(Module):
39 |
40 | def __init__(self, kernel_size=1):
41 | super(Resample2d, self).__init__()
42 | self.kernel_size = kernel_size
43 |
44 | def forward(self, input1, input2):
45 | input1_c = input1.contiguous()
46 | return Resample2dFunction.apply(input1_c, input2, self.kernel_size)
47 |
--------------------------------------------------------------------------------
/flownet2/networks/resample2d_package/resample2d_cuda.cc:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #include "resample2d_kernel.cuh"
5 |
6 | int resample2d_cuda_forward(
7 | at::Tensor& input1,
8 | at::Tensor& input2,
9 | at::Tensor& output,
10 | int kernel_size) {
11 | resample2d_kernel_forward(input1, input2, output, kernel_size);
12 | return 1;
13 | }
14 |
15 | int resample2d_cuda_backward(
16 | at::Tensor& input1,
17 | at::Tensor& input2,
18 | at::Tensor& gradOutput,
19 | at::Tensor& gradInput1,
20 | at::Tensor& gradInput2,
21 | int kernel_size) {
22 | resample2d_kernel_backward(input1, input2, gradOutput, gradInput1, gradInput2, kernel_size);
23 | return 1;
24 | }
25 |
26 |
27 |
28 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
29 | m.def("forward", &resample2d_cuda_forward, "Resample2D forward (CUDA)");
30 | m.def("backward", &resample2d_cuda_backward, "Resample2D backward (CUDA)");
31 | }
32 |
33 |
--------------------------------------------------------------------------------
/flownet2/networks/resample2d_package/resample2d_kernel.cu:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | #define CUDA_NUM_THREADS 512
5 | #define THREADS_PER_BLOCK 64
6 |
7 | #define DIM0(TENSOR) ((TENSOR).x)
8 | #define DIM1(TENSOR) ((TENSOR).y)
9 | #define DIM2(TENSOR) ((TENSOR).z)
10 | #define DIM3(TENSOR) ((TENSOR).w)
11 |
12 | #define DIM3_INDEX(TENSOR, xx, yy, zz, ww) ((TENSOR)[((xx) * (TENSOR##_stride.x)) + ((yy) * (TENSOR##_stride.y)) + ((zz) * (TENSOR##_stride.z)) + ((ww) * (TENSOR##_stride.w))])
13 |
14 | template
15 | __global__ void kernel_resample2d_update_output(const int n,
16 | const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
17 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
18 | scalar_t* __restrict__ output, const long4 output_size, const long4 output_stride, int kernel_size) {
19 | int index = blockIdx.x * blockDim.x + threadIdx.x;
20 |
21 | if (index >= n) {
22 | return;
23 | }
24 |
25 | scalar_t val = 0.0f;
26 |
27 | int dim_b = DIM0(output_size);
28 | int dim_c = DIM1(output_size);
29 | int dim_h = DIM2(output_size);
30 | int dim_w = DIM3(output_size);
31 | int dim_chw = dim_c * dim_h * dim_w;
32 | int dim_hw = dim_h * dim_w;
33 |
34 | int b = ( index / dim_chw ) % dim_b;
35 | int c = ( index / dim_hw ) % dim_c;
36 | int y = ( index / dim_w ) % dim_h;
37 | int x = ( index ) % dim_w;
38 |
39 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
40 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
41 |
42 | scalar_t xf = static_cast(x) + dx;
43 | scalar_t yf = static_cast(y) + dy;
44 | scalar_t alpha = xf - floor(xf); // alpha
45 | scalar_t beta = yf - floor(yf); // beta
46 |
47 | int xL = max(min( int (floor(xf)), dim_w-1), 0);
48 | int xR = max(min( int (floor(xf)+1), dim_w -1), 0);
49 | int yT = max(min( int (floor(yf)), dim_h-1), 0);
50 | int yB = max(min( int (floor(yf)+1), dim_h-1), 0);
51 |
52 | for (int fy = 0; fy < kernel_size; fy += 1) {
53 | for (int fx = 0; fx < kernel_size; fx += 1) {
54 | val += static_cast((1. - alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xL + fx));
55 | val += static_cast((alpha)*(1. - beta) * DIM3_INDEX(input1, b, c, yT + fy, xR + fx));
56 | val += static_cast((1. - alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xL + fx));
57 | val += static_cast((alpha)*(beta) * DIM3_INDEX(input1, b, c, yB + fy, xR + fx));
58 | }
59 | }
60 |
61 | output[index] = val;
62 |
63 | }
64 |
65 |
66 | template
67 | __global__ void kernel_resample2d_backward_input1(
68 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
69 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
70 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
71 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) {
72 |
73 | int index = blockIdx.x * blockDim.x + threadIdx.x;
74 |
75 | if (index >= n) {
76 | return;
77 | }
78 |
79 | int dim_b = DIM0(gradOutput_size);
80 | int dim_c = DIM1(gradOutput_size);
81 | int dim_h = DIM2(gradOutput_size);
82 | int dim_w = DIM3(gradOutput_size);
83 | int dim_chw = dim_c * dim_h * dim_w;
84 | int dim_hw = dim_h * dim_w;
85 |
86 | int b = ( index / dim_chw ) % dim_b;
87 | int c = ( index / dim_hw ) % dim_c;
88 | int y = ( index / dim_w ) % dim_h;
89 | int x = ( index ) % dim_w;
90 |
91 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
92 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
93 |
94 | scalar_t xf = static_cast(x) + dx;
95 | scalar_t yf = static_cast(y) + dy;
96 | scalar_t alpha = xf - int(xf); // alpha
97 | scalar_t beta = yf - int(yf); // beta
98 |
99 | int idim_h = DIM2(input1_size);
100 | int idim_w = DIM3(input1_size);
101 |
102 | int xL = max(min( int (floor(xf)), idim_w-1), 0);
103 | int xR = max(min( int (floor(xf)+1), idim_w -1), 0);
104 | int yT = max(min( int (floor(yf)), idim_h-1), 0);
105 | int yB = max(min( int (floor(yf)+1), idim_h-1), 0);
106 |
107 | for (int fy = 0; fy < kernel_size; fy += 1) {
108 | for (int fx = 0; fx < kernel_size; fx += 1) {
109 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xL + fx)), (1-alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x));
110 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yT + fy), (xR + fx)), (alpha)*(1-beta) * DIM3_INDEX(gradOutput, b, c, y, x));
111 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xL + fx)), (1-alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x));
112 | atomicAdd(&DIM3_INDEX(gradInput, b, c, (yB + fy), (xR + fx)), (alpha)*(beta) * DIM3_INDEX(gradOutput, b, c, y, x));
113 | }
114 | }
115 |
116 | }
117 |
118 | template
119 | __global__ void kernel_resample2d_backward_input2(
120 | const int n, const scalar_t* __restrict__ input1, const long4 input1_size, const long4 input1_stride,
121 | const scalar_t* __restrict__ input2, const long4 input2_size, const long4 input2_stride,
122 | const scalar_t* __restrict__ gradOutput, const long4 gradOutput_size, const long4 gradOutput_stride,
123 | scalar_t* __restrict__ gradInput, const long4 gradInput_size, const long4 gradInput_stride, int kernel_size) {
124 |
125 | int index = blockIdx.x * blockDim.x + threadIdx.x;
126 |
127 | if (index >= n) {
128 | return;
129 | }
130 |
131 | scalar_t output = 0.0;
132 | int kernel_rad = (kernel_size - 1)/2;
133 |
134 | int dim_b = DIM0(gradInput_size);
135 | int dim_c = DIM1(gradInput_size);
136 | int dim_h = DIM2(gradInput_size);
137 | int dim_w = DIM3(gradInput_size);
138 | int dim_chw = dim_c * dim_h * dim_w;
139 | int dim_hw = dim_h * dim_w;
140 |
141 | int b = ( index / dim_chw ) % dim_b;
142 | int c = ( index / dim_hw ) % dim_c;
143 | int y = ( index / dim_w ) % dim_h;
144 | int x = ( index ) % dim_w;
145 |
146 | int odim_c = DIM1(gradOutput_size);
147 |
148 | scalar_t dx = DIM3_INDEX(input2, b, 0, y, x);
149 | scalar_t dy = DIM3_INDEX(input2, b, 1, y, x);
150 |
151 | scalar_t xf = static_cast(x) + dx;
152 | scalar_t yf = static_cast(y) + dy;
153 |
154 | int xL = max(min( int (floor(xf)), dim_w-1), 0);
155 | int xR = max(min( int (floor(xf)+1), dim_w -1), 0);
156 | int yT = max(min( int (floor(yf)), dim_h-1), 0);
157 | int yB = max(min( int (floor(yf)+1), dim_h-1), 0);
158 |
159 | if (c % 2) {
160 | float gamma = 1 - (xf - floor(xf)); // alpha
161 | for (int i = 0; i <= 2*kernel_rad; ++i) {
162 | for (int j = 0; j <= 2*kernel_rad; ++j) {
163 | for (int ch = 0; ch < odim_c; ++ch) {
164 | output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i));
165 | output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i));
166 | output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i));
167 | output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i));
168 | }
169 | }
170 | }
171 | }
172 | else {
173 | float gamma = 1 - (yf - floor(yf)); // alpha
174 | for (int i = 0; i <= 2*kernel_rad; ++i) {
175 | for (int j = 0; j <= 2*kernel_rad; ++j) {
176 | for (int ch = 0; ch < odim_c; ++ch) {
177 | output += (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xR + i));
178 | output -= (gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yT + j), (xL + i));
179 | output += (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xR + i));
180 | output -= (1-gamma) * DIM3_INDEX(gradOutput, b, ch, y, x) * DIM3_INDEX(input1, b, ch, (yB + j), (xL + i));
181 | }
182 | }
183 | }
184 |
185 | }
186 |
187 | gradInput[index] = output;
188 |
189 | }
190 |
191 | void resample2d_kernel_forward(
192 | at::Tensor& input1,
193 | at::Tensor& input2,
194 | at::Tensor& output,
195 | int kernel_size) {
196 |
197 | int n = output.numel();
198 |
199 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
200 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
201 |
202 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
203 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
204 |
205 | const long4 output_size = make_long4(output.size(0), output.size(1), output.size(2), output.size(3));
206 | const long4 output_stride = make_long4(output.stride(0), output.stride(1), output.stride(2), output.stride(3));
207 |
208 | // TODO: when atomicAdd gets resolved, change to AT_DISPATCH_FLOATING_TYPES_AND_HALF
209 | // AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_forward_kernel", ([&] {
210 |
211 | kernel_resample2d_update_output<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
212 | n,
213 | input1.data(),
214 | input1_size,
215 | input1_stride,
216 | input2.data(),
217 | input2_size,
218 | input2_stride,
219 | output.data(),
220 | output_size,
221 | output_stride,
222 | kernel_size);
223 |
224 | // }));
225 |
226 | // TODO: ATen-equivalent check
227 |
228 | // THCudaCheck(cudaGetLastError());
229 |
230 | }
231 |
232 | void resample2d_kernel_backward(
233 | at::Tensor& input1,
234 | at::Tensor& input2,
235 | at::Tensor& gradOutput,
236 | at::Tensor& gradInput1,
237 | at::Tensor& gradInput2,
238 | int kernel_size) {
239 |
240 | int n = gradOutput.numel();
241 |
242 | const long4 input1_size = make_long4(input1.size(0), input1.size(1), input1.size(2), input1.size(3));
243 | const long4 input1_stride = make_long4(input1.stride(0), input1.stride(1), input1.stride(2), input1.stride(3));
244 |
245 | const long4 input2_size = make_long4(input2.size(0), input2.size(1), input2.size(2), input2.size(3));
246 | const long4 input2_stride = make_long4(input2.stride(0), input2.stride(1), input2.stride(2), input2.stride(3));
247 |
248 | const long4 gradOutput_size = make_long4(gradOutput.size(0), gradOutput.size(1), gradOutput.size(2), gradOutput.size(3));
249 | const long4 gradOutput_stride = make_long4(gradOutput.stride(0), gradOutput.stride(1), gradOutput.stride(2), gradOutput.stride(3));
250 |
251 | const long4 gradInput1_size = make_long4(gradInput1.size(0), gradInput1.size(1), gradInput1.size(2), gradInput1.size(3));
252 | const long4 gradInput1_stride = make_long4(gradInput1.stride(0), gradInput1.stride(1), gradInput1.stride(2), gradInput1.stride(3));
253 |
254 | // AT_DISPATCH_FLOATING_TYPES(input1.type(), "resample_backward_input1", ([&] {
255 |
256 | kernel_resample2d_backward_input1<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
257 | n,
258 | input1.data(),
259 | input1_size,
260 | input1_stride,
261 | input2.data(),
262 | input2_size,
263 | input2_stride,
264 | gradOutput.data(),
265 | gradOutput_size,
266 | gradOutput_stride,
267 | gradInput1.data(),
268 | gradInput1_size,
269 | gradInput1_stride,
270 | kernel_size
271 | );
272 |
273 | // }));
274 |
275 | const long4 gradInput2_size = make_long4(gradInput2.size(0), gradInput2.size(1), gradInput2.size(2), gradInput2.size(3));
276 | const long4 gradInput2_stride = make_long4(gradInput2.stride(0), gradInput2.stride(1), gradInput2.stride(2), gradInput2.stride(3));
277 |
278 | n = gradInput2.numel();
279 |
280 | // AT_DISPATCH_FLOATING_TYPES(gradInput2.type(), "resample_backward_input2", ([&] {
281 |
282 |
283 | kernel_resample2d_backward_input2<<< (n + CUDA_NUM_THREADS - 1)/CUDA_NUM_THREADS, CUDA_NUM_THREADS, 0, at::globalContext().getCurrentCUDAStream() >>>(
284 | n,
285 | input1.data(),
286 | input1_size,
287 | input1_stride,
288 | input2.data(),
289 | input2_size,
290 | input2_stride,
291 | gradOutput.data(),
292 | gradOutput_size,
293 | gradOutput_stride,
294 | gradInput2.data(),
295 | gradInput2_size,
296 | gradInput2_stride,
297 | kernel_size
298 | );
299 |
300 | // }));
301 |
302 | // TODO: Use the ATen equivalent to get last error
303 |
304 | // THCudaCheck(cudaGetLastError());
305 |
306 | }
307 |
--------------------------------------------------------------------------------
/flownet2/networks/resample2d_package/resample2d_kernel.cuh:
--------------------------------------------------------------------------------
1 | #pragma once
2 |
3 | #include
4 |
5 | void resample2d_kernel_forward(
6 | at::Tensor& input1,
7 | at::Tensor& input2,
8 | at::Tensor& output,
9 | int kernel_size);
10 |
11 | void resample2d_kernel_backward(
12 | at::Tensor& input1,
13 | at::Tensor& input2,
14 | at::Tensor& gradOutput,
15 | at::Tensor& gradInput1,
16 | at::Tensor& gradInput2,
17 | int kernel_size);
18 |
--------------------------------------------------------------------------------
/flownet2/networks/resample2d_package/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | import os
3 | import torch
4 |
5 | from setuptools import setup
6 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
7 |
8 | cxx_args = ['-std=c++11']
9 |
10 | nvcc_args = [
11 | '-gencode', 'arch=compute_50,code=sm_50',
12 | '-gencode', 'arch=compute_52,code=sm_52',
13 | '-gencode', 'arch=compute_60,code=sm_60',
14 | '-gencode', 'arch=compute_61,code=sm_61',
15 | '-gencode', 'arch=compute_70,code=sm_70',
16 | '-gencode', 'arch=compute_70,code=compute_70'
17 | ]
18 |
19 | setup(
20 | name='resample2d_cuda',
21 | ext_modules=[
22 | CUDAExtension('resample2d_cuda', [
23 | 'resample2d_cuda.cc',
24 | 'resample2d_kernel.cu'
25 | ], extra_compile_args={'cxx': cxx_args, 'nvcc': nvcc_args})
26 | ],
27 | cmdclass={
28 | 'build_ext': BuildExtension
29 | })
30 |
--------------------------------------------------------------------------------
/flownet2/networks/submodules.py:
--------------------------------------------------------------------------------
1 | # freda (todo) :
2 |
3 | import torch.nn as nn
4 | import torch
5 | import numpy as np
6 |
7 | def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
8 | if batchNorm:
9 | return nn.Sequential(
10 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
11 | nn.BatchNorm2d(out_planes),
12 | nn.LeakyReLU(0.1,inplace=True)
13 | )
14 | else:
15 | return nn.Sequential(
16 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
17 | nn.LeakyReLU(0.1,inplace=True)
18 | )
19 |
20 | def i_conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, bias = True):
21 | if batchNorm:
22 | return nn.Sequential(
23 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias),
24 | nn.BatchNorm2d(out_planes),
25 | )
26 | else:
27 | return nn.Sequential(
28 | nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=bias),
29 | )
30 |
31 | def predict_flow(in_planes):
32 | return nn.Conv2d(in_planes,2,kernel_size=3,stride=1,padding=1,bias=True)
33 |
34 | def deconv(in_planes, out_planes):
35 | return nn.Sequential(
36 | nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
37 | nn.LeakyReLU(0.1,inplace=True)
38 | )
39 |
40 | class tofp16(nn.Module):
41 | def __init__(self):
42 | super(tofp16, self).__init__()
43 |
44 | def forward(self, input):
45 | return input.half()
46 |
47 |
48 | class tofp32(nn.Module):
49 | def __init__(self):
50 | super(tofp32, self).__init__()
51 |
52 | def forward(self, input):
53 | return input.float()
54 |
55 |
56 | def init_deconv_bilinear(weight):
57 | f_shape = weight.size()
58 | heigh, width = f_shape[-2], f_shape[-1]
59 | f = np.ceil(width/2.0)
60 | c = (2 * f - 1 - f % 2) / (2.0 * f)
61 | bilinear = np.zeros([heigh, width])
62 | for x in range(width):
63 | for y in range(heigh):
64 | value = (1 - abs(x / f - c)) * (1 - abs(y / f - c))
65 | bilinear[x, y] = value
66 | weight.data.fill_(0.)
67 | for i in range(f_shape[0]):
68 | for j in range(f_shape[1]):
69 | weight.data[i,j,:,:] = torch.from_numpy(bilinear)
70 |
71 |
72 | def save_grad(grads, name):
73 | def hook(grad):
74 | grads[name] = grad
75 | return hook
76 |
77 | '''
78 | def save_grad(grads, name):
79 | def hook(grad):
80 | grads[name] = grad
81 | return hook
82 | import torch
83 | from channelnorm_package.modules.channelnorm import ChannelNorm
84 | model = ChannelNorm().cuda()
85 | grads = {}
86 | a = 100*torch.autograd.Variable(torch.randn((1,3,5,5)).cuda(), requires_grad=True)
87 | a.register_hook(save_grad(grads, 'a'))
88 | b = model(a)
89 | y = torch.mean(b)
90 | y.backward()
91 |
92 | '''
93 |
--------------------------------------------------------------------------------
/flownet2/run-caffe2pytorch.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | FN2PYTORCH=${1:-/}
4 |
5 | # install custom layers
6 | sudo nvidia-docker build -t $USER/pytorch:CUDA8-py27 .
7 | sudo nvidia-docker run --rm -ti --volume=${FN2PYTORCH}:/flownet2-pytorch:rw --workdir=/flownet2-pytorch $USER/pytorch:CUDA8-py27 /bin/bash -c "./install.sh"
8 |
9 | # convert FlowNet2-C, CS, CSS, CSS-ft-sd, SD, S and 2 to PyTorch
10 | sudo nvidia-docker run -ti --volume=${FN2PYTORCH}:/fn2pytorch:rw flownet2:latest /bin/bash -c "source /flownet2/flownet2/set-env.sh && cd /flownet2/flownet2/models && \
11 | python /fn2pytorch/convert.py ./FlowNet2-C/FlowNet2-C_weights.caffemodel ./FlowNet2-C/FlowNet2-C_deploy.prototxt.template /fn2pytorch &&
12 | python /fn2pytorch/convert.py ./FlowNet2-CS/FlowNet2-CS_weights.caffemodel ./FlowNet2-CS/FlowNet2-CS_deploy.prototxt.template /fn2pytorch && \
13 | python /fn2pytorch/convert.py ./FlowNet2-CSS/FlowNet2-CSS_weights.caffemodel.h5 ./FlowNet2-CSS/FlowNet2-CSS_deploy.prototxt.template /fn2pytorch && \
14 | python /fn2pytorch/convert.py ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_weights.caffemodel.h5 ./FlowNet2-CSS-ft-sd/FlowNet2-CSS-ft-sd_deploy.prototxt.template /fn2pytorch && \
15 | python /fn2pytorch/convert.py ./FlowNet2-SD/FlowNet2-SD_weights.caffemodel.h5 ./FlowNet2-SD/FlowNet2-SD_deploy.prototxt.template /fn2pytorch && \
16 | python /fn2pytorch/convert.py ./FlowNet2-S/FlowNet2-S_weights.caffemodel.h5 ./FlowNet2-S/FlowNet2-S_deploy.prototxt.template /fn2pytorch && \
17 | python /fn2pytorch/convert.py ./FlowNet2/FlowNet2_weights.caffemodel.h5 ./FlowNet2/FlowNet2_deploy.prototxt.template /fn2pytorch"
--------------------------------------------------------------------------------
/flownet2/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/flownet2/utils/__init__.py
--------------------------------------------------------------------------------
/flownet2/utils/flow_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | TAG_CHAR = np.array([202021.25], np.float32)
4 |
5 | def readFlow(fn):
6 | """ Read .flo file in Middlebury format"""
7 | # Code adapted from:
8 | # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
9 |
10 | # WARNING: this will work on little-endian architectures (eg Intel x86) only!
11 | # print 'fn = %s'%(fn)
12 | with open(fn, 'rb') as f:
13 | magic = np.fromfile(f, np.float32, count=1)
14 | if 202021.25 != magic:
15 | print('Magic number incorrect. Invalid .flo file')
16 | return None
17 | else:
18 | w = np.fromfile(f, np.int32, count=1)
19 | h = np.fromfile(f, np.int32, count=1)
20 | # print 'Reading %d x %d flo file\n' % (w, h)
21 | data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
22 | # Reshape data into 3D array (columns, rows, bands)
23 | # The reshape here is for visualization, the original code is (w,h,2)
24 | return np.resize(data, (int(h), int(w), 2))
25 |
26 | def writeFlow(filename,uv,v=None):
27 | """ Write optical flow to file.
28 |
29 | If v is None, uv is assumed to contain both u and v channels,
30 | stacked in depth.
31 | Original code by Deqing Sun, adapted from Daniel Scharstein.
32 | """
33 | nBands = 2
34 |
35 | if v is None:
36 | assert(uv.ndim == 3)
37 | assert(uv.shape[2] == 2)
38 | u = uv[:,:,0]
39 | v = uv[:,:,1]
40 | else:
41 | u = uv
42 |
43 | assert(u.shape == v.shape)
44 | height,width = u.shape
45 | f = open(filename,'wb')
46 | # write the header
47 | f.write(TAG_CHAR)
48 | np.array(width).astype(np.int32).tofile(f)
49 | np.array(height).astype(np.int32).tofile(f)
50 | # arrange into matrix form
51 | tmp = np.zeros((height, width*nBands))
52 | tmp[:,np.arange(width)*2] = u
53 | tmp[:,np.arange(width)*2 + 1] = v
54 | tmp.astype(np.float32).tofile(f)
55 | f.close()
56 |
--------------------------------------------------------------------------------
/flownet2/utils/frame_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from os.path import *
3 | from scipy.misc import imread
4 | from . import flow_utils
5 |
6 | def read_gen(file_name):
7 | ext = splitext(file_name)[-1]
8 | if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
9 | im = imread(file_name)
10 | if im.shape[2] > 3:
11 | return im[:,:,:3]
12 | else:
13 | return im
14 | elif ext == '.bin' or ext == '.raw':
15 | return np.load(file_name)
16 | elif ext == '.flo':
17 | return flow_utils.readFlow(file_name).astype(np.float32)
18 | return []
19 |
--------------------------------------------------------------------------------
/flownet2/utils/param_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 |
5 | def parse_flownetc(modules, weights, biases):
6 | keys = [
7 | 'conv1',
8 | 'conv2',
9 | 'conv3',
10 | 'conv_redir',
11 | 'conv3_1',
12 | 'conv4',
13 | 'conv4_1',
14 | 'conv5',
15 | 'conv5_1',
16 | 'conv6',
17 | 'conv6_1',
18 |
19 | 'deconv5',
20 | 'deconv4',
21 | 'deconv3',
22 | 'deconv2',
23 |
24 | 'Convolution1',
25 | 'Convolution2',
26 | 'Convolution3',
27 | 'Convolution4',
28 | 'Convolution5',
29 |
30 | 'upsample_flow6to5',
31 | 'upsample_flow5to4',
32 | 'upsample_flow4to3',
33 | 'upsample_flow3to2',
34 |
35 | ]
36 | i = 0
37 | for m in modules:
38 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
39 | weight = weights[keys[i]].copy()
40 | bias = biases[keys[i]].copy()
41 | if keys[i] == 'conv1':
42 | m.weight.data[:,:,:,:] = torch.from_numpy(np.flip(weight, axis=1).copy())
43 | m.bias.data[:] = torch.from_numpy(bias)
44 | else:
45 | m.weight.data[:,:,:,:] = torch.from_numpy(weight)
46 | m.bias.data[:] = torch.from_numpy(bias)
47 |
48 | i = i + 1
49 | return
50 |
51 | def parse_flownets(modules, weights, biases, param_prefix='net2_'):
52 | keys = [
53 | 'conv1',
54 | 'conv2',
55 | 'conv3',
56 | 'conv3_1',
57 | 'conv4',
58 | 'conv4_1',
59 | 'conv5',
60 | 'conv5_1',
61 | 'conv6',
62 | 'conv6_1',
63 |
64 | 'deconv5',
65 | 'deconv4',
66 | 'deconv3',
67 | 'deconv2',
68 |
69 | 'predict_conv6',
70 | 'predict_conv5',
71 | 'predict_conv4',
72 | 'predict_conv3',
73 | 'predict_conv2',
74 |
75 | 'upsample_flow6to5',
76 | 'upsample_flow5to4',
77 | 'upsample_flow4to3',
78 | 'upsample_flow3to2',
79 | ]
80 | for i, k in enumerate(keys):
81 | if 'upsample' in k:
82 | keys[i] = param_prefix + param_prefix + k
83 | else:
84 | keys[i] = param_prefix + k
85 | i = 0
86 | for m in modules:
87 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
88 | weight = weights[keys[i]].copy()
89 | bias = biases[keys[i]].copy()
90 | if keys[i] == param_prefix+'conv1':
91 | m.weight.data[:,0:3,:,:] = torch.from_numpy(np.flip(weight[:,0:3,:,:], axis=1).copy())
92 | m.weight.data[:,3:6,:,:] = torch.from_numpy(np.flip(weight[:,3:6,:,:], axis=1).copy())
93 | m.weight.data[:,6:9,:,:] = torch.from_numpy(np.flip(weight[:,6:9,:,:], axis=1).copy())
94 | m.weight.data[:,9::,:,:] = torch.from_numpy(weight[:,9:,:,:].copy())
95 | if m.bias is not None:
96 | m.bias.data[:] = torch.from_numpy(bias)
97 | else:
98 | m.weight.data[:,:,:,:] = torch.from_numpy(weight)
99 | if m.bias is not None:
100 | m.bias.data[:] = torch.from_numpy(bias)
101 | i = i + 1
102 | return
103 |
104 | def parse_flownetsonly(modules, weights, biases, param_prefix=''):
105 | keys = [
106 | 'conv1',
107 | 'conv2',
108 | 'conv3',
109 | 'conv3_1',
110 | 'conv4',
111 | 'conv4_1',
112 | 'conv5',
113 | 'conv5_1',
114 | 'conv6',
115 | 'conv6_1',
116 |
117 | 'deconv5',
118 | 'deconv4',
119 | 'deconv3',
120 | 'deconv2',
121 |
122 | 'Convolution1',
123 | 'Convolution2',
124 | 'Convolution3',
125 | 'Convolution4',
126 | 'Convolution5',
127 |
128 | 'upsample_flow6to5',
129 | 'upsample_flow5to4',
130 | 'upsample_flow4to3',
131 | 'upsample_flow3to2',
132 | ]
133 | for i, k in enumerate(keys):
134 | if 'upsample' in k:
135 | keys[i] = param_prefix + param_prefix + k
136 | else:
137 | keys[i] = param_prefix + k
138 | i = 0
139 | for m in modules:
140 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
141 | weight = weights[keys[i]].copy()
142 | bias = biases[keys[i]].copy()
143 | if keys[i] == param_prefix+'conv1':
144 | # print ("%s :"%(keys[i]), m.weight.size(), m.bias.size(), tf_w[keys[i]].shape[::-1])
145 | m.weight.data[:,0:3,:,:] = torch.from_numpy(np.flip(weight[:,0:3,:,:], axis=1).copy())
146 | m.weight.data[:,3:6,:,:] = torch.from_numpy(np.flip(weight[:,3:6,:,:], axis=1).copy())
147 | if m.bias is not None:
148 | m.bias.data[:] = torch.from_numpy(bias)
149 | else:
150 | m.weight.data[:,:,:,:] = torch.from_numpy(weight)
151 | if m.bias is not None:
152 | m.bias.data[:] = torch.from_numpy(bias)
153 | i = i + 1
154 | return
155 |
156 | def parse_flownetsd(modules, weights, biases, param_prefix='netsd_'):
157 | keys = [
158 | 'conv0',
159 | 'conv1',
160 | 'conv1_1',
161 | 'conv2',
162 | 'conv2_1',
163 | 'conv3',
164 | 'conv3_1',
165 | 'conv4',
166 | 'conv4_1',
167 | 'conv5',
168 | 'conv5_1',
169 | 'conv6',
170 | 'conv6_1',
171 |
172 | 'deconv5',
173 | 'deconv4',
174 | 'deconv3',
175 | 'deconv2',
176 |
177 | 'interconv5',
178 | 'interconv4',
179 | 'interconv3',
180 | 'interconv2',
181 |
182 | 'Convolution1',
183 | 'Convolution2',
184 | 'Convolution3',
185 | 'Convolution4',
186 | 'Convolution5',
187 |
188 | 'upsample_flow6to5',
189 | 'upsample_flow5to4',
190 | 'upsample_flow4to3',
191 | 'upsample_flow3to2',
192 | ]
193 | for i, k in enumerate(keys):
194 | keys[i] = param_prefix + k
195 |
196 | i = 0
197 | for m in modules:
198 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
199 | weight = weights[keys[i]].copy()
200 | bias = biases[keys[i]].copy()
201 | if keys[i] == param_prefix+'conv0':
202 | m.weight.data[:,0:3,:,:] = torch.from_numpy(np.flip(weight[:,0:3,:,:], axis=1).copy())
203 | m.weight.data[:,3:6,:,:] = torch.from_numpy(np.flip(weight[:,3:6,:,:], axis=1).copy())
204 | if m.bias is not None:
205 | m.bias.data[:] = torch.from_numpy(bias)
206 | else:
207 | m.weight.data[:,:,:,:] = torch.from_numpy(weight)
208 | if m.bias is not None:
209 | m.bias.data[:] = torch.from_numpy(bias)
210 | i = i + 1
211 |
212 | return
213 |
214 | def parse_flownetfusion(modules, weights, biases, param_prefix='fuse_'):
215 | keys = [
216 | 'conv0',
217 | 'conv1',
218 | 'conv1_1',
219 | 'conv2',
220 | 'conv2_1',
221 |
222 | 'deconv1',
223 | 'deconv0',
224 |
225 | 'interconv1',
226 | 'interconv0',
227 |
228 | '_Convolution5',
229 | '_Convolution6',
230 | '_Convolution7',
231 |
232 | 'upsample_flow2to1',
233 | 'upsample_flow1to0',
234 | ]
235 | for i, k in enumerate(keys):
236 | keys[i] = param_prefix + k
237 |
238 | i = 0
239 | for m in modules:
240 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
241 | weight = weights[keys[i]].copy()
242 | bias = biases[keys[i]].copy()
243 | if keys[i] == param_prefix+'conv0':
244 | m.weight.data[:,0:3,:,:] = torch.from_numpy(np.flip(weight[:,0:3,:,:], axis=1).copy())
245 | m.weight.data[:,3::,:,:] = torch.from_numpy(weight[:,3:,:,:].copy())
246 | if m.bias is not None:
247 | m.bias.data[:] = torch.from_numpy(bias)
248 | else:
249 | m.weight.data[:,:,:,:] = torch.from_numpy(weight)
250 | if m.bias is not None:
251 | m.bias.data[:] = torch.from_numpy(bias)
252 | i = i + 1
253 |
254 | return
255 |
--------------------------------------------------------------------------------
/flownet2/utils/tools.py:
--------------------------------------------------------------------------------
1 | # freda (todo) :
2 |
3 | import os, time, sys, math
4 | import subprocess, shutil
5 | from os.path import *
6 | import numpy as np
7 | from inspect import isclass
8 | from pytz import timezone
9 | from datetime import datetime
10 | import inspect
11 | import torch
12 |
13 | def datestr():
14 | pacific = timezone('US/Pacific')
15 | now = datetime.now(pacific)
16 | return '{}{:02}{:02}_{:02}{:02}'.format(now.year, now.month, now.day, now.hour, now.minute)
17 |
18 | def module_to_dict(module, exclude=[]):
19 | return dict([(x, getattr(module, x)) for x in dir(module)
20 | if isclass(getattr(module, x))
21 | and x not in exclude
22 | and getattr(module, x) not in exclude])
23 |
24 | class TimerBlock:
25 | def __init__(self, title):
26 | print(("{}".format(title)))
27 |
28 | def __enter__(self):
29 | self.start = time.clock()
30 | return self
31 |
32 | def __exit__(self, exc_type, exc_value, traceback):
33 | self.end = time.clock()
34 | self.interval = self.end - self.start
35 |
36 | if exc_type is not None:
37 | self.log("Operation failed\n")
38 | else:
39 | self.log("Operation finished\n")
40 |
41 |
42 | def log(self, string):
43 | duration = time.clock() - self.start
44 | units = 's'
45 | if duration > 60:
46 | duration = duration / 60.
47 | units = 'm'
48 | print((" [{:.3f}{}] {}".format(duration, units, string)))
49 |
50 | def log2file(self, fid, string):
51 | fid = open(fid, 'a')
52 | fid.write("%s\n"%(string))
53 | fid.close()
54 |
55 | def add_arguments_for_module(parser, module, argument_for_class, default, skip_params=[], parameter_defaults={}):
56 | argument_group = parser.add_argument_group(argument_for_class.capitalize())
57 |
58 | module_dict = module_to_dict(module)
59 | argument_group.add_argument('--' + argument_for_class, type=str, default=default, choices=list(module_dict.keys()))
60 |
61 | args, unknown_args = parser.parse_known_args()
62 | class_obj = module_dict[vars(args)[argument_for_class]]
63 |
64 | argspec = inspect.getargspec(class_obj.__init__)
65 |
66 | defaults = argspec.defaults[::-1] if argspec.defaults else None
67 |
68 | args = argspec.args[::-1]
69 | for i, arg in enumerate(args):
70 | cmd_arg = '{}_{}'.format(argument_for_class, arg)
71 | if arg not in skip_params + ['self', 'args']:
72 | if arg in list(parameter_defaults.keys()):
73 | argument_group.add_argument('--{}'.format(cmd_arg), type=type(parameter_defaults[arg]), default=parameter_defaults[arg])
74 | elif (defaults is not None and i < len(defaults)):
75 | argument_group.add_argument('--{}'.format(cmd_arg), type=type(defaults[i]), default=defaults[i])
76 | else:
77 | print(("[Warning]: non-default argument '{}' detected on class '{}'. This argument cannot be modified via the command line"
78 | .format(arg, module.__class__.__name__)))
79 | # We don't have a good way of dealing with inferring the type of the argument
80 | # TODO: try creating a custom action and using ast's infer type?
81 | # else:
82 | # argument_group.add_argument('--{}'.format(cmd_arg), required=True)
83 |
84 | def kwargs_from_args(args, argument_for_class):
85 | argument_for_class = argument_for_class + '_'
86 | return {key[len(argument_for_class):]: value for key, value in list(vars(args).items()) if argument_for_class in key and key != argument_for_class + 'class'}
87 |
88 | def format_dictionary_of_losses(labels, values):
89 | try:
90 | string = ', '.join([('{}: {:' + ('.3f' if value >= 0.001 else '.1e') +'}').format(name, value) for name, value in zip(labels, values)])
91 | except (TypeError, ValueError) as e:
92 | print((list(zip(labels, values))))
93 | string = '[Log Error] ' + str(e)
94 |
95 | return string
96 |
97 |
98 | class IteratorTimer():
99 | def __init__(self, iterable):
100 | self.iterable = iterable
101 | self.iterator = self.iterable.__iter__()
102 |
103 | def __iter__(self):
104 | return self
105 |
106 | def __len__(self):
107 | return len(self.iterable)
108 |
109 | def __next__(self):
110 | start = time.time()
111 | n = next(self.iterator)
112 | self.last_duration = (time.time() - start)
113 | return n
114 |
115 | next = __next__
116 |
117 | def gpumemusage():
118 | gpu_mem = subprocess.check_output("nvidia-smi | grep MiB | cut -f 3 -d '|'", shell=True).replace(' ', '').replace('\n', '').replace('i', '')
119 | all_stat = [float(a) for a in gpu_mem.replace('/','').split('MB')[:-1]]
120 |
121 | gpu_mem = ''
122 | for i in range(len(all_stat)/2):
123 | curr, tot = all_stat[2*i], all_stat[2*i+1]
124 | util = "%1.2f"%(100*curr/tot)+'%'
125 | cmem = str(int(math.ceil(curr/1024.)))+'GB'
126 | gmem = str(int(math.ceil(tot/1024.)))+'GB'
127 | gpu_mem += util + '--' + join(cmem, gmem) + ' '
128 | return gpu_mem
129 |
130 |
131 | def update_hyperparameter_schedule(args, epoch, global_iteration, optimizer):
132 | if args.schedule_lr_frequency > 0:
133 | for param_group in optimizer.param_groups:
134 | if (global_iteration + 1) % args.schedule_lr_frequency == 0:
135 | param_group['lr'] /= float(args.schedule_lr_fraction)
136 | param_group['lr'] = float(np.maximum(param_group['lr'], 0.000001))
137 |
138 | def save_checkpoint(state, is_best, path, prefix, filename='checkpoint.pth.tar'):
139 | prefix_save = os.path.join(path, prefix)
140 | name = prefix_save + '_' + filename
141 | torch.save(state, name)
142 | if is_best:
143 | shutil.copyfile(name, prefix_save + '_model_best.pth.tar')
144 |
145 |
--------------------------------------------------------------------------------
/generate_pseudo_labels.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | from __future__ import absolute_import, division, print_function
8 | import os
9 | import sys
10 | sys.path.append('flownet2')
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.utils import data
15 | from torchvision.transforms import functional as TF
16 |
17 | import argparse
18 | from tqdm import tqdm
19 |
20 | from libs.datasets import get_transforms, get_datasets
21 | from libs.networks.pseudo_label_generator import FGPLG
22 | from libs.utils.pyt_utils import load_model
23 |
24 | parser = argparse.ArgumentParser()
25 |
26 | # Dataloading-related settings
27 | parser.add_argument('--data', type=str, default='data/datasets/',
28 | help='path to datasets folder')
29 | parser.add_argument('--checkpoint', default='models/pseudo_label_generator_5.pth',
30 | help='path to the pretrained checkpoint')
31 | parser.add_argument('--dataset-config', default='config/datasets.yaml',
32 | help='dataset config file')
33 | parser.add_argument('--pseudo-label-folder', default='data/pseudo-labels',
34 | help='location to save generated pseudo-labels')
35 | parser.add_argument("--label_interval", default=5, type=int,
36 | help="the interval of ground truth labels")
37 | parser.add_argument("--frame_between_label_num", default=1, type=int,
38 | help="the number of generated pseudo-labels in each interval")
39 | parser.add_argument('-j', '--num_workers', default=1, type=int, metavar='N',
40 | help='number of data loading workers.')
41 |
42 | # Model settings
43 | parser.add_argument('--size', default=448, type=int,
44 | help='image size')
45 | parser.add_argument('--os', default=16, type=int,
46 | help='output stride.')
47 |
48 | # FlowNet setting
49 | parser.add_argument("--fp16", action="store_true",
50 | help="Run model in pseudo-fp16 mode (fp16 storage fp32 math).")
51 | parser.add_argument("--rgb_max", type=float, default=1.)
52 |
53 | args = parser.parse_args()
54 |
55 | cuda = torch.cuda.is_available()
56 | device = torch.device("cuda" if cuda else "cpu")
57 |
58 | if cuda:
59 | torch.backends.cudnn.benchmark = True
60 | current_device = torch.cuda.current_device()
61 | print("Running on", torch.cuda.get_device_name(current_device))
62 | else:
63 | print("Running on CPU")
64 |
65 | data_transforms = get_transforms(
66 | input_size=(args.size, args.size),
67 | image_mode=False
68 | )
69 | dataset = get_datasets(
70 | name_list=["DAVIS2016", "FBMS", "VOS"],
71 | split_list=["train", "train", "train"],
72 | config_path=args.dataset_config,
73 | root=args.data,
74 | training=True, # provide labels
75 | transforms=data_transforms['test'],
76 | read_clip=True,
77 | random_reverse_clip=False,
78 | label_interval=args.label_interval,
79 | frame_between_label_num=args.frame_between_label_num,
80 | clip_len=args.frame_between_label_num+2
81 | )
82 |
83 | dataloader = data.DataLoader(
84 | dataset=dataset,
85 | batch_size=1, # only support 1 video clip
86 | num_workers=args.num_workers,
87 | shuffle=False,
88 | drop_last=True
89 | )
90 |
91 | pseudo_label_generator = FGPLG(args=args, output_stride=args.os)
92 |
93 | # load pretrained models
94 | if os.path.exists(args.checkpoint):
95 | print('Loading state dict from: {0}'.format(args.checkpoint))
96 | pseudo_label_generator = load_model(model=pseudo_label_generator, model_file=args.checkpoint, is_restore=True)
97 | else:
98 | raise ValueError("Cannot find model file at {}".format(args.checkpoint))
99 |
100 | pseudo_label_generator.to(device)
101 |
102 | pseudo_label_folder = os.path.join(args.pseudo_label_folder, "{}_{}".format(args.frame_between_label_num, args.label_interval))
103 | if not os.path.exists(pseudo_label_folder):
104 | os.makedirs(pseudo_label_folder)
105 |
106 | def generate_pseudo_label():
107 | pseudo_label_generator.eval()
108 |
109 | for data in tqdm(dataloader):
110 | images = []
111 | labels = []
112 | for frame in data:
113 | images.append(frame['image'].to(device))
114 | labels.append(frame['label'].to(device) if 'label' in frame else None)
115 | with torch.no_grad():
116 | for i in range(1, args.frame_between_label_num+1):
117 | pseudo_label = pseudo_label_generator.generate_pseudo_label(images[i], images[0], images[-1], labels[0], labels[-1])
118 | labels[i] = torch.sigmoid(pseudo_label).detach()
119 | # save pseudo-labels
120 | for i, label_ in enumerate(labels):
121 | for j, label in enumerate(label_.detach().cpu()):
122 | dataset = data[i]['dataset'][j]
123 | image_id = data[i]['image_id'][j]
124 | pseudo_label_path = os.path.join(pseudo_label_folder, "{}/{}.png".format(dataset, image_id))
125 |
126 | height = data[i]['height'].item()
127 | width = data[i]['width'].item()
128 | result = TF.to_pil_image(label)
129 | result = result.resize((height, width))
130 | dirname = os.path.dirname(pseudo_label_path)
131 | if not os.path.exists(dirname):
132 | os.makedirs(dirname)
133 | result.save(pseudo_label_path)
134 |
135 | if __name__ == "__main__":
136 | print("Generating pseudo-labels at {}".format(args.pseudo_label_folder))
137 | print("label interval: {}".format(args.label_interval))
138 | print("frame between label num: {}".format(args.frame_between_label_num))
139 | generate_pseudo_label()
140 |
--------------------------------------------------------------------------------
/inference.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | from __future__ import absolute_import, division, print_function
8 | import os
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils import data
13 | from torchvision.transforms import functional as TF
14 |
15 | import argparse
16 | from tqdm import tqdm
17 |
18 | from libs.datasets import get_transforms, get_datasets
19 | from libs.networks import VideoModel
20 | from libs.utils.pyt_utils import load_model
21 |
22 | parser = argparse.ArgumentParser()
23 |
24 | # Dataloading-related settings
25 | parser.add_argument('--data', type=str, default='data/datasets/',
26 | help='path to datasets folder')
27 | parser.add_argument('--dataset', default='VOS', type=str,
28 | help='dataset name for inference')
29 | parser.add_argument('--split', default='test', type=str,
30 | help='dataset split for inference')
31 | parser.add_argument('--checkpoint', default='models/video_best_model.pth',
32 | help='path to the pretrained checkpoint')
33 | parser.add_argument('--dataset-config', default='config/datasets.yaml',
34 | help='dataset config file')
35 | parser.add_argument('--results-folder', default='data/results/',
36 | help='location to save predicted saliency maps')
37 | parser.add_argument('-j', '--num_workers', default=1, type=int, metavar='N',
38 | help='number of data loading workers.')
39 |
40 | # Model settings
41 | parser.add_argument('--size', default=448, type=int,
42 | help='image size')
43 | parser.add_argument('--os', default=16, type=int,
44 | help='output stride.')
45 | parser.add_argument("--clip_len", type=int, default=4,
46 | help="the number of frames in a video clip.")
47 |
48 | args = parser.parse_args()
49 |
50 | cuda = torch.cuda.is_available()
51 | device = torch.device("cuda" if cuda else "cpu")
52 |
53 | if cuda:
54 | torch.backends.cudnn.benchmark = True
55 | current_device = torch.cuda.current_device()
56 | print("Running on", torch.cuda.get_device_name(current_device))
57 | else:
58 | print("Running on CPU")
59 |
60 | data_transforms = get_transforms(
61 | input_size=(args.size, args.size),
62 | image_mode=False
63 | )
64 | dataset = get_datasets(
65 | name_list=args.dataset,
66 | split_list=args.split,
67 | config_path=args.dataset_config,
68 | root=args.data,
69 | training=False,
70 | transforms=data_transforms['test'],
71 | read_clip=True,
72 | random_reverse_clip=False,
73 | label_interval=1,
74 | frame_between_label_num=0,
75 | clip_len=args.clip_len
76 | )
77 |
78 | dataloader = data.DataLoader(
79 | dataset=dataset,
80 | batch_size=1, # only support 1 video clip
81 | num_workers=args.num_workers,
82 | shuffle=False
83 | )
84 |
85 | model = VideoModel(output_stride=args.os)
86 |
87 | # load pretrained models
88 | if os.path.exists(args.checkpoint):
89 | print('Loading state dict from: {0}'.format(args.checkpoint))
90 | model = load_model(model=model, model_file=args.checkpoint, is_restore=True)
91 | else:
92 | raise ValueError("Cannot find model file at {}".format(args.checkpoint))
93 |
94 | model.to(device)
95 |
96 | def inference():
97 | model.eval()
98 | print("Begin inference on {} {}.".format(args.dataset, args.split))
99 | for data in tqdm(dataloader):
100 | images = [frame['image'].to(device) for frame in data]
101 | with torch.no_grad():
102 | preds = model(images)
103 | preds = [torch.sigmoid(pred) for pred in preds]
104 | # save predicted saliency maps
105 | for i, pred_ in enumerate(preds):
106 | for j, pred in enumerate(pred_.detach().cpu()):
107 | dataset = data[i]['dataset'][j]
108 | image_id = data[i]['image_id'][j]
109 | height = data[i]['height'].item()
110 | width = data[i]['width'].item()
111 | result_path = os.path.join(args.results_folder, "{}/{}.png".format(dataset, image_id))
112 |
113 | result = TF.to_pil_image(pred)
114 | result = result.resize((height, width))
115 | dirname = os.path.dirname(result_path)
116 | if not os.path.exists(dirname):
117 | os.makedirs(dirname)
118 | result.save(result_path)
119 |
120 | if __name__ == "__main__":
121 | inference()
122 |
--------------------------------------------------------------------------------
/libs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/libs/__init__.py
--------------------------------------------------------------------------------
/libs/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from .transforms import get_transforms
2 | from .video_datasets import get_datasets
--------------------------------------------------------------------------------
/libs/datasets/transforms.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | import random
8 |
9 | from PIL import Image
10 | import torch
11 | from torchvision import transforms
12 | from torchvision.transforms import functional as F
13 | from torch.utils import data
14 | import numpy as np
15 |
16 | def get_transforms(image_mode, input_size, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
17 | data_transforms = {
18 | 'train': transforms.Compose([
19 | ColorJitter(brightness=.3, contrast=.3, saturation=.3, hue=.3, image_mode=image_mode),
20 | RandomResizedCrop(input_size, image_mode),
21 | RandomFlip(image_mode),
22 | ToTensor(),
23 | Normalize(mean=mean,
24 | std=std)
25 | ]) if image_mode else transforms.Compose([
26 | Resize(input_size),
27 | ToTensor(),
28 | Normalize(mean=mean,
29 | std=std)
30 | ]),
31 | 'val': transforms.Compose([
32 | Resize(input_size),
33 | ToTensor(),
34 | Normalize(mean=mean,
35 | std=std)
36 | ]),
37 | 'test': transforms.Compose([
38 | Resize(input_size),
39 | ToTensor(),
40 | Normalize(mean=mean,
41 | std=std)
42 | ]),
43 | }
44 | return data_transforms
45 |
46 | class ColorJitter(transforms.ColorJitter):
47 | def __init__(self, image_mode, **kwargs):
48 | super(ColorJitter, self).__init__(**kwargs)
49 | self.transform = None
50 | self.image_mode = image_mode
51 | def __call__(self, sample):
52 | if self.transform is None or self.image_mode:
53 | self.transform = self.get_params(self.brightness, self.contrast,
54 | self.saturation, self.hue)
55 | sample['image'] = self.transform(sample['image'])
56 | return sample
57 |
58 | class RandomResizedCrop(object):
59 | """
60 | A crop of random size (default: of 0.08 to 1.0) of the original size and a random
61 | aspect ratio (default: of 3/4 to 4/3) of the original aspect ratio is made. This crop
62 | is finally resized to given size.
63 | This is popularly used to train the Inception networks.
64 | Args:
65 | size: expected output size of each edge
66 | scale: range of size of the origin size cropped
67 | ratio: range of aspect ratio of the origin aspect ratio cropped
68 | """
69 |
70 | def __init__(self, size, image_mode, scale=(0.7, 1.0), ratio=(3. / 4., 4. / 3.)):
71 | self.size = size
72 | self.scale = scale
73 | self.ratio = ratio
74 | self.i, self.j, self.h, self.w = None, None, None, None
75 | self.image_mode = image_mode
76 | def __call__(self, sample):
77 | image, label = sample['image'], sample['label']
78 | if self.i is None or self.image_mode:
79 | self.i, self.j, self.h, self.w = transforms.RandomResizedCrop.get_params(image, self.scale, self.ratio)
80 | image = F.resized_crop(image, self.i, self.j, self.h, self.w, self.size, Image.BILINEAR)
81 | label = F.resized_crop(label, self.i, self.j, self.h, self.w, self.size, Image.BILINEAR)
82 | sample['image'], sample['label'] = image, label
83 | return sample
84 |
85 | class RandomFlip(object):
86 | """Horizontally flip the given PIL Image randomly with a given probability.
87 | """
88 | def __init__(self, image_mode):
89 | self.rand_flip_index = None
90 | self.image_mode = image_mode
91 | def __call__(self, sample):
92 | image, label = sample['image'], sample['label']
93 | if self.rand_flip_index is None or self.image_mode:
94 | self.rand_flip_index = random.randint(-1,2)
95 | # 0: horizontal flip, 1: vertical flip, -1: horizontal and vertical flip
96 | if self.rand_flip_index == 0:
97 | image = F.hflip(image)
98 | label = F.hflip(label)
99 | elif self.rand_flip_index == 1:
100 | image = F.vflip(image)
101 | label = F.vflip(label)
102 | elif self.rand_flip_index == 2:
103 | image = F.vflip(F.hflip(image))
104 | label = F.vflip(F.hflip(label))
105 | sample['image'], sample['label'] = image, label
106 | return sample
107 |
108 | class Resize(object):
109 | """ Resize PIL image use both for training and inference"""
110 | def __init__(self, size):
111 | self.size = size
112 |
113 | def __call__(self, sample):
114 | image, label = sample['image'], sample['label']
115 | image = F.resize(image, self.size, Image.BILINEAR)
116 | if label is not None:
117 | label = F.resize(label, self.size, Image.BILINEAR)
118 | sample['image'], sample['label'] = image, label
119 | return sample
120 |
121 | class ToTensor(object):
122 | """Convert ndarrays in sample to Tensors."""
123 |
124 | def __call__(self, sample):
125 | image, label = sample['image'], sample['label']
126 |
127 | # swap color axis because
128 | # numpy image: H x W x C
129 | # torch image: C X H X W
130 | # Image range from [0~255] to [0.0 ~ 1.0]
131 | image = F.to_tensor(image)
132 | if label is not None:
133 | label = torch.from_numpy(np.array(label)).unsqueeze(0).float()
134 | return {'image': image, 'label': label}
135 |
136 | class Normalize(object):
137 | """ Normalize a tensor image with mean and standard deviation.
138 | args: tensor (Tensor) – Tensor image of size (C, H, W) to be normalized.
139 | Returns: Normalized Tensor image.
140 | """
141 | # default caffe mode
142 | def __init__(self, mean, std):
143 | self.mean = mean
144 | self.std = std
145 |
146 | def __call__(self, sample):
147 | image, label = sample['image'], sample['label']
148 | image = F.normalize(image, self.mean, self.std)
149 | return {'image': image, 'label': label}
150 |
151 |
--------------------------------------------------------------------------------
/libs/modules/__init__.py:
--------------------------------------------------------------------------------
1 | from .convgru import ConvGRUCell
2 | from .non_local_dot_product import NONLocalBlock3D
--------------------------------------------------------------------------------
/libs/modules/convgru.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | import torch
8 | from torch import nn
9 |
10 | class ConvGRUCell(nn.Module):
11 | """
12 | ICLR2016: Delving Deeper into Convolutional Networks for Learning Video Representations
13 | url: https://arxiv.org/abs/1511.06432
14 | """
15 | def __init__(self, input_channels, hidden_channels, kernel_size, cuda_flag=True):
16 | super(ConvGRUCell, self).__init__()
17 | self.input_channels = input_channels
18 | self.cuda_flag = cuda_flag
19 | self.hidden_channels = hidden_channels
20 | self.kernel_size = kernel_size
21 |
22 | padding = self.kernel_size // 2
23 | self.reset_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding)
24 | self.update_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding)
25 | self.output_gate = nn.Conv2d(input_channels + hidden_channels, hidden_channels, 3, padding=padding)
26 | # init
27 | for m in self.state_dict():
28 | if isinstance(m, nn.Conv2d):
29 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
30 | nn.init.constant_(m.bias, 0)
31 |
32 | def forward(self, x, hidden):
33 | if hidden is None:
34 | size_h = [x.data.size()[0], self.hidden_channels] + list(x.data.size()[2:])
35 | if self.cuda_flag:
36 | hidden = torch.zeros(size_h).cuda()
37 | else:
38 | hidden = torch.zeros(size_h)
39 |
40 | inputs = torch.cat((x, hidden), dim=1)
41 | reset_gate = torch.sigmoid(self.reset_gate(inputs))
42 | update_gate = torch.sigmoid(self.update_gate(inputs))
43 |
44 | reset_hidden = reset_gate * hidden
45 | reset_inputs = torch.tanh(self.output_gate(torch.cat((x, reset_hidden), dim=1)))
46 | new_hidden = (1 - update_gate)*reset_inputs + update_gate*hidden
47 |
48 | return new_hidden
--------------------------------------------------------------------------------
/libs/modules/non_local_dot_product.py:
--------------------------------------------------------------------------------
1 |
2 | #!/usr/bin/env python
3 | # coding: utf-8
4 | #
5 | # Author: AlexHex7
6 | # URL: https://github.com/AlexHex7/Non-local_pytorch
7 |
8 | import torch
9 | from torch import nn
10 | from torch.nn import functional as F
11 |
12 |
13 | class _NonLocalBlockND(nn.Module):
14 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample=True, bn_layer=True):
15 | super(_NonLocalBlockND, self).__init__()
16 |
17 | assert dimension in [1, 2, 3]
18 |
19 | self.dimension = dimension
20 | self.sub_sample = sub_sample
21 |
22 | self.in_channels = in_channels
23 | self.inter_channels = inter_channels
24 |
25 | if self.inter_channels is None:
26 | self.inter_channels = in_channels // 2
27 | if self.inter_channels == 0:
28 | self.inter_channels = 1
29 |
30 | if dimension == 3:
31 | conv_nd = nn.Conv3d
32 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
33 | bn = nn.BatchNorm3d
34 | elif dimension == 2:
35 | conv_nd = nn.Conv2d
36 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
37 | bn = nn.BatchNorm2d
38 | else:
39 | conv_nd = nn.Conv1d
40 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
41 | bn = nn.BatchNorm1d
42 |
43 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
44 | kernel_size=1, stride=1, padding=0)
45 |
46 | if bn_layer:
47 | self.W = nn.Sequential(
48 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
49 | kernel_size=1, stride=1, padding=0),
50 | bn(self.in_channels)
51 | )
52 | nn.init.constant_(self.W[1].weight, 0)
53 | nn.init.constant_(self.W[1].bias, 0)
54 | else:
55 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
56 | kernel_size=1, stride=1, padding=0)
57 | nn.init.constant_(self.W.weight, 0)
58 | nn.init.constant_(self.W.bias, 0)
59 |
60 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
61 | kernel_size=1, stride=1, padding=0)
62 |
63 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
64 | kernel_size=1, stride=1, padding=0)
65 |
66 | if sub_sample:
67 | self.g = nn.Sequential(self.g, max_pool_layer)
68 | self.phi = nn.Sequential(self.phi, max_pool_layer)
69 |
70 | def forward(self, x):
71 | '''
72 | :param x: (b, c, t, h, w)
73 | :return:
74 | '''
75 |
76 | batch_size = x.size(0)
77 |
78 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
79 | g_x = g_x.permute(0, 2, 1)
80 |
81 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
82 | theta_x = theta_x.permute(0, 2, 1)
83 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
84 | f = torch.matmul(theta_x, phi_x)
85 | N = f.size(-1)
86 | f_div_C = f / N
87 |
88 | y = torch.matmul(f_div_C, g_x)
89 | y = y.permute(0, 2, 1).contiguous()
90 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
91 | W_y = self.W(y)
92 | z = W_y + x
93 |
94 | return z
95 |
96 |
97 | class NONLocalBlock1D(_NonLocalBlockND):
98 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
99 | super(NONLocalBlock1D, self).__init__(in_channels,
100 | inter_channels=inter_channels,
101 | dimension=1, sub_sample=sub_sample,
102 | bn_layer=bn_layer)
103 |
104 |
105 | class NONLocalBlock2D(_NonLocalBlockND):
106 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
107 | super(NONLocalBlock2D, self).__init__(in_channels,
108 | inter_channels=inter_channels,
109 | dimension=2, sub_sample=sub_sample,
110 | bn_layer=bn_layer)
111 |
112 |
113 | class NONLocalBlock3D(_NonLocalBlockND):
114 | def __init__(self, in_channels, inter_channels=None, sub_sample=True, bn_layer=True):
115 | super(NONLocalBlock3D, self).__init__(in_channels,
116 | inter_channels=inter_channels,
117 | dimension=3, sub_sample=sub_sample,
118 | bn_layer=bn_layer)
119 |
120 |
121 | if __name__ == '__main__':
122 | import torch
123 |
124 | for (sub_sample, bn_layer) in [(True, True), (False, False), (True, False), (False, True)]:
125 | img = torch.zeros(2, 3, 20)
126 | net = NONLocalBlock1D(3, sub_sample=sub_sample, bn_layer=bn_layer)
127 | out = net(img)
128 | print(out.size())
129 |
130 | img = torch.zeros(2, 3, 20, 20)
131 | net = NONLocalBlock2D(3, sub_sample=sub_sample, bn_layer=bn_layer)
132 | out = net(img)
133 | print(out.size())
134 |
135 | img = torch.randn(2, 3, 8, 20, 20)
136 | net = NONLocalBlock3D(3, sub_sample=sub_sample, bn_layer=bn_layer)
137 | out = net(img)
138 | print(out.size())
--------------------------------------------------------------------------------
/libs/networks/__init__.py:
--------------------------------------------------------------------------------
1 | from .models import ImageModel, VideoModel
--------------------------------------------------------------------------------
/libs/networks/models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | from libs.networks.rcrnet import RCRNet
8 | from libs.modules.convgru import ConvGRUCell
9 | from libs.modules.non_local_dot_product import NONLocalBlock3D
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 | class ImageModel(nn.Module):
16 | '''
17 | RCRNet
18 | '''
19 | def __init__(self, pretrained=False):
20 | super(ImageModel, self).__init__()
21 | self.backbone = RCRNet(
22 | n_classes=1,
23 | output_stride=16,
24 | pretrained=pretrained
25 | )
26 | def forward(self, frame):
27 | seg = self.backbone(frame)
28 | return seg
29 |
30 | class VideoModel(nn.Module):
31 | '''
32 | RCRNet+NER
33 | '''
34 | def __init__(self, output_stride=16):
35 | super(VideoModel, self).__init__()
36 | # video mode + video dataset
37 | self.backbone = RCRNet(
38 | n_classes=1,
39 | output_stride=output_stride,
40 | pretrained=False,
41 | input_channels=3
42 | )
43 | self.convgru_forward = ConvGRUCell(256, 256, 3)
44 | self.convgru_backward = ConvGRUCell(256, 256, 3)
45 | self.bidirection_conv = nn.Conv2d(512, 256, 3, 1, 1)
46 |
47 | self.non_local_block = NONLocalBlock3D(256, sub_sample=False, bn_layer=False)
48 | self.non_local_block2 = NONLocalBlock3D(256, sub_sample=False, bn_layer=False)
49 |
50 | self.freeze_bn()
51 |
52 | def freeze_bn(self):
53 | for m in self.backbone.named_modules():
54 | if isinstance(m[1], nn.BatchNorm2d):
55 | m[1].eval()
56 |
57 | def forward(self, clip):
58 | clip_feats = [self.backbone.feat_conv(frame) for frame in clip]
59 | feats_time = [feats[-1] for feats in clip_feats]
60 | feats_time = torch.stack(feats_time, dim=2)
61 | feats_time = self.non_local_block(feats_time)
62 |
63 | # Deep Bidirectional ConvGRU
64 | frame = clip[0]
65 | feat = feats_time[:,:,0,:,:]
66 | feats_forward = []
67 | # forward
68 | for i in range(len(clip)):
69 | feat = self.convgru_forward(feats_time[:,:,i,:,:], feat)
70 | feats_forward.append(feat)
71 | # backward
72 | feat = feats_forward[-1]
73 | feats_backward = []
74 | for i in range(len(clip)):
75 | feat = self.convgru_backward(feats_forward[len(clip)-1-i], feat)
76 | feats_backward.append(feat)
77 |
78 | feats_backward = feats_backward[::-1]
79 | feats = []
80 | for i in range(len(clip)):
81 | feat = torch.tanh(self.bidirection_conv(torch.cat((feats_forward[i], feats_backward[i]), dim=1)))
82 | feats.append(feat)
83 | feats = torch.stack(feats, dim=2)
84 |
85 | feats = self.non_local_block2(feats)
86 | preds = []
87 | for i, frame in enumerate(clip):
88 | seg = self.backbone.seg_conv(clip_feats[i][0], clip_feats[i][1], clip_feats[i][2], feats[:,:,i,:,:], frame.shape[2:])
89 | preds.append(seg)
90 | return preds
--------------------------------------------------------------------------------
/libs/networks/pseudo_label_generator.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | from libs.networks.rcrnet import RCRNet
8 |
9 | from flownet2.models import FlowNet2
10 | from flownet2.networks.resample2d_package.resample2d import Resample2d
11 |
12 | import torch
13 | import torch.nn as nn
14 |
15 | _mean = [0.485, 0.456, 0.406]
16 | _std = [0.229, 0.224, 0.225]
17 |
18 | def normalize_flow(flow):
19 | origin_size = flow.shape[2:]
20 | flow[:, 0, :, :] /= origin_size[1] # dx
21 | flow[:, 1, :, :] /= origin_size[0] # dy
22 | norm_flow = (flow[:, 0, :, :] ** 2 + flow[:, 1, :, :] ** 2) ** 0.5
23 | return norm_flow.unsqueeze(1)
24 |
25 | def compute_flow(flownet, data, data_ref):
26 | # flow from data_ref to data
27 | images = [data[0].clone(), data_ref[0].clone()]
28 | for image in images:
29 | for i, (mean, std) in enumerate(zip(_mean, _std)):
30 | image[i].mul_(std).add_(mean)
31 | images = torch.stack(images)
32 | images = images.permute((1, 0, 2, 3)) # to channel, 2, h, w
33 | im = images.unsqueeze(0).float() # add batch_size = 1 [batch_size, channel, 2, h, w]
34 | return flownet(im)
35 |
36 | def resize_flow(flow, size):
37 | origin_size = flow.shape[2:]
38 | flow = F.interpolate(flow.clone(), size=size, mode="near")
39 | flow[:, 0, :, :] /= origin_size[1] / size[1] # dx
40 | flow[:, 1, :, :] /= origin_size[0] / size[0] # dy
41 | return flow
42 |
43 | class FGPLG(nn.Module):
44 | def __init__(self, args, output_stride=16):
45 | super(FGPLG, self).__init__()
46 | self.flownet = FlowNet2(args)
47 | self.warp = Resample2d()
48 | channels = 7
49 |
50 | self.backbone = RCRNet(
51 | n_classes=1,
52 | output_stride=output_stride,
53 | pretrained=False,
54 | input_channels=channels
55 | )
56 |
57 | self.freeze_bn()
58 | self.freeze_layer()
59 |
60 | def freeze_bn(self):
61 | for m in self.backbone.named_modules():
62 | if isinstance(m[1], nn.BatchNorm2d):
63 | m[1].eval()
64 | def freeze_layer(self):
65 | if hasattr(self, 'flownet'):
66 | for p in self.flownet.parameters():
67 | p.requires_grad = False
68 |
69 | def generate_pseudo_label(self, frame, frame_l, frame_r, label_l, label_r):
70 | flow_forward = compute_flow(self.flownet, frame, frame_l)
71 | flow_backward = compute_flow(self.flownet, frame, frame_r)
72 | warp_label_l = self.warp(label_l, flow_forward)
73 | warp_label_r = self.warp(label_r, flow_backward)
74 | inputs = torch.cat((
75 | frame,
76 | warp_label_l,
77 | warp_label_r,
78 | normalize_flow(flow_forward),
79 | normalize_flow(flow_backward)
80 | ), 1)
81 | pseudo_label = self.backbone(inputs)
82 | return pseudo_label
83 |
84 | def forward(self, clip, clip_label):
85 | pseudo_label = self.generate_pseudo_label(clip[1], clip[0], clip[2], clip_label[0], clip_label[2])
86 | return pseudo_label, clip_label[1]
87 |
--------------------------------------------------------------------------------
/libs/networks/rcrnet.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | from collections import OrderedDict
8 |
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 |
13 | from .resnet_dilation import resnet50, Bottleneck, conv1x1
14 |
15 | class _ConvBatchNormReLU(nn.Sequential):
16 | def __init__(self,
17 | in_channels,
18 | out_channels,
19 | kernel_size,
20 | stride,
21 | padding,
22 | dilation,
23 | relu=True,
24 | ):
25 | super(_ConvBatchNormReLU, self).__init__()
26 | self.add_module(
27 | "conv",
28 | nn.Conv2d(
29 | in_channels=in_channels,
30 | out_channels=out_channels,
31 | kernel_size=kernel_size,
32 | stride=stride,
33 | padding=padding,
34 | dilation=dilation,
35 | bias=False,
36 | ),
37 | )
38 | self.add_module(
39 | "bn",
40 | nn.BatchNorm2d(out_channels),
41 | )
42 |
43 | if relu:
44 | self.add_module("relu", nn.ReLU())
45 |
46 | def forward(self, x):
47 | return super(_ConvBatchNormReLU, self).forward(x)
48 |
49 | class _ASPPModule(nn.Module):
50 | """Atrous Spatial Pyramid Pooling with image pool"""
51 |
52 | def __init__(self, in_channels, out_channels, output_stride):
53 | super(_ASPPModule, self).__init__()
54 | if output_stride == 8:
55 | pyramids = [12, 24, 36]
56 | elif output_stride == 16:
57 | pyramids = [6, 12, 18]
58 | self.stages = nn.Module()
59 | self.stages.add_module(
60 | "c0", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1)
61 | )
62 | for i, (dilation, padding) in enumerate(zip(pyramids, pyramids)):
63 | self.stages.add_module(
64 | "c{}".format(i + 1),
65 | _ConvBatchNormReLU(in_channels, out_channels, 3, 1, padding, dilation),
66 | )
67 | self.imagepool = nn.Sequential(
68 | OrderedDict(
69 | [
70 | ("pool", nn.AdaptiveAvgPool2d((1,1))),
71 | ("conv", _ConvBatchNormReLU(in_channels, out_channels, 1, 1, 0, 1)),
72 | ]
73 | )
74 | )
75 | self.fire = nn.Sequential(
76 | OrderedDict(
77 | [
78 | ("conv", _ConvBatchNormReLU(out_channels * 5, out_channels, 3, 1, 1, 1)),
79 | ("dropout", nn.Dropout2d(0.1))
80 | ]
81 | )
82 | )
83 |
84 | def forward(self, x):
85 | h = self.imagepool(x)
86 | h = [F.interpolate(h, size=x.shape[2:], mode="bilinear", align_corners=False)]
87 | for stage in self.stages.children():
88 | h += [stage(x)]
89 | h = torch.cat(h, dim=1)
90 | h = self.fire(h)
91 | return h
92 |
93 | class _RefinementModule(nn.Module):
94 | """ Reduce channels and refinment module"""
95 |
96 | def __init__(self,
97 | bottom_up_channels,
98 | reduce_channels,
99 | top_down_channels,
100 | refinement_channels,
101 | expansion=2
102 | ):
103 | super(_RefinementModule, self).__init__()
104 | downsample = None
105 | if bottom_up_channels != reduce_channels:
106 | downsample = nn.Sequential(
107 | conv1x1(bottom_up_channels, reduce_channels),
108 | nn.BatchNorm2d(reduce_channels),
109 | )
110 | self.skip = Bottleneck(bottom_up_channels, reduce_channels // expansion, 1, 1, downsample, expansion)
111 | self.refine = _ConvBatchNormReLU(reduce_channels + top_down_channels, refinement_channels, 3, 1, 1, 1)
112 | def forward(self, td, bu):
113 | td = self.skip(td)
114 | x = torch.cat((bu, td), dim=1)
115 | x = self.refine(x)
116 | return x
117 |
118 | class RCRNet(nn.Module):
119 |
120 | def __init__(self, n_classes, output_stride, input_channels=3, pretrained=False):
121 | super(RCRNet, self).__init__()
122 | self.resnet = resnet50(pretrained=pretrained, output_stride=output_stride, input_channels=input_channels)
123 | self.aspp = _ASPPModule(2048, 256, output_stride)
124 | # Decoder
125 | self.decoder = nn.Sequential(
126 | OrderedDict(
127 | [
128 | ("conv1", _ConvBatchNormReLU(128, 256, 3, 1, 1, 1)),
129 | ("conv2", nn.Conv2d(256, n_classes, kernel_size=1)),
130 | ]
131 | )
132 | )
133 | self.add_module("refinement1", _RefinementModule(1024, 96, 256, 128, 2))
134 | self.add_module("refinement2", _RefinementModule(512, 96, 128, 128, 2))
135 | self.add_module("refinement3", _RefinementModule(256, 96, 128, 128, 2))
136 |
137 | if pretrained:
138 | for key in self.state_dict():
139 | if 'resnet' not in key:
140 | self.init_layer(key)
141 |
142 | def init_layer(self, key):
143 | if key.split('.')[-1] == 'weight':
144 | if 'conv' in key:
145 | if self.state_dict()[key].ndimension() >= 2:
146 | nn.init.kaiming_normal_(self.state_dict()[key], mode='fan_out', nonlinearity='relu')
147 | elif 'bn' in key:
148 | self.state_dict()[key][...] = 1
149 | elif key.split('.')[-1] == 'bias':
150 | self.state_dict()[key][...] = 0.001
151 |
152 | def feat_conv(self, x):
153 | '''
154 | Spatial feature extractor
155 | '''
156 | block0 = self.resnet.conv1(x)
157 | block0 = self.resnet.bn1(block0)
158 | block0 = self.resnet.relu(block0)
159 | block0 = self.resnet.maxpool(block0)
160 |
161 | block1 = self.resnet.layer1(block0)
162 | block2 = self.resnet.layer2(block1)
163 | block3 = self.resnet.layer3(block2)
164 | block4 = self.resnet.layer4(block3)
165 | block4 = self.aspp(block4)
166 | return block1, block2, block3, block4
167 |
168 | def seg_conv(self, block1, block2, block3, block4, shape):
169 | '''
170 | Pixel-wise classifer
171 | '''
172 | bu1 = self.refinement1(block3, block4)
173 | bu1 = F.interpolate(bu1, size=block2.shape[2:], mode="bilinear", align_corners=False)
174 | bu2 = self.refinement2(block2, bu1)
175 | bu2 = F.interpolate(bu2, size=block1.shape[2:], mode="bilinear", align_corners=False)
176 | bu3 = self.refinement3(block1, bu2)
177 | bu3 = F.interpolate(bu3, size=shape, mode="bilinear", align_corners=False)
178 | seg = self.decoder(bu3)
179 | return seg
180 |
181 | def forward(self, x):
182 | block1, block2, block3, block4 = self.feat_conv(x)
183 | seg = self.seg_conv(block1, block2, block3, block4, x.shape[2:])
184 | return seg
--------------------------------------------------------------------------------
/libs/networks/resnet_dilation.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # This code is based on torchvison resnet
5 | # URL: https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
6 |
7 | import torch.nn as nn
8 | import torch.utils.model_zoo as model_zoo
9 |
10 |
11 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
12 | 'resnet152']
13 |
14 |
15 | model_urls = {
16 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
17 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
18 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
19 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
20 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
21 | }
22 |
23 |
24 | def conv3x3(in_planes, out_planes, stride=1, padding=1, dilation=1):
25 | """3x3 convolution with padding"""
26 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
27 | padding=padding, dilation=dilation, bias=False)
28 |
29 |
30 | def conv1x1(in_planes, out_planes, stride=1):
31 | """1x1 convolution"""
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
33 |
34 |
35 | class BasicBlock(nn.Module):
36 | expansion = 1
37 |
38 | def __init__(self, inplanes, planes, stride, dilation, downsample=None):
39 | super(BasicBlock, self).__init__()
40 | self.conv1 = conv3x3(inplanes, planes, stride)
41 | self.bn1 = nn.BatchNorm2d(planes)
42 | self.relu = nn.ReLU(inplace=True)
43 | self.conv2 = conv3x3(planes, planes, 1, dilation, dilation)
44 | self.bn2 = nn.BatchNorm2d(planes)
45 | self.downsample = downsample
46 | self.stride = stride
47 |
48 | def forward(self, x):
49 | residual = x
50 |
51 | out = self.conv1(x)
52 | out = self.bn1(out)
53 | out = self.relu(out)
54 |
55 | out = self.conv2(out)
56 | out = self.bn2(out)
57 |
58 | if self.downsample is not None:
59 | residual = self.downsample(x)
60 |
61 | out += residual
62 | out = self.relu(out)
63 |
64 | return out
65 |
66 |
67 | class Bottleneck(nn.Module):
68 | expansion = 4
69 |
70 | def __init__(self, inplanes, planes, stride, dilation, downsample=None, expansion=4):
71 | super(Bottleneck, self).__init__()
72 | self.expansion = expansion
73 | self.conv1 = conv1x1(inplanes, planes)
74 | self.bn1 = nn.BatchNorm2d(planes)
75 | self.conv2 = conv3x3(planes, planes, stride, dilation, dilation)
76 | self.bn2 = nn.BatchNorm2d(planes)
77 | self.conv3 = conv1x1(planes, planes * self.expansion)
78 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
79 | self.relu = nn.ReLU(inplace=True)
80 | self.downsample = downsample
81 | self.stride = stride
82 |
83 | def forward(self, x):
84 | residual = x
85 |
86 | out = self.conv1(x)
87 | out = self.bn1(out)
88 | out = self.relu(out)
89 |
90 | out = self.conv2(out)
91 | out = self.bn2(out)
92 | out = self.relu(out)
93 |
94 | out = self.conv3(out)
95 | out = self.bn3(out)
96 |
97 | if self.downsample is not None:
98 | residual = self.downsample(x)
99 |
100 | out += residual
101 | out = self.relu(out)
102 |
103 | return out
104 |
105 |
106 | class ResNet(nn.Module):
107 |
108 | def __init__(self, block, layers, output_stride, num_classes=1000, input_channels=3):
109 | super(ResNet, self).__init__()
110 | if output_stride == 8:
111 | stride = [1, 2, 1, 1]
112 | dilation = [1, 1, 2, 2]
113 | elif output_stride == 16:
114 | stride = [1, 2, 2, 1]
115 | dilation = [1, 1, 1, 2]
116 |
117 | self.inplanes = 64
118 | self.conv1 = nn.Conv2d(input_channels, 64, kernel_size=7, stride=2, padding=3,
119 | bias=False)
120 | self.bn1 = nn.BatchNorm2d(64)
121 | self.relu = nn.ReLU(inplace=True)
122 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
123 | self.layer1 = self._make_layer(block, 64, layers[0], stride=stride[0], dilation=dilation[0])
124 | self.layer2 = self._make_layer(block, 128, layers[1], stride=stride[1], dilation=dilation[1])
125 | self.layer3 = self._make_layer(block, 256, layers[2], stride=stride[2], dilation=dilation[2])
126 | self.layer4 = self._make_layer(block, 512, layers[3], stride=stride[3], dilation=dilation[3])
127 | self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
128 | self.fc = nn.Linear(512 * block.expansion, num_classes)
129 |
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
133 | elif isinstance(m, nn.BatchNorm2d):
134 | nn.init.constant_(m.weight, 1)
135 | nn.init.constant_(m.bias, 0)
136 |
137 | def _make_layer(self, block, planes, blocks, stride, dilation):
138 | downsample = None
139 | if stride != 1 or self.inplanes != planes * block.expansion:
140 | downsample = nn.Sequential(
141 | conv1x1(self.inplanes, planes * block.expansion, stride),
142 | nn.BatchNorm2d(planes * block.expansion),
143 | )
144 |
145 | layers = []
146 | layers.append(block(self.inplanes, planes, stride, dilation, downsample))
147 | self.inplanes = planes * block.expansion
148 | for _ in range(1, blocks):
149 | layers.append(block(self.inplanes, planes, 1, dilation))
150 |
151 | return nn.Sequential(*layers)
152 |
153 | def forward(self, x):
154 | x = self.conv1(x)
155 | x = self.bn1(x)
156 | x = self.relu(x)
157 | x = self.maxpool(x)
158 |
159 | x = self.layer1(x)
160 | x = self.layer2(x)
161 | x = self.layer3(x)
162 | x = self.layer4(x)
163 |
164 | x = self.avgpool(x)
165 | x = x.view(x.size(0), -1)
166 | x = self.fc(x)
167 |
168 | return x
169 |
170 |
171 | def resnet18(pretrained=False, **kwargs):
172 | """Constructs a ResNet-18 model.
173 | Args:
174 | pretrained (bool): If True, returns a model pre-trained on ImageNet
175 | """
176 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
177 | if pretrained:
178 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
179 | return model
180 |
181 |
182 | def resnet34(pretrained=False, **kwargs):
183 | """Constructs a ResNet-34 model.
184 | Args:
185 | pretrained (bool): If True, returns a model pre-trained on ImageNet
186 | """
187 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
188 | if pretrained:
189 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
190 | return model
191 |
192 |
193 | def resnet50(pretrained=False, **kwargs):
194 | """Constructs a ResNet-50 model.
195 | Args:
196 | pretrained (bool): If True, returns a model pre-trained on ImageNet
197 | """
198 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
199 | if pretrained:
200 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
201 | return model
202 |
203 |
204 | def resnet101(pretrained=False, **kwargs):
205 | """Constructs a ResNet-101 model.
206 | Args:
207 | pretrained (bool): If True, returns a model pre-trained on ImageNet
208 | """
209 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
210 | if pretrained:
211 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
212 | return model
213 |
214 |
215 | def resnet152(pretrained=False, **kwargs):
216 | """Constructs a ResNet-152 model.
217 | Args:
218 | pretrained (bool): If True, returns a model pre-trained on ImageNet
219 | """
220 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
221 | if pretrained:
222 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
223 | return model
224 |
--------------------------------------------------------------------------------
/libs/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/libs/utils/__init__.py
--------------------------------------------------------------------------------
/libs/utils/logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: speedinghzl
5 | # URL: https://github.com/speedinghzl/pytorch-segmentation-toolbox
6 |
7 | import os
8 | import sys
9 | import logging
10 |
11 | _default_level_name = os.getenv('ENGINE_LOGGING_LEVEL', 'INFO')
12 | _default_level = logging.getLevelName(_default_level_name.upper())
13 |
14 |
15 | class LogFormatter(logging.Formatter):
16 | log_fout = None
17 | date_full = '[%(asctime)s %(lineno)d@%(filename)s:%(name)s] '
18 | date = '%(asctime)s '
19 | msg = '%(message)s'
20 |
21 | def format(self, record):
22 | if record.levelno == logging.DEBUG:
23 | mcl, mtxt = self._color_dbg, 'DBG'
24 | elif record.levelno == logging.WARNING:
25 | mcl, mtxt = self._color_warn, 'WRN'
26 | elif record.levelno == logging.ERROR:
27 | mcl, mtxt = self._color_err, 'ERR'
28 | else:
29 | mcl, mtxt = self._color_normal, ''
30 |
31 | if mtxt:
32 | mtxt += ' '
33 |
34 | if self.log_fout:
35 | self.__set_fmt(self.date_full + mtxt + self.msg)
36 | formatted = super(LogFormatter, self).format(record)
37 | # self.log_fout.write(formatted)
38 | # self.log_fout.write('\n')
39 | # self.log_fout.flush()
40 | return formatted
41 |
42 | self.__set_fmt(self._color_date(self.date) + mcl(mtxt + self.msg))
43 | formatted = super(LogFormatter, self).format(record)
44 |
45 | return formatted
46 |
47 | if sys.version_info.major < 3:
48 | def __set_fmt(self, fmt):
49 | self._fmt = fmt
50 | else:
51 | def __set_fmt(self, fmt):
52 | self._style._fmt = fmt
53 |
54 | @staticmethod
55 | def _color_dbg(msg):
56 | return '\x1b[36m{}\x1b[0m'.format(msg)
57 |
58 | @staticmethod
59 | def _color_warn(msg):
60 | return '\x1b[1;31m{}\x1b[0m'.format(msg)
61 |
62 | @staticmethod
63 | def _color_err(msg):
64 | return '\x1b[1;4;31m{}\x1b[0m'.format(msg)
65 |
66 | @staticmethod
67 | def _color_omitted(msg):
68 | return '\x1b[35m{}\x1b[0m'.format(msg)
69 |
70 | @staticmethod
71 | def _color_normal(msg):
72 | return msg
73 |
74 | @staticmethod
75 | def _color_date(msg):
76 | return '\x1b[32m{}\x1b[0m'.format(msg)
77 |
78 |
79 | def get_logger(log_dir=None, log_file=None, formatter=LogFormatter):
80 | logger = logging.getLogger()
81 | logger.setLevel(_default_level)
82 | del logger.handlers[:]
83 |
84 | if log_dir and log_file:
85 | if not os.path.isdir(log_dir):
86 | os.makedirs(log_dir)
87 | LogFormatter.log_fout = True
88 | file_handler = logging.FileHandler(log_file, mode='a')
89 | file_handler.setLevel(logging.INFO)
90 | file_handler.setFormatter(formatter)
91 | logger.addHandler(file_handler)
92 |
93 | stream_handler = logging.StreamHandler()
94 | stream_handler.setFormatter(formatter(datefmt='%d %H:%M:%S'))
95 | stream_handler.setLevel(0)
96 | logger.addHandler(stream_handler)
97 | return logger
--------------------------------------------------------------------------------
/libs/utils/metric.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 | # This is the python implementation of S-measure
7 |
8 | import numpy as np
9 |
10 | eps = np.finfo(np.float32).eps
11 | def StructureMeasure(prediction, GT):
12 | """
13 | StructureMeasure computes the similarity between the foreground map and
14 | ground truth(as proposed in "Structure-measure: A new way to evaluate
15 | foreground maps" [Deng-Ping Fan et. al - ICCV 2017])
16 | Usage:
17 | Q = StructureMeasure(prediction,GT)
18 | Input:
19 | prediction - Binary/Non binary foreground map with values in the range
20 | [0 1]. Type: np.float32
21 | GT - Binary ground truth. Type: np.bool
22 | Output:
23 | Q - The computed similarity score
24 | """
25 | # check input
26 | if prediction.dtype != np.float32:
27 | raise ValueError("prediction should be of type: np.float32")
28 | if np.amax(prediction) > 1 or np.amin(prediction) < 0:
29 | raise ValueError("prediction should be in the range of [0 1]")
30 | if GT.dtype != np.bool:
31 | raise ValueError("prediction should be of type: np.bool")
32 |
33 | y = np.mean(GT)
34 |
35 | if y == 0: # if the GT is completely black
36 | x = np.mean(prediction)
37 | Q = 1.0 - x
38 | elif y == 1: # if the GT is completely white
39 | x = np.mean(prediction)
40 | Q = x
41 | else:
42 | alpha = 0.5
43 | Q = alpha * S_object(prediction, GT) + (1 - alpha) * S_region(prediction, GT)
44 | if Q < 0:
45 | Q = 0
46 |
47 | return Q
48 |
49 | def S_object(prediction, GT):
50 | """
51 | S_object Computes the object similarity between foreground maps and ground
52 | truth(as proposed in "Structure-measure:A new way to evaluate foreground
53 | maps" [Deng-Ping Fan et. al - ICCV 2017])
54 | Usage:
55 | Q = S_object(prediction,GT)
56 | Input:
57 | prediction - Binary/Non binary foreground map with values in the range
58 | [0 1]. Type: np.float32
59 | GT - Binary ground truth. Type: np.bool
60 | Output:
61 | Q - The object similarity score
62 | """
63 | # compute the similarity of the foreground in the object level
64 | # Notice: inplace operation need deep copy
65 | prediction_fg = prediction.copy()
66 | prediction_fg[~GT] = 0
67 | O_FG = Object(prediction_fg, GT)
68 |
69 | # compute the similarity of the background
70 | prediction_bg = 1.0 - prediction;
71 | prediction_bg[GT] = 0
72 | O_BG = Object(prediction_bg, ~GT)
73 |
74 | # combine the foreground measure and background measure together
75 | u = np.mean(GT)
76 | Q = u * O_FG + (1 - u) * O_BG
77 |
78 | return Q
79 |
80 | def Object(prediction, GT):
81 | # compute the mean of the foreground or background in prediction
82 | x = np.mean(prediction[GT])
83 | # compute the standard deviations of the foreground or background in prediction
84 | sigma_x = np.std(prediction[GT])
85 |
86 | score = 2.0 * x / (x * x + 1.0 + sigma_x + eps)
87 | return score
88 |
89 | def S_region(prediction, GT):
90 | """
91 | S_region computes the region similarity between the foreground map and
92 | ground truth(as proposed in "Structure-measure:A new way to evaluate
93 | foreground maps" [Deng-Ping Fan et. al - ICCV 2017])
94 | Usage:
95 | Q = S_region(prediction,GT)
96 | Input:
97 | prediction - Binary/Non binary foreground map with values in the range
98 | [0 1]. Type: np.float32
99 | GT - Binary ground truth. Type: np.bool
100 | Output:
101 | Q - The region similarity score
102 | """
103 | # find the centroid of the GT
104 | X, Y = centroid(GT)
105 | # divide GT into 4 regions
106 | GT_1, GT_2, GT_3, GT_4, w1, w2, w3, w4 = divideGT(GT, X, Y)
107 | # Divede prediction into 4 regions
108 | prediction_1, prediction_2, prediction_3, prediction_4 = Divideprediction(prediction, X, Y)
109 | # Compute the ssim score for each regions
110 | Q1 = ssim(prediction_1, GT_1)
111 | Q2 = ssim(prediction_2, GT_2)
112 | Q3 = ssim(prediction_3, GT_3)
113 | Q4 = ssim(prediction_4, GT_4)
114 | #Sum the 4 scores
115 | Q = w1 * Q1 + w2 * Q2 + w3 * Q3 + w4 * Q4
116 | return Q
117 |
118 | def centroid(GT):
119 | """
120 | Centroid Compute the centroid of the GT
121 | Usage:
122 | X,Y = Centroid(GT)
123 | Input:
124 | GT - Binary ground truth. Type: logical.
125 | Output:
126 | X,Y - The coordinates of centroid.
127 | """
128 | rows, cols = GT.shape
129 |
130 | total = np.sum(GT)
131 | if total == 0:
132 | X = round(float(cols) / 2)
133 | Y = round(float(rows) / 2)
134 | else:
135 | i = np.arange(1, cols + 1).astype(np.float)
136 | j = (np.arange(1, rows + 1)[np.newaxis].T)[:,0].astype(np.float)
137 | X = round(np.sum(np.sum(GT, axis=0) * i) / total)
138 | Y = round(np.sum(np.sum(GT, axis=1) * j) / total)
139 | return int(X), int(Y)
140 |
141 | def divideGT(GT, X, Y):
142 | """
143 | LT - left top;
144 | RT - right top;
145 | LB - left bottom;
146 | RB - right bottom;
147 | """
148 | # width and height of the GT
149 | hei, wid = GT.shape
150 | area = float(wid * hei)
151 |
152 | # copy 4 regions
153 | LT = GT[0:Y, 0:X]
154 | RT = GT[0:Y, X:wid]
155 | LB = GT[Y:hei, 0:X]
156 | RB = GT[Y:hei, X:wid]
157 |
158 | # The different weight (each block proportional to the GT foreground region).
159 | w1 = (X * Y) / area
160 | w2 = ((wid - X) * Y) / area
161 | w3 = (X * (hei-Y)) / area
162 | w4 = 1.0 - w1 - w2 - w3
163 | return LT, RT, LB, RB, w1, w2, w3, w4
164 |
165 | def Divideprediction(prediction, X, Y):
166 | """
167 | Divide the prediction into 4 regions according to the centroid of the GT
168 | """
169 | hei, wid = prediction.shape
170 | # copy 4 regions
171 | LT = prediction[0:Y, 0:X]
172 | RT = prediction[0:Y, X:wid]
173 | LB = prediction[Y:hei, 0:X]
174 | RB = prediction[Y:hei, X:wid]
175 |
176 | return LT, RT, LB, RB
177 |
178 | def ssim(prediction, GT):
179 | """
180 | ssim computes the region similarity between foreground maps and ground
181 | truth(as proposed in "Structure-measure: A new way to evaluate foreground
182 | maps" [Deng-Ping Fan et. al - ICCV 2017])
183 | Usage:
184 | Q = ssim(prediction,GT)
185 | Input:
186 | prediction - Binary/Non binary foreground map with values in the range
187 | [0 1]. Type: np.float32
188 | GT - Binary ground truth. Type: np.bool
189 | Output:
190 | Q - The region similarity score
191 | """
192 | dGT = GT.astype(np.float32)
193 |
194 | hei, wid = prediction.shape
195 | N = float(wid * hei)
196 |
197 | # Compute the mean of SM,GT
198 | x = np.mean(prediction)
199 | y = np.mean(dGT)
200 |
201 | # Compute the variance of SM,GT
202 | dx = prediction - x
203 | dy = dGT - y
204 | total = N - 1 + eps
205 | sigma_x2 = np.sum(dx * dx) / total
206 | sigma_y2 = np.sum(dy * dy) / total
207 |
208 | # Compute the covariance between SM and GT
209 | sigma_xy = np.sum(dx * dy) / total
210 |
211 | alpha = 4 * x * y * sigma_xy
212 | beta = (x * x + y * y) * (sigma_x2 + sigma_y2)
213 |
214 | if alpha != 0:
215 | Q = alpha / (beta + eps)
216 | elif beta == 0:
217 | Q = 1.0
218 | else:
219 | Q = 0
220 | return Q
221 |
222 |
--------------------------------------------------------------------------------
/libs/utils/pyt_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: speedinghzl
5 | # URL: https://github.com/speedinghzl/pytorch-segmentation-toolbox
6 |
7 | import time
8 | import logging
9 |
10 | import torch
11 |
12 | from .logger import get_logger
13 |
14 | logger = get_logger()
15 |
16 | def load_model(model, model_file, is_restore=False):
17 | t_start = time.time()
18 | if isinstance(model_file, str):
19 | device = torch.device('cpu')
20 | state_dict = torch.load(model_file, map_location=device)
21 | if 'state_dict' in state_dict.keys():
22 | state_dict = state_dict['state_dict']
23 | else:
24 | state_dict = model_file
25 | t_ioend = time.time()
26 |
27 | if not is_restore:
28 | # extend the input channels of FGPLG from 3 to 7
29 | v2 = model.backbone.resnet.conv1.weight
30 | if v2.size(1) > 3:
31 | v = state_dict['backbone.resnet.conv1.weight']
32 | v = torch.cat((v,v2[:,3:,:,:]), dim=1)
33 | state_dict['backbone.resnet.conv1.weight'] = v
34 |
35 | model.load_state_dict(state_dict, strict=False)
36 | ckpt_keys = set(state_dict.keys())
37 | own_keys = set(model.state_dict().keys())
38 | missing_keys = own_keys - ckpt_keys
39 | unexpected_keys = ckpt_keys - own_keys
40 |
41 | if len(missing_keys) > 0:
42 | logger.warning('Missing key(s) in state_dict: {}'.format(
43 | ', '.join('{}'.format(k) for k in missing_keys)))
44 |
45 | if len(unexpected_keys) > 0:
46 | logger.warning('Unexpected key(s) in state_dict: {}'.format(
47 | ', '.join('{}'.format(k) for k in unexpected_keys)))
48 |
49 | del state_dict
50 | t_end = time.time()
51 | logger.info(
52 | "Load model, Time usage:\n\tIO: {}, initialize parameters: {}".format(
53 | t_ioend - t_start, t_end - t_ioend))
54 |
55 | return model
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/models/.gitkeep
--------------------------------------------------------------------------------
/models/checkpoints/.gitkeep:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Kinpzz/RCRNet-Pytorch/8d9f0fe0c7ad651db7578b2d96741de11036ef82/models/checkpoints/.gitkeep
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | from __future__ import absolute_import, division, print_function
8 | import os
9 |
10 | import torch
11 | import torch.nn as nn
12 | from torch.utils import data
13 |
14 | import argparse
15 | from tqdm import tqdm
16 | import numpy as np
17 |
18 | from libs.datasets import get_transforms, get_datasets
19 | from libs.networks import VideoModel
20 | from libs.utils.metric import StructureMeasure
21 | from libs.utils.pyt_utils import load_model
22 |
23 | parser = argparse.ArgumentParser()
24 |
25 | # Dataloading-related settings
26 | parser.add_argument('--data', type=str, default='data/datasets/',
27 | help='path to datasets folder')
28 | parser.add_argument('--checkpoint', default='models/image_pretrained_model.pth',
29 | help='path to the pretrained checkpoint')
30 | parser.add_argument('--dataset-config', default='config/datasets.yaml',
31 | help='dataset config file')
32 | parser.add_argument('--save-folder', default='models/checkpoints',
33 | help='location to save checkpoint models')
34 | parser.add_argument('--pseudo-label-folder', default='',
35 | help='location to load pseudo-labels')
36 | parser.add_argument('-j', '--num_workers', default=1, type=int, metavar='N',
37 | help='number of data loading workers.')
38 |
39 | # Training procedure settings
40 | parser.add_argument('--batch-size', default=1, type=int,
41 | help='batch size for each gpu. Only support 1 for video clips.')
42 | parser.add_argument('--backup-epochs', type=int, default=1,
43 | help='iteration epoch to perform state backups')
44 | parser.add_argument('--epochs', type=int, default=50,
45 | help='upper epoch limit')
46 | parser.add_argument('--start-epoch', type=int, default=0,
47 | help='epoch number to resume')
48 | parser.add_argument('--eval-first', default=False, action='store_true',
49 | help='evaluate model weights before training')
50 | parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float,
51 | help='initial learning rate')
52 |
53 | # Model settings
54 | parser.add_argument('--size', default=448, type=int,
55 | help='image size')
56 | parser.add_argument('--os', default=16, type=int,
57 | help='output stride.')
58 | parser.add_argument("--clip_len", type=int, default=4,
59 | help="the number of frames in a video clip.")
60 |
61 | args = parser.parse_args()
62 |
63 | cuda = torch.cuda.is_available()
64 | device = torch.device("cuda" if cuda else "cpu")
65 |
66 | if cuda:
67 | torch.backends.cudnn.benchmark = True
68 | current_device = torch.cuda.current_device()
69 | print("Running on", torch.cuda.get_device_name(current_device))
70 | else:
71 | print("Running on CPU")
72 |
73 | data_transforms = get_transforms(
74 | input_size=(args.size, args.size),
75 | image_mode=False
76 | )
77 |
78 | train_dataset = get_datasets(
79 | name_list=["DAVIS2016", "FBMS", "VOS"],
80 | split_list=["train", "train", "train"],
81 | config_path=args.dataset_config,
82 | root=args.data,
83 | training=True,
84 | transforms=data_transforms['train'],
85 | read_clip=True,
86 | random_reverse_clip=True,
87 | clip_len=args.clip_len
88 | )
89 | val_dataset = get_datasets(
90 | name_list="VOS",
91 | split_list="val",
92 | config_path=args.dataset_config,
93 | root=args.data,
94 | training=True,
95 | transforms=data_transforms['val'],
96 | read_clip=True,
97 | random_reverse_clip=False,
98 | clip_len=args.clip_len
99 | )
100 |
101 | train_dataloader = data.DataLoader(
102 | dataset=train_dataset,
103 | batch_size=args.batch_size,
104 | num_workers=args.num_workers,
105 | shuffle=True,
106 | drop_last=True
107 | )
108 | val_dataloader = data.DataLoader(
109 | dataset=val_dataset,
110 | batch_size=args.batch_size,
111 | num_workers=args.num_workers,
112 | shuffle=False
113 | )
114 | dataloaders = {'train': train_dataloader, 'val': val_dataloader}
115 |
116 | # loading pseudo-labels for training
117 | if os.path.exists(args.pseudo_label_folder):
118 | print("Loading pseudo-labels from {}".format(args.pseudo_label_folder))
119 | datasets = dataloaders['train'].dataset
120 | for dataset in datasets.datasets:
121 | dataset._reset_files(clip_len=args.clip_len, label_dir=os.path.join(args.pseudo_label_folder, dataset.name))
122 | if isinstance(datasets, data.ConcatDataset):
123 | datasets.cumulative_sizes = datasets.cumsum(datasets.datasets)
124 |
125 | model = VideoModel(output_stride=args.os)
126 | # load pretrained models
127 | if os.path.exists(args.checkpoint):
128 | print('Loading state dict from: {0}'.format(args.checkpoint))
129 | if args.start_epoch == 0:
130 | model = load_model(model=model, model_file=args.checkpoint, is_restore=False)
131 | else:
132 | model = load_model(model=model, model_file=args.checkpoint, is_restore=True)
133 | else:
134 | raise ValueError("Cannot find model file at {}".format(args.checkpoint))
135 |
136 | model = nn.DataParallel(model)
137 | model.to(device)
138 |
139 | criterion = nn.BCEWithLogitsLoss()
140 |
141 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.module.parameters()), lr=args.lr)
142 |
143 | if not os.path.exists(args.save_folder):
144 | os.makedirs(args.save_folder)
145 |
146 | def train():
147 |
148 | best_smeasure = 0.0
149 | best_epoch = 0
150 | for epoch in range(args.start_epoch, args.epochs+1):
151 | # Each epoch has a training and validation phase
152 | if args.eval_first:
153 | phases = ['val']
154 | else:
155 | phases = ['train', 'val']
156 |
157 | for phase in phases:
158 | if phase == 'train':
159 | model.train() # Set model to training mode
160 | model.module.freeze_bn()
161 | else:
162 | model.eval() # Set model to evaluate mode
163 |
164 | running_loss = 0.0
165 | running_mae = 0.0
166 | running_smean = 0.0
167 | print("{} epoch {}...".format(phase, epoch))
168 | # Iterate over data.
169 | for data in tqdm(dataloaders[phase]):
170 | images, labels = [], []
171 | for frame in data:
172 | images.append(frame['image'].to(device))
173 | labels.append(frame['label'].to(device))
174 | # zero the parameter gradients
175 | optimizer.zero_grad()
176 | # track history if only in train
177 | with torch.set_grad_enabled(phase == 'train'):
178 | # read clips
179 | preds = model(images)
180 | loss = []
181 | for pred, label in zip(preds, labels):
182 | loss.append(criterion(pred, label))
183 | # backward + optimize only if in training phase
184 | if phase == 'train':
185 | torch.autograd.backward(loss)
186 | optimizer.step()
187 | # statistics
188 | for _loss in loss:
189 | running_loss += _loss.item()
190 | preds = [torch.sigmoid(pred) for pred in preds] # activation
191 |
192 | # iterate list
193 | for i, (label_, pred_) in enumerate(zip(labels, preds)):
194 | # interate batch
195 | for j, (label, pred) in enumerate(zip(label_.detach().cpu(), pred_.detach().cpu())):
196 | pred_idx = pred[0,:,:].numpy()
197 | label_idx = label[0,:,:].numpy()
198 | if phase == 'val':
199 | running_smean += StructureMeasure(pred_idx.astype(np.float32), (label_idx>=0.5).astype(np.bool))
200 | running_mae += np.abs(pred_idx - label_idx).mean()
201 |
202 | samples_num = len(dataloaders[phase].dataset)
203 | samples_num *= args.clip_len
204 | epoch_loss = running_loss / samples_num
205 | epoch_mae = running_mae / samples_num
206 | print('{} Loss: {:.4f}'.format(phase, epoch_loss))
207 | print('{} MAE: {:.4f}'.format(phase, epoch_mae))
208 |
209 | # save current best epoch
210 | if phase == 'val':
211 | epoch_smeasure = running_smean / samples_num
212 | print('{} S-measure: {:.4f}'.format(phase, epoch_smeasure))
213 | if epoch_smeasure > best_smeasure:
214 | best_smeasure = epoch_smeasure
215 | best_epoch = epoch
216 | model_path = os.path.join(args.save_folder, "video_current_best_model.pth")
217 | print("Saving current best model at: {}".format(model_path) )
218 | torch.save(
219 | model.module.state_dict(),
220 | model_path,
221 | )
222 | if epoch > 0 and epoch % args.backup_epochs == 0:
223 | # save model
224 | model_path = os.path.join(args.save_folder, "video_epoch-{}.pth".format(epoch))
225 | print("Backup model at: {}".format(model_path))
226 | torch.save(
227 | model.module.state_dict(),
228 | model_path,
229 | )
230 |
231 | print('Best S-measure: {} at epoch {}'.format(best_smeasure, best_epoch))
232 |
233 | if __name__ == "__main__":
234 | train()
235 |
--------------------------------------------------------------------------------
/train_fgplg.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | #
4 | # Author: Pengxiang Yan
5 | # Email: yanpx (at) mail2.sysu.edu.cn
6 |
7 | from __future__ import absolute_import, division, print_function
8 | import os
9 | import sys
10 | sys.path.append('flownet2')
11 |
12 | import torch
13 | import torch.nn as nn
14 | from torch.utils import data
15 |
16 | import argparse
17 | from tqdm import tqdm
18 | import numpy as np
19 |
20 | from libs.datasets import get_transforms, get_datasets
21 | from libs.networks.pseudo_label_generator import FGPLG
22 | from libs.utils.metric import StructureMeasure
23 | from libs.utils.pyt_utils import load_model
24 |
25 | parser = argparse.ArgumentParser()
26 |
27 | # Dataloading-related settings
28 | parser.add_argument('--data', type=str, default='data/datasets/',
29 | help='path to datasets folder')
30 | parser.add_argument('--checkpoint', default='models/image_pretrained_model.pth',
31 | help='path to the pretrained checkpoint')
32 | parser.add_argument('--flownet-checkpoint', default='models/FlowNet2_checkpoint.pth.tar',
33 | help='path to the checkpoint of pretrained flownet2')
34 | parser.add_argument('--dataset-config', default='config/datasets.yaml',
35 | help='dataset config file')
36 | parser.add_argument('--save-folder', default='models/checkpoints',
37 | help='location to save checkpoint models')
38 | parser.add_argument("--label_interval", default=5, type=int,
39 | help="the interval of ground truth labels")
40 | parser.add_argument('-j', '--num_workers', default=1, type=int, metavar='N',
41 | help='number of data loading workers.')
42 |
43 | # Training procedure settings
44 | parser.add_argument('--batch-size', default=1, type=int,
45 | help='batch size for each gpu. Only support 1 for video clips.')
46 | parser.add_argument('--backup-epochs', type=int, default=1,
47 | help='iteration epoch to perform state backups')
48 | parser.add_argument('--epochs', type=int, default=50,
49 | help='upper epoch limit')
50 | parser.add_argument('--start-epoch', type=int, default=0,
51 | help='epoch number to resume')
52 | parser.add_argument('--eval-first', default=False, action='store_true',
53 | help='evaluate model weights before training')
54 | parser.add_argument('--lr', '--learning-rate', default=1e-5, type=float,
55 | help='initial learning rate')
56 |
57 | # Model settings
58 | parser.add_argument('--size', default=448, type=int,
59 | help='image size')
60 | parser.add_argument('--os', default=16, type=int,
61 | help='output stride.')
62 | parser.add_argument("--clip_len", type=int, default=3,
63 | help="the number of frames in a video clip.")
64 |
65 | # FlowNet setting
66 | parser.add_argument("--fp16", action="store_true",
67 | help="Run model in pseudo-fp16 mode (fp16 storage fp32 math).")
68 | parser.add_argument("--rgb_max", type=float, default=1.)
69 |
70 |
71 | args = parser.parse_args()
72 |
73 | cuda = torch.cuda.is_available()
74 | device = torch.device("cuda" if cuda else "cpu")
75 |
76 | if cuda:
77 | torch.backends.cudnn.benchmark = True
78 | current_device = torch.cuda.current_device()
79 | print("Running on", torch.cuda.get_device_name(current_device))
80 | else:
81 | print("Running on CPU")
82 |
83 | data_transforms = get_transforms(
84 | input_size=(args.size, args.size),
85 | image_mode=False
86 | )
87 |
88 | train_dataset = get_datasets(
89 | name_list=["DAVIS2016", "FBMS", "VOS"],
90 | split_list=["train", "train", "train"],
91 | config_path=args.dataset_config,
92 | root=args.data,
93 | training=True,
94 | transforms=data_transforms['train'],
95 | read_clip=True,
96 | random_reverse_clip=True,
97 | label_interval=args.label_interval,
98 | clip_len=args.clip_len
99 | )
100 | val_dataset = get_datasets(
101 | name_list="VOS",
102 | split_list="val",
103 | config_path=args.dataset_config,
104 | root=args.data,
105 | training=True,
106 | transforms=data_transforms['val'],
107 | read_clip=True,
108 | random_reverse_clip=False,
109 | label_interval=args.label_interval,
110 | clip_len=args.clip_len
111 | )
112 |
113 | train_dataloader = data.DataLoader(
114 | dataset=train_dataset,
115 | batch_size=args.batch_size,
116 | num_workers=args.num_workers,
117 | shuffle=True,
118 | drop_last=True
119 | )
120 | val_dataloader = data.DataLoader(
121 | dataset=val_dataset,
122 | batch_size=args.batch_size,
123 | num_workers=args.num_workers,
124 | shuffle=False
125 | )
126 | dataloaders = {'train': train_dataloader, 'val': val_dataloader}
127 |
128 | pseudo_label_generator = FGPLG(args=args, output_stride=args.os)
129 |
130 | if os.path.exists(args.checkpoint):
131 | print('Loading state dict from: {0}'.format(args.checkpoint))
132 | if args.start_epoch == 0:
133 | pseudo_label_generator = load_model(model=pseudo_label_generator, model_file=args.checkpoint, is_restore=False)
134 | if os.path.exists(args.flownet_checkpoint):
135 | pseudo_label_generator.flownet = load_model(model=pseudo_label_generator.flownet, model_file=args.flownet_checkpoint, is_restore=True)
136 | else:
137 | raise ValueError("Cannot pretrained flownet model file at {}".format(args.flownet_checkpoint))
138 | else:
139 | pseudo_label_generator = load_model(model=pseudo_label_generator, model_file=args.checkpoint, is_restore=True)
140 | else:
141 | raise ValueError("Cannot find model file at {}".format(args.checkpoint))
142 |
143 | pseudo_label_generator = nn.DataParallel(pseudo_label_generator)
144 | pseudo_label_generator.to(device)
145 |
146 | criterion = nn.BCEWithLogitsLoss()
147 |
148 | optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, pseudo_label_generator.module.parameters()), lr=args.lr)
149 |
150 | if not os.path.exists(args.save_folder):
151 | os.makedirs(args.save_folder)
152 |
153 | def train():
154 | best_smeasure = 0.0
155 | best_epoch = 0
156 | for epoch in range(args.start_epoch, args.epochs+1):
157 | # Each epoch has a training and validation phase
158 | if args.eval_first:
159 | phases = ['val']
160 | else:
161 | phases = ['train', 'val']
162 |
163 | for phase in phases:
164 | if phase == 'train':
165 | pseudo_label_generator.train() # Set model to training mode
166 | pseudo_label_generator.module.freeze_bn()
167 | else:
168 | pseudo_label_generator.eval() # Set model to evaluate mode
169 |
170 | running_loss = 0.0
171 | running_acc = 0.0
172 | running_iou = 0.0
173 | running_mae = 0.0
174 | running_smean = 0.0
175 | print("{} epoch {}...".format(phase, epoch))
176 | # Iterate over data.
177 | for data in tqdm(dataloaders[phase]):
178 | images, labels = [], []
179 | for frame in data:
180 | images.append(frame['image'].to(device))
181 | labels.append(frame['label'].to(device))
182 | # zero the parameter gradients
183 | optimizer.zero_grad()
184 | # track history if only in train
185 | with torch.set_grad_enabled(phase == 'train'):
186 | # read clips
187 | preds, labels = pseudo_label_generator(images, labels)
188 | loss = criterion(preds, labels)
189 | # backward + optimize only if in training phase
190 | if phase == 'train':
191 | torch.autograd.backward(loss)
192 | optimizer.step()
193 | # statistics
194 | running_loss += loss.item()
195 | preds = torch.sigmoid(preds) # activation
196 |
197 | pred_idx = preds.squeeze().detach().cpu().numpy()
198 | label_idx = labels.squeeze().detach().cpu().numpy()
199 | if phase == 'val':
200 | running_smean += StructureMeasure(pred_idx.astype(np.float32), (label_idx>=0.5).astype(np.bool))
201 | running_mae += np.abs(pred_idx - label_idx).mean()
202 |
203 | samples_num = len(dataloaders[phase].dataset)
204 | epoch_loss = running_loss / samples_num
205 | epoch_mae = running_mae / samples_num
206 | print('{} Loss: {:.4f}'.format(phase, epoch_loss))
207 | print('{} MAE: {:.4f}'.format(phase, epoch_mae))
208 |
209 | # save current best epoch
210 | if phase == 'val':
211 | epoch_smeasure = running_smean / samples_num
212 | print('{} S-measure: {:.4f}'.format(phase, epoch_smeasure))
213 | if epoch_smeasure > best_smeasure:
214 | best_smeasure = epoch_smeasure
215 | best_epoch = epoch
216 | model_path = os.path.join(args.save_folder, "fgplg_current_best_model.pth")
217 | print("Saving current best model at: {}".format(model_path) )
218 | torch.save(
219 | pseudo_label_generator.module.state_dict(),
220 | model_path,
221 | )
222 | if epoch > 0 and epoch % args.backup_epochs == 0:
223 | # save model
224 | model_path = os.path.join(args.save_folder, "fgplg_epoch-{}.pth".format(epoch))
225 | print("Backup model at: {}".format(model_path))
226 | torch.save(
227 | pseudo_label_generator.module.state_dict(),
228 | model_path,
229 | )
230 |
231 | print('Best S-measure: {} at epoch {}'.format(best_smeasure, best_epoch))
232 |
233 | if __name__ == "__main__":
234 | train()
235 |
--------------------------------------------------------------------------------