├── .gitattributes ├── .github └── workflows │ └── tests.yml ├── .gitignore ├── LICENSE ├── README.md ├── StereoNet.ipynb ├── hydra_conf ├── config.example.yaml ├── training_paths.example.txt └── validation_paths.example.txt ├── pyproject.toml ├── requirements.txt ├── requirements_dev.txt ├── setup.cfg ├── setup.py ├── src └── stereonet │ ├── __init__.py │ ├── datasets.py │ ├── model.py │ ├── py.typed │ ├── train.py │ ├── types.py │ ├── utils.py │ └── utils_io.py ├── tests ├── __init__.py ├── conftest.py └── test_model.py └── tox.ini /.gitattributes: -------------------------------------------------------------------------------- 1 | *.ipynb linguist-vendored -------------------------------------------------------------------------------- /.github/workflows/tests.yml: -------------------------------------------------------------------------------- 1 | name: Tests 2 | 3 | on: 4 | - push 5 | - pull_request 6 | 7 | jobs: 8 | test: 9 | runs-on: ${{ matrix.os }} 10 | strategy: 11 | matrix: 12 | os: [ubuntu-latest, windows-latest] 13 | python-version: ['3.8'] 14 | 15 | steps: 16 | - uses: actions/checkout@v2 17 | - name: Set up Python ${{ matrix.python-version }} 18 | uses: actions/setup-python@v2 19 | with: 20 | python-version: ${{ matrix.python-version }} 21 | - name: Install dependencies 22 | run: | 23 | python -m pip install --upgrade pip 24 | pip install tox tox-gh-actions 25 | - name: Test with tox 26 | run: tox 27 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | 131 | # VSCode 132 | .vscode/ 133 | 134 | # PyTorch Lightning loggers 135 | logging/ 136 | 137 | # Permanently saved models 138 | saved_models/ 139 | 140 | # Data folder 141 | data/ 142 | 143 | # Hiding config file just to make the configs easier to understand and self-generate 144 | hydra_conf/config.yaml 145 | hydra_conf/training_paths.txt 146 | hydra_conf/validation_paths.txt 147 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | This is free and unencumbered software released into the public domain. 2 | 3 | Anyone is free to copy, modify, publish, use, compile, sell, or 4 | distribute this software, either in source code form or as a compiled 5 | binary, for any purpose, commercial or non-commercial, and by any 6 | means. 7 | 8 | In jurisdictions that recognize copyright laws, the author or authors 9 | of this software dedicate any and all copyright interest in the 10 | software to the public domain. We make this dedication for the benefit 11 | of the public at large and to the detriment of our heirs and 12 | successors. We intend this dedication to be an overt act of 13 | relinquishment in perpetuity of all present and future rights to this 14 | software under copyright law. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 17 | EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF 18 | MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. 19 | IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR 20 | OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, 21 | ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 22 | OTHER DEALINGS IN THE SOFTWARE. 23 | 24 | For more information, please refer to -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # StereoNet implemented in PyTorch 2 | 3 | ![Tests](https://github.com/andrewlstewart/StereoNet_PyTorch/actions/workflows/tests.yml/badge.svg) 4 | 5 | [![Open All Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/andrewlstewart/StereoNet_PyTorch/blob/main/StereoNet.ipynb) 6 | 7 | Refer to the above Colab notebook for an example of model inference. 8 | 9 | Install with: 10 | ``` 11 | pip install "git+https://github.com/andrewlstewart/StereoNet_PyTorch" 12 | ``` 13 | 14 | 15 | Example left image disparity generated from stereo image pair: 16 | 17 | ![output](https://user-images.githubusercontent.com/7529012/235552448-b96f6023-7349-4f97-899f-fcdc9276e573.png) 18 | 19 | How to perform the most basic inference: 20 | 21 | ``` 22 | import numpy as np 23 | import torch 24 | 25 | from stereonet.model import StereoNet 26 | import stereonet.utils_io 27 | 28 | # Load in the image pair as numpy uint8 arrays, ensure the shapes are the same for both images for concatenation [Height, Width, Channels] 29 | # left = stereonet.utils_io.image_loader(path_to_left_rgb_image_file) # [Height, Width, Channel] [0, 255] uint8 30 | # right = stereonet.utils_io.image_loader(path_to_left_rgb_image_file) # [Height, Width, Channel] [0, 255] uint8 31 | # min_height = min(left.shape[0], right.shape[0]) 32 | # min_width = min(left.shape[1], right.shape[1]) 33 | # tensored = [torch.permute(torch.from_numpy(array).to(torch.float32).to(device), (2, 0, 1)) for array in (left, right)] [Channel, Height, Width] [0, 255] uint8 34 | # cropper = torchvision.transforms.CenterCrop((min_height, min_width)) 35 | # stack = torch.concatenate(list(map(cropper, tensored)), dim=0) # [Stacked left/right channels, Height, Width] [0, 255] float32 36 | 37 | # Here just creating a random image 38 | # stack = torch.randint(0, 256, size=(1, 6, 540, 960), dtype=torch.float32) # [Batch, Stacked left/right channels, Height, Width] 6 for 3 RGB x 2 images 39 | stack = torch.randint(0, 256, size=(1, 2, 540, 960), dtype=torch.float32) # [Batch, Stacked left/right channels, Height, Width] 2 for 1 grayscale x 2 images 40 | 41 | normalizer = torchvision.transforms.Normalize((111.5684, 113.6528), (61.9625, 62.0313)) 42 | normalized = normalizer(stack) 43 | 44 | batch = torch.unsqueeze(normalized, dim=0) 45 | 46 | # Load in the model from the trained checkpoint 47 | # model = StereoNet.load_from_checkpoint(path_to_checkpoint) # "C:\\users\\name\\Downloads\\epoch=21-step=696366.ckpt" 48 | 49 | # Here just instantiate the model with random weights 50 | model = StereoNet(in_channels=1) # 3 channels for RGB, 1 channel for grayscale 51 | 52 | # Set the model to eval and run the forward method without tracking gradients 53 | model.eval() 54 | with torch.no_grad(): 55 | batched_prediction = model(sample) 56 | 57 | # Remove the batch diemnsion and switch back to channels last notation 58 | single_prediction = batched_prediction[0].detach().cpu().numpy() # [batch, channel, height, width] -> [channel, height, width] 59 | single_prediction = np.moveaxis(single_prediction, 0, 2) # [channel, height, width] -> [height, width, channel] 60 | 61 | assert (single_prediction.shape) == (540, 960, 1) 62 | ``` 63 | 64 | ## Weights 65 | KeystoneDepth checkpoint: https://www.dropbox.com/s/ffgeqyzk4kec9cf/epoch%3D21-step%3D696366.ckpt?dl=0 66 | 67 | * Trained with this mean/std normalizer for left/right grayscale images: torchvision.transforms.Normalize((111.5684, 113.6528), (61.9625, 62.0313)) 68 | * Model was trained on grayscale images and has in_channels=1 69 | * Train/val split ratio of 85% 70 | * Max disparity parameter during training = 256 with the mask applied 71 | * 3 downsampling (1/8 resolution) and 3 refinement layers 72 | * Batch size of 1 73 | * Trained for a maximum of 25 epochs (lowest validation loss was at epoch 21) 74 | * RMSProp with a learning rate of 2.54e-4 75 | * Exponention LR scheduler with gamma=0.9 76 | * Maximum image side length of 625 with aspect ratio preserving resizing 77 | * Validation EPE of 1.543 for all pixels (including >256). 78 | 79 | Train and validation loss curves for the KeystoneDepth training run: 80 | 81 | KeystoneDepth_train_loss 82 | KeystoneDepth_val_loss 83 | 84 | On a GTX 1070, wall clock time was about 5-6 hours per epoch or ~6-7 days of wall clock time to train. 85 | 86 | Older model checkpoint trained on Sceneflow corresponding with this [commit](https://github.com/andrewlstewart/StereoNet_PyTorch/tree/9c0260f270547d8001e9d637cf3a94658f805bae): https://www.dropbox.com/s/9gpjfe3r1rfch02/epoch%3D20-step%3D744533.ckpt?dl=0 87 | 88 | * Model was trained on RGB images and has in_channels=3 89 | * Max disparity parameter during training = 256 with the mask applied 90 | * 3 downsampling (1/8 resolution) and 3 refinement layers 91 | * Validation EPE of 3.93 for all pixels (including >256). 92 | 93 | ## Notes 94 | 95 | Implementation of the StereoNet network to compute a disparity map using stereo images. 96 | 97 | This project was implemented using PyTorch Lightning + Hydra as a learning exercise to learn about stereo networks, PyTorch, PyTorch Lightning, and Hydra. Feel free to make any comments or recommendations for better coding practice. 98 | 99 | Currently implemented 100 | 101 | * Downsampling feature network with `k_downsampling_layers` 102 | * Cost volume filtering 103 | * When training, a left *and* right cost volume is computed with the loss arising from the mean of the losses of left and right disparity delta to ground truth. 104 | * Hierarchical refinement with cascading `k_refinement_layers` 105 | * Robust loss function [A General and Adaptive Robust Loss Function, Barron (2019)](https://arxiv.org/abs/1701.03077) 106 | 107 | Two repos were relied on heavily to inform the network (along with the actual paper) 108 | 109 | Original paper: https://arxiv.org/abs/1807.08865 110 | 111 | X-StereoLab: https://github.com/meteorshowers/X-StereoLab/blob/9ae8c1413307e7df91b14a7f31e8a95f9e5754f9/disparity/models/stereonet_disp.py 112 | 113 | ZhiXuanLi: https://github.com/zhixuanli/StereoNet/blob/f5576689e66e8370b78d9646c00b7e7772db0394/models/stereonet.py 114 | 115 | I believe ZhiXuanLi's repo follows the paper best up until line 107 (note their CostVolume computation is incorrect) 116 | https://github.com/zhixuanli/StereoNet/issues/12#issuecomment-508327106 117 | 118 | X-StereoLab is good up until line 180. X-StereoLab return both the up sampled and refined independently and don't perform the final ReLU. 119 | 120 | I believe the implementation that I have written takes the best of both repos and follows the paper most closely. 121 | 122 | Noteably, the argmin'd disparity is computed prior to the bilinear interpolation (follows X-Stereo but not ZhiXuanLi, the latter do it reverse order). 123 | 124 | Further, neither repo had a cascade of refinement networks and neither repo trained on both the left *and* right disparities. I believe my repo has both of these correctly implemented. 125 | 126 | The paper clearly states they use (many) batch norm layers while simultaneously using a batch size of 1. I find this interesting. I naively tried training on random 50% crops (same crop applied to left/right/and disparities) so that I could get more samples into a batch but I think I was losing too many features so the EPE was consistently high. Currently, training using a single sample (left/right images and left/right disparity). I still needed to crop down to 513x912 images in order to not run into GPU memory issues. 127 | -------------------------------------------------------------------------------- /hydra_conf/config.example.yaml: -------------------------------------------------------------------------------- 1 | hydra: 2 | run: 3 | dir: ./logging/${now:%Y-%m-%d}/${now:%H-%M-%S}/hydra 4 | 5 | stereonet_config: 6 | _target_: stereonet.types.StereoNetConfig 7 | global_settings: 8 | _target_: stereonet.types.GlobalSettings 9 | devices: 10 | _target_: stereonet.types.Devices 11 | num: 1 # type=int, help="Number of devices to use during training" 12 | type: "gpu" # type=str, help="Type of device to use during training" 13 | 14 | logging: 15 | _target_: stereonet.types.Logging 16 | lightning_log_root: "./logging/${now:%Y-%m-%d}/${now:%H-%M-%S}" # type=str, help="Root path to output logs files to" 17 | 18 | model: 19 | # _target_: stereonet.types.CheckpointModel 20 | # model_checkpoint_path: "E:\\StereoNet_PyTorch_Refactor\\logging\\2023-04-10\\10-31-22\\lightning_logs\\version_0\\checkpoints\\epoch=25-step=822978.ckpt" 21 | _target_: stereonet.types.StereoNetModel 22 | in_channels: 1 23 | k_downsampling_layers: 3 24 | k_refinement_layers: 3 25 | candidate_disparities: 256 26 | 27 | training: 28 | _target_: stereonet.types.Training 29 | fast_dev_run: false # type=bool, help="Whether or not to run a fast dev run of the training loop" 30 | random_seed: null # type=Optional[int], help="Seed to set the torch random generator for reproducible training runs" 31 | deterministic: null # type=bool, help="Whether or not to set the PyTorch Lightning Trainer to deterministic mode" 32 | mask: true # type=bool, help="Whether or not to compute the loss for disparities larger than the model's max_disparities" 33 | min_epochs: 15 # type=int, help="Number of iterations of training over the full train dataset" 34 | max_epochs: 25 # type=int, help="Number of iterations of training over the full train dataset" 35 | optimizer_partial: 36 | _target_: torch.optim.RMSprop 37 | _partial_: true 38 | lr: 2.54e-4 # type=float, help="Starting learning rate" 39 | weight_decay: 0.0001 40 | scheduler_partial: 41 | _target_: torch.optim.lr_scheduler.ExponentialLR 42 | _partial_: true 43 | gamma: 0.9 44 | loader: 45 | _target_: stereonet.types.Loader 46 | batch_size: 1 # type=int, help='Number of examples for data loading.' 47 | data: 48 | # - _target_: stereonet.types.SceneflowData 49 | # root_path: "E:\\Sceneflow" # type=str, help='Path to the root of the Sceneflow depth files' 50 | # transforms: 51 | # - name: rescale 52 | # - name: center_crop 53 | # properties: 54 | # scale: 0.925 # type=float, help='Center crop percentage to decrease GPU memory' 55 | - _target_: stereonet.types.KeystoneDepthData 56 | root_path: "E:\\Keystone\\projects\\grail\\slowglass\\2-BBox\\annotation_results" # type=str, help='Path to the root of the Keystone depth files' 57 | split_ratio: 0.85 # type=float, help='Train:Test split ratio.' 58 | max_size: 59 | - 625 60 | - 625 61 | transforms: 62 | - _target_: torchvision.transforms.Normalize 63 | mean: 64 | - 111.5684 65 | - 113.6528 66 | - 4.3221 67 | - 4.2296 68 | std: 69 | - 61.9625 70 | - 62.0313 71 | - 10.8142 72 | - 9.9528 73 | debug: 74 | _target_: stereonet.types.DataDebug 75 | enabled: False # type=bool, help="Whether or not to use the following debugging flags" 76 | limit_train_batches: 10_000 # type=int, help="Debugging, how many batches to train on" 77 | 78 | validation: 79 | _target_: stereonet.types.Validation 80 | loader: 81 | _target_: stereonet.types.Loader 82 | batch_size: 1 # type=int, help='Number of examples for data loading.' 83 | data: 84 | - _target_: stereonet.types.KeystoneDepthData 85 | root_path: "E:\\Keystone\\projects\\grail\\slowglass\\2-BBox\\annotation_results" # type=str, help='Path to the root of the Keystone depth files' 86 | split_ratio: 0.9 # type=float, help='Train:Test split ratio.' 87 | max_size: 88 | - 625 89 | - 625 90 | transforms: 91 | - _target_: torchvision.transforms.Normalize 92 | mean: 93 | - 111.5684 94 | - 113.6528 95 | - 4.3221 96 | - 4.2296 97 | std: 98 | - 61.9625 99 | - 62.0313 100 | - 10.8142 101 | - 9.9528 -------------------------------------------------------------------------------- /hydra_conf/training_paths.example.txt: -------------------------------------------------------------------------------- 1 | cropped\cropped_rectified_LR\04733L.png,cropped\cropped_rectified_LR\04733R.png,processed\rectified\disp_info_LR\04733L.exr,processed\rectified\disp_info_LR\04733R.exr 2 | cropped\cropped_rectified_LR\04771L.png,cropped\cropped_rectified_LR\04771R.png,processed\rectified\disp_info_LR\04771L.exr,processed\rectified\disp_info_LR\04771R.exr 3 | cropped\cropped_rectified_LR\04775L.png,cropped\cropped_rectified_LR\04775R.png,processed\rectified\disp_info_LR\04775L.exr,processed\rectified\disp_info_LR\04775R.exr -------------------------------------------------------------------------------- /hydra_conf/validation_paths.example.txt: -------------------------------------------------------------------------------- 1 | cropped\cropped_full_LR\51498L.png,cropped\cropped_full_LR\51498R.png,processed\rectified\disp_info_LR\51498L.exr,processed\rectified\disp_info_LR\51498R.exr 2 | cropped\cropped_full_LR\33563L.png,cropped\cropped_full_LR\33563R.png,processed\rectified\disp_info_LR\33563L.exr,processed\rectified\disp_info_LR\33563R.exr 3 | cropped\cropped_full_LR\33567L.png,cropped\cropped_full_LR\33567R.png,processed\rectified\disp_info_LR\33567L.exr,processed\rectified\disp_info_LR\33567R.exr -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [build-system] 2 | requires = ["setuptools>=67.6.1", "wheel"] 3 | build-backend = "setuptools.build_meta" 4 | 5 | [tool.pytest.ini_options] 6 | addopts = "--cov=stereonet" 7 | testpaths = [ 8 | "tests", 9 | ] 10 | 11 | [tool.mypy] 12 | mypy_path = "src" 13 | check_untyped_defs = true 14 | disallow_any_generics = true 15 | disallow_untyped_defs = true 16 | ignore_missing_imports = true 17 | no_implicit_optional = true 18 | show_error_codes = true 19 | strict_equality = true 20 | warn_redundant_casts = true 21 | warn_return_any = true 22 | warn_unreachable = true 23 | warn_unused_configs = true 24 | no_implicit_reexport = false -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision==0.15.1 3 | torchaudio==2.0.1 4 | tensorboard==2.12.1 5 | lightning==2.0.1 6 | scikit_image==0.20.0 7 | setuptools==67.6.1 8 | hydra-core==1.3.2 9 | opencv-python>=4.7.0.72 10 | matplotlib>=3.7.1 -------------------------------------------------------------------------------- /requirements_dev.txt: -------------------------------------------------------------------------------- 1 | torch==2.0.0 2 | torchvision==0.15.1 3 | torchaudio==2.0.1 4 | tensorboard==2.12.1 5 | lightning==2.0.1 6 | scikit_image==0.20.0 7 | matplotlib==3.7.1 8 | setuptools==67.6.1 9 | hydra-core==1.3.2 10 | opencv-python==4.7.0.72 11 | flake8==6.0.0 12 | tox==4.4.8 13 | pytest==7.2.2 14 | pytest-cov==4.0.0 15 | mypy==1.1.1 -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | name = stereonet 3 | description = PyTorch Lightning implementation of StereoNet 4 | author = Andrew Stewart 5 | license = The Unlicense 6 | license_file = LICENSE 7 | platforms = unix, linux, osx, cygwin, win32 8 | classifiers = 9 | Programming Language :: Python :: 3 10 | Programming Language :: Python :: 3 :: Only 11 | Programming Language :: Python :: 3.8 12 | 13 | [options] 14 | packages = 15 | stereonet 16 | install_requires = 17 | torch>=2.0.0 18 | torchvision>=0.15.1 19 | torchaudio>=2.0.1 20 | tensorboard>=2.12.1 21 | lightning>=2.0.1 22 | scikit_image>=0.20.0 23 | setuptools>=67.6.1 24 | hydra-core>=1.3.2 25 | opencv-python>=4.7.0.72 26 | matplotlib>=3.7.1 27 | python_requires = >=3.8 28 | package_dir = 29 | =src 30 | zip_safe = no 31 | 32 | [options.extras_require] 33 | testing = 34 | pytest>=7.2.2 35 | pytest-cov>=4.0.0 36 | mypy>=1.1.1 37 | flake8>=6.0.0 38 | tox>=4.4.8 39 | 40 | [options.package_data] 41 | stereonet = py.typed 42 | 43 | [flake8] 44 | max-line-length = 240 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | if __name__ == "__main__": 4 | setup() -------------------------------------------------------------------------------- /src/stereonet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewlstewart/StereoNet_PyTorch/3cb3a0215a3c4690d74d93ae8973feaaf9f65bc1/src/stereonet/__init__.py -------------------------------------------------------------------------------- /src/stereonet/datasets.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | 4 | from typing import Optional, List, Set, Tuple, Union, Any 5 | from pathlib import Path 6 | import os 7 | 8 | import hydra 9 | from hydra.core.hydra_config import HydraConfig 10 | import numpy as np 11 | import torch 12 | from torch.utils.data import Dataset, DataLoader 13 | import torchvision.transforms as transforms 14 | 15 | import stereonet.types as stt 16 | import stereonet.utils_io as stu 17 | 18 | 19 | RNG = np.random.default_rng() 20 | 21 | 22 | class SceneflowDataset(Dataset[torch.Tensor]): 23 | """ 24 | Sceneflow dataset composed of FlyingThings3D, Driving, and Monkaa 25 | https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html 26 | 27 | Download the RGB (cleanpass) PNG image and the Disparity files 28 | 29 | The train set includes FlyingThings3D Train folder and all files in Driving and Monkaa folders 30 | The test set includes FlyingThings3D Test folder 31 | """ 32 | 33 | def __init__(self, 34 | root_path: Path, 35 | transforms: Optional[List[torch.nn.Module]] = None, 36 | string_exclude: Optional[str] = None, 37 | string_include: Optional[str] = None 38 | ): 39 | self.root_path = root_path 40 | self.string_exclude = string_exclude 41 | self.string_include = string_include 42 | 43 | self.transforms = transforms 44 | 45 | self.left_image_path, self.right_image_path, self.left_disp_path, self.right_disp_path = self.get_paths(self.root_path, self.string_include, self.string_exclude) 46 | 47 | def __len__(self) -> int: 48 | return len(self.left_image_path) 49 | 50 | def __getitem__(self, index: int) -> torch.Tensor: 51 | left = stu.image_loader(self.left_image_path[index]) 52 | right = stu.image_loader(self.right_image_path[index]) 53 | 54 | disp_left, _ = stu.pfm_loader(self.left_disp_path[index]) 55 | disp_left = disp_left[..., np.newaxis] 56 | disp_left = np.ascontiguousarray(disp_left) 57 | 58 | disp_right, _ = stu.pfm_loader(self.right_disp_path[index]) 59 | disp_right = disp_right[..., np.newaxis] 60 | disp_right = np.ascontiguousarray(disp_right) 61 | 62 | assert left.dtype == np.uint8 63 | assert right.dtype == np.uint8 64 | assert disp_left.dtype != np.uint8 65 | assert disp_right.dtype != np.uint8 66 | 67 | # ToTensor works differently for dtypes, for uint8 it scales to [0,1], for float32 it does not scale 68 | tensorer = transforms.ToTensor() 69 | stack = torch.concatenate(list(map(tensorer, (left, right, disp_left, disp_right))), dim=0) # C, H, W 70 | 71 | if self.transforms is not None: 72 | for transform in self.transforms: 73 | stack = transform(stack) 74 | 75 | return stack 76 | 77 | @staticmethod 78 | def get_paths(root_path: Path, string_include: Optional[str] = None, string_exclude: Optional[str] = None) -> Tuple[List[Path], List[Path], List[Path], List[Path]]: 79 | """ 80 | string_exclude: If this string appears in the parent path of an image, don't add them to the dataset (ie. 'TEST' will exclude any path with 'TEST' in Path.parts) 81 | string_include: If this string DOES NOT appear in the parent path of an image, don't add them to the dataset (ie. 'TEST' will require 'TEST' to be in the Path.parts) 82 | if shuffle is None, don't shuffle, else shuffle. 83 | """ 84 | 85 | left_image_path = [] 86 | right_image_path = [] 87 | left_disp_path = [] 88 | right_disp_path = [] 89 | 90 | # For each left image, do some path manipulation to find the corresponding right 91 | # image and left disparity. 92 | for path in root_path.rglob('*.png'): 93 | if 'left' not in path.parts: 94 | continue 95 | 96 | if string_exclude and string_exclude in path.parts: 97 | continue 98 | if string_include and string_include not in path.parts: 99 | continue 100 | 101 | r_path = Path("\\".join(['right' if 'left' in part else part for part in path.parts])) 102 | dl_path = Path("\\".join([f'{part.replace("frames_cleanpass","")}disparity' if 'frames_cleanpass' in part else part for part in path.parts])).with_suffix('.pfm') 103 | dr_path = Path("\\".join([f'{part.replace("frames_cleanpass","")}disparity' if 'frames_cleanpass' in part else part for part in r_path.parts])).with_suffix('.pfm') 104 | # assert r_path.exists() 105 | # assert d_path.exists() 106 | 107 | if not r_path.exists() or not dl_path.exists(): 108 | continue 109 | 110 | left_image_path.append(path) 111 | right_image_path.append(r_path) 112 | left_disp_path.append(dl_path) 113 | right_disp_path.append(dr_path) 114 | 115 | return (left_image_path, right_image_path, left_disp_path, right_disp_path) 116 | 117 | 118 | class KeystoneDataset(Dataset[torch.Tensor]): 119 | """ 120 | https://keystonedepth.cs.washington.edu/download 121 | """ 122 | 123 | def __init__(self, 124 | root_path: str, 125 | image_paths: str, 126 | transforms: Optional[List[torch.nn.Module]] = None, 127 | max_size: Optional[List[int]] = None, 128 | ): 129 | 130 | self.root_path = root_path 131 | 132 | self.transforms = transforms 133 | 134 | self.left_image_path, self.right_image_path, self.left_disp_path, self.right_disp_path = [], [], [], [] 135 | with open(image_paths, 'r') as f: 136 | for line in f: 137 | left, right, disp_left, disp_right = line.rstrip().split(',') 138 | self.left_image_path.append(left) 139 | self.right_image_path.append(right) 140 | self.left_disp_path.append(disp_left) 141 | self.right_disp_path.append(disp_right) 142 | 143 | self.max_size = max_size 144 | 145 | def __len__(self) -> int: 146 | return len(self.left_image_path) 147 | 148 | def __getitem__(self, index: int) -> torch.Tensor: 149 | left = stu.image_loader(os.path.join(self.root_path, self.left_image_path[index])) 150 | right = stu.image_loader(os.path.join(self.root_path, self.right_image_path[index])) 151 | 152 | disp_left = stu.exr_loader(os.path.join(self.root_path, self.left_disp_path[index])) 153 | disp_left = np.ascontiguousarray(disp_left) 154 | 155 | disp_right = stu.exr_loader(os.path.join(self.root_path, self.right_disp_path[index])) 156 | disp_right = np.ascontiguousarray(disp_right) 157 | 158 | assert left.dtype == np.uint8 159 | assert right.dtype == np.uint8 160 | assert disp_left.dtype == np.float32 161 | assert disp_right.dtype == np.float32 162 | 163 | min_height = min(left.shape[0], right.shape[0], disp_left.shape[0], disp_right.shape[0]) 164 | min_width = min(left.shape[1], right.shape[1], disp_left.shape[1], disp_right.shape[1]) 165 | 166 | tensored = [torch.permute(torch.from_numpy(array).to(torch.float32), (2, 0, 1)) for array in (left, right, disp_left, disp_right)] 167 | 168 | # Not sure if this is the best way to do this... 169 | # Keystone dataset sizes between left/right/disp_left/disp_right are inconsistent 170 | cropper = transforms.CenterCrop((min_height, min_width)) 171 | stack = torch.concatenate(list(map(cropper, tensored)), dim=0) # C, H, W 172 | 173 | if self.transforms is not None: 174 | for transform in self.transforms: 175 | stack = transform(stack) 176 | 177 | height, width = stack.size()[-2:] 178 | 179 | # preserve aspect ratio 180 | if self.max_size and (height > self.max_size[0] or width > self.max_size[1]): 181 | original_aspect_ratio = width / height 182 | new_height = int(min(self.max_size[0], self.max_size[0] / original_aspect_ratio)) 183 | new_width = int(min(self.max_size[1], original_aspect_ratio * self.max_size[1])) 184 | resizer = transforms.Resize(size=(new_height, new_width), antialias=True) 185 | stack = resizer(stack) 186 | 187 | return stack 188 | 189 | @staticmethod 190 | def get_paths(root_path: str, image_extensions: Set[str]) -> Tuple[List[str], List[str], List[str], List[str]]: 191 | """ 192 | if shuffle is None, don't shuffle, else shuffle. 193 | """ 194 | 195 | left_image_path = [] 196 | right_image_path = [] 197 | left_disp_path = [] 198 | right_disp_path = [] 199 | 200 | # For each left image, do some path manipulation to find the corresponding right 201 | # image and left/right disparity. 202 | for root, dirs, files in os.walk(root_path): 203 | for file_ in files: 204 | name, ext = os.path.splitext(file_) 205 | if ext not in image_extensions: 206 | continue 207 | 208 | if name[-1] != 'L': 209 | continue 210 | 211 | l_path = os.path.join(root, file_) 212 | r_path = os.path.join(root, name[:-1] + 'R' + ext) 213 | assert os.path.exists(r_path) 214 | 215 | parent_path = os.path.dirname(os.path.dirname(root)) # idempotent 216 | parent_path = os.path.join(os.path.join(os.path.join(parent_path, 'processed'), 'rectified'), 'disp_info_LR') 217 | dl_path = os.path.join(parent_path, name[:-1] + 'L' + '.exr') 218 | dr_path = os.path.join(parent_path, name[:-1] + 'R' + '.exr') 219 | 220 | assert os.path.exists(dl_path) 221 | assert os.path.exists(dr_path) 222 | 223 | # if not r_path.exists() or not dl_path.exists(): 224 | # continue 225 | 226 | left_image_path.append(l_path) 227 | right_image_path.append(r_path) 228 | left_disp_path.append(dl_path) 229 | right_disp_path.append(dr_path) 230 | 231 | return (left_image_path, right_image_path, left_disp_path, right_disp_path) 232 | 233 | 234 | def construct_sceneflow_dataset(cfg: stt.SceneflowData, is_training: bool) -> SceneflowDataset: 235 | dataset = SceneflowDataset(root_path=Path(cfg.root_path), 236 | transforms=cfg.transforms, 237 | string_exclude='TEST' if is_training else None, 238 | string_include=None if is_training else 'TEST', 239 | ) 240 | return dataset 241 | 242 | 243 | def construct_keystone_dataset(cfg: stt.KeystoneDepthData, is_training: bool) -> KeystoneDataset: 244 | root_path = Path(HydraConfig.get().runtime.cwd) / 'hydra_conf' 245 | 246 | training_paths = root_path / 'training_paths.txt' 247 | validation_paths = root_path / 'validation_paths.txt' 248 | 249 | if (is_training and not training_paths.exists()) or (not is_training and not validation_paths.exists()): 250 | assert not training_paths.exists() and not validation_paths.exists(), "Either both training and validation paths should exist, or neither should exist." 251 | left_image_path, right_image_path, left_disp_path, right_disp_path = KeystoneDataset.get_paths(cfg.root_path, image_extensions={'.png'}) 252 | train_indices = set(RNG.choice(len(left_image_path), size=int(cfg.split_ratio*len(left_image_path)), replace=False)) 253 | val_indices = set(range(len(left_image_path))) - train_indices 254 | 255 | for path, indices in [(training_paths, train_indices), (validation_paths, val_indices)]: 256 | with open(path, 'w') as f: 257 | root = Path(left_image_path[0]).parents[2] 258 | rows = [f'{Path(left_image_path[i]).relative_to(root)},{Path(right_image_path[i]).relative_to(root)},{Path(left_disp_path[i]).relative_to(root)},{Path(right_disp_path[i]).relative_to(root)}' 259 | for i in indices] 260 | f.write("\n".join(rows)) 261 | 262 | dataset = KeystoneDataset(root_path=cfg.root_path, 263 | image_paths=str(training_paths) if is_training else str(validation_paths), 264 | transforms=cfg.transforms, 265 | max_size=cfg.max_size 266 | ) 267 | return dataset 268 | 269 | 270 | def construct_dataloaders(data_config: Union[stt.Training, stt.Validation], 271 | is_training: bool, 272 | **kwargs: Any 273 | ) -> DataLoader[torch.Tensor]: 274 | for datum_config in data_config.data: 275 | if isinstance(datum_config, stt.KeystoneDepthData): 276 | dataset: Dataset[torch.Tensor] = construct_keystone_dataset(datum_config, is_training) 277 | elif isinstance(datum_config, stt.SceneflowData): 278 | dataset = construct_sceneflow_dataset(datum_config, is_training) 279 | 280 | return DataLoader(dataset, batch_size=data_config.loader.batch_size, **kwargs) 281 | 282 | 283 | def get_normalization_values(dataloader: DataLoader[torch.Tensor]) -> Tuple[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: 284 | """ 285 | Get the min/max values and mean/std values for each channel in the training dataset. 286 | """ 287 | mins: List[torch.Tensor] = [] 288 | maxs: List[torch.Tensor] = [] 289 | means: List[torch.Tensor] = [] 290 | squared_means: List[torch.Tensor] = [] 291 | 292 | for data in dataloader: 293 | mins.append(data.min(0)[0].min(1)[0].min(1)[0]) 294 | maxs.append(data.max(0)[0].max(1)[0].max(1)[0]) 295 | means.append(data.mean(dim=(0, 2, 3))) 296 | squared_means.append((data**2).mean(dim=(0, 2, 3))) 297 | 298 | all_mins = torch.vstack(mins).min(dim=0)[0] 299 | all_maxs = torch.vstack(maxs).max(dim=0)[0] 300 | 301 | all_mean = torch.vstack(means).mean(dim=0) 302 | all_std = torch.sqrt(torch.vstack(squared_means).mean(dim=0) - all_mean**2) 303 | 304 | return (all_mins, all_maxs), (all_mean, all_std) 305 | 306 | 307 | @hydra.main(version_base=None, config_name="config") 308 | def main(cfg: stt.StereoNetConfig) -> int: 309 | config: stt.StereoNetConfig = hydra.utils.instantiate(cfg, _convert_="all")['stereonet_config'] 310 | 311 | if config.training is None: 312 | raise Exception("Need to provide training arguments to get normalization values.") 313 | 314 | # Get training dataset 315 | train_loader = construct_dataloaders(data_config=config.training, 316 | is_training=True, 317 | shuffle=False, num_workers=8, drop_last=False) 318 | (mins, maxs), (mean, std) = get_normalization_values(train_loader) 319 | 320 | print(f"{mins=}") 321 | print(f"{maxs=}") 322 | 323 | print(f"{mean=}") 324 | print(f"{std=}") 325 | 326 | return 0 327 | 328 | 329 | if __name__ == "__main__": 330 | raise SystemExit(main()) 331 | -------------------------------------------------------------------------------- /src/stereonet/model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Classes and functions to instantiate, and train, a StereoNet model (https://arxiv.org/abs/1807.08865). 3 | 4 | StereoNet model is decomposed into a feature extractor, cost volume creation, and a cascade of refiner networks. 5 | 6 | Loss function is the Robust Loss function (https://arxiv.org/abs/1701.03077) 7 | """ 8 | 9 | from typing import Tuple, List, Optional, Dict, Any, Callable 10 | from collections import OrderedDict 11 | 12 | import numpy as np 13 | import torch 14 | from torch import nn 15 | import torch.nn.functional as F 16 | import lightning.pytorch as pl 17 | from lightning.pytorch import loggers as pl_loggers 18 | 19 | import stereonet.utils as utils 20 | 21 | 22 | class StereoNet(pl.LightningModule): 23 | """ 24 | StereoNet model. During training, takes in a torch.Tensor dimensions [batch, left/right/disp_left/disp_right channels, height, width]. 25 | At inference, (ie. calling the forward method), only the predicted left disparity is returned. 26 | 27 | Trained with RMSProp + Exponentially decaying learning rate scheduler. 28 | """ 29 | 30 | def __init__(self, in_channels: int, 31 | k_downsampling_layers: int = 3, 32 | k_refinement_layers: int = 3, 33 | candidate_disparities: int = 256, 34 | feature_extractor_filters: int = 32, 35 | cost_volumizer_filters: int = 32, 36 | mask: bool = True, 37 | optimizer_partial: Optional[Callable[[torch.nn.Module], torch.optim.Optimizer]] = None, 38 | scheduler_partial: Optional[Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler]] = None) -> None: 39 | super().__init__() 40 | self.save_hyperparameters() 41 | 42 | self.in_channels = in_channels 43 | self.k_downsampling_layers = k_downsampling_layers 44 | self.k_refinement_layers = k_refinement_layers 45 | self.candidate_disparities = candidate_disparities 46 | self.mask = mask 47 | 48 | self.feature_extractor_filters = feature_extractor_filters 49 | self.cost_volumizer_filters = cost_volumizer_filters 50 | 51 | self._max_downsampled_disps = (candidate_disparities+1) // (2**k_downsampling_layers) 52 | 53 | # Feature network 54 | self.feature_extractor = FeatureExtractor(in_channels=in_channels, out_channels=self.feature_extractor_filters, k_downsampling_layers=self.k_downsampling_layers) 55 | 56 | # Cost volume 57 | self.cost_volumizer = CostVolume(in_channels=self.feature_extractor_filters, out_channels=self.cost_volumizer_filters, max_downsampled_disps=self._max_downsampled_disps) 58 | 59 | # Hierarchical Refinement: Edge-Aware Upsampling 60 | self.refiners = nn.ModuleList() 61 | for _ in range(self.k_refinement_layers): 62 | self.refiners.append(Refinement(in_channels=in_channels+1)) 63 | 64 | self.optimizer_partial = optimizer_partial 65 | self.scheduler_partial = scheduler_partial 66 | 67 | def forward_pyramid(self, sample: torch.Tensor, side: str = 'left') -> List[torch.Tensor]: 68 | """ 69 | This is the heart of the forward pass. Given a torch.Tensor of shape [Batch, left/right, Height, Width], perform the feature extraction, cost volume estimation, cascading 70 | refiners to return a list of the disparities. First entry of the returned list is the lowest resolution while the last is the full resolution disparity. 71 | 72 | For clarity, the zeroth element of the first dimension is the left image and the first element of the first dimension is the right image. 73 | 74 | The idea with reference/shifting is that when computing the cost volume, one image is effectively held stationary while the other image 75 | sweeps across. If the provided tuple (x) is (left/right) stereo pair with the argument side='left', then the stationary image will be the left 76 | image and the sweeping image will be the right image and vice versa. 77 | """ 78 | if side == 'left': 79 | reference = sample[:, :self.in_channels, ...] 80 | shifting = sample[:, self.in_channels:self.in_channels*2, ...] 81 | elif side == 'right': 82 | reference = sample[:, self.in_channels:self.in_channels*2, ...] 83 | shifting = sample[:, :self.in_channels, ...] 84 | 85 | reference_embedding = self.feature_extractor(reference) 86 | shifting_embedding = self.feature_extractor(shifting) 87 | 88 | cost = self.cost_volumizer((reference_embedding, shifting_embedding), side=side) 89 | 90 | disparity_pyramid = [soft_argmin(cost, self.candidate_disparities)] 91 | 92 | for idx, refiner in enumerate(self.refiners, start=1): 93 | scale = (2**self.k_refinement_layers) / (2**idx) 94 | new_h, new_w = int(reference.size()[2]//scale), int(reference.size()[3]//scale) 95 | reference_rescaled = F.interpolate(reference, [new_h, new_w], mode='bilinear', align_corners=True) 96 | disparity_low_rescaled = F.interpolate(disparity_pyramid[-1], [new_h, new_w], mode='bilinear', align_corners=True) 97 | refined_disparity = F.relu(refiner(torch.cat((reference_rescaled, disparity_low_rescaled), dim=1)) + disparity_low_rescaled) 98 | disparity_pyramid.append(refined_disparity) 99 | 100 | return disparity_pyramid 101 | 102 | def forward(self, batch: torch.Tensor) -> torch.Tensor: 103 | """ 104 | Do the forward pass using forward_pyramid (for the left disparity map) and return only the full resolution map. 105 | """ 106 | disparities = self.forward_pyramid(batch, side='left') 107 | return disparities[-1] # Ultimately, only output the last refined disparity 108 | 109 | def training_step(self, batch: torch.Tensor, _: int) -> torch.Tensor: 110 | """ 111 | Compute the disparities for both the left and right volumes then compute the loss for each. Finally take the mean between the two losses and 112 | return that as the final loss. 113 | 114 | Log at each step the Robust Loss and log the L1 loss (End-point-error) at each epoch. 115 | """ 116 | 117 | height, width = batch.size()[-2:] 118 | 119 | # Non-uniform because the sizes of each of the list entries returned from the forward_pyramid aren't the same 120 | disp_pred_left_nonuniform = self.forward_pyramid(batch, side='left') 121 | disp_pred_right_nonuniform = self.forward_pyramid(batch, side='right') 122 | 123 | for idx, (disparity_left, disparity_right) in enumerate(zip(disp_pred_left_nonuniform, disp_pred_right_nonuniform)): 124 | disp_pred_left_nonuniform[idx] = F.interpolate(disparity_left, [height, width], mode='bilinear', align_corners=True) 125 | disp_pred_right_nonuniform[idx] = F.interpolate(disparity_right, [height, width], mode='bilinear', align_corners=True) 126 | 127 | disp_pred_left = torch.stack(disp_pred_left_nonuniform, dim=0) 128 | disp_pred_right = torch.stack(disp_pred_right_nonuniform, dim=0) 129 | 130 | def _tiler(tensor: torch.Tensor, matching_size: Optional[List[int]] = None) -> torch.Tensor: 131 | if matching_size is None: 132 | matching_size = [disp_pred_left.size()[0], 1, 1, 1, 1] 133 | return tensor.tile(matching_size) 134 | 135 | disp_gt_left = _tiler(batch[:, -2, ...]) 136 | disp_gt_right = _tiler(batch[:, -1, ...]) 137 | 138 | if self.mask: 139 | left_mask = (disp_gt_left < self.candidate_disparities).detach() 140 | right_mask = (disp_gt_right < self.candidate_disparities).detach() 141 | 142 | loss_left = torch.mean(robust_loss(disp_gt_left[left_mask] - disp_pred_left[left_mask], alpha=1, c=2)) 143 | loss_right = torch.mean(robust_loss(disp_gt_right[right_mask] - disp_pred_right[right_mask], alpha=1, c=2)) 144 | else: 145 | loss_left = torch.mean(robust_loss(disp_gt_left - disp_pred_left, alpha=1, c=2)) 146 | loss_right = torch.mean(robust_loss(disp_gt_right - disp_pred_right, alpha=1, c=2)) 147 | 148 | loss = (loss_left + loss_right) / 2 149 | 150 | self.log("train_loss_step", loss, on_step=True, on_epoch=False, prog_bar=True, logger=True) 151 | self.log("train_loss_epoch", F.l1_loss(disp_pred_left[-1], disp_gt_left[-1]), on_step=False, on_epoch=True, prog_bar=False, logger=True) 152 | return loss 153 | 154 | def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None: 155 | """ 156 | Compute the L1 loss (End-point-error) over the validation set for the left disparity map. 157 | 158 | Log a figure of the left/right RGB images and the grount truth disparity + predicted disparity to the logger. 159 | """ 160 | disp_pred = self(batch[:, :self.in_channels*2, ...]) 161 | disp_gt = batch[:, self.in_channels*2:self.in_channels*2+1, ...] 162 | 163 | loss = F.l1_loss(disp_pred, disp_gt) 164 | self.log("val_loss_epoch", loss, on_epoch=True, logger=True) 165 | if batch_idx == 0: 166 | fig = utils.plot_figure(batch[0, :self.in_channels, ...].detach().cpu(), 167 | batch[0, self.in_channels:self.in_channels*2, ...].detach().cpu(), 168 | batch[0, -2:-1, ...].detach().cpu(), 169 | disp_pred[0].detach().cpu()) 170 | fig.canvas.draw() 171 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 172 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 173 | data = np.moveaxis(data, 2, 0) 174 | 175 | tensorboard: pl_loggers.TensorBoardLogger = self.logger 176 | tensorboard.experiment.add_image("generated_images", data, self.current_epoch) 177 | 178 | def configure_optimizers(self) -> Dict[str, Any]: 179 | """ 180 | RMSProp optimizer + Exponentially decaying learning rate. 181 | Original authors trained with a batch size of 1 and a decaying learning rate. If we randomly crop down the image 182 | to, lets say, a rescale factor of 2, 1/2 width 1/2 height, (total reduction of 2**2) then each epoch will only train on 1/4 the number of 183 | image patches. Therefore, to keep the learning rate similar, delay the decay by the square of the rescale factor. 184 | """ 185 | if self.optimizer_partial is None: 186 | raise Exception("Need to provide optimizer arguments.") 187 | 188 | optimizer = self.optimizer_partial(self.parameters()) 189 | config: Dict[str, Any] = {'optimizer': optimizer} 190 | 191 | if self.scheduler_partial is not None: 192 | scheduler = self.scheduler_partial(optimizer) 193 | lr_dict = {"scheduler": scheduler, 194 | "interval": "epoch", 195 | "frequency": 1, 196 | "name": "ExponentialDecayLR"} 197 | config["lr_scheduler"] = lr_dict 198 | 199 | return config 200 | 201 | 202 | class FeatureExtractor(torch.nn.Module): 203 | """ 204 | Feature extractor network with 'K' downsampling layers. Refer to the original paper for full discussion. 205 | """ 206 | 207 | def __init__(self, in_channels: int, out_channels: int, k_downsampling_layers: int): 208 | super().__init__() 209 | self.k = k_downsampling_layers 210 | 211 | net: OrderedDict[str, nn.Module] = OrderedDict() 212 | 213 | for block_idx in range(self.k): 214 | net[f'segment_0_conv_{block_idx}'] = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=5, stride=2, padding=2) 215 | in_channels = out_channels 216 | 217 | for block_idx in range(6): 218 | net[f'segment_1_res_{block_idx}'] = ResBlock(in_channels=out_channels, out_channels=out_channels, kernel_size=3, padding=1) 219 | 220 | net['segment_2_conv_0'] = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1) 221 | 222 | self.net = nn.Sequential(net) 223 | 224 | def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name, missing-function-docstring 225 | x = self.net(x) 226 | return x 227 | 228 | 229 | class CostVolume(torch.nn.Module): 230 | """ 231 | Computes the cost volume and filters it using the 3D convolutional network. Refer to original paper for a full discussion. 232 | """ 233 | 234 | def __init__(self, in_channels: int, out_channels: int, max_downsampled_disps: int): 235 | super().__init__() 236 | 237 | self.in_channels = in_channels 238 | self.out_channels = out_channels 239 | 240 | self._max_downsampled_disps = max_downsampled_disps 241 | 242 | net: OrderedDict[str, nn.Module] = OrderedDict() 243 | 244 | for block_idx in range(4): 245 | net[f'segment_0_conv_{block_idx}'] = nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1) 246 | net[f'segment_0_bn_{block_idx}'] = nn.BatchNorm3d(num_features=out_channels) 247 | net[f'segment_0_act_{block_idx}'] = nn.LeakyReLU(negative_slope=0.2) # Not clear in paper if default or implied to be 0.2 like the rest 248 | 249 | in_channels = out_channels 250 | 251 | net['segment_1_conv_0'] = nn.Conv3d(in_channels=out_channels, out_channels=1, kernel_size=3, padding=1) 252 | 253 | self.net = nn.Sequential(net) 254 | 255 | def forward(self, x: Tuple[torch.Tensor, torch.Tensor], side: str = 'left') -> torch.Tensor: # pylint: disable=invalid-name 256 | """ 257 | The cost volume effectively holds one of the left/right images constant (albeit clipping) and computes the difference with a 258 | shifting (left/right) portion of the corresponding image. By default, this method holds the left image stationary and sweeps the right image. 259 | 260 | To compute the cost volume for holding the right image stationary and sweeping the left image, use side='right'. 261 | """ 262 | reference_embedding, target_embedding = x 263 | 264 | cost = compute_volume(reference_embedding, target_embedding, max_downsampled_disps=self._max_downsampled_disps, side=side) 265 | 266 | cost = self.net(cost) 267 | cost = torch.squeeze(cost, dim=1) 268 | 269 | return cost 270 | 271 | 272 | def compute_volume(reference_embedding: torch.Tensor, target_embedding: torch.Tensor, max_downsampled_disps: int, side: str = 'left') -> torch.Tensor: 273 | """ 274 | Refer to the doc string in CostVolume.forward. 275 | Refer to https://github.com/meteorshowers/X-StereoLab/blob/9ae8c1413307e7df91b14a7f31e8a95f9e5754f9/disparity/models/stereonet_disp.py 276 | 277 | This difference based cost volume is also reflected in an implementation of the popular DispNetCorr: 278 | Line 81 https://github.com/wyf2017/DSMnet/blob/b61652dfb3ee84b996f0ad4055eaf527dc6b965f/models/util_conv.py 279 | """ 280 | batch, channel, height, width = reference_embedding.size() 281 | cost = torch.Tensor(batch, channel, max_downsampled_disps, height, width).zero_() 282 | cost = cost.type_as(reference_embedding) # PyTorch Lightning handles the devices 283 | cost[:, :, 0, :, :] = reference_embedding - target_embedding 284 | for idx in range(1, max_downsampled_disps): 285 | if side == 'left': 286 | cost[:, :, idx, :, idx:] = reference_embedding[:, :, :, idx:] - target_embedding[:, :, :, :-idx] 287 | if side == 'right': 288 | cost[:, :, idx, :, :-idx] = reference_embedding[:, :, :, :-idx] - target_embedding[:, :, :, idx:] 289 | cost = cost.contiguous() 290 | 291 | return cost 292 | 293 | 294 | class Refinement(torch.nn.Module): 295 | """ 296 | Several of these classes will be instantiated to perform the *cascading* refinement. Refer to the original paper for a full discussion. 297 | """ 298 | 299 | def __init__(self, in_channels: int = 3) -> None: 300 | super().__init__() 301 | 302 | dilations = [1, 2, 4, 8, 1, 1] 303 | 304 | net: OrderedDict[str, nn.Module] = OrderedDict() 305 | 306 | net['segment_0_conv_0'] = nn.Conv2d(in_channels=in_channels, out_channels=32, kernel_size=3, padding=1) 307 | 308 | for block_idx, dilation in enumerate(dilations): 309 | net[f'segment_1_res_{block_idx}'] = ResBlock(in_channels=32, out_channels=32, kernel_size=3, padding=dilation, dilation=dilation) 310 | 311 | net['segment_2_conv_0'] = nn.Conv2d(in_channels=32, out_channels=1, kernel_size=3, padding=1) 312 | 313 | self.net = nn.Sequential(net) 314 | 315 | def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name, missing-function-docstring 316 | x = self.net(x) 317 | return x 318 | 319 | 320 | class ResBlock(torch.nn.Module): 321 | """ 322 | Just a note, in the original paper, there is no discussion about padding; however, both the ZhiXuanLi and the X-StereoLab implementation using padding. 323 | This does make sense to maintain the image size after the feature extraction has occured. 324 | 325 | X-StereoLab uses a simple Res unit with a single conv and summation while ZhiXuanLi uses the original residual unit implementation. 326 | This class also uses the original implementation with 2 layers of convolutions. 327 | """ 328 | 329 | def __init__(self, 330 | in_channels: int, 331 | out_channels: int, 332 | kernel_size: int, 333 | stride: int = 1, 334 | padding: int = 0, 335 | dilation: int = 1): 336 | super().__init__() 337 | 338 | self.conv_1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False) 339 | self.batch_norm_1 = nn.BatchNorm2d(num_features=out_channels) 340 | self.activation_1 = nn.LeakyReLU(negative_slope=0.2) 341 | 342 | self.conv_2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=False) 343 | self.batch_norm_2 = nn.BatchNorm2d(num_features=out_channels) 344 | self.activation_2 = nn.LeakyReLU(negative_slope=0.2) 345 | 346 | def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name 347 | """ 348 | Original Residual Unit: https://arxiv.org/pdf/1603.05027.pdf (Fig 1. Left) 349 | """ 350 | 351 | res = self.conv_1(x) 352 | res = self.batch_norm_1(res) 353 | res = self.activation_1(res) 354 | res = self.conv_2(res) 355 | res = self.batch_norm_2(res) 356 | 357 | # I'm not really sure why the type definition is required here... nn.Conv2d already returns type Tensor... 358 | # So res should be of type torch.Tensor AND x is already defined as type torch.Tensor. 359 | out: torch.Tensor = res + x 360 | out = self.activation_2(out) 361 | 362 | return out 363 | 364 | 365 | def soft_argmin(cost: torch.Tensor, max_downsampled_disps: int) -> torch.Tensor: 366 | """ 367 | Soft argmin function described in the original paper. The disparity grid creates the first 'd' value in equation 2 while 368 | cost is the C_i(d) term. The exp/sum(exp) == softmax function. 369 | """ 370 | disparity_softmax = F.softmax(-cost, dim=1) 371 | # TODO: Bilinear interpolate the disparity dimension back to D to perform the proper d*exp(-C_i(d)) 372 | 373 | disparity_grid = torch.linspace(0, max_downsampled_disps, disparity_softmax.size(1)).reshape(1, -1, 1, 1) 374 | disparity_grid = disparity_grid.type_as(disparity_softmax) 375 | 376 | disp = torch.sum(disparity_softmax * disparity_grid, dim=1, keepdim=True) 377 | 378 | return disp 379 | 380 | 381 | def robust_loss(x: torch.Tensor, alpha: float, c: float) -> torch.Tensor: # pylint: disable=invalid-name 382 | """ 383 | A General and Adaptive Robust Loss Function (https://arxiv.org/abs/1701.03077) 384 | """ 385 | f: torch.Tensor = (abs(alpha - 2) / alpha) * (torch.pow(torch.pow(x / c, 2)/abs(alpha - 2) + 1, alpha/2) - 1) # pylint: disable=invalid-name 386 | return f 387 | -------------------------------------------------------------------------------- /src/stereonet/py.typed: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewlstewart/StereoNet_PyTorch/3cb3a0215a3c4690d74d93ae8973feaaf9f65bc1/src/stereonet/py.typed -------------------------------------------------------------------------------- /src/stereonet/train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to instantiate a StereoNet model + train on the SceneFlow or KeystoneDepth dataset. 3 | """ 4 | 5 | from typing import Optional 6 | 7 | import hydra 8 | 9 | import lightning.pytorch as pl 10 | from lightning.pytorch.loggers import TensorBoardLogger 11 | from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor 12 | 13 | from stereonet.model import StereoNet 14 | import stereonet.datasets as std 15 | import stereonet.types as stt 16 | 17 | 18 | @hydra.main(version_base=None, config_name="config") 19 | def main(cfg: stt.StereoNetConfig) -> int: 20 | config: stt.StereoNetConfig = hydra.utils.instantiate(cfg, _convert_="all")['stereonet_config'] 21 | 22 | if config.training is None: 23 | raise Exception("Need to provide training arguments to train the model.") 24 | 25 | if config.training.random_seed is not None: 26 | pl.seed_everything(config.training.random_seed) 27 | 28 | checkpoint_path: Optional[str] = None 29 | # Instantiate model with built in optimizer 30 | if isinstance(config.model, stt.CheckpointModel): 31 | checkpoint_path = config.model.model_checkpoint_path 32 | model = StereoNet.load_from_checkpoint(checkpoint_path) 33 | elif isinstance(config.model, stt.StereoNetModel): 34 | model = StereoNet(in_channels=config.model.in_channels, 35 | k_downsampling_layers=config.model.k_downsampling_layers, 36 | k_refinement_layers=config.model.k_refinement_layers, 37 | candidate_disparities=config.model.candidate_disparities, 38 | mask=config.training.mask, 39 | optimizer_partial=config.training.optimizer_partial, 40 | scheduler_partial=config.training.scheduler_partial) 41 | else: 42 | raise Exception("Unknown model type") 43 | 44 | # Get datasets 45 | train_loader = std.construct_dataloaders(data_config=config.training, 46 | is_training=True, 47 | shuffle=True, num_workers=8, drop_last=False) 48 | 49 | val_loader = None 50 | if config.validation is not None: 51 | val_loader = std.construct_dataloaders(data_config=config.validation, 52 | is_training=False, 53 | shuffle=False, num_workers=8, drop_last=False) 54 | 55 | checkpoint_callback = ModelCheckpoint(monitor='val_loss_epoch', save_top_k=-1, mode='min') 56 | 57 | lr_monitor = LearningRateMonitor(logging_interval='epoch') 58 | logger = TensorBoardLogger(save_dir=config.logging.lightning_log_root, name="lightning_logs") 59 | trainer = pl.Trainer(devices=config.global_settings.devices.num, 60 | accelerator=config.global_settings.devices.type, 61 | min_epochs=config.training.min_epochs, 62 | max_epochs=config.training.max_epochs, 63 | logger=logger, 64 | callbacks=[lr_monitor, checkpoint_callback], 65 | deterministic=config.training.deterministic) 66 | 67 | trainer.fit(model=model, train_dataloaders=train_loader, val_dataloaders=val_loader, ckpt_path=checkpoint_path) 68 | 69 | return 0 70 | 71 | 72 | if __name__ == "__main__": 73 | raise SystemExit(main()) 74 | -------------------------------------------------------------------------------- /src/stereonet/types.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, List, Any, Callable 2 | 3 | from abc import ABC 4 | from dataclasses import dataclass 5 | 6 | import torch.nn 7 | import torch.optim 8 | 9 | 10 | @dataclass 11 | class Run: 12 | dir: str 13 | 14 | 15 | @dataclass 16 | class Hydra: 17 | run: Run 18 | 19 | 20 | @dataclass 21 | class Devices: 22 | num: int 23 | type: str 24 | 25 | 26 | @dataclass 27 | class GlobalSettings: 28 | devices: Devices 29 | 30 | 31 | @dataclass 32 | class Logging: 33 | lightning_log_root: str 34 | 35 | 36 | @dataclass 37 | class Model(ABC): 38 | ... 39 | 40 | 41 | @dataclass 42 | class CheckpointModel(Model): 43 | model_checkpoint_path: str 44 | 45 | 46 | @dataclass 47 | class StereoNetModel(Model): 48 | in_channels: int 49 | k_downsampling_layers: int 50 | k_refinement_layers: int 51 | candidate_disparities: int 52 | 53 | 54 | @dataclass 55 | class Loader: 56 | batch_size: int 57 | 58 | 59 | # https://stackoverflow.com/a/69822584 60 | class Data(ABC): 61 | def __init__(self, root_path: str, 62 | max_size: Optional[List[int]] = None, 63 | transforms: Optional[List[torch.nn.Module]] = None): 64 | self.root_path = root_path 65 | self.max_size = max_size 66 | self.transforms = transforms 67 | 68 | 69 | class SceneflowData(Data): 70 | def __init__(self, *args: Any, **kwargs: Any): 71 | super().__init__(*args, **kwargs) 72 | 73 | 74 | class KeystoneDepthData(Data): 75 | def __init__(self, split_ratio: float, *args: Any, **kwargs: Any): 76 | super().__init__(*args, **kwargs) 77 | self.split_ratio = split_ratio 78 | 79 | 80 | @dataclass 81 | class DataDebug: 82 | enabled: bool 83 | limit_train_batches: int 84 | 85 | 86 | @dataclass 87 | class Training: 88 | min_epochs: int 89 | max_epochs: int 90 | mask: bool 91 | data: List[Data] 92 | loader: Loader 93 | debug: DataDebug 94 | optimizer_partial: Callable[[torch.nn.Module], torch.optim.Optimizer] 95 | scheduler_partial: Optional[Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler.LRScheduler]] = None 96 | random_seed: Optional[int] = None 97 | deterministic: Optional[bool] = None 98 | fast_dev_run: bool = False 99 | 100 | 101 | @dataclass 102 | class Validation: 103 | data: List[Data] 104 | loader: Loader 105 | 106 | 107 | @dataclass 108 | class StereoNetConfig: 109 | global_settings: GlobalSettings 110 | logging: Logging 111 | model: Model 112 | training: Optional[Training] = None 113 | validation: Optional[Validation] = None 114 | -------------------------------------------------------------------------------- /src/stereonet/utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Helper functions for StereoNet training. 3 | 4 | Includes a dataset object for the Scene Flow image and disparity dataset. 5 | """ 6 | 7 | import torch 8 | import matplotlib.pyplot as plt 9 | 10 | 11 | def plot_figure(left: torch.Tensor, right: torch.Tensor, disp_gt: torch.Tensor, disp_pred: torch.Tensor) -> plt.figure: 12 | """ 13 | Helper function to plot the left/right image pair from the dataset (ie. normalized between -1/+1 and c,h,w) and the 14 | ground truth disparity and the predicted disparity. The disparities colour range between ground truth disparity min and max. 15 | """ 16 | plt.close('all') 17 | fig, ax = plt.subplots(ncols=2, nrows=2) 18 | left = (torch.moveaxis(left, 0, 2) + 1) / 2 19 | right = (torch.moveaxis(right, 0, 2) + 1) / 2 20 | disp_gt = torch.moveaxis(disp_gt, 0, 2) 21 | disp_pred = torch.moveaxis(disp_pred, 0, 2) 22 | ax[0, 0].imshow(left) 23 | ax[0, 1].imshow(right) 24 | ax[1, 0].imshow(disp_gt, vmin=disp_gt.min(), vmax=disp_gt.max()) 25 | im = ax[1, 1].imshow(disp_pred, vmin=disp_gt.min(), vmax=disp_gt.max()) 26 | ax[0, 0].title.set_text('Left') 27 | ax[0, 1].title.set_text('Right') 28 | ax[1, 0].title.set_text('Ground truth disparity') 29 | ax[1, 1].title.set_text('Predicted disparity') 30 | fig.subplots_adjust(right=0.8) 31 | cbar_ax = fig.add_axes([0.85, 0.15, 0.05, 0.27]) 32 | fig.colorbar(im, cax=cbar_ax) 33 | return fig 34 | -------------------------------------------------------------------------------- /src/stereonet/utils_io.py: -------------------------------------------------------------------------------- 1 | """ 2 | """ 3 | 4 | from typing import Tuple, Union 5 | from pathlib import Path 6 | import re 7 | import os 8 | 9 | import numpy as np 10 | import numpy.typing as npt 11 | from skimage import io 12 | os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" 13 | import cv2 # noqa: E402 14 | 15 | 16 | def pfm_loader(path: Union[Path, str]) -> Tuple[npt.NDArray[np.float32], float]: 17 | """ 18 | This function was entirely written by the the Freiburg group 19 | https://lmb.informatik.uni-freiburg.de/resources/datasets/IO.py 20 | 21 | Read in a PFM formated file and return a image/disparity 22 | """ 23 | with open(path, 'rb') as file: 24 | header = file.readline().rstrip() 25 | if header.decode("ascii") == 'PF': 26 | color = True 27 | elif header.decode("ascii") == 'Pf': 28 | color = False 29 | else: 30 | raise Exception('Not a PFM file.') 31 | 32 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 33 | if dim_match: 34 | width, height = list(map(int, dim_match.groups())) 35 | else: 36 | raise Exception('Malformed PFM header.') 37 | 38 | scale = float(file.readline().decode("ascii").rstrip()) 39 | if scale < 0: # little-endian 40 | endian = '<' 41 | scale = -scale 42 | else: 43 | endian = '>' # big-endian 44 | 45 | data: npt.NDArray[np.float32] = np.fromfile(file, endian + 'f') 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | return data, scale 51 | 52 | 53 | def image_loader(path: Union[Path, str]) -> npt.NDArray[np.uint8]: 54 | """ 55 | Load an image from a path using skimage.io and return a np.uint8 numpy array. 56 | """ 57 | img: npt.NDArray[np.uint8] = io.imread(path) 58 | if img.ndim == 2: 59 | img = np.expand_dims(img, axis=2) 60 | return img 61 | 62 | 63 | def exr_loader(path: Union[Path, str]) -> npt.NDArray[np.float32]: 64 | """ 65 | Load an image from a path using opencv and return a np.float32 numpy array. 66 | """ 67 | img: npt.NDArray[np.float32] = cv2.imread(path, cv2.IMREAD_ANYDEPTH) 68 | if img.ndim == 2: 69 | img = np.expand_dims(img, axis=2) 70 | return img 71 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andrewlstewart/StereoNet_PyTorch/3cb3a0215a3c4690d74d93ae8973feaaf9f65bc1/tests/__init__.py -------------------------------------------------------------------------------- /tests/conftest.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | 3 | import pytest 4 | 5 | from stereonet.model import StereoNet 6 | 7 | 8 | @pytest.fixture(scope="session") 9 | def models() -> List[Tuple[StereoNet, int]]: 10 | """ 11 | Instantiates a bunch of models so we can test their output sizes. 12 | """ 13 | models = [ 14 | (StereoNet(in_channels=3, k_downsampling_layers=3, k_refinement_layers=1), 398978), # Tuple of (model, number of trainable parameters) 15 | (StereoNet(in_channels=3, k_downsampling_layers=4, k_refinement_layers=1), 424610), 16 | (StereoNet(in_channels=3, k_downsampling_layers=3, k_refinement_layers=3), 624644), 17 | (StereoNet(in_channels=3, k_downsampling_layers=4, k_refinement_layers=3), 650276) 18 | ] 19 | 20 | for (model, _) in models: 21 | model.eval() 22 | 23 | return models 24 | -------------------------------------------------------------------------------- /tests/test_model.py: -------------------------------------------------------------------------------- 1 | """ 2 | Test suite for instantiating StereoNet models and performing simple forward passes. 3 | """ 4 | 5 | from typing import List, Tuple 6 | 7 | import torch 8 | 9 | from stereonet.model import StereoNet 10 | 11 | 12 | def test_model_trainable_parameters(models: List[Tuple[StereoNet, int]]): 13 | """ 14 | Test to see if the number of trainable parameters matches the expected number. 15 | """ 16 | for model, n_params in models: 17 | assert (count_parameters(model) == n_params) 18 | 19 | 20 | def test_forward_sizes(models: List[Tuple[StereoNet, int]]): 21 | """ 22 | Test to see if each of the networks produces the correct shape. 23 | """ 24 | input_data = torch.rand((2, 6, 540, 960)) 25 | 26 | with torch.no_grad(): 27 | for model, _ in models: 28 | assert (model(input_data).size() == (2, 1, 540, 960)) 29 | 30 | 31 | def count_parameters(model: StereoNet) -> int: 32 | """ 33 | Counts the number of trainable parameters in a torch model 34 | https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/9 35 | """ 36 | return sum(p.numel() for p in model.parameters() if p.requires_grad) 37 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | minversion = 3.24.4 3 | envlist = py38, flake8, mypy 4 | isolated_build = true 5 | 6 | [gh-actions] 7 | python = 8 | 3.8: mypy, py38 9 | 10 | [testenv] 11 | setenv = 12 | PYTHONPATH = {toxinidir} 13 | deps = 14 | -r{toxinidir}/requirements_dev.txt 15 | commands = 16 | pytest --basetemp={envtmpdir} 17 | 18 | [testenv:flake8] 19 | basepython = python3.8 20 | deps = flake8 21 | commands = flake8 src tests 22 | 23 | [testenv:mypy] 24 | basepython = python3.8 25 | deps = 26 | -r{toxinidir}/requirements_dev.txt 27 | commands = mypy src 28 | 29 | --------------------------------------------------------------------------------