├── LICENSE ├── README.md ├── checkpoints ├── README.md └── big-lama-config.yaml ├── create_SDFT_pairs.py ├── demo ├── GS_depth_video.gif ├── GS_render_video.gif ├── Overview.png ├── Teaser.png └── input_panorama.png ├── gaussian_renderer ├── __init__.py └── network_gui.py ├── input ├── Camera_Trajectory │ ├── camera_pose_frame000000.txt │ ├── camera_pose_frame000001.txt │ ├── camera_pose_frame000002.txt │ ├── camera_pose_frame000003.txt │ ├── camera_pose_frame000004.txt │ ├── camera_pose_frame000005.txt │ ├── camera_pose_frame000006.txt │ ├── camera_pose_frame000007.txt │ ├── camera_pose_frame000008.txt │ ├── camera_pose_frame000009.txt │ ├── camera_pose_frame000010.txt │ ├── camera_pose_frame000011.txt │ ├── camera_pose_frame000012.txt │ ├── camera_pose_frame000013.txt │ ├── camera_pose_frame000014.txt │ ├── camera_pose_frame000015.txt │ ├── camera_pose_frame000016.txt │ ├── camera_pose_frame000017.txt │ ├── camera_pose_frame000018.txt │ ├── camera_pose_frame000019.txt │ ├── camera_pose_frame000020.txt │ ├── camera_pose_frame000021.txt │ ├── camera_pose_frame000022.txt │ ├── camera_pose_frame000023.txt │ ├── camera_pose_frame000024.txt │ ├── camera_pose_frame000025.txt │ ├── camera_pose_frame000026.txt │ ├── camera_pose_frame000027.txt │ ├── camera_pose_frame000028.txt │ ├── camera_pose_frame000029.txt │ ├── camera_pose_frame000030.txt │ ├── camera_pose_frame000031.txt │ ├── camera_pose_frame000032.txt │ ├── camera_pose_frame000033.txt │ ├── camera_pose_frame000034.txt │ ├── camera_pose_frame000035.txt │ ├── camera_pose_frame000036.txt │ ├── camera_pose_frame000037.txt │ ├── camera_pose_frame000038.txt │ ├── camera_pose_frame000039.txt │ ├── camera_pose_frame000040.txt │ ├── camera_pose_frame000041.txt │ ├── camera_pose_frame000042.txt │ ├── camera_pose_frame000043.txt │ ├── camera_pose_frame000044.txt │ ├── camera_pose_frame000045.txt │ ├── camera_pose_frame000046.txt │ ├── camera_pose_frame000047.txt │ ├── camera_pose_frame000048.txt │ ├── camera_pose_frame000049.txt │ ├── camera_pose_frame000050.txt │ ├── camera_pose_frame000051.txt │ ├── camera_pose_frame000052.txt │ ├── camera_pose_frame000053.txt │ ├── camera_pose_frame000054.txt │ ├── camera_pose_frame000055.txt │ ├── camera_pose_frame000056.txt │ ├── camera_pose_frame000057.txt │ ├── camera_pose_frame000058.txt │ ├── camera_pose_frame000059.txt │ ├── camera_pose_frame000060.txt │ ├── camera_pose_frame000061.txt │ ├── camera_pose_frame000062.txt │ ├── camera_pose_frame000063.txt │ ├── camera_pose_frame000064.txt │ ├── camera_pose_frame000065.txt │ ├── camera_pose_frame000081.txt │ ├── camera_pose_frame000082.txt │ ├── camera_pose_frame000083.txt │ ├── camera_pose_frame000084.txt │ ├── camera_pose_frame000085.txt │ ├── camera_pose_frame000086.txt │ ├── camera_pose_frame000087.txt │ ├── camera_pose_frame000088.txt │ ├── camera_pose_frame000089.txt │ ├── camera_pose_frame000090.txt │ ├── camera_pose_frame000091.txt │ ├── camera_pose_frame000092.txt │ ├── camera_pose_frame000093.txt │ ├── camera_pose_frame000094.txt │ ├── camera_pose_frame000095.txt │ ├── camera_pose_frame000096.txt │ ├── camera_pose_frame000097.txt │ ├── camera_pose_frame000098.txt │ ├── camera_pose_frame000099.txt │ ├── camera_pose_frame000100.txt │ ├── camera_pose_frame000101.txt │ ├── camera_pose_frame000102.txt │ ├── camera_pose_frame000103.txt │ ├── camera_pose_frame000104.txt │ ├── camera_pose_frame000105.txt │ ├── camera_pose_frame000106.txt │ ├── camera_pose_frame000107.txt │ ├── camera_pose_frame000108.txt │ ├── camera_pose_frame000109.txt │ ├── camera_pose_frame000110.txt │ ├── camera_pose_frame000111.txt │ ├── camera_pose_frame000112.txt │ ├── camera_pose_frame000113.txt │ ├── camera_pose_frame000114.txt │ ├── camera_pose_frame000115.txt │ ├── camera_pose_frame000116.txt │ ├── camera_pose_frame000117.txt │ ├── camera_pose_frame000118.txt │ ├── camera_pose_frame000119.txt │ ├── camera_pose_frame000120.txt │ ├── camera_pose_frame000121.txt │ ├── camera_pose_frame000122.txt │ ├── camera_pose_frame000123.txt │ ├── camera_pose_frame000124.txt │ ├── camera_pose_frame000125.txt │ ├── camera_pose_frame000126.txt │ ├── camera_pose_frame000127.txt │ ├── camera_pose_frame000128.txt │ ├── camera_pose_frame000129.txt │ ├── camera_pose_frame000130.txt │ ├── camera_pose_frame000131.txt │ ├── camera_pose_frame000132.txt │ ├── camera_pose_frame000133.txt │ ├── camera_pose_frame000134.txt │ ├── camera_pose_frame000135.txt │ ├── camera_pose_frame000136.txt │ ├── camera_pose_frame000137.txt │ ├── camera_pose_frame000138.txt │ ├── camera_pose_frame000139.txt │ ├── camera_pose_frame000140.txt │ ├── camera_pose_frame000141.txt │ ├── camera_pose_frame000142.txt │ ├── camera_pose_frame000143.txt │ ├── camera_pose_frame000144.txt │ ├── camera_pose_frame000145.txt │ ├── camera_pose_frame000146.txt │ ├── camera_pose_frame000147.txt │ ├── camera_pose_frame000148.txt │ ├── camera_pose_frame000149.txt │ ├── camera_pose_frame000150.txt │ ├── camera_pose_frame000151.txt │ ├── camera_pose_frame000152.txt │ ├── camera_pose_frame000153.txt │ ├── camera_pose_frame000154.txt │ └── camera_pose_frame000155.txt ├── another_input_panorama.png └── input_panorama.png ├── modules ├── equilib │ ├── __init__.py │ ├── cube2equi │ │ ├── __init__.py │ │ ├── base.py │ │ ├── numpy.py │ │ └── torch.py │ ├── equi2cube │ │ ├── __init__.py │ │ ├── base.py │ │ ├── numpy.py │ │ └── torch.py │ ├── equi2equi │ │ ├── __init__.py │ │ ├── base.py │ │ ├── numpy.py │ │ └── torch.py │ ├── equi2pers │ │ ├── __init__.py │ │ ├── base.py │ │ ├── numpy.py │ │ └── torch.py │ ├── grid_sample │ │ ├── __init__.py │ │ ├── cpp │ │ │ ├── __init__.py │ │ │ └── setup.py │ │ ├── numpy │ │ │ ├── __init__.py │ │ │ ├── bicubic.py │ │ │ ├── bilinear.py │ │ │ ├── grid_sample.py │ │ │ └── nearest.py │ │ └── torch │ │ │ ├── __init__.py │ │ │ ├── bicubic.py │ │ │ ├── bilinear.py │ │ │ ├── grid_sample.py │ │ │ ├── native.py │ │ │ └── nearest.py │ ├── numpy_utils │ │ ├── __init__.py │ │ ├── grid.py │ │ ├── intrinsic.py │ │ └── rotation.py │ └── torch_utils │ │ ├── __init__.py │ │ ├── func.py │ │ ├── grid.py │ │ ├── intrinsic.py │ │ └── rotation.py ├── geo_predictors │ ├── PanoFusionDistancePredictor.py │ ├── __init__.py │ ├── geo_predictor.py │ ├── networks.py │ ├── omnidata │ │ ├── modules │ │ │ ├── __init__.py │ │ │ ├── channel_attention.py │ │ │ ├── midas │ │ │ │ ├── __init__.py │ │ │ │ ├── base_model.py │ │ │ │ ├── blocks.py │ │ │ │ ├── dpt_depth.py │ │ │ │ ├── midas_net.py │ │ │ │ ├── midas_net_custom.py │ │ │ │ ├── transforms.py │ │ │ │ └── vit.py │ │ │ └── unet.py │ │ ├── omnidata_normal_predictor.py │ │ ├── omnidata_predictor.py │ │ ├── task_configs.py │ │ └── transforms.py │ ├── pano_fusion_inv_predictor.py │ ├── pano_fusion_normal_predictor.py │ ├── pano_geo_refiner.py │ └── pano_joint_predictor.py ├── inpainters │ ├── SDFT_inpainter.py │ ├── __init__.py │ ├── inpainter.py │ ├── lama │ │ ├── __init__.py │ │ ├── models │ │ │ ├── __init__.py │ │ │ └── ade20k │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── color150.mat │ │ │ │ ├── mobilenet.py │ │ │ │ ├── object150_info.csv │ │ │ │ ├── resnet.py │ │ │ │ ├── segm_lib │ │ │ │ ├── __init__.py │ │ │ │ ├── nn │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── modules │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── batchnorm.py │ │ │ │ │ │ ├── comm.py │ │ │ │ │ │ ├── replicate.py │ │ │ │ │ │ ├── tests │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ ├── test_numeric_batchnorm.py │ │ │ │ │ │ │ └── test_sync_batchnorm.py │ │ │ │ │ │ └── unittest.py │ │ │ │ │ └── parallel │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── data_parallel.py │ │ │ │ └── utils │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── th.py │ │ │ │ └── utils.py │ │ ├── predict_config.yaml │ │ └── saicinpainting │ │ │ ├── __init__.py │ │ │ ├── evaluation │ │ │ ├── __init__.py │ │ │ ├── data.py │ │ │ ├── evaluator.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── base_loss.py │ │ │ │ ├── fid │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── fid_score.py │ │ │ │ │ └── inception.py │ │ │ │ ├── lpips.py │ │ │ │ └── ssim.py │ │ │ ├── masks │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── countless │ │ │ │ │ ├── README.md │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── countless2d.py │ │ │ │ │ ├── countless3d.py │ │ │ │ │ ├── images │ │ │ │ │ │ ├── gcim.jpg │ │ │ │ │ │ ├── gray_segmentation.png │ │ │ │ │ │ ├── segmentation.png │ │ │ │ │ │ └── sparse.png │ │ │ │ │ ├── memprof │ │ │ │ │ │ ├── countless2d_gcim_N_1000.png │ │ │ │ │ │ ├── countless2d_quick_gcim_N_1000.png │ │ │ │ │ │ ├── countless3d.png │ │ │ │ │ │ ├── countless3d_dynamic.png │ │ │ │ │ │ ├── countless3d_dynamic_generalized.png │ │ │ │ │ │ └── countless3d_generalized.png │ │ │ │ │ ├── requirements.txt │ │ │ │ │ └── test.py │ │ │ │ └── mask.py │ │ │ ├── refinement.py │ │ │ ├── utils.py │ │ │ └── vis.py │ │ │ ├── training │ │ │ ├── __init__.py │ │ │ ├── data │ │ │ │ ├── __init__.py │ │ │ │ ├── aug.py │ │ │ │ ├── datasets.py │ │ │ │ └── masks.py │ │ │ ├── losses │ │ │ │ ├── __init__.py │ │ │ │ ├── adversarial.py │ │ │ │ ├── constants.py │ │ │ │ ├── distance_weighting.py │ │ │ │ ├── feature_matching.py │ │ │ │ ├── perceptual.py │ │ │ │ ├── segmentation.py │ │ │ │ └── style_loss.py │ │ │ ├── modules │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── depthwise_sep_conv.py │ │ │ │ ├── fake_fakes.py │ │ │ │ ├── ffc.py │ │ │ │ ├── multidilated_conv.py │ │ │ │ ├── multiscale.py │ │ │ │ ├── pix2pixhd.py │ │ │ │ ├── spatial_transform.py │ │ │ │ └── squeeze_excitation.py │ │ │ ├── trainers │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ └── default.py │ │ │ └── visualizers │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── colors.py │ │ │ │ ├── directory.py │ │ │ │ └── noop.py │ │ │ └── utils.py │ ├── lama_inpainter.py │ └── pano_pers_fusion_inpainter.py └── mesh_fusion │ ├── __init__.py │ ├── render.py │ ├── sup_info.py │ └── util.py ├── pano2room.py ├── requirements.txt ├── scene ├── __init__.py ├── arguments.py ├── cameras.py ├── dataset_readers.py └── gaussian_model.py ├── scripts ├── accelerate.yaml ├── create_SDFT_pairs.sh ├── run_Pano2Room.sh └── train_SDFT.sh ├── train_SDFT.py └── utils ├── __init__.py ├── camera_utils.py ├── common_utils.py ├── functions.py ├── general.py ├── generic_utils.py ├── geo_utils.py ├── graphics.py ├── loss.py ├── sh.py ├── system.py ├── trajectory.py └── warp_utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Tricky 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Welcome to Pano2Room! 2 | 3 | [Pano2Room: Novel View Synthesis from a Single Indoor Panorama (SIGGRAPH Asia 2024)](https://arxiv.org/abs/2408.11413). 4 | 5 | ## Overview 6 | #### In short, Pano2Room converts an input panorama into 3DGS. 7 | 8 | 9 | 10 | 11 | 12 | ## Demo 13 | In this demo, specify input panorama as: 14 | 15 | 16 | Then, Pano2Room generates the corresponding 3DGS and renders novel views: 17 | 18 | 19 | 20 | And the corresponding rendered depth: 21 | 22 | 23 | 24 | ## Quick Run 25 | ### 0. Setup the environment 26 | (1) Create a new conda environment and install [Pytorch3D](https://github.com/facebookresearch/pytorch3d) (for mesh rendering) and [diff-gaussian-rasterization-w-depth](https://github.com/JonathonLuiten/diff-gaussian-rasterization-w-depth) (for 3DGS rendering with depth) accordingly. Other requirements are specified in \. 27 | 28 | (2) Download pretrained weights in \ (for image inpainting and depth estimation). See \ for instructions. 29 | 30 | ### 1. Run Demo 31 | ``` 32 | sh scripts/run_Pano2Room.sh 33 | ``` 34 | This demo converts \ to 3DGS and renders novel views as in \. 35 | 36 | ### (Optional) 0.5. Fine-tune Inpainter (SDFT) 37 | 38 | Before running step 1., you can also fine-tune SD Inpainter model for better inpainting performance for a specific panorama. To do this: 39 | 40 | (1) Create self-supervised training pairs: 41 | ``` 42 | sh scripts/create_SDFT_pairs.sh 43 | ``` 44 | The pairs are then stored at \. 45 | 46 | (2) Training: 47 | ``` 48 | sh scripts/train_SDFT.sh 49 | ``` 50 | The SDFT weights are then stored at \. 51 | 52 | Then by running step 1., the SDFT weights will be automatically loaded. 53 | 54 | Notice this step needs to be performed for each new panorama. If you don't want to train SDFT for a new panorama, delete previous \ if exists. 55 | 56 | ## Try on your own panorama 57 | 58 | Simply replace \ with your own panorama and run the previous steps! 59 | 60 | #### Camera Trajectory 61 | We provide a camera trajectory at \ as in the above demo. Each file consists of [R|T] 4*4 matrix of a frame. Feel free to use more camera trajectories. 62 | 63 | 64 | ## Cite our paper 65 | 66 | If you find our work helpful, please cite our paper. Thank you! 67 | 68 | ACM Reference Format: 69 | ``` 70 | Guo Pu, Yiming Zhao, and Zhouhui Lian. 2024. Pano2Room: Novel View Synthesis from a Single Indoor Panorama. In SIGGRAPH Asia 2024 Conference Papers (SA Conference Papers '24), December 3--6, 2024, Tokyo, Japan. ACM, New York, NY, USA, 10 pages. 71 | https://doi.org/10.1145/3680528.3687616 72 | ``` 73 | -------------------------------------------------------------------------------- /checkpoints/README.md: -------------------------------------------------------------------------------- 1 | This folder should contain the several checkpoint files of pretrained inpainting and geometry estimation models: 2 | - big-lama.ckpt 3 | - omnidata_dpt_depth_v2.ckpt 4 | - omnidata_dpt_normal_v2.ckpt 5 | 6 | Checkpoint files can be found in this [dropbox link](https://www.dropbox.com/scl/fo/348s01x0trt0yxb934cwe/h?rlkey=a96g2incso7g53evzamzo0j0y&dl=0). -------------------------------------------------------------------------------- /demo/GS_depth_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/demo/GS_depth_video.gif -------------------------------------------------------------------------------- /demo/GS_render_video.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/demo/GS_render_video.gif -------------------------------------------------------------------------------- /demo/Overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/demo/Overview.png -------------------------------------------------------------------------------- /demo/Teaser.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/demo/Teaser.png -------------------------------------------------------------------------------- /demo/input_panorama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/demo/input_panorama.png -------------------------------------------------------------------------------- /gaussian_renderer/__init__.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import math 14 | from depth_diff_gaussian_rasterization_min import GaussianRasterizationSettings, GaussianRasterizer 15 | from scene.gaussian_model import GaussianModel 16 | from utils.sh import eval_sh 17 | 18 | def render(viewpoint_camera, pc: GaussianModel, opt, bg_color: torch.Tensor, scaling_modifier=1.0, override_color=None, render_only=False): 19 | """ 20 | Render the scene. 21 | 22 | Background tensor (bg_color) must be on GPU! 23 | """ 24 | 25 | # Create zero tensor. We will use it to make pytorch return gradients of the 2D (screen-space) means 26 | screenspace_points = torch.zeros_like(pc.get_xyz, dtype=pc.get_xyz.dtype, requires_grad=True, device="cuda") + 0 27 | try: 28 | screenspace_points.retain_grad() 29 | except: 30 | pass 31 | 32 | # Set up rasterization configuration 33 | tanfovx = math.tan(viewpoint_camera.FoVx * 0.5) 34 | tanfovy = math.tan(viewpoint_camera.FoVy * 0.5) 35 | 36 | raster_settings = GaussianRasterizationSettings( 37 | image_height=int(viewpoint_camera.image_height), 38 | image_width=int(viewpoint_camera.image_width), 39 | tanfovx=tanfovx, 40 | tanfovy=tanfovy, 41 | bg=bg_color, 42 | scale_modifier=scaling_modifier, 43 | viewmatrix=viewpoint_camera.world_view_transform, 44 | projmatrix=viewpoint_camera.full_proj_transform, 45 | sh_degree=pc.active_sh_degree, 46 | campos=viewpoint_camera.camera_center, 47 | prefiltered=False, 48 | debug=opt.debug 49 | ) 50 | 51 | rasterizer = GaussianRasterizer(raster_settings=raster_settings) 52 | 53 | means3D = pc.get_xyz 54 | means2D = screenspace_points 55 | opacity = pc.get_opacity 56 | 57 | # If precomputed 3d covariance is provided, use it. If not, then it will be computed from 58 | # scaling / rotation by the rasterizer. 59 | scales = None 60 | rotations = None 61 | cov3D_precomp = None 62 | if opt.compute_cov3D_python: 63 | cov3D_precomp = pc.get_covariance(scaling_modifier) 64 | else: 65 | scales = pc.get_scaling 66 | rotations = pc.get_rotation 67 | 68 | # If precomputed colors are provided, use them. Otherwise, if it is desired to precompute colors 69 | # from SHs in Python, do it. If not, then SH -> RGB conversion will be done by rasterizer. 70 | shs = None 71 | colors_precomp = None 72 | if override_color is None: 73 | if opt.convert_SHs_python: 74 | shs_view = pc.get_features.transpose(1, 2).view(-1, 3, (pc.max_sh_degree+1)**2) 75 | dir_pp = (pc.get_xyz - viewpoint_camera.camera_center.repeat(pc.get_features.shape[0], 1)) 76 | dir_pp_normalized = dir_pp/dir_pp.norm(dim=1, keepdim=True) 77 | sh2rgb = eval_sh(pc.active_sh_degree, shs_view, dir_pp_normalized) 78 | colors_precomp = torch.clamp_min(sh2rgb + 0.5, 0.0) 79 | else: 80 | shs = pc.get_features 81 | else: 82 | colors_precomp = override_color 83 | 84 | # Rasterize visible Gaussians to image, obtain their radii (on screen). 85 | rendered_image, radii, depth = rasterizer( 86 | means3D = means3D, 87 | means2D = means2D, 88 | shs = shs, 89 | colors_precomp = colors_precomp, 90 | opacities = opacity, 91 | scales = scales, 92 | rotations = rotations, 93 | cov3D_precomp = cov3D_precomp) 94 | 95 | # Those Gaussians that were frustum culled or had a radius of 0 were not visible. 96 | # They will be excluded from value updates used in the splitting criteria. 97 | if render_only: 98 | return {"render": rendered_image, "depth": depth} 99 | else: 100 | return {"render": rendered_image, 101 | "viewspace_points": screenspace_points, 102 | "visibility_filter" : radii > 0, 103 | "radii": radii, 104 | "depth": depth} 105 | -------------------------------------------------------------------------------- /gaussian_renderer/network_gui.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | 12 | import torch 13 | import traceback 14 | import socket 15 | import json 16 | from scene.cameras import MiniCam 17 | 18 | host = "127.0.0.1" 19 | port = 6009 20 | 21 | conn = None 22 | addr = None 23 | 24 | listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 25 | 26 | def init(wish_host, wish_port): 27 | global host, port, listener 28 | host = wish_host 29 | port = wish_port 30 | listener.bind((host, port)) 31 | listener.listen() 32 | listener.settimeout(0) 33 | 34 | def try_connect(): 35 | global conn, addr, listener 36 | try: 37 | conn, addr = listener.accept() 38 | print(f"\nConnected by {addr}") 39 | conn.settimeout(None) 40 | except Exception as inst: 41 | pass 42 | 43 | def read(): 44 | global conn 45 | messageLength = conn.recv(4) 46 | messageLength = int.from_bytes(messageLength, 'little') 47 | message = conn.recv(messageLength) 48 | return json.loads(message.decode("utf-8")) 49 | 50 | def send(message_bytes, verify): 51 | global conn 52 | if message_bytes != None: 53 | conn.sendall(message_bytes) 54 | conn.sendall(len(verify).to_bytes(4, 'little')) 55 | conn.sendall(bytes(verify, 'ascii')) 56 | 57 | def receive(): 58 | message = read() 59 | 60 | width = message["resolution_x"] 61 | height = message["resolution_y"] 62 | 63 | if width != 0 and height != 0: 64 | try: 65 | do_training = bool(message["train"]) 66 | fovy = message["fov_y"] 67 | fovx = message["fov_x"] 68 | znear = message["z_near"] 69 | zfar = message["z_far"] 70 | do_shs_python = bool(message["shs_python"]) 71 | do_rot_scale_python = bool(message["rot_scale_python"]) 72 | keep_alive = bool(message["keep_alive"]) 73 | scaling_modifier = message["scaling_modifier"] 74 | world_view_transform = torch.reshape(torch.tensor(message["view_matrix"]), (4, 4)).cuda() 75 | world_view_transform[:,1] = -world_view_transform[:,1] 76 | world_view_transform[:,2] = -world_view_transform[:,2] 77 | full_proj_transform = torch.reshape(torch.tensor(message["view_projection_matrix"]), (4, 4)).cuda() 78 | full_proj_transform[:,1] = -full_proj_transform[:,1] 79 | custom_cam = MiniCam(width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform) 80 | except Exception as e: 81 | print("") 82 | traceback.print_exc() 83 | raise e 84 | return custom_cam, do_training, do_shs_python, do_rot_scale_python, keep_alive, scaling_modifier 85 | else: 86 | return None, None, None, None, None, None -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000000.txt: -------------------------------------------------------------------------------- 1 | 1 2.44929e-16 1.49976e-32 -1 2 | 0 6.12323e-17 -1 -6.12323e-17 3 | -2.44929e-16 1 6.12323e-17 -1 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000001.txt: -------------------------------------------------------------------------------- 1 | 3.06162e-16 -1 -6.12323e-17 1 2 | 0 6.12323e-17 -1 -6.12323e-17 3 | 1 3.06162e-16 1.8747e-32 -1 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000002.txt: -------------------------------------------------------------------------------- 1 | -1 -3.67394e-16 -2.24964e-32 1 2 | 0 6.12323e-17 -1 -6.12323e-17 3 | 3.67394e-16 -1 -6.12323e-17 1 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000003.txt: -------------------------------------------------------------------------------- 1 | -4.28626e-16 1 6.12323e-17 -1 2 | 0 6.12323e-17 -1 -6.12323e-17 3 | -1 -4.28626e-16 -2.62458e-32 1 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000004.txt: -------------------------------------------------------------------------------- 1 | -4.28626e-16 1 6.12323e-17 -1 2 | 1 4.28626e-16 -6.12323e-17 -1 3 | -6.12323e-17 6.12323e-17 -1 1.2326e-32 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000005.txt: -------------------------------------------------------------------------------- 1 | -4.28626e-16 1 6.12323e-17 -1 2 | -1 -4.28626e-16 1.83697e-16 1 3 | 1.83697e-16 -6.12323e-17 1 -1.22465e-16 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000006.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.99 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000007.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.965 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000008.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.94 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000009.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.915 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000010.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.89 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000011.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.865 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000012.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.84 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000013.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.815 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000014.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.79 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000015.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.765 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000016.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.74 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000017.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.715 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000018.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.69 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000019.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.665 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000020.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.64 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000021.txt: -------------------------------------------------------------------------------- 1 | 0 -1 0 1 2 | 0 0 -1 -0 3 | 1 0 0 -0.615 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000022.txt: -------------------------------------------------------------------------------- 1 | -0.0324504 -0.999473 0 1.03192 2 | 0 -0 -1 -0 3 | 0.999473 -0.0324504 0 -0.58182 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000023.txt: -------------------------------------------------------------------------------- 1 | -0.0647986 -0.997898 0 1.0627 2 | 0 -0 -1 -0 3 | 0.997898 -0.0647986 0 -0.547289 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000024.txt: -------------------------------------------------------------------------------- 1 | -0.0969438 -0.99529 0 1.09223 2 | 0 -0 -1 -0 3 | 0.99529 -0.0969438 0 -0.511524 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000025.txt: -------------------------------------------------------------------------------- 1 | -0.128789 -0.991672 0 1.12046 2 | 0 -0 -1 -0 3 | 0.991672 -0.128789 0 -0.47465 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000026.txt: -------------------------------------------------------------------------------- 1 | -0.16024 -0.987078 0 1.14732 2 | 0 -0 -1 -0 3 | 0.987078 -0.16024 0 -0.436798 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000027.txt: -------------------------------------------------------------------------------- 1 | -0.191211 -0.981549 0 1.17276 2 | 0 -0 -1 -0 3 | 0.981549 -0.191211 0 -0.398101 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000028.txt: -------------------------------------------------------------------------------- 1 | -0.221621 -0.975133 0 1.19675 2 | 0 -0 -1 -0 3 | 0.975133 -0.221621 0 -0.358694 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000029.txt: -------------------------------------------------------------------------------- 1 | -0.251398 -0.967884 0 1.21928 2 | 0 -0 -1 -0 3 | 0.967884 -0.251398 0 -0.31871 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000030.txt: -------------------------------------------------------------------------------- 1 | -0.280479 -0.95986 0 1.24034 2 | 0 -0 -1 -0 3 | 0.95986 -0.280479 0 -0.278282 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000031.txt: -------------------------------------------------------------------------------- 1 | -0.308807 -0.951125 0 1.25993 2 | 0 -0 -1 -0 3 | 0.951125 -0.308807 0 -0.237534 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000032.txt: -------------------------------------------------------------------------------- 1 | -0.336336 -0.941742 0 1.27808 2 | 0 -0 -1 -0 3 | 0.941742 -0.336336 0 -0.196589 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000033.txt: -------------------------------------------------------------------------------- 1 | -0.36303 -0.931777 0 1.29481 2 | 0 -0 -1 -0 3 | 0.931777 -0.36303 0 -0.155558 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000034.txt: -------------------------------------------------------------------------------- 1 | -0.388859 -0.921297 0 1.31016 2 | 0 -0 -1 -0 3 | 0.921297 -0.388859 0 -0.114549 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000035.txt: -------------------------------------------------------------------------------- 1 | -0.413803 -0.910366 0 1.32417 2 | 0 -0 -1 -0 3 | 0.910366 -0.413803 0 -0.0736569 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000036.txt: -------------------------------------------------------------------------------- 1 | -0.437848 -0.899049 0 1.3369 2 | 0 -0 -1 -0 3 | 0.899049 -0.437848 0 -0.03297 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000037.txt: -------------------------------------------------------------------------------- 1 | -0.488405 -0.872617 0 1.36102 2 | 0 -0 -1 -0 3 | 0.872617 -0.488405 0 -0.000309317 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000038.txt: -------------------------------------------------------------------------------- 1 | -0.549617 -0.835417 0 1.38503 2 | 0 -0 -1 -0 3 | 0.835417 -0.549617 0 0.0553464 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000039.txt: -------------------------------------------------------------------------------- 1 | -0.62368 -0.781679 0 1.40536 2 | 0 -0 -1 -0 3 | 0.781679 -0.62368 0 0.142636 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000040.txt: -------------------------------------------------------------------------------- 1 | -0.711836 -0.702345 0 1.41418 2 | 0 -0 -1 -0 3 | 0.702345 -0.711836 0 0.272894 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000041.txt: -------------------------------------------------------------------------------- 1 | -0.811534 -0.584305 0 1.39584 2 | 0 -0 -1 -0 3 | 0.584305 -0.811534 0 0.458273 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000042.txt: -------------------------------------------------------------------------------- 1 | -0.910782 -0.412888 0 1.32367 2 | 0 -0 -1 -0 3 | 0.412888 -0.910782 0 0.703761 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000043.txt: -------------------------------------------------------------------------------- 1 | -0.98302 -0.183497 0 1.16652 2 | 0 -0 -1 -0 3 | 0.183497 -0.98302 0 0.990262 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000044.txt: -------------------------------------------------------------------------------- 1 | -0.996815 0.0797452 0 0.91707 2 | -0 0 -1 -0 3 | -0.0797452 -0.996815 0 1.26466 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000045.txt: -------------------------------------------------------------------------------- 1 | -0.944836 0.327543 0 0.617293 2 | -0 0 -1 -0 3 | -0.327543 -0.944836 0 1.47083 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000046.txt: -------------------------------------------------------------------------------- 1 | -0.852438 0.522829 0 0.329609 2 | -0 0 -1 -0 3 | -0.522829 -0.852438 0 1.59522 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000047.txt: -------------------------------------------------------------------------------- 1 | -0.750714 0.660628 0 0.0900856 2 | -0 0 -1 -0 3 | -0.660628 -0.750714 0 1.6611 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000048.txt: -------------------------------------------------------------------------------- 1 | -0.657263 0.753661 0 -0.0963986 2 | -0 0 -1 -0 3 | -0.753661 -0.657263 0 1.6962 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000049.txt: -------------------------------------------------------------------------------- 1 | -0.57759 0.816327 0 -0.238737 2 | -0 0 -1 -0 3 | -0.816327 -0.57759 0 1.71854 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000050.txt: -------------------------------------------------------------------------------- 1 | -0.511484 0.859293 0 -0.347809 2 | -0 0 -1 -0 3 | -0.859293 -0.511484 0 1.73736 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000051.txt: -------------------------------------------------------------------------------- 1 | -0.456935 0.8895 0 -0.432565 2 | -0 0 -1 -0 3 | -0.8895 -0.456935 0 1.75678 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000052.txt: -------------------------------------------------------------------------------- 1 | -0.411734 0.911304 0 -0.49957 2 | -0 0 -1 -0 3 | -0.911304 -0.411734 0 1.77843 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000053.txt: -------------------------------------------------------------------------------- 1 | -0.373968 0.927441 0 -0.553473 2 | -0 0 -1 -0 3 | -0.927441 -0.373968 0 1.80279 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000054.txt: -------------------------------------------------------------------------------- 1 | -0.342109 0.93966 0 -0.597551 2 | -0 0 -1 -0 3 | -0.93966 -0.342109 0 1.82984 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000055.txt: -------------------------------------------------------------------------------- 1 | -0.314968 0.949102 0 -0.634135 2 | -0 0 -1 -0 3 | -0.949102 -0.314968 0 1.85937 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000056.txt: -------------------------------------------------------------------------------- 1 | -0.291626 0.956532 0 -0.664907 2 | -0 0 -1 -0 3 | -0.956532 -0.291626 0 1.89111 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000057.txt: -------------------------------------------------------------------------------- 1 | -0.271374 0.962474 0 -0.6911 2 | -0 0 -1 -0 3 | -0.962474 -0.271374 0 1.92478 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000058.txt: -------------------------------------------------------------------------------- 1 | -0.253661 0.967293 0 -0.713632 2 | -0 0 -1 -0 3 | -0.967293 -0.253661 0 1.96013 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000059.txt: -------------------------------------------------------------------------------- 1 | -0.238052 0.971252 0 -0.7332 2 | -0 0 -1 -0 3 | -0.971252 -0.238052 0 1.99695 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000060.txt: -------------------------------------------------------------------------------- 1 | -0.224204 0.974542 0 -0.750338 2 | -0 0 -1 -0 3 | -0.974542 -0.224204 0 2.03504 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000061.txt: -------------------------------------------------------------------------------- 1 | -0.211843 0.977304 0 -0.76546 2 | -0 0 -1 -0 3 | -0.977304 -0.211843 0 2.07424 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000062.txt: -------------------------------------------------------------------------------- 1 | -0.200747 0.979643 0 -0.778897 2 | -0 0 -1 -0 3 | -0.979643 -0.200747 0 2.1144 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000063.txt: -------------------------------------------------------------------------------- 1 | -0.190734 0.981642 0 -0.790908 2 | -0 0 -1 -0 3 | -0.981642 -0.190734 0 2.15542 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000064.txt: -------------------------------------------------------------------------------- 1 | -0.181656 0.983362 0 -0.801707 2 | -0 0 -1 -0 3 | -0.983362 -0.181656 0 2.19719 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000065.txt: -------------------------------------------------------------------------------- 1 | -0.17339 0.984853 0 -0.811464 2 | -0 0 -1 -0 3 | -0.984853 -0.17339 0 2.23962 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000081.txt: -------------------------------------------------------------------------------- 1 | 0.165833 0.986154 -0 -1.15199 2 | -0 0 -1 -0 3 | -0.986154 0.165833 0 1.95098 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000082.txt: -------------------------------------------------------------------------------- 1 | 0.187213 0.982319 -0 -1.16953 2 | -0 0 -1 -0 3 | -0.982319 0.187213 0 1.93017 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000083.txt: -------------------------------------------------------------------------------- 1 | 0.208331 0.978058 -0 -1.18639 2 | -0 0 -1 -0 3 | -0.978058 0.208331 0 1.90974 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000084.txt: -------------------------------------------------------------------------------- 1 | 0.229161 0.973389 -0 -1.20255 2 | -0 0 -1 -0 3 | -0.973389 0.229161 0 1.88971 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000085.txt: -------------------------------------------------------------------------------- 1 | 0.249681 0.968328 -0 -1.21801 2 | -0 0 -1 -0 3 | -0.968328 0.249681 0 1.87012 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000086.txt: -------------------------------------------------------------------------------- 1 | 0.26987 0.962897 -0 -1.23277 2 | -0 0 -1 -0 3 | -0.962897 0.26987 0 1.85099 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000087.txt: -------------------------------------------------------------------------------- 1 | 0.28971 0.957115 -0 -1.24682 2 | -0 0 -1 -0 3 | -0.957115 0.28971 0 1.83236 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000088.txt: -------------------------------------------------------------------------------- 1 | 0.309183 0.951003 -0 -1.26019 2 | -0 0 -1 -0 3 | -0.951003 0.309183 0 1.81427 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000089.txt: -------------------------------------------------------------------------------- 1 | 0.328274 0.944582 -0 -1.27286 2 | -0 0 -1 -0 3 | -0.944582 0.328274 0 1.79672 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000090.txt: -------------------------------------------------------------------------------- 1 | 0.346972 0.937876 -0 -1.28485 2 | -0 0 -1 -0 3 | -0.937876 0.346972 0 1.77976 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000091.txt: -------------------------------------------------------------------------------- 1 | 0.365265 0.930904 -0 -1.29617 2 | -0 0 -1 -0 3 | -0.930904 0.365265 0 1.7634 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000092.txt: -------------------------------------------------------------------------------- 1 | 0.383144 0.923688 -0 -1.30683 2 | -0 0 -1 -0 3 | -0.923688 0.383144 0 1.74766 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000093.txt: -------------------------------------------------------------------------------- 1 | 0.400603 0.916252 -0 -1.31685 2 | -0 0 -1 -0 3 | -0.916252 0.400603 0 1.73256 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000094.txt: -------------------------------------------------------------------------------- 1 | 0.417637 0.908614 -0 -1.32625 2 | -0 0 -1 -0 3 | -0.908614 0.417637 0 1.71812 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000095.txt: -------------------------------------------------------------------------------- 1 | 0.434241 0.900797 -0 -1.33504 2 | -0 0 -1 -0 3 | -0.900797 0.434241 0 1.70435 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000096.txt: -------------------------------------------------------------------------------- 1 | 0.450414 0.89282 -0 -1.34323 2 | -0 0 -1 -0 3 | -0.89282 0.450414 0 1.69126 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000097.txt: -------------------------------------------------------------------------------- 1 | 0.467029 0.884242 -0 -1.35127 2 | -0 0 -1 -0 3 | -0.884242 0.467029 0 1.62163 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000098.txt: -------------------------------------------------------------------------------- 1 | 0.484728 0.874665 -0 -1.35939 2 | -0 0 -1 -0 3 | -0.874665 0.484728 0 1.55038 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000099.txt: -------------------------------------------------------------------------------- 1 | 0.503593 0.863941 -0 -1.36753 2 | -0 0 -1 -0 3 | -0.863941 0.503593 0 1.47732 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000100.txt: -------------------------------------------------------------------------------- 1 | 0.523708 0.851898 -0 -1.37561 2 | -0 0 -1 -0 3 | -0.851898 0.523708 0 1.40226 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000101.txt: -------------------------------------------------------------------------------- 1 | 0.545159 0.838333 -0 -1.38349 2 | -0 0 -1 -0 3 | -0.838333 0.545159 0 1.32498 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000102.txt: -------------------------------------------------------------------------------- 1 | 0.568028 0.823009 -0 -1.39104 2 | -0 0 -1 -0 3 | -0.823009 0.568028 0 1.24525 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000103.txt: -------------------------------------------------------------------------------- 1 | 0.59239 0.805651 -0 -1.39804 2 | -0 0 -1 -0 3 | -0.805651 0.59239 0 1.1628 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000104.txt: -------------------------------------------------------------------------------- 1 | 0.618307 0.785937 -0 -1.40424 2 | -0 0 -1 -0 3 | -0.785937 0.618307 0 1.07737 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000105.txt: -------------------------------------------------------------------------------- 1 | 0.645814 0.763495 -0 -1.40931 2 | -0 0 -1 -0 3 | -0.763495 0.645814 0 0.988676 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000106.txt: -------------------------------------------------------------------------------- 1 | 0.674909 0.737901 -0 -1.41281 2 | -0 0 -1 -0 3 | -0.737901 0.674909 0 0.896437 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000107.txt: -------------------------------------------------------------------------------- 1 | 0.705537 0.708673 -0 -1.41421 2 | -0 0 -1 -0 3 | -0.708673 0.705537 0 0.800401 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000108.txt: -------------------------------------------------------------------------------- 1 | 0.737562 0.675279 -0 -1.41284 2 | -0 0 -1 -0 3 | -0.675279 0.737562 0 0.700365 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000109.txt: -------------------------------------------------------------------------------- 1 | 0.770742 0.637147 -0 -1.40789 2 | -0 0 -1 -0 3 | -0.637147 0.770742 0 0.596221 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000110.txt: -------------------------------------------------------------------------------- 1 | 0.804696 0.593687 -0 -1.39838 2 | -0 0 -1 -0 3 | -0.593687 0.804696 0 0.488012 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000111.txt: -------------------------------------------------------------------------------- 1 | 0.838869 0.544333 -0 -1.3832 2 | -0 0 -1 -0 3 | -0.544333 0.838869 0 0.376009 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000112.txt: -------------------------------------------------------------------------------- 1 | 0.872506 0.488603 -0 -1.36111 2 | -0 0 -1 -0 3 | -0.488603 0.872506 0 0.260792 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000113.txt: -------------------------------------------------------------------------------- 1 | 0.904636 0.426184 -0 -1.33082 2 | -0 0 -1 -0 3 | -0.426184 0.904636 0 0.143345 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000114.txt: -------------------------------------------------------------------------------- 1 | 0.934093 0.357031 -0 -1.29112 2 | -0 0 -1 -0 3 | -0.357031 0.934093 0 0.0251271 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000115.txt: -------------------------------------------------------------------------------- 1 | 0.959569 0.281474 -0 -1.24104 2 | -0 0 -1 -0 3 | -0.281474 0.959569 0 -0.0918947 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000116.txt: -------------------------------------------------------------------------------- 1 | 0.979734 0.200301 -0 -1.18004 2 | -0 0 -1 -0 3 | -0.200301 0.979734 0 -0.205298 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000117.txt: -------------------------------------------------------------------------------- 1 | 0.99339 0.114792 -0 -1.10818 2 | -0 0 -1 -0 3 | -0.114792 0.99339 0 -0.312355 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000118.txt: -------------------------------------------------------------------------------- 1 | 0.999645 0.0266572 -0 -1.0263 2 | -0 0 -1 -0 3 | -0.0266572 0.999645 0 -0.410287 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000119.txt: -------------------------------------------------------------------------------- 1 | 0.99807 -0.0621021 0 -0.935968 2 | 0 0 -1 -0 3 | 0.0621021 0.99807 0 -0.496584 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000120.txt: -------------------------------------------------------------------------------- 1 | 0.988775 -0.149415 0 -0.83936 2 | 0 0 -1 -0 3 | 0.149415 0.988775 0 -0.569303 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000121.txt: -------------------------------------------------------------------------------- 1 | 0.972387 -0.233373 0 -0.739014 2 | 0 0 -1 -0 3 | 0.233373 0.972387 0 -0.627287 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000122.txt: -------------------------------------------------------------------------------- 1 | 0.949942 -0.312425 0 -0.637517 2 | 0 0 -1 -0 3 | 0.312425 0.949942 0 -0.670226 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000123.txt: -------------------------------------------------------------------------------- 1 | 0.922713 -0.385489 0 -0.537224 2 | 0 0 -1 -0 3 | 0.385489 0.922713 0 -0.698586 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000124.txt: -------------------------------------------------------------------------------- 1 | 0.892036 -0.451965 0 -0.440071 2 | 0 0 -1 -0 3 | 0.451965 0.892036 0 -0.71342 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000125.txt: -------------------------------------------------------------------------------- 1 | 0.859173 -0.511685 0 -0.347488 2 | 0 0 -1 -0 3 | 0.511685 0.859173 0 -0.716159 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000126.txt: -------------------------------------------------------------------------------- 1 | 0.825217 -0.564815 0 -0.260402 2 | 0 0 -1 -0 3 | 0.564815 0.825217 0 -0.708394 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000127.txt: -------------------------------------------------------------------------------- 1 | 0.819232 -0.573462 0 -0.24577 2 | 0 0 -1 -0 3 | 0.573462 0.819232 0 -0.721334 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000128.txt: -------------------------------------------------------------------------------- 1 | 0.812966 -0.582311 0 -0.230656 2 | 0 0 -1 -0 3 | 0.582311 0.812966 0 -0.734118 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000129.txt: -------------------------------------------------------------------------------- 1 | 0.806405 -0.591364 0 -0.215041 2 | 0 0 -1 -0 3 | 0.591364 0.806405 0 -0.746731 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000130.txt: -------------------------------------------------------------------------------- 1 | 0.799532 -0.600624 0 -0.198908 2 | 0 0 -1 -0 3 | 0.600624 0.799532 0 -0.759155 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000131.txt: -------------------------------------------------------------------------------- 1 | 0.792329 -0.610094 0 -0.182236 2 | 0 0 -1 -0 3 | 0.610094 0.792329 0 -0.771372 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000132.txt: -------------------------------------------------------------------------------- 1 | 0.78478 -0.619775 0 -0.165005 2 | 0 0 -1 -0 3 | 0.619775 0.78478 0 -0.783361 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000133.txt: -------------------------------------------------------------------------------- 1 | 0.776864 -0.629669 0 -0.147195 2 | 0 0 -1 -0 3 | 0.629669 0.776864 0 -0.7951 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000134.txt: -------------------------------------------------------------------------------- 1 | 0.768562 -0.639776 0 -0.128786 2 | 0 0 -1 -0 3 | 0.639776 0.768562 0 -0.806564 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000135.txt: -------------------------------------------------------------------------------- 1 | 0.759852 -0.650096 0 -0.109756 2 | 0 0 -1 -0 3 | 0.650096 0.759852 0 -0.817728 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000136.txt: -------------------------------------------------------------------------------- 1 | 0.750714 -0.660628 0 -0.0900856 2 | 0 0 -1 -0 3 | 0.660628 0.750714 0 -0.828563 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000137.txt: -------------------------------------------------------------------------------- 1 | 0.741123 -0.67137 0 -0.0697527 2 | 0 0 -1 -0 3 | 0.67137 0.741123 0 -0.839038 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000138.txt: -------------------------------------------------------------------------------- 1 | 0.731055 -0.682318 0 -0.048737 2 | 0 0 -1 -0 3 | 0.682318 0.731055 0 -0.849121 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000139.txt: -------------------------------------------------------------------------------- 1 | 0.720487 -0.693469 0 -0.0270183 2 | 0 0 -1 -0 3 | 0.693469 0.720487 0 -0.858775 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000140.txt: -------------------------------------------------------------------------------- 1 | 0.709391 -0.704815 0 -0.00457672 2 | 0 0 -1 -0 3 | 0.704815 0.709391 0 -0.867963 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000141.txt: -------------------------------------------------------------------------------- 1 | 0.697742 -0.716349 0 0.0186065 2 | 0 0 -1 -0 3 | 0.716349 0.697742 0 -0.876643 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000142.txt: -------------------------------------------------------------------------------- 1 | 0.721387 -0.692532 0 -0.0288555 2 | 0 0 -1 -0 3 | 0.692532 0.721387 0 -0.894087 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000143.txt: -------------------------------------------------------------------------------- 1 | 0.745761 -0.666213 0 -0.0795479 2 | 0 0 -1 -0 3 | 0.666213 0.745761 0 -0.909133 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000144.txt: -------------------------------------------------------------------------------- 1 | 0.770742 -0.637147 0 -0.133595 2 | 0 0 -1 -0 3 | 0.637147 0.770742 0 -0.921345 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000145.txt: -------------------------------------------------------------------------------- 1 | 0.796162 -0.605083 0 -0.191079 2 | 0 0 -1 -0 3 | 0.605083 0.796162 0 -0.930236 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000146.txt: -------------------------------------------------------------------------------- 1 | 0.821798 -0.56978 0 -0.252018 2 | 0 0 -1 -0 3 | 0.56978 0.821798 0 -0.93526 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000147.txt: -------------------------------------------------------------------------------- 1 | 0.847363 -0.531014 0 -0.316349 2 | 0 0 -1 -0 3 | 0.531014 0.847363 0 -0.935828 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000148.txt: -------------------------------------------------------------------------------- 1 | 0.872506 -0.488603 0 -0.383903 2 | 0 0 -1 -0 3 | 0.488603 0.872506 0 -0.931313 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000149.txt: -------------------------------------------------------------------------------- 1 | 0.896806 -0.442424 0 -0.454382 2 | 0 0 -1 -0 3 | 0.442424 0.896806 0 -0.921079 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000150.txt: -------------------------------------------------------------------------------- 1 | 0.919778 -0.392439 0 -0.527339 2 | 0 0 -1 -0 3 | 0.392439 0.919778 0 -0.90451 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000151.txt: -------------------------------------------------------------------------------- 1 | 0.940887 -0.338719 0 -0.602168 2 | 0 0 -1 -0 3 | 0.338719 0.940887 0 -0.881047 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000152.txt: -------------------------------------------------------------------------------- 1 | 0.959569 -0.281474 0 -0.678095 2 | 0 0 -1 -0 3 | 0.281474 0.959569 0 -0.850242 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000153.txt: -------------------------------------------------------------------------------- 1 | 0.97526 -0.221059 0 -0.754201 2 | 0 0 -1 -0 3 | 0.221059 0.97526 0 -0.811807 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000154.txt: -------------------------------------------------------------------------------- 1 | 0.987441 -0.157991 0 -0.82945 2 | 0 0 -1 -0 3 | 0.157991 0.987441 0 -0.765661 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/Camera_Trajectory/camera_pose_frame000155.txt: -------------------------------------------------------------------------------- 1 | 0.995673 -0.0929295 0 -0.902743 2 | 0 0 -1 -0 3 | 0.0929295 0.995673 0 -0.711972 4 | 0 0 0 1 5 | -------------------------------------------------------------------------------- /input/another_input_panorama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/input/another_input_panorama.png -------------------------------------------------------------------------------- /input/input_panorama.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/input/input_panorama.png -------------------------------------------------------------------------------- /modules/equilib/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from modules.equilib.cube2equi.base import Cube2Equi, cube2equi 4 | from modules.equilib.equi2cube.base import Equi2Cube, equi2cube 5 | from modules.equilib.equi2equi.base import Equi2Equi, equi2equi 6 | from modules.equilib.equi2pers.base import Equi2Pers, equi2pers 7 | 8 | __all__ = [ 9 | "Cube2Equi", 10 | "Equi2Cube", 11 | "Equi2Equi", 12 | "Equi2Pers", 13 | "cube2equi", 14 | "equi2cube", 15 | "equi2equi", 16 | "equi2pers", 17 | ] 18 | -------------------------------------------------------------------------------- /modules/equilib/cube2equi/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /modules/equilib/equi2cube/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /modules/equilib/equi2cube/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Dict, List, Union 4 | 5 | import numpy as np 6 | 7 | import torch 8 | 9 | from .numpy import run as run_numpy 10 | from .torch import run as run_torch 11 | 12 | __all__ = ["Equi2Cube", "equi2cube"] 13 | 14 | ArrayLike = Union[np.ndarray, torch.Tensor] 15 | Rot = Union[Dict[str, float], List[Dict[str, float]]] 16 | CubeMaps = Union[ 17 | # single/batch 'horizon' or 'dice' 18 | np.ndarray, 19 | torch.Tensor, 20 | # single 'list' 21 | List[np.ndarray], 22 | List[torch.Tensor], 23 | # batch 'list' 24 | List[List[np.ndarray]], 25 | List[List[torch.Tensor]], 26 | # single 'dict' 27 | Dict[str, np.ndarray], 28 | Dict[str, np.ndarray], 29 | # batch 'dict' 30 | List[Dict[str, np.ndarray]], 31 | List[Dict[str, np.ndarray]], 32 | ] 33 | 34 | 35 | class Equi2Cube(object): 36 | """ 37 | params: 38 | - w_face (int): cube face width 39 | - cube_format (str): ("dice", "horizon", "dict", "list") 40 | - mode (str) 41 | - z_down (bool) 42 | 43 | inputs: 44 | - equi (np.ndarray, torch.Tensor) 45 | - rots (dict, list[dict]): {"roll", "pitch", "yaw"} 46 | 47 | returns: 48 | - cube (np.ndarray, torch.Tensor, list, dict) 49 | """ 50 | 51 | def __init__( 52 | self, 53 | w_face: int, 54 | cube_format: str, 55 | z_down: bool = False, 56 | mode: str = "bilinear", 57 | ) -> None: 58 | self.w_face = w_face 59 | self.cube_format = cube_format 60 | self.z_down = z_down 61 | self.mode = mode 62 | 63 | def __call__(self, equi: ArrayLike, rots: Rot) -> CubeMaps: 64 | return equi2cube( 65 | equi=equi, 66 | rots=rots, 67 | w_face=self.w_face, 68 | cube_format=self.cube_format, 69 | z_down=self.z_down, 70 | mode=self.mode, 71 | ) 72 | 73 | 74 | def equi2cube( 75 | equi: ArrayLike, 76 | rots: Rot, 77 | w_face: int, 78 | cube_format: str, 79 | z_down: bool = False, 80 | mode: str = "bilinear", 81 | **kwargs, 82 | ) -> CubeMaps: 83 | """ 84 | params: 85 | - equi (np.ndarray, torch.Tensor) 86 | - rot (dict, list[dict]): {"roll", "pitch", "yaw"} 87 | - w_face (int): cube face width 88 | - cube_format (str): ("dice", "horizon", "dict", "list") 89 | - z_down (bool) 90 | - mode (str) 91 | 92 | returns: 93 | - cube (np.ndarray, torch.Tensor, dict, list) 94 | 95 | """ 96 | 97 | _type = None 98 | if isinstance(equi, np.ndarray): 99 | _type = "numpy" 100 | elif torch.is_tensor(equi): 101 | _type = "torch" 102 | else: 103 | raise ValueError 104 | 105 | is_single = False 106 | if len(equi.shape) == 3 and isinstance(rots, dict): 107 | # probably the input was a single image 108 | equi = equi[None, ...] 109 | rots = [rots] 110 | is_single = True 111 | elif len(equi.shape) == 3: 112 | # probably a grayscale image 113 | equi = equi[:, None, ...] 114 | 115 | assert isinstance(rots, list), "ERR: rots is not a list" 116 | if _type == "numpy": 117 | out = run_numpy( 118 | equi=equi, 119 | rots=rots, 120 | w_face=w_face, 121 | cube_format=cube_format, 122 | z_down=z_down, 123 | mode=mode, 124 | **kwargs, 125 | ) 126 | elif _type == "torch": 127 | out = run_torch( 128 | equi=equi, 129 | rots=rots, 130 | w_face=w_face, 131 | cube_format=cube_format, 132 | z_down=z_down, 133 | mode=mode, 134 | **kwargs, 135 | ) 136 | else: 137 | raise ValueError 138 | 139 | # make sure that the output batch dim is removed if it's only a single cubemap 140 | if is_single: 141 | out = out[0] 142 | 143 | return out 144 | -------------------------------------------------------------------------------- /modules/equilib/equi2equi/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /modules/equilib/equi2equi/base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Dict, List, Optional, Union 4 | 5 | import numpy as np 6 | 7 | import torch 8 | 9 | from .numpy import run as run_numpy 10 | from .torch import run as run_torch 11 | 12 | __all__ = ["Equi2Equi", "equi2equi"] 13 | 14 | ArrayLike = Union[np.ndarray, torch.Tensor] 15 | Rot = Union[Dict[str, float], List[Dict[str, float]]] 16 | 17 | 18 | class Equi2Equi(object): 19 | """ 20 | params: 21 | - w_out, h_out (optional int): equi image size 22 | - sampling_method (str): defaults to "default" 23 | - mode (str): interpolation mode, defaults to "bilinear" 24 | - z_down (bool) 25 | 26 | input params: 27 | - src (np.ndarray, torch.Tensor) 28 | - rots (dict, list[dict]) 29 | 30 | return: 31 | - equi (np.ndarray, torch.Tensor) 32 | """ 33 | 34 | def __init__( 35 | self, 36 | height: Optional[int] = None, 37 | width: Optional[int] = None, 38 | mode: str = "bilinear", 39 | z_down: bool = False, 40 | ) -> None: 41 | self.height = height 42 | self.width = width 43 | self.mode = mode 44 | self.z_down = z_down 45 | 46 | def __call__(self, src: ArrayLike, rots: Rot, **kwargs) -> ArrayLike: 47 | return equi2equi( 48 | src=src, rots=rots, mode=self.mode, z_down=self.z_down, **kwargs 49 | ) 50 | 51 | 52 | def equi2equi( 53 | src: ArrayLike, 54 | rots: Rot, 55 | mode: str = "bilinear", 56 | z_down: bool = False, 57 | height: Optional[int] = None, 58 | width: Optional[int] = None, 59 | **kwargs, 60 | ) -> ArrayLike: 61 | """ 62 | params: 63 | - src 64 | - rots 65 | - mode (str): interpolation mode, defaults to "bilinear" 66 | - z_down (bool) 67 | - height, width (optional int): output image size 68 | 69 | returns: 70 | - out 71 | 72 | """ 73 | 74 | _type = None 75 | if isinstance(src, np.ndarray): 76 | _type = "numpy" 77 | elif torch.is_tensor(src): 78 | _type = "torch" 79 | else: 80 | raise ValueError 81 | 82 | is_single = False 83 | if len(src.shape) == 3 and isinstance(rots, dict): 84 | # probably the input was a single image 85 | src = src[None, ...] 86 | rots = [rots] 87 | is_single = True 88 | elif len(src.shape) == 3: 89 | # probably a grayscale image 90 | src = src[:, None, ...] 91 | 92 | assert isinstance(rots, list), "ERR: rots is not a list" 93 | if _type == "numpy": 94 | out = run_numpy( 95 | src=src, 96 | rots=rots, 97 | mode=mode, 98 | z_down=z_down, 99 | height=height, 100 | width=width, 101 | **kwargs, 102 | ) 103 | elif _type == "torch": 104 | out = run_torch( 105 | src=src, 106 | rots=rots, 107 | mode=mode, 108 | z_down=z_down, 109 | height=height, 110 | width=width, 111 | **kwargs, 112 | ) 113 | else: 114 | raise ValueError 115 | 116 | # make sure that the output batch dim is removed if it's only a single image 117 | if is_single: 118 | out = out.squeeze(0) 119 | 120 | return out 121 | -------------------------------------------------------------------------------- /modules/equilib/equi2pers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .numpy import grid_sample as numpy_grid_sample 4 | from .torch import grid_sample as torch_grid_sample 5 | 6 | __all__ = ["numpy_grid_sample", "torch_grid_sample"] 7 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/cpp/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/cpp/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from setuptools import Extension, setup 4 | 5 | from torch.utils import cpp_extension 6 | 7 | 8 | setup( 9 | name="grid_sample_cpp", 10 | ext_modules=[ 11 | cpp_extension.CppExtension("grid_sample", ["grid_sample.cpp"]) 12 | ], 13 | cmdclass={"build_ext": cpp_extension.BuildExtension}, 14 | ) 15 | 16 | Extension( 17 | name="grid_sample_cpp", 18 | sources=["grid_sample.cpp"], 19 | include_dirs=cpp_extension.include_paths(), 20 | language="c++", 21 | ) 22 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/numpy/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from .grid_sample import grid_sample 4 | 5 | __all__ = ["grid_sample"] 6 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/numpy/bicubic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | 5 | __all__ = ["bicubic"] 6 | 7 | 8 | def kernel( 9 | s: np.ndarray, a: float = -0.75, dtype: np.dtype = np.dtype(np.float32) 10 | ) -> np.ndarray: 11 | out = np.zeros_like(s, dtype) 12 | s = np.abs(s) 13 | mask1 = np.logical_and(0 <= s, s <= 1) 14 | mask2 = np.logical_and(1 < s, s <= 2) 15 | out[mask1] = (a + 2) * (s[mask1] ** 3) - (a + 3) * (s[mask1] ** 2) + 1 16 | out[mask2] = ( 17 | a * (s[mask2] ** 3) 18 | - (5 * a) * (s[mask2] ** 2) 19 | + (8 * a) * s[mask2] 20 | - 4 * a 21 | ) 22 | return out 23 | 24 | 25 | def bicubic(img: np.ndarray, grid: np.ndarray, out: np.ndarray) -> np.ndarray: 26 | """Bicubic Interpolation""" 27 | 28 | b_in, c_in, h_in, w_in = img.shape 29 | b_out, _, h_out, w_out = out.shape 30 | dtype = out.dtype 31 | # NOTE: this is hardcoded since pytorch is also -0.75 32 | a = -0.75 33 | 34 | int_dtype = np.dtype(np.int64) 35 | min_grid = np.floor(grid).astype(int_dtype) 36 | 37 | d1 = 1 + (grid - min_grid) # (b, 2, h, w) 38 | d2 = grid - min_grid 39 | d3 = min_grid + 1 - grid 40 | d4 = min_grid + 2 - grid 41 | 42 | c1 = (grid - d1).astype(int_dtype) # (b, 2, h, w) 43 | c2 = (grid - d2).astype(int_dtype) 44 | c3 = (grid + d3).astype(int_dtype) 45 | c4 = (grid + d4).astype(int_dtype) 46 | 47 | c1[:, 0, ...] %= h_in 48 | c1[:, 1, ...] %= w_in 49 | c2[:, 0, ...] %= h_in 50 | c2[:, 1, ...] %= w_in 51 | c3[:, 0, ...] %= h_in 52 | c3[:, 1, ...] %= w_in 53 | c4[:, 0, ...] %= h_in 54 | c4[:, 1, ...] %= w_in 55 | 56 | # FIXME: this part is slow 57 | k1 = kernel(d1, a, dtype) # (b, 2, h, w) 58 | k2 = kernel(d2, a, dtype) 59 | k3 = kernel(d3, a, dtype) 60 | k4 = kernel(d4, a, dtype) 61 | 62 | mat_l = np.stack( 63 | [k1[:, 1, ...], k2[:, 1, ...], k3[:, 1, ...], k4[:, 1, ...]], axis=-1 64 | ) 65 | mat_r = np.stack( 66 | [k1[:, 0, ...], k2[:, 0, ...], k3[:, 0, ...], k4[:, 0, ...]], axis=-1 67 | ) 68 | 69 | # FIXME: this part is slow 70 | mat_m = np.empty((b_out, c_in, h_out, w_out, 4, 4), dtype=dtype) 71 | for b in range(b_out): 72 | y1 = c1[b, 0, ...] # (h, w) 73 | y2 = c2[b, 0, ...] 74 | y3 = c3[b, 0, ...] 75 | y4 = c4[b, 0, ...] 76 | 77 | x1 = c1[b, 1, ...] 78 | x2 = c2[b, 1, ...] 79 | x3 = c3[b, 1, ...] 80 | x4 = c4[b, 1, ...] 81 | 82 | mat_m_x1 = np.stack( 83 | [ 84 | img[b][:, y1, x1], # (c, h, w) 85 | img[b][:, y2, x1], 86 | img[b][:, y3, x1], 87 | img[b][:, y4, x1], 88 | ], 89 | axis=-1, 90 | ) 91 | mat_m_x2 = np.stack( 92 | [ 93 | img[b][:, y1, x2], 94 | img[b][:, y2, x2], 95 | img[b][:, y3, x2], 96 | img[b][:, y4, x2], 97 | ], 98 | axis=-1, 99 | ) 100 | mat_m_x3 = np.stack( 101 | [ 102 | img[b][:, y1, x3], 103 | img[b][:, y2, x3], 104 | img[b][:, y3, x3], 105 | img[b][:, y4, x3], 106 | ], 107 | axis=-1, 108 | ) 109 | mat_m_x4 = np.stack( 110 | [ 111 | img[b][:, y1, x4], 112 | img[b][:, y2, x4], 113 | img[b][:, y3, x4], 114 | img[b][:, y4, x4], 115 | ], 116 | axis=-1, 117 | ) 118 | 119 | mat_m[b, ...] = np.stack( 120 | [mat_m_x1, mat_m_x2, mat_m_x3, mat_m_x4], axis=-2 121 | ) 122 | 123 | mat_l = mat_l[:, np.newaxis, ..., np.newaxis, :] 124 | mat_r = mat_r[:, np.newaxis, ..., np.newaxis] 125 | out = (mat_l @ mat_m @ mat_r).squeeze(-1).squeeze(-1) 126 | 127 | return out 128 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/numpy/bilinear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | 5 | __all__ = ["bilinear"] 6 | 7 | 8 | def interp(v0, v1, d, L): 9 | return v0 * (1 - d) / L + v1 * d / L 10 | 11 | 12 | def interp2d(q00, q10, q01, q11, dy, dx): 13 | f0 = interp(q00, q01, dx, 1) 14 | f1 = interp(q10, q11, dx, 1) 15 | return interp(f0, f1, dy, 1) 16 | 17 | 18 | def bilinear(img: np.ndarray, grid: np.ndarray, out: np.ndarray) -> np.ndarray: 19 | """Bilinear Interpolation 20 | 21 | NOTE: asserts are removed 22 | """ 23 | 24 | b, _, h, w = img.shape 25 | 26 | min_grid = np.floor(grid).astype(np.int64) 27 | max_grid = min_grid + 1 28 | d_grid = grid - min_grid 29 | 30 | min_grid[:, 0, :, :] %= h 31 | min_grid[:, 1, :, :] %= w 32 | max_grid[:, 0, :, :] %= h 33 | max_grid[:, 1, :, :] %= w 34 | 35 | # FIXME: any way to do efficient batch? 36 | for i in range(b): 37 | dy = d_grid[i, 0, ...] 38 | dx = d_grid[i, 1, ...] 39 | min_ys = min_grid[i, 0, ...] 40 | min_xs = min_grid[i, 1, ...] 41 | max_ys = max_grid[i, 0, ...] 42 | max_xs = max_grid[i, 1, ...] 43 | 44 | p00 = img[i][:, min_ys, min_xs] 45 | p10 = img[i][:, max_ys, min_xs] 46 | p01 = img[i][:, min_ys, max_xs] 47 | p11 = img[i][:, max_ys, max_xs] 48 | 49 | out[i, ...] = interp2d(p00, p10, p01, p11, dy, dx) 50 | 51 | return out 52 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/numpy/grid_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import warnings 4 | 5 | import numpy as np 6 | 7 | from .bicubic import bicubic 8 | from .bilinear import bilinear 9 | from .nearest import nearest 10 | 11 | 12 | def grid_sample( 13 | img: np.ndarray, grid: np.ndarray, out: np.ndarray, mode: str = "bilinear" 14 | ) -> np.ndarray: 15 | """Numpy grid sampling algorithm 16 | 17 | params: 18 | - img (np.ndarray) 19 | - grid (np.ndarray) 20 | - out (np.ndarray) 21 | - mode (str): ('bilinear', 'bicubic', 'nearest') 22 | 23 | return: 24 | - out (np.ndarray) 25 | 26 | NOTE: 27 | - assumes that `img`, `grid`, and `out` have the same dimension of 28 | (batch, channel, height, width). 29 | - channel for `grid` should be 2 (yx) 30 | 31 | """ 32 | 33 | if mode == "nearest": 34 | out = nearest(img, grid, out) 35 | elif mode == "bilinear": 36 | out = bilinear(img, grid, out) 37 | elif mode == "bicubic": 38 | # FIXME: bicubic algorithm is not perfect yet 39 | warnings.warn( 40 | "Bicubic interpolation is not perfect (especially when upsampling). Use with care!" 41 | ) 42 | out = bicubic(img, grid, out) 43 | else: 44 | raise ValueError(f"ERR: {mode} is not supported") 45 | 46 | return out 47 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/numpy/nearest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | 5 | __all__ = ["nearest"] 6 | 7 | 8 | def nearest(img: np.ndarray, grid: np.ndarray, out: np.ndarray) -> np.ndarray: 9 | """Nearest Neightbor Sampling""" 10 | 11 | b, _, h, w = img.shape 12 | 13 | round_grid = np.rint(grid).astype(np.int64) 14 | round_grid[:, 0, ...] %= h 15 | round_grid[:, 1, ...] %= w 16 | 17 | for i in range(b): 18 | y = round_grid[i, 0, ...] 19 | x = round_grid[i, 1, ...] 20 | out[i, ...] = img[i][:, y, x] 21 | 22 | return out 23 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/torch/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .grid_sample import grid_sample 4 | 5 | __all__ = ["grid_sample"] 6 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/torch/bicubic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | 5 | from modules.equilib.torch_utils.func import get_device 6 | 7 | __all__ = ["bicubic"] 8 | 9 | 10 | def kernel(s, a): 11 | out = torch.zeros_like(s) 12 | s = torch.abs(s) 13 | mask1 = torch.logical_and(0 <= s, s <= 1) 14 | mask2 = torch.logical_and(1 < s, s <= 2) 15 | out[mask1] = (a + 2) * (s[mask1] ** 3) - (a + 3) * (s[mask1] ** 2) + 1 16 | out[mask2] = ( 17 | a * (s[mask2] ** 3) 18 | - (5 * a) * (s[mask2] ** 2) 19 | + (8 * a) * s[mask2] 20 | - 4 * a 21 | ) 22 | return out 23 | 24 | 25 | def bicubic( 26 | img: torch.Tensor, grid: torch.Tensor, out: torch.Tensor 27 | ) -> torch.Tensor: 28 | 29 | # FIXME: out being initialized doesn't really matter? 30 | 31 | b_in, c_in, h_in, w_in = img.shape 32 | b_out, _, h_out, w_out = out.shape 33 | dtype = out.dtype 34 | device = get_device(out) 35 | 36 | a = -0.75 37 | 38 | int_dtype = torch.int64 39 | 40 | min_grid = torch.floor(grid).type(int_dtype) 41 | 42 | d1 = 1 + (grid - min_grid) 43 | d2 = grid - min_grid 44 | d3 = min_grid + 1 - grid 45 | d4 = min_grid + 2 - grid 46 | 47 | c1 = (grid - d1).type(int_dtype) 48 | c2 = (grid - d2).type(int_dtype) 49 | c3 = (grid + d3).type(int_dtype) 50 | c4 = (grid + d4).type(int_dtype) 51 | 52 | c1[:, 0, ...] %= h_in 53 | c1[:, 1, ...] %= w_in 54 | c2[:, 0, ...] %= h_in 55 | c2[:, 1, ...] %= w_in 56 | c3[:, 0, ...] %= h_in 57 | c3[:, 1, ...] %= w_in 58 | c4[:, 0, ...] %= h_in 59 | c4[:, 1, ...] %= w_in 60 | 61 | k1 = kernel(d1, a).type(dtype) 62 | k2 = kernel(d2, a).type(dtype) 63 | k3 = kernel(d3, a).type(dtype) 64 | k4 = kernel(d4, a).type(dtype) 65 | 66 | mat_l = torch.stack( 67 | [k1[:, 1, ...], k2[:, 1, ...], k3[:, 1, ...], k4[:, 1, ...]], dim=-1 68 | ).to(device) 69 | mat_r = torch.stack( 70 | [k1[:, 0, ...], k2[:, 0, ...], k3[:, 0, ...], k4[:, 0, ...]], dim=-1 71 | ).to(device) 72 | 73 | mat_m = torch.empty( 74 | (b_out, c_in, h_out, w_out, 4, 4), dtype=dtype, device=device 75 | ) 76 | for b in range(b_out): 77 | y1 = c1[b, 0, ...] # (h, w) 78 | y2 = c2[b, 0, ...] 79 | y3 = c3[b, 0, ...] 80 | y4 = c4[b, 0, ...] 81 | 82 | x1 = c1[b, 1, ...] 83 | x2 = c2[b, 1, ...] 84 | x3 = c3[b, 1, ...] 85 | x4 = c4[b, 1, ...] 86 | 87 | mat_m_x1 = torch.stack( 88 | [ 89 | img[b][:, y1, x1], # (c, h, w) 90 | img[b][:, y2, x1], 91 | img[b][:, y3, x1], 92 | img[b][:, y4, x1], 93 | ], 94 | dim=-1, 95 | ) 96 | mat_m_x2 = torch.stack( 97 | [ 98 | img[b][:, y1, x2], 99 | img[b][:, y2, x2], 100 | img[b][:, y3, x2], 101 | img[b][:, y4, x2], 102 | ], 103 | dim=-1, 104 | ) 105 | mat_m_x3 = torch.stack( 106 | [ 107 | img[b][:, y1, x3], 108 | img[b][:, y2, x3], 109 | img[b][:, y3, x3], 110 | img[b][:, y4, x3], 111 | ], 112 | dim=-1, 113 | ) 114 | mat_m_x4 = torch.stack( 115 | [ 116 | img[b][:, y1, x4], 117 | img[b][:, y2, x4], 118 | img[b][:, y3, x4], 119 | img[b][:, y4, x4], 120 | ], 121 | dim=-1, 122 | ) 123 | 124 | mat_m[b, ...] = torch.stack( 125 | [mat_m_x1, mat_m_x2, mat_m_x3, mat_m_x4], dim=-2 126 | ) 127 | 128 | mat_l = mat_l.unsqueeze(1).unsqueeze(-2) 129 | mat_r = mat_r.unsqueeze(1).unsqueeze(-1) 130 | out = (mat_l @ mat_m @ mat_r).squeeze(-1).squeeze(-1) 131 | 132 | return out 133 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/torch/bilinear.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | 5 | __all__ = ["bilinear"] 6 | 7 | 8 | def linear_interp(v0, v1, d, L): 9 | return v0 * (1 - d) / L + v1 * d / L 10 | 11 | 12 | def interp2d(q00, q10, q01, q11, dy, dx): 13 | f0 = linear_interp(q00, q01, dx, 1) 14 | f1 = linear_interp(q10, q11, dx, 1) 15 | return linear_interp(f0, f1, dy, 1) 16 | 17 | 18 | def bilinear( 19 | img: torch.Tensor, grid: torch.Tensor, out: torch.Tensor 20 | ) -> torch.Tensor: 21 | 22 | b, _, h, w = img.shape 23 | 24 | min_grid = torch.floor(grid).type(torch.int64) 25 | max_grid = min_grid + 1 26 | d_grid = grid - min_grid 27 | 28 | min_grid[:, 0, :, :] %= h 29 | min_grid[:, 1, :, :] %= w 30 | max_grid[:, 0, :, :] %= h 31 | max_grid[:, 1, :, :] %= w 32 | 33 | # FIXME: anyway to do efficient batch? 34 | for i in range(b): 35 | dy = d_grid[i, 0, ...] 36 | dx = d_grid[i, 1, ...] 37 | min_ys = min_grid[i, 0, ...] 38 | min_xs = min_grid[i, 1, ...] 39 | max_ys = max_grid[i, 0, ...] 40 | max_xs = max_grid[i, 1, ...] 41 | 42 | min_ys %= h 43 | min_xs %= w 44 | 45 | p00 = img[i][:, min_ys, min_xs] 46 | p10 = img[i][:, max_ys, min_xs] 47 | p01 = img[i][:, min_ys, max_xs] 48 | p11 = img[i][:, min_ys, max_xs] 49 | 50 | out[i, ...] = interp2d(p00, p10, p01, p11, dy, dx) 51 | 52 | return out 53 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/torch/grid_sample.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from typing import Optional 4 | import warnings 5 | 6 | import torch 7 | 8 | from .native import native_bicubic, native_bilinear, native_nearest 9 | from .nearest import nearest 10 | from .bilinear import bilinear 11 | from .bicubic import bicubic 12 | 13 | DTYPES = (torch.uint8, torch.float16, torch.float32, torch.float64) 14 | 15 | 16 | def grid_sample( 17 | img: torch.Tensor, 18 | grid: torch.Tensor, 19 | out: Optional[torch.Tensor] = None, 20 | mode: str = "bilinear", 21 | backend: str = "native", 22 | ) -> torch.Tensor: 23 | """Torch grid sampling algorithm 24 | 25 | params: 26 | - img (torch.Tensor) 27 | - grid (torch.Tensor) 28 | - out (Optional[torch.Tensor]): defaults to None 29 | - mode (str): ('bilinear', 'bicubic', 'nearest') 30 | - backend (str): ('native', 'pure') 31 | 32 | return: 33 | - img (torch.Tensor) 34 | 35 | NOTE: for `backend`, `pure` is relatively efficient since grid doesn't need 36 | to be in the same device as the `img`. However, `native` is faster. 37 | 38 | NOTE: for `pure` backends, we need to pass reference to `out`. 39 | 40 | NOTE: for `native` backends, we should pass anything for `out` 41 | 42 | """ 43 | 44 | if backend == "native": 45 | if out is not None: 46 | # NOTE: out is created 47 | warnings.warn( 48 | "don't need to pass preallocated `out` to `grid_sample`" 49 | ) 50 | assert img.device == grid.device, ( 51 | f"ERR: when using {backend}, the devices of `img` and `grid` need" 52 | "to be on the same device" 53 | ) 54 | if mode == "nearest": 55 | out = native_nearest(img, grid) 56 | elif mode == "bilinear": 57 | out = native_bilinear(img, grid) 58 | elif mode == "bicubic": 59 | out = native_bicubic(img, grid) 60 | else: 61 | raise ValueError(f"ERR: {mode} is not supported") 62 | elif backend == "pure": 63 | # NOTE: img and grid can be on different devices, but grid should be on the cpu 64 | # FIXME: since bilinear implementation depends on `grid` being on device, I'm removing 65 | # this warning and will put `grid` onto the same device until a fix is found 66 | # if grid.device.type == "cuda": 67 | # warnings.warn("input `grid` should be on the cpu, but got a cuda tensor") 68 | assert ( 69 | out is not None 70 | ), "ERR: need to pass reference to `out`, but got None" 71 | assert img.device == grid.device, ( 72 | f"ERR: when using {backend}, the devices of `img` and `grid` need" 73 | "to be on the same device" 74 | ) 75 | if mode == "nearest": 76 | out = nearest(img, grid, out) 77 | elif mode == "bilinear": 78 | out = bilinear(img, grid, out) 79 | elif mode == "bicubic": 80 | out = bicubic(img, grid, out) 81 | else: 82 | raise ValueError(f"ERR: {mode} is not supported") 83 | else: 84 | raise ValueError(f"ERR: {backend} is not supported") 85 | 86 | return out 87 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/torch/native.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from functools import partial 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | __all__ = ["native", "native_bicubic", "native_bilinear", "native_nearest"] 9 | 10 | 11 | def native( 12 | img: torch.Tensor, grid: torch.Tensor, mode: str = "bilinear" 13 | ) -> torch.Tensor: 14 | """Torch Grid Sample (default) 15 | 16 | - Uses `torch.nn.functional.grid_sample` 17 | - By far the best way to sample 18 | 19 | params: 20 | - img (torch.Tensor): Tensor[B, C, H, W] or Tensor[C, H, W] 21 | - grid (torch.Tensor): Tensor[B, 2, H, W] or Tensor[2, H, W] 22 | - device (int or str): torch.device 23 | - mode (str): (`bilinear`, `bicubic`, `nearest`) 24 | 25 | returns: 26 | - out (torch.Tensor): Tensor[B, C, H, W] or Tensor[C, H, W] 27 | where H, W are grid size 28 | 29 | NOTE: `img` and `grid` needs to be on the same device 30 | 31 | NOTE: `img` and `grid` is somehow mutated (inplace?), so if you need 32 | to reuse `img` and `grid` somewhere else, use `.clone()` before 33 | passing it to this function 34 | 35 | NOTE: this method is different from other grid sampling that 36 | the padding cannot be wrapped. There might be pixel inaccuracies 37 | when sampling from the boundaries of the image (the seam). 38 | 39 | I hope later on, we can add wrap padding to this since the function 40 | is super fast. 41 | 42 | """ 43 | 44 | assert ( 45 | grid.dtype == img.dtype 46 | ), "ERR: img and grid should have the same dtype" 47 | 48 | _, _, h, w = img.shape 49 | 50 | # grid in shape: (batch, channel, h_out, w_out) 51 | # grid out shape: (batch, h_out, w_out, channel) 52 | grid = grid.permute(0, 2, 3, 1) 53 | 54 | """Preprocess for grid_sample 55 | normalize grid -1 ~ 1 56 | 57 | assumptions: 58 | - values of `grid` is between `0 ~ (h-1)` and `0 ~ (w-1)` 59 | - input of `grid_sample` need to be between `-1 ~ 1` 60 | - maybe lose some precision when we map the values (int to float)? 61 | 62 | mapping (e.g. mapping of height): 63 | 1. 0 <= y <= (h-1) 64 | 2. -1/2 <= y' <= 1/2 <- y' = y/(h-1) - 1/2 65 | 3. -1 <= y" <= 1 <- y" = 2y' 66 | """ 67 | 68 | # FIXME: this is not necessary when we are already preprocessing grid before 69 | # this method is called 70 | # grid[..., 0] %= h 71 | # grid[..., 1] %= w 72 | 73 | norm_uj = torch.clamp(2 * grid[..., 0] / (h - 1) - 1, -1, 1) 74 | norm_ui = torch.clamp(2 * grid[..., 1] / (w - 1) - 1, -1, 1) 75 | 76 | # reverse: grid sample takes xy, not (height, width) 77 | grid[..., 0] = norm_ui 78 | grid[..., 1] = norm_uj 79 | 80 | # grid.requires_grad = True 81 | 82 | out = F.grid_sample( 83 | img, 84 | grid, 85 | mode=mode, 86 | # use center of pixel instead of corner 87 | align_corners=True, 88 | # padding mode defaults to 'zeros' and there is no 'wrapping' mode 89 | padding_mode="reflection", 90 | ) 91 | 92 | return out 93 | 94 | 95 | # aliases 96 | native_nearest = partial(native, mode="nearest") 97 | native_bilinear = partial(native, mode="bilinear") 98 | native_bicubic = partial(native, mode="bicubic") 99 | -------------------------------------------------------------------------------- /modules/equilib/grid_sample/torch/nearest.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | 5 | __all__ = ["nearest"] 6 | 7 | 8 | def nearest( 9 | img: torch.Tensor, grid: torch.Tensor, out: torch.Tensor 10 | ) -> torch.Tensor: 11 | """Nearest Neighbor Interpolation 12 | 13 | Merit of using this nearest instead is that the grid doesn't need to be a 14 | cuda tensor. Although it is a little bit slow since it is iterating batches 15 | """ 16 | 17 | b, _, h, w = img.shape 18 | 19 | round_grid = torch.round(grid).type(torch.int64) 20 | round_grid[:, 0, ...] %= h 21 | round_grid[:, 1, ...] %= w 22 | 23 | # FIXME: find a better way of sampling batches 24 | for i in range(b): 25 | y = round_grid[i, 0, :, :] 26 | x = round_grid[i, 1, :, :] 27 | out[i, ...] = img[i][:, y, x] 28 | 29 | return out 30 | -------------------------------------------------------------------------------- /modules/equilib/numpy_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .grid import create_grid, create_normalized_grid, create_xyz_grid 4 | from .intrinsic import create_intrinsic_matrix 5 | from .rotation import ( 6 | create_global2camera_rotation_matrix, 7 | create_rotation_matrices, 8 | create_rotation_matrix, 9 | create_rotation_matrix_at_once, 10 | ) 11 | 12 | __all__ = [ 13 | "create_grid", 14 | "create_intrinsic_matrix", 15 | "create_global2camera_rotation_matrix", 16 | "create_normalized_grid", 17 | "create_rotation_matrices", 18 | "create_rotation_matrix", 19 | "create_rotation_matrix_at_once", 20 | "create_xyz_grid", 21 | ] 22 | -------------------------------------------------------------------------------- /modules/equilib/numpy_utils/intrinsic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import numpy as np 4 | 5 | 6 | def create_intrinsic_matrix( 7 | height: int, 8 | width: int, 9 | fov_x: float, 10 | skew: float, 11 | dtype: np.dtype = np.dtype(np.float32), 12 | ) -> np.ndarray: 13 | """Create intrinsic matrix 14 | 15 | params: 16 | - height, width (int) 17 | - fov_x (float): make sure it's in degrees 18 | - skew (float): 0.0 19 | - dtype (np.dtype): np.float32 20 | 21 | returns: 22 | - K (np.ndarray): 3x3 intrinsic matrix 23 | """ 24 | 25 | # perspective projection (focal length) 26 | f = width / (2.0 * np.tan(np.radians(fov_x).astype(dtype) / 2.0)) 27 | # transform between camera frame and pixel coordinates 28 | K = np.array( 29 | [[f, skew, width / 2], [0.0, f, height / 2], [0.0, 0.0, 1.0]], 30 | dtype=dtype, 31 | ) 32 | 33 | return K 34 | -------------------------------------------------------------------------------- /modules/equilib/torch_utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | from .grid import create_grid, create_normalized_grid, create_xyz_grid 4 | from .intrinsic import create_intrinsic_matrix, pi 5 | from .rotation import ( 6 | create_global2camera_rotation_matrix, 7 | create_rotation_matrices, 8 | create_rotation_matrix, 9 | create_rotation_matrix_at_once, 10 | ) 11 | from .func import get_device, sizeof 12 | 13 | __all__ = [ 14 | "create_global2camera_rotation_matrix", 15 | "create_grid", 16 | "create_intrinsic_matrix", 17 | "create_normalized_grid", 18 | "create_rotation_matrices", 19 | "create_rotation_matrix", 20 | "create_rotation_matrix_at_once", 21 | "create_xyz_grid", 22 | "get_device", 23 | "sizeof", 24 | "pi", 25 | ] 26 | -------------------------------------------------------------------------------- /modules/equilib/torch_utils/func.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | 5 | 6 | def sizeof(tensor: torch.Tensor) -> float: 7 | """Get the size of a tensor""" 8 | assert torch.is_tensor(tensor), "ERR: is not tensor" 9 | return tensor.element_size() * tensor.nelement() 10 | 11 | 12 | def get_device(a: torch.Tensor) -> torch.device: 13 | """Get device of a Tensor""" 14 | return torch.device(a.get_device() if a.get_device() >= 0 else "cpu") 15 | -------------------------------------------------------------------------------- /modules/equilib/torch_utils/intrinsic.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | 3 | import torch 4 | 5 | pi = torch.Tensor([3.14159265358979323846]) 6 | 7 | 8 | def deg2rad(tensor: torch.Tensor) -> torch.Tensor: 9 | """Function that converts angles from degrees to radians""" 10 | return tensor * pi.to(tensor.device).type(tensor.dtype) / 180.0 11 | 12 | 13 | def create_intrinsic_matrix( 14 | height: int, 15 | width: int, 16 | fov_x: float, 17 | skew: float, 18 | dtype: torch.dtype = torch.float32, 19 | device: torch.device = torch.device("cpu"), 20 | ) -> torch.Tensor: 21 | """Create intrinsic matrix 22 | 23 | params: 24 | - height, width (int) 25 | - fov_x (float): make sure it's in degrees 26 | - skew (float): 0.0 27 | - dtype (torch.dtype): torch.float32 28 | - device (torch.device): torch.device("cpu") 29 | 30 | returns: 31 | - K (torch.tensor): 3x3 intrinsic matrix 32 | """ 33 | f = width / (2 * torch.tan(deg2rad(torch.tensor(fov_x, dtype=dtype)) / 2)) 34 | K = torch.tensor( 35 | [[f, skew, width / 2], [0.0, f, height / 2], [0.0, 0.0, 1.0]], 36 | dtype=dtype, 37 | device=device, 38 | ) 39 | return K 40 | -------------------------------------------------------------------------------- /modules/geo_predictors/PanoFusionDistancePredictor.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import numpy as np 3 | import torch 4 | import trimesh 5 | 6 | from utils.camera_utils import * 7 | from modules.geo_predictors import PanoFusionInvPredictor, PanoFusionNormalPredictor, PanoGeoRefiner, PanoJointPredictor 8 | 9 | import torchvision 10 | from PIL import Image 11 | 12 | class PanoFusionDistance: 13 | def __init__(self): 14 | self.image_path = None 15 | self.ref_distance_path = None 16 | self.ref_normal_path = None 17 | self.ref_geometry_path = None 18 | self.image = None 19 | self.gt_distance = None 20 | self.ref_distance = None 21 | self.ref_normal = None 22 | self.pano_width, self.pano_height = 2048, 1024 23 | self.data_dir = None 24 | self.case_name = 'wp' 25 | 26 | def get_ref_distance(self): 27 | assert self.image is not None 28 | assert self.ref_distance_path is not None 29 | assert self.height > 0 and self.width > 0 30 | 31 | ref_distance = None 32 | if os.path.exists(self.ref_distance_path): 33 | ref_distance = np.load(self.ref_distance_path) 34 | ref_distance = torch.from_numpy(ref_distance.astype(np.float32)).cuda() 35 | else: 36 | distance_predictor = PanoFusionInvPredictor() 37 | ref_distance, _ = distance_predictor(self.image, 38 | torch.zeros([self.height, self.width]), 39 | torch.ones([self.height, self.width])) 40 | return ref_distance 41 | 42 | def get_ref_normal(self): 43 | 44 | normal_predictor = PanoFusionNormalPredictor() 45 | ref_normal = normal_predictor.inpaint_normal(self.image, 46 | torch.ones([self.height, self.width, 3]) / np.sqrt(3.), 47 | torch.ones([self.height, self.width])) 48 | 49 | return ref_normal 50 | 51 | def refine_geometry(self, distance_map, normal_map): 52 | refiner = PanoGeoRefiner() 53 | return refiner.refine(distance_map, normal_map) 54 | 55 | def get_joint_distance_normal(self, init_distance=None, init_mask=None): 56 | 57 | joint_predictor = PanoJointPredictor() 58 | idx = 0 59 | ref_distance, ref_normal = joint_predictor(idx, self.image, 60 | torch.ones([self.pano_height, self.pano_width, 1]), 61 | torch.ones([self.pano_height, self.pano_width])) 62 | 63 | return ref_distance, ref_normal 64 | 65 | def normalization(self): 66 | scale = self.ref_distance.max().item() * 1.05 67 | self.ref_distance /= scale 68 | 69 | def save_ref_geometry(self): 70 | if self.ref_distance_path is not None: 71 | np.save(self.ref_distance_path, self.ref_distance.cpu().numpy()) 72 | if self.ref_normal_path is not None: 73 | np.save(self.ref_normal_path, self.ref_normal.cpu().numpy()) 74 | 75 | # Save point cloud 76 | pano_dirs = img_coord_to_pano_direction(img_coord_from_hw(self.height, self.width)) 77 | pts = pano_dirs * self.ref_distance.squeeze()[..., None] 78 | pts = pts.cpu().numpy().reshape(-1, 3) 79 | if self.image is not None: 80 | pcd = trimesh.PointCloud(pts, vertex_colors=self.image.reshape(-1, 3).cpu().numpy()) 81 | else: 82 | pcd = trimesh.PointCloud(pts) 83 | 84 | assert self.ref_geometry_path is not None and self.ref_geometry_path[-4:] == '.ply' 85 | pcd.export(self.ref_geometry_path) 86 | 87 | @torch.no_grad() 88 | def ref_point_cloud(self): 89 | pano_dirs = img_coord_to_pano_direction(img_coord_from_hw(self.height, self.width)) 90 | pts = pano_dirs * self.ref_distance.squeeze()[..., None] 91 | return pts 92 | 93 | 94 | class PanoFusionDistancePredictor(PanoFusionDistance): 95 | def __init__(self): 96 | super().__init__() 97 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 98 | 99 | def predict(self, pano_tensor, init_distance=None, init_mask=None, pano_width=2048, pano_height=1024): 100 | self.pano_width, self.pano_height = pano_width, pano_height 101 | self.image = pano_tensor.cuda() 102 | 103 | self.ref_distance, self.ref_normal = self.get_joint_distance_normal(init_distance, init_mask) 104 | 105 | return self.ref_distance.squeeze(-1) 106 | -------------------------------------------------------------------------------- /modules/geo_predictors/__init__.py: -------------------------------------------------------------------------------- 1 | from modules.geo_predictors.omnidata.omnidata_predictor import OmnidataPredictor 2 | from .pano_fusion_inv_predictor import PanoFusionInvPredictor 3 | from .pano_fusion_normal_predictor import PanoFusionNormalPredictor 4 | from .pano_geo_refiner import PanoGeoRefiner 5 | from .pano_joint_predictor import PanoJointPredictor 6 | -------------------------------------------------------------------------------- /modules/geo_predictors/geo_predictor.py: -------------------------------------------------------------------------------- 1 | 2 | class GeoPredictor: 3 | def __init__(self): 4 | pass 5 | 6 | def inpaint_distance(self, img, ref_distance, mask): 7 | raise NotImplementedError 8 | -------------------------------------------------------------------------------- /modules/geo_predictors/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | 7 | def get_activation(activation): 8 | if activation == 'identity': 9 | return lambda x: x 10 | elif activation == 'relu': 11 | return lambda x: F.relu(x) 12 | else: 13 | raise NotImplementedError 14 | 15 | 16 | class VanillaMLP(nn.Module): 17 | def __init__(self, 18 | dim_in, 19 | dim_out, 20 | n_neurons, 21 | n_hidden_layers, 22 | sphere_init=False, 23 | weight_norm=False, 24 | sphere_init_radius=0.5, 25 | output_activation='identity'): 26 | super().__init__() 27 | self.n_neurons, self.n_hidden_layers = n_neurons, n_hidden_layers 28 | self.sphere_init, self.weight_norm = sphere_init, weight_norm 29 | self.sphere_init_radius = sphere_init_radius 30 | self.layers = [self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation()] 31 | for i in range(self.n_hidden_layers - 1): 32 | self.layers += [self.make_linear(self.n_neurons, self.n_neurons, is_first=False, is_last=False), 33 | self.make_activation()] 34 | self.layers += [self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True)] 35 | self.layers = nn.Sequential(*self.layers) 36 | 37 | def forward(self, x): 38 | x = self.layers(x.float()) 39 | return -x 40 | 41 | def make_linear(self, dim_in, dim_out, is_first, is_last): 42 | layer = nn.Linear(dim_in, dim_out, bias=True) # network without bias will degrade quality 43 | if self.sphere_init: 44 | if is_last: 45 | torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) 46 | torch.nn.init.normal_(layer.weight, mean=np.sqrt(np.pi) / np.sqrt(dim_in), std=0.0001) 47 | elif is_first: 48 | torch.nn.init.constant_(layer.bias, 0.0) 49 | torch.nn.init.constant_(layer.weight[:, 3:], 0.0) 50 | torch.nn.init.normal_(layer.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(dim_out)) 51 | else: 52 | torch.nn.init.constant_(layer.bias, 0.0) 53 | torch.nn.init.normal_(layer.weight, 0.0, np.sqrt(2) / np.sqrt(dim_out)) 54 | else: 55 | torch.nn.init.constant_(layer.bias, 0.0) 56 | torch.nn.init.kaiming_uniform_(layer.weight, nonlinearity='relu') 57 | 58 | if self.weight_norm: 59 | layer = nn.utils.weight_norm(layer) 60 | return layer 61 | 62 | def make_activation(self): 63 | if self.sphere_init: 64 | return nn.Softplus(beta=100) 65 | else: 66 | return nn.ReLU(inplace=True) 67 | -------------------------------------------------------------------------------- /modules/geo_predictors/omnidata/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/geo_predictors/omnidata/modules/__init__.py -------------------------------------------------------------------------------- /modules/geo_predictors/omnidata/modules/midas/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/geo_predictors/omnidata/modules/midas/__init__.py -------------------------------------------------------------------------------- /modules/geo_predictors/omnidata/modules/midas/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device('cpu')) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /modules/geo_predictors/omnidata/modules/midas/dpt_depth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | ): 36 | 37 | super(DPT, self).__init__() 38 | 39 | self.channels_last = channels_last 40 | 41 | hooks = { 42 | "vitb_rn50_384": [0, 1, 8, 11], 43 | "vitb16_384": [2, 5, 8, 11], 44 | "vitl16_384": [5, 11, 17, 23], 45 | } 46 | 47 | # Instantiate backbone and reassemble blocks 48 | self.pretrained, self.scratch = _make_encoder( 49 | backbone, 50 | features, 51 | True, # Set to true of you want to train from scratch, uses ImageNet weights 52 | groups=1, 53 | expand=False, 54 | exportable=False, 55 | hooks=hooks[backbone], 56 | use_readout=readout, 57 | ) 58 | 59 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 60 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 61 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 63 | 64 | self.scratch.output_conv = head 65 | 66 | 67 | def forward(self, x): 68 | if self.channels_last == True: 69 | x.contiguous(memory_format=torch.channels_last) 70 | 71 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 72 | 73 | layer_1_rn = self.scratch.layer1_rn(layer_1) 74 | layer_2_rn = self.scratch.layer2_rn(layer_2) 75 | layer_3_rn = self.scratch.layer3_rn(layer_3) 76 | layer_4_rn = self.scratch.layer4_rn(layer_4) 77 | 78 | path_4 = self.scratch.refinenet4(layer_4_rn) 79 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 80 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 81 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 82 | 83 | out = self.scratch.output_conv(path_1) 84 | 85 | return out 86 | 87 | class DPTDepthModel(DPT): 88 | def __init__(self, path=None, non_negative=True, num_channels=1, **kwargs): 89 | features = kwargs["features"] if "features" in kwargs else 256 90 | 91 | head = nn.Sequential( 92 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 93 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 94 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 95 | nn.ReLU(True), 96 | nn.Conv2d(32, num_channels, kernel_size=1, stride=1, padding=0), 97 | nn.ReLU(True) if non_negative else nn.Identity(), 98 | nn.Identity(), 99 | ) 100 | 101 | super().__init__(head, **kwargs) 102 | 103 | if path is not None: 104 | self.load(path) 105 | 106 | def forward(self, x): 107 | return super().forward(x).squeeze(dim=1) 108 | -------------------------------------------------------------------------------- /modules/geo_predictors/omnidata/modules/midas/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet(BaseModel): 13 | """Network for monocular depth estimation. 14 | """ 15 | 16 | def __init__(self, path=None, features=256, non_negative=True): 17 | """Init. 18 | 19 | Args: 20 | path (str, optional): Path to saved model. Defaults to None. 21 | features (int, optional): Number of features. Defaults to 256. 22 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 23 | """ 24 | print("Loading weights: ", path) 25 | 26 | super(MidasNet, self).__init__() 27 | 28 | use_pretrained = False if path is None else True 29 | 30 | self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained) 31 | 32 | self.scratch.refinenet4 = FeatureFusionBlock(features) 33 | self.scratch.refinenet3 = FeatureFusionBlock(features) 34 | self.scratch.refinenet2 = FeatureFusionBlock(features) 35 | self.scratch.refinenet1 = FeatureFusionBlock(features) 36 | 37 | self.scratch.output_conv = nn.Sequential( 38 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 39 | Interpolate(scale_factor=2, mode="bilinear"), 40 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 41 | nn.ReLU(True), 42 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 43 | nn.ReLU(True) if non_negative else nn.Identity(), 44 | ) 45 | 46 | if path: 47 | self.load(path) 48 | 49 | def forward(self, x): 50 | """Forward pass. 51 | 52 | Args: 53 | x (tensor): input data (image) 54 | 55 | Returns: 56 | tensor: depth 57 | """ 58 | 59 | layer_1 = self.pretrained.layer1(x) 60 | layer_2 = self.pretrained.layer2(layer_1) 61 | layer_3 = self.pretrained.layer3(layer_2) 62 | layer_4 = self.pretrained.layer4(layer_3) 63 | 64 | layer_1_rn = self.scratch.layer1_rn(layer_1) 65 | layer_2_rn = self.scratch.layer2_rn(layer_2) 66 | layer_3_rn = self.scratch.layer3_rn(layer_3) 67 | layer_4_rn = self.scratch.layer4_rn(layer_4) 68 | 69 | path_4 = self.scratch.refinenet4(layer_4_rn) 70 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 71 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 72 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 73 | 74 | out = self.scratch.output_conv(path_1) 75 | 76 | return torch.squeeze(out, dim=1) 77 | -------------------------------------------------------------------------------- /modules/geo_predictors/omnidata/omnidata_normal_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import transforms 4 | 5 | import PIL 6 | from PIL import Image 7 | 8 | from modules.geo_predictors.geo_predictor import GeoPredictor 9 | from .modules.unet import UNet 10 | from .modules.midas.dpt_depth import DPTDepthModel 11 | from .transforms import get_transform 12 | 13 | 14 | class OmnidataNormalPredictor(GeoPredictor): 15 | def __init__(self): 16 | super().__init__() 17 | self.img_size = 384 18 | ckpt_path = './checkpoints/omnidata_dpt_normal_v2.ckpt' 19 | self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=3) 20 | self.model.to(torch.device('cpu')) 21 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) 22 | if 'state_dict' in checkpoint: 23 | state_dict = {} 24 | for k, v in checkpoint['state_dict'].items(): 25 | state_dict[k[6:]] = v 26 | else: 27 | state_dict = checkpoint 28 | 29 | self.model.load_state_dict(state_dict) 30 | self.trans_totensor = transforms.Compose([transforms.Resize(self.img_size, interpolation=Image.BILINEAR), 31 | transforms.CenterCrop(self.img_size), 32 | transforms.Normalize(mean=0.5, std=0.5)]) 33 | 34 | def predict_normal(self, img): 35 | self.model.to(torch.device('cuda')) 36 | img_tensor = self.trans_totensor(img) 37 | output = self.model(img_tensor) 38 | self.model.to(torch.device('cpu')) 39 | output = F.interpolate(output, size=(512, 512), mode='bicubic') 40 | return output 41 | -------------------------------------------------------------------------------- /modules/geo_predictors/omnidata/omnidata_predictor.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torchvision import transforms 4 | 5 | import PIL 6 | from PIL import Image 7 | 8 | from modules.geo_predictors.geo_predictor import GeoPredictor 9 | from .modules.unet import UNet 10 | from .modules.midas.dpt_depth import DPTDepthModel 11 | from .transforms import get_transform 12 | 13 | 14 | def standardize_depth_map(img, mask_valid=None, trunc_value=0.1): 15 | if mask_valid is not None: 16 | img[~mask_valid] = torch.nan 17 | sorted_img = torch.sort(torch.flatten(img))[0] 18 | # Remove nan, nan at the end of sort 19 | num_nan = sorted_img.isnan().sum() 20 | if num_nan > 0: 21 | sorted_img = sorted_img[:-num_nan] 22 | # Remove outliers 23 | trunc_img = sorted_img[int(trunc_value * len(sorted_img)): int((1 - trunc_value) * len(sorted_img))] 24 | trunc_mean = trunc_img.mean() 25 | trunc_var = trunc_img.var() 26 | eps = 1e-6 27 | # Replace nan by mean 28 | img = torch.nan_to_num(img, nan=trunc_mean.item()) 29 | 30 | # Standardize 31 | img = (img - trunc_mean) / torch.sqrt(trunc_var + eps) 32 | return img 33 | 34 | class OmnidataPredictor(GeoPredictor): 35 | def __init__(self): 36 | super().__init__() 37 | self.img_size = 384 38 | ckpt_path = './checkpoints/omnidata_dpt_depth_v2.ckpt' 39 | self.model = DPTDepthModel(backbone='vitb_rn50_384', num_channels=1) 40 | self.model.to(torch.device('cpu')) 41 | checkpoint = torch.load(ckpt_path, map_location=torch.device('cpu')) 42 | if 'state_dict' in checkpoint: 43 | state_dict = {} 44 | for k, v in checkpoint['state_dict'].items(): 45 | state_dict[k[6:]] = v 46 | else: 47 | state_dict = checkpoint 48 | 49 | self.model.load_state_dict(state_dict) 50 | self.trans_totensor = transforms.Compose([transforms.Resize(self.img_size, interpolation=Image.BILINEAR), 51 | transforms.CenterCrop(self.img_size), 52 | transforms.Normalize(mean=0.5, std=0.5)]) 53 | 54 | 55 | def predict_disparity(self, img, **kwargs): 56 | self.model.to(torch.device('cuda')) 57 | img_tensor = self.trans_totensor(img) 58 | output = self.model(img_tensor).clip(0., 1.) 59 | self.model.to(torch.device('cpu')) 60 | output = output.clip(0., 1.) 61 | output = 1. / (output + 1e-6) 62 | return output[:, None] 63 | 64 | def predict_depth(self, img, **kwargs): 65 | self.model.to(torch.device('cuda')) 66 | img_tensor = self.trans_totensor(img) 67 | output = self.model(img_tensor).clip(0., 1.) 68 | self.model.to(torch.device('cpu')) 69 | output = F.interpolate(output[:, None], size=(512, 512), mode='bicubic') 70 | output = output.clip(0., 1.) 71 | return output 72 | -------------------------------------------------------------------------------- /modules/geo_predictors/omnidata/task_configs.py: -------------------------------------------------------------------------------- 1 | #################### 2 | # Tasks 3 | #################### 4 | 5 | task_parameters = { 6 | 'autoencoding':{ 7 | 'out_channels': 3, 8 | 9 | }, 10 | 11 | 'denoising':{ 12 | 'out_channels': 3, 13 | 14 | }, 15 | 'colorization': { 16 | 'out_channels': 3, 17 | 18 | }, 19 | 'class_object':{ 20 | 'out_channels': 1000, 21 | 22 | }, 23 | 'class_scene':{ 24 | 'out_channels': 365, 25 | 26 | }, 27 | 'depth_zbuffer':{ 28 | 'out_channels': 1, 29 | 'mask_val': 1.0, 30 | 'clamp_to': (0.0, 8000.0 / (2**16 - 1)) # Same as consistency 31 | }, 32 | 'depth_euclidean':{ 33 | 'out_channels': 1, 34 | 'clamp_to': (0.0, 8000.0 / (2**16 - 1)) # Same as consistency 35 | # 'mask_val': 1.0, 36 | }, 37 | 'edge_texture': { 38 | 'out_channels': 1, 39 | 'clamp_to': (0.0, 0.25) 40 | }, 41 | 'edge_occlusion': { 42 | 'out_channels': 1, 43 | 44 | }, 45 | 'inpainting':{ 46 | 'out_channels': 3, 47 | 48 | }, 49 | 'keypoints3d': { 50 | 'out_channels': 1, 51 | 52 | }, 53 | 'keypoints2d':{ 54 | 'out_channels': 1, 55 | 56 | }, 57 | 'principal_curvature':{ 58 | 'out_channels': 2, 59 | 'mask_val': 0.0, 60 | }, 61 | 62 | 'reshading':{ 63 | 'out_channels': 1, 64 | 65 | }, 66 | 'normal':{ 67 | 'out_channels': 3, 68 | # 'mask_val': 0.004, 69 | 'mask_val': 0.502, 70 | }, 71 | 'mask_valid':{ 72 | 'out_channels': 1, 73 | 'mask_val': 0.0, 74 | }, 75 | 'rgb':{ 76 | 'out_channels': 3, 77 | }, 78 | 'segment_semantic': { 79 | 'out_channels': 17, 80 | }, 81 | 'segment_unsup2d': { 82 | 'out_channels': 64, 83 | }, 84 | 'segment_unsup25d': { 85 | 'out_channels': 64, 86 | }, 87 | 'segment_instance': { 88 | }, 89 | 'segment_panoptic': { 90 | 'out_channels': 2, 91 | }, 92 | 'fragments': { 93 | 'out_channels': 1 94 | } 95 | } 96 | 97 | 98 | PIX_TO_PIX_TASKS = ['colorization', 'edge_texture', 'edge_occlusion', 'keypoints3d', 'keypoints2d', 'reshading', 'depth_zbuffer', 'depth_euclidean', 'curvature', 'autoencoding', 'denoising', 'normal', 'inpainting', 'segment_unsup2d', 'segment_unsup25d', 'segment_semantic', 'segment_instance'] 99 | FEED_FORWARD_TASKS = ['class_object', 'class_scene', 'room_layout', 'vanishing_point'] 100 | SINGLE_IMAGE_TASKS = PIX_TO_PIX_TASKS + FEED_FORWARD_TASKS 101 | SIAMESE_TASKS = ['fix_pose', 'jigsaw', 'ego_motion', 'point_match', 'non_fixated_pose'] 102 | 103 | -------------------------------------------------------------------------------- /modules/inpainters/SDFT_inpainter.py: -------------------------------------------------------------------------------- 1 | import utils.functions as functions 2 | import torch 3 | from .inpainter import Inpainter 4 | from PIL import Image 5 | from diffusers import StableDiffusionInpaintPipeline 6 | import numpy as np 7 | import os 8 | 9 | class SDFTInpainter(Inpainter): 10 | def __init__(self, subset_name=None): 11 | super().__init__() 12 | 13 | SD_path = "stabilityai/stable-diffusion-2-inpainting" 14 | pipe = StableDiffusionInpaintPipeline.from_pretrained(SD_path, torch_dtype=torch.float16, variant="fp16").to("cuda") 15 | 16 | SDFT_path = f"output/SDFT_weights" 17 | if os.path.exists(SDFT_path): 18 | pipe.load_lora_weights(SDFT_path) 19 | self.inpaint_pipe = pipe 20 | 21 | @torch.no_grad() 22 | def inpaint(self, img, mask): 23 | ''' 24 | :param img: B C H W? 25 | :param mask: 26 | :return: 27 | ''' 28 | inpaint_mask_pil = Image.fromarray(mask.detach().cpu().squeeze(0).squeeze(0).float().numpy() * 255).convert("RGB") 29 | 30 | rendered_image_pil = functions.tensor_to_pil(img) 31 | 32 | prompt = "" 33 | generator = torch.Generator(device="cuda").manual_seed(0) 34 | 35 | inpainted_image_pil = self.inpaint_pipe( 36 | prompt=prompt, 37 | image=rendered_image_pil, 38 | mask_image=inpaint_mask_pil, 39 | guidance_scale=7.5, 40 | num_inference_steps=30, 41 | generator=generator, 42 | ).images[0] 43 | result = functions.pil_to_tensor(inpainted_image_pil) 44 | 45 | return result.to(torch.float32) 46 | -------------------------------------------------------------------------------- /modules/inpainters/__init__.py: -------------------------------------------------------------------------------- 1 | from .inpainter import Inpainter 2 | from .lama_inpainter import LamaInpainter 3 | from .pano_pers_fusion_inpainter import PanoPersFusionInpainter 4 | -------------------------------------------------------------------------------- /modules/inpainters/inpainter.py: -------------------------------------------------------------------------------- 1 | class Inpainter: 2 | def __init__(self): 3 | pass 4 | 5 | def inpaint(self, img, mask): 6 | ''' 7 | :param img: 8 | :param mask: 9 | :return: img 10 | ''' 11 | raise NotImplementedError 12 | 13 | def inpaint_rgbd(self, img, distance, mask): 14 | raise NotImplementedError 15 | 16 | def encode(self, img, mask): 17 | ''' 18 | :param img: [B, 3, H, W] 19 | :param mask: [B, 1, H, W] 20 | :return: z 21 | ''' 22 | raise NotImplementedError 23 | 24 | -------------------------------------------------------------------------------- /modules/inpainters/lama/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/models/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/__init__.py: -------------------------------------------------------------------------------- 1 | from .base import * -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/color150.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/models/ade20k/color150.mat -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/models/ade20k/segm_lib/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .modules import * 2 | from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 3 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/modules/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : __init__.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d 12 | from .replicate import DataParallelWithCallback, patch_replication_callback 13 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/modules/replicate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : replicate.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import functools 12 | 13 | from torch.nn.parallel.data_parallel import DataParallel 14 | 15 | __all__ = [ 16 | 'CallbackContext', 17 | 'execute_replication_callbacks', 18 | 'DataParallelWithCallback', 19 | 'patch_replication_callback' 20 | ] 21 | 22 | 23 | class CallbackContext(object): 24 | pass 25 | 26 | 27 | def execute_replication_callbacks(modules): 28 | """ 29 | Execute an replication callback `__data_parallel_replicate__` on each module created by original replication. 30 | 31 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 32 | 33 | Note that, as all modules are isomorphism, we assign each sub-module with a context 34 | (shared among multiple copies of this module on different devices). 35 | Through this context, different copies can share some information. 36 | 37 | We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback 38 | of any slave copies. 39 | """ 40 | master_copy = modules[0] 41 | nr_modules = len(list(master_copy.modules())) 42 | ctxs = [CallbackContext() for _ in range(nr_modules)] 43 | 44 | for i, module in enumerate(modules): 45 | for j, m in enumerate(module.modules()): 46 | if hasattr(m, '__data_parallel_replicate__'): 47 | m.__data_parallel_replicate__(ctxs[j], i) 48 | 49 | 50 | class DataParallelWithCallback(DataParallel): 51 | """ 52 | Data Parallel with a replication callback. 53 | 54 | An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by 55 | original `replicate` function. 56 | The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)` 57 | 58 | Examples: 59 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 60 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 61 | # sync_bn.__data_parallel_replicate__ will be invoked. 62 | """ 63 | 64 | def replicate(self, module, device_ids): 65 | modules = super(DataParallelWithCallback, self).replicate(module, device_ids) 66 | execute_replication_callbacks(modules) 67 | return modules 68 | 69 | 70 | def patch_replication_callback(data_parallel): 71 | """ 72 | Monkey-patch an existing `DataParallel` object. Add the replication callback. 73 | Useful when you have customized `DataParallel` implementation. 74 | 75 | Examples: 76 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 77 | > sync_bn = DataParallel(sync_bn, device_ids=[0, 1]) 78 | > patch_replication_callback(sync_bn) 79 | # this is equivalent to 80 | > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 81 | > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 82 | """ 83 | 84 | assert isinstance(data_parallel, DataParallel) 85 | 86 | old_replicate = data_parallel.replicate 87 | 88 | @functools.wraps(old_replicate) 89 | def new_replicate(module, device_ids): 90 | modules = old_replicate(module, device_ids) 91 | execute_replication_callbacks(modules) 92 | return modules 93 | 94 | data_parallel.replicate = new_replicate 95 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/modules/tests/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/models/ade20k/segm_lib/nn/modules/tests/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_numeric_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm.unittest import TorchTestCase 16 | 17 | 18 | def handy_var(a, unbias=True): 19 | n = a.size(0) 20 | asum = a.sum(dim=0) 21 | as_sum = (a ** 2).sum(dim=0) # a square sum 22 | sumvar = as_sum - asum * asum / n 23 | if unbias: 24 | return sumvar / (n - 1) 25 | else: 26 | return sumvar / n 27 | 28 | 29 | class NumericTestCase(TorchTestCase): 30 | def testNumericBatchNorm(self): 31 | a = torch.rand(16, 10) 32 | bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False) 33 | bn.train() 34 | 35 | a_var1 = Variable(a, requires_grad=True) 36 | b_var1 = bn(a_var1) 37 | loss1 = b_var1.sum() 38 | loss1.backward() 39 | 40 | a_var2 = Variable(a, requires_grad=True) 41 | a_mean2 = a_var2.mean(dim=0, keepdim=True) 42 | a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5)) 43 | # a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5) 44 | b_var2 = (a_var2 - a_mean2) / a_std2 45 | loss2 = b_var2.sum() 46 | loss2.backward() 47 | 48 | self.assertTensorClose(bn.running_mean, a.mean(dim=0)) 49 | self.assertTensorClose(bn.running_var, handy_var(a)) 50 | self.assertTensorClose(a_var1.data, a_var2.data) 51 | self.assertTensorClose(b_var1.data, b_var2.data) 52 | self.assertTensorClose(a_var1.grad, a_var2.grad) 53 | 54 | 55 | if __name__ == '__main__': 56 | unittest.main() 57 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : test_sync_batchnorm.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | 9 | import unittest 10 | 11 | import torch 12 | import torch.nn as nn 13 | from torch.autograd import Variable 14 | 15 | from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback 16 | from sync_batchnorm.unittest import TorchTestCase 17 | 18 | 19 | def handy_var(a, unbias=True): 20 | n = a.size(0) 21 | asum = a.sum(dim=0) 22 | as_sum = (a ** 2).sum(dim=0) # a square sum 23 | sumvar = as_sum - asum * asum / n 24 | if unbias: 25 | return sumvar / (n - 1) 26 | else: 27 | return sumvar / n 28 | 29 | 30 | def _find_bn(module): 31 | for m in module.modules(): 32 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)): 33 | return m 34 | 35 | 36 | class SyncTestCase(TorchTestCase): 37 | def _syncParameters(self, bn1, bn2): 38 | bn1.reset_parameters() 39 | bn2.reset_parameters() 40 | if bn1.affine and bn2.affine: 41 | bn2.weight.data.copy_(bn1.weight.data) 42 | bn2.bias.data.copy_(bn1.bias.data) 43 | 44 | def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False): 45 | """Check the forward and backward for the customized batch normalization.""" 46 | bn1.train(mode=is_train) 47 | bn2.train(mode=is_train) 48 | 49 | if cuda: 50 | input = input.cuda() 51 | 52 | self._syncParameters(_find_bn(bn1), _find_bn(bn2)) 53 | 54 | input1 = Variable(input, requires_grad=True) 55 | output1 = bn1(input1) 56 | output1.sum().backward() 57 | input2 = Variable(input, requires_grad=True) 58 | output2 = bn2(input2) 59 | output2.sum().backward() 60 | 61 | self.assertTensorClose(input1.data, input2.data) 62 | self.assertTensorClose(output1.data, output2.data) 63 | self.assertTensorClose(input1.grad, input2.grad) 64 | self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean) 65 | self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var) 66 | 67 | def testSyncBatchNormNormalTrain(self): 68 | bn = nn.BatchNorm1d(10) 69 | sync_bn = SynchronizedBatchNorm1d(10) 70 | 71 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True) 72 | 73 | def testSyncBatchNormNormalEval(self): 74 | bn = nn.BatchNorm1d(10) 75 | sync_bn = SynchronizedBatchNorm1d(10) 76 | 77 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False) 78 | 79 | def testSyncBatchNormSyncTrain(self): 80 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 81 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 82 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 83 | 84 | bn.cuda() 85 | sync_bn.cuda() 86 | 87 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True) 88 | 89 | def testSyncBatchNormSyncEval(self): 90 | bn = nn.BatchNorm1d(10, eps=1e-5, affine=False) 91 | sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False) 92 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 93 | 94 | bn.cuda() 95 | sync_bn.cuda() 96 | 97 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True) 98 | 99 | def testSyncBatchNorm2DSyncTrain(self): 100 | bn = nn.BatchNorm2d(10) 101 | sync_bn = SynchronizedBatchNorm2d(10) 102 | sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1]) 103 | 104 | bn.cuda() 105 | sync_bn.cuda() 106 | 107 | self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True) 108 | 109 | 110 | if __name__ == '__main__': 111 | unittest.main() 112 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/modules/unittest.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # File : unittest.py 3 | # Author : Jiayuan Mao 4 | # Email : maojiayuan@gmail.com 5 | # Date : 27/01/2018 6 | # 7 | # This file is part of Synchronized-BatchNorm-PyTorch. 8 | # https://github.com/vacancy/Synchronized-BatchNorm-PyTorch 9 | # Distributed under MIT License. 10 | 11 | import unittest 12 | 13 | import numpy as np 14 | from torch.autograd import Variable 15 | 16 | 17 | def as_numpy(v): 18 | if isinstance(v, Variable): 19 | v = v.data 20 | return v.cpu().numpy() 21 | 22 | 23 | class TorchTestCase(unittest.TestCase): 24 | def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3): 25 | npa, npb = as_numpy(a), as_numpy(b) 26 | self.assertTrue( 27 | np.allclose(npa, npb, atol=atol), 28 | 'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max()) 29 | ) 30 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/parallel/__init__.py: -------------------------------------------------------------------------------- 1 | from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to 2 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/nn/parallel/data_parallel.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf8 -*- 2 | 3 | import torch.cuda as cuda 4 | import torch.nn as nn 5 | import torch 6 | import collections 7 | from torch.nn.parallel._functions import Gather 8 | 9 | 10 | __all__ = ['UserScatteredDataParallel', 'user_scattered_collate', 'async_copy_to'] 11 | 12 | 13 | def async_copy_to(obj, dev, main_stream=None): 14 | if torch.is_tensor(obj): 15 | v = obj.cuda(dev, non_blocking=True) 16 | if main_stream is not None: 17 | v.data.record_stream(main_stream) 18 | return v 19 | elif isinstance(obj, collections.Mapping): 20 | return {k: async_copy_to(o, dev, main_stream) for k, o in obj.items()} 21 | elif isinstance(obj, collections.Sequence): 22 | return [async_copy_to(o, dev, main_stream) for o in obj] 23 | else: 24 | return obj 25 | 26 | 27 | def dict_gather(outputs, target_device, dim=0): 28 | """ 29 | Gathers variables from different GPUs on a specified device 30 | (-1 means the CPU), with dictionary support. 31 | """ 32 | def gather_map(outputs): 33 | out = outputs[0] 34 | if torch.is_tensor(out): 35 | # MJY(20180330) HACK:: force nr_dims > 0 36 | if out.dim() == 0: 37 | outputs = [o.unsqueeze(0) for o in outputs] 38 | return Gather.apply(target_device, dim, *outputs) 39 | elif out is None: 40 | return None 41 | elif isinstance(out, collections.Mapping): 42 | return {k: gather_map([o[k] for o in outputs]) for k in out} 43 | elif isinstance(out, collections.Sequence): 44 | return type(out)(map(gather_map, zip(*outputs))) 45 | return gather_map(outputs) 46 | 47 | 48 | class DictGatherDataParallel(nn.DataParallel): 49 | def gather(self, outputs, output_device): 50 | return dict_gather(outputs, output_device, dim=self.dim) 51 | 52 | 53 | class UserScatteredDataParallel(DictGatherDataParallel): 54 | def scatter(self, inputs, kwargs, device_ids): 55 | assert len(inputs) == 1 56 | inputs = inputs[0] 57 | inputs = _async_copy_stream(inputs, device_ids) 58 | inputs = [[i] for i in inputs] 59 | assert len(kwargs) == 0 60 | kwargs = [{} for _ in range(len(inputs))] 61 | 62 | return inputs, kwargs 63 | 64 | 65 | def user_scattered_collate(batch): 66 | return batch 67 | 68 | 69 | def _async_copy(inputs, device_ids): 70 | nr_devs = len(device_ids) 71 | assert type(inputs) in (tuple, list) 72 | assert len(inputs) == nr_devs 73 | 74 | outputs = [] 75 | for i, dev in zip(inputs, device_ids): 76 | with cuda.device(dev): 77 | outputs.append(async_copy_to(i, dev)) 78 | 79 | return tuple(outputs) 80 | 81 | 82 | def _async_copy_stream(inputs, device_ids): 83 | nr_devs = len(device_ids) 84 | assert type(inputs) in (tuple, list) 85 | assert len(inputs) == nr_devs 86 | 87 | outputs = [] 88 | streams = [_get_stream(d) for d in device_ids] 89 | for i, dev, stream in zip(inputs, device_ids, streams): 90 | with cuda.device(dev): 91 | main_stream = cuda.current_stream() 92 | with cuda.stream(stream): 93 | outputs.append(async_copy_to(i, dev, main_stream=main_stream)) 94 | main_stream.wait_stream(stream) 95 | 96 | return outputs 97 | 98 | 99 | """Adapted from: torch/nn/parallel/_functions.py""" 100 | # background streams used for copying 101 | _streams = None 102 | 103 | 104 | def _get_stream(device): 105 | """Gets a background stream for copying between CPU and GPU""" 106 | global _streams 107 | if device == -1: 108 | return None 109 | if _streams is None: 110 | _streams = [None] * cuda.device_count() 111 | if _streams[device] is None: _streams[device] = cuda.Stream(device) 112 | return _streams[device] 113 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .th import * 2 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/segm_lib/utils/th.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import collections 5 | 6 | __all__ = ['as_variable', 'as_numpy', 'mark_volatile'] 7 | 8 | def as_variable(obj): 9 | if isinstance(obj, Variable): 10 | return obj 11 | if isinstance(obj, collections.Sequence): 12 | return [as_variable(v) for v in obj] 13 | elif isinstance(obj, collections.Mapping): 14 | return {k: as_variable(v) for k, v in obj.items()} 15 | else: 16 | return Variable(obj) 17 | 18 | def as_numpy(obj): 19 | if isinstance(obj, collections.Sequence): 20 | return [as_numpy(v) for v in obj] 21 | elif isinstance(obj, collections.Mapping): 22 | return {k: as_numpy(v) for k, v in obj.items()} 23 | elif isinstance(obj, Variable): 24 | return obj.data.cpu().numpy() 25 | elif torch.is_tensor(obj): 26 | return obj.cpu().numpy() 27 | else: 28 | return np.array(obj) 29 | 30 | def mark_volatile(obj): 31 | if torch.is_tensor(obj): 32 | obj = Variable(obj) 33 | if isinstance(obj, Variable): 34 | obj.no_grad = True 35 | return obj 36 | elif isinstance(obj, collections.Mapping): 37 | return {k: mark_volatile(o) for k, o in obj.items()} 38 | elif isinstance(obj, collections.Sequence): 39 | return [mark_volatile(o) for o in obj] 40 | else: 41 | return obj 42 | -------------------------------------------------------------------------------- /modules/inpainters/lama/models/ade20k/utils.py: -------------------------------------------------------------------------------- 1 | """Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch""" 2 | 3 | import os 4 | import sys 5 | 6 | import numpy as np 7 | import torch 8 | 9 | try: 10 | from urllib import urlretrieve 11 | except ImportError: 12 | from urllib.request import urlretrieve 13 | 14 | 15 | def load_url(url, model_dir='./pretrained', map_location=None): 16 | if not os.path.exists(model_dir): 17 | os.makedirs(model_dir) 18 | filename = url.split('/')[-1] 19 | cached_file = os.path.join(model_dir, filename) 20 | if not os.path.exists(cached_file): 21 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 22 | urlretrieve(url, cached_file) 23 | return torch.load(cached_file, map_location=map_location) 24 | 25 | 26 | def color_encode(labelmap, colors, mode='RGB'): 27 | labelmap = labelmap.astype('int') 28 | labelmap_rgb = np.zeros((labelmap.shape[0], labelmap.shape[1], 3), 29 | dtype=np.uint8) 30 | for label in np.unique(labelmap): 31 | if label < 0: 32 | continue 33 | labelmap_rgb += (labelmap == label)[:, :, np.newaxis] * \ 34 | np.tile(colors[label], 35 | (labelmap.shape[0], labelmap.shape[1], 1)) 36 | 37 | if mode == 'BGR': 38 | return labelmap_rgb[:, :, ::-1] 39 | else: 40 | return labelmap_rgb 41 | -------------------------------------------------------------------------------- /modules/inpainters/lama/predict_config.yaml: -------------------------------------------------------------------------------- 1 | indir: no # to be overriden in CLI 2 | outdir: no # to be overriden in CLI 3 | 4 | model: 5 | path: no # to be overriden in CLI 6 | checkpoint: best.ckpt 7 | 8 | dataset: 9 | kind: default 10 | img_suffix: .png 11 | pad_out_to_modulo: 8 12 | 13 | device: cuda 14 | # out_key: inpainted 15 | out_key: predicted_image 16 | 17 | refine: False # refiner will only run if this is True 18 | refiner: 19 | gpu_ids: "0," # the GPU ids of the machine to use. If only single GPU, use: "0," 20 | modulo: ${dataset.pad_out_to_modulo} 21 | n_iters: 15 # number of iterations of refinement for each scale 22 | lr: 0.002 # learning rate 23 | min_side: 512 # all sides of image on all scales should be >= min_side / sqrt(2) 24 | max_scales: 3 # max number of downscaling scales for the image-mask pyramid 25 | px_budget: 1800000 # pixels budget. Any image will be resized to satisfy height*width <= px_budget 26 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import torch 4 | 5 | from modules.inpainters.lama.saicinpainting.evaluation.evaluator import InpaintingEvaluatorOnline, ssim_fid100_f1, lpips_fid100_f1 6 | from modules.inpainters.lama.saicinpainting.evaluation.losses.base_loss import SSIMScore, LPIPSScore, FIDScore 7 | 8 | 9 | def make_evaluator(kind='default', ssim=True, lpips=True, fid=True, integral_kind=None, **kwargs): 10 | logging.info(f'Make evaluator {kind}') 11 | device = "cuda" if torch.cuda.is_available() else "cpu" 12 | metrics = {} 13 | if ssim: 14 | metrics['ssim'] = SSIMScore() 15 | if lpips: 16 | metrics['lpips'] = LPIPSScore() 17 | if fid: 18 | metrics['fid'] = FIDScore().to(device) 19 | 20 | if integral_kind is None: 21 | integral_func = None 22 | elif integral_kind == 'ssim_fid100_f1': 23 | integral_func = ssim_fid100_f1 24 | elif integral_kind == 'lpips_fid100_f1': 25 | integral_func = lpips_fid100_f1 26 | else: 27 | raise ValueError(f'Unexpected integral_kind={integral_kind}') 28 | 29 | if kind == 'default': 30 | return InpaintingEvaluatorOnline(scores=metrics, 31 | integral_func=integral_func, 32 | integral_title=integral_kind, 33 | **kwargs) 34 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/losses/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/losses/fid/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/losses/fid/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/losses/ssim.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn.functional as F 4 | 5 | 6 | class SSIM(torch.nn.Module): 7 | """SSIM. Modified from: 8 | https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 9 | """ 10 | 11 | def __init__(self, window_size=11, size_average=True): 12 | super().__init__() 13 | self.window_size = window_size 14 | self.size_average = size_average 15 | self.channel = 1 16 | self.register_buffer('window', self._create_window(window_size, self.channel)) 17 | 18 | def forward(self, img1, img2): 19 | assert len(img1.shape) == 4 20 | 21 | channel = img1.size()[1] 22 | 23 | if channel == self.channel and self.window.data.type() == img1.data.type(): 24 | window = self.window 25 | else: 26 | window = self._create_window(self.window_size, channel) 27 | 28 | # window = window.to(img1.get_device()) 29 | window = window.type_as(img1) 30 | 31 | self.window = window 32 | self.channel = channel 33 | 34 | return self._ssim(img1, img2, window, self.window_size, channel, self.size_average) 35 | 36 | def _gaussian(self, window_size, sigma): 37 | gauss = torch.Tensor([ 38 | np.exp(-(x - (window_size // 2)) ** 2 / float(2 * sigma ** 2)) for x in range(window_size) 39 | ]) 40 | return gauss / gauss.sum() 41 | 42 | def _create_window(self, window_size, channel): 43 | _1D_window = self._gaussian(window_size, 1.5).unsqueeze(1) 44 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 45 | return _2D_window.expand(channel, 1, window_size, window_size).contiguous() 46 | 47 | def _ssim(self, img1, img2, window, window_size, channel, size_average=True): 48 | mu1 = F.conv2d(img1, window, padding=(window_size // 2), groups=channel) 49 | mu2 = F.conv2d(img2, window, padding=(window_size // 2), groups=channel) 50 | 51 | mu1_sq = mu1.pow(2) 52 | mu2_sq = mu2.pow(2) 53 | mu1_mu2 = mu1 * mu2 54 | 55 | sigma1_sq = F.conv2d( 56 | img1 * img1, window, padding=(window_size // 2), groups=channel) - mu1_sq 57 | sigma2_sq = F.conv2d( 58 | img2 * img2, window, padding=(window_size // 2), groups=channel) - mu2_sq 59 | sigma12 = F.conv2d( 60 | img1 * img2, window, padding=(window_size // 2), groups=channel) - mu1_mu2 61 | 62 | C1 = 0.01 ** 2 63 | C2 = 0.03 ** 2 64 | 65 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / \ 66 | ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 67 | 68 | if size_average: 69 | return ssim_map.mean() 70 | 71 | return ssim_map.mean(1).mean(1).mean(1) 72 | 73 | def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): 74 | return 75 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/README.md: -------------------------------------------------------------------------------- 1 | # Current algorithm 2 | 3 | ## Choice of mask objects 4 | 5 | For identification of the objects which are suitable for mask obtaining, panoptic segmentation model 6 | from [detectron2](https://github.com/facebookresearch/detectron2) trained on COCO. Categories of the detected instances 7 | belong either to "stuff" or "things" types. We consider that instances of objects should have category belong 8 | to "things". Besides, we set upper bound on area which is taken by the object — we consider that too big 9 | area indicates either of the instance being a background or a main object which should not be removed. 10 | 11 | ## Choice of position for mask 12 | 13 | We consider that input image has size 2^n x 2^m. We downsample it using 14 | [COUNTLESS](https://github.com/william-silversmith/countless) algorithm so the width is equal to 15 | 64 = 2^8 = 2^{downsample_levels}. 16 | 17 | ### Augmentation 18 | 19 | There are several parameters for augmentation: 20 | - Scaling factor. We limit scaling to the case when a mask after scaling with pivot point in its center fits inside the 21 | image completely. 22 | - 23 | 24 | ### Shift 25 | 26 | 27 | ## Select 28 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/README.md: -------------------------------------------------------------------------------- 1 | [![Build Status](https://travis-ci.org/william-silversmith/countless.svg?branch=master)](https://travis-ci.org/william-silversmith/countless) 2 | 3 | Python COUNTLESS Downsampling 4 | ============================= 5 | 6 | To install: 7 | 8 | `pip install -r requirements.txt` 9 | 10 | To test: 11 | 12 | `python test.py` 13 | 14 | To benchmark countless2d: 15 | 16 | `python python/countless2d.py python/images/gray_segmentation.png` 17 | 18 | To benchmark countless3d: 19 | 20 | `python python/countless3d.py` 21 | 22 | Adjust N and the list of algorithms inside each script to modify the run parameters. 23 | 24 | 25 | Python3 is slightly faster than Python2. -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/images/gcim.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/images/gcim.jpg -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/images/gray_segmentation.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/images/segmentation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/images/segmentation.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/images/sparse.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/images/sparse.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless2d_gcim_N_1000.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless2d_quick_gcim_N_1000.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless3d.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless3d.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless3d_dynamic_generalized.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/evaluation/masks/countless/memprof/countless3d_generalized.png -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/masks/countless/requirements.txt: -------------------------------------------------------------------------------- 1 | Pillow>=6.2.0 2 | numpy>=1.16 3 | scipy 4 | tqdm 5 | memory_profiler 6 | six 7 | pytest -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/utils.py: -------------------------------------------------------------------------------- 1 | from enum import Enum 2 | 3 | import yaml 4 | from easydict import EasyDict as edict 5 | import torch.nn as nn 6 | import torch 7 | 8 | 9 | def load_yaml(path): 10 | with open(path, 'r') as f: 11 | return edict(yaml.safe_load(f)) 12 | 13 | 14 | def move_to_device(obj, device): 15 | if isinstance(obj, nn.Module): 16 | return obj.to(device) 17 | if torch.is_tensor(obj): 18 | return obj.to(device) 19 | if isinstance(obj, (tuple, list)): 20 | return [move_to_device(el, device) for el in obj] 21 | if isinstance(obj, dict): 22 | return {name: move_to_device(val, device) for name, val in obj.items()} 23 | raise ValueError(f'Unexpected type {type(obj)}') 24 | 25 | 26 | class SmallMode(Enum): 27 | DROP = "drop" 28 | UPSCALE = "upscale" 29 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/evaluation/vis.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from skimage import io 3 | from skimage.segmentation import mark_boundaries 4 | 5 | 6 | def save_item_for_vis(item, out_file): 7 | mask = item['mask'] > 0.5 8 | if mask.ndim == 3: 9 | mask = mask[0] 10 | img = mark_boundaries(np.transpose(item['image'], (1, 2, 0)), 11 | mask, 12 | color=(1., 0., 0.), 13 | outline_color=(1., 1., 1.), 14 | mode='thick') 15 | 16 | if 'inpainted' in item: 17 | inp_img = mark_boundaries(np.transpose(item['inpainted'], (1, 2, 0)), 18 | mask, 19 | color=(1., 0., 0.), 20 | mode='outer') 21 | img = np.concatenate((img, inp_img), axis=1) 22 | 23 | img = np.clip(img * 255, 0, 255).astype('uint8') 24 | io.imsave(out_file, img) 25 | 26 | 27 | def save_mask_for_sidebyside(item, out_file): 28 | mask = item['mask']# > 0.5 29 | if mask.ndim == 3: 30 | mask = mask[0] 31 | mask = np.clip(mask * 255, 0, 255).astype('uint8') 32 | io.imsave(out_file, mask) 33 | 34 | def save_img_for_sidebyside(item, out_file): 35 | img = np.transpose(item['image'], (1, 2, 0)) 36 | img = np.clip(img * 255, 0, 255).astype('uint8') 37 | io.imsave(out_file, img) -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/training/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/training/data/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/data/aug.py: -------------------------------------------------------------------------------- 1 | from albumentations import DualIAATransform, to_tuple 2 | import imgaug.augmenters as iaa 3 | 4 | class IAAAffine2(DualIAATransform): 5 | """Place a regular grid of points on the input and randomly move the neighbourhood of these point around 6 | via affine transformations. 7 | 8 | Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} 9 | 10 | Args: 11 | p (float): probability of applying the transform. Default: 0.5. 12 | 13 | Targets: 14 | image, mask 15 | """ 16 | 17 | def __init__( 18 | self, 19 | scale=(0.7, 1.3), 20 | translate_percent=None, 21 | translate_px=None, 22 | rotate=0.0, 23 | shear=(-0.1, 0.1), 24 | order=1, 25 | cval=0, 26 | mode="reflect", 27 | always_apply=False, 28 | p=0.5, 29 | ): 30 | super(IAAAffine2, self).__init__(always_apply, p) 31 | self.scale = dict(x=scale, y=scale) 32 | self.translate_percent = to_tuple(translate_percent, 0) 33 | self.translate_px = to_tuple(translate_px, 0) 34 | self.rotate = to_tuple(rotate) 35 | self.shear = dict(x=shear, y=shear) 36 | self.order = order 37 | self.cval = cval 38 | self.mode = mode 39 | 40 | @property 41 | def processor(self): 42 | return iaa.Affine( 43 | self.scale, 44 | self.translate_percent, 45 | self.translate_px, 46 | self.rotate, 47 | self.shear, 48 | self.order, 49 | self.cval, 50 | self.mode, 51 | ) 52 | 53 | def get_transform_init_args_names(self): 54 | return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode") 55 | 56 | 57 | class IAAPerspective2(DualIAATransform): 58 | """Perform a random four point perspective transform of the input. 59 | 60 | Note: This class introduce interpolation artifacts to mask if it has values other than {0;1} 61 | 62 | Args: 63 | scale ((float, float): standard deviation of the normal distributions. These are used to sample 64 | the random distances of the subimage's corners from the full image's corners. Default: (0.05, 0.1). 65 | p (float): probability of applying the transform. Default: 0.5. 66 | 67 | Targets: 68 | image, mask 69 | """ 70 | 71 | def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5, 72 | order=1, cval=0, mode="replicate"): 73 | super(IAAPerspective2, self).__init__(always_apply, p) 74 | self.scale = to_tuple(scale, 1.0) 75 | self.keep_size = keep_size 76 | self.cval = cval 77 | self.mode = mode 78 | 79 | @property 80 | def processor(self): 81 | return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval) 82 | 83 | def get_transform_init_args_names(self): 84 | return ("scale", "keep_size") 85 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/losses/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/inpainters/lama/saicinpainting/training/losses/__init__.py -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/losses/constants.py: -------------------------------------------------------------------------------- 1 | weights = {"ade20k": 2 | [6.34517766497462, 3 | 9.328358208955224, 4 | 11.389521640091116, 5 | 16.10305958132045, 6 | 20.833333333333332, 7 | 22.22222222222222, 8 | 25.125628140703515, 9 | 43.29004329004329, 10 | 50.5050505050505, 11 | 54.6448087431694, 12 | 55.24861878453038, 13 | 60.24096385542168, 14 | 62.5, 15 | 66.2251655629139, 16 | 84.74576271186442, 17 | 90.90909090909092, 18 | 91.74311926605505, 19 | 96.15384615384616, 20 | 96.15384615384616, 21 | 97.08737864077669, 22 | 102.04081632653062, 23 | 135.13513513513513, 24 | 149.2537313432836, 25 | 153.84615384615384, 26 | 163.93442622950818, 27 | 166.66666666666666, 28 | 188.67924528301887, 29 | 192.30769230769232, 30 | 217.3913043478261, 31 | 227.27272727272725, 32 | 227.27272727272725, 33 | 227.27272727272725, 34 | 303.03030303030306, 35 | 322.5806451612903, 36 | 333.3333333333333, 37 | 370.3703703703703, 38 | 384.61538461538464, 39 | 416.6666666666667, 40 | 416.6666666666667, 41 | 434.7826086956522, 42 | 434.7826086956522, 43 | 454.5454545454545, 44 | 454.5454545454545, 45 | 500.0, 46 | 526.3157894736842, 47 | 526.3157894736842, 48 | 555.5555555555555, 49 | 555.5555555555555, 50 | 555.5555555555555, 51 | 555.5555555555555, 52 | 555.5555555555555, 53 | 555.5555555555555, 54 | 555.5555555555555, 55 | 588.2352941176471, 56 | 588.2352941176471, 57 | 588.2352941176471, 58 | 588.2352941176471, 59 | 588.2352941176471, 60 | 666.6666666666666, 61 | 666.6666666666666, 62 | 666.6666666666666, 63 | 666.6666666666666, 64 | 714.2857142857143, 65 | 714.2857142857143, 66 | 714.2857142857143, 67 | 714.2857142857143, 68 | 714.2857142857143, 69 | 769.2307692307693, 70 | 769.2307692307693, 71 | 769.2307692307693, 72 | 833.3333333333334, 73 | 833.3333333333334, 74 | 833.3333333333334, 75 | 833.3333333333334, 76 | 909.090909090909, 77 | 1000.0, 78 | 1111.111111111111, 79 | 1111.111111111111, 80 | 1111.111111111111, 81 | 1111.111111111111, 82 | 1111.111111111111, 83 | 1250.0, 84 | 1250.0, 85 | 1250.0, 86 | 1250.0, 87 | 1250.0, 88 | 1428.5714285714287, 89 | 1428.5714285714287, 90 | 1428.5714285714287, 91 | 1428.5714285714287, 92 | 1428.5714285714287, 93 | 1428.5714285714287, 94 | 1428.5714285714287, 95 | 1666.6666666666667, 96 | 1666.6666666666667, 97 | 1666.6666666666667, 98 | 1666.6666666666667, 99 | 1666.6666666666667, 100 | 1666.6666666666667, 101 | 1666.6666666666667, 102 | 1666.6666666666667, 103 | 1666.6666666666667, 104 | 1666.6666666666667, 105 | 1666.6666666666667, 106 | 2000.0, 107 | 2000.0, 108 | 2000.0, 109 | 2000.0, 110 | 2000.0, 111 | 2000.0, 112 | 2000.0, 113 | 2000.0, 114 | 2000.0, 115 | 2000.0, 116 | 2000.0, 117 | 2000.0, 118 | 2000.0, 119 | 2000.0, 120 | 2000.0, 121 | 2000.0, 122 | 2000.0, 123 | 2500.0, 124 | 2500.0, 125 | 2500.0, 126 | 2500.0, 127 | 2500.0, 128 | 2500.0, 129 | 2500.0, 130 | 2500.0, 131 | 2500.0, 132 | 2500.0, 133 | 2500.0, 134 | 2500.0, 135 | 2500.0, 136 | 3333.3333333333335, 137 | 3333.3333333333335, 138 | 3333.3333333333335, 139 | 3333.3333333333335, 140 | 3333.3333333333335, 141 | 3333.3333333333335, 142 | 3333.3333333333335, 143 | 3333.3333333333335, 144 | 3333.3333333333335, 145 | 3333.3333333333335, 146 | 3333.3333333333335, 147 | 3333.3333333333335, 148 | 3333.3333333333335, 149 | 5000.0, 150 | 5000.0, 151 | 5000.0] 152 | } -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/losses/feature_matching.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def masked_l2_loss(pred, target, mask, weight_known, weight_missing): 8 | per_pixel_l2 = F.mse_loss(pred, target, reduction='none') 9 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known 10 | return (pixel_weights * per_pixel_l2).mean() 11 | 12 | 13 | def masked_l1_loss(pred, target, mask, weight_known, weight_missing): 14 | per_pixel_l1 = F.l1_loss(pred, target, reduction='none') 15 | pixel_weights = mask * weight_missing + (1 - mask) * weight_known 16 | return (pixel_weights * per_pixel_l1).mean() 17 | 18 | 19 | def feature_matching_loss(fake_features: List[torch.Tensor], target_features: List[torch.Tensor], mask=None): 20 | if mask is None: 21 | res = torch.stack([F.mse_loss(fake_feat, target_feat) 22 | for fake_feat, target_feat in zip(fake_features, target_features)]).mean() 23 | else: 24 | res = 0 25 | norm = 0 26 | for fake_feat, target_feat in zip(fake_features, target_features): 27 | cur_mask = F.interpolate(mask, size=fake_feat.shape[-2:], mode='bilinear', align_corners=False) 28 | error_weights = 1 - cur_mask 29 | cur_val = ((fake_feat - target_feat).pow(2) * error_weights).mean() 30 | res = res + cur_val 31 | norm += 1 32 | res = res / norm 33 | return res 34 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/losses/segmentation.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .constants import weights as constant_weights 6 | 7 | 8 | class CrossEntropy2d(nn.Module): 9 | def __init__(self, reduction="mean", ignore_label=255, weights=None, *args, **kwargs): 10 | """ 11 | weight (Tensor, optional): a manual rescaling weight given to each class. 12 | If given, has to be a Tensor of size "nclasses" 13 | """ 14 | super(CrossEntropy2d, self).__init__() 15 | self.reduction = reduction 16 | self.ignore_label = ignore_label 17 | self.weights = weights 18 | if self.weights is not None: 19 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 20 | self.weights = torch.FloatTensor(constant_weights[weights]).to(device) 21 | 22 | def forward(self, predict, target): 23 | """ 24 | Args: 25 | predict:(n, c, h, w) 26 | target:(n, 1, h, w) 27 | """ 28 | target = target.long() 29 | assert not target.requires_grad 30 | assert predict.dim() == 4, "{0}".format(predict.size()) 31 | assert target.dim() == 4, "{0}".format(target.size()) 32 | assert predict.size(0) == target.size(0), "{0} vs {1} ".format(predict.size(0), target.size(0)) 33 | assert target.size(1) == 1, "{0}".format(target.size(1)) 34 | assert predict.size(2) == target.size(2), "{0} vs {1} ".format(predict.size(2), target.size(2)) 35 | assert predict.size(3) == target.size(3), "{0} vs {1} ".format(predict.size(3), target.size(3)) 36 | target = target.squeeze(1) 37 | n, c, h, w = predict.size() 38 | target_mask = (target >= 0) * (target != self.ignore_label) 39 | target = target[target_mask] 40 | predict = predict.transpose(1, 2).transpose(2, 3).contiguous() 41 | predict = predict[target_mask.view(n, h, w, 1).repeat(1, 1, 1, c)].view(-1, c) 42 | loss = F.cross_entropy(predict, target, weight=self.weights, reduction=self.reduction) 43 | return loss 44 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/modules/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from modules.inpainters.lama.saicinpainting.training.modules.ffc import FFCResNetGenerator 4 | from modules.inpainters.lama.saicinpainting.training.modules.pix2pixhd import GlobalGenerator, MultiDilatedGlobalGenerator, \ 5 | NLayerDiscriminator, MultidilatedNLayerDiscriminator 6 | 7 | def make_generator(config, kind, **kwargs): 8 | logging.info(f'Make generator {kind}') 9 | 10 | if kind == 'pix2pixhd_multidilated': 11 | return MultiDilatedGlobalGenerator(**kwargs) 12 | 13 | if kind == 'pix2pixhd_global': 14 | return GlobalGenerator(**kwargs) 15 | 16 | if kind == 'ffc_resnet': 17 | return FFCResNetGenerator(**kwargs) 18 | 19 | raise ValueError(f'Unknown generator kind {kind}') 20 | 21 | 22 | def make_discriminator(kind, **kwargs): 23 | logging.info(f'Make discriminator {kind}') 24 | 25 | if kind == 'pix2pixhd_nlayer_multidilated': 26 | return MultidilatedNLayerDiscriminator(**kwargs) 27 | 28 | if kind == 'pix2pixhd_nlayer': 29 | return NLayerDiscriminator(**kwargs) 30 | 31 | raise ValueError(f'Unknown discriminator kind {kind}') 32 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/modules/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Tuple, List 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | from modules.inpainters.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv 8 | from modules.inpainters.lama.saicinpainting.training.modules.multidilated_conv import MultidilatedConv 9 | 10 | 11 | class BaseDiscriminator(nn.Module): 12 | @abc.abstractmethod 13 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]: 14 | """ 15 | Predict scores and get intermediate activations. Useful for feature matching loss 16 | :return tuple (scores, list of intermediate activations) 17 | """ 18 | raise NotImplemented() 19 | 20 | 21 | def get_conv_block_ctor(kind='default'): 22 | if not isinstance(kind, str): 23 | return kind 24 | if kind == 'default': 25 | return nn.Conv2d 26 | if kind == 'depthwise': 27 | return DepthWiseSeperableConv 28 | if kind == 'multidilated': 29 | return MultidilatedConv 30 | raise ValueError(f'Unknown convolutional block kind {kind}') 31 | 32 | 33 | def get_norm_layer(kind='bn'): 34 | if not isinstance(kind, str): 35 | return kind 36 | if kind == 'bn': 37 | return nn.BatchNorm2d 38 | if kind == 'in': 39 | return nn.InstanceNorm2d 40 | raise ValueError(f'Unknown norm block kind {kind}') 41 | 42 | 43 | def get_activation(kind='tanh'): 44 | if kind == 'tanh': 45 | return nn.Tanh() 46 | if kind == 'sigmoid': 47 | return nn.Sigmoid() 48 | if kind is False: 49 | return nn.Identity() 50 | raise ValueError(f'Unknown activation kind {kind}') 51 | 52 | 53 | class SimpleMultiStepGenerator(nn.Module): 54 | def __init__(self, steps: List[nn.Module]): 55 | super().__init__() 56 | self.steps = nn.ModuleList(steps) 57 | 58 | def forward(self, x): 59 | cur_in = x 60 | outs = [] 61 | for step in self.steps: 62 | cur_out = step(cur_in) 63 | outs.append(cur_out) 64 | cur_in = torch.cat((cur_in, cur_out), dim=1) 65 | return torch.cat(outs[::-1], dim=1) 66 | 67 | def deconv_factory(kind, ngf, mult, norm_layer, activation, max_features): 68 | if kind == 'convtranspose': 69 | return [nn.ConvTranspose2d(min(max_features, ngf * mult), 70 | min(max_features, int(ngf * mult / 2)), 71 | kernel_size=3, stride=2, padding=1, output_padding=1), 72 | norm_layer(min(max_features, int(ngf * mult / 2))), activation] 73 | elif kind == 'bilinear': 74 | return [nn.Upsample(scale_factor=2, mode='bilinear'), 75 | DepthWiseSeperableConv(min(max_features, ngf * mult), 76 | min(max_features, int(ngf * mult / 2)), 77 | kernel_size=3, stride=1, padding=1), 78 | norm_layer(min(max_features, int(ngf * mult / 2))), activation] 79 | else: 80 | raise Exception(f"Invalid deconv kind: {kind}") -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/modules/depthwise_sep_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | class DepthWiseSeperableConv(nn.Module): 5 | def __init__(self, in_dim, out_dim, *args, **kwargs): 6 | super().__init__() 7 | if 'groups' in kwargs: 8 | # ignoring groups for Depthwise Sep Conv 9 | del kwargs['groups'] 10 | 11 | self.depthwise = nn.Conv2d(in_dim, in_dim, *args, groups=in_dim, **kwargs) 12 | self.pointwise = nn.Conv2d(in_dim, out_dim, kernel_size=1) 13 | 14 | def forward(self, x): 15 | out = self.depthwise(x) 16 | out = self.pointwise(out) 17 | return out -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/modules/fake_fakes.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from kornia.constants import SamplePadding 3 | from kornia.augmentation import RandomAffine, CenterCrop 4 | 5 | 6 | class FakeFakesGenerator: 7 | def __init__(self, aug_proba=0.5, img_aug_degree=30, img_aug_translate=0.2): 8 | self.grad_aug = RandomAffine(degrees=360, 9 | translate=0.2, 10 | padding_mode=SamplePadding.REFLECTION, 11 | keepdim=False, 12 | p=1) 13 | self.img_aug = RandomAffine(degrees=img_aug_degree, 14 | translate=img_aug_translate, 15 | padding_mode=SamplePadding.REFLECTION, 16 | keepdim=True, 17 | p=1) 18 | self.aug_proba = aug_proba 19 | 20 | def __call__(self, input_images, masks): 21 | blend_masks = self._fill_masks_with_gradient(masks) 22 | blend_target = self._make_blend_target(input_images) 23 | result = input_images * (1 - blend_masks) + blend_target * blend_masks 24 | return result, blend_masks 25 | 26 | def _make_blend_target(self, input_images): 27 | batch_size = input_images.shape[0] 28 | permuted = input_images[torch.randperm(batch_size)] 29 | augmented = self.img_aug(input_images) 30 | is_aug = (torch.rand(batch_size, device=input_images.device)[:, None, None, None] < self.aug_proba).float() 31 | result = augmented * is_aug + permuted * (1 - is_aug) 32 | return result 33 | 34 | def _fill_masks_with_gradient(self, masks): 35 | batch_size, _, height, width = masks.shape 36 | grad = torch.linspace(0, 1, steps=width * 2, device=masks.device, dtype=masks.dtype) \ 37 | .view(1, 1, 1, -1).expand(batch_size, 1, height * 2, width * 2) 38 | grad = self.grad_aug(grad) 39 | grad = CenterCrop((height, width))(grad) 40 | grad *= masks 41 | 42 | grad_for_min = grad + (1 - masks) * 10 43 | grad -= grad_for_min.view(batch_size, -1).min(-1).values[:, None, None, None] 44 | grad /= grad.view(batch_size, -1).max(-1).values[:, None, None, None] + 1e-6 45 | grad.clamp_(min=0, max=1) 46 | 47 | return grad 48 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/modules/multidilated_conv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import random 4 | from modules.inpainters.lama.saicinpainting.training.modules.depthwise_sep_conv import DepthWiseSeperableConv 5 | 6 | class MultidilatedConv(nn.Module): 7 | def __init__(self, in_dim, out_dim, kernel_size, dilation_num=3, comb_mode='sum', equal_dim=True, 8 | shared_weights=False, padding=1, min_dilation=1, shuffle_in_channels=False, use_depthwise=False, **kwargs): 9 | super().__init__() 10 | convs = [] 11 | self.equal_dim = equal_dim 12 | assert comb_mode in ('cat_out', 'sum', 'cat_in', 'cat_both'), comb_mode 13 | if comb_mode in ('cat_out', 'cat_both'): 14 | self.cat_out = True 15 | if equal_dim: 16 | assert out_dim % dilation_num == 0 17 | out_dims = [out_dim // dilation_num] * dilation_num 18 | self.index = sum([[i + j * (out_dims[0]) for j in range(dilation_num)] for i in range(out_dims[0])], []) 19 | else: 20 | out_dims = [out_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 21 | out_dims.append(out_dim - sum(out_dims)) 22 | index = [] 23 | starts = [0] + out_dims[:-1] 24 | lengths = [out_dims[i] // out_dims[-1] for i in range(dilation_num)] 25 | for i in range(out_dims[-1]): 26 | for j in range(dilation_num): 27 | index += list(range(starts[j], starts[j] + lengths[j])) 28 | starts[j] += lengths[j] 29 | self.index = index 30 | assert(len(index) == out_dim) 31 | self.out_dims = out_dims 32 | else: 33 | self.cat_out = False 34 | self.out_dims = [out_dim] * dilation_num 35 | 36 | if comb_mode in ('cat_in', 'cat_both'): 37 | if equal_dim: 38 | assert in_dim % dilation_num == 0 39 | in_dims = [in_dim // dilation_num] * dilation_num 40 | else: 41 | in_dims = [in_dim // 2 ** (i + 1) for i in range(dilation_num - 1)] 42 | in_dims.append(in_dim - sum(in_dims)) 43 | self.in_dims = in_dims 44 | self.cat_in = True 45 | else: 46 | self.cat_in = False 47 | self.in_dims = [in_dim] * dilation_num 48 | 49 | conv_type = DepthWiseSeperableConv if use_depthwise else nn.Conv2d 50 | dilation = min_dilation 51 | for i in range(dilation_num): 52 | if isinstance(padding, int): 53 | cur_padding = padding * dilation 54 | else: 55 | cur_padding = padding[i] 56 | convs.append(conv_type( 57 | self.in_dims[i], self.out_dims[i], kernel_size, padding=cur_padding, dilation=dilation, **kwargs 58 | )) 59 | if i > 0 and shared_weights: 60 | convs[-1].weight = convs[0].weight 61 | convs[-1].bias = convs[0].bias 62 | dilation *= 2 63 | self.convs = nn.ModuleList(convs) 64 | 65 | self.shuffle_in_channels = shuffle_in_channels 66 | if self.shuffle_in_channels: 67 | # shuffle list as shuffling of tensors is nondeterministic 68 | in_channels_permute = list(range(in_dim)) 69 | random.shuffle(in_channels_permute) 70 | # save as buffer so it is saved and loaded with checkpoint 71 | self.register_buffer('in_channels_permute', torch.tensor(in_channels_permute)) 72 | 73 | def forward(self, x): 74 | if self.shuffle_in_channels: 75 | x = x[:, self.in_channels_permute] 76 | 77 | outs = [] 78 | if self.cat_in: 79 | if self.equal_dim: 80 | x = x.chunk(len(self.convs), dim=1) 81 | else: 82 | new_x = [] 83 | start = 0 84 | for dim in self.in_dims: 85 | new_x.append(x[:, start:start+dim]) 86 | start += dim 87 | x = new_x 88 | for i, conv in enumerate(self.convs): 89 | if self.cat_in: 90 | input = x[i] 91 | else: 92 | input = x 93 | outs.append(conv(input)) 94 | if self.cat_out: 95 | out = torch.cat(outs, dim=1)[:, self.index] 96 | else: 97 | out = sum(outs) 98 | return out 99 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/modules/spatial_transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from kornia.geometry.transform import rotate 5 | 6 | 7 | class LearnableSpatialTransformWrapper(nn.Module): 8 | def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True): 9 | super().__init__() 10 | self.impl = impl 11 | self.angle = torch.rand(1) * angle_init_range 12 | if train_angle: 13 | self.angle = nn.Parameter(self.angle, requires_grad=True) 14 | self.pad_coef = pad_coef 15 | 16 | def forward(self, x): 17 | if torch.is_tensor(x): 18 | return self.inverse_transform(self.impl(self.transform(x)), x) 19 | elif isinstance(x, tuple): 20 | x_trans = tuple(self.transform(elem) for elem in x) 21 | y_trans = self.impl(x_trans) 22 | return tuple(self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)) 23 | else: 24 | raise ValueError(f'Unexpected input type {type(x)}') 25 | 26 | def transform(self, x): 27 | height, width = x.shape[2:] 28 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) 29 | x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode='reflect') 30 | x_padded_rotated = rotate(x_padded, angle=self.angle.to(x_padded)) 31 | return x_padded_rotated 32 | 33 | def inverse_transform(self, y_padded_rotated, orig_x): 34 | height, width = orig_x.shape[2:] 35 | pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef) 36 | 37 | y_padded = rotate(y_padded_rotated, angle=-self.angle.to(y_padded_rotated)) 38 | y_height, y_width = y_padded.shape[2:] 39 | y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w] 40 | return y 41 | 42 | 43 | if __name__ == '__main__': 44 | layer = LearnableSpatialTransformWrapper(nn.Identity()) 45 | x = torch.arange(2* 3 * 15 * 15).view(2, 3, 15, 15).float() 46 | y = layer(x) 47 | assert x.shape == y.shape 48 | assert torch.allclose(x[:, :, 1:, 1:][:, :, :-1, :-1], y[:, :, 1:, 1:][:, :, :-1, :-1]) 49 | print('all ok') 50 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/modules/squeeze_excitation.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class SELayer(nn.Module): 5 | def __init__(self, channel, reduction=16): 6 | super(SELayer, self).__init__() 7 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 8 | self.fc = nn.Sequential( 9 | nn.Linear(channel, channel // reduction, bias=False), 10 | nn.ReLU(inplace=True), 11 | nn.Linear(channel // reduction, channel, bias=False), 12 | nn.Sigmoid() 13 | ) 14 | 15 | def forward(self, x): 16 | b, c, _, _ = x.size() 17 | y = self.avg_pool(x).view(b, c) 18 | y = self.fc(y).view(b, c, 1, 1) 19 | res = x * y.expand_as(x) 20 | return res 21 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/trainers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from modules.inpainters.lama.saicinpainting.training.trainers.default import DefaultInpaintingTrainingModule 4 | 5 | 6 | def get_training_model_class(kind): 7 | if kind == 'default': 8 | return DefaultInpaintingTrainingModule 9 | 10 | raise ValueError(f'Unknown trainer module {kind}') 11 | 12 | 13 | def make_training_model(config): 14 | kind = config.training_model.kind 15 | kwargs = dict(config.training_model) 16 | kwargs.pop('kind') 17 | kwargs['use_ddp'] = config.trainer.kwargs.get('accelerator', None) == 'ddp' 18 | 19 | logging.info(f'Make training model {kind}') 20 | 21 | cls = get_training_model_class(kind) 22 | return cls(config, **kwargs) 23 | 24 | 25 | def load_checkpoint(train_config, path, map_location='cuda', strict=True): 26 | model: torch.nn.Module = make_training_model(train_config) 27 | state = torch.load(path, map_location=map_location) 28 | model.load_state_dict(state['state_dict'], strict=strict) 29 | model.on_load_checkpoint(state) 30 | return model 31 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | from modules.inpainters.lama.saicinpainting.training.visualizers.directory import DirectoryVisualizer 4 | from modules.inpainters.lama.saicinpainting.training.visualizers.noop import NoopVisualizer 5 | 6 | 7 | def make_visualizer(kind, **kwargs): 8 | logging.info(f'Make visualizer {kind}') 9 | 10 | if kind == 'directory': 11 | return DirectoryVisualizer(**kwargs) 12 | if kind == 'noop': 13 | return NoopVisualizer() 14 | 15 | raise ValueError(f'Unknown visualizer kind {kind}') 16 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/visualizers/base.py: -------------------------------------------------------------------------------- 1 | import abc 2 | from typing import Dict, List 3 | 4 | import numpy as np 5 | import torch 6 | from skimage import color 7 | from skimage.segmentation import mark_boundaries 8 | 9 | from . import colors 10 | 11 | COLORS, _ = colors.generate_colors(151) # 151 - max classes for semantic segmentation 12 | 13 | 14 | class BaseVisualizer: 15 | @abc.abstractmethod 16 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 17 | """ 18 | Take a batch, make an image from it and visualize 19 | """ 20 | raise NotImplementedError() 21 | 22 | 23 | def visualize_mask_and_images(images_dict: Dict[str, np.ndarray], keys: List[str], 24 | last_without_mask=True, rescale_keys=None, mask_only_first=None, 25 | black_mask=False) -> np.ndarray: 26 | mask = images_dict['mask'] > 0.5 27 | result = [] 28 | for i, k in enumerate(keys): 29 | img = images_dict[k] 30 | img = np.transpose(img, (1, 2, 0)) 31 | 32 | if rescale_keys is not None and k in rescale_keys: 33 | img = img - img.min() 34 | img /= img.max() + 1e-5 35 | if len(img.shape) == 2: 36 | img = np.expand_dims(img, 2) 37 | 38 | if img.shape[2] == 1: 39 | img = np.repeat(img, 3, axis=2) 40 | elif (img.shape[2] > 3): 41 | img_classes = img.argmax(2) 42 | img = color.label2rgb(img_classes, colors=COLORS) 43 | 44 | if mask_only_first: 45 | need_mark_boundaries = i == 0 46 | else: 47 | need_mark_boundaries = i < len(keys) - 1 or not last_without_mask 48 | 49 | if need_mark_boundaries: 50 | if black_mask: 51 | img = img * (1 - mask[0][..., None]) 52 | img = mark_boundaries(img, 53 | mask[0], 54 | color=(1., 0., 0.), 55 | outline_color=(1., 1., 1.), 56 | mode='thick') 57 | result.append(img) 58 | return np.concatenate(result, axis=1) 59 | 60 | 61 | def visualize_mask_and_images_batch(batch: Dict[str, torch.Tensor], keys: List[str], max_items=10, 62 | last_without_mask=True, rescale_keys=None) -> np.ndarray: 63 | batch = {k: tens.detach().cpu().numpy() for k, tens in batch.items() 64 | if k in keys or k == 'mask'} 65 | 66 | batch_size = next(iter(batch.values())).shape[0] 67 | items_to_vis = min(batch_size, max_items) 68 | result = [] 69 | for i in range(items_to_vis): 70 | cur_dct = {k: tens[i] for k, tens in batch.items()} 71 | result.append(visualize_mask_and_images(cur_dct, keys, last_without_mask=last_without_mask, 72 | rescale_keys=rescale_keys)) 73 | return np.concatenate(result, axis=0) 74 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/visualizers/colors.py: -------------------------------------------------------------------------------- 1 | import random 2 | import colorsys 3 | 4 | import numpy as np 5 | import matplotlib 6 | matplotlib.use('agg') 7 | import matplotlib.pyplot as plt 8 | from matplotlib.colors import LinearSegmentedColormap 9 | 10 | 11 | def generate_colors(nlabels, type='bright', first_color_black=False, last_color_black=True, verbose=False): 12 | # https://stackoverflow.com/questions/14720331/how-to-generate-random-colors-in-matplotlib 13 | """ 14 | Creates a random colormap to be used together with matplotlib. Useful for segmentation tasks 15 | :param nlabels: Number of labels (size of colormap) 16 | :param type: 'bright' for strong colors, 'soft' for pastel colors 17 | :param first_color_black: Option to use first color as black, True or False 18 | :param last_color_black: Option to use last color as black, True or False 19 | :param verbose: Prints the number of labels and shows the colormap. True or False 20 | :return: colormap for matplotlib 21 | """ 22 | if type not in ('bright', 'soft'): 23 | print ('Please choose "bright" or "soft" for type') 24 | return 25 | 26 | if verbose: 27 | print('Number of labels: ' + str(nlabels)) 28 | 29 | # Generate color map for bright colors, based on hsv 30 | if type == 'bright': 31 | randHSVcolors = [(np.random.uniform(low=0.0, high=1), 32 | np.random.uniform(low=0.2, high=1), 33 | np.random.uniform(low=0.9, high=1)) for i in range(nlabels)] 34 | 35 | # Convert HSV list to RGB 36 | randRGBcolors = [] 37 | for HSVcolor in randHSVcolors: 38 | randRGBcolors.append(colorsys.hsv_to_rgb(HSVcolor[0], HSVcolor[1], HSVcolor[2])) 39 | 40 | if first_color_black: 41 | randRGBcolors[0] = [0, 0, 0] 42 | 43 | if last_color_black: 44 | randRGBcolors[-1] = [0, 0, 0] 45 | 46 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 47 | 48 | # Generate soft pastel colors, by limiting the RGB spectrum 49 | if type == 'soft': 50 | low = 0.6 51 | high = 0.95 52 | randRGBcolors = [(np.random.uniform(low=low, high=high), 53 | np.random.uniform(low=low, high=high), 54 | np.random.uniform(low=low, high=high)) for i in range(nlabels)] 55 | 56 | if first_color_black: 57 | randRGBcolors[0] = [0, 0, 0] 58 | 59 | if last_color_black: 60 | randRGBcolors[-1] = [0, 0, 0] 61 | random_colormap = LinearSegmentedColormap.from_list('new_map', randRGBcolors, N=nlabels) 62 | 63 | # Display colorbar 64 | if verbose: 65 | from matplotlib import colors, colorbar 66 | from matplotlib import pyplot as plt 67 | fig, ax = plt.subplots(1, 1, figsize=(15, 0.5)) 68 | 69 | bounds = np.linspace(0, nlabels, nlabels + 1) 70 | norm = colors.BoundaryNorm(bounds, nlabels) 71 | 72 | cb = colorbar.ColorbarBase(ax, cmap=random_colormap, norm=norm, spacing='proportional', ticks=None, 73 | boundaries=bounds, format='%1i', orientation=u'horizontal') 74 | 75 | return randRGBcolors, random_colormap 76 | 77 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/visualizers/directory.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import cv2 4 | import numpy as np 5 | 6 | from modules.inpainters.lama.saicinpainting.training.visualizers.base import BaseVisualizer, visualize_mask_and_images_batch 7 | from modules.inpainters.lama.saicinpainting.utils import check_and_warn_input_range 8 | 9 | 10 | class DirectoryVisualizer(BaseVisualizer): 11 | DEFAULT_KEY_ORDER = 'image predicted_image inpainted'.split(' ') 12 | 13 | def __init__(self, outdir, key_order=DEFAULT_KEY_ORDER, max_items_in_batch=10, 14 | last_without_mask=True, rescale_keys=None): 15 | self.outdir = outdir 16 | os.makedirs(self.outdir, exist_ok=True) 17 | self.key_order = key_order 18 | self.max_items_in_batch = max_items_in_batch 19 | self.last_without_mask = last_without_mask 20 | self.rescale_keys = rescale_keys 21 | 22 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 23 | check_and_warn_input_range(batch['image'], 0, 1, 'DirectoryVisualizer target image') 24 | vis_img = visualize_mask_and_images_batch(batch, self.key_order, max_items=self.max_items_in_batch, 25 | last_without_mask=self.last_without_mask, 26 | rescale_keys=self.rescale_keys) 27 | 28 | vis_img = np.clip(vis_img * 255, 0, 255).astype('uint8') 29 | 30 | curoutdir = os.path.join(self.outdir, f'epoch{epoch_i:04d}{suffix}') 31 | os.makedirs(curoutdir, exist_ok=True) 32 | rank_suffix = f'_r{rank}' if rank is not None else '' 33 | out_fname = os.path.join(curoutdir, f'batch{batch_i:07d}{rank_suffix}.jpg') 34 | 35 | vis_img = cv2.cvtColor(vis_img, cv2.COLOR_RGB2BGR) 36 | cv2.imwrite(out_fname, vis_img) 37 | -------------------------------------------------------------------------------- /modules/inpainters/lama/saicinpainting/training/visualizers/noop.py: -------------------------------------------------------------------------------- 1 | from modules.inpainters.lama.saicinpainting.training.visualizers.base import BaseVisualizer 2 | 3 | 4 | class NoopVisualizer(BaseVisualizer): 5 | def __init__(self, *args, **kwargs): 6 | pass 7 | 8 | def __call__(self, epoch_i, batch_i, batch, suffix='', rank=None): 9 | pass 10 | -------------------------------------------------------------------------------- /modules/inpainters/lama_inpainter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | import os 4 | from .inpainter import Inpainter 5 | from .lama.saicinpainting.training.data.datasets import make_default_val_dataset 6 | from .lama.saicinpainting.training.trainers import load_checkpoint 7 | 8 | from omegaconf import OmegaConf 9 | 10 | 11 | class LamaInpainter(Inpainter): 12 | def __init__(self): 13 | super().__init__() 14 | predict_config = OmegaConf.load('./modules/inpainters/lama/predict_config.yaml') 15 | train_config = OmegaConf.load('./checkpoints/big-lama-config.yaml') 16 | 17 | train_config.training_model.predict_only = True 18 | train_config.visualizer.kind = 'noop' 19 | 20 | out_ext = predict_config.get('out_ext', '.png') 21 | 22 | checkpoint_path = './checkpoints/big-lama.ckpt' 23 | self.model = load_checkpoint(train_config, checkpoint_path, strict=False, map_location='cpu') 24 | self.model.freeze() 25 | 26 | @torch.no_grad() 27 | def inpaint(self, img, mask, out_key='inpainted'): 28 | self.model = self.model.to('cuda') 29 | batch = { 'image': img, 'mask': mask } 30 | batch['image'] = (batch['image'] * 255.).to(torch.uint8).float() / 255. 31 | batch['mask'] = (batch['mask'] > 0) * 1 32 | batch = self.model(batch) 33 | cur_res = batch[out_key] 34 | unpad_to_size = batch.get('unpad_to_size', None) 35 | 36 | if unpad_to_size is not None: 37 | orig_height, orig_width = unpad_to_size 38 | cur_res = cur_res[:, :orig_height, :orig_width] 39 | 40 | self.model = self.model.to('cpu') 41 | return cur_res 42 | -------------------------------------------------------------------------------- /modules/inpainters/pano_pers_fusion_inpainter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import cv2 as cv 5 | from tqdm import tqdm 6 | from kornia.morphology import erosion, dilation 7 | 8 | from .inpainter import Inpainter 9 | from .lama_inpainter import LamaInpainter 10 | from .SDFT_inpainter import SDFTInpainter 11 | 12 | from utils.geo_utils import panorama_to_pers_directions 13 | from utils.camera_utils import img_coord_to_sample_coord,\ 14 | direction_to_img_coord, img_coord_to_pano_direction, direction_to_pers_img_coord 15 | 16 | from PIL import Image, ImageDraw 17 | 18 | class PanoPersFusionInpainter(Inpainter): 19 | def __init__(self, save_path, subset_name=None): 20 | super().__init__() 21 | 22 | self.diff_inpainter = SDFTInpainter(subset_name) 23 | 24 | self.lama_inpainter = LamaInpainter() 25 | 26 | self.save_path = save_path 27 | 28 | @torch.no_grad() 29 | def inpaint(self, idx, img, mask): 30 | img = img.squeeze().permute(2, 0, 1) 31 | mask = mask.squeeze()[None] 32 | inpainted_img = img.clone() 33 | 34 | pers_dirs, pers_ratios, to_vecs, down_vecs, right_vecs = panorama_to_pers_directions(gen_res=512, ratio=1.4) 35 | 36 | n_pers = len(pers_dirs) 37 | img_coords = direction_to_img_coord(pers_dirs) 38 | sample_coords = img_coord_to_sample_coord(img_coords) 39 | 40 | _, pano_height, pano_width = img.shape 41 | pano_img_coords = torch.meshgrid(torch.linspace(.5 / pano_height, 1. - .5 / pano_height, pano_height), 42 | torch.linspace(.5 / pano_width, 1. - .5 / pano_width, pano_width), 43 | indexing='ij') 44 | pano_img_coords = torch.stack(list(pano_img_coords), dim=-1) 45 | 46 | pano_dirs = img_coord_to_pano_direction(pano_img_coords) 47 | 48 | for i in tqdm(range(n_pers)): 49 | cur_sample_coords = sample_coords[i] 50 | pers_image = F.grid_sample(inpainted_img[None], cur_sample_coords[None], padding_mode='border')[0] 51 | pers_mask = F.grid_sample(mask[None, :, :], cur_sample_coords[None], padding_mode='border')[0] 52 | pers_mask = (pers_mask > 0.5).float() #CHW 53 | if self.lama_inpainter is not None: 54 | kernel = torch.from_numpy(cv.getStructuringElement(cv.MORPH_ELLIPSE, (11, 11))).float().to(pers_mask.device) 55 | smooth_mask = pers_mask 56 | smooth_mask = erosion(pers_mask[None], kernel=kernel)[0] 57 | smooth_mask = dilation(smooth_mask[None], kernel=kernel)[0] 58 | smooth_mask = torch.minimum(smooth_mask, pers_mask) 59 | lama_inpainted = self.lama_inpainter.inpaint(pers_image[None], pers_mask[None])[0] 60 | if smooth_mask.max().item() > .5: 61 | cur_inpainted = self.diff_inpainter.inpaint(lama_inpainted[None], smooth_mask[None])[0] 62 | else: 63 | cur_inpainted = lama_inpainted 64 | else: 65 | if pers_mask.max().item() > .5: 66 | cur_inpainted = self.diff_inpainter.inpaint(pers_image[None], pers_mask[None])[0] 67 | else: 68 | cur_inpainted = pers_image 69 | 70 | cur_inpainted = pers_image * (1 - pers_mask) + cur_inpainted * pers_mask 71 | 72 | proj_coord, proj_mask = direction_to_pers_img_coord(pano_dirs, to_vecs[i], down_vecs[i], right_vecs[i]) 73 | proj_coord = img_coord_to_sample_coord(proj_coord) 74 | 75 | cur_inpainted_pano_img = F.grid_sample(cur_inpainted[None], proj_coord[None], padding_mode='border')[0] 76 | proj_mask = proj_mask.permute(2, 0, 1).float() 77 | inpainted_img = inpainted_img * (1. - proj_mask) + cur_inpainted_pano_img * proj_mask 78 | mask = mask * (1. - proj_mask) + 0. * proj_mask 79 | 80 | inpainted_img = img * mask + inpainted_img * (1 - mask) 81 | return inpainted_img.permute(1, 2, 0) 82 | -------------------------------------------------------------------------------- /modules/mesh_fusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/TrickyGo/Pano2Room/bbf93ae57086ed700edc6ee445852d4457a9d704/modules/mesh_fusion/__init__.py -------------------------------------------------------------------------------- /scene/__init__.py: -------------------------------------------------------------------------------- 1 | ### 2 | # Copyright (C) 2023, Computer Vision Lab, Seoul National University, https://cv.snu.ac.kr 3 | # For permission requests, please contact robot0321@snu.ac.kr, esw0116@snu.ac.kr, namhj28@gmail.com, jarin.lee@gmail.com. 4 | # All rights reserved. 5 | ### 6 | import os 7 | import random 8 | 9 | from scene.arguments import GSParams 10 | from utils.system import searchForMaxIteration 11 | from scene.dataset_readers import readDataInfo 12 | from scene.gaussian_model import GaussianModel 13 | 14 | 15 | class Scene: 16 | gaussians: GaussianModel 17 | 18 | def __init__(self, traindata, gaussians: GaussianModel, opt: GSParams): 19 | self.traindata = traindata 20 | self.gaussians = gaussians 21 | 22 | info = readDataInfo(traindata, opt.white_background) 23 | # random.shuffle(info.train_cameras) # Multi-res consistent random shuffling 24 | self.cameras_extent = info.nerf_normalization["radius"] 25 | 26 | print("Loading Training Cameras") 27 | self.train_cameras = info.train_cameras 28 | print("Loading Preset Cameras") 29 | self.preset_cameras = {} 30 | for campath in info.preset_cameras.keys(): 31 | self.preset_cameras[campath] = info.preset_cameras[campath] 32 | 33 | self.gaussians.create_from_pcd(info.point_cloud, self.cameras_extent) 34 | self.gaussians.training_setup(opt) 35 | 36 | def getTrainCameras(self): 37 | return self.train_cameras 38 | 39 | def getPresetCameras(self, preset): 40 | assert preset in self.preset_cameras 41 | return self.preset_cameras[preset] -------------------------------------------------------------------------------- /scene/arguments.py: -------------------------------------------------------------------------------- 1 | ### 2 | # Copyright (C) 2023, Computer Vision Lab, Seoul National University, https://cv.snu.ac.kr 3 | # For permission requests, please contact robot0321@snu.ac.kr, esw0116@snu.ac.kr, namhj28@gmail.com, jarin.lee@gmail.com. 4 | # All rights reserved. 5 | ### 6 | import numpy as np 7 | 8 | 9 | class GSParams: 10 | def __init__(self): 11 | self.sh_degree = 1 12 | self.images = "images" 13 | self.resolution = -1 14 | self.white_background = False 15 | self.data_device = "cuda" 16 | self.eval = False 17 | self.use_depth = False 18 | 19 | self.iterations = 3000 20 | self.position_lr_init = 0.00016 21 | self.position_lr_final = 0.0000016 22 | self.position_lr_delay_mult = 0.01 23 | self.position_lr_max_steps = 2990#3_000 24 | self.feature_lr = 0.0025 25 | self.opacity_lr = 0.05 #0.05 26 | self.scaling_lr = 0.005 27 | self.rotation_lr = 0.001 28 | self.percent_dense = 0.01 29 | self.lambda_dssim = 0.2 30 | self.densification_interval = 100 31 | self.opacity_reset_interval = 999999 #3000 32 | self.densify_from_iter = 500 33 | self.densify_until_iter = 15_000 34 | self.densify_grad_threshold = 0.0002 35 | 36 | self.convert_SHs_python = False 37 | self.compute_cov3D_python = False 38 | self.debug = False 39 | 40 | 41 | class CameraParams: 42 | def __init__(self, H: int = 512, W: int = 512, angle = 90.0): 43 | self.H = H 44 | self.W = W 45 | 46 | self.fov = ((angle/180.0)*np.pi, (angle/180.0)*np.pi) 47 | self.focal = ((0.5*W/np.tan(self.fov[0]/2)), (0.5*H/np.tan(self.fov[0]/2))) 48 | 49 | self.K = np.array([ 50 | [self.focal[0], 0., self.W/2], 51 | [0., self.focal[1], self.H/2], 52 | [0., 0., 1.], 53 | ]).astype(np.float32) -------------------------------------------------------------------------------- /scene/cameras.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import numpy as np 12 | 13 | import torch 14 | from torch import nn 15 | 16 | from utils.graphics import getWorld2View2, getProjectionMatrix 17 | from utils.loss import image2canny 18 | 19 | 20 | class Camera(nn.Module): 21 | def __init__(self, colmap_id, R, T, FoVx, FoVy, image, gt_alpha_mask, 22 | image_name, uid, 23 | trans=np.array([0.0, 0.0, 0.0]), scale=1.0, data_device = "cuda" 24 | ): 25 | super(Camera, self).__init__() 26 | 27 | self.uid = uid 28 | self.colmap_id = colmap_id 29 | self.R = R 30 | self.T = T 31 | self.FoVx = FoVx 32 | self.FoVy = FoVy 33 | self.image_name = image_name 34 | 35 | try: 36 | self.data_device = torch.device(data_device) 37 | except Exception as e: 38 | print(e) 39 | print(f"[Warning] Custom device {data_device} failed, fallback to default cuda device" ) 40 | self.data_device = torch.device("cuda") 41 | 42 | self.original_image = image.clamp(0.0, 1.0).to(self.data_device) 43 | self.canny_mask = image2canny(self.original_image.permute(1,2,0), 50, 150, isEdge1=False).detach().to(self.data_device) 44 | self.image_width = self.original_image.shape[2] 45 | self.image_height = self.original_image.shape[1] 46 | 47 | if gt_alpha_mask is not None: 48 | self.original_image *= gt_alpha_mask.to(self.data_device) 49 | else: 50 | self.original_image *= torch.ones((1, self.image_height, self.image_width), device=self.data_device) 51 | 52 | self.zfar = 100.0 53 | self.znear = 0.01 54 | 55 | self.trans = trans 56 | self.scale = scale 57 | 58 | self.world_view_transform = torch.tensor(getWorld2View2(R, T, trans, scale)).transpose(0, 1).cuda() 59 | self.projection_matrix = getProjectionMatrix(znear=self.znear, zfar=self.zfar, fovX=self.FoVx, fovY=self.FoVy).transpose(0,1).cuda() 60 | self.full_proj_transform = (self.world_view_transform.unsqueeze(0).bmm(self.projection_matrix.unsqueeze(0))).squeeze(0) 61 | self.camera_center = self.world_view_transform.inverse()[3, :3] 62 | 63 | 64 | class MiniCam: 65 | def __init__(self, width, height, fovy, fovx, znear, zfar, world_view_transform, full_proj_transform): 66 | self.image_width = width 67 | self.image_height = height 68 | self.FoVy = fovy 69 | self.FoVx = fovx 70 | self.znear = znear 71 | self.zfar = zfar 72 | self.world_view_transform = world_view_transform 73 | self.full_proj_transform = full_proj_transform 74 | view_inv = torch.inverse(self.world_view_transform) 75 | self.camera_center = view_inv[3][:3] 76 | 77 | -------------------------------------------------------------------------------- /scripts/accelerate.yaml: -------------------------------------------------------------------------------- 1 | compute_environment: LOCAL_MACHINE 2 | debug: false 3 | distributed_type: 'NO' 4 | downcast_bf16: 'no' 5 | gpu_ids: '4' 6 | machine_rank: 0 7 | main_training_function: main 8 | # mixed_precision: fp16 9 | num_machines: 1 10 | num_processes: 1 11 | rdzv_backend: static 12 | same_network: true 13 | tpu_env: [] 14 | tpu_use_cluster: false 15 | tpu_use_sudo: false 16 | use_cpu: false 17 | -------------------------------------------------------------------------------- /scripts/create_SDFT_pairs.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python create_SDFT_pairs.py -------------------------------------------------------------------------------- /scripts/run_Pano2Room.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=0 python3 pano2room.py -------------------------------------------------------------------------------- /scripts/train_SDFT.sh: -------------------------------------------------------------------------------- 1 | accelerate launch --config_file scripts/accelerate.yaml train_SDFT.py -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/graphics.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | import math 12 | from typing import NamedTuple 13 | import numpy as np 14 | import torch 15 | 16 | 17 | class BasicPointCloud(NamedTuple): 18 | points : np.array 19 | colors : np.array 20 | normals : np.array 21 | 22 | 23 | def geom_transform_points(points, transf_matrix): 24 | P, _ = points.shape 25 | ones = torch.ones(P, 1, dtype=points.dtype, device=points.device) 26 | points_hom = torch.cat([points, ones], dim=1) 27 | points_out = torch.matmul(points_hom, transf_matrix.unsqueeze(0)) 28 | 29 | denom = points_out[..., 3:] + 0.0000001 30 | return (points_out[..., :3] / denom).squeeze(dim=0) 31 | 32 | 33 | def getWorld2View(R, t): 34 | Rt = np.zeros((4, 4)) 35 | Rt[:3, :3] = R.transpose() 36 | Rt[:3, 3] = t 37 | Rt[3, 3] = 1.0 38 | return np.float32(Rt) 39 | 40 | 41 | def getWorld2View2(R, t, translate=np.array([.0, .0, .0]), scale=1.0): 42 | Rt = np.zeros((4, 4)) 43 | Rt[:3, :3] = R.transpose() 44 | Rt[:3, 3] = t 45 | Rt[3, 3] = 1.0 46 | 47 | C2W = np.linalg.inv(Rt) 48 | cam_center = C2W[:3, 3] 49 | cam_center = (cam_center + translate) * scale 50 | C2W[:3, 3] = cam_center 51 | Rt = np.linalg.inv(C2W) 52 | return np.float32(Rt) 53 | 54 | 55 | def getProjectionMatrix(znear, zfar, fovX, fovY): 56 | tanHalfFovY = math.tan((fovY / 2)) 57 | tanHalfFovX = math.tan((fovX / 2)) 58 | 59 | top = tanHalfFovY * znear 60 | bottom = -top 61 | right = tanHalfFovX * znear 62 | left = -right 63 | 64 | P = torch.zeros(4, 4) 65 | 66 | z_sign = 1.0 67 | 68 | P[0, 0] = 2.0 * znear / (right - left) 69 | P[1, 1] = 2.0 * znear / (top - bottom) 70 | P[0, 2] = (right + left) / (right - left) 71 | P[1, 2] = (top + bottom) / (top - bottom) 72 | P[3, 2] = z_sign 73 | P[2, 2] = z_sign * zfar / (zfar - znear) 74 | P[2, 3] = -(zfar * znear) / (zfar - znear) 75 | return P 76 | 77 | 78 | def fov2focal(fov, pixels): 79 | return pixels / (2 * math.tan(fov / 2)) 80 | 81 | 82 | def focal2fov(focal, pixels): 83 | return 2*math.atan(pixels/(2*focal)) -------------------------------------------------------------------------------- /utils/loss.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | from math import exp 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | from torch.autograd import Variable 16 | 17 | 18 | def l1_loss(network_output, gt): 19 | return torch.abs((network_output - gt)).mean() 20 | 21 | 22 | def l2_loss(network_output, gt): 23 | return ((network_output - gt) ** 2).mean() 24 | 25 | 26 | def gaussian(window_size, sigma): 27 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 28 | return gauss / gauss.sum() 29 | 30 | 31 | def create_window(window_size, channel): 32 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 33 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 34 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 35 | return window 36 | 37 | 38 | def ssim(img1, img2, window_size=11, size_average=True): 39 | channel = img1.size(-3) 40 | window = create_window(window_size, channel) 41 | 42 | if img1.is_cuda: 43 | window = window.cuda() 44 | 45 | window = window.type_as(img1) 46 | 47 | return _ssim(img1, img2, window, window_size, channel, size_average) 48 | 49 | 50 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 51 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 52 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 53 | 54 | mu1_sq = mu1.pow(2) 55 | mu2_sq = mu2.pow(2) 56 | mu1_mu2 = mu1 * mu2 57 | 58 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 59 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 60 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 61 | 62 | C1 = 0.01 ** 2 63 | C2 = 0.03 ** 2 64 | 65 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 66 | 67 | if size_average: 68 | return ssim_map.mean() 69 | else: 70 | return ssim_map.mean(1).mean(1).mean(1) 71 | 72 | 73 | import numpy as np 74 | import cv2 75 | def image2canny(image, thres1, thres2, isEdge1=True): 76 | """ image: (H, W, 3)""" 77 | canny_mask = torch.from_numpy(cv2.Canny((image.detach().cpu().numpy()*255.).astype(np.uint8), thres1, thres2)/255.) 78 | if not isEdge1: 79 | canny_mask = 1. - canny_mask 80 | return canny_mask.float() 81 | 82 | with torch.no_grad(): 83 | kernelsize=3 84 | conv = torch.nn.Conv2d(1, 1, kernel_size=kernelsize, padding=(kernelsize//2)) 85 | kernel = torch.tensor([[0.,1.,0.],[1.,0.,1.],[0.,1.,0.]]).reshape(1,1,kernelsize,kernelsize) 86 | conv.weight.data = kernel #torch.ones((1,1,kernelsize,kernelsize)) 87 | conv.bias.data = torch.tensor([0.]) 88 | conv.requires_grad_(False) 89 | conv = conv.cuda() 90 | 91 | 92 | def nearMean_map(array, mask, kernelsize=3): 93 | """ array: (H,W) / mask: (H,W) """ 94 | cnt_map = torch.ones_like(array) 95 | 96 | nearMean_map = conv((array * mask)[None,None]) 97 | cnt_map = conv((cnt_map * mask)[None,None]) 98 | nearMean_map = (nearMean_map / (cnt_map+1e-8)).squeeze() 99 | 100 | return nearMean_map -------------------------------------------------------------------------------- /utils/system.py: -------------------------------------------------------------------------------- 1 | # 2 | # Copyright (C) 2023, Inria 3 | # GRAPHDECO research group, https://team.inria.fr/graphdeco 4 | # All rights reserved. 5 | # 6 | # This software is free for non-commercial, research and evaluation use 7 | # under the terms of the LICENSE.md file. 8 | # 9 | # For inquiries contact george.drettakis@inria.fr 10 | # 11 | from errno import EEXIST 12 | from os import makedirs, path 13 | import os 14 | 15 | 16 | def mkdir_p(folder_path): 17 | # Creates a directory. equivalent to using mkdir -p on the command line 18 | try: 19 | makedirs(folder_path) 20 | except OSError as exc: # Python >2.5 21 | if exc.errno == EEXIST and path.isdir(folder_path): 22 | pass 23 | else: 24 | raise 25 | 26 | 27 | def searchForMaxIteration(folder): 28 | saved_iters = [int(fname.split("_")[-1]) for fname in os.listdir(folder)] 29 | return max(saved_iters) 30 | -------------------------------------------------------------------------------- /utils/warp_utils.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from PIL import Image 3 | import torch 4 | import torch.nn.functional as F 5 | from torchvision.utils import save_image 6 | from torchvision import transforms 7 | 8 | ####################### 9 | # some helper I/O functions 10 | ####################### 11 | def image_to_tensor(img_path, unsqueeze=True): 12 | rgb = transforms.ToTensor()(Image.open(img_path)) 13 | if unsqueeze: 14 | rgb = rgb.unsqueeze(0) 15 | return rgb 16 | 17 | 18 | def disparity_to_tensor(disp_path, unsqueeze=True): 19 | disp = cv2.imread(disp_path, -1) / (2 ** 16 - 1) 20 | disp = torch.from_numpy(disp)[None, ...] 21 | if unsqueeze: 22 | disp = disp.unsqueeze(0) 23 | return disp.float() 24 | 25 | 26 | ####################### 27 | # some helper geometry functions 28 | # adapt from https://github.com/mattpoggi/depthstillation 29 | ####################### 30 | def transformation_from_parameters(axisangle, translation, invert=False): 31 | R = rot_from_axisangle(axisangle) 32 | t = translation.clone() 33 | 34 | if invert: 35 | R = R.transpose(1, 2) 36 | t *= -1 37 | 38 | T = get_translation_matrix(t) 39 | 40 | if invert: 41 | M = torch.matmul(R, T) 42 | else: 43 | M = torch.matmul(T, R) 44 | 45 | return M 46 | 47 | 48 | def get_translation_matrix(translation_vector): 49 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 50 | t = translation_vector.contiguous().view(-1, 3, 1) 51 | T[:, 0, 0] = 1 52 | T[:, 1, 1] = 1 53 | T[:, 2, 2] = 1 54 | T[:, 3, 3] = 1 55 | T[:, :3, 3, None] = t 56 | return T 57 | 58 | 59 | def rot_from_axisangle(vec): 60 | angle = torch.norm(vec, 2, 2, True) 61 | axis = vec / (angle + 1e-7) 62 | 63 | ca = torch.cos(angle) 64 | sa = torch.sin(angle) 65 | C = 1 - ca 66 | 67 | x = axis[..., 0].unsqueeze(1) 68 | y = axis[..., 1].unsqueeze(1) 69 | z = axis[..., 2].unsqueeze(1) 70 | 71 | xs = x * sa 72 | ys = y * sa 73 | zs = z * sa 74 | xC = x * C 75 | yC = y * C 76 | zC = z * C 77 | xyC = x * yC 78 | yzC = y * zC 79 | zxC = z * xC 80 | 81 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 82 | 83 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 84 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 85 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 86 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 87 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 88 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 89 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 90 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 91 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 92 | rot[:, 3, 3] = 1 93 | 94 | return rot 95 | 96 | --------------------------------------------------------------------------------