├── ezflow ├── __init__.py ├── utils │ ├── viz.py │ ├── __init__.py │ ├── warp.py │ ├── metrics.py │ ├── registry.py │ └── resampling.py ├── modules │ ├── models │ │ ├── __init__.py │ │ └── recurrent.py │ ├── __init__.py │ ├── base_module.py │ ├── build.py │ └── dap.py ├── decoder │ ├── iterative │ │ └── __init__.py │ ├── noniterative │ │ ├── __init__.py │ │ └── operators.py │ ├── __init__.py │ ├── build.py │ └── context.py ├── model_zoo │ ├── __init__.py │ └── model_zoo.py ├── config │ ├── __init__.py │ └── retrieve.py ├── data │ ├── dataloader │ │ ├── __init__.py │ │ └── device_dataloader.py │ ├── __init__.py │ ├── dataset │ │ ├── __init__.py │ │ ├── hd1k.py │ │ ├── autoflow.py │ │ └── kitti.py │ └── build.py ├── functional │ ├── data_augmentation │ │ └── __init__.py │ ├── criterion │ │ ├── __init__.py │ │ └── sequence.py │ ├── __init__.py │ ├── registry.py │ ├── scheduler.py │ └── weight_annealers.py ├── similarity │ ├── __init__.py │ ├── correlation │ │ ├── __init__.py │ │ ├── layer.py │ │ └── pairwise.py │ └── build.py ├── engine │ ├── __init__.py │ ├── pruning.py │ ├── registry.py │ ├── retrieve.py │ └── profiler.py ├── encoder │ ├── __init__.py │ ├── build.py │ ├── pyramid.py │ ├── dcvnet.py │ └── ganet.py └── models │ ├── __init__.py │ ├── build.py │ ├── flownet_s.py │ ├── pwcnet.py │ ├── dcvnet.py │ ├── flownet_c.py │ ├── raft.py │ └── predictor.py ├── tests ├── __init__.py ├── utils │ ├── __init__.py │ ├── mock_model.py │ └── mock_data.py ├── configs │ ├── custom_loss_trainer.yaml │ └── base_trainer_test.yaml ├── test_similarity.py ├── test_utils.py ├── test_models.py └── test_modules.py ├── docs ├── authors.rst ├── readme.rst ├── contributing.rst ├── assets │ ├── logo.png │ └── logo_name.png ├── api │ ├── data │ │ ├── index.rst │ │ ├── ezflow.data.dataloader.rst │ │ └── ezflow.data.dataset.rst │ ├── modules │ │ ├── ezflow.modules.units.rst │ │ ├── ezflow.modules.blocks.rst │ │ ├── ezflow.modules.build.rst │ │ ├── ezflow.modules.dap.rst │ │ ├── ezflow.modules.models.rst │ │ └── index.rst │ ├── decoder │ │ ├── ezflow.decoder.builder.rst │ │ ├── ezflow.decoder.conv_decoder.rst │ │ ├── ezflow.decoder.separable_conv.rst │ │ ├── ezflow.decoder.iterative.rst │ │ ├── index.rst │ │ └── ezflow.decoder.noniterative.rst │ ├── similarity │ │ ├── ezflow.similarity.builder.rst │ │ ├── ezflow.similarity.learnable_cost.rst │ │ ├── index.rst │ │ └── ezflow.similarity.correlation.rst │ ├── functional │ │ ├── ezflow.functional.registry.rst │ │ ├── ezflow.functional.scheduler.rst │ │ ├── index.rst │ │ ├── ezflow.functional.criterion.rst │ │ └── ezflow.functional.data_augmentation.rst │ ├── config │ │ └── index.rst │ ├── engine │ │ └── index.rst │ ├── encoder │ │ └── index.rst │ ├── utils │ │ └── index.rst │ └── models │ │ └── index.rst ├── tutorials │ ├── index.rst │ └── using.rst ├── index.rst ├── Makefile ├── make.bat ├── installation.rst └── requirements.txt ├── .flake8 ├── AUTHORS.rst ├── .isort.cfg ├── configs ├── models │ ├── flownet_s.yaml │ ├── pwcnet.yaml │ ├── flownet_c.yaml │ ├── vcn.yaml │ ├── raft.yaml │ ├── raft_small.yaml │ ├── dicl.yaml │ └── dcvnet.yaml └── trainers │ ├── pwcnet │ ├── pwcnet_chairs_baseline.yaml │ ├── pwcnet_kubric_improved_aug.yaml │ └── pwcnet_things_baseline.yaml │ ├── flownetc │ ├── flownetc_chairs_baseline.yaml │ ├── flownetc_kubric_improved_aug.yaml │ └── flownetc_things_baseline.yaml │ ├── raft │ ├── raft_chairs_baseline.yaml │ ├── raft_kubric_improved_aug.yaml │ └── raft_things_baseline.yaml │ ├── dcvnet │ └── dcvnet_sceneflow_baseline.yaml │ └── _base_ │ ├── chairs_baseline.yaml │ ├── kubric_baseline.yaml │ ├── things_baseline.yaml │ ├── kubric_improved_aug.yaml │ └── sceneflow_baseline.yaml ├── .codecov.yml ├── MANIFEST.in ├── .editorconfig ├── .pre-commit-config.yaml ├── tox.ini ├── CITATION.cff ├── setup.cfg ├── .coveragerc ├── .readthedocs.yaml ├── README.rst ├── .github └── workflows │ ├── linting.yml │ ├── publish.yml │ ├── package-test.yml │ └── codecov.yml ├── tools ├── evaluate.py └── train.py ├── LICENSE ├── .gitignore ├── requirements.txt ├── Makefile ├── generate_dir_structure.py ├── CONTRIBUTING.rst ├── setup.py └── environment.yml /ezflow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /tests/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ezflow/utils/viz.py: -------------------------------------------------------------------------------- 1 | pass 2 | -------------------------------------------------------------------------------- /docs/authors.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../AUTHORS.rst 2 | -------------------------------------------------------------------------------- /docs/readme.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../README.rst 2 | -------------------------------------------------------------------------------- /docs/contributing.rst: -------------------------------------------------------------------------------- 1 | .. include:: ../CONTRIBUTING.rst 2 | -------------------------------------------------------------------------------- /.flake8: -------------------------------------------------------------------------------- 1 | [flake8] 2 | # exclude = .git, 3 | max-line-length = 90 -------------------------------------------------------------------------------- /ezflow/modules/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .recurrent import * 2 | -------------------------------------------------------------------------------- /ezflow/decoder/iterative/__init__.py: -------------------------------------------------------------------------------- 1 | from .recurrent_lookup import * 2 | -------------------------------------------------------------------------------- /ezflow/model_zoo/__init__.py: -------------------------------------------------------------------------------- 1 | from .model_zoo import _ModelZooConfigs 2 | -------------------------------------------------------------------------------- /docs/assets/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neu-vi/ezflow/HEAD/docs/assets/logo.png -------------------------------------------------------------------------------- /docs/assets/logo_name.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/neu-vi/ezflow/HEAD/docs/assets/logo_name.png -------------------------------------------------------------------------------- /ezflow/config/__init__.py: -------------------------------------------------------------------------------- 1 | from .config import CfgNode, configurable 2 | from .retrieve import get_cfg, get_cfg_obj, get_cfg_path 3 | -------------------------------------------------------------------------------- /ezflow/data/dataloader/__init__.py: -------------------------------------------------------------------------------- 1 | from .dataloader_creator import DataloaderCreator 2 | from .device_dataloader import DeviceDataLoader 3 | -------------------------------------------------------------------------------- /ezflow/functional/data_augmentation/__init__.py: -------------------------------------------------------------------------------- 1 | from .augmentor import FlowAugmentor, SparseFlowAugmentor 2 | from .operations import * 3 | -------------------------------------------------------------------------------- /ezflow/decoder/noniterative/__init__.py: -------------------------------------------------------------------------------- 1 | from .operators import * 2 | from .soft_regression import Soft4DFlowRegression, SoftArg2DFlowRegression 3 | -------------------------------------------------------------------------------- /tests/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .mock_data import MockDataloaderCreator, MockOpticalFlowDataset 2 | from .mock_model import MockOpticalFlowModel 3 | -------------------------------------------------------------------------------- /ezflow/data/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import DATASET_REGISTRY, build_dataloader, get_dataset_list 2 | from .dataloader import * 3 | from .dataset import * 4 | -------------------------------------------------------------------------------- /ezflow/similarity/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import SIMILARITY_REGISTRY, build_similarity 2 | from .correlation import * 3 | from .learnable_cost import * 4 | -------------------------------------------------------------------------------- /AUTHORS.rst: -------------------------------------------------------------------------------- 1 | ======== 2 | Authors 3 | ======== 4 | 5 | 6 | * Neelay Shah 7 | * Prajnan Goswami 8 | * Huaizu Jiang 9 | 10 | -------------------------------------------------------------------------------- /docs/api/data/index.rst: -------------------------------------------------------------------------------- 1 | Data 2 | ======= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | ezflow.data.dataloader 9 | ezflow.data.dataset 10 | 11 | -------------------------------------------------------------------------------- /docs/tutorials/index.rst: -------------------------------------------------------------------------------- 1 | =========== 2 | Tutorials 3 | =========== 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | using 9 | constructing 10 | training -------------------------------------------------------------------------------- /docs/api/modules/ezflow.modules.units.rst: -------------------------------------------------------------------------------- 1 | Units 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.modules.units 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /ezflow/functional/criterion/__init__.py: -------------------------------------------------------------------------------- 1 | from .multiscale import MultiScaleLoss 2 | from .offset import FlowOffsetLoss, OffsetCrossEntropyLoss 3 | from .sequence import SequenceLoss 4 | -------------------------------------------------------------------------------- /ezflow/similarity/correlation/__init__.py: -------------------------------------------------------------------------------- 1 | from .layer import CorrelationLayer 2 | from .pairwise import MutliScalePairwise4DCorr 3 | from .sampler import IterSpatialCorrelationSampler 4 | -------------------------------------------------------------------------------- /docs/api/decoder/ezflow.decoder.builder.rst: -------------------------------------------------------------------------------- 1 | Builder 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.decoder.build 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/api/modules/ezflow.modules.blocks.rst: -------------------------------------------------------------------------------- 1 | Blocks 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.modules.blocks 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/api/modules/ezflow.modules.build.rst: -------------------------------------------------------------------------------- 1 | Builder 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.modules.build 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/api/similarity/ezflow.similarity.builder.rst: -------------------------------------------------------------------------------- 1 | Builder 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.similarity.build 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/api/functional/ezflow.functional.registry.rst: -------------------------------------------------------------------------------- 1 | Registry 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.functional.registry 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/api/modules/ezflow.modules.dap.rst: -------------------------------------------------------------------------------- 1 | Displacement-aware Projection 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.modules.dap 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /ezflow/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .common import * 2 | from .io import * 3 | from .metrics import * 4 | from .registry import * 5 | from .resampling import * 6 | from .viz import * 7 | from .warp import * 8 | -------------------------------------------------------------------------------- /docs/api/functional/ezflow.functional.scheduler.rst: -------------------------------------------------------------------------------- 1 | Schedulers 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.functional.scheduler 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/api/decoder/ezflow.decoder.conv_decoder.rst: -------------------------------------------------------------------------------- 1 | Convolutional Decoder 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.decoder.conv_decoder 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/api/decoder/ezflow.decoder.separable_conv.rst: -------------------------------------------------------------------------------- 1 | Separable Convolutions 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.decoder.separable_conv 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /docs/api/similarity/ezflow.similarity.learnable_cost.rst: -------------------------------------------------------------------------------- 1 | Learnable Cost 2 | ========================================================= 3 | 4 | .. automodule:: ezflow.similarity.learnable_cost 5 | :members: 6 | 7 | -------------------------------------------------------------------------------- /.isort.cfg: -------------------------------------------------------------------------------- 1 | [settings] 2 | known_third_party = PIL,cv2,fvcore,numpy,pkg_resources,scipy,setuptools,torch,torchvision 3 | multi_line_output=3 4 | include_trailing_comma=True 5 | force_grid_wrap=0 6 | use_parentheses=True 7 | line_length=88 -------------------------------------------------------------------------------- /docs/api/similarity/index.rst: -------------------------------------------------------------------------------- 1 | Similarity 2 | ================= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | ezflow.similarity.correlation 9 | ezflow.similarity.learnable_cost 10 | ezflow.similarity.builder 11 | -------------------------------------------------------------------------------- /docs/api/modules/ezflow.modules.models.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ========================================================= 3 | 4 | Recurrent 5 | --------------------- 6 | 7 | .. automodule:: ezflow.modules.models.recurrent 8 | :members: 9 | 10 | 11 | -------------------------------------------------------------------------------- /ezflow/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .criterion import * 2 | from .data_augmentation import * 3 | from .registry import FUNCTIONAL_REGISTRY, get_functional 4 | from .scheduler import CosineWarmupScheduler 5 | from .weight_annealers import CosineAnnealer, PolyAnnealer 6 | -------------------------------------------------------------------------------- /ezflow/modules/__init__.py: -------------------------------------------------------------------------------- 1 | from .base_module import BaseModule 2 | from .blocks import * 3 | from .build import MODULE_REGISTRY, build_module 4 | from .dap import DisplacementAwareProjection 5 | from .models import * 6 | from .unet import * 7 | from .units import * 8 | -------------------------------------------------------------------------------- /docs/api/modules/index.rst: -------------------------------------------------------------------------------- 1 | Modules 2 | ================= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | ezflow.modules.models 9 | ezflow.modules.blocks 10 | ezflow.modules.units 11 | ezflow.modules.dap 12 | ezflow.modules.build 13 | 14 | 15 | -------------------------------------------------------------------------------- /ezflow/engine/__init__.py: -------------------------------------------------------------------------------- 1 | from .eval import eval_model 2 | from .profiler import Profiler 3 | from .pruning import prune_l1_structured, prune_l1_unstructured 4 | from .registry import * 5 | from .retrieve import get_training_cfg 6 | from .trainer import DistributedTrainer, Trainer 7 | -------------------------------------------------------------------------------- /docs/api/decoder/ezflow.decoder.iterative.rst: -------------------------------------------------------------------------------- 1 | Iterative Deocders 2 | ========================================================= 3 | 4 | Recurrent Lookup 5 | --------------------- 6 | 7 | .. automodule:: ezflow.decoder.iterative.recurrent_lookup 8 | :members: 9 | 10 | 11 | -------------------------------------------------------------------------------- /docs/api/functional/index.rst: -------------------------------------------------------------------------------- 1 | Functional 2 | ================= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | ezflow.functional.criterion 9 | ezflow.functional.data_augmentation 10 | ezflow.functional.registry 11 | ezflow.functional.scheduler 12 | 13 | 14 | -------------------------------------------------------------------------------- /docs/api/config/index.rst: -------------------------------------------------------------------------------- 1 | Config 2 | ======= 3 | 4 | 5 | Config 6 | ------- 7 | 8 | .. automodule:: ezflow.config.config 9 | :members: 10 | 11 | 12 | 13 | Retrieve 14 | --------- 15 | 16 | .. automodule:: ezflow.config.retrieve 17 | :members: 18 | 19 | 20 | -------------------------------------------------------------------------------- /docs/api/decoder/index.rst: -------------------------------------------------------------------------------- 1 | Decoder 2 | ================= 3 | 4 | 5 | .. toctree:: 6 | :maxdepth: 2 7 | 8 | ezflow.decoder.iterative 9 | ezflow.decoder.noniterative 10 | ezflow.decoder.conv_decoder 11 | ezflow.decoder.separable_conv 12 | ezflow.decoder.builder 13 | -------------------------------------------------------------------------------- /configs/models/flownet_s.yaml: -------------------------------------------------------------------------------- 1 | NAME: FlowNetS 2 | ENCODER: 3 | NAME: FlowNetConvEncoder 4 | IN_CHANNELS: 6 5 | CONFIG: [64, 128, 256, 256, 512, 512, 512, 512, 1024, 1024] 6 | NORM: batch 7 | DECODER: 8 | NAME: FlowNetConvDecoder 9 | IN_CHANNELS: 1024 10 | CONFIG: [512, 256, 128, 64] 11 | INTERPOLATE_FLOW: True 12 | -------------------------------------------------------------------------------- /configs/models/pwcnet.yaml: -------------------------------------------------------------------------------- 1 | NAME: PWCNet 2 | ENCODER: 3 | NAME: PyramidEncoder 4 | IN_CHANNELS: 3 5 | CONFIG: [16, 32, 64, 96, 128, 196] 6 | DECODER: 7 | NAME: PyramidDecoder 8 | CONFIG: [128, 128, 96, 64, 32] 9 | TO_FLOW: True 10 | SIMILARITY: 11 | PAD_SIZE: 0 12 | MAX_DISPLACEMENT: 4 13 | FLOW_SCALE_FACTOR: 20.0 -------------------------------------------------------------------------------- /.codecov.yml: -------------------------------------------------------------------------------- 1 | codecov: 2 | require_ci_to_pass: yes 3 | 4 | coverage: 5 | precision: 2 6 | round: down 7 | range: "70...100" 8 | 9 | comment: 10 | layout: "header, diff, changes, tree" 11 | behavior: default 12 | require_changes: no 13 | 14 | ignore: 15 | - "ezflow/data" 16 | - "ezflow/config" 17 | - "ezflow/model_zoo" -------------------------------------------------------------------------------- /ezflow/encoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import ENCODER_REGISTRY, build_encoder 2 | from .conv_encoder import BasicConvEncoder, FlowNetConvEncoder 3 | from .dcvnet import DCVNetBackbone 4 | from .ganet import GANetBackbone 5 | from .pspnet import PSPNetBackbone 6 | from .pyramid import PyramidEncoder 7 | from .raft import RAFTBackbone 8 | from .residual import * 9 | -------------------------------------------------------------------------------- /ezflow/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import MODEL_REGISTRY, build_model, get_default_model_cfg, get_model_list 2 | from .dcvnet import DCVNet 3 | from .dicl import DICL 4 | from .flownet_c import FlowNetC 5 | from .flownet_s import FlowNetS 6 | from .predictor import Predictor 7 | from .pwcnet import PWCNet 8 | from .raft import RAFT 9 | from .vcn import VCN 10 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include AUTHORS.rst 2 | include CONTRIBUTING.rst 3 | include LICENSE 4 | include README.rst 5 | include README.md 6 | include requirements.txt 7 | 8 | recursive-include configs * 9 | recursive-include tests * 10 | recursive-exclude * __pycache__ 11 | recursive-exclude * *.py[co] 12 | 13 | recursive-include docs *.rst conf.py Makefile make.bat *.jpg *.png *.gif 14 | -------------------------------------------------------------------------------- /docs/api/data/ezflow.data.dataloader.rst: -------------------------------------------------------------------------------- 1 | Dataloader 2 | ======================== 3 | 4 | Dataloader Creator 5 | ---------------------- 6 | 7 | .. automodule:: ezflow.data.dataloader.dataloader_creator 8 | :members: 9 | 10 | 11 | 12 | Device Dataloader 13 | ---------------------- 14 | 15 | .. automodule:: ezflow.data.dataloader.device_dataloader 16 | :members: 17 | 18 | -------------------------------------------------------------------------------- /.editorconfig: -------------------------------------------------------------------------------- 1 | # http://editorconfig.org 2 | 3 | root = true 4 | 5 | [*] 6 | indent_style = space 7 | indent_size = 4 8 | trim_trailing_whitespace = true 9 | insert_final_newline = true 10 | charset = utf-8 11 | end_of_line = lf 12 | 13 | [*.bat] 14 | indent_style = tab 15 | end_of_line = crlf 16 | 17 | [LICENSE] 18 | insert_final_newline = false 19 | 20 | [Makefile] 21 | indent_style = tab 22 | -------------------------------------------------------------------------------- /configs/models/flownet_c.yaml: -------------------------------------------------------------------------------- 1 | NAME: FlowNetC 2 | ENCODER: 3 | NAME: FlowNetConvEncoder 4 | IN_CHANNELS: 3 5 | CONFIG: [64, 128, 256, 256, 512, 512, 512, 512, 1024, 1024] 6 | NORM: batch 7 | SIMILARITY: 8 | NAME: IterSpatialCorrelationSampler 9 | PAD_SIZE: 0 10 | MAX_DISPLACEMENT: 10 11 | DECODER: 12 | NAME: FlowNetConvDecoder 13 | IN_CHANNELS: 1024 14 | CONFIG: [512, 256, 128, 64] -------------------------------------------------------------------------------- /configs/models/vcn.yaml: -------------------------------------------------------------------------------- 1 | NAME: VCN 2 | ENCODER: 3 | NAME: PSPNetBackbone 4 | IS_PROJ: False 5 | GROUPS: 1 6 | IN_CHANNELS: 3 7 | NORM: True 8 | DECODER: 9 | F_DIM_A1: 128 10 | F_DIM_A2: 64 11 | F_DIM_B1: 16 12 | F_DIM_B2: 12 13 | NORM: True 14 | ENTROPY: True 15 | SIZE: [16, 256, 256] # batch_size x H x W 16 | MAX_DISPLACEMENTS: [2, 2, 2, 2, 2] 17 | FACTORIZATION: 1 -------------------------------------------------------------------------------- /ezflow/data/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoflow import AutoFlow 2 | from .base_dataset import BaseDataset 3 | from .driving import Driving 4 | from .flying_chairs import FlyingChairs 5 | from .flying_things3d import FlyingThings3D, FlyingThings3DSubset 6 | from .hd1k import HD1K 7 | from .kitti import Kitti 8 | from .kubric import Kubric 9 | from .monkaa import Monkaa 10 | from .mpi_sintel import MPISintel 11 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | repos: 2 | - repo: https://github.com/timothycrosley/isort 3 | rev: 5.11.5 4 | hooks: 5 | - id: isort 6 | 7 | - repo: https://github.com/python/black 8 | rev: 22.3.0 9 | hooks: 10 | - id: black 11 | language_version: python3 12 | 13 | # - repo: https://gitlab.com/pycqa/flake8 14 | # rev: 3.8.3 15 | # hooks: 16 | # - id: flake8 -------------------------------------------------------------------------------- /docs/api/functional/ezflow.functional.criterion.rst: -------------------------------------------------------------------------------- 1 | Criterion 2 | ========================================================= 3 | 4 | Sequence Loss 5 | --------------------- 6 | 7 | .. automodule:: ezflow.functional.criterion.sequence 8 | :members: 9 | 10 | 11 | 12 | Multi-scale Loss 13 | --------------------- 14 | 15 | .. automodule:: ezflow.functional.criterion.multiscale 16 | :members: 17 | 18 | -------------------------------------------------------------------------------- /docs/api/decoder/ezflow.decoder.noniterative.rst: -------------------------------------------------------------------------------- 1 | Non-iterative Deocders 2 | ========================================================= 3 | 4 | Soft Regression 5 | --------------------- 6 | 7 | .. automodule:: ezflow.decoder.noniterative.soft_regression 8 | :members: 9 | 10 | 11 | 12 | Operators 13 | --------------------- 14 | 15 | .. automodule:: ezflow.decoder.noniterative.operators 16 | :members: 17 | 18 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py36, py37, py38, flake8 3 | 4 | [travis] 5 | python = 6 | 3.8: py38 7 | 3.7: py37 8 | 9 | [testenv:flake8] 10 | basepython = python 11 | deps = flake8 12 | commands = flake8 ezflow tests 13 | 14 | [testenv] 15 | setenv = 16 | PYTHONPATH = {toxinidir} 17 | deps = 18 | -r{toxinidir}/requirements.txt 19 | commands = 20 | pip install -U pip 21 | pytest --basetemp={envtmpdir} 22 | 23 | -------------------------------------------------------------------------------- /docs/api/functional/ezflow.functional.data_augmentation.rst: -------------------------------------------------------------------------------- 1 | Data Augmentation 2 | ========================================================= 3 | 4 | Augmentor 5 | --------------------- 6 | 7 | .. automodule:: ezflow.functional.data_augmentation.augmentor 8 | :members: 9 | 10 | 11 | 12 | Operations 13 | --------------------- 14 | 15 | .. automodule:: ezflow.functional.data_augmentation.operations 16 | :members: 17 | 18 | -------------------------------------------------------------------------------- /ezflow/decoder/__init__.py: -------------------------------------------------------------------------------- 1 | from .build import DECODER_REGISTRY, build_decoder 2 | from .context import ContextNetwork 3 | from .conv_decoder import ConvDecoder, FlowNetConvDecoder 4 | from .dilated_flow_stack_filter import ( 5 | DCVDilatedFlowStackFilterDecoder, 6 | DCVFilterGroupConvStemJoint, 7 | ) 8 | from .iterative import * 9 | from .noniterative import * 10 | from .pyramid import PyramidDecoder 11 | from .separable_conv import Butterfly4D, SeparableConv4D 12 | -------------------------------------------------------------------------------- /CITATION.cff: -------------------------------------------------------------------------------- 1 | cff-version: 1.2.0 2 | message: "If you find this software useful, please cite it as below." 3 | authors: 4 | - family-names: "Shah" 5 | given-names: "Neelay" 6 | - family-names: "Goswami" 7 | given-names: "Prajnan" 8 | - family-names: "Jiang" 9 | given-names: "Huaizu" 10 | title: "EzFlow: A modular PyTorch library for optical flow estimation using neural networks" 11 | date-released: 2021-11-18 12 | url: "https://github.com/neu-vig/ezflow" 13 | license: MIT -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [bumpversion] 2 | current_version = 0.2.5 3 | commit = True 4 | tag = True 5 | 6 | [bumpversion:file:setup.py] 7 | search = version='{current_version}' 8 | replace = version='{new_version}' 9 | 10 | [bumpversion:file:ezflow/__init__.py] 11 | search = __version__ = '{current_version}' 12 | replace = __version__ = '{new_version}' 13 | 14 | [bdist_wheel] 15 | universal = 1 16 | 17 | [flake8] 18 | exclude = docs 19 | [tool:pytest] 20 | collect_ignore = ['setup.py'] 21 | -------------------------------------------------------------------------------- /.coveragerc: -------------------------------------------------------------------------------- 1 | [report] 2 | # Regexes for lines to exclude from consideration 3 | exclude_lines = 4 | # Have to re-enable the standard pragma 5 | pragma: no cover 6 | 7 | # Don't complain about missing debug-only code: 8 | def __repr__ 9 | if self\.debug 10 | 11 | # Don't complain if tests don't hit defensive assertion code: 12 | raise AssertionError 13 | raise NotImplementedError 14 | raise ValueError 15 | 16 | # Don't complain if non-runnable code isn't run: 17 | if 0: 18 | if __name__ == .__main__.: -------------------------------------------------------------------------------- /docs/api/similarity/ezflow.similarity.correlation.rst: -------------------------------------------------------------------------------- 1 | Correlation 2 | ========================================================= 3 | 4 | 5 | Correlation Layer 6 | -------------------- 7 | 8 | .. automodule:: ezflow.similarity.correlation.layer 9 | :members: 10 | 11 | 12 | 13 | 14 | Pairwise Correlation 15 | ---------------------- 16 | 17 | .. automodule:: ezflow.similarity.correlation.pairwise 18 | :members: 19 | 20 | 21 | 22 | 23 | Correlation Sampler 24 | ---------------------- 25 | 26 | .. automodule:: ezflow.similarity.correlation.sampler 27 | :members: 28 | 29 | -------------------------------------------------------------------------------- /docs/api/engine/index.rst: -------------------------------------------------------------------------------- 1 | Engine 2 | ================ 3 | 4 | Training 5 | --------------- 6 | 7 | .. automodule:: ezflow.engine.trainer 8 | :members: 9 | 10 | 11 | 12 | Evaluation 13 | --------------- 14 | 15 | .. automodule:: ezflow.engine.eval 16 | :members: 17 | 18 | 19 | 20 | Pruning 21 | --------------- 22 | 23 | .. automodule:: ezflow.engine.pruning 24 | :members: 25 | 26 | 27 | Profiler 28 | --------------- 29 | 30 | .. automodule:: ezflow.engine.profiler 31 | :members: 32 | 33 | 34 | Retrieve 35 | --------------- 36 | 37 | .. automodule:: ezflow.engine.retrieve 38 | :members: 39 | 40 | -------------------------------------------------------------------------------- /configs/models/raft.yaml: -------------------------------------------------------------------------------- 1 | NAME: RAFT 2 | ENCODER: 3 | FEATURE: 4 | NAME: RAFTBackbone 5 | IN_CHANNELS: 3 6 | OUT_CHANNELS: 256 7 | NORM: instance 8 | P_DROPOUT: 0.0 9 | LAYER_CONFIG: [64, 96, 128] 10 | CONTEXT: 11 | NAME: RAFTBackbone 12 | IN_CHANNELS: 3 13 | OUT_CHANNELS: 256 14 | NORM: batch 15 | P_DROPOUT: 0.0 16 | LAYER_CONFIG: [64, 96, 128] 17 | HIDDEN_DIM: 128 18 | CONTEXT_DIM: 128 19 | SIMILARITY: 20 | NAME: MutliScalePairwise4DCorr 21 | NUM_LEVELS: 4 22 | DECODER: 23 | NAME: RecurrentLookupUpdateBlock 24 | INPUT_DIM: 128 25 | CORR_RADIUS: 4 26 | CORR_LEVELS: 4 27 | MIXED_PRECISION: False 28 | UPDATE_ITERS: 12 -------------------------------------------------------------------------------- /ezflow/modules/base_module.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class BaseModule(nn.Module): 5 | """ 6 | A wrapper for torch.nn.Module to maintain common 7 | module functionalities. 8 | 9 | """ 10 | 11 | def __init__(self): 12 | super(BaseModule, self).__init__() 13 | 14 | def forward(self): 15 | pass 16 | 17 | def freeze_batch_norm(self): 18 | """ 19 | Set Batch Norm layers to evaluation state. 20 | This method can be used for fine tuning. 21 | 22 | """ 23 | for module in self.modules(): 24 | if isinstance(module, nn.BatchNorm2d): 25 | module.eval() 26 | -------------------------------------------------------------------------------- /configs/models/raft_small.yaml: -------------------------------------------------------------------------------- 1 | NAME: RAFT 2 | ENCODER: 3 | FEATURE: 4 | NAME: RAFTBackboneSmall 5 | IN_CHANNELS: 3 6 | OUT_CHANNELS: 128 7 | NORM: instance 8 | P_DROPOUT: 0.0 9 | LAYER_CONFIG: [32, 64, 96] 10 | CONTEXT: 11 | NAME: RAFTBackboneSmall 12 | IN_CHANNELS: 3 13 | OUT_CHANNELS: 160 14 | NORM: batch 15 | P_DROPOUT: 0.0 16 | LAYER_CONFIG: [32, 64, 96] 17 | HIDDEN_DIM: 96 18 | CONTEXT_DIM: 64 19 | SIMILARITY: 20 | NAME: MutliScalePairwise4DCorr 21 | NUM_LEVELS: 4 22 | DECODER: 23 | NAME: RecurrentLookupUpdateBlock 24 | INPUT_DIM: 96 25 | CORR_RADIUS: 4 26 | CORR_LEVELS: 3 27 | MIXED_PRECISION: False 28 | UPDATE_ITERS: 12 -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | Welcome to EzFlow's documentation! 2 | ====================================== 3 | 4 | .. toctree:: 5 | :maxdepth: 2 6 | :caption: Contents 7 | 8 | readme 9 | installation 10 | tutorials/index 11 | 12 | .. toctree:: 13 | :maxdepth: 2 14 | :caption: API Reference 15 | 16 | api/config/index 17 | api/data/index 18 | api/decoder/index 19 | api/encoder/index 20 | api/engine/index 21 | api/functional/index 22 | api/models/index 23 | api/modules/index 24 | api/similarity/index 25 | api/utils/index 26 | 27 | Indices and tables 28 | ================== 29 | * :ref:`genindex` 30 | * :ref:`modindex` 31 | * :ref:`search` 32 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = python -msphinx 7 | SPHINXPROJ = ezflow 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 21 | -------------------------------------------------------------------------------- /.readthedocs.yaml: -------------------------------------------------------------------------------- 1 | # .readthedocs.yaml 2 | # Read the Docs configuration file 3 | # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details 4 | 5 | # Required 6 | version: 2 7 | 8 | # Set the version of Python and other tools you might need 9 | build: 10 | os: ubuntu-20.04 11 | tools: 12 | python: "3.7" 13 | 14 | # Build documentation in the docs/ directory with Sphinx 15 | sphinx: 16 | configuration: docs/conf.py 17 | 18 | # If using Sphinx, optionally build your docs in additional formats such as PDF 19 | # formats: 20 | # - pdf 21 | 22 | # Optionally declare the Python requirements required to build your docs 23 | python: 24 | install: 25 | - requirements: docs/requirements.txt -------------------------------------------------------------------------------- /tests/utils/mock_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | from ezflow.modules import BaseModule 6 | 7 | 8 | class MockOpticalFlowModel(BaseModule): 9 | def __init__(self, img_channels): 10 | super().__init__() 11 | 12 | self.model = nn.Conv2d(img_channels * 2, 2, kernel_size=1) 13 | 14 | def forward(self, img1, img2): 15 | 16 | x = torch.cat([img1, img2], dim=-3) 17 | mock_flow_prediction = self.model(x) 18 | 19 | flow_up = F.interpolate( 20 | mock_flow_prediction, img1.shape[-2:], mode="bilinear", align_corners=True 21 | ) 22 | output = {"flow_preds": [mock_flow_prediction], "flow_upsampled": flow_up} 23 | return output 24 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | ================= 2 | EzFlow 3 | ================= 4 | 5 | 6 | .. image:: https://github.com/neu-vig/ezflow/actions/workflows/package-test.yml/badge.svg 7 | :target: https://github.com/neu-vig/ezflow/actions/workflows/package-test.yml 8 | :alt: Tests 9 | 10 | .. image:: https://readthedocs.org/projects/ezflow/badge/?version=latest 11 | :target: https://ezflow.readthedocs.io/en/latest/?version=latest 12 | :alt: Documentation Status 13 | 14 | .. image:: https://img.shields.io/pypi/v/ezflow.svg 15 | :target: https://pypi.python.org/pypi/ezflow 16 | 17 | 18 | A modular PyTorch library for optical flow estimation using neural networks 19 | 20 | * Free software: MIT license 21 | * Documentation: https://ezflow.readthedocs.io. 22 | -------------------------------------------------------------------------------- /docs/api/encoder/index.rst: -------------------------------------------------------------------------------- 1 | Encoder 2 | ================ 3 | 4 | Convolution Encoder 5 | -------------------- 6 | 7 | .. automodule:: ezflow.encoder.conv_encoder 8 | :members: 9 | 10 | 11 | 12 | GANet 13 | --------------- 14 | 15 | .. automodule:: ezflow.encoder.ganet 16 | :members: 17 | 18 | 19 | 20 | PSPNet 21 | --------------- 22 | 23 | .. automodule:: ezflow.encoder.pspnet 24 | :members: 25 | 26 | 27 | 28 | Pyramid 29 | --------------- 30 | 31 | .. automodule:: ezflow.encoder.pyramid 32 | :members: 33 | 34 | 35 | 36 | Residual 37 | --------------- 38 | 39 | .. automodule:: ezflow.encoder.residual 40 | :members: 41 | 42 | 43 | 44 | Builder 45 | --------------- 46 | 47 | .. automodule:: ezflow.encoder.build 48 | :members: 49 | 50 | -------------------------------------------------------------------------------- /configs/models/dicl.yaml: -------------------------------------------------------------------------------- 1 | NAME: DICL 2 | ENCODER: 3 | NAME: GANetBackbone 4 | IN_CHANNELS: 3 5 | OUT_CHANNELS: 32 6 | SIMILARITY: 7 | NAME: LearnableMatchingCost 8 | MAX_U: 3 9 | MAX_V: 3 10 | CONFIG: [64, 96, 128, 64, 32, 1] 11 | REMOVE_WARP_HOLE: True 12 | CUDA_COST_COMPUTE: False 13 | MATCHING_NET: 14 | NAME: Conv2DMatching 15 | CONFIG: [64, 96, 128, 64, 32, 1] 16 | DECODER: 17 | NAME: SoftArg2DFlowRegression 18 | MAX_U: 3 19 | MAX_V: 3 20 | OPERATION: argmax 21 | SEARCH_RANGE: [3, 3, 3, 3, 3] 22 | CONTEXT_NET: True 23 | SUP_RAW_FLOW: False 24 | SCALE_FACTORS: [0.25, 0.125, 0.0625, 0.03125, 0.015625] 25 | SCALE_CONTEXTS: [1.0, 1.0, 1.0, 1.0, 1.0] 26 | DAP: 27 | USE_DAP: True 28 | INIT_ID: True 29 | MAX_DISPLACEMENT: 3 30 | TEMPERATURE: False 31 | TEMP_FACTOR: 1.e-6 32 | 33 | -------------------------------------------------------------------------------- /tests/configs/custom_loss_trainer.yaml: -------------------------------------------------------------------------------- 1 | OPTIMIZER: 2 | NAME: Adam 3 | LR: 0.0003 4 | PARAMS: 5 | betas: [0.9, 0.999] 6 | eps: 1.e-08 7 | SCHEDULER: 8 | USE: True 9 | NAME: StepLR 10 | PARAMS: 11 | step_size: 10 12 | gamma: 0.1 13 | CRITERION: 14 | CUSTOM: True 15 | NAME: "MultiScaleLoss" 16 | PARAMS: 17 | weights: [1, 0.5, 0.25, 0.125, 0.0625] 18 | GRAD_CLIP: 19 | USE: True 20 | VALUE: 1.0 21 | TARGET_SCALE_FACTOR: 1 22 | APPEND_VALID_MASK: False 23 | MIXED_PRECISION: False 24 | FREEZE_BATCH_NORM: False 25 | DEVICE: "cpu" 26 | LOG_DIR: "./logs" 27 | LOG_ITERATIONS_INTERVAL: 1 28 | VALIDATE_INTERVAL: 1 29 | VALIDATE_ON: metric 30 | CKPT_DIR: "./ckpts" 31 | CKPT_INTERVAL: 1 32 | EPOCHS: 1 33 | NUM_STEPS: null 34 | RESUME_TRAINING: 35 | CONSOLIDATED_CKPT: null 36 | EPOCHS: 1 37 | START_EPOCH: null -------------------------------------------------------------------------------- /configs/trainers/pwcnet/pwcnet_chairs_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/chairs_baseline.yaml" 2 | TARGET_SCALE_FACTOR: 20.0 3 | DATA: 4 | BATCH_SIZE: 8 5 | NORM_PARAMS: {"use": True, "mean":[0.0, 0.0, 0.0], "std":[255.0, 255.0, 255.0]} 6 | SCHEDULER: 7 | USE: True 8 | NAME: OneCycleLR 9 | PARAMS: 10 | max_lr: 0.0004 11 | total_steps: 1200100 12 | pct_start: 0.05 13 | cycle_momentum: False 14 | anneal_strategy: linear 15 | CRITERION: 16 | CUSTOM: True 17 | NAME: MultiScaleLoss 18 | PARAMS: 19 | norm: "l2" 20 | weights: [0.32, 0.08, 0.02, 0.01, 0.005] 21 | average: "sum" 22 | resize_flow: "downsample" 23 | EPOCHS: null 24 | NUM_STEPS: 1200100 25 | LOG_DIR: "./logs" 26 | CKPT_DIR: "./ckpts" 27 | LOG_ITERATIONS_INTERVAL: 100 28 | CKPT_INTERVAL: 100000 29 | VALIDATE_INTERVAL: 10000 30 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /configs/trainers/pwcnet/pwcnet_kubric_improved_aug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/kubric_improved_aug.yaml" 2 | TARGET_SCALE_FACTOR: 20.0 3 | DATA: 4 | BATCH_SIZE: 8 5 | NORM_PARAMS: {"use": True, "mean":[0.0, 0.0, 0.0], "std":[255.0, 255.0, 255.0]} 6 | SCHEDULER: 7 | USE: True 8 | NAME: OneCycleLR 9 | PARAMS: 10 | max_lr: 0.0004 11 | total_steps: 1200100 12 | pct_start: 0.05 13 | cycle_momentum: False 14 | anneal_strategy: linear 15 | CRITERION: 16 | CUSTOM: True 17 | NAME: MultiScaleLoss 18 | PARAMS: 19 | norm: "l2" 20 | weights: [0.32, 0.08, 0.02, 0.01, 0.005] 21 | average: "sum" 22 | resize_flow: "downsample" 23 | EPOCHS: null 24 | NUM_STEPS: 1200100 25 | LOG_DIR: "./logs" 26 | CKPT_DIR: "./ckpts" 27 | LOG_ITERATIONS_INTERVAL: 100 28 | CKPT_INTERVAL: 100000 29 | VALIDATE_INTERVAL: 10000 30 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /docs/api/utils/index.rst: -------------------------------------------------------------------------------- 1 | Utilities 2 | ================ 3 | 4 | IO 5 | --- 6 | 7 | .. automodule:: ezflow.utils.io 8 | :members: 9 | 10 | 11 | 12 | Metrics 13 | -------- 14 | 15 | .. automodule:: ezflow.utils.metrics 16 | :members: 17 | 18 | 19 | 20 | Other Utilities 21 | ---------------- 22 | 23 | .. automodule:: ezflow.utils.other_utils 24 | :members: 25 | 26 | 27 | Registry 28 | ------------ 29 | 30 | .. automodule:: ezflow.utils.registry 31 | :members: 32 | 33 | 34 | 35 | Resampling 36 | --------------- 37 | 38 | .. automodule:: ezflow.utils.resampling 39 | :members: 40 | 41 | 42 | 43 | Visualization 44 | --------------- 45 | 46 | .. automodule:: ezflow.utils.viz 47 | :members: 48 | 49 | 50 | 51 | Warping 52 | --------------- 53 | 54 | .. automodule:: ezflow.utils.warp 55 | :members: 56 | 57 | -------------------------------------------------------------------------------- /.github/workflows/linting.yml: -------------------------------------------------------------------------------- 1 | name: Linting and code style 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | python-version: ['3.7', '3.8'] 17 | 18 | steps: 19 | 20 | - uses: actions/checkout@v2 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 31 | 32 | - name: Run pre-commit hooks 33 | run: | 34 | pre-commit install 35 | pre-commit run -a -------------------------------------------------------------------------------- /docs/api/data/ezflow.data.dataset.rst: -------------------------------------------------------------------------------- 1 | Dataset 2 | ======================== 3 | 4 | Base Dataset 5 | ---------------------- 6 | 7 | .. automodule:: ezflow.data.dataset.base_dataset 8 | :members: 9 | 10 | 11 | 12 | Flying Chairs 13 | ---------------------- 14 | 15 | .. automodule:: ezflow.data.dataset.flying_chairs 16 | :members: 17 | 18 | 19 | 20 | Flying Things 3D 21 | ---------------------- 22 | 23 | .. automodule:: ezflow.data.dataset.flying_things3d 24 | :members: 25 | 26 | 27 | 28 | HD1K 29 | ---------------------- 30 | 31 | .. automodule:: ezflow.data.dataset.hd1k 32 | :members: 33 | 34 | 35 | 36 | KITTI 37 | ---------------------- 38 | 39 | .. automodule:: ezflow.data.dataset.kitti 40 | :members: 41 | 42 | 43 | 44 | MPI Sintel 45 | ---------------------- 46 | 47 | .. automodule:: ezflow.data.dataset.mpi_sintel 48 | :members: 49 | 50 | -------------------------------------------------------------------------------- /.github/workflows/publish.yml: -------------------------------------------------------------------------------- 1 | name: Publish to PyPI 2 | 3 | on: 4 | release: 5 | types: [published] 6 | 7 | jobs: 8 | 9 | deploy: 10 | runs-on: ubuntu-latest 11 | strategy: 12 | matrix: 13 | python-version: ['3.8'] 14 | 15 | steps: 16 | 17 | - uses: actions/checkout@v2 18 | 19 | - name: Set up Python ${{ matrix.python-version }} 20 | uses: actions/setup-python@v2 21 | with: 22 | python-version: ${{ matrix.python-version }} 23 | 24 | - name: Install dependencies 25 | run: | 26 | python -m pip install --upgrade pip 27 | pip install build 28 | 29 | - name: Build package 30 | run: python -m build 31 | 32 | - name: Publish package to PyPI 33 | uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 34 | with: 35 | user: __token__ 36 | password: ${{ secrets.PYPI_API_TOKEN }} -------------------------------------------------------------------------------- /ezflow/functional/registry.py: -------------------------------------------------------------------------------- 1 | from ..utils.registry import Registry 2 | 3 | FUNCTIONAL_REGISTRY = Registry("FUNCTIONAL") 4 | 5 | 6 | def get_functional(cfg_grp=None, name=None, **kwargs): 7 | """ 8 | Retrieve a component from the functional registry 9 | 10 | Parameters 11 | ---------- 12 | cfg_grp : :class: `CfgNode` 13 | Configuration for the component 14 | name : str 15 | Name of the component 16 | kwargs : dict 17 | Additional keyword arguments 18 | """ 19 | 20 | if cfg_grp is None: 21 | assert name is not None, "Must provide name or cfg_grp" 22 | assert dict(**kwargs) is not None, "Must provide either cfg_grp or kwargs" 23 | 24 | if name is None: 25 | name = cfg_grp.NAME 26 | 27 | fn = FUNCTIONAL_REGISTRY.get(name) 28 | 29 | if cfg_grp is None: 30 | return fn(**kwargs) 31 | 32 | return fn(cfg_grp, **kwargs) 33 | -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=python -msphinx 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=ezflow 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The Sphinx module was not found. Make sure you have Sphinx installed, 20 | echo.then set the SPHINXBUILD environment variable to point to the full 21 | echo.path of the 'sphinx-build' executable. Alternatively you may add the 22 | echo.Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /tests/configs/base_trainer_test.yaml: -------------------------------------------------------------------------------- 1 | OPTIMIZER: 2 | NAME: Adam 3 | LR: 0.0003 4 | PARAMS: 5 | betas: [0.9, 0.999] 6 | eps: 1.e-08 7 | SCHEDULER: 8 | USE: True 9 | NAME: StepLR 10 | PARAMS: 11 | step_size: 10 12 | gamma: 0.1 13 | CRITERION: 14 | CUSTOM: False 15 | NAME: L1Loss 16 | PARAMS: null 17 | GRAD_CLIP: 18 | USE: True 19 | VALUE: 1.0 20 | TARGET_SCALE_FACTOR: 1 21 | APPEND_VALID_MASK: False 22 | MIXED_PRECISION: False 23 | FREEZE_BATCH_NORM: False 24 | SYNC_BATCH_NORM: False 25 | DEVICE: "cpu" 26 | LOG_DIR: "./logs" 27 | LOG_ITERATIONS_INTERVAL: 1 28 | VALIDATE_INTERVAL: 1 29 | VALIDATE_ON: metric 30 | CKPT_DIR: "./ckpts" 31 | CKPT_INTERVAL: 1 32 | EPOCHS: 1 33 | NUM_STEPS: null 34 | DISTRIBUTED: 35 | USE: False 36 | WORLD_SIZE: 2 37 | RANK: 0 38 | BACKEND: nccl 39 | MASTER_ADDR: localhost 40 | MASTER_PORT: "12355" 41 | RESUME_TRAINING: 42 | CONSOLIDATED_CKPT: null 43 | EPOCHS: 1 44 | START_EPOCH: null -------------------------------------------------------------------------------- /configs/trainers/pwcnet/pwcnet_things_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/things_baseline.yaml" 2 | TARGET_SCALE_FACTOR: 20.0 3 | DATA: 4 | BATCH_SIZE: 4 5 | NORM_PARAMS: {"use": True, "mean":[0.0, 0.0, 0.0], "std":[255.0, 255.0, 255.0]} 6 | SCHEDULER: 7 | USE: True 8 | NAME: OneCycleLR 9 | PARAMS: 10 | max_lr: 0.000125 11 | total_steps: 1200100 12 | pct_start: 0.05 13 | cycle_momentum: False 14 | anneal_strategy: linear 15 | CRITERION: 16 | CUSTOM: True 17 | NAME: MultiScaleLoss 18 | PARAMS: 19 | norm: "l2" 20 | weights: [0.32, 0.08, 0.02, 0.01, 0.005] 21 | average: "sum" 22 | resize_flow: "downsample" 23 | use_valid_range: True 24 | valid_range: [[1000,1000],[1000,1000],[1000,1000],[1000,1000],[1000,1000]] 25 | EPOCHS: null 26 | NUM_STEPS: 1200100 27 | LOG_DIR: "./logs" 28 | CKPT_DIR: "./ckpts" 29 | LOG_ITERATIONS_INTERVAL: 100 30 | CKPT_INTERVAL: 100000 31 | VALIDATE_INTERVAL: 10000 32 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /ezflow/model_zoo/model_zoo.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Detectron2 (https://github.com/facebookresearch/detectron2) 3 | """ 4 | 5 | 6 | class _ModelZooConfigs: 7 | 8 | MODEL_NAME_TO_CONFIG = { 9 | "RAFT": "raft.yaml", 10 | "RAFT_SMALL": "raft_small.yaml", 11 | "DICL": "dicl.yaml", 12 | "DCVNet": "dcvnet.yaml", 13 | "PWCNet": "pwcnet.yaml", 14 | "VCN": "vcn.yaml", 15 | "FlowNetS": "flownet_s.yaml", 16 | "FlowNetC": "flownet_c.yaml", 17 | } 18 | 19 | @staticmethod 20 | def get_names(): 21 | return list(_ModelZooConfigs.MODEL_NAME_TO_CONFIG.keys()) 22 | 23 | @staticmethod 24 | def query(model_name): 25 | 26 | if model_name in _ModelZooConfigs.MODEL_NAME_TO_CONFIG: 27 | 28 | cfg_file = _ModelZooConfigs.MODEL_NAME_TO_CONFIG[model_name] 29 | return cfg_file 30 | 31 | raise ValueError(f"Model name '{model_name}' not found in model zoo") 32 | -------------------------------------------------------------------------------- /.github/workflows/package-test.yml: -------------------------------------------------------------------------------- 1 | name: Package Tests 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | python-version: ['3.7', '3.8'] 17 | 18 | steps: 19 | 20 | - uses: actions/checkout@v2 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install black pytest 31 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 32 | pip install build 33 | 34 | - name: Build package 35 | run: python -m build 36 | 37 | - name: Test with pytest 38 | run: | 39 | pytest -------------------------------------------------------------------------------- /docs/api/models/index.rst: -------------------------------------------------------------------------------- 1 | Models 2 | ================= 3 | 4 | 5 | Builder 6 | ---------- 7 | 8 | .. automodule:: ezflow.models.build 9 | :members: 10 | 11 | 12 | 13 | DICL 14 | ---------- 15 | 16 | .. automodule:: ezflow.models.dicl 17 | :members: 18 | 19 | 20 | 21 | FlowNetS 22 | ---------- 23 | 24 | .. automodule:: ezflow.models.flownet_s 25 | :members: 26 | 27 | 28 | 29 | FlowNetC 30 | ---------- 31 | 32 | .. automodule:: ezflow.models.flownet_c 33 | :members: 34 | 35 | 36 | 37 | Predictor 38 | ---------- 39 | 40 | .. automodule:: ezflow.models.predictor 41 | :members: 42 | 43 | 44 | 45 | PWCNet 46 | ---------- 47 | 48 | .. automodule:: ezflow.models.pwcnet 49 | :members: 50 | 51 | 52 | 53 | RAFT 54 | ---------- 55 | 56 | .. automodule:: ezflow.models.raft 57 | :members: 58 | 59 | 60 | 61 | VCN 62 | ---------- 63 | 64 | .. automodule:: ezflow.models.vcn 65 | :members: 66 | 67 | 68 | -------------------------------------------------------------------------------- /docs/installation.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Installation 5 | ============ 6 | 7 | 8 | From source (recommended) 9 | ------------ 10 | 11 | EzFlow can be installed from the `GitHub repo`_. 12 | 13 | Clone the public repository: 14 | 15 | .. code-block:: console 16 | 17 | $ git clone https://github.com/neu-vig/ezflow.git 18 | 19 | and then run the following command to install EzFlow: 20 | 21 | .. code-block:: console 22 | 23 | $ python setup.py install 24 | 25 | 26 | .. _Github repo: https://github.com/neu-vig/ezflow 27 | 28 | 29 | Stable release 30 | -------------- 31 | 32 | To install EzFlow, run this command in your terminal: 33 | 34 | .. code-block:: console 35 | 36 | $ pip install ezflow 37 | 38 | Note that EzFlow is an active project and routinely publishes new releases. In order to upgrade EzFlow to the latest version, use pip as follows: 39 | 40 | .. code-block:: console 41 | 42 | $ pip install -U ezflow 43 | 44 | 45 | 46 | -------------------------------------------------------------------------------- /configs/trainers/flownetc/flownetc_chairs_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/chairs_baseline.yaml" 2 | TARGET_SCALE_FACTOR: 20.0 3 | DATA: 4 | BATCH_SIZE: 8 5 | NORM_PARAMS: {"use": True, "mean":[0.0, 0.0, 0.0], "std":[255.0, 255.0, 255.0]} 6 | SCHEDULER: 7 | USE: True 8 | NAME: OneCycleLR 9 | PARAMS: 10 | max_lr: 0.0004 11 | total_steps: 1200100 12 | pct_start: 0.05 13 | cycle_momentum: False 14 | anneal_strategy: linear 15 | CRITERION: 16 | CUSTOM: True 17 | NAME: MultiScaleLoss 18 | PARAMS: 19 | norm: "l2" 20 | weights: [0.32, 0.08, 0.02, 0.01, 0.005] 21 | average: "sum" 22 | resize_flow: "downsample" 23 | DEVICE: "all" 24 | DISTRIBUTED: 25 | USE: True 26 | WORLD_SIZE: 4 27 | BACKEND: nccl 28 | MASTER_ADDR: localhost 29 | MASTER_PORT: "12355" 30 | SYNC_BATCH_NORM: True 31 | EPOCHS: null 32 | NUM_STEPS: 1200100 33 | LOG_DIR: "./logs" 34 | CKPT_DIR: "./ckpts" 35 | LOG_ITERATIONS_INTERVAL: 100 36 | CKPT_INTERVAL: 100000 37 | VALIDATE_INTERVAL: 10000 38 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /configs/trainers/raft/raft_chairs_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/chairs_baseline.yaml" 2 | DATA: 3 | BATCH_SIZE: 10 4 | NORM_PARAMS: {"use": True, "mean":[127.5, 127.5, 127.5], "std":[127.5, 127.5, 127.5]} 5 | APPEND_VALID_MASK: True 6 | TRAIN_DATASET: 7 | FlyingChairs: 8 | CROP: 9 | USE: True 10 | SIZE: [368, 496] 11 | TYPE: "random" 12 | VAL_DATASET: 13 | FlyingChairs: 14 | PADDING: 1 15 | CROP: 16 | USE: True 17 | SIZE: [368, 496] 18 | TYPE: "center" 19 | SCHEDULER: 20 | USE: True 21 | NAME: OneCycleLR 22 | PARAMS: 23 | max_lr: 0.0004 24 | total_steps: 100100 25 | pct_start: 0.05 26 | cycle_momentum: False 27 | anneal_strategy: linear 28 | CRITERION: 29 | CUSTOM: True 30 | NAME: SequenceLoss 31 | PARAMS: 32 | gamma: 0.8 33 | max_flow: 400.0 34 | EPOCHS: null 35 | NUM_STEPS: 100100 36 | LOG_DIR: "./logs" 37 | CKPT_DIR: "./ckpts" 38 | LOG_ITERATIONS_INTERVAL: 100 39 | CKPT_INTERVAL: 20000 40 | VALIDATE_INTERVAL: 1000 41 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /configs/trainers/raft/raft_kubric_improved_aug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/kubric_improved_aug.yaml" 2 | DATA: 3 | BATCH_SIZE: 10 4 | NORM_PARAMS: {"use": True, "mean":[127.5, 127.5, 127.5], "std":[127.5, 127.5, 127.5]} 5 | APPEND_VALID_MASK: True 6 | TRAIN_DATASET: 7 | Kubric: 8 | CROP: 9 | USE: True 10 | SIZE: [368, 496] 11 | TYPE: "random" 12 | VAL_DATASET: 13 | Kubric: 14 | PADDING: 1 15 | CROP: 16 | USE: True 17 | SIZE: [368, 496] 18 | TYPE: "center" 19 | SCHEDULER: 20 | USE: True 21 | NAME: OneCycleLR 22 | PARAMS: 23 | max_lr: 0.0004 24 | total_steps: 100100 25 | pct_start: 0.05 26 | cycle_momentum: False 27 | anneal_strategy: linear 28 | CRITERION: 29 | CUSTOM: True 30 | NAME: SequenceLoss 31 | PARAMS: 32 | gamma: 0.8 33 | max_flow: 400.0 34 | EPOCHS: null 35 | NUM_STEPS: 100100 36 | LOG_DIR: "./logs" 37 | CKPT_DIR: "./ckpts" 38 | LOG_ITERATIONS_INTERVAL: 100 39 | CKPT_INTERVAL: 20000 40 | VALIDATE_INTERVAL: 1000 41 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /configs/trainers/flownetc/flownetc_kubric_improved_aug.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/kubric_improved_aug.yaml" 2 | TARGET_SCALE_FACTOR: 20.0 3 | DATA: 4 | BATCH_SIZE: 2 # Effective Batch Size = 2 x 4 GPUs = 8 5 | NORM_PARAMS: {"use": True, "mean":[0.0, 0.0, 0.0], "std":[255.0, 255.0, 255.0]} 6 | SCHEDULER: 7 | USE: True 8 | NAME: OneCycleLR 9 | PARAMS: 10 | max_lr: 0.0004 11 | total_steps: 1200100 12 | pct_start: 0.05 13 | cycle_momentum: False 14 | anneal_strategy: linear 15 | CRITERION: 16 | CUSTOM: True 17 | NAME: MultiScaleLoss 18 | PARAMS: 19 | norm: "l2" 20 | weights: [0.32, 0.08, 0.02, 0.01, 0.005] 21 | average: "sum" 22 | resize_flow: "downsample" 23 | DEVICE: "all" 24 | DISTRIBUTED: 25 | USE: True 26 | WORLD_SIZE: 4 27 | BACKEND: nccl 28 | MASTER_ADDR: localhost 29 | MASTER_PORT: "12355" 30 | SYNC_BATCH_NORM: True 31 | EPOCHS: null 32 | NUM_STEPS: 1200100 33 | LOG_DIR: "./logs" 34 | CKPT_DIR: "./ckpts" 35 | LOG_ITERATIONS_INTERVAL: 100 36 | CKPT_INTERVAL: 100000 37 | VALIDATE_INTERVAL: 10000 38 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /ezflow/data/dataloader/device_dataloader.py: -------------------------------------------------------------------------------- 1 | def to_device(data, device): 2 | if isinstance(data, (list, tuple)): 3 | return [to_device(x, device) for x in data] 4 | 5 | return data.to(device) 6 | 7 | 8 | class DeviceDataLoader: 9 | """ 10 | A data loader wrapper to move data to a specific compute device. 11 | 12 | Parameters 13 | ---------- 14 | data_loader : DataLoader 15 | The PyTorch DataLoader from torch.utils.data.dataloader 16 | device : torch.device 17 | The compute device 18 | """ 19 | 20 | def __init__(self, data_loader, device): 21 | 22 | self.data_loader = data_loader 23 | self.device = device 24 | 25 | def __iter__(self): 26 | """ 27 | Yield a batch of data after moving it to a device. 28 | 29 | """ 30 | for batch in self.data_loader: 31 | yield to_device(batch, self.device) 32 | 33 | def __len__(self): 34 | """ 35 | Return the number of batches. 36 | 37 | """ 38 | return len(self.data_loader) 39 | -------------------------------------------------------------------------------- /ezflow/decoder/build.py: -------------------------------------------------------------------------------- 1 | from ..utils import Registry 2 | 3 | DECODER_REGISTRY = Registry("DECODER") 4 | 5 | 6 | def build_decoder(cfg_grp=None, name=None, instantiate=True, **kwargs): 7 | 8 | """ 9 | Build a decoder from a registered decoder name. 10 | 11 | Parameters 12 | ---------- 13 | cfg_grp : CfgNode 14 | Config to pass to the decoder. 15 | name : str 16 | Name of the registered decoder. 17 | instantiate : bool 18 | Whether to instantiate the decoder. 19 | 20 | Returns 21 | ------- 22 | decoder : object 23 | The decoder object. 24 | """ 25 | 26 | if cfg_grp is None: 27 | assert name is not None, "Must provide name or cfg_grp" 28 | assert dict(**kwargs) is not None, "Must provide either cfg_grp or kwargs" 29 | 30 | if name is None: 31 | name = cfg_grp.NAME 32 | 33 | decoder = DECODER_REGISTRY.get(name) 34 | 35 | if not instantiate: 36 | return decoder 37 | 38 | if cfg_grp is None: 39 | return decoder(**kwargs) 40 | 41 | return decoder(cfg_grp, **kwargs) 42 | -------------------------------------------------------------------------------- /ezflow/decoder/noniterative/operators.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class FlowEntropy(nn.Module): 8 | """ 9 | Computes entropy from matching cost 10 | 11 | """ 12 | 13 | def __init__(self): 14 | super(FlowEntropy, self).__init__() 15 | 16 | def forward(self, x): 17 | """ 18 | Performs forward pass. 19 | 20 | Parameters 21 | ---------- 22 | x : torch.Tensor 23 | A tensor of shape B x U x V x H x W representing the cost 24 | 25 | Returns 26 | ------- 27 | torch.Tensor 28 | A tensor of shape B x 1 x H x W 29 | """ 30 | 31 | x = torch.squeeze(x, 1) 32 | B, U, V, H, W = x.shape 33 | x = x.view(B, -1, H, W) 34 | x = F.softmax(x, dim=1).view(B, U, V, H, W) 35 | 36 | global_entropy = ( 37 | (-x * torch.clamp(x, 1e-9, 1 - 1e-9).log()).sum(1).sum(1)[:, np.newaxis] 38 | ) 39 | global_entropy /= np.log(x.shape[1] * x.shape[2]) 40 | 41 | return global_entropy 42 | -------------------------------------------------------------------------------- /tools/evaluate.py: -------------------------------------------------------------------------------- 1 | from ezflow.data import DataloaderCreator 2 | from ezflow.engine import eval_model 3 | from ezflow.models import build_model 4 | 5 | if __name__ == "__main__": 6 | model = build_model( 7 | "DCVNet", 8 | default=True, 9 | weights_path="./pretrained_weights/dcvnet_sceneflow_step800k.pth", 10 | ) 11 | dataloader_creator = DataloaderCreator( 12 | batch_size=8, shuffle=False, num_workers=4, pin_memory=True 13 | ) 14 | dataloader_creator.add_Kitti( 15 | root_dir="./Datasets/KITTI2015/", 16 | split="training", 17 | crop=True, 18 | crop_type="center", 19 | crop_size=[370, 1224], 20 | norm_params={ 21 | "use": True, 22 | "mean": (127.5, 127.5, 127.5), 23 | "std": (127.5, 127.5, 127.5), 24 | }, 25 | ) 26 | 27 | kitti_data_loader = dataloader_creator.get_dataloader() 28 | eval_model( 29 | model, 30 | kitti_data_loader, 31 | device="0", 32 | pad_divisor=8, 33 | flow_scale=1.0, 34 | ) 35 | 36 | print("Evaluation Complete!!") 37 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021, Neelay Shah 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | 23 | -------------------------------------------------------------------------------- /ezflow/modules/build.py: -------------------------------------------------------------------------------- 1 | from ..utils import Registry 2 | 3 | MODULE_REGISTRY = Registry("MODULE") 4 | 5 | 6 | def build_module(cfg_grp=None, name=None, instantiate=True, **kwargs): 7 | """ 8 | Build a module from a registered module name 9 | 10 | Parameters 11 | ---------- 12 | cfg_grp : :class:`CfgNode` 13 | Config to pass to the module 14 | name : str 15 | Name of the registered module 16 | instantiate : bool 17 | Whether to instantiate the module or not 18 | kwargs : dict 19 | Keyword arguments to pass to the module 20 | 21 | Returns 22 | ------- 23 | object 24 | The module object 25 | """ 26 | 27 | if cfg_grp is None: 28 | assert name is not None, "Must provide name or cfg_grp" 29 | assert dict(**kwargs) is not None, "Must provide either cfg_grp or kwargs" 30 | 31 | if name is None: 32 | name = cfg_grp.NAME 33 | 34 | module = MODULE_REGISTRY.get(name) 35 | 36 | if not instantiate: 37 | return module 38 | 39 | if cfg_grp is None: 40 | return module(**kwargs) 41 | 42 | return module(cfg_grp, **kwargs) 43 | -------------------------------------------------------------------------------- /configs/trainers/flownetc/flownetc_things_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/things_baseline.yaml" 2 | TARGET_SCALE_FACTOR: 20.0 3 | DATA: 4 | BATCH_SIZE: 2 # Effective Batch Size = 2 x 2 GPUs = 4 5 | NORM_PARAMS: {"use": True, "mean":[0.0, 0.0, 0.0], "std":[255.0, 255.0, 255.0]} 6 | SCHEDULER: 7 | USE: True 8 | NAME: OneCycleLR 9 | PARAMS: 10 | max_lr: 0.000125 11 | total_steps: 380100 12 | pct_start: 0.05 13 | cycle_momentum: False 14 | anneal_strategy: linear 15 | CRITERION: 16 | CUSTOM: True 17 | NAME: MultiScaleLoss 18 | PARAMS: 19 | norm: "l2" 20 | weights: [0.32, 0.08, 0.02, 0.01, 0.005] 21 | average: "sum" 22 | resize_flow: "downsample" 23 | use_valid_range: True 24 | valid_range: [[1000,1000],[1000,1000],[1000,1000],[1000,1000],[1000,1000]] 25 | DEVICE: "all" 26 | DISTRIBUTED: 27 | USE: True 28 | WORLD_SIZE: 2 29 | BACKEND: nccl 30 | MASTER_ADDR: localhost 31 | MASTER_PORT: "12355" 32 | SYNC_BATCH_NORM: True 33 | EPOCHS: null 34 | NUM_STEPS: 380100 35 | LOG_DIR: "./logs" 36 | CKPT_DIR: "./ckpts" 37 | LOG_ITERATIONS_INTERVAL: 100 38 | CKPT_INTERVAL: 100000 39 | VALIDATE_INTERVAL: 10000 40 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /ezflow/encoder/build.py: -------------------------------------------------------------------------------- 1 | from ..utils import Registry 2 | 3 | ENCODER_REGISTRY = Registry("ENCODER") 4 | 5 | 6 | def build_encoder(cfg_grp=None, name=None, instantiate=True, **kwargs): 7 | 8 | """ 9 | Build an encoder from a registered encoder name 10 | 11 | Parameters 12 | ---------- 13 | cfg : :class:`CfgNode` 14 | Config to pass to the encoder 15 | name : str 16 | Name of the registered encoder 17 | instantiate : bool 18 | Whether to instantiate the encoder 19 | kwargs : dict 20 | Additional keyword arguments to pass to the encoder 21 | 22 | Returns 23 | ------- 24 | torch.nn.Module 25 | The encoder object 26 | """ 27 | 28 | if cfg_grp is None: 29 | assert name is not None, "Must provide name or cfg_grp" 30 | assert dict(**kwargs) is not None, "Must provide either cfg_grp or kwargs" 31 | 32 | if name is None: 33 | name = cfg_grp.NAME 34 | 35 | encoder = ENCODER_REGISTRY.get(name) 36 | 37 | if not instantiate: 38 | return encoder 39 | 40 | if cfg_grp is None: 41 | return encoder(**kwargs) 42 | 43 | return encoder(cfg_grp, **kwargs) 44 | -------------------------------------------------------------------------------- /configs/models/dcvnet.yaml: -------------------------------------------------------------------------------- 1 | NAME: DCVNet 2 | ENCODER: 3 | NAME: DCVNetBackbone 4 | IN_CHANNELS: 3 5 | OUT_CHANNELS: 256 6 | NORM: instance 7 | P_DROPOUT: 0.0 8 | LAYER_CONFIG: [64, 96, 128] 9 | SIMILARITY: 10 | NAME: MatryoshkaDilatedCostVolumeList 11 | NUM_GROUPS: 1 12 | MAX_DISPLACEMENT: 4 13 | ENCODER_OUTPUT_STRIDES: [2, 8] 14 | DILATIONS: [[1],[1, 2, 3, 5, 9, 16]] 15 | NORMALIZE_FEAT_L2: False 16 | USE_RELU: False 17 | DECODER: 18 | NAME: DCVDilatedFlowStackFilterDecoder 19 | FEAT_STRIDES: [2, 8] 20 | DILATIONS: [[1],[1, 2, 3, 5, 9, 16]] 21 | COST_VOLUME_FILTER: 22 | NAME: DCVFilterGroupConvStemJoint 23 | NUM_GROUPS: 1 24 | HIDDEN_DIM: 96 25 | FEAT_IN_PLANES: 128 26 | OUT_CHANNELS: 567 27 | USE_FILTER_RESIDUAL: True 28 | USE_GROUP_CONV_STEM: True 29 | NORM: none 30 | UNET: 31 | NAME: UNetBase 32 | NUM_GROUPS: 1 33 | IN_CHANNELS: 695 34 | HIDDEN_DIM: 96 35 | OUT_CHANNELS: 96 36 | NORM: none 37 | BOTTLE_NECK: 38 | NAME: ASPPConv2D 39 | IN_CHANNELS: 192 40 | HIDDEN_DIM: 192 41 | OUT_CHANNELS: 192 42 | DILATIONS: [2, 4, 8] 43 | NUM_GROUPS: 1 44 | NORM: none -------------------------------------------------------------------------------- /ezflow/similarity/build.py: -------------------------------------------------------------------------------- 1 | from ..utils import Registry 2 | 3 | SIMILARITY_REGISTRY = Registry("SIMILARITY") 4 | 5 | 6 | def build_similarity(cfg_grp=None, name=None, instantiate=True, **kwargs): 7 | 8 | """ 9 | Build a similarity function from a registered similarity function name. 10 | 11 | Parameters 12 | ---------- 13 | cfg : CfgNode 14 | Config to pass to the similarity function. 15 | name : str 16 | Name of the registered similarity function. 17 | instantiate : bool 18 | Whether to instantiate the similarity function. 19 | 20 | Returns 21 | ------- 22 | similarity_fn : object 23 | The similarity function object. 24 | """ 25 | 26 | if cfg_grp is None: 27 | assert name is not None, "Must provide name or cfg_grp" 28 | assert dict(**kwargs) is not None, "Must provide either cfg_grp or kwargs" 29 | 30 | if name is None: 31 | name = cfg_grp.NAME 32 | 33 | similarity_fn = SIMILARITY_REGISTRY.get(name) 34 | 35 | if not instantiate: 36 | return similarity_fn 37 | 38 | if cfg_grp is None: 39 | return similarity_fn(**kwargs) 40 | 41 | return similarity_fn(cfg_grp, **kwargs) 42 | -------------------------------------------------------------------------------- /configs/trainers/raft/raft_things_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/things_baseline.yaml" 2 | FREEZE_BATCH_NORM: True 3 | DATA: 4 | BATCH_SIZE: 6 5 | NORM_PARAMS: {"use": True, "mean":[127.5, 127.5, 127.5], "std":[127.5, 127.5, 127.5]} 6 | APPEND_VALID_MASK: True 7 | TRAIN_DATASET: 8 | FlyingThings3DClean: &TRAIN_DS_CONFIG 9 | CROP: 10 | USE: True 11 | SIZE: [400, 720] 12 | TYPE: "random" 13 | FlyingThings3DFinal: *TRAIN_DS_CONFIG 14 | VAL_DATASET: 15 | FlyingThings3DClean: &VAL_DS_CONFIG 16 | PADDING: 1 17 | CROP: 18 | USE: True 19 | SIZE: [368, 496] 20 | TYPE: "center" 21 | FlyingThings3DFinal: *VAL_DS_CONFIG 22 | SCHEDULER: 23 | USE: True 24 | NAME: OneCycleLR 25 | PARAMS: 26 | max_lr: 0.000125 27 | total_steps: 100100 28 | pct_start: 0.05 29 | cycle_momentum: False 30 | anneal_strategy: linear 31 | CRITERION: 32 | CUSTOM: True 33 | NAME: SequenceLoss 34 | PARAMS: 35 | gamma: 0.8 36 | max_flow: 400.0 37 | EPOCHS: null 38 | NUM_STEPS: 100100 39 | LOG_DIR: "./logs" 40 | CKPT_DIR: "./ckpts" 41 | LOG_ITERATIONS_INTERVAL: 100 42 | CKPT_INTERVAL: 20000 43 | VALIDATE_INTERVAL: 1000 44 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /ezflow/utils/warp.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def warp(x, flow): 6 | """ 7 | Warps an image x according to the optical flow field specified 8 | 9 | Parameters 10 | ---------- 11 | x : torch.Tensor 12 | Image to be warped 13 | flow : torch.Tensor 14 | Optical flow field 15 | 16 | Returns 17 | ------- 18 | torch.Tensor 19 | Warped image 20 | """ 21 | 22 | B, _, H, W = x.size() 23 | 24 | xx = torch.arange(0, W).view(1, -1).repeat(H, 1) 25 | yy = torch.arange(0, H).view(-1, 1).repeat(1, W) 26 | xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) 27 | yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) 28 | 29 | grid = torch.cat((xx, yy), 1).float() 30 | vgrid = torch.Tensor(grid).to(x.device) + flow 31 | vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :] / max(W - 1, 1) - 1.0 32 | vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :] / max(H - 1, 1) - 1.0 33 | vgrid = vgrid.permute(0, 2, 3, 1) 34 | 35 | output = nn.functional.grid_sample(x, vgrid, align_corners=True) 36 | 37 | mask = torch.ones_like(x) 38 | mask = nn.functional.grid_sample(mask, vgrid, align_corners=True) 39 | mask[mask < 0.9999] = 0 40 | mask[mask > 0] = 1 41 | 42 | return output * mask 43 | -------------------------------------------------------------------------------- /tests/utils/mock_data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | from ezflow.data import DataloaderCreator 5 | 6 | 7 | class MockOpticalFlowDataset(Dataset): 8 | def __init__(self, size, channels, length): 9 | 10 | self.length = length 11 | 12 | if not isinstance(size, list) and not isinstance(size, tuple): 13 | size = (size, size) 14 | 15 | self.imgs = torch.randn(length, channels, *size) 16 | self.flow = torch.randn(length, 2, *size) 17 | self.valid = torch.ones(1, *size) 18 | self.offset_labs = torch.randint(0, 1, (1, 567, 32, 32)) 19 | 20 | def __len__(self): 21 | return self.length 22 | 23 | def __getitem__(self, idx): 24 | target = {} 25 | target["flow_gt"] = self.flow[idx] 26 | target["valid"] = self.valid 27 | target["offset_labs"] = self.offset_labs 28 | return (self.imgs[idx], self.imgs[idx]), target 29 | 30 | 31 | class MockDataloaderCreator(DataloaderCreator): 32 | def __init__(self): 33 | super(MockDataloaderCreator, self).__init__(batch_size=1) 34 | 35 | self.dataset_list = [] 36 | self.dataset_list.append( 37 | MockOpticalFlowDataset(size=(64, 64), channels=3, length=4) 38 | ) 39 | -------------------------------------------------------------------------------- /ezflow/engine/pruning.py: -------------------------------------------------------------------------------- 1 | from torch.nn.utils import prune 2 | 3 | 4 | def prune_l1_unstructured(model, layer_type, proportion): 5 | """ 6 | L1 unstructured pruning 7 | 8 | Parameters 9 | ---------- 10 | model : torch.nn.Module 11 | The model to prune 12 | layer_type : torch.nn.Module 13 | The layer type to prune 14 | proportion : float 15 | The proportion of weights to prune 16 | """ 17 | 18 | for module in model.modules(): 19 | if isinstance(module, layer_type): 20 | prune.l1_unstructured(module, "weight", proportion) 21 | prune.remove(module, "weight") 22 | 23 | return model 24 | 25 | 26 | def prune_l1_structured(model, layer_type, proportion): 27 | """ 28 | L1 structured pruning 29 | 30 | Parameters 31 | ---------- 32 | model : torch.nn.Module 33 | The model to prune 34 | layer_type : torch.nn.Module 35 | The layer type to prune 36 | proportion : float 37 | The proportion of weights to prune 38 | """ 39 | 40 | for module in model.modules(): 41 | if isinstance(module, layer_type): 42 | prune.ln_structured(module, "weight", proportion, n=1, dim=1) 43 | prune.remove(module, "weight") 44 | 45 | return model 46 | -------------------------------------------------------------------------------- /.github/workflows/codecov.yml: -------------------------------------------------------------------------------- 1 | name: Code coverage 2 | 3 | on: 4 | push: 5 | branches: [ main ] 6 | pull_request: 7 | branches: [ main ] 8 | 9 | jobs: 10 | 11 | build: 12 | runs-on: ubuntu-latest 13 | strategy: 14 | fail-fast: false 15 | matrix: 16 | python-version: ['3.7', '3.8'] 17 | 18 | steps: 19 | 20 | - uses: actions/checkout@v2 21 | 22 | - name: Set up Python ${{ matrix.python-version }} 23 | uses: actions/setup-python@v2 24 | with: 25 | python-version: ${{ matrix.python-version }} 26 | 27 | - name: Install dependencies 28 | run: | 29 | python -m pip install --upgrade pip 30 | python -m pip install black pytest 31 | pip install pytest-cov codecov 32 | if [ -f requirements.txt ]; then pip install -r requirements.txt; fi 33 | pip install build 34 | 35 | - name: Build package 36 | run: python -m build 37 | 38 | - name: Tests 39 | run: pytest --cov-report xml --cov='./ezflow/' --cov-config=.coveragerc 40 | 41 | - name: Code Coverage Report 42 | uses: codecov/codecov-action@v1 43 | if: always() 44 | with: 45 | token: ${{ secrets.CODECOV_TOKEN }} 46 | fail_ci_if_error: false 47 | file: coverage.xml 48 | env_vars: OS,PYTHON -------------------------------------------------------------------------------- /ezflow/functional/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.optim as optim 3 | 4 | from .registry import FUNCTIONAL_REGISTRY 5 | 6 | 7 | @FUNCTIONAL_REGISTRY.register() 8 | class CosineWarmupScheduler(optim.lr_scheduler._LRScheduler): 9 | """ 10 | Cosine learning rate warmup scheduler 11 | 12 | Parameters 13 | ---------- 14 | optimizer : torch.optim.Optimizer 15 | Optimizer to be used with the scheduler 16 | warmup : int 17 | Number of epochs to warmup the learning rate 18 | max_iters : int 19 | Maximum number of iterations to train the model 20 | """ 21 | 22 | def __init__(self, optimizer, warmup=100, max_iters=200): 23 | super().__init__(optimizer) 24 | 25 | self.warmup = warmup 26 | self.max_num_iters = max_iters 27 | 28 | def get_lr(self): 29 | 30 | lr_factor = self.get_lr_factor(epoch=self.last_epoch) 31 | 32 | return [base_lr * lr_factor for base_lr in self.base_lrs] 33 | 34 | def get_lr_factor(self, epoch): 35 | """ 36 | Parameters 37 | ---------- 38 | epoch : int 39 | Current epoch 40 | 41 | Returns 42 | ------- 43 | float 44 | Learning rate factor 45 | """ 46 | 47 | lr_factor = 0.5 * (1 + np.cos(np.pi * epoch / self.max_num_iters)) 48 | if epoch <= self.warmup: 49 | lr_factor *= epoch * 1.0 / self.warmup 50 | 51 | return lr_factor 52 | -------------------------------------------------------------------------------- /ezflow/engine/registry.py: -------------------------------------------------------------------------------- 1 | from torch.nn import CrossEntropyLoss, L1Loss, MSELoss 2 | from torch.optim import SGD, Adadelta, Adagrad, Adam, AdamW, RMSprop 3 | from torch.optim.lr_scheduler import ( 4 | CosineAnnealingLR, 5 | CosineAnnealingWarmRestarts, 6 | CyclicLR, 7 | MultiStepLR, 8 | OneCycleLR, 9 | ReduceLROnPlateau, 10 | StepLR, 11 | ) 12 | 13 | from ..functional import CosineWarmupScheduler 14 | from ..utils import Registry 15 | 16 | loss_functions = Registry("loss_functions") 17 | optimizers = Registry("optimizers") 18 | schedulers = Registry("schedulers") 19 | 20 | loss_functions.register(CrossEntropyLoss, "CrossEntropyLoss") 21 | loss_functions.register(MSELoss, "MSELoss") 22 | loss_functions.register(L1Loss, "L1Loss") 23 | 24 | optimizers.register(SGD, "SGD") 25 | optimizers.register(Adam, "Adam") 26 | optimizers.register(AdamW, "AdamW") 27 | optimizers.register(Adagrad, "Adagrad") 28 | optimizers.register(Adadelta, "Adadelta") 29 | optimizers.register(RMSprop, "RMSprop") 30 | 31 | schedulers.register(CosineAnnealingLR, "CosineAnnealingLR") 32 | schedulers.register(CosineAnnealingWarmRestarts, "CosineAnnealingWarmRestarts") 33 | schedulers.register(CyclicLR, "CyclicLR") 34 | schedulers.register(MultiStepLR, "MultiStepLR") 35 | schedulers.register(ReduceLROnPlateau, "ReduceLROnPlateau") 36 | schedulers.register(StepLR, "StepLR") 37 | schedulers.register(OneCycleLR, "OneCycleLR") 38 | schedulers.register(CosineWarmupScheduler, "CosineWarmupScheduler") 39 | -------------------------------------------------------------------------------- /ezflow/utils/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def endpointerror(pred, flow_gt, valid=None, multi_magnitude=False, **kwargs): 5 | """ 6 | Endpoint error 7 | 8 | Parameters 9 | ---------- 10 | pred : torch.Tensor 11 | Predicted flow 12 | flow_gt : torch.Tensor 13 | flow_gt flow 14 | valid : torch.Tensor 15 | Valid flow vectors 16 | 17 | Returns 18 | ------- 19 | torch.Tensor 20 | Endpoint error 21 | """ 22 | if isinstance(pred, tuple) or isinstance(pred, list): 23 | pred = pred[-1] 24 | 25 | epe = torch.norm(pred - flow_gt, p=2, dim=1) 26 | f1 = None 27 | 28 | if valid is not None: 29 | mag = torch.sum(flow_gt**2, dim=1).sqrt() 30 | 31 | epe = epe.view(-1) 32 | mag = mag.view(-1) 33 | val = valid.reshape(-1) >= 0.5 34 | 35 | f1 = ((epe > 3.0) & ((epe / mag) > 0.05)).float() 36 | 37 | epe = epe[val] 38 | f1 = f1[val].cpu().numpy() 39 | 40 | if not multi_magnitude: 41 | if f1 is not None: 42 | return epe.mean().item(), f1 43 | 44 | return epe.mean().item() 45 | 46 | epe = epe.view(-1) 47 | multi_magnitude_epe = { 48 | "epe": epe.mean().item(), 49 | "1px": (epe < 1).float().mean().item(), 50 | "3px": (epe < 3).float().mean().item(), 51 | "5px": (epe < 5).float().mean().item(), 52 | } 53 | 54 | if f1 is not None: 55 | return multi_magnitude_epe, f1 56 | 57 | return multi_magnitude_epe 58 | -------------------------------------------------------------------------------- /ezflow/data/build.py: -------------------------------------------------------------------------------- 1 | from ..utils import Registry 2 | 3 | DATASET_REGISTRY = Registry("DATASET_REGISTRY") 4 | 5 | 6 | def build_dataloader(cfg, split="training", is_distributed=False, world_size=None): 7 | """ 8 | Build a dataloader from registered datasets 9 | 10 | Parameters 11 | ---------- 12 | cfg: :class:`CfgNode` 13 | Dataset configuration to instantiate DatalaoderCreator 14 | 15 | 16 | Returns 17 | ------- 18 | ezflow.data.DataloaderCreator 19 | 20 | """ 21 | from .dataloader import DataloaderCreator 22 | 23 | # TODO: assert mandatory config in cfg.data 24 | 25 | dataloader_creator = DataloaderCreator( 26 | batch_size=cfg.BATCH_SIZE, 27 | pin_memory=cfg.PIN_MEMORY, 28 | shuffle=cfg.SHUFFLE, 29 | num_workers=cfg.NUM_WORKERS, 30 | drop_last=cfg.DROP_LAST, 31 | init_seed=cfg.INIT_SEED, 32 | append_valid_mask=cfg.APPEND_VALID_MASK, 33 | distributed=is_distributed, 34 | world_size=world_size, 35 | ) 36 | 37 | data_cfg = cfg.TRAIN_DATASET if split.lower() == "training" else cfg.VAL_DATASET 38 | 39 | for key in data_cfg: 40 | data_cfg[key].INIT_SEED = cfg.INIT_SEED 41 | data_cfg[key].NORM_PARAMS = cfg.NORM_PARAMS 42 | data_cfg[key].APPEND_VALID_MASK = cfg.APPEND_VALID_MASK 43 | 44 | dataset = DATASET_REGISTRY.get(key)(data_cfg[key]) 45 | dataloader_creator.add_dataset(dataset) 46 | 47 | return dataloader_creator 48 | 49 | 50 | def get_dataset_list(): 51 | return DATASET_REGISTRY.get_list() 52 | -------------------------------------------------------------------------------- /ezflow/functional/criterion/sequence.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ...config import configurable 5 | from ..registry import FUNCTIONAL_REGISTRY 6 | 7 | 8 | @FUNCTIONAL_REGISTRY.register() 9 | class SequenceLoss(nn.Module): 10 | """ 11 | Sequence loss for optical flow estimation. 12 | Used in **RAFT** (https://arxiv.org/abs/2003.12039) 13 | 14 | Parameters 15 | ---------- 16 | gamma : float 17 | Weight for the loss 18 | max_flow : float 19 | Maximum flow magnitude 20 | """ 21 | 22 | @configurable 23 | def __init__(self, gamma=0.8, max_flow=400, **kwargs): 24 | super(SequenceLoss, self).__init__() 25 | 26 | self.gamma = gamma 27 | self.max_flow = max_flow 28 | 29 | @classmethod 30 | def from_config(cls, cfg): 31 | return {"gamma": cfg.GAMMA, "max_flow": cfg.MAX_FLOW} 32 | 33 | def forward(self, flow_preds, flow_gt, valid, **kwargs): 34 | # detect NaN 35 | nan_mask = (~torch.isnan(flow_gt)).float() 36 | flow_gt[torch.isnan(flow_gt)] = 0.0 37 | 38 | n_preds = len(flow_preds) 39 | flow_loss = 0.0 40 | valid = torch.squeeze(valid, dim=1) 41 | 42 | mag = torch.sqrt(torch.sum(flow_gt**2, dim=1)) 43 | valid = (valid >= 0.5) & (mag < self.max_flow) 44 | 45 | for i in range(n_preds): 46 | 47 | i_weight = self.gamma ** (n_preds - i - 1) 48 | i_loss = torch.abs(flow_preds[i] - flow_gt) 49 | flow_loss += i_weight * torch.mean((valid[:, None] * i_loss)) 50 | 51 | return flow_loss 52 | -------------------------------------------------------------------------------- /ezflow/modules/models/recurrent.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ...config import configurable 5 | from ..build import MODULE_REGISTRY 6 | 7 | 8 | @MODULE_REGISTRY.register() 9 | class ConvGRU(nn.Module): 10 | """ 11 | Convolutinal GRU layer 12 | 13 | Parameters 14 | ---------- 15 | hidden_dim : int, optional 16 | Hidden dimension of the GRU 17 | input_dim : int, optional 18 | Input dimension of the GRU 19 | kernel_size : int, optional 20 | Kernel size of the convolutional layers 21 | """ 22 | 23 | @configurable 24 | def __init__(self, hidden_dim=128, input_dim=192 + 128, kernel_size=3): 25 | super(ConvGRU, self).__init__() 26 | 27 | self.convz = nn.Conv2d( 28 | hidden_dim + input_dim, hidden_dim, kernel_size, padding=1 29 | ) 30 | self.convr = nn.Conv2d( 31 | hidden_dim + input_dim, hidden_dim, kernel_size, padding=1 32 | ) 33 | self.convq = nn.Conv2d( 34 | hidden_dim + input_dim, hidden_dim, kernel_size, padding=1 35 | ) 36 | 37 | @classmethod 38 | def from_config(cls, cfg): 39 | return { 40 | "hidden_dim": cfg.HIDDEN_DIM, 41 | "input_dim": cfg.INPUT_DIM, 42 | "kernel_size": cfg.KERNEL_SIZE, 43 | } 44 | 45 | def forward(self, h, x): 46 | hx = torch.cat([h, x], dim=1) 47 | 48 | z = torch.sigmoid(self.convz(hx)) 49 | r = torch.sigmoid(self.convr(hx)) 50 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 51 | 52 | h = (1 - z) * h + z * q 53 | 54 | return h 55 | -------------------------------------------------------------------------------- /ezflow/functional/weight_annealers.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | from ..config import configurable 4 | from .registry import FUNCTIONAL_REGISTRY 5 | 6 | 7 | @FUNCTIONAL_REGISTRY.register() 8 | class CosineAnnealer(object): 9 | @configurable 10 | def __init__(self, init_weight, min_weight, max_iter): 11 | self.init_weight = init_weight 12 | self.min_weight = min_weight 13 | self.max_iter = max_iter 14 | 15 | @classmethod 16 | def from_config(cls, cfg): 17 | return { 18 | "init_weight": cfg.INIT_WEIGHT, 19 | "min_weight": cfg.MAX_weight, 20 | "max_iter": cfg.MAX_ITER, 21 | } 22 | 23 | def __call__(self, cur_iter): 24 | wt = ( 25 | self.min_weight 26 | + (self.init_weight - self.min_weight) 27 | * (1 + math.cos(math.pi * cur_iter / self.max_iter)) 28 | / 2 29 | ) 30 | return wt 31 | 32 | 33 | @FUNCTIONAL_REGISTRY.register() 34 | class PolyAnnealer(object): 35 | @configurable 36 | def __init__(self, init_weight, min_weight, max_iter, power): 37 | self.init_weight = init_weight 38 | self.min_weight = min_weight 39 | self.max_iter = max_iter 40 | self.power = power 41 | 42 | @classmethod 43 | def from_config(cls, cfg): 44 | return { 45 | "init_weight": cfg.INIT_WEIGHT, 46 | "min_weight": cfg.MAX_weight, 47 | "max_iter": cfg.MAX_ITER, 48 | "power": cfg.POWER, 49 | } 50 | 51 | def __call__(self, cur_iter): 52 | wt = (self.init_weight - self.min_weight) * ( 53 | (1 - cur_iter / self.max_iter) ** (self.power) 54 | ) + self.min_weight 55 | return wt 56 | -------------------------------------------------------------------------------- /ezflow/engine/retrieve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Detectron2 (https://github.com/facebookresearch/detectron2) 3 | """ 4 | 5 | from ..config import get_cfg 6 | 7 | 8 | class _TrainerConfigs: 9 | """ 10 | Container class for training configurations. 11 | """ 12 | 13 | NAME_TO_TRAINER_CONFIG = { 14 | "BASE": "base.yaml", 15 | "RAFT": "raft_default.yaml", 16 | "DICL": "dicl_default.yaml", 17 | } 18 | 19 | @staticmethod 20 | def query(trainer_name): 21 | 22 | trainer_name = trainer_name.upper() 23 | 24 | if trainer_name in _TrainerConfigs.NAME_TO_TRAINER_CONFIG: 25 | 26 | cfg_file = _TrainerConfigs.NAME_TO_TRAINER_CONFIG[trainer_name] 27 | return cfg_file 28 | 29 | raise ValueError( 30 | f"Trainer name '{trainer_name}' not found in the training configs" 31 | ) 32 | 33 | 34 | def get_training_cfg(cfg_path=None, cfg_name=None, custom=True): 35 | 36 | """ 37 | Parameters 38 | ---------- 39 | cfg_path : str 40 | Path to the config file. 41 | cfg_name : str 42 | Name of the config file. 43 | custom : bool 44 | If True, the config file is assumed to be a custom config file. 45 | If False, the config file is assumed to be a standard config file present in ezflow/configs/trainers. 46 | 47 | Returns 48 | ------- 49 | cfg : CfgNode 50 | The config object 51 | """ 52 | 53 | assert ( 54 | cfg_path is not None or cfg_name is not None 55 | ), "Either cfg_path or cfg_name must be provided" 56 | 57 | if cfg_path is None: 58 | cfg_path = _TrainerConfigs.query(cfg_name) 59 | return get_cfg(cfg_path, custom=False, grp="trainers") 60 | 61 | return get_cfg(cfg_path, custom=custom, grp="trainers") 62 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | env/ 12 | build/ 13 | develop-eggs/ 14 | dist/ 15 | downloads/ 16 | eggs/ 17 | .eggs/ 18 | lib/ 19 | lib64/ 20 | parts/ 21 | sdist/ 22 | var/ 23 | wheels/ 24 | *.egg-info/ 25 | .installed.cfg 26 | *.egg 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | 58 | # Flask stuff: 59 | instance/ 60 | .webassets-cache 61 | 62 | # Scrapy stuff: 63 | .scrapy 64 | 65 | # Sphinx documentation 66 | docs/_build/ 67 | 68 | # PyBuilder 69 | target/ 70 | 71 | # Jupyter Notebook 72 | .ipynb_checkpoints 73 | 74 | # pyenv 75 | .python-version 76 | 77 | # celery beat schedule file 78 | celerybeat-schedule 79 | 80 | # SageMath parsed files 81 | *.sage.py 82 | 83 | # dotenv 84 | .env 85 | 86 | # virtualenv 87 | .venv 88 | venv/ 89 | ENV/ 90 | 91 | # Spyder project settings 92 | .spyderproject 93 | .spyproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | # mypy 102 | .mypy_cache/ 103 | 104 | # IDE settings 105 | .vscode/ 106 | .idea/ 107 | 108 | # model_zoo 109 | ezflow/model_zoo/configs 110 | 111 | # unit test outputs 112 | ckpts/ 113 | logs/ -------------------------------------------------------------------------------- /configs/trainers/dcvnet/dcvnet_sceneflow_baseline.yaml: -------------------------------------------------------------------------------- 1 | _BASE_: "../_base_/sceneflow_baseline.yaml" 2 | DATA: 3 | BATCH_SIZE: 8 4 | NORM_PARAMS: {"use": True, "mean":[127.5, 127.5, 127.5], "std":[127.5, 127.5, 127.5]} 5 | TRAIN_DATASET: 6 | FlyingThings3DSubset: 7 | APPEND_VALID_MASK: True 8 | CROP: &TRAIN_CROP_CONFIG 9 | USE: True 10 | SIZE: [400, 720] 11 | TYPE: "random" 12 | FLOW_OFFSET_PARAMS: &FLOW_OFFSET_PARAMS { 13 | "use": True, 14 | "dilations": [[1], [1, 2, 3, 5, 9, 16]], 15 | "feat_strides": [2, 8], 16 | "search_radius": 4, 17 | } 18 | Driving: 19 | APPEND_VALID_MASK: True 20 | CROP: *TRAIN_CROP_CONFIG 21 | FLOW_OFFSET_PARAMS: *FLOW_OFFSET_PARAMS 22 | Monkaa: 23 | APPEND_VALID_MASK: True 24 | CROP: *TRAIN_CROP_CONFIG 25 | FLOW_OFFSET_PARAMS: *FLOW_OFFSET_PARAMS 26 | VAL_DATASET: 27 | MPISintelClean: 28 | APPEND_VALID_MASK: True 29 | PADDING: 1 30 | CROP: 31 | USE: True 32 | SIZE: [384, 1024] 33 | TYPE: "center" 34 | OPTIMIZER: 35 | NAME: AdamW 36 | LR: 0.0002 37 | PARAMS: 38 | weight_decay: 0.0001 39 | betas: [0.9, 0.999] 40 | eps: 1.e-08 41 | amsgrad: False 42 | GRAD_CLIP: 43 | USE: True 44 | VALUE: 1.0 45 | SCHEDULER: 46 | USE: True 47 | NAME: OneCycleLR 48 | PARAMS: 49 | max_lr: 0.0002 50 | epochs: 50 51 | pct_start: 0.05 52 | cycle_momentum: False 53 | anneal_strategy: linear 54 | final_div_factor: 10000 55 | CRITERION: 56 | CUSTOM: True 57 | NAME: FlowOffsetLoss 58 | PARAMS: 59 | gamma: 0.25 60 | max_flow: 500.0 61 | stride: 8 62 | weight_anneal_fn: CosineAnnealer 63 | min_weight: 0 64 | offset_loss_weight: [0, 1] 65 | EPOCHS: 50 66 | NUM_STEPS: null 67 | LOG_DIR: "./logs" 68 | CKPT_DIR: "./ckpts" 69 | LOG_ITERATIONS_INTERVAL: 100 70 | CKPT_INTERVAL: 20000 71 | VALIDATE_INTERVAL: 1000 72 | VALIDATE_ON: metric -------------------------------------------------------------------------------- /ezflow/config/retrieve.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Detectron2 (https://github.com/facebookresearch/detectron2) 3 | """ 4 | 5 | import os 6 | 7 | import pkg_resources 8 | 9 | from .config import CfgNode 10 | 11 | 12 | def get_cfg_obj(): 13 | """Returns a :class:`CfgNode` object""" 14 | return CfgNode(new_allowed=True) 15 | 16 | 17 | def get_cfg_path(cfg_path, grp="models"): 18 | 19 | """ 20 | Returns the complete path to a config file present in ezflow 21 | 22 | Parameters 23 | ---------- 24 | cfg_path : str 25 | Config file path relative to ezflow's "configs/{grp}" directory 26 | 27 | Returns 28 | ------- 29 | str 30 | The complete path to the config file. 31 | """ 32 | 33 | grp = grp.lower() 34 | assert grp in ("models", "trainers"), "Grp must be either 'models' or 'trainers' " 35 | 36 | if grp == "models": 37 | cfg_complete_path = pkg_resources.resource_filename( 38 | "ezflow.model_zoo", os.path.join("configs", grp, cfg_path) 39 | ) 40 | 41 | elif grp == "trainers": 42 | cfg_complete_path = pkg_resources.resource_filename( 43 | "ezflow.model_zoo", os.path.join("configs", grp, cfg_path) 44 | ) 45 | 46 | if not os.path.exists(cfg_complete_path): 47 | raise RuntimeError( 48 | f"{grp}/{cfg_path} is not available in ezflow's model zoo or trainer configs!" 49 | ) 50 | 51 | return cfg_complete_path 52 | 53 | 54 | def get_cfg(cfg_path, custom=False, grp="models"): 55 | 56 | """ 57 | Returns a config object for a model in model zoo. 58 | 59 | Parameters 60 | ---- 61 | config_path : str 62 | Complete config file path 63 | 64 | 65 | Returns 66 | CfgNode or omegaconf.DictConfig: a config object 67 | """ 68 | 69 | if not custom: 70 | cfg_path = get_cfg_path(cfg_path, grp=grp) 71 | 72 | cfg = get_cfg_obj() 73 | cfg.merge_from_file(cfg_path) 74 | 75 | return cfg 76 | -------------------------------------------------------------------------------- /ezflow/similarity/correlation/layer.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/oblime/CorrelationLayer 2 | 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from ...config import configurable 8 | from ..build import SIMILARITY_REGISTRY 9 | 10 | 11 | @SIMILARITY_REGISTRY.register() 12 | class CorrelationLayer(nn.Module): 13 | """ 14 | Correlation layer in pure PyTorch 15 | Only supports specific values of the following parameters: 16 | - kernel_size: 1 17 | - stride_1: 1 18 | - stride_2: 1 19 | - corr_multiply: 1 20 | 21 | Parameters 22 | ---------- 23 | pad_size : int 24 | Padding size for the correlation layer 25 | max_displacement : int 26 | Maximum displacement for the correlation computation 27 | """ 28 | 29 | @configurable 30 | def __init__(self, pad_size=4, max_displacement=4): 31 | super().__init__() 32 | 33 | self.max_h_disp = max_displacement 34 | self.padlayer = nn.ConstantPad2d(pad_size, 0) 35 | 36 | @classmethod 37 | def from_config(cls, cfg): 38 | return { 39 | "pad_size": cfg.PAD_SIZE, 40 | "max_displacement": cfg.MAX_DISPLACEMENT, 41 | } 42 | 43 | def forward(self, features1, features2): 44 | 45 | features2_pad = self.padlayer(features2) 46 | offsety, offsetx = torch.meshgrid( 47 | [ 48 | torch.arange(0, 2 * self.max_h_disp + 1), 49 | torch.arange(0, 2 * self.max_h_disp + 1), 50 | ], 51 | indexing="ij", 52 | ) 53 | 54 | H, W = features1.shape[2], features1.shape[3] 55 | output = torch.cat( 56 | [ 57 | torch.mean( 58 | features1 * features2_pad[:, :, dy : dy + H, dx : dx + W], 59 | 1, 60 | keepdim=True, 61 | ) 62 | for dx, dy in zip(offsetx.reshape(-1), offsety.reshape(-1)) 63 | ], 64 | 1, 65 | ) 66 | 67 | return output 68 | -------------------------------------------------------------------------------- /ezflow/modules/dap.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ..config import configurable 4 | from .build import MODULE_REGISTRY 5 | from .units import ConvNormRelu 6 | 7 | 8 | @MODULE_REGISTRY.register() 9 | class DisplacementAwareProjection(nn.Module): 10 | """ 11 | Displacement-aware projection layer 12 | 13 | Parameters 14 | ---------- 15 | max_displacement : int, optional 16 | Maximum displacement 17 | temperature : bool, optional 18 | If True, use temperature scaling 19 | temp_factor : float, optional 20 | Temperature scaling factor 21 | """ 22 | 23 | @configurable 24 | def __init__(self, max_displacement=3, temperature=False, temp_factor=1e-6): 25 | super(DisplacementAwareProjection, self).__init__() 26 | 27 | self.temperature = temperature 28 | self.temp_factor = temp_factor 29 | 30 | dim_c = (2 * max_displacement + 1) ** 2 31 | 32 | if self.temperature: 33 | self.dap_layer = ConvNormRelu( 34 | dim_c, 1, kernel_size=1, padding=0, stride=1, norm=None, activation=None 35 | ) 36 | 37 | else: 38 | self.dap_layer = ConvNormRelu( 39 | dim_c, 40 | dim_c, 41 | kernel_size=1, 42 | padding=0, 43 | stride=1, 44 | norm=None, 45 | activation=None, 46 | ) 47 | 48 | @classmethod 49 | def from_config(cls, cfg): 50 | return { 51 | "max_displacement": cfg.MAX_DISPLACEMENT, 52 | "temperature": cfg.TEMPERATURE, 53 | "temp_factor": cfg.TEMP_FACTOR, 54 | } 55 | 56 | def forward(self, x): 57 | 58 | x = x.squeeze(1) 59 | bs, du, dv, h, w = x.shape 60 | x = x.view(bs, du * dv, h, w) 61 | 62 | if self.temperature: 63 | temp = self.dap_layer(x) + self.temp_factor 64 | x = x * temp 65 | else: 66 | x = self.dap_layer(x) 67 | 68 | return x.view(bs, du, dv, h, w).unsqueeze(1) 69 | -------------------------------------------------------------------------------- /ezflow/encoder/pyramid.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | from ..config import configurable 4 | from ..modules import conv 5 | from .build import ENCODER_REGISTRY 6 | 7 | 8 | @ENCODER_REGISTRY.register() 9 | class PyramidEncoder(nn.Module): 10 | """ 11 | Pyramid encoder which returns a hierarchy of features 12 | Used in **PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume** (https://arxiv.org/abs/1709.02371) 13 | 14 | Parameters 15 | ---------- 16 | in_channels : int 17 | Number of input channels 18 | config : list of int 19 | Configuration of the pyramid encoder's layers 20 | """ 21 | 22 | @configurable 23 | def __init__(self, in_channels=3, config=[16, 32, 64, 96, 128, 196]): 24 | super().__init__() 25 | 26 | if isinstance(config, tuple): 27 | config = list(config) 28 | config = [in_channels] + config 29 | 30 | self.encoder = nn.ModuleList() 31 | 32 | for i in range(len(config) - 1): 33 | self.encoder.append( 34 | nn.Sequential( 35 | conv(config[i], config[i + 1], kernel_size=3, stride=2), 36 | conv(config[i + 1], config[i + 1], kernel_size=3, stride=1), 37 | conv(config[i + 1], config[i + 1], kernel_size=3, stride=1), 38 | ) 39 | ) 40 | 41 | @classmethod 42 | def from_config(self, cfg): 43 | return { 44 | "config": cfg.CONFIG, 45 | } 46 | 47 | def forward(self, img): 48 | """ 49 | Performs forward pass. 50 | 51 | Parameters 52 | ---------- 53 | img : torch.Tensor 54 | Input tensor 55 | 56 | Returns 57 | ------- 58 | List[torch.Tensor], 59 | List of all the output convolutions from each encoder layer 60 | """ 61 | 62 | feature_pyramid = [] 63 | x = img 64 | 65 | for i in range(len(self.encoder)): 66 | 67 | x = self.encoder[i](x) 68 | feature_pyramid.append(x) 69 | 70 | return feature_pyramid 71 | -------------------------------------------------------------------------------- /ezflow/decoder/context.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..config import configurable 5 | from ..modules import conv 6 | from .build import DECODER_REGISTRY 7 | 8 | 9 | @DECODER_REGISTRY.register() 10 | class ContextNetwork(nn.Module): 11 | """ 12 | PWCNet Context Network decoder 13 | 14 | Parameters 15 | ---------- 16 | in_channels: int, default: 565 17 | Number of input channels 18 | config : List[int], default : [128, 128, 96, 64, 32] 19 | List containing all output channels of the decoder. 20 | """ 21 | 22 | @configurable 23 | def __init__(self, in_channels=565, config=[128, 128, 96, 64, 32]): 24 | super(ContextNetwork, self).__init__() 25 | 26 | self.context_net = nn.ModuleList( 27 | [ 28 | conv( 29 | in_channels, 30 | config[0], 31 | kernel_size=3, 32 | stride=1, 33 | padding=1, 34 | dilation=1, 35 | ), 36 | ] 37 | ) 38 | self.context_net.append( 39 | conv(config[0], config[0], kernel_size=3, stride=1, padding=2, dilation=2) 40 | ) 41 | self.context_net.append( 42 | conv(config[0], config[1], kernel_size=3, stride=1, padding=4, dilation=4) 43 | ) 44 | self.context_net.append( 45 | conv(config[1], config[2], kernel_size=3, stride=1, padding=8, dilation=8) 46 | ) 47 | self.context_net.append( 48 | conv(config[2], config[3], kernel_size=3, stride=1, padding=16, dilation=16) 49 | ) 50 | self.context_net.append( 51 | conv(config[3], config[4], kernel_size=3, stride=1, padding=1, dilation=1) 52 | ) 53 | self.context_net.append( 54 | nn.Conv2d(config[4], 2, kernel_size=3, stride=1, padding=1, bias=True) 55 | ) 56 | self.context_net = nn.Sequential(*self.context_net) 57 | 58 | @classmethod 59 | def from_config(self, cfg): 60 | return {"in_channels": cfg.IN_CHANNELS, "config": cfg.CONFIG} 61 | 62 | def forward(self, x): 63 | return self.context_net(x) 64 | -------------------------------------------------------------------------------- /ezflow/models/build.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ..config import get_cfg 4 | from ..model_zoo import _ModelZooConfigs 5 | from ..utils import Registry 6 | 7 | MODEL_REGISTRY = Registry("MODEL") 8 | 9 | 10 | def get_default_model_cfg(model_name): 11 | 12 | cfg_path = _ModelZooConfigs.query(model_name) 13 | 14 | return get_cfg(cfg_path) 15 | 16 | 17 | def get_model_list(): 18 | return _ModelZooConfigs.get_names() 19 | 20 | 21 | def build_model( 22 | name, cfg_path=None, custom_cfg=False, cfg=None, default=False, weights_path=None 23 | ): 24 | """ 25 | Builds a model from a model name and config. Also supports loading weights 26 | 27 | Parameters 28 | ---------- 29 | name : str 30 | Name of the model to build 31 | cfg_path : str, optional 32 | Path to a config file. If not provided, will use the default config 33 | for the model 34 | custom_cfg : bool, optional 35 | Whether to use a custom config file. If False, will use the default 36 | config for the model 37 | cfg : CfgNode object, optional 38 | Custom config object. If provided, will use this config instead of 39 | the default config for the model 40 | default : bool, optional 41 | Whether to use the default config for the model 42 | weights_path : str, optional 43 | Path to a weights file 44 | 45 | Returns 46 | ------- 47 | torch.nn.Module 48 | The model 49 | """ 50 | 51 | if name not in MODEL_REGISTRY: 52 | raise ValueError(f"Model {name} not found in registry.") 53 | 54 | if cfg is None: 55 | 56 | if default: 57 | cfg_path = _ModelZooConfigs.query(name) 58 | cfg = get_cfg(cfg_path) 59 | 60 | else: 61 | assert cfg_path is not None, "Please provide a config path." 62 | cfg = get_cfg(cfg_path, custom=custom_cfg) 63 | 64 | model = MODEL_REGISTRY.get(name) 65 | model = model(cfg) 66 | 67 | if weights_path is not None: 68 | state_dict = torch.load(weights_path, map_location=torch.device("cpu")) 69 | if "model_state_dict" in state_dict: 70 | state_dict = state_dict["model_state_dict"] 71 | model.load_state_dict(state_dict) 72 | 73 | return model 74 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | alabaster==0.7.12 3 | antlr4-python3-runtime==4.8 4 | appdirs==1.4.4 5 | argh==0.26.2 6 | arrow==0.15.1 7 | attrs 8 | Babel==2.9.1 9 | backports.entry-points-selectable==1.1.0 10 | binaryornot==0.4.4 11 | black==21.7b0 12 | bleach==3.3.0 13 | brotlipy==0.7.0 14 | bump2version==0.5.11 15 | cachetools==4.2.2 16 | certifi==2022.12.7 17 | cffi==1.14.6 18 | cfgv==3.3.0 19 | chardet==4.0.0 20 | charset-normalizer==2.0.1 21 | click==8.0.1 22 | colorama==0.4.4 23 | coverage==5.5 24 | cryptography==3.4.8 25 | cycler==0.10.0 26 | distlib==0.3.6 27 | dnspython==1.16.0 28 | docopt==0.6.2 29 | docutils==0.17.1 30 | easydict==1.9 31 | entrypoints==0.3 32 | fett==0.3.2 33 | filelock==3.12.2 34 | flake8==3.9.2 35 | fvcore==0.1.5.post20210915 36 | grpcio==1.39.0 37 | identify==2.2.13 38 | idna==3.2 39 | imagesize==1.2.0 40 | importlib-metadata==6.7.0 41 | iniconfig==1.1.1 42 | iopath==0.1.9 43 | Jinja2==3.0.1 44 | jinja2-time==0.2.0 45 | kiwisolver 46 | Markdown==3.3.4 47 | MarkupSafe==2.0.1 48 | matplotlib==3.4.2 49 | mccabe==0.6.1 50 | mypy-extensions==0.4.3 51 | networkx==2.6.2 52 | nodeenv==1.6.0 53 | numpy==1.20.2 54 | oauthlib==3.2.2 55 | olefile==0.46 56 | omegaconf 57 | opencv-python==4.5.3.56 58 | packaging==21.0 59 | pathspec==0.9.0 60 | pathtools==0.1.2 61 | Pillow 62 | pkginfo==1.7.1 63 | platformdirs==3.8.0 64 | pluggy==0.13.1 65 | portalocker 66 | poyo==0.5.0 67 | pre-commit 68 | protobuf==3.18.3 69 | py==1.10.0 70 | pyasn1==0.4.8 71 | pyasn1-modules==0.2.8 72 | pycodestyle==2.7.0 73 | pycparser==2.20 74 | pyflakes==2.3.1 75 | Pygments==2.9.0 76 | pymongo==3.11.4 77 | pyOpenSSL==20.0.1 78 | pyparsing==2.4.7 79 | PySocks==1.7.1 80 | pytest==6.2.4 81 | python-dateutil==2.8.1 82 | python-slugify==5.0.2 83 | pytz==2021.1 84 | PyYAML==5.4.1 85 | readme-renderer==29.0 86 | regex 87 | requests==2.26.0 88 | requests-oauthlib==1.3.0 89 | requests-toolbelt==0.9.1 90 | rsa==4.7.2 91 | scipy==1.7.0 92 | six 93 | snooty-lextudio==1.11.1.dev0 94 | snowballstemmer==2.1.0 95 | tabulate==0.8.9 96 | tensorboard==2.6.0 97 | tensorboard-data-server==0.6.1 98 | tensorboard-plugin-wit==1.8.0 99 | termcolor==1.1.0 100 | text-unidecode==1.3 101 | toml 102 | tomli==1.2.1 103 | torch>=1.9.0 104 | torchmetrics>=0.5.0 105 | torchvision>=0.10.0 106 | tornado 107 | tqdm 108 | twine==1.14.0 109 | typed-ast 110 | typing-extensions==4.6.3 111 | Unidecode==1.3.2 112 | urllib3==1.26.6 113 | virtualenv==20.23.1 114 | watchdog==1.0.2 115 | webencodings==0.5.1 116 | whichcraft==0.6.1 117 | yacs==0.1.8 118 | zipp==3.5.0 -------------------------------------------------------------------------------- /ezflow/encoder/dcvnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..config import configurable 5 | from .build import ENCODER_REGISTRY 6 | from .residual import BasicEncoder 7 | 8 | 9 | @ENCODER_REGISTRY.register() 10 | class DCVNetBackbone(nn.Module): 11 | """ 12 | ResNet-style encoder that outputs feature maps of size (H/2,W/2) and (H/8,W/8) 13 | used in `DCVNet: Dilated Cost Volume Networks for Fast Optical Flow `_ 14 | 15 | Parameters 16 | ---------- 17 | in_channels : int 18 | Number of input channels 19 | out_channels : int 20 | Number of output channels 21 | norm : str 22 | Normalization layer to use. One of "batch", "instance", "group", or None 23 | p_dropout : float 24 | Dropout probability 25 | layer_config : list of int or tuple of int 26 | Number of output features per layer 27 | 28 | """ 29 | 30 | @configurable 31 | def __init__( 32 | self, 33 | in_channels=3, 34 | out_channels=256, 35 | norm="instance", 36 | p_dropout=0.0, 37 | layer_config=(64, 96, 128), 38 | ): 39 | super(DCVNetBackbone, self).__init__() 40 | assert len(layer_config) == 3, "Invalid number of layers for DCVNetBackbone." 41 | 42 | self.encoder = BasicEncoder( 43 | in_channels=in_channels, 44 | norm=norm, 45 | p_dropout=p_dropout, 46 | layer_config=layer_config, 47 | num_residual_layers=(1, 2, 2), 48 | intermediate_features=True, 49 | ) 50 | 51 | self.conv_stride2 = nn.Conv2d(layer_config[0], out_channels // 2, kernel_size=1) 52 | self.conv_stride8 = nn.Conv2d(layer_config[2], out_channels, kernel_size=1) 53 | 54 | @classmethod 55 | def from_config(cls, cfg): 56 | return { 57 | "in_channels": cfg.IN_CHANNELS, 58 | "out_channels": cfg.OUT_CHANNELS, 59 | "norm": cfg.NORM, 60 | "p_dropout": cfg.P_DROPOUT, 61 | "layer_config": cfg.LAYER_CONFIG, 62 | } 63 | 64 | def forward(self, x): 65 | if isinstance(x, tuple) or isinstance(x, list): 66 | x = torch.cat(x, dim=0) 67 | 68 | feature_pyramid = self.encoder(x) 69 | 70 | # Use feature maps of downsampling size (H/2,W/2) and (H/8, W/8) 71 | context = [feature_pyramid[0], feature_pyramid[2]] 72 | 73 | feat = [] 74 | for x_i, conv_i in zip(context, [self.conv_stride2, self.conv_stride8]): 75 | feat.append(conv_i(x_i)) 76 | 77 | return feat, context 78 | -------------------------------------------------------------------------------- /ezflow/utils/registry.py: -------------------------------------------------------------------------------- 1 | """ 2 | Adapted from Detectron2 (https://github.com/facebookresearch/detectron2) 3 | """ 4 | 5 | 6 | class Registry: 7 | """ 8 | Class to register objects and then retrieve them by name. 9 | 10 | Parameters 11 | ---------- 12 | name : str 13 | Name of the registry 14 | """ 15 | 16 | def __init__(self, name): 17 | 18 | self._name = name 19 | self._obj_map = {} 20 | 21 | def _do_register(self, name, obj): 22 | 23 | assert ( 24 | name not in self._obj_map 25 | ), f"An object named '{name}' was already registered in '{self._name}' registry!" 26 | 27 | self._obj_map[name] = obj 28 | 29 | def register(self, obj=None, name=None): 30 | """ 31 | Method to register an object in the registry 32 | 33 | Parameters 34 | ---------- 35 | obj : object, optional 36 | Object to register, defaults to None (which will return the decorator) 37 | name : str, optional 38 | Name of the object to register, defaults to None (which will use the name of the object) 39 | """ 40 | 41 | if obj is None: 42 | 43 | def deco(func_or_class, name=name): 44 | if name is None: 45 | name = func_or_class.__name__ 46 | self._do_register(name, func_or_class) 47 | return func_or_class 48 | 49 | return deco 50 | 51 | if name is None: 52 | name = obj.__name__ 53 | 54 | self._do_register(name, obj) 55 | 56 | def get(self, name): 57 | """ 58 | Method to retrieve an object from the registry 59 | 60 | Parameters 61 | ---------- 62 | name : str 63 | Name of the object to retrieve 64 | 65 | Returns 66 | ------- 67 | object 68 | Object registered under the given name 69 | """ 70 | 71 | ret = self._obj_map.get(name) 72 | if ret is None: 73 | raise KeyError( 74 | f"No object named '{name}' found in '{self._name}' registry!" 75 | ) 76 | 77 | return ret 78 | 79 | def get_list(self): 80 | """ 81 | Method to retrieve all objects from the registry 82 | 83 | Returns 84 | ------- 85 | list 86 | List of all objects registered in the registry 87 | """ 88 | 89 | return list(self._obj_map.keys()) 90 | 91 | def __contains__(self, name): 92 | return name in self._obj_map 93 | 94 | def __iter__(self): 95 | return iter(self._obj_map.items()) 96 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | .PHONY: clean clean-test clean-pyc clean-build docs help 2 | .DEFAULT_GOAL := help 3 | 4 | define BROWSER_PYSCRIPT 5 | import os, webbrowser, sys 6 | 7 | from urllib.request import pathname2url 8 | 9 | webbrowser.open("file://" + pathname2url(os.path.abspath(sys.argv[1]))) 10 | endef 11 | export BROWSER_PYSCRIPT 12 | 13 | define PRINT_HELP_PYSCRIPT 14 | import re, sys 15 | 16 | for line in sys.stdin: 17 | match = re.match(r'^([a-zA-Z_-]+):.*?## (.*)$$', line) 18 | if match: 19 | target, help = match.groups() 20 | print("%-20s %s" % (target, help)) 21 | endef 22 | export PRINT_HELP_PYSCRIPT 23 | 24 | BROWSER := python -c "$$BROWSER_PYSCRIPT" 25 | 26 | help: 27 | @python -c "$$PRINT_HELP_PYSCRIPT" < $(MAKEFILE_LIST) 28 | 29 | clean: clean-build clean-pyc clean-test ## remove all build, test, coverage and Python artifacts 30 | 31 | clean-build: ## remove build artifacts 32 | rm -fr build/ 33 | rm -fr dist/ 34 | rm -fr .eggs/ 35 | find . -name '*.egg-info' -exec rm -fr {} + 36 | find . -name '*.egg' -exec rm -f {} + 37 | 38 | clean-pyc: ## remove Python file artifacts 39 | find . -name '*.pyc' -exec rm -f {} + 40 | find . -name '*.pyo' -exec rm -f {} + 41 | find . -name '*~' -exec rm -f {} + 42 | find . -name '__pycache__' -exec rm -fr {} + 43 | 44 | clean-test: ## remove test and coverage artifacts 45 | rm -fr .tox/ 46 | rm -f .coverage 47 | rm -fr htmlcov/ 48 | rm -fr .pytest_cache 49 | 50 | lint/flake8: ## check style with flake8 51 | flake8 ezflow tests 52 | lint/black: ## check style with black 53 | black --check ezflow tests 54 | 55 | lint: lint/flake8 lint/black ## check style 56 | 57 | test: ## run tests quickly with the default Python 58 | pytest 59 | 60 | test-all: ## run tests on every Python version with tox 61 | tox 62 | 63 | coverage: ## check code coverage quickly with the default Python 64 | coverage run --source ezflow -m pytest 65 | coverage report -m 66 | coverage html 67 | $(BROWSER) htmlcov/index.html 68 | 69 | docs: ## generate Sphinx HTML documentation, including API docs 70 | rm -f docs/ezflow.rst 71 | rm -f docs/modules.rst 72 | sphinx-apidoc -o docs/ ezflow 73 | $(MAKE) -C docs clean 74 | $(MAKE) -C docs html 75 | $(BROWSER) docs/_build/html/index.html 76 | 77 | servedocs: docs ## compile the docs watching for changes 78 | watchmedo shell-command -p '*.rst' -c '$(MAKE) -C docs html' -R -D . 79 | 80 | release: dist ## package and upload a release 81 | twine upload dist/* 82 | 83 | dist: clean ## builds source and wheel package 84 | python setup.py sdist 85 | python setup.py bdist_wheel 86 | ls -l dist 87 | 88 | install: clean ## install the package to the active Python's site-packages 89 | python setup.py install 90 | -------------------------------------------------------------------------------- /ezflow/models/flownet_s.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.init import constant_, kaiming_normal_ 5 | 6 | from ..decoder import build_decoder 7 | from ..encoder import build_encoder 8 | from ..modules import BaseModule 9 | from .build import MODEL_REGISTRY 10 | 11 | 12 | @MODEL_REGISTRY.register() 13 | class FlowNetS(BaseModule): 14 | """ 15 | Implementation of **FlowNetSimple** from the paper 16 | `FlowNet: Learning Optical Flow with Convolutional Networks `_ 17 | 18 | Parameters 19 | ---------- 20 | cfg : :class:`CfgNode` 21 | Configuration for the model 22 | """ 23 | 24 | def __init__(self, cfg): 25 | super(FlowNetS, self).__init__() 26 | 27 | self.cfg = cfg 28 | 29 | self.encoder = build_encoder(cfg.ENCODER) 30 | 31 | self.decoder = build_decoder(cfg.DECODER) 32 | 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 35 | kaiming_normal_(m.weight, 0.1) 36 | if m.bias is not None: 37 | constant_(m.bias, 0) 38 | elif isinstance(m, nn.BatchNorm2d): 39 | constant_(m.weight, 1) 40 | constant_(m.bias, 0) 41 | 42 | def forward(self, img1, img2): 43 | """ 44 | Performs forward pass of the network 45 | 46 | Parameters 47 | ---------- 48 | img1 : torch.Tensor 49 | Image to predict flow from 50 | img2 : torch.Tensor 51 | Image to predict flow to 52 | 53 | Returns 54 | ------- 55 | :class:`dict` 56 | torch.Tensor : intermediate flow predications from img1 to img2 57 | torch.Tensor : if model is in eval state, return upsampled flow 58 | """ 59 | 60 | H, W = img1.shape[-2:] 61 | 62 | x = torch.cat([img1, img2], axis=1) 63 | 64 | conv_outputs = self.encoder(x) 65 | 66 | flow_preds = self.decoder(conv_outputs) 67 | flow_preds.reverse() 68 | 69 | output = {"flow_preds": flow_preds} 70 | 71 | if self.training: 72 | return output 73 | 74 | flow = flow_preds[0] 75 | 76 | H_, W_ = flow.shape[-2:] 77 | flow = F.interpolate( 78 | flow, img1.shape[-2:], mode="bilinear", align_corners=False 79 | ) 80 | flow_u = flow[:, 0, :, :] * (W / W_) 81 | flow_v = flow[:, 1, :, :] * (H / H_) 82 | flow = torch.stack([flow_u, flow_v], dim=1) 83 | 84 | output["flow_upsampled"] = flow 85 | return output 86 | -------------------------------------------------------------------------------- /ezflow/engine/profiler.py: -------------------------------------------------------------------------------- 1 | from torch.profiler import ProfilerActivity, schedule, tensorboard_trace_handler 2 | 3 | 4 | class Profiler: 5 | """ 6 | This class is a wrapper to initialize the parameters of PyTorch profiler. 7 | An instance of this class can be passed as an argument to ezflow.engine.eval_model 8 | to enable profiling of the model during inference. 9 | 10 | `Official documentation on torch.profiler `_ 11 | 12 | Parameters 13 | ---------- 14 | model_name : str 15 | Name of the model 16 | log_dir : str 17 | Path to save the profiling logs 18 | profile_cpu : bool, optional 19 | Enable CPU profiling, by default False 20 | profile_cuda : bool, optional 21 | Enable CUDA profiling, by default False 22 | profile_memory : bool, optional 23 | Enable memory profiling, by default False 24 | record_shapes : bool, optional 25 | Enable shape recording for tensors, by default False 26 | skip_first : int, optional 27 | Number of warmup iterations to skip, by default 0 28 | wait : int, optional 29 | Number of seconds to wait before starting the profiler, by default 0 30 | warmup : int, optional 31 | Number of iterations to warmup the profiler, by default 1 32 | active : int, optional 33 | Number of iterations to profile, by default 1 34 | repeat : int, optional 35 | Number of times to repeat the profiling, by default 10 36 | """ 37 | 38 | def __init__( 39 | self, 40 | model_name, 41 | log_dir, 42 | profile_cpu=False, 43 | profile_cuda=False, 44 | profile_memory=False, 45 | record_shapes=False, 46 | skip_first=0, 47 | wait=0, 48 | warmup=1, 49 | active=1, 50 | repeat=10, 51 | ): 52 | 53 | assert warmup != 0, "warmup cannot be 0, this can skew profiler results" 54 | assert ( 55 | log_dir is not None 56 | ), "log_dir path is not provided to save profiling logs" 57 | 58 | self.activites = [] 59 | self.model_name = model_name.upper() 60 | if profile_cpu: 61 | self.activites.append(ProfilerActivity.CPU) 62 | 63 | if profile_cuda: 64 | self.activites.append(ProfilerActivity.CUDA) 65 | 66 | self.profile_memory = profile_memory 67 | self.record_shapes = record_shapes 68 | 69 | self.schedule = schedule( 70 | skip_first=skip_first, 71 | wait=wait, 72 | warmup=warmup, 73 | active=active, 74 | repeat=repeat, 75 | ) 76 | 77 | self.on_trace_ready = tensorboard_trace_handler(log_dir) 78 | -------------------------------------------------------------------------------- /ezflow/models/pwcnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..decoder import ContextNetwork, build_decoder 6 | from ..encoder import build_encoder 7 | from ..modules import BaseModule 8 | from .build import MODEL_REGISTRY 9 | 10 | 11 | @MODEL_REGISTRY.register() 12 | class PWCNet(BaseModule): 13 | """ 14 | Implementation of the paper 15 | `PWC-Net: CNNs for Optical Flow Using Pyramid, Warping, and Cost Volume `_ 16 | 17 | Parameters 18 | ---------- 19 | cfg : :class:`CfgNode` 20 | Configuration for the model 21 | """ 22 | 23 | def __init__(self, cfg): 24 | super(PWCNet, self).__init__() 25 | 26 | self.cfg = cfg 27 | self.encoder = build_encoder(cfg.ENCODER) 28 | 29 | self.decoder = build_decoder(cfg.DECODER) 30 | 31 | search_range = (2 * cfg.DECODER.SIMILARITY.MAX_DISPLACEMENT + 1) ** 2 32 | self.context_net = ContextNetwork( 33 | in_channels=search_range 34 | + cfg.DECODER.SIMILARITY.MAX_DISPLACEMENT 35 | + cfg.DECODER.CONFIG[-1] 36 | + sum(cfg.DECODER.CONFIG), 37 | config=cfg.DECODER.CONFIG, 38 | ) 39 | 40 | self._init_weights() 41 | 42 | def _init_weights(self): 43 | 44 | for m in self.modules(): 45 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 46 | nn.init.kaiming_normal_(m.weight.data, mode="fan_in") 47 | if m.bias is not None: 48 | m.bias.data.zero_() 49 | 50 | def forward(self, img1, img2): 51 | """ 52 | Performs forward pass of the network 53 | 54 | Parameters 55 | ---------- 56 | img1 : torch.Tensor 57 | Image to predict flow from 58 | img2 : torch.Tensor 59 | Image to predict flow to 60 | 61 | Returns 62 | ------- 63 | :class:`dict` 64 | torch.Tensor : intermediate flow predications from img1 to img2 65 | torch.Tensor : if model is in eval state, return upsampled flow 66 | """ 67 | 68 | H, W = img1.shape[-2:] 69 | 70 | feature_pyramid1 = self.encoder(img1) 71 | feature_pyramid2 = self.encoder(img2) 72 | 73 | feature_pyramid1.reverse() 74 | feature_pyramid2.reverse() 75 | 76 | flow_preds, features = self.decoder(feature_pyramid1, feature_pyramid2) 77 | 78 | flow_preds[-1] += self.context_net(features) 79 | 80 | output = {"flow_preds": flow_preds} 81 | 82 | if self.training: 83 | return output 84 | 85 | flow_up = flow_preds[-1] 86 | 87 | flow_up = F.interpolate( 88 | flow_up, size=(H, W), mode="bilinear", align_corners=False 89 | ) 90 | 91 | output["flow_upsampled"] = flow_up 92 | 93 | return output 94 | -------------------------------------------------------------------------------- /tests/test_similarity.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ezflow.similarity import SIMILARITY_REGISTRY 4 | 5 | features1 = features2 = torch.rand(2, 32, 16, 16) 6 | 7 | 8 | def test_CorrelationLayer(): 9 | 10 | features1 = torch.rand(2, 8, 32, 32) 11 | features2 = torch.rand(2, 8, 32, 32) 12 | 13 | corr_fn = SIMILARITY_REGISTRY.get("CorrelationLayer")() 14 | _ = corr_fn(features1, features2) 15 | 16 | del corr_fn, features1, features2 17 | 18 | 19 | def test_IterSpatialCorrelationSampler(): 20 | 21 | features1 = torch.rand(2, 8, 32, 32) 22 | features2 = torch.rand(2, 8, 32, 32) 23 | 24 | corr_fn = SIMILARITY_REGISTRY.get("IterSpatialCorrelationSampler")() 25 | _ = corr_fn(features1, features2) 26 | 27 | del corr_fn, features1, features2 28 | 29 | 30 | def test_LearnableMatchingCost(): 31 | 32 | similarity_fn = SIMILARITY_REGISTRY.get("LearnableMatchingCost")() 33 | _ = similarity_fn(features1, features2) 34 | del similarity_fn 35 | 36 | 37 | def test_MultiScalePairwise4DCorr(): 38 | 39 | _ = SIMILARITY_REGISTRY.get("MutliScalePairwise4DCorr")(features1, features2) 40 | 41 | 42 | def test_MatryoshkaDilatedCostVolume(): 43 | 44 | features1 = torch.rand(2, 256, 32, 32) 45 | features2 = torch.rand(2, 256, 32, 32) 46 | dilations = [1, 2, 3, 5, 9, 16] 47 | 48 | corr_fn = SIMILARITY_REGISTRY.get("MatryoshkaDilatedCostVolume")( 49 | max_displacement=4, dilations=dilations 50 | ) 51 | 52 | search_range = corr_fn.get_search_range() 53 | assert search_range == 9 54 | 55 | offsets = corr_fn.get_relative_offsets() 56 | assert offsets.shape == (len(dilations), search_range) 57 | 58 | _ = corr_fn(features1, features2) 59 | 60 | corr_fn = SIMILARITY_REGISTRY.get("MatryoshkaDilatedCostVolume")(use_relu=True) 61 | _ = corr_fn(features1, features2) 62 | 63 | del corr_fn, features1, features2 64 | 65 | 66 | def test_MatryoshkaDilatedCostVolumeList(): 67 | 68 | features1 = [torch.rand(1, 128, 128, 128), torch.rand(1, 256, 32, 32)] 69 | features2 = [torch.rand(1, 128, 128, 128), torch.rand(1, 256, 32, 32)] 70 | 71 | strides = [2, 8] 72 | dilations = [[1], [1, 2, 3, 5, 9, 16]] 73 | 74 | corr_fn = SIMILARITY_REGISTRY.get("MatryoshkaDilatedCostVolumeList")( 75 | max_displacement=4, encoder_output_strides=strides, dilations=dilations 76 | ) 77 | 78 | search_range = corr_fn.get_search_range() 79 | assert search_range == 9 80 | 81 | offsets = corr_fn.get_global_flow_offsets() 82 | assert offsets.shape == ( 83 | len(dilations[0]) + len(dilations[1]), 84 | search_range, 85 | search_range, 86 | 2, 87 | ) 88 | 89 | _ = corr_fn(features1, features2) 90 | 91 | corr_fn = SIMILARITY_REGISTRY.get("MatryoshkaDilatedCostVolumeList")( 92 | normalize_feat_l2=True 93 | ) 94 | _ = corr_fn(features1, features2) 95 | 96 | del corr_fn, features1, features2 97 | -------------------------------------------------------------------------------- /ezflow/similarity/correlation/pairwise.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from ...config import configurable 5 | from ...utils import bilinear_sampler 6 | from ..build import SIMILARITY_REGISTRY 7 | 8 | 9 | @SIMILARITY_REGISTRY.register() 10 | class MutliScalePairwise4DCorr: 11 | """ 12 | Pairwise 4D correlation at multiple scales. Used in **RAFT** (https://arxiv.org/abs/2003.12039) 13 | 14 | Parameters 15 | ---------- 16 | fmap1 : torch.Tensor 17 | First feature map 18 | fmap2 : torch.Tensor 19 | Second feature map 20 | num_levels : int 21 | Number of levels in the feature pyramid 22 | corr_radius : int 23 | Radius of the correlation window 24 | """ 25 | 26 | @configurable 27 | def __init__(self, fmap1, fmap2, num_levels=4, corr_radius=4): 28 | 29 | self.num_levels = num_levels 30 | self.corr_radius = corr_radius 31 | self.corr_pyramid = [] 32 | 33 | corr = MutliScalePairwise4DCorr.corr(fmap1, fmap2) 34 | 35 | batch, h1, w1, dim, h2, w2 = corr.shape 36 | corr = corr.reshape(batch * h1 * w1, dim, h2, w2) 37 | 38 | self.corr_pyramid.append(corr) 39 | for _ in range(self.num_levels - 1): 40 | corr = F.avg_pool2d(corr, 2, stride=2) 41 | self.corr_pyramid.append(corr) 42 | 43 | def __call__(self, coords): 44 | 45 | r = self.corr_radius 46 | coords = coords.permute(0, 2, 3, 1) 47 | batch, h1, w1, _ = coords.shape 48 | 49 | out_pyramid = [] 50 | for i in range(self.num_levels): 51 | 52 | corr = self.corr_pyramid[i] 53 | dx = torch.linspace(-r, r, 2 * r + 1) 54 | dy = torch.linspace(-r, r, 2 * r + 1) 55 | delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to( 56 | coords.device 57 | ) 58 | 59 | centroid_lvl = coords.reshape(batch * h1 * w1, 1, 1, 2) / 2**i 60 | delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2) 61 | coords_lvl = centroid_lvl + delta_lvl 62 | 63 | corr = bilinear_sampler(corr, coords_lvl) 64 | corr = corr.view(batch, h1, w1, -1) 65 | out_pyramid.append(corr) 66 | 67 | out = torch.cat(out_pyramid, dim=-1) 68 | 69 | return out.permute(0, 3, 1, 2).contiguous().float() 70 | 71 | @classmethod 72 | def from_config(cls, cfg): 73 | return { 74 | "num_levels": cfg.NUM_LEVELS, 75 | "corr_radius": cfg.CORR_RADIUS, 76 | } 77 | 78 | @staticmethod 79 | def corr(fmap1, fmap2): 80 | 81 | batch, dim, ht, wd = fmap1.shape 82 | fmap1 = fmap1.view(batch, dim, ht * wd) 83 | fmap2 = fmap2.view(batch, dim, ht * wd) 84 | 85 | corr = torch.matmul(fmap1.transpose(1, 2), fmap2) 86 | corr = corr.view(batch, ht, wd, 1, ht, wd) 87 | 88 | return corr / torch.sqrt(torch.tensor(dim).float()) 89 | -------------------------------------------------------------------------------- /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | absl-py==0.13.0 2 | alabaster==0.7.12 3 | antlr4-python3-runtime==4.8 4 | appdirs==1.4.4 5 | argh==0.26.2 6 | arrow==0.15.1 7 | attrs 8 | Babel==2.9.1 9 | backports.entry-points-selectable==1.1.0 10 | binaryornot==0.4.4 11 | black==21.7b0 12 | bleach==3.3.0 13 | brotlipy==0.7.0 14 | bump2version==0.5.11 15 | cachetools==4.2.2 16 | certifi==2022.12.7 17 | cffi==1.14.6 18 | cfgv==3.3.0 19 | chardet==4.0.0 20 | charset-normalizer==2.0.1 21 | click==8.0.1 22 | colorama==0.4.4 23 | coverage==5.5 24 | cryptography==3.4.8 25 | cycler==0.10.0 26 | distlib==0.3.2 27 | dnspython==1.16.0 28 | docopt==0.6.2 29 | docutils==0.17.1 30 | easydict==1.9 31 | entrypoints==0.3 32 | fett==0.3.2 33 | filelock==3.0.12 34 | flake8==3.9.2 35 | fvcore==0.1.5.post20210915 36 | google-auth==1.35.0 37 | google-auth-oauthlib==0.4.5 38 | grpcio==1.39.0 39 | identify==2.2.13 40 | idna==3.2 41 | imagesize==1.2.0 42 | importlib-metadata==0.23 43 | iniconfig==1.1.1 44 | iopath==0.1.9 45 | Jinja2==3.0.1 46 | jinja2-time==0.2.0 47 | kiwisolver 48 | Markdown==3.3.4 49 | MarkupSafe==2.0.1 50 | matplotlib==3.4.2 51 | mccabe==0.6.1 52 | mypy-extensions==0.4.3 53 | networkx==2.6.2 54 | nodeenv==1.6.0 55 | numpy==1.20.2 56 | oauthlib==3.2.2 57 | olefile==0.46 58 | omegaconf 59 | opencv-python==4.5.3.56 60 | packaging==21.0 61 | pathspec==0.9.0 62 | pathtools==0.1.2 63 | Pillow 64 | pkginfo==1.7.1 65 | platformdirs==2.0.0 66 | pluggy==0.13.1 67 | portalocker 68 | poyo==0.5.0 69 | pre-commit==2.14.0 70 | protobuf==3.18.3 71 | py==1.10.0 72 | pyasn1==0.4.8 73 | pyasn1-modules==0.2.8 74 | pycodestyle==2.7.0 75 | pycparser==2.20 76 | pyflakes==2.3.1 77 | Pygments==2.9.0 78 | pymongo==3.11.4 79 | pyOpenSSL==20.0.1 80 | pyparsing==2.4.7 81 | PySocks==1.7.1 82 | pytest==6.2.4 83 | python-dateutil==2.8.1 84 | python-slugify==5.0.2 85 | pytz==2021.1 86 | PyYAML==5.4.1 87 | readme-renderer==29.0 88 | regex 89 | requests==2.26.0 90 | requests-oauthlib==1.3.0 91 | requests-toolbelt==0.9.1 92 | rsa==4.7.2 93 | scipy==1.7.0 94 | six 95 | snooty-lextudio==1.11.1.dev0 96 | snowballstemmer==2.1.0 97 | Sphinx==4.1.2 98 | sphinx-rtd-theme==1.0.0 99 | sphinxcontrib-applehelp==1.0.2 100 | sphinxcontrib-devhelp==1.0.2 101 | sphinxcontrib-htmlhelp==2.0.0 102 | sphinxcontrib-jsmath==1.0.1 103 | sphinxcontrib-qthelp==1.0.3 104 | sphinxcontrib-serializinghtml==1.1.5 105 | sphinxcontrib-websupport==1.2.4 106 | tabulate==0.8.9 107 | tensorboard==2.6.0 108 | tensorboard-data-server==0.6.1 109 | tensorboard-plugin-wit==1.8.0 110 | termcolor==1.1.0 111 | text-unidecode==1.3 112 | toml 113 | tomli==1.2.1 114 | torch>=1.9.0 115 | torchmetrics>=0.5.0 116 | torchvision>=0.10.0 117 | tornado 118 | tox==3.14.0 119 | tqdm 120 | twine==1.14.0 121 | typed-ast 122 | typing-extensions 123 | ujson 124 | Unidecode==1.3.2 125 | urllib3==1.26.6 126 | virtualenv==20.5.0 127 | watchdog==1.0.2 128 | webencodings==0.5.1 129 | whichcraft==0.6.1 130 | yacs==0.1.8 131 | zipp 132 | -------------------------------------------------------------------------------- /ezflow/models/dcvnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..decoder import build_decoder 6 | from ..encoder import build_encoder 7 | from ..modules import BaseModule, build_module 8 | from ..similarity import build_similarity 9 | from ..utils import replace_relu 10 | from .build import MODEL_REGISTRY 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class DCVNet(BaseModule): 15 | """ 16 | Implementation of **DCVNet** from the paper 17 | `DCVNet: Dilated Cost Volume Networks for Fast Optical Flow `_ 18 | 19 | Parameters 20 | ---------- 21 | cfg : :class:`CfgNode` 22 | Configuration for the model 23 | """ 24 | 25 | def __init__(self, cfg): 26 | super(DCVNet, self).__init__() 27 | 28 | self.cfg = cfg 29 | 30 | self.encoder = build_encoder(self.cfg.ENCODER) 31 | self.cost_volume_list = build_similarity(self.cfg.SIMILARITY) 32 | 33 | if "DILATIONS" not in self.cfg.DECODER: 34 | self.cfg.DECODER.DILATIONS = self.cfg.SIMILARITY.DILATIONS 35 | 36 | if "SEARCH_RANGE" not in self.cfg.DECODER.COST_VOLUME_FILTER: 37 | self.cfg.DECODER.COST_VOLUME_FILTER.SEARCH_RANGE = ( 38 | self.cost_volume_list.get_search_range() 39 | ) 40 | 41 | self.decoder = build_decoder(self.cfg.DECODER) 42 | self = replace_relu(self, nn.LeakyReLU(negative_slope=0.1)) 43 | 44 | def forward(self, img1, img2): 45 | """ 46 | Performs forward pass of the network 47 | 48 | Parameters 49 | ---------- 50 | img1 : torch.Tensor 51 | Image to predict flow from 52 | img2 : torch.Tensor 53 | Image to predict flow to 54 | 55 | Returns 56 | ------- 57 | :class:`dict` 58 | torch.Tensor : intermediate flow predications from img1 to img2 59 | torch.Tensor : interpolated flow logits 60 | torch.Tensor : if model is in eval state, return upsampled flow 61 | """ 62 | N, C, H, W = img1.shape 63 | feat_map, context_map = self.encoder([img1, img2]) 64 | fmap1 = [feat_i[:N] for feat_i in feat_map] 65 | fmap2 = [feat_i[N:] for feat_i in feat_map] 66 | context_fmap1 = [context_i[:N] for context_i in context_map] 67 | 68 | assert len(fmap1) == len(self.cfg.SIMILARITY.DILATIONS) 69 | assert len(fmap1) == len(self.cfg.SIMILARITY.DILATIONS) 70 | 71 | cost = self.cost_volume_list(fmap1, fmap2) 72 | flow_offsets = self.cost_volume_list.get_global_flow_offsets().view( 73 | 1, -1, 2, 1, 1 74 | ) 75 | 76 | flow_list, flow_logits_list = self.decoder(cost, context_fmap1, flow_offsets) 77 | 78 | output = {"flow_preds": flow_list, "flow_logits": flow_logits_list} 79 | 80 | if self.training: 81 | return output 82 | 83 | output["flow_upsampled"] = flow_list[-1] 84 | return output 85 | -------------------------------------------------------------------------------- /generate_dir_structure.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | 4 | class DisplayablePath(object): 5 | display_filename_prefix_middle = "├──" 6 | display_filename_prefix_last = "└──" 7 | display_parent_prefix_middle = " " 8 | display_parent_prefix_last = "│ " 9 | 10 | def __init__(self, path, parent_path, is_last): 11 | self.path = Path(str(path)) 12 | self.parent = parent_path 13 | self.is_last = is_last 14 | if self.parent: 15 | self.depth = self.parent.depth + 1 16 | else: 17 | self.depth = 0 18 | 19 | @classmethod 20 | def make_tree(cls, root, parent=None, is_last=False, criteria=None): 21 | root = Path(str(root)) 22 | criteria = criteria or cls._default_criteria 23 | 24 | displayable_root = cls(root, parent, is_last) 25 | yield displayable_root 26 | 27 | children = sorted( 28 | list(path for path in root.iterdir() if criteria(path)), 29 | key=lambda s: str(s).lower(), 30 | ) 31 | count = 1 32 | for path in children: 33 | is_last = count == len(children) 34 | if path.is_dir(): 35 | yield from cls.make_tree( 36 | path, 37 | parent=displayable_root, 38 | is_last=is_last, 39 | criteria=criteria, 40 | ) 41 | else: 42 | yield cls(path, displayable_root, is_last) 43 | count += 1 44 | 45 | @classmethod 46 | def _default_criteria(cls, path): 47 | return True 48 | 49 | @property 50 | def displayname(self): 51 | if self.path.is_dir(): 52 | return self.path.name + "/" 53 | return self.path.name 54 | 55 | def displayable(self): 56 | if self.parent is None: 57 | return self.displayname 58 | 59 | _filename_prefix = ( 60 | self.display_filename_prefix_last 61 | if self.is_last 62 | else self.display_filename_prefix_middle 63 | ) 64 | 65 | parts = ["{!s} {!s}".format(_filename_prefix, self.displayname)] 66 | 67 | parent = self.parent 68 | while parent and parent.parent is not None: 69 | parts.append( 70 | self.display_parent_prefix_middle 71 | if parent.is_last 72 | else self.display_parent_prefix_last 73 | ) 74 | parent = parent.parent 75 | 76 | return "".join(reversed(parts)) 77 | 78 | 79 | if __name__ == "__main__": 80 | 81 | from argparse import ArgumentParser 82 | 83 | parser = ArgumentParser("Utility for displaying directory structure") 84 | parser.add_argument( 85 | "--dir", 86 | type=str, 87 | required=True, 88 | help="Name of the directory whose structure is to be generated", 89 | ) 90 | args = parser.parse_args() 91 | 92 | paths = DisplayablePath.make_tree( 93 | Path(args.dir), 94 | criteria=lambda path: True 95 | if path.name not in (".git", "__pycache__", "__init__.py") 96 | else False, 97 | ) 98 | for path in paths: 99 | print(path.displayable()) 100 | -------------------------------------------------------------------------------- /configs/trainers/_base_/chairs_baseline.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | BATCH_SIZE: 8 3 | NUM_WORKERS: 4 4 | PIN_MEMORY: True 5 | SHUFFLE: True 6 | INIT_SEED: False 7 | DROP_LAST: True 8 | TRAIN_DATASET: 9 | FlyingChairs: 10 | ROOT_DIR: "./Datasets/FlyingChairs_release/data" 11 | SPLIT: "training" 12 | IS_PREDICTION: False 13 | APPEND_VALID_MASK: False 14 | CROP: 15 | USE: True 16 | SIZE: [384, 448] 17 | TYPE: "random" 18 | FLOW_OFFSET_PARAMS: {"use": False} 19 | AUGMENTATION: 20 | # Augmentation Settings borrowed from RAFT 21 | USE: True 22 | PARAMS: 23 | color_aug_params: { 24 | "enabled": True, 25 | "asymmetric_color_aug_prob": 0.2, 26 | "brightness": 0.4, 27 | "contrast": 0.4, 28 | "saturation": 0.4, 29 | "hue": 0.15915494309189535 30 | } 31 | eraser_aug_params: { 32 | "enabled": True, 33 | "aug_prob": 0.5, 34 | "bounds": [50, 100] 35 | } 36 | noise_aug_params: { 37 | "enabled": False, 38 | "aug_prob": 0.5, 39 | "noise_std_range": 0.06 40 | } 41 | flip_aug_params: { 42 | "enabled": True, 43 | "h_flip_prob": 0.5, 44 | "v_flip_prob": 0.1 45 | } 46 | spatial_aug_params: { 47 | "enabled": True, 48 | "aug_prob": 0.8, 49 | "stretch_prob": 0.8, 50 | "min_scale": -0.1, 51 | "max_scale": 1.0, 52 | "max_stretch": 0.2, 53 | } 54 | advanced_spatial_aug_params: { 55 | "enabled": False, 56 | "scale1": 0.0, 57 | "scale2": 0.0, 58 | "stretch": 0.0, 59 | "rotate": 0.0, 60 | "translate": 0.0, 61 | "enable_out_of_boundary_crop": False 62 | } 63 | VAL_DATASET: 64 | FlyingChairs: 65 | ROOT_DIR: "./Datasets/FlyingChairs_release/data" 66 | SPLIT: "validation" 67 | APPEND_VALID_MASK: False 68 | IS_PREDICTION: False 69 | PADDING: 1 70 | CROP: 71 | USE: True 72 | SIZE: [384, 448] 73 | TYPE: "center" 74 | FLOW_OFFSET_PARAMS: {"use": False} 75 | AUGMENTATION: 76 | USE: False 77 | PARAMS: 78 | color_aug_params: {"enabled": False} 79 | eraser_aug_params: {"enabled": False} 80 | noise_aug_params: {"enabled": False} 81 | flip_aug_params: {"enabled": False} 82 | spatial_aug_params: {"enabled": False} 83 | advanced_spatial_aug_params: {"enabled": False} 84 | OPTIMIZER: 85 | NAME: AdamW 86 | LR: 0.0004 87 | PARAMS: 88 | weight_decay: 0.0001 89 | betas: [0.9, 0.999] 90 | eps: 1.e-08 91 | amsgrad: False 92 | GRAD_CLIP: 93 | USE: True 94 | VALUE: 1.0 95 | FREEZE_BATCH_NORM: False 96 | TARGET_SCALE_FACTOR: 1.0 97 | MIXED_PRECISION: False 98 | DEVICE: "0" 99 | DISTRIBUTED: 100 | USE: False 101 | WORLD_SIZE: 2 102 | BACKEND: nccl 103 | MASTER_ADDR: localhost 104 | MASTER_PORT: "12355" 105 | EPOCHS: null 106 | NUM_STEPS: null 107 | RESUME_TRAINING: 108 | CONSOLIDATED_CKPT: null 109 | EPOCHS: null 110 | START_EPOCH: null -------------------------------------------------------------------------------- /configs/trainers/_base_/kubric_baseline.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | BATCH_SIZE: 6 3 | APPEND_VALID_MASK: False 4 | NUM_WORKERS: 4 5 | PIN_MEMORY: True 6 | SHUFFLE: True 7 | INIT_SEED: False 8 | DROP_LAST: True 9 | TRAIN_DATASET: 10 | Kubric: 11 | ROOT_DIR: "./Datasets/KubricFlow" 12 | SPLIT: "training" 13 | USE_BACKWARD_FLOW: False 14 | SWAP_COLUMN_TO_ROW: True 15 | IS_PREDICTION: False 16 | APPEND_VALID_MASK: False 17 | CROP: 18 | USE: True 19 | SIZE: [384, 448] 20 | TYPE: "random" 21 | FLOW_OFFSET_PARAMS: {"use": False} 22 | AUGMENTATION: 23 | # Augmentation Settings borrowed from RAFT 24 | USE: True 25 | PARAMS: 26 | color_aug_params: { 27 | "enabled": True, 28 | "asymmetric_color_aug_prob": 0.2, 29 | "brightness": 0.4, 30 | "contrast": 0.4, 31 | "saturation": 0.4, 32 | "hue": 0.15915494309189535 33 | } 34 | eraser_aug_params: { 35 | "enabled": True, 36 | "aug_prob": 0.5, 37 | "bounds": [50, 100] 38 | } 39 | noise_aug_params: { 40 | "enabled": False, 41 | "aug_prob": 0.5, 42 | "noise_std_range": 0.06 43 | } 44 | flip_aug_params: { 45 | "enabled": True, 46 | "h_flip_prob": 0.5, 47 | "v_flip_prob": 0.1 48 | } 49 | spatial_aug_params: { 50 | "enabled": True, 51 | "aug_prob": 0.8, 52 | "stretch_prob": 0.8, 53 | "min_scale": -0.1, 54 | "max_scale": 1.0, 55 | "max_stretch": 0.2, 56 | } 57 | advanced_spatial_aug_params: { 58 | "enabled": False, 59 | "scale1": 0.0, 60 | "scale2": 0.0, 61 | "stretch": 0.0, 62 | "rotate": 0.0, 63 | "translate": 0.0, 64 | "enable_out_of_boundary_crop": False 65 | } 66 | VAL_DATASET: 67 | Kubric: 68 | ROOT_DIR: "./Datasets/KubricFlow" 69 | SPLIT: "validation" 70 | USE_BACKWARD_FLOW: False 71 | SWAP_COLUMN_TO_ROW: True 72 | APPEND_VALID_MASK: False 73 | IS_PREDICTION: False 74 | PADDING: 1 75 | CROP: 76 | USE: True 77 | SIZE: [384, 448] 78 | TYPE: "center" 79 | FLOW_OFFSET_PARAMS: {"use": False} 80 | AUGMENTATION: 81 | USE: False 82 | PARAMS: 83 | color_aug_params: {"enabled": False} 84 | eraser_aug_params: {"enabled": False} 85 | noise_aug_params: {"enabled": False} 86 | flip_aug_params: {"enabled": False} 87 | spatial_aug_params: {"enabled": False} 88 | advanced_spatial_aug_params: {"enabled": False} 89 | OPTIMIZER: 90 | NAME: AdamW 91 | LR: 0.0004 92 | PARAMS: 93 | weight_decay: 0.0001 94 | betas: [0.9, 0.999] 95 | eps: 1.e-08 96 | amsgrad: False 97 | GRAD_CLIP: 98 | USE: True 99 | VALUE: 1.0 100 | TARGET_SCALE_FACTOR: 1.0 101 | MIXED_PRECISION: False 102 | DEVICE: "0" 103 | DISTRIBUTED: 104 | USE: False 105 | WORLD_SIZE: 2 106 | BACKEND: nccl 107 | MASTER_ADDR: localhost 108 | MASTER_PORT: "12355" 109 | EPOCHS: null 110 | NUM_STEPS: null 111 | RESUME_TRAINING: 112 | CONSOLIDATED_CKPT: null 113 | EPOCHS: null 114 | START_EPOCH: null -------------------------------------------------------------------------------- /configs/trainers/_base_/things_baseline.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | BATCH_SIZE: 6 3 | APPEND_VALID_MASK: False 4 | NUM_WORKERS: 4 5 | PIN_MEMORY: True 6 | SHUFFLE: True 7 | INIT_SEED: False 8 | DROP_LAST: True 9 | TRAIN_DATASET: 10 | FlyingThings3DClean: &TRAIN_DS_CONFIG 11 | ROOT_DIR: "./Datasets/SceneFlow/FlyingThings3D" 12 | SPLIT: "training" 13 | IS_PREDICTION: False 14 | APPEND_VALID_MASK: False 15 | CROP: 16 | USE: True 17 | SIZE: [384, 768] 18 | TYPE: "random" 19 | FLOW_OFFSET_PARAMS: {"use": False} 20 | AUGMENTATION: 21 | # Augmentation Settings borrowed from RAFT 22 | USE: True 23 | PARAMS: 24 | color_aug_params: { 25 | "enabled": True, 26 | "asymmetric_color_aug_prob": 0.2, 27 | "brightness": 0.4, 28 | "contrast": 0.4, 29 | "saturation": 0.4, 30 | "hue": 0.15915494309189535 31 | } 32 | eraser_aug_params: { 33 | "enabled": True, 34 | "aug_prob": 0.5, 35 | "bounds": [50, 100] 36 | } 37 | noise_aug_params: { 38 | "enabled": False, 39 | "aug_prob": 0.5, 40 | "noise_std_range": 0.06 41 | } 42 | flip_aug_params: { 43 | "enabled": True, 44 | "h_flip_prob": 0.5, 45 | "v_flip_prob": 0.1 46 | } 47 | spatial_aug_params: { 48 | "enabled": True, 49 | "aug_prob": 0.8, 50 | "stretch_prob": 0.8, 51 | "min_scale": -0.1, 52 | "max_scale": 1.0, 53 | "max_stretch": 0.2, 54 | } 55 | advanced_spatial_aug_params: { 56 | "enabled": False, 57 | "scale1": 0.0, 58 | "scale2": 0.0, 59 | "stretch": 0.0, 60 | "rotate": 0.0, 61 | "translate": 0.0, 62 | "enable_out_of_boundary_crop": False 63 | } 64 | FlyingThings3DFinal: *TRAIN_DS_CONFIG 65 | VAL_DATASET: 66 | FlyingThings3DClean: &VAL_DS_CONFIG 67 | ROOT_DIR: "./Datasets/SceneFlow/FlyingThings3D" 68 | SPLIT: "validation" 69 | APPEND_VALID_MASK: False 70 | IS_PREDICTION: False 71 | PADDING: 1 72 | CROP: 73 | USE: True 74 | SIZE: [384, 768] 75 | TYPE: "center" 76 | FLOW_OFFSET_PARAMS: {"use": False} 77 | AUGMENTATION: 78 | USE: False 79 | PARAMS: 80 | color_aug_params: {"enabled": False} 81 | eraser_aug_params: {"enabled": False} 82 | noise_aug_params: {"enabled": False} 83 | flip_aug_params: {"enabled": False} 84 | spatial_aug_params: {"enabled": False} 85 | advanced_spatial_aug_params: {"enabled": False} 86 | FlyingThings3DFinal: *VAL_DS_CONFIG 87 | OPTIMIZER: 88 | NAME: AdamW 89 | LR: 0.000125 90 | PARAMS: 91 | weight_decay: 0.0001 92 | betas: [0.9, 0.999] 93 | eps: 1.e-08 94 | amsgrad: False 95 | GRAD_CLIP: 96 | USE: True 97 | VALUE: 1.0 98 | FREEZE_BATCH_NORM: False 99 | TARGET_SCALE_FACTOR: 1.0 100 | MIXED_PRECISION: False 101 | DEVICE: "0" 102 | DISTRIBUTED: 103 | USE: False 104 | WORLD_SIZE: 2 105 | BACKEND: nccl 106 | MASTER_ADDR: localhost 107 | MASTER_PORT: "12355" 108 | EPOCHS: null 109 | NUM_STEPS: null 110 | RESUME_TRAINING: 111 | CONSOLIDATED_CKPT: null 112 | EPOCHS: null 113 | START_EPOCH: null -------------------------------------------------------------------------------- /configs/trainers/_base_/kubric_improved_aug.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | BATCH_SIZE: 6 3 | APPEND_VALID_MASK: False 4 | NUM_WORKERS: 4 5 | PIN_MEMORY: True 6 | SHUFFLE: True 7 | INIT_SEED: False 8 | DROP_LAST: True 9 | TRAIN_DATASET: 10 | Kubric: 11 | ROOT_DIR: "./Datasets/KubricFlow" 12 | SPLIT: "training" 13 | USE_BACKWARD_FLOW: False 14 | SWAP_COLUMN_TO_ROW: True 15 | IS_PREDICTION: False 16 | APPEND_VALID_MASK: False 17 | CROP: 18 | USE: True 19 | SIZE: [384, 448] 20 | TYPE: "random" 21 | FLOW_OFFSET_PARAMS: {"use": False} 22 | AUGMENTATION: 23 | # Spatial Augmentation Settings borrowed from AutoFlow: https://github.com/google-research/opticalflow-autoflow/blob/main/src/dataset_lib/augmentations/aug_params.py 24 | USE: True 25 | PARAMS: 26 | color_aug_params: { 27 | "enabled": True, 28 | "asymmetric_color_aug_prob": 0.2, 29 | "brightness": 0.4, 30 | "contrast": 0.4, 31 | "saturation": 0.4, 32 | "hue": 0.15915494309189535 33 | } 34 | eraser_aug_params: { 35 | "enabled": True, 36 | "aug_prob": 0.5, 37 | "bounds": [50, 100] 38 | } 39 | noise_aug_params: { 40 | "enabled": True, 41 | "aug_prob": 0.5, 42 | "noise_std_range": 0.06 43 | } 44 | flip_aug_params: { 45 | "enabled": True, 46 | "h_flip_prob": 0.5, 47 | "v_flip_prob": 0.1 48 | } 49 | spatial_aug_params: { 50 | "enabled": False, 51 | "aug_prob": 0.0, 52 | "stretch_prob": 0.0, 53 | "min_scale": 0, 54 | "max_scale": 0, 55 | "max_stretch": 0 56 | } 57 | advanced_spatial_aug_params: { 58 | "enabled": True, 59 | "scale1": 0.3, 60 | "scale2": 0.1, 61 | "rotate": 0.4, 62 | "translate": 0.4, 63 | "stretch": 0.3, 64 | "enable_out_of_boundary_crop": False 65 | } 66 | VAL_DATASET: 67 | Kubric: 68 | ROOT_DIR: "./Datasets/KubricFlow" 69 | SPLIT: "validation" 70 | USE_BACKWARD_FLOW: False 71 | SWAP_COLUMN_TO_ROW: True 72 | APPEND_VALID_MASK: False 73 | IS_PREDICTION: False 74 | PADDING: 1 75 | CROP: 76 | USE: True 77 | SIZE: [384, 448] 78 | TYPE: "center" 79 | FLOW_OFFSET_PARAMS: {"use": False} 80 | AUGMENTATION: 81 | USE: False 82 | PARAMS: 83 | color_aug_params: {"enabled": False} 84 | eraser_aug_params: {"enabled": False} 85 | noise_aug_params: {"enabled": False} 86 | flip_aug_params: {"enabled": False} 87 | spatial_aug_params: {"enabled": False} 88 | advanced_spatial_aug_params: {"enabled": False} 89 | OPTIMIZER: 90 | NAME: AdamW 91 | LR: 0.0004 92 | PARAMS: 93 | weight_decay: 0.0001 94 | betas: [0.9, 0.999] 95 | eps: 1.e-08 96 | amsgrad: False 97 | GRAD_CLIP: 98 | USE: True 99 | VALUE: 1.0 100 | FREEZE_BATCH_NORM: False 101 | TARGET_SCALE_FACTOR: 1.0 102 | MIXED_PRECISION: False 103 | DEVICE: "0" 104 | DISTRIBUTED: 105 | USE: False 106 | WORLD_SIZE: 2 107 | BACKEND: nccl 108 | MASTER_ADDR: localhost 109 | MASTER_PORT: "12355" 110 | EPOCHS: null 111 | NUM_STEPS: null 112 | RESUME_TRAINING: 113 | CONSOLIDATED_CKPT: null 114 | EPOCHS: null 115 | START_EPOCH: null -------------------------------------------------------------------------------- /ezflow/utils/resampling.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | from scipy import interpolate 5 | 6 | 7 | def forward_interpolate(flow): 8 | """ 9 | Forward interpolation of flow field 10 | 11 | Parameters 12 | ---------- 13 | flow : torch.Tensor 14 | Flow field to be interpolated 15 | 16 | Returns 17 | ------- 18 | torch.Tensor 19 | Forward interpolated flow field 20 | 21 | """ 22 | 23 | flow = flow.detach().cpu().numpy() 24 | dx, dy = flow[0], flow[1] 25 | 26 | ht, wd = dx.shape 27 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 28 | 29 | x1 = x0 + dx 30 | y1 = y0 + dy 31 | 32 | x1 = x1.reshape(-1) 33 | y1 = y1.reshape(-1) 34 | dx = dx.reshape(-1) 35 | dy = dy.reshape(-1) 36 | 37 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 38 | x1 = x1[valid] 39 | y1 = y1[valid] 40 | dx = dx[valid] 41 | dy = dy[valid] 42 | 43 | flow_x = interpolate.griddata( 44 | (x1, y1), dx, (x0, y0), method="nearest", fill_value=0 45 | ) 46 | 47 | flow_y = interpolate.griddata( 48 | (x1, y1), dy, (x0, y0), method="nearest", fill_value=0 49 | ) 50 | 51 | flow = np.stack([flow_x, flow_y], axis=0) 52 | 53 | return torch.from_numpy(flow).float() 54 | 55 | 56 | def bilinear_sampler(img, coords, mask=False): 57 | """ 58 | Biliear sampler for images 59 | 60 | Parameters 61 | ---------- 62 | img : torch.Tensor 63 | Image to be sampled 64 | coords : torch.Tensor 65 | Coordinates to be sampled 66 | 67 | Returns 68 | ------- 69 | torch.Tensor 70 | Sampled image 71 | """ 72 | 73 | H, W = img.shape[-2:] 74 | xgrid, ygrid = coords.split([1, 1], dim=-1) 75 | xgrid = 2 * xgrid / (W - 1) - 1 76 | ygrid = 2 * ygrid / (H - 1) - 1 77 | 78 | grid = torch.cat([xgrid, ygrid], dim=-1) 79 | img = F.grid_sample(img, grid, align_corners=True) 80 | 81 | if mask: 82 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 83 | return img, mask.float() 84 | 85 | return img 86 | 87 | 88 | def upflow(flow, scale=8, mode="bilinear"): 89 | """ 90 | Interpolate flow field 91 | 92 | Parameters 93 | ---------- 94 | flow : torch.Tensor 95 | Flow field to be interpolated 96 | scale : int 97 | Scale of the interpolated flow field 98 | mode : str 99 | Interpolation mode 100 | 101 | Returns 102 | ------- 103 | torch.Tensor 104 | Interpolated flow field 105 | """ 106 | 107 | new_size = (scale * flow.shape[-2], scale * flow.shape[-1]) 108 | 109 | return scale * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 110 | 111 | 112 | def convex_upsample_flow(flow, mask_logits, out_stride): # adapted from RAFT 113 | """ 114 | Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination 115 | 116 | Parameters 117 | ---------- 118 | flow : torch.Tensor 119 | Flow field to be upsampled 120 | mask_logits : torch.Tensor 121 | Mask logits 122 | out_stride : int 123 | Output stride 124 | 125 | Returns 126 | ------- 127 | torch.Tensor 128 | Upsampled flow field 129 | """ 130 | 131 | N, C, H, W = flow.shape 132 | mask_logits = mask_logits.view(N, 1, 9, out_stride, out_stride, H, W) 133 | mask_probs = torch.softmax(mask_logits, dim=2) 134 | 135 | up_flow = F.unfold(flow, [3, 3], padding=1) 136 | up_flow = up_flow.view(N, C, 9, 1, 1, H, W) 137 | 138 | up_flow = torch.sum(mask_probs * up_flow, dim=2) 139 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 140 | 141 | return up_flow.reshape(N, C, out_stride * H, out_stride * W) 142 | -------------------------------------------------------------------------------- /CONTRIBUTING.rst: -------------------------------------------------------------------------------- 1 | .. highlight:: shell 2 | 3 | ============ 4 | Contributing 5 | ============ 6 | 7 | Contributions are welcome, and they are greatly appreciated! Every little bit 8 | helps, and credit will always be given. 9 | 10 | You can contribute in many ways: 11 | 12 | Types of Contributions 13 | ---------------------- 14 | 15 | Report Bugs 16 | ~~~~~~~~~~~ 17 | 18 | Report bugs at https://github.com/neu-vig/ezflow/issues. 19 | 20 | If you are reporting a bug, please include: 21 | 22 | * Your operating system name and version. 23 | * Any details about your local setup that might be helpful in troubleshooting. 24 | * Detailed steps to reproduce the bug. 25 | 26 | Fix Bugs 27 | ~~~~~~~~ 28 | 29 | Look through the GitHub issues for bugs. Anything tagged with "bug" and "help 30 | wanted" is open to whoever wants to implement it. 31 | 32 | Implement Features 33 | ~~~~~~~~~~~~~~~~~~ 34 | 35 | Look through the GitHub issues for features. Anything tagged with "enhancement" 36 | and "help wanted" is open to whoever wants to implement it. 37 | 38 | Write Documentation 39 | ~~~~~~~~~~~~~~~~~~~ 40 | 41 | EzFlow could always use more documentation, whether as part of the 42 | official EzFlow docs, in docstrings, or even on the web in blog posts, 43 | articles, and such. 44 | 45 | Submit Feedback 46 | ~~~~~~~~~~~~~~~ 47 | 48 | The best way to send feedback is to file an issue at https://github.com/neu-vig/ezflow/issues. 49 | 50 | If you are proposing a feature: 51 | 52 | * Explain in detail how it would work. 53 | * Keep the scope as narrow as possible, to make it easier to implement. 54 | * Remember that this is a volunteer-driven project, and that contributions 55 | are welcome :) 56 | 57 | Get Started! 58 | ------------ 59 | 60 | Ready to contribute? Here's how to set up `ezflow` for local development. 61 | 62 | 1. Fork the `ezflow` repo on GitHub. 63 | 2. Clone your fork locally:: 64 | 65 | $ git clone git@github.com:your_name_here/ezflow.git 66 | 67 | 3. Create a Conda virtual environment using the `environment.yml` file. Install your local copy of the package into the environment:: 68 | 69 | $ conda env create -f environment.yml 70 | $ conda activate ezflow 71 | $ python setup.py develop 72 | 73 | 4. Set up pre-commit hooks:: 74 | 75 | $ pip install pre-commit 76 | $ pre-commit install 77 | 78 | 5. Create a branch for local development:: 79 | 80 | $ git checkout -b name-of-your-bugfix-or-feature 81 | 82 | Now you can make your changes locally. 83 | 84 | 5. Ensure you write tests for the code you add and run the tests before you commit. You can run tests locally using `pytest` from the root directory of the repository:: 85 | 86 | $ pytest 87 | 88 | 6. Commit your changes and push your branch to GitHub:: 89 | 90 | $ git add . 91 | $ git commit -m "Your detailed description of your changes." 92 | $ git push origin name-of-your-bugfix-or-feature 93 | 94 | Note: You might need to add and commit twice if pre-commit hooks modify your code. 95 | 96 | 7. Submit a pull request through the GitHub website. 97 | 98 | Pull Request Guidelines 99 | ----------------------- 100 | 101 | Before you submit a pull request, check that it meets these guidelines: 102 | 103 | 1. The pull request should include tests. 104 | 2. If the pull request adds functionality, appropriate doctrings should be added. 105 | 106 | Tips 107 | ---- 108 | 109 | To run a subset of tests:: 110 | 111 | $ pytest tests.test_ezflow 112 | 113 | 114 | Deploying 115 | --------- 116 | 117 | A reminder for the maintainers on how to deploy. 118 | Make sure all your changes are committed (including an entry in HISTORY.rst). 119 | Then run:: 120 | 121 | $ bump2version patch # possible: major / minor / patch 122 | $ git push 123 | $ git push --tags 124 | 125 | Travis will then deploy to PyPI if tests pass. 126 | -------------------------------------------------------------------------------- /tools/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | from ezflow.data import build_dataloader, get_dataset_list 4 | from ezflow.engine import DistributedTrainer, Trainer, get_training_cfg 5 | from ezflow.models import build_model, get_model_list 6 | 7 | 8 | def main(args): 9 | 10 | # Load training configuration 11 | cfg = get_training_cfg(args.train_cfg) 12 | 13 | if args.device: 14 | cfg.DEVICE = args.device 15 | 16 | if args.train_ds is not None and args.train_data_dir is not None: 17 | cfg.DATA.TRAIN_DATASET[args.train_ds].ROOT_DIR = args.train_data_dir 18 | 19 | if args.val_ds is not None and args.val_data_dir is not None: 20 | cfg.DATA.VAL_DATASET[args.val_ds].ROOT_DIR = args.val_data_dir 21 | 22 | if args.n_epochs is not None: 23 | cfg.EPOCHS = args.n_epochs 24 | cfg.SCHEDULER.PARAMS.epochs = args.n_epochs 25 | 26 | cfg.LOG_DIR = args.log_dir 27 | cfg.CKPT_DIR = args.ckpt_dir 28 | 29 | # Create dataloader 30 | train_loader = build_dataloader( 31 | cfg.DATA, 32 | split="training", 33 | is_distributed=cfg.DISTRIBUTED.USE, 34 | world_size=cfg.DISTRIBUTED.WORLD_SIZE, 35 | ) 36 | 37 | val_loader = build_dataloader(cfg.DATA, split="validation") 38 | 39 | # Build model 40 | model = build_model(args.model, default=True) 41 | 42 | # Create trainer 43 | if cfg.DISTRIBUTED.USE is True: 44 | trainer = DistributedTrainer( 45 | cfg, 46 | model, 47 | train_loader_creator=train_loader, 48 | val_loader_creator=val_loader, 49 | ) 50 | else: 51 | trainer = Trainer( 52 | cfg, 53 | model, 54 | train_loader_creator=train_loader, 55 | val_loader_creator=val_loader, 56 | ) 57 | 58 | # Train model 59 | trainer.train(total_epochs=args.n_epochs) 60 | 61 | 62 | if __name__ == "__main__": 63 | 64 | parser = argparse.ArgumentParser( 65 | description="Train an optical flow model using EzFlow" 66 | ) 67 | parser.add_argument( 68 | "--train_cfg", 69 | type=str, 70 | required=True, 71 | help="Path to the training configuration file", 72 | ) 73 | parser.add_argument( 74 | "--train_ds", 75 | type=str, 76 | default=None, 77 | choices=get_dataset_list(), 78 | help="Name of the training dataset.", 79 | ) 80 | parser.add_argument( 81 | "--train_data_dir", 82 | type=str, 83 | default=None, 84 | help="Path to the root data directory", 85 | ) 86 | parser.add_argument( 87 | "--val_ds", 88 | type=str, 89 | default=None, 90 | choices=get_dataset_list(), 91 | help="Name of the validation dataset.", 92 | ) 93 | parser.add_argument( 94 | "--val_data_dir", 95 | type=str, 96 | default=None, 97 | help="Path to the root data directory", 98 | ) 99 | parser.add_argument( 100 | "--model", 101 | type=str, 102 | required=True, 103 | choices=get_model_list(), 104 | help="Name of the model to train", 105 | ) 106 | parser.add_argument( 107 | "--log_dir", 108 | type=str, 109 | required=True, 110 | help="Path to the log directory", 111 | ) 112 | parser.add_argument( 113 | "--ckpt_dir", 114 | type=str, 115 | required=True, 116 | help="Path to the log directory", 117 | ) 118 | parser.add_argument( 119 | "--n_epochs", type=int, default=None, help="Number of epochs to train" 120 | ) 121 | parser.add_argument( 122 | "--device", 123 | type=str, 124 | default="0", 125 | help="Device(s) to train on separated by commas. -1 for CPU", 126 | ) 127 | 128 | args = parser.parse_args() 129 | main(args) 130 | -------------------------------------------------------------------------------- /configs/trainers/_base_/sceneflow_baseline.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | BATCH_SIZE: 6 3 | APPEND_VALID_MASK: False 4 | NUM_WORKERS: 4 5 | PIN_MEMORY: True 6 | SHUFFLE: True 7 | INIT_SEED: False 8 | DROP_LAST: True 9 | TRAIN_DATASET: 10 | FlyingThings3DSubset: 11 | ROOT_DIR: "./Datasets/SceneFlow/FlyingThings3D_subset" 12 | SPLIT: "training" 13 | IS_PREDICTION: False 14 | APPEND_VALID_MASK: False 15 | CROP: &TRAIN_CROP_CONFIG 16 | USE: True 17 | SIZE: [400, 720] 18 | TYPE: "random" 19 | FLOW_OFFSET_PARAMS: {"use": False} 20 | AUGMENTATION: &TRAIN_AUGMENTATION_CONFIG 21 | # Augmentation Settings borrowed from RAFT 22 | USE: True 23 | PARAMS: 24 | color_aug_params: { 25 | "enabled": True, 26 | "asymmetric_color_aug_prob": 0.2, 27 | "brightness": 0.4, 28 | "contrast": 0.4, 29 | "saturation": 0.4, 30 | "hue": 0.15915494309189535 31 | } 32 | eraser_aug_params: { 33 | "enabled": True, 34 | "aug_prob": 0.5, 35 | "bounds": [50, 100] 36 | } 37 | noise_aug_params: { 38 | "enabled": False, 39 | "aug_prob": 0.5, 40 | "noise_std_range": 0.06 41 | } 42 | flip_aug_params: { 43 | "enabled": True, 44 | "h_flip_prob": 0.5, 45 | "v_flip_prob": 0.1 46 | } 47 | spatial_aug_params: { 48 | "enabled": True, 49 | "aug_prob": 0.8, 50 | "stretch_prob": 0.8, 51 | "min_scale": -0.1, 52 | "max_scale": 1.0, 53 | "max_stretch": 0.2, 54 | } 55 | advanced_spatial_aug_params: { 56 | "enabled": False, 57 | "scale1": 0.0, 58 | "scale2": 0.0, 59 | "stretch": 0.0, 60 | "rotate": 0.0, 61 | "translate": 0.0, 62 | "enable_out_of_boundary_crop": False 63 | } 64 | Driving: 65 | ROOT_DIR: "./Datasets/SceneFlow/Driving" 66 | IS_PREDICTION: False 67 | APPEND_VALID_MASK: False 68 | CROP: *TRAIN_CROP_CONFIG 69 | AUGMENTATION: *TRAIN_AUGMENTATION_CONFIG 70 | FLOW_OFFSET_PARAMS: {"use": False} 71 | Monkaa: 72 | ROOT_DIR: "./Datasets/SceneFlow/Monkaa" 73 | IS_PREDICTION: False 74 | APPEND_VALID_MASK: False 75 | CROP: *TRAIN_CROP_CONFIG 76 | AUGMENTATION: *TRAIN_AUGMENTATION_CONFIG 77 | FLOW_OFFSET_PARAMS: {"use": False} 78 | VAL_DATASET: 79 | MPISintelClean: 80 | ROOT_DIR: "./Datasets/MPI_Sintel/" 81 | SPLIT: "training" 82 | APPEND_VALID_MASK: False 83 | IS_PREDICTION: False 84 | PADDING: 1 85 | CROP: 86 | USE: True 87 | SIZE: [384, 1024] 88 | TYPE: "center" 89 | FLOW_OFFSET_PARAMS: {"use": False} 90 | AUGMENTATION: 91 | USE: False 92 | PARAMS: 93 | color_aug_params: {"enabled": False} 94 | eraser_aug_params: {"enabled": False} 95 | noise_aug_params: {"enabled": False} 96 | flip_aug_params: {"enabled": False} 97 | spatial_aug_params: {"enabled": False} 98 | advanced_spatial_aug_params: {"enabled": False} 99 | OPTIMIZER: 100 | NAME: AdamW 101 | LR: 0.000125 102 | PARAMS: 103 | weight_decay: 0.0001 104 | betas: [0.9, 0.999] 105 | eps: 1.e-08 106 | amsgrad: False 107 | GRAD_CLIP: 108 | USE: True 109 | VALUE: 1.0 110 | FREEZE_BATCH_NORM: False 111 | TARGET_SCALE_FACTOR: 1.0 112 | MIXED_PRECISION: False 113 | DEVICE: "0" 114 | DISTRIBUTED: 115 | USE: False 116 | WORLD_SIZE: 2 117 | BACKEND: nccl 118 | MASTER_ADDR: localhost 119 | MASTER_PORT: "12355" 120 | EPOCHS: null 121 | NUM_STEPS: null 122 | RESUME_TRAINING: 123 | CONSOLIDATED_CKPT: null 124 | EPOCHS: null 125 | START_EPOCH: null -------------------------------------------------------------------------------- /ezflow/models/flownet_c.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.init import constant_, kaiming_normal_ 5 | 6 | from ..decoder import build_decoder 7 | from ..encoder import BasicConvEncoder, build_encoder 8 | from ..modules import BaseModule, conv 9 | from ..similarity import IterSpatialCorrelationSampler as SpatialCorrelationSampler 10 | from .build import MODEL_REGISTRY 11 | 12 | 13 | @MODEL_REGISTRY.register() 14 | class FlowNetC(BaseModule): 15 | """ 16 | Implementation of **FlowNetCorrelation** from the paper 17 | `FlowNet: Learning Optical Flow with Convolutional Networks `_ 18 | 19 | Parameters 20 | ---------- 21 | cfg : :class:`CfgNode` 22 | Configuration for the model 23 | """ 24 | 25 | def __init__(self, cfg): 26 | super(FlowNetC, self).__init__() 27 | 28 | self.cfg = cfg 29 | 30 | channels = cfg.ENCODER.CONFIG 31 | cfg.ENCODER.CONFIG = channels[:3] 32 | 33 | self.feature_encoder = build_encoder(cfg.ENCODER) 34 | 35 | self.correlation_layer = SpatialCorrelationSampler( 36 | kernel_size=1, 37 | patch_size=2 * cfg.SIMILARITY.MAX_DISPLACEMENT + 1, 38 | padding=cfg.SIMILARITY.PAD_SIZE, 39 | dilation_patch=2, 40 | ) 41 | 42 | self.corr_activation = nn.LeakyReLU(negative_slope=0.1, inplace=True) 43 | 44 | self.conv_redirect = conv( 45 | in_channels=cfg.ENCODER.CONFIG[-1], out_channels=32, norm=cfg.ENCODER.NORM 46 | ) 47 | 48 | self.corr_encoder = BasicConvEncoder( 49 | in_channels=473, config=channels[3:], norm=cfg.ENCODER.NORM 50 | ) 51 | 52 | self.decoder = build_decoder(cfg.DECODER) 53 | 54 | for m in self.modules(): 55 | if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): 56 | kaiming_normal_(m.weight, 0.1) 57 | if m.bias is not None: 58 | constant_(m.bias, 0) 59 | elif isinstance(m, nn.BatchNorm2d): 60 | constant_(m.weight, 1) 61 | constant_(m.bias, 0) 62 | 63 | def forward(self, img1, img2): 64 | """ 65 | Performs forward pass of the network 66 | 67 | Parameters 68 | ---------- 69 | img1 : torch.Tensor 70 | Image to predict flow from 71 | img2 : torch.Tensor 72 | Image to predict flow to 73 | 74 | Returns 75 | ------- 76 | :class:`dict` 77 | torch.Tensor : intermediate flow predications from img1 to img2 78 | torch.Tensor : if model is in eval state, return upsampled flow 79 | """ 80 | 81 | H, W = img1.shape[-2:] 82 | 83 | conv_outputs1 = self.feature_encoder(img1) 84 | conv_outputs2 = self.feature_encoder(img2) 85 | 86 | corr_output = self.correlation_layer(conv_outputs1[-1], conv_outputs2[-1]) 87 | corr_output = corr_output.view( 88 | corr_output.shape[0], -1, corr_output.shape[3], corr_output.shape[4] 89 | ) 90 | corr_output = self.corr_activation(corr_output) 91 | 92 | # Redirect final feature output of img1 93 | conv_redirect_output = self.conv_redirect(conv_outputs1[-1]) 94 | 95 | x = torch.cat([conv_redirect_output, corr_output], dim=1) 96 | 97 | conv_outputs = self.corr_encoder(x) 98 | 99 | # Add first two convolution output from img1 100 | conv_outputs = [conv_outputs1[0], conv_outputs1[1]] + conv_outputs 101 | 102 | flow_preds = self.decoder(conv_outputs) 103 | 104 | output = {"flow_preds": flow_preds} 105 | 106 | if self.training: 107 | return output 108 | 109 | flow_up = flow_preds[-1] 110 | 111 | flow_up = F.interpolate( 112 | flow_up, size=(H, W), mode="bilinear", align_corners=False 113 | ) 114 | 115 | output["flow_upsampled"] = flow_up 116 | 117 | return output 118 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import codecs 2 | import glob 3 | import os 4 | import shutil 5 | 6 | from setuptools import find_packages, setup 7 | 8 | # Basic information 9 | NAME = "ezflow" 10 | DESCRIPTION = "A PyTorch library for optical flow estimation using neural networks" 11 | VERSION = "0.2.5" 12 | AUTHOR = "EzFlow Contributors" 13 | EMAIL = "shahnh19@gmail.com" 14 | LICENSE = "MIT" 15 | REPOSITORY = "https://github.com/neu-vig/ezflow" 16 | PACKAGE = "ezflow" 17 | 18 | with open("README.md", "r") as f: 19 | LONG_DESCRIPTION = f.read() 20 | 21 | # Define the keywords 22 | KEYWORDS = ["optical flow", "pytorch", "machine learning", "deep learning"] 23 | 24 | # Define the classifiers 25 | # See https://pypi.python.org/pypi?%3Aaction=list_classifiers 26 | CLASSIFIERS = [ 27 | "Development Status :: 2 - Pre-Alpha", 28 | "Intended Audience :: Developers", 29 | "License :: OSI Approved :: MIT License", 30 | "Natural Language :: English", 31 | "Programming Language :: Python :: 3", 32 | "Programming Language :: Python :: 3.6", 33 | "Programming Language :: Python :: 3.7", 34 | "Programming Language :: Python :: 3.8", 35 | ] 36 | 37 | # Important Paths 38 | PROJECT = os.path.abspath(os.path.dirname(__file__)) 39 | REQUIRE_PATH = "requirements.txt" 40 | PKG_DESCRIBE = "README.md" 41 | 42 | # Directories to ignore in find_packages 43 | EXCLUDES = () 44 | 45 | 46 | # helper functions 47 | def read(*parts): 48 | """ 49 | returns contents of file 50 | """ 51 | with codecs.open(os.path.join(PROJECT, *parts), "rb", "utf-8") as file: 52 | return file.read() 53 | 54 | 55 | def get_requires(path=REQUIRE_PATH): 56 | """ 57 | generates requirements from file path given as REQUIRE_PATH 58 | """ 59 | for line in read(path).splitlines(): 60 | line = line.strip() 61 | if line and not line.startswith("#"): 62 | yield line 63 | 64 | 65 | def get_model_zoo_configs(): 66 | """ 67 | Return a list of configs to include in package for model zoo. Copy over these configs inside 68 | ezflow/model_zoo. 69 | """ 70 | 71 | # Use absolute paths while symlinking. 72 | source_configs_dir = os.path.join( 73 | os.path.dirname(os.path.realpath(__file__)), "configs" 74 | ) 75 | destination = os.path.join( 76 | os.path.dirname(os.path.realpath(__file__)), 77 | "ezflow", 78 | "model_zoo", 79 | "configs", 80 | ) 81 | # Symlink the config directory inside package to have a cleaner pip install. 82 | 83 | # Remove stale symlink/directory from a previous build. 84 | if os.path.exists(source_configs_dir): 85 | if os.path.islink(destination): 86 | os.unlink(destination) 87 | elif os.path.isdir(destination): 88 | shutil.rmtree(destination) 89 | 90 | if not os.path.exists(destination): 91 | try: 92 | os.symlink(source_configs_dir, destination) 93 | except OSError: 94 | # Fall back to copying if symlink fails: ex. on Windows. 95 | shutil.copytree(source_configs_dir, destination) 96 | 97 | config_paths = glob.glob("configs/**/*.yaml", recursive=True) + glob.glob( 98 | "configs/**/*.py", recursive=True 99 | ) 100 | return config_paths 101 | 102 | 103 | # Define the configuration 104 | CONFIG = { 105 | "name": NAME, 106 | "version": VERSION, 107 | "description": DESCRIPTION, 108 | "long_description": LONG_DESCRIPTION, 109 | "long_description_content_type": "text/markdown", 110 | "classifiers": CLASSIFIERS, 111 | "keywords": KEYWORDS, 112 | "license": LICENSE, 113 | "author": AUTHOR, 114 | "author_email": EMAIL, 115 | "url": REPOSITORY, 116 | "project_urls": {"Source": REPOSITORY}, 117 | "packages": find_packages( 118 | where=PROJECT, include=["ezflow", "ezflow.*"], exclude=EXCLUDES 119 | ), 120 | "package_data": {"ezflow.model_zoo": get_model_zoo_configs()}, 121 | "install_requires": list(get_requires()), 122 | "python_requires": ">=3.6", 123 | "test_suite": "tests", 124 | "tests_require": ["pytest>=3"], 125 | "include_package_data": True, 126 | } 127 | 128 | if __name__ == "__main__": 129 | setup(**CONFIG) 130 | -------------------------------------------------------------------------------- /ezflow/data/dataset/hd1k.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from glob import glob 3 | 4 | from ...config import configurable 5 | from ...functional import SparseFlowAugmentor 6 | from ..build import DATASET_REGISTRY 7 | from .base_dataset import BaseDataset 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class HD1K(BaseDataset): 12 | """ 13 | Dataset Class for preparing the HD1K dataset for training and validation. 14 | 15 | Parameters 16 | ---------- 17 | root_dir : str 18 | path of the root directory for the HD1K dataset 19 | is_prediction : bool, default : False 20 | If True, only image data are loaded for prediction otherwise both images and flow data are loaded 21 | init_seed : bool, default : False 22 | If True, sets random seed to worker 23 | append_valid_mask : bool, default : False 24 | If True, appends the valid flow mask to the original flow mask at dim=0 25 | crop: bool, default : True 26 | Whether to perform cropping 27 | crop_size : :obj:`tuple` of :obj:`int` 28 | The size of the image crop 29 | crop_type : :obj:`str`, default : 'center' 30 | The type of croppping to be performed, one of "center", "random" 31 | augment : bool, default : True 32 | If True, applies data augmentation 33 | aug_params : :obj:`dict`, optional 34 | The parameters for data augmentation 35 | norm_params : :obj:`dict`, optional 36 | The parameters for normalization 37 | flow_offset_params: :obj:`dict`, optional 38 | The parameters for adding bilinear interpolated weights surrounding each ground truth flow values. 39 | """ 40 | 41 | @configurable 42 | def __init__( 43 | self, 44 | root_dir, 45 | is_prediction=False, 46 | init_seed=False, 47 | append_valid_mask=False, 48 | crop=False, 49 | crop_size=(256, 256), 50 | crop_type="center", 51 | augment=True, 52 | aug_params={ 53 | "eraser_aug_params": {"enabled": False}, 54 | "noise_aug_params": {"enabled": False}, 55 | "flip_aug_params": {"enabled": False}, 56 | "color_aug_params": {"enabled": False}, 57 | "spatial_aug_params": {"enabled": False}, 58 | "advanced_spatial_aug_params": {"enabled": False}, 59 | }, 60 | norm_params={"use": False}, 61 | flow_offset_params={"use": False}, 62 | ): 63 | super(HD1K, self).__init__( 64 | init_seed=init_seed, 65 | is_prediction=is_prediction, 66 | append_valid_mask=append_valid_mask, 67 | crop=crop, 68 | crop_size=crop_size, 69 | crop_type=crop_type, 70 | augment=augment, 71 | aug_params=aug_params, 72 | sparse_transform=True, 73 | norm_params=norm_params, 74 | flow_offset_params=flow_offset_params, 75 | ) 76 | 77 | self.is_prediction = is_prediction 78 | self.append_valid_mask = append_valid_mask 79 | 80 | if augment: 81 | self.augmentor = SparseFlowAugmentor(crop_size=crop_size, **aug_params) 82 | 83 | seq_ix = 0 84 | while 1: 85 | flows = sorted( 86 | glob(osp.join(root_dir, "hd1k_flow_gt", "flow_occ/%06d_*.png" % seq_ix)) 87 | ) 88 | images = sorted( 89 | glob(osp.join(root_dir, "hd1k_input", "image_2/%06d_*.png" % seq_ix)) 90 | ) 91 | 92 | if len(flows) == 0: 93 | break 94 | 95 | for i in range(len(flows) - 1): 96 | self.flow_list += [flows[i]] 97 | self.image_list += [[images[i], images[i + 1]]] 98 | 99 | seq_ix += 1 100 | 101 | @classmethod 102 | def from_config(cls, cfg): 103 | return { 104 | "root_dir": cfg.ROOT_DIR, 105 | "is_prediction": cfg.IS_PREDICTION, 106 | "init_seed": cfg.INIT_SEED, 107 | "append_valid_mask": cfg.APPEND_VALID_MASK, 108 | "crop": cfg.CROP.USE, 109 | "crop_size": cfg.CROP.SIZE, 110 | "crop_type": cfg.CROP.TYPE, 111 | "augment": cfg.AUGMENTATION.USE, 112 | "aug_params": cfg.AUGMENTATION.PARAMS, 113 | "norm_params": cfg.NORM_PARAMS, 114 | "flow_offset_params": cfg.FLOW_OFFSET_PARAMS, 115 | } 116 | -------------------------------------------------------------------------------- /tests/test_utils.py: -------------------------------------------------------------------------------- 1 | from gettext import find 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | 7 | from ezflow.utils import ( 8 | AverageMeter, 9 | concentric_offsets, 10 | coords_grid, 11 | endpointerror, 12 | find_free_port, 13 | flow_to_bilinear_interpolation_weights, 14 | forward_interpolate, 15 | get_flow_offsets, 16 | is_port_available, 17 | replace_relu, 18 | upflow, 19 | ) 20 | 21 | 22 | def test_endpointerror(): 23 | 24 | pred = torch.rand(4, 2, 256, 256) 25 | target = torch.rand(4, 2, 256, 256) 26 | epe = endpointerror(pred, target) 27 | 28 | multi_magnitude_epe = endpointerror(pred, target, multi_magnitude=True) 29 | assert isinstance(multi_magnitude_epe, dict) 30 | 31 | valid = torch.rand(4, 1, 256, 256) 32 | epe, f1 = endpointerror(pred, target, valid) 33 | 34 | multi_magnitude_epe, f1 = endpointerror(pred, target, valid, multi_magnitude=True) 35 | assert isinstance(multi_magnitude_epe, dict) 36 | 37 | 38 | def test_forward_interpolate(): 39 | 40 | flow = torch.rand(2, 256, 256) 41 | _ = forward_interpolate(flow) 42 | 43 | 44 | def test_upflow(): 45 | 46 | flow = torch.rand(2, 2, 256, 256) 47 | _ = upflow(flow) 48 | 49 | 50 | def test_coords_grid(): 51 | 52 | _ = coords_grid(2, 256, 256) 53 | 54 | 55 | def test_AverageMeter(): 56 | 57 | meter = AverageMeter() 58 | meter.update(1) 59 | assert meter.avg == 1 60 | 61 | meter.reset() 62 | assert meter.avg == 0 63 | 64 | 65 | def test_find_free_port(): 66 | assert len(find_free_port()) == 5 67 | 68 | 69 | def test_is_port_available(): 70 | port = find_free_port() 71 | assert is_port_available(int(port)) is True 72 | 73 | 74 | def test_replace_relu(): 75 | model = nn.Sequential( 76 | nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=True) 77 | ) 78 | 79 | model = replace_relu(model, nn.LeakyReLU(negative_slope=0.1)) 80 | 81 | assert isinstance(model[1], nn.LeakyReLU) 82 | 83 | del model 84 | 85 | 86 | def test_concentric_offsets(): 87 | 88 | offset_matrix = [ 89 | [-4, -3, -2, -1, 0, 1, 2, 3, 4], 90 | [-20, -15, -10, -5, 0, 5, 10, 15, 20], 91 | [-36, -27, -18, -9, 0, 9, 18, 27, 36], 92 | [-64, -48, -32, -16, 0, 16, 32, 48, 64], 93 | ] 94 | 95 | offset_matrix = np.array(offset_matrix) 96 | offsets = concentric_offsets(dilations=[1, 5, 9, 16], radius=4) 97 | 98 | assert offsets.shape == (4, 9) 99 | assert (offsets == offset_matrix).all() 100 | 101 | del offsets, offset_matrix 102 | 103 | 104 | def test_get_flow_offsets(): 105 | 106 | flow_offsets_matrix = [ 107 | [-8, -6, -4, -2, 0, 2, 4, 6, 8], 108 | [-32, -24, -16, -8, 0, 8, 16, 24, 32], 109 | [-64, -48, -32, -16, 0, 16, 32, 48, 64], 110 | [-96, -72, -48, -24, 0, 24, 48, 72, 96], 111 | [-160, -120, -80, -40, 0, 40, 80, 120, 160], 112 | [-288, -216, -144, -72, 0, 72, 144, 216, 288], 113 | [-512, -384, -256, -128, 0, 128, 256, 384, 512], 114 | ] 115 | 116 | flow_offsets = get_flow_offsets( 117 | dilations=[[1], [1, 2, 3, 5, 9, 16]], 118 | feat_strides=[2, 8], 119 | radius=4, 120 | offset_bias=[0, 0], 121 | offset_fn=concentric_offsets, 122 | ) 123 | 124 | assert flow_offsets.shape == (7, 9) 125 | assert (flow_offsets == flow_offsets_matrix).all() 126 | 127 | del flow_offsets, flow_offsets_matrix 128 | 129 | 130 | def test_flow_to_bilinear_interpolation_weights(): 131 | 132 | flow = np.ones((32, 32, 2)) 133 | flow[:, :, 1] = flow[:, :, 1] * 2 134 | 135 | valid = np.ones(flow.shape[:2]) > 0 136 | 137 | flow_offsets = get_flow_offsets( 138 | dilations=[[1], [1, 2, 3]], 139 | feat_strides=[2, 8], 140 | radius=4, 141 | offset_bias=[0, 0], 142 | offset_fn=concentric_offsets, 143 | ) 144 | 145 | offset_labs, dilation_labs = flow_to_bilinear_interpolation_weights( 146 | flow, valid, flow_offsets 147 | ) 148 | 149 | assert offset_labs.shape == (32, 32, 4, 9, 9) 150 | assert dilation_labs.shape == (32, 32) 151 | 152 | offset_labs_reshape = offset_labs.reshape(32, 32, -1) 153 | err = np.sum(np.abs(np.sum(offset_labs_reshape, axis=2) - 1)) 154 | assert err < 1e-10, err 155 | 156 | del flow, valid, offset_labs, dilation_labs 157 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import transforms as T 3 | 4 | from ezflow.models import Predictor, build_model 5 | 6 | img1 = torch.randn(2, 3, 256, 256) 7 | img2 = torch.randn(2, 3, 256, 256) 8 | 9 | 10 | def test_Predictor(): 11 | 12 | predictor = Predictor("RAFT", (0.0, 0.0, 0.0), (255.0, 255.0, 255.0), "raft.yaml") 13 | flow = predictor(img1, img2) 14 | assert flow.shape == (2, 2, 256, 256) 15 | 16 | transform = T.Compose([T.Resize((224, 224))]) 17 | 18 | predictor = Predictor( 19 | "RAFT", 20 | (0.0, 0.0, 0.0), 21 | (255.0, 255.0, 255.0), 22 | "raft.yaml", 23 | data_transform=transform, 24 | pad_divisor=32, 25 | ) 26 | flow = predictor(img1, img2) 27 | assert flow.shape == (2, 2, 224, 224) 28 | 29 | 30 | def test_RAFT(): 31 | 32 | model = build_model("RAFT", "raft.yaml") 33 | output = model(img1, img2) 34 | assert isinstance(output, dict) 35 | assert isinstance(output["flow_preds"], tuple) or isinstance( 36 | output["flow_preds"], list 37 | ) 38 | 39 | model.eval() 40 | output = model(img1, img2) 41 | assert output["flow_upsampled"].shape == (2, 2, 256, 256) 42 | 43 | del model, output 44 | 45 | _ = build_model("RAFT", default=True) 46 | 47 | 48 | def test_DICL(): 49 | 50 | model = build_model("DICL", "dicl.yaml") 51 | output = model(img1, img2) 52 | assert isinstance(output, dict) 53 | assert isinstance(output["flow_preds"], tuple) or isinstance( 54 | output["flow_preds"], list 55 | ) 56 | 57 | model.eval() 58 | output = model(img1, img2) 59 | assert output["flow_upsampled"].shape == (2, 2, 256, 256) 60 | 61 | del model, output 62 | 63 | _ = build_model("DICL", default=True) 64 | 65 | 66 | def test_PWCNet(): 67 | 68 | model = build_model("PWCNet", "pwcnet.yaml") 69 | output = model(img1, img2) 70 | assert isinstance(output, dict) 71 | assert isinstance(output["flow_preds"], tuple) or isinstance( 72 | output["flow_preds"], list 73 | ) 74 | 75 | model.eval() 76 | output = model(img1, img2) 77 | assert output["flow_upsampled"].shape == (2, 2, 256, 256) 78 | 79 | del model, output 80 | 81 | _ = build_model("PWCNet", default=True) 82 | 83 | 84 | def test_FlowNetS(): 85 | 86 | model = build_model("FlowNetS", "flownet_s.yaml") 87 | output = model(img1, img2) 88 | assert isinstance(output, dict) 89 | assert isinstance(output["flow_preds"], tuple) or isinstance( 90 | output["flow_preds"], list 91 | ) 92 | 93 | model.eval() 94 | output = model(img1, img2) 95 | assert output["flow_upsampled"].shape == (2, 2, 256, 256) 96 | 97 | del model, output 98 | 99 | _ = build_model("FlowNetS", default=True) 100 | 101 | 102 | def test_FlowNetC(): 103 | 104 | model = build_model("FlowNetC", "flownet_c.yaml") 105 | output = model(img1, img2) 106 | assert isinstance(output, dict) 107 | assert isinstance(output["flow_preds"], tuple) or isinstance( 108 | output["flow_preds"], list 109 | ) 110 | 111 | model.eval() 112 | output = model(img1, img2) 113 | assert output["flow_upsampled"].shape == (2, 2, 256, 256) 114 | 115 | del model, output 116 | 117 | _ = build_model("FlowNetC", default=True) 118 | 119 | 120 | def test_VCN(): 121 | 122 | model = build_model("VCN", "vcn.yaml") 123 | 124 | img = torch.randn(16, 3, 256, 256) 125 | 126 | output = model(img, img) 127 | assert isinstance(output, dict) 128 | assert isinstance(output["flow_preds"], tuple) or isinstance( 129 | output["flow_preds"], list 130 | ) 131 | 132 | model.eval() 133 | output = model(img, img) 134 | assert output["flow_upsampled"].shape == (16, 2, 256, 256) 135 | 136 | del model, output 137 | 138 | 139 | def test_DCVNet(): 140 | 141 | model = build_model("DCVNet", "dcvnet.yaml") 142 | output = model(img1, img2) 143 | assert isinstance(output, dict) 144 | assert isinstance(output["flow_preds"], tuple) or isinstance( 145 | output["flow_preds"], list 146 | ) 147 | assert isinstance(output["flow_logits"], tuple) or isinstance( 148 | output["flow_logits"], list 149 | ) 150 | 151 | model.eval() 152 | output = model(img1, img2) 153 | assert output["flow_upsampled"].shape == (2, 2, 256, 256) 154 | 155 | del model, output 156 | 157 | model = build_model("DCVNet", default=True) 158 | del model 159 | -------------------------------------------------------------------------------- /ezflow/models/raft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..decoder import build_decoder 6 | from ..encoder import build_encoder 7 | from ..modules import BaseModule 8 | from ..similarity import build_similarity 9 | from ..utils import convex_upsample_flow, coords_grid, upflow 10 | from .build import MODEL_REGISTRY 11 | 12 | try: 13 | autocast = torch.cuda.amp.autocast 14 | except: 15 | 16 | class autocast: 17 | def __init__(self, enabled): 18 | pass 19 | 20 | def __enter__(self): 21 | pass 22 | 23 | def __exit__(self, *args): 24 | pass 25 | 26 | 27 | @MODEL_REGISTRY.register() 28 | class RAFT(BaseModule): 29 | """ 30 | Implementation of the paper 31 | `RAFT: Recurrent All-Pairs Field Transforms for Optical Flow `_ 32 | 33 | Parameters 34 | ---------- 35 | cfg : :class:`CfgNode` 36 | Configuration for the model 37 | """ 38 | 39 | def __init__(self, cfg): 40 | super(RAFT, self).__init__() 41 | 42 | self.cfg = cfg 43 | 44 | self.fnet = build_encoder(cfg.ENCODER.FEATURE) 45 | self.cnet = build_encoder( 46 | cfg.ENCODER.CONTEXT, out_channels=cfg.HIDDEN_DIM + cfg.CONTEXT_DIM 47 | ) 48 | 49 | self.similarity_fn = build_similarity(cfg.SIMILARITY, instantiate=False) 50 | self.corr_radius = cfg.CORR_RADIUS 51 | self.corr_levels = cfg.CORR_LEVELS 52 | 53 | self.update_block = build_decoder( 54 | name=cfg.DECODER.NAME, 55 | corr_radius=self.corr_radius, 56 | corr_levels=self.corr_levels, 57 | hidden_dim=cfg.HIDDEN_DIM, 58 | input_dim=cfg.DECODER.INPUT_DIM, 59 | ) 60 | 61 | def _initialize_flow(self, img): 62 | 63 | N, _, H, W = img.shape 64 | coords0 = coords_grid(N, H // 8, W // 8).to(img.device) 65 | coords1 = coords_grid(N, H // 8, W // 8).to(img.device) 66 | 67 | return coords0, coords1 68 | 69 | def forward(self, img1, img2, flow_init=None): 70 | """ 71 | Performs forward pass of the network 72 | 73 | Parameters 74 | ---------- 75 | img1 : torch.Tensor 76 | Image to predict flow from 77 | img2 : torch.Tensor 78 | Image to predict flow to 79 | 80 | Returns 81 | ------- 82 | :class:`dict` 83 | torch.Tensor : intermediate flow predications from img1 to img2 84 | torch.Tensor : if model is in eval state, return upsampled flow 85 | """ 86 | 87 | img1 = img1.contiguous() 88 | img2 = img2.contiguous() 89 | 90 | with autocast(enabled=self.cfg.MIXED_PRECISION): 91 | fmap1, fmap2 = self.fnet([img1, img2]) 92 | 93 | fmap1 = fmap1.float() 94 | fmap2 = fmap2.float() 95 | 96 | corr_fn = self.similarity_fn( 97 | fmap1, fmap2, num_levels=self.corr_levels, corr_radius=self.corr_radius 98 | ) 99 | 100 | with autocast(enabled=self.cfg.MIXED_PRECISION): 101 | cnet = self.cnet(img1) 102 | net, inp = torch.split( 103 | cnet, [self.cfg.HIDDEN_DIM, self.cfg.CONTEXT_DIM], dim=1 104 | ) 105 | net = torch.tanh(net) 106 | inp = torch.relu(inp) 107 | 108 | coords0, coords1 = self._initialize_flow(img1) 109 | 110 | if flow_init is not None: 111 | coords1 = coords1 + flow_init 112 | 113 | flow_predictions = [] 114 | for _ in range(self.cfg.UPDATE_ITERS): 115 | coords1 = coords1.detach() 116 | corr = corr_fn(coords1) 117 | 118 | flow = coords1 - coords0 119 | with autocast(enabled=self.cfg.MIXED_PRECISION): 120 | net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) 121 | 122 | coords1 = coords1 + delta_flow 123 | flow = coords1 - coords0 124 | if up_mask is None: 125 | flow_up = upflow(flow) 126 | else: 127 | flow_up = convex_upsample_flow(8 * flow, up_mask, out_stride=8) 128 | 129 | flow_predictions.append(flow_up) 130 | 131 | output = {"flow_preds": flow_predictions} 132 | 133 | if self.training: 134 | return output 135 | 136 | output["flow_upsampled"] = flow_up 137 | return output 138 | -------------------------------------------------------------------------------- /ezflow/data/dataset/autoflow.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from glob import glob 3 | 4 | from ...config import configurable 5 | from ...functional import FlowAugmentor 6 | from ..build import DATASET_REGISTRY 7 | from .base_dataset import BaseDataset 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class AutoFlow(BaseDataset): 12 | """ 13 | Dataset Class for preparing the AutoFlow Synthetic dataset for training and validation. 14 | 15 | Parameters 16 | ---------- 17 | root_dir : str 18 | path of the root directory for the Monkaa dataset 19 | is_prediction : bool, default : False 20 | If True, only image data are loaded for prediction otherwise both images and flow data are loaded 21 | init_seed : bool, default : False 22 | If True, sets random seed to worker 23 | append_valid_mask : bool, default : False 24 | If True, appends the valid flow mask to the original flow mask at dim=0 25 | crop: bool, default : True 26 | Whether to perform cropping 27 | crop_size : :obj:`tuple` of :obj:`int` 28 | The size of the image crop 29 | crop_type : :obj:`str`, default : 'center' 30 | The type of croppping to be performed, one of "center", "random" 31 | augment : bool, default : True 32 | If True, applies data augmentation 33 | aug_params : :obj:`dict`, optional 34 | The parameters for data augmentation 35 | norm_params : :obj:`dict`, optional 36 | The parameters for normalization 37 | flow_offset_params: :obj:`dict`, optional 38 | The parameters for adding bilinear interpolated weights surrounding each ground truth flow values. 39 | """ 40 | 41 | @configurable 42 | def __init__( 43 | self, 44 | root_dir, 45 | is_prediction=False, 46 | init_seed=False, 47 | append_valid_mask=False, 48 | crop=False, 49 | crop_size=(256, 256), 50 | crop_type="center", 51 | augment=True, 52 | aug_params={ 53 | "eraser_aug_params": {"enabled": False}, 54 | "noise_aug_params": {"enabled": False}, 55 | "flip_aug_params": {"enabled": False}, 56 | "color_aug_params": {"enabled": False}, 57 | "spatial_aug_params": {"enabled": False}, 58 | "advanced_spatial_aug_params": {"enabled": False}, 59 | }, 60 | norm_params={"use": False}, 61 | flow_offset_params={"use": False}, 62 | ): 63 | super(AutoFlow, self).__init__( 64 | init_seed=init_seed, 65 | is_prediction=is_prediction, 66 | append_valid_mask=append_valid_mask, 67 | crop=crop, 68 | crop_size=crop_size, 69 | crop_type=crop_type, 70 | augment=augment, 71 | aug_params=aug_params, 72 | sparse_transform=False, 73 | norm_params=norm_params, 74 | flow_offset_params=flow_offset_params, 75 | ) 76 | 77 | self.is_prediction = is_prediction 78 | self.append_valid_mask = append_valid_mask 79 | 80 | if augment: 81 | self.augmentor = FlowAugmentor(crop_size=crop_size, **aug_params) 82 | 83 | scenes = [ 84 | "static_40k_png_1_of_4", 85 | "static_40k_png_2_of_4", 86 | "static_40k_png_2_of_4", 87 | "static_40k_png_2_of_4", 88 | ] 89 | 90 | for scene in scenes: 91 | seqs = glob(osp.join(root_dir, scene, "*")) 92 | for s in seqs: 93 | images = sorted(glob(osp.join(s, "*.png"))) 94 | flows = sorted(glob(osp.join(s, "*.flo"))) 95 | if len(images) == 2: 96 | assert len(flows) == 1 97 | for i in range(len(flows)): 98 | self.flow_list += [flows[i]] 99 | self.image_list += [[images[i], images[i + 1]]] 100 | 101 | @classmethod 102 | def from_config(cls, cfg): 103 | return { 104 | "root_dir": cfg.ROOT_DIR, 105 | "is_prediction": cfg.IS_PREDICTION, 106 | "init_seed": cfg.INIT_SEED, 107 | "append_valid_mask": cfg.APPEND_VALID_MASK, 108 | "crop": cfg.CROP.USE, 109 | "crop_size": cfg.CROP.SIZE, 110 | "crop_type": cfg.CROP.TYPE, 111 | "augment": cfg.AUGMENTATION.USE, 112 | "aug_params": cfg.AUGMENTATION.PARAMS, 113 | "norm_params": cfg.NORM_PARAMS, 114 | "flow_offset_params": cfg.FLOW_OFFSET_PARAMS, 115 | } 116 | -------------------------------------------------------------------------------- /ezflow/encoder/ganet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from ..config import configurable 5 | from ..modules import Conv2x, ConvNormRelu 6 | from .build import ENCODER_REGISTRY 7 | 8 | 9 | @ENCODER_REGISTRY.register() 10 | class GANetBackbone(nn.Module): 11 | """ 12 | Feature extractor backbone used in **GA-Net: Guided Aggregation Net for End-to-end Stereo Matching** (https://arxiv.org/abs/1904.06587) 13 | 14 | Parameters 15 | ---------- 16 | in_channels : int 17 | Number of input channels 18 | out_channels : int 19 | Number of output channels 20 | """ 21 | 22 | @configurable 23 | def __init__(self, in_channels=3, out_channels=32): 24 | super(GANetBackbone, self).__init__() 25 | 26 | self.conv_start = nn.Sequential( 27 | ConvNormRelu(in_channels, 32, kernel_size=3, padding=1), 28 | ConvNormRelu(32, 32, kernel_size=3, stride=2, padding=1), 29 | ConvNormRelu(32, 32, kernel_size=3, padding=1), 30 | ) 31 | self.conv1a = ConvNormRelu(32, 48, kernel_size=3, stride=2, padding=1) 32 | self.conv2a = ConvNormRelu(48, 64, kernel_size=3, stride=2, padding=1) 33 | self.conv3a = ConvNormRelu(64, 96, kernel_size=3, stride=2, padding=1) 34 | self.conv4a = ConvNormRelu(96, 128, kernel_size=3, stride=2, padding=1) 35 | self.conv5a = ConvNormRelu(128, 160, kernel_size=3, stride=2, padding=1) 36 | self.conv6a = ConvNormRelu(160, 192, kernel_size=3, stride=2, padding=1) 37 | 38 | self.deconv6a = Conv2x(192, 160, deconv=True) 39 | self.deconv5a = Conv2x(160, 128, deconv=True) 40 | self.deconv4a = Conv2x(128, 96, deconv=True) 41 | self.deconv3a = Conv2x(96, 64, deconv=True) 42 | self.deconv2a = Conv2x(64, 48, deconv=True) 43 | self.deconv1a = Conv2x(48, 32, deconv=True) 44 | 45 | self.conv1b = Conv2x(32, 48) 46 | self.conv2b = Conv2x(48, 64) 47 | self.conv3b = Conv2x(64, 96) 48 | self.conv4b = Conv2x(96, 128) 49 | self.conv5b = Conv2x(128, 160) 50 | self.conv6b = Conv2x(160, 192) 51 | 52 | self.deconv6b = Conv2x(192, 160, deconv=True) 53 | self.outconv_6 = ConvNormRelu(160, 32, kernel_size=3, padding=1) 54 | 55 | self.deconv5b = Conv2x(160, 128, deconv=True) 56 | self.outconv_5 = ConvNormRelu(128, 32, kernel_size=3, padding=1) 57 | 58 | self.deconv4b = Conv2x(128, 96, deconv=True) 59 | self.outconv_4 = ConvNormRelu(96, 32, kernel_size=3, padding=1) 60 | 61 | self.deconv3b = Conv2x(96, 64, deconv=True) 62 | self.outconv_3 = ConvNormRelu(64, 32, kernel_size=3, padding=1) 63 | 64 | self.deconv2b = Conv2x(64, 48, deconv=True) 65 | self.outconv_2 = ConvNormRelu(48, out_channels, kernel_size=3, padding=1) 66 | 67 | @classmethod 68 | def from_config(cls, cfg): 69 | return { 70 | "in_channels": cfg.IN_CHANNELS, 71 | "out_channels": cfg.OUT_CHANNELS, 72 | } 73 | 74 | def forward(self, x): 75 | 76 | x = self.conv_start(x) 77 | rem0 = x 78 | x = self.conv1a(x) 79 | rem1 = x 80 | x = self.conv2a(x) 81 | rem2 = x 82 | x = self.conv3a(x) 83 | rem3 = x 84 | x = self.conv4a(x) 85 | rem4 = x 86 | x = self.conv5a(x) 87 | rem5 = x 88 | x = self.conv6a(x) 89 | rem6 = x 90 | 91 | x = self.deconv6a(x, rem5) 92 | rem5 = x 93 | x = self.deconv5a(x, rem4) 94 | rem4 = x 95 | x = self.deconv4a(x, rem3) 96 | rem3 = x 97 | x = self.deconv3a(x, rem2) 98 | rem2 = x 99 | x = self.deconv2a(x, rem1) 100 | rem1 = x 101 | x = self.deconv1a(x, rem0) 102 | rem0 = x 103 | 104 | x = self.conv1b(x, rem1) 105 | rem1 = x 106 | x = self.conv2b(x, rem2) 107 | rem2 = x 108 | x = self.conv3b(x, rem3) 109 | rem3 = x 110 | x = self.conv4b(x, rem4) 111 | rem4 = x 112 | x = self.conv5b(x, rem5) 113 | rem5 = x 114 | x = self.conv6b(x, rem6) 115 | 116 | x = self.deconv6b(x, rem5) 117 | x6 = self.outconv_6(x) 118 | x = self.deconv5b(x, rem4) 119 | x5 = self.outconv_5(x) 120 | x = self.deconv4b(x, rem3) 121 | x4 = self.outconv_4(x) 122 | x = self.deconv3b(x, rem2) 123 | x3 = self.outconv_3(x) 124 | x = self.deconv2b(x, rem1) 125 | x2 = self.outconv_2(x) 126 | 127 | return [x, x2, x3, x4, x5, x6] 128 | -------------------------------------------------------------------------------- /ezflow/models/predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision import io 3 | from torchvision.transforms import Normalize 4 | 5 | from ..utils import InputPadder 6 | from .build import build_model 7 | 8 | 9 | class Predictor: 10 | """ 11 | A class that uses an instance of an optical flow estimation model to predict flow between two images 12 | 13 | Parameters 14 | ---------- 15 | model_name : str 16 | The name of the optical flow estimation model to use 17 | mean : tuple of float, 18 | Sequence of mean for normalizing each image channel 19 | std : tuple of float, 20 | Sequence of standard deviations for normalizing each image channel 21 | model_cfg_path : str, optional 22 | The path to the config file for the optical flow estimation model, by default None in which case the default config is used 23 | model_cfg : CfgNode object, optional 24 | The config object for the optical flow estimation model, by default None 25 | model_weights_path : str, optional 26 | The path to the weights file for the optical flow estimation model 27 | custom_cfg_file : bool, optional 28 | Whether the config file is a custom config file or one one of the configs included in EzFlow, by default False 29 | default : bool, optional 30 | Whether to use the default config for the model 31 | data_transform : torchvision.transforms object, optional 32 | The data transform to apply to the images before passing them to the model, by default None 33 | device : str, optional 34 | The device to use for the model, by default "cpu" 35 | flow_scale : float, optional 36 | The scale to apply to the predicted flow, by default 1.0 37 | pad_divisor : int, optional 38 | The divisor to make the image dimensions evenly divisible by using padding, by default 1 39 | """ 40 | 41 | def __init__( 42 | self, 43 | model_name, 44 | mean, 45 | std, 46 | model_cfg_path=None, 47 | model_cfg=None, 48 | model_weights_path=None, 49 | custom_cfg_file=False, 50 | default=False, 51 | device="cpu", 52 | data_transform=None, 53 | flow_scale=1.0, 54 | pad_divisor=1, 55 | ): 56 | 57 | self.flow_scale = flow_scale 58 | self.pad_divisor = pad_divisor 59 | 60 | if model_cfg_path is not None: 61 | self.model = build_model( 62 | model_name, 63 | cfg_path=model_cfg_path, 64 | custom_cfg=custom_cfg_file, 65 | default=default, 66 | weights_path=model_weights_path, 67 | ) 68 | 69 | elif default: 70 | self.model = build_model( 71 | model_name, default=True, weights_path=model_weights_path 72 | ) 73 | 74 | else: 75 | assert ( 76 | model_cfg is not None 77 | ), "Must provide either a path to a config file or a config object" 78 | self.model = build_model( 79 | model_name, cfg=model_cfg, weights_path=model_weights_path 80 | ) 81 | 82 | self.model = self.model.eval() 83 | self.norm = Normalize(mean=mean, std=std) 84 | self.data_transform = data_transform 85 | self.device = torch.device(device) 86 | 87 | def __call__(self, img1, img2): 88 | """ 89 | Runs the prediction on the two images 90 | 91 | Parameters 92 | ---------- 93 | img1 : torch.Tensor or str 94 | The first image to predict flow from 95 | img2 : torch.Tensor or str 96 | The second image to predict flow to 97 | 98 | Returns 99 | ------- 100 | torch.Tensor 101 | The predicted flow 102 | """ 103 | 104 | if type(img1) == str: 105 | img1 = io.read_image(img1).float() 106 | img1 = img1.unsqueeze(dim=0) 107 | 108 | if type(img2) == str: 109 | img2 = io.read_image(img2).float() 110 | img2 = img2.unsqueeze(dim=0) 111 | 112 | if self.data_transform: 113 | img1 = self.data_transform(img1) 114 | img2 = self.data_transform(img2) 115 | 116 | img1 = self.norm(img1) 117 | img2 = self.norm(img2) 118 | 119 | padder = InputPadder(img1.shape, divisor=self.pad_divisor) 120 | img1, img2 = padder.pad(img1, img2) 121 | 122 | output = self.model(img1, img2) 123 | flow_pred = padder.unpad(output["flow_upsampled"]) 124 | flow_pred = flow_pred * self.flow_scale 125 | return flow_pred 126 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: ezflow 2 | channels: 3 | - fvcore 4 | - pytorch 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1 9 | - attrs=21.2.0 10 | - blas=1.0 11 | - brotlipy=0.7.0 12 | - bzip2=1.0.8 13 | - ca-certificates=2021.10.8 14 | - cffi=1.14.6 15 | - click=8.0.1 16 | - colorama=0.4.4 17 | - cudatoolkit=11.3 18 | - cycler=0.10.0 19 | - dnspython=1.16.0 20 | - docopt=0.6.2 21 | - einops=0.3.2 22 | - expat=2.4.1 23 | - ffmpeg=4.3 24 | - fontconfig=2.13.1 25 | - freetype=2.10.4 26 | - future=0.18.2 27 | - icu=58.2 28 | - intel-openmp=2021.2.0 29 | - jpeg=9b 30 | - kiwisolver=1.3.1 31 | - lcms2=2.12 32 | - libffi=3.3 33 | - libiconv=1.15 34 | - libpng=1.6.37 35 | - libtiff=4.2.0 36 | - libuv=1.40.0 37 | - libwebp-base=1.2.0 38 | - libxcb=1.13 39 | - libxml2=2.9.12 40 | - lz4-c=1.9.3 41 | - matplotlib=3.4.2 42 | - matplotlib-base=3.4.2 43 | - mkl=2021.2.0 44 | - mkl-service=2.3.0 45 | - mkl_fft=1.3.0 46 | - mkl_random=1.2.1 47 | - mypy_extensions=0.4.3 48 | - networkx=2.6.2 49 | - ninja=1.10.2 50 | - numpy=1.20.2 51 | - numpy-base=1.20.2 52 | - olefile=0.46 53 | - openjpeg=2.3.0 54 | - openssl=1.1.1l 55 | - pcre=8.45 56 | - pillow=8.3.1 57 | - pip=21.2.4 58 | - pthread-stubs=0.4 59 | - pycparser=2.20 60 | - pymongo=3.11.4 61 | - pyopenssl=20.0.1 62 | - pyparsing=2.4.7 63 | - pyqt=5.9.2 64 | - pysocks=1.7.1 65 | - python=3.7.10 66 | - python-dateutil=2.8.1 67 | - python-jsonrpc-server=0.3.4 68 | - python_abi=3.7 69 | - pytorch=1.10.1 70 | - qt=5.9.7 71 | - regex=2021.7.6 72 | - setuptools=52.0.0 73 | - sip=4.19.8 74 | - six=1.16.0 75 | - sqlite=3.36.0 76 | - tk=8.6.10 77 | - toml=0.10.2 78 | - torchvision=0.11.2 79 | - tornado=6.1 80 | - typed-ast=1.4.3 81 | - typing_extensions=3.10.0.0 82 | - ujson=1.35 83 | - whichcraft=0.6.1 84 | - xorg-libxau=1.0.9 85 | - xorg-libxdmcp=1.1.3 86 | - xz=5.2.5 87 | - yaml=0.2.5 88 | - zipp=3.5.0 89 | - zlib=1.2.11 90 | - zstd=1.4.9 91 | - pip: 92 | - absl-py==0.13.0 93 | - alabaster==0.7.12 94 | - antlr4-python3-runtime==4.8 95 | - appdirs==1.4.4 96 | - argh==0.26.2 97 | - arrow==0.13.1 98 | - babel==2.9.1 99 | - backports-entry-points-selectable==1.1.0 100 | - binaryornot==0.4.4 101 | - black==21.7b0 102 | - bleach==3.3.0 103 | - bump2version==0.5.11 104 | - cachetools==4.2.2 105 | - certifi==2021.5.30 106 | - cfgv==3.3.0 107 | - chardet==4.0.0 108 | - charset-normalizer==2.0.1 109 | - cookiecutter==1.7.3 110 | - coverage==5.5 111 | - cryptography==3.4.8 112 | - distlib==0.3.2 113 | - docutils==0.17.1 114 | - easydict==1.9 115 | - entrypoints==0.3 116 | - filelock==3.0.12 117 | - flake8==3.9.2 118 | - fvcore==0.1.5.post20210722 119 | - google-auth==1.35.0 120 | - google-auth-oauthlib==0.4.5 121 | - grpcio==1.40.0 122 | - identify==2.2.13 123 | - idna==3.2 124 | - imagesize==1.2.0 125 | - importlib-metadata==0.23 126 | - iniconfig==1.1.1 127 | - iopath==0.1.9 128 | - jinja2==3.0.1 129 | - jinja2-time==0.2.0 130 | - markdown==3.3.4 131 | - markupsafe==2.0.1 132 | - mccabe==0.6.1 133 | - nodeenv==1.6.0 134 | - oauthlib==3.1.1 135 | - omegaconf==2.1.1 136 | - opencv-python==4.5.3.56 137 | - packaging==21.0 138 | - pathspec==0.9.0 139 | - pathtools==0.1.2 140 | - pip==21.2.4 141 | - pkginfo==1.7.1 142 | - platformdirs==2.0.0 143 | - pluggy==0.13.1 144 | - portalocker==2.3.0 145 | - poyo==0.5.0 146 | - pre-commit==2.14.0 147 | - protobuf==3.17.3 148 | - py==1.10.0 149 | - pyasn1==0.4.8 150 | - pyasn1-modules==0.2.8 151 | - pycodestyle==2.7.0 152 | - pyflakes==2.3.1 153 | - pygments==2.9.0 154 | - pytest==6.2.4 155 | - python-slugify==5.0.2 156 | - pytz==2021.1 157 | - pyyaml==5.4.1 158 | - readme-renderer==29.0 159 | - requests==2.26.0 160 | - requests-oauthlib==1.3.0 161 | - requests-toolbelt==0.9.1 162 | - rsa==4.7.2 163 | - scipy==1.7.0 164 | - snowballstemmer==2.1.0 165 | - tabulate==0.8.9 166 | - tensorboard==2.6.0 167 | - tensorboard-data-server==0.6.1 168 | - tensorboard-plugin-wit==1.8.0 169 | - termcolor==1.1.0 170 | - text-unidecode==1.3 171 | - tomli==1.2.1 172 | - torchmetrics==0.5.0 173 | - tox==3.14.0 174 | - tqdm==4.61.2 175 | - twine==1.14.0 176 | - unidecode==1.3.2 177 | - urllib3==1.26.6 178 | - virtualenv==20.5.0 179 | - watchdog==0.9.0 180 | - webencodings==0.5.1 181 | - wheel==0.37.0 182 | - yacs==0.1.8 -------------------------------------------------------------------------------- /docs/tutorials/using.rst: -------------------------------------------------------------------------------- 1 | Using one of the implemented optical flow estimation models 2 | ============================================================================== 3 | 4 | **EzFlow** contains easy-to-use implementations of a number of eminent models for optical flow estimation. 5 | `PWC-Net `_, `RAFT `_, and 6 | `VCN `_ to name a few. 7 | 8 | These models can be accessed with the help of builder functions. For example, to build a **RAFT** model, the following code snippet can be used: 9 | 10 | .. code-block:: python 11 | 12 | from ezflow.models import build_model 13 | 14 | model = build_model("RAFT", default=True) 15 | 16 | This snippet will return a **RAFT** model with the default configuration and parameters. 17 | 18 | Let's now talk about how the models are implemented and how they can be accessed using the builder functions. 19 | 20 | Each model is a composite of sub-modules like encoders and decoders which are present in the library. Every implementation takes in 21 | a `YACS `_ configuration object (:class:`CfgNode`) as input and returns a model object. This configuration object is 22 | used to supply the various parameters for the encoder, decoder, and other modules and hence to build the model. 23 | 24 | **EzFlow** packages default configurations for each models which have the apporpriate parameters for the respective models. To access these default configurations, 25 | the following function can be used: 26 | 27 | .. code-block:: python 28 | 29 | from ezflow.models import get_default_model_cfg 30 | 31 | raft_cfg = get_default_model_cfg("RAFT") 32 | 33 | The above mentioned getter function reads the YAML configuration file supplied with the library for a model and returns a :class:`CfgNode` object. 34 | This configuration object can be used to build the model. 35 | 36 | In the example provided above about using the builder function to access **RAFT**, under the hood, the getter function is used to fetch the default configuration 37 | object for the model name specified and then it is passed to model class constructor to build the model. 38 | However, this is not the only way to access a models present in **EzFlow**. The builder functions also accept a :class:`CfgNode` object as input and return a model object. 39 | 40 | .. code-block:: python 41 | 42 | raft_cfg = get_default_model_cfg("RAFT") 43 | model = build_model("RAFT", cfg=raft_cfg) 44 | 45 | This way you can also make a few modifications to the default configuration if required. 46 | To view all the parameters present in a configration object, the :func:`.to_dict()` method of the object can be used can be used. 47 | 48 | .. code-block:: python 49 | 50 | raft_cfg = get_default_model_cfg("RAFT") 51 | raft_cfg.ENCODER.CORR_RADIUS = 5 52 | model = build_model("RAFT", cfg=raft_cfg) 53 | 54 | 55 | Additinally, you can also supply YAML configuration file paths to the builder function. These can further be of two types. 56 | **EzFlow** stores config files in the `configs/models` directory in the `root `_ of the library. Files present in this directory can be accessed by specifying the path to the file 57 | relative to `configs/models`. For example, to build **RAFT** this way: 58 | 59 | .. code-block:: python 60 | 61 | model = build_model("RAFT", cfg_path="raft.yaml") 62 | 63 | Furthermore, you can also supply a path to a custom YAML configuration file which you may have created for a model. 64 | 65 | .. code-block:: python 66 | 67 | model = build_model("RAFT", cfg_path="my_raft_cfg.yaml", custom_cfg=True) 68 | 69 | Lastly, the builder function can also be used to load a model with pretrained weights. 70 | 71 | .. code-block:: python 72 | 73 | model = build_model("RAFT", default=True, weights_path="raft_weights.pth") 74 | 75 | 76 | Along with the above described ways to access models, **EzFlow** also provides a higher level API to use these models for prediction. 77 | This can be done using the :class:`Predictor` class. 78 | 79 | .. code-block:: python 80 | 81 | from ezflow.models import Predictor 82 | from torchvision.transforms import Resize 83 | 84 | predictor = Predictor("RAFT", 85 | mean=(127.5, 127.5, 127.5), 86 | std=(127.5, 127.5, 127.5), 87 | default=True, 88 | model_weights_path="raft_weights.pth", 89 | data_transform=Resize((256, 256)) 90 | ) 91 | flow = predictor("img1.png", "img2.png") 92 | 93 | Please refer to the API documentation for more details. 94 | Also, do check out out the other tutorials for details on how to use **EzFlow** to build custom models 95 | and how to train them using the training pipeline provided by the library. 96 | -------------------------------------------------------------------------------- /ezflow/data/dataset/kitti.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | from glob import glob 3 | 4 | from ...config import configurable 5 | from ...functional import SparseFlowAugmentor 6 | from ..build import DATASET_REGISTRY 7 | from .base_dataset import BaseDataset 8 | 9 | 10 | @DATASET_REGISTRY.register() 11 | class Kitti(BaseDataset): 12 | """ 13 | Dataset Class for preparing the Kitti dataset for training and validation. 14 | 15 | Parameters 16 | ---------- 17 | root_dir : str 18 | path of the root directory for the HD1K dataset 19 | split : str, default : "training" 20 | specify the training or validation split 21 | is_prediction : bool, default : False 22 | If True, only image data are loaded for prediction otherwise both images and flow data are loaded 23 | init_seed : bool, default : False 24 | If True, sets random seed to worker 25 | append_valid_mask : bool, default : False 26 | If True, appends the valid flow mask to the original flow mask at dim=0 27 | crop: bool, default : True 28 | Whether to perform cropping 29 | crop_size : :obj:`tuple` of :obj:`int` 30 | The size of the image crop 31 | crop_type : :obj:`str`, default : 'center' 32 | The type of croppping to be performed, one of "center", "random" 33 | augment : bool, default : True 34 | If True, applies data augmentation 35 | aug_params : :obj:`dict`, optional 36 | The parameters for data augmentation 37 | norm_params : :obj:`dict`, optional 38 | The parameters for normalization 39 | flow_offset_params: :obj:`dict`, optional 40 | The parameters for adding bilinear interpolated weights surrounding each ground truth flow values. 41 | """ 42 | 43 | @configurable 44 | def __init__( 45 | self, 46 | root_dir, 47 | split="training", 48 | is_prediction=False, 49 | init_seed=False, 50 | append_valid_mask=False, 51 | crop=False, 52 | crop_size=(256, 256), 53 | crop_type="center", 54 | augment=True, 55 | aug_params={ 56 | "eraser_aug_params": {"enabled": False}, 57 | "noise_aug_params": {"enabled": False}, 58 | "flip_aug_params": {"enabled": False}, 59 | "color_aug_params": {"enabled": False}, 60 | "spatial_aug_params": {"enabled": False}, 61 | "advanced_spatial_aug_params": {"enabled": False}, 62 | }, 63 | norm_params={"use": False}, 64 | flow_offset_params={"use": False}, 65 | ): 66 | super(Kitti, self).__init__( 67 | init_seed=init_seed, 68 | is_prediction=is_prediction, 69 | append_valid_mask=append_valid_mask, 70 | crop=crop, 71 | crop_size=crop_size, 72 | crop_type=crop_type, 73 | augment=augment, 74 | aug_params=aug_params, 75 | sparse_transform=True, 76 | norm_params=norm_params, 77 | flow_offset_params=flow_offset_params, 78 | ) 79 | assert ( 80 | split.lower() == "training" or split.lower() == "validation" 81 | ), "Incorrect split values. Accepted split values: training, validation" 82 | 83 | self.is_prediction = is_prediction 84 | self.append_valid_mask = append_valid_mask 85 | 86 | if augment: 87 | self.augmentor = SparseFlowAugmentor(crop_size=crop_size, **aug_params) 88 | 89 | split = split.lower() 90 | if split == "validation": 91 | split = "testing" 92 | self.is_prediction = True 93 | 94 | root_dir = osp.join(root_dir, split) 95 | images1 = sorted(glob(osp.join(root_dir, "image_2/*_10.png"))) 96 | images2 = sorted(glob(osp.join(root_dir, "image_2/*_11.png"))) 97 | 98 | for img1, img2 in zip(images1, images2): 99 | self.image_list += [[img1, img2]] 100 | 101 | if not self.is_prediction: 102 | self.flow_list += sorted(glob(osp.join(root_dir, "flow_occ/*_10.png"))) 103 | 104 | @classmethod 105 | def from_config(cls, cfg): 106 | return { 107 | "root_dir": cfg.ROOT_DIR, 108 | "split": cfg.SPLIT, 109 | "is_prediction": cfg.IS_PREDICTION, 110 | "init_seed": cfg.INIT_SEED, 111 | "append_valid_mask": cfg.APPEND_VALID_MASK, 112 | "crop": cfg.CROP.USE, 113 | "crop_size": cfg.CROP.SIZE, 114 | "crop_type": cfg.CROP.TYPE, 115 | "augment": cfg.AUGMENTATION.USE, 116 | "aug_params": cfg.AUGMENTATION.PARAMS, 117 | "norm_params": cfg.NORM_PARAMS, 118 | "flow_offset_params": cfg.FLOW_OFFSET_PARAMS, 119 | } 120 | -------------------------------------------------------------------------------- /tests/test_modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from ezflow.config import CfgNode 4 | from ezflow.modules import MODULE_REGISTRY 5 | 6 | 7 | def test_ConvGRU(): 8 | 9 | inp_x = torch.rand(2, 8, 32, 32) 10 | inp_h = torch.rand(2, 8, 32, 32) 11 | 12 | module = MODULE_REGISTRY.get("ConvGRU")(hidden_dim=8, input_dim=8) 13 | _ = module(inp_h, inp_x) 14 | 15 | 16 | def test_BasicBlock(): 17 | 18 | inp = torch.randn(2, 3, 256, 256) 19 | 20 | module = MODULE_REGISTRY.get("BasicBlock")( 21 | inp.shape[1], 32, norm="group", activation="relu", stride=3 22 | ) 23 | _ = module(inp) 24 | del module 25 | 26 | module = MODULE_REGISTRY.get("BasicBlock")( 27 | inp.shape[1], 32, norm="batch", activation="leakyrelu", stride=3 28 | ) 29 | _ = module(inp) 30 | del module 31 | 32 | module = MODULE_REGISTRY.get("BasicBlock")( 33 | inp.shape[1], 32, norm="instance", activation="relu", stride=3 34 | ) 35 | _ = module(inp) 36 | del module 37 | 38 | module = MODULE_REGISTRY.get("BasicBlock")( 39 | inp.shape[1], 32, norm="none", activation="relu", stride=3 40 | ) 41 | _ = module(inp) 42 | del module 43 | 44 | module = MODULE_REGISTRY.get("BasicBlock")( 45 | inp.shape[1], 32, norm=None, activation="relu", stride=3 46 | ) 47 | _ = module(inp) 48 | del module 49 | 50 | 51 | def test_BottleneckBlock(): 52 | 53 | inp = torch.randn(2, 3, 256, 256) 54 | 55 | module = MODULE_REGISTRY.get("BottleneckBlock")( 56 | inp.shape[1], 32, norm="group", activation="relu", stride=3 57 | ) 58 | _ = module(inp) 59 | del module 60 | 61 | module = MODULE_REGISTRY.get("BottleneckBlock")( 62 | inp.shape[1], 32, norm="batch", activation="leakyrelu", stride=3 63 | ) 64 | _ = module(inp) 65 | del module 66 | 67 | module = MODULE_REGISTRY.get("BottleneckBlock")( 68 | inp.shape[1], 32, norm="instance", activation="relu", stride=3 69 | ) 70 | _ = module(inp) 71 | del module 72 | 73 | module = MODULE_REGISTRY.get("BottleneckBlock")( 74 | inp.shape[1], 32, norm="none", activation="relu", stride=3 75 | ) 76 | _ = module(inp) 77 | del module 78 | 79 | module = MODULE_REGISTRY.get("BottleneckBlock")( 80 | inp.shape[1], 32, norm=None, activation="relu", stride=3 81 | ) 82 | _ = module(inp) 83 | del module 84 | 85 | 86 | def test_DAP(): 87 | 88 | inp = torch.randn(2, 1, 7, 7, 16, 16) 89 | 90 | module = MODULE_REGISTRY.get("DisplacementAwareProjection")(temperature=False) 91 | _ = module(inp) 92 | 93 | module = MODULE_REGISTRY.get("DisplacementAwareProjection")(temperature=True) 94 | _ = module(inp) 95 | 96 | 97 | def test_ASPPConv2D(): 98 | inp = torch.randn(2, 256, 32, 32) 99 | 100 | module = MODULE_REGISTRY.get("ASPPConv2D")( 101 | in_channels=256, hidden_dim=256, out_channels=256, norm="none" 102 | ) 103 | out = module(inp) 104 | assert out.shape == (2, 256, 32, 32) 105 | del module 106 | 107 | module = MODULE_REGISTRY.get("ASPPConv2D")( 108 | in_channels=256, hidden_dim=256, out_channels=256, norm="batch" 109 | ) 110 | out = module(inp) 111 | del module 112 | 113 | 114 | def test_UNetBase(): 115 | inp = torch.randn(2, 695, 32, 32) 116 | 117 | bottleneck_config = CfgNode( 118 | init_dict={ 119 | "NAME": "ASPPConv2D", 120 | "IN_CHANNELS": 192, 121 | "HIDDEN_DIM": 192, 122 | "OUT_CHANNELS": 192, 123 | "DILATIONS": [2, 4, 8], 124 | "NUM_GROUPS": 1, 125 | "NORM": "none", 126 | }, 127 | new_allowed=True, 128 | ) 129 | 130 | module = MODULE_REGISTRY.get("UNetBase")( 131 | in_channels=695, 132 | hidden_dim=96, 133 | out_channels=96, 134 | bottle_neck_cfg=bottleneck_config, 135 | ) 136 | out = module(inp) 137 | assert out.shape == (2, 96, 32, 32) 138 | del module 139 | 140 | 141 | def test_UNetLight(): 142 | inp = torch.randn(2, 695, 32, 32) 143 | 144 | bottleneck_config = CfgNode( 145 | init_dict={ 146 | "NAME": "ASPPConv2D", 147 | "IN_CHANNELS": 192, 148 | "HIDDEN_DIM": 192, 149 | "OUT_CHANNELS": 192, 150 | "DILATIONS": [2, 4, 8], 151 | "NUM_GROUPS": 1, 152 | "NORM": "none", 153 | }, 154 | new_allowed=True, 155 | ) 156 | 157 | module = MODULE_REGISTRY.get("UNetLight")( 158 | in_channels=695, 159 | hidden_dim=96, 160 | out_channels=96, 161 | bottle_neck_cfg=bottleneck_config, 162 | ) 163 | out = module(inp) 164 | assert out.shape == (2, 96, 32, 32) 165 | del module 166 | --------------------------------------------------------------------------------