├── .gitignore ├── LEStereo.png ├── LICENSE ├── README.md ├── config_utils ├── decode_args.py ├── predict_args.py ├── search_args.py └── train_args.py ├── create_link.sh ├── dataloaders ├── __init__.py ├── datasets │ ├── __init__.py │ └── stereo.py └── lists │ ├── kitti2012_test.list │ ├── kitti2012_train.list │ ├── kitti2012_train170.list │ ├── kitti2012_val24.list │ ├── kitti2015_test.list │ ├── kitti2015_train.list │ ├── kitti2015_train180.list │ ├── kitti2015_val20.list │ ├── middeval3_test.list │ ├── middeval3_train.list │ ├── sceneflow_search_trainA.list │ ├── sceneflow_search_trainB.list │ ├── sceneflow_search_val.list │ ├── sceneflow_test.list │ └── sceneflow_train.list ├── dataset ├── MiddEval3 ├── SceneFlow ├── kitti2012 └── kitti2015 ├── decode.py ├── decode.sh ├── models ├── __init__.py ├── build_model.py ├── build_model_2d.py ├── build_model_3d.py ├── cell_level_search_2d.py ├── cell_level_search_3d.py ├── decoding_formulas.py ├── genotypes_2d.py ├── genotypes_3d.py ├── operations_2d.py └── operations_3d.py ├── mypath.py ├── predict.py ├── predict_kitti12.sh ├── predict_kitti15.sh ├── predict_md.sh ├── predict_sf.sh ├── retrain ├── LEAStereo.py ├── new_model_2d.py ├── new_model_3d.py └── skip_model_3d.py ├── run ├── Kitti12 │ └── best │ │ └── best_1.16.pth ├── Kitti15 │ └── best │ │ ├── best.pth │ │ └── kitti15_best_1.65.pth ├── MiddEval3 │ └── best │ │ └── best.pth └── sceneflow │ ├── best │ ├── architecture │ │ ├── feature_genotype.npy │ │ ├── feature_network_path.npy │ │ ├── matching_genotype.npy │ │ └── matching_network_path.npy │ └── checkpoint │ │ └── best.pth │ └── experiment │ ├── feature_genotype.npy │ ├── feature_network_path.npy │ ├── matching_genotype.npy │ └── matching_network_path.npy ├── search.py ├── search.sh ├── thop ├── __init__.py ├── count_hooks.py ├── profile.py └── utils.py ├── train.py ├── train_kitti12.sh ├── train_kitti15.sh ├── train_md.sh ├── train_sf.sh └── utils ├── colorize.py ├── copy_state_dict.py ├── lr_scheduler.py ├── multadds_count.py ├── saver.py └── summaries.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | **/__pycache__/ 4 | *.py[cod] 5 | *$py.class 6 | /predict/* 7 | /apex/* 8 | /run/* 9 | 10 | # Packages # 11 | ############ 12 | # it's better to unpack these files and commit the raw source 13 | # git has its own built in compression methods 14 | *.7z 15 | *.dmg 16 | *.gz 17 | *.iso 18 | *.jar 19 | *.rar 20 | *.tar 21 | *.zip 22 | 23 | # C extensions 24 | *.so 25 | 26 | # Distribution / packaging 27 | .Python 28 | build/ 29 | develop-eggs/ 30 | dist/ 31 | downloads/ 32 | eggs/ 33 | .eggs/ 34 | lib/ 35 | lib64/ 36 | parts/ 37 | sdist/ 38 | var/ 39 | wheels/ 40 | share/python-wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | cover/ 70 | 71 | # Translations 72 | *.mo 73 | *.pot 74 | 75 | # Django stuff: 76 | *.log 77 | local_settings.py 78 | db.sqlite3 79 | db.sqlite3-journal 80 | 81 | # Flask stuff: 82 | instance/ 83 | .webassets-cache 84 | 85 | # Scrapy stuff: 86 | .scrapy 87 | 88 | # Sphinx documentation 89 | docs/_build/ 90 | 91 | # PyBuilder 92 | .pybuilder/ 93 | target/ 94 | 95 | # Jupyter Notebook 96 | .ipynb_checkpoints 97 | 98 | # IPython 99 | profile_default/ 100 | ipython_config.py 101 | 102 | # pyenv 103 | # For a library or package, you might want to ignore these files since the code is 104 | # intended to run in multiple environments; otherwise, check them in: 105 | # .python-version 106 | 107 | # pipenv 108 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 109 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 110 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 111 | # install all needed dependencies. 112 | #Pipfile.lock 113 | 114 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 115 | __pypackages__/ 116 | 117 | # Celery stuff 118 | celerybeat-schedule 119 | celerybeat.pid 120 | 121 | # SageMath parsed files 122 | *.sage.py 123 | 124 | # Environments 125 | .env 126 | .venv 127 | env/ 128 | venv/ 129 | ENV/ 130 | env.bak/ 131 | venv.bak/ 132 | 133 | # Spyder project settings 134 | .spyderproject 135 | .spyproject 136 | 137 | # Rope project settings 138 | .ropeproject 139 | 140 | # mkdocs documentation 141 | /site 142 | 143 | # mypy 144 | .mypy_cache/ 145 | .dmypy.json 146 | dmypy.json 147 | 148 | # Pyre type checker 149 | .pyre/ 150 | 151 | # pytype static type analyzer 152 | .pytype/ 153 | 154 | # Cython debug symbols 155 | cython_debug/ 156 | -------------------------------------------------------------------------------- /LEStereo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/LEStereo.png -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Xuelian Cheng 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## LEAStereo 2 | 3 | This repository contains the code for our NeurIPS 2020 paper `Hierarchical Neural Architecture Searchfor Deep Stereo Matching` [[NeurIPS 20](https://proceedings.neurips.cc/paper/2020/file/fc146be0b230d7e0a92e66a6114b840d-Paper.pdf)] 4 | 5 | ![alt text](./LEStereo.png) 6 | 7 | ## Requirements 8 | 9 | ### Environment 10 | 11 | 1. Python 3.8.* 12 | 2. CUDA 10.0 13 | 3. PyTorch 14 | 4. TorchVision 15 | 16 | ### Install 17 | Create a virtual environment and activate it. 18 | ```shell 19 | conda create -n leastereo python=3.8 20 | conda activate leastereo 21 | ``` 22 | The code has been tested with PyTorch 1.6 and Cuda 10.2. 23 | ```shell 24 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.2 -c pytorch 25 | conda install matplotlib path.py tqdm 26 | conda install tensorboard tensorboardX 27 | conda install scipy scikit-image opencv 28 | ``` 29 | 30 | Install Nvidia Apex 31 | 32 | 33 | Follow the instructions [here](https://github.com/NVIDIA/apex#quick-start). Apex is required for mixed precision training. 34 | Please do not use pip install apex - this will not install the correct package. 35 | 36 | ### Dataset 37 | To evaluate/train our LEAStereo network, you will need to download the required datasets. 38 | 39 | * [SceneFlow](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 40 | 41 | * [KITTI2015](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo) 42 | 43 | * [KITTI2012](http://www.cvlibs.net/datasets/kitti/eval_stereo_flow.php?benchmark=stereo) 44 | 45 | * [Middlebury 2014](https://vision.middlebury.edu/stereo/submit3/) 46 | 47 | Change the first column path in file `create_link.sh` with your actual dataset location. Then run `create_link.sh` that will create symbolic links to wherever the datasets were downloaded in the `datasets` folder. For Middlebury 2014 dataset, we perform our network on half resolution images. 48 | 49 | 50 | ```Shell 51 | ├── datasets 52 | ├── SceneFlow 53 | ├── camera_data 54 | ├── disparity 55 | ├── frames_finalpass 56 | ├── kitti2012 57 | ├── testing 58 | ├── training 59 | ├── kitti2015 60 | ├── testing 61 | ├── training 62 | ├── MiddEval3 63 | ├── testH 64 | ├── trainingH 65 | ``` 66 | 67 | ### Prediction 68 | 69 | You can evaluate a trained model using `prediction.sh` for each dataset, that would help you generate *.png or *.pfm images correspoding to different datasets. 70 | ```shell 71 | sh predict_sf.sh 72 | sh predict_md.sh 73 | sh predict_kitti12.sh 74 | sh predict_kitti15.sh 75 | ``` 76 | Results of our model on three benchmark datasets could also be found [here](https://drive.google.com/file/d/1Wcv-WzQToTwAiBfWpONrtyQSgsHrWqWC/view?usp=sharing) 77 | 78 | 79 | ### Architecture Search 80 | Three steps for the architecture search: 81 | 82 | #### 1. Search 83 | ```shell 84 | sh search.sh 85 | ``` 86 | #### 2. decode 87 | ```shell 88 | sh decode.sh 89 | ``` 90 | #### 3. retrain 91 | ```shell 92 | sh train_sf.sh 93 | sh train_md.sh 94 | sh train_kitti12.sh 95 | sh train_kitti15.sh 96 | ``` 97 | 98 | ### Acknowledgements 99 | This repository makes liberal use of code from [[AutoDeeplab](https://openaccess.thecvf.com/content_CVPR_2019/html/Liu_Auto-DeepLab_Hierarchical_Neural_Architecture_Search_for_Semantic_Image_Segmentation_CVPR_2019_paper.html)] [[pytorch code](https://github.com/NoamRosenberg/autodeeplab)(Non-official)]. 100 | 101 | ### Citing 102 | If you find this code useful, please consider to cite our work. 103 | 104 | ``` 105 | @article{cheng2020hierarchical, 106 | title={Hierarchical Neural Architecture Search for Deep Stereo Matching}, 107 | author={Cheng, Xuelian and Zhong, Yiran and Harandi, Mehrtash and Dai, Yuchao and Chang, Xiaojun and Li, Hongdong and Drummond, Tom and Ge, Zongyuan}, 108 | journal={Advances in Neural Information Processing Systems}, 109 | volume={33}, 110 | year={2020} 111 | } 112 | ``` 113 | 114 | -------------------------------------------------------------------------------- /config_utils/decode_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def obtain_decode_args(): 4 | parser = argparse.ArgumentParser(description="LEStereo Decoding..") 5 | parser.add_argument('--dataset', type=str, default='sceneflow', 6 | choices=['sceneflow', 'kitti15', 'kitti12', 'middlebury'], 7 | help='dataset name (default: sceneflow)') 8 | parser.add_argument('--step', type=int, default=3) 9 | parser.add_argument('--resume', type=str, default=None, 10 | help='put the path to resuming file if needed') 11 | return parser.parse_args() 12 | -------------------------------------------------------------------------------- /config_utils/predict_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def obtain_predict_args(): 4 | 5 | parser = argparse.ArgumentParser(description='LEStereo Prediction') 6 | parser.add_argument('--crop_height', type=int, required=True, help="crop height") 7 | parser.add_argument('--crop_width', type=int, required=True, help="crop width") 8 | parser.add_argument('--maxdisp', type=int, default=192, help="max disp") 9 | parser.add_argument('--resume', type=str, default='', help="resume from saved model") 10 | parser.add_argument('--cuda', type=bool, default=True, help='use cuda?') 11 | parser.add_argument('--sceneflow', type=int, default=0, help='sceneflow dataset? Default=False') 12 | parser.add_argument('--kitti2012', type=int, default=0, help='kitti 2012? Default=False') 13 | parser.add_argument('--kitti2015', type=int, default=0, help='kitti 2015? Default=False') 14 | parser.add_argument('--middlebury', type=int, default=0, help='Middlebury? Default=False') 15 | parser.add_argument('--data_path', type=str, required=True, help="data root") 16 | parser.add_argument('--test_list', type=str, required=True, help="training list") 17 | parser.add_argument('--save_path', type=str, default='./result/', help="location to save result") 18 | ######### LEStereo params#################### 19 | parser.add_argument('--fea_num_layers', type=int, default=6) 20 | parser.add_argument('--mat_num_layers', type=int, default=12) 21 | parser.add_argument('--fea_filter_multiplier', type=int, default=8) 22 | parser.add_argument('--mat_filter_multiplier', type=int, default=8) 23 | parser.add_argument('--fea_block_multiplier', type=int, default=4) 24 | parser.add_argument('--mat_block_multiplier', type=int, default=4) 25 | parser.add_argument('--fea_step', type=int, default=3) 26 | parser.add_argument('--mat_step', type=int, default=3) 27 | parser.add_argument('--net_arch_fea', default=None, type=str) 28 | parser.add_argument('--cell_arch_fea', default=None, type=str) 29 | parser.add_argument('--net_arch_mat', default=None, type=str) 30 | parser.add_argument('--cell_arch_mat', default=None, type=str) 31 | 32 | args = parser.parse_args() 33 | return args 34 | -------------------------------------------------------------------------------- /config_utils/search_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def obtain_search_args(): 4 | parser = argparse.ArgumentParser(description="LEStereo Searching...") 5 | parser.add_argument('--clean-module', type=int, default=0) 6 | parser.add_argument('--dataset', type=str, default='sceneflow', 7 | choices=['sceneflow', 'kitti15', 'kitti12', 'middlebury'], 8 | help='dataset name (default: sceneflow)') 9 | parser.add_argument('--stage', type=str, default='search', 10 | choices=['search', 'train']) 11 | parser.add_argument('--fea_num_layers', type=int, default=6) 12 | parser.add_argument('--mat_num_layers', type=int, default=12) 13 | parser.add_argument('--fea_filter_multiplier', type=int, default=8) 14 | parser.add_argument('--mat_filter_multiplier', type=int, default=8) 15 | parser.add_argument('--fea_block_multiplier', type=int, default=4) 16 | parser.add_argument('--mat_block_multiplier', type=int, default=4) 17 | parser.add_argument('--fea_step', type=int, default=2) 18 | parser.add_argument('--mat_step', type=int, default=2) 19 | parser.add_argument('--workers', type=int, default=0, 20 | metavar='N', help='dataloader threads') 21 | parser.add_argument('--max_disp', type=int, default=192, help="max disp") 22 | parser.add_argument('--crop_height', type=int, default=384, 23 | help="crop height") 24 | parser.add_argument('--crop_width', type=int, default=576, 25 | help="crop width") 26 | parser.add_argument('--freeze-bn', type=bool, default=False, 27 | help='whether to freeze bn parameters (default: False)') 28 | # training hyper params 29 | parser.add_argument('--epochs', type=int, default=None, metavar='N', 30 | help='number of epochs to train (default: auto)') 31 | parser.add_argument('--start_epoch', type=int, default=0, 32 | metavar='N', help='start epochs (default:0)') 33 | parser.add_argument('--alpha_epoch', type=int, default=10, 34 | metavar='N', help='epoch to start training alphas') 35 | parser.add_argument('--batch-size', type=int, default=2, 36 | metavar='N', help='input batch size for \ 37 | training (default: auto)') 38 | parser.add_argument('--testBatchSize', type=int, default=1, 39 | metavar='N', help='input batch size for \ 40 | testing (default: auto)') 41 | # optimizer params 42 | parser.add_argument('--lr', type=float, default=0.025, metavar='LR', 43 | help='learning rate (default: auto)') 44 | parser.add_argument('--min_lr', type=float, default=0.001) 45 | parser.add_argument('--arch-lr', type=float, default=1e-3, metavar='LR', 46 | help='learning rate for alpha and beta in architect searching process') 47 | 48 | parser.add_argument('--lr-scheduler', type=str, default='cos', 49 | choices=['poly', 'step', 'cos'], 50 | help='lr scheduler mode') 51 | parser.add_argument('--momentum', type=float, default=0.9, 52 | metavar='M', help='momentum (default: 0.9)') 53 | parser.add_argument('--weight-decay', type=float, default=3e-4, 54 | metavar='M', help='w-decay (default: 5e-4)') 55 | parser.add_argument('--arch-weight-decay', type=float, default=1e-3, 56 | metavar='M', help='w-decay (default: 5e-4)') 57 | 58 | parser.add_argument('--nesterov', action='store_true', default=False, 59 | help='whether use nesterov (default: False)') 60 | # cuda, seed and logging 61 | parser.add_argument('--cuda', type=int, default=1, 62 | help='use cuda? Default=True') 63 | parser.add_argument('--gpu-ids', type=str, default='0', 64 | help='use which gpu to train, must be a \ 65 | comma-separated list of integers only (default=0)') 66 | parser.add_argument('--seed', type=int, default=1, metavar='S', 67 | help='random seed (default: 1)') 68 | # finetuning pre-trained models 69 | parser.add_argument('--ft', action='store_true', default=False, 70 | help='finetuning on a different dataset') 71 | # checking point 72 | parser.add_argument('--resume', type=str, default=None, 73 | help='put the path to resuming file if needed') 74 | 75 | parser.add_argument('--no-val', action='store_true', default=False, 76 | help='skip validation during training') 77 | args = parser.parse_args() 78 | return args 79 | -------------------------------------------------------------------------------- /config_utils/train_args.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def obtain_train_args(): 4 | 5 | # Training settings 6 | parser = argparse.ArgumentParser(description='LEStereo training...') 7 | parser.add_argument('--maxdisp', type=int, default=192, 8 | help="max disp") 9 | parser.add_argument('--crop_height', type=int, required=True, 10 | help="crop height") 11 | parser.add_argument('--crop_width', type=int, required=True, 12 | help="crop width") 13 | parser.add_argument('--resume', type=str, default='', 14 | help="resume from saved model") 15 | parser.add_argument('--batch_size', type=int, default=1, 16 | help='training batch size') 17 | parser.add_argument('--testBatchSize', type=int, default=8, 18 | help='testing batch size') 19 | parser.add_argument('--nEpochs', type=int, default=2048, 20 | help='number of epochs to train for') 21 | parser.add_argument('--solver', default='adam',choices=['adam','sgd'], 22 | help='solver algorithms') 23 | parser.add_argument('--lr', type=float, default=0.001, 24 | help='Learning Rate. Default=0.001') 25 | parser.add_argument('--cuda', type=int, default=1, 26 | help='use cuda? Default=True') 27 | parser.add_argument('--threads', type=int, default=1, 28 | help='number of threads for data loader to use') 29 | parser.add_argument('--seed', type=int, default=2019, 30 | help='random seed to use. Default=123') 31 | parser.add_argument('--shift', type=int, default=0, 32 | help='random shift of left image. Default=0') 33 | parser.add_argument('--save_path', type=str, default='./checkpoints/', 34 | help="location to save models") 35 | parser.add_argument('--milestones', default=[30,50,300], metavar='N', nargs='*', 36 | help='epochs at which learning rate is divided by 2') 37 | parser.add_argument('--stage', type=str, default='train', choices=['search', 'train']) 38 | parser.add_argument('--dataset', type=str, default='sceneflow', 39 | choices=['sceneflow', 'kitti15', 'kitti12', 'middlebury'], help='dataset name') 40 | 41 | ######### LEStereo params ################## 42 | parser.add_argument('--fea_num_layers', type=int, default=6) 43 | parser.add_argument('--mat_num_layers', type=int, default=12) 44 | parser.add_argument('--fea_filter_multiplier', type=int, default=8) 45 | parser.add_argument('--mat_filter_multiplier', type=int, default=8) 46 | parser.add_argument('--fea_block_multiplier', type=int, default=4) 47 | parser.add_argument('--mat_block_multiplier', type=int, default=4) 48 | parser.add_argument('--fea_step', type=int, default=2) 49 | parser.add_argument('--mat_step', type=int, default=2) 50 | parser.add_argument('--net_arch_fea', default=None, type=str) 51 | parser.add_argument('--cell_arch_fea', default=None, type=str) 52 | parser.add_argument('--net_arch_mat', default=None, type=str) 53 | parser.add_argument('--cell_arch_mat', default=None, type=str) 54 | 55 | args = parser.parse_args() 56 | return args 57 | -------------------------------------------------------------------------------- /create_link.sh: -------------------------------------------------------------------------------- 1 | ln -s /mnt/data/StereoDataset/dataset/kitti2015 ./dataset 2 | ln -s /mnt/data/StereoDataset/dataset/kitti2012 ./dataset 3 | ln -s /mnt/data/StereoDataset/dataset/SceneFlow ./dataset 4 | ln -s /mnt/data/StereoDataset/dataset/MiddEval3 ./dataset 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /dataloaders/__init__.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from dataloaders.datasets import stereo 3 | import pdb 4 | 5 | def make_data_loader(args, **kwargs): 6 | ############################ sceneflow ########################### 7 | if args.dataset == 'sceneflow': 8 | trainA_list= 'dataloaders/lists/sceneflow_search_trainA.list' #randomly select 10,000 from the original training set 9 | trainB_list= 'dataloaders/lists/sceneflow_search_trainB.list' #randomly select 10,000 from the original training set 10 | val_list = 'dataloaders/lists/sceneflow_search_val.list' #randomly select 1,000 from the original test set 11 | train_list = 'dataloaders/lists/sceneflow_train.list' #original training set: 35,454 12 | test_list = 'dataloaders/lists/sceneflow_test.list' #original test set:4,370 13 | trainA_set = stereo.DatasetFromList(args, trainA_list, [args.crop_height, args.crop_width], True) 14 | trainB_set = stereo.DatasetFromList(args, trainB_list, [args.crop_height, args.crop_width], True) 15 | train_set = stereo.DatasetFromList(args, train_list, [args.crop_height, args.crop_width], True) 16 | val_set = stereo.DatasetFromList(args, val_list, [576,960], False) 17 | test_set = stereo.DatasetFromList(args, test_list, [576,960], False) 18 | 19 | if args.stage == 'search': 20 | train_loaderA = DataLoader(trainA_set, batch_size=args.batch_size, shuffle=True, **kwargs) 21 | train_loaderB = DataLoader(trainB_set, batch_size=args.batch_size, shuffle=True, **kwargs) 22 | elif args.stage == 'train': 23 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 24 | else: 25 | raise Exception('parameters not set properly') 26 | 27 | val_loader = DataLoader(val_set, batch_size=args.testBatchSize, shuffle=False, **kwargs) 28 | test_loader = DataLoader(test_set, batch_size=args.testBatchSize, shuffle=False, **kwargs) 29 | 30 | if args.stage == 'search': 31 | return train_loaderA, train_loaderB, val_loader, test_loader 32 | elif args.stage == 'train': 33 | return train_loader, test_loader 34 | 35 | ############################ kitti15 ########################### 36 | elif args.dataset == 'kitti15': 37 | train_list= 'dataloaders/lists/kitti2015_train180.list' 38 | test_list = 'dataloaders/lists/kitti2015_val20.list' 39 | train_set = stereo.DatasetFromList(args, train_list, [args.crop_height, args.crop_width], True) 40 | test_set = stereo.DatasetFromList(args, test_list, [384,1248], False) 41 | 42 | train_loader= DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 43 | test_loader = DataLoader(test_set, batch_size=args.testBatchSize, shuffle=False, **kwargs) 44 | return train_loader, test_loader 45 | 46 | ############################ kitti12 ########################### 47 | elif args.dataset == 'kitti12': 48 | train_list= 'dataloaders/lists/kitti2012_train170.list' 49 | test_list = 'dataloaders/lists/kitti2012_val24.list' 50 | train_set = stereo.DatasetFromList(args, train_list, [args.crop_height, args.crop_width], True) 51 | test_set = stereo.DatasetFromList(args, test_list, [384,1248], False) 52 | 53 | train_loader= DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 54 | test_loader = DataLoader(test_set, batch_size=args.testBatchSize, shuffle=False, **kwargs) 55 | return train_loader, test_loader 56 | 57 | ############################ middlebury ########################### 58 | elif args.dataset == 'middlebury': 59 | train_list= 'dataloaders/lists/middeval3_train.list' 60 | test_list = 'dataloaders/lists/middeval3_train.list' 61 | train_set = stereo.DatasetFromList(args, train_list, [args.crop_height, args.crop_width], True) 62 | test_set = stereo.DatasetFromList(args, test_list, [1008,1512], False) 63 | 64 | train_loader = DataLoader(train_set, batch_size=args.batch_size, shuffle=True, **kwargs) 65 | test_loader = DataLoader(test_set, batch_size=args.testBatchSize, shuffle=False, **kwargs) 66 | return train_loader, test_loader 67 | else: 68 | raise NotImplementedError 69 | -------------------------------------------------------------------------------- /dataloaders/datasets/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/dataloaders/datasets/__init__.py -------------------------------------------------------------------------------- /dataloaders/datasets/stereo.py: -------------------------------------------------------------------------------- 1 | import torch.utils.data as data 2 | import skimage 3 | import skimage.io 4 | import skimage.transform 5 | from PIL import Image 6 | import numpy as np 7 | import random 8 | from struct import unpack 9 | import re 10 | import sys 11 | from mypath import Path 12 | import pdb 13 | 14 | def readPFM(file): 15 | with open(file, "rb") as f: 16 | # Line 1: PF=>RGB (3 channels), Pf=>Greyscale (1 channel) 17 | type = f.readline().decode('latin-1') 18 | if "PF" in type: 19 | channels = 3 20 | elif "Pf" in type: 21 | channels = 1 22 | else: 23 | sys.exit(1) 24 | # Line 2: width height 25 | line = f.readline().decode('latin-1') 26 | width, height = re.findall('\d+', line) 27 | width = int(width) 28 | height = int(height) 29 | # Line 3: +ve number means big endian, negative means little endian 30 | line = f.readline().decode('latin-1') 31 | BigEndian = True 32 | if "-" in line: 33 | BigEndian = False 34 | # Slurp all binary data 35 | samples = width * height * channels; 36 | buffer = f.read(samples * 4) 37 | # Unpack floats with appropriate endianness 38 | if BigEndian: 39 | fmt = ">" 40 | else: 41 | fmt = "<" 42 | fmt = fmt + str(samples) + "f" 43 | img = unpack(fmt, buffer) 44 | img = np.reshape(img, (height, width)) 45 | img = np.flipud(img) 46 | return img, height, width 47 | 48 | def train_transform(temp_data, crop_height, crop_width, left_right=False, shift=0): 49 | _, h, w = np.shape(temp_data) 50 | 51 | if h > crop_height and w <= crop_width: 52 | temp = temp_data 53 | temp_data = np.zeros([8, h+shift, crop_width + shift], 'float32') 54 | temp_data[6:7,:,:] = 1000 55 | temp_data[:, h + shift - h: h + shift, crop_width + shift - w: crop_width + shift] = temp 56 | _, h, w = np.shape(temp_data) 57 | 58 | if h <= crop_height and w <= crop_width: 59 | temp = temp_data 60 | temp_data = np.zeros([8, crop_height + shift, crop_width + shift], 'float32') 61 | temp_data[6: 7, :, :] = 1000 62 | temp_data[:, crop_height + shift - h: crop_height + shift, crop_width + shift - w: crop_width + shift] = temp 63 | _, h, w = np.shape(temp_data) 64 | if shift > 0: 65 | start_x = random.randint(0, w - crop_width) 66 | shift_x = random.randint(-shift, shift) 67 | if shift_x + start_x < 0 or shift_x + start_x + crop_width > w: 68 | shift_x = 0 69 | start_y = random.randint(0, h - crop_height) 70 | left = temp_data[0: 3, start_y: start_y + crop_height, start_x + shift_x: start_x + shift_x + crop_width] 71 | right = temp_data[3: 6, start_y: start_y + crop_height, start_x: start_x + crop_width] 72 | target = temp_data[6: 7, start_y: start_y + crop_height, start_x + shift_x : start_x+shift_x + crop_width] 73 | target = target - shift_x 74 | return left, right, target 75 | if h <= crop_height and w <= crop_width: 76 | temp = temp_data 77 | temp_data = np.zeros([8, crop_height, crop_width], 'float32') 78 | temp_data[:, crop_height - h: crop_height, crop_width - w: crop_width] = temp 79 | else: 80 | start_x = random.randint(0, w - crop_width) 81 | start_y = random.randint(0, h - crop_height) 82 | temp_data = temp_data[:, start_y: start_y + crop_height, start_x: start_x + crop_width] 83 | if random.randint(0, 1) == 0 and left_right: 84 | right = temp_data[0: 3, :, :] 85 | left = temp_data[3: 6, :, :] 86 | target = temp_data[7: 8, :, :] 87 | return left, right, target 88 | else: 89 | left = temp_data[0: 3, :, :] 90 | right = temp_data[3: 6, :, :] 91 | target = temp_data[6: 7, :, :] 92 | return left, right, target 93 | 94 | def test_transform(temp_data, crop_height, crop_width, left_right=False): 95 | _, h, w = np.shape(temp_data) 96 | 97 | if h <= crop_height and w <= crop_width: 98 | temp = temp_data 99 | temp_data = np.zeros([8,crop_height,crop_width], 'float32') 100 | temp_data[6: 7, :, :] = 1000 101 | temp_data[:, crop_height - h: crop_height, crop_width - w: crop_width] = temp 102 | else: 103 | start_x = (w-crop_width)//2 104 | start_y = (h-crop_height)//2 105 | temp_data = temp_data[:, start_y: start_y + crop_height, start_x: start_x + crop_width] 106 | 107 | left = temp_data[0: 3, :, :] 108 | right = temp_data[3: 6, :, :] 109 | target = temp_data[6: 7, :, :] 110 | 111 | return left, right, target 112 | 113 | 114 | def load_data_sceneflow(data_path, current_file): 115 | A = current_file 116 | filename = data_path + 'frames_finalpass/' + A[0: len(A) - 1] 117 | left =Image.open(filename) 118 | filename = data_path + 'frames_finalpass/' + A[0: len(A) - 14] + 'right/' + A[len(A) - 9:len(A) - 1] 119 | right = Image.open(filename) 120 | filename = data_path + 'disparity/' + A[0: len(A) - 4] + 'pfm' 121 | disp_left, height, width = readPFM(filename) 122 | filename = data_path + 'disparity/' + A[0: len(A) - 14] + 'right/' + A[len(A) - 9: len(A) - 4] + 'pfm' 123 | disp_right, height, width = readPFM(filename) 124 | size = np.shape(left) 125 | height = size[0] 126 | width = size[1] 127 | temp_data = np.zeros([8, height, width], 'float32') 128 | left = np.asarray(left) 129 | right = np.asarray(right) 130 | r = left[:, :, 0] 131 | g = left[:, :, 1] 132 | b = left[:,:,2] 133 | temp_data[0, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 134 | temp_data[1, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 135 | temp_data[2, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 136 | r=right[:, :, 0] 137 | g=right[:, :, 1] 138 | b=right[:, :, 2] 139 | temp_data[3, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 140 | temp_data[4, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 141 | temp_data[5, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 142 | temp_data[6: 7, :, :] = width * 2 143 | temp_data[6, :, :] = disp_left 144 | temp_data[7, :, :] = disp_right 145 | return temp_data 146 | 147 | def load_kitti2015_data(file_path, current_file): 148 | """ load current file from the list""" 149 | filename = file_path + 'image_2/' + current_file[0: len(current_file) - 1] 150 | left = Image.open(filename) 151 | filename = file_path + 'image_3/' + current_file[0: len(current_file) - 1] 152 | right = Image.open(filename) 153 | filename = file_path + 'disp_occ_0/' + current_file[0: len(current_file) - 1] 154 | disp_left = Image.open(filename) 155 | temp = np.asarray(disp_left) 156 | size = np.shape(left) 157 | 158 | height = size[0] 159 | width = size[1] 160 | temp_data = np.zeros([8, height, width], 'float32') 161 | left = np.asarray(left) 162 | right = np.asarray(right) 163 | disp_left = np.asarray(disp_left) 164 | r = left[:, :, 0] 165 | g = left[:, :, 1] 166 | b = left[:, :, 2] 167 | 168 | temp_data[0, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 169 | temp_data[1, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 170 | temp_data[2, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 171 | r = right[:, :, 0] 172 | g = right[:, :, 1] 173 | b = right[:, :, 2] 174 | 175 | temp_data[3, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 176 | temp_data[4, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 177 | temp_data[5, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 178 | temp_data[6: 7, :, :] = width * 2 179 | temp_data[6, :, :] = disp_left[:, :] 180 | temp = temp_data[6, :, :] 181 | temp[temp < 0.1] = width * 2 * 256 182 | temp_data[6, :, :] = temp / 256. 183 | 184 | return temp_data 185 | 186 | def load_kitti2012_data(file_path, current_file): 187 | """ load current file from the list""" 188 | filename = file_path + 'colored_0/' + current_file[0: len(current_file) - 1] 189 | left = Image.open(filename) 190 | filename = file_path+'colored_1/' + current_file[0: len(current_file) - 1] 191 | right = Image.open(filename) 192 | filename = file_path+'disp_noc/' + current_file[0: len(current_file) - 1] #disp_occ 193 | 194 | disp_left = Image.open(filename) 195 | temp = np.asarray(disp_left) 196 | size = np.shape(left) 197 | 198 | height = size[0] 199 | width = size[1] 200 | temp_data = np.zeros([8, height, width], 'float32') 201 | left = np.asarray(left) 202 | right = np.asarray(right) 203 | disp_left = np.asarray(disp_left) 204 | r = left[:, :, 0] 205 | g = left[:, :, 1] 206 | b = left[:, :, 2] 207 | 208 | temp_data[0, :, :] = (r-np.mean(r[:])) / np.std(r[:]) 209 | temp_data[1, :, :] = (g-np.mean(g[:])) / np.std(g[:]) 210 | temp_data[2, :, :] = (b-np.mean(b[:])) / np.std(b[:]) 211 | r=right[:, :, 0] 212 | g=right[:, :, 1] 213 | b=right[:, :, 2] 214 | 215 | temp_data[3, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 216 | temp_data[4, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 217 | temp_data[5, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 218 | temp_data[6: 7, :, :] = width * 2 219 | temp_data[6, :, :] = disp_left[:, :] 220 | temp = temp_data[6, :, :] 221 | temp[temp < 0.1] = width * 2 * 256 222 | temp_data[6, :, :] = temp / 256. 223 | 224 | return temp_data 225 | 226 | def load_data_md(file_path, current_file, eth=False): 227 | """ load current file from the list""" 228 | imgl = file_path + current_file[0: len(current_file) - 1] 229 | gt_l = imgl.replace('im0.png','disp0GT.pfm') 230 | imgr = imgl.replace('im0.png','im1.png') 231 | 232 | left = Image.open(imgl) 233 | right = Image.open(imgr) 234 | 235 | disp_left, height, width = readPFM(gt_l) 236 | pdb.set_trace() 237 | 238 | temp_data = np.zeros([8, height, width], 'float32') 239 | left = np.asarray(left) 240 | right = np.asarray(right) 241 | disp_left = np.asarray(disp_left) 242 | r = left[:, :, 0] 243 | g = left[:, :, 1] 244 | b = left[:, :, 2] 245 | 246 | temp_data[0, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 247 | temp_data[1, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 248 | temp_data[2, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 249 | r = right[:, :, 0] 250 | g = right[:, :, 1] 251 | b = right[:, :, 2] 252 | 253 | temp_data[3, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 254 | temp_data[4, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 255 | temp_data[5, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 256 | temp_data[6: 7, :, :] = width * 2 257 | temp_data[6, :, :] = disp_left[:, :] 258 | temp = temp_data[6, :, :] 259 | temp[temp < 0.1] = width * 2 * 256 260 | temp_data[6, :, :] = temp #/ 256. 261 | 262 | return temp_data 263 | 264 | class DatasetFromList(data.Dataset): 265 | def __init__(self, args, file_list, crop_size=[256, 256], training=True, left_right=False, shift=0): 266 | super(DatasetFromList, self).__init__() 267 | f = open(file_list, 'r') 268 | self.args = args 269 | self.file_list = f.readlines() 270 | self.training = training 271 | self.crop_height = crop_size[0] 272 | self.crop_width = crop_size[1] 273 | self.left_right = left_right 274 | self.shift = shift 275 | 276 | def __getitem__(self, index): 277 | if self.args.dataset == 'kitti12': #load kitti2012 dataset 278 | temp_data = load_kitti2012_data(Path.db_root_dir('kitti12'), self.file_list[index]) 279 | elif self.args.dataset == 'kitti15': #load kitti2015 dataset 280 | temp_data = load_kitti2015_data(Path.db_root_dir('kitti15'), self.file_list[index]) 281 | elif self.args.dataset == 'sceneflow': #load sceneflow dataset 282 | temp_data = load_data_sceneflow(Path.db_root_dir('sceneflow'), self.file_list[index]) 283 | elif self.args.dataset == 'middlebury': #load middbury dataset 284 | temp_data = load_data_md(Path.db_root_dir('middlebury'), self.file_list[index]) 285 | 286 | if self.training: 287 | input1, input2, target = train_transform(temp_data, self.crop_height, self.crop_width, self.left_right, self.shift) 288 | return input1, input2, target 289 | 290 | else: 291 | input1, input2, target = test_transform(temp_data, self.crop_height, self.crop_width) 292 | return input1, input2, target 293 | 294 | def __len__(self): 295 | return len(self.file_list) 296 | 297 | -------------------------------------------------------------------------------- /dataloaders/lists/kitti2012_test.list: -------------------------------------------------------------------------------- 1 | 000000_10.png 2 | 000001_10.png 3 | 000002_10.png 4 | 000003_10.png 5 | 000004_10.png 6 | 000005_10.png 7 | 000006_10.png 8 | 000007_10.png 9 | 000008_10.png 10 | 000009_10.png 11 | 000010_10.png 12 | 000011_10.png 13 | 000012_10.png 14 | 000013_10.png 15 | 000014_10.png 16 | 000015_10.png 17 | 000016_10.png 18 | 000017_10.png 19 | 000018_10.png 20 | 000019_10.png 21 | 000020_10.png 22 | 000021_10.png 23 | 000022_10.png 24 | 000023_10.png 25 | 000024_10.png 26 | 000025_10.png 27 | 000026_10.png 28 | 000027_10.png 29 | 000028_10.png 30 | 000029_10.png 31 | 000030_10.png 32 | 000031_10.png 33 | 000032_10.png 34 | 000033_10.png 35 | 000034_10.png 36 | 000035_10.png 37 | 000036_10.png 38 | 000037_10.png 39 | 000038_10.png 40 | 000039_10.png 41 | 000040_10.png 42 | 000041_10.png 43 | 000042_10.png 44 | 000043_10.png 45 | 000044_10.png 46 | 000045_10.png 47 | 000046_10.png 48 | 000047_10.png 49 | 000048_10.png 50 | 000049_10.png 51 | 000050_10.png 52 | 000051_10.png 53 | 000052_10.png 54 | 000053_10.png 55 | 000054_10.png 56 | 000055_10.png 57 | 000056_10.png 58 | 000057_10.png 59 | 000058_10.png 60 | 000059_10.png 61 | 000060_10.png 62 | 000061_10.png 63 | 000062_10.png 64 | 000063_10.png 65 | 000064_10.png 66 | 000065_10.png 67 | 000066_10.png 68 | 000067_10.png 69 | 000068_10.png 70 | 000069_10.png 71 | 000070_10.png 72 | 000071_10.png 73 | 000072_10.png 74 | 000073_10.png 75 | 000074_10.png 76 | 000075_10.png 77 | 000076_10.png 78 | 000077_10.png 79 | 000078_10.png 80 | 000079_10.png 81 | 000080_10.png 82 | 000081_10.png 83 | 000082_10.png 84 | 000083_10.png 85 | 000084_10.png 86 | 000085_10.png 87 | 000086_10.png 88 | 000087_10.png 89 | 000088_10.png 90 | 000089_10.png 91 | 000090_10.png 92 | 000091_10.png 93 | 000092_10.png 94 | 000093_10.png 95 | 000094_10.png 96 | 000095_10.png 97 | 000096_10.png 98 | 000097_10.png 99 | 000098_10.png 100 | 000099_10.png 101 | 000100_10.png 102 | 000101_10.png 103 | 000102_10.png 104 | 000103_10.png 105 | 000104_10.png 106 | 000105_10.png 107 | 000106_10.png 108 | 000107_10.png 109 | 000108_10.png 110 | 000109_10.png 111 | 000110_10.png 112 | 000111_10.png 113 | 000112_10.png 114 | 000113_10.png 115 | 000114_10.png 116 | 000115_10.png 117 | 000116_10.png 118 | 000117_10.png 119 | 000118_10.png 120 | 000119_10.png 121 | 000120_10.png 122 | 000121_10.png 123 | 000122_10.png 124 | 000123_10.png 125 | 000124_10.png 126 | 000125_10.png 127 | 000126_10.png 128 | 000127_10.png 129 | 000128_10.png 130 | 000129_10.png 131 | 000130_10.png 132 | 000131_10.png 133 | 000132_10.png 134 | 000133_10.png 135 | 000134_10.png 136 | 000135_10.png 137 | 000136_10.png 138 | 000137_10.png 139 | 000138_10.png 140 | 000139_10.png 141 | 000140_10.png 142 | 000141_10.png 143 | 000142_10.png 144 | 000143_10.png 145 | 000144_10.png 146 | 000145_10.png 147 | 000146_10.png 148 | 000147_10.png 149 | 000148_10.png 150 | 000149_10.png 151 | 000150_10.png 152 | 000151_10.png 153 | 000152_10.png 154 | 000153_10.png 155 | 000154_10.png 156 | 000155_10.png 157 | 000156_10.png 158 | 000157_10.png 159 | 000158_10.png 160 | 000159_10.png 161 | 000160_10.png 162 | 000161_10.png 163 | 000162_10.png 164 | 000163_10.png 165 | 000164_10.png 166 | 000165_10.png 167 | 000166_10.png 168 | 000167_10.png 169 | 000168_10.png 170 | 000169_10.png 171 | 000170_10.png 172 | 000171_10.png 173 | 000172_10.png 174 | 000173_10.png 175 | 000174_10.png 176 | 000175_10.png 177 | 000176_10.png 178 | 000177_10.png 179 | 000178_10.png 180 | 000179_10.png 181 | 000180_10.png 182 | 000181_10.png 183 | 000182_10.png 184 | 000183_10.png 185 | 000184_10.png 186 | 000185_10.png 187 | 000186_10.png 188 | 000187_10.png 189 | 000188_10.png 190 | 000189_10.png 191 | 000190_10.png 192 | 000191_10.png 193 | 000192_10.png 194 | 000193_10.png 195 | 000194_10.png 196 | -------------------------------------------------------------------------------- /dataloaders/lists/kitti2012_train.list: -------------------------------------------------------------------------------- 1 | 000153_10.png 2 | 000174_10.png 3 | 000176_10.png 4 | 000022_10.png 5 | 000169_10.png 6 | 000057_10.png 7 | 000013_10.png 8 | 000185_10.png 9 | 000072_10.png 10 | 000182_10.png 11 | 000012_10.png 12 | 000054_10.png 13 | 000064_10.png 14 | 000061_10.png 15 | 000017_10.png 16 | 000155_10.png 17 | 000044_10.png 18 | 000001_10.png 19 | 000030_10.png 20 | 000106_10.png 21 | 000007_10.png 22 | 000092_10.png 23 | 000142_10.png 24 | 000144_10.png 25 | 000193_10.png 26 | 000080_10.png 27 | 000097_10.png 28 | 000143_10.png 29 | 000042_10.png 30 | 000081_10.png 31 | 000109_10.png 32 | 000171_10.png 33 | 000175_10.png 34 | 000047_10.png 35 | 000014_10.png 36 | 000100_10.png 37 | 000015_10.png 38 | 000180_10.png 39 | 000016_10.png 40 | 000019_10.png 41 | 000139_10.png 42 | 000184_10.png 43 | 000050_10.png 44 | 000115_10.png 45 | 000136_10.png 46 | 000188_10.png 47 | 000034_10.png 48 | 000033_10.png 49 | 000114_10.png 50 | 000038_10.png 51 | 000063_10.png 52 | 000121_10.png 53 | 000004_10.png 54 | 000062_10.png 55 | 000160_10.png 56 | 000148_10.png 57 | 000120_10.png 58 | 000052_10.png 59 | 000045_10.png 60 | 000126_10.png 61 | 000065_10.png 62 | 000127_10.png 63 | 000113_10.png 64 | 000131_10.png 65 | 000163_10.png 66 | 000146_10.png 67 | 000028_10.png 68 | 000048_10.png 69 | 000179_10.png 70 | 000135_10.png 71 | 000053_10.png 72 | 000159_10.png 73 | 000117_10.png 74 | 000134_10.png 75 | 000181_10.png 76 | 000021_10.png 77 | 000067_10.png 78 | 000077_10.png 79 | 000177_10.png 80 | 000029_10.png 81 | 000070_10.png 82 | 000137_10.png 83 | 000090_10.png 84 | 000027_10.png 85 | 000123_10.png 86 | 000107_10.png 87 | 000026_10.png 88 | 000190_10.png 89 | 000133_10.png 90 | 000024_10.png 91 | 000119_10.png 92 | 000116_10.png 93 | 000023_10.png 94 | 000066_10.png 95 | 000152_10.png 96 | 000008_10.png 97 | 000079_10.png 98 | 000069_10.png 99 | 000187_10.png 100 | 000145_10.png 101 | 000124_10.png 102 | 000140_10.png 103 | 000150_10.png 104 | 000132_10.png 105 | 000025_10.png 106 | 000122_10.png 107 | 000110_10.png 108 | 000041_10.png 109 | 000098_10.png 110 | 000118_10.png 111 | 000084_10.png 112 | 000103_10.png 113 | 000086_10.png 114 | 000178_10.png 115 | 000056_10.png 116 | 000058_10.png 117 | 000006_10.png 118 | 000111_10.png 119 | 000186_10.png 120 | 000170_10.png 121 | 000141_10.png 122 | 000020_10.png 123 | 000073_10.png 124 | 000093_10.png 125 | 000009_10.png 126 | 000000_10.png 127 | 000094_10.png 128 | 000149_10.png 129 | 000151_10.png 130 | 000046_10.png 131 | 000002_10.png 132 | 000010_10.png 133 | 000003_10.png 134 | 000035_10.png 135 | 000183_10.png 136 | 000036_10.png 137 | 000112_10.png 138 | 000040_10.png 139 | 000192_10.png 140 | 000091_10.png 141 | 000156_10.png 142 | 000128_10.png 143 | 000164_10.png 144 | 000055_10.png 145 | 000130_10.png 146 | 000095_10.png 147 | 000162_10.png 148 | 000189_10.png 149 | 000088_10.png 150 | 000125_10.png 151 | 000078_10.png 152 | 000039_10.png 153 | 000157_10.png 154 | 000102_10.png 155 | 000138_10.png 156 | 000129_10.png 157 | 000165_10.png 158 | 000161_10.png 159 | 000089_10.png 160 | 000087_10.png 161 | 000011_10.png 162 | 000085_10.png 163 | 000191_10.png 164 | 000049_10.png 165 | 000168_10.png 166 | 000082_10.png 167 | 000154_10.png 168 | 000060_10.png 169 | 000018_10.png 170 | 000166_10.png 171 | 000173_10.png 172 | 000172_10.png 173 | 000147_10.png 174 | 000059_10.png 175 | 000043_10.png 176 | 000099_10.png 177 | 000031_10.png 178 | 000076_10.png 179 | 000032_10.png 180 | 000105_10.png 181 | 000101_10.png 182 | 000074_10.png 183 | 000071_10.png 184 | 000158_10.png 185 | 000051_10.png 186 | 000083_10.png 187 | 000037_10.png 188 | 000005_10.png 189 | 000104_10.png 190 | 000096_10.png 191 | 000068_10.png 192 | 000108_10.png 193 | 000075_10.png 194 | 000167_10.png 195 | -------------------------------------------------------------------------------- /dataloaders/lists/kitti2012_train170.list: -------------------------------------------------------------------------------- 1 | 000193_10.png 2 | 000080_10.png 3 | 000097_10.png 4 | 000143_10.png 5 | 000042_10.png 6 | 000081_10.png 7 | 000109_10.png 8 | 000171_10.png 9 | 000175_10.png 10 | 000047_10.png 11 | 000014_10.png 12 | 000100_10.png 13 | 000015_10.png 14 | 000180_10.png 15 | 000016_10.png 16 | 000019_10.png 17 | 000139_10.png 18 | 000184_10.png 19 | 000050_10.png 20 | 000115_10.png 21 | 000136_10.png 22 | 000188_10.png 23 | 000034_10.png 24 | 000033_10.png 25 | 000114_10.png 26 | 000038_10.png 27 | 000063_10.png 28 | 000121_10.png 29 | 000004_10.png 30 | 000062_10.png 31 | 000160_10.png 32 | 000148_10.png 33 | 000120_10.png 34 | 000052_10.png 35 | 000045_10.png 36 | 000126_10.png 37 | 000065_10.png 38 | 000127_10.png 39 | 000113_10.png 40 | 000131_10.png 41 | 000163_10.png 42 | 000146_10.png 43 | 000028_10.png 44 | 000048_10.png 45 | 000179_10.png 46 | 000135_10.png 47 | 000053_10.png 48 | 000159_10.png 49 | 000117_10.png 50 | 000134_10.png 51 | 000181_10.png 52 | 000021_10.png 53 | 000067_10.png 54 | 000077_10.png 55 | 000177_10.png 56 | 000029_10.png 57 | 000070_10.png 58 | 000137_10.png 59 | 000090_10.png 60 | 000027_10.png 61 | 000123_10.png 62 | 000107_10.png 63 | 000026_10.png 64 | 000190_10.png 65 | 000133_10.png 66 | 000024_10.png 67 | 000119_10.png 68 | 000116_10.png 69 | 000023_10.png 70 | 000066_10.png 71 | 000152_10.png 72 | 000008_10.png 73 | 000079_10.png 74 | 000069_10.png 75 | 000187_10.png 76 | 000145_10.png 77 | 000124_10.png 78 | 000140_10.png 79 | 000150_10.png 80 | 000132_10.png 81 | 000025_10.png 82 | 000122_10.png 83 | 000110_10.png 84 | 000041_10.png 85 | 000098_10.png 86 | 000118_10.png 87 | 000084_10.png 88 | 000103_10.png 89 | 000086_10.png 90 | 000178_10.png 91 | 000056_10.png 92 | 000058_10.png 93 | 000006_10.png 94 | 000111_10.png 95 | 000186_10.png 96 | 000170_10.png 97 | 000141_10.png 98 | 000020_10.png 99 | 000073_10.png 100 | 000093_10.png 101 | 000009_10.png 102 | 000000_10.png 103 | 000094_10.png 104 | 000149_10.png 105 | 000151_10.png 106 | 000046_10.png 107 | 000002_10.png 108 | 000010_10.png 109 | 000003_10.png 110 | 000035_10.png 111 | 000183_10.png 112 | 000036_10.png 113 | 000112_10.png 114 | 000040_10.png 115 | 000192_10.png 116 | 000091_10.png 117 | 000156_10.png 118 | 000128_10.png 119 | 000164_10.png 120 | 000055_10.png 121 | 000130_10.png 122 | 000095_10.png 123 | 000162_10.png 124 | 000189_10.png 125 | 000088_10.png 126 | 000125_10.png 127 | 000078_10.png 128 | 000039_10.png 129 | 000157_10.png 130 | 000102_10.png 131 | 000138_10.png 132 | 000129_10.png 133 | 000165_10.png 134 | 000161_10.png 135 | 000089_10.png 136 | 000087_10.png 137 | 000011_10.png 138 | 000085_10.png 139 | 000191_10.png 140 | 000049_10.png 141 | 000168_10.png 142 | 000082_10.png 143 | 000154_10.png 144 | 000060_10.png 145 | 000018_10.png 146 | 000166_10.png 147 | 000173_10.png 148 | 000172_10.png 149 | 000147_10.png 150 | 000059_10.png 151 | 000043_10.png 152 | 000099_10.png 153 | 000031_10.png 154 | 000076_10.png 155 | 000032_10.png 156 | 000105_10.png 157 | 000101_10.png 158 | 000074_10.png 159 | 000071_10.png 160 | 000158_10.png 161 | 000051_10.png 162 | 000083_10.png 163 | 000037_10.png 164 | 000005_10.png 165 | 000104_10.png 166 | 000096_10.png 167 | 000068_10.png 168 | 000108_10.png 169 | 000075_10.png 170 | 000167_10.png 171 | -------------------------------------------------------------------------------- /dataloaders/lists/kitti2012_val24.list: -------------------------------------------------------------------------------- 1 | 000153_10.png 2 | 000174_10.png 3 | 000176_10.png 4 | 000022_10.png 5 | 000169_10.png 6 | 000057_10.png 7 | 000013_10.png 8 | 000185_10.png 9 | 000072_10.png 10 | 000182_10.png 11 | 000012_10.png 12 | 000054_10.png 13 | 000064_10.png 14 | 000061_10.png 15 | 000017_10.png 16 | 000155_10.png 17 | 000044_10.png 18 | 000001_10.png 19 | 000030_10.png 20 | 000106_10.png 21 | 000007_10.png 22 | 000092_10.png 23 | 000142_10.png 24 | 000144_10.png 25 | -------------------------------------------------------------------------------- /dataloaders/lists/kitti2015_test.list: -------------------------------------------------------------------------------- 1 | 000000_10.png 2 | 000001_10.png 3 | 000002_10.png 4 | 000003_10.png 5 | 000004_10.png 6 | 000005_10.png 7 | 000006_10.png 8 | 000007_10.png 9 | 000008_10.png 10 | 000009_10.png 11 | 000010_10.png 12 | 000011_10.png 13 | 000012_10.png 14 | 000013_10.png 15 | 000014_10.png 16 | 000015_10.png 17 | 000016_10.png 18 | 000017_10.png 19 | 000018_10.png 20 | 000019_10.png 21 | 000020_10.png 22 | 000021_10.png 23 | 000022_10.png 24 | 000023_10.png 25 | 000024_10.png 26 | 000025_10.png 27 | 000026_10.png 28 | 000027_10.png 29 | 000028_10.png 30 | 000029_10.png 31 | 000030_10.png 32 | 000031_10.png 33 | 000032_10.png 34 | 000033_10.png 35 | 000034_10.png 36 | 000035_10.png 37 | 000036_10.png 38 | 000037_10.png 39 | 000038_10.png 40 | 000039_10.png 41 | 000040_10.png 42 | 000041_10.png 43 | 000042_10.png 44 | 000043_10.png 45 | 000044_10.png 46 | 000045_10.png 47 | 000046_10.png 48 | 000047_10.png 49 | 000048_10.png 50 | 000049_10.png 51 | 000050_10.png 52 | 000051_10.png 53 | 000052_10.png 54 | 000053_10.png 55 | 000054_10.png 56 | 000055_10.png 57 | 000056_10.png 58 | 000057_10.png 59 | 000058_10.png 60 | 000059_10.png 61 | 000060_10.png 62 | 000061_10.png 63 | 000062_10.png 64 | 000063_10.png 65 | 000064_10.png 66 | 000065_10.png 67 | 000066_10.png 68 | 000067_10.png 69 | 000068_10.png 70 | 000069_10.png 71 | 000070_10.png 72 | 000071_10.png 73 | 000072_10.png 74 | 000073_10.png 75 | 000074_10.png 76 | 000075_10.png 77 | 000076_10.png 78 | 000077_10.png 79 | 000078_10.png 80 | 000079_10.png 81 | 000080_10.png 82 | 000081_10.png 83 | 000082_10.png 84 | 000083_10.png 85 | 000084_10.png 86 | 000085_10.png 87 | 000086_10.png 88 | 000087_10.png 89 | 000088_10.png 90 | 000089_10.png 91 | 000090_10.png 92 | 000091_10.png 93 | 000092_10.png 94 | 000093_10.png 95 | 000094_10.png 96 | 000095_10.png 97 | 000096_10.png 98 | 000097_10.png 99 | 000098_10.png 100 | 000099_10.png 101 | 000100_10.png 102 | 000101_10.png 103 | 000102_10.png 104 | 000103_10.png 105 | 000104_10.png 106 | 000105_10.png 107 | 000106_10.png 108 | 000107_10.png 109 | 000108_10.png 110 | 000109_10.png 111 | 000110_10.png 112 | 000111_10.png 113 | 000112_10.png 114 | 000113_10.png 115 | 000114_10.png 116 | 000115_10.png 117 | 000116_10.png 118 | 000117_10.png 119 | 000118_10.png 120 | 000119_10.png 121 | 000120_10.png 122 | 000121_10.png 123 | 000122_10.png 124 | 000123_10.png 125 | 000124_10.png 126 | 000125_10.png 127 | 000126_10.png 128 | 000127_10.png 129 | 000128_10.png 130 | 000129_10.png 131 | 000130_10.png 132 | 000131_10.png 133 | 000132_10.png 134 | 000133_10.png 135 | 000134_10.png 136 | 000135_10.png 137 | 000136_10.png 138 | 000137_10.png 139 | 000138_10.png 140 | 000139_10.png 141 | 000140_10.png 142 | 000141_10.png 143 | 000142_10.png 144 | 000143_10.png 145 | 000144_10.png 146 | 000145_10.png 147 | 000146_10.png 148 | 000147_10.png 149 | 000148_10.png 150 | 000149_10.png 151 | 000150_10.png 152 | 000151_10.png 153 | 000152_10.png 154 | 000153_10.png 155 | 000154_10.png 156 | 000155_10.png 157 | 000156_10.png 158 | 000157_10.png 159 | 000158_10.png 160 | 000159_10.png 161 | 000160_10.png 162 | 000161_10.png 163 | 000162_10.png 164 | 000163_10.png 165 | 000164_10.png 166 | 000165_10.png 167 | 000166_10.png 168 | 000167_10.png 169 | 000168_10.png 170 | 000169_10.png 171 | 000170_10.png 172 | 000171_10.png 173 | 000172_10.png 174 | 000173_10.png 175 | 000174_10.png 176 | 000175_10.png 177 | 000176_10.png 178 | 000177_10.png 179 | 000178_10.png 180 | 000179_10.png 181 | 000180_10.png 182 | 000181_10.png 183 | 000182_10.png 184 | 000183_10.png 185 | 000184_10.png 186 | 000185_10.png 187 | 000186_10.png 188 | 000187_10.png 189 | 000188_10.png 190 | 000189_10.png 191 | 000190_10.png 192 | 000191_10.png 193 | 000192_10.png 194 | 000193_10.png 195 | 000194_10.png 196 | 000195_10.png 197 | 000196_10.png 198 | 000197_10.png 199 | 000198_10.png 200 | 000199_10.png 201 | -------------------------------------------------------------------------------- /dataloaders/lists/kitti2015_train.list: -------------------------------------------------------------------------------- 1 | 000179_10.png 2 | 000128_10.png 3 | 000122_10.png 4 | 000178_10.png 5 | 000173_10.png 6 | 000100_10.png 7 | 000114_10.png 8 | 000037_10.png 9 | 000071_10.png 10 | 000076_10.png 11 | 000031_10.png 12 | 000130_10.png 13 | 000191_10.png 14 | 000086_10.png 15 | 000099_10.png 16 | 000195_10.png 17 | 000005_10.png 18 | 000131_10.png 19 | 000083_10.png 20 | 000196_10.png 21 | 000001_10.png 22 | 000022_10.png 23 | 000189_10.png 24 | 000030_10.png 25 | 000096_10.png 26 | 000082_10.png 27 | 000101_10.png 28 | 000121_10.png 29 | 000046_10.png 30 | 000106_10.png 31 | 000015_10.png 32 | 000056_10.png 33 | 000021_10.png 34 | 000111_10.png 35 | 000070_10.png 36 | 000145_10.png 37 | 000197_10.png 38 | 000040_10.png 39 | 000092_10.png 40 | 000186_10.png 41 | 000140_10.png 42 | 000127_10.png 43 | 000165_10.png 44 | 000141_10.png 45 | 000193_10.png 46 | 000184_10.png 47 | 000154_10.png 48 | 000060_10.png 49 | 000107_10.png 50 | 000029_10.png 51 | 000038_10.png 52 | 000135_10.png 53 | 000091_10.png 54 | 000156_10.png 55 | 000112_10.png 56 | 000054_10.png 57 | 000065_10.png 58 | 000113_10.png 59 | 000016_10.png 60 | 000050_10.png 61 | 000089_10.png 62 | 000166_10.png 63 | 000142_10.png 64 | 000097_10.png 65 | 000042_10.png 66 | 000194_10.png 67 | 000190_10.png 68 | 000134_10.png 69 | 000080_10.png 70 | 000155_10.png 71 | 000055_10.png 72 | 000116_10.png 73 | 000069_10.png 74 | 000020_10.png 75 | 000180_10.png 76 | 000068_10.png 77 | 000129_10.png 78 | 000027_10.png 79 | 000167_10.png 80 | 000081_10.png 81 | 000115_10.png 82 | 000048_10.png 83 | 000061_10.png 84 | 000011_10.png 85 | 000105_10.png 86 | 000161_10.png 87 | 000062_10.png 88 | 000088_10.png 89 | 000098_10.png 90 | 000093_10.png 91 | 000064_10.png 92 | 000032_10.png 93 | 000102_10.png 94 | 000058_10.png 95 | 000188_10.png 96 | 000148_10.png 97 | 000183_10.png 98 | 000072_10.png 99 | 000103_10.png 100 | 000126_10.png 101 | 000034_10.png 102 | 000024_10.png 103 | 000185_10.png 104 | 000108_10.png 105 | 000181_10.png 106 | 000153_10.png 107 | 000174_10.png 108 | 000157_10.png 109 | 000151_10.png 110 | 000170_10.png 111 | 000118_10.png 112 | 000117_10.png 113 | 000139_10.png 114 | 000143_10.png 115 | 000077_10.png 116 | 000168_10.png 117 | 000012_10.png 118 | 000146_10.png 119 | 000152_10.png 120 | 000039_10.png 121 | 000053_10.png 122 | 000162_10.png 123 | 000041_10.png 124 | 000051_10.png 125 | 000035_10.png 126 | 000075_10.png 127 | 000138_10.png 128 | 000044_10.png 129 | 000036_10.png 130 | 000000_10.png 131 | 000164_10.png 132 | 000004_10.png 133 | 000175_10.png 134 | 000094_10.png 135 | 000109_10.png 136 | 000052_10.png 137 | 000047_10.png 138 | 000028_10.png 139 | 000149_10.png 140 | 000026_10.png 141 | 000192_10.png 142 | 000049_10.png 143 | 000014_10.png 144 | 000125_10.png 145 | 000136_10.png 146 | 000177_10.png 147 | 000084_10.png 148 | 000007_10.png 149 | 000085_10.png 150 | 000063_10.png 151 | 000019_10.png 152 | 000176_10.png 153 | 000079_10.png 154 | 000187_10.png 155 | 000172_10.png 156 | 000199_10.png 157 | 000003_10.png 158 | 000124_10.png 159 | 000095_10.png 160 | 000160_10.png 161 | 000009_10.png 162 | 000137_10.png 163 | 000087_10.png 164 | 000150_10.png 165 | 000090_10.png 166 | 000006_10.png 167 | 000182_10.png 168 | 000132_10.png 169 | 000078_10.png 170 | 000002_10.png 171 | 000169_10.png 172 | 000017_10.png 173 | 000110_10.png 174 | 000074_10.png 175 | 000119_10.png 176 | 000147_10.png 177 | 000159_10.png 178 | 000013_10.png 179 | 000045_10.png 180 | 000010_10.png 181 | 000198_10.png 182 | 000059_10.png 183 | 000067_10.png 184 | 000123_10.png 185 | 000133_10.png 186 | 000057_10.png 187 | 000073_10.png 188 | 000120_10.png 189 | 000018_10.png 190 | 000171_10.png 191 | 000033_10.png 192 | 000023_10.png 193 | 000043_10.png 194 | 000158_10.png 195 | 000163_10.png 196 | 000025_10.png 197 | 000144_10.png 198 | 000008_10.png 199 | 000066_10.png 200 | 000104_10.png 201 | -------------------------------------------------------------------------------- /dataloaders/lists/kitti2015_train180.list: -------------------------------------------------------------------------------- 1 | 000179_10.png 2 | 000128_10.png 3 | 000122_10.png 4 | 000178_10.png 5 | 000173_10.png 6 | 000100_10.png 7 | 000114_10.png 8 | 000037_10.png 9 | 000071_10.png 10 | 000076_10.png 11 | 000031_10.png 12 | 000130_10.png 13 | 000191_10.png 14 | 000086_10.png 15 | 000099_10.png 16 | 000195_10.png 17 | 000005_10.png 18 | 000131_10.png 19 | 000083_10.png 20 | 000196_10.png 21 | 000001_10.png 22 | 000022_10.png 23 | 000189_10.png 24 | 000030_10.png 25 | 000096_10.png 26 | 000082_10.png 27 | 000101_10.png 28 | 000121_10.png 29 | 000046_10.png 30 | 000106_10.png 31 | 000015_10.png 32 | 000056_10.png 33 | 000021_10.png 34 | 000111_10.png 35 | 000070_10.png 36 | 000145_10.png 37 | 000197_10.png 38 | 000040_10.png 39 | 000092_10.png 40 | 000186_10.png 41 | 000140_10.png 42 | 000127_10.png 43 | 000165_10.png 44 | 000141_10.png 45 | 000193_10.png 46 | 000184_10.png 47 | 000154_10.png 48 | 000060_10.png 49 | 000107_10.png 50 | 000029_10.png 51 | 000038_10.png 52 | 000135_10.png 53 | 000091_10.png 54 | 000156_10.png 55 | 000112_10.png 56 | 000054_10.png 57 | 000065_10.png 58 | 000113_10.png 59 | 000016_10.png 60 | 000050_10.png 61 | 000089_10.png 62 | 000166_10.png 63 | 000142_10.png 64 | 000097_10.png 65 | 000042_10.png 66 | 000194_10.png 67 | 000190_10.png 68 | 000134_10.png 69 | 000080_10.png 70 | 000155_10.png 71 | 000055_10.png 72 | 000116_10.png 73 | 000069_10.png 74 | 000020_10.png 75 | 000180_10.png 76 | 000068_10.png 77 | 000129_10.png 78 | 000027_10.png 79 | 000167_10.png 80 | 000081_10.png 81 | 000115_10.png 82 | 000048_10.png 83 | 000061_10.png 84 | 000011_10.png 85 | 000105_10.png 86 | 000161_10.png 87 | 000062_10.png 88 | 000088_10.png 89 | 000098_10.png 90 | 000093_10.png 91 | 000064_10.png 92 | 000032_10.png 93 | 000102_10.png 94 | 000058_10.png 95 | 000188_10.png 96 | 000148_10.png 97 | 000183_10.png 98 | 000072_10.png 99 | 000103_10.png 100 | 000126_10.png 101 | 000034_10.png 102 | 000024_10.png 103 | 000185_10.png 104 | 000108_10.png 105 | 000181_10.png 106 | 000153_10.png 107 | 000174_10.png 108 | 000157_10.png 109 | 000151_10.png 110 | 000170_10.png 111 | 000118_10.png 112 | 000117_10.png 113 | 000139_10.png 114 | 000143_10.png 115 | 000077_10.png 116 | 000168_10.png 117 | 000012_10.png 118 | 000146_10.png 119 | 000152_10.png 120 | 000039_10.png 121 | 000053_10.png 122 | 000162_10.png 123 | 000041_10.png 124 | 000051_10.png 125 | 000035_10.png 126 | 000075_10.png 127 | 000138_10.png 128 | 000044_10.png 129 | 000036_10.png 130 | 000000_10.png 131 | 000164_10.png 132 | 000004_10.png 133 | 000175_10.png 134 | 000094_10.png 135 | 000109_10.png 136 | 000052_10.png 137 | 000047_10.png 138 | 000028_10.png 139 | 000149_10.png 140 | 000026_10.png 141 | 000192_10.png 142 | 000049_10.png 143 | 000014_10.png 144 | 000125_10.png 145 | 000136_10.png 146 | 000177_10.png 147 | 000084_10.png 148 | 000007_10.png 149 | 000085_10.png 150 | 000063_10.png 151 | 000019_10.png 152 | 000176_10.png 153 | 000079_10.png 154 | 000187_10.png 155 | 000172_10.png 156 | 000199_10.png 157 | 000003_10.png 158 | 000124_10.png 159 | 000095_10.png 160 | 000160_10.png 161 | 000009_10.png 162 | 000137_10.png 163 | 000087_10.png 164 | 000150_10.png 165 | 000090_10.png 166 | 000006_10.png 167 | 000182_10.png 168 | 000132_10.png 169 | 000078_10.png 170 | 000002_10.png 171 | 000169_10.png 172 | 000017_10.png 173 | 000110_10.png 174 | 000074_10.png 175 | 000119_10.png 176 | 000147_10.png 177 | 000159_10.png 178 | 000013_10.png 179 | 000045_10.png 180 | 000010_10.png 181 | -------------------------------------------------------------------------------- /dataloaders/lists/kitti2015_val20.list: -------------------------------------------------------------------------------- 1 | 000198_10.png 2 | 000059_10.png 3 | 000067_10.png 4 | 000123_10.png 5 | 000133_10.png 6 | 000057_10.png 7 | 000073_10.png 8 | 000120_10.png 9 | 000018_10.png 10 | 000171_10.png 11 | 000033_10.png 12 | 000023_10.png 13 | 000043_10.png 14 | 000158_10.png 15 | 000163_10.png 16 | 000025_10.png 17 | 000144_10.png 18 | 000008_10.png 19 | 000066_10.png 20 | 000104_10.png 21 | -------------------------------------------------------------------------------- /dataloaders/lists/middeval3_test.list: -------------------------------------------------------------------------------- 1 | Australia/im0.png 2 | AustraliaP/im0.png 3 | Bicycle2/im0.png 4 | Classroom2/im0.png 5 | Classroom2E/im0.png 6 | Computer/im0.png 7 | Crusade/im0.png 8 | CrusadeP/im0.png 9 | Djembe/im0.png 10 | DjembeL/im0.png 11 | Hoops/im0.png 12 | Livingroom/im0.png 13 | Newkuba/im0.png 14 | Plants/im0.png 15 | Staircase/im0.png 16 | -------------------------------------------------------------------------------- /dataloaders/lists/middeval3_train.list: -------------------------------------------------------------------------------- 1 | Playtable/im0.png 2 | ArtL/im0.png 3 | Jadeplant/im0.png 4 | PlaytableP/im0.png 5 | PianoL/im0.png 6 | Piano/im0.png 7 | Adirondack/im0.png 8 | Teddy/im0.png 9 | Recycle/im0.png 10 | Motorcycle/im0.png 11 | MotorcycleE/im0.png 12 | Vintage/im0.png 13 | Playroom/im0.png 14 | Shelves/im0.png 15 | Pipes/im0.png 16 | -------------------------------------------------------------------------------- /dataset/MiddEval3: -------------------------------------------------------------------------------- 1 | /mnt/data/StereoDataset/dataset/MiddEval3 -------------------------------------------------------------------------------- /dataset/SceneFlow: -------------------------------------------------------------------------------- 1 | /mnt/data/StereoDataset/dataset/SceneFlow -------------------------------------------------------------------------------- /dataset/kitti2012: -------------------------------------------------------------------------------- 1 | /mnt/data/StereoDataset/dataset/kitti2012 -------------------------------------------------------------------------------- /dataset/kitti2015: -------------------------------------------------------------------------------- 1 | /mnt/data/StereoDataset/dataset/kitti2015 -------------------------------------------------------------------------------- /decode.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import numpy as np 4 | import torch 5 | from models.decoding_formulas import Decoder 6 | from config_utils.decode_args import obtain_decode_args 7 | 8 | class Loader(object): 9 | def __init__(self, args): 10 | self.args = args 11 | # Resuming checkpoint 12 | self.best_pred = 0.0 13 | assert args.resume is not None, RuntimeError("No model to decode in resume path: '{:}'".format(args.resume)) 14 | assert os.path.isfile(args.resume), RuntimeError("=> no checkpoint found at '{}'".format(args.resume)) 15 | 16 | checkpoint = torch.load(args.resume) 17 | args.start_epoch = checkpoint['epoch'] 18 | 19 | self._alphas_fea = checkpoint['state_dict']['feature.alphas'] 20 | self._betas_fea = checkpoint['state_dict']['feature.betas'] 21 | self.decoder_fea = Decoder(alphas=self._alphas_fea, betas=self._betas_fea, steps=self.args.step) 22 | 23 | self._alphas_mat = checkpoint['state_dict']['matching.alphas'] 24 | self._betas_mat = checkpoint['state_dict']['matching.betas'] 25 | self.decoder_mat = Decoder(alphas=self._alphas_mat, betas=self._betas_mat, steps=self.args.step) 26 | 27 | def retreive_alphas_betas(self): 28 | return self._alphas_fea, self._betas_fea, self._alphas_mat, self._betas_mat 29 | 30 | def decode_architecture(self): 31 | fea_paths, fea_paths_space = self.decoder_fea.viterbi_decode() 32 | mat_paths, mat_paths_space = self.decoder_mat.viterbi_decode() 33 | return fea_paths, fea_paths_space, mat_paths, mat_paths_space 34 | 35 | def decode_cell(self): 36 | fea_genotype = self.decoder_fea.genotype_decode() 37 | mat_genotype = self.decoder_mat.genotype_decode() 38 | return fea_genotype, mat_genotype 39 | 40 | def get_new_network_cell(): 41 | args = obtain_decode_args() 42 | load_model = Loader(args) 43 | fea_net_paths, fea_net_paths_space, mat_net_paths, mat_net_paths_space = load_model.decode_architecture() 44 | fea_genotype, mat_genotype = load_model.decode_cell() 45 | print('Feature Net search results:', fea_net_paths) 46 | print('Matching Net search results:', mat_net_paths) 47 | print('Feature Net cell structure:', fea_genotype) 48 | print('Matching Net cell structure:', mat_genotype) 49 | 50 | dir_name = os.path.dirname(args.resume) 51 | fea_net_path_filename = os.path.join(dir_name, 'feature_network_path') 52 | fea_genotype_filename = os.path.join(dir_name, 'feature_genotype') 53 | np.save(fea_net_path_filename, fea_net_paths) 54 | np.save(fea_genotype_filename, fea_genotype) 55 | 56 | mat_net_path_filename = os.path.join(dir_name, 'matching_network_path') 57 | mat_genotype_filename = os.path.join(dir_name, 'matching_genotype') 58 | np.save(mat_net_path_filename, mat_net_paths) 59 | np.save(mat_genotype_filename, mat_genotype) 60 | 61 | fea_cell_name = os.path.join(dir_name, 'feature_cell_structure') 62 | mat_cell_name = os.path.join(dir_name, 'matching_cell_structure') 63 | 64 | if __name__ == '__main__': 65 | get_new_network_cell() 66 | -------------------------------------------------------------------------------- /decode.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python decode.py --dataset sceneflow --step 3 \ 2 | --resume ./run/sceneflow/experiment/checkpoint.pth.tar 3 | -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/models/__init__.py -------------------------------------------------------------------------------- /models/build_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from models.build_model_2d import AutoFeature, Disp 5 | from models.build_model_3d import AutoMatching 6 | import pdb 7 | from time import time 8 | 9 | class AutoStereo(nn.Module): 10 | def __init__(self, maxdisp=192, Fea_Layers=6, Fea_Filter=8, Fea_Block=4, Fea_Step=3, Mat_Layers=12, Mat_Filter=8, Mat_Block=4, Mat_Step=3): 11 | super(AutoStereo, self).__init__() 12 | self.maxdisp = maxdisp 13 | #define Feature parameters 14 | self.Fea_Layers = Fea_Layers 15 | self.Fea_Filter = Fea_Filter 16 | self.Fea_Block = Fea_Block 17 | self.Fea_Step = Fea_Step 18 | #define Matching parameters 19 | self.Mat_Layers = Mat_Layers 20 | self.Mat_Filter = Mat_Filter 21 | self.Mat_Block = Mat_Block 22 | self.Mat_Step = Mat_Step 23 | 24 | self.feature = AutoFeature(self.Fea_Layers, self.Fea_Filter, self.Fea_Block, self.Fea_Step) 25 | self.matching = AutoMatching(self.Mat_Layers, self.Mat_Filter, self.Mat_Block, self.Mat_Step) 26 | self.disp = Disp(self.maxdisp) 27 | 28 | for m in self.modules(): 29 | if isinstance(m, (nn.Conv2d)): 30 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 31 | elif isinstance(m, nn.Conv3d): 32 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 33 | elif isinstance(m, nn.BatchNorm3d): 34 | nn.init.constant_(m.weight, 1) 35 | nn.init.constant_(m.bias, 0) 36 | elif isinstance(m, (nn.BatchNorm2d)): 37 | nn.init.constant_(m.weight, 1) 38 | nn.init.constant_(m.bias, 0) 39 | 40 | def forward(self, x, y): 41 | 42 | x = self.feature(x) 43 | y = self.feature(y) 44 | 45 | with torch.cuda.device_of(x): 46 | cost = x.new().resize_(x.size()[0], x.size()[1]*2, int(self.maxdisp/3), x.size()[2], x.size()[3]).zero_() 47 | for i in range(int(self.maxdisp/3)): 48 | if i > 0 : 49 | cost[:,:x.size()[1], i,:,i:] = x[:,:,:,i:] 50 | cost[:,x.size()[1]:, i,:,i:] = y[:,:,:,:-i] 51 | else: 52 | cost[:,:x.size()[1],i,:,i:] = x 53 | cost[:,x.size()[1]:,i,:,i:] = y 54 | 55 | cost = self.matching(cost) 56 | disp0 = self.disp(cost) 57 | return disp0 58 | 59 | -------------------------------------------------------------------------------- /models/cell_level_search_2d.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from models.operations_2d import * 3 | from models.genotypes_2d import PRIMITIVES 4 | 5 | class MixedOp(nn.Module): 6 | 7 | def __init__(self, C, stride): 8 | super(MixedOp, self).__init__() 9 | self._ops = nn.ModuleList() 10 | for primitive in PRIMITIVES: 11 | op = OPS[primitive](C, stride) 12 | if 'pool' in primitive: 13 | op = nn.Sequential(op, nn.BatchNorm2d(C)) 14 | self._ops.append(op) 15 | 16 | def forward(self, x, weights): 17 | return sum(w * op(x) for w, op in zip(weights, self._ops)) 18 | 19 | 20 | class Cell(nn.Module): 21 | 22 | def __init__(self, steps, block_multiplier, prev_prev_fmultiplier, 23 | prev_fmultiplier_down, prev_fmultiplier_same, prev_fmultiplier_up, 24 | filter_multiplier): 25 | 26 | super(Cell, self).__init__() 27 | 28 | self.C_in = block_multiplier * filter_multiplier 29 | self.C_out = filter_multiplier 30 | 31 | self.C_prev_prev = int(prev_prev_fmultiplier * block_multiplier) 32 | self._prev_fmultiplier_same = prev_fmultiplier_same 33 | 34 | if prev_fmultiplier_down is not None: 35 | self.C_prev_down = int(prev_fmultiplier_down * block_multiplier) 36 | self.preprocess_down = ConvBR( 37 | self.C_prev_down, self.C_out, 1, 1, 0) 38 | if prev_fmultiplier_same is not None: 39 | self.C_prev_same = int(prev_fmultiplier_same * block_multiplier) 40 | self.preprocess_same = ConvBR( 41 | self.C_prev_same, self.C_out, 1, 1, 0) 42 | if prev_fmultiplier_up is not None: 43 | self.C_prev_up = int(prev_fmultiplier_up * block_multiplier) 44 | self.preprocess_up = ConvBR( 45 | self.C_prev_up, self.C_out, 1, 1, 0) 46 | 47 | if prev_prev_fmultiplier != -1: 48 | self.pre_preprocess = ConvBR( 49 | self.C_prev_prev, self.C_out, 1, 1, 0) 50 | 51 | self._steps = steps 52 | self.block_multiplier = block_multiplier 53 | self._ops = nn.ModuleList() 54 | 55 | for i in range(self._steps): 56 | for j in range(2 + i): 57 | stride = 1 58 | if prev_prev_fmultiplier == -1 and j == 0: 59 | op = None 60 | else: 61 | op = MixedOp(self.C_out, stride) 62 | self._ops.append(op) 63 | 64 | self._initialize_weights() 65 | 66 | def scale_dimension(self, dim, scale): 67 | assert isinstance(dim, int) 68 | return int((float(dim) - 1.0) * scale + 1.0) if dim % 2 else int(dim * scale) 69 | 70 | def prev_feature_resize(self, prev_feature, mode): 71 | if mode == 'down': 72 | feature_size_h = self.scale_dimension(prev_feature.shape[2], 0.5) 73 | feature_size_w = self.scale_dimension(prev_feature.shape[3], 0.5) 74 | elif mode == 'up': 75 | feature_size_h = self.scale_dimension(prev_feature.shape[2], 2) 76 | feature_size_w = self.scale_dimension(prev_feature.shape[3], 2) 77 | 78 | return F.interpolate(prev_feature, (feature_size_h, feature_size_w), mode='bilinear', align_corners=True) 79 | 80 | def forward(self, s0, s1_down, s1_same, s1_up, n_alphas): 81 | 82 | if s1_down is not None: 83 | s1_down = self.prev_feature_resize(s1_down, 'down') 84 | s1_down = self.preprocess_down(s1_down) 85 | size_h, size_w = s1_down.shape[2], s1_down.shape[3] 86 | if s1_same is not None: 87 | s1_same = self.preprocess_same(s1_same) 88 | size_h, size_w = s1_same.shape[2], s1_same.shape[3] 89 | if s1_up is not None: 90 | s1_up = self.prev_feature_resize(s1_up, 'up') 91 | s1_up = self.preprocess_up(s1_up) 92 | size_h, size_w = s1_up.shape[2], s1_up.shape[3] 93 | all_states = [] 94 | if s0 is not None: 95 | 96 | s0 = F.interpolate(s0, (size_h, size_w), mode='bilinear', align_corners=True) if (s0.shape[2] != size_h) or (s0.shape[3] != size_w) else s0 97 | s0 = self.pre_preprocess(s0) if (s0.shape[1] != self.C_out) else s0 98 | if s1_down is not None: 99 | states_down = [s0, s1_down] 100 | all_states.append(states_down) 101 | if s1_same is not None: 102 | states_same = [s0, s1_same] 103 | all_states.append(states_same) 104 | if s1_up is not None: 105 | states_up = [s0, s1_up] 106 | all_states.append(states_up) 107 | else: 108 | if s1_down is not None: 109 | states_down = [0, s1_down] 110 | all_states.append(states_down) 111 | if s1_same is not None: 112 | states_same = [0, s1_same] 113 | all_states.append(states_same) 114 | if s1_up is not None: 115 | states_up = [0, s1_up] 116 | all_states.append(states_up) 117 | 118 | final_concates = [] 119 | for states in all_states: 120 | offset = 0 121 | for i in range(self._steps): 122 | new_states = [] 123 | for j, h in enumerate(states): 124 | branch_index = offset + j 125 | if self._ops[branch_index] is None: 126 | continue 127 | new_state = self._ops[branch_index](h, n_alphas[branch_index]) 128 | new_states.append(new_state) 129 | 130 | s = sum(new_states) 131 | offset += len(states) 132 | states.append(s) 133 | 134 | concat_feature = torch.cat(states[-self.block_multiplier:], dim=1) 135 | final_concates.append(concat_feature) 136 | return final_concates 137 | 138 | 139 | def _initialize_weights(self): 140 | for m in self.modules(): 141 | if isinstance(m, nn.Conv2d): 142 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 143 | elif isinstance(m, nn.BatchNorm2d): 144 | nn.init.constant_(m.weight, 1) 145 | nn.init.constant_(m.bias, 0) 146 | -------------------------------------------------------------------------------- /models/cell_level_search_3d.py: -------------------------------------------------------------------------------- 1 | import torch.nn.functional as F 2 | from models.operations_3d import * 3 | from models.genotypes_3d import PRIMITIVES 4 | 5 | class MixedOp(nn.Module): 6 | 7 | def __init__(self, C, stride): 8 | super(MixedOp, self).__init__() 9 | self._ops = nn.ModuleList() 10 | for primitive in PRIMITIVES: 11 | op = OPS[primitive](C, stride) 12 | if 'pool' in primitive: 13 | op = nn.Sequential(op, nn.BatchNorm3d(C)) 14 | self._ops.append(op) 15 | 16 | def forward(self, x, weights): 17 | return sum(w * op(x) for w, op in zip(weights, self._ops)) 18 | 19 | 20 | class Cell(nn.Module): 21 | 22 | def __init__(self, steps, block_multiplier, prev_prev_fmultiplier, 23 | prev_fmultiplier_down, prev_fmultiplier_same, prev_fmultiplier_up, 24 | filter_multiplier): 25 | 26 | super(Cell, self).__init__() 27 | 28 | self.C_in = block_multiplier * filter_multiplier 29 | self.C_out = filter_multiplier 30 | 31 | self.C_prev_prev = int(prev_prev_fmultiplier * block_multiplier) 32 | self._prev_fmultiplier_same = prev_fmultiplier_same 33 | 34 | if prev_fmultiplier_down is not None: 35 | self.C_prev_down = int(prev_fmultiplier_down * block_multiplier) 36 | self.preprocess_down = ConvBR( 37 | self.C_prev_down, self.C_out, 1, 1, 0) 38 | if prev_fmultiplier_same is not None: 39 | self.C_prev_same = int(prev_fmultiplier_same * block_multiplier) 40 | self.preprocess_same = ConvBR( 41 | self.C_prev_same, self.C_out, 1, 1, 0) 42 | if prev_fmultiplier_up is not None: 43 | self.C_prev_up = int(prev_fmultiplier_up * block_multiplier) 44 | self.preprocess_up = ConvBR( 45 | self.C_prev_up, self.C_out, 1, 1, 0) 46 | 47 | if prev_prev_fmultiplier != -1: 48 | self.pre_preprocess = ConvBR( 49 | self.C_prev_prev, self.C_out, 1, 1, 0) 50 | 51 | self._steps = steps 52 | self.block_multiplier = block_multiplier 53 | self._ops = nn.ModuleList() 54 | 55 | for i in range(self._steps): 56 | for j in range(2 + i): 57 | stride = 1 58 | if prev_prev_fmultiplier == -1 and j == 0: 59 | op = None 60 | else: 61 | op = MixedOp(self.C_out, stride) 62 | self._ops.append(op) 63 | 64 | self._initialize_weights() 65 | 66 | def scale_dimension(self, dim, scale): 67 | assert isinstance(dim, int) 68 | return int((float(dim) - 1.0) * scale + 1.0) if dim % 2 else int(dim * scale) 69 | 70 | def prev_feature_resize(self, prev_feature, mode): 71 | if mode == 'down': 72 | feature_size_d = self.scale_dimension(prev_feature.shape[2], 0.5) 73 | feature_size_h = self.scale_dimension(prev_feature.shape[3], 0.5) 74 | feature_size_w = self.scale_dimension(prev_feature.shape[4], 0.5) 75 | elif mode == 'up': 76 | feature_size_d = self.scale_dimension(prev_feature.shape[2], 2) 77 | feature_size_h = self.scale_dimension(prev_feature.shape[3], 2) 78 | feature_size_w = self.scale_dimension(prev_feature.shape[4], 2) 79 | return F.interpolate(prev_feature, (feature_size_d, feature_size_h, feature_size_w), mode='trilinear', align_corners=True) 80 | 81 | def forward(self, s0, s1_down, s1_same, s1_up, n_alphas): 82 | 83 | if s1_down is not None: 84 | s1_down = self.prev_feature_resize(s1_down, 'down') 85 | s1_down = self.preprocess_down(s1_down) 86 | size_d, size_h, size_w = s1_down.shape[2], s1_down.shape[3], s1_down.shape[4] 87 | if s1_same is not None: 88 | s1_same = self.preprocess_same(s1_same) 89 | size_d, size_h, size_w = s1_same.shape[2], s1_same.shape[3], s1_same.shape[4] 90 | if s1_up is not None: 91 | s1_up = self.prev_feature_resize(s1_up, 'up') 92 | s1_up = self.preprocess_up(s1_up) 93 | size_d, size_h, size_w = s1_up.shape[2], s1_up.shape[3], s1_up.shape[4] 94 | all_states = [] 95 | if s0 is not None: 96 | s0 = F.interpolate(s0, (size_d, size_h, size_w), mode='trilinear', align_corners=True) if (s0.shape[3] != size_h) or (s0.shape[4] != size_w) or (s0.shape[2] != size_d) else s0 97 | s0 = self.pre_preprocess(s0) if (s0.shape[1] != self.C_out) else s0 98 | if s1_down is not None: 99 | states_down = [s0, s1_down] 100 | all_states.append(states_down) 101 | if s1_same is not None: 102 | states_same = [s0, s1_same] 103 | all_states.append(states_same) 104 | if s1_up is not None: 105 | states_up = [s0, s1_up] 106 | all_states.append(states_up) 107 | else: 108 | if s1_down is not None: 109 | states_down = [0, s1_down] 110 | all_states.append(states_down) 111 | if s1_same is not None: 112 | states_same = [0, s1_same] 113 | all_states.append(states_same) 114 | if s1_up is not None: 115 | states_up = [0, s1_up] 116 | all_states.append(states_up) 117 | 118 | final_concates = [] 119 | for states in all_states: 120 | offset = 0 121 | for i in range(self._steps): 122 | new_states = [] 123 | for j, h in enumerate(states): 124 | branch_index = offset + j 125 | if self._ops[branch_index] is None: 126 | continue 127 | new_state = self._ops[branch_index]( 128 | h, n_alphas[branch_index]) 129 | new_states.append(new_state) 130 | 131 | s = sum(new_states) 132 | offset += len(states) 133 | states.append(s) 134 | 135 | concat_feature = torch.cat(states[-self.block_multiplier:], dim=1) 136 | final_concates.append(concat_feature) 137 | return final_concates 138 | 139 | 140 | def _initialize_weights(self): 141 | for m in self.modules(): 142 | if isinstance(m, nn.Conv3d): 143 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 144 | elif isinstance(m, nn.BatchNorm3d): 145 | nn.init.constant_(m.weight, 1) 146 | nn.init.constant_(m.bias, 0) 147 | -------------------------------------------------------------------------------- /models/decoding_formulas.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pdb 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | def network_layer_to_space(net_arch): 7 | for i, layer in enumerate(net_arch): 8 | if i == 0: 9 | space = np.zeros((1, 4, 3)) 10 | space[0][layer][0] = 1 11 | prev = layer 12 | else: 13 | if layer == prev + 1: 14 | sample = 0 15 | elif layer == prev: 16 | sample = 1 17 | elif layer == prev - 1: 18 | sample = 2 19 | space1 = np.zeros((1, 4, 3)) 20 | space1[0][layer][sample] = 1 21 | space = np.concatenate([space, space1], axis=0) 22 | prev = layer 23 | """ 24 | return: 25 | network_space[layer][level][sample]: 26 | layer: 0 - 12 27 | level: sample_level {0: 1, 1: 2, 2: 4, 3: 8} 28 | sample: 0: down 1: None 2: Up 29 | """ 30 | return space 31 | 32 | class Decoder(object): 33 | def __init__(self, alphas, betas, steps): 34 | self._betas = betas 35 | self._alphas = alphas 36 | self._steps = steps 37 | self._num_layers = self._betas.shape[0] 38 | self.network_space = torch.zeros(self._num_layers, 4, 3) 39 | 40 | for layer in range(self._num_layers): 41 | if layer == 0: 42 | self.network_space[layer][0][1:] = F.softmax(self._betas[layer][0][1:], dim=-1) * (2/3) 43 | elif layer == 1: 44 | self.network_space[layer][0][1:] = F.softmax(self._betas[layer][0][1:], dim=-1) * (2/3) 45 | self.network_space[layer][1] = F.softmax(self._betas[layer][1], dim=-1) 46 | 47 | elif layer == 2: 48 | self.network_space[layer][0][1:] = F.softmax(self._betas[layer][0][1:], dim=-1) * (2/3) 49 | self.network_space[layer][1] = F.softmax(self._betas[layer][1], dim=-1) 50 | self.network_space[layer][2] = F.softmax(self._betas[layer][2], dim=-1) 51 | 52 | 53 | else: 54 | self.network_space[layer][0][1:] = F.softmax(self._betas[layer][0][1:], dim=-1) * (2/3) 55 | self.network_space[layer][1] = F.softmax(self._betas[layer][1], dim=-1) 56 | self.network_space[layer][2] = F.softmax(self._betas[layer][2], dim=-1) 57 | self.network_space[layer][3][:2] = F.softmax(self._betas[layer][3][:2], dim=-1) * (2/3) 58 | 59 | def viterbi_decode(self): 60 | 61 | prob_space = np.zeros((self.network_space.shape[:2])) 62 | path_space = np.zeros((self.network_space.shape[:2])).astype('int8') 63 | 64 | for layer in range(self.network_space.shape[0]): 65 | if layer == 0: 66 | prob_space[layer][0] = self.network_space[layer][0][1] 67 | prob_space[layer][1] = self.network_space[layer][0][2] 68 | path_space[layer][0] = 0 69 | path_space[layer][1] = -1 70 | else: 71 | for sample in range(self.network_space.shape[1]): 72 | if layer - sample < - 1: 73 | continue 74 | local_prob = [] 75 | for rate in range(self.network_space.shape[2]): # k[0 : ➚, 1: ➙, 2 : ➘] 76 | if (sample == 0 and rate == 2) or (sample == 3 and rate == 0): 77 | continue 78 | else: 79 | local_prob.append(prob_space[layer - 1][sample + 1 - rate] * 80 | self.network_space[layer][sample + 1 - rate][rate]) 81 | prob_space[layer][sample] = np.max(local_prob, axis=0) 82 | rate = np.argmax(local_prob, axis=0) 83 | path = 1 - rate if sample != 3 else -rate 84 | path_space[layer][sample] = path # path[1 : ➚, 0: ➙, -1 : ➘] 85 | 86 | output_sample = prob_space[-1, :].argmax(axis=-1) 87 | actual_path = np.zeros(self._num_layers).astype('uint8') 88 | actual_path[-1] = output_sample 89 | for i in range(1, self._num_layers): 90 | actual_path[-i - 1] = actual_path[-i] + path_space[self._num_layers - i, actual_path[-i]] 91 | return actual_path, network_layer_to_space(actual_path) 92 | 93 | def genotype_decode(self): 94 | def _parse(alphas, steps): 95 | gene = [] 96 | start = 0 97 | n = 2 98 | for i in range(steps): 99 | end = start + n 100 | edges = sorted(range(start, end), key=lambda x: -np.max(alphas[x, 1:])) # ignore none value 101 | top2edges = edges[:2] 102 | for j in top2edges: 103 | best_op_index = np.argmax(alphas[j]) # this can include none op 104 | gene.append([j, best_op_index]) 105 | start = end 106 | n += 1 107 | return np.array(gene) 108 | 109 | normalized_alphas = F.softmax(self._alphas, dim=-1).data.cpu().numpy() 110 | gene_cell = _parse(normalized_alphas, self._steps) 111 | return gene_cell 112 | -------------------------------------------------------------------------------- /models/genotypes_2d.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple('Genotype_2D', 'cell cell_concat') 4 | 5 | PRIMITIVES = [ 6 | 'skip_connect', 7 | 'conv_3x3'] 8 | 9 | 10 | -------------------------------------------------------------------------------- /models/genotypes_3d.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple('Genotype', 'cell cell_concat') 4 | 5 | PRIMITIVES = [ 6 | 'skip_connect', 7 | '3d_conv_3x3' 8 | ] 9 | 10 | -------------------------------------------------------------------------------- /models/operations_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | OPS = { 6 | 'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C), 7 | 'conv_3x3': lambda C, stride: ConvBR(C, C, 3, stride, 1) 8 | } 9 | 10 | class NaiveBN(nn.Module): 11 | def __init__(self, C_out, momentum=0.1): 12 | super(NaiveBN, self).__init__() 13 | self.op = nn.Sequential( 14 | nn.BatchNorm2d(C_out), 15 | nn.ReLU() 16 | ) 17 | self._initialize_weights() 18 | 19 | def forward(self, x): 20 | return self.op(x) 21 | 22 | def _initialize_weights(self): 23 | for m in self.modules(): 24 | if isinstance(m, nn.Conv2d): 25 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 26 | elif isinstance(m, nn.BatchNorm2d): 27 | nn.init.constant_(m.weight, 1) 28 | nn.init.constant_(m.bias, 0) 29 | 30 | 31 | class ConvBR(nn.Module): 32 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bn=True, relu=True): 33 | super(ConvBR, self).__init__() 34 | self.relu = relu 35 | self.use_bn = bn 36 | 37 | self.conv = nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False) 38 | self.bn = nn.BatchNorm2d(C_out) 39 | self._initialize_weights() 40 | 41 | def forward(self, x): 42 | x = self.conv(x) 43 | if self.use_bn: 44 | x = self.bn(x) 45 | if self.relu: 46 | x = F.relu(x, inplace=True) 47 | return x 48 | def _initialize_weights(self): 49 | for m in self.modules(): 50 | if isinstance(m, nn.Conv2d): 51 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 52 | elif isinstance(m, nn.BatchNorm2d): 53 | nn.init.constant_(m.weight, 1) 54 | nn.init.constant_(m.bias, 0) 55 | 56 | class SepConv(nn.Module): 57 | def __init__(self, C_in, C_out, kernel_size, stride, padding): 58 | super(SepConv, self).__init__() 59 | self.op = nn.Sequential( 60 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, 61 | bias=False), 62 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 63 | nn.BatchNorm2d(C_in), 64 | nn.ReLU(inplace=False), 65 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 66 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 67 | nn.BatchNorm2d(C_out), 68 | nn.ReLU(inplace=False) 69 | ) 70 | self._initialize_weights() 71 | 72 | def forward(self, x): 73 | return self.op(x) 74 | 75 | def _initialize_weights(self): 76 | for m in self.modules(): 77 | if isinstance(m, nn.Conv2d): 78 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 79 | elif isinstance(m, nn.BatchNorm2d): 80 | nn.init.constant_(m.weight, 1) 81 | nn.init.constant_(m.bias, 0) 82 | 83 | class Identity(nn.Module): 84 | def __init__(self): 85 | super(Identity, self).__init__() 86 | self._initialize_weights() 87 | 88 | def forward(self, x): 89 | return x 90 | 91 | def init_weight(self): 92 | for ly in self.children(): 93 | if isinstance(ly, nn.Conv2d): 94 | nn.init.kaiming_normal_(ly.weight, a=1) 95 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 96 | 97 | def _initialize_weights(self): 98 | for m in self.modules(): 99 | if isinstance(m, nn.Conv2d): 100 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 101 | elif isinstance(m, nn.BatchNorm2d): 102 | nn.init.constant_(m.weight, 1) 103 | nn.init.constant_(m.bias, 0) 104 | 105 | 106 | class FactorizedReduce(nn.Module): 107 | def __init__(self, C_in, C_out): 108 | super(FactorizedReduce, self).__init__() 109 | assert C_out % 2 == 0 110 | self.relu = nn.ReLU(inplace=False) 111 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 112 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 113 | self.bn = nn.BatchNorm2d(C_out) 114 | self._initialize_weights() 115 | 116 | def forward(self, x): 117 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) 118 | out = self.bn(out) 119 | out = self.relu(out) 120 | return out 121 | 122 | def init_weight(self): 123 | for ly in self.children(): 124 | if isinstance(ly, nn.Conv2d): 125 | nn.init.kaiming_normal_(ly.weight, a=1) 126 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 127 | 128 | def _initialize_weights(self): 129 | for m in self.modules(): 130 | if isinstance(m, nn.Conv2d): 131 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 132 | elif isinstance(m, nn.BatchNorm2d): 133 | nn.init.constant_(m.weight, 1) 134 | nn.init.constant_(m.bias, 0) 135 | 136 | 137 | class DoubleFactorizedReduce(nn.Module): 138 | def __init__(self, C_in, C_out): 139 | super(DoubleFactorizedReduce, self).__init__() 140 | assert C_out % 2 == 0 141 | self.relu = nn.ReLU(inplace=False) 142 | self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=4, padding=0, bias=False) 143 | self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=4, padding=0, bias=False) 144 | self.bn = nn.BatchNorm2d(C_out) 145 | self._initialize_weights() 146 | 147 | def forward(self, x): 148 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) 149 | out = self.bn(out) 150 | out = self.relu(out) 151 | return out 152 | 153 | def _initialize_weights(self): 154 | for m in self.modules(): 155 | if isinstance(m, nn.Conv2d): 156 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 157 | elif isinstance(m, nn.BatchNorm2d): 158 | nn.init.constant_(m.weight, 1) 159 | nn.init.constant_(m.bias, 0) 160 | 161 | class FactorizedIncrease(nn.Module): 162 | def __init__(self, in_channel, out_channel): 163 | super(FactorizedIncrease, self).__init__() 164 | 165 | self._in_channel = in_channel 166 | self.op = nn.Sequential( 167 | nn.Upsample(scale_factor=2, mode="bilinear"), 168 | nn.Conv2d(self._in_channel, out_channel, 1, stride=1, padding=0), 169 | nn.BatchNorm2d(out_channel), 170 | nn.ReLU(inplace=False) 171 | ) 172 | self._initialize_weights() 173 | 174 | def forward(self, x): 175 | return self.op(x) 176 | 177 | def _initialize_weights(self): 178 | for m in self.modules(): 179 | if isinstance(m, nn.Conv2d): 180 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 181 | elif isinstance(m, nn.BatchNorm2d): 182 | nn.init.constant_(m.weight, 1) 183 | nn.init.constant_(m.bias, 0) 184 | 185 | 186 | class DoubleFactorizedIncrease(nn.Module): 187 | def __init__(self, in_channel, out_channel): 188 | super(DoubleFactorizedIncrease, self).__init__() 189 | 190 | self._in_channel = in_channel 191 | self.op = nn.Sequential( 192 | nn.Upsample(scale_factor=2, mode="bilinear"), 193 | nn.Conv2d(self._in_channel, out_channel, 1, stride=1, padding=0), 194 | nn.BatchNorm2d(out_channel), 195 | nn.ReLU(inplace=False), 196 | nn.Upsample(scale_factor=2, mode="bilinear"), 197 | nn.Conv2d(self._in_channel, out_channel, 1, stride=1, padding=0), 198 | nn.BatchNorm2d(out_channel), 199 | nn.ReLU(inplace=False) 200 | ) 201 | self._initialize_weights() 202 | 203 | def forward(self, x): 204 | return self.op(x) 205 | 206 | def _initialize_weights(self): 207 | for m in self.modules(): 208 | if isinstance(m, nn.Conv2d): 209 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 210 | elif isinstance(m, nn.BatchNorm2d): 211 | nn.init.constant_(m.weight, 1) 212 | nn.init.constant_(m.bias, 0) 213 | 214 | 215 | -------------------------------------------------------------------------------- /models/operations_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | OPS = { 6 | 'skip_connect': lambda C, stride: Identity() if stride == 1 else FactorizedReduce(C, C), 7 | '3d_conv_3x3': lambda C, stride: ConvBR(C, C, 3, stride, 1) 8 | } 9 | 10 | class NaiveBN(nn.Module): 11 | def __init__(self, C_out, momentum=0.1): 12 | super(NaiveBN, self).__init__() 13 | self.op = nn.Sequential( 14 | nn.BatchNorm3d(C_out), 15 | nn.ReLU() 16 | ) 17 | self._initialize_weights() 18 | 19 | def forward(self, x): 20 | return self.op(x) 21 | 22 | def _initialize_weights(self): 23 | for m in self.modules(): 24 | if isinstance(m, nn.Conv3d): 25 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 26 | elif isinstance(m, nn.BatchNorm3d): 27 | nn.init.constant_(m.weight, 1) 28 | nn.init.constant_(m.bias, 0) 29 | 30 | 31 | class ConvBR(nn.Module): 32 | def __init__(self, C_in, C_out, kernel_size, stride, padding, bn=True, relu=True): 33 | super(ConvBR, self).__init__() 34 | self.relu = relu 35 | self.use_bn = bn 36 | 37 | self.conv = nn.Conv3d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False) 38 | self.bn = nn.BatchNorm3d(C_out) 39 | self._initialize_weights() 40 | 41 | def forward(self, x): 42 | x = self.conv(x) 43 | if self.use_bn: 44 | x = self.bn(x) 45 | if self.relu: 46 | x = F.relu(x, inplace=True) 47 | return x 48 | 49 | def _initialize_weights(self): 50 | for m in self.modules(): 51 | if isinstance(m, nn.Conv3d): 52 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 53 | elif isinstance(m, nn.BatchNorm3d): 54 | nn.init.constant_(m.weight, 1) 55 | nn.init.constant_(m.bias, 0) 56 | 57 | class SepConv(nn.Module): 58 | def __init__(self, C_in, C_out, kernel_size, stride, padding): 59 | super(SepConv, self).__init__() 60 | self.op = nn.Sequential( 61 | nn.Conv3d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, groups=C_in, 62 | bias=False), 63 | nn.Conv3d(C_in, C_in, kernel_size=1, padding=0, bias=False), 64 | nn.BatchNorm3d(C_in), 65 | nn.ReLU(inplace=False), 66 | nn.Conv3d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, groups=C_in, bias=False), 67 | nn.Conv3d(C_in, C_out, kernel_size=1, padding=0, bias=False), 68 | nn.BatchNorm3d(C_out), 69 | nn.ReLU(inplace=False) 70 | ) 71 | self._initialize_weights() 72 | 73 | def forward(self, x): 74 | return self.op(x) 75 | 76 | def _initialize_weights(self): 77 | for m in self.modules(): 78 | if isinstance(m, nn.Conv3d): 79 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 80 | elif isinstance(m, nn.BatchNorm3d): 81 | nn.init.constant_(m.weight, 1) 82 | nn.init.constant_(m.bias, 0) 83 | 84 | class Identity(nn.Module): 85 | def __init__(self): 86 | super(Identity, self).__init__() 87 | self._initialize_weights() 88 | 89 | def forward(self, x): 90 | return x 91 | 92 | def init_weight(self): 93 | for ly in self.children(): 94 | if isinstance(ly, nn.Conv3d): 95 | nn.init.kaiming_normal_(ly.weight, a=1) 96 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 97 | 98 | def _initialize_weights(self): 99 | for m in self.modules(): 100 | if isinstance(m, nn.Conv3d): 101 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 102 | elif isinstance(m, nn.BatchNorm3d): 103 | nn.init.constant_(m.weight, 1) 104 | nn.init.constant_(m.bias, 0) 105 | 106 | 107 | class FactorizedReduce(nn.Module): 108 | def __init__(self, C_in, C_out): 109 | super(FactorizedReduce, self).__init__() 110 | assert C_out % 2 == 0 111 | self.relu = nn.ReLU(inplace=False) 112 | self.conv_1 = nn.Conv3d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 113 | self.conv_2 = nn.Conv3d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False) 114 | self.bn = nn.BatchNorm3d(C_out) 115 | self._initialize_weights() 116 | 117 | def forward(self, x): 118 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) 119 | out = self.bn(out) 120 | out = self.relu(out) 121 | return out 122 | 123 | def init_weight(self): 124 | for ly in self.children(): 125 | if isinstance(ly, nn.Conv3d): 126 | nn.init.kaiming_normal_(ly.weight, a=1) 127 | if not ly.bias is None: nn.init.constant_(ly.bias, 0) 128 | 129 | def _initialize_weights(self): 130 | for m in self.modules(): 131 | if isinstance(m, nn.Conv3d): 132 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 133 | elif isinstance(m, nn.BatchNorm3d): 134 | nn.init.constant_(m.weight, 1) 135 | nn.init.constant_(m.bias, 0) 136 | 137 | 138 | class DoubleFactorizedReduce(nn.Module): 139 | def __init__(self, C_in, C_out): 140 | super(DoubleFactorizedReduce, self).__init__() 141 | assert C_out % 2 == 0 142 | self.relu = nn.ReLU(inplace=False) 143 | self.conv_1 = nn.Conv3d(C_in, C_out // 2, 1, stride=4, padding=0, bias=False) 144 | self.conv_2 = nn.Conv3d(C_in, C_out // 2, 1, stride=4, padding=0, bias=False) 145 | self.bn = nn.BatchNorm3d(C_out) 146 | self._initialize_weights() 147 | 148 | def forward(self, x): 149 | out = torch.cat([self.conv_1(x), self.conv_2(x[:, :, 1:, 1:, 1:])], dim=1) 150 | out = self.bn(out) 151 | out = self.relu(out) 152 | return out 153 | 154 | def _initialize_weights(self): 155 | for m in self.modules(): 156 | if isinstance(m, nn.Conv3d): 157 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 158 | elif isinstance(m, nn.BatchNorm3d): 159 | nn.init.constant_(m.weight, 1) 160 | nn.init.constant_(m.bias, 0) 161 | 162 | class FactorizedIncrease(nn.Module): 163 | def __init__(self, in_channel, out_channel): 164 | super(FactorizedIncrease, self).__init__() 165 | 166 | self._in_channel = in_channel 167 | self.op = nn.Sequential( 168 | nn.Upsample(scale_factor=2, mode="trilinear"), 169 | nn.Conv3d(self._in_channel, out_channel, 1, stride=1, padding=0), 170 | nn.BatchNorm3d(out_channel), 171 | nn.ReLU(inplace=False) 172 | ) 173 | self._initialize_weights() 174 | 175 | def forward(self, x): 176 | return self.op(x) 177 | 178 | def _initialize_weights(self): 179 | for m in self.modules(): 180 | if isinstance(m, nn.Conv3d): 181 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 182 | elif isinstance(m, nn.BatchNorm3d): 183 | nn.init.constant_(m.weight, 1) 184 | nn.init.constant_(m.bias, 0) 185 | 186 | 187 | class DoubleFactorizedIncrease(nn.Module): 188 | def __init__(self, in_channel, out_channel): 189 | super(DoubleFactorizedIncrease, self).__init__() 190 | 191 | self._in_channel = in_channel 192 | self.op = nn.Sequential( 193 | nn.Upsample(scale_factor=2, mode="trilinear"), 194 | nn.Conv3d(self._in_channel, out_channel, 1, stride=1, padding=0), 195 | nn.BatchNorm3d(out_channel), 196 | nn.ReLU(inplace=False), 197 | nn.Upsample(scale_factor=2, mode="trilinear"), 198 | nn.Conv3d(self._in_channel, out_channel, 1, stride=1, padding=0), 199 | nn.BatchNorm3d(out_channel), 200 | nn.ReLU(inplace=False) 201 | ) 202 | self._initialize_weights() 203 | 204 | def forward(self, x): 205 | return self.op(x) 206 | 207 | def _initialize_weights(self): 208 | for m in self.modules(): 209 | if isinstance(m, nn.Conv3d): 210 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 211 | elif isinstance(m, nn.BatchNorm3d): 212 | nn.init.constant_(m.weight, 1) 213 | nn.init.constant_(m.bias, 0) 214 | 215 | 216 | -------------------------------------------------------------------------------- /mypath.py: -------------------------------------------------------------------------------- 1 | class Path(object): 2 | @staticmethod 3 | def db_root_dir(dataset): 4 | if dataset == 'sceneflow': 5 | return './dataset/SceneFlow/' 6 | elif dataset == 'kitti15': 7 | return './dataset/kitti2015/training/' 8 | elif dataset == 'kitti12': 9 | return './dataset/kitti2012/training/' 10 | elif dataset == 'middlebury': 11 | return './dataset/MiddEval3/trainingH/' 12 | else: 13 | print('Dataset {} not available.'.format(dataset)) 14 | raise NotImplementedError 15 | -------------------------------------------------------------------------------- /predict.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | import skimage 4 | import skimage.io 5 | import skimage.transform 6 | from PIL import Image 7 | from math import log10 8 | 9 | import sys 10 | import shutil 11 | import os 12 | import torch 13 | import torch.nn as nn 14 | import torch.nn.parallel 15 | import torch.backends.cudnn as cudnn 16 | import torch.optim as optim 17 | from torch.autograd import Variable 18 | from torch.utils.data import DataLoader 19 | from retrain.LEAStereo import LEAStereo 20 | 21 | from config_utils.predict_args import obtain_predict_args 22 | from utils.colorize import get_color_map 23 | from utils.multadds_count import count_parameters_in_MB, comp_multadds 24 | from time import time 25 | from struct import unpack 26 | import matplotlib.pyplot as plt 27 | import re 28 | import numpy as np 29 | import pdb 30 | from path import Path 31 | 32 | opt = obtain_predict_args() 33 | print(opt) 34 | 35 | torch.backends.cudnn.benchmark = True 36 | 37 | cuda = opt.cuda 38 | if cuda and not torch.cuda.is_available(): 39 | raise Exception("No GPU found, please run without --cuda") 40 | 41 | print('===> Building LEAStereo model') 42 | model = LEAStereo(opt) 43 | 44 | print('Total Params = %.2fMB' % count_parameters_in_MB(model)) 45 | print('Feature Net Params = %.2fMB' % count_parameters_in_MB(model.feature)) 46 | print('Matching Net Params = %.2fMB' % count_parameters_in_MB(model.matching)) 47 | 48 | mult_adds = comp_multadds(model, input_size=(3,opt.crop_height, opt.crop_width)) #(3,192, 192)) 49 | print("compute_average_flops_cost = %.2fMB" % mult_adds) 50 | 51 | if cuda: 52 | model = torch.nn.DataParallel(model).cuda() 53 | 54 | if opt.resume: 55 | if os.path.isfile(opt.resume): 56 | print("=> loading checkpoint '{}'".format(opt.resume)) 57 | checkpoint = torch.load(opt.resume) 58 | model.load_state_dict(checkpoint['state_dict'], strict=True) 59 | else: 60 | print("=> no checkpoint found at '{}'".format(opt.resume)) 61 | 62 | turbo_colormap_data = get_color_map() 63 | 64 | def RGBToPyCmap(rgbdata): 65 | nsteps = rgbdata.shape[0] 66 | stepaxis = np.linspace(0, 1, nsteps) 67 | 68 | rdata=[]; gdata=[]; bdata=[] 69 | for istep in range(nsteps): 70 | r = rgbdata[istep,0] 71 | g = rgbdata[istep,1] 72 | b = rgbdata[istep,2] 73 | rdata.append((stepaxis[istep], r, r)) 74 | gdata.append((stepaxis[istep], g, g)) 75 | bdata.append((stepaxis[istep], b, b)) 76 | 77 | mpl_data = {'red': rdata, 78 | 'green': gdata, 79 | 'blue': bdata} 80 | 81 | return mpl_data 82 | 83 | mpl_data = RGBToPyCmap(turbo_colormap_data) 84 | plt.register_cmap(name='turbo', data=mpl_data, lut=turbo_colormap_data.shape[0]) 85 | 86 | def readPFM(file): 87 | with open(file, "rb") as f: 88 | # Line 1: PF=>RGB (3 channels), Pf=>Greyscale (1 channel) 89 | type = f.readline().decode('latin-1') 90 | if "PF" in type: 91 | channels = 3 92 | elif "Pf" in type: 93 | channels = 1 94 | else: 95 | sys.exit(1) 96 | # Line 2: width height 97 | line = f.readline().decode('latin-1') 98 | width, height = re.findall('\d+', line) 99 | width = int(width) 100 | height = int(height) 101 | 102 | # Line 3: +ve number means big endian, negative means little endian 103 | line = f.readline().decode('latin-1') 104 | BigEndian = True 105 | if "-" in line: 106 | BigEndian = False 107 | # Slurp all binary data 108 | samples = width * height * channels; 109 | buffer = f.read(samples * 4) 110 | # Unpack floats with appropriate endianness 111 | if BigEndian: 112 | fmt = ">" 113 | else: 114 | fmt = "<" 115 | fmt = fmt + str(samples) + "f" 116 | img = unpack(fmt, buffer) 117 | img = np.reshape(img, (height, width)) 118 | img = np.flipud(img) 119 | 120 | return img, height, width 121 | 122 | def save_pfm(filename, image, scale=1): 123 | ''' 124 | Save a Numpy array to a PFM file. 125 | ''' 126 | color = None 127 | file = open(filename, "w") 128 | if image.dtype.name != 'float32': 129 | raise Exception('Image dtype must be float32.') 130 | 131 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 132 | color = True 133 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 134 | color = False 135 | else: 136 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 137 | 138 | file.write('PF\n' if color else 'Pf\n') 139 | file.write('%d %d\n' % (image.shape[1], image.shape[0])) 140 | 141 | endian = image.dtype.byteorder 142 | 143 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 144 | scale = -scale 145 | 146 | file.write('%f\n' % scale) 147 | 148 | image.tofile(file) 149 | 150 | def test_transform(temp_data, crop_height, crop_width): 151 | _, h, w=np.shape(temp_data) 152 | 153 | if h <= crop_height and w <= crop_width: 154 | # padding zero 155 | temp = temp_data 156 | temp_data = np.zeros([6, crop_height, crop_width], 'float32') 157 | temp_data[:, crop_height - h: crop_height, crop_width - w: crop_width] = temp 158 | else: 159 | start_x = int((w - crop_width) / 2) 160 | start_y = int((h - crop_height) / 2) 161 | temp_data = temp_data[:, start_y: start_y + crop_height, start_x: start_x + crop_width] 162 | left = np.ones([1, 3,crop_height,crop_width],'float32') 163 | left[0, :, :, :] = temp_data[0: 3, :, :] 164 | right = np.ones([1, 3, crop_height, crop_width], 'float32') 165 | right[0, :, :, :] = temp_data[3: 6, :, :] 166 | return torch.from_numpy(left).float(), torch.from_numpy(right).float(), h, w 167 | 168 | def load_data(leftname, rightname): 169 | left = Image.open(leftname) 170 | right = Image.open(rightname) 171 | size = np.shape(left) 172 | height = size[0] 173 | width = size[1] 174 | temp_data = np.zeros([6, height, width], 'float32') 175 | left = np.asarray(left) 176 | right = np.asarray(right) 177 | r = left[:, :, 0] 178 | g = left[:, :, 1] 179 | b = left[:, :, 2] 180 | temp_data[0, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 181 | temp_data[1, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 182 | temp_data[2, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 183 | r = right[:, :, 0] 184 | g = right[:, :, 1] 185 | b = right[:, :, 2] 186 | #r,g,b,_ = right.split() 187 | temp_data[3, :, :] = (r - np.mean(r[:])) / np.std(r[:]) 188 | temp_data[4, :, :] = (g - np.mean(g[:])) / np.std(g[:]) 189 | temp_data[5, :, :] = (b - np.mean(b[:])) / np.std(b[:]) 190 | return temp_data 191 | 192 | def test_md(leftname, rightname, savename, imgname): 193 | 194 | input1, input2, height, width = test_transform(load_data(leftname, rightname), opt.crop_height, opt.crop_width) 195 | 196 | input1 = Variable(input1, requires_grad = False) 197 | input2 = Variable(input2, requires_grad = False) 198 | 199 | model.eval() 200 | if cuda: 201 | input1 = input1.cuda() 202 | input2 = input2.cuda() 203 | torch.cuda.synchronize() 204 | start_time = time() 205 | with torch.no_grad(): 206 | prediction = model(input1, input2) 207 | torch.cuda.synchronize() 208 | end_time = time() 209 | 210 | print("Processing time: {:.4f}".format(end_time - start_time)) 211 | temp = prediction.cpu() 212 | temp = temp.detach().numpy() 213 | if height <= opt.crop_height or width <= opt.crop_width: 214 | temp = temp[0, opt.crop_height - height: opt.crop_height, opt.crop_width - width: opt.crop_width] 215 | else: 216 | temp = temp[0, :, :] 217 | plot_disparity(imgname, temp, 192) 218 | savepfm_path = savename.replace('.png','') 219 | temp = np.flipud(temp) 220 | 221 | disppath = Path(savepfm_path) 222 | disppath.makedirs_p() 223 | save_pfm(savepfm_path+'/disp0LEAStereo.pfm', temp, scale=1) 224 | ##########write time txt######## 225 | fp = open(savepfm_path+'/timeLEAStereo.txt', 'w') 226 | runtime = "XXs" 227 | fp.write(runtime) 228 | fp.close() 229 | 230 | def test_kitti(leftname, rightname, savename): 231 | input1, input2, height, width = test_transform(load_data(leftname, rightname), opt.crop_height, opt.crop_width) 232 | 233 | input1 = Variable(input1, requires_grad = False) 234 | input2 = Variable(input2, requires_grad = False) 235 | 236 | model.eval() 237 | if cuda: 238 | input1 = input1.cuda() 239 | input2 = input2.cuda() 240 | with torch.no_grad(): 241 | prediction = model(input1, input2) 242 | 243 | temp = prediction.cpu() 244 | temp = temp.detach().numpy() 245 | if height <= opt.crop_height and width <= opt.crop_width: 246 | temp = temp[0, opt.crop_height - height: opt.crop_height, opt.crop_width - width: opt.crop_width] 247 | else: 248 | temp = temp[0, :, :] 249 | skimage.io.imsave(savename, (temp * 256).astype('uint16')) 250 | 251 | 252 | def test(leftname, rightname, savename): 253 | input1, input2, height, width = test_transform(load_data(leftname, rightname), opt.crop_height, opt.crop_width) 254 | 255 | input1 = Variable(input1, requires_grad = False) 256 | input2 = Variable(input2, requires_grad = False) 257 | 258 | model.eval() 259 | if cuda: 260 | input1 = input1.cuda() 261 | input2 = input2.cuda() 262 | 263 | start_time = time() 264 | with torch.no_grad(): 265 | prediction = model(input1, input2) 266 | end_time = time() 267 | 268 | print("Processing time: {:.4f}".format(end_time - start_time)) 269 | temp = prediction.cpu() 270 | temp = temp.detach().numpy() 271 | if height <= opt.crop_height or width <= opt.crop_width: 272 | temp = temp[0, opt.crop_height - height: opt.crop_height, opt.crop_width - width: opt.crop_width] 273 | else: 274 | temp = temp[0, :, :] 275 | plot_disparity(savename, temp, 192) 276 | savename_pfm = savename.replace('png','pfm') 277 | temp = np.flipud(temp) 278 | 279 | def plot_disparity(savename, data, max_disp): 280 | plt.imsave(savename, data, vmin=0, vmax=max_disp, cmap='turbo') 281 | 282 | 283 | if __name__ == "__main__": 284 | file_path = opt.data_path 285 | file_list = opt.test_list 286 | f = open(file_list, 'r') 287 | filelist = f.readlines() 288 | for index in range(len(filelist)): 289 | current_file = filelist[index] 290 | if opt.kitti2015: 291 | leftname = file_path + 'image_2/' + current_file[0: len(current_file) - 1] 292 | rightname = file_path + 'image_3/' + current_file[0: len(current_file) - 1] 293 | savename = opt.save_path + current_file[0: len(current_file) - 1] 294 | test_kitti(leftname, rightname, savename) 295 | 296 | if opt.kitti2012: 297 | leftname = file_path + 'colored_0/' + current_file[0: len(current_file) - 1] 298 | rightname = file_path + 'colored_1/' + current_file[0: len(current_file) - 1] 299 | savename = opt.save_path + current_file[0: len(current_file) - 1] 300 | test_kitti(leftname, rightname, savename) 301 | 302 | if opt.sceneflow: 303 | leftname = file_path + 'frames_finalpass/' + current_file[0: len(current_file) - 1] 304 | rightname = file_path + 'frames_finalpass/' + current_file[0: len(current_file) - 14] + 'right/' + current_file[len(current_file) - 9:len(current_file) - 1] 305 | leftgtname = file_path + 'disparity/' + current_file[0: len(current_file) - 4] + 'pfm' 306 | disp_left_gt, height, width = readPFM(leftgtname) 307 | savenamegt = opt.save_path + "{:d}_gt.png".format(index) 308 | plot_disparity(savenamegt, disp_left_gt, 192) 309 | 310 | savename = opt.save_path + "{:d}.png".format(index) 311 | test(leftname, rightname, savename) 312 | 313 | if opt.middlebury: 314 | leftname = file_path + current_file[0: len(current_file) - 1] 315 | rightname = leftname.replace('im0','im1') 316 | 317 | temppath = opt.save_path.replace(opt.save_path.split("/")[-2], opt.save_path.split("/")[-2]+"/images") 318 | img_path = Path(temppath) 319 | img_path.makedirs_p() 320 | savename = opt.save_path + current_file[0: len(current_file) - 9] + ".png" 321 | img_name = img_path + current_file[0: len(current_file) - 9] + ".png" 322 | test_md(leftname, rightname, savename, img_name) 323 | 324 | -------------------------------------------------------------------------------- /predict_kitti12.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python predict.py \ 2 | --kitti2012=1 --maxdisp=192 \ 3 | --crop_height=384 --crop_width=1248 \ 4 | --data_path='./dataset/kitti2012/testing/' \ 5 | --test_list='./dataloaders/lists/kitti2012_test.list' \ 6 | --save_path='./predict/kitti2012/images/' \ 7 | --fea_num_layer 6 --mat_num_layers 12\ 8 | --fea_filter_multiplier 8 --fea_block_multiplier 4 --fea_step 3 \ 9 | --mat_filter_multiplier 8 --mat_block_multiplier 4 --mat_step 3 \ 10 | --net_arch_fea='run/sceneflow/best/architecture/feature_network_path.npy' \ 11 | --cell_arch_fea='run/sceneflow/best/architecture/feature_genotype.npy' \ 12 | --net_arch_mat='run/sceneflow/best/architecture/matching_network_path.npy' \ 13 | --cell_arch_mat='run/sceneflow/best/architecture/matching_genotype.npy' \ 14 | --resume './run/Kitti12/best/best_1.16.pth' 15 | 16 | -------------------------------------------------------------------------------- /predict_kitti15.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python predict.py \ 2 | --kitti2015=1 --maxdisp=192 \ 3 | --crop_height=384 --crop_width=1248 \ 4 | --data_path='./dataset/kitti2015/testing/' \ 5 | --test_list='./dataloaders/lists/kitti2015_test.list' \ 6 | --save_path='./predict/kitti2015/images/' \ 7 | --fea_num_layer 6 --mat_num_layers 12\ 8 | --fea_filter_multiplier 8 --fea_block_multiplier 4 --fea_step 3 \ 9 | --mat_filter_multiplier 8 --mat_block_multiplier 4 --mat_step 3 \ 10 | --net_arch_fea='run/sceneflow/best/architecture/feature_network_path.npy' \ 11 | --cell_arch_fea='run/sceneflow/best/architecture/feature_genotype.npy' \ 12 | --net_arch_mat='run/sceneflow/best/architecture/matching_network_path.npy' \ 13 | --cell_arch_mat='run/sceneflow/best/architecture/matching_genotype.npy' \ 14 | --resume './run/Kitti15/best/best.pth' 15 | 16 | -------------------------------------------------------------------------------- /predict_md.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python predict.py \ 2 | --middlebury=1 --maxdisp=408 \ 3 | --crop_height=1008 --crop_width=1512 \ 4 | --data_path='./dataset/MiddEval3/testH/' \ 5 | --test_list='./dataloaders/lists/middeval3_test.list' \ 6 | --save_path='./predict/middlebury/images/' \ 7 | --fea_num_layer 6 --mat_num_layers 12\ 8 | --fea_filter_multiplier 8 --fea_block_multiplier 4 --fea_step 3 \ 9 | --mat_filter_multiplier 8 --mat_block_multiplier 4 --mat_step 3 \ 10 | --net_arch_fea='run/sceneflow/best/architecture/feature_network_path.npy' \ 11 | --cell_arch_fea='run/sceneflow/best/architecture/feature_genotype.npy' \ 12 | --net_arch_mat='run/sceneflow/best/architecture/matching_network_path.npy' \ 13 | --cell_arch_mat='run/sceneflow/best/architecture/matching_genotype.npy' \ 14 | --resume './run/MiddEval3/best/best.pth' 15 | 16 | -------------------------------------------------------------------------------- /predict_sf.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python predict.py \ 2 | --sceneflow=1 --maxdisp=192 \ 3 | --crop_height=576 --crop_width=960 \ 4 | --data_path='./dataset/SceneFlow/' \ 5 | --test_list='./dataloaders/lists/sceneflow_test.list' \ 6 | --save_path='./predict/sceneflow/images/' \ 7 | --fea_num_layer 6 --mat_num_layers 12\ 8 | --fea_filter_multiplier 8 --fea_block_multiplier 4 --fea_step 3 \ 9 | --mat_filter_multiplier 8 --mat_block_multiplier 4 --mat_step 3 \ 10 | --net_arch_fea='run/sceneflow/best/architecture/feature_network_path.npy' \ 11 | --cell_arch_fea='run/sceneflow/best/architecture/feature_genotype.npy' \ 12 | --net_arch_mat='run/sceneflow/best/architecture/matching_network_path.npy' \ 13 | --cell_arch_mat='run/sceneflow/best/architecture/matching_genotype.npy' \ 14 | --resume='./run/sceneflow/best/checkpoint/best.pth' 15 | 16 | -------------------------------------------------------------------------------- /retrain/LEAStereo.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from models.build_model_2d import Disp 7 | from models.decoding_formulas import network_layer_to_space 8 | from retrain.new_model_2d import newFeature 9 | from retrain.skip_model_3d import newMatching 10 | 11 | class LEAStereo(nn.Module): 12 | def __init__(self, args): 13 | super(LEAStereo, self).__init__() 14 | 15 | network_path_fea, cell_arch_fea = np.load(args.net_arch_fea), np.load(args.cell_arch_fea) 16 | network_path_mat, cell_arch_mat = np.load(args.net_arch_mat), np.load(args.cell_arch_mat) 17 | print('Feature network path:{}\nMatching network path:{} \n'.format(network_path_fea, network_path_mat)) 18 | 19 | network_arch_fea = network_layer_to_space(network_path_fea) 20 | network_arch_mat = network_layer_to_space(network_path_mat) 21 | 22 | self.maxdisp = args.maxdisp 23 | self.feature = newFeature(network_arch_fea, cell_arch_fea, args=args) 24 | self.matching= newMatching(network_arch_mat, cell_arch_mat, args=args) 25 | self.disp = Disp(self.maxdisp) 26 | 27 | def forward(self, x, y): 28 | x = self.feature(x) 29 | y = self.feature(y) 30 | 31 | with torch.cuda.device_of(x): 32 | cost = x.new().resize_(x.size()[0], x.size()[1]*2, int(self.maxdisp/3), x.size()[2], x.size()[3]).zero_() 33 | for i in range(int(self.maxdisp/3)): 34 | if i > 0 : 35 | cost[:,:x.size()[1], i,:,i:] = x[:,:,:,i:] 36 | cost[:,x.size()[1]:, i,:,i:] = y[:,:,:,:-i] 37 | else: 38 | cost[:,:x.size()[1],i,:,i:] = x 39 | cost[:,x.size()[1]:,i,:,i:] = y 40 | 41 | cost = self.matching(cost) 42 | disp = self.disp(cost) 43 | return disp 44 | 45 | def get_params(self): 46 | back_bn_params, back_no_bn_params = self.encoder.get_params() 47 | tune_wd_params = list(self.aspp.parameters()) \ 48 | + list(self.decoder.parameters()) \ 49 | + back_no_bn_params 50 | no_tune_wd_params = back_bn_params 51 | return tune_wd_params, no_tune_wd_params 52 | -------------------------------------------------------------------------------- /retrain/new_model_2d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.genotypes_2d import PRIMITIVES 6 | from models.genotypes_2d import Genotype 7 | from models.operations_2d import * 8 | import torch.nn.functional as F 9 | import numpy as np 10 | import pdb 11 | 12 | class Cell(nn.Module): 13 | def __init__(self, steps, block_multiplier, prev_prev_fmultiplier, 14 | prev_filter_multiplier, cell_arch, network_arch, 15 | filter_multiplier, downup_sample, args=None): 16 | super(Cell, self).__init__() 17 | self.cell_arch = cell_arch 18 | 19 | self.C_in = block_multiplier * filter_multiplier 20 | self.C_out = filter_multiplier 21 | self.C_prev = int(block_multiplier * prev_filter_multiplier) 22 | self.C_prev_prev = int(block_multiplier * prev_prev_fmultiplier) 23 | self.downup_sample = downup_sample 24 | self.pre_preprocess = ConvBR(self.C_prev_prev, self.C_out, 1, 1, 0) 25 | self.preprocess = ConvBR(self.C_prev, self.C_out, 1, 1, 0) 26 | self._steps = steps 27 | self.block_multiplier = block_multiplier 28 | self._ops = nn.ModuleList() 29 | if downup_sample == -1: 30 | self.scale = 0.5 31 | elif downup_sample == 1: 32 | self.scale = 2 33 | for x in self.cell_arch: 34 | primitive = PRIMITIVES[x[1]] 35 | op = OPS[primitive](self.C_out, stride=1) 36 | self._ops.append(op) 37 | 38 | def scale_dimension(self, dim, scale): 39 | return (int((float(dim) - 1.0) * scale + 1.0) if dim % 2 == 1 else int((float(dim) * scale))) 40 | 41 | def forward(self, prev_prev_input, prev_input): 42 | s0 = prev_prev_input 43 | s1 = prev_input 44 | if self.downup_sample != 0: 45 | feature_size_h = self.scale_dimension(s1.shape[2], self.scale) 46 | feature_size_w = self.scale_dimension(s1.shape[3], self.scale) 47 | s1 = F.interpolate(s1, [feature_size_h, feature_size_w], mode='bilinear', align_corners=True) 48 | if (s0.shape[2] != s1.shape[2]) or (s0.shape[3] != s1.shape[3]): 49 | s0 = F.interpolate(s0, (s1.shape[2], s1.shape[3]), 50 | mode='bilinear', align_corners=True) 51 | 52 | s0 = self.pre_preprocess(s0) if (s0.shape[1] != self.C_out) else s0 53 | s1 = self.preprocess(s1) 54 | 55 | states = [s0, s1] 56 | offset = 0 57 | ops_index = 0 58 | for i in range(self._steps): 59 | new_states = [] 60 | for j, h in enumerate(states): 61 | branch_index = offset + j 62 | if branch_index in self.cell_arch[:, 0]: 63 | if prev_prev_input is None and j == 0: 64 | ops_index += 1 65 | continue 66 | new_state = self._ops[ops_index](h) 67 | new_states.append(new_state) 68 | ops_index += 1 69 | 70 | s = sum(new_states) 71 | offset += len(states) 72 | states.append(s) 73 | 74 | concat_feature = torch.cat(states[-self.block_multiplier:], dim=1) 75 | return prev_input, concat_feature 76 | 77 | 78 | class newFeature(nn.Module): 79 | def __init__(self, network_arch, cell_arch, cell=Cell, args=None): 80 | super(newFeature, self).__init__() 81 | self.args = args 82 | self.cells = nn.ModuleList() 83 | self.network_arch = torch.from_numpy(network_arch) 84 | self.cell_arch = torch.from_numpy(cell_arch) 85 | self._step = args.fea_step 86 | self._num_layers = args.fea_num_layers 87 | self._block_multiplier = args.fea_block_multiplier 88 | self._filter_multiplier = args.fea_filter_multiplier 89 | 90 | initial_fm = self._filter_multiplier * self._block_multiplier 91 | half_initial_fm = initial_fm // 2 92 | 93 | self.stem0 = ConvBR(3, half_initial_fm, 3, stride=1, padding=1) 94 | self.stem1 = ConvBR(half_initial_fm, initial_fm, 3, stride=3, padding=1) 95 | self.stem2 = ConvBR(initial_fm, initial_fm, 3, stride=1, padding=1) 96 | 97 | filter_param_dict = {0: 1, 1: 2, 2: 4, 3: 8} 98 | 99 | for i in range(self._num_layers): 100 | level_option = torch.sum(self.network_arch[i], dim=1) 101 | prev_level_option = torch.sum(self.network_arch[i - 1], dim=1) 102 | prev_prev_level_option = torch.sum(self.network_arch[i - 2], dim=1) 103 | level = torch.argmax(level_option).item() 104 | prev_level = torch.argmax(prev_level_option).item() 105 | prev_prev_level = torch.argmax(prev_prev_level_option).item() 106 | 107 | if i == 0: 108 | downup_sample = - torch.argmax(torch.sum(self.network_arch[0], dim=1)) 109 | _cell = cell(self._step, self._block_multiplier, initial_fm / self._block_multiplier, 110 | initial_fm / self._block_multiplier, 111 | self.cell_arch, self.network_arch[i], 112 | int(self._filter_multiplier * filter_param_dict[level]), 113 | downup_sample, self.args) 114 | 115 | else: 116 | three_branch_options = torch.sum(self.network_arch[i], dim=0) 117 | downup_sample = torch.argmax(three_branch_options).item() - 1 118 | if i == 1: 119 | _cell = cell(self._step, self._block_multiplier, 120 | initial_fm / self._block_multiplier, 121 | int(self._filter_multiplier * filter_param_dict[prev_level]), 122 | self.cell_arch, self.network_arch[i], 123 | int(self._filter_multiplier * filter_param_dict[level]), 124 | downup_sample, self.args) 125 | 126 | else: 127 | _cell = cell(self._step, self._block_multiplier, 128 | int(self._filter_multiplier * filter_param_dict[prev_prev_level]), 129 | int(self._filter_multiplier * filter_param_dict[prev_level]), 130 | self.cell_arch, self.network_arch[i], 131 | int(self._filter_multiplier * filter_param_dict[level]), downup_sample, self.args) 132 | 133 | self.cells += [_cell] 134 | 135 | self.last_3 = ConvBR(initial_fm , initial_fm, 1, 1, 0, bn=False, relu=False) 136 | self.last_6 = ConvBR(initial_fm*2 , initial_fm, 1, 1, 0) 137 | self.last_12 = ConvBR(initial_fm*4 , initial_fm*2, 1, 1, 0) 138 | self.last_24 = ConvBR(initial_fm*8 , initial_fm*4, 1, 1, 0) 139 | 140 | def forward(self, x): 141 | stem0 = self.stem0(x) 142 | stem1 = self.stem1(stem0) 143 | stem2 = self.stem2(stem1) 144 | out = (stem1, stem2) 145 | 146 | for i in range(self._num_layers): 147 | out = self.cells[i](out[0], out[1]) 148 | 149 | last_output = out[-1] 150 | 151 | h, w = stem2.size()[2], stem2.size()[3] 152 | upsample_6 = nn.Upsample(size=stem2.size()[2:], mode='bilinear', align_corners=True) 153 | upsample_12 = nn.Upsample(size=[h//2, w//2], mode='bilinear', align_corners=True) 154 | upsample_24 = nn.Upsample(size=[h//4, w//4], mode='bilinear', align_corners=True) 155 | 156 | if last_output.size()[2] == h: 157 | fea = self.last_3(last_output) 158 | elif last_output.size()[2] == h//2: 159 | fea = self.last_3(upsample_6(self.last_6(last_output))) 160 | elif last_output.size()[2] == h//4: 161 | fea = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(last_output))))) 162 | elif last_output.size()[2] == h//8: 163 | fea = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(upsample_24(self.last_24(last_output))))))) 164 | 165 | return fea 166 | 167 | def get_params(self): 168 | bn_params = [] 169 | non_bn_params = [] 170 | for name, param in self.named_parameters(): 171 | if 'bn' in name or 'downsample.1' in name: 172 | bn_params.append(param) 173 | else: 174 | bn_params.append(param) 175 | return bn_params, non_bn_params 176 | 177 | -------------------------------------------------------------------------------- /retrain/new_model_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from models.genotypes_3d import PRIMITIVES 5 | from models.genotypes_3d import Genotype 6 | from models.operations_3d import * 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import pdb 10 | 11 | class Cell(nn.Module): 12 | def __init__(self, steps, block_multiplier, prev_prev_fmultiplier, 13 | prev_filter_multiplier, cell_arch, network_arch, 14 | filter_multiplier, downup_sample, args=None): 15 | super(Cell, self).__init__() 16 | self.cell_arch = cell_arch 17 | 18 | self.C_in = block_multiplier * filter_multiplier 19 | self.C_out = filter_multiplier 20 | self.C_prev = int(block_multiplier * prev_filter_multiplier) 21 | self.C_prev_prev = int(block_multiplier * prev_prev_fmultiplier) 22 | self.downup_sample = downup_sample 23 | self.pre_preprocess = ConvBR(self.C_prev_prev, self.C_out, 1, 1, 0) 24 | self.preprocess = ConvBR(self.C_prev, self.C_out, 1, 1, 0) 25 | self._steps = steps 26 | self.block_multiplier = block_multiplier 27 | self._ops = nn.ModuleList() 28 | if downup_sample == -1: 29 | self.scale = 0.5 30 | elif downup_sample == 1: 31 | self.scale = 2 32 | for x in self.cell_arch: 33 | primitive = PRIMITIVES[x[1]] 34 | op = OPS[primitive](self.C_out, stride=1) 35 | self._ops.append(op) 36 | 37 | def scale_dimension(self, dim, scale): 38 | return (int((float(dim) - 1.0) * scale + 1.0) if dim % 2 == 1 else int((float(dim) * scale))) 39 | 40 | def forward(self, prev_prev_input, prev_input): 41 | s0 = prev_prev_input 42 | s1 = prev_input 43 | if self.downup_sample != 0: 44 | feature_size_d = self.scale_dimension(s1.shape[2], self.scale) 45 | feature_size_h = self.scale_dimension(s1.shape[3], self.scale) 46 | feature_size_w = self.scale_dimension(s1.shape[4], self.scale) 47 | s1 = F.interpolate(s1, [feature_size_d, feature_size_h, feature_size_w], mode='trilinear', align_corners=True) 48 | if (s0.shape[2] != s1.shape[2]) or (s0.shape[3] != s1.shape[3]) or (s0.shape[4] != s1.shape[4]): 49 | s0 = F.interpolate(s0, (s1.shape[2], s1.shape[3], s1.shape[4]), 50 | mode='trilinear', align_corners=True) 51 | s0 = self.pre_preprocess(s0) if (s0.shape[1] != self.C_out) else s0 52 | s1 = self.preprocess(s1) 53 | 54 | states = [s0, s1] 55 | offset = 0 56 | ops_index = 0 57 | for i in range(self._steps): 58 | new_states = [] 59 | for j, h in enumerate(states): 60 | branch_index = offset + j 61 | if branch_index in self.cell_arch[:, 0]: 62 | if prev_prev_input is None and j == 0: 63 | ops_index += 1 64 | continue 65 | new_state = self._ops[ops_index](h) 66 | new_states.append(new_state) 67 | ops_index += 1 68 | 69 | s = sum(new_states) 70 | offset += len(states) 71 | states.append(s) 72 | 73 | concat_feature = torch.cat(states[-self.block_multiplier:], dim=1) 74 | return prev_input, concat_feature 75 | 76 | class newMatching(nn.Module): 77 | def __init__(self, network_arch, cell_arch, cell=Cell, args=None): 78 | super(newMatching, self).__init__() 79 | self.args = args 80 | self.cells = nn.ModuleList() 81 | self.network_arch = torch.from_numpy(network_arch) 82 | self.cell_arch = torch.from_numpy(cell_arch) 83 | self._step = args.mat_step 84 | self._num_layers = args.mat_num_layers 85 | self._block_multiplier = args.mat_block_multiplier 86 | self._filter_multiplier = args.mat_filter_multiplier 87 | 88 | initial_fm = self._filter_multiplier * self._block_multiplier 89 | half_initial_fm = initial_fm // 2 90 | 91 | self.stem0 = ConvBR(initial_fm*2, initial_fm, 3, stride=1, padding=1) 92 | self.stem1 = ConvBR(initial_fm, initial_fm, 3, stride=1, padding=1) 93 | 94 | filter_param_dict = {0: 1, 1: 2, 2: 4, 3: 8} 95 | for i in range(self._num_layers): 96 | level_option = torch.sum(self.network_arch[i], dim=1) 97 | prev_level_option = torch.sum(self.network_arch[i - 1], dim=1) 98 | prev_prev_level_option = torch.sum(self.network_arch[i - 2], dim=1) 99 | level = torch.argmax(level_option).item() 100 | prev_level = torch.argmax(prev_level_option).item() 101 | prev_prev_level = torch.argmax(prev_prev_level_option).item() 102 | if i == 0: 103 | downup_sample = - torch.argmax(torch.sum(self.network_arch[0], dim=1)) 104 | _cell = cell(self._step, self._block_multiplier, initial_fm / self._block_multiplier, 105 | initial_fm / self._block_multiplier, 106 | self.cell_arch, self.network_arch[i], 107 | self._filter_multiplier * filter_param_dict[level], 108 | downup_sample, self.args) 109 | 110 | else: 111 | three_branch_options = torch.sum(self.network_arch[i], dim=0) 112 | downup_sample = torch.argmax(three_branch_options).item() - 1 113 | if i == 1: 114 | _cell = cell(self._step, self._block_multiplier, 115 | initial_fm / self._block_multiplier, 116 | self._filter_multiplier * filter_param_dict[prev_level], 117 | self.cell_arch, self.network_arch[i], 118 | self._filter_multiplier * filter_param_dict[level], 119 | downup_sample, self.args) 120 | else: 121 | _cell = cell(self._step, self._block_multiplier, 122 | self._filter_multiplier * filter_param_dict[prev_prev_level], 123 | self._filter_multiplier * 124 | filter_param_dict[prev_level], 125 | self.cell_arch, self.network_arch[i], 126 | self._filter_multiplier * filter_param_dict[level], downup_sample, self.args) 127 | 128 | self.cells += [_cell] 129 | 130 | self.last_3 = ConvBR(initial_fm, 1, 3, 1, 1, bn=False, relu=False) 131 | self.last_6 = ConvBR(initial_fm*2 , initial_fm, 1, 1, 0) 132 | self.last_12 = ConvBR(initial_fm*4 , initial_fm*2, 1, 1, 0) 133 | self.last_24 = ConvBR(initial_fm*8 , initial_fm*4, 1, 1, 0) 134 | 135 | 136 | def forward(self, x): 137 | stem0 = self.stem0(x) 138 | stem1 = self.stem1(stem0) 139 | out = (stem0, stem1) 140 | 141 | for i in range(self._num_layers): 142 | out = self.cells[i](out[0], out[1]) 143 | last_output = out[-1] 144 | 145 | #define upsampling 146 | d, h, w = x.size()[2], x.size()[3], x.size()[4] 147 | upsample_6 = nn.Upsample(size=x.size()[2:], mode='trilinear', align_corners=True) 148 | upsample_12 = nn.Upsample(size=[d//2, h//2, w//2], mode='trilinear', align_corners=True) 149 | upsample_24 = nn.Upsample(size=[d//4, h//4, w//4], mode='trilinear', align_corners=True) 150 | 151 | if last_output.size()[3] == h: 152 | mat = self.last_3(last_output) 153 | elif last_output.size()[3] == h//2: 154 | mat = self.last_3(upsample_6(self.last_6(last_output))) 155 | elif last_output.size()[3] == h//4: 156 | mat = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(last_output))))) 157 | elif last_output.size()[3] == h//8: 158 | mat = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(upsample_24(self.last_24(last_output))))))) 159 | return mat 160 | 161 | -------------------------------------------------------------------------------- /retrain/skip_model_3d.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from models.genotypes_3d import PRIMITIVES 5 | from models.genotypes_3d import Genotype 6 | from models.operations_3d import * 7 | import torch.nn.functional as F 8 | import numpy as np 9 | import pdb 10 | 11 | class Cell(nn.Module): 12 | def __init__(self, steps, block_multiplier, prev_prev_fmultiplier, 13 | prev_filter_multiplier, cell_arch, network_arch, 14 | filter_multiplier, downup_sample, args=None): 15 | super(Cell, self).__init__() 16 | self.cell_arch = cell_arch 17 | 18 | self.C_in = block_multiplier * filter_multiplier 19 | self.C_out = filter_multiplier 20 | self.C_prev = int(block_multiplier * prev_filter_multiplier) 21 | self.C_prev_prev = int(block_multiplier * prev_prev_fmultiplier) 22 | self.downup_sample = downup_sample 23 | self.pre_preprocess = ConvBR(self.C_prev_prev, self.C_out, 1, 1, 0) 24 | self.preprocess = ConvBR(self.C_prev, self.C_out, 1, 1, 0) 25 | self._steps = steps 26 | self.block_multiplier = block_multiplier 27 | self._ops = nn.ModuleList() 28 | if downup_sample == -1: 29 | self.scale = 0.5 30 | elif downup_sample == 1: 31 | self.scale = 2 32 | for x in self.cell_arch: 33 | primitive = PRIMITIVES[x[1]] 34 | op = OPS[primitive](self.C_out, stride=1) 35 | self._ops.append(op) 36 | 37 | def scale_dimension(self, dim, scale): 38 | return (int((float(dim) - 1.0) * scale + 1.0) if dim % 2 == 1 else int((float(dim) * scale))) 39 | 40 | def forward(self, prev_prev_input, prev_input): 41 | s0 = prev_prev_input 42 | s1 = prev_input 43 | if self.downup_sample != 0: 44 | feature_size_d = self.scale_dimension(s1.shape[2], self.scale) 45 | feature_size_h = self.scale_dimension(s1.shape[3], self.scale) 46 | feature_size_w = self.scale_dimension(s1.shape[4], self.scale) 47 | s1 = F.interpolate(s1, [feature_size_d, feature_size_h, feature_size_w], mode='trilinear', align_corners=True) 48 | if (s0.shape[2] != s1.shape[2]) or (s0.shape[3] != s1.shape[3]) or (s0.shape[4] != s1.shape[4]): 49 | s0 = F.interpolate(s0, (s1.shape[2], s1.shape[3], s1.shape[4]), 50 | mode='trilinear', align_corners=True) 51 | s0 = self.pre_preprocess(s0) if (s0.shape[1] != self.C_out) else s0 52 | s1 = self.preprocess(s1) 53 | 54 | states = [s0, s1] 55 | offset = 0 56 | ops_index = 0 57 | for i in range(self._steps): 58 | new_states = [] 59 | for j, h in enumerate(states): 60 | branch_index = offset + j 61 | if branch_index in self.cell_arch[:, 0]: 62 | if prev_prev_input is None and j == 0: 63 | ops_index += 1 64 | continue 65 | new_state = self._ops[ops_index](h) 66 | new_states.append(new_state) 67 | ops_index += 1 68 | 69 | s = sum(new_states) 70 | offset += len(states) 71 | states.append(s) 72 | 73 | concat_feature = torch.cat(states[-self.block_multiplier:], dim=1) 74 | return prev_input, concat_feature 75 | 76 | class newMatching(nn.Module): 77 | def __init__(self, network_arch, cell_arch, cell=Cell, args=None): 78 | super(newMatching, self).__init__() 79 | self.args = args 80 | self.cells = nn.ModuleList() 81 | self.network_arch = torch.from_numpy(network_arch) 82 | self.cell_arch = torch.from_numpy(cell_arch) 83 | self._step = args.mat_step 84 | self._num_layers = args.mat_num_layers 85 | self._block_multiplier = args.mat_block_multiplier 86 | self._filter_multiplier = args.mat_filter_multiplier 87 | 88 | initial_fm = self._filter_multiplier * self._block_multiplier 89 | half_initial_fm = initial_fm // 2 90 | 91 | self.stem0 = ConvBR(initial_fm*2, initial_fm, 3, stride=1, padding=1) 92 | self.stem1 = ConvBR(initial_fm, initial_fm, 3, stride=1, padding=1) 93 | 94 | filter_param_dict = {0: 1, 1: 2, 2: 4, 3: 8} 95 | for i in range(self._num_layers): 96 | level_option = torch.sum(self.network_arch[i], dim=1) 97 | prev_level_option = torch.sum(self.network_arch[i - 1], dim=1) 98 | prev_prev_level_option = torch.sum(self.network_arch[i - 2], dim=1) 99 | level = torch.argmax(level_option).item() 100 | prev_level = torch.argmax(prev_level_option).item() 101 | prev_prev_level = torch.argmax(prev_prev_level_option).item() 102 | if i == 0: 103 | downup_sample = - torch.argmax(torch.sum(self.network_arch[0], dim=1)) 104 | _cell = cell(self._step, self._block_multiplier, initial_fm / self._block_multiplier, 105 | initial_fm / self._block_multiplier, 106 | self.cell_arch, self.network_arch[i], 107 | self._filter_multiplier * filter_param_dict[level], 108 | downup_sample, self.args) 109 | 110 | else: 111 | three_branch_options = torch.sum(self.network_arch[i], dim=0) 112 | downup_sample = torch.argmax(three_branch_options).item() - 1 113 | if i == 1: 114 | _cell = cell(self._step, self._block_multiplier, 115 | initial_fm / self._block_multiplier, 116 | self._filter_multiplier * filter_param_dict[prev_level], 117 | self.cell_arch, self.network_arch[i], 118 | self._filter_multiplier * filter_param_dict[level], 119 | downup_sample, self.args) 120 | else: 121 | _cell = cell(self._step, self._block_multiplier, 122 | self._filter_multiplier * filter_param_dict[prev_prev_level], 123 | self._filter_multiplier * 124 | filter_param_dict[prev_level], 125 | self.cell_arch, self.network_arch[i], 126 | self._filter_multiplier * filter_param_dict[level], downup_sample, self.args) 127 | 128 | self.cells += [_cell] 129 | 130 | self.last_3 = ConvBR(initial_fm, 1, 3, 1, 1, bn=False, relu=False) 131 | self.last_6 = ConvBR(initial_fm*2 , initial_fm, 1, 1, 0) 132 | self.last_12 = ConvBR(initial_fm*4 , initial_fm*2, 1, 1, 0) 133 | self.last_24 = ConvBR(initial_fm*8 , initial_fm*4, 1, 1, 0) 134 | 135 | self.conv1 = ConvBR(initial_fm*4, initial_fm*2, 3, 1, 1) 136 | self.conv2 = ConvBR(initial_fm*4, initial_fm*2, 3, 1, 1) 137 | 138 | def forward(self, x): 139 | stem0 = self.stem0(x) 140 | stem1 = self.stem1(stem0) 141 | out = (stem0, stem1) 142 | out0 = self.cells[0](out[0], out[1]) 143 | out1 = self.cells[1](out0[0], out0[1]) 144 | out2 = self.cells[2](out1[0], out1[1]) 145 | out3 = self.cells[3](out2[0], out2[1]) 146 | out4 = self.cells[4](out3[0], out3[1]) 147 | 148 | out4_cat = self.conv1(torch.cat((out1[-1], out4[-1]), 1)) 149 | out5 = self.cells[5](out4[0], out4_cat) 150 | out6 = self.cells[6](out5[0], out5[1]) 151 | out7 = self.cells[7](out6[0], out6[1]) 152 | out8 = self.cells[8](out7[0], out7[1]) 153 | out8_cat = self.conv2(torch.cat((out4[-1], out8[-1]), 1)) 154 | out9 = self.cells[9](out8[0], out8_cat) 155 | out10= self.cells[10](out9[0], out9[1]) 156 | out11= self.cells[11](out10[0],out10[1]) 157 | last_output = out11[-1] 158 | 159 | d, h, w = x.size()[2], x.size()[3], x.size()[4] 160 | upsample_6 = nn.Upsample(size=x.size()[2:], mode='trilinear', align_corners=True) 161 | upsample_12 = nn.Upsample(size=[d//2, h//2, w//2], mode='trilinear', align_corners=True) 162 | upsample_24 = nn.Upsample(size=[d//4, h//4, w//4], mode='trilinear', align_corners=True) 163 | 164 | if last_output.size()[3] == h: 165 | mat = self.last_3(last_output) 166 | elif last_output.size()[3] == h//2: 167 | mat = self.last_3(upsample_6(self.last_6(last_output))) 168 | elif last_output.size()[3] == h//4: 169 | mat = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(last_output))))) 170 | elif last_output.size()[3] == h//8: 171 | mat = self.last_3(upsample_6(self.last_6(upsample_12(self.last_12(upsample_24(self.last_24(last_output))))))) 172 | return mat 173 | 174 | -------------------------------------------------------------------------------- /run/Kitti12/best/best_1.16.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/Kitti12/best/best_1.16.pth -------------------------------------------------------------------------------- /run/Kitti15/best/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/Kitti15/best/best.pth -------------------------------------------------------------------------------- /run/Kitti15/best/kitti15_best_1.65.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/Kitti15/best/kitti15_best_1.65.pth -------------------------------------------------------------------------------- /run/MiddEval3/best/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/MiddEval3/best/best.pth -------------------------------------------------------------------------------- /run/sceneflow/best/architecture/feature_genotype.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/best/architecture/feature_genotype.npy -------------------------------------------------------------------------------- /run/sceneflow/best/architecture/feature_network_path.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/best/architecture/feature_network_path.npy -------------------------------------------------------------------------------- /run/sceneflow/best/architecture/matching_genotype.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/best/architecture/matching_genotype.npy -------------------------------------------------------------------------------- /run/sceneflow/best/architecture/matching_network_path.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/best/architecture/matching_network_path.npy -------------------------------------------------------------------------------- /run/sceneflow/best/checkpoint/best.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/best/checkpoint/best.pth -------------------------------------------------------------------------------- /run/sceneflow/experiment/feature_genotype.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/experiment/feature_genotype.npy -------------------------------------------------------------------------------- /run/sceneflow/experiment/feature_network_path.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/experiment/feature_network_path.npy -------------------------------------------------------------------------------- /run/sceneflow/experiment/matching_genotype.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/experiment/matching_genotype.npy -------------------------------------------------------------------------------- /run/sceneflow/experiment/matching_network_path.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/XuelianCheng/LEAStereo/b2921eaf2554901eab5e2c0e46b5763274a78400/run/sceneflow/experiment/matching_network_path.npy -------------------------------------------------------------------------------- /search.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import torch.nn as nn 5 | from tqdm import tqdm 6 | from collections import OrderedDict 7 | from mypath import Path 8 | from dataloaders import make_data_loader 9 | from utils.lr_scheduler import LR_Scheduler 10 | from utils.saver import Saver 11 | from utils.summaries import TensorboardSummary 12 | from utils.copy_state_dict import copy_state_dict 13 | from torch.autograd import Variable 14 | from time import time 15 | import imageio 16 | import apex 17 | import torch.nn.functional as F 18 | import pdb 19 | from config_utils.search_args import obtain_search_args 20 | from models.build_model import AutoStereo 21 | 22 | print('working with pytorch version {}'.format(torch.__version__)) 23 | print('with cuda version {}'.format(torch.version.cuda)) 24 | print('cudnn enabled: {}'.format(torch.backends.cudnn.enabled)) 25 | print('cudnn version: {}'.format(torch.backends.cudnn.version())) 26 | 27 | 28 | opt = obtain_search_args() 29 | print(opt) 30 | 31 | cuda = opt.cuda 32 | if cuda and not torch.cuda.is_available(): 33 | raise Exception("No GPU found, please run without --cuda") 34 | torch.manual_seed(opt.seed) 35 | if cuda: 36 | torch.cuda.manual_seed(opt.seed) 37 | 38 | # default settings for epochs, batch_size and lr 39 | if opt.epochs is None: 40 | epoches = {'sceneflow': 10} 41 | opt.epochs = epoches[opt.dataset.lower()] 42 | 43 | if opt.batch_size is None: 44 | opt.batch_size = 4 * len(opt.gpu_ids) 45 | 46 | if opt.testBatchSize is None: 47 | opt.testBatchSize = opt.batch_size 48 | 49 | 50 | class Trainer(object): 51 | def __init__(self, args): 52 | self.args = args 53 | 54 | # Define Saver 55 | self.saver = Saver(args) 56 | self.saver.save_experiment_config() 57 | # Define Tensorboard Summary 58 | self.summary = TensorboardSummary(self.saver.experiment_dir) 59 | self.writer = self.summary.create_summary() 60 | 61 | kwargs = {'num_workers': args.workers, 'pin_memory': True, 'drop_last':True} 62 | 63 | self.train_loaderA, self.train_loaderB, self.val_loader, self.test_loader = make_data_loader(args, **kwargs) 64 | 65 | # Define network 66 | model = AutoStereo(maxdisp = self.args.max_disp, 67 | Fea_Layers=self.args.fea_num_layers, Fea_Filter=self.args.fea_filter_multiplier, 68 | Fea_Block=self.args.fea_block_multiplier, Fea_Step=self.args.fea_step, 69 | Mat_Layers=self.args.mat_num_layers, Mat_Filter=self.args.mat_filter_multiplier, 70 | Mat_Block=self.args.mat_block_multiplier, Mat_Step=self.args.mat_step) 71 | 72 | optimizer_F = torch.optim.SGD( 73 | model.feature.weight_parameters(), 74 | args.lr, 75 | momentum=args.momentum, 76 | weight_decay=args.weight_decay 77 | ) 78 | optimizer_M = torch.optim.SGD( 79 | model.matching.weight_parameters(), 80 | args.lr, 81 | momentum=args.momentum, 82 | weight_decay=args.weight_decay 83 | ) 84 | 85 | 86 | self.model, self.optimizer_F, self.optimizer_M = model, optimizer_F, optimizer_M 87 | self.architect_optimizer_F = torch.optim.Adam(self.model.feature.arch_parameters(), 88 | lr=args.arch_lr, betas=(0.9, 0.999), 89 | weight_decay=args.arch_weight_decay) 90 | 91 | self.architect_optimizer_M = torch.optim.Adam(self.model.matching.arch_parameters(), 92 | lr=args.arch_lr, betas=(0.9, 0.999), 93 | weight_decay=args.arch_weight_decay) 94 | 95 | # Define lr scheduler 96 | self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, 97 | args.epochs, len(self.train_loaderA), min_lr=args.min_lr) 98 | # Using cuda 99 | if args.cuda: 100 | self.model = torch.nn.DataParallel(self.model).cuda() 101 | 102 | # Resuming checkpoint 103 | self.best_pred = 100.0 104 | if args.resume is not None: 105 | if not os.path.isfile(args.resume): 106 | raise RuntimeError("=> no checkpoint found at '{}'" .format(args.resume)) 107 | checkpoint = torch.load(args.resume) 108 | args.start_epoch = checkpoint['epoch'] 109 | 110 | # if the weights are wrapped in module object we have to clean it 111 | if args.clean_module: 112 | self.model.load_state_dict(checkpoint['state_dict']) 113 | state_dict = checkpoint['state_dict'] 114 | new_state_dict = OrderedDict() 115 | for k, v in state_dict.items(): 116 | if k.find('module') != -1: 117 | print(1) 118 | pdb.set_trace() 119 | name = k[7:] # remove 'module.' of dataparallel 120 | new_state_dict[name] = v 121 | # self.model.load_state_dict(new_state_dict) 122 | pdb.set_trace() 123 | copy_state_dict(self.model.state_dict(), new_state_dict) 124 | 125 | else: 126 | if torch.cuda.device_count() > 1:#or args.load_parallel: 127 | # self.model.module.load_state_dict(checkpoint['state_dict']) 128 | copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) 129 | else: 130 | # self.model.load_state_dict(checkpoint['state_dict']) 131 | copy_state_dict(self.model.module.state_dict(), checkpoint['state_dict']) 132 | 133 | 134 | if not args.ft: 135 | # self.optimizer.load_state_dict(checkpoint['optimizer']) 136 | copy_state_dict(self.optimizer_M.state_dict(), checkpoint['optimizer_M']) 137 | copy_state_dict(self.optimizer_F.state_dict(), checkpoint['optimizer_F']) 138 | self.best_pred = checkpoint['best_pred'] 139 | print("=> loaded checkpoint '{}' (epoch {})" 140 | .format(args.resume, checkpoint['epoch'])) 141 | 142 | # Clear start epoch if fine-tuning 143 | if args.ft: 144 | args.start_epoch = 0 145 | 146 | print('Total number of model parameters : {}'.format(sum([p.data.nelement() for p in self.model.parameters()]))) 147 | print('Number of Feature Net parameters: {}'.format(sum([p.data.nelement() for p in self.model.module.feature.parameters()]))) 148 | print('Number of Matching Net parameters: {}'.format(sum([p.data.nelement() for p in self.model.module.matching.parameters()]))) 149 | 150 | 151 | def training(self, epoch): 152 | train_loss = 0.0 153 | valid_iteration = 0 154 | self.model.train() 155 | tbar = tqdm(self.train_loaderA) 156 | num_img_tr = len(self.train_loaderA) 157 | 158 | for i, batch in enumerate(tbar): 159 | input1, input2, target = Variable(batch[0],requires_grad=True), Variable(batch[1], requires_grad=True), (batch[2]) 160 | if self.args.cuda: 161 | input1 = input1.cuda() 162 | input2 = input2.cuda() 163 | target = target.cuda() 164 | 165 | target=torch.squeeze(target,1) 166 | mask = target < self.args.max_disp 167 | mask.detach_() 168 | valid = target[mask].size()[0] 169 | if valid > 0: 170 | self.scheduler(self.optimizer_F, i, epoch, self.best_pred) 171 | self.scheduler(self.optimizer_M, i, epoch, self.best_pred) 172 | self.optimizer_F.zero_grad() 173 | self.optimizer_M.zero_grad() 174 | 175 | output = self.model(input1, input2) 176 | loss = F.smooth_l1_loss(output[mask], target[mask], reduction='mean') 177 | loss.backward() 178 | self.optimizer_F.step() 179 | self.optimizer_M.step() 180 | 181 | if epoch >= self.args.alpha_epoch: 182 | print("Start searching architecture!...........") 183 | search = next(iter(self.train_loaderB)) 184 | input1_search, input2_search, target_search = Variable(search[0],requires_grad=True), Variable(search[1], requires_grad=True), (search[2]) 185 | if self.args.cuda: 186 | input1_search = input1_search.cuda() 187 | input2_search = input2_search.cuda() 188 | target_search = target_search.cuda() 189 | 190 | target_search=torch.squeeze(target_search,1) 191 | mask_search = target_search < self.args.max_disp 192 | mask_search.detach_() 193 | 194 | self.architect_optimizer_F.zero_grad() 195 | self.architect_optimizer_M.zero_grad() 196 | output_search = self.model(input1_search, input2_search) 197 | arch_loss = F.smooth_l1_loss(output_search[mask_search], target_search[mask_search], reduction='mean') 198 | 199 | arch_loss.backward() 200 | self.architect_optimizer_F.step() 201 | self.architect_optimizer_M.step() 202 | 203 | train_loss += loss.item() 204 | valid_iteration += 1 205 | tbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) 206 | self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) 207 | 208 | #Show 10 * 3 inference results each epoch 209 | if i % (num_img_tr // 10) == 0: 210 | global_step = i + num_img_tr * epoch 211 | self.summary.visualize_image_stereo(self.writer, input1, target, output, global_step) 212 | 213 | self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) 214 | print("=== Train ===> Epoch :{} Error: {:.4f}".format(epoch, train_loss/valid_iteration)) 215 | print(self.model.module.feature.alphas) 216 | 217 | #save checkpoint every epoch 218 | is_best = False 219 | if torch.cuda.device_count() > 1: 220 | state_dict = self.model.module.state_dict() 221 | else: 222 | state_dict = self.model.state_dict() 223 | self.saver.save_checkpoint({ 224 | 'epoch': epoch + 1, 225 | 'state_dict': state_dict, 226 | 'optimizer_F': self.optimizer_F.state_dict(), 227 | 'optimizer_M': self.optimizer_M.state_dict(), 228 | 'best_pred': self.best_pred, 229 | }, is_best, filename='checkpoint_{}.pth.tar'.format(epoch)) 230 | 231 | def validation(self, epoch): 232 | self.model.eval() 233 | 234 | epoch_error = 0 235 | three_px_acc_all = 0 236 | valid_iteration = 0 237 | 238 | tbar = tqdm(self.val_loader, desc='\r') 239 | test_loss = 0.0 240 | 241 | for i, batch in enumerate(tbar): 242 | input1, input2, target = Variable(batch[0],requires_grad=False), Variable(batch[1], requires_grad=False), Variable(batch[2], requires_grad=False) 243 | if self.args.cuda: 244 | input1 = input1.cuda() 245 | input2 = input2.cuda() 246 | target = target.cuda() 247 | 248 | target=torch.squeeze(target,1) 249 | mask = target < self.args.max_disp 250 | mask.detach_() 251 | valid = target[mask].size()[0] 252 | 253 | if valid>0: 254 | with torch.no_grad(): 255 | output = self.model(input1, input2) 256 | 257 | error = torch.mean(torch.abs(output[mask] - target[mask])) 258 | epoch_error += error.item() 259 | 260 | valid_iteration += 1 261 | 262 | #computing 3-px error# 263 | pred_disp = output.cpu().detach() 264 | true_disp = target.cpu().detach() 265 | disp_true = true_disp 266 | index = np.argwhere(true_disp Test({}/{}): Error(EPE): ({:.4f} {:.4f})".format(i, len(self.val_loader), error.item(),three_px_acc)) 273 | 274 | self.writer.add_scalar('val/EPE', epoch_error/valid_iteration, epoch) 275 | self.writer.add_scalar('val/D1_all', three_px_acc_all/valid_iteration, epoch) 276 | 277 | print("===> Test: Avg. Error: ({:.4f} {:.4f})".format(epoch_error/valid_iteration, three_px_acc_all/valid_iteration)) 278 | 279 | 280 | # save model 281 | new_pred = epoch_error/valid_iteration # three_px_acc_all/valid_iteration 282 | if new_pred < self.best_pred: 283 | is_best = True 284 | self.best_pred = new_pred 285 | if torch.cuda.device_count() > 1: 286 | state_dict = self.model.module.state_dict() 287 | else: 288 | state_dict = self.model.state_dict() 289 | self.saver.save_checkpoint({ 290 | 'epoch': epoch + 1, 291 | 'state_dict': state_dict, 292 | 'optimizer_F': self.optimizer_F.state_dict(), 293 | 'optimizer_M': self.optimizer_M.state_dict(), 294 | 'best_pred': self.best_pred, 295 | }, is_best) 296 | 297 | if __name__ == "__main__": 298 | 299 | trainer = Trainer(opt) 300 | print('Starting Epoch:', trainer.args.start_epoch) 301 | print('Total Epoches:', trainer.args.epochs) 302 | for epoch in range(trainer.args.start_epoch, trainer.args.epochs): 303 | trainer.training(epoch) 304 | if not trainer.args.no_val: 305 | trainer.validation(epoch) 306 | 307 | trainer.writer.close() 308 | -------------------------------------------------------------------------------- /search.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=1 python search.py \ 2 | --batch-size 1 \ 3 | --dataset sceneflow \ 4 | --crop_height 192 \ 5 | --crop_width 384 \ 6 | --gpu-ids 0 \ 7 | --fea_num_layers 6 \ 8 | --mat_num_layers 12 \ 9 | --fea_filter_multiplier 4 --fea_block_multiplier 3 --fea_step 3 \ 10 | --mat_filter_multiplier 4 --mat_block_multiplier 3 --mat_step 3 \ 11 | --alpha_epoch 3 \ 12 | --lr 1e-3 \ 13 | --testBatchSize 8 \ 14 | #--resume './run/sceneflow/experiment_0/checkpoint.pth.tar' 15 | -------------------------------------------------------------------------------- /thop/__init__.py: -------------------------------------------------------------------------------- 1 | from .profile import profile -------------------------------------------------------------------------------- /thop/count_hooks.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | multiply_adds = 1 7 | 8 | 9 | def count_convNd(m, x, y): 10 | x = x[0] 11 | 12 | kernel_ops = m.weight.size()[2:].numel() 13 | bias_ops = 1 if m.bias is not None else 0 14 | 15 | # cout x oW x oH 16 | total_ops = y.nelement() * (m.in_channels//m.groups * kernel_ops + bias_ops) 17 | m.total_ops = torch.Tensor([int(total_ops)]) 18 | 19 | 20 | def count_conv2d(m, x, y): 21 | x = x[0] 22 | 23 | cin = m.in_channels 24 | cout = m.out_channels 25 | kh, kw = m.kernel_size 26 | batch_size = x.size()[0] 27 | 28 | out_h = y.size(2) 29 | out_w = y.size(3) 30 | 31 | # ops per output element 32 | # kernel_mul = kh * kw * cin 33 | # kernel_add = kh * kw * cin - 1 34 | kernel_ops = multiply_adds * kh * kw 35 | bias_ops = 1 if m.bias is not None else 0 36 | ops_per_element = kernel_ops + bias_ops 37 | 38 | # total ops 39 | # num_out_elements = y.numel() 40 | output_elements = batch_size * out_w * out_h * cout 41 | total_ops = output_elements * ops_per_element * cin // m.groups 42 | 43 | m.total_ops = torch.Tensor([int(total_ops)]) 44 | 45 | 46 | def count_convtranspose2d(m, x, y): 47 | x = x[0] 48 | 49 | cin = m.in_channels 50 | cout = m.out_channels 51 | kh, kw = m.kernel_size 52 | # batch_size = x.size()[0] 53 | 54 | out_h = y.size(2) 55 | out_w = y.size(3) 56 | 57 | # ops per output element 58 | # kernel_mul = kh * kw * cin 59 | # kernel_add = kh * kw * cin - 1 60 | kernel_ops = multiply_adds * kh * kw * cin // m.groups 61 | bias_ops = 1 if m.bias is not None else 0 62 | ops_per_element = kernel_ops + bias_ops 63 | 64 | # total ops 65 | # num_out_elements = y.numel() 66 | # output_elements = batch_size * out_w * out_h * cout 67 | ops_per_element = m.weight.nelement() 68 | output_elements = y.nelement() 69 | total_ops = output_elements * ops_per_element 70 | 71 | m.total_ops = torch.Tensor([int(total_ops)]) 72 | import pdb; pdb.set_trace() 73 | print(m.total_ops) 74 | 75 | 76 | def count_bn(m, x, y): 77 | x = x[0] 78 | 79 | nelements = x.numel() 80 | # subtract, divide, gamma, beta 81 | total_ops = 4 * nelements 82 | 83 | m.total_ops = torch.Tensor([int(total_ops)]) 84 | 85 | 86 | def count_relu(m, x, y): 87 | x = x[0] 88 | 89 | nelements = x.numel() 90 | total_ops = nelements 91 | 92 | m.total_ops = torch.Tensor([int(total_ops)]) 93 | 94 | 95 | def count_softmax(m, x, y): 96 | x = x[0] 97 | 98 | batch_size, nfeatures = x.size() 99 | 100 | total_exp = nfeatures 101 | total_add = nfeatures - 1 102 | total_div = nfeatures 103 | total_ops = batch_size * (total_exp + total_add + total_div) 104 | 105 | m.total_ops = torch.Tensor([int(total_ops)]) 106 | 107 | 108 | def count_maxpool(m, x, y): 109 | kernel_ops = torch.prod(torch.Tensor([m.kernel_size])) 110 | num_elements = y.numel() 111 | total_ops = kernel_ops * num_elements 112 | 113 | m.total_ops = torch.Tensor([int(total_ops)]) 114 | 115 | 116 | def count_adap_maxpool(m, x, y): 117 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 118 | kernel_ops = torch.prod(kernel) 119 | num_elements = y.numel() 120 | total_ops = kernel_ops * num_elements 121 | 122 | m.total_ops = torch.Tensor([int(total_ops)]) 123 | 124 | 125 | def count_avgpool(m, x, y): 126 | total_add = torch.prod(torch.Tensor([m.kernel_size])) 127 | total_div = 1 128 | kernel_ops = total_add + total_div 129 | num_elements = y.numel() 130 | total_ops = kernel_ops * num_elements 131 | 132 | m.total_ops = torch.Tensor([int(total_ops)]) 133 | 134 | 135 | def count_adap_avgpool(m, x, y): 136 | kernel = torch.Tensor([*(x[0].shape[2:])]) // torch.Tensor(list((m.output_size,))).squeeze() 137 | total_add = torch.prod(kernel) 138 | total_div = 1 139 | kernel_ops = total_add + total_div 140 | num_elements = y.numel() 141 | total_ops = kernel_ops * num_elements 142 | 143 | m.total_ops = torch.Tensor([int(total_ops)]) 144 | 145 | 146 | def count_linear(m, x, y): 147 | # per output element 148 | total_mul = m.in_features 149 | total_add = m.in_features - 1 150 | num_elements = y.numel() 151 | total_ops = (total_mul + total_add) * num_elements 152 | 153 | m.total_ops = torch.Tensor([int(total_ops)]) 154 | -------------------------------------------------------------------------------- /thop/profile.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.modules.conv import _ConvNd 6 | 7 | from .count_hooks import * 8 | 9 | register_hooks = { 10 | nn.Conv1d: count_convNd, 11 | nn.Conv2d: count_convNd, 12 | nn.Conv3d: count_convNd, 13 | nn.ConvTranspose2d: count_convNd, 14 | 15 | nn.BatchNorm1d: count_bn, 16 | nn.BatchNorm2d: count_bn, 17 | nn.BatchNorm3d: count_bn, 18 | 19 | nn.ReLU: count_relu, 20 | nn.ReLU6: count_relu, 21 | nn.LeakyReLU: count_relu, 22 | 23 | nn.MaxPool1d: count_maxpool, 24 | nn.MaxPool2d: count_maxpool, 25 | nn.MaxPool3d: count_maxpool, 26 | nn.AdaptiveMaxPool1d: count_adap_maxpool, 27 | nn.AdaptiveMaxPool2d: count_adap_maxpool, 28 | nn.AdaptiveMaxPool3d: count_adap_maxpool, 29 | 30 | nn.AvgPool1d: count_avgpool, 31 | nn.AvgPool2d: count_avgpool, 32 | nn.AvgPool3d: count_avgpool, 33 | 34 | nn.AdaptiveAvgPool1d: count_adap_avgpool, 35 | nn.AdaptiveAvgPool2d: count_adap_avgpool, 36 | nn.AdaptiveAvgPool3d: count_adap_avgpool, 37 | nn.Linear: count_linear, 38 | nn.Dropout: None, 39 | } 40 | 41 | 42 | def profile(model, input_size, custom_ops={}, device="cpu"): 43 | handler_collection = [] 44 | 45 | def add_hooks(m): 46 | if len(list(m.children())) > 0: 47 | return 48 | 49 | m.register_buffer('total_ops', torch.zeros(1)) 50 | m.register_buffer('total_params', torch.zeros(1)) 51 | 52 | for p in m.parameters(): 53 | m.total_params += torch.Tensor([p.numel()]) 54 | 55 | m_type = type(m) 56 | fn = None 57 | 58 | if m_type in custom_ops: 59 | fn = custom_ops[m_type] 60 | elif m_type in register_hooks: 61 | fn = register_hooks[m_type] 62 | else: 63 | print("Not implemented for ", m) 64 | 65 | if fn is not None: 66 | handler = m.register_forward_hook(fn) 67 | handler_collection.append(handler) 68 | 69 | original_device = model.parameters().__next__().device 70 | training = model.training 71 | 72 | model.eval().to(device) 73 | model.apply(add_hooks) 74 | 75 | x = torch.zeros(input_size).to(device) 76 | y = torch.zeros(input_size).to(device) 77 | with torch.no_grad(): 78 | model(x,y) 79 | 80 | total_ops = 0 81 | total_params = 0 82 | for m in model.modules(): 83 | if len(list(m.children())) > 0: # skip for non-leaf module 84 | #if 'butterfly' in str(m._get_name()): break 85 | print('-> %s'%(str(m._get_name()))) 86 | continue 87 | #if not '2d' in str(m._get_name()): continue 88 | #if not '3d' in str(m._get_name()): continue 89 | print("Registered FLOP counter (%.1f M/%.1f) for module %s" % (m.total_ops/1e6, m.total_params, str(m))) 90 | total_ops += m.total_ops 91 | total_params += m.total_params 92 | 93 | total_ops = total_ops.item() 94 | total_params = total_params.item() 95 | 96 | model.train(training).to(original_device) 97 | for handler in handler_collection: 98 | handler.remove() 99 | 100 | return total_ops, total_params 101 | -------------------------------------------------------------------------------- /thop/utils.py: -------------------------------------------------------------------------------- 1 | 2 | def clever_format(num, format="%.2f"): 3 | if num > 1e12: 4 | return format % (num / 1e12) + "T" 5 | if num > 1e9: 6 | return format % (num / 1e9) + "G" 7 | if num > 1e6: 8 | return format % (num / 1e6) + "M" 9 | if num > 1e3: 10 | return format % (num / 1e3) + "K" -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import argparse 3 | from math import log10 4 | 5 | import sys 6 | import shutil 7 | import os 8 | import torch 9 | import torch.nn as nn 10 | import torch.nn.parallel 11 | import torch.backends.cudnn as cudnn 12 | import torch.optim as optim 13 | import torch.nn.functional as F 14 | import skimage 15 | import pdb 16 | import numpy as np 17 | from torch.autograd import Variable 18 | from torch.utils.data import DataLoader 19 | from time import time 20 | from collections import OrderedDict 21 | from retrain.LEAStereo import LEAStereo 22 | 23 | from mypath import Path 24 | from dataloaders import make_data_loader 25 | from utils.multadds_count import count_parameters_in_MB, comp_multadds, comp_multadds_fw 26 | from config_utils.train_args import obtain_train_args 27 | 28 | 29 | opt = obtain_train_args() 30 | print(opt) 31 | 32 | cuda = opt.cuda 33 | 34 | if cuda and not torch.cuda.is_available(): 35 | raise Exception("No GPU found, please run without --cuda") 36 | 37 | torch.manual_seed(opt.seed) 38 | if cuda: 39 | torch.cuda.manual_seed(opt.seed) 40 | 41 | print('===> Loading datasets') 42 | kwargs = {'num_workers': opt.threads, 'pin_memory': True, 'drop_last':True} 43 | training_data_loader, testing_data_loader = make_data_loader(opt, **kwargs) 44 | 45 | print('===> Building model') 46 | model = LEAStereo(opt) 47 | 48 | ## compute parameters 49 | #print('Total number of model parameters : {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 50 | #print('Number of Feature Net parameters: {}'.format(sum([p.data.nelement() for p in model.feature.parameters()]))) 51 | #print('Number of Matching Net parameters: {}'.format(sum([p.data.nelement() for p in model.matching.parameters()]))) 52 | 53 | print('Total Params = %.2fMB' % count_parameters_in_MB(model)) 54 | print('Feature Net Params = %.2fMB' % count_parameters_in_MB(model.feature)) 55 | print('Matching Net Params = %.2fMB' % count_parameters_in_MB(model.matching)) 56 | 57 | #mult_adds = comp_multadds(model, input_size=(3,opt.crop_height, opt.crop_width)) #(3,192, 192)) 58 | #print("compute_average_flops_cost = %.2fMB" % mult_adds) 59 | 60 | if cuda: 61 | model = torch.nn.DataParallel(model).cuda() 62 | 63 | torch.backends.cudnn.benchmark = True 64 | 65 | if opt.solver == 'adam': 66 | optimizer = optim.Adam(model.parameters(), lr=opt.lr, betas=(0.9,0.999)) 67 | elif opt.solver == 'sgd': 68 | optimizer = optim.SGD(model.parameters(), lr=opt.lr, momentum=0.9) 69 | 70 | scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=opt.milestones, gamma=0.5) 71 | 72 | if opt.resume: 73 | if os.path.isfile(opt.resume): 74 | print("=> loading checkpoint '{}'".format(opt.resume)) 75 | checkpoint = torch.load(opt.resume) 76 | model.load_state_dict(checkpoint['state_dict'], strict=True) 77 | else: 78 | print("=> no checkpoint found at '{}'".format(opt.resume)) 79 | 80 | def train(epoch): 81 | epoch_loss = 0 82 | epoch_error = 0 83 | valid_iteration = 0 84 | 85 | for iteration, batch in enumerate(training_data_loader): 86 | input1, input2, target = Variable(batch[0], requires_grad=True), Variable(batch[1], requires_grad=True), (batch[2]) 87 | if cuda: 88 | input1 = input1.cuda() 89 | input2 = input2.cuda() 90 | target = target.cuda() 91 | 92 | target=torch.squeeze(target,1) 93 | mask = target < opt.maxdisp 94 | mask.detach_() 95 | valid = target[mask].size()[0] 96 | train_start_time = time() 97 | if valid > 0: 98 | model.train() 99 | 100 | optimizer.zero_grad() 101 | disp = model(input1,input2) 102 | loss = F.smooth_l1_loss(disp[mask], target[mask], reduction='mean') 103 | loss.backward() 104 | optimizer.step() 105 | 106 | error = torch.mean(torch.abs(disp[mask] - target[mask])) 107 | train_end_time = time() 108 | train_time = train_end_time - train_start_time 109 | 110 | epoch_loss += loss.item() 111 | valid_iteration += 1 112 | epoch_error += error.item() 113 | print("===> Epoch[{}]({}/{}): Loss: ({:.4f}), Error: ({:.4f}), Time: ({:.2f}s)".format(epoch, iteration, len(training_data_loader), loss.item(), error.item(), train_time)) 114 | sys.stdout.flush() 115 | print("===> Epoch {} Complete: Avg. Loss: ({:.4f}), Avg. Error: ({:.4f})".format(epoch, epoch_loss / valid_iteration, epoch_error/valid_iteration)) 116 | 117 | def val(): 118 | epoch_error = 0 119 | valid_iteration = 0 120 | three_px_acc_all = 0 121 | model.eval() 122 | for iteration, batch in enumerate(testing_data_loader): 123 | input1, input2, target = Variable(batch[0],requires_grad=False), Variable(batch[1], requires_grad=False), Variable(batch[2], requires_grad=False) 124 | if cuda: 125 | input1 = input1.cuda() 126 | input2 = input2.cuda() 127 | target = target.cuda() 128 | target=torch.squeeze(target,1) 129 | mask = target < opt.maxdisp 130 | mask.detach_() 131 | valid=target[mask].size()[0] 132 | if valid>0: 133 | with torch.no_grad(): 134 | disp = model(input1,input2) 135 | error = torch.mean(torch.abs(disp[mask] - target[mask])) 136 | 137 | valid_iteration += 1 138 | epoch_error += error.item() 139 | #computing 3-px error# 140 | pred_disp = disp.cpu().detach() 141 | true_disp = target.cpu().detach() 142 | disp_true = true_disp 143 | index = np.argwhere(true_disp Test({}/{}): Error: ({:.4f} {:.4f})".format(iteration, len(testing_data_loader), error.item(), three_px_acc)) 151 | sys.stdout.flush() 152 | 153 | print("===> Test: Avg. Error: ({:.4f} {:.4f})".format(epoch_error/valid_iteration, three_px_acc_all/valid_iteration)) 154 | return three_px_acc_all/valid_iteration 155 | 156 | def save_checkpoint(save_path, epoch,state, is_best): 157 | filename = save_path + "epoch_{}.pth".format(epoch) 158 | torch.save(state, filename) 159 | if is_best: 160 | shutil.copyfile(filename, save_path + 'best.pth') 161 | print("Checkpoint saved to {}".format(filename)) 162 | 163 | if __name__ == '__main__': 164 | error=100 165 | for epoch in range(1, opt.nEpochs + 1): 166 | train(epoch) 167 | is_best = False 168 | loss=val() 169 | if loss < error: 170 | error=loss 171 | is_best = True 172 | if opt.dataset == 'sceneflow': 173 | if epoch>=0: 174 | save_checkpoint(opt.save_path, epoch,{ 175 | 'epoch': epoch, 176 | 'state_dict': model.state_dict(), 177 | 'optimizer' : optimizer.state_dict(), 178 | }, is_best) 179 | else: 180 | if epoch%100 == 0 and epoch >= 3000: 181 | save_checkpoint(opt.save_path, epoch,{ 182 | 'epoch': epoch, 183 | 'state_dict': model.state_dict(), 184 | 'optimizer' : optimizer.state_dict(), 185 | }, is_best) 186 | if is_best: 187 | save_checkpoint(opt.save_path, epoch,{ 188 | 'epoch': epoch, 189 | 'state_dict': model.state_dict(), 190 | 'optimizer' : optimizer.state_dict(), 191 | }, is_best) 192 | 193 | scheduler.step() 194 | 195 | save_checkpoint(opt.save_path, opt.nEpochs,{ 196 | 'epoch': opt.nEpochs, 197 | 'state_dict': model.state_dict(), 198 | 'optimizer' : optimizer.state_dict(), 199 | }, is_best) 200 | -------------------------------------------------------------------------------- /train_kitti12.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --batch_size=4 \ 3 | --testBatchSize=1 \ 4 | --crop_height=288 \ 5 | --crop_width=576 \ 6 | --maxdisp=192 \ 7 | --threads=8 \ 8 | --dataset='kitti12' \ 9 | --save_path='./run/Kitti12/' \ 10 | --resume='./run/sceneflow/best/checkpoint/best.pth' \ 11 | --fea_num_layer 6 --mat_num_layers 12 \ 12 | --fea_filter_multiplier 8 --fea_block_multiplier 4 --fea_step 3 \ 13 | --mat_filter_multiplier 8 --mat_block_multiplier 4 --mat_step 3 \ 14 | --net_arch_fea='run/sceneflow/best/architecture/feature_network_path.npy' \ 15 | --cell_arch_fea='run/sceneflow/best/architecture/feature_genotype.npy' \ 16 | --net_arch_mat='run/sceneflow/best/architecture/matching_network_path.npy' \ 17 | --cell_arch_mat='run/sceneflow/best/architecture/matching_genotype.npy' \ 18 | --nEpochs=800 2>&1 |tee ./run/Kitti12/log.txt 19 | 20 | #--resume='./run/Kitti12/best/best_1.16.pth' 21 | -------------------------------------------------------------------------------- /train_kitti15.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --batch_size=4 \ 3 | --testBatchSize=1 \ 4 | --crop_height=288 \ 5 | --crop_width=576 \ 6 | --maxdisp=192 \ 7 | --threads=8 \ 8 | --dataset='kitti15' \ 9 | --save_path='./run/Kitti15/' \ 10 | --resume='./run/sceneflow/best/checkpoint/best.pth' \ 11 | --fea_num_layer 6 --mat_num_layers 12 \ 12 | --fea_filter_multiplier 8 --fea_block_multiplier 4 --fea_step 3 \ 13 | --mat_filter_multiplier 8 --mat_block_multiplier 4 --mat_step 3 \ 14 | --net_arch_fea='run/sceneflow/best/architecture/feature_network_path.npy' \ 15 | --cell_arch_fea='run/sceneflow/best/architecture/feature_genotype.npy' \ 16 | --net_arch_mat='run/sceneflow/best/architecture/matching_network_path.npy' \ 17 | --cell_arch_mat='run/sceneflow/best/architecture/matching_genotype.npy' \ 18 | --nEpochs=800 2>&1 |tee ./run/Kitti15/log.txt 19 | 20 | #--resume='./run/Kitti15/best/best.pth' or './run/sceneflow/best/checkpoint/best.pth' 21 | -------------------------------------------------------------------------------- /train_md.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python train.py \ 2 | --batch_size=2 \ 3 | --testBatchSize=1 \ 4 | --crop_height=384 \ 5 | --crop_width=576 \ 6 | --maxdisp=408 \ 7 | --threads=8 \ 8 | --shift=3 \ 9 | --dataset='middlebury' \ 10 | --save_path='./run/MiddEval3/' \ 11 | --resume='./run/sceneflow/best/checkpoint/best.pth' \ 12 | --fea_num_layer 6 --mat_num_layers 12 \ 13 | --fea_filter_multiplier 8 --fea_block_multiplier 4 --fea_step 3 \ 14 | --mat_filter_multiplier 8 --mat_block_multiplier 4 --mat_step 3 \ 15 | --net_arch_fea='run/sceneflow/best/architecture/feature_network_path.npy' \ 16 | --cell_arch_fea='run/sceneflow/best/architecture/feature_genotype.npy' \ 17 | --net_arch_mat='run/sceneflow/best/architecture/matching_network_path.npy' \ 18 | --cell_arch_mat='run/sceneflow/best/architecture/matching_genotype.npy' \ 19 | --nEpochs=800 2>&1 |tee ./run/MiddEval3/log.txt 20 | 21 | -------------------------------------------------------------------------------- /train_sf.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 train.py --batch_size=4 \ 2 | --crop_height=384 \ 3 | --crop_width=576 \ 4 | --maxdisp=192 \ 5 | --threads=8 \ 6 | --save_path='./run/sceneflow/retrain/' \ 7 | --fea_num_layer 6 --mat_num_layers 12 \ 8 | --fea_filter_multiplier 8 --fea_block_multiplier 4 --fea_step 3 \ 9 | --mat_filter_multiplier 8 --mat_block_multiplier 4 --mat_step 3 \ 10 | --net_arch_fea='run/sceneflow/best/architecture/feature_network_path.npy' \ 11 | --cell_arch_fea='run/sceneflow/best/architecture/feature_genotype.npy' \ 12 | --net_arch_mat='run/sceneflow/best/architecture/matching_network_path.npy' \ 13 | --cell_arch_mat='run/sceneflow/best/architecture/matching_genotype.npy' \ 14 | --nEpochs=20 2>&1 |tee ./run/sceneflow/retrain/log.txt 15 | 16 | -------------------------------------------------------------------------------- /utils/colorize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def get_color_map(): 4 | 5 | return np.array([[0.18995,0.07176,0.23217], 6 | [0.19483,0.08339,0.26149], 7 | [0.19956,0.09498,0.29024], 8 | [0.20415,0.10652,0.31844], 9 | [0.20860,0.11802,0.34607], 10 | [0.21291,0.12947,0.37314], 11 | [0.21708,0.14087,0.39964], 12 | [0.22111,0.15223,0.42558], 13 | [0.22500,0.16354,0.45096], 14 | [0.22875,0.17481,0.47578], 15 | [0.23236,0.18603,0.50004], 16 | [0.23582,0.19720,0.52373], 17 | [0.23915,0.20833,0.54686], 18 | [0.24234,0.21941,0.56942], 19 | [0.24539,0.23044,0.59142], 20 | [0.24830,0.24143,0.61286], 21 | [0.25107,0.25237,0.63374], 22 | [0.25369,0.26327,0.65406], 23 | [0.25618,0.27412,0.67381], 24 | [0.25853,0.28492,0.69300], 25 | [0.26074,0.29568,0.71162], 26 | [0.26280,0.30639,0.72968], 27 | [0.26473,0.31706,0.74718], 28 | [0.26652,0.32768,0.76412], 29 | [0.26816,0.33825,0.78050], 30 | [0.26967,0.34878,0.79631], 31 | [0.27103,0.35926,0.81156], 32 | [0.27226,0.36970,0.82624], 33 | [0.27334,0.38008,0.84037], 34 | [0.27429,0.39043,0.85393], 35 | [0.27509,0.40072,0.86692], 36 | [0.27576,0.41097,0.87936], 37 | [0.27628,0.42118,0.89123], 38 | [0.27667,0.43134,0.90254], 39 | [0.27691,0.44145,0.91328], 40 | [0.27701,0.45152,0.92347], 41 | [0.27698,0.46153,0.93309], 42 | [0.27680,0.47151,0.94214], 43 | [0.27648,0.48144,0.95064], 44 | [0.27603,0.49132,0.95857], 45 | [0.27543,0.50115,0.96594], 46 | [0.27469,0.51094,0.97275], 47 | [0.27381,0.52069,0.97899], 48 | [0.27273,0.53040,0.98461], 49 | [0.27106,0.54015,0.98930], 50 | [0.26878,0.54995,0.99303], 51 | [0.26592,0.55979,0.99583], 52 | [0.26252,0.56967,0.99773], 53 | [0.25862,0.57958,0.99876], 54 | [0.25425,0.58950,0.99896], 55 | [0.24946,0.59943,0.99835], 56 | [0.24427,0.60937,0.99697], 57 | [0.23874,0.61931,0.99485], 58 | [0.23288,0.62923,0.99202], 59 | [0.22676,0.63913,0.98851], 60 | [0.22039,0.64901,0.98436], 61 | [0.21382,0.65886,0.97959], 62 | [0.20708,0.66866,0.97423], 63 | [0.20021,0.67842,0.96833], 64 | [0.19326,0.68812,0.96190], 65 | [0.18625,0.69775,0.95498], 66 | [0.17923,0.70732,0.94761], 67 | [0.17223,0.71680,0.93981], 68 | [0.16529,0.72620,0.93161], 69 | [0.15844,0.73551,0.92305], 70 | [0.15173,0.74472,0.91416], 71 | [0.14519,0.75381,0.90496], 72 | [0.13886,0.76279,0.89550], 73 | [0.13278,0.77165,0.88580], 74 | [0.12698,0.78037,0.87590], 75 | [0.12151,0.78896,0.86581], 76 | [0.11639,0.79740,0.85559], 77 | [0.11167,0.80569,0.84525], 78 | [0.10738,0.81381,0.83484], 79 | [0.10357,0.82177,0.82437], 80 | [0.10026,0.82955,0.81389], 81 | [0.09750,0.83714,0.80342], 82 | [0.09532,0.84455,0.79299], 83 | [0.09377,0.85175,0.78264], 84 | [0.09287,0.85875,0.77240], 85 | [0.09267,0.86554,0.76230], 86 | [0.09320,0.87211,0.75237], 87 | [0.09451,0.87844,0.74265], 88 | [0.09662,0.88454,0.73316], 89 | [0.09958,0.89040,0.72393], 90 | [0.10342,0.89600,0.71500], 91 | [0.10815,0.90142,0.70599], 92 | [0.11374,0.90673,0.69651], 93 | [0.12014,0.91193,0.68660], 94 | [0.12733,0.91701,0.67627], 95 | [0.13526,0.92197,0.66556], 96 | [0.14391,0.92680,0.65448], 97 | [0.15323,0.93151,0.64308], 98 | [0.16319,0.93609,0.63137], 99 | [0.17377,0.94053,0.61938], 100 | [0.18491,0.94484,0.60713], 101 | [0.19659,0.94901,0.59466], 102 | [0.20877,0.95304,0.58199], 103 | [0.22142,0.95692,0.56914], 104 | [0.23449,0.96065,0.55614], 105 | [0.24797,0.96423,0.54303], 106 | [0.26180,0.96765,0.52981], 107 | [0.27597,0.97092,0.51653], 108 | [0.29042,0.97403,0.50321], 109 | [0.30513,0.97697,0.48987], 110 | [0.32006,0.97974,0.47654], 111 | [0.33517,0.98234,0.46325], 112 | [0.35043,0.98477,0.45002], 113 | [0.36581,0.98702,0.43688], 114 | [0.38127,0.98909,0.42386], 115 | [0.39678,0.99098,0.41098], 116 | [0.41229,0.99268,0.39826], 117 | [0.42778,0.99419,0.38575], 118 | [0.44321,0.99551,0.37345], 119 | [0.45854,0.99663,0.36140], 120 | [0.47375,0.99755,0.34963], 121 | [0.48879,0.99828,0.33816], 122 | [0.50362,0.99879,0.32701], 123 | [0.51822,0.99910,0.31622], 124 | [0.53255,0.99919,0.30581], 125 | [0.54658,0.99907,0.29581], 126 | [0.56026,0.99873,0.28623], 127 | [0.57357,0.99817,0.27712], 128 | [0.58646,0.99739,0.26849], 129 | [0.59891,0.99638,0.26038], 130 | [0.61088,0.99514,0.25280], 131 | [0.62233,0.99366,0.24579], 132 | [0.63323,0.99195,0.23937], 133 | [0.64362,0.98999,0.23356], 134 | [0.65394,0.98775,0.22835], 135 | [0.66428,0.98524,0.22370], 136 | [0.67462,0.98246,0.21960], 137 | [0.68494,0.97941,0.21602], 138 | [0.69525,0.97610,0.21294], 139 | [0.70553,0.97255,0.21032], 140 | [0.71577,0.96875,0.20815], 141 | [0.72596,0.96470,0.20640], 142 | [0.73610,0.96043,0.20504], 143 | [0.74617,0.95593,0.20406], 144 | [0.75617,0.95121,0.20343], 145 | [0.76608,0.94627,0.20311], 146 | [0.77591,0.94113,0.20310], 147 | [0.78563,0.93579,0.20336], 148 | [0.79524,0.93025,0.20386], 149 | [0.80473,0.92452,0.20459], 150 | [0.81410,0.91861,0.20552], 151 | [0.82333,0.91253,0.20663], 152 | [0.83241,0.90627,0.20788], 153 | [0.84133,0.89986,0.20926], 154 | [0.85010,0.89328,0.21074], 155 | [0.85868,0.88655,0.21230], 156 | [0.86709,0.87968,0.21391], 157 | [0.87530,0.87267,0.21555], 158 | [0.88331,0.86553,0.21719], 159 | [0.89112,0.85826,0.21880], 160 | [0.89870,0.85087,0.22038], 161 | [0.90605,0.84337,0.22188], 162 | [0.91317,0.83576,0.22328], 163 | [0.92004,0.82806,0.22456], 164 | [0.92666,0.82025,0.22570], 165 | [0.93301,0.81236,0.22667], 166 | [0.93909,0.80439,0.22744], 167 | [0.94489,0.79634,0.22800], 168 | [0.95039,0.78823,0.22831], 169 | [0.95560,0.78005,0.22836], 170 | [0.96049,0.77181,0.22811], 171 | [0.96507,0.76352,0.22754], 172 | [0.96931,0.75519,0.22663], 173 | [0.97323,0.74682,0.22536], 174 | [0.97679,0.73842,0.22369], 175 | [0.98000,0.73000,0.22161], 176 | [0.98289,0.72140,0.21918], 177 | [0.98549,0.71250,0.21650], 178 | [0.98781,0.70330,0.21358], 179 | [0.98986,0.69382,0.21043], 180 | [0.99163,0.68408,0.20706], 181 | [0.99314,0.67408,0.20348], 182 | [0.99438,0.66386,0.19971], 183 | [0.99535,0.65341,0.19577], 184 | [0.99607,0.64277,0.19165], 185 | [0.99654,0.63193,0.18738], 186 | [0.99675,0.62093,0.18297], 187 | [0.99672,0.60977,0.17842], 188 | [0.99644,0.59846,0.17376], 189 | [0.99593,0.58703,0.16899], 190 | [0.99517,0.57549,0.16412], 191 | [0.99419,0.56386,0.15918], 192 | [0.99297,0.55214,0.15417], 193 | [0.99153,0.54036,0.14910], 194 | [0.98987,0.52854,0.14398], 195 | [0.98799,0.51667,0.13883], 196 | [0.98590,0.50479,0.13367], 197 | [0.98360,0.49291,0.12849], 198 | [0.98108,0.48104,0.12332], 199 | [0.97837,0.46920,0.11817], 200 | [0.97545,0.45740,0.11305], 201 | [0.97234,0.44565,0.10797], 202 | [0.96904,0.43399,0.10294], 203 | [0.96555,0.42241,0.09798], 204 | [0.96187,0.41093,0.09310], 205 | [0.95801,0.39958,0.08831], 206 | [0.95398,0.38836,0.08362], 207 | [0.94977,0.37729,0.07905], 208 | [0.94538,0.36638,0.07461], 209 | [0.94084,0.35566,0.07031], 210 | [0.93612,0.34513,0.06616], 211 | [0.93125,0.33482,0.06218], 212 | [0.92623,0.32473,0.05837], 213 | [0.92105,0.31489,0.05475], 214 | [0.91572,0.30530,0.05134], 215 | [0.91024,0.29599,0.04814], 216 | [0.90463,0.28696,0.04516], 217 | [0.89888,0.27824,0.04243], 218 | [0.89298,0.26981,0.03993], 219 | [0.88691,0.26152,0.03753], 220 | [0.88066,0.25334,0.03521], 221 | [0.87422,0.24526,0.03297], 222 | [0.86760,0.23730,0.03082], 223 | [0.86079,0.22945,0.02875], 224 | [0.85380,0.22170,0.02677], 225 | [0.84662,0.21407,0.02487], 226 | [0.83926,0.20654,0.02305], 227 | [0.83172,0.19912,0.02131], 228 | [0.82399,0.19182,0.01966], 229 | [0.81608,0.18462,0.01809], 230 | [0.80799,0.17753,0.01660], 231 | [0.79971,0.17055,0.01520], 232 | [0.79125,0.16368,0.01387], 233 | [0.78260,0.15693,0.01264], 234 | [0.77377,0.15028,0.01148], 235 | [0.76476,0.14374,0.01041], 236 | [0.75556,0.13731,0.00942], 237 | [0.74617,0.13098,0.00851], 238 | [0.73661,0.12477,0.00769], 239 | [0.72686,0.11867,0.00695], 240 | [0.71692,0.11268,0.00629], 241 | [0.70680,0.10680,0.00571], 242 | [0.69650,0.10102,0.00522], 243 | [0.68602,0.09536,0.00481], 244 | [0.67535,0.08980,0.00449], 245 | [0.66449,0.08436,0.00424], 246 | [0.65345,0.07902,0.00408], 247 | [0.64223,0.07380,0.00401], 248 | [0.63082,0.06868,0.00401], 249 | [0.61923,0.06367,0.00410], 250 | [0.60746,0.05878,0.00427], 251 | [0.59550,0.05399,0.00453], 252 | [0.58336,0.04931,0.00486], 253 | [0.57103,0.04474,0.00529], 254 | [0.55852,0.04028,0.00579], 255 | [0.54583,0.03593,0.00638], 256 | [0.53295,0.03169,0.00705], 257 | [0.51989,0.02756,0.00780], 258 | [0.50664,0.02354,0.00863], 259 | [0.49321,0.01963,0.00955], 260 | [0.47960,0.01583,0.01055]]) 261 | -------------------------------------------------------------------------------- /utils/copy_state_dict.py: -------------------------------------------------------------------------------- 1 | def copy_state_dict(cur_state_dict, pre_state_dict, prefix = ''): 2 | def _get_params(key): 3 | key = prefix + key 4 | if key in pre_state_dict: 5 | return pre_state_dict[key] 6 | return None 7 | 8 | for k in cur_state_dict.keys(): 9 | v = _get_params(k) 10 | try: 11 | if v is None: 12 | print('parameter {} not found'.format(k)) 13 | continue 14 | cur_state_dict[k].copy_(v) 15 | except: 16 | print('copy param {} failed'.format(k)) 17 | continue -------------------------------------------------------------------------------- /utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | import math 12 | 13 | 14 | class LR_Scheduler(object): 15 | """Learning Rate Scheduler 16 | 17 | Step mode: ``lr = baselr * 0.1 ^ {floor(epoch-1 / lr_step)}`` 18 | 19 | Cosine mode: ``lr = baselr * 0.5 * (1 + cos(iter/maxiter))`` 20 | 21 | Poly mode: ``lr = baselr * (1 - iter/maxiter) ^ 0.9`` 22 | 23 | Args: 24 | args: 25 | :attr:`args.lr_scheduler` lr scheduler mode (`cos`, `poly`), 26 | :attr:`args.lr` base learning rate, :attr:`args.epochs` number of epochs, 27 | :attr:`args.lr_step` 28 | 29 | iters_per_epoch: number of iterations per epoch 30 | """ 31 | 32 | def __init__(self, mode, base_lr, num_epochs, iters_per_epoch=0, 33 | lr_step=0, warmup_epochs=0, min_lr=None): 34 | self.mode = mode 35 | print('Using {} LR Scheduler!'.format(self.mode)) 36 | self.lr = base_lr 37 | if mode == 'step': 38 | assert lr_step 39 | self.lr_step = lr_step 40 | self.iters_per_epoch = iters_per_epoch 41 | self.N = num_epochs * iters_per_epoch 42 | self.epoch = -1 43 | self.warmup_iters = warmup_epochs * iters_per_epoch 44 | self.min_lr = min_lr 45 | 46 | def __call__(self, optimizer, i, epoch, best_pred): 47 | T = epoch * self.iters_per_epoch + i 48 | if self.mode == 'cos': 49 | lr = 0.5 * self.lr * (1 + math.cos(1.0 * T / self.N * math.pi)) 50 | elif self.mode == 'poly': 51 | lr = self.lr * pow((1 - 1.0 * T / self.N), 0.9) 52 | elif self.mode == 'step': 53 | lr = self.lr * (0.1 ** (epoch // self.lr_step)) 54 | else: 55 | raise NotImplemented 56 | # warm up lr schedule 57 | if self.min_lr is not None: 58 | if lr < self.min_lr: 59 | lr = self.min_lr 60 | if self.warmup_iters > 0 and T < self.warmup_iters: 61 | lr = lr * 1.0 * T / self.warmup_iters 62 | if epoch > self.epoch: 63 | print('\n=>Epoches %i, learning rate = %.4f, \ 64 | previous best = %.4f' % (epoch, lr, best_pred)) 65 | self.epoch = epoch 66 | assert lr >= 0 67 | self._adjust_learning_rate(optimizer, lr) 68 | 69 | def _adjust_learning_rate(self, optimizer, lr): 70 | if len(optimizer.param_groups) == 1: 71 | optimizer.param_groups[0]['lr'] = lr 72 | else: 73 | # enlarge the lr at the head 74 | optimizer.param_groups[0]['lr'] = lr 75 | for i in range(1, len(optimizer.param_groups)): 76 | optimizer.param_groups[i]['lr'] = lr * 10 77 | -------------------------------------------------------------------------------- /utils/multadds_count.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | # Original implementation: 4 | # https://github.com/warmspringwinds/pytorch-segmentation-detection/blob/master/pytorch_segmentation_detection/utils/flops_benchmark.py 5 | 6 | # ---- Public functions 7 | 8 | def count_parameters_in_MB(model): 9 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "aux" not in name)/1e6 10 | 11 | 12 | def comp_multadds(model, input_size=(3,224,224), half=False): 13 | input_size = (1,) + tuple(input_size) 14 | model = model.cuda() 15 | input_data = torch.randn(input_size).cuda() 16 | model = add_flops_counting_methods(model) 17 | if half: 18 | input_data = input_data.half() 19 | model.start_flops_count() 20 | with torch.no_grad(): 21 | _ = model(input_data, input_data) 22 | 23 | mult_adds = model.compute_average_flops_cost() / 1e6 24 | return mult_adds 25 | 26 | 27 | def comp_multadds_fw(model, input_data): 28 | model = add_flops_counting_methods(model) 29 | model = model.cuda() 30 | model.start_flops_count() 31 | with torch.no_grad(): 32 | output_data = model(input_data, input_data) 33 | 34 | mult_adds = model.compute_average_flops_cost() / 1e6 35 | return mult_adds, output_data 36 | 37 | 38 | def add_flops_counting_methods(net_main_module): 39 | """Adds flops counting functions to an existing model. After that 40 | the flops count should be activated and the model should be run on an input 41 | image. 42 | Example: 43 | fcn = add_flops_counting_methods(fcn) 44 | fcn = fcn.cuda().train() 45 | fcn.start_flops_count() 46 | _ = fcn(batch) 47 | fcn.compute_average_flops_cost() / 1e9 / 2 # Result in GFLOPs per image in batch 48 | Important: dividing by 2 only works for resnet models -- see below for the details 49 | of flops computation. 50 | Attention: we are counting multiply-add as two flops in this work, because in 51 | most resnet models convolutions are bias-free (BN layers act as bias there) 52 | and it makes sense to count muliply and add as separate flops therefore. 53 | This is why in the above example we divide by 2 in order to be consistent with 54 | most modern benchmarks. For example in "Spatially Adaptive Computatin Time for Residual 55 | Networks" by Figurnov et al multiply-add was counted as two flops. 56 | This module computes the average flops which is necessary for dynamic networks which 57 | have different number of executed layers. For static networks it is enough to run the network 58 | once and get statistics (above example). 59 | Implementation: 60 | The module works by adding batch_count to the main module which tracks the sum 61 | of all batch sizes that were run through the network. 62 | Also each convolutional layer of the network tracks the overall number of flops 63 | performed. 64 | The parameters are updated with the help of registered hook-functions which 65 | are being called each time the respective layer is executed. 66 | Parameters 67 | ---------- 68 | net_main_module : torch.nn.Module 69 | Main module containing network 70 | Returns 71 | ------- 72 | net_main_module : torch.nn.Module 73 | Updated main module with new methods/attributes that are used 74 | to compute flops. 75 | """ 76 | 77 | # adding additional methods to the existing module object, 78 | # this is done this way so that each function has access to self object 79 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 80 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 81 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 82 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__(net_main_module) 83 | 84 | net_main_module.reset_flops_count() 85 | 86 | # Adding varialbles necessary for masked flops computation 87 | net_main_module.apply(add_flops_mask_variable_or_reset) 88 | 89 | return net_main_module 90 | 91 | 92 | def compute_average_flops_cost(self): 93 | """ 94 | A method that will be available after add_flops_counting_methods() is called 95 | on a desired net object. 96 | Returns current mean flops consumption per image. 97 | """ 98 | 99 | batches_count = self.__batch_counter__ 100 | 101 | flops_sum = 0 102 | 103 | for module in self.modules(): 104 | 105 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 106 | flops_sum += module.__flops__ 107 | 108 | 109 | return flops_sum / batches_count 110 | 111 | 112 | def start_flops_count(self): 113 | """ 114 | A method that will be available after add_flops_counting_methods() is called 115 | on a desired net object. 116 | Activates the computation of mean flops consumption per image. 117 | Call it before you run the network. 118 | """ 119 | 120 | add_batch_counter_hook_function(self) 121 | 122 | self.apply(add_flops_counter_hook_function) 123 | 124 | 125 | def stop_flops_count(self): 126 | """ 127 | A method that will be available after add_flops_counting_methods() is called 128 | on a desired net object. 129 | Stops computing the mean flops consumption per image. 130 | Call whenever you want to pause the computation. 131 | """ 132 | 133 | remove_batch_counter_hook_function(self) 134 | 135 | self.apply(remove_flops_counter_hook_function) 136 | 137 | 138 | def reset_flops_count(self): 139 | """ 140 | A method that will be available after add_flops_counting_methods() is called 141 | on a desired net object. 142 | Resets statistics computed so far. 143 | """ 144 | 145 | add_batch_counter_variables_or_reset(self) 146 | 147 | self.apply(add_flops_counter_variable_or_reset) 148 | 149 | 150 | def add_flops_mask(module, mask): 151 | def add_flops_mask_func(module): 152 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 153 | module.__mask__ = mask 154 | 155 | module.apply(add_flops_mask_func) 156 | 157 | 158 | def remove_flops_mask(module): 159 | module.apply(add_flops_mask_variable_or_reset) 160 | 161 | 162 | # ---- Internal functions 163 | 164 | 165 | def conv_flops_counter_hook(conv_module, input, output): 166 | # Can have multiple inputs, getting the first one 167 | input = input[0] 168 | 169 | batch_size = input.shape[0] 170 | output_height, output_width = output.shape[2:] 171 | 172 | kernel_height, kernel_width = conv_module.kernel_size 173 | in_channels = conv_module.in_channels 174 | out_channels = conv_module.out_channels 175 | 176 | conv_per_position_flops = (kernel_height * kernel_width * in_channels * out_channels) / conv_module.groups 177 | 178 | active_elements_count = batch_size * output_height * output_width 179 | 180 | if conv_module.__mask__ is not None: 181 | # (b, 1, h, w) 182 | flops_mask = conv_module.__mask__.expand(batch_size, 1, output_height, output_width) 183 | active_elements_count = flops_mask.sum() 184 | 185 | overall_conv_flops = conv_per_position_flops * active_elements_count 186 | 187 | bias_flops = 0 188 | 189 | if conv_module.bias is not None: 190 | bias_flops = out_channels * active_elements_count 191 | 192 | overall_flops = overall_conv_flops + bias_flops 193 | 194 | conv_module.__flops__ += overall_flops 195 | 196 | 197 | def linear_flops_counter_hook(linear_module, input, output): 198 | 199 | input = input[0] 200 | batch_size = input.shape[0] 201 | overall_flops = linear_module.in_features * linear_module.out_features * batch_size 202 | 203 | # bias_flops = 0 204 | 205 | # if conv_module.bias is not None: 206 | # bias_flops = out_channels * active_elements_count 207 | 208 | # overall_flops = overall_conv_flops + bias_flops 209 | 210 | linear_module.__flops__ += overall_flops 211 | 212 | 213 | def batch_counter_hook(module, input, output): 214 | # Can have multiple inputs, getting the first one 215 | input = input[0] 216 | 217 | batch_size = input.shape[0] 218 | 219 | module.__batch_counter__ += batch_size 220 | 221 | 222 | def add_batch_counter_variables_or_reset(module): 223 | module.__batch_counter__ = 0 224 | 225 | 226 | def add_batch_counter_hook_function(module): 227 | if hasattr(module, '__batch_counter_handle__'): 228 | return 229 | 230 | handle = module.register_forward_hook(batch_counter_hook) 231 | module.__batch_counter_handle__ = handle 232 | 233 | 234 | def remove_batch_counter_hook_function(module): 235 | if hasattr(module, '__batch_counter_handle__'): 236 | module.__batch_counter_handle__.remove() 237 | 238 | del module.__batch_counter_handle__ 239 | 240 | 241 | def add_flops_counter_variable_or_reset(module): 242 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 243 | module.__flops__ = 0 244 | 245 | 246 | def add_flops_counter_hook_function(module): 247 | if isinstance(module, torch.nn.Conv2d): 248 | if hasattr(module, '__flops_handle__'): 249 | return 250 | 251 | handle = module.register_forward_hook(conv_flops_counter_hook) 252 | module.__flops_handle__ = handle 253 | elif isinstance(module, torch.nn.Linear): 254 | 255 | if hasattr(module, '__flops_handle__'): 256 | return 257 | 258 | handle = module.register_forward_hook(linear_flops_counter_hook) 259 | module.__flops_handle__ = handle 260 | 261 | 262 | def remove_flops_counter_hook_function(module): 263 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 264 | 265 | if hasattr(module, '__flops_handle__'): 266 | module.__flops_handle__.remove() 267 | 268 | del module.__flops_handle__ 269 | 270 | 271 | # --- Masked flops counting 272 | 273 | 274 | # Also being run in the initialization 275 | def add_flops_mask_variable_or_reset(module): 276 | if isinstance(module, torch.nn.Conv2d) or isinstance(module, torch.nn.Linear): 277 | module.__mask__ = None 278 | 279 | -------------------------------------------------------------------------------- /utils/saver.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import torch 4 | from collections import OrderedDict 5 | import glob 6 | 7 | 8 | class Saver(object): 9 | 10 | def __init__(self, args): 11 | self.args = args 12 | self.directory = os.path.join('run', args.dataset) 13 | self.runs = sorted(glob.glob(os.path.join(self.directory, 'experiment_*'))) 14 | run_id = max([int(x.split('_')[-1]) for x in self.runs]) + 1 if self.runs else 0 15 | 16 | self.experiment_dir = os.path.join(self.directory, 'experiment_{}'.format(str(run_id))) 17 | if not os.path.exists(self.experiment_dir): 18 | os.makedirs(self.experiment_dir) 19 | 20 | def save_checkpoint(self, state, is_best, filename='checkpoint.pth.tar'): 21 | """Saves checkpoint to disk""" 22 | filename = os.path.join(self.experiment_dir, filename) 23 | torch.save(state, filename) 24 | if is_best: 25 | best_pred = state['best_pred'] 26 | with open(os.path.join(self.experiment_dir, 'best_pred.txt'), 'w') as f: 27 | f.write(str(best_pred)) 28 | if self.runs: 29 | previous_miou = [0.0] 30 | for run in self.runs: 31 | run_id = run.split('_')[-1] 32 | path = os.path.join(self.directory, 'experiment_{}'.format(str(run_id)), 'best_pred.txt') 33 | if os.path.exists(path): 34 | with open(path, 'r') as f: 35 | miou = float(f.readline()) 36 | previous_miou.append(miou) 37 | else: 38 | continue 39 | max_miou = max(previous_miou) 40 | if best_pred > max_miou: 41 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 42 | else: 43 | shutil.copyfile(filename, os.path.join(self.directory, 'model_best.pth.tar')) 44 | 45 | def save_experiment_config(self): 46 | logfile = os.path.join(self.experiment_dir, 'parameters.txt') 47 | log_file = open(logfile, 'w') 48 | p = OrderedDict() 49 | p['datset'] = self.args.dataset 50 | p['lr'] = self.args.lr 51 | p['lr_scheduler'] = self.args.lr_scheduler 52 | p['epoch'] = self.args.epochs 53 | p['crop_height'] = self.args.crop_height 54 | p['crop_width'] = self.args.crop_width 55 | 56 | for key, val in p.items(): 57 | log_file.write(key + ':' + str(val) + '\n') 58 | log_file.close() 59 | -------------------------------------------------------------------------------- /utils/summaries.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torchvision.utils import make_grid 4 | from tensorboardX import SummaryWriter 5 | import numpy as np 6 | import pdb 7 | 8 | 9 | class TensorboardSummary(object): 10 | def __init__(self, directory): 11 | self.directory = directory 12 | 13 | def create_summary(self): 14 | writer = SummaryWriter(logdir=os.path.join(self.directory)) 15 | return writer 16 | 17 | def tensor2array(self, tensor, max_value=255, colormap='rainbow'): 18 | if max_value is None: 19 | max_value = tensor.max().numpy() 20 | if tensor.ndimension() == 2: 21 | try: 22 | import cv2 23 | if cv2.__version__.startswith('2'):# 2.4 24 | color_cvt = cv2.cv.CV_BGR2RGB 25 | else: 26 | color_cvt = cv2.COLOR_BGR2RGB 27 | if colormap == 'rainbow': 28 | colormap = cv2.COLORMAP_RAINBOW 29 | elif colormap == 'bone': 30 | colormap = cv2.COLORMAP_BONE 31 | array = (255*tensor.numpy()/max_value).clip(0, 255).astype(np.uint8) 32 | colored_array = cv2.applyColorMap(array, colormap) 33 | array = cv2.cvtColor(colored_array, color_cvt).astype(np.float32)/255 34 | array = array.transpose(2,0,1) 35 | except ImportError: 36 | if tensor.ndimension() == 2: 37 | tensor.unsqueeze_(2) 38 | array = (tensor.expand(tensor.size(0), tensor.size(1), 3).numpy()/max_value).clip(0,1) 39 | 40 | elif tensor.ndimension() == 3: 41 | #assert(tensor.size(0) == 3) 42 | array = 0.5 + tensor.numpy().transpose(1, 2, 0)*0.5 43 | return array 44 | 45 | def visualize_image_stereo(self, writer, image, target, output, global_step): 46 | pr_image = self.tensor2array(output[0].cpu().data, max_value=144, colormap='bone') 47 | writer.add_image('Predicted disparity', pr_image, global_step) 48 | gt_image = self.tensor2array(target[0].cpu().data, max_value=144, colormap='bone') 49 | writer.add_image('Groundtruth disparity', gt_image, global_step) 50 | --------------------------------------------------------------------------------