├── .gitignore ├── .vscode └── settings.json ├── CODE_OF_CONDUCT.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── core ├── config │ ├── checkpoint_config.py │ ├── dataset_config.py │ ├── eval_config.py │ ├── model_config.py │ ├── object_config.py │ ├── train_config.py │ └── workflow_config.py ├── data │ ├── __init__.py │ ├── blender_camera.py │ ├── depths.py │ ├── flow.py │ ├── homography.py │ ├── llff_camera.py │ ├── mask.py │ └── rgba_mask.py ├── hook │ ├── checkpoint.py │ └── validation.py ├── loss │ ├── __init__.py │ ├── alpha_reg_loss.py │ ├── depth_matching.py │ ├── distortion_loss.py │ ├── flow_recons_loss.py │ ├── l1_loss.py │ ├── mask_loss.py │ ├── mean_flow_match_loss.py │ ├── mse_loss.py │ ├── robust_depth_matching.py │ ├── zero_reg_loss.py │ └── zero_reg_loss_optional.py ├── model │ ├── matting_dataset.py │ └── render_context.py ├── module │ ├── __init__.py │ ├── dummy.py │ ├── omnimatte.py │ └── tensorf.py ├── scheduler │ ├── __init__.py │ └── exp_lr.py ├── trainer │ ├── __init__.py │ ├── matting_trainer.py │ └── omnimatte_trainer.py └── utils │ ├── omnimatte_utils.py │ ├── tensorf_utils.py │ └── trainer_utils.py ├── data_manager_example.json ├── docker ├── .bashrc ├── Dockerfile ├── config.fish └── docker-compose.yaml ├── docs ├── docker-images.md ├── using-the-cli.md └── using-your-video.md ├── init.fish ├── lib ├── hook.py ├── loss.py ├── registry.py └── trainer.py ├── licenses ├── MiDaS ├── Omnimatte ├── RAFT ├── RoDynRF └── TensoRF ├── preprocess ├── config │ ├── convert_segmentation.yaml │ ├── run_colmap.yaml │ ├── run_depth.yaml │ ├── run_flow.yaml │ ├── run_homography.yaml │ ├── run_motion_mask.yaml │ ├── run_segmentation.yaml │ ├── video_to_images.yaml │ └── visualize_segmentation.yaml ├── convert_segmentation.py ├── run_colmap.py ├── run_depth.py ├── run_flow.py ├── run_homography.py ├── run_segmentation.py ├── video_to_images.py └── visualize_segmentation.py ├── third_party ├── MiDaS │ ├── .gitignore │ ├── Dockerfile │ ├── LICENSE │ ├── README.md │ ├── README.third_party │ ├── environment.yaml │ ├── hubconf.py │ ├── midas │ │ ├── backbones │ │ │ ├── beit.py │ │ │ ├── levit.py │ │ │ ├── next_vit.py │ │ │ ├── swin.py │ │ │ ├── swin2.py │ │ │ ├── swin_common.py │ │ │ ├── utils.py │ │ │ └── vit.py │ │ ├── base_model.py │ │ ├── blocks.py │ │ ├── dpt_depth.py │ │ ├── midas_net.py │ │ ├── midas_net_custom.py │ │ ├── model_loader.py │ │ └── transforms.py │ ├── run.py │ └── utils.py ├── RAFT │ ├── .gitignore │ ├── LICENSE │ ├── METADATA │ ├── RAFT.png │ ├── README.md │ ├── alt_cuda_corr │ │ ├── correlation.cpp │ │ ├── correlation_kernel.cu │ │ └── setup.py │ ├── chairs_split.txt │ ├── core │ │ ├── __init__.py │ │ ├── corr.py │ │ ├── datasets.py │ │ ├── extractor.py │ │ ├── raft.py │ │ ├── update.py │ │ └── utils │ │ │ ├── __init__.py │ │ │ ├── augmentor.py │ │ │ ├── flow_viz.py │ │ │ ├── frame_utils.py │ │ │ └── utils.py │ ├── demo.py │ ├── download_models.sh │ ├── evaluate.py │ ├── train.py │ ├── train_mixed.sh │ └── train_standard.sh ├── RoDynRF │ ├── .gitignore │ ├── camera.py │ ├── configs │ │ ├── DAVIS_CAM.txt │ │ ├── REALWORLD_CAM.txt │ │ └── REALWORLD_CAM_NDC.txt │ ├── dataLoader │ │ ├── __init__.py │ │ ├── colmap2nerf.py │ │ ├── nvidia_pose.py │ │ └── ray_utils.py │ ├── flow_viz.py │ ├── models │ │ ├── __init__.py │ │ ├── sh.py │ │ ├── tensoRF.py │ │ └── tensorBase.py │ ├── opt.py │ ├── renderer.py │ ├── scripts │ │ ├── RAFT │ │ │ ├── __init__.py │ │ │ ├── corr.py │ │ │ ├── datasets.py │ │ │ ├── demo.py │ │ │ ├── extractor.py │ │ │ ├── raft.py │ │ │ ├── update.py │ │ │ └── utils │ │ │ │ ├── __init__.py │ │ │ │ ├── augmentor.py │ │ │ │ ├── flow_viz.py │ │ │ │ ├── frame_utils.py │ │ │ │ └── utils.py │ │ ├── flow_utils.py │ │ ├── generate_depth.py │ │ ├── generate_flow.py │ │ ├── midas │ │ │ ├── base_model.py │ │ │ ├── blocks.py │ │ │ ├── midas_net.py │ │ │ ├── transforms.py │ │ │ └── vit.py │ │ └── our_mask.py │ ├── train.py │ └── utils.py ├── TensoRF │ ├── LICENSE │ ├── METADATA │ ├── README.md │ ├── configs │ │ ├── flower.txt │ │ ├── lego.txt │ │ ├── truck.txt │ │ ├── wineholder.txt │ │ └── your_own_data.txt │ ├── dataLoader │ │ ├── __init__.py │ │ ├── blender.py │ │ ├── colmap2nerf.py │ │ ├── llff.py │ │ ├── nsvf.py │ │ ├── ray_utils.py │ │ ├── tankstemple.py │ │ └── your_own_data.py │ ├── extra │ │ ├── auto_run_paramsets.py │ │ └── compute_metrics.py │ ├── models │ │ ├── __init__.py │ │ ├── sh.py │ │ ├── tensoRF.py │ │ └── tensorBase.py │ ├── opt.py │ ├── renderer.py │ ├── train.py │ └── utils.py └── omnimatte │ ├── CITATION.cff │ ├── LICENSE │ ├── METADATA │ ├── README.md │ ├── data │ ├── __init__.py │ └── omnimatte_dataset.py │ ├── datasets │ ├── confidence.py │ └── homography.py │ ├── docs │ └── data.md │ ├── environment.yml │ ├── models │ ├── __init__.py │ ├── networks.py │ └── omnimatte_model.py │ ├── options │ ├── __init__.py │ ├── base_options.py │ ├── test_options.py │ └── train_options.py │ ├── requirements.txt │ ├── test.py │ ├── third_party │ ├── __init__.py │ ├── data │ │ ├── __init__.py │ │ ├── base_dataset.py │ │ └── image_folder.py │ ├── models │ │ ├── __init__.py │ │ ├── base_model.py │ │ ├── networks.py │ │ └── networks_lnr.py │ └── util │ │ ├── __init__.py │ │ ├── html.py │ │ ├── util.py │ │ └── visualizer.py │ ├── train.py │ └── utils.py ├── tools ├── blender_to_matting.py ├── convert_masks.py ├── make_video.py ├── matting_to_nerfies.py ├── matting_to_omnimatte.py ├── nerfies_to_blender.py └── simple_video.py ├── ui ├── cli.py ├── commands.py ├── common.py ├── data_manager.py └── data_model.py ├── utils ├── array_utils.py ├── colmap │ ├── colmap_read_model.py │ └── colmap_utils.py ├── dict_utils.py ├── eval_utils.py ├── image_utils.py ├── io_utils.py ├── json_utils.py ├── render_utils.py ├── string_utils.py └── torch_utils.py └── workflows ├── common.py ├── config ├── bg_losses │ ├── distortion.yaml │ ├── recons.yaml │ ├── recons_coarse.yaml │ ├── recons_om.yaml │ ├── robust_depth_matching.yaml │ └── tv_reg.yaml ├── bg_model │ ├── dummy.yaml │ └── tensorf.yaml ├── data_sources │ ├── blender.yaml │ ├── colmap.yaml │ ├── depths.yaml │ ├── flow.yaml │ ├── homography.yaml │ ├── mask.yaml │ ├── nerfies_camera.yaml │ ├── rgba_mask.yaml │ └── rodynrf.yaml ├── debug.yaml ├── eval.yaml ├── fg_losses │ ├── alpha_reg.yaml │ ├── bg_distortion.yaml │ ├── bg_tv_reg.yaml │ ├── brightness_reg.yaml │ ├── depth_matching.yaml │ ├── flow_recons.yaml │ ├── mask.yaml │ ├── mean_flow_match.yaml │ ├── offset_reg.yaml │ ├── recons.yaml │ ├── robust_depth_matching.yaml │ └── warped_alpha.yaml ├── fg_model │ ├── dummy.yaml │ ├── omnimatte.yaml │ └── omnimatte_noise.yaml ├── profile.yaml ├── train.yaml ├── train_bg.yaml ├── train_bg_rgba.yaml ├── train_bg_rgba_blender.yaml ├── train_both.yaml ├── train_both_davis.yaml ├── train_om.yaml ├── train_tf.yaml ├── trainer │ ├── matting.yaml │ └── omnimatte.yaml └── validation │ ├── render_all.yaml │ ├── render_all_bg.yaml │ └── training_dump.yaml ├── eval.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | __pycache__ 3 | .ipynb_checkpoints 4 | /outputs 5 | /scripts 6 | /data 7 | /data_manager.json 8 | -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "python.sortImports.args": [ 3 | "--src=${workspaceFolder}" 4 | ], 5 | "python.formatting.autopep8Args": [ 6 | "--max-line-length", 7 | "120", 8 | "--ignore", 9 | "E226,E24,W50,W690,E731", 10 | ], 11 | "jupyter.notebookFileRoot": "${workspaceFolder}" 12 | } 13 | -------------------------------------------------------------------------------- /CODE_OF_CONDUCT.md: -------------------------------------------------------------------------------- 1 | # Code of Conduct 2 | 3 | ## Our Pledge 4 | 5 | In the interest of fostering an open and welcoming environment, we as 6 | contributors and maintainers pledge to make participation in our project and 7 | our community a harassment-free experience for everyone, regardless of age, body 8 | size, disability, ethnicity, sex characteristics, gender identity and expression, 9 | level of experience, education, socio-economic status, nationality, personal 10 | appearance, race, religion, or sexual identity and orientation. 11 | 12 | ## Our Standards 13 | 14 | Examples of behavior that contributes to creating a positive environment 15 | include: 16 | 17 | * Using welcoming and inclusive language 18 | * Being respectful of differing viewpoints and experiences 19 | * Gracefully accepting constructive criticism 20 | * Focusing on what is best for the community 21 | * Showing empathy towards other community members 22 | 23 | Examples of unacceptable behavior by participants include: 24 | 25 | * The use of sexualized language or imagery and unwelcome sexual attention or 26 | advances 27 | * Trolling, insulting/derogatory comments, and personal or political attacks 28 | * Public or private harassment 29 | * Publishing others' private information, such as a physical or electronic 30 | address, without explicit permission 31 | * Other conduct which could reasonably be considered inappropriate in a 32 | professional setting 33 | 34 | ## Our Responsibilities 35 | 36 | Project maintainers are responsible for clarifying the standards of acceptable 37 | behavior and are expected to take appropriate and fair corrective action in 38 | response to any instances of unacceptable behavior. 39 | 40 | Project maintainers have the right and responsibility to remove, edit, or 41 | reject comments, commits, code, wiki edits, issues, and other contributions 42 | that are not aligned to this Code of Conduct, or to ban temporarily or 43 | permanently any contributor for other behaviors that they deem inappropriate, 44 | threatening, offensive, or harmful. 45 | 46 | ## Scope 47 | 48 | This Code of Conduct applies within all project spaces, and it also applies when 49 | an individual is representing the project or its community in public spaces. 50 | Examples of representing a project or community include using an official 51 | project e-mail address, posting via an official social media account, or acting 52 | as an appointed representative at an online or offline event. Representation of 53 | a project may be further defined and clarified by project maintainers. 54 | 55 | This Code of Conduct also applies outside the project spaces when there is a 56 | reasonable belief that an individual's behavior may have a negative impact on 57 | the project or its community. 58 | 59 | ## Enforcement 60 | 61 | Instances of abusive, harassing, or otherwise unacceptable behavior may be 62 | reported by contacting the project team at . All 63 | complaints will be reviewed and investigated and will result in a response that 64 | is deemed necessary and appropriate to the circumstances. The project team is 65 | obligated to maintain confidentiality with regard to the reporter of an incident. 66 | Further details of specific enforcement policies may be posted separately. 67 | 68 | Project maintainers who do not follow or enforce the Code of Conduct in good 69 | faith may face temporary or permanent repercussions as determined by other 70 | members of the project's leadership. 71 | 72 | ## Attribution 73 | 74 | This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, 75 | available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html 76 | 77 | [homepage]: https://www.contributor-covenant.org 78 | 79 | For answers to common questions about this code of conduct, see 80 | https://www.contributor-covenant.org/faq 81 | -------------------------------------------------------------------------------- /CONTRIBUTING.md: -------------------------------------------------------------------------------- 1 | # Contributing to OmnimatteRF 2 | We want to make contributing to this project as easy and transparent as 3 | possible. 4 | 5 | ## Pull Requests 6 | We actively welcome your pull requests. 7 | 8 | 1. Fork the repo and create your branch from `main`. 9 | 2. If you've added code that should be tested, add tests. 10 | 3. If you've changed APIs, update the documentation. 11 | 4. Ensure the test suite passes. 12 | 5. Make sure your code lints. 13 | 6. If you haven't already, complete the Contributor License Agreement ("CLA"). 14 | 15 | ## Contributor License Agreement ("CLA") 16 | In order to accept your pull request, we need you to submit a CLA. You only need 17 | to do this once to work on any of Facebook's open source projects. 18 | 19 | Complete your CLA here: 20 | 21 | ## Issues 22 | We use GitHub issues to track public bugs. Please ensure your description is 23 | clear and has sufficient instructions to be able to reproduce the issue. 24 | 25 | Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe 26 | disclosure of security bugs. In those cases, please go through the process 27 | outlined on that page and do not file a public issue. 28 | 29 | ## License 30 | By contributing to OmnimatteRF, you agree that your contributions will be licensed 31 | under the LICENSE file in the root directory of this source tree. 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /core/config/checkpoint_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass 4 | 5 | 6 | @dataclass 7 | class CheckpointConfig: 8 | step_size: int 9 | min_step: int 10 | folder: str 11 | -------------------------------------------------------------------------------- /core/config/dataset_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any 5 | 6 | from core.config.object_config import ObjectConfig 7 | 8 | 9 | @dataclass 10 | class DatasetConfig: 11 | path: str 12 | image_subpath: str 13 | scale: float 14 | source_configs: list[ObjectConfig] 15 | sources_injection: dict[str, Any] 16 | -------------------------------------------------------------------------------- /core/config/eval_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass 4 | 5 | from core.config.workflow_config import WorkflowConfig 6 | 7 | 8 | @dataclass 9 | class EvalConfig(WorkflowConfig): 10 | output: str 11 | 12 | checkpoint: str 13 | train_config_file: str 14 | data_root: str 15 | dataset_name: str | None 16 | experiment: str | None 17 | step: str | None 18 | write_videos: bool 19 | 20 | alpha_threshold: float 21 | """Alpha threshold when generating pred_fg masks""" 22 | eval_bg_layer: bool 23 | """Evaluate background layer against input image at input background mask""" 24 | 25 | debug_count: int 26 | """Debug: only render this number of frames""" 27 | raw_data_keys: list[str] 28 | """Debug: save raw data of these keys as npy""" 29 | raw_data_indices: list[int] 30 | """Debug: indices of frames to save raw data""" 31 | -------------------------------------------------------------------------------- /core/config/model_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any 5 | 6 | from core.config.object_config import ObjectConfig 7 | 8 | 9 | @dataclass 10 | class ModelConfig(ObjectConfig): 11 | train: bool 12 | """Whether this model is trained""" 13 | 14 | optim: dict[str, Any] 15 | """Adam optimizer constructor kwargs""" 16 | -------------------------------------------------------------------------------- /core/config/object_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any 5 | 6 | 7 | @dataclass 8 | class ObjectConfig: 9 | name: str 10 | config: dict[str, Any] 11 | -------------------------------------------------------------------------------- /core/config/train_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any 5 | 6 | from core.config.checkpoint_config import CheckpointConfig 7 | from core.config.workflow_config import WorkflowConfig 8 | from core.config.object_config import ObjectConfig 9 | 10 | 11 | @dataclass 12 | class SchedulerConfig: 13 | fg: ObjectConfig 14 | bg: ObjectConfig 15 | 16 | 17 | @dataclass 18 | class ValidationConfig: 19 | config: dict[str, Any] 20 | """ValidationHook constructor kwargs""" 21 | pretrain: bool 22 | """Whether to run this validation before training""" 23 | 24 | 25 | @dataclass 26 | class TrainConfig(WorkflowConfig): 27 | output: str 28 | checkpoint: str | None 29 | seed: int 30 | debug: bool 31 | 32 | n_steps: int 33 | validation: dict[str, ValidationConfig] 34 | 35 | # Optimization 36 | fg_losses: dict[str, ObjectConfig] 37 | bg_losses: dict[str, ObjectConfig] 38 | scheduler: SchedulerConfig 39 | 40 | # Checkpoint saving 41 | save_checkpoint: CheckpointConfig 42 | save_pretrain_checkpoint: bool 43 | save_final_checkpoint: bool 44 | 45 | # Checkpoint loading 46 | load_fg: bool 47 | load_bg: bool 48 | 49 | reset_global_step: bool 50 | """When loading checkpoint, set beginning global step to zero.""" 51 | reset_bg_optimization: bool 52 | """When loading checkpoint, reset bg model optimizer and scheduler states.""" 53 | -------------------------------------------------------------------------------- /core/config/workflow_config.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass 4 | from typing import Any, Literal 5 | from core.config.dataset_config import DatasetConfig 6 | 7 | from core.config.object_config import ObjectConfig 8 | from core.config.model_config import ModelConfig 9 | 10 | @dataclass 11 | class WorkflowConfig: 12 | device: str 13 | 14 | data_sources: dict[str, dict[str, Any]] 15 | dataset: DatasetConfig 16 | fg_model: ModelConfig 17 | bg_model: ModelConfig 18 | trainer: ObjectConfig 19 | 20 | contraction: Literal["none", "ndc", "mipnerf"] 21 | """Ray contraction scheme to use""" 22 | -------------------------------------------------------------------------------- /core/data/blender_camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import json 4 | import logging 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | 9 | from core.data import CameraDataSource, register_data_source 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @register_data_source("blender_camera") 15 | class BlenderCameraDataSource(CameraDataSource): 16 | def __init__( 17 | self, 18 | root: Path, 19 | image_hw: tuple[int, int], 20 | subpath: str, 21 | near_default: float, 22 | far_default: float, 23 | contraction: str, 24 | compute_ndc_aabb: bool = False, 25 | scene_scale: float = 1, 26 | process_poses: bool = True, 27 | ): 28 | with open(Path(root / subpath), "r") as f: 29 | meta = json.load(f) 30 | 31 | # read data 32 | frames = meta["frames"] 33 | poses = [] 34 | bounds = [] 35 | for i in range(len(frames)): 36 | frame = frames[i] 37 | 38 | poses.append(np.array(frame["transform_matrix"], dtype=np.float32)) 39 | bounds.append( 40 | [ 41 | frame.get("near", float(near_default)), 42 | frame.get("far", float(far_default)), 43 | ] 44 | ) 45 | 46 | poses = np.stack(poses, axis=0) 47 | bounds = np.array(bounds, dtype=np.float32) 48 | 49 | # intrinsics 50 | fov = float(meta["camera_angle_x"]) 51 | focal = 0.5 * image_hw[1] / np.tan(0.5 * fov) 52 | 53 | super().__init__( 54 | root, 55 | image_hw, 56 | contraction, 57 | focal, 58 | poses, 59 | bounds, 60 | compute_ndc_aabb, 61 | scene_scale, 62 | process_poses, 63 | ) 64 | -------------------------------------------------------------------------------- /core/data/depths.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | 10 | from core.data import DataSource, register_data_source 11 | from utils.io_utils import multi_glob_sorted 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | @register_data_source("depths") 17 | class DepthsDataSource(DataSource): 18 | def __init__( 19 | self, 20 | root: Path, 21 | image_hw: tuple[int, int], 22 | subpath: str, 23 | ): 24 | super().__init__(root, image_hw) 25 | 26 | files = multi_glob_sorted(root / subpath, "*.npy") 27 | 28 | # load depths and scale them to (0, 1) per image 29 | # we only care about smoothness so this makes sense? 30 | # note that MiDaS depths are in disparity space 31 | depths = [] 32 | for file in files: 33 | d = np.load(file) 34 | d = (d - d.min()) / (d.max() - d.min()) 35 | depths.append(d) 36 | 37 | depths = torch.from_numpy(np.stack(depths)) # (N, H, W) 38 | 39 | logger.info(f"Loaded depths: {depths.shape}") 40 | depths, _ = self._scale_to_image_size("depths", depths[:, None]) 41 | self.depths = depths[:, 0].view(len(depths), -1) 42 | 43 | def __len__(self) -> int: 44 | return len(self.depths) 45 | 46 | def __getitem__(self, idx: int) -> dict[str, Tensor]: 47 | return { 48 | "depths": self.depths[idx] 49 | } 50 | 51 | def get_keys(self) -> list[str]: 52 | """Get list of data keys provided by this source""" 53 | return ["depths"] 54 | -------------------------------------------------------------------------------- /core/data/flow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from pathlib import Path 4 | from typing import Dict, List, Tuple 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from core.data import DataSource, register_data_source 10 | from third_party.RAFT.core.utils.frame_utils import readFlow 11 | from utils.image_utils import read_image_np 12 | from utils.io_utils import multi_glob_sorted 13 | 14 | 15 | @register_data_source("flow") 16 | class FlowDataSource(DataSource): 17 | def __init__( 18 | self, 19 | root: Path, 20 | image_hw: Tuple[int, int], 21 | flow_path: str, 22 | confidence_path: str, 23 | ): 24 | super().__init__(root, image_hw) 25 | self.flows = self.load_flows(root / flow_path) 26 | self.confidences = self.load_confidences(root / confidence_path) 27 | 28 | confidence_sum = self.confidences.sum(dim=(1, 2), keepdim=True) 29 | confidence_scaled_flow = self.flows * self.confidences[..., None] / confidence_sum[..., None] 30 | mean_flow = confidence_scaled_flow.sum(dim=(1, 2)) # (N, 2) 31 | 32 | self.mean_dist_map = (self.flows - mean_flow.view(len(mean_flow), 1, 1, 2)).abs().mean(dim=-1) 33 | 34 | def __len__(self) -> int: 35 | return len(self.flows) 36 | 37 | def __getitem__(self, idx: int) -> Dict[str, Tensor]: 38 | return { 39 | "flow": self.flows[idx], 40 | "flow_confidence": self.confidences[idx], 41 | "flow_mean_dist_map": self.mean_dist_map[idx], 42 | } 43 | 44 | def get_keys(self) -> List[str]: 45 | return ["flow", "flow_confidence", "flow_mean_dist_map"] 46 | 47 | def load_flows(self, path: Path) -> Tensor: 48 | files = multi_glob_sorted(path, ["*.flo"]) 49 | flows = torch.stack( 50 | [torch.from_numpy(readFlow(f)) for f in files], dim=0 51 | ) # [N-1, H, W, 2] 52 | 53 | flows, scale = self._scale_to_image_size( 54 | "flows", torch.permute(flows, (0, 3, 1, 2)) 55 | ) 56 | flows = torch.permute(flows, (0, 2, 3, 1)) 57 | flows *= scale 58 | 59 | return flows 60 | 61 | def load_confidences(self, path: Path) -> Tensor: 62 | files = multi_glob_sorted(path, ["*.png"]) 63 | confidences = torch.stack( 64 | [torch.from_numpy(read_image_np(f)) for f in files], dim=0 65 | ) # [N-1, H, W] 66 | 67 | confidences, _ = self._scale_to_image_size( 68 | "confidences", confidences[:, None]) 69 | confidences = confidences[:, 0] 70 | 71 | return confidences 72 | -------------------------------------------------------------------------------- /core/data/homography.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from pathlib import Path 4 | from typing import Any, Dict, List, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | 10 | from core.data import DataSource, register_data_source 11 | 12 | 13 | @register_data_source("homography") 14 | class HomographyDataSource(DataSource): 15 | def __init__( 16 | self, 17 | root: Path, 18 | image_hw: Tuple[int, int], 19 | subpath: str, 20 | ): 21 | super().__init__(root, image_hw) 22 | ( 23 | self.homography, 24 | self.homography_size, 25 | self.homography_bounds, 26 | ) = self.load_homography(root / subpath) 27 | 28 | def __len__(self) -> int: 29 | return len(self.homography) 30 | 31 | def __getitem__(self, idx: int) -> Dict[str, Tensor]: 32 | return {"homography": self.homography[idx]} 33 | 34 | def get_keys(self) -> List[str]: 35 | return ["homography"] 36 | 37 | def get_global_data(self) -> Dict[str, Any]: 38 | return { 39 | "homography": self.homography, 40 | "homography_size": self.homography_size, 41 | "homography_bounds": self.homography_bounds, 42 | } 43 | 44 | def load_homography(self, path: Path) -> Tensor: 45 | homographies = np.load(path / "homographies-first-frame.npy") 46 | H_homo, W_homo = np.load(path / "size.npy") 47 | 48 | with open(path / "homographies-first-frame.txt", "r") as f: 49 | _ = f.readline() 50 | bounds = f.readline() 51 | 52 | bounds = [float(v) for v in bounds.rstrip().split(" ")[1:]] 53 | assert ( 54 | len(bounds) == 4 55 | ), f"Failed to parse bounds (length is not 4 but {len(bounds)})" 56 | 57 | return torch.from_numpy(homographies).float(), [H_homo, W_homo], bounds 58 | -------------------------------------------------------------------------------- /core/data/llff_camera.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | 8 | from core.data import CameraDataSource, register_data_source 9 | from utils.array_utils import log_stats 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @register_data_source("llff_camera") 15 | class LlffCameraDataSource(CameraDataSource): 16 | def __init__( 17 | self, 18 | root: Path, 19 | image_hw: tuple[int, int], 20 | subpath: str, 21 | near_p: float, 22 | far_p: float, 23 | near_min: float, 24 | far_max: float, 25 | contraction: str, 26 | compute_ndc_aabb: bool = False, 27 | scene_scale: float = 1, 28 | process_poses: bool = True, 29 | ): 30 | poses_bounds = np.load(root / subpath / "poses_bounds.npy") # [N, 17] 31 | bounds = poses_bounds[:, 15:] # [N, 2] 32 | poses = poses_bounds[:, :15].reshape([-1, 3, 5]) # [N, 3, 5] 33 | 34 | # Scale focal 35 | Hpose = poses[0, 0, 4] 36 | focal = poses[0, 2, 4] * (image_hw[0] / Hpose) 37 | 38 | # Correct poses 39 | poses = np.concatenate( 40 | [poses[..., 1:2], -poses[..., :1], poses[..., 2:4]], -1 41 | ) # (N, 3, 4) 42 | 43 | # Clip bounds 44 | bounds = self.clip_bounds(bounds, near_p, far_p, near_min, far_max) 45 | super().__init__( 46 | root, 47 | image_hw, 48 | contraction, 49 | focal, 50 | poses, 51 | bounds, 52 | compute_ndc_aabb, 53 | scene_scale, 54 | process_poses, 55 | ) 56 | 57 | @staticmethod 58 | def clip_bounds( 59 | bounds: np.ndarray, near_p: float, far_p: float, near_min: float, far_max: float 60 | ) -> np.ndarray: 61 | # Limit range of bounds to specified percentiles of values within range 62 | near_bounds = bounds[:, 0] 63 | far_bounds = bounds[:, 1] 64 | log_stats(logger.info, "near bounds", near_bounds) 65 | log_stats(logger.info, "far bounds", far_bounds) 66 | 67 | near_bounds_in_range = near_bounds[near_bounds > near_min] 68 | far_bounds_in_range = far_bounds[far_bounds < far_max] 69 | log_stats(logger.info, "near bounds (in range)", near_bounds_in_range) 70 | log_stats(logger.info, "far bounds (in range)", far_bounds_in_range) 71 | 72 | bounds[:, 0] = np.clip( 73 | bounds[:, 0], a_min=np.percentile(near_bounds_in_range, near_p), a_max=None 74 | ) 75 | bounds[:, 1] = np.clip( 76 | bounds[:, 1], a_min=None, a_max=np.percentile(far_bounds_in_range, far_p) 77 | ) 78 | 79 | return bounds 80 | -------------------------------------------------------------------------------- /core/data/rgba_mask.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from pathlib import Path 5 | from typing import Dict, List, Tuple 6 | 7 | import torch 8 | from torch import Tensor 9 | 10 | from core.data import DataSource, register_data_source 11 | from utils.image_utils import read_image_np 12 | from utils.io_utils import multi_glob_sorted 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | @register_data_source("rgba_mask") 18 | class RGBAMaskDataSource(DataSource): 19 | def __init__( 20 | self, 21 | root: Path, 22 | image_hw: Tuple[int, int], 23 | subpath: str, 24 | ): 25 | super().__init__(root, image_hw) 26 | self.masks = self.load_masks(root / subpath) 27 | 28 | def __len__(self) -> int: 29 | return len(self.masks) 30 | 31 | def __getitem__(self, idx: int) -> Dict[str, Tensor]: 32 | return { 33 | "rgba_masks": self.masks[idx], 34 | } 35 | 36 | def get_keys(self) -> List[str]: 37 | return ["rgba_masks"] 38 | 39 | def load_masks(self, path: Path) -> Tensor: 40 | files = multi_glob_sorted(path, "*.png") 41 | logger.info(f"Load {len(files)} RGBA masks") 42 | 43 | result = [] 44 | for file in files: 45 | img = torch.from_numpy(read_image_np(file)) 46 | assert len(img.shape) == 3 and img.shape[2] == 4 47 | result.append(img) 48 | result = torch.stack(result, dim=0) 49 | 50 | result = torch.permute(result, (0, 3, 1, 2)) # to NCHW 51 | result, _ = self._scale_to_image_size("rgba_masks", result) 52 | result = torch.permute(result, (0, 2, 3, 1)) # to NHWC 53 | 54 | return result 55 | -------------------------------------------------------------------------------- /core/hook/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from pathlib import Path 5 | from typing import Optional 6 | 7 | import torch 8 | 9 | from lib.hook import Hook 10 | from lib.trainer import Trainer 11 | 12 | logger = logging.getLogger(__name__) 13 | 14 | 15 | class SaveCheckpointHook(Hook): 16 | def __init__( 17 | self, 18 | folder: str, 19 | step_size: Optional[int] = None, 20 | min_step: Optional[int] = None, 21 | ): 22 | super().__init__(step_size) 23 | self.folder = Path(folder) 24 | self.min_step = 0 if min_step is None else min_step 25 | 26 | def execute(self, trainer: Trainer): 27 | if trainer.global_step > 0 and trainer.global_step < self.min_step: 28 | logger.info(f"Checkpoint saving skipped by min_step ({self.min_step})") 29 | return 30 | torch.save(trainer.state_dict(), self.folder / f"checkpoint_{trainer.global_step}.pth") 31 | -------------------------------------------------------------------------------- /core/hook/validation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from collections import defaultdict 5 | from typing import Callable, Optional 6 | 7 | import torch 8 | 9 | from core.trainer import Trainer 10 | from lib.hook import Hook 11 | from utils.eval_utils import compute_frame_metrics 12 | from utils.io_utils import mkdir 13 | 14 | logger = logging.getLogger(__name__) 15 | 16 | 17 | class ValidationHook(Hook): 18 | def __init__( 19 | self, 20 | name: str, 21 | step_size: int, 22 | n_frames: int = 0, 23 | use_upload_callback: bool = False, 24 | upload_callback: Optional[Callable[[str, str], None]] = None, 25 | write_tensorboard: bool = False, 26 | ): 27 | """Hook for running validation steps in Matting project 28 | We want to have multiple validation hooks with different intervals, so we use name to distinguish them 29 | 30 | name: name of output folder and prefix of tensorboard tag 31 | first_step: whether to execute when global_step equals start_step 32 | n_frames: if positive, limit the number of frames to render 33 | write_tensorboard: if True, write the FIRST image to tensorboard after executing 34 | """ 35 | super().__init__(step_size) 36 | self.name = name 37 | self.n_frames = n_frames 38 | self.use_upload_callback = use_upload_callback 39 | self.upload_callback = upload_callback 40 | self.write_tensorboard = write_tensorboard 41 | 42 | def execute(self, trainer: Trainer): 43 | step = trainer.global_step 44 | out_dir = mkdir(trainer.output / self.name / f"step_{step}") 45 | 46 | trainer.set_model_eval() 47 | 48 | # pred_key -> metric_key -> List[float] 49 | all_metrics = defaultdict(lambda: defaultdict(list)) 50 | 51 | def add_metric(result, frame, pred_key, gt_key): 52 | metrics = compute_frame_metrics(result, frame, pred_key, gt_key) 53 | for key, value in metrics.items(): 54 | all_metrics[pred_key][key].append(value) 55 | 56 | written_to_tb = False 57 | for result, frame in trainer._test_full_sequence(out_dir, self.n_frames): 58 | add_metric(result, frame, "composite_rgb", "image") 59 | add_metric(result, frame, "detailed_composite_rgb", "image") 60 | 61 | if written_to_tb or not self.write_tensorboard: 62 | continue 63 | 64 | for key, image in result.items(): 65 | if len(image.shape) == 2: 66 | image = image[..., None] 67 | trainer.writer.add_image( 68 | f"{self.name}_{key}", 69 | torch.from_numpy(image), 70 | global_step=step, 71 | dataformats="HWC", 72 | ) 73 | written_to_tb = True 74 | 75 | # write metrics to tb 76 | for pred_k, pred_metrics in all_metrics.items(): 77 | for metric_k, values in pred_metrics.items(): 78 | mean = sum(values) / len(values) 79 | trainer.writer.add_scalar( 80 | f"{self.name}_metrics_{pred_k}_{metric_k}", 81 | mean, 82 | global_step=step, 83 | ) 84 | 85 | if self.use_upload_callback and self.upload_callback is not None: 86 | self.upload_callback(self.name, out_dir) 87 | 88 | trainer.set_model_train() 89 | -------------------------------------------------------------------------------- /core/loss/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from lib.loss import Loss 4 | from lib.registry import create_registry, import_children 5 | 6 | _, register_loss, build_loss = create_registry("Loss", Loss) 7 | import_children(__file__, __name__) 8 | -------------------------------------------------------------------------------- /core/loss/alpha_reg_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import List 4 | 5 | from core.loss import Loss, register_loss 6 | from third_party.omnimatte.third_party.models.networks_lnr import cal_alpha_reg 7 | 8 | 9 | @register_loss("alpha_reg") 10 | class AlphaRegLoss(Loss): 11 | def __init__( 12 | self, 13 | alpha: float, 14 | inputs: List[str], 15 | lambda_alpha_l1: float, 16 | lambda_alpha_l0: float, 17 | l1_end_step: int, 18 | ): 19 | super().__init__(alpha, inputs=inputs) 20 | self.lambda_alpha_l1 = lambda_alpha_l1 21 | self.lambda_alpha_l0 = lambda_alpha_l0 22 | self.l1_end_step = l1_end_step 23 | 24 | def forward(self, pred, global_step): 25 | lambda_alpha_l1 = self.lambda_alpha_l1 26 | if global_step >= self.l1_end_step: 27 | lambda_alpha_l1 = 0 28 | 29 | return self.alpha * cal_alpha_reg(pred, lambda_alpha_l1, self.lambda_alpha_l0) 30 | -------------------------------------------------------------------------------- /core/loss/distortion_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from torch import Tensor 4 | from torch_efficient_distloss import eff_distloss_native 5 | 6 | from core.loss import Loss, register_loss 7 | 8 | 9 | @register_loss("distortion") 10 | class DistortionLoss(Loss): 11 | def forward(self, weight: Tensor, z_vals: Tensor): 12 | interval = 1 / weight.shape[1] 13 | z_vals = z_vals.expand(len(weight), -1) 14 | return self.alpha * eff_distloss_native(weight, z_vals, interval) 15 | -------------------------------------------------------------------------------- /core/loss/flow_recons_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from torch.nn.functional import l1_loss 4 | 5 | from core.loss import Loss, register_loss 6 | 7 | 8 | @register_loss("flow_recons") 9 | class FlowReconsLoss(Loss): 10 | def forward(self, pred, target, confidence): 11 | return self.alpha * l1_loss(pred * confidence, target * confidence) 12 | -------------------------------------------------------------------------------- /core/loss/l1_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from torch.nn.functional import l1_loss 4 | 5 | from core.loss import Loss, register_loss 6 | 7 | 8 | @register_loss("l1") 9 | class L1Loss(Loss): 10 | def forward(self, pred, target): 11 | return self.alpha * l1_loss(pred, target) 12 | -------------------------------------------------------------------------------- /core/loss/mask_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from typing import List 5 | 6 | import torch 7 | 8 | from core.loss import Loss, register_loss 9 | from third_party.omnimatte.third_party.models.networks_lnr import MaskLoss 10 | 11 | logger = logging.getLogger(__name__) 12 | 13 | 14 | @register_loss("mask_loss") 15 | class MyMaskLoss(Loss): 16 | def __init__( 17 | self, 18 | alpha: float, 19 | inputs: List[str], 20 | reduce_threshold: float, 21 | ): 22 | """Compute the mask loss in Omnimatte. 23 | 24 | reduce_threshold: After the loss (pre alpha) is smaller than this threshold, alpha is reduced to 1/10 of original from the next step. 25 | Further, this loss is disabled after another the same number of steps it took to reach the threshold. 26 | """ 27 | super().__init__(alpha, inputs=inputs) 28 | self.loss = MaskLoss() 29 | 30 | # keep track of when the threshold is reached 31 | self.register_buffer("reduce_step", torch.zeros(1, dtype=torch.long)) 32 | self.reduce_threshold = reduce_threshold 33 | 34 | def forward(self, pred, target, global_step): 35 | reduce_step = int(self.reduce_step) 36 | 37 | if reduce_step > 0: 38 | # disable this loss when the reduced weight has been used long enough 39 | if global_step > reduce_step * 2: 40 | return 0 41 | 42 | mult = 0.1 43 | else: 44 | mult = 1 45 | 46 | # our alpha is in [0, 1] while omnimatte is [-1, 1] 47 | pred = pred * 2 - 1 48 | loss = self.loss(pred, target) 49 | 50 | if reduce_step == 0 and float(loss) < self.reduce_threshold: 51 | self.reduce_step[0] = global_step 52 | logger.info( 53 | f"Start reducing mask loss (step {global_step}), raw loss is {float(loss)} < {self.reduce_threshold}" 54 | ) 55 | 56 | return self.alpha * mult * self.loss(pred, target) 57 | -------------------------------------------------------------------------------- /core/loss/mean_flow_match_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | from core.loss import Loss, register_loss 7 | 8 | 9 | @register_loss("mean_flow_match") 10 | class MeanFlowMatchLoss(Loss): 11 | def forward(self, alpha_layers: Tensor, mean_dist_map: Tensor): 12 | """ 13 | alpha_layers: (B, L, N) layered alpha 14 | mean_dist_map: (B, N) flow error map 15 | """ 16 | return self.alpha * torch.mean(alpha_layers * mean_dist_map[:, None]) 17 | -------------------------------------------------------------------------------- /core/loss/mse_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from torch.nn.functional import mse_loss 4 | 5 | from core.loss import Loss, register_loss 6 | 7 | 8 | @register_loss("mse") 9 | class MseLoss(Loss): 10 | def forward(self, pred, target): 11 | return self.alpha * mse_loss(pred, target) 12 | -------------------------------------------------------------------------------- /core/loss/robust_depth_matching.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Optional 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | from core.loss import Loss, register_loss 9 | 10 | 11 | @register_loss("robust_depth_matching") 12 | class RobustDepthMatchingLoss(Loss): 13 | def __init__(self, alpha: float = 1, inputs: Optional[list[str]] = None, start_step: int = 0): 14 | super().__init__(alpha, inputs) 15 | self.start_step = start_step 16 | 17 | def forward(self, pred: Tensor, gt: Tensor, mask: Tensor, global_step: int): 18 | if global_step < self.start_step: 19 | return 0 20 | 21 | return self.alpha * compute_depth_loss(pred, gt, mask) 22 | 23 | 24 | def compute_depth_loss(dyn_depth, gt_depth, mask): 25 | # https://github.com/gaochen315/DynamicNeRF/blob/c417fb207ef352f7e97521a786c66680218a13af/run_nerf_helpers.py#L483 26 | 27 | t_d = torch.median(dyn_depth) 28 | s_d = torch.mean(torch.abs(dyn_depth - t_d)) 29 | dyn_depth_norm = (dyn_depth - t_d) / s_d 30 | 31 | t_gt = torch.median(gt_depth) 32 | s_gt = torch.mean(torch.abs(gt_depth - t_gt)) 33 | gt_depth_norm = (gt_depth - t_gt) / s_gt 34 | 35 | return torch.mean((mask * (dyn_depth_norm - gt_depth_norm)) ** 2) 36 | -------------------------------------------------------------------------------- /core/loss/zero_reg_loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from core.loss import Loss, register_loss 4 | 5 | 6 | @register_loss("zero_reg") 7 | class ZeroRegLoss(Loss): 8 | def forward(self, pred): 9 | return self.alpha * pred 10 | -------------------------------------------------------------------------------- /core/loss/zero_reg_loss_optional.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from core.loss import Loss, register_loss 4 | 5 | 6 | @register_loss("zero_reg_optional") 7 | class ZeroRegLossOptional(Loss): 8 | def __init__(self, alpha: float, optional_input: str): 9 | super().__init__(alpha) 10 | self.optional_input = optional_input 11 | 12 | def forward(self, *args, **kwargs): 13 | pred = kwargs.get(self.optional_input) 14 | if pred is None: 15 | return 0 16 | return self.alpha * pred 17 | -------------------------------------------------------------------------------- /core/model/matting_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from pathlib import Path 5 | from typing import Any, Dict, List 6 | 7 | import numpy as np 8 | import torch 9 | from torch import Tensor 10 | from torch.utils.data import Dataset 11 | 12 | from core.data import build_data_source 13 | from utils.dict_utils import inject_dict 14 | from utils.image_utils import read_image_np 15 | from utils.io_utils import multi_glob_sorted 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | 20 | class MattingDataset(Dataset): 21 | def __init__( 22 | self, 23 | path: str, 24 | image_subpath: str, 25 | scale: float, 26 | source_configs: List[Dict[str, Any]], 27 | sources_injection: Dict[str, Any], 28 | ): 29 | path = Path(path) 30 | self.images = self.load_images( 31 | path / image_subpath, scale) # [N, H, W, 3] 32 | self.length = len(self.images) 33 | self.image_hw = list(self.images.shape[1:3]) 34 | 35 | self.sources = [] 36 | self.data_keys = set() 37 | self.global_data = {"n_frames": self.length, "image_hw": self.image_hw} 38 | for name, config in source_configs.items(): 39 | logger.info(f"Create data source: {name}") 40 | config["root"] = path 41 | config["image_hw"] = self.image_hw 42 | if "n_images" in config: 43 | config["n_images"] = len(self.images) 44 | inject_dict(config, sources_injection) 45 | 46 | source = build_data_source(name, config) 47 | self.sources.append(source) 48 | self.length = min(self.length, len(source)) 49 | for key in source.get_keys(): 50 | if key in self.data_keys: 51 | raise ValueError( 52 | f"Duplicated data key {key} provided by data sources" 53 | ) 54 | self.data_keys.add(key) 55 | for key, value in source.get_global_data().items(): 56 | if key in self.global_data: 57 | raise ValueError( 58 | f"Duplicated global data {key} provided by data sources" 59 | ) 60 | self.global_data[key] = value 61 | 62 | logger.info(f"Data keys: {', '.join(list(self.data_keys))}") 63 | logger.info(f"Global data: {', '.join(list(self.global_data.keys()))}") 64 | 65 | def __len__(self) -> int: 66 | return self.length - 1 67 | 68 | def __getitem__(self, idx) -> Dict[str, Tensor]: 69 | result = { 70 | "image": self.images[idx], 71 | "data_idx": torch.tensor([idx], dtype=torch.long), 72 | } 73 | 74 | for source in self.sources: 75 | result.update(source[idx]) 76 | 77 | return result 78 | 79 | def load_images(self, path: Path, scale: float) -> Tensor: 80 | files = multi_glob_sorted(path, ["*.jpg", "*.png", "*.JPG", "*.PNG"]) 81 | assert len(files) > 0, "No image file is found" 82 | 83 | logger.info(f"Loading {len(files)} images") 84 | return torch.from_numpy( 85 | np.stack( 86 | [read_image_np(f, scale)[..., :3] for f in files], 87 | axis=0, 88 | ) 89 | ) 90 | -------------------------------------------------------------------------------- /core/model/render_context.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass, field 4 | from typing import Dict, Optional 5 | 6 | from torch import Tensor 7 | 8 | from core.model.matting_dataset import MattingDataset 9 | 10 | 11 | @dataclass 12 | class RenderContext: 13 | coords: Tensor 14 | 15 | dataset: MattingDataset 16 | device: str 17 | is_train: bool 18 | global_step: int 19 | 20 | ray_offset: Optional[list] = None 21 | 22 | output: Dict[str, Tensor] = field(default_factory=dict) 23 | -------------------------------------------------------------------------------- /core/module/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Any, Dict 4 | 5 | import torch 6 | from torch import Tensor 7 | from torch.nn import Module 8 | 9 | from core.model.render_context import RenderContext 10 | from core.scheduler import build_scheduler 11 | from lib.registry import create_registry, import_children 12 | 13 | 14 | class CommonModel(Module): 15 | def forward(self, data: Dict[str, Tensor], ctx: RenderContext) -> None: 16 | raise NotImplementedError() 17 | 18 | def post_training_step(self, global_step: int) -> Dict[str, Any]: 19 | return {} 20 | 21 | def create_optimizer(self, params: Dict[str, Any]) -> torch.optim.Optimizer: 22 | return torch.optim.Adam(self.parameters(), **params) 23 | 24 | def create_scheduler( 25 | self, params: Dict[str, Any], optimizer: torch.optim.Optimizer 26 | ): 27 | name = params["name"] 28 | config = params["config"] 29 | config["optimizer"] = optimizer 30 | return build_scheduler(name, config) 31 | 32 | def get_kwargs_override(self) -> Dict[str, Any]: 33 | return {} 34 | 35 | 36 | _, register_fg_model, build_fg_model = create_registry("FgModel", CommonModel) 37 | _, register_bg_model, build_bg_model = create_registry("BgModel", CommonModel) 38 | import_children(__file__, __name__) 39 | -------------------------------------------------------------------------------- /core/module/dummy.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Dict 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | 9 | from core.model.render_context import RenderContext 10 | from core.module import CommonModel, register_bg_model, register_fg_model 11 | 12 | 13 | @register_fg_model("dummy") 14 | @register_bg_model("dummy") 15 | class DummyModel(CommonModel): 16 | """A model that does not produce any output""" 17 | 18 | def __init__(self): 19 | super().__init__() 20 | self.dummy = nn.Parameter(torch.zeros(1), requires_grad=True) 21 | 22 | def forward(self, data: Dict[str, Tensor], ctx: RenderContext) -> None: 23 | return 24 | -------------------------------------------------------------------------------- /core/scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | 5 | from lib.registry import create_registry, import_children 6 | 7 | 8 | class Scheduler(_LRScheduler): 9 | def __init__(self, optimizer, last_epoch=-1, verbose=False): 10 | super().__init__(optimizer, last_epoch, verbose) 11 | 12 | 13 | _, register_scheduler, build_scheduler = create_registry("Scheduler", Scheduler) 14 | import_children(__file__, __name__) 15 | -------------------------------------------------------------------------------- /core/scheduler/exp_lr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from core.scheduler import Scheduler, register_scheduler 4 | 5 | 6 | @register_scheduler("exp_lr") 7 | class ExpLR(Scheduler): 8 | def __init__( 9 | self, 10 | optimizer, 11 | decay_start, 12 | decay_rate, 13 | decay_steps, 14 | min_rate, 15 | last_epoch=-1, 16 | verbose=False, 17 | ): 18 | self.decay_start = decay_start 19 | self.decay_rate = decay_rate 20 | self.decay_steps = decay_steps 21 | self.min_rate = min_rate 22 | super().__init__(optimizer, last_epoch, verbose) 23 | 24 | def get_lr(self): 25 | if self.last_epoch < self.decay_start: 26 | return list(self.base_lrs) 27 | rate = max( 28 | self.min_rate, 29 | ( 30 | self.decay_rate 31 | ** ((self.last_epoch - self.decay_start) / self.decay_steps) 32 | ), 33 | ) 34 | return [rate * base_lr for base_lr in self.base_lrs] 35 | -------------------------------------------------------------------------------- /core/utils/tensorf_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from typing import Dict 5 | 6 | import torch 7 | from torch import Tensor 8 | 9 | from third_party.TensoRF.models.tensorBase import (MLPRender_Fea, 10 | positional_encoding) 11 | from third_party.TensoRF.models.tensoRF import TensorVMSplit 12 | 13 | logger = logging.getLogger(__name__) 14 | 15 | 16 | class MLPRender_Fea_Switchable(MLPRender_Fea): 17 | def __init__(self, inChannel, viewpe_skip_steps, viewpe=6, feape=6, featureC=128): 18 | super().__init__(inChannel, viewpe, feape, featureC) 19 | self.global_step = 0 20 | self.viewpe_skip_steps = viewpe_skip_steps 21 | self.viewpe_ch = ( 22 | 0 23 | if viewpe <= 0 24 | else positional_encoding(torch.zeros(1, 3), self.viewpe).shape[-1] 25 | ) 26 | self.log_once_keys = set() 27 | logger.info(f"View PE channels: {self.viewpe_ch}") 28 | 29 | def _log_once(self, key, message): 30 | if key in self.log_once_keys: 31 | return 32 | self.log_once_keys.add(key) 33 | logger.info(message) 34 | 35 | def forward(self, pts, viewdirs, features): 36 | indata = [features] 37 | if self.global_step < self.viewpe_skip_steps: 38 | self._log_once("zeros_viewdir", "Using zeros for viewdirs") 39 | indata += [torch.zeros_like(viewdirs)] 40 | else: 41 | self._log_once("actual_viewdir", "Start using actual viewdirs") 42 | indata += [viewdirs] 43 | 44 | if self.feape > 0: 45 | indata += [positional_encoding(features, self.feape)] 46 | 47 | if self.viewpe > 0: 48 | if self.global_step < self.viewpe_skip_steps: 49 | self._log_once("zeros_viewdir_pe", 50 | "Using zeros for viewdir PE") 51 | indata += [ 52 | torch.zeros(len(features), self.viewpe_ch, 53 | device=features.device) 54 | ] 55 | else: 56 | self._log_once("actual_viewdir_pe", 57 | "Start using acutal viewdir PE") 58 | indata += [positional_encoding(viewdirs, self.viewpe)] 59 | 60 | mlp_in = torch.cat(indata, dim=-1) 61 | rgb = self.mlp(mlp_in) 62 | rgb = torch.sigmoid(rgb) 63 | 64 | return rgb 65 | 66 | 67 | class MyTensorVMSplit(TensorVMSplit): 68 | def __init__(self, viewpe_skip_steps: int, **kargs): 69 | self.global_step = 0 70 | self.viewpe_skip_steps = viewpe_skip_steps 71 | 72 | super().__init__(**kargs) 73 | 74 | def init_render_func(self, shadingMode, pos_pe, view_pe, fea_pe, featureC, device): 75 | if shadingMode == "MLP_Fea": 76 | self.renderModule = MLPRender_Fea_Switchable( 77 | self.app_dim, self.viewpe_skip_steps, view_pe, fea_pe, featureC 78 | ).to(device) 79 | else: 80 | raise NotImplementedError() 81 | 82 | def forward( 83 | self, 84 | rays_chunk: Tensor, 85 | is_train: bool, 86 | ray_contraction: str, 87 | N_samples: int, 88 | **kwargs, 89 | ) -> Dict[str, Tensor]: 90 | self.renderModule.global_step = self.global_step 91 | return super().forward( 92 | rays_chunk, 93 | white_bg=False, 94 | is_train=is_train, 95 | ray_contraction=ray_contraction, 96 | N_samples=N_samples, 97 | **kwargs, 98 | ) 99 | 100 | def query_render_rgb(self, points, viewdirs): 101 | app_features = self.compute_appfeature(points) 102 | rgbs = self.renderModule(points, viewdirs, app_features) 103 | return rgbs 104 | -------------------------------------------------------------------------------- /core/utils/trainer_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Dict 4 | 5 | from torch import Tensor 6 | 7 | 8 | def frame_to_batch(frame: Dict[str, Tensor]) -> Dict[str, Tensor]: 9 | """ 10 | Convert frame data to batch data by prepending a [1] dimension, e.g. 11 | (H, W, C) tensor -> (1, H, W, C) tensor 12 | """ 13 | return {k: v[None] for k, v in frame.items()} 14 | 15 | 16 | def batch_to_frame(batch: Dict[str, Tensor], idx: int) -> Dict[str, Tensor]: 17 | """ 18 | Convert batch data to frame data by taking the idx item, e.g. 19 | (B, H, W, C) tensor -> (H, W, C) tensor 20 | """ 21 | return {k: v[idx] for k, v in batch.items()} 22 | -------------------------------------------------------------------------------- /data_manager_example.json: -------------------------------------------------------------------------------- 1 | { 2 | "local": { 3 | "data_root": "/data/matting", 4 | "output_root": "/output/matting" 5 | }, 6 | "remote": { 7 | "endpoint": "", 8 | "bucket": "", 9 | "access_key": "", 10 | "secret_key": "" 11 | } 12 | } 13 | -------------------------------------------------------------------------------- /docker/.bashrc: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | case $- in 4 | *i*) ;; 5 | *) return;; 6 | esac 7 | 8 | exec fish 9 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM logchan/pyenv:20221225.01 2 | 3 | USER root 4 | RUN mkdir /python && chown user:user /python 5 | 6 | USER user 7 | ADD .bashrc /home/user/.bashrc 8 | RUN mkdir -p /home/user/.config/fish 9 | ADD config.fish /home/user/.config/fish/config.fish 10 | 11 | WORKDIR /python 12 | RUN python3 -m virtualenv env 13 | 14 | RUN . env/bin/activate && \ 15 | python -m pip install --upgrade pip 16 | 17 | RUN . env/bin/activate && \ 18 | pip install --no-cache-dir \ 19 | autopep8 \ 20 | configargparse \ 21 | dataclasses-json \ 22 | dominate \ 23 | easydict \ 24 | hydra-core \ 25 | imageio-ffmpeg \ 26 | matplotlib \ 27 | minio \ 28 | ninja \ 29 | notebook \ 30 | opencv-python \ 31 | pillow \ 32 | plotly \ 33 | plyfile \ 34 | pylint \ 35 | scikit-image \ 36 | scipy \ 37 | tqdm \ 38 | visdom 39 | 40 | RUN . env/bin/activate && \ 41 | pip uninstall ipywidgets && \ 42 | pip install --no-cache-dir ipywidgets==7.7.2 43 | 44 | RUN . env/bin/activate && \ 45 | pip install --no-cache-dir \ 46 | torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 47 | 48 | RUN . env/bin/activate && \ 49 | pip install --no-cache-dir \ 50 | 'git+https://github.com/facebookresearch/detectron2.git' \ 51 | gradio \ 52 | kornia \ 53 | lpips \ 54 | tensorboard \ 55 | torch_efficient_distloss 56 | 57 | RUN . env/bin/activate && \ 58 | export MAKEFLAGS='-j 8' && \ 59 | export TCNN_CUDA_ARCHITECTURES=86 && \ 60 | pip install --no-cache-dir \ 61 | git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch \ 62 | timm 63 | -------------------------------------------------------------------------------- /docker/config.fish: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | set fish_greeting '' 4 | source /python/env/bin/activate.fish 5 | -------------------------------------------------------------------------------- /docker/docker-compose.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | version: '3.8' 4 | services: 5 | matting: 6 | image: 'logchan/matting:20221229.01' 7 | container_name: 'matting' 8 | deploy: 9 | resources: 10 | reservations: 11 | devices: 12 | - driver: nvidia 13 | count: 1 14 | capabilities: [gpu] 15 | restart: 'unless-stopped' 16 | volumes: 17 | - /home/user/research/code:/code 18 | - /home/user/research/data:/data 19 | - /home/user/research/output:/output 20 | - /home/user/research/devenv/matting/home:/home/user 21 | environment: 22 | - NVIDIA_DRIVER_CAPABILITIES=all 23 | working_dir: /code/OmnimatteRF 24 | shm_size: '8g' 25 | -------------------------------------------------------------------------------- /docs/docker-images.md: -------------------------------------------------------------------------------- 1 | # Containerized Environment 2 | 3 | ## Docker image 4 | 5 | Please use the pre-built Docker image from Docker Hub. It requires CUDA >= 11.7. 6 | 7 | ``` 8 | docker pull logchan/matting:20221229.01 9 | ``` 10 | 11 | ## Environment Setup 12 | 13 | - Map the following folders in your container: 14 | - `/code` that holds the `x3d_matting` folder 15 | - `/data` that contains your datasets (videos) 16 | - `/output` where you store training outputs 17 | - If you use a pvc, create `code`, `output`, `data` folders in the PVC and map them to the root of container. Transfer code (`x3d_matting`) to `pvc:code` and data (`matting`) to `pvc:data`. Then inside the container you can access code from `/code/x3d_matting` and data from `/data/matting/`. 18 | -------------------------------------------------------------------------------- /init.fish: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | export PYTHONPATH=(pwd) 4 | -------------------------------------------------------------------------------- /lib/hook.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Optional 4 | from lib.trainer import Trainer, TrainerEvents 5 | import logging 6 | 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | class Hook: 12 | def __init__(self, step_size: Optional[int] = None): 13 | self.step_size = step_size 14 | self._last_execution_step = -1 15 | 16 | def __call__(self, event: TrainerEvents, trainer: Trainer): 17 | if not self._valid_step(event, trainer.global_step): 18 | return 19 | logger.info( 20 | f"Executing {self.__class__.__name__} " 21 | f"(Current Step: {trainer.global_step} " 22 | f"Last Step {self._last_execution_step})" 23 | ) 24 | self._last_execution_step = trainer.global_step 25 | self.execute(trainer) 26 | 27 | def _valid_step(self, event: TrainerEvents, global_step: int): 28 | if self._last_execution_step == global_step: 29 | return False 30 | if event != TrainerEvents.POST_STEP: 31 | return True 32 | return self.step_size is None or global_step % self.step_size == 0 33 | 34 | def execute(self, trainer: Trainer): 35 | raise NotImplementedError 36 | -------------------------------------------------------------------------------- /lib/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from typing import Any, Dict, List, Optional, Tuple 5 | 6 | import torch 7 | import torch.nn as nn 8 | 9 | logger = logging.getLogger(__name__) 10 | 11 | 12 | class Loss(nn.Module): 13 | def __init__( 14 | self, 15 | alpha: float = 1.0, 16 | inputs: Optional[List[str]] = None, 17 | ): 18 | super().__init__() 19 | self.alpha = alpha 20 | self.inputs = inputs 21 | 22 | def __call__(self, workspace: Dict[str, Any]): 23 | if self.inputs is not None: 24 | missing = [k for k in self.inputs if k not in workspace] 25 | if len(missing) > 0: 26 | raise ValueError(f"Missing keys from workspace: {missing}") 27 | 28 | return super().__call__(*(workspace[k] for k in self.inputs)) 29 | else: 30 | return super().__call__(**workspace) 31 | 32 | 33 | class ComposedLoss(nn.ModuleDict): 34 | def validate(self): 35 | invalid = [name for name, loss in self.items() 36 | if not isinstance(loss, Loss)] 37 | if len(invalid) > 0: 38 | raise ValueError( 39 | f"Found loss that are not subclasses of Loss: {invalid}") 40 | 41 | def forward(self, workspace: Dict[str, Any]) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: 42 | loss_dict = {} 43 | for name, loss in self.items(): 44 | ret = loss(workspace) 45 | if ret is None: 46 | continue 47 | 48 | loss_dict[name] = ret 49 | 50 | total_loss = sum(loss_dict.values()) 51 | loss_dict["total_loss"] = total_loss 52 | return total_loss, loss_dict 53 | -------------------------------------------------------------------------------- /lib/registry.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import importlib 4 | import logging 5 | from pathlib import Path 6 | from typing import Any, Callable, Dict, Generic, Tuple, Type, TypeVar 7 | 8 | logger = logging.getLogger(__name__) 9 | T = TypeVar("T") 10 | TImpl = TypeVar("TImpl") 11 | 12 | 13 | class Registry(Generic[T]): 14 | def __init__(self, name: str, base: Callable[..., T]) -> None: 15 | super().__init__() 16 | self.name = name 17 | self.constructors: Dict[str, Callable[..., T]] = {} 18 | 19 | def add(self, name: str, constructor: Callable[..., T]) -> None: 20 | logger.info(f"Register {self.name}: {name}") 21 | self.constructors[name] = constructor 22 | 23 | def register(self, name: str): 24 | def adder(cls: Type[TImpl]) -> Type[TImpl]: 25 | self.add(name, cls) 26 | return cls 27 | return adder 28 | 29 | def build(self, name: str, kwargs: Dict[str, Any]) -> T: 30 | return self.constructors[name](**kwargs) 31 | 32 | 33 | def create_registry(name: str, base: Type[T]) -> Tuple[ 34 | Registry[T], 35 | Callable[[str], Callable[[Type[TImpl]], Type[TImpl]]], 36 | Callable[[str, Dict[str, Any]], T] 37 | ]: 38 | registry = Registry(name, base) 39 | return registry, registry.register, registry.build 40 | 41 | 42 | def import_children(path: str, module: str): 43 | folder = Path(path).parent 44 | for file in folder.glob("*.py"): 45 | if file.name == "__init__.py": 46 | continue 47 | name = module + "." + file.stem 48 | importlib.import_module(name) 49 | -------------------------------------------------------------------------------- /licenses/MiDaS: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/RAFT: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /licenses/RoDynRF: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) Meta Platforms, Inc. All Rights Reserved 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /licenses/TensoRF: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Anpei Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /preprocess/config/convert_segmentation.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | instances: ??? 4 | output: ??? 5 | indices: ??? 6 | width: 0 7 | height: 0 8 | -------------------------------------------------------------------------------- /preprocess/config/run_colmap.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | images: ??? 4 | masks: 5 | output: ??? 6 | colmap_binary: colmap 7 | colmap_options: 8 | feature_extractor: 9 | SiftExtraction.use_gpu: 0 10 | SiftExtraction.upright: 0 11 | ImageReader.camera_model: OPENCV 12 | ImageReader.single_camera: 1 13 | exhaustive_matcher: 14 | SiftMatching.use_gpu: 0 15 | mapper: 16 | Mapper.ba_refine_principal_point: 1 17 | Mapper.filter_max_reproj_error: 2 18 | Mapper.tri_complete_max_reproj_error: 2 19 | Mapper.min_num_matches: 32 20 | -------------------------------------------------------------------------------- /preprocess/config/run_depth.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | type: dpt_beit_large_512 4 | model: /data/matting/pretrained/midas/dpt_beit_large_512.pt 5 | device: cuda:0 6 | scale: 1 7 | input: ??? 8 | output: ??? 9 | -------------------------------------------------------------------------------- /preprocess/config/run_flow.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | input: ??? 4 | output: ??? 5 | model: raft-things.pth 6 | scale: 1 7 | device: cuda:0 8 | forward: true 9 | backward: true 10 | -------------------------------------------------------------------------------- /preprocess/config/run_homography.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | input: ??? 4 | output: ??? 5 | scale: -1 6 | device: cuda:0 7 | finder: 8 | matching: loftr 9 | loftr_pretrained: outdoor 10 | ransac_threshold: 3 11 | -------------------------------------------------------------------------------- /preprocess/config/run_motion_mask.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | dataset: ??? 4 | camera: ??? 5 | contraction: ??? 6 | scale: 1.0 7 | 8 | threshold: 5 9 | -------------------------------------------------------------------------------- /preprocess/config/run_segmentation.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | model: COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml 4 | device: cuda:0 5 | input: ??? 6 | output: ??? 7 | -------------------------------------------------------------------------------- /preprocess/config/video_to_images.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | input: ??? 4 | output: ??? 5 | step: 1 6 | limit: 200 7 | skip: 8 | mask: false 9 | -------------------------------------------------------------------------------- /preprocess/config/visualize_segmentation.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | instances: ??? 4 | output: ??? 5 | rgb: ??? 6 | draw_masks: true 7 | draw_mask_indices: [0,1] 8 | -------------------------------------------------------------------------------- /preprocess/run_colmap.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | import os 5 | import subprocess 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | from typing import Any, Dict, List, Optional 9 | 10 | import hydra 11 | import numpy as np 12 | from hydra.core.config_store import ConfigStore 13 | from omegaconf import OmegaConf 14 | 15 | from utils.colmap.colmap_utils import gen_poses 16 | from utils.io_utils import mkdir, multi_glob_sorted 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @dataclass 22 | class RunColmapConfig: 23 | images: str 24 | masks: Optional[str] 25 | output: str 26 | colmap_binary: str 27 | colmap_options: Dict[str, Dict[str, Any]] 28 | 29 | 30 | ConfigStore.instance().store(name="run_colmap_schema", node=RunColmapConfig) 31 | 32 | 33 | @hydra.main(version_base="1.2", config_path="config", config_name="run_colmap") 34 | def main(cfg: RunColmapConfig): 35 | image_dir = Path(cfg.images) 36 | mask_dir = Path(cfg.masks) if cfg.masks else None 37 | 38 | image_files = multi_glob_sorted(image_dir, ["*.png", "*.jpg"]) 39 | assert len(image_files) > 0, "No image is found!" 40 | 41 | colmap_cfg = OmegaConf.to_container(cfg.colmap_options) 42 | if mask_dir is not None: 43 | # check mask files before calling COLMAP 44 | for file in image_files: 45 | name = file.name 46 | assert os.path.isfile(mask_dir / f"{name}.png"), f"Mask image for {name} is not found!" 47 | 48 | colmap_cfg["feature_extractor"]["ImageReader.mask_path"] = str(mask_dir) 49 | 50 | out_root = mkdir(cfg.output) 51 | colmap_db_path = out_root / "database.db" 52 | colmap_out_path = out_root / "sparse" 53 | 54 | def run_colmap(cmd: List[str]) -> None: 55 | cmd = [str(v) for v in cmd] 56 | action = cmd[0] 57 | 58 | # apply additional configs 59 | options = colmap_cfg.get(action, {}) 60 | for key, value in options.items(): 61 | cmd += [f"--{key}", str(value)] 62 | 63 | logger.info("Run: colmap %s", " ".join(cmd)) 64 | log_dir = mkdir(out_root / "logs") 65 | 66 | stdout_file = open(log_dir / f"{action}.stdout.txt", "w", encoding="utf-8") 67 | stderr_file = open(log_dir / f"{action}.stderr.txt", "w", encoding="utf-8") 68 | try: 69 | subprocess.run( 70 | [cfg.colmap_binary, *cmd], stdout=stdout_file, stderr=stderr_file, check=True 71 | ) 72 | finally: 73 | stdout_file.close() 74 | stderr_file.close() 75 | 76 | try: 77 | run_colmap( 78 | [ 79 | "feature_extractor", 80 | "--database_path", 81 | colmap_db_path, 82 | "--image_path", 83 | image_dir, 84 | ], 85 | ) 86 | 87 | run_colmap( 88 | [ 89 | "exhaustive_matcher", 90 | "--database_path", 91 | colmap_db_path, 92 | ], 93 | ) 94 | 95 | colmap_out_path.mkdir(parents=True, exist_ok=True) 96 | run_colmap( 97 | [ 98 | "mapper", 99 | "--database_path", 100 | colmap_db_path, 101 | "--image_path", 102 | image_dir, 103 | "--output_path", 104 | colmap_out_path, 105 | ] 106 | ) 107 | 108 | gen_poses(str(out_root)) 109 | 110 | poses_file = out_root / "poses_bounds.npy" 111 | poses = np.load(poses_file) 112 | if len(poses) < len(image_files): 113 | logger.error(f"Colmap only recovered {len(poses)} for {len(image_files)} images") 114 | poses_file.unlink() 115 | except subprocess.CalledProcessError: 116 | logger.error(f"Colmap has failed, aborting") 117 | 118 | 119 | if __name__ == "__main__": 120 | main() 121 | -------------------------------------------------------------------------------- /preprocess/run_depth.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | 7 | import cv2 8 | import hydra 9 | import numpy as np 10 | import torch 11 | from hydra.core.config_store import ConfigStore 12 | from tqdm import tqdm 13 | 14 | from third_party.MiDaS.midas.model_loader import load_model 15 | from utils.image_utils import read_image_np, save_image_np, visualize_array 16 | from utils.io_utils import mkdir, multi_glob_sorted 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @dataclass 22 | class RunDepthConfig: 23 | type: str 24 | model: str 25 | device: str 26 | 27 | scale: float 28 | 29 | input: str 30 | output: str 31 | 32 | 33 | ConfigStore.instance().store(name="run_depth_schema", node=RunDepthConfig) 34 | 35 | 36 | @hydra.main(version_base="1.2", config_path="config", config_name="run_depth") 37 | def main(cfg: RunDepthConfig = None): 38 | device = cfg.device 39 | 40 | # scan input 41 | files = multi_glob_sorted(Path(cfg.input), ["*.png", "*.jpg"]) 42 | if len(files) < 1: 43 | raise ValueError("No image to process") 44 | logger.info(f"Process {len(files)} files") 45 | 46 | # output dir 47 | out_root = mkdir(cfg.output) 48 | depth_folder = mkdir(out_root / "depth") 49 | vis_folder = mkdir(out_root / "visualization") 50 | 51 | model, transform, net_w, net_h = load_model(device, cfg.model, cfg.type, optimize=False) 52 | logger.info(f"net_w, net_h = {net_w}, {net_h}") 53 | 54 | for file in tqdm(files): 55 | img = read_image_np(file, cfg.scale)[..., :3] 56 | img = transform({"image": img})["image"] 57 | img = torch.from_numpy(img).to(device)[None] 58 | 59 | with torch.no_grad(): 60 | prediction = model.forward(img) 61 | prediction = prediction[0].cpu().numpy() 62 | 63 | # save raw output, users should resize per use case 64 | np.save(depth_folder / f"{file.stem}.npy", prediction) 65 | 66 | # midas output is in disparity space, see 67 | # https://github.com/isl-org/MiDaS/issues/42#issuecomment-680801589 68 | vis = (prediction - prediction.min()) / (prediction.max() - prediction.min()) 69 | vis = visualize_array(vis, cv2.COLORMAP_TURBO) 70 | save_image_np(vis_folder / f"{file.stem}.png", vis) 71 | 72 | 73 | if __name__ == "__main__": 74 | main() 75 | -------------------------------------------------------------------------------- /preprocess/run_flow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | import os 5 | from argparse import Namespace 6 | from dataclasses import dataclass 7 | from pathlib import Path 8 | 9 | import hydra 10 | import numpy as np 11 | import torch 12 | import torch.nn.functional as F 13 | from hydra.core.config_store import ConfigStore 14 | from PIL import Image 15 | from tqdm import tqdm 16 | 17 | from third_party.RAFT.core.raft import RAFT 18 | from third_party.RAFT.core.utils.frame_utils import writeFlow 19 | from utils.io_utils import mkdir 20 | 21 | 22 | @dataclass 23 | class RunFlowConfig: 24 | model: str 25 | scale: float 26 | device: str 27 | input: str 28 | output: str 29 | forward: bool 30 | backward: bool 31 | 32 | 33 | ConfigStore.instance().store(name="run_flow_schema", node=RunFlowConfig) 34 | 35 | 36 | @hydra.main(version_base="1.2", config_path="config", config_name="run_flow") 37 | def main(cfg: RunFlowConfig=None): 38 | device = cfg.device 39 | scale = cfg.scale 40 | 41 | out_root = mkdir(cfg.output) 42 | forward_dir = mkdir(out_root / "flow") 43 | backward_dir = mkdir(out_root / "flow_backward") 44 | 45 | # scan input 46 | in_root = Path(cfg.input) 47 | files = sorted(os.listdir(in_root)) 48 | logging.info(f"Process {len(files)} files") 49 | 50 | # load model 51 | model = create_raft_model(cfg.model, device) 52 | 53 | # read the first image and determine padding 54 | curr_image, pad_h, pad_w = load_image(in_root / files[0], scale, device) 55 | 56 | def save_flow(file, flow): 57 | """Save prediction from network with padding removed""" 58 | flow = flow.cpu().numpy()[0].transpose([1, 2, 0]) 59 | H, W = flow.shape[:2] 60 | flow = flow[0 : H - pad_h, pad_w // 2 : (W - pad_w + pad_w // 2)] 61 | 62 | writeFlow(file, flow) 63 | 64 | for i in tqdm(range(len(files) - 1)): 65 | next_image, _, _ = load_image(in_root / files[i + 1], scale, device) 66 | name = os.path.splitext(files[i])[0] + ".flo" 67 | 68 | with torch.no_grad(): 69 | if cfg.forward: 70 | _, forward = model(curr_image, next_image, iters=20, test_mode=True) 71 | save_flow(forward_dir / name, forward) 72 | 73 | if cfg.backward: 74 | _, backward = model(next_image, curr_image, iters=20, test_mode=True) 75 | save_flow(backward_dir / name, backward) 76 | 77 | curr_image = next_image 78 | 79 | 80 | def create_raft_model(ckpt: str, device: str) -> RAFT: 81 | cp = torch.load(ckpt, map_location="cpu") 82 | 83 | # remove DataParallel prefix "module." from dictionary 84 | cp = {k[len("module."):]: cp[k] for k in cp} 85 | 86 | args = Namespace() 87 | args.small = False 88 | args.mixed_precision = False 89 | args.alternate_corr = False 90 | 91 | model = RAFT(args) 92 | model.load_state_dict(cp) 93 | model = model.to(device).eval() 94 | return model 95 | 96 | 97 | def load_image(path: str, scale: float, device: str) -> torch.Tensor: 98 | """Read an image and convert to [1, 3, H, W] float tensor, keeping values in [0, 255]. 99 | Also pad the sides of the image to multiples of 8. 100 | 101 | returns: image, pad_h, pad_w 102 | """ 103 | img = Image.open(path) 104 | if scale != 1: 105 | img = img.resize( 106 | (int(np.round(scale * img.width)), int(np.round(scale * img.height))), 107 | Image.LANCZOS, 108 | ) 109 | 110 | img = np.array(img, dtype=np.float32)[..., :3] # drop alpha channel 111 | 112 | img = torch.from_numpy(img.transpose([2, 0, 1]))[None] 113 | 114 | H, W = img.shape[-2:] 115 | pH = (8 - H % 8) % 8 116 | pW = (8 - W % 8) % 8 117 | if pH != 0 or pW != 0: 118 | img = F.pad(img, [pW // 2, pW - pW // 2, 0, pH], mode="replicate") 119 | 120 | return img.to(device), pH, pW 121 | 122 | 123 | if __name__ == "__main__": 124 | main() 125 | -------------------------------------------------------------------------------- /preprocess/run_segmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from zipfile import ZIP_DEFLATED, ZipFile 6 | 7 | import hydra 8 | import numpy as np 9 | import torch 10 | from detectron2 import model_zoo 11 | from hydra.core.config_store import ConfigStore 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | from utils.io_utils import mkdir, multi_glob_sorted 16 | 17 | 18 | @dataclass 19 | class RunSegmentationConfig: 20 | model: str 21 | device: str 22 | input: str 23 | output: str 24 | 25 | 26 | ConfigStore.instance().store(name="run_segmentation_schema", node=RunSegmentationConfig) 27 | 28 | 29 | @hydra.main(version_base="1.2", config_path="config", config_name="run_segmentation") 30 | def main(cfg: RunSegmentationConfig): 31 | device = cfg.device 32 | in_root = Path(cfg.input) 33 | 34 | write_to_zip = cfg.output.endswith(".zip") 35 | if write_to_zip: 36 | zip_folder = Path(cfg.output).stem 37 | out_zip = ZipFile(cfg.output, "w", compression=ZIP_DEFLATED) 38 | else: 39 | out_root = mkdir(cfg.output) 40 | 41 | in_files = multi_glob_sorted(in_root, ["*.png", "*.jpg"]) 42 | 43 | model = model_zoo.get(cfg.model, trained=True).eval().to(device) 44 | 45 | for file in tqdm(in_files): 46 | image = np.array(Image.open(file), dtype=np.float32) 47 | image = image[..., :3] # RGBA to RGB 48 | image = torch.from_numpy(image.transpose(2, 0, 1)) 49 | inputs = [ 50 | { 51 | "image": image.to(device), 52 | "height": image.shape[1], 53 | "width": image.shape[2], 54 | } 55 | ] 56 | with torch.no_grad(): 57 | instances = model(inputs)[0]["instances"].to("cpu") 58 | 59 | fields = instances.get_fields() 60 | result = { 61 | "pred_boxes": fields["pred_boxes"].tensor, 62 | **{k: fields[k] for k in ["scores", "pred_classes", "pred_masks"]} 63 | } 64 | 65 | if write_to_zip: 66 | with out_zip.open(f"{zip_folder}/{file.stem}.ckpt", mode="w") as f: 67 | torch.save(result, f) 68 | else: 69 | torch.save(result, out_root / f"{file.stem}.ckpt") 70 | 71 | if write_to_zip: 72 | out_zip.close() 73 | 74 | 75 | if __name__ == "__main__": 76 | main() 77 | -------------------------------------------------------------------------------- /preprocess/video_to_images.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | import shutil 5 | import subprocess 6 | import tempfile 7 | from dataclasses import dataclass 8 | from pathlib import Path 9 | import json 10 | from typing import Optional 11 | 12 | import hydra 13 | from hydra.core.config_store import ConfigStore 14 | 15 | from utils.io_utils import multi_glob_sorted 16 | from utils.image_utils import read_image_np, save_image_np 17 | 18 | logger = logging.getLogger(__name__) 19 | 20 | 21 | @dataclass 22 | class VideoToImagesConfig: 23 | input: str 24 | output: str 25 | step: int 26 | limit: int 27 | skip: Optional[str] 28 | mask: bool 29 | 30 | 31 | ConfigStore.instance().store(name="video_to_images_schema", node=VideoToImagesConfig) 32 | 33 | 34 | @hydra.main(version_base="1.2", config_path="config", config_name="video_to_images") 35 | def main(cfg: VideoToImagesConfig): 36 | out_dir = Path(cfg.output) 37 | if len(multi_glob_sorted(out_dir, ["*.png", "*.jpg"])) > 0: 38 | raise ValueError("Output folder is not empty") 39 | 40 | proc = subprocess.run([ 41 | "ffprobe", 42 | "-v", 43 | "quiet", 44 | "-of", 45 | "json", 46 | "-show_streams", 47 | "-select_streams", 48 | "v:0", 49 | cfg.input, 50 | ], capture_output=True, check=True) 51 | data = json.loads(proc.stdout.decode("utf-8"))["streams"][0] 52 | width = data["width"] 53 | height = data["height"] 54 | n_frames = data["nb_frames"] 55 | 56 | logger.info(f"Video frame is {width}x{height}, total {n_frames} frames") 57 | need_scale = None 58 | if width > height: 59 | if height > 1080: 60 | need_scale = "h" 61 | else: 62 | if width > 1080: 63 | need_scale = "w" 64 | 65 | with tempfile.TemporaryDirectory() as tmpdir: 66 | logger.info("Extract frames with ffmpeg") 67 | 68 | ss = [] 69 | if cfg.skip is not None: 70 | ss = ["-ss", cfg.skip] 71 | 72 | vf = [] 73 | if cfg.step > 1: 74 | vf.append(f"select='not(mod(n\\,{cfg.step}))'") 75 | if need_scale == "w": 76 | vf.append(f"scale=w=1080:h=-1") 77 | elif need_scale == "h": 78 | vf.append(f"scale=w=-1:h=1080") 79 | 80 | if len(vf) > 0: 81 | vf = ["-vf", ",".join(vf)] 82 | 83 | args = [ 84 | "ffmpeg", 85 | "-vsync", 86 | "drop", 87 | *ss, 88 | "-i", 89 | cfg.input, 90 | *vf, 91 | "-frames:v", 92 | str(cfg.limit), 93 | f"{tmpdir}/%05d.png" 94 | ] 95 | 96 | logger.info(" ".join(args)) 97 | subprocess.run(args, check=True) 98 | 99 | tmpdir = Path(tmpdir) 100 | files = multi_glob_sorted(tmpdir, "*.png") 101 | logger.info(f"Output has {len(files)} frames") 102 | 103 | out_dir.mkdir(parents=True, exist_ok=True) 104 | for i, file in enumerate(files): 105 | if not cfg.mask: 106 | shutil.move(file, out_dir / f"{i:05d}.png") 107 | continue 108 | 109 | # convert mask image 110 | img = read_image_np(file) 111 | if len(img.shape) == 3: 112 | img = img[..., 0] 113 | save_image_np(out_dir / f"{i:05d}.png", img) 114 | 115 | 116 | if __name__ == "__main__": 117 | main() 118 | -------------------------------------------------------------------------------- /preprocess/visualize_segmentation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | from dataclasses import dataclass 5 | from pathlib import Path 6 | from typing import List, Optional, Tuple 7 | 8 | import hydra 9 | import numpy as np 10 | import torch 11 | from hydra.core.config_store import ConfigStore 12 | from PIL import Image, ImageDraw 13 | from tqdm import tqdm 14 | 15 | from utils.io_utils import mkdir 16 | 17 | 18 | @dataclass 19 | class VisualizeSegmentationConfig: 20 | instances: str 21 | output: str 22 | rgb: str 23 | draw_masks: bool 24 | draw_mask_indices: Optional[List[int]] 25 | 26 | 27 | ConfigStore.instance().store(name="visualize_segmentation_schema", node=VisualizeSegmentationConfig) 28 | 29 | 30 | @hydra.main(version_base="1.2", config_path="config", config_name="visualize_segmentation") 31 | def main(cfg: VisualizeSegmentationConfig): 32 | in_root = Path(cfg.instances) 33 | rgb_root = Path(cfg.rgb) 34 | files = sorted(os.listdir(in_root)) 35 | out_root = mkdir(cfg.output) 36 | 37 | mask_frames = cfg.draw_mask_indices or [] 38 | 39 | for i_file, file in tqdm(list(enumerate(files))): 40 | name = os.path.splitext(file)[0] 41 | 42 | img = load_rgb(rgb_root, name) 43 | out_img = img.copy() 44 | draw = ImageDraw.Draw(out_img) 45 | w_rgb = 0.4 46 | 47 | boxes, masks = load_instance(in_root / file) 48 | for i in range(len(boxes)): 49 | x, y = boxes[i][:2] 50 | draw.rectangle(boxes[i], outline=(255, 0, 0)) 51 | draw.text((x + 2, y + 2), f"{i}", fill=(255, 0, 0)) 52 | 53 | if not cfg.draw_masks: 54 | continue 55 | if len(mask_frames) > 0 and i_file not in mask_frames: 56 | continue 57 | 58 | mask = masks[i] 59 | mask_img = np.array(img, dtype=np.float32) / 255 60 | mask_img[mask] = w_rgb * mask_img[mask] + (1 - w_rgb) 61 | mask_img = Image.fromarray((mask_img * 255).astype(np.uint8)) 62 | mask_img.save(mkdir(out_root / "masked" / name) / f"{i}.jpg", quality=90) 63 | 64 | out_img.save(mkdir(out_root / "boxes") / f"{name}.jpg", quality=90) 65 | 66 | 67 | def load_rgb(rgb_root: str, name: str) -> Image.Image: 68 | root = Path(rgb_root) 69 | for ext in [".png", ".jpg"]: 70 | file = root / f"{name}{ext}" 71 | if not os.path.isfile(file): 72 | continue 73 | 74 | return Image.open(file) 75 | 76 | raise RuntimeError(f"RGB file {name} not found") 77 | 78 | 79 | def load_instance(path: str) -> Tuple[np.ndarray, ...]: 80 | cp = torch.load(path, map_location="cpu") 81 | cp = {k: v.numpy() for k, v in cp.items()} 82 | return [cp[k] for k in ["pred_boxes", "pred_masks"]] 83 | 84 | 85 | if __name__ == "__main__": 86 | main() 87 | -------------------------------------------------------------------------------- /third_party/MiDaS/.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | *.png 107 | *.pfm 108 | *.jpg 109 | *.jpeg 110 | *.pt -------------------------------------------------------------------------------- /third_party/MiDaS/Dockerfile: -------------------------------------------------------------------------------- 1 | # enables cuda support in docker 2 | FROM nvidia/cuda:10.2-cudnn7-runtime-ubuntu18.04 3 | 4 | # install python 3.6, pip and requirements for opencv-python 5 | # (see https://github.com/NVIDIA/nvidia-docker/issues/864) 6 | RUN apt-get update && apt-get -y install \ 7 | python3 \ 8 | python3-pip \ 9 | libsm6 \ 10 | libxext6 \ 11 | libxrender-dev \ 12 | curl \ 13 | && rm -rf /var/lib/apt/lists/* 14 | 15 | # install python dependencies 16 | RUN pip3 install --upgrade pip 17 | RUN pip3 install torch~=1.8 torchvision opencv-python-headless~=3.4 timm 18 | 19 | # copy inference code 20 | WORKDIR /opt/MiDaS 21 | COPY ./midas ./midas 22 | COPY ./*.py ./ 23 | 24 | # download model weights so the docker image can be used offline 25 | RUN cd weights && {curl -OL https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt; cd -; } 26 | RUN python3 run.py --model_type dpt_hybrid; exit 0 27 | 28 | # entrypoint (dont forget to mount input and output directories) 29 | CMD python3 run.py --model_type dpt_hybrid 30 | -------------------------------------------------------------------------------- /third_party/MiDaS/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab) 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/MiDaS/README.third_party: -------------------------------------------------------------------------------- 1 | Source: https://github.com/isl-org/MiDaS 2 | Commit: 1645b7e1675301fdfac03640738fe5a6531e17d6 3 | -------------------------------------------------------------------------------- /third_party/MiDaS/environment.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: midas-py310 4 | channels: 5 | - pytorch 6 | - defaults 7 | dependencies: 8 | - nvidia::cudatoolkit=11.7 9 | - python=3.10.8 10 | - pytorch::pytorch=1.13.0 11 | - torchvision=0.14.0 12 | - pip=22.3.1 13 | - numpy=1.23.4 14 | - pip: 15 | - opencv-python==4.6.0.66 16 | - imutils==0.5.4 17 | - timm==0.6.12 18 | - einops==0.6.0 -------------------------------------------------------------------------------- /third_party/MiDaS/midas/backbones/levit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import timm 4 | import torch 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | from .utils import activations, get_activation, Transpose 9 | 10 | 11 | def forward_levit(pretrained, x): 12 | pretrained.model.forward_features(x) 13 | 14 | layer_1 = pretrained.activations["1"] 15 | layer_2 = pretrained.activations["2"] 16 | layer_3 = pretrained.activations["3"] 17 | 18 | layer_1 = pretrained.act_postprocess1(layer_1) 19 | layer_2 = pretrained.act_postprocess2(layer_2) 20 | layer_3 = pretrained.act_postprocess3(layer_3) 21 | 22 | return layer_1, layer_2, layer_3 23 | 24 | 25 | def _make_levit_backbone( 26 | model, 27 | hooks=[3, 11, 21], 28 | patch_grid=[14, 14] 29 | ): 30 | pretrained = nn.Module() 31 | 32 | pretrained.model = model 33 | pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1")) 34 | pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2")) 35 | pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3")) 36 | 37 | pretrained.activations = activations 38 | 39 | patch_grid_size = np.array(patch_grid, dtype=int) 40 | 41 | pretrained.act_postprocess1 = nn.Sequential( 42 | Transpose(1, 2), 43 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) 44 | ) 45 | pretrained.act_postprocess2 = nn.Sequential( 46 | Transpose(1, 2), 47 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist())) 48 | ) 49 | pretrained.act_postprocess3 = nn.Sequential( 50 | Transpose(1, 2), 51 | nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist())) 52 | ) 53 | 54 | return pretrained 55 | 56 | 57 | class ConvTransposeNorm(nn.Sequential): 58 | """ 59 | Modification of 60 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm 61 | such that ConvTranspose2d is used instead of Conv2d. 62 | """ 63 | 64 | def __init__( 65 | self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1, 66 | groups=1, bn_weight_init=1): 67 | super().__init__() 68 | self.add_module('c', 69 | nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False)) 70 | self.add_module('bn', nn.BatchNorm2d(out_chs)) 71 | 72 | nn.init.constant_(self.bn.weight, bn_weight_init) 73 | 74 | @torch.no_grad() 75 | def fuse(self): 76 | c, bn = self._modules.values() 77 | w = bn.weight / (bn.running_var + bn.eps) ** 0.5 78 | w = c.weight * w[:, None, None, None] 79 | b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5 80 | m = nn.ConvTranspose2d( 81 | w.size(1), w.size(0), w.shape[2:], stride=self.c.stride, 82 | padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups) 83 | m.weight.data.copy_(w) 84 | m.bias.data.copy_(b) 85 | return m 86 | 87 | 88 | def stem_b4_transpose(in_chs, out_chs, activation): 89 | """ 90 | Modification of 91 | https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16 92 | such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half. 93 | """ 94 | return nn.Sequential( 95 | ConvTransposeNorm(in_chs, out_chs, 3, 2, 1), 96 | activation(), 97 | ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1), 98 | activation()) 99 | 100 | 101 | def _make_pretrained_levit_384(pretrained, hooks=None): 102 | model = timm.create_model("levit_384", pretrained=pretrained) 103 | 104 | hooks = [3, 11, 21] if hooks == None else hooks 105 | return _make_levit_backbone( 106 | model, 107 | hooks=hooks 108 | ) 109 | -------------------------------------------------------------------------------- /third_party/MiDaS/midas/backbones/next_vit.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import timm 4 | 5 | import torch.nn as nn 6 | 7 | from pathlib import Path 8 | from .utils import activations, forward_default, get_activation 9 | 10 | from ..external.next_vit.classification.nextvit import * 11 | 12 | 13 | def forward_next_vit(pretrained, x): 14 | return forward_default(pretrained, x, "forward") 15 | 16 | 17 | def _make_next_vit_backbone( 18 | model, 19 | hooks=[2, 6, 36, 39], 20 | ): 21 | pretrained = nn.Module() 22 | 23 | pretrained.model = model 24 | pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1")) 25 | pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2")) 26 | pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3")) 27 | pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4")) 28 | 29 | pretrained.activations = activations 30 | 31 | return pretrained 32 | 33 | 34 | def _make_pretrained_next_vit_large_6m(hooks=None): 35 | model = timm.create_model("nextvit_large") 36 | 37 | hooks = [2, 6, 36, 39] if hooks == None else hooks 38 | return _make_next_vit_backbone( 39 | model, 40 | hooks=hooks, 41 | ) 42 | -------------------------------------------------------------------------------- /third_party/MiDaS/midas/backbones/swin.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import timm 4 | 5 | from .swin_common import _make_swin_backbone 6 | 7 | 8 | def _make_pretrained_swinl12_384(pretrained, hooks=None): 9 | model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained) 10 | 11 | hooks = [1, 1, 17, 1] if hooks == None else hooks 12 | return _make_swin_backbone( 13 | model, 14 | hooks=hooks 15 | ) 16 | -------------------------------------------------------------------------------- /third_party/MiDaS/midas/backbones/swin2.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import timm 4 | 5 | from .swin_common import _make_swin_backbone 6 | 7 | 8 | def _make_pretrained_swin2l24_384(pretrained, hooks=None): 9 | model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained) 10 | 11 | hooks = [1, 1, 17, 1] if hooks == None else hooks 12 | return _make_swin_backbone( 13 | model, 14 | hooks=hooks 15 | ) 16 | 17 | 18 | def _make_pretrained_swin2b24_384(pretrained, hooks=None): 19 | model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained) 20 | 21 | hooks = [1, 1, 17, 1] if hooks == None else hooks 22 | return _make_swin_backbone( 23 | model, 24 | hooks=hooks 25 | ) 26 | 27 | 28 | def _make_pretrained_swin2t16_256(pretrained, hooks=None): 29 | model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained) 30 | 31 | hooks = [1, 1, 5, 1] if hooks == None else hooks 32 | return _make_swin_backbone( 33 | model, 34 | hooks=hooks, 35 | patch_grid=[64, 64] 36 | ) 37 | -------------------------------------------------------------------------------- /third_party/MiDaS/midas/backbones/swin_common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | from .utils import activations, forward_default, get_activation, Transpose 9 | 10 | 11 | def forward_swin(pretrained, x): 12 | return forward_default(pretrained, x) 13 | 14 | 15 | def _make_swin_backbone( 16 | model, 17 | hooks=[1, 1, 17, 1], 18 | patch_grid=[96, 96] 19 | ): 20 | pretrained = nn.Module() 21 | 22 | pretrained.model = model 23 | pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1")) 24 | pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2")) 25 | pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3")) 26 | pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4")) 27 | 28 | pretrained.activations = activations 29 | 30 | if hasattr(model, "patch_grid"): 31 | used_patch_grid = model.patch_grid 32 | else: 33 | used_patch_grid = patch_grid 34 | 35 | patch_grid_size = np.array(used_patch_grid, dtype=int) 36 | 37 | pretrained.act_postprocess1 = nn.Sequential( 38 | Transpose(1, 2), 39 | nn.Unflatten(2, torch.Size(patch_grid_size.tolist())) 40 | ) 41 | pretrained.act_postprocess2 = nn.Sequential( 42 | Transpose(1, 2), 43 | nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist())) 44 | ) 45 | pretrained.act_postprocess3 = nn.Sequential( 46 | Transpose(1, 2), 47 | nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist())) 48 | ) 49 | pretrained.act_postprocess4 = nn.Sequential( 50 | Transpose(1, 2), 51 | nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist())) 52 | ) 53 | 54 | return pretrained 55 | -------------------------------------------------------------------------------- /third_party/MiDaS/midas/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | 5 | 6 | class BaseModel(torch.nn.Module): 7 | def load(self, path): 8 | """Load model from file. 9 | 10 | Args: 11 | path (str): file path 12 | """ 13 | parameters = torch.load(path, map_location=torch.device('cpu')) 14 | 15 | if "optimizer" in parameters: 16 | parameters = parameters["model"] 17 | 18 | self.load_state_dict(parameters) 19 | -------------------------------------------------------------------------------- /third_party/MiDaS/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 4 | This file contains code that is adapted from 5 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .base_model import BaseModel 11 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 12 | 13 | 14 | class MidasNet(BaseModel): 15 | """Network for monocular depth estimation. 16 | """ 17 | 18 | def __init__(self, path=None, features=256, non_negative=True): 19 | """Init. 20 | 21 | Args: 22 | path (str, optional): Path to saved model. Defaults to None. 23 | features (int, optional): Number of features. Defaults to 256. 24 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 25 | """ 26 | print("Loading weights: ", path) 27 | 28 | super(MidasNet, self).__init__() 29 | 30 | use_pretrained = False if path is None else True 31 | 32 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 33 | 34 | self.scratch.refinenet4 = FeatureFusionBlock(features) 35 | self.scratch.refinenet3 = FeatureFusionBlock(features) 36 | self.scratch.refinenet2 = FeatureFusionBlock(features) 37 | self.scratch.refinenet1 = FeatureFusionBlock(features) 38 | 39 | self.scratch.output_conv = nn.Sequential( 40 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 41 | Interpolate(scale_factor=2, mode="bilinear"), 42 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 43 | nn.ReLU(True), 44 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 45 | nn.ReLU(True) if non_negative else nn.Identity(), 46 | ) 47 | 48 | if path: 49 | self.load(path) 50 | 51 | def forward(self, x): 52 | """Forward pass. 53 | 54 | Args: 55 | x (tensor): input data (image) 56 | 57 | Returns: 58 | tensor: depth 59 | """ 60 | 61 | layer_1 = self.pretrained.layer1(x) 62 | layer_2 = self.pretrained.layer2(layer_1) 63 | layer_3 = self.pretrained.layer3(layer_2) 64 | layer_4 = self.pretrained.layer4(layer_3) 65 | 66 | layer_1_rn = self.scratch.layer1_rn(layer_1) 67 | layer_2_rn = self.scratch.layer2_rn(layer_2) 68 | layer_3_rn = self.scratch.layer3_rn(layer_3) 69 | layer_4_rn = self.scratch.layer4_rn(layer_4) 70 | 71 | path_4 = self.scratch.refinenet4(layer_4_rn) 72 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 73 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 74 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 75 | 76 | out = self.scratch.output_conv(path_1) 77 | 78 | return torch.squeeze(out, dim=1) 79 | -------------------------------------------------------------------------------- /third_party/RAFT/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.egg-info 3 | dist 4 | datasets 5 | pytorch_env 6 | models 7 | build 8 | correlation.egg-info 9 | -------------------------------------------------------------------------------- /third_party/RAFT/LICENSE: -------------------------------------------------------------------------------- 1 | BSD 3-Clause License 2 | 3 | Copyright (c) 2020, princeton-vl 4 | All rights reserved. 5 | 6 | Redistribution and use in source and binary forms, with or without 7 | modification, are permitted provided that the following conditions are met: 8 | 9 | * Redistributions of source code must retain the above copyright notice, this 10 | list of conditions and the following disclaimer. 11 | 12 | * Redistributions in binary form must reproduce the above copyright notice, 13 | this list of conditions and the following disclaimer in the documentation 14 | and/or other materials provided with the distribution. 15 | 16 | * Neither the name of the copyright holder nor the names of its 17 | contributors may be used to endorse or promote products derived from 18 | this software without specific prior written permission. 19 | 20 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 21 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 22 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 23 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE 24 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 25 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 26 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 27 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 28 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 29 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 30 | -------------------------------------------------------------------------------- /third_party/RAFT/METADATA: -------------------------------------------------------------------------------- 1 | Code from: https://github.com/princeton-vl/RAFT.git 2 | Commit: aac9dd54726caf2cf81d8661b07663e220c5586d 3 | -------------------------------------------------------------------------------- /third_party/RAFT/RAFT.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/facebookresearch/OmnimatteRF/f3bbc62a4df062af0409ac067910965372919107/third_party/RAFT/RAFT.png -------------------------------------------------------------------------------- /third_party/RAFT/README.md: -------------------------------------------------------------------------------- 1 | # RAFT 2 | This repository contains the source code for our paper: 3 | 4 | [RAFT: Recurrent All Pairs Field Transforms for Optical Flow](https://arxiv.org/pdf/2003.12039.pdf)
5 | ECCV 2020
6 | Zachary Teed and Jia Deng
7 | 8 | 9 | 10 | ## Requirements 11 | The code has been tested with PyTorch 1.6 and Cuda 10.1. 12 | ```Shell 13 | conda create --name raft 14 | conda activate raft 15 | conda install pytorch=1.6.0 torchvision=0.7.0 cudatoolkit=10.1 matplotlib tensorboard scipy opencv -c pytorch 16 | ``` 17 | 18 | ## Demos 19 | Pretrained models can be downloaded by running 20 | ```Shell 21 | ./download_models.sh 22 | ``` 23 | or downloaded from [google drive](https://drive.google.com/drive/folders/1sWDsfuZ3Up38EUQt7-JDTT1HcGHuJgvT?usp=sharing) 24 | 25 | You can demo a trained model on a sequence of frames 26 | ```Shell 27 | python demo.py --model=models/raft-things.pth --path=demo-frames 28 | ``` 29 | 30 | ## Required Data 31 | To evaluate/train RAFT, you will need to download the required datasets. 32 | * [FlyingChairs](https://lmb.informatik.uni-freiburg.de/resources/datasets/FlyingChairs.en.html#flyingchairs) 33 | * [FlyingThings3D](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html) 34 | * [Sintel](http://sintel.is.tue.mpg.de/) 35 | * [KITTI](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=flow) 36 | * [HD1K](http://hci-benchmark.iwr.uni-heidelberg.de/) (optional) 37 | 38 | 39 | By default `datasets.py` will search for the datasets in these locations. You can create symbolic links to wherever the datasets were downloaded in the `datasets` folder 40 | 41 | ```Shell 42 | ├── datasets 43 | ├── Sintel 44 | ├── test 45 | ├── training 46 | ├── KITTI 47 | ├── testing 48 | ├── training 49 | ├── devkit 50 | ├── FlyingChairs_release 51 | ├── data 52 | ├── FlyingThings3D 53 | ├── frames_cleanpass 54 | ├── frames_finalpass 55 | ├── optical_flow 56 | ``` 57 | 58 | ## Evaluation 59 | You can evaluate a trained model using `evaluate.py` 60 | ```Shell 61 | python evaluate.py --model=models/raft-things.pth --dataset=sintel --mixed_precision 62 | ``` 63 | 64 | ## Training 65 | We used the following training schedule in our paper (2 GPUs). Training logs will be written to the `runs` which can be visualized using tensorboard 66 | ```Shell 67 | ./train_standard.sh 68 | ``` 69 | 70 | If you have a RTX GPU, training can be accelerated using mixed precision. You can expect similiar results in this setting (1 GPU) 71 | ```Shell 72 | ./train_mixed.sh 73 | ``` 74 | 75 | ## (Optional) Efficent Implementation 76 | You can optionally use our alternate (efficent) implementation by compiling the provided cuda extension 77 | ```Shell 78 | cd alt_cuda_corr && python setup.py install && cd .. 79 | ``` 80 | and running `demo.py` and `evaluate.py` with the `--alternate_corr` flag Note, this implementation is somewhat slower than all-pairs, but uses significantly less GPU memory during the forward pass. 81 | -------------------------------------------------------------------------------- /third_party/RAFT/alt_cuda_corr/correlation.cpp: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | #include 4 | #include 5 | 6 | // CUDA forward declarations 7 | std::vector corr_cuda_forward( 8 | torch::Tensor fmap1, 9 | torch::Tensor fmap2, 10 | torch::Tensor coords, 11 | int radius); 12 | 13 | std::vector corr_cuda_backward( 14 | torch::Tensor fmap1, 15 | torch::Tensor fmap2, 16 | torch::Tensor coords, 17 | torch::Tensor corr_grad, 18 | int radius); 19 | 20 | // C++ interface 21 | #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") 22 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 23 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 24 | 25 | std::vector corr_forward( 26 | torch::Tensor fmap1, 27 | torch::Tensor fmap2, 28 | torch::Tensor coords, 29 | int radius) { 30 | CHECK_INPUT(fmap1); 31 | CHECK_INPUT(fmap2); 32 | CHECK_INPUT(coords); 33 | 34 | return corr_cuda_forward(fmap1, fmap2, coords, radius); 35 | } 36 | 37 | 38 | std::vector corr_backward( 39 | torch::Tensor fmap1, 40 | torch::Tensor fmap2, 41 | torch::Tensor coords, 42 | torch::Tensor corr_grad, 43 | int radius) { 44 | CHECK_INPUT(fmap1); 45 | CHECK_INPUT(fmap2); 46 | CHECK_INPUT(coords); 47 | CHECK_INPUT(corr_grad); 48 | 49 | return corr_cuda_backward(fmap1, fmap2, coords, corr_grad, radius); 50 | } 51 | 52 | 53 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 54 | m.def("forward", &corr_forward, "CORR forward"); 55 | m.def("backward", &corr_backward, "CORR backward"); 56 | } -------------------------------------------------------------------------------- /third_party/RAFT/alt_cuda_corr/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from setuptools import setup 4 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 5 | 6 | 7 | setup( 8 | name='correlation', 9 | ext_modules=[ 10 | CUDAExtension('alt_cuda_corr', 11 | sources=['correlation.cpp', 'correlation_kernel.cu'], 12 | extra_compile_args={'cxx': [], 'nvcc': ['-O3']}), 13 | ], 14 | cmdclass={ 15 | 'build_ext': BuildExtension 16 | }) 17 | 18 | -------------------------------------------------------------------------------- /third_party/RAFT/core/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | -------------------------------------------------------------------------------- /third_party/RAFT/core/corr.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from .utils.utils import bilinear_sampler, coords_grid 6 | 7 | try: 8 | import alt_cuda_corr 9 | except: 10 | # alt_cuda_corr is not compiled 11 | pass 12 | 13 | 14 | class CorrBlock: 15 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 16 | self.num_levels = num_levels 17 | self.radius = radius 18 | self.corr_pyramid = [] 19 | 20 | # all pairs correlation 21 | corr = CorrBlock.corr(fmap1, fmap2) 22 | 23 | batch, h1, w1, dim, h2, w2 = corr.shape 24 | corr = corr.reshape(batch*h1*w1, dim, h2, w2) 25 | 26 | self.corr_pyramid.append(corr) 27 | for i in range(self.num_levels-1): 28 | corr = F.avg_pool2d(corr, 2, stride=2) 29 | self.corr_pyramid.append(corr) 30 | 31 | def __call__(self, coords): 32 | r = self.radius 33 | coords = coords.permute(0, 2, 3, 1) 34 | batch, h1, w1, _ = coords.shape 35 | 36 | out_pyramid = [] 37 | for i in range(self.num_levels): 38 | corr = self.corr_pyramid[i] 39 | dx = torch.linspace(-r, r, 2*r+1, device=coords.device) 40 | dy = torch.linspace(-r, r, 2*r+1, device=coords.device) 41 | delta = torch.stack(torch.meshgrid(dy, dx), axis=-1) 42 | 43 | centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i 44 | delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) 45 | coords_lvl = centroid_lvl + delta_lvl 46 | 47 | corr = bilinear_sampler(corr, coords_lvl) 48 | corr = corr.view(batch, h1, w1, -1) 49 | out_pyramid.append(corr) 50 | 51 | out = torch.cat(out_pyramid, dim=-1) 52 | return out.permute(0, 3, 1, 2).contiguous().float() 53 | 54 | @staticmethod 55 | def corr(fmap1, fmap2): 56 | batch, dim, ht, wd = fmap1.shape 57 | fmap1 = fmap1.view(batch, dim, ht*wd) 58 | fmap2 = fmap2.view(batch, dim, ht*wd) 59 | 60 | corr = torch.matmul(fmap1.transpose(1,2), fmap2) 61 | corr = corr.view(batch, ht, wd, 1, ht, wd) 62 | return corr / torch.sqrt(torch.tensor(dim).float()) 63 | 64 | 65 | class AlternateCorrBlock: 66 | def __init__(self, fmap1, fmap2, num_levels=4, radius=4): 67 | self.num_levels = num_levels 68 | self.radius = radius 69 | 70 | self.pyramid = [(fmap1, fmap2)] 71 | for i in range(self.num_levels): 72 | fmap1 = F.avg_pool2d(fmap1, 2, stride=2) 73 | fmap2 = F.avg_pool2d(fmap2, 2, stride=2) 74 | self.pyramid.append((fmap1, fmap2)) 75 | 76 | def __call__(self, coords): 77 | coords = coords.permute(0, 2, 3, 1) 78 | B, H, W, _ = coords.shape 79 | dim = self.pyramid[0][0].shape[1] 80 | 81 | corr_list = [] 82 | for i in range(self.num_levels): 83 | r = self.radius 84 | fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1).contiguous() 85 | fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1).contiguous() 86 | 87 | coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() 88 | corr, = alt_cuda_corr.forward(fmap1_i, fmap2_i, coords_i, r) 89 | corr_list.append(corr.squeeze(1)) 90 | 91 | corr = torch.stack(corr_list, dim=1) 92 | corr = corr.reshape(B, -1, H, W) 93 | return corr / torch.sqrt(torch.tensor(dim).float()) 94 | -------------------------------------------------------------------------------- /third_party/RAFT/core/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | -------------------------------------------------------------------------------- /third_party/RAFT/core/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from scipy import interpolate 7 | 8 | 9 | class InputPadder: 10 | """ Pads images such that dimensions are divisible by 8 """ 11 | def __init__(self, dims, mode='sintel'): 12 | self.ht, self.wd = dims[-2:] 13 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 14 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 15 | if mode == 'sintel': 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 17 | else: 18 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 19 | 20 | def pad(self, *inputs): 21 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 22 | 23 | def unpad(self,x): 24 | ht, wd = x.shape[-2:] 25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 26 | return x[..., c[0]:c[1], c[2]:c[3]] 27 | 28 | def forward_interpolate(flow): 29 | flow = flow.detach().cpu().numpy() 30 | dx, dy = flow[0], flow[1] 31 | 32 | ht, wd = dx.shape 33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 34 | 35 | x1 = x0 + dx 36 | y1 = y0 + dy 37 | 38 | x1 = x1.reshape(-1) 39 | y1 = y1.reshape(-1) 40 | dx = dx.reshape(-1) 41 | dy = dy.reshape(-1) 42 | 43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 44 | x1 = x1[valid] 45 | y1 = y1[valid] 46 | dx = dx[valid] 47 | dy = dy[valid] 48 | 49 | flow_x = interpolate.griddata( 50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 51 | 52 | flow_y = interpolate.griddata( 53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 54 | 55 | flow = np.stack([flow_x, flow_y], axis=0) 56 | return torch.from_numpy(flow).float() 57 | 58 | 59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 60 | """ Wrapper for grid_sample, uses pixel coordinates """ 61 | H, W = img.shape[-2:] 62 | xgrid, ygrid = coords.split([1,1], dim=-1) 63 | xgrid = 2*xgrid/(W-1) - 1 64 | ygrid = 2*ygrid/(H-1) - 1 65 | 66 | grid = torch.cat([xgrid, ygrid], dim=-1) 67 | img = F.grid_sample(img, grid, align_corners=True) 68 | 69 | if mask: 70 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 71 | return img, mask.float() 72 | 73 | return img 74 | 75 | 76 | def coords_grid(batch, ht, wd, device): 77 | coords = torch.meshgrid(torch.arange(ht, device=device), torch.arange(wd, device=device)) 78 | coords = torch.stack(coords[::-1], dim=0).float() 79 | return coords[None].repeat(batch, 1, 1, 1) 80 | 81 | 82 | def upflow8(flow, mode='bilinear'): 83 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 84 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 85 | -------------------------------------------------------------------------------- /third_party/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import sys 4 | sys.path.append('core') 5 | 6 | import argparse 7 | import os 8 | import cv2 9 | import glob 10 | import numpy as np 11 | import torch 12 | from PIL import Image 13 | 14 | from raft import RAFT 15 | from utils import flow_viz 16 | from utils.utils import InputPadder 17 | 18 | 19 | 20 | DEVICE = 'cuda' 21 | 22 | def load_image(imfile): 23 | img = np.array(Image.open(imfile)).astype(np.uint8) 24 | img = torch.from_numpy(img).permute(2, 0, 1).float() 25 | return img[None].to(DEVICE) 26 | 27 | 28 | def viz(img, flo): 29 | img = img[0].permute(1,2,0).cpu().numpy() 30 | flo = flo[0].permute(1,2,0).cpu().numpy() 31 | 32 | # map flow to rgb image 33 | flo = flow_viz.flow_to_image(flo) 34 | img_flo = np.concatenate([img, flo], axis=0) 35 | 36 | # import matplotlib.pyplot as plt 37 | # plt.imshow(img_flo / 255.0) 38 | # plt.show() 39 | 40 | cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 41 | cv2.waitKey() 42 | 43 | 44 | def demo(args): 45 | model = torch.nn.DataParallel(RAFT(args)) 46 | model.load_state_dict(torch.load(args.model)) 47 | 48 | model = model.module 49 | model.to(DEVICE) 50 | model.eval() 51 | 52 | with torch.no_grad(): 53 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 54 | glob.glob(os.path.join(args.path, '*.jpg')) 55 | 56 | images = sorted(images) 57 | for imfile1, imfile2 in zip(images[:-1], images[1:]): 58 | image1 = load_image(imfile1) 59 | image2 = load_image(imfile2) 60 | 61 | padder = InputPadder(image1.shape) 62 | image1, image2 = padder.pad(image1, image2) 63 | 64 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 65 | viz(image1, flow_up) 66 | 67 | 68 | if __name__ == '__main__': 69 | parser = argparse.ArgumentParser() 70 | parser.add_argument('--model', help="restore checkpoint") 71 | parser.add_argument('--path', help="dataset for evaluation") 72 | parser.add_argument('--small', action='store_true', help='use small model') 73 | parser.add_argument('--mixed_precision', action='store_true', help='use mixed precision') 74 | parser.add_argument('--alternate_corr', action='store_true', help='use efficent correlation implementation') 75 | args = parser.parse_args() 76 | 77 | demo(args) 78 | -------------------------------------------------------------------------------- /third_party/RAFT/download_models.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | #!/bin/bash 4 | wget https://www.dropbox.com/s/4j4z58wuv8o0mfz/models.zip 5 | unzip models.zip 6 | -------------------------------------------------------------------------------- /third_party/RAFT/train_mixed.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | #!/bin/bash 4 | mkdir -p checkpoints 5 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 --num_steps 120000 --batch_size 8 --lr 0.00025 --image_size 368 496 --wdecay 0.0001 --mixed_precision 6 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 400 720 --wdecay 0.0001 --mixed_precision 7 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 --num_steps 120000 --batch_size 5 --lr 0.0001 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 --mixed_precision 8 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 --num_steps 50000 --batch_size 5 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 --mixed_precision 9 | -------------------------------------------------------------------------------- /third_party/RAFT/train_standard.sh: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | #!/bin/bash 4 | mkdir -p checkpoints 5 | python -u train.py --name raft-chairs --stage chairs --validation chairs --gpus 0 1 --num_steps 100000 --batch_size 10 --lr 0.0004 --image_size 368 496 --wdecay 0.0001 6 | python -u train.py --name raft-things --stage things --validation sintel --restore_ckpt checkpoints/raft-chairs.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 400 720 --wdecay 0.0001 7 | python -u train.py --name raft-sintel --stage sintel --validation sintel --restore_ckpt checkpoints/raft-things.pth --gpus 0 1 --num_steps 100000 --batch_size 6 --lr 0.000125 --image_size 368 768 --wdecay 0.00001 --gamma=0.85 8 | python -u train.py --name raft-kitti --stage kitti --validation kitti --restore_ckpt checkpoints/raft-sintel.pth --gpus 0 1 --num_steps 50000 --batch_size 6 --lr 0.0001 --image_size 288 960 --wdecay 0.00001 --gamma=0.85 9 | -------------------------------------------------------------------------------- /third_party/RoDynRF/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | __pycache__ 3 | dataset 4 | log 5 | weights 6 | /local 7 | -------------------------------------------------------------------------------- /third_party/RoDynRF/configs/DAVIS_CAM.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | dataset_name = nvidia_pose 4 | 5 | multiGPU = [0] 6 | downsample_train = 1.0 7 | ray_type = contract 8 | with_GT_poses = 0 9 | optimize_focal_length = 1 10 | step_ratio = 2.0 11 | 12 | N_voxel_t = -1 13 | n_iters = 10000 14 | batch_size = 1024 15 | 16 | N_voxel_init = 4096 # 16**3 # 32 O, 16 X, 8 O, 4 O 17 | N_voxel_final = 27000000 # 300**3 18 | upsamp_list = [2000, 4000, 6000, 8000, 12000, 16000, 22000] 19 | update_AlphaMask_list = [300000000] 20 | 21 | N_vis = 3 22 | vis_train_every = 0 23 | vis_full_every = 10000 24 | progress_refresh_rate = 1000 25 | save_every = 10000 26 | 27 | render_test = 1 28 | render_path = 0 29 | 30 | model_name = TensorVMSplit_TimeEmbedding 31 | n_lamb_sigma = [16, 4, 4] 32 | n_lamb_sh = [48, 12, 12] 33 | 34 | shadingMode = MLP_Fea_late_view 35 | 36 | fea2denseAct = relu 37 | 38 | view_pe = 0 39 | fea_pe = 0 40 | 41 | L1_weight_inital = 8e-5 42 | TV_weight_density = 0.0 43 | TV_weight_app = 0.0 44 | distortion_weight_static = 0.04 45 | distortion_weight_dynamic = 0.02 46 | monodepth_weight_static = 0.04 47 | 48 | optimize_poses = 1 49 | use_time_embedding = 0 50 | multiview_dataset = 0 51 | 52 | use_foreground_mask = epipolar_error_png 53 | use_disp = 1 54 | -------------------------------------------------------------------------------- /third_party/RoDynRF/configs/REALWORLD_CAM.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | dataset_name = nvidia_pose 4 | 5 | multiGPU = [0] 6 | downsample_train = 2.0 7 | ray_type = contract 8 | with_GT_poses = 0 9 | optimize_focal_length = 1 10 | step_ratio = 2.0 11 | 12 | N_voxel_t = -1 13 | n_iters = 10000 14 | batch_size = 1024 15 | 16 | N_voxel_init = 4096 # 16**3 # 32 O, 16 X, 8 O, 4 O 17 | N_voxel_final = 64000000 # 400**3 18 | upsamp_list = [2000, 4000, 6000, 8000, 12000, 16000, 22000] 19 | update_AlphaMask_list = [300000000] 20 | 21 | N_vis = 3 22 | vis_train_every = 0 23 | vis_full_every = 10000 24 | progress_refresh_rate = 1000 25 | save_every = 10000 26 | 27 | render_test = 1 28 | render_path = 0 29 | 30 | model_name = TensorVMSplit_TimeEmbedding 31 | n_lamb_sigma = [16, 4, 4] 32 | n_lamb_sh = [48, 12, 12] 33 | 34 | shadingMode = MLP_Fea_late_view 35 | shadingModeStatic = MLP_Fea 36 | fea2denseAct = relu 37 | 38 | view_pe = 0 39 | fea_pe = 0 40 | 41 | 42 | TV_weight_density = 0.0 43 | TV_weight_app = 0.0 44 | distortion_weight_static = 0.04 45 | distortion_weight_dynamic = 0.02 46 | monodepth_weight_static = 0.04 47 | 48 | optimize_poses = 1 49 | use_time_embedding = 0 50 | multiview_dataset = 0 51 | 52 | use_foreground_mask = epipolar_error_png 53 | use_disp = 1 54 | -------------------------------------------------------------------------------- /third_party/RoDynRF/configs/REALWORLD_CAM_NDC.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | dataset_name = nvidia_pose 4 | 5 | multiGPU = [0] 6 | downsample_train = 2.0 7 | ray_type = ndc 8 | with_GT_poses = 0 9 | optimize_focal_length = 1 10 | step_ratio = 2.0 11 | 12 | N_voxel_t = -1 13 | n_iters = 10000 14 | batch_size = 1024 15 | 16 | N_voxel_init = 4096 # 16**3 # 32 O, 16 X, 8 O, 4 O 17 | N_voxel_final = 64000000 # 400**3 18 | upsamp_list = [2000, 4000, 6000, 8000, 12000, 16000, 22000] 19 | update_AlphaMask_list = [300000000] 20 | 21 | N_vis = 3 22 | vis_train_every = 0 23 | vis_full_every = 10000 24 | progress_refresh_rate = 1000 25 | save_every = 10000 26 | 27 | render_test = 1 28 | render_path = 0 29 | 30 | model_name = TensorVMSplit_TimeEmbedding 31 | n_lamb_sigma = [16, 4, 4] 32 | n_lamb_sh = [48, 12, 12] 33 | 34 | shadingMode = MLP_Fea_late_view 35 | shadingModeStatic = MLP_Fea 36 | fea2denseAct = relu 37 | 38 | view_pe = 0 39 | fea_pe = 0 40 | 41 | TV_weight_density = 0.0 42 | TV_weight_app = 0.0 43 | distortion_weight_static = 0.04 44 | distortion_weight_dynamic = 0.02 45 | monodepth_weight_static = 0.04 46 | 47 | optimize_poses = 1 48 | use_time_embedding = 0 49 | multiview_dataset = 0 50 | 51 | use_foreground_mask = epipolar_error_png 52 | use_disp = 1 53 | -------------------------------------------------------------------------------- /third_party/RoDynRF/dataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .nvidia_pose import NvidiaPoseDataset 4 | 5 | dataset_dict = {'nvidia_pose':NvidiaPoseDataset} 6 | -------------------------------------------------------------------------------- /third_party/RoDynRF/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | -------------------------------------------------------------------------------- /third_party/RoDynRF/scripts/RAFT/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # from .demo import RAFT_infer 4 | from .raft import RAFT 5 | -------------------------------------------------------------------------------- /third_party/RoDynRF/scripts/RAFT/demo.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import sys 4 | import argparse 5 | import os 6 | import cv2 7 | import glob 8 | import numpy as np 9 | import torch 10 | from PIL import Image 11 | 12 | from .raft import RAFT 13 | from .utils import flow_viz 14 | from .utils.utils import InputPadder 15 | 16 | 17 | 18 | DEVICE = 'cuda' 19 | 20 | def load_image(imfile): 21 | img = np.array(Image.open(imfile)).astype(np.uint8) 22 | img = torch.from_numpy(img).permute(2, 0, 1).float() 23 | return img 24 | 25 | 26 | def load_image_list(image_files): 27 | images = [] 28 | for imfile in sorted(image_files): 29 | images.append(load_image(imfile)) 30 | 31 | images = torch.stack(images, dim=0) 32 | images = images.to(DEVICE) 33 | 34 | padder = InputPadder(images.shape) 35 | return padder.pad(images)[0] 36 | 37 | 38 | def viz(img, flo): 39 | img = img[0].permute(1,2,0).cpu().numpy() 40 | flo = flo[0].permute(1,2,0).cpu().numpy() 41 | 42 | # map flow to rgb image 43 | flo = flow_viz.flow_to_image(flo) 44 | # img_flo = np.concatenate([img, flo], axis=0) 45 | img_flo = flo 46 | 47 | cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) 48 | # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) 49 | # cv2.waitKey() 50 | 51 | 52 | def demo(args): 53 | model = torch.nn.DataParallel(RAFT(args)) 54 | model.load_state_dict(torch.load(args.model)) 55 | 56 | model = model.module 57 | model.to(DEVICE) 58 | model.eval() 59 | 60 | with torch.no_grad(): 61 | images = glob.glob(os.path.join(args.path, '*.png')) + \ 62 | glob.glob(os.path.join(args.path, '*.jpg')) 63 | 64 | images = load_image_list(images) 65 | for i in range(images.shape[0]-1): 66 | image1 = images[i,None] 67 | image2 = images[i+1,None] 68 | 69 | flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) 70 | viz(image1, flow_up) 71 | 72 | 73 | def RAFT_infer(args): 74 | model = torch.nn.DataParallel(RAFT(args)) 75 | model.load_state_dict(torch.load(args.model)) 76 | 77 | model = model.module 78 | model.to(DEVICE) 79 | model.eval() 80 | 81 | return model 82 | -------------------------------------------------------------------------------- /third_party/RoDynRF/scripts/RAFT/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .flow_viz import flow_to_image 4 | from .frame_utils import writeFlow 5 | -------------------------------------------------------------------------------- /third_party/RoDynRF/scripts/RAFT/utils/utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from scipy import interpolate 7 | 8 | 9 | class InputPadder: 10 | """ Pads images such that dimensions are divisible by 8 """ 11 | def __init__(self, dims, mode='sintel'): 12 | self.ht, self.wd = dims[-2:] 13 | pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 14 | pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 15 | if mode == 'sintel': 16 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] 17 | else: 18 | self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] 19 | 20 | def pad(self, *inputs): 21 | return [F.pad(x, self._pad, mode='replicate') for x in inputs] 22 | 23 | def unpad(self,x): 24 | ht, wd = x.shape[-2:] 25 | c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] 26 | return x[..., c[0]:c[1], c[2]:c[3]] 27 | 28 | def forward_interpolate(flow): 29 | flow = flow.detach().cpu().numpy() 30 | dx, dy = flow[0], flow[1] 31 | 32 | ht, wd = dx.shape 33 | x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) 34 | 35 | x1 = x0 + dx 36 | y1 = y0 + dy 37 | 38 | x1 = x1.reshape(-1) 39 | y1 = y1.reshape(-1) 40 | dx = dx.reshape(-1) 41 | dy = dy.reshape(-1) 42 | 43 | valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) 44 | x1 = x1[valid] 45 | y1 = y1[valid] 46 | dx = dx[valid] 47 | dy = dy[valid] 48 | 49 | flow_x = interpolate.griddata( 50 | (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) 51 | 52 | flow_y = interpolate.griddata( 53 | (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) 54 | 55 | flow = np.stack([flow_x, flow_y], axis=0) 56 | return torch.from_numpy(flow).float() 57 | 58 | 59 | def bilinear_sampler(img, coords, mode='bilinear', mask=False): 60 | """ Wrapper for grid_sample, uses pixel coordinates """ 61 | H, W = img.shape[-2:] 62 | xgrid, ygrid = coords.split([1,1], dim=-1) 63 | xgrid = 2*xgrid/(W-1) - 1 64 | ygrid = 2*ygrid/(H-1) - 1 65 | 66 | grid = torch.cat([xgrid, ygrid], dim=-1) 67 | img = F.grid_sample(img, grid, align_corners=True) 68 | 69 | if mask: 70 | mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) 71 | return img, mask.float() 72 | 73 | return img 74 | 75 | 76 | def coords_grid(batch, ht, wd): 77 | coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) 78 | coords = torch.stack(coords[::-1], dim=0).float() 79 | return coords[None].repeat(batch, 1, 1, 1) 80 | 81 | 82 | def upflow8(flow, mode='bilinear'): 83 | new_size = (8 * flow.shape[2], 8 * flow.shape[3]) 84 | return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) 85 | -------------------------------------------------------------------------------- /third_party/RoDynRF/scripts/midas/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | 5 | 6 | class BaseModel(torch.nn.Module): 7 | def load(self, path): 8 | """Load model from file. 9 | Args: 10 | path (str): file path 11 | """ 12 | parameters = torch.load(path, map_location=torch.device('cpu')) 13 | 14 | if "optimizer" in parameters: 15 | parameters = parameters["model"] 16 | 17 | self.load_state_dict(parameters) 18 | -------------------------------------------------------------------------------- /third_party/RoDynRF/scripts/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 4 | This file contains code that is adapted from 5 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 6 | """ 7 | import torch 8 | import torch.nn as nn 9 | 10 | from .base_model import BaseModel 11 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 12 | 13 | 14 | class MidasNet(BaseModel): 15 | """Network for monocular depth estimation. 16 | """ 17 | 18 | def __init__(self, path=None, features=256, non_negative=True): 19 | """Init. 20 | 21 | Args: 22 | path (str, optional): Path to saved model. Defaults to None. 23 | features (int, optional): Number of features. Defaults to 256. 24 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 25 | """ 26 | print("Loading weights: ", path) 27 | 28 | super(MidasNet, self).__init__() 29 | 30 | use_pretrained = False if path is None else True 31 | 32 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 33 | 34 | self.scratch.refinenet4 = FeatureFusionBlock(features) 35 | self.scratch.refinenet3 = FeatureFusionBlock(features) 36 | self.scratch.refinenet2 = FeatureFusionBlock(features) 37 | self.scratch.refinenet1 = FeatureFusionBlock(features) 38 | 39 | self.scratch.output_conv = nn.Sequential( 40 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 41 | Interpolate(scale_factor=2, mode="bilinear"), 42 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 43 | nn.ReLU(True), 44 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 45 | nn.ReLU(True) if non_negative else nn.Identity(), 46 | ) 47 | 48 | if path: 49 | self.load(path) 50 | 51 | def forward(self, x): 52 | """Forward pass. 53 | 54 | Args: 55 | x (tensor): input data (image) 56 | 57 | Returns: 58 | tensor: depth 59 | """ 60 | 61 | layer_1 = self.pretrained.layer1(x) 62 | layer_2 = self.pretrained.layer2(layer_1) 63 | layer_3 = self.pretrained.layer3(layer_2) 64 | layer_4 = self.pretrained.layer4(layer_3) 65 | 66 | layer_1_rn = self.scratch.layer1_rn(layer_1) 67 | layer_2_rn = self.scratch.layer2_rn(layer_2) 68 | layer_3_rn = self.scratch.layer3_rn(layer_3) 69 | layer_4_rn = self.scratch.layer4_rn(layer_4) 70 | 71 | path_4 = self.scratch.refinenet4(layer_4_rn) 72 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 73 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 74 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 75 | 76 | out = self.scratch.output_conv(path_1) 77 | 78 | return torch.squeeze(out, dim=1) 79 | -------------------------------------------------------------------------------- /third_party/TensoRF/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Anpei Chen 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /third_party/TensoRF/METADATA: -------------------------------------------------------------------------------- 1 | Code from: https://github.com/apchenstu/TensoRF.git 2 | Commit: 17deeedae5ab4106b30a3295709ec3a8a654c7b1 3 | -------------------------------------------------------------------------------- /third_party/TensoRF/configs/flower.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | 4 | dataset_name = llff 5 | datadir = ./data/nerf_llff_data/flower 6 | expname = tensorf_flower_VM 7 | basedir = ./log 8 | 9 | downsample_train = 4.0 10 | ndc_ray = 1 11 | 12 | n_iters = 25000 13 | batch_size = 4096 14 | 15 | N_voxel_init = 2097156 # 128**3 16 | N_voxel_final = 262144000 # 640**3 17 | upsamp_list = [2000,3000,4000,5500] 18 | update_AlphaMask_list = [2500] 19 | 20 | N_vis = -1 # vis all testing images 21 | vis_every = 10000 22 | 23 | render_test = 1 24 | render_path = 1 25 | 26 | n_lamb_sigma = [16,4,4] 27 | n_lamb_sh = [48,12,12] 28 | 29 | shadingMode = MLP_Fea 30 | fea2denseAct = relu 31 | 32 | view_pe = 0 33 | fea_pe = 0 34 | 35 | TV_weight_density = 1.0 36 | TV_weight_app = 1.0 37 | 38 | -------------------------------------------------------------------------------- /third_party/TensoRF/configs/lego.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | 4 | dataset_name = blender 5 | datadir = ./data/nerf_synthetic/lego 6 | expname = tensorf_lego_VM 7 | basedir = ./log 8 | 9 | n_iters = 30000 10 | batch_size = 4096 11 | 12 | N_voxel_init = 2097156 # 128**3 13 | N_voxel_final = 27000000 # 300**3 14 | upsamp_list = [2000,3000,4000,5500,7000] 15 | update_AlphaMask_list = [2000,4000] 16 | 17 | N_vis = 5 18 | vis_every = 10000 19 | 20 | render_test = 1 21 | 22 | n_lamb_sigma = [16,16,16] 23 | n_lamb_sh = [48,48,48] 24 | model_name = TensorVMSplit 25 | 26 | 27 | shadingMode = MLP_Fea 28 | fea2denseAct = softplus 29 | 30 | view_pe = 2 31 | fea_pe = 2 32 | 33 | L1_weight_inital = 8e-5 34 | L1_weight_rest = 4e-5 35 | rm_weight_mask_thre = 1e-4 36 | 37 | ## please uncomment following configuration if hope to training on cp model 38 | #model_name = TensorCP 39 | #n_lamb_sigma = [96] 40 | #n_lamb_sh = [288] 41 | #N_voxel_final = 125000000 # 500**3 42 | #L1_weight_inital = 1e-5 43 | #L1_weight_rest = 1e-5 44 | -------------------------------------------------------------------------------- /third_party/TensoRF/configs/truck.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | 4 | 5 | dataset_name = tankstemple 6 | datadir = ./data/TanksAndTemple/Truck 7 | expname = tensorf_truck_VM 8 | basedir = ./log 9 | 10 | n_iters = 30000 11 | batch_size = 4096 12 | 13 | N_voxel_init = 2097156 # 128**3 14 | N_voxel_final = 27000000 # 300**3 15 | upsamp_list = [2000,3000,4000,5500,7000] 16 | update_AlphaMask_list = [2000,4000] 17 | 18 | N_vis = 5 19 | vis_every = 10000 20 | 21 | render_test = 1 22 | 23 | n_lamb_sigma = [16,16,16] 24 | n_lamb_sh = [48,48,48] 25 | 26 | shadingMode = MLP_Fea 27 | fea2denseAct = softplus 28 | 29 | view_pe = 2 30 | fea_pe = 2 31 | 32 | TV_weight_density = 0.1 33 | TV_weight_app = 0.01 34 | 35 | ## please uncomment following configuration if hope to training on cp model 36 | #model_name = TensorCP 37 | #n_lamb_sigma = [96] 38 | #n_lamb_sh = [288] 39 | #N_voxel_final = 125000000 # 500**3 40 | #L1_weight_inital = 1e-5 41 | #L1_weight_rest = 1e-5 42 | 43 | -------------------------------------------------------------------------------- /third_party/TensoRF/configs/wineholder.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | 4 | dataset_name = nsvf 5 | datadir = ./data/Synthetic_NSVF/Wineholder 6 | expname = tensorf_Wineholder_VM 7 | basedir = ./log 8 | 9 | n_iters = 30000 10 | batch_size = 4096 11 | 12 | N_voxel_init = 2097156 # 128**3 13 | N_voxel_final = 27000000 # 300**3 14 | upsamp_list = [2000,3000,4000,5500,7000] 15 | update_AlphaMask_list = [2000,4000] 16 | 17 | N_vis = 5 18 | vis_every = 10000 19 | 20 | render_test = 1 21 | 22 | n_lamb_sigma = [16,16,16] 23 | n_lamb_sh = [48,48,48] 24 | 25 | shadingMode = MLP_Fea 26 | fea2denseAct = softplus 27 | 28 | view_pe = 2 29 | fea_pe = 2 30 | 31 | L1_weight_inital = 8e-5 32 | L1_weight_rest = 4e-5 33 | rm_weight_mask_thre = 1e-4 34 | 35 | ## please uncomment following configuration if hope to training on cp model 36 | #model_name = TensorCP 37 | #n_lamb_sigma = [96] 38 | #n_lamb_sh = [288] 39 | #N_voxel_final = 125000000 # 500**3 40 | #L1_weight_inital = 1e-5 41 | #L1_weight_rest = 1e-5 42 | -------------------------------------------------------------------------------- /third_party/TensoRF/configs/your_own_data.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | 4 | dataset_name = own_data 5 | datadir = ./data/xxx 6 | expname = tensorf_xxx_VM 7 | basedir = ./log 8 | 9 | n_iters = 30000 10 | batch_size = 4096 11 | 12 | N_voxel_init = 2097156 # 128**3 13 | N_voxel_final = 27000000 # 300**3 14 | upsamp_list = [2000,3000,4000,5500,7000] 15 | update_AlphaMask_list = [2000,4000] 16 | 17 | N_vis = 5 18 | vis_every = 10000 19 | 20 | render_test = 1 21 | 22 | n_lamb_sigma = [16,16,16] 23 | n_lamb_sh = [48,48,48] 24 | model_name = TensorVMSplit 25 | 26 | 27 | shadingMode = MLP_Fea 28 | fea2denseAct = softplus 29 | 30 | view_pe = 2 31 | fea_pe = 2 32 | 33 | view_pe = 2 34 | fea_pe = 2 35 | 36 | TV_weight_density = 0.1 37 | TV_weight_app = 0.01 38 | 39 | rm_weight_mask_thre = 1e-4 40 | 41 | ## please uncomment following configuration if hope to training on cp model 42 | #model_name = TensorCP 43 | #n_lamb_sigma = [96] 44 | #n_lamb_sh = [288] 45 | #N_voxel_final = 125000000 # 500**3 46 | #L1_weight_inital = 1e-5 47 | #L1_weight_rest = 1e-5 48 | -------------------------------------------------------------------------------- /third_party/TensoRF/dataLoader/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .llff import LLFFDataset 4 | from .blender import BlenderDataset 5 | from .nsvf import NSVF 6 | from .tankstemple import TanksTempleDataset 7 | from .your_own_data import YourOwnDataset 8 | 9 | 10 | 11 | dataset_dict = {'blender': BlenderDataset, 12 | 'llff':LLFFDataset, 13 | 'tankstemple':TanksTempleDataset, 14 | 'nsvf':NSVF, 15 | 'own_data':YourOwnDataset} -------------------------------------------------------------------------------- /third_party/TensoRF/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | -------------------------------------------------------------------------------- /third_party/omnimatte/CITATION.cff: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | cff-version: 1.2.0 4 | message: "If you use this code for your research, please cite the following paper:" 5 | authors: 6 | - family-names: "Lu" 7 | given-names: "Erika" 8 | - family-names: "Cole" 9 | given-names: "Forrester" 10 | - family-names: "Dekel" 11 | given-names: "Tali" 12 | - family-names: "Zisserman" 13 | given-names: "Andrew" 14 | - family-names: "Freeman" 15 | given-names: "William T" 16 | - family-names: "Rubinstein" 17 | given-names: "Michael" 18 | title: "Omnimatte: Associating Objects and Their Effects in Video" 19 | year: 2021 20 | preferred-citation: 21 | type: inproceedings 22 | authors: 23 | - family-names: "Lu" 24 | given-names: "Erika" 25 | - family-names: "Cole" 26 | given-names: "Forrester" 27 | - family-names: "Dekel" 28 | given-names: "Tali" 29 | - family-names: "Zisserman" 30 | given-names: "Andrew" 31 | - family-names: "Freeman" 32 | given-names: "William T" 33 | - family-names: "Rubinstein" 34 | given-names: "Michael" 35 | title: "Omnimatte: Associating Objects and Their Effects in Video" 36 | year: 2021 37 | -------------------------------------------------------------------------------- /third_party/omnimatte/METADATA: -------------------------------------------------------------------------------- 1 | Code from: https://github.com/erikalu/omnimatte.git 2 | Commit: 2f45a6e479f2fa29c5a9d955d9b8b9b997491b64 3 | -------------------------------------------------------------------------------- /third_party/omnimatte/README.md: -------------------------------------------------------------------------------- 1 | # Omnimatte in PyTorch 2 | 3 | This repository contains a re-implementation of the code for the CVPR 2021 paper "[Omnimatte: Associating Objects and Their Effects in Video](https://omnimatte.github.io/)." 4 | 5 | 6 | 7 | 8 | ## Prerequisites 9 | - Linux 10 | - Python 3.6+ 11 | - NVIDIA GPU + CUDA CuDNN 12 | 13 | ## Installation 14 | This code has been tested with PyTorch 1.8 and Python 3.8. 15 | 16 | - Install [PyTorch](http://pytorch.org) 1.8 and other dependencies. 17 | - For pip users, please type the command `pip install -r requirements.txt`. 18 | - For Conda users, you can create a new Conda environment using `conda env create -f environment.yml`. 19 | 20 | ## Demo 21 | To train a model on a video (e.g. "tennis"), run: 22 | ```bash 23 | python train.py --name tennis --dataroot ./datasets/tennis --gpu_ids 0,1 24 | ``` 25 | To view training results and loss plots, visit the URL http://localhost:8097. 26 | Intermediate results are also at `./checkpoints/tennis/web/index.html`. 27 | 28 | To save the omnimatte layer outputs of the trained model, run: 29 | ```bash 30 | python test.py --name tennis --dataroot ./datasets/tennis --gpu_ids 0 31 | ``` 32 | The results (RGBA layers, videos) will be saved to `./results/tennis/test_latest/`. 33 | 34 | ## Custom video 35 | To train on your own video, you will have to preprocess the data: 36 | 1. Extract the frames, e.g. 37 | ``` 38 | mkdir ./datasets/my_video && cd ./datasets/my_video 39 | mkdir rgb && ffmpeg -i video.mp4 rgb/%04d.png 40 | ``` 41 | 1. Resize the video to 256x448 and save the frames in `my_video/rgb`. 42 | 1. Get input object masks (e.g. using [Mask-RCNN](https://github.com/facebookresearch/detectron2) and [STM](https://github.com/seoungwugoh/STM)), save each object's masks in its own subdirectory, e.g. `my_video/mask/01/`, `my_video/mask/02/`, etc. 43 | 1. Compute flow (e.g. using [RAFT](https://github.com/princeton-vl/RAFT)), and save the forward .flo files to `my_video/flow` and backward flow to `my_video/flow_backward` 44 | 1. Compute the confidence maps from the forward/backward flows: 45 | ```bash 46 | python datasets/confidence.py --dataroot ./datasets/tennis 47 | ``` 48 | 1. Register the video and save the computed homographies in `my_video/homographies.txt`. 49 | See [here](docs/data.md#camera-registration) for details. 50 | 51 | **Note**: Videos that are suitable for our method have the following attributes: 52 | - Static camera or limited camera motion that can be represented with a homography. 53 | - Limited number of omnimatte layers, due to GPU memory limitations. We tested up to 6 layers. 54 | - Objects that move relative to the background (static objects will be absorbed into the background layer). 55 | - We tested a video length of up to 200 frames (~7 seconds). 56 | 57 | ## Citation 58 | If you use this code for your research, please cite the following paper: 59 | ``` 60 | @inproceedings{lu2021, 61 | title={Omnimatte: Associating Objects and Their Effects in Video}, 62 | author={Lu, Erika and Cole, Forrester and Dekel, Tali and Zisserman, Andrew and Freeman, William T and Rubinstein, Michael}, 63 | booktitle={CVPR}, 64 | year={2021} 65 | } 66 | ``` 67 | 68 | ## Acknowledgments 69 | This code is based on [retiming](https://github.com/google/retiming) and [pytorch-CycleGAN-and-pix2pix](https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix). 70 | -------------------------------------------------------------------------------- /third_party/omnimatte/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | -------------------------------------------------------------------------------- /third_party/omnimatte/datasets/homography.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | # Copyright 2021 Erika Lu 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | 18 | """Helper tools for computing the world bounds from homographies.""" 19 | import numpy as np 20 | 21 | 22 | def transform2h(x, y, m): 23 | """Applies 2d homogeneous transformation.""" 24 | A = np.dot(m, np.array([x, y, np.ones(len(x))])) 25 | xt = A[0, :] / A[2, :] 26 | yt = A[1, :] / A[2, :] 27 | return xt, yt 28 | 29 | 30 | def compute_world_bounds(homographies, height, width): 31 | """Compute minimum and maximum coordinates. 32 | 33 | homographies - list of 3x3 numpy arrays 34 | height, width - video dimensions 35 | """ 36 | xbounds = [0, width - 1] 37 | ybounds = [0, height - 1] 38 | 39 | for h in homographies: 40 | # find transformed image bounding box 41 | x = np.array([0, width - 1, 0, width - 1]) 42 | y = np.array([0, 0, height - 1, height - 1]) 43 | [xt, yt] = transform2h(x, y, np.linalg.inv(h)) 44 | xbounds[0] = min(xbounds[0], min(xt)) 45 | xbounds[1] = max(xbounds[1], max(xt)) 46 | ybounds[0] = min(ybounds[0], min(yt)) 47 | ybounds[1] = max(ybounds[1], max(yt)) 48 | 49 | return xbounds, ybounds 50 | 51 | 52 | def main(argv=None): 53 | import argparse 54 | 55 | arguments = argparse.ArgumentParser() 56 | arguments.add_argument( 57 | "--homography_path", type=str, help="path to file containing homographies" 58 | ) 59 | arguments.add_argument("--width", type=int, help="video width") 60 | arguments.add_argument("--height", type=int, help="video height") 61 | opt = arguments.parse_args(argv) 62 | 63 | if opt.homography_path.endswith(".npy"): 64 | homographies = np.load(opt.homography_path) 65 | lines = [ 66 | " ".join([str(v) for v in H.reshape([-1])]) + "\n" for H in homographies 67 | ] 68 | else: 69 | with open(opt.homography_path) as f: 70 | lines = f.readlines() 71 | homographies = [l.rstrip().split(" ") for l in lines] 72 | homographies = [[float(h) for h in l] for l in homographies] 73 | homographies = [np.array(H).reshape(3, 3) for H in homographies] 74 | 75 | xbounds, ybounds = compute_world_bounds(homographies, opt.height, opt.width) 76 | out_path = f"{opt.homography_path[:-4]}.txt" 77 | with open(out_path, "w") as f: 78 | f.write(f"size: {opt.width} {opt.height}\n") 79 | f.write(f"bounds: {xbounds[0]} {xbounds[1]} {ybounds[0]} {ybounds[1]}\n") 80 | f.writelines(lines) 81 | print(f"saved {out_path}") 82 | 83 | 84 | if __name__ == "__main__": 85 | main() 86 | -------------------------------------------------------------------------------- /third_party/omnimatte/docs/data.md: -------------------------------------------------------------------------------- 1 | ## Data 2 | The data directory for a video is structured as follows: 3 | ``` 4 | video_name/ 5 | |-- rgb/ 6 | | |-- 0001.png, ... 7 | |-- mask/ 8 | |-- |-- 01, ... 9 | |-- |-- |-- 0001.png, ... 10 | |-- flow/ 11 | | |-- 0001.flo, ... 12 | |-- confidence/ 13 | | |-- 0001.png, ... 14 | |-- homographies.txt 15 | ``` 16 | The `mask/` directory should contain a subdirectory for each omnimatte's input masks. 17 | 18 | ### Camera registration 19 | The method requires as input homographies computed between frames (e.g. using OpenCV). 20 | 21 | See `datasets/tennis/homographies.txt` for an example. 22 | 23 | The expected format for `homographies.txt` is: 24 | ``` 25 | size: width height # dimensions of video 26 | bounds: x_min x_max y_min y_max # world bounds 27 | 1 0 0 0 1 0 0 0 1 # homography for frame 1 28 | ... # homography for frame 2, etc. 29 | ``` 30 | After computing the homographies and saving to a text file, 31 | the helper script `datasets/homography.py` can be used to compute the world bounds 32 | and add the first 2 lines expected in the `homographies.txt` file: 33 | ``` 34 | python datasets/homography.py --homography_path path_to_homographies.txt --width vid_width --height vid_height 35 | ``` 36 | This will output `path_to_homographies-final.txt`, which should be renamed to `video_name/homographies.txt`. -------------------------------------------------------------------------------- /third_party/omnimatte/environment.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: omnimatte 4 | channels: 5 | - pytorch 6 | - defaults 7 | dependencies: 8 | - python=3.8 9 | - pytorch=1.8.0 10 | - pip: 11 | - dominate==2.4.0 12 | - torchvision 13 | - Pillow>=6.1.0 14 | - numpy==1.19.2 15 | - visdom==0.1.8 16 | - opencv-python==4.5.1 17 | - matplotlib 18 | 19 | -------------------------------------------------------------------------------- /third_party/omnimatte/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | -------------------------------------------------------------------------------- /third_party/omnimatte/options/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | -------------------------------------------------------------------------------- /third_party/omnimatte/options/test_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .base_options import BaseOptions 4 | 5 | 6 | class TestOptions(BaseOptions): 7 | """This class includes test options. 8 | 9 | It also includes shared options defined in BaseOptions. 10 | """ 11 | 12 | def initialize(self, parser): 13 | parser = BaseOptions.initialize(self, parser) # define shared options 14 | parser.add_argument('--results_dir', type=str, default='./results/', help='saves results here.') 15 | parser.add_argument('--aspect_ratio', type=float, default=1.0, help='aspect ratio of result images') 16 | parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc') 17 | parser.add_argument('--num_test', type=int, default=float("inf"), help='how many test images to run') 18 | parser.add_argument('--use_eval', action='store_true', help='Call model.eval() before test') 19 | self.isTrain = False 20 | return parser 21 | -------------------------------------------------------------------------------- /third_party/omnimatte/options/train_options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from .base_options import BaseOptions 4 | 5 | 6 | class TrainOptions(BaseOptions): 7 | """This class includes training options. 8 | 9 | It also includes shared options defined in BaseOptions. 10 | """ 11 | 12 | def initialize(self, parser): 13 | parser = BaseOptions.initialize(self, parser) 14 | # visdom and HTML visualization parameters 15 | parser.add_argument('--display_freq', type=int, default=20, help='frequency of showing training results on screen (in epochs)') 16 | parser.add_argument('--display_ncols', type=int, default=0, help='if positive, display all images in a single visdom web panel with certain number of images per row.') 17 | parser.add_argument('--display_id', type=int, default=1, help='window id of the web display') 18 | parser.add_argument('--display_server', type=str, default="http://localhost", help='visdom server of the web display') 19 | parser.add_argument('--display_env', type=str, default='main', help='visdom display environment name (default is "main")') 20 | parser.add_argument('--display_port', type=int, default=8097, help='visdom port of the web display') 21 | parser.add_argument('--update_html_freq', type=int, default=50, help='frequency of saving training results to html') 22 | parser.add_argument('--print_freq', type=int, default=10, help='frequency of showing training results on console (in steps per epoch)') 23 | parser.add_argument('--no_html', action='store_true', help='do not save intermediate training results to [opt.checkpoints_dir]/[opt.name]/web/') 24 | # network saving and loading parameters 25 | parser.add_argument('--save_latest_freq', type=int, default=50, help='frequency of saving the latest results (in epochs)') 26 | parser.add_argument('--save_by_epoch', action='store_true', help='whether saves model as "epoch" or "latest" (overwrites previous)') 27 | parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model') 28 | parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by , +, ...') 29 | parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc') 30 | # training parameters 31 | parser.add_argument('--n_steps', type=int, default=12000, help='number of training steps with the initial learning rate') 32 | parser.add_argument('--n_steps_decay', type=int, default=0, help='number of steps to linearly decay learning rate to zero') 33 | parser.add_argument('--lr', type=float, default=0.001, help='initial learning rate for adam') 34 | parser.add_argument('--lr_policy', type=str, default='linear', help='learning rate policy. [linear | step | plateau | cosine]') 35 | parser.add_argument('--lr_decay_iters', type=int, default=0) 36 | 37 | self.isTrain = True 38 | return parser 39 | -------------------------------------------------------------------------------- /third_party/omnimatte/requirements.txt: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | torch==1.8.0 4 | torchvision 5 | dominate>=2.4.0 6 | visdom>=0.1.8 7 | matplotlib>=3.2.1 8 | opencv-python==4.5.1 9 | Pillow>=6.1.0 10 | numpy>=1.19.2 11 | 12 | -------------------------------------------------------------------------------- /third_party/omnimatte/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | """Script to save the full outputs of an Omnimatte model. 4 | 5 | Once you have trained the Omnimatte model with train.py, you can use this script to save the model's final omnimattes. 6 | It will load a saved model from '--checkpoints_dir' and save the results to '--results_dir'. 7 | 8 | It first creates a model and dataset given the options. It will hard-code some parameters. 9 | It then runs inference for '--num_test' images and save results to an HTML file. 10 | 11 | Example (after training a model): 12 | python test.py --dataroot ./datasets/tennis --name tennis 13 | 14 | Use '--results_dir ' to specify the results directory. 15 | 16 | See options/base_options.py and options/test_options.py for more test options. 17 | """ 18 | import os 19 | from options.test_options import TestOptions 20 | from third_party.data import create_dataset 21 | from third_party.models import create_model 22 | from third_party.util.visualizer import save_images, save_videos 23 | from third_party.util import html 24 | import torch 25 | 26 | 27 | if __name__ == '__main__': 28 | testopt = TestOptions() 29 | opt = testopt.parse() 30 | # hard-code some parameters for test 31 | opt.num_threads = 0 # test code only supports num_threads = 0 32 | opt.batch_size = 1 # test code only supports batch_size = 1 33 | opt.serial_batches = True # disable data shuffling; comment this line if results on randomly chosen images are needed. 34 | opt.display_id = -1 # no visdom display; the test code saves the results to a HTML file. 35 | dataset = create_dataset(opt) # create a dataset given opt.dataset_mode and other options 36 | model = create_model(opt) # create a model given opt.model and other options 37 | model.setup(opt) # regular setup: load and print networks; create schedulers 38 | if opt.use_eval: 39 | model.eval() 40 | # create a website 41 | web_dir = os.path.join(opt.results_dir, opt.name, '{}_{}'.format(opt.phase, opt.epoch)) # define the website directory 42 | print('creating web directory', web_dir) 43 | webpage = html.HTML(web_dir, 'Experiment = %s, Phase = %s, Epoch = %s' % (opt.name, opt.phase, opt.epoch)) 44 | video_visuals = None 45 | for i, data in enumerate(dataset): 46 | if i >= opt.num_test: # only apply our model to opt.num_test images. 47 | break 48 | model.set_input(data) # unpack data from data loader 49 | model.test() # run inference 50 | img_path = model.get_image_paths() # get image paths 51 | if i % 5 == 0: # save images to an HTML file 52 | print('processing (%04d)-th image... %s' % (i, img_path)) 53 | visuals = model.get_results() # rgba, reconstruction, original, mask 54 | if video_visuals is None: 55 | video_visuals = visuals 56 | else: 57 | for k in video_visuals: 58 | video_visuals[k] = torch.cat((video_visuals[k], visuals[k])) 59 | rgba = { k: visuals[k] for k in visuals if 'rgba' in k } 60 | # save RGBA layers 61 | save_images(webpage, rgba, img_path, aspect_ratio=opt.aspect_ratio, width=opt.display_winsize) 62 | save_videos(webpage, video_visuals, width=opt.display_winsize) 63 | webpage.save() # save the HTML of videos 64 | -------------------------------------------------------------------------------- /third_party/omnimatte/third_party/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | -------------------------------------------------------------------------------- /third_party/omnimatte/third_party/data/image_folder.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | """A modified image folder class 4 | 5 | We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py) 6 | so that this class can load images from both current directory and its subdirectories. 7 | """ 8 | 9 | import torch.utils.data as data 10 | 11 | from PIL import Image 12 | import os 13 | 14 | IMG_EXTENSIONS = [ 15 | '.jpg', '.JPG', '.jpeg', '.JPEG', 16 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 17 | '.tif', '.TIF', '.tiff', '.TIFF', 18 | ] 19 | 20 | 21 | def is_image_file(filename): 22 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 23 | 24 | 25 | def make_dataset(dir, max_dataset_size=float("inf")): 26 | images = [] 27 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 28 | 29 | for root, _, fnames in sorted(os.walk(dir)): 30 | for fname in fnames: 31 | if is_image_file(fname): 32 | path = os.path.join(root, fname) 33 | images.append(path) 34 | images = sorted(images) 35 | return images[:min(max_dataset_size, len(images))] 36 | 37 | 38 | def default_loader(path): 39 | return Image.open(path).convert('RGB') 40 | 41 | 42 | class ImageFolder(data.Dataset): 43 | 44 | def __init__(self, root, transform=None, return_paths=False, 45 | loader=default_loader): 46 | imgs = make_dataset(root) 47 | if len(imgs) == 0: 48 | raise(RuntimeError("Found 0 images in: " + root + "\n" 49 | "Supported image extensions are: " + ",".join(IMG_EXTENSIONS))) 50 | 51 | self.root = root 52 | self.imgs = imgs 53 | self.transform = transform 54 | self.return_paths = return_paths 55 | self.loader = loader 56 | 57 | def __getitem__(self, index): 58 | path = self.imgs[index] 59 | img = self.loader(path) 60 | if self.transform is not None: 61 | img = self.transform(img) 62 | if self.return_paths: 63 | return img, path 64 | else: 65 | return img 66 | 67 | def __len__(self): 68 | return len(self.imgs) 69 | -------------------------------------------------------------------------------- /third_party/omnimatte/third_party/models/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | """This package contains modules related to objective functions, optimizations, and network architectures. 4 | 5 | To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel. 6 | You need to implement the following five functions: 7 | -- <__init__>: initialize the class; first call BaseModel.__init__(self, opt). 8 | -- : unpack data from dataset and apply preprocessing. 9 | -- : produce intermediate results. 10 | -- : calculate loss, gradients, and update network weights. 11 | -- : (optionally) add model-specific options and set default options. 12 | 13 | In the function <__init__>, you need to define four lists: 14 | -- self.loss_names (str list): specify the training losses that you want to plot and save. 15 | -- self.model_names (str list): define networks used in our training. 16 | -- self.visual_names (str list): specify the images that you want to display and save. 17 | -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage. 18 | 19 | Now you can use the model class by specifying flag '--model dummy'. 20 | See our template model class 'template_model.py' for more details. 21 | """ 22 | 23 | import importlib 24 | from .base_model import BaseModel 25 | 26 | 27 | def find_model_using_name(model_name): 28 | """Import the module "models/[model_name]_model.py". 29 | 30 | In the file, the class called DatasetNameModel() will 31 | be instantiated. It has to be a subclass of BaseModel, 32 | and it is case-insensitive. 33 | """ 34 | model_filename = "third_party.omnimatte.models." + model_name + "_model" 35 | modellib = importlib.import_module(model_filename) 36 | model = None 37 | target_model_name = model_name.replace('_', '') + 'model' 38 | for name, cls in modellib.__dict__.items(): 39 | if name.lower() == target_model_name.lower() \ 40 | and issubclass(cls, BaseModel): 41 | model = cls 42 | 43 | if model is None: 44 | print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name)) 45 | exit(0) 46 | 47 | return model 48 | 49 | 50 | def get_option_setter(model_name): 51 | """Return the static method of the model class.""" 52 | model_class = find_model_using_name(model_name) 53 | return model_class.modify_commandline_options 54 | 55 | 56 | def create_model(opt): 57 | """Create a model given the option. 58 | 59 | This function warps the class CustomDatasetDataLoader. 60 | This is the main interface between this package and 'train.py'/'test.py' 61 | 62 | Example: 63 | >>> from models import create_model 64 | >>> model = create_model(opt) 65 | """ 66 | model = find_model_using_name(opt.model) 67 | instance = model(opt) 68 | print("model [%s] was created" % type(instance).__name__) 69 | return instance 70 | -------------------------------------------------------------------------------- /third_party/omnimatte/third_party/models/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch.optim import lr_scheduler 6 | 7 | 8 | ############################################################################### 9 | # Helper Functions 10 | ############################################################################### 11 | def get_scheduler(optimizer, opt): 12 | """Return a learning rate scheduler 13 | 14 | Parameters: 15 | optimizer -- the optimizer of the network 16 | opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.  17 | opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine 18 | 19 | For 'linear', we keep the same learning rate for the first epochs 20 | and linearly decay the rate to zero over the next epochs. 21 | For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers. 22 | See https://pytorch.org/docs/stable/optim.html for more details. 23 | """ 24 | if opt.lr_policy == 'linear': 25 | def lambda_rule(epoch): 26 | lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1) 27 | return lr_l 28 | 29 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 30 | elif opt.lr_policy == 'step': 31 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.1) 32 | elif opt.lr_policy == 'plateau': 33 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 34 | elif opt.lr_policy == 'cosine': 35 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 36 | else: 37 | return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy) 38 | return scheduler 39 | 40 | 41 | def init_net(net, gpu_ids=[]): 42 | """Initialize a network by registering CPU/GPU device (with multi-GPU support) 43 | Parameters: 44 | net (network) -- the network to be initialized 45 | gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 46 | 47 | Return an initialized network. 48 | """ 49 | if len(gpu_ids) > 0: 50 | assert (torch.cuda.is_available()) 51 | net.to(gpu_ids[0]) 52 | net = torch.nn.DataParallel(net, gpu_ids) # multi-GPUs 53 | return net 54 | -------------------------------------------------------------------------------- /third_party/omnimatte/third_party/util/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | """This package includes a miscellaneous collection of useful helper functions.""" 4 | -------------------------------------------------------------------------------- /tools/blender_to_matting.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import json 4 | import math 5 | import shutil 6 | from argparse import ArgumentParser 7 | from pathlib import Path 8 | 9 | 10 | def main(): 11 | parser = ArgumentParser() 12 | parser.add_argument("folder", help="Folder containing Blender render outputs") 13 | 14 | args = parser.parse_args() 15 | root = Path(args.folder) 16 | 17 | folders = sorted(root.glob("frame_*")) 18 | N = len(folders) 19 | assert N > 0, "No frame is found!" 20 | print(f"Found {N} frames") 21 | 22 | rgb_folder = root / "rgb_1x" 23 | rgb_folder.mkdir(parents=True, exist_ok=True) 24 | 25 | # move image files 26 | for i in range(N): 27 | src = folders[i] / "camera_0000.png" 28 | dst = rgb_folder / f"{i:04d}.png" 29 | if dst.exists(): 30 | continue 31 | shutil.move(src, dst) 32 | print("Moved all frames") 33 | 34 | frames = [] 35 | # load data 36 | for i in range(N): 37 | with open(folders[i] / "camera_0000.json", "r", encoding="utf-8") as f: 38 | cam = json.load(f) 39 | fov = 2 * math.atan2(1, 2 * cam["normalized_focal_length_x"]) 40 | frames.append({ 41 | "file_path": f"rgb_1x/{i:04d}", 42 | "transform_matrix": cam["camera_to_world"], 43 | "world_to_camera": cam["world_to_camera"], 44 | "near": cam["near_clip"], 45 | "far": cam["far_clip"], 46 | }) 47 | data = { 48 | "camera_angle_x": fov, 49 | "frames": frames, 50 | } 51 | with open(root / "poses.json", "w", encoding="utf-8") as f: 52 | json.dump(data, f, ensure_ascii=False, indent=2, sort_keys=True) 53 | print(f"Wrote data, fov = {fov}") 54 | 55 | ans = input("Delete source folders? (y,N) ") 56 | if ans.lower() == "y": 57 | for i in range(N): 58 | shutil.rmtree(folders[i]) 59 | print("Deleted source folders") 60 | 61 | 62 | if __name__ == "__main__": 63 | main() 64 | -------------------------------------------------------------------------------- /tools/nerfies_to_blender.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import json 4 | import math 5 | from argparse import ArgumentParser 6 | from pathlib import Path 7 | import os 8 | 9 | import numpy as np 10 | 11 | from utils.json_utils import write_json 12 | 13 | 14 | def load(path): 15 | with open(path, "r") as f: 16 | return json.load(f) 17 | 18 | 19 | parser = ArgumentParser() 20 | parser.add_argument("input", help="Path to nerfies video folder") 21 | parser.add_argument("output", help="Path to matting video folder") 22 | args = parser.parse_args() 23 | 24 | src_root = Path(args.input) 25 | if src_root.name == "freeze-test": 26 | scene = load(src_root / ".." / "scene_gt.json") 27 | else: 28 | scene = load(src_root / "scene.json") 29 | 30 | scale = scene["scale"] 31 | near = scene["near"] / scale 32 | far = scene["far"] / scale 33 | 34 | src_cam_folder = src_root / "camera" 35 | if not os.path.isdir(src_cam_folder): 36 | src_cam_folder = src_root / "camera-gt" 37 | 38 | src_cam_files = sorted(src_cam_folder.glob("*.json")) 39 | src_cams = [load(f) for f in src_cam_files] 40 | 41 | cam = src_cams[0] 42 | fov = 2 * math.atan2(1, 2 * cam["focal_length"] / cam["image_size"][0]) 43 | print(f"fov = {fov}") 44 | 45 | image_files = sorted((Path(args.output) / "rgb_1x").glob("*.png")) 46 | 47 | data = { 48 | "frames": [], 49 | "camera_angle_x": fov, 50 | } 51 | frames = data["frames"] 52 | 53 | 54 | for i, cam in enumerate(src_cams): 55 | mat = np.eye(4) 56 | mat[:3, :3] = np.array(cam["orientation"]).T 57 | mat[:3, 1:3] *= -1 58 | mat[:3, 3:] = np.array(cam["position"]).reshape(3, 1) 59 | 60 | frames.append({ 61 | "file_path": f"rgb_1x/{image_files[i].stem}", 62 | "transform_matrix": mat.tolist(), 63 | "near": near, 64 | "far": far, 65 | }) 66 | 67 | write_json(Path(args.output) / "poses.json", data) 68 | -------------------------------------------------------------------------------- /tools/simple_video.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | import subprocess 5 | from argparse import ArgumentParser 6 | from pathlib import Path 7 | 8 | parser = ArgumentParser() 9 | parser.add_argument("input") 10 | parser.add_argument("-o", "--output") 11 | parser.add_argument("-s", "--scale", type=float, default=0) 12 | args = parser.parse_args() 13 | 14 | folder = Path(args.input) 15 | if not os.path.isdir(folder): 16 | print(f"Not found: {folder}") 17 | exit(1) 18 | 19 | in_files = sorted(list(folder.glob("*.png")) + list(folder.glob("*.jpg"))) 20 | if len(in_files) == 0: 21 | print("No input file exists") 22 | exit(1) 23 | 24 | file = in_files[0] 25 | prefix_idx = file.name.find("0") 26 | prefix = "" if prefix_idx <= 0 else file.name[:prefix_idx] 27 | start_number = int(file.stem[len(prefix):]) 28 | 29 | name_length = len(file.stem) - len(prefix) 30 | name_ext = os.path.splitext(file.name)[1] 31 | 32 | output = args.output 33 | if output is None: 34 | output = folder.name + ".mp4" 35 | output = Path(output) 36 | assert output.suffix in { ".mp4", ".webm", ".mov" } 37 | 38 | filters = ["pad=ceil(iw/2)*2:ceil(ih/2)*2"] 39 | if args.scale > 0: 40 | filters = [f"scale=iw*{args.scale}:-2"] + filters 41 | if output.suffix == ".mp4": 42 | filters += ["format=yuv420p"] 43 | encoder_args = ["-crf", 17, "-preset", "veryslow"] 44 | elif output.suffix == ".webm": 45 | encoder_args = ["-c:v", "libvpx-vp9", "-crf", 17, "-b:v", 0] 46 | elif output.suffix == ".mov": 47 | encoder_args = ["-c:v", "prores_ks", "-profile:v", 4, "-pix_fmt", "yuva444p10le"] 48 | 49 | ffmpeg_args = [ 50 | "ffmpeg", 51 | "-y", 52 | "-r", 53 | 10, 54 | "-start_number", 55 | start_number, 56 | "-i", 57 | folder / f"{prefix}%0{name_length}d{name_ext}", 58 | "-vf", 59 | ",".join(filters), 60 | *encoder_args, 61 | "-r", 62 | 30, 63 | output 64 | ] 65 | ffmpeg_args = [str(v) for v in ffmpeg_args] 66 | 67 | print(" ".join(ffmpeg_args)) 68 | subprocess.call(ffmpeg_args) 69 | -------------------------------------------------------------------------------- /ui/cli.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import logging 4 | import sys 5 | from argparse import ArgumentParser 6 | from inspect import Parameter, signature 7 | 8 | from ui.commands import all_cli_commands 9 | from ui.common import create_data_manager 10 | from ui.data_manager import DataManager 11 | 12 | 13 | def main(): 14 | argv = sys.argv 15 | 16 | logging.basicConfig( 17 | format="[%(asctime)s][%(name)s][%(levelname)s] - %(message)s", 18 | handlers=[ 19 | logging.StreamHandler(), 20 | ], 21 | level=logging.INFO, 22 | ) 23 | 24 | if len(argv) < 2 or argv[1] not in all_cli_commands: 25 | cmd_names = sorted(all_cli_commands.keys()) 26 | print(f"Usage: {argv[0]} command [...args]") 27 | print(f"Commands:") 28 | print("\n".join([" " + cmd for cmd in cmd_names])) 29 | exit(1) 30 | 31 | func = all_cli_commands[argv[1]] 32 | sig = signature(func) 33 | dm_param = None 34 | has_extra_args = False 35 | 36 | parser = ArgumentParser(argv[1]) 37 | for param in sig.parameters.values(): 38 | if param.annotation == DataManager: 39 | dm_param = param.name 40 | continue 41 | if param.name == "extra_args": 42 | has_extra_args = True 43 | continue 44 | if param.annotation == list[str]: 45 | parser.add_argument(f"--{param.name}", type=str, default=param.default, action="append") 46 | continue 47 | 48 | if param.annotation == bool: 49 | assert param.default in { True, False }, f"bool param ({param.name}) must have default" 50 | if not param.default: 51 | parser.add_argument(f"--{param.name}", action="store_true") 52 | else: 53 | parser.add_argument(f"--no_{param.name}", dest=param.name, action="store_false") 54 | continue 55 | 56 | if param.default != Parameter.empty: 57 | parser.add_argument(f"--{param.name}", type=param.annotation, default=param.default) 58 | else: 59 | parser.add_argument(f"{param.name}", type=param.annotation) 60 | 61 | argv = argv[2:] 62 | extra_args = [] 63 | if has_extra_args and "--" in argv: 64 | split_idx = argv.index("--") 65 | extra_args = argv[split_idx+1:] 66 | argv = argv[:split_idx] 67 | 68 | args = parser.parse_args(argv) 69 | args = vars(args) 70 | if dm_param is not None: 71 | dm = create_data_manager() 72 | args[dm_param] = dm 73 | if has_extra_args: 74 | args["extra_args"] = extra_args 75 | 76 | func(**args) 77 | 78 | 79 | if __name__ == "__main__": 80 | main() 81 | -------------------------------------------------------------------------------- /ui/common.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from ui.data_manager import DataManager 4 | from pathlib import Path 5 | 6 | dm_config_file = Path(__file__).parent.parent / "data_manager.json" 7 | 8 | 9 | def create_data_manager() -> DataManager: 10 | if not dm_config_file.exists(): 11 | raise ValueError(f"Config file {dm_config_file} does not exist. Create one by copying data_manager_example.json.") 12 | 13 | return DataManager(dm_config_file) 14 | -------------------------------------------------------------------------------- /ui/data_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from dataclasses import dataclass, field 4 | from typing import List, Optional 5 | 6 | from dataclasses_json import dataclass_json 7 | 8 | 9 | @dataclass_json 10 | @dataclass 11 | class ExperimentAvailability: 12 | checkpoints: List[str] = field(default_factory=list) 13 | evals: List[str] = field(default_factory=list) 14 | other_artifacts: List[str] = field(default_factory=list) 15 | 16 | def __str__(self) -> str: 17 | data = [ 18 | f"{k}({', '.join(self.__dict__[k])})" 19 | for k in ["checkpoints", "evals", "other_artifacts"] 20 | if len(self.__dict__[k]) > 0 21 | ] 22 | return f"Experiment[{', '.join(data)}]" 23 | 24 | 25 | @dataclass_json 26 | @dataclass 27 | class DataAvailability: 28 | images: bool = False 29 | poses: List[str] = field(default_factory=list) 30 | masks: List[str] = field(default_factory=list) 31 | flow: bool = False 32 | depth: bool = False 33 | homography: bool = False 34 | segmentation: bool = False 35 | other_artifacts: List[str] = field(default_factory=list) 36 | other_formats: List[str] = field(default_factory=list) 37 | 38 | def __str__(self) -> str: 39 | data = [k for k in ["images", "flow", "depth", "homography", "segmentation"] if self.__dict__[k]] 40 | data += [ 41 | f"{k}({', '.join(self.__dict__[k])})" 42 | for k in ["poses", "masks", "other_artifacts", "other_formats"] 43 | if len(self.__dict__[k]) > 0 44 | ] 45 | return f"Data[{', '.join(data)}]" 46 | 47 | 48 | @dataclass_json 49 | @dataclass 50 | class Experiment: 51 | category: str 52 | video: str 53 | method: str 54 | name: str 55 | notes: str 56 | local: ExperimentAvailability = field(default_factory=ExperimentAvailability) 57 | remote: ExperimentAvailability = field(default_factory=ExperimentAvailability) 58 | 59 | 60 | @dataclass_json 61 | @dataclass 62 | class Dataset: 63 | category: str 64 | video: str 65 | local: DataAvailability = field(default_factory=DataAvailability) 66 | remote: DataAvailability = field(default_factory=DataAvailability) 67 | 68 | 69 | @dataclass_json 70 | @dataclass 71 | class RemoteConfig: 72 | endpoint: str 73 | bucket: str 74 | access_key: str 75 | secret_key: str 76 | 77 | 78 | @dataclass_json 79 | @dataclass 80 | class LocalConfig: 81 | data_root: str 82 | output_root: str 83 | training_folder: str = field(default="train") 84 | 85 | 86 | @dataclass_json 87 | @dataclass 88 | class DataManagerConfig: 89 | local: LocalConfig 90 | remote: Optional[RemoteConfig] 91 | -------------------------------------------------------------------------------- /utils/array_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import numpy as np 4 | 5 | 6 | def log_stats(logger_func, name: str, array: np.ndarray): 7 | """Log the statistics of an array""" 8 | keys = ["min", "max", "mean", "median"] 9 | values = [func(array) for func in [np.min, np.max, np.mean, np.median]] 10 | 11 | percentiles = [0.1, 1, 5, 10, 25, 50, 75, 90, 95, 99, 99.9] 12 | keys += [str(v) for v in percentiles] 13 | values += list(np.percentile(values, percentiles)) 14 | 15 | length = max([len(k) for k in keys]) 16 | message = [ 17 | f"stats of {name}:", 18 | *(f"{key:>{length}} {value}" for key, value in zip(keys, values)), 19 | ] 20 | 21 | logger_func("\n".join(message)) 22 | -------------------------------------------------------------------------------- /utils/dict_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import Any, Dict 4 | 5 | 6 | def inject_dict(cfg: Dict[str, Any], injection: Dict[str, Any]): 7 | """Merge item from injection if the key exists in cfg and value is None. 8 | 9 | Note: this is different from dictionary merging. 10 | Returns: input cfg object 11 | """ 12 | for k, v in injection.items(): 13 | if k in cfg and cfg[k] is None: 14 | cfg[k] = v 15 | return cfg 16 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import os 4 | 5 | import cv2 6 | import numpy as np 7 | from PIL import Image 8 | 9 | ckbd_cache = {} 10 | 11 | 12 | def checkerboard(H, W): 13 | key = f"{H}x{W}" 14 | if key in ckbd_cache: 15 | return ckbd_cache[key] 16 | 17 | board = np.ones([H, W, 3]) * 0.8 18 | sz = H // 16 19 | for y in range(0, H, sz): 20 | for x in range(0, W, sz): 21 | if (x//sz+y//sz) % 2 == 0: 22 | continue 23 | board[y:y+sz, x:x+sz] = 0.6 24 | ckbd_cache[key] = board 25 | return board 26 | 27 | 28 | def read_image(fp, scale: float = 1, resample: int = Image.BILINEAR) -> Image.Image: 29 | img = Image.open(fp) 30 | if scale == 1: 31 | return img 32 | 33 | w = int(np.round(img.width * scale)) 34 | h = int(np.round(img.height * scale)) 35 | return img.resize((w, h), resample=resample) 36 | 37 | 38 | def read_image_np(fp, scale: float = 1, resample: int = Image.BILINEAR) -> np.ndarray: 39 | img = read_image(fp, scale, resample) 40 | return np.array(img, dtype=np.float32) / 255 41 | 42 | 43 | def save_image(path, image: Image.Image): 44 | ext = os.path.splitext(path)[1] 45 | if ext == ".jpg": 46 | params = {"quality": 95} 47 | elif ext == ".png": 48 | params = {"optimize": True} 49 | elif ext == ".webp": 50 | params = {"quality": 95, "method": 6} 51 | 52 | image.save(path, **params) 53 | 54 | 55 | def save_image_np(path, data: np.ndarray): 56 | data = np.clip(data * 255, 0, 255).astype(np.uint8) 57 | image = Image.fromarray(data) 58 | save_image(path, image) 59 | 60 | 61 | def normalize_array(array: np.ndarray, pmin=0, pmax=100, gamma=1) -> np.ndarray: 62 | dmin, dmax = np.percentile(array, [pmin, pmax]) 63 | array = np.clip((array - dmin) / (dmax - dmin), 0, 1) 64 | if gamma != 1: 65 | array = np.power(array, gamma) 66 | return array 67 | 68 | 69 | def visualize_array(array, color_map=cv2.COLORMAP_JET): 70 | """ 71 | array: [H, W], values in range [0, 1] 72 | color_map: cv2.COLORMAP_* 73 | """ 74 | # NOTE: casting float array with NaNs to uint8 results in 0s 75 | scaled = (array**0.5 * 255).astype(np.uint8) 76 | colored = (cv2.applyColorMap(scaled, color_map) / 255) ** 2.2 * 255 77 | converted = cv2.cvtColor(colored.astype(np.float32), cv2.COLOR_BGR2RGB) / 255 78 | 79 | return converted 80 | -------------------------------------------------------------------------------- /utils/io_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import itertools 4 | import os 5 | from pathlib import Path 6 | from typing import List, Union 7 | 8 | 9 | def mkdir(path, exist_ok=True) -> Path: 10 | """Create the directory and returns a Path object.""" 11 | path = Path(path) 12 | os.makedirs(path, exist_ok=exist_ok) 13 | return path 14 | 15 | 16 | def multi_glob_sorted(path: Union[str, Path], appendices: Union[str, List[str]]) -> List[Path]: 17 | """List files in directory if its extension is in provided list. 18 | 19 | Returns: sorted Path objects 20 | """ 21 | if not isinstance(appendices, list): 22 | appendices = [appendices] 23 | 24 | path = Path(path) 25 | return sorted(itertools.chain(*[path.glob(app) for app in appendices])) 26 | 27 | 28 | def filter_dirs(paths: List[Path]) -> List[Path]: 29 | return [d for d in paths if os.path.isdir(d)] 30 | -------------------------------------------------------------------------------- /utils/json_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import json 4 | import os 5 | 6 | 7 | def write_json(file, obj): 8 | text = json.dumps(obj, ensure_ascii=False, indent=2, sort_keys=True) 9 | 10 | os.makedirs(os.path.split(file)[0], exist_ok=True) 11 | with open(file, "w", encoding="utf-8") as f: 12 | f.write(text + "\n") 13 | 14 | 15 | def read_json(file, default_factory): 16 | if not os.path.isfile(file): 17 | return default_factory() 18 | with open(file, "r", encoding="utf-8") as f: 19 | return json.load(f) 20 | 21 | 22 | def write_data_json(file, obj, cls): 23 | """Serialize an object of a dataclass_json class""" 24 | os.makedirs(os.path.split(file)[0], exist_ok=True) 25 | text = cls.schema().dumps( 26 | obj, many=isinstance(obj, list), indent=2, ensure_ascii=False, sort_keys=True 27 | ) 28 | with open(file, "w", encoding="utf-8") as f: 29 | f.write(text) 30 | 31 | 32 | def read_data_json(file, cls): 33 | with open(file, "r", encoding="utf-8") as f: 34 | text = f.read() 35 | return cls.schema().loads(text) 36 | -------------------------------------------------------------------------------- /utils/render_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | from typing import List, Tuple 4 | 5 | import numpy as np 6 | import torch 7 | from numpy import ndarray 8 | from torch import Tensor 9 | 10 | 11 | def get_coords(H: int, W: int) -> Tensor: 12 | """Get a [H, W, 2] tensor of coordinates""" 13 | return torch.stack( 14 | torch.meshgrid( 15 | torch.arange(W), 16 | torch.arange(H), 17 | indexing="xy", 18 | ), 19 | dim=-1, 20 | ) 21 | 22 | 23 | def get_rays(coords: Tensor, K: np.ndarray, c2w: Tensor) -> Tuple[Tensor, Tensor]: 24 | """Get ray origins and directions for coords 25 | 26 | coords: [*, 2] 27 | K: [3, 3] 28 | c2w: [3, 4] 29 | """ 30 | x, y = coords[..., 0], coords[..., 1] 31 | dirs = torch.stack( 32 | [(x - K[0, 2]) / K[0, 0], -(y - K[1, 2]) / K[1, 1], -torch.ones_like(x)], dim=-1 33 | ) 34 | rays_d = torch.sum(dirs[..., None, :] * c2w[:3, :3], dim=-1) 35 | rays_o = c2w[:3, -1].expand(rays_d.shape) 36 | return rays_o, rays_d 37 | 38 | 39 | def alpha_composite( 40 | alpha: Tensor, data: List[Tensor], bg_data: List[Tensor] 41 | ) -> List[Tensor]: 42 | """ 43 | alpha: [B, L, S, 1] in range [0, 1] 44 | data: list of [B, L, S, C] tensors of foreground (front to back) 45 | bg_data: list of [B, S, C] tensors of background (i.e. mulplied by cumulated (1-alpha)) 46 | 47 | return: list of [B, S, C] composited data 48 | """ 49 | L = alpha.shape[1] 50 | weights = torch.cat( 51 | [torch.ones_like(alpha[:, 0:1]), torch.cumprod(1 - alpha, dim=1)], dim=1 52 | ) # [B, L+1, S, 1] 53 | fg_weights = weights[:, :L] * alpha # [B, L, S, 1] 54 | bg_weights = weights[:, L] # [B, S, 1] 55 | 56 | return [ 57 | torch.sum(fg_weights * data[i], dim=1) + bg_weights * bg_data[i] 58 | for i in range(len(data)) 59 | ] 60 | 61 | 62 | def detail_transfer( 63 | target: ndarray, 64 | image: ndarray, 65 | rgba_layers: ndarray, 66 | ) -> ndarray: 67 | """ 68 | transfer residual to foreground layers 69 | 70 | target: [H, W, 3] 71 | image: [H, W, 3] 72 | rgba_layers: [L, H, W, 3] 73 | 74 | returns: a copy of rgba_layers with details added 75 | """ 76 | residual = target - image 77 | trans_comp = np.zeros_like(target[..., 0:1]) 78 | rgba_detail = rgba_layers.copy() 79 | n_layers = rgba_detail.shape[0] 80 | for i in range(n_layers): 81 | trans_i = 1 - trans_comp 82 | rgba_detail[i, ..., :3] += trans_i * residual 83 | alpha = rgba_detail[i, ..., 3:4] 84 | trans_comp = alpha + (1 - alpha) * trans_comp 85 | rgba_detail = np.clip(rgba_detail, 0, 1) 86 | return rgba_detail, trans_comp[..., 0] 87 | -------------------------------------------------------------------------------- /utils/string_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | def format_float(v: float) -> str: 4 | """ 5 | Display a float in scientific notation only if it's too small or large. 6 | Examples: P511286390 7 | """ 8 | if 1e-3 < abs(v) < 1e3: 9 | return f"{v:.5g}" 10 | return f"{v:.4e}" 11 | -------------------------------------------------------------------------------- /utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class Padder: 8 | def __init__(self, pad_to, H, W): 9 | self.pads = [ 10 | (pad_to - H % pad_to) % pad_to, 11 | (pad_to - W % pad_to) % pad_to, 12 | ] 13 | self.hw = [H, W] 14 | 15 | def pad(self, tensor): 16 | Hpad, Wpad = self.pads 17 | if Hpad == 0 and Wpad == 0: 18 | return tensor 19 | return F.pad( 20 | tensor, 21 | (Wpad // 2, Wpad - Wpad // 2, Hpad // 2, Hpad - Hpad // 2), 22 | "constant", 23 | ) 24 | 25 | def unpad(self, tensor): 26 | Hpad, Wpad = self.pads 27 | if Hpad == 0 and Wpad == 0: 28 | return tensor 29 | H, W = self.hw 30 | return tensor[ 31 | ..., Hpad // 2: Hpad // 2 + H, Wpad // 2: Wpad // 2 + W 32 | ].contiguous() 33 | 34 | 35 | class PositionalEncoder: 36 | 37 | def __init__( 38 | self, 39 | in_dims: int, 40 | num_freq: int, 41 | max_freq: int, 42 | include_input: bool = True, 43 | log_sampling: bool = True, 44 | ): 45 | self.in_dims = in_dims 46 | self.num_freq = num_freq 47 | self.max_freq = max_freq 48 | self.include_input = include_input 49 | self.log_sampling = log_sampling 50 | self.periodic_fns = [torch.sin, torch.cos] 51 | 52 | if log_sampling: 53 | self.freq_bands = 2 ** torch.linspace(0., max_freq, steps=num_freq) 54 | else: 55 | self.freq_bands = torch.linspace( 56 | 2. ** 0, 2. ** max_freq, steps=num_freq) 57 | 58 | @property 59 | def out_dims(self): 60 | dims = len(self.freq_bands) * len(self.periodic_fns) * self.in_dims 61 | if self.include_input: 62 | dims += self.in_dims 63 | return dims 64 | 65 | def __call__(self, x): 66 | x = x[..., :self.in_dims] 67 | encoding = [] 68 | if self.include_input: 69 | encoding.append(x) 70 | 71 | for freq in self.freq_bands: 72 | for p_fn in self.periodic_fns: 73 | encoding.append(p_fn(x * freq)) 74 | 75 | return torch.cat(encoding, dim=-1) 76 | -------------------------------------------------------------------------------- /workflows/config/bg_losses/distortion.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | distortion: 4 | name: distortion 5 | config: 6 | inputs: [bg_weight, bg_z_vals] 7 | alpha: 0.02 8 | -------------------------------------------------------------------------------- /workflows/config/bg_losses/recons.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | recons: 4 | name: mse 5 | config: 6 | inputs: [bg_rgb, bg_gt_rgb] 7 | alpha: 1.0 8 | -------------------------------------------------------------------------------- /workflows/config/bg_losses/recons_coarse.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | recons_coarse: 4 | name: mse 5 | config: 6 | inputs: [bg_rgb_coarse, bg_gt_rgb] 7 | alpha: 1.0 8 | -------------------------------------------------------------------------------- /workflows/config/bg_losses/recons_om.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | recons: 4 | name: l1 5 | config: 6 | inputs: [bg_rgb, bg_gt_rgb] 7 | alpha: 2.0 8 | -------------------------------------------------------------------------------- /workflows/config/bg_losses/robust_depth_matching.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | robust_depth_matching: 4 | name: robust_depth_matching 5 | config: 6 | inputs: [bg_depths,gt_depths,depth_mask,global_step] 7 | alpha: 0.1 8 | start_step: 200 9 | -------------------------------------------------------------------------------- /workflows/config/bg_losses/tv_reg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | reg_tv_density: 4 | name: zero_reg 5 | config: 6 | inputs: [reg_tv_density] 7 | alpha: 1.0 8 | reg_tv_app: 9 | name: zero_reg 10 | config: 11 | inputs: [reg_tv_app] 12 | alpha: 1.0 13 | -------------------------------------------------------------------------------- /workflows/config/bg_model/dummy.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: dummy 4 | config: {} 5 | train: false 6 | optim: 7 | lr: 0.001 8 | -------------------------------------------------------------------------------- /workflows/config/bg_model/tensorf.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: tensorf 4 | config: 5 | N_voxel_init: 2097156 6 | N_voxel_final: 262144000 7 | upsamp_list: [2000,3000,4000,5500] 8 | update_AlphaMask_list: [2500] 9 | density_n_comp: [16,4,4] 10 | appearance_n_comp: [48,12,12] 11 | app_dim: 27 12 | shadingMode: MLP_Fea 13 | alphaMask_thres: 0.0001 14 | density_shift: -10 15 | distance_scale: 25 16 | pos_pe: 6 17 | view_pe: 2 18 | fea_pe: 2 19 | featureC: 128 20 | step_ratio: 0.5 21 | fea2denseAct: relu 22 | nSamples: 1000000 23 | lr_upsample_reset: true 24 | viewpe_skip_steps: 0 25 | # dep injection 26 | contraction: 27 | prev_global_step: 28 | global_step_offset: 29 | aabb: 30 | near: 31 | far: 32 | hwf: 33 | device: 34 | train: true 35 | optim: 36 | lr: 0.02 37 | lr_basis: 0.001 38 | -------------------------------------------------------------------------------- /workflows/config/data_sources/blender.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | blender_camera: 4 | subpath: poses.json 5 | near_default: 2 6 | far_default: 6 7 | contraction: 8 | compute_ndc_aabb: true 9 | scene_scale: 1 10 | process_poses: true 11 | -------------------------------------------------------------------------------- /workflows/config/data_sources/colmap.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | llff_camera: 4 | subpath: colmap 5 | near_p: 1 6 | far_p: 99 7 | near_min: 0.01 8 | far_max: 1000 9 | contraction: 10 | compute_ndc_aabb: true 11 | scene_scale: 1 12 | -------------------------------------------------------------------------------- /workflows/config/data_sources/depths.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | depths: 4 | subpath: depth/depth 5 | -------------------------------------------------------------------------------- /workflows/config/data_sources/flow.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | flow: 4 | flow_path: flow/flow 5 | confidence_path: flow/confidence 6 | -------------------------------------------------------------------------------- /workflows/config/data_sources/homography.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | homography: 4 | subpath: homography 5 | -------------------------------------------------------------------------------- /workflows/config/data_sources/mask.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | mask: 4 | subpath: masks/mask 5 | trimap_width: 20 6 | bg_mask_erode_width: 0 7 | n_images: 0 8 | blank_layer_only: false 9 | extra_layers: 0 10 | indices: [] 11 | -------------------------------------------------------------------------------- /workflows/config/data_sources/nerfies_camera.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | nerfies_camera: 4 | subpath: nerfies 5 | contraction: 6 | compute_ndc_aabb: true 7 | -------------------------------------------------------------------------------- /workflows/config/data_sources/rgba_mask.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | rgba_mask: 4 | subpath: 5 | -------------------------------------------------------------------------------- /workflows/config/data_sources/rodynrf.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | llff_camera: 4 | subpath: rodynrf 5 | near_p: 0 6 | far_p: 100 7 | near_min: 0.01 8 | far_max: 1000 9 | contraction: 10 | compute_ndc_aabb: true 11 | process_poses: false 12 | -------------------------------------------------------------------------------- /workflows/config/debug.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train_both 5 | - override validation: [] 6 | - _self_ 7 | 8 | validation: {} 9 | 10 | save_final_checkpoint: false 11 | 12 | n_steps: 10 13 | -------------------------------------------------------------------------------- /workflows/config/eval.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - trainer: matting 5 | - fg_model: dummy 6 | - bg_model: dummy 7 | - data_sources: [mask,flow,colmap] 8 | - _self_ 9 | 10 | output: ??? 11 | hydra: 12 | run: 13 | dir: ${output} 14 | 15 | dataset: 16 | path: 17 | scale: 18 | image_subpath: 19 | data_root: /data/matting/matting 20 | 21 | contraction: 22 | 23 | alpha_threshold: 0.5 24 | 25 | checkpoint: ??? 26 | device: cuda:0 27 | 28 | eval_bg_layer: false 29 | 30 | write_videos: true 31 | debug_count: 0 32 | raw_data_keys: [] 33 | raw_data_indices: [] 34 | 35 | # adding new output to existing evals 36 | migration_mode: 37 | 38 | # defaults to dataset.path basename 39 | dataset_name: 40 | # defaults to get /checkpoints/xxx.ckpt (checkpoint.parent.parent.name) 41 | experiment: 42 | # default loaded from checkpoint 43 | step: 44 | 45 | # defaults to /config.yaml 46 | train_config_file: 47 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/alpha_reg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | alpha_reg: 4 | name: alpha_reg 5 | config: 6 | inputs: [composite_alpha, global_step] 7 | alpha: 1 8 | lambda_alpha_l1: 0.01 9 | lambda_alpha_l0: 0.005 10 | l1_end_step: 1500 11 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/bg_distortion.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | bg_distortion: 4 | name: distortion 5 | config: 6 | inputs: [bg_weight, bg_z_vals] 7 | alpha: 0.02 8 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/bg_tv_reg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | reg_tv_density: 4 | name: zero_reg_optional 5 | config: 6 | alpha: 1.0 7 | optional_input: reg_tv_density 8 | reg_tv_app: 9 | name: zero_reg_optional 10 | config: 11 | alpha: 1.0 12 | optional_input: reg_tv_app 13 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/brightness_reg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | brightness_reg: 4 | name: l1 5 | config: 6 | inputs: [br_scale, br_target] 7 | alpha: 0.001 8 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/depth_matching.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | depth_matching: 4 | name: depth_matching 5 | config: 6 | inputs: [bg_depths,gt_depths,depth_mask,global_step] 7 | alpha: 0.1 8 | reg_alpha: 0.5 9 | reg_scales: 1 10 | start_step: 200 11 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/flow_recons.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | flow_recons: 4 | name: flow_recons 5 | config: 6 | inputs: [composite_flow, fg_gt_flow, fg_gt_flow_confidence] 7 | alpha: 1 8 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/mask.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | mask: 4 | name: mask_loss 5 | config: 6 | inputs: [alpha_layers, fg_gt_mask, global_step] 7 | alpha: 50.0 8 | reduce_threshold: 0.02 9 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/mean_flow_match.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | mean_flow_match: 4 | name: mean_flow_match 5 | config: 6 | inputs: [alpha_layers_extra, flow_mean_dist_map] 7 | alpha: 0.1 8 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/offset_reg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | offset_reg: 4 | name: l1 5 | config: 6 | inputs: [bg_offset, bg_offset_target] 7 | alpha: 0.001 8 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/recons.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | recons: 4 | name: l1 5 | config: 6 | inputs: [composite_rgb, fg_gt_rgb] 7 | alpha: 2.0 8 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/robust_depth_matching.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | robust_depth_matching: 4 | name: robust_depth_matching 5 | config: 6 | inputs: [bg_depths,gt_depths,depth_mask,global_step] 7 | alpha: 0.1 8 | start_step: 200 9 | -------------------------------------------------------------------------------- /workflows/config/fg_losses/warped_alpha.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | warped_alpha: 4 | name: l1 5 | config: 6 | inputs: [warped_alpha_layers, alpha_layers] 7 | alpha: 0.01 8 | -------------------------------------------------------------------------------- /workflows/config/fg_model/dummy.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: dummy 4 | config: {} 5 | train: false 6 | optim: 7 | lr: 0.001 8 | -------------------------------------------------------------------------------- /workflows/config/fg_model/omnimatte.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: omnimatte 4 | config: 5 | n_frames: 6 | image_hw: 7 | hidden_channels: 64 8 | network_normalization: batch 9 | max_frames: 200 10 | coarseness: 10 11 | feature_mode: xyt 12 | feature_config: 13 | pos_n_freq: 10 14 | save_feature_cache_to_disk: false 15 | feature_cache_device: 16 | train: true 17 | optim: 18 | lr: 0.001 19 | -------------------------------------------------------------------------------- /workflows/config/fg_model/omnimatte_noise.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: omnimatte 4 | config: 5 | n_frames: 6 | image_hw: 7 | homography_bounds: 8 | homography_size: 9 | network_normalization: batch 10 | max_frames: 200 11 | hidden_channels: 64 12 | coarseness: 10 13 | feature_mode: noise 14 | feature_config: 15 | channels: 13 16 | train: true 17 | optim: 18 | lr: 0.001 19 | -------------------------------------------------------------------------------- /workflows/config/profile.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train_both 5 | - override validation: [] 6 | - _self_ 7 | 8 | validation: {} 9 | 10 | trainer: 11 | config: 12 | img_batch_size: 4 13 | ray_batch_size: 1024 14 | 15 | save_final_checkpoint: false 16 | 17 | n_steps: 10 18 | profile: true 19 | -------------------------------------------------------------------------------- /workflows/config/train.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - fg_losses: [] 5 | - bg_losses: [] 6 | - validation: [training_dump] 7 | - fg_model: dummy 8 | - bg_model: dummy 9 | - trainer: matting 10 | - data_sources: [mask,flow,colmap] 11 | - _self_ 12 | 13 | output: ??? 14 | hydra: 15 | run: 16 | dir: ${output} 17 | 18 | fg_losses: {} 19 | bg_losses: {} 20 | 21 | dataset: 22 | path: ??? 23 | scale: 1 24 | image_subpath: rgb_1x 25 | 26 | contraction: 27 | 28 | scheduler: 29 | fg: 30 | name: exp_lr 31 | config: 32 | decay_start: 10000 33 | decay_rate: 0.1 34 | decay_steps: 10000 35 | min_rate: 0.1 36 | bg: 37 | name: exp_lr 38 | config: 39 | decay_start: 0 40 | decay_rate: 0.1 41 | decay_steps: 30000 42 | min_rate: 0.01 43 | 44 | save_checkpoint: 45 | step_size: 3000 46 | min_step: 9000 47 | folder: 48 | save_pretrain_checkpoint: false 49 | save_final_checkpoint: true 50 | 51 | checkpoint: 52 | load_fg: true 53 | load_bg: true 54 | reset_global_step: false 55 | reset_bg_optimization: false 56 | 57 | n_steps: ??? 58 | device: cuda:0 59 | debug: false 60 | profile: false 61 | seed: 3 62 | -------------------------------------------------------------------------------- /workflows/config/train_bg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train 5 | - bg_losses: [recons_om,tv_reg] 6 | - override bg_model: tensorf 7 | - override data_sources: [mask,colmap] 8 | - override validation: [training_dump] 9 | - _self_ 10 | 11 | n_steps: 15000 12 | 13 | save_checkpoint: 14 | step_size: 5000 15 | min_step: 10000 16 | folder: 17 | -------------------------------------------------------------------------------- /workflows/config/train_bg_rgba.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train_bg 5 | - override data_sources: [mask,rgba_mask,colmap] 6 | - _self_ 7 | 8 | data_sources: 9 | mask: 10 | blank_layer_only: true 11 | -------------------------------------------------------------------------------- /workflows/config/train_bg_rgba_blender.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train_bg 5 | - override data_sources: [mask,rgba_mask,blender] 6 | - _self_ 7 | 8 | data_sources: 9 | mask: 10 | blank_layer_only: true 11 | -------------------------------------------------------------------------------- /workflows/config/train_both.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train 5 | - fg_losses: [alpha_reg,brightness_reg,flow_recons,mask,recons,warped_alpha,bg_tv_reg] 6 | - bg_losses: [recons,tv_reg] 7 | - override fg_model: omnimatte 8 | - override bg_model: tensorf 9 | - _self_ 10 | 11 | n_steps: 15000 12 | -------------------------------------------------------------------------------- /workflows/config/train_both_davis.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train_both 5 | - _self_ 6 | 7 | n_steps: 10000 8 | 9 | save_checkpoint: 10 | step_size: 2000 11 | min_step: 6000 12 | -------------------------------------------------------------------------------- /workflows/config/train_om.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train 5 | - fg_losses: [alpha_reg,brightness_reg,flow_recons,mask,recons,warped_alpha] 6 | - override fg_model: omnimatte_noise 7 | - override trainer: omnimatte 8 | - override data_sources: [mask,flow,homography] 9 | - override validation: [] 10 | - _self_ 11 | 12 | validation: {} 13 | 14 | scheduler: 15 | fg: 16 | config: 17 | decay_start: 12000 18 | 19 | n_steps: 12000 20 | save_checkpoint: 21 | step_size: 3000 22 | min_step: 9000 23 | -------------------------------------------------------------------------------- /workflows/config/train_tf.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | defaults: 4 | - train 5 | - bg_losses: [recons,tv_reg] 6 | - override bg_model: tensorf 7 | - override data_sources: [mask,colmap] 8 | - override validation: [training_dump] 9 | - _self_ 10 | 11 | n_steps: 15000 12 | 13 | save_checkpoint: 14 | step_size: 5000 15 | min_step: 10000 16 | folder: 17 | -------------------------------------------------------------------------------- /workflows/config/trainer/matting.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: matting 4 | config: 5 | img_batch_size: 8 6 | ray_batch_size: 4096 7 | pbar_update_frequency: 10 8 | pretrain_bg_step: 0 9 | bg_composite_grad_step: 0 10 | fg_batch_size_fg: -1 11 | fg_batch_size_bg: -1 12 | fg_flow_mode: zeros 13 | depth_visualization_gamma: 1 14 | log_raw_mask_loss: true 15 | prerender_bg: false 16 | train_full_image: false 17 | fg_indexing_strategy: 18 | num_workers: 8 19 | # from dep injection 20 | output: 21 | writer_path: 22 | train_bg: 23 | train_fg: 24 | render_fg: 25 | prerender_bg_path: 26 | -------------------------------------------------------------------------------- /workflows/config/trainer/omnimatte.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | name: omnimatte 4 | config: 5 | img_batch_size: 24 6 | ray_batch_size: 0 7 | fg_batch_size_fg: 0 8 | fg_batch_size_bg: 0 9 | pbar_update_frequency: 10 10 | num_workers: 8 11 | # from dep injection 12 | output: 13 | writer_path: 14 | -------------------------------------------------------------------------------- /workflows/config/validation/render_all.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | render_all: 4 | config: 5 | name: render_all 6 | step_size: 3000 7 | use_upload_callback: true 8 | pre_train: false 9 | -------------------------------------------------------------------------------- /workflows/config/validation/render_all_bg.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | render_all: 4 | config: 5 | name: render_all 6 | step_size: 10000 7 | use_upload_callback: true 8 | pre_train: false 9 | -------------------------------------------------------------------------------- /workflows/config/validation/training_dump.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (c) Meta Platforms, Inc. and affiliates. 2 | 3 | training_dump: 4 | config: 5 | name: training_dump 6 | step_size: 1000 7 | n_frames: 1 8 | write_tensorboard: true 9 | pre_train: true 10 | --------------------------------------------------------------------------------