├── .gitignore ├── LICENSE ├── README.md ├── assets └── teaser.jpg ├── configs ├── pose │ └── depth_gcn.yaml ├── scale │ └── depth_feat.yaml └── shape │ └── nocs_embed.yaml ├── data ├── __init__.py ├── base.py ├── data_util.py ├── pose.py ├── scale_joint.py └── shape.py ├── env.yaml ├── lr_scheduler.py ├── models ├── diffusion │ ├── __init__.py │ ├── ddim_scale.py │ ├── ddpm_pose.py │ ├── ddpm_scale.py │ └── ddpm_shape.py ├── gcn3d.py ├── latentloader.py ├── pointembed.py └── unet_feature.py ├── modules ├── attention.py ├── diffusionmodules │ ├── __init__.py │ ├── model.py │ ├── openaimodel_pose.py │ ├── openaimodel_scale.py │ ├── openaimodel_shape.py │ └── util.py ├── distributions │ ├── __init__.py │ └── distributions.py ├── ema.py ├── encoders │ ├── __init__.py │ └── modules.py └── x_transformer.py ├── scripts ├── alignment_from_nocs.py ├── eval_alignments.py ├── generate_multi_nocs_candidates.py ├── generate_multi_scale_candidates.py └── generate_multi_shape_candidates.py ├── splits ├── pose │ ├── 02818832 │ │ ├── train.txt │ │ ├── val_02818832.txt │ │ └── val_nonocc_centroid_maskexist.txt │ ├── 02871439 │ │ ├── train.txt │ │ ├── val_02871439.txt │ │ └── val_nonocc_centroid_maskexist.txt │ ├── 02933112 │ │ ├── train.txt │ │ ├── val_02933112.txt │ │ └── val_nonocc_centroid_maskexist.txt │ ├── 03001627 │ │ ├── train.txt │ │ ├── val_03001627.txt │ │ └── val_nonocc_centroid_maskexist.txt │ ├── 04256520 │ │ ├── train.txt │ │ ├── val_04256520.txt │ │ └── val_nonocc_centroid_maskexist.txt │ └── 04379243 │ │ ├── train.txt │ │ ├── val_04379243.txt │ │ └── val_nonocc_centroid_maskexist.txt ├── scale │ ├── train_joint.txt │ └── val_joint.txt └── shape │ ├── 02818832 │ ├── train.txt │ └── val_nonocc_centroid_maskexist.txt │ ├── 02871439 │ ├── train.txt │ └── val_nonocc_centroid_maskexist.txt │ ├── 02933112 │ ├── train.txt │ └── val_nonocc_centroid_maskexist.txt │ ├── 03001627 │ ├── train.txt │ └── val_nonocc_centroid_maskexist.txt │ ├── 04256520 │ ├── train.txt │ └── val_nonocc_centroid_maskexist.txt │ └── 04379243 │ ├── train.txt │ └── val_nonocc_centroid_maskexist.txt ├── train_pose.py ├── train_scale.py ├── train_shape.py └── util.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__/ 2 | .DS_Store 3 | *.egg 4 | log/ 5 | wandb/ 6 | .vscode/ 7 | *.ckpt -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Creative Commons Legal Code 2 | 3 | CC0 1.0 Universal 4 | 5 | CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE 6 | LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN 7 | ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS 8 | INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES 9 | REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS 10 | PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM 11 | THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED 12 | HEREUNDER. 13 | 14 | Statement of Purpose 15 | 16 | The laws of most jurisdictions throughout the world automatically confer 17 | exclusive Copyright and Related Rights (defined below) upon the creator 18 | and subsequent owner(s) (each and all, an "owner") of an original work of 19 | authorship and/or a database (each, a "Work"). 20 | 21 | Certain owners wish to permanently relinquish those rights to a Work for 22 | the purpose of contributing to a commons of creative, cultural and 23 | scientific works ("Commons") that the public can reliably and without fear 24 | of later claims of infringement build upon, modify, incorporate in other 25 | works, reuse and redistribute as freely as possible in any form whatsoever 26 | and for any purposes, including without limitation commercial purposes. 27 | These owners may contribute to the Commons to promote the ideal of a free 28 | culture and the further production of creative, cultural and scientific 29 | works, or to gain reputation or greater distribution for their Work in 30 | part through the use and efforts of others. 31 | 32 | For these and/or other purposes and motivations, and without any 33 | expectation of additional consideration or compensation, the person 34 | associating CC0 with a Work (the "Affirmer"), to the extent that he or she 35 | is an owner of Copyright and Related Rights in the Work, voluntarily 36 | elects to apply CC0 to the Work and publicly distribute the Work under its 37 | terms, with knowledge of his or her Copyright and Related Rights in the 38 | Work and the meaning and intended legal effect of CC0 on those rights. 39 | 40 | 1. Copyright and Related Rights. A Work made available under CC0 may be 41 | protected by copyright and related or neighboring rights ("Copyright and 42 | Related Rights"). Copyright and Related Rights include, but are not 43 | limited to, the following: 44 | 45 | i. the right to reproduce, adapt, distribute, perform, display, 46 | communicate, and translate a Work; 47 | ii. moral rights retained by the original author(s) and/or performer(s); 48 | iii. publicity and privacy rights pertaining to a person's image or 49 | likeness depicted in a Work; 50 | iv. rights protecting against unfair competition in regards to a Work, 51 | subject to the limitations in paragraph 4(a), below; 52 | v. rights protecting the extraction, dissemination, use and reuse of data 53 | in a Work; 54 | vi. database rights (such as those arising under Directive 96/9/EC of the 55 | European Parliament and of the Council of 11 March 1996 on the legal 56 | protection of databases, and under any national implementation 57 | thereof, including any amended or successor version of such 58 | directive); and 59 | vii. other similar, equivalent or corresponding rights throughout the 60 | world based on applicable law or treaty, and any national 61 | implementations thereof. 62 | 63 | 2. Waiver. To the greatest extent permitted by, but not in contravention 64 | of, applicable law, Affirmer hereby overtly, fully, permanently, 65 | irrevocably and unconditionally waives, abandons, and surrenders all of 66 | Affirmer's Copyright and Related Rights and associated claims and causes 67 | of action, whether now known or unknown (including existing as well as 68 | future claims and causes of action), in the Work (i) in all territories 69 | worldwide, (ii) for the maximum duration provided by applicable law or 70 | treaty (including future time extensions), (iii) in any current or future 71 | medium and for any number of copies, and (iv) for any purpose whatsoever, 72 | including without limitation commercial, advertising or promotional 73 | purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each 74 | member of the public at large and to the detriment of Affirmer's heirs and 75 | successors, fully intending that such Waiver shall not be subject to 76 | revocation, rescission, cancellation, termination, or any other legal or 77 | equitable action to disrupt the quiet enjoyment of the Work by the public 78 | as contemplated by Affirmer's express Statement of Purpose. 79 | 80 | 3. Public License Fallback. Should any part of the Waiver for any reason 81 | be judged legally invalid or ineffective under applicable law, then the 82 | Waiver shall be preserved to the maximum extent permitted taking into 83 | account Affirmer's express Statement of Purpose. In addition, to the 84 | extent the Waiver is so judged Affirmer hereby grants to each affected 85 | person a royalty-free, non transferable, non sublicensable, non exclusive, 86 | irrevocable and unconditional license to exercise Affirmer's Copyright and 87 | Related Rights in the Work (i) in all territories worldwide, (ii) for the 88 | maximum duration provided by applicable law or treaty (including future 89 | time extensions), (iii) in any current or future medium and for any number 90 | of copies, and (iv) for any purpose whatsoever, including without 91 | limitation commercial, advertising or promotional purposes (the 92 | "License"). The License shall be deemed effective as of the date CC0 was 93 | applied by Affirmer to the Work. Should any part of the License for any 94 | reason be judged legally invalid or ineffective under applicable law, such 95 | partial invalidity or ineffectiveness shall not invalidate the remainder 96 | of the License, and in such case Affirmer hereby affirms that he or she 97 | will not (i) exercise any of his or her remaining Copyright and Related 98 | Rights in the Work or (ii) assert any associated claims and causes of 99 | action with respect to the Work, in either case contrary to Affirmer's 100 | express Statement of Purpose. 101 | 102 | 4. Limitations and Disclaimers. 103 | 104 | a. No trademark or patent rights held by Affirmer are waived, abandoned, 105 | surrendered, licensed or otherwise affected by this document. 106 | b. Affirmer offers the Work as-is and makes no representations or 107 | warranties of any kind concerning the Work, express, implied, 108 | statutory or otherwise, including without limitation warranties of 109 | title, merchantability, fitness for a particular purpose, non 110 | infringement, or the absence of latent or other defects, accuracy, or 111 | the present or absence of errors, whether or not discoverable, all to 112 | the greatest extent permissible under applicable law. 113 | c. Affirmer disclaims responsibility for clearing rights of other persons 114 | that may apply to the Work or any use thereof, including without 115 | limitation any person's Copyright and Related Rights in the Work. 116 | Further, Affirmer disclaims responsibility for obtaining any necessary 117 | consents, permissions or other rights required for any use of the 118 | Work. 119 | d. Affirmer understands and acknowledges that Creative Commons is not a 120 | party to this document and has no duty or obligation with respect to 121 | this CC0 or use of the Work. 122 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DiffCAD: Weakly-Supervised Probabilistic CAD Model Retrieval and Alignment from an RGB Image 2 | [![arXiv](https://img.shields.io/badge/📃-arXiv%20-red.svg)](https://arxiv.org/abs/2311.18610) 3 | [![webpage](https://img.shields.io/badge/🌐-Website%20-blue.svg)](https://daoyig.github.io/DiffCAD/) 4 | 5 | [Daoyi Gao](https://daoyig.github.io/), [Dávid Rozenberszki](https://rozdavid.github.io/), [Stefan Leutenegger](https://srl.cit.tum.de/members/leuteneg), and [Angela Dai](https://www.3dunderstanding.org/index.html) 6 | 7 | DiffCAD proposed a weakly-supervised approach for CAD model retrieval and alignment from an RGB image. Our approach utilzes disentangled diffusion models to tackle the ambiguities in the monocular perception, and achives robuts cross-domain performance while only trained on synthetic dataset. 8 | 9 | ![DiffCAD](assets/teaser.jpg) 10 | 11 | 12 | ## Environment 13 | 14 | We tested with Ubuntu 20.04, Python 3.8, CUDA 11, Pytorch 2.0 15 | 16 | ### Dependencies 17 | 18 | We provide an Anaconda environment with the dependencies, to install run 19 | 20 | ``` 21 | conda env create -f env.yaml 22 | ``` 23 | 24 | ## Available Resources 25 | 26 | ### Data 27 | We provide our synthetic 3D-FRONT data rendering (RGB, rendered/predicted depth, mask, camera poses); processed watertight ([mesh-fusion](https://github.com/autonomousvision/occupancy_networks/tree/master/external/mesh-fusion)) and canonicalized meshes (ShapeNet and 3D-FUTURE), and their encoded latent vectors; machine estimated depth and masks on the validation set of ScanNet25k data. However, since the rendered data will take up large storage space, we also encourage you to generate the synthetic data rendering yourself following [BlenderProc](https://github.com/DLR-RM/BlenderProc) or [3DFront-Rendering](https://github.com/yinyunie/BlenderProc-3DFront). 28 | 29 | | **Source Dataset** | **Description** | 30 | |--------------------| --------------| 31 | | [3D-FRONT-CONFIG](https://syncandshare.lrz.de/getlink/fiMLEHNEu87SA4gTHcQkuB/3D-FRONT-CONFIG) | Scene config for rendering, we also augment it with ShapeNet objects. | 32 | | [3D-FRONT-RENDERING](https://syncandshare.lrz.de/getlink/fiQUDhpSxJV3HJjQx66Ngb/3D-FRONT-RENDER) | Renderings of 3D-FRONT dataset for each target category. | 33 | | [Object Meshes](https://syncandshare.lrz.de/getlink/fiQWpWzs5qSeXrt2JStEbT/Mesh) | Canonicalized and watertighted mesh of ShapeNet and 3D-FUTURE. | 34 | | [Object Meshes - AUG](https://syncandshare.lrz.de/getlink/fiAhSmZduitQM8FeLEU4Yy/Mesh-AUG) | ShapeNet object but scaled by its NN 3DF object scale, which we use to augment the synthetic dataset. | 35 | | [Object Latents](https://syncandshare.lrz.de/getlink/fi53KQjYS2MJgKdgc3zzAo/Latents) | Encoded object latents for retrieval. | 36 | | [Val ScanNet25k](https://syncandshare.lrz.de/getlink/fiKQasexdTsyRfqQV6YQSU/Scan2CAD_processed) | Predict depth, GT and predicted masks, CAD pools, pose gts on the validation set. | 37 | | [ScanNet25k data](https://drive.google.com/drive/folders/1JbPidWsfcLyUswYQsulZN8HDFBTdoQog) | The processed data from [ROCA](https://github.com/cangumeli/ROCA) | 38 | 39 | 40 | 41 | ## Pretrained Checkpoint 42 | We also provide the checkpoints for scene scale, object pose, and shape diffusion models. 43 | | **Source Dataset** | | 44 | |--------------------|-------------------| 45 | | Scale | [Joint category ldm model](https://syncandshare.lrz.de/getlink/fiEuyDe5EusDujuetyk9UN/scale) | 46 | | Pose | [Category-specific ldm model](https://syncandshare.lrz.de/getlink/fiSMR6RAwVS5ucGh2e9Mvu/pose) | 47 | | Shape | [Category-specific ldm model](https://syncandshare.lrz.de/getlink/fiEdb3iPSjPg8QdcAnJ7ou/shape) | 48 | 49 | 50 | ## Training 51 | For scene scale: 52 | 53 | ```python train_scale.py --base=configs/scale/depth_feat.yaml -t --gpus=0, --logdir=logs``` 54 | 55 | For object NOCs: 56 | 57 | ```python train_pose.py --base=configs/pose/depth_gcn.yaml -t --gpus=0, --logdir=logs``` 58 | 59 | For object latents: 60 | 61 | ```python train_shape.py --base=configs/shape/nocs_embed.yaml -t --gpus=0, --logdir=logs``` 62 | 63 | 64 | ## Inference 65 | For scene scale sampling: 66 | 67 | ```python scripts/generate_multi_scale_candidates.py``` 68 | 69 | For object NOCs generation: 70 | 71 | ```python scripts/generate_multi_nocs_candidates.py``` 72 | 73 | For object latent sampling: 74 | 75 | ```python scripts/generate_multi_shape_candidates.py``` 76 | 77 | ## BibTeX 78 | 79 | ``` 80 | @article{gao2023diffcad, 81 | title= {DiffCAD: Weakly-Supervised Probabilistic CAD Model Retrieval and Alignment from an RGB Image}, 82 | author={Gao, Daoyi and Rozenberszki, David and Leutenegger, Stefan and Dai, Angela}, 83 | booktitle={ArXiv Preprint}, 84 | year={2023} 85 | } 86 | 87 | ``` 88 | 89 | ## Reference 90 | We borrow [latent-diffusion](https://arxiv.org/abs/2112.10752) from the [official implementation](https://github.com/CompVis/latent-diffusion). 91 | -------------------------------------------------------------------------------- /assets/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoyiG/DiffCAD/0496e93a8dc1110102a38352499ab35155ab1609/assets/teaser.jpg -------------------------------------------------------------------------------- /configs/pose/depth_gcn.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: models.diffusion.ddpm_pose.LatentDiffusion 4 | params: 5 | parameterization: eps 6 | scale_by_std: False 7 | linear_start: 0.0015 8 | linear_end: 0.0195 9 | num_timesteps_cond: 1 10 | log_every_t: 200 11 | timesteps: 1000 12 | first_stage_key: nocs_gt 13 | cond_stage_trainable: True 14 | conditioning_key: concat 15 | cond_stage_key: depth_input 16 | monitor: val/nocs_diff 17 | 18 | unet_config: 19 | target: modules.diffusionmodules.openaimodel_pose.UNetModel 20 | params: 21 | image_size: 1024 22 | in_channels: 1283 23 | out_channels: 3 24 | model_channels: 256 25 | dropout: 0.0 26 | attention_resolutions: 27 | - 4 28 | - 2 29 | - 1 30 | num_res_blocks: 2 31 | channel_mult: 32 | - 1 33 | - 2 34 | - 4 35 | num_heads: 8 36 | resblock_updown: False 37 | use_spatial_transformer: False 38 | transformer_depth: 39 | context_dim: 40 | 41 | first_stage_config: 42 | target: models.latentloader.IdentityFirstStage 43 | cond_stage_config: 44 | target: models.gcn3d.GCN3D 45 | 46 | 47 | 48 | data: 49 | target: train_pose.DataModuleFromConfig 50 | params: 51 | batch_size: 2 52 | num_workers: 8 53 | wrap: false 54 | train: 55 | target: data.pose.DiffCADposeTrain 56 | validation: 57 | target: data.pose.DiffCADposeValidation 58 | 59 | lightning: 60 | callbacks: 61 | image_logger: 62 | target: train_pose.ImageLogger 63 | params: 64 | batch_frequency: 100 65 | increase_log_steps: False 66 | 67 | trainer: 68 | max_epochs: 1001 69 | benchmark: True 70 | max_steps: -------------------------------------------------------------------------------- /configs/scale/depth_feat.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 5.0e-05 3 | target: models.diffusion.ddpm_scale.LatentDiffusion 4 | params: 5 | parameterization: eps 6 | scale_by_std: false 7 | linear_start: 0.0015 8 | linear_end: 0.0195 9 | num_timesteps_cond: 1 10 | log_every_t: 200 11 | timesteps: 1000 12 | first_stage_key: depth_scale_img 13 | cond_stage_trainable: True 14 | conditioning_key: concat 15 | cond_stage_key: depth_input 16 | monitor: val/scale_diff 17 | unet_config: 18 | target: modules.diffusionmodules.openaimodel_scale.UNetModel 19 | params: 20 | image_size: [30, 40] 21 | in_channels: 129 22 | out_channels: 1 23 | model_channels: 128 24 | dropout: 0.1 25 | attention_resolutions: 26 | - 1 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | num_heads: 2 31 | use_spatial_transformer: false 32 | transformer_depth: null 33 | context_dim: null 34 | first_stage_config: 35 | target: models.latentloader.IdentityFirstStage 36 | cond_stage_config: 37 | target: models.unet_feature.FeatureCondStage 38 | data: 39 | target: train_scale.DataModuleFromConfig 40 | params: 41 | batch_size: 8 42 | num_workers: 8 43 | wrap: false 44 | train: 45 | target: data.scale_joint.DiffCADscaleTrain 46 | validation: 47 | target: data.scale_joint.DiffCADscaleValidation 48 | 49 | lightning: 50 | callbacks: 51 | image_logger: 52 | target: train_scale.ImageLogger 53 | params: 54 | batch_frequency: 200 55 | increase_log_steps: False 56 | 57 | trainer: 58 | max_epochs: 1001 59 | benchmark: True 60 | max_steps: -------------------------------------------------------------------------------- /configs/shape/nocs_embed.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-04 3 | target: models.diffusion.ddpm_shape.LatentDiffusion 4 | params: 5 | parameterization: x0 6 | scale_by_std: False 7 | linear_start: 0.0015 8 | linear_end: 0.0195 9 | num_timesteps_cond: 1 10 | log_every_t: 200 11 | timesteps: 1000 12 | first_stage_key: latent_gt 13 | cond_stage_trainable: True 14 | conditioning_key: crossattn 15 | cond_stage_key: nocs_pc 16 | monitor: val/ChamferDistance_L1 17 | 18 | unet_config: 19 | target: modules.diffusionmodules.openaimodel_shape.UNetModel 20 | params: 21 | image_size: 1 22 | in_channels: 256 23 | out_channels: 256 24 | model_channels: 256 25 | attention_resolutions: 26 | - 1 27 | num_res_blocks: 2 28 | channel_mult: 29 | - 1 30 | num_heads: 8 31 | resblock_updown: False 32 | use_spatial_transformer: true 33 | transformer_depth: 1 34 | context_dim: 512 35 | legacy: False 36 | 37 | first_stage_config: 38 | target: models.latentloader.IdentityFirstStage 39 | 40 | cond_stage_config: 41 | target: models.pointembed.PointEmbed 42 | 43 | 44 | 45 | data: 46 | target: train_shape.DataModuleFromConfig 47 | params: 48 | batch_size: 2 49 | num_workers: 8 50 | wrap: false 51 | train: 52 | target: data.shape.DiffCADshapeTrain 53 | validation: 54 | target: data.shape.DiffCADshapeValidation 55 | 56 | lightning: 57 | callbacks: 58 | image_logger: 59 | target: train_shape.ImageLogger 60 | params: 61 | batch_frequency: 100 62 | increase_log_steps: False 63 | 64 | trainer: 65 | max_epochs: 500 66 | benchmark: True 67 | max_steps: -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoyiG/DiffCAD/0496e93a8dc1110102a38352499ab35155ab1609/data/__init__.py -------------------------------------------------------------------------------- /data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /data/scale_joint.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import torch 8 | import json 9 | from glob import glob 10 | import cv2 11 | import pycocotools.mask as mask_util 12 | from .data_util import crop_resize_by_warp_affine 13 | import torchvision 14 | import kornia as kn 15 | 16 | 17 | class DiffCADscaleBase(Dataset): 18 | def __init__(self, 19 | category, 20 | txt_file, 21 | data_root, 22 | real_data_root, 23 | is_train=False, 24 | ): 25 | self.category = category 26 | self.data_paths = txt_file # split file 27 | self.data_root = data_root 28 | self.real_data_path = real_data_root 29 | self.augment = is_train 30 | 31 | with open(self.data_paths, "r") as f: 32 | self.image_paths = f.read().splitlines() 33 | self._length = len(self.image_paths) 34 | 35 | 36 | # load pre-computed scales 37 | with open(os.path.join(self.real_data_path, 'val_zoedepth_scales', '02818832.json'), 'r') as f: 38 | s2c_02818832_scales = json.load(f) 39 | with open(os.path.join(self.real_data_path, 'val_zoedepth_scales', '02871439.json'), 'r') as f: 40 | s2c_02871439_scales = json.load(f) 41 | with open(os.path.join(self.real_data_path, 'val_zoedepth_scales', '02933112.json'), 'r') as f: 42 | s2c_02933112_scales = json.load(f) 43 | with open(os.path.join(self.real_data_path, 'val_zoedepth_scales', '03001627.json'), 'r') as f: 44 | s2c_03001627_scales = json.load(f) 45 | with open(os.path.join(self.real_data_path, 'val_zoedepth_scales', '04256520.json'), 'r') as f: 46 | s2c_04256520_scales = json.load(f) 47 | with open(os.path.join(self.real_data_path, 'val_zoedepth_scales', '04379243.json'), 'r') as f: 48 | s2c_04379243_scales = json.load(f) 49 | 50 | # load pre-computed scales 51 | with open(os.path.join(self.data_root, 'train_zoedepth_scales', '02818832.json'), 'r') as f: 52 | syn_02818832_scales = json.load(f) 53 | with open(os.path.join(self.data_root, 'train_zoedepth_scales', '02871439.json'), 'r') as f: 54 | syn_02871439_scales = json.load(f) 55 | with open(os.path.join(self.data_root, 'train_zoedepth_scales', '02933112.json'), 'r') as f: 56 | syn_02933112_scales = json.load(f) 57 | with open(os.path.join(self.data_root, 'train_zoedepth_scales', '03001627.json'), 'r') as f: 58 | syn_03001627_scales = json.load(f) 59 | with open(os.path.join(self.data_root, 'train_zoedepth_scales', '04256520.json'), 'r') as f: 60 | syn_04256520_scales = json.load(f) 61 | with open(os.path.join(self.data_root, 'train_zoedepth_scales', '04379243.json'), 'r') as f: 62 | syn_04379243_scales = json.load(f) 63 | 64 | self.train_scales = { 65 | '02818832': syn_02818832_scales, 66 | '02871439': syn_02871439_scales, 67 | '02933112': syn_02933112_scales, 68 | '03001627': syn_03001627_scales, 69 | '04256520': syn_04256520_scales, 70 | '04379243': syn_04379243_scales, 71 | } 72 | 73 | self.s2c_scales = { 74 | '02818832': s2c_02818832_scales, 75 | '02871439': s2c_02871439_scales, 76 | '02933112': s2c_02933112_scales, 77 | '03001627': s2c_03001627_scales, 78 | '04256520': s2c_04256520_scales, 79 | '04379243': s2c_04379243_scales, 80 | } 81 | 82 | def __len__(self): 83 | return self._length 84 | 85 | def __getitem__(self, i): 86 | 87 | scene_idx, frame_idx = self.image_paths[i].split() 88 | scene_info = scene_idx + '_' + frame_idx 89 | 90 | dataset_label = self.get_dataset_name(i) 91 | 92 | latent_idx = None 93 | if dataset_label == '3DF': 94 | category_id = scene_idx.split('-')[0] 95 | scene_id = '-'.join(scene_idx.split('-')[1:6]) 96 | latent_idx = '-'.join(scene_idx.split('-')[6:]) 97 | elif dataset_label == 'S2C': 98 | category_id = scene_idx.split('_')[0] 99 | scene_id = '_'.join(scene_idx.split('_')[1:]) 100 | frame_idx_s2c, latent_idx, instance_id = frame_idx.split('_') 101 | else: 102 | raise ValueError 103 | 104 | assert latent_idx is not None 105 | 106 | if dataset_label == '3DF': 107 | example = dict(category_id=category_id, scene_idx=scene_id, frame_idx=frame_idx, scene_info=scene_info, latent_idx=latent_idx, 108 | dataset_label=dataset_label) 109 | elif dataset_label == 'S2C': 110 | example = dict(category_id=category_id, scene_idx=scene_id, frame_idx=frame_idx_s2c, scene_info=scene_info, latent_idx=latent_idx, 111 | dataset_label=dataset_label) 112 | 113 | depth_full = self.get_depth(i, dataset_label, ext='.png') / 1000.0 114 | self.orig_h, self.orig_w = depth_full.shape 115 | 116 | depth_pred = self.get_depth_pred(i, dataset_label) / 1000.0 117 | 118 | rz = self.get_resize(size=(30, 40)) 119 | 120 | depth_input = torch.from_numpy(depth_full)[None] 121 | depth_input = rz(depth_input) 122 | 123 | depth_pred_input = torch.from_numpy(depth_pred)[None] 124 | depth_pred_input = rz(depth_pred_input) 125 | 126 | if self.augment: 127 | scale_depth = self.train_scales[category_id]['-'.join(scene_idx.split('-')[1:])+'_'+frame_idx]["best_scale_from_renddepth"] 128 | else: 129 | scale_depth = self.s2c_scales[category_id][scene_id + '_'+frame_idx]["best_scale_from_sensordepth"] 130 | 131 | if self.augment: 132 | # horizontal flip 133 | aug_flip = self.get_flip(p=0.5) 134 | depth_pred_input = aug_flip(depth_pred_input) 135 | 136 | # augmentation using the scale 137 | shift_prob = np.random.uniform(0.0, 1.0) 138 | if shift_prob > 0.5: 139 | rand_scale_shift = np.random.uniform(scale_depth-0.2, scale_depth+0.3) 140 | 141 | depth_pred_input = depth_pred_input / rand_scale_shift 142 | scale_depth = rand_scale_shift 143 | 144 | # rotation prob 145 | rot_prob = np.random.uniform(0.0, 1.0) 146 | if rot_prob > 0.7: 147 | aug_rot = self.get_roration(degrees=45) 148 | depth_pred_input = aug_rot(depth_pred_input) 149 | scale_depth = torch.tensor(scale_depth).float() - 1.0 150 | else: 151 | scale_depth = torch.tensor(scale_depth).float() - 1.0 152 | 153 | depth_scale_img = torch.ones_like(depth_pred_input) * scale_depth 154 | 155 | example.update( 156 | depth_input=depth_pred_input, 157 | depth_scale_img=depth_scale_img, 158 | ) 159 | 160 | return example 161 | 162 | def generate_occ_mask(self, orig_mask): 163 | h, w = orig_mask.shape 164 | h_start = np.random.uniform(0.1, 0.5) 165 | w_start = np.random.uniform(0.1, 0.5) 166 | occ_mask = np.zeros_like(orig_mask) 167 | occ_mask[int(h_start*h):, int(w_start*w):] = 1.0 168 | return occ_mask 169 | 170 | def get_gaussian_blur(self): 171 | kernel_size = 5 172 | return torchvision.transforms.GaussianBlur(kernel_size=kernel_size, sigma=(1.0, 2.0)) 173 | 174 | def get_resize(self, size): 175 | return torchvision.transforms.Resize(size=size) 176 | 177 | def get_flip(self, p): 178 | return torchvision.transforms.RandomHorizontalFlip(p=p) 179 | 180 | def get_roration(self, degrees): 181 | return torchvision.transforms.RandomRotation(degrees=degrees) 182 | 183 | def get_dataset_name(self, idx): 184 | line = self.image_paths[idx].split() 185 | scene_id, _ = line 186 | 187 | if '-' in scene_id: 188 | return '3DF' 189 | elif '_' in scene_id: 190 | return 'S2C' 191 | else: 192 | raise ValueError 193 | 194 | 195 | def get_depth(self, idx, dataset_label, ext='.png'): 196 | try: 197 | # Parse the split text 198 | line = self.image_paths[idx].split() 199 | scene_id, frame_id = line 200 | 201 | if dataset_label == "3DF": 202 | category_id = scene_id.split('-')[0] 203 | file = "{}{}".format(frame_id, ext) 204 | elif dataset_label == "S2C": 205 | category_id = scene_id.split('_')[0] 206 | file = "{}{}".format(frame_id.split('_')[0], ext) 207 | 208 | except Exception as e: 209 | print(line) 210 | raise e 211 | 212 | assert dataset_label in ["3DF", "S2C"] 213 | 214 | if dataset_label == "3DF": 215 | data_root = self.data_root + '/3D-FRONT-RENDER-{}'.format(category_id) 216 | # gt rendered depth 217 | depth_fname = os.path.join(data_root, '-'.join(scene_id.split('-')[1:]), 'bop_data/train_pbr/000000/depth', file) 218 | depth = cv2.imread(depth_fname, -1) 219 | 220 | else: 221 | # gt rendered depth 222 | depth_fname = os.path.join(self.real_data_path, "Rendering", '_'.join(scene_id.split('_')[1:]), 'depth', file) 223 | 224 | depth = cv2.imread(depth_fname, -1) # 360, 480 millimiter uint16 225 | 226 | assert depth.dtype == np.uint16 227 | 228 | return depth 229 | 230 | def get_depth_pred(self, idx, dataset_label): 231 | try: 232 | line = self.image_paths[idx].split() 233 | scene_id, frame_id = line 234 | if dataset_label == "3DF": 235 | category_id = scene_id.split('-')[0] 236 | file = "{}_pred_dmap.npy".format(frame_id) 237 | elif dataset_label == "S2C": 238 | category_id = scene_id.split('_')[0] 239 | file = "{}_pred_dmap.npy".format(frame_id.split('_')[0]) 240 | 241 | except Exception as e: 242 | print(line) 243 | raise e 244 | 245 | assert dataset_label in ["3DF", "S2C"] 246 | 247 | if dataset_label == "3DF": 248 | data_root = self.data_root + '/3D-FRONT-RENDER-{}'.format(category_id) 249 | # gt rendered depth 250 | depth_fname = os.path.join(data_root, '-'.join(scene_id.split('-')[1:]), 'bop_data/train_pbr/000000/zoedepth', file) 251 | 252 | else: 253 | # predicted depth 254 | depth_fname = os.path.join(self.real_data_path, "ZoeDepthPredictions", '_'.join(scene_id.split('_')[1:]), file) 255 | 256 | depth = np.load(depth_fname) 257 | depth = (depth * 1000).astype(np.uint16) 258 | 259 | assert depth.dtype == np.uint16 260 | 261 | return depth 262 | 263 | 264 | class DiffCADscaleTrain(DiffCADscaleBase): 265 | def __init__(self, **kwargs): 266 | super().__init__( 267 | category="joint", 268 | txt_file="path to train split", 269 | data_root="path to synthetic data with gt scale defined", 270 | real_data_root="path to real data with gt scale defined", 271 | is_train=True, 272 | **kwargs) 273 | 274 | 275 | class DiffCADscaleValidation(DiffCADscaleBase): 276 | def __init__(self, **kwargs): 277 | super().__init__( 278 | category="joint", 279 | txt_file="path to val split", 280 | data_root="path to synthetic data with gt scale defined", 281 | real_data_root="path to real data with gt scale defined", 282 | is_train=False, 283 | **kwargs) 284 | -------------------------------------------------------------------------------- /env.yaml: -------------------------------------------------------------------------------- 1 | name: diffcad 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - _openmp_mutex=5.1=1_gnu 8 | - blas=1.0=mkl 9 | - ca-certificates=2023.01.10=h06a4308_0 10 | - cudatoolkit=11.0.221=h6bb024c_0 11 | - freetype=2.12.1=h4a9f257_0 12 | - giflib=5.2.1=h5eee18b_3 13 | - intel-openmp=2021.4.0=h06a4308_3561 14 | - jpeg=9e=h5eee18b_1 15 | - lcms2=2.12=h3be6417_0 16 | - ld_impl_linux-64=2.38=h1181459_1 17 | - lerc=3.0=h295c915_0 18 | - libdeflate=1.17=h5eee18b_0 19 | - libffi=3.3=he6710b0_2 20 | - libgcc-ng=11.2.0=h1234567_1 21 | - libgfortran-ng=11.2.0=h00389a5_1 22 | - libgfortran5=11.2.0=h1234567_1 23 | - libgomp=11.2.0=h1234567_1 24 | - libpng=1.6.39=h5eee18b_0 25 | - libstdcxx-ng=11.2.0=h1234567_1 26 | - libtiff=4.5.0=h6a678d5_2 27 | - libuv=1.44.2=h5eee18b_0 28 | - libwebp=1.2.4=h11a3e52_1 29 | - libwebp-base=1.2.4=h5eee18b_1 30 | - lz4-c=1.9.4=h6a678d5_0 31 | - mkl=2021.4.0=h06a4308_640 32 | - mkl-service=2.4.0=py38h7f8727e_0 33 | - mkl_fft=1.3.1=py38hd3c417c_0 34 | - mkl_random=1.2.2=py38h51133e4_0 35 | - ncurses=6.4=h6a678d5_0 36 | - ninja-base=1.10.2=hd09550d_5 37 | - openssl=1.1.1t=h7f8727e_0 38 | - pillow=9.4.0=py38h6a678d5_0 39 | - pip=20.3.3=py38h06a4308_0 40 | - python=3.8.5=h7579374_1 41 | - readline=8.2=h5eee18b_0 42 | - setuptools=67.8.0=py38h06a4308_0 43 | - six=1.16.0=pyhd3eb1b0_1 44 | - sqlite=3.41.2=h5eee18b_0 45 | - tk=8.6.12=h1ccaba5_0 46 | - torchvision=0.8.1=py38_cu110 47 | - typing_extensions=4.5.0=py38h06a4308_0 48 | - wheel=0.38.4=py38h06a4308_0 49 | - xz=5.4.2=h5eee18b_0 50 | - zlib=1.2.13=h5eee18b_0 51 | - zstd=1.5.5=hc292b87_0 52 | - pip: 53 | - absl-py==1.4.0 54 | - addict==2.4.0 55 | - aiohttp==3.8.4 56 | - aiosignal==1.3.1 57 | - albumentations==0.4.3 58 | - altair==4.2.2 59 | - antlr4-python3-runtime==4.8 60 | - appdirs==1.4.4 61 | - asttokens==2.2.1 62 | - async-timeout==4.0.2 63 | - attrs==23.1.0 64 | - backcall==0.2.0 65 | - backports-zoneinfo==0.2.1 66 | - blinker==1.6.2 67 | - cachetools==5.3.1 68 | - ccimport==0.4.2 69 | - certifi==2023.5.7 70 | - charset-normalizer==3.1.0 71 | - click==8.1.3 72 | - cmake==3.26.3 73 | - comm==0.1.3 74 | - configargparse==1.5.3 75 | - contourpy==1.0.7 76 | - cumm-cu113==0.4.11 77 | - cycler==0.11.0 78 | - dash==2.10.2 79 | - dash-core-components==2.0.0 80 | - dash-html-components==2.0.0 81 | - dash-table==5.0.0 82 | - debugpy==1.6.7 83 | - decorator==5.1.1 84 | - docker-pycreds==0.4.0 85 | - einops==0.3.0 86 | - entrypoints==0.4 87 | - executing==1.2.0 88 | - fastjsonschema==2.17.1 89 | - filelock==3.12.0 90 | - fire==0.5.0 91 | - flask==2.2.5 92 | - fonttools==4.39.4 93 | - frozenlist==1.3.3 94 | - fsspec==2023.5.0 95 | - ftfy==6.1.1 96 | - future==0.18.3 97 | - gitdb==4.0.10 98 | - gitpython==3.1.31 99 | - google-auth==2.19.0 100 | - google-auth-oauthlib==1.0.0 101 | - grpcio==1.54.2 102 | - idna==3.4 103 | - imageio==2.9.0 104 | - imageio-ffmpeg==0.4.2 105 | - imgaug==0.2.6 106 | - importlib-metadata==6.6.0 107 | - importlib-resources==5.12.0 108 | - ipykernel==6.23.2 109 | - ipython==8.12.2 110 | - ipywidgets==8.0.6 111 | - itsdangerous==2.1.2 112 | - jedi==0.18.2 113 | - jinja2==3.1.2 114 | - joblib==1.2.0 115 | - jsonschema==4.17.3 116 | - jupyter-client==8.2.0 117 | - jupyter-core==5.3.1 118 | - jupyterlab-widgets==3.0.7 119 | - kiwisolver==1.4.4 120 | - kornia==0.6.12 121 | - lark==1.1.9 122 | - lazy-loader==0.2 123 | - lit==16.0.5 124 | - markdown==3.4.3 125 | - markdown-it-py==2.2.0 126 | - markupsafe==2.1.2 127 | - matplotlib==3.7.1 128 | - matplotlib-inline==0.1.6 129 | - mdurl==0.1.2 130 | - mpmath==1.3.0 131 | - multidict==6.0.4 132 | - nbformat==5.7.0 133 | - nest-asyncio==1.5.6 134 | - networkx==3.1 135 | - ninja==1.11.1.1 136 | - numpy==1.24.3 137 | - numpy-quaternion==2022.4.3 138 | - nvidia-cublas-cu11==11.10.3.66 139 | - nvidia-cuda-cupti-cu11==11.7.101 140 | - nvidia-cuda-nvrtc-cu11==11.7.99 141 | - nvidia-cuda-runtime-cu11==11.7.99 142 | - nvidia-cudnn-cu11==8.5.0.96 143 | - nvidia-cufft-cu11==10.9.0.58 144 | - nvidia-curand-cu11==10.2.10.91 145 | - nvidia-cusolver-cu11==11.4.0.1 146 | - nvidia-cusparse-cu11==11.7.4.91 147 | - nvidia-nccl-cu11==2.14.3 148 | - nvidia-nvtx-cu11==11.7.91 149 | - oauthlib==3.2.2 150 | - omegaconf==2.1.1 151 | - open3d==0.17.0 152 | - opencv-python==4.1.2.30 153 | - opencv-python-headless==4.7.0.72 154 | - packaging==21.3 155 | - pandas==2.0.2 156 | - parso==0.8.3 157 | - pathtools==0.1.2 158 | - pccm==0.4.11 159 | - pexpect==4.8.0 160 | - pickleshare==0.7.5 161 | - pkgutil-resolve-name==1.3.10 162 | - platformdirs==3.5.3 163 | - plotly==5.15.0 164 | - portalocker==2.8.2 165 | - prompt-toolkit==3.0.38 166 | - protobuf==3.20.3 167 | - psutil==5.9.5 168 | - ptyprocess==0.7.0 169 | - pudb==2019.2 170 | - pure-eval==0.2.2 171 | - pyarrow==12.0.0 172 | - pyasn1==0.5.0 173 | - pyasn1-modules==0.3.0 174 | - pybind11==2.11.1 175 | - pycocotools==2.0.6 176 | - pydeck==0.8.1b0 177 | - pydeprecate==0.3.1 178 | - pygments==2.15.1 179 | - pympler==1.0.1 180 | - pyparsing==3.0.9 181 | - pyquaternion==0.9.9 182 | - pyrsistent==0.19.3 183 | - python-dateutil==2.8.2 184 | - pytorch-lightning==1.4.2 185 | - pytz==2023.3 186 | - pywavelets==1.4.1 187 | - pyyaml==6.0 188 | - pyzmq==25.1.0 189 | - quaternion==3.5.2.post4 190 | - regex==2023.5.5 191 | - requests==2.31.0 192 | - requests-oauthlib==1.3.1 193 | - rich==13.3.5 194 | - rsa==4.9 195 | - sacremoses==0.0.53 196 | - scikit-image==0.20.0 197 | - scikit-learn==1.2.2 198 | - scipy==1.9.1 199 | - sentry-sdk==1.24.0 200 | - setproctitle==1.3.2 201 | - smmap==5.0.0 202 | - spconv-cu113==2.3.6 203 | - stack-data==0.6.2 204 | - streamlit==1.22.0 205 | - sympy==1.12 206 | - tenacity==8.2.2 207 | - tensorboard==2.13.0 208 | - tensorboard-data-server==0.7.0 209 | - termcolor==2.4.0 210 | - test-tube==0.7.5 211 | - threadpoolctl==3.1.0 212 | - tifffile==2023.4.12 213 | - tokenizers==0.10.3 214 | - toml==0.10.2 215 | - toolz==0.12.0 216 | - torch==2.0.1 217 | - torch-fidelity==0.3.0 218 | - torchmetrics==0.7.3 219 | - tornado==6.3.2 220 | - tqdm==4.65.0 221 | - traitlets==5.9.0 222 | - transformers==4.3.1 223 | - trimesh==3.21.7 224 | - triton==2.0.0 225 | - tzdata==2023.3 226 | - tzlocal==5.0.1 227 | - urllib3==1.26.16 228 | - urwid==2.1.2 229 | - validators==0.20.0 230 | - wandb==0.15.3 231 | - watchdog==3.0.0 232 | - wcwidth==0.2.6 233 | - werkzeug==2.2.3 234 | - widgetsnbextension==4.0.7 235 | - yarl==1.9.2 236 | - zipp==3.15.0 237 | -------------------------------------------------------------------------------- /lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoyiG/DiffCAD/0496e93a8dc1110102a38352499ab35155ab1609/models/diffusion/__init__.py -------------------------------------------------------------------------------- /models/diffusion/ddim_scale.py: -------------------------------------------------------------------------------- 1 | """SAMPLING ONLY.""" 2 | 3 | import torch 4 | import numpy as np 5 | from tqdm import tqdm 6 | 7 | from modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like 8 | 9 | 10 | class DDIMSampler(object): 11 | def __init__(self, model, schedule="linear", **kwargs): 12 | super().__init__() 13 | self.model = model 14 | self.ddpm_num_timesteps = model.num_timesteps 15 | self.schedule = schedule 16 | 17 | def register_buffer(self, name, attr): 18 | if type(attr) == torch.Tensor: 19 | if attr.device != torch.device("cuda"): 20 | attr = attr.to(torch.device("cuda")) 21 | setattr(self, name, attr) 22 | 23 | def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): 24 | self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, 25 | num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) 26 | alphas_cumprod = self.model.alphas_cumprod 27 | assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' 28 | to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) 29 | 30 | self.register_buffer('betas', to_torch(self.model.betas)) 31 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 32 | self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) 33 | 34 | # calculations for diffusion q(x_t | x_{t-1}) and others 35 | self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) 36 | self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) 37 | self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) 38 | self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) 39 | self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) 40 | 41 | # ddim sampling parameters 42 | ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), 43 | ddim_timesteps=self.ddim_timesteps, 44 | eta=ddim_eta,verbose=verbose) 45 | self.register_buffer('ddim_sigmas', ddim_sigmas) 46 | self.register_buffer('ddim_alphas', ddim_alphas) 47 | self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) 48 | self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) 49 | sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( 50 | (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( 51 | 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) 52 | self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) 53 | 54 | @torch.no_grad() 55 | def sample(self, 56 | S, 57 | batch_size, 58 | shape, 59 | conditioning=None, 60 | callback=None, 61 | normals_sequence=None, 62 | img_callback=None, 63 | quantize_x0=False, 64 | eta=0., 65 | mask=None, 66 | x0=None, 67 | temperature=1., 68 | noise_dropout=0., 69 | score_corrector=None, 70 | corrector_kwargs=None, 71 | verbose=True, 72 | x_T=None, 73 | log_every_t=100, 74 | unconditional_guidance_scale=1., 75 | unconditional_conditioning=None, 76 | # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... 77 | **kwargs 78 | ): 79 | if conditioning is not None: 80 | if isinstance(conditioning, dict): 81 | cbs = conditioning[list(conditioning.keys())[0]].shape[0] 82 | if cbs != batch_size: 83 | print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") 84 | else: 85 | if conditioning.shape[0] != batch_size: 86 | print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") 87 | 88 | self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) 89 | # sampling 90 | C, H, W = shape 91 | size = (batch_size, C, H, W) 92 | print(f'Data shape for DDIM sampling is {size}, eta {eta}') 93 | 94 | samples, intermediates = self.ddim_sampling(conditioning, size, 95 | callback=callback, 96 | img_callback=img_callback, 97 | quantize_denoised=quantize_x0, 98 | mask=mask, x0=x0, 99 | ddim_use_original_steps=False, 100 | noise_dropout=noise_dropout, 101 | temperature=temperature, 102 | score_corrector=score_corrector, 103 | corrector_kwargs=corrector_kwargs, 104 | x_T=x_T, 105 | log_every_t=log_every_t, 106 | unconditional_guidance_scale=unconditional_guidance_scale, 107 | unconditional_conditioning=unconditional_conditioning, 108 | ) 109 | return samples, intermediates 110 | 111 | @torch.no_grad() 112 | def ddim_sampling(self, cond, shape, 113 | x_T=None, ddim_use_original_steps=False, 114 | callback=None, timesteps=None, quantize_denoised=False, 115 | mask=None, x0=None, img_callback=None, log_every_t=100, 116 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 117 | unconditional_guidance_scale=1., unconditional_conditioning=None,): 118 | device = self.model.betas.device 119 | b = shape[0] 120 | if x_T is None: 121 | img = torch.randn(shape, device=device) 122 | else: 123 | img = x_T 124 | 125 | if timesteps is None: 126 | timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps 127 | elif timesteps is not None and not ddim_use_original_steps: 128 | subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 129 | timesteps = self.ddim_timesteps[:subset_end] 130 | 131 | intermediates = {'x_inter': [img], 'pred_x0': [img]} 132 | time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) 133 | total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] 134 | print(f"Running DDIM Sampling with {total_steps} timesteps") 135 | 136 | iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) 137 | 138 | for i, step in enumerate(iterator): 139 | index = total_steps - i - 1 140 | ts = torch.full((b,), step, device=device, dtype=torch.long) 141 | 142 | if mask is not None: 143 | assert x0 is not None 144 | img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? 145 | img = img_orig * mask + (1. - mask) * img 146 | 147 | outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, 148 | quantize_denoised=quantize_denoised, temperature=temperature, 149 | noise_dropout=noise_dropout, score_corrector=score_corrector, 150 | corrector_kwargs=corrector_kwargs, 151 | unconditional_guidance_scale=unconditional_guidance_scale, 152 | unconditional_conditioning=unconditional_conditioning) 153 | img, pred_x0 = outs 154 | if callback: callback(i) 155 | if img_callback: img_callback(pred_x0, i) 156 | 157 | if index % log_every_t == 0 or index == total_steps - 1: 158 | intermediates['x_inter'].append(img) 159 | intermediates['pred_x0'].append(pred_x0) 160 | 161 | return img, intermediates 162 | 163 | @torch.no_grad() 164 | def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, 165 | temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, 166 | unconditional_guidance_scale=1., unconditional_conditioning=None): 167 | b, *_, device = *x.shape, x.device 168 | 169 | if unconditional_conditioning is None or unconditional_guidance_scale == 1.: 170 | e_t = self.model.apply_model(x, t, c) 171 | else: 172 | x_in = torch.cat([x] * 2) 173 | t_in = torch.cat([t] * 2) 174 | c_in = torch.cat([unconditional_conditioning, c]) 175 | e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) 176 | e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) 177 | 178 | if score_corrector is not None: 179 | assert self.model.parameterization == "eps" 180 | e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) 181 | 182 | alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas 183 | alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev 184 | sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas 185 | sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas 186 | # select parameters corresponding to the currently considered timestep 187 | a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) 188 | a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) 189 | sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) 190 | sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) 191 | 192 | # current prediction for x_0 193 | pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() 194 | if quantize_denoised: 195 | pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) 196 | # direction pointing to x_t 197 | dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t 198 | noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature 199 | if noise_dropout > 0.: 200 | noise = torch.nn.functional.dropout(noise, p=noise_dropout) 201 | x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise 202 | return x_prev, pred_x0 203 | -------------------------------------------------------------------------------- /models/latentloader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | class IdentityFirstStage(torch.nn.Module): 4 | def __init__(self, *args, vq_interface=False, **kwargs): 5 | super().__init__() 6 | 7 | def encode(self, x, *args, **kwargs): 8 | return x 9 | 10 | def decode(self, x, *args, **kwargs): 11 | return x 12 | 13 | def forward(self, x, *args, **kwargs): 14 | return x 15 | -------------------------------------------------------------------------------- /models/pointembed.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import pytorch_lightning as pl 3 | import torch.nn.functional as F 4 | from contextlib import contextmanager 5 | import torch.nn as nn 6 | import numpy as np 7 | 8 | 9 | class PointEmbed(nn.Module): 10 | def __init__(self, hidden_dim=48, dim=512): 11 | super().__init__() 12 | 13 | assert hidden_dim % 6 == 0 14 | 15 | self.embedding_dim = hidden_dim 16 | e = torch.pow(2, torch.arange(self.embedding_dim // 6)).float() * np.pi 17 | e = torch.stack([ 18 | torch.cat([e, torch.zeros(self.embedding_dim // 6), 19 | torch.zeros(self.embedding_dim // 6)]), 20 | torch.cat([torch.zeros(self.embedding_dim // 6), e, 21 | torch.zeros(self.embedding_dim // 6)]), 22 | torch.cat([torch.zeros(self.embedding_dim // 6), 23 | torch.zeros(self.embedding_dim // 6), e]), 24 | ]) 25 | self.register_buffer('basis', e) # 3 x 16 26 | 27 | self.mlp = nn.Linear(self.embedding_dim+3, dim) 28 | 29 | @staticmethod 30 | def embed(input, basis): 31 | projections = torch.einsum( 32 | 'bnd,de->bne', input, basis) 33 | embeddings = torch.cat([projections.sin(), projections.cos()], dim=2) 34 | return embeddings 35 | 36 | def forward(self, input): 37 | # input: B x N x 3 38 | embed = self.mlp( torch.cat([self.embed(input, self.basis), input], dim=2)) # B x N x C 39 | return embed 40 | 41 | 42 | if __name__ == "__main__": 43 | model = PointEmbed() 44 | input = torch.randn(2, 1024, 3) 45 | output = model(input) 46 | print(output.shape) 47 | -------------------------------------------------------------------------------- /models/unet_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.models.segmentation import fcn_resnet50 3 | 4 | 5 | class FeatureCondStage(torch.nn.Module): 6 | def __init__(self, *args, output_channels=128, **kwargs): 7 | super().__init__() 8 | 9 | self.model = fcn_resnet50(pretrained=True) 10 | self.model.backbone.conv1 = torch.nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False) 11 | self.model.classifier[4] = CustomUpsample(512, output_channels) 12 | 13 | def forward(self, x): 14 | return self.model(x)['out'] 15 | 16 | 17 | class CustomUpsample(torch.nn.Module): 18 | def __init__(self, in_channels, out_channels): 19 | super(CustomUpsample, self).__init__() 20 | self.upsample = torch.nn.Upsample( 21 | scale_factor=8, mode='bilinear', align_corners=False) 22 | self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1) 23 | 24 | def forward(self, x): 25 | x = self.upsample(x) 26 | x = self.conv(x) 27 | return x 28 | 29 | -------------------------------------------------------------------------------- /modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoyiG/DiffCAD/0496e93a8dc1110102a38352499ab35155ab1609/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /modules/diffusionmodules/util.py: -------------------------------------------------------------------------------- 1 | # adopted from 2 | # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 3 | # and 4 | # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py 5 | # and 6 | # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py 7 | # 8 | # thanks! 9 | 10 | 11 | import os 12 | import math 13 | import torch 14 | import torch.nn as nn 15 | import numpy as np 16 | from einops import repeat 17 | 18 | from util import instantiate_from_config 19 | 20 | 21 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 22 | if schedule == "linear": 23 | betas = ( 24 | torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 25 | ) 26 | 27 | elif schedule == "cosine": 28 | timesteps = ( 29 | torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s 30 | ) 31 | alphas = timesteps / (1 + cosine_s) * np.pi / 2 32 | alphas = torch.cos(alphas).pow(2) 33 | alphas = alphas / alphas[0] 34 | betas = 1 - alphas[1:] / alphas[:-1] 35 | betas = np.clip(betas, a_min=0, a_max=0.999) 36 | 37 | elif schedule == "sqrt_linear": 38 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) 39 | elif schedule == "sqrt": 40 | betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 41 | else: 42 | raise ValueError(f"schedule '{schedule}' unknown.") 43 | return betas.numpy() 44 | 45 | 46 | def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): 47 | if ddim_discr_method == 'uniform': 48 | c = num_ddpm_timesteps // num_ddim_timesteps 49 | ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) 50 | elif ddim_discr_method == 'quad': 51 | ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) 52 | else: 53 | raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') 54 | 55 | # assert ddim_timesteps.shape[0] == num_ddim_timesteps 56 | # add one to get the final alpha values right (the ones from first scale to data during sampling) 57 | steps_out = ddim_timesteps + 1 58 | if verbose: 59 | print(f'Selected timesteps for ddim sampler: {steps_out}') 60 | return steps_out 61 | 62 | 63 | def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): 64 | # select alphas for computing the variance schedule 65 | alphas = alphacums[ddim_timesteps] 66 | alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) 67 | 68 | # according the the formula provided in https://arxiv.org/abs/2010.02502 69 | sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) 70 | if verbose: 71 | print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') 72 | print(f'For the chosen value of eta, which is {eta}, ' 73 | f'this results in the following sigma_t schedule for ddim sampler {sigmas}') 74 | return sigmas, alphas, alphas_prev 75 | 76 | 77 | def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): 78 | """ 79 | Create a beta schedule that discretizes the given alpha_t_bar function, 80 | which defines the cumulative product of (1-beta) over time from t = [0,1]. 81 | :param num_diffusion_timesteps: the number of betas to produce. 82 | :param alpha_bar: a lambda that takes an argument t from 0 to 1 and 83 | produces the cumulative product of (1-beta) up to that 84 | part of the diffusion process. 85 | :param max_beta: the maximum beta to use; use values lower than 1 to 86 | prevent singularities. 87 | """ 88 | betas = [] 89 | for i in range(num_diffusion_timesteps): 90 | t1 = i / num_diffusion_timesteps 91 | t2 = (i + 1) / num_diffusion_timesteps 92 | betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) 93 | return np.array(betas) 94 | 95 | 96 | def extract_into_tensor(a, t, x_shape): 97 | b, *_ = t.shape 98 | out = a.gather(-1, t) 99 | return out.reshape(b, *((1,) * (len(x_shape) - 1))) 100 | 101 | 102 | def checkpoint(func, inputs, params, flag): 103 | """ 104 | Evaluate a function without caching intermediate activations, allowing for 105 | reduced memory at the expense of extra compute in the backward pass. 106 | :param func: the function to evaluate. 107 | :param inputs: the argument sequence to pass to `func`. 108 | :param params: a sequence of parameters `func` depends on but does not 109 | explicitly take as arguments. 110 | :param flag: if False, disable gradient checkpointing. 111 | """ 112 | if flag: 113 | args = tuple(inputs) + tuple(params) 114 | return CheckpointFunction.apply(func, len(inputs), *args) 115 | else: 116 | return func(*inputs) 117 | 118 | 119 | class CheckpointFunction(torch.autograd.Function): 120 | @staticmethod 121 | def forward(ctx, run_function, length, *args): 122 | ctx.run_function = run_function 123 | ctx.input_tensors = list(args[:length]) 124 | ctx.input_params = list(args[length:]) 125 | 126 | with torch.no_grad(): 127 | output_tensors = ctx.run_function(*ctx.input_tensors) 128 | return output_tensors 129 | 130 | @staticmethod 131 | def backward(ctx, *output_grads): 132 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 133 | with torch.enable_grad(): 134 | # Fixes a bug where the first op in run_function modifies the 135 | # Tensor storage in place, which is not allowed for detach()'d 136 | # Tensors. 137 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 138 | output_tensors = ctx.run_function(*shallow_copies) 139 | input_grads = torch.autograd.grad( 140 | output_tensors, 141 | ctx.input_tensors + ctx.input_params, 142 | output_grads, 143 | allow_unused=True, 144 | ) 145 | del ctx.input_tensors 146 | del ctx.input_params 147 | del output_tensors 148 | return (None, None) + input_grads 149 | 150 | 151 | def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): 152 | """ 153 | Create sinusoidal timestep embeddings. 154 | :param timesteps: a 1-D Tensor of N indices, one per batch element. 155 | These may be fractional. 156 | :param dim: the dimension of the output. 157 | :param max_period: controls the minimum frequency of the embeddings. 158 | :return: an [N x dim] Tensor of positional embeddings. 159 | """ 160 | if not repeat_only: 161 | half = dim // 2 162 | freqs = torch.exp( 163 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 164 | ).to(device=timesteps.device) 165 | args = timesteps[:, None].float() * freqs[None] 166 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 167 | if dim % 2: 168 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 169 | else: 170 | embedding = repeat(timesteps, 'b -> b d', d=dim) 171 | return embedding 172 | 173 | 174 | def zero_module(module): 175 | """ 176 | Zero out the parameters of a module and return it. 177 | """ 178 | for p in module.parameters(): 179 | p.detach().zero_() 180 | return module 181 | 182 | 183 | def scale_module(module, scale): 184 | """ 185 | Scale the parameters of a module and return it. 186 | """ 187 | for p in module.parameters(): 188 | p.detach().mul_(scale) 189 | return module 190 | 191 | 192 | def mean_flat(tensor): 193 | """ 194 | Take the mean over all non-batch dimensions. 195 | """ 196 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 197 | 198 | 199 | def normalization(channels): 200 | """ 201 | Make a standard normalization layer. 202 | :param channels: number of input channels. 203 | :return: an nn.Module for normalization. 204 | """ 205 | return GroupNorm32(32, channels) 206 | 207 | 208 | # PyTorch 1.7 has SiLU, but we support PyTorch 1.5. 209 | class SiLU(nn.Module): 210 | def forward(self, x): 211 | return x * torch.sigmoid(x) 212 | 213 | 214 | class GroupNorm32(nn.GroupNorm): 215 | def forward(self, x): 216 | return super().forward(x.float()).type(x.dtype) 217 | 218 | def conv_nd(dims, *args, **kwargs): 219 | """ 220 | Create a 1D, 2D, or 3D convolution module. 221 | """ 222 | if dims == 1: 223 | return nn.Conv1d(*args, **kwargs) 224 | elif dims == 2: 225 | return nn.Conv2d(*args, **kwargs) 226 | elif dims == 3: 227 | return nn.Conv3d(*args, **kwargs) 228 | raise ValueError(f"unsupported dimensions: {dims}") 229 | 230 | 231 | def linear(*args, **kwargs): 232 | """ 233 | Create a linear module. 234 | """ 235 | return nn.Linear(*args, **kwargs) 236 | 237 | 238 | def avg_pool_nd(dims, *args, **kwargs): 239 | """ 240 | Create a 1D, 2D, or 3D average pooling module. 241 | """ 242 | if dims == 1: 243 | return nn.AvgPool1d(*args, **kwargs) 244 | elif dims == 2: 245 | return nn.AvgPool2d(*args, **kwargs) 246 | elif dims == 3: 247 | return nn.AvgPool3d(*args, **kwargs) 248 | raise ValueError(f"unsupported dimensions: {dims}") 249 | 250 | 251 | class HybridConditioner(nn.Module): 252 | 253 | def __init__(self, c_concat_config, c_crossattn_config): 254 | super().__init__() 255 | self.concat_conditioner = instantiate_from_config(c_concat_config) 256 | self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) 257 | 258 | def forward(self, c_concat, c_crossattn): 259 | c_concat = self.concat_conditioner(c_concat) 260 | c_crossattn = self.crossattn_conditioner(c_crossattn) 261 | return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} 262 | 263 | 264 | def noise_like(shape, device, repeat=False): 265 | repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) 266 | noise = lambda: torch.randn(shape, device=device) 267 | return repeat_noise() if repeat else noise() -------------------------------------------------------------------------------- /modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoyiG/DiffCAD/0496e93a8dc1110102a38352499ab35155ab1609/modules/distributions/__init__.py -------------------------------------------------------------------------------- /modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DaoyiG/DiffCAD/0496e93a8dc1110102a38352499ab35155ab1609/modules/encoders/__init__.py -------------------------------------------------------------------------------- /modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | import kornia 7 | 8 | 9 | from modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast # TODO: add to reuquirements 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | 138 | class FrozenCLIPTextEmbedder(nn.Module): 139 | """ 140 | Uses the CLIP transformer encoder for text. 141 | """ 142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 143 | super().__init__() 144 | self.model, _ = clip.load(version, jit=False, device="cpu") 145 | self.device = device 146 | self.max_length = max_length 147 | self.n_repeat = n_repeat 148 | self.normalize = normalize 149 | 150 | def freeze(self): 151 | self.model = self.model.eval() 152 | for param in self.parameters(): 153 | param.requires_grad = False 154 | 155 | def forward(self, text): 156 | tokens = clip.tokenize(text).to(self.device) 157 | z = self.model.encode_text(tokens) 158 | if self.normalize: 159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 160 | return z 161 | 162 | def encode(self, text): 163 | z = self(text) 164 | if z.ndim==2: 165 | z = z[:, None, :] 166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 167 | return z 168 | 169 | 170 | class FrozenClipImageEmbedder(nn.Module): 171 | """ 172 | Uses the CLIP image encoder. 173 | """ 174 | def __init__( 175 | self, 176 | model, 177 | jit=False, 178 | device='cuda' if torch.cuda.is_available() else 'cpu', 179 | antialias=False, 180 | ): 181 | super().__init__() 182 | self.model, _ = clip.load(name=model, device=device, jit=jit) 183 | 184 | self.antialias = antialias 185 | 186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 188 | 189 | def preprocess(self, x): 190 | # normalize to [0,1] 191 | x = kornia.geometry.resize(x, (224, 224), 192 | interpolation='bicubic',align_corners=True, 193 | antialias=self.antialias) 194 | x = (x + 1.) / 2. 195 | # renormalize according to clip 196 | x = kornia.enhance.normalize(x, self.mean, self.std) 197 | return x 198 | 199 | def forward(self, x): 200 | # x is assumed to be in range [-1,1] 201 | return self.model.encode_image(self.preprocess(x)) 202 | 203 | -------------------------------------------------------------------------------- /scripts/alignment_from_nocs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from tqdm import tqdm, trange 10 | from util import instantiate_from_config 11 | from data.data_util import estimate9DTransform, to_homo 12 | import json 13 | import open3d as o3d 14 | import quaternion 15 | 16 | 17 | def load_model_from_config(config, ckpt, verbose=False): 18 | print(f"Loading model from {ckpt}") 19 | pl_sd = torch.load(ckpt, map_location="cpu") 20 | sd = pl_sd["state_dict"] 21 | model = instantiate_from_config(config.model) 22 | m, u = model.load_state_dict(sd, strict=True) 23 | if len(m) > 0 and verbose: 24 | print("missing keys:") 25 | print(m) 26 | if len(u) > 0 and verbose: 27 | print("unexpected keys:") 28 | print(u) 29 | 30 | model.cuda() 31 | model.eval() 32 | return model 33 | 34 | 35 | def calc_rotation_diff(q: np.quaternion, q00: np.quaternion) -> float: 36 | np.seterr(all='raise') 37 | # rotation_dot = np.dot(quaternion.as_float_array(q00), quaternion.as_float_array(q)) 38 | rotation_dot = q00[0] * q[0] + q00[1] * q[1] + q00[2] * q[2] + q00[3] * q[3] 39 | 40 | rotation_dot_abs = np.abs(rotation_dot) 41 | try: 42 | error_rotation_rad = 2 * np.arccos(rotation_dot_abs) 43 | except: 44 | return 0.0 45 | error_rotation_rad = 2 * np.arccos(rotation_dot_abs) 46 | error_rotation = np.rad2deg(error_rotation_rad) 47 | 48 | return error_rotation 49 | 50 | 51 | def make_M_from_tqs(t: list, q: list, s: list, center=None) -> np.ndarray: 52 | if not isinstance(q, np.quaternion): 53 | q = np.quaternion(q[0], q[1], q[2], q[3]) 54 | T = np.eye(4) 55 | T[0:3, 3] = t 56 | R = np.eye(4) 57 | R[0:3, 0:3] = quaternion.as_rotation_matrix(q) 58 | S = np.eye(4) 59 | S[0:3, 0:3] = np.diag(s) 60 | 61 | C = np.eye(4) 62 | if center is not None: 63 | C[0:3, 3] = center 64 | 65 | M = T.dot(R).dot(S).dot(C) 66 | return M 67 | 68 | 69 | def decompose_mat4(M: np.ndarray) -> tuple: 70 | R = M[0:3, 0:3].copy() 71 | sx = np.linalg.norm(R[0:3, 0]) 72 | sy = np.linalg.norm(R[0:3, 1]) 73 | sz = np.linalg.norm(R[0:3, 2]) 74 | 75 | s = np.array([sx, sy, sz]) 76 | 77 | R[:, 0] /= sx 78 | R[:, 1] /= sy 79 | R[:, 2] /= sz 80 | 81 | q = quaternion.as_float_array(quaternion.from_rotation_matrix(R[0:3, 0:3])) 82 | # q = quaternion.from_float_array(quaternion_from_matrix(M, False)) 83 | 84 | t = M[0:3, 3] 85 | return t, q, s, R 86 | 87 | 88 | def re(R_est, R_gt): 89 | """Rotational Error. 90 | 91 | :param R_est: 3x3 ndarray with the estimated rotation matrix. 92 | :param R_gt: 3x3 ndarray with the ground-truth rotation matrix. 93 | :return: The calculated error. 94 | """ 95 | assert R_est.shape == R_gt.shape == (3, 3) 96 | rotation_diff = np.dot(R_est, R_gt.T) 97 | trace = np.trace(rotation_diff) 98 | trace = trace if trace <= 3 else 3 99 | # Avoid invalid values due to numerical errors 100 | error_cos = min(1.0, max(-1.0, 0.5 * (trace - 1.0))) 101 | rd_deg = np.rad2deg(np.arccos(error_cos)) 102 | 103 | return rd_deg 104 | 105 | 106 | def te(t_est, t_gt): 107 | """Translational Error. 108 | 109 | :param t_est: 3x1 ndarray with the estimated translation vector. 110 | :param t_gt: 3x1 ndarray with the ground-truth translation vector. 111 | :return: The calculated error. 112 | """ 113 | t_est = t_est.flatten() 114 | t_gt = t_gt.flatten() 115 | assert t_est.size == t_gt.size == 3 116 | error = np.linalg.norm(t_gt - t_est) 117 | return error 118 | 119 | 120 | def se(s_est, s_gt): 121 | """Scale Error. 122 | 123 | :param s_est: 3x1 ndarray with the estimated scale vector. 124 | :param s_gt: 3x1 ndarray with the ground-truth scale vector. 125 | :return: The calculated error. 126 | """ 127 | s_est = s_est.flatten() 128 | s_gt = s_gt.flatten() 129 | assert s_est.size == s_gt.size == 3 130 | # error = np.abs(np.mean(s_est/s_gt) - 1) 131 | error = np.mean(np.abs((s_est/s_gt) - 1)) 132 | return error 133 | 134 | 135 | if __name__ == "__main__": 136 | parser = argparse.ArgumentParser() 137 | 138 | parser.add_argument( 139 | "--prediction_path", 140 | type=str, 141 | help="path to the predictions of NOCs of a category", 142 | ) 143 | 144 | parser.add_argument( 145 | "--pose_gt_root", 146 | type=str, 147 | help="path to the ground truth pose json", 148 | ) 149 | 150 | parser.add_argument( 151 | "--mesh_root", 152 | type=str, 153 | help="original shapenet meshes are not canonicalized (they have slight misalignment), the original ground truth pose is defined based on those meshes. We canonicalize those meshes and save the centroid offset into the mesh root", 154 | ) 155 | 156 | parser.add_argument( 157 | "--split_path", 158 | type=str, 159 | help="read the split", 160 | ) 161 | 162 | parser.add_argument( 163 | "--outdir", 164 | type=str, 165 | nargs="?", 166 | help="dir to write results to", 167 | ) 168 | 169 | parser.add_argument( 170 | "--num_iters", 171 | type=int, 172 | default=2, 173 | help="number of different input candidates", 174 | ) 175 | 176 | parser.add_argument( 177 | "--visualize", 178 | type=bool, 179 | default=False, 180 | help="visualize the alignment results (object point clouds with transformed NOCs)", 181 | ) 182 | 183 | opt = parser.parse_args() 184 | 185 | with open(opt.pose_gt_root, 'r') as f: 186 | pose_gts = json.load(f) 187 | 188 | with open(opt.split_path, 'r') as f: 189 | split_lines = f.read().splitlines() 190 | 191 | os.makedirs(opt.outdir, exist_ok=True) 192 | 193 | res = [] 194 | tes = [] 195 | ses = [] 196 | invalid_pose = [] 197 | prediction = {} 198 | 199 | for line in tqdm(split_lines): 200 | 201 | # read the target frame 202 | scene_id, frame_idx = line.split() 203 | scene_info = scene_id + '_' + frame_idx 204 | frame_id, mesh_id, inst_id = frame_idx.split('_') 205 | prediction[scene_info] = {} 206 | 207 | gt_pose_ws = np.asarray(pose_gts[scene_info]).reshape(4, 4) 208 | gt_pose_ws = np.asarray(gt_pose_ws).reshape(4, 4) 209 | 210 | with open(os.path.join(opt.mesh_root, mesh_id, 'mesh_info.json'), 'r') as f: 211 | centroid_offset = json.load(f)["centroid_offset"] 212 | 213 | centroid_offset = np.asarray(centroid_offset) 214 | offset = np.eye(4) 215 | offset[:3, 3] = -centroid_offset 216 | 217 | cam_K = np.asarray([434.98, 0.0, 239.36, 0.0, 434.05, 182.2, 0.0, 0.0, 1.0], dtype=np.float32).reshape(3, 3) 218 | 219 | pose_recalib = gt_pose_ws @ offset 220 | 221 | t_gt, q_gt, s_gt, R_gt = decompose_mat4(pose_recalib) 222 | 223 | scale_gt_frompose = np.linalg.norm(pose_recalib[:3, :3], axis=0) 224 | assert np.allclose(s_gt, scale_gt_frompose) 225 | 226 | best_transforms = [] 227 | best_ratios = [] 228 | best_nocss = [] 229 | best_errs = [] 230 | for i in tqdm(range(opt.num_iters)): 231 | 232 | prediction[scene_info]['gt_pose'] = pose_recalib.tolist() 233 | 234 | prediction[scene_info]['{}'.format(i)] = {} 235 | 236 | ori_cloud_ply = o3d.io.read_point_cloud(os.path.join(opt.prediction_path, scene_info + '_depth_input_{}.ply'.format(i))) 237 | pred_nocs_ply = o3d.io.read_point_cloud(os.path.join(opt.prediction_path, scene_info + '_nocs_pred_{}.ply'.format(i))) 238 | 239 | ori_cloud = np.asarray(ori_cloud_ply.points) 240 | 241 | pred_nocs = np.asarray(pred_nocs_ply.points) 242 | 243 | min_scale = [0.8, 0.8, 0.8] 244 | max_scale = [4.0, 4.0, 4.0] 245 | max_dimensions = np.array([1.2, 1.2, 1.2]) 246 | 247 | best_ratio = 0 248 | best_transform = None 249 | best_err = np.inf 250 | for thres in [0.001, 0.005, 0.01, 0.05, 0.1]: 251 | use_kdtree_for_eval = False 252 | kdtree_eval_resolution = 1 253 | transform, inliers = estimate9DTransform(source=pred_nocs, target=ori_cloud, PassThreshold=thres, max_iter=5000, 254 | use_kdtree_for_eval=use_kdtree_for_eval, kdtree_eval_resolution=kdtree_eval_resolution, 255 | max_scale=max_scale, min_scale=min_scale, max_dimensions=max_dimensions) 256 | 257 | if transform is not None: 258 | transformed = (transform@to_homo(pred_nocs).T).T[:, :3] 259 | errs = np.linalg.norm(transformed-ori_cloud, axis=1) 260 | total_err = errs.mean() 261 | if total_err < best_err: 262 | best_transform = transform.copy() 263 | best_err = total_err.copy() 264 | best_nocs = pred_nocs.copy() 265 | 266 | if best_transform is not None: 267 | best_transforms.append(best_transform) 268 | best_errs.append(best_err) 269 | best_nocss.append(best_nocs) 270 | prediction[scene_info]['{}'.format(i)]['best_transform'] = best_transform.tolist() 271 | 272 | if len(best_errs) != 0: 273 | # pick one with the best ratio 274 | best_rat_idx = np.argmin(np.asarray(best_errs)) 275 | best_transform_selected = best_transforms[best_rat_idx] 276 | best_pred_nocs = best_nocss[best_rat_idx] 277 | 278 | # save the depth 279 | depth_fname = os.path.join(opt.outdir, scene_info + '_depth_input.ply') 280 | 281 | o3d.io.write_point_cloud("{}".format(depth_fname), ori_cloud_ply, write_ascii=False) 282 | 283 | pred_fname = os.path.join(opt.outdir, scene_info + '_best_pred.ply') 284 | pcd_pred = o3d.geometry.PointCloud() 285 | pcd_pred.points = o3d.utility.Vector3dVector(best_pred_nocs) 286 | 287 | o3d.io.write_point_cloud("{}".format(pred_fname), pcd_pred, write_ascii=False) 288 | 289 | transformed = (best_transform_selected@to_homo(best_pred_nocs).T).T[:, :3] 290 | pred_trans_fname = os.path.join(opt.outdir, scene_info + '_best_pred_transformed.ply') 291 | pcd_pred_trans = o3d.geometry.PointCloud() 292 | pcd_pred_trans.points = o3d.utility.Vector3dVector(transformed) 293 | 294 | o3d.io.write_point_cloud("{}".format(pred_trans_fname), pcd_pred_trans, write_ascii=False) 295 | 296 | if opt.visualize: 297 | pcd_pred_trans.paint_uniform_color([1, 0.706, 0]) 298 | ori_cloud_ply.paint_uniform_color([0, 0.651, 0.929]) 299 | o3d.visualization.draw_geometries([pcd_pred_trans, ori_cloud_ply], window_name=scene_info) 300 | 301 | t_pred, q_pred, s_pred, R_pred = decompose_mat4(best_transform_selected) 302 | 303 | rot_err = re(R_pred, R_gt) 304 | rot_err_1 = calc_rotation_diff(q_pred, q_gt) 305 | trans_err = te(t_pred.reshape(3, 1), t_gt.reshape(3, 1)) 306 | scale_err = se(s_pred, s_gt) 307 | 308 | print("{} rot err {}; trans err {} ; scale err {} ".format(scene_info, rot_err, trans_err, scale_err)) 309 | 310 | res.append(rot_err) 311 | tes.append(trans_err) 312 | ses.append(scale_err) 313 | 314 | 315 | prediction[scene_info]['predicted_pose'] = best_transform_selected.tolist() 316 | prediction[scene_info]['gt_pose'] = pose_recalib.tolist() 317 | prediction[scene_info]['rot_err'] = rot_err.tolist() 318 | prediction[scene_info]['trans_err'] = trans_err.tolist() 319 | prediction[scene_info]['scale_err'] = scale_err.tolist() 320 | 321 | 322 | with open(os.path.join(opt.outdir, 'pose_predictions.json'), 'w', encoding='utf-8') as f: 323 | json.dump(prediction, f, ensure_ascii=False, indent=2) 324 | -------------------------------------------------------------------------------- /scripts/generate_multi_nocs_candidates.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from PIL import Image 10 | from tqdm import tqdm 11 | from util import instantiate_from_config 12 | import cv2 13 | import json 14 | import kornia as kn 15 | import open3d as o3d 16 | import quaternion 17 | 18 | 19 | def load_model_from_config(config, ckpt, verbose=False): 20 | print(f"Loading model from {ckpt}") 21 | pl_sd = torch.load(ckpt, map_location="cpu") 22 | sd = pl_sd["state_dict"] 23 | model = instantiate_from_config(config.model) 24 | m, u = model.load_state_dict(sd, strict=True) 25 | if len(m) > 0 and verbose: 26 | print("missing keys:") 27 | print(m) 28 | if len(u) > 0 and verbose: 29 | print("unexpected keys:") 30 | print(u) 31 | 32 | model.cuda() 33 | model.eval() 34 | return model 35 | 36 | def make_M_from_tqs(t: list, q: list, s: list, center=None) -> np.ndarray: 37 | if not isinstance(q, np.quaternion): 38 | q = np.quaternion(q[0], q[1], q[2], q[3]) 39 | T = np.eye(4) 40 | T[0:3, 3] = t 41 | R = np.eye(4) 42 | R[0:3, 0:3] = quaternion.as_rotation_matrix(q) 43 | S = np.eye(4) 44 | S[0:3, 0:3] = np.diag(s) 45 | 46 | C = np.eye(4) 47 | if center is not None: 48 | C[0:3, 3] = center 49 | 50 | M = T.dot(R).dot(S).dot(C) 51 | return M 52 | 53 | def decompose_mat4(M: np.ndarray) -> tuple: 54 | R = M[0:3, 0:3].copy() 55 | sx = np.linalg.norm(R[0:3, 0]) 56 | sy = np.linalg.norm(R[0:3, 1]) 57 | sz = np.linalg.norm(R[0:3, 2]) 58 | 59 | s = np.array([sx, sy, sz]) 60 | s = np.abs(s) 61 | 62 | R[:, 0] /= sx 63 | R[:, 1] /= sy 64 | R[:, 2] /= sz 65 | 66 | q = quaternion.as_float_array(quaternion.from_rotation_matrix(R[0:3, 0:3])) 67 | # q = quaternion.from_float_array(quaternion_from_matrix(M, False)) 68 | 69 | t = M[0:3, 3] 70 | return t, q, s, R 71 | 72 | 73 | if __name__ == "__main__": 74 | parser = argparse.ArgumentParser() 75 | 76 | parser.add_argument( 77 | "--category", 78 | type=str, 79 | default=None 80 | ) 81 | 82 | parser.add_argument( 83 | "--config_path", 84 | type=str, 85 | help="path to the model config file" 86 | ) 87 | 88 | parser.add_argument( 89 | "--model_path", 90 | type=str, 91 | help="path to the model checkpoint" 92 | ) 93 | 94 | parser.add_argument( 95 | "--normalized_depth", 96 | type=bool, 97 | default=False, 98 | help="whether using normalized depth as input, this should be consistent with the training setting" 99 | ) 100 | 101 | parser.add_argument( 102 | "--data_path", 103 | type=str, 104 | ) 105 | parser.add_argument( 106 | "--pose_gt_root", 107 | type=str, 108 | ) 109 | 110 | parser.add_argument( 111 | "--mesh_root", 112 | type=str, 113 | help="to get the centroid offset for the canonicalized mesh to align with the GT pose" 114 | ) 115 | 116 | parser.add_argument( 117 | "--split_path", 118 | type=str, 119 | help="read the split", 120 | ) 121 | 122 | parser.add_argument( 123 | "--outdir", 124 | type=str, 125 | nargs="?", 126 | help="dir to write results to", 127 | ) 128 | 129 | parser.add_argument( 130 | "--num_iters", 131 | type=int, 132 | default=3, 133 | help="number of different input subsampled pointcloud", 134 | ) 135 | 136 | parser.add_argument( 137 | "--gt_pose", 138 | type=bool, 139 | default=False, 140 | help="whether have access to ground truth pose&rendered depth", 141 | ) 142 | 143 | parser.add_argument( 144 | "--pred_scale_dir", 145 | type=str, 146 | ) 147 | 148 | opt = parser.parse_args() 149 | 150 | with open(opt.pred_scale_dir, 'r') as f: 151 | scales_fullset = json.load(f) 152 | 153 | config = OmegaConf.load(opt.config_path) 154 | 155 | print("evaluate model with parameterization of {}".format(config.model.params.parameterization)) 156 | 157 | model_path = opt.model_path 158 | 159 | if opt.gt_pose: 160 | with open(opt.pose_gt_root, 'r') as f: 161 | pose_gts = json.load(f) 162 | 163 | ckpt_name = model_path.split('/')[-1].split('.')[0] 164 | 165 | print("running generation on {}".format(model_path)) 166 | model = load_model_from_config(config, model_path) 167 | 168 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 169 | model = model.to(device) 170 | 171 | with open(opt.split_path, 'r') as f: 172 | split_lines = f.read().splitlines() 173 | 174 | output_path = os.path.join(opt.outdir, ckpt_name) 175 | print("Saving visuals to {}".format(output_path)) 176 | os.makedirs(output_path, exist_ok=True) 177 | 178 | for line in tqdm(split_lines): 179 | ori_cloud_cond_batch = [] 180 | 181 | # read the target frame 182 | scene_id, frame_idx = line.split() 183 | scene_info = scene_id + '_' + frame_idx 184 | frame_id, mesh_id, inst_id = frame_idx.split('_') 185 | 186 | scales_subset = [] 187 | 188 | for x in range(opt.num_iters): 189 | scales_subset.append(scales_fullset[scene_id + ' ' + frame_idx][str(x)][0]) 190 | 191 | scales_subset = sorted(scales_subset) 192 | print('number of scales per-scene: ', len(scales_subset)) 193 | 194 | for i, depth_aug_scalar in enumerate(scales_subset): 195 | if opt.gt_pose: 196 | depth_gt_fname = os.path.join(opt.data_path, "Rendering", scene_id, 'depth', "{}{}".format(frame_id, '.png')) 197 | depth_gt = cv2.imread(depth_gt_fname, -1) 198 | gt_pose_ws = np.asarray(pose_gts[scene_info]).reshape(4, 4) 199 | t_gt, q_gt, s_gt, R_gt = decompose_mat4(gt_pose_ws) 200 | scale_gt = np.linalg.norm(gt_pose_ws[:3, :3], axis=0) 201 | with open(os.path.join(opt.mesh_root, mesh_id, 'mesh_info.json'), 'r') as f: 202 | centroid_offset = json.load(f)["centroid_offset"] 203 | centroid_offset = np.asarray(centroid_offset) 204 | offset = np.eye(4) 205 | offset[:3, 3] = -centroid_offset 206 | pose_recalib = gt_pose_ws @ offset 207 | 208 | depth_fname = os.path.join(opt.data_path, 'ZoeDepthPredictions', scene_id, "{}_pred_dmap.npy".format(frame_id)) 209 | depth = np.load(depth_fname) 210 | depth = (depth * 1000) 211 | depth = (depth / depth_aug_scalar).astype(np.uint16) 212 | 213 | mask_fname = os.path.join(opt.data_path, "ODISEPredictions_NEW", scene_id, opt.category, frame_idx+'.png') 214 | mask = cv2.imread(mask_fname, -1) / 255 215 | 216 | mask = torch.from_numpy(mask).unsqueeze(0).unsqueeze(0) 217 | kernel = torch.ones(7, 7) 218 | mask = kn.morphology.erosion(mask, kernel) 219 | 220 | mask_ero = mask.squeeze(0).squeeze(0).detach().cpu().numpy() 221 | 222 | target_depth = depth * mask_ero.astype(np.uint8) 223 | 224 | if opt.gt_pose: 225 | target_depth_gt = depth_gt * mask_ero.astype(np.uint8) 226 | target_depth_gt = o3d.geometry.Image(target_depth_gt) 227 | pcd_trans_gt = o3d.geometry.PointCloud.create_from_depth_image(target_depth_gt, intr, pose_recalib) 228 | gt_nocs_rendered = np.asarray(pcd_trans_gt.points) 229 | 230 | target_depth = o3d.geometry.Image(target_depth) 231 | target_depth_full = o3d.geometry.Image(depth) 232 | cam_K = np.asarray([434.98, 0.0, 239.36, 0.0, 434.05, 182.2, 0.0, 0.0, 1.0], dtype=np.float32).reshape(3, 3) 233 | intr = o3d.camera.PinholeCameraIntrinsic(480, 360, cam_K) 234 | pcd_orig = o3d.geometry.PointCloud.create_from_depth_image(target_depth, intr) 235 | 236 | ori_cloud = np.asarray(pcd_orig.points) 237 | 238 | ori_cloud_normalized = (ori_cloud - ori_cloud.min(axis=0)) / ((ori_cloud.max(axis=0) - ori_cloud.min(axis=0)).max() + 1e-15) 239 | 240 | sample_num_pc = 1024 241 | indices_pc = np.random.choice(ori_cloud.shape[0], size=sample_num_pc, replace=False) 242 | ori_cloud = ori_cloud[indices_pc, :] 243 | ori_cloud_normalized = ori_cloud_normalized[indices_pc, :] 244 | 245 | if opt.normalized_depth: 246 | ori_cloud_cond = torch.from_numpy(ori_cloud_normalized).float() 247 | else: 248 | ori_cloud_cond = torch.from_numpy(ori_cloud).float() 249 | 250 | ori_cloud_cond = ori_cloud_cond.unsqueeze(0).to(device) 251 | ori_cloud_cond_batch.append(ori_cloud_cond) 252 | 253 | # save the original depth for pose solver 254 | depth_fname = os.path.join(output_path, scene_info + '_depth_input_{}.ply'.format(i)) 255 | pcd_depth = o3d.geometry.PointCloud() 256 | pcd_depth.points = o3d.utility.Vector3dVector(ori_cloud) 257 | 258 | o3d.io.write_point_cloud("{}".format(depth_fname), pcd_depth, write_ascii=False) 259 | ori_cloud_cond_batch = torch.cat(ori_cloud_cond_batch, dim=0) 260 | with torch.no_grad(): 261 | with model.ema_scope(): 262 | cond = model.get_learned_conditioning(ori_cloud_cond_batch) 263 | 264 | samples, _ = model.sample(cond=cond, batch_size=opt.num_iters, return_intermediates=True) 265 | samples = model.decode_first_stage(samples) 266 | 267 | for j in range(opt.num_iters): 268 | pred_nocs = samples[j].permute(1, 0) 269 | pred_nocs = pred_nocs.detach().cpu().numpy() 270 | pred_fname = os.path.join(output_path, scene_info + '_nocs_pred_{}.ply'.format(j)) 271 | pcd_pred = o3d.geometry.PointCloud() 272 | pcd_pred.points = o3d.utility.Vector3dVector(pred_nocs) 273 | 274 | o3d.io.write_point_cloud("{}".format(pred_fname), pcd_pred, write_ascii=False) 275 | -------------------------------------------------------------------------------- /scripts/generate_multi_scale_candidates.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | import importlib 5 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 7 | import torch 8 | import numpy as np 9 | from omegaconf import OmegaConf 10 | from tqdm import tqdm, trange 11 | import torchvision 12 | from models.diffusion.ddim_scale import DDIMSampler 13 | import cv2 14 | import json 15 | import kornia as kn 16 | 17 | 18 | def instantiate_from_config(config): 19 | if not "target" in config: 20 | if config == '__is_first_stage__': 21 | return None 22 | elif config == "__is_unconditional__": 23 | return None 24 | raise KeyError("Expected key `target` to instantiate.") 25 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 26 | 27 | def get_obj_from_str(string, reload=False): 28 | module, cls = string.rsplit(".", 1) 29 | if reload: 30 | module_imp = importlib.import_module(module) 31 | importlib.reload(module_imp) 32 | return getattr(importlib.import_module(module, package=None), cls) 33 | 34 | def load_model_from_config(config, ckpt, verbose=False): 35 | print(f"Loading model from {ckpt}") 36 | pl_sd = torch.load(ckpt, map_location="cpu") 37 | sd = pl_sd["state_dict"] 38 | model = instantiate_from_config(config.model) 39 | m, u = model.load_state_dict(sd, strict=True) 40 | if len(m) > 0 and verbose: 41 | print("missing keys:") 42 | print(m) 43 | if len(u) > 0 and verbose: 44 | print("unexpected keys:") 45 | print(u) 46 | model.cuda() 47 | model.eval() 48 | return model 49 | 50 | 51 | if __name__ == "__main__": 52 | parser = argparse.ArgumentParser() 53 | 54 | parser.add_argument( 55 | "--category", 56 | type=str, 57 | default="02818832" 58 | ) 59 | 60 | parser.add_argument( 61 | "--config_path", 62 | type=str, 63 | help="path to the model config file" 64 | ) 65 | 66 | parser.add_argument( 67 | "--model_path", 68 | type=str, 69 | help="path to the model checkpoint" 70 | ) 71 | 72 | parser.add_argument( 73 | "--data_path", 74 | type=str, 75 | help="path to the data" 76 | ) 77 | 78 | parser.add_argument( 79 | "--split_path", 80 | type=str, 81 | help="read the test split", 82 | ) 83 | 84 | parser.add_argument( 85 | "--outdir", 86 | type=str, 87 | nargs="?", 88 | help="dir to write results to" 89 | ) 90 | 91 | parser.add_argument( 92 | "--num_iters", 93 | type=int, 94 | default=5, 95 | help="number of different input subsampled pointcloud", 96 | ) 97 | 98 | parser.add_argument( 99 | "--ddim_steps", 100 | type=int, 101 | default=200, 102 | help="number of ddim sampling steps", 103 | ) 104 | 105 | parser.add_argument( 106 | "--gt_scale", 107 | type=bool, 108 | default=False, 109 | help="whether have access to ground truth scale", 110 | ) 111 | 112 | 113 | opt = parser.parse_args() 114 | 115 | config = OmegaConf.load(opt.config_path) 116 | 117 | print("evaluate model with parameterization of {}".format(config.model.params.parameterization)) 118 | 119 | model_path = opt.model_path 120 | 121 | ckpt_name = model_path.split('/')[-1].split('.')[0] 122 | 123 | print("running generation on {}".format(model_path)) 124 | model = load_model_from_config(config, model_path) 125 | 126 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 127 | model = model.to(device) 128 | 129 | with open(opt.split_path, 'r') as f: 130 | split_lines = f.read().splitlines() 131 | 132 | sampler = DDIMSampler(model) 133 | 134 | output_path = os.path.join(opt.outdir, ckpt_name) 135 | print("Saving predictions to {}".format(output_path)) 136 | os.makedirs(output_path, exist_ok=True) 137 | 138 | all_samples = list() 139 | 140 | prediction = {} 141 | 142 | min_diffs = [] 143 | mean_diffs = [] 144 | 145 | scale_infos = {} 146 | for line in tqdm(split_lines): 147 | scale_infos[line] = {} 148 | 149 | rz = torchvision.transforms.Resize(size=(30, 40)) 150 | 151 | for i in range(opt.num_iters): 152 | with torch.no_grad(): 153 | with model.ema_scope(): 154 | for line in tqdm(split_lines): 155 | # read the target frame 156 | scene_id, frame_idx = line.split() 157 | scene_info = scene_id + '_' + frame_idx 158 | frame_id, mesh_id, inst_id = frame_idx.split('_') 159 | prediction[scene_info] = {} 160 | 161 | if opt.gt_scale: 162 | sensor_depth_path = os.path.join(opt.data_path, 'SensorDepth', scene_id, frame_id + '.png') 163 | sensor_depth = cv2.imread(sensor_depth_path, -1) / 1000.0 164 | 165 | sensor_depth = torch.from_numpy(sensor_depth).unsqueeze(0).unsqueeze(0) 166 | sensor_depth = torch.nn.functional.interpolate(sensor_depth, size=(360, 480)) 167 | sensor_depth = sensor_depth.squeeze(0).squeeze(0).detach().cpu().numpy() 168 | 169 | depth_input = torch.from_numpy(sensor_depth)[None] 170 | depth_input = rz(depth_input) 171 | 172 | depth_fname = os.path.join(opt.data_path, "ZoeDepthPredictions", scene_id, "{}_pred_dmap{}".format(frame_id, '.npy')) 173 | depth = np.load(depth_fname) 174 | 175 | depth_pred_input = torch.from_numpy(depth)[None] 176 | depth_pred_input = rz(depth_pred_input).to(device) 177 | 178 | mask_fname = os.path.join(opt.data_path, "ODISEPredictions_NEW", scene_id, opt.category, frame_idx+'.png') 179 | mask_full = cv2.imread(mask_fname, -1) / 255 180 | 181 | mask = torch.from_numpy(mask_full).unsqueeze(0).unsqueeze(0).float() 182 | kernel = torch.ones(3, 3) 183 | mask = kn.morphology.erosion(mask, kernel) 184 | mask_ero = mask.squeeze(0).squeeze(0).detach().cpu().numpy() 185 | 186 | pred_target_depth = depth * mask_ero 187 | 188 | if opt.gt_scale: 189 | target_depth = sensor_depth * mask_ero 190 | gt_scale = np.mean(pred_target_depth) / np.mean(target_depth) 191 | scale_infos[line]['gt'] = gt_scale 192 | 193 | cond = model.get_learned_conditioning(depth_pred_input.unsqueeze(0)).to(device) 194 | 195 | samples, _ = model.sample_log(cond=cond, batch_size=1, ddim=True, ddim_steps=200, eta=1.) 196 | 197 | pred_scales = model.decode_first_stage(samples).squeeze(0).squeeze(0) 198 | 199 | pred_scales = pred_scales.detach().cpu().numpy() 200 | 201 | pred_scale_mean = np.mean(pred_scales) 202 | 203 | scale_infos[line][str(int(i))] = pred_scale_mean + 1.1 204 | 205 | 206 | with open('{}/predictions.json'.format(output_path), 'w', encoding='utf-8') as f: 207 | json.dump(scale_infos, f, ensure_ascii=False, indent=2) 208 | -------------------------------------------------------------------------------- /scripts/generate_multi_shape_candidates.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.dirname(SCRIPT_DIR)) 6 | import torch 7 | import numpy as np 8 | from omegaconf import OmegaConf 9 | from tqdm import tqdm 10 | from util import instantiate_from_config 11 | import cv2 12 | import json 13 | import open3d as o3d 14 | 15 | 16 | def load_model_from_config(config, ckpt, verbose=False): 17 | print(f"Loading model from {ckpt}") 18 | pl_sd = torch.load(ckpt, map_location="cpu") 19 | sd = pl_sd["state_dict"] 20 | model = instantiate_from_config(config.model) 21 | m, u = model.load_state_dict(sd, strict=True) 22 | if len(m) > 0 and verbose: 23 | print("missing keys:") 24 | print(m) 25 | if len(u) > 0 and verbose: 26 | print("unexpected keys:") 27 | print(u) 28 | 29 | model.cuda() 30 | model.eval() 31 | return model 32 | 33 | 34 | 35 | if __name__ == "__main__": 36 | parser = argparse.ArgumentParser() 37 | 38 | parser.add_argument( 39 | "--category", 40 | type=str, 41 | default="02818832" 42 | ) 43 | 44 | parser.add_argument( 45 | "--config_path", 46 | type=str, 47 | ) 48 | 49 | parser.add_argument( 50 | "--model_path", 51 | type=str, 52 | ) 53 | parser.add_argument( 54 | "--data_path", 55 | type=str, 56 | ) 57 | parser.add_argument( 58 | "--ply_path", 59 | type=str, 60 | help="path to the predicted NOCs, which are used as input to the model" 61 | ) 62 | parser.add_argument( 63 | "--num_iters", 64 | type=int, 65 | default=3, 66 | help="number of different input NOCs", 67 | ) 68 | 69 | parser.add_argument( 70 | "--latent_root", 71 | type=str, 72 | help="path to the latent codes for retrieval" 73 | ) 74 | 75 | 76 | parser.add_argument( 77 | "--split_path", 78 | type=str, 79 | help="read the split", 80 | ) 81 | 82 | parser.add_argument( 83 | "--outdir", 84 | type=str, 85 | nargs="?", 86 | help="dir to write results to", 87 | ) 88 | 89 | 90 | opt = parser.parse_args() 91 | 92 | config = OmegaConf.load(opt.config_path) 93 | 94 | print("evaluate model with parameterization of {}".format(config.model.params.parameterization)) 95 | 96 | model_basename = opt.model_path.split('/')[-3] 97 | model_path = opt.model_path 98 | 99 | ckpt_name = model_path.split('/')[-1].split('.')[0] 100 | 101 | print("running generation on {}".format(model_path)) 102 | model = load_model_from_config(config, model_path) 103 | 104 | device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") 105 | model = model.to(device) 106 | 107 | with open(opt.split_path, 'r') as f: 108 | split_lines = f.read().splitlines() 109 | 110 | with open(os.path.join(opt.data_path, 'CAD_Pools/scan2cad_cad_pool_{}.json'.format(opt.category)), 'r') as f: 111 | cad_pool_s2c = json.load(f) 112 | 113 | latent_joint = [] 114 | latent_joint_ids = [] 115 | for cad_id in cad_pool_s2c['train']: 116 | latent_joint_ids.append(cad_id) 117 | latent_joint.append(torch.load( 118 | os.path.join(opt.latent_root, cad_id+'.pt'), map_location='cpu').squeeze(0)) 119 | latent_joint = torch.stack(latent_joint).to(device) 120 | 121 | print("Loaded total {} latents".format(latent_joint.shape[0])) 122 | 123 | output_path = os.path.join(opt.outdir, model_basename, ckpt_name) 124 | os.makedirs(output_path, exist_ok=True) 125 | 126 | all_samples = list() 127 | 128 | cham_dists_l1 = [] 129 | cham_dists_l2 = [] 130 | cham_dists_l1_eu = [] 131 | cham_dists_l2_eu = [] 132 | inference_results = {} 133 | 134 | for line in tqdm(split_lines): 135 | scene_id, frame_idx = line.split() 136 | scene_info = scene_id + '_' + frame_idx 137 | inference_results[scene_info] = {} 138 | 139 | for line in tqdm(split_lines): 140 | cond_batch = [] 141 | for i in range(opt.num_iters): 142 | scene_id, frame_info = line.split() 143 | scene_info = scene_id + '_' + frame_info 144 | frame_id, latent_gt_idx, instance_id = frame_info.split('_') 145 | 146 | nocs_fname = os.path.join(opt.ply_path, scene_info + '_nocs_pred_{}.ply'.format(i)) 147 | nocs_ply = o3d.io.read_point_cloud(nocs_fname) 148 | points3d = np.asarray(nocs_ply.points) 149 | 150 | sample_num = 1024 151 | indices = np.random.randint(points3d.shape[0], size=sample_num) 152 | nocs_pc = points3d[indices, :] 153 | nocs_pc = torch.from_numpy(nocs_pc) # num_points, 3 154 | nocs_pc = nocs_pc[None].to(device).float() 155 | cond_batch.append(nocs_pc) 156 | 157 | cond_batch = torch.cat(cond_batch, dim=0) 158 | with torch.no_grad(): 159 | with model.ema_scope(): 160 | c = model.get_learned_conditioning(cond_batch).float() 161 | samples, _ = model.sample(cond=c, batch_size=opt.num_iters, return_intermediates=True) # B 256 1 1 162 | x_samples = model.decode_first_stage(samples) 163 | 164 | dists = torch.zeros(latent_joint.shape[0]).to(device) 165 | 166 | for x in range(opt.num_iters): 167 | for j in range(latent_joint.shape[0]): 168 | dists[j] = torch.nn.functional.cosine_similarity(x_samples[x].squeeze(-1).squeeze(-1), latent_joint[j], dim=0) 169 | top_idx = latent_joint_ids[torch.argmax(dists).item()] 170 | retrieved_id = top_idx 171 | gt_id = latent_gt_idx 172 | 173 | query_lat_eu = x_samples[x].unsqueeze(0).squeeze(-1).squeeze(-1) 174 | eu_dists = torch.cdist(latent_joint, query_lat_eu, p=2).squeeze(-1) # N 175 | ret_id_eu = latent_joint_ids[torch.argmin(eu_dists).item()] 176 | 177 | inference_results[scene_info]['gt_latent_idx'] = gt_id 178 | inference_results[scene_info]['retrieved_latent_idx_{}'.format(x)] = retrieved_id # retrieval based on cosine similarity 179 | inference_results[scene_info]['retrieved_latent_idx_{}_eu'.format(x)] = ret_id_eu # retrieval based on euclidean distance 180 | 181 | print("Saving results to {}".format(output_path)) 182 | 183 | with open(os.path.join(output_path, 'inference_results_.json'), 'w', encoding='utf-8') as f: 184 | json.dump(inference_results, f, ensure_ascii=False, indent=2) 185 | -------------------------------------------------------------------------------- /splits/pose/02818832/val_02818832.txt: -------------------------------------------------------------------------------- 1 | scene0699_00 000000_22b8e1805041fe56010a6840f668b41_0 2 | scene0699_00 000100_22b8e1805041fe56010a6840f668b41_0 3 | scene0699_00 000400_22b8e1805041fe56010a6840f668b41_0 4 | scene0144_01 000700_f10984c7255bc5b25519d54a714fac86_0 5 | scene0353_02 000500_698ef3d1a8c0c829c580fdeb5460f6d6_0 6 | scene0353_02 000700_698ef3d1a8c0c829c580fdeb5460f6d6_1 7 | scene0217_00 000000_3acfa3c60a03415643abcff1f32a8b0c_1 8 | scene0217_00 000500_3acfa3c60a03415643abcff1f32a8b0c_0 9 | scene0217_00 000700_3acfa3c60a03415643abcff1f32a8b0c_1 10 | scene0217_00 000800_3acfa3c60a03415643abcff1f32a8b0c_1 11 | scene0193_01 000300_51223b1b770ff5e72f38f5bd71072746_0 12 | scene0046_02 000100_d7b9238af2efa963c862eec8232fff1e_0 13 | scene0046_02 000200_d7b9238af2efa963c862eec8232fff1e_0 14 | scene0046_02 000700_d7b9238af2efa963c862eec8232fff1e_0 15 | scene0046_02 000800_d7b9238af2efa963c862eec8232fff1e_0 16 | scene0046_02 002200_d7b9238af2efa963c862eec8232fff1e_0 17 | scene0046_02 002500_d7b9238af2efa963c862eec8232fff1e_0 18 | scene0645_01 000100_4dbd37cb85686dea674ce64e4bf77aec_0 19 | scene0645_01 000300_4dbd37cb85686dea674ce64e4bf77aec_1 20 | scene0645_01 000400_4dbd37cb85686dea674ce64e4bf77aec_1 21 | scene0645_01 003500_4dbd37cb85686dea674ce64e4bf77aec_0 22 | scene0645_01 003600_4dbd37cb85686dea674ce64e4bf77aec_0 23 | scene0645_01 005000_4dbd37cb85686dea674ce64e4bf77aec_0 24 | scene0645_00 000100_4dbd37cb85686dea674ce64e4bf77aec_0 25 | scene0645_00 000200_4dbd37cb85686dea674ce64e4bf77aec_0 26 | scene0645_00 000300_4dbd37cb85686dea674ce64e4bf77aec_1 27 | scene0645_00 000400_4dbd37cb85686dea674ce64e4bf77aec_1 28 | scene0645_00 000600_4dbd37cb85686dea674ce64e4bf77aec_1 29 | scene0193_00 000300_c758919a1fbe3d0c9cef17528faf7bc5_0 30 | scene0193_00 000700_c758919a1fbe3d0c9cef17528faf7bc5_0 31 | scene0658_00 000200_c2b65540d51fada22bfa21768064df9c_0 32 | scene0426_03 000600_946bd5aeda453b38b3a454ed6a7199e2_0 33 | scene0046_00 000300_f7d2cf0ebbf5453531cd8798c40e5949_0 34 | scene0046_00 000800_f7d2cf0ebbf5453531cd8798c40e5949_0 35 | scene0046_00 000900_f7d2cf0ebbf5453531cd8798c40e5949_0 36 | scene0221_00 000300_9fb6014c9944a98bd2096b2fa6f98cc7_0 37 | scene0697_02 000300_4dbd37cb85686dea674ce64e4bf77aec_0 38 | scene0697_02 000500_4dbd37cb85686dea674ce64e4bf77aec_0 39 | scene0697_02 001600_4dbd37cb85686dea674ce64e4bf77aec_0 40 | scene0697_02 002600_4dbd37cb85686dea674ce64e4bf77aec_0 41 | scene0697_02 002700_4dbd37cb85686dea674ce64e4bf77aec_0 42 | scene0144_00 000800_4dbd37cb85686dea674ce64e4bf77aec_0 43 | scene0426_00 001100_ce0f3c9d6a0b0cda71010004e0594e66_0 44 | scene0695_03 000400_1f11b3d9953fabcf8b4396b18c85cf0f_0 45 | scene0695_03 000900_1f11b3d9953fabcf8b4396b18c85cf0f_0 46 | scene0695_03 002600_1f11b3d9953fabcf8b4396b18c85cf0f_0 47 | scene0580_01 000400_5d12d1a313cff5ad66f379f51753f72b_0 48 | scene0580_01 000500_5d12d1a313cff5ad66f379f51753f72b_0 49 | scene0580_01 000600_5d12d1a313cff5ad66f379f51753f72b_0 50 | scene0580_01 000800_5d12d1a313cff5ad66f379f51753f72b_0 51 | scene0580_01 002100_5d12d1a313cff5ad66f379f51753f72b_0 52 | scene0580_01 003600_5d12d1a313cff5ad66f379f51753f72b_0 53 | scene0580_01 003800_5d12d1a313cff5ad66f379f51753f72b_0 54 | scene0353_00 000800_3acfa3c60a03415643abcff1f32a8b0c_1 55 | scene0353_00 001900_3acfa3c60a03415643abcff1f32a8b0c_0 56 | scene0435_00 000800_d7b9238af2efa963c862eec8232fff1e_0 57 | scene0435_00 001000_d7b9238af2efa963c862eec8232fff1e_0 58 | scene0435_00 001100_d7b9238af2efa963c862eec8232fff1e_0 59 | scene0435_00 001200_76db17c76f828282dcb2f14e2e42ec8d_1 60 | scene0435_00 001300_76db17c76f828282dcb2f14e2e42ec8d_1 61 | scene0435_00 001400_76db17c76f828282dcb2f14e2e42ec8d_1 62 | scene0695_02 000000_48973f489d06e8139f9d5a5f7267a470_0 63 | scene0695_02 001800_48973f489d06e8139f9d5a5f7267a470_0 64 | scene0695_02 002000_48973f489d06e8139f9d5a5f7267a470_0 65 | scene0629_02 000100_946bd5aeda453b38b3a454ed6a7199e2_0 66 | scene0629_02 000200_946bd5aeda453b38b3a454ed6a7199e2_0 67 | scene0629_02 001400_946bd5aeda453b38b3a454ed6a7199e2_0 68 | scene0629_02 001500_946bd5aeda453b38b3a454ed6a7199e2_0 69 | scene0629_02 001600_946bd5aeda453b38b3a454ed6a7199e2_0 70 | scene0046_01 000500_d7b9238af2efa963c862eec8232fff1e_0 71 | scene0046_01 000600_d7b9238af2efa963c862eec8232fff1e_0 72 | scene0046_01 002200_d7b9238af2efa963c862eec8232fff1e_0 73 | scene0580_00 000100_946bd5aeda453b38b3a454ed6a7199e2_0 74 | scene0580_00 000700_946bd5aeda453b38b3a454ed6a7199e2_0 75 | scene0580_00 001300_946bd5aeda453b38b3a454ed6a7199e2_0 76 | scene0580_00 001400_946bd5aeda453b38b3a454ed6a7199e2_0 77 | scene0580_00 001500_946bd5aeda453b38b3a454ed6a7199e2_0 78 | scene0580_00 004600_946bd5aeda453b38b3a454ed6a7199e2_0 79 | scene0246_00 000500_4dbd37cb85686dea674ce64e4bf77aec_0 80 | scene0246_00 002000_4dbd37cb85686dea674ce64e4bf77aec_0 81 | scene0246_00 002600_4dbd37cb85686dea674ce64e4bf77aec_0 82 | scene0697_01 000300_4dbd37cb85686dea674ce64e4bf77aec_0 83 | scene0697_01 001400_4dbd37cb85686dea674ce64e4bf77aec_0 84 | scene0697_01 001600_4dbd37cb85686dea674ce64e4bf77aec_0 85 | scene0645_02 000100_4dbd37cb85686dea674ce64e4bf77aec_1 86 | scene0645_02 000200_4dbd37cb85686dea674ce64e4bf77aec_1 87 | scene0645_02 002600_4dbd37cb85686dea674ce64e4bf77aec_0 88 | scene0353_01 000400_8df7e58200ac5e6ab91b871e750ca615_0 89 | scene0353_01 000600_8df7e58200ac5e6ab91b871e750ca615_0 90 | scene0353_01 000900_8df7e58200ac5e6ab91b871e750ca615_1 91 | scene0256_02 000000_e7d0920ba8d4b1be71424c004dd7ab2f_0 92 | scene0256_02 000100_e7d0920ba8d4b1be71424c004dd7ab2f_0 93 | scene0256_02 000600_e7d0920ba8d4b1be71424c004dd7ab2f_0 94 | scene0696_02 000100_946bd5aeda453b38b3a454ed6a7199e2_0 95 | scene0696_02 001300_946bd5aeda453b38b3a454ed6a7199e2_0 96 | scene0435_01 000700_e91c2df09de0d4b1ed4d676215f46734_0 97 | scene0435_01 000800_e91c2df09de0d4b1ed4d676215f46734_0 98 | scene0435_01 000800_e91c2df09de0d4b1ed4d676215f46734_1 99 | scene0435_01 000900_e91c2df09de0d4b1ed4d676215f46734_0 100 | scene0435_01 001000_e91c2df09de0d4b1ed4d676215f46734_0 101 | scene0435_01 001100_e91c2df09de0d4b1ed4d676215f46734_1 102 | scene0435_01 001200_e91c2df09de0d4b1ed4d676215f46734_1 103 | scene0222_01 000400_6e5f10f2574f8a285d64ca7820a9c2ca_1 104 | scene0222_01 001500_6e5f10f2574f8a285d64ca7820a9c2ca_1 105 | scene0222_01 003600_6e5f10f2574f8a285d64ca7820a9c2ca_0 106 | scene0222_01 004100_6e5f10f2574f8a285d64ca7820a9c2ca_0 107 | scene0277_01 000100_3acfa3c60a03415643abcff1f32a8b0c_0 108 | scene0277_01 000700_3acfa3c60a03415643abcff1f32a8b0c_0 109 | scene0277_01 000800_3acfa3c60a03415643abcff1f32a8b0c_0 110 | scene0435_02 000600_9fb6014c9944a98bd2096b2fa6f98cc7_0 111 | scene0435_02 000700_9fb6014c9944a98bd2096b2fa6f98cc7_0 112 | scene0435_02 000800_9fb6014c9944a98bd2096b2fa6f98cc7_1 113 | scene0435_02 000900_9fb6014c9944a98bd2096b2fa6f98cc7_0 114 | scene0435_02 001100_9fb6014c9944a98bd2096b2fa6f98cc7_1 115 | scene0697_00 001000_4dbd37cb85686dea674ce64e4bf77aec_0 116 | scene0696_01 000600_946bd5aeda453b38b3a454ed6a7199e2_0 117 | scene0696_01 001300_946bd5aeda453b38b3a454ed6a7199e2_0 118 | scene0696_01 001800_946bd5aeda453b38b3a454ed6a7199e2_0 119 | scene0652_00 000500_4dbd37cb85686dea674ce64e4bf77aec_0 120 | scene0652_00 001300_4dbd37cb85686dea674ce64e4bf77aec_0 121 | scene0356_02 000200_3acfa3c60a03415643abcff1f32a8b0c_0 122 | scene0356_02 000800_3acfa3c60a03415643abcff1f32a8b0c_0 123 | scene0207_02 000000_4dbd37cb85686dea674ce64e4bf77aec_0 124 | scene0207_02 000500_4dbd37cb85686dea674ce64e4bf77aec_0 125 | scene0207_02 000600_4dbd37cb85686dea674ce64e4bf77aec_0 126 | scene0207_00 000000_4dbd37cb85686dea674ce64e4bf77aec_0 127 | scene0207_00 000100_4dbd37cb85686dea674ce64e4bf77aec_0 128 | scene0207_00 000400_4dbd37cb85686dea674ce64e4bf77aec_0 129 | scene0648_00 000000_2d1a2be896054548997e2c877588ae24_0 130 | scene0648_00 000600_2d1a2be896054548997e2c877588ae24_0 131 | scene0648_00 001100_2d1a2be896054548997e2c877588ae24_1 132 | scene0648_00 002500_2d1a2be896054548997e2c877588ae24_1 133 | scene0648_00 002800_2d1a2be896054548997e2c877588ae24_0 134 | scene0648_00 003500_2d1a2be896054548997e2c877588ae24_1 135 | scene0648_00 003800_2d1a2be896054548997e2c877588ae24_0 136 | scene0648_00 003900_2d1a2be896054548997e2c877588ae24_0 137 | scene0277_00 000400_6e4707cac21b09f0531c83488903771b_0 138 | scene0277_00 001000_6e4707cac21b09f0531c83488903771b_0 139 | scene0435_03 000600_e91c2df09de0d4b1ed4d676215f46734_0 140 | scene0435_03 000700_e91c2df09de0d4b1ed4d676215f46734_1 141 | scene0435_03 000800_e91c2df09de0d4b1ed4d676215f46734_0 142 | scene0435_03 000900_e91c2df09de0d4b1ed4d676215f46734_0 143 | scene0435_03 001300_e91c2df09de0d4b1ed4d676215f46734_1 144 | scene0435_03 001400_e91c2df09de0d4b1ed4d676215f46734_1 145 | scene0648_01 001900_b9302be3dc846d834f0ba81bea651144_1 146 | scene0648_01 002200_b9302be3dc846d834f0ba81bea651144_1 147 | scene0648_01 002300_b9302be3dc846d834f0ba81bea651144_0 148 | scene0648_01 002900_b9302be3dc846d834f0ba81bea651144_1 149 | scene0648_01 003000_b9302be3dc846d834f0ba81bea651144_1 150 | scene0382_01 000500_edf13191dacf07af42d7295fb0533ac0_0 151 | scene0695_00 000800_3acfa3c60a03415643abcff1f32a8b0c_0 152 | scene0695_00 000900_3acfa3c60a03415643abcff1f32a8b0c_0 153 | scene0695_00 001800_3acfa3c60a03415643abcff1f32a8b0c_0 154 | scene0633_01 000000_76db17c76f828282dcb2f14e2e42ec8d_0 155 | scene0633_01 000100_76db17c76f828282dcb2f14e2e42ec8d_0 156 | scene0633_01 001100_76db17c76f828282dcb2f14e2e42ec8d_0 157 | scene0356_00 000200_1f11b3d9953fabcf8b4396b18c85cf0f_0 158 | scene0356_00 000900_1f11b3d9953fabcf8b4396b18c85cf0f_0 159 | scene0697_03 001500_4dbd37cb85686dea674ce64e4bf77aec_0 160 | scene0697_03 001600_4dbd37cb85686dea674ce64e4bf77aec_0 161 | scene0697_03 001700_4dbd37cb85686dea674ce64e4bf77aec_0 162 | -------------------------------------------------------------------------------- /splits/pose/02818832/val_nonocc_centroid_maskexist.txt: -------------------------------------------------------------------------------- 1 | scene0699_00 000000_22b8e1805041fe56010a6840f668b41_0 2 | scene0699_00 000100_22b8e1805041fe56010a6840f668b41_0 3 | scene0699_00 000400_22b8e1805041fe56010a6840f668b41_0 4 | scene0353_02 000500_698ef3d1a8c0c829c580fdeb5460f6d6_0 5 | scene0353_02 000700_698ef3d1a8c0c829c580fdeb5460f6d6_1 6 | scene0217_00 000000_3acfa3c60a03415643abcff1f32a8b0c_1 7 | scene0217_00 000500_3acfa3c60a03415643abcff1f32a8b0c_0 8 | scene0217_00 000700_3acfa3c60a03415643abcff1f32a8b0c_1 9 | scene0217_00 000800_3acfa3c60a03415643abcff1f32a8b0c_1 10 | scene0193_01 000300_51223b1b770ff5e72f38f5bd71072746_0 11 | scene0046_02 000100_d7b9238af2efa963c862eec8232fff1e_0 12 | scene0046_02 000200_d7b9238af2efa963c862eec8232fff1e_0 13 | scene0046_02 000700_d7b9238af2efa963c862eec8232fff1e_0 14 | scene0046_02 000800_d7b9238af2efa963c862eec8232fff1e_0 15 | scene0046_02 002200_d7b9238af2efa963c862eec8232fff1e_0 16 | scene0046_02 002500_d7b9238af2efa963c862eec8232fff1e_0 17 | scene0645_01 000100_4dbd37cb85686dea674ce64e4bf77aec_0 18 | scene0645_01 000300_4dbd37cb85686dea674ce64e4bf77aec_1 19 | scene0645_01 000400_4dbd37cb85686dea674ce64e4bf77aec_1 20 | scene0645_01 003500_4dbd37cb85686dea674ce64e4bf77aec_0 21 | scene0645_01 003600_4dbd37cb85686dea674ce64e4bf77aec_0 22 | scene0645_01 005000_4dbd37cb85686dea674ce64e4bf77aec_0 23 | scene0645_00 000100_4dbd37cb85686dea674ce64e4bf77aec_0 24 | scene0645_00 000200_4dbd37cb85686dea674ce64e4bf77aec_0 25 | scene0645_00 000300_4dbd37cb85686dea674ce64e4bf77aec_1 26 | scene0645_00 000400_4dbd37cb85686dea674ce64e4bf77aec_1 27 | scene0645_00 000600_4dbd37cb85686dea674ce64e4bf77aec_1 28 | scene0193_00 000300_c758919a1fbe3d0c9cef17528faf7bc5_0 29 | scene0193_00 000700_c758919a1fbe3d0c9cef17528faf7bc5_0 30 | scene0658_00 000200_c2b65540d51fada22bfa21768064df9c_0 31 | scene0426_03 000600_946bd5aeda453b38b3a454ed6a7199e2_0 32 | scene0046_00 000300_f7d2cf0ebbf5453531cd8798c40e5949_0 33 | scene0046_00 000800_f7d2cf0ebbf5453531cd8798c40e5949_0 34 | scene0046_00 000900_f7d2cf0ebbf5453531cd8798c40e5949_0 35 | scene0221_00 000300_9fb6014c9944a98bd2096b2fa6f98cc7_0 36 | scene0697_02 000300_4dbd37cb85686dea674ce64e4bf77aec_0 37 | scene0697_02 000500_4dbd37cb85686dea674ce64e4bf77aec_0 38 | scene0697_02 001600_4dbd37cb85686dea674ce64e4bf77aec_0 39 | scene0697_02 002600_4dbd37cb85686dea674ce64e4bf77aec_0 40 | scene0697_02 002700_4dbd37cb85686dea674ce64e4bf77aec_0 41 | scene0144_00 000800_4dbd37cb85686dea674ce64e4bf77aec_0 42 | scene0426_00 001100_ce0f3c9d6a0b0cda71010004e0594e66_0 43 | scene0695_03 000400_1f11b3d9953fabcf8b4396b18c85cf0f_0 44 | scene0695_03 000900_1f11b3d9953fabcf8b4396b18c85cf0f_0 45 | scene0695_03 002600_1f11b3d9953fabcf8b4396b18c85cf0f_0 46 | scene0580_01 000400_5d12d1a313cff5ad66f379f51753f72b_0 47 | scene0580_01 000500_5d12d1a313cff5ad66f379f51753f72b_0 48 | scene0580_01 000600_5d12d1a313cff5ad66f379f51753f72b_0 49 | scene0580_01 000800_5d12d1a313cff5ad66f379f51753f72b_0 50 | scene0580_01 002100_5d12d1a313cff5ad66f379f51753f72b_0 51 | scene0580_01 003600_5d12d1a313cff5ad66f379f51753f72b_0 52 | scene0580_01 003800_5d12d1a313cff5ad66f379f51753f72b_0 53 | scene0353_00 000800_3acfa3c60a03415643abcff1f32a8b0c_1 54 | scene0353_00 001900_3acfa3c60a03415643abcff1f32a8b0c_0 55 | scene0435_00 000800_d7b9238af2efa963c862eec8232fff1e_0 56 | scene0435_00 001000_d7b9238af2efa963c862eec8232fff1e_0 57 | scene0435_00 001100_d7b9238af2efa963c862eec8232fff1e_0 58 | scene0435_00 001200_76db17c76f828282dcb2f14e2e42ec8d_1 59 | scene0435_00 001300_76db17c76f828282dcb2f14e2e42ec8d_1 60 | scene0435_00 001400_76db17c76f828282dcb2f14e2e42ec8d_1 61 | scene0695_02 000000_48973f489d06e8139f9d5a5f7267a470_0 62 | scene0695_02 001800_48973f489d06e8139f9d5a5f7267a470_0 63 | scene0695_02 002000_48973f489d06e8139f9d5a5f7267a470_0 64 | scene0629_02 000200_946bd5aeda453b38b3a454ed6a7199e2_0 65 | scene0629_02 001400_946bd5aeda453b38b3a454ed6a7199e2_0 66 | scene0629_02 001500_946bd5aeda453b38b3a454ed6a7199e2_0 67 | scene0629_02 001600_946bd5aeda453b38b3a454ed6a7199e2_0 68 | scene0046_01 000500_d7b9238af2efa963c862eec8232fff1e_0 69 | scene0046_01 000600_d7b9238af2efa963c862eec8232fff1e_0 70 | scene0046_01 002200_d7b9238af2efa963c862eec8232fff1e_0 71 | scene0580_00 000100_946bd5aeda453b38b3a454ed6a7199e2_0 72 | scene0580_00 000700_946bd5aeda453b38b3a454ed6a7199e2_0 73 | scene0580_00 001300_946bd5aeda453b38b3a454ed6a7199e2_0 74 | scene0580_00 001400_946bd5aeda453b38b3a454ed6a7199e2_0 75 | scene0580_00 001500_946bd5aeda453b38b3a454ed6a7199e2_0 76 | scene0580_00 004600_946bd5aeda453b38b3a454ed6a7199e2_0 77 | scene0246_00 000500_4dbd37cb85686dea674ce64e4bf77aec_0 78 | scene0246_00 002000_4dbd37cb85686dea674ce64e4bf77aec_0 79 | scene0246_00 002600_4dbd37cb85686dea674ce64e4bf77aec_0 80 | scene0697_01 000300_4dbd37cb85686dea674ce64e4bf77aec_0 81 | scene0697_01 001400_4dbd37cb85686dea674ce64e4bf77aec_0 82 | scene0697_01 001600_4dbd37cb85686dea674ce64e4bf77aec_0 83 | scene0645_02 000100_4dbd37cb85686dea674ce64e4bf77aec_1 84 | scene0645_02 000200_4dbd37cb85686dea674ce64e4bf77aec_1 85 | scene0645_02 002600_4dbd37cb85686dea674ce64e4bf77aec_0 86 | scene0353_01 000400_8df7e58200ac5e6ab91b871e750ca615_0 87 | scene0353_01 000600_8df7e58200ac5e6ab91b871e750ca615_0 88 | scene0353_01 000900_8df7e58200ac5e6ab91b871e750ca615_1 89 | scene0256_02 000000_e7d0920ba8d4b1be71424c004dd7ab2f_0 90 | scene0256_02 000100_e7d0920ba8d4b1be71424c004dd7ab2f_0 91 | scene0256_02 000600_e7d0920ba8d4b1be71424c004dd7ab2f_0 92 | scene0696_02 000100_946bd5aeda453b38b3a454ed6a7199e2_0 93 | scene0696_02 001300_946bd5aeda453b38b3a454ed6a7199e2_0 94 | scene0435_01 000700_e91c2df09de0d4b1ed4d676215f46734_0 95 | scene0435_01 000800_e91c2df09de0d4b1ed4d676215f46734_0 96 | scene0435_01 000800_e91c2df09de0d4b1ed4d676215f46734_1 97 | scene0435_01 000900_e91c2df09de0d4b1ed4d676215f46734_0 98 | scene0435_01 001000_e91c2df09de0d4b1ed4d676215f46734_0 99 | scene0435_01 001100_e91c2df09de0d4b1ed4d676215f46734_1 100 | scene0435_01 001200_e91c2df09de0d4b1ed4d676215f46734_1 101 | scene0222_01 000400_6e5f10f2574f8a285d64ca7820a9c2ca_1 102 | scene0222_01 001500_6e5f10f2574f8a285d64ca7820a9c2ca_1 103 | scene0222_01 003600_6e5f10f2574f8a285d64ca7820a9c2ca_0 104 | scene0222_01 004100_6e5f10f2574f8a285d64ca7820a9c2ca_0 105 | scene0277_01 000100_3acfa3c60a03415643abcff1f32a8b0c_0 106 | scene0277_01 000700_3acfa3c60a03415643abcff1f32a8b0c_0 107 | scene0277_01 000800_3acfa3c60a03415643abcff1f32a8b0c_0 108 | scene0435_02 000600_9fb6014c9944a98bd2096b2fa6f98cc7_0 109 | scene0435_02 000700_9fb6014c9944a98bd2096b2fa6f98cc7_0 110 | scene0435_02 000800_9fb6014c9944a98bd2096b2fa6f98cc7_1 111 | scene0435_02 000900_9fb6014c9944a98bd2096b2fa6f98cc7_0 112 | scene0435_02 001100_9fb6014c9944a98bd2096b2fa6f98cc7_1 113 | scene0697_00 001000_4dbd37cb85686dea674ce64e4bf77aec_0 114 | scene0652_00 000500_4dbd37cb85686dea674ce64e4bf77aec_0 115 | scene0652_00 001300_4dbd37cb85686dea674ce64e4bf77aec_0 116 | scene0356_02 000200_3acfa3c60a03415643abcff1f32a8b0c_0 117 | scene0356_02 000800_3acfa3c60a03415643abcff1f32a8b0c_0 118 | scene0207_02 000000_4dbd37cb85686dea674ce64e4bf77aec_0 119 | scene0207_02 000500_4dbd37cb85686dea674ce64e4bf77aec_0 120 | scene0207_02 000600_4dbd37cb85686dea674ce64e4bf77aec_0 121 | scene0207_00 000000_4dbd37cb85686dea674ce64e4bf77aec_0 122 | scene0207_00 000100_4dbd37cb85686dea674ce64e4bf77aec_0 123 | scene0207_00 000400_4dbd37cb85686dea674ce64e4bf77aec_0 124 | scene0648_00 000000_2d1a2be896054548997e2c877588ae24_0 125 | scene0648_00 000600_2d1a2be896054548997e2c877588ae24_0 126 | scene0648_00 001100_2d1a2be896054548997e2c877588ae24_1 127 | scene0648_00 002500_2d1a2be896054548997e2c877588ae24_1 128 | scene0648_00 002800_2d1a2be896054548997e2c877588ae24_0 129 | scene0648_00 003500_2d1a2be896054548997e2c877588ae24_1 130 | scene0648_00 003800_2d1a2be896054548997e2c877588ae24_0 131 | scene0648_00 003900_2d1a2be896054548997e2c877588ae24_0 132 | scene0277_00 000400_6e4707cac21b09f0531c83488903771b_0 133 | scene0277_00 001000_6e4707cac21b09f0531c83488903771b_0 134 | scene0435_03 000600_e91c2df09de0d4b1ed4d676215f46734_0 135 | scene0435_03 000700_e91c2df09de0d4b1ed4d676215f46734_1 136 | scene0435_03 000800_e91c2df09de0d4b1ed4d676215f46734_0 137 | scene0435_03 000900_e91c2df09de0d4b1ed4d676215f46734_0 138 | scene0435_03 001300_e91c2df09de0d4b1ed4d676215f46734_1 139 | scene0435_03 001400_e91c2df09de0d4b1ed4d676215f46734_1 140 | scene0648_01 001900_b9302be3dc846d834f0ba81bea651144_1 141 | scene0648_01 002200_b9302be3dc846d834f0ba81bea651144_1 142 | scene0648_01 002300_b9302be3dc846d834f0ba81bea651144_0 143 | scene0648_01 002900_b9302be3dc846d834f0ba81bea651144_1 144 | scene0648_01 003000_b9302be3dc846d834f0ba81bea651144_1 145 | scene0695_00 000800_3acfa3c60a03415643abcff1f32a8b0c_0 146 | scene0695_00 000900_3acfa3c60a03415643abcff1f32a8b0c_0 147 | scene0695_00 001800_3acfa3c60a03415643abcff1f32a8b0c_0 148 | scene0633_01 000000_76db17c76f828282dcb2f14e2e42ec8d_0 149 | scene0633_01 000100_76db17c76f828282dcb2f14e2e42ec8d_0 150 | scene0633_01 001100_76db17c76f828282dcb2f14e2e42ec8d_0 151 | scene0356_00 000200_1f11b3d9953fabcf8b4396b18c85cf0f_0 152 | scene0356_00 000900_1f11b3d9953fabcf8b4396b18c85cf0f_0 153 | scene0697_03 001500_4dbd37cb85686dea674ce64e4bf77aec_0 154 | scene0697_03 001600_4dbd37cb85686dea674ce64e4bf77aec_0 155 | scene0697_03 001700_4dbd37cb85686dea674ce64e4bf77aec_0 156 | -------------------------------------------------------------------------------- /splits/pose/02871439/val_02871439.txt: -------------------------------------------------------------------------------- 1 | scene0593_00 000000_601359d274c2c00a1497d160eced5e7a_0 2 | scene0593_00 000700_601359d274c2c00a1497d160eced5e7a_0 3 | scene0593_00 000800_601359d274c2c00a1497d160eced5e7a_0 4 | scene0593_00 001300_601359d274c2c00a1497d160eced5e7a_0 5 | scene0593_00 001800_601359d274c2c00a1497d160eced5e7a_0 6 | scene0593_00 001900_601359d274c2c00a1497d160eced5e7a_0 7 | scene0700_01 000300_722624d2cd4b72018ac5263758737a81_0 8 | scene0700_01 003000_722624d2cd4b72018ac5263758737a81_0 9 | scene0700_01 003200_722624d2cd4b72018ac5263758737a81_0 10 | scene0700_01 003300_722624d2cd4b72018ac5263758737a81_0 11 | scene0558_01 000100_4e26a2e39e3b6d39961b70a6f96df2a4_3 12 | scene0558_01 000400_4e26a2e39e3b6d39961b70a6f96df2a4_2 13 | scene0378_02 001100_b1696f5b9b3926b1a523e28192de797e_1 14 | scene0378_02 001200_b1696f5b9b3926b1a523e28192de797e_1 15 | scene0378_02 001300_b1696f5b9b3926b1a523e28192de797e_1 16 | scene0378_02 001400_ae26bec5a79f51943da27ece6ae88fff_0 17 | scene0378_02 001500_ae26bec5a79f51943da27ece6ae88fff_0 18 | scene0378_02 001600_ae26bec5a79f51943da27ece6ae88fff_0 19 | scene0378_02 001700_ae26bec5a79f51943da27ece6ae88fff_0 20 | scene0353_02 000000_1d3014ad5c35944f9af68b4b3261e1f8_1 21 | scene0203_00 000400_c214e1d190cb87362a9cd5247487b619_3 22 | scene0704_00 000600_9ca929d28c1838d5e41994ffd448fd07_0 23 | scene0704_00 002000_9ca929d28c1838d5e41994ffd448fd07_0 24 | scene0704_00 002200_9ca929d28c1838d5e41994ffd448fd07_0 25 | scene0704_00 002300_9ca929d28c1838d5e41994ffd448fd07_0 26 | scene0329_01 000300_4996e0c9dac4ba3649fade4fe2abc936_0 27 | scene0329_01 000400_4996e0c9dac4ba3649fade4fe2abc936_0 28 | scene0591_01 000000_b6264b92d0cbd598d5919aa833abb5a_1 29 | scene0591_01 000100_ad808d321f7cd914c6d17bc3a482f77_2 30 | scene0591_01 000200_ad808d321f7cd914c6d17bc3a482f77_0 31 | scene0591_01 000300_ad808d321f7cd914c6d17bc3a482f77_2 32 | scene0591_01 000900_ad808d321f7cd914c6d17bc3a482f77_2 33 | scene0591_01 001700_ad808d321f7cd914c6d17bc3a482f77_0 34 | scene0474_00 000000_c007d1b4972f70102751be1fc72418bb_0 35 | scene0474_00 000100_c007d1b4972f70102751be1fc72418bb_0 36 | scene0474_00 000900_c007d1b4972f70102751be1fc72418bb_0 37 | scene0030_02 001700_89f1540881171ee3f4c4977ed0ba5296_4 38 | scene0598_02 000000_601b64e8e27159e157da56e4c9ff868d_3 39 | scene0598_02 000100_601b64e8e27159e157da56e4c9ff868d_3 40 | scene0598_02 000400_601b64e8e27159e157da56e4c9ff868d_4 41 | scene0598_02 000600_601b64e8e27159e157da56e4c9ff868d_3 42 | scene0598_02 000700_601b64e8e27159e157da56e4c9ff868d_3 43 | scene0598_02 001100_601b64e8e27159e157da56e4c9ff868d_0 44 | scene0598_02 001300_601b64e8e27159e157da56e4c9ff868d_4 45 | scene0231_02 000300_3edaff50f5355da2666ec9bdd91ab5f1_1 46 | scene0231_02 000500_c214e1d190cb87362a9cd5247487b619_0 47 | scene0231_02 000800_c214e1d190cb87362a9cd5247487b619_0 48 | scene0231_02 002400_c214e1d190cb87362a9cd5247487b619_0 49 | scene0231_02 002800_c214e1d190cb87362a9cd5247487b619_0 50 | scene0658_00 000000_c214e1d190cb87362a9cd5247487b619_1 51 | scene0658_00 000300_c214e1d190cb87362a9cd5247487b619_0 52 | scene0643_00 001600_3fa60816f15b58c4607974568e26586f_1 53 | scene0568_02 000400_4996e0c9dac4ba3649fade4fe2abc936_0 54 | scene0568_02 000500_4996e0c9dac4ba3649fade4fe2abc936_0 55 | scene0474_01 001600_c007d1b4972f70102751be1fc72418bb_0 56 | scene0527_00 000100_35afa1b806556803e99c79bad29da781_0 57 | scene0203_01 000500_d70d80e5855785a761b9fd1751b9fcb_4 58 | scene0203_01 001100_e4c42bcbba4ef5b09a2d4f7bacd6e0d8_0 59 | scene0203_01 001300_e3ae56f176b77359aa1bd50387389420_1 60 | scene0695_03 001200_465bb5a9ed42ad40a5817f81a1efa3cc_0 61 | scene0695_03 002100_465bb5a9ed42ad40a5817f81a1efa3cc_0 62 | scene0598_00 000200_81b8784259e3331f94cdfc338037bd95_0 63 | scene0328_00 000900_b079feff448e925546c4f23965b7dd40_0 64 | scene0353_00 000000_ad808d321f7cd914c6d17bc3a482f77_0 65 | scene0353_00 001500_ad808d321f7cd914c6d17bc3a482f77_0 66 | scene0695_02 000300_1ab8202a944a6ff1de650492e45fb14f_0 67 | scene0695_02 000900_1ab8202a944a6ff1de650492e45fb14f_0 68 | scene0064_00 000500_b105d36dbf010903b022c94235bc8601_2 69 | scene0064_00 000600_b105d36dbf010903b022c94235bc8601_2 70 | scene0064_00 001000_b105d36dbf010903b022c94235bc8601_0 71 | scene0064_01 000200_46579c6050cac50a1c8c7b57a94dbb2e_2 72 | scene0064_01 000500_dfdc22f3fecbb57c5050d6a2f5d42f74_0 73 | scene0700_02 002000_ec882f5717b0f405b2bf4f773fe0e622_0 74 | scene0700_02 002100_ec882f5717b0f405b2bf4f773fe0e622_0 75 | scene0700_02 002200_ec882f5717b0f405b2bf4f773fe0e622_0 76 | scene0231_00 001400_4552f6002e96cb00205444155ae5c84d_0 77 | scene0231_00 002100_c214e1d190cb87362a9cd5247487b619_1 78 | scene0231_00 004100_4552f6002e96cb00205444155ae5c84d_0 79 | scene0580_00 001700_b079feff448e925546c4f23965b7dd40_0 80 | scene0580_00 001800_b079feff448e925546c4f23965b7dd40_0 81 | scene0697_01 002000_c214e1d190cb87362a9cd5247487b619_0 82 | scene0700_00 001100_c007d1b4972f70102751be1fc72418bb_0 83 | scene0700_00 001200_c007d1b4972f70102751be1fc72418bb_0 84 | scene0700_00 001300_c007d1b4972f70102751be1fc72418bb_0 85 | scene0593_01 000500_ec882f5717b0f405b2bf4f773fe0e622_0 86 | scene0307_02 000200_db5b6dc38fd82fa7cdfa73f789b383fe_1 87 | scene0307_02 000600_db5b6dc38fd82fa7cdfa73f789b383fe_1 88 | scene0307_02 001300_ec882f5717b0f405b2bf4f773fe0e622_2 89 | scene0307_02 001400_ec882f5717b0f405b2bf4f773fe0e622_3 90 | scene0591_02 000100_c214e1d190cb87362a9cd5247487b619_1 91 | scene0591_02 000200_1ab8202a944a6ff1de650492e45fb14f_3 92 | scene0591_02 000700_c214e1d190cb87362a9cd5247487b619_1 93 | scene0591_02 000800_c214e1d190cb87362a9cd5247487b619_0 94 | scene0591_02 000900_1ab8202a944a6ff1de650492e45fb14f_3 95 | scene0663_02 001000_348d539dade47e8e664b3b9b23ddfcbc_2 96 | scene0663_02 001100_348d539dade47e8e664b3b9b23ddfcbc_2 97 | scene0203_02 000900_c40828710c91001546c4f23965b7dd40_6 98 | scene0203_02 001300_db5b6dc38fd82fa7cdfa73f789b383fe_1 99 | scene0025_01 000400_db5b6dc38fd82fa7cdfa73f789b383fe_0 100 | scene0025_01 000500_db5b6dc38fd82fa7cdfa73f789b383fe_0 101 | scene0025_01 000700_db5b6dc38fd82fa7cdfa73f789b383fe_0 102 | scene0025_01 001300_db5b6dc38fd82fa7cdfa73f789b383fe_0 103 | scene0353_01 001500_c214e1d190cb87362a9cd5247487b619_1 104 | scene0378_01 001000_c007d1b4972f70102751be1fc72418bb_0 105 | scene0378_01 001100_c007d1b4972f70102751be1fc72418bb_0 106 | scene0378_01 001200_c007d1b4972f70102751be1fc72418bb_0 107 | scene0378_01 001400_3d422581d2f439c74cc1952ae0d6e81a_1 108 | scene0378_01 001600_3d422581d2f439c74cc1952ae0d6e81a_1 109 | scene0591_00 000600_fb1a15143f6df0b35d3dc7a81f13d5e8_0 110 | scene0591_00 001100_1ab8202a944a6ff1de650492e45fb14f_1 111 | scene0591_00 001200_fb1a15143f6df0b35d3dc7a81f13d5e8_0 112 | scene0591_00 001300_88d2609fe01abd4a6b8fdc9bc8edc250_2 113 | scene0591_00 001400_fb1a15143f6df0b35d3dc7a81f13d5e8_0 114 | scene0696_02 001000_225da2a2b7a461e349edb0f98d2a2a29_0 115 | scene0222_01 002400_465bb5a9ed42ad40a5817f81a1efa3cc_0 116 | scene0222_01 002500_465bb5a9ed42ad40a5817f81a1efa3cc_0 117 | scene0222_01 002600_465bb5a9ed42ad40a5817f81a1efa3cc_0 118 | scene0222_01 002700_465bb5a9ed42ad40a5817f81a1efa3cc_0 119 | scene0558_02 000100_1ab8202a944a6ff1de650492e45fb14f_9 120 | scene0558_02 001700_1ab8202a944a6ff1de650492e45fb14f_8 121 | scene0644_00 000700_c40828710c91001546c4f23965b7dd40_0 122 | scene0644_00 001500_c40828710c91001546c4f23965b7dd40_0 123 | scene0357_01 000200_c214e1d190cb87362a9cd5247487b619_0 124 | scene0307_01 000100_ec882f5717b0f405b2bf4f773fe0e622_0 125 | scene0307_01 000200_ec882f5717b0f405b2bf4f773fe0e622_0 126 | scene0307_01 000500_ec882f5717b0f405b2bf4f773fe0e622_0 127 | scene0307_01 001800_ec882f5717b0f405b2bf4f773fe0e622_0 128 | scene0307_01 002200_ec882f5717b0f405b2bf4f773fe0e622_1 129 | scene0307_01 003000_ec882f5717b0f405b2bf4f773fe0e622_0 130 | scene0307_01 003100_ec882f5717b0f405b2bf4f773fe0e622_0 131 | scene0222_00 000100_e3ae56f176b77359aa1bd50387389420_1 132 | scene0222_00 000900_e3ae56f176b77359aa1bd50387389420_1 133 | scene0222_00 001000_e3ae56f176b77359aa1bd50387389420_1 134 | scene0222_00 001100_e3ae56f176b77359aa1bd50387389420_1 135 | scene0222_00 003200_e3ae56f176b77359aa1bd50387389420_0 136 | scene0222_00 004900_e3ae56f176b77359aa1bd50387389420_0 137 | scene0695_01 001100_465bb5a9ed42ad40a5817f81a1efa3cc_0 138 | scene0695_01 002000_465bb5a9ed42ad40a5817f81a1efa3cc_0 139 | scene0474_02 001800_8007cf8349381dada5817f81a1efa3cc_0 140 | scene0474_02 002000_8007cf8349381dada5817f81a1efa3cc_0 141 | scene0131_02 000200_b6264b92d0cbd598d5919aa833abb5a_0 142 | scene0277_02 000300_465bb5a9ed42ad40a5817f81a1efa3cc_0 143 | scene0356_02 000700_b6264b92d0cbd598d5919aa833abb5a_0 144 | scene0207_01 001700_ec882f5717b0f405b2bf4f773fe0e622_0 145 | scene0207_02 000200_ec882f5717b0f405b2bf4f773fe0e622_0 146 | scene0558_00 000000_b105d36dbf010903b022c94235bc8601_1 147 | scene0558_00 000300_b105d36dbf010903b022c94235bc8601_0 148 | scene0558_00 000700_b105d36dbf010903b022c94235bc8601_1 149 | scene0357_00 000200_d841664bb8c7b842cbf15a79411ef31b_0 150 | scene0648_00 000400_d841664bb8c7b842cbf15a79411ef31b_0 151 | scene0648_00 002000_e3ae56f176b77359aa1bd50387389420_1 152 | scene0648_00 002900_d841664bb8c7b842cbf15a79411ef31b_0 153 | scene0648_00 003100_d841664bb8c7b842cbf15a79411ef31b_0 154 | scene0648_00 003300_e3ae56f176b77359aa1bd50387389420_2 155 | scene0231_01 002000_5793619650b38e335acd449a2ae99009_1 156 | scene0704_01 001500_bf8c485ab29e0c3549a1e995ce589127_0 157 | scene0704_01 001600_bf8c485ab29e0c3549a1e995ce589127_0 158 | scene0535_00 000000_29b66fc9db2f1558e0e89fd83955713c_1 159 | scene0535_00 000100_29b66fc9db2f1558e0e89fd83955713c_1 160 | scene0568_01 000300_9ca929d28c1838d5e41994ffd448fd07_0 161 | scene0568_01 000400_9ca929d28c1838d5e41994ffd448fd07_0 162 | scene0568_01 000500_9ca929d28c1838d5e41994ffd448fd07_0 163 | scene0568_01 001200_34128cd8b8ddbbedd92903da5c4b5ef6_1 164 | scene0663_00 001000_131b5b691ff5bff3945770bff82992ca_0 165 | scene0663_00 001300_364a2a1172dc0a827c6de7e52b00ebab_3 166 | scene0663_00 001500_364a2a1172dc0a827c6de7e52b00ebab_3 167 | scene0663_00 001900_131b5b691ff5bff3945770bff82992ca_0 168 | scene0663_00 002000_131b5b691ff5bff3945770bff82992ca_0 169 | scene0663_00 002100_131b5b691ff5bff3945770bff82992ca_0 170 | scene0663_00 002200_131b5b691ff5bff3945770bff82992ca_0 171 | scene0208_00 000200_908c0b3c235a81c08998b3b64a143d42_8 172 | scene0208_00 000700_908c0b3c235a81c08998b3b64a143d42_8 173 | scene0208_00 001100_908c0b3c235a81c08998b3b64a143d42_6 174 | scene0208_00 001600_908c0b3c235a81c08998b3b64a143d42_7 175 | scene0208_00 001700_908c0b3c235a81c08998b3b64a143d42_7 176 | scene0208_00 002000_908c0b3c235a81c08998b3b64a143d42_9 177 | scene0648_01 000200_b6264b92d0cbd598d5919aa833abb5a_2 178 | scene0648_01 000300_b6264b92d0cbd598d5919aa833abb5a_2 179 | scene0648_01 000400_b6264b92d0cbd598d5919aa833abb5a_2 180 | scene0648_01 003400_61ea75b808512124ae5607418ff674f1_1 181 | scene0378_00 001000_b1696f5b9b3926b1a523e28192de797e_1 182 | scene0378_00 001100_b1696f5b9b3926b1a523e28192de797e_1 183 | scene0378_00 001400_c007d1b4972f70102751be1fc72418bb_0 184 | scene0378_00 001500_c007d1b4972f70102751be1fc72418bb_0 185 | scene0378_00 001600_c007d1b4972f70102751be1fc72418bb_0 186 | scene0378_00 001800_c007d1b4972f70102751be1fc72418bb_0 187 | scene0356_01 000100_dc9af91c3831983597e6cb8e9bd5d9e5_0 188 | scene0025_00 000600_db5b6dc38fd82fa7cdfa73f789b383fe_0 189 | scene0025_00 000700_db5b6dc38fd82fa7cdfa73f789b383fe_0 190 | scene0695_00 001300_b6264b92d0cbd598d5919aa833abb5a_0 191 | scene0077_01 000200_1b3090bfb11cf834f06824a9291300d4_0 192 | scene0474_05 000100_8007cf8349381dada5817f81a1efa3cc_0 193 | scene0474_05 000200_8007cf8349381dada5817f81a1efa3cc_0 194 | scene0474_05 001900_8007cf8349381dada5817f81a1efa3cc_0 195 | scene0633_01 000600_34128cd8b8ddbbedd92903da5c4b5ef6_0 196 | scene0633_01 000700_34128cd8b8ddbbedd92903da5c4b5ef6_0 197 | scene0633_01 000700_fbc6b4e0aa9cc13d713e7f5d7ea85661_1 198 | scene0633_01 000800_34128cd8b8ddbbedd92903da5c4b5ef6_0 199 | scene0077_00 000100_8f0baf74e4d91b0d7b39f9a454e866f6_0 200 | scene0356_00 000700_b6264b92d0cbd598d5919aa833abb5a_0 201 | scene0568_00 000300_4996e0c9dac4ba3649fade4fe2abc936_0 202 | scene0568_00 000400_4996e0c9dac4ba3649fade4fe2abc936_0 203 | scene0568_00 000500_4996e0c9dac4ba3649fade4fe2abc936_0 204 | -------------------------------------------------------------------------------- /splits/pose/02871439/val_nonocc_centroid_maskexist.txt: -------------------------------------------------------------------------------- 1 | scene0593_00 000000_601359d274c2c00a1497d160eced5e7a_0 2 | scene0593_00 000700_601359d274c2c00a1497d160eced5e7a_0 3 | scene0593_00 000800_601359d274c2c00a1497d160eced5e7a_0 4 | scene0593_00 001300_601359d274c2c00a1497d160eced5e7a_0 5 | scene0593_00 001800_601359d274c2c00a1497d160eced5e7a_0 6 | scene0593_00 001900_601359d274c2c00a1497d160eced5e7a_0 7 | scene0700_01 003200_722624d2cd4b72018ac5263758737a81_0 8 | scene0700_01 003300_722624d2cd4b72018ac5263758737a81_0 9 | scene0558_01 000100_4e26a2e39e3b6d39961b70a6f96df2a4_3 10 | scene0558_01 000400_4e26a2e39e3b6d39961b70a6f96df2a4_2 11 | scene0378_02 001100_b1696f5b9b3926b1a523e28192de797e_1 12 | scene0378_02 001200_b1696f5b9b3926b1a523e28192de797e_1 13 | scene0378_02 001300_b1696f5b9b3926b1a523e28192de797e_1 14 | scene0378_02 001400_ae26bec5a79f51943da27ece6ae88fff_0 15 | scene0378_02 001500_ae26bec5a79f51943da27ece6ae88fff_0 16 | scene0378_02 001600_ae26bec5a79f51943da27ece6ae88fff_0 17 | scene0378_02 001700_ae26bec5a79f51943da27ece6ae88fff_0 18 | scene0353_02 000000_1d3014ad5c35944f9af68b4b3261e1f8_1 19 | scene0203_00 000400_c214e1d190cb87362a9cd5247487b619_3 20 | scene0704_00 000600_9ca929d28c1838d5e41994ffd448fd07_0 21 | scene0329_01 000300_4996e0c9dac4ba3649fade4fe2abc936_0 22 | scene0591_01 000100_ad808d321f7cd914c6d17bc3a482f77_2 23 | scene0591_01 000200_ad808d321f7cd914c6d17bc3a482f77_0 24 | scene0591_01 000300_ad808d321f7cd914c6d17bc3a482f77_2 25 | scene0591_01 000900_ad808d321f7cd914c6d17bc3a482f77_2 26 | scene0591_01 001700_ad808d321f7cd914c6d17bc3a482f77_0 27 | scene0474_00 000000_c007d1b4972f70102751be1fc72418bb_0 28 | scene0474_00 000100_c007d1b4972f70102751be1fc72418bb_0 29 | scene0474_00 000900_c007d1b4972f70102751be1fc72418bb_0 30 | scene0030_02 001700_89f1540881171ee3f4c4977ed0ba5296_4 31 | scene0598_02 000000_601b64e8e27159e157da56e4c9ff868d_3 32 | scene0598_02 000100_601b64e8e27159e157da56e4c9ff868d_3 33 | scene0598_02 000400_601b64e8e27159e157da56e4c9ff868d_4 34 | scene0598_02 000600_601b64e8e27159e157da56e4c9ff868d_3 35 | scene0598_02 000700_601b64e8e27159e157da56e4c9ff868d_3 36 | scene0598_02 001100_601b64e8e27159e157da56e4c9ff868d_0 37 | scene0598_02 001300_601b64e8e27159e157da56e4c9ff868d_4 38 | scene0231_02 000500_c214e1d190cb87362a9cd5247487b619_0 39 | scene0231_02 002800_c214e1d190cb87362a9cd5247487b619_0 40 | scene0658_00 000000_c214e1d190cb87362a9cd5247487b619_1 41 | scene0658_00 000300_c214e1d190cb87362a9cd5247487b619_0 42 | scene0643_00 001600_3fa60816f15b58c4607974568e26586f_1 43 | scene0568_02 000400_4996e0c9dac4ba3649fade4fe2abc936_0 44 | scene0568_02 000500_4996e0c9dac4ba3649fade4fe2abc936_0 45 | scene0474_01 001600_c007d1b4972f70102751be1fc72418bb_0 46 | scene0527_00 000100_35afa1b806556803e99c79bad29da781_0 47 | scene0203_01 000500_d70d80e5855785a761b9fd1751b9fcb_4 48 | scene0203_01 001100_e4c42bcbba4ef5b09a2d4f7bacd6e0d8_0 49 | scene0203_01 001300_e3ae56f176b77359aa1bd50387389420_1 50 | scene0695_03 002100_465bb5a9ed42ad40a5817f81a1efa3cc_0 51 | scene0598_00 000200_81b8784259e3331f94cdfc338037bd95_0 52 | scene0353_00 000000_ad808d321f7cd914c6d17bc3a482f77_0 53 | scene0353_00 001500_ad808d321f7cd914c6d17bc3a482f77_0 54 | scene0695_02 000300_1ab8202a944a6ff1de650492e45fb14f_0 55 | scene0695_02 000900_1ab8202a944a6ff1de650492e45fb14f_0 56 | scene0064_00 000500_b105d36dbf010903b022c94235bc8601_2 57 | scene0064_00 000600_b105d36dbf010903b022c94235bc8601_2 58 | scene0064_00 001000_b105d36dbf010903b022c94235bc8601_0 59 | scene0064_01 000200_46579c6050cac50a1c8c7b57a94dbb2e_2 60 | scene0064_01 000500_dfdc22f3fecbb57c5050d6a2f5d42f74_0 61 | scene0700_02 002000_ec882f5717b0f405b2bf4f773fe0e622_0 62 | scene0700_02 002100_ec882f5717b0f405b2bf4f773fe0e622_0 63 | scene0700_02 002200_ec882f5717b0f405b2bf4f773fe0e622_0 64 | scene0231_00 001400_4552f6002e96cb00205444155ae5c84d_0 65 | scene0231_00 002100_c214e1d190cb87362a9cd5247487b619_1 66 | scene0231_00 004100_4552f6002e96cb00205444155ae5c84d_0 67 | scene0580_00 001700_b079feff448e925546c4f23965b7dd40_0 68 | scene0580_00 001800_b079feff448e925546c4f23965b7dd40_0 69 | scene0697_01 002000_c214e1d190cb87362a9cd5247487b619_0 70 | scene0700_00 001100_c007d1b4972f70102751be1fc72418bb_0 71 | scene0700_00 001200_c007d1b4972f70102751be1fc72418bb_0 72 | scene0700_00 001300_c007d1b4972f70102751be1fc72418bb_0 73 | scene0593_01 000500_ec882f5717b0f405b2bf4f773fe0e622_0 74 | scene0307_02 000200_db5b6dc38fd82fa7cdfa73f789b383fe_1 75 | scene0307_02 000600_db5b6dc38fd82fa7cdfa73f789b383fe_1 76 | scene0307_02 001300_ec882f5717b0f405b2bf4f773fe0e622_2 77 | scene0307_02 001400_ec882f5717b0f405b2bf4f773fe0e622_3 78 | scene0591_02 000100_c214e1d190cb87362a9cd5247487b619_1 79 | scene0591_02 000200_1ab8202a944a6ff1de650492e45fb14f_3 80 | scene0591_02 000700_c214e1d190cb87362a9cd5247487b619_1 81 | scene0591_02 000800_c214e1d190cb87362a9cd5247487b619_0 82 | scene0591_02 000900_1ab8202a944a6ff1de650492e45fb14f_3 83 | scene0663_02 001000_348d539dade47e8e664b3b9b23ddfcbc_2 84 | scene0663_02 001100_348d539dade47e8e664b3b9b23ddfcbc_2 85 | scene0203_02 001300_db5b6dc38fd82fa7cdfa73f789b383fe_1 86 | scene0025_01 000400_db5b6dc38fd82fa7cdfa73f789b383fe_0 87 | scene0025_01 000500_db5b6dc38fd82fa7cdfa73f789b383fe_0 88 | scene0025_01 000700_db5b6dc38fd82fa7cdfa73f789b383fe_0 89 | scene0025_01 001300_db5b6dc38fd82fa7cdfa73f789b383fe_0 90 | scene0353_01 001500_c214e1d190cb87362a9cd5247487b619_1 91 | scene0378_01 001000_c007d1b4972f70102751be1fc72418bb_0 92 | scene0378_01 001100_c007d1b4972f70102751be1fc72418bb_0 93 | scene0378_01 001200_c007d1b4972f70102751be1fc72418bb_0 94 | scene0378_01 001400_3d422581d2f439c74cc1952ae0d6e81a_1 95 | scene0378_01 001600_3d422581d2f439c74cc1952ae0d6e81a_1 96 | scene0591_00 001100_1ab8202a944a6ff1de650492e45fb14f_1 97 | scene0591_00 001200_fb1a15143f6df0b35d3dc7a81f13d5e8_0 98 | scene0591_00 001300_88d2609fe01abd4a6b8fdc9bc8edc250_2 99 | scene0591_00 001400_fb1a15143f6df0b35d3dc7a81f13d5e8_0 100 | scene0696_02 001000_225da2a2b7a461e349edb0f98d2a2a29_0 101 | scene0222_01 002500_465bb5a9ed42ad40a5817f81a1efa3cc_0 102 | scene0222_01 002600_465bb5a9ed42ad40a5817f81a1efa3cc_0 103 | scene0222_01 002700_465bb5a9ed42ad40a5817f81a1efa3cc_0 104 | scene0558_02 001700_1ab8202a944a6ff1de650492e45fb14f_8 105 | scene0644_00 000700_c40828710c91001546c4f23965b7dd40_0 106 | scene0644_00 001500_c40828710c91001546c4f23965b7dd40_0 107 | scene0357_01 000200_c214e1d190cb87362a9cd5247487b619_0 108 | scene0307_01 000100_ec882f5717b0f405b2bf4f773fe0e622_0 109 | scene0307_01 000200_ec882f5717b0f405b2bf4f773fe0e622_0 110 | scene0307_01 000500_ec882f5717b0f405b2bf4f773fe0e622_0 111 | scene0307_01 001800_ec882f5717b0f405b2bf4f773fe0e622_0 112 | scene0307_01 002200_ec882f5717b0f405b2bf4f773fe0e622_1 113 | scene0307_01 003000_ec882f5717b0f405b2bf4f773fe0e622_0 114 | scene0307_01 003100_ec882f5717b0f405b2bf4f773fe0e622_0 115 | scene0222_00 000100_e3ae56f176b77359aa1bd50387389420_1 116 | scene0222_00 000900_e3ae56f176b77359aa1bd50387389420_1 117 | scene0222_00 001000_e3ae56f176b77359aa1bd50387389420_1 118 | scene0222_00 001100_e3ae56f176b77359aa1bd50387389420_1 119 | scene0222_00 003200_e3ae56f176b77359aa1bd50387389420_0 120 | scene0222_00 004900_e3ae56f176b77359aa1bd50387389420_0 121 | scene0695_01 001100_465bb5a9ed42ad40a5817f81a1efa3cc_0 122 | scene0695_01 002000_465bb5a9ed42ad40a5817f81a1efa3cc_0 123 | scene0474_02 001800_8007cf8349381dada5817f81a1efa3cc_0 124 | scene0474_02 002000_8007cf8349381dada5817f81a1efa3cc_0 125 | scene0131_02 000200_b6264b92d0cbd598d5919aa833abb5a_0 126 | scene0277_02 000300_465bb5a9ed42ad40a5817f81a1efa3cc_0 127 | scene0356_02 000700_b6264b92d0cbd598d5919aa833abb5a_0 128 | scene0207_01 001700_ec882f5717b0f405b2bf4f773fe0e622_0 129 | scene0207_02 000200_ec882f5717b0f405b2bf4f773fe0e622_0 130 | scene0558_00 000000_b105d36dbf010903b022c94235bc8601_1 131 | scene0558_00 000300_b105d36dbf010903b022c94235bc8601_0 132 | scene0558_00 000700_b105d36dbf010903b022c94235bc8601_1 133 | scene0357_00 000200_d841664bb8c7b842cbf15a79411ef31b_0 134 | scene0648_00 000400_d841664bb8c7b842cbf15a79411ef31b_0 135 | scene0648_00 002000_e3ae56f176b77359aa1bd50387389420_1 136 | scene0648_00 002900_d841664bb8c7b842cbf15a79411ef31b_0 137 | scene0648_00 003100_d841664bb8c7b842cbf15a79411ef31b_0 138 | scene0648_00 003300_e3ae56f176b77359aa1bd50387389420_2 139 | scene0231_01 002000_5793619650b38e335acd449a2ae99009_1 140 | scene0535_00 000000_29b66fc9db2f1558e0e89fd83955713c_1 141 | scene0535_00 000100_29b66fc9db2f1558e0e89fd83955713c_1 142 | scene0568_01 000300_9ca929d28c1838d5e41994ffd448fd07_0 143 | scene0568_01 000400_9ca929d28c1838d5e41994ffd448fd07_0 144 | scene0568_01 000500_9ca929d28c1838d5e41994ffd448fd07_0 145 | scene0568_01 001200_34128cd8b8ddbbedd92903da5c4b5ef6_1 146 | scene0663_00 001000_131b5b691ff5bff3945770bff82992ca_0 147 | scene0663_00 001500_364a2a1172dc0a827c6de7e52b00ebab_3 148 | scene0663_00 001900_131b5b691ff5bff3945770bff82992ca_0 149 | scene0663_00 002000_131b5b691ff5bff3945770bff82992ca_0 150 | scene0663_00 002100_131b5b691ff5bff3945770bff82992ca_0 151 | scene0663_00 002200_131b5b691ff5bff3945770bff82992ca_0 152 | scene0208_00 000200_908c0b3c235a81c08998b3b64a143d42_8 153 | scene0208_00 000700_908c0b3c235a81c08998b3b64a143d42_8 154 | scene0208_00 001100_908c0b3c235a81c08998b3b64a143d42_6 155 | scene0208_00 001600_908c0b3c235a81c08998b3b64a143d42_7 156 | scene0208_00 001700_908c0b3c235a81c08998b3b64a143d42_7 157 | scene0208_00 002000_908c0b3c235a81c08998b3b64a143d42_9 158 | scene0648_01 000200_b6264b92d0cbd598d5919aa833abb5a_2 159 | scene0648_01 000300_b6264b92d0cbd598d5919aa833abb5a_2 160 | scene0648_01 000400_b6264b92d0cbd598d5919aa833abb5a_2 161 | scene0648_01 003400_61ea75b808512124ae5607418ff674f1_1 162 | scene0378_00 001000_b1696f5b9b3926b1a523e28192de797e_1 163 | scene0378_00 001100_b1696f5b9b3926b1a523e28192de797e_1 164 | scene0378_00 001400_c007d1b4972f70102751be1fc72418bb_0 165 | scene0378_00 001500_c007d1b4972f70102751be1fc72418bb_0 166 | scene0378_00 001600_c007d1b4972f70102751be1fc72418bb_0 167 | scene0378_00 001800_c007d1b4972f70102751be1fc72418bb_0 168 | scene0356_01 000100_dc9af91c3831983597e6cb8e9bd5d9e5_0 169 | scene0025_00 000600_db5b6dc38fd82fa7cdfa73f789b383fe_0 170 | scene0025_00 000700_db5b6dc38fd82fa7cdfa73f789b383fe_0 171 | scene0695_00 001300_b6264b92d0cbd598d5919aa833abb5a_0 172 | scene0077_01 000200_1b3090bfb11cf834f06824a9291300d4_0 173 | scene0474_05 000100_8007cf8349381dada5817f81a1efa3cc_0 174 | scene0474_05 000200_8007cf8349381dada5817f81a1efa3cc_0 175 | scene0474_05 001900_8007cf8349381dada5817f81a1efa3cc_0 176 | scene0633_01 000600_34128cd8b8ddbbedd92903da5c4b5ef6_0 177 | scene0633_01 000700_34128cd8b8ddbbedd92903da5c4b5ef6_0 178 | scene0633_01 000700_fbc6b4e0aa9cc13d713e7f5d7ea85661_1 179 | scene0633_01 000800_34128cd8b8ddbbedd92903da5c4b5ef6_0 180 | scene0077_00 000100_8f0baf74e4d91b0d7b39f9a454e866f6_0 181 | scene0356_00 000700_b6264b92d0cbd598d5919aa833abb5a_0 182 | scene0568_00 000300_4996e0c9dac4ba3649fade4fe2abc936_0 183 | scene0568_00 000400_4996e0c9dac4ba3649fade4fe2abc936_0 184 | scene0568_00 000500_4996e0c9dac4ba3649fade4fe2abc936_0 185 | -------------------------------------------------------------------------------- /splits/pose/04256520/val_nonocc_centroid_maskexist.txt: -------------------------------------------------------------------------------- 1 | scene0474_03 001200_58447a958c4af154942bb07caacf4df3_0 2 | scene0474_03 001300_58447a958c4af154942bb07caacf4df3_0 3 | scene0593_00 000900_93b421c66ff4529f37b2bb75885cfc44_0 4 | scene0593_00 001600_93b421c66ff4529f37b2bb75885cfc44_0 5 | scene0353_02 000200_61f828a545649e98f1d7342136779c0_0 6 | scene0353_02 001400_61f828a545649e98f1d7342136779c0_0 7 | scene0353_02 002100_61f828a545649e98f1d7342136779c0_0 8 | scene0329_00 000000_23833969c0011b8e98494085d68ad6a0_1 9 | scene0329_00 000100_556166f38429cdfe29bdd38dd4a1a461_0 10 | scene0329_00 000200_23833969c0011b8e98494085d68ad6a0_1 11 | scene0329_00 000600_23833969c0011b8e98494085d68ad6a0_1 12 | scene0329_00 001100_556166f38429cdfe29bdd38dd4a1a461_0 13 | scene0030_00 000000_7ab86358957e386d76de5cade2fd5247_0 14 | scene0030_00 000100_7ab86358957e386d76de5cade2fd5247_0 15 | scene0030_00 001600_7ab86358957e386d76de5cade2fd5247_0 16 | scene0518_00 000200_955d633562dff06f843e991acd39f432_0 17 | scene0518_00 000500_955d633562dff06f843e991acd39f432_0 18 | scene0518_00 001000_955d633562dff06f843e991acd39f432_0 19 | scene0203_00 000200_330d44833e1b4b168b38796afe7ee552_0 20 | scene0203_00 000600_330d44833e1b4b168b38796afe7ee552_0 21 | scene0574_01 001000_945a038c3e0c46ec19fb4103277a6b93_0 22 | scene0574_01 001100_945a038c3e0c46ec19fb4103277a6b93_0 23 | scene0574_01 001200_945a038c3e0c46ec19fb4103277a6b93_0 24 | scene0574_01 001500_945a038c3e0c46ec19fb4103277a6b93_0 25 | scene0329_01 000000_7c9e1876b1643e93f9377e1922a21892_1 26 | scene0329_01 000100_7c9e1876b1643e93f9377e1922a21892_1 27 | scene0329_01 000200_7c9e1876b1643e93f9377e1922a21892_1 28 | scene0329_01 000800_7c9e1876b1643e93f9377e1922a21892_0 29 | scene0329_01 000900_7c9e1876b1643e93f9377e1922a21892_1 30 | scene0701_01 000100_41b02faaceadb39560fcec8f64d76ffb_3 31 | scene0701_01 000200_4653af854bf098f2d74aae0eb2ddb027_1 32 | scene0701_01 000200_4653af854bf098f2d74aae0eb2ddb027_2 33 | scene0701_01 000300_681d226acbeaaf08a4ee0fb6a51564c3_0 34 | scene0701_01 000600_41b02faaceadb39560fcec8f64d76ffb_3 35 | scene0701_01 000700_4653af854bf098f2d74aae0eb2ddb027_2 36 | scene0701_01 000800_4653af854bf098f2d74aae0eb2ddb027_1 37 | scene0591_01 001000_23833969c0011b8e98494085d68ad6a0_0 38 | scene0474_00 000400_41b02faaceadb39560fcec8f64d76ffb_0 39 | scene0474_00 001200_41b02faaceadb39560fcec8f64d76ffb_0 40 | scene0474_00 001300_41b02faaceadb39560fcec8f64d76ffb_0 41 | scene0608_01 000500_fb74336a6192c4787afee304cce81d6f_0 42 | scene0608_01 000700_fb74336a6192c4787afee304cce81d6f_0 43 | scene0608_01 000800_fb74336a6192c4787afee304cce81d6f_0 44 | scene0608_01 000900_fb74336a6192c4787afee304cce81d6f_0 45 | scene0608_01 001000_fb74336a6192c4787afee304cce81d6f_0 46 | scene0609_03 000000_cd10e95d1501ed6719fb4103277a6b93_3 47 | scene0609_03 000300_cd10e95d1501ed6719fb4103277a6b93_0 48 | scene0609_03 000600_cd10e95d1501ed6719fb4103277a6b93_2 49 | scene0609_03 000600_cd10e95d1501ed6719fb4103277a6b93_3 50 | scene0645_01 001100_fb74336a6192c4787afee304cce81d6f_0 51 | scene0645_01 001200_fb74336a6192c4787afee304cce81d6f_0 52 | scene0645_01 001300_fb74336a6192c4787afee304cce81d6f_0 53 | scene0645_01 001400_fb74336a6192c4787afee304cce81d6f_0 54 | scene0645_01 001700_fb74336a6192c4787afee304cce81d6f_0 55 | scene0559_02 000000_330d44833e1b4b168b38796afe7ee552_0 56 | scene0559_02 000100_330d44833e1b4b168b38796afe7ee552_0 57 | scene0645_00 001100_8a5a40fe10eb2b2eb022c94235bc8601_0 58 | scene0645_00 001200_8a5a40fe10eb2b2eb022c94235bc8601_0 59 | scene0645_00 001300_8a5a40fe10eb2b2eb022c94235bc8601_0 60 | scene0645_00 001400_8a5a40fe10eb2b2eb022c94235bc8601_0 61 | scene0645_00 001500_8a5a40fe10eb2b2eb022c94235bc8601_0 62 | scene0645_00 001600_8a5a40fe10eb2b2eb022c94235bc8601_0 63 | scene0231_02 000300_cbd547bfb6b7d8e54b50faf1a96496ef_1 64 | scene0231_02 000700_cbd547bfb6b7d8e54b50faf1a96496ef_1 65 | scene0231_02 002500_3f5fdc05fc572730490ad276cd2af3a4_0 66 | scene0231_02 002600_cbd547bfb6b7d8e54b50faf1a96496ef_1 67 | scene0568_02 000000_60fc7123d6360e6d620ef1b4a95dca08_0 68 | scene0568_02 000100_60fc7123d6360e6d620ef1b4a95dca08_0 69 | scene0568_02 001400_60fc7123d6360e6d620ef1b4a95dca08_0 70 | scene0423_01 000200_68a1f95fed336299f51f77a6d7299806_0 71 | scene0423_01 000300_68a1f95fed336299f51f77a6d7299806_0 72 | scene0423_01 000400_68a1f95fed336299f51f77a6d7299806_0 73 | scene0423_01 000600_68a1f95fed336299f51f77a6d7299806_0 74 | scene0690_00 000200_23833969c0011b8e98494085d68ad6a0_0 75 | scene0050_02 000400_8458d6939967ac1bbc7a6acbd8f058b_0 76 | scene0050_02 001800_8458d6939967ac1bbc7a6acbd8f058b_0 77 | scene0050_02 001900_8458d6939967ac1bbc7a6acbd8f058b_0 78 | scene0050_02 002000_8458d6939967ac1bbc7a6acbd8f058b_0 79 | scene0050_02 004300_8458d6939967ac1bbc7a6acbd8f058b_0 80 | scene0608_02 000500_608936a307740f5df7628281ecb18112_0 81 | scene0608_02 000600_608936a307740f5df7628281ecb18112_0 82 | scene0608_02 000700_608936a307740f5df7628281ecb18112_0 83 | scene0608_02 000800_608936a307740f5df7628281ecb18112_0 84 | scene0608_02 000900_608936a307740f5df7628281ecb18112_0 85 | scene0608_02 002200_608936a307740f5df7628281ecb18112_0 86 | scene0608_02 002300_608936a307740f5df7628281ecb18112_0 87 | scene0608_02 002400_608936a307740f5df7628281ecb18112_0 88 | scene0608_02 002500_608936a307740f5df7628281ecb18112_0 89 | scene0474_01 000800_fee8e1e0161f69b0db039d8689a74349_0 90 | scene0474_01 000900_fee8e1e0161f69b0db039d8689a74349_0 91 | scene0474_01 001000_fee8e1e0161f69b0db039d8689a74349_0 92 | scene0474_01 001100_fee8e1e0161f69b0db039d8689a74349_0 93 | scene0203_01 000200_117f6ac4bcd75d8b4ad65adb06bbae49_1 94 | scene0203_01 000800_117f6ac4bcd75d8b4ad65adb06bbae49_0 95 | scene0203_01 001200_117f6ac4bcd75d8b4ad65adb06bbae49_0 96 | scene0701_00 000000_60fc7123d6360e6d620ef1b4a95dca08_2 97 | scene0701_00 000100_60fc7123d6360e6d620ef1b4a95dca08_2 98 | scene0701_00 000200_f846fb7af63a5e838eec9023c5b97e00_0 99 | scene0701_00 000400_c3c5818cbe6d0903822a33e080d0e71c_1 100 | scene0701_00 000700_60fc7123d6360e6d620ef1b4a95dca08_2 101 | scene0701_00 000800_f846fb7af63a5e838eec9023c5b97e00_0 102 | scene0701_00 000900_c3c5818cbe6d0903822a33e080d0e71c_1 103 | scene0701_00 001000_f846fb7af63a5e838eec9023c5b97e00_0 104 | scene0334_01 000000_849ddda40bd6540efac8371a83e130ac_1 105 | scene0334_01 000100_849ddda40bd6540efac8371a83e130ac_0 106 | scene0334_01 000200_849ddda40bd6540efac8371a83e130ac_0 107 | scene0334_01 000900_849ddda40bd6540efac8371a83e130ac_0 108 | scene0334_01 001000_849ddda40bd6540efac8371a83e130ac_0 109 | scene0334_01 001000_849ddda40bd6540efac8371a83e130ac_1 110 | scene0353_00 000200_436a96f58ef9a6fdb039d8689a74349_0 111 | scene0353_00 000300_436a96f58ef9a6fdb039d8689a74349_0 112 | scene0435_00 000700_1230d31e3a6cbf309cd431573238602d_0 113 | scene0064_00 000000_c2d26d8c8d5917d443ba2b548bab2839_1 114 | scene0064_00 000100_c2d26d8c8d5917d443ba2b548bab2839_1 115 | scene0064_00 000300_c2d26d8c8d5917d443ba2b548bab2839_0 116 | scene0064_00 000400_c2d26d8c8d5917d443ba2b548bab2839_0 117 | scene0064_00 000500_c2d26d8c8d5917d443ba2b548bab2839_0 118 | scene0064_00 000800_c2d26d8c8d5917d443ba2b548bab2839_1 119 | scene0064_00 000900_c2d26d8c8d5917d443ba2b548bab2839_1 120 | scene0064_00 001000_c2d26d8c8d5917d443ba2b548bab2839_1 121 | scene0064_00 001200_c2d26d8c8d5917d443ba2b548bab2839_0 122 | scene0608_00 001000_556166f38429cdfe29bdd38dd4a1a461_0 123 | scene0608_00 002500_556166f38429cdfe29bdd38dd4a1a461_0 124 | scene0608_00 002600_556166f38429cdfe29bdd38dd4a1a461_0 125 | scene0231_00 000600_9ceb81a09813d5f3d2565bc39479705a_0 126 | scene0231_00 000700_9ceb81a09813d5f3d2565bc39479705a_0 127 | scene0231_00 000800_9ceb81a09813d5f3d2565bc39479705a_0 128 | scene0231_00 001500_9ceb81a09813d5f3d2565bc39479705a_0 129 | scene0231_00 003300_9ceb81a09813d5f3d2565bc39479705a_0 130 | scene0231_00 004200_9ceb81a09813d5f3d2565bc39479705a_0 131 | scene0050_00 000200_556166f38429cdfe29bdd38dd4a1a461_0 132 | scene0050_00 001700_556166f38429cdfe29bdd38dd4a1a461_0 133 | scene0050_00 002600_556166f38429cdfe29bdd38dd4a1a461_0 134 | scene0050_00 003000_556166f38429cdfe29bdd38dd4a1a461_0 135 | scene0645_02 001400_dcd9a34a9892fb11490ad276cd2af3a4_0 136 | scene0645_02 001500_dcd9a34a9892fb11490ad276cd2af3a4_0 137 | scene0645_02 001800_dcd9a34a9892fb11490ad276cd2af3a4_0 138 | scene0645_02 001900_dcd9a34a9892fb11490ad276cd2af3a4_0 139 | scene0591_02 000400_fb65fdcded332e4118039d66c0209ecb_0 140 | scene0591_02 001500_fb65fdcded332e4118039d66c0209ecb_0 141 | scene0591_02 002200_fb65fdcded332e4118039d66c0209ecb_0 142 | scene0203_02 000000_1824d5cfb7472fcf9d5cfc3a8d7af21d_0 143 | scene0203_02 000800_1824d5cfb7472fcf9d5cfc3a8d7af21d_0 144 | scene0203_02 001500_1824d5cfb7472fcf9d5cfc3a8d7af21d_0 145 | scene0025_01 000000_f20e7f4f41f323a04b3c42e318f3affc_0 146 | scene0025_01 000400_d053e745b565fa391c1b3b2ed8d13bf8_1 147 | scene0025_01 000500_d053e745b565fa391c1b3b2ed8d13bf8_1 148 | scene0025_01 000700_d053e745b565fa391c1b3b2ed8d13bf8_1 149 | scene0025_01 001200_f20e7f4f41f323a04b3c42e318f3affc_0 150 | scene0025_01 001500_f20e7f4f41f323a04b3c42e318f3affc_0 151 | scene0549_00 000100_681d226acbeaaf08a4ee0fb6a51564c3_0 152 | scene0549_00 000300_ef479941cb60405f8cbd400aa99bee96_1 153 | scene0549_00 000400_ef479941cb60405f8cbd400aa99bee96_1 154 | scene0549_00 000500_ef479941cb60405f8cbd400aa99bee96_1 155 | scene0549_00 000800_681d226acbeaaf08a4ee0fb6a51564c3_0 156 | scene0353_01 000200_b2a9553d5d81060b36c9a52137c03278_0 157 | scene0353_01 001700_b2a9553d5d81060b36c9a52137c03278_0 158 | scene0353_01 002300_b2a9553d5d81060b36c9a52137c03278_0 159 | scene0591_00 000800_289e520179ed1e397282872e507d5fff_0 160 | scene0591_00 000900_289e520179ed1e397282872e507d5fff_0 161 | scene0591_00 001000_289e520179ed1e397282872e507d5fff_0 162 | scene0591_00 001600_289e520179ed1e397282872e507d5fff_0 163 | scene0696_02 000500_7ae657b39aa2be68ccd1bcd57588acf8_0 164 | scene0549_01 000100_4e1ee66994a95492f2543b208c9ee8e2_1 165 | scene0549_01 000300_4e1ee66994a95492f2543b208c9ee8e2_1 166 | scene0549_01 000500_4e1ee66994a95492f2543b208c9ee8e2_1 167 | scene0549_01 000800_4e1ee66994a95492f2543b208c9ee8e2_0 168 | scene0549_01 000900_4e1ee66994a95492f2543b208c9ee8e2_0 169 | scene0549_01 001000_4e1ee66994a95492f2543b208c9ee8e2_0 170 | scene0474_02 000800_7cfccaf7557934911ee8243f54292d6_0 171 | scene0474_02 001200_7cfccaf7557934911ee8243f54292d6_0 172 | scene0559_00 000000_867d1e4a9f7cc110b8df7b9b18a5c81f_0 173 | scene0696_01 000300_bc6a3fa659dd7ec0c62ac18334863d36_0 174 | scene0696_01 001000_bc6a3fa659dd7ec0c62ac18334863d36_0 175 | scene0652_00 000200_cd249bd432c4bc75b82cf928f6ed5338_0 176 | scene0652_00 000300_cd249bd432c4bc75b82cf928f6ed5338_0 177 | scene0652_00 001000_cd249bd432c4bc75b82cf928f6ed5338_0 178 | scene0207_01 000000_fd4dd071f73ca07355eab99951962891_0 179 | scene0207_01 001600_fd4dd071f73ca07355eab99951962891_0 180 | scene0207_02 000300_8efa91e2f3e2eaf7bdc82a7932cd806_0 181 | scene0207_02 002300_8efa91e2f3e2eaf7bdc82a7932cd806_0 182 | scene0690_01 000100_556166f38429cdfe29bdd38dd4a1a461_0 183 | scene0690_01 000200_556166f38429cdfe29bdd38dd4a1a461_0 184 | scene0207_00 000300_330d44833e1b4b168b38796afe7ee552_0 185 | scene0207_00 000600_330d44833e1b4b168b38796afe7ee552_0 186 | scene0207_00 001700_330d44833e1b4b168b38796afe7ee552_0 187 | scene0231_01 001200_13b9cc6c187edb98afd316e82119b42_0 188 | scene0231_01 003700_13b9cc6c187edb98afd316e82119b42_0 189 | scene0025_02 000000_61f828a545649e98f1d7342136779c0_1 190 | scene0025_02 000600_7fd704652332a45b2ce025aebfea84a4_0 191 | scene0025_02 000700_61f828a545649e98f1d7342136779c0_1 192 | scene0025_02 000800_7fd704652332a45b2ce025aebfea84a4_0 193 | scene0568_01 000000_cceaeed0d8cf5bdbca68d7e2f215cb19_0 194 | scene0568_01 000100_cceaeed0d8cf5bdbca68d7e2f215cb19_0 195 | scene0187_01 000000_44854046021846f219fb4103277a6b93_0 196 | scene0187_01 000100_44854046021846f219fb4103277a6b93_0 197 | scene0187_01 001200_44854046021846f219fb4103277a6b93_1 198 | scene0187_01 001300_44854046021846f219fb4103277a6b93_1 199 | scene0187_01 001600_44854046021846f219fb4103277a6b93_0 200 | scene0334_00 000000_849ddda40bd6540efac8371a83e130ac_2 201 | scene0334_00 000100_849ddda40bd6540efac8371a83e130ac_1 202 | scene0334_00 000300_849ddda40bd6540efac8371a83e130ac_3 203 | scene0334_00 000500_849ddda40bd6540efac8371a83e130ac_0 204 | scene0334_00 001100_849ddda40bd6540efac8371a83e130ac_1 205 | scene0334_00 001100_849ddda40bd6540efac8371a83e130ac_2 206 | scene0461_00 000000_e9e5da988215f06513292732a7b1ed9a_0 207 | scene0461_00 000000_e9e5da988215f06513292732a7b1ed9a_1 208 | scene0461_00 000100_e9e5da988215f06513292732a7b1ed9a_0 209 | scene0461_00 000100_e9e5da988215f06513292732a7b1ed9a_1 210 | scene0461_00 000200_e9e5da988215f06513292732a7b1ed9a_0 211 | scene0461_00 000200_e9e5da988215f06513292732a7b1ed9a_1 212 | scene0461_00 000300_e9e5da988215f06513292732a7b1ed9a_1 213 | scene0461_00 000500_e9e5da988215f06513292732a7b1ed9a_1 214 | scene0025_00 000200_8659f0f422096e3d26f6c8b5b75f0ee9_1 215 | scene0025_00 000400_8659f0f422096e3d26f6c8b5b75f0ee9_1 216 | scene0025_00 000500_8659f0f422096e3d26f6c8b5b75f0ee9_1 217 | scene0025_00 000800_8659f0f422096e3d26f6c8b5b75f0ee9_0 218 | scene0025_00 001600_8659f0f422096e3d26f6c8b5b75f0ee9_0 219 | scene0050_01 001000_8659f0f422096e3d26f6c8b5b75f0ee9_1 220 | scene0050_01 001200_bf01483d8b58f0819767624530e7fce3_0 221 | scene0050_01 001900_8659f0f422096e3d26f6c8b5b75f0ee9_1 222 | scene0050_01 002000_8659f0f422096e3d26f6c8b5b75f0ee9_1 223 | scene0050_01 002100_8659f0f422096e3d26f6c8b5b75f0ee9_1 224 | scene0050_01 003500_bf01483d8b58f0819767624530e7fce3_0 225 | scene0701_02 000000_679010d35da8193219fb4103277a6b93_0 226 | scene0701_02 000100_679010d35da8193219fb4103277a6b93_0 227 | scene0701_02 000100_bdd7a0eb66e8884dad04591c9486ec0_2 228 | scene0701_02 000200_bdd7a0eb66e8884dad04591c9486ec0_2 229 | scene0701_02 000300_62e90a6ed511a1b2d291861d5bc3e7c8_1 230 | scene0701_02 000400_679010d35da8193219fb4103277a6b93_0 231 | scene0701_02 000500_679010d35da8193219fb4103277a6b93_0 232 | scene0701_02 000700_679010d35da8193219fb4103277a6b93_0 233 | scene0701_02 000800_bdd7a0eb66e8884dad04591c9486ec0_2 234 | scene0701_02 000900_679010d35da8193219fb4103277a6b93_0 235 | scene0701_02 000900_62e90a6ed511a1b2d291861d5bc3e7c8_1 236 | scene0701_02 001000_62e90a6ed511a1b2d291861d5bc3e7c8_1 237 | scene0701_02 001100_62e90a6ed511a1b2d291861d5bc3e7c8_1 238 | scene0701_02 001200_62e90a6ed511a1b2d291861d5bc3e7c8_1 239 | scene0701_02 001200_bdd7a0eb66e8884dad04591c9486ec0_2 240 | scene0647_01 000500_c7f31b9900a1a7644785ad2feb797e_0 241 | scene0647_01 000500_354c37c168778a0bd4830313df3656b_1 242 | scene0647_01 000600_c7f31b9900a1a7644785ad2feb797e_0 243 | scene0647_01 000600_354c37c168778a0bd4830313df3656b_1 244 | scene0474_05 002700_3d164c442e5788e25c7a30510dbe4e9f_0 245 | scene0559_01 000200_ad0e50d6f1e9a16aefc579970fcfc006_0 246 | scene0559_01 000300_ad0e50d6f1e9a16aefc579970fcfc006_0 247 | scene0559_01 000400_ad0e50d6f1e9a16aefc579970fcfc006_0 248 | scene0329_02 000000_8659f0f422096e3d26f6c8b5b75f0ee9_0 249 | scene0329_02 000100_8659f0f422096e3d26f6c8b5b75f0ee9_1 250 | scene0329_02 000500_8659f0f422096e3d26f6c8b5b75f0ee9_1 251 | scene0329_02 000600_8659f0f422096e3d26f6c8b5b75f0ee9_1 252 | scene0329_02 001300_8659f0f422096e3d26f6c8b5b75f0ee9_0 253 | scene0474_04 000200_41b02faaceadb39560fcec8f64d76ffb_0 254 | scene0334_02 000000_c856e6b37c9e12ab8a3de2846876a3c7_0 255 | scene0334_02 000100_c856e6b37c9e12ab8a3de2846876a3c7_0 256 | scene0334_02 000200_c856e6b37c9e12ab8a3de2846876a3c7_0 257 | scene0334_02 000300_c856e6b37c9e12ab8a3de2846876a3c7_1 258 | scene0334_02 000400_c856e6b37c9e12ab8a3de2846876a3c7_1 259 | scene0334_02 001000_c856e6b37c9e12ab8a3de2846876a3c7_0 260 | scene0568_00 000000_60fc7123d6360e6d620ef1b4a95dca08_0 261 | scene0568_00 001600_60fc7123d6360e6d620ef1b4a95dca08_0 262 | -------------------------------------------------------------------------------- /splits/shape/02818832/val_nonocc_centroid_maskexist.txt: -------------------------------------------------------------------------------- 1 | scene0699_00 000000_22b8e1805041fe56010a6840f668b41_0 2 | scene0699_00 000100_22b8e1805041fe56010a6840f668b41_0 3 | scene0699_00 000400_22b8e1805041fe56010a6840f668b41_0 4 | scene0353_02 000500_698ef3d1a8c0c829c580fdeb5460f6d6_0 5 | scene0353_02 000700_698ef3d1a8c0c829c580fdeb5460f6d6_1 6 | scene0217_00 000000_3acfa3c60a03415643abcff1f32a8b0c_1 7 | scene0217_00 000500_3acfa3c60a03415643abcff1f32a8b0c_0 8 | scene0217_00 000700_3acfa3c60a03415643abcff1f32a8b0c_1 9 | scene0217_00 000800_3acfa3c60a03415643abcff1f32a8b0c_1 10 | scene0193_01 000300_51223b1b770ff5e72f38f5bd71072746_0 11 | scene0046_02 000100_d7b9238af2efa963c862eec8232fff1e_0 12 | scene0046_02 000200_d7b9238af2efa963c862eec8232fff1e_0 13 | scene0046_02 000700_d7b9238af2efa963c862eec8232fff1e_0 14 | scene0046_02 000800_d7b9238af2efa963c862eec8232fff1e_0 15 | scene0046_02 002200_d7b9238af2efa963c862eec8232fff1e_0 16 | scene0046_02 002500_d7b9238af2efa963c862eec8232fff1e_0 17 | scene0645_01 000100_4dbd37cb85686dea674ce64e4bf77aec_0 18 | scene0645_01 000300_4dbd37cb85686dea674ce64e4bf77aec_1 19 | scene0645_01 000400_4dbd37cb85686dea674ce64e4bf77aec_1 20 | scene0645_01 003500_4dbd37cb85686dea674ce64e4bf77aec_0 21 | scene0645_01 003600_4dbd37cb85686dea674ce64e4bf77aec_0 22 | scene0645_01 005000_4dbd37cb85686dea674ce64e4bf77aec_0 23 | scene0645_00 000100_4dbd37cb85686dea674ce64e4bf77aec_0 24 | scene0645_00 000200_4dbd37cb85686dea674ce64e4bf77aec_0 25 | scene0645_00 000300_4dbd37cb85686dea674ce64e4bf77aec_1 26 | scene0645_00 000400_4dbd37cb85686dea674ce64e4bf77aec_1 27 | scene0645_00 000600_4dbd37cb85686dea674ce64e4bf77aec_1 28 | scene0193_00 000300_c758919a1fbe3d0c9cef17528faf7bc5_0 29 | scene0193_00 000700_c758919a1fbe3d0c9cef17528faf7bc5_0 30 | scene0658_00 000200_c2b65540d51fada22bfa21768064df9c_0 31 | scene0426_03 000600_946bd5aeda453b38b3a454ed6a7199e2_0 32 | scene0046_00 000300_f7d2cf0ebbf5453531cd8798c40e5949_0 33 | scene0046_00 000800_f7d2cf0ebbf5453531cd8798c40e5949_0 34 | scene0046_00 000900_f7d2cf0ebbf5453531cd8798c40e5949_0 35 | scene0221_00 000300_9fb6014c9944a98bd2096b2fa6f98cc7_0 36 | scene0697_02 000300_4dbd37cb85686dea674ce64e4bf77aec_0 37 | scene0697_02 000500_4dbd37cb85686dea674ce64e4bf77aec_0 38 | scene0697_02 001600_4dbd37cb85686dea674ce64e4bf77aec_0 39 | scene0697_02 002600_4dbd37cb85686dea674ce64e4bf77aec_0 40 | scene0697_02 002700_4dbd37cb85686dea674ce64e4bf77aec_0 41 | scene0144_00 000800_4dbd37cb85686dea674ce64e4bf77aec_0 42 | scene0426_00 001100_ce0f3c9d6a0b0cda71010004e0594e66_0 43 | scene0695_03 000400_1f11b3d9953fabcf8b4396b18c85cf0f_0 44 | scene0695_03 000900_1f11b3d9953fabcf8b4396b18c85cf0f_0 45 | scene0695_03 002600_1f11b3d9953fabcf8b4396b18c85cf0f_0 46 | scene0580_01 000400_5d12d1a313cff5ad66f379f51753f72b_0 47 | scene0580_01 000500_5d12d1a313cff5ad66f379f51753f72b_0 48 | scene0580_01 000600_5d12d1a313cff5ad66f379f51753f72b_0 49 | scene0580_01 000800_5d12d1a313cff5ad66f379f51753f72b_0 50 | scene0580_01 002100_5d12d1a313cff5ad66f379f51753f72b_0 51 | scene0580_01 003600_5d12d1a313cff5ad66f379f51753f72b_0 52 | scene0580_01 003800_5d12d1a313cff5ad66f379f51753f72b_0 53 | scene0353_00 000800_3acfa3c60a03415643abcff1f32a8b0c_1 54 | scene0353_00 001900_3acfa3c60a03415643abcff1f32a8b0c_0 55 | scene0435_00 000800_d7b9238af2efa963c862eec8232fff1e_0 56 | scene0435_00 001000_d7b9238af2efa963c862eec8232fff1e_0 57 | scene0435_00 001100_d7b9238af2efa963c862eec8232fff1e_0 58 | scene0435_00 001200_76db17c76f828282dcb2f14e2e42ec8d_1 59 | scene0435_00 001300_76db17c76f828282dcb2f14e2e42ec8d_1 60 | scene0435_00 001400_76db17c76f828282dcb2f14e2e42ec8d_1 61 | scene0695_02 000000_48973f489d06e8139f9d5a5f7267a470_0 62 | scene0695_02 001800_48973f489d06e8139f9d5a5f7267a470_0 63 | scene0695_02 002000_48973f489d06e8139f9d5a5f7267a470_0 64 | scene0629_02 000200_946bd5aeda453b38b3a454ed6a7199e2_0 65 | scene0629_02 001400_946bd5aeda453b38b3a454ed6a7199e2_0 66 | scene0629_02 001500_946bd5aeda453b38b3a454ed6a7199e2_0 67 | scene0629_02 001600_946bd5aeda453b38b3a454ed6a7199e2_0 68 | scene0046_01 000500_d7b9238af2efa963c862eec8232fff1e_0 69 | scene0046_01 000600_d7b9238af2efa963c862eec8232fff1e_0 70 | scene0046_01 002200_d7b9238af2efa963c862eec8232fff1e_0 71 | scene0580_00 000100_946bd5aeda453b38b3a454ed6a7199e2_0 72 | scene0580_00 000700_946bd5aeda453b38b3a454ed6a7199e2_0 73 | scene0580_00 001300_946bd5aeda453b38b3a454ed6a7199e2_0 74 | scene0580_00 001400_946bd5aeda453b38b3a454ed6a7199e2_0 75 | scene0580_00 001500_946bd5aeda453b38b3a454ed6a7199e2_0 76 | scene0580_00 004600_946bd5aeda453b38b3a454ed6a7199e2_0 77 | scene0246_00 000500_4dbd37cb85686dea674ce64e4bf77aec_0 78 | scene0246_00 002000_4dbd37cb85686dea674ce64e4bf77aec_0 79 | scene0246_00 002600_4dbd37cb85686dea674ce64e4bf77aec_0 80 | scene0697_01 000300_4dbd37cb85686dea674ce64e4bf77aec_0 81 | scene0697_01 001400_4dbd37cb85686dea674ce64e4bf77aec_0 82 | scene0697_01 001600_4dbd37cb85686dea674ce64e4bf77aec_0 83 | scene0645_02 000100_4dbd37cb85686dea674ce64e4bf77aec_1 84 | scene0645_02 000200_4dbd37cb85686dea674ce64e4bf77aec_1 85 | scene0645_02 002600_4dbd37cb85686dea674ce64e4bf77aec_0 86 | scene0353_01 000400_8df7e58200ac5e6ab91b871e750ca615_0 87 | scene0353_01 000600_8df7e58200ac5e6ab91b871e750ca615_0 88 | scene0353_01 000900_8df7e58200ac5e6ab91b871e750ca615_1 89 | scene0256_02 000000_e7d0920ba8d4b1be71424c004dd7ab2f_0 90 | scene0256_02 000100_e7d0920ba8d4b1be71424c004dd7ab2f_0 91 | scene0256_02 000600_e7d0920ba8d4b1be71424c004dd7ab2f_0 92 | scene0696_02 000100_946bd5aeda453b38b3a454ed6a7199e2_0 93 | scene0696_02 001300_946bd5aeda453b38b3a454ed6a7199e2_0 94 | scene0435_01 000700_e91c2df09de0d4b1ed4d676215f46734_0 95 | scene0435_01 000800_e91c2df09de0d4b1ed4d676215f46734_0 96 | scene0435_01 000800_e91c2df09de0d4b1ed4d676215f46734_1 97 | scene0435_01 000900_e91c2df09de0d4b1ed4d676215f46734_0 98 | scene0435_01 001000_e91c2df09de0d4b1ed4d676215f46734_0 99 | scene0435_01 001100_e91c2df09de0d4b1ed4d676215f46734_1 100 | scene0435_01 001200_e91c2df09de0d4b1ed4d676215f46734_1 101 | scene0222_01 000400_6e5f10f2574f8a285d64ca7820a9c2ca_1 102 | scene0222_01 001500_6e5f10f2574f8a285d64ca7820a9c2ca_1 103 | scene0222_01 003600_6e5f10f2574f8a285d64ca7820a9c2ca_0 104 | scene0222_01 004100_6e5f10f2574f8a285d64ca7820a9c2ca_0 105 | scene0277_01 000100_3acfa3c60a03415643abcff1f32a8b0c_0 106 | scene0277_01 000700_3acfa3c60a03415643abcff1f32a8b0c_0 107 | scene0277_01 000800_3acfa3c60a03415643abcff1f32a8b0c_0 108 | scene0435_02 000600_9fb6014c9944a98bd2096b2fa6f98cc7_0 109 | scene0435_02 000700_9fb6014c9944a98bd2096b2fa6f98cc7_0 110 | scene0435_02 000800_9fb6014c9944a98bd2096b2fa6f98cc7_1 111 | scene0435_02 000900_9fb6014c9944a98bd2096b2fa6f98cc7_0 112 | scene0435_02 001100_9fb6014c9944a98bd2096b2fa6f98cc7_1 113 | scene0697_00 001000_4dbd37cb85686dea674ce64e4bf77aec_0 114 | scene0652_00 000500_4dbd37cb85686dea674ce64e4bf77aec_0 115 | scene0652_00 001300_4dbd37cb85686dea674ce64e4bf77aec_0 116 | scene0356_02 000200_3acfa3c60a03415643abcff1f32a8b0c_0 117 | scene0356_02 000800_3acfa3c60a03415643abcff1f32a8b0c_0 118 | scene0207_02 000000_4dbd37cb85686dea674ce64e4bf77aec_0 119 | scene0207_02 000500_4dbd37cb85686dea674ce64e4bf77aec_0 120 | scene0207_02 000600_4dbd37cb85686dea674ce64e4bf77aec_0 121 | scene0207_00 000000_4dbd37cb85686dea674ce64e4bf77aec_0 122 | scene0207_00 000100_4dbd37cb85686dea674ce64e4bf77aec_0 123 | scene0207_00 000400_4dbd37cb85686dea674ce64e4bf77aec_0 124 | scene0648_00 000000_2d1a2be896054548997e2c877588ae24_0 125 | scene0648_00 000600_2d1a2be896054548997e2c877588ae24_0 126 | scene0648_00 001100_2d1a2be896054548997e2c877588ae24_1 127 | scene0648_00 002500_2d1a2be896054548997e2c877588ae24_1 128 | scene0648_00 002800_2d1a2be896054548997e2c877588ae24_0 129 | scene0648_00 003500_2d1a2be896054548997e2c877588ae24_1 130 | scene0648_00 003800_2d1a2be896054548997e2c877588ae24_0 131 | scene0648_00 003900_2d1a2be896054548997e2c877588ae24_0 132 | scene0277_00 000400_6e4707cac21b09f0531c83488903771b_0 133 | scene0277_00 001000_6e4707cac21b09f0531c83488903771b_0 134 | scene0435_03 000600_e91c2df09de0d4b1ed4d676215f46734_0 135 | scene0435_03 000700_e91c2df09de0d4b1ed4d676215f46734_1 136 | scene0435_03 000800_e91c2df09de0d4b1ed4d676215f46734_0 137 | scene0435_03 000900_e91c2df09de0d4b1ed4d676215f46734_0 138 | scene0435_03 001300_e91c2df09de0d4b1ed4d676215f46734_1 139 | scene0435_03 001400_e91c2df09de0d4b1ed4d676215f46734_1 140 | scene0648_01 001900_b9302be3dc846d834f0ba81bea651144_1 141 | scene0648_01 002200_b9302be3dc846d834f0ba81bea651144_1 142 | scene0648_01 002300_b9302be3dc846d834f0ba81bea651144_0 143 | scene0648_01 002900_b9302be3dc846d834f0ba81bea651144_1 144 | scene0648_01 003000_b9302be3dc846d834f0ba81bea651144_1 145 | scene0695_00 000800_3acfa3c60a03415643abcff1f32a8b0c_0 146 | scene0695_00 000900_3acfa3c60a03415643abcff1f32a8b0c_0 147 | scene0695_00 001800_3acfa3c60a03415643abcff1f32a8b0c_0 148 | scene0633_01 000000_76db17c76f828282dcb2f14e2e42ec8d_0 149 | scene0633_01 000100_76db17c76f828282dcb2f14e2e42ec8d_0 150 | scene0633_01 001100_76db17c76f828282dcb2f14e2e42ec8d_0 151 | scene0356_00 000200_1f11b3d9953fabcf8b4396b18c85cf0f_0 152 | scene0356_00 000900_1f11b3d9953fabcf8b4396b18c85cf0f_0 153 | scene0697_03 001500_4dbd37cb85686dea674ce64e4bf77aec_0 154 | scene0697_03 001600_4dbd37cb85686dea674ce64e4bf77aec_0 155 | scene0697_03 001700_4dbd37cb85686dea674ce64e4bf77aec_0 156 | -------------------------------------------------------------------------------- /splits/shape/02871439/val_nonocc_centroid_maskexist.txt: -------------------------------------------------------------------------------- 1 | scene0593_00 000000_601359d274c2c00a1497d160eced5e7a_0 2 | scene0593_00 000700_601359d274c2c00a1497d160eced5e7a_0 3 | scene0593_00 000800_601359d274c2c00a1497d160eced5e7a_0 4 | scene0593_00 001300_601359d274c2c00a1497d160eced5e7a_0 5 | scene0593_00 001800_601359d274c2c00a1497d160eced5e7a_0 6 | scene0593_00 001900_601359d274c2c00a1497d160eced5e7a_0 7 | scene0700_01 003200_722624d2cd4b72018ac5263758737a81_0 8 | scene0700_01 003300_722624d2cd4b72018ac5263758737a81_0 9 | scene0558_01 000100_4e26a2e39e3b6d39961b70a6f96df2a4_3 10 | scene0558_01 000400_4e26a2e39e3b6d39961b70a6f96df2a4_2 11 | scene0378_02 001100_b1696f5b9b3926b1a523e28192de797e_1 12 | scene0378_02 001200_b1696f5b9b3926b1a523e28192de797e_1 13 | scene0378_02 001300_b1696f5b9b3926b1a523e28192de797e_1 14 | scene0378_02 001400_ae26bec5a79f51943da27ece6ae88fff_0 15 | scene0378_02 001500_ae26bec5a79f51943da27ece6ae88fff_0 16 | scene0378_02 001600_ae26bec5a79f51943da27ece6ae88fff_0 17 | scene0378_02 001700_ae26bec5a79f51943da27ece6ae88fff_0 18 | scene0353_02 000000_1d3014ad5c35944f9af68b4b3261e1f8_1 19 | scene0203_00 000400_c214e1d190cb87362a9cd5247487b619_3 20 | scene0704_00 000600_9ca929d28c1838d5e41994ffd448fd07_0 21 | scene0329_01 000300_4996e0c9dac4ba3649fade4fe2abc936_0 22 | scene0591_01 000100_ad808d321f7cd914c6d17bc3a482f77_2 23 | scene0591_01 000200_ad808d321f7cd914c6d17bc3a482f77_0 24 | scene0591_01 000300_ad808d321f7cd914c6d17bc3a482f77_2 25 | scene0591_01 000900_ad808d321f7cd914c6d17bc3a482f77_2 26 | scene0591_01 001700_ad808d321f7cd914c6d17bc3a482f77_0 27 | scene0474_00 000000_c007d1b4972f70102751be1fc72418bb_0 28 | scene0474_00 000100_c007d1b4972f70102751be1fc72418bb_0 29 | scene0474_00 000900_c007d1b4972f70102751be1fc72418bb_0 30 | scene0030_02 001700_89f1540881171ee3f4c4977ed0ba5296_4 31 | scene0598_02 000000_601b64e8e27159e157da56e4c9ff868d_3 32 | scene0598_02 000100_601b64e8e27159e157da56e4c9ff868d_3 33 | scene0598_02 000400_601b64e8e27159e157da56e4c9ff868d_4 34 | scene0598_02 000600_601b64e8e27159e157da56e4c9ff868d_3 35 | scene0598_02 000700_601b64e8e27159e157da56e4c9ff868d_3 36 | scene0598_02 001100_601b64e8e27159e157da56e4c9ff868d_0 37 | scene0598_02 001300_601b64e8e27159e157da56e4c9ff868d_4 38 | scene0231_02 000500_c214e1d190cb87362a9cd5247487b619_0 39 | scene0231_02 002800_c214e1d190cb87362a9cd5247487b619_0 40 | scene0658_00 000000_c214e1d190cb87362a9cd5247487b619_1 41 | scene0658_00 000300_c214e1d190cb87362a9cd5247487b619_0 42 | scene0643_00 001600_3fa60816f15b58c4607974568e26586f_1 43 | scene0568_02 000400_4996e0c9dac4ba3649fade4fe2abc936_0 44 | scene0568_02 000500_4996e0c9dac4ba3649fade4fe2abc936_0 45 | scene0474_01 001600_c007d1b4972f70102751be1fc72418bb_0 46 | scene0527_00 000100_35afa1b806556803e99c79bad29da781_0 47 | scene0203_01 000500_d70d80e5855785a761b9fd1751b9fcb_4 48 | scene0203_01 001100_e4c42bcbba4ef5b09a2d4f7bacd6e0d8_0 49 | scene0203_01 001300_e3ae56f176b77359aa1bd50387389420_1 50 | scene0695_03 002100_465bb5a9ed42ad40a5817f81a1efa3cc_0 51 | scene0598_00 000200_81b8784259e3331f94cdfc338037bd95_0 52 | scene0353_00 000000_ad808d321f7cd914c6d17bc3a482f77_0 53 | scene0353_00 001500_ad808d321f7cd914c6d17bc3a482f77_0 54 | scene0695_02 000300_1ab8202a944a6ff1de650492e45fb14f_0 55 | scene0695_02 000900_1ab8202a944a6ff1de650492e45fb14f_0 56 | scene0064_00 000500_b105d36dbf010903b022c94235bc8601_2 57 | scene0064_00 000600_b105d36dbf010903b022c94235bc8601_2 58 | scene0064_00 001000_b105d36dbf010903b022c94235bc8601_0 59 | scene0064_01 000200_46579c6050cac50a1c8c7b57a94dbb2e_2 60 | scene0064_01 000500_dfdc22f3fecbb57c5050d6a2f5d42f74_0 61 | scene0700_02 002000_ec882f5717b0f405b2bf4f773fe0e622_0 62 | scene0700_02 002100_ec882f5717b0f405b2bf4f773fe0e622_0 63 | scene0700_02 002200_ec882f5717b0f405b2bf4f773fe0e622_0 64 | scene0231_00 001400_4552f6002e96cb00205444155ae5c84d_0 65 | scene0231_00 002100_c214e1d190cb87362a9cd5247487b619_1 66 | scene0231_00 004100_4552f6002e96cb00205444155ae5c84d_0 67 | scene0580_00 001700_b079feff448e925546c4f23965b7dd40_0 68 | scene0580_00 001800_b079feff448e925546c4f23965b7dd40_0 69 | scene0697_01 002000_c214e1d190cb87362a9cd5247487b619_0 70 | scene0700_00 001100_c007d1b4972f70102751be1fc72418bb_0 71 | scene0700_00 001200_c007d1b4972f70102751be1fc72418bb_0 72 | scene0700_00 001300_c007d1b4972f70102751be1fc72418bb_0 73 | scene0593_01 000500_ec882f5717b0f405b2bf4f773fe0e622_0 74 | scene0307_02 000200_db5b6dc38fd82fa7cdfa73f789b383fe_1 75 | scene0307_02 000600_db5b6dc38fd82fa7cdfa73f789b383fe_1 76 | scene0307_02 001300_ec882f5717b0f405b2bf4f773fe0e622_2 77 | scene0307_02 001400_ec882f5717b0f405b2bf4f773fe0e622_3 78 | scene0591_02 000100_c214e1d190cb87362a9cd5247487b619_1 79 | scene0591_02 000200_1ab8202a944a6ff1de650492e45fb14f_3 80 | scene0591_02 000700_c214e1d190cb87362a9cd5247487b619_1 81 | scene0591_02 000800_c214e1d190cb87362a9cd5247487b619_0 82 | scene0591_02 000900_1ab8202a944a6ff1de650492e45fb14f_3 83 | scene0663_02 001000_348d539dade47e8e664b3b9b23ddfcbc_2 84 | scene0663_02 001100_348d539dade47e8e664b3b9b23ddfcbc_2 85 | scene0203_02 001300_db5b6dc38fd82fa7cdfa73f789b383fe_1 86 | scene0025_01 000400_db5b6dc38fd82fa7cdfa73f789b383fe_0 87 | scene0025_01 000500_db5b6dc38fd82fa7cdfa73f789b383fe_0 88 | scene0025_01 000700_db5b6dc38fd82fa7cdfa73f789b383fe_0 89 | scene0025_01 001300_db5b6dc38fd82fa7cdfa73f789b383fe_0 90 | scene0353_01 001500_c214e1d190cb87362a9cd5247487b619_1 91 | scene0378_01 001000_c007d1b4972f70102751be1fc72418bb_0 92 | scene0378_01 001100_c007d1b4972f70102751be1fc72418bb_0 93 | scene0378_01 001200_c007d1b4972f70102751be1fc72418bb_0 94 | scene0378_01 001400_3d422581d2f439c74cc1952ae0d6e81a_1 95 | scene0378_01 001600_3d422581d2f439c74cc1952ae0d6e81a_1 96 | scene0591_00 001100_1ab8202a944a6ff1de650492e45fb14f_1 97 | scene0591_00 001200_fb1a15143f6df0b35d3dc7a81f13d5e8_0 98 | scene0591_00 001300_88d2609fe01abd4a6b8fdc9bc8edc250_2 99 | scene0591_00 001400_fb1a15143f6df0b35d3dc7a81f13d5e8_0 100 | scene0696_02 001000_225da2a2b7a461e349edb0f98d2a2a29_0 101 | scene0222_01 002500_465bb5a9ed42ad40a5817f81a1efa3cc_0 102 | scene0222_01 002600_465bb5a9ed42ad40a5817f81a1efa3cc_0 103 | scene0222_01 002700_465bb5a9ed42ad40a5817f81a1efa3cc_0 104 | scene0558_02 001700_1ab8202a944a6ff1de650492e45fb14f_8 105 | scene0644_00 000700_c40828710c91001546c4f23965b7dd40_0 106 | scene0644_00 001500_c40828710c91001546c4f23965b7dd40_0 107 | scene0357_01 000200_c214e1d190cb87362a9cd5247487b619_0 108 | scene0307_01 000100_ec882f5717b0f405b2bf4f773fe0e622_0 109 | scene0307_01 000200_ec882f5717b0f405b2bf4f773fe0e622_0 110 | scene0307_01 000500_ec882f5717b0f405b2bf4f773fe0e622_0 111 | scene0307_01 001800_ec882f5717b0f405b2bf4f773fe0e622_0 112 | scene0307_01 002200_ec882f5717b0f405b2bf4f773fe0e622_1 113 | scene0307_01 003000_ec882f5717b0f405b2bf4f773fe0e622_0 114 | scene0307_01 003100_ec882f5717b0f405b2bf4f773fe0e622_0 115 | scene0222_00 000100_e3ae56f176b77359aa1bd50387389420_1 116 | scene0222_00 000900_e3ae56f176b77359aa1bd50387389420_1 117 | scene0222_00 001000_e3ae56f176b77359aa1bd50387389420_1 118 | scene0222_00 001100_e3ae56f176b77359aa1bd50387389420_1 119 | scene0222_00 003200_e3ae56f176b77359aa1bd50387389420_0 120 | scene0222_00 004900_e3ae56f176b77359aa1bd50387389420_0 121 | scene0695_01 001100_465bb5a9ed42ad40a5817f81a1efa3cc_0 122 | scene0695_01 002000_465bb5a9ed42ad40a5817f81a1efa3cc_0 123 | scene0474_02 001800_8007cf8349381dada5817f81a1efa3cc_0 124 | scene0474_02 002000_8007cf8349381dada5817f81a1efa3cc_0 125 | scene0131_02 000200_b6264b92d0cbd598d5919aa833abb5a_0 126 | scene0277_02 000300_465bb5a9ed42ad40a5817f81a1efa3cc_0 127 | scene0356_02 000700_b6264b92d0cbd598d5919aa833abb5a_0 128 | scene0207_01 001700_ec882f5717b0f405b2bf4f773fe0e622_0 129 | scene0207_02 000200_ec882f5717b0f405b2bf4f773fe0e622_0 130 | scene0558_00 000000_b105d36dbf010903b022c94235bc8601_1 131 | scene0558_00 000300_b105d36dbf010903b022c94235bc8601_0 132 | scene0558_00 000700_b105d36dbf010903b022c94235bc8601_1 133 | scene0357_00 000200_d841664bb8c7b842cbf15a79411ef31b_0 134 | scene0648_00 000400_d841664bb8c7b842cbf15a79411ef31b_0 135 | scene0648_00 002000_e3ae56f176b77359aa1bd50387389420_1 136 | scene0648_00 002900_d841664bb8c7b842cbf15a79411ef31b_0 137 | scene0648_00 003100_d841664bb8c7b842cbf15a79411ef31b_0 138 | scene0648_00 003300_e3ae56f176b77359aa1bd50387389420_2 139 | scene0231_01 002000_5793619650b38e335acd449a2ae99009_1 140 | scene0535_00 000000_29b66fc9db2f1558e0e89fd83955713c_1 141 | scene0535_00 000100_29b66fc9db2f1558e0e89fd83955713c_1 142 | scene0568_01 000300_9ca929d28c1838d5e41994ffd448fd07_0 143 | scene0568_01 000400_9ca929d28c1838d5e41994ffd448fd07_0 144 | scene0568_01 000500_9ca929d28c1838d5e41994ffd448fd07_0 145 | scene0568_01 001200_34128cd8b8ddbbedd92903da5c4b5ef6_1 146 | scene0663_00 001000_131b5b691ff5bff3945770bff82992ca_0 147 | scene0663_00 001500_364a2a1172dc0a827c6de7e52b00ebab_3 148 | scene0663_00 001900_131b5b691ff5bff3945770bff82992ca_0 149 | scene0663_00 002000_131b5b691ff5bff3945770bff82992ca_0 150 | scene0663_00 002100_131b5b691ff5bff3945770bff82992ca_0 151 | scene0663_00 002200_131b5b691ff5bff3945770bff82992ca_0 152 | scene0208_00 000200_908c0b3c235a81c08998b3b64a143d42_8 153 | scene0208_00 000700_908c0b3c235a81c08998b3b64a143d42_8 154 | scene0208_00 001100_908c0b3c235a81c08998b3b64a143d42_6 155 | scene0208_00 001600_908c0b3c235a81c08998b3b64a143d42_7 156 | scene0208_00 001700_908c0b3c235a81c08998b3b64a143d42_7 157 | scene0208_00 002000_908c0b3c235a81c08998b3b64a143d42_9 158 | scene0648_01 000200_b6264b92d0cbd598d5919aa833abb5a_2 159 | scene0648_01 000300_b6264b92d0cbd598d5919aa833abb5a_2 160 | scene0648_01 000400_b6264b92d0cbd598d5919aa833abb5a_2 161 | scene0648_01 003400_61ea75b808512124ae5607418ff674f1_1 162 | scene0378_00 001000_b1696f5b9b3926b1a523e28192de797e_1 163 | scene0378_00 001100_b1696f5b9b3926b1a523e28192de797e_1 164 | scene0378_00 001400_c007d1b4972f70102751be1fc72418bb_0 165 | scene0378_00 001500_c007d1b4972f70102751be1fc72418bb_0 166 | scene0378_00 001600_c007d1b4972f70102751be1fc72418bb_0 167 | scene0378_00 001800_c007d1b4972f70102751be1fc72418bb_0 168 | scene0356_01 000100_dc9af91c3831983597e6cb8e9bd5d9e5_0 169 | scene0025_00 000600_db5b6dc38fd82fa7cdfa73f789b383fe_0 170 | scene0025_00 000700_db5b6dc38fd82fa7cdfa73f789b383fe_0 171 | scene0695_00 001300_b6264b92d0cbd598d5919aa833abb5a_0 172 | scene0077_01 000200_1b3090bfb11cf834f06824a9291300d4_0 173 | scene0474_05 000100_8007cf8349381dada5817f81a1efa3cc_0 174 | scene0474_05 000200_8007cf8349381dada5817f81a1efa3cc_0 175 | scene0474_05 001900_8007cf8349381dada5817f81a1efa3cc_0 176 | scene0633_01 000600_34128cd8b8ddbbedd92903da5c4b5ef6_0 177 | scene0633_01 000700_34128cd8b8ddbbedd92903da5c4b5ef6_0 178 | scene0633_01 000700_fbc6b4e0aa9cc13d713e7f5d7ea85661_1 179 | scene0633_01 000800_34128cd8b8ddbbedd92903da5c4b5ef6_0 180 | scene0077_00 000100_8f0baf74e4d91b0d7b39f9a454e866f6_0 181 | scene0356_00 000700_b6264b92d0cbd598d5919aa833abb5a_0 182 | scene0568_00 000300_4996e0c9dac4ba3649fade4fe2abc936_0 183 | scene0568_00 000400_4996e0c9dac4ba3649fade4fe2abc936_0 184 | scene0568_00 000500_4996e0c9dac4ba3649fade4fe2abc936_0 185 | -------------------------------------------------------------------------------- /splits/shape/04256520/val_nonocc_centroid_maskexist.txt: -------------------------------------------------------------------------------- 1 | scene0474_03 001200_58447a958c4af154942bb07caacf4df3_0 2 | scene0474_03 001300_58447a958c4af154942bb07caacf4df3_0 3 | scene0593_00 000900_93b421c66ff4529f37b2bb75885cfc44_0 4 | scene0593_00 001600_93b421c66ff4529f37b2bb75885cfc44_0 5 | scene0353_02 000200_61f828a545649e98f1d7342136779c0_0 6 | scene0353_02 001400_61f828a545649e98f1d7342136779c0_0 7 | scene0353_02 002100_61f828a545649e98f1d7342136779c0_0 8 | scene0329_00 000000_23833969c0011b8e98494085d68ad6a0_1 9 | scene0329_00 000100_556166f38429cdfe29bdd38dd4a1a461_0 10 | scene0329_00 000200_23833969c0011b8e98494085d68ad6a0_1 11 | scene0329_00 000600_23833969c0011b8e98494085d68ad6a0_1 12 | scene0329_00 001100_556166f38429cdfe29bdd38dd4a1a461_0 13 | scene0030_00 000000_7ab86358957e386d76de5cade2fd5247_0 14 | scene0030_00 000100_7ab86358957e386d76de5cade2fd5247_0 15 | scene0030_00 001600_7ab86358957e386d76de5cade2fd5247_0 16 | scene0518_00 000200_955d633562dff06f843e991acd39f432_0 17 | scene0518_00 000500_955d633562dff06f843e991acd39f432_0 18 | scene0518_00 001000_955d633562dff06f843e991acd39f432_0 19 | scene0203_00 000200_330d44833e1b4b168b38796afe7ee552_0 20 | scene0203_00 000600_330d44833e1b4b168b38796afe7ee552_0 21 | scene0574_01 001000_945a038c3e0c46ec19fb4103277a6b93_0 22 | scene0574_01 001100_945a038c3e0c46ec19fb4103277a6b93_0 23 | scene0574_01 001200_945a038c3e0c46ec19fb4103277a6b93_0 24 | scene0574_01 001500_945a038c3e0c46ec19fb4103277a6b93_0 25 | scene0329_01 000000_7c9e1876b1643e93f9377e1922a21892_1 26 | scene0329_01 000100_7c9e1876b1643e93f9377e1922a21892_1 27 | scene0329_01 000200_7c9e1876b1643e93f9377e1922a21892_1 28 | scene0329_01 000800_7c9e1876b1643e93f9377e1922a21892_0 29 | scene0329_01 000900_7c9e1876b1643e93f9377e1922a21892_1 30 | scene0701_01 000100_41b02faaceadb39560fcec8f64d76ffb_3 31 | scene0701_01 000200_4653af854bf098f2d74aae0eb2ddb027_1 32 | scene0701_01 000200_4653af854bf098f2d74aae0eb2ddb027_2 33 | scene0701_01 000300_681d226acbeaaf08a4ee0fb6a51564c3_0 34 | scene0701_01 000600_41b02faaceadb39560fcec8f64d76ffb_3 35 | scene0701_01 000700_4653af854bf098f2d74aae0eb2ddb027_2 36 | scene0701_01 000800_4653af854bf098f2d74aae0eb2ddb027_1 37 | scene0591_01 001000_23833969c0011b8e98494085d68ad6a0_0 38 | scene0474_00 000400_41b02faaceadb39560fcec8f64d76ffb_0 39 | scene0474_00 001200_41b02faaceadb39560fcec8f64d76ffb_0 40 | scene0474_00 001300_41b02faaceadb39560fcec8f64d76ffb_0 41 | scene0608_01 000500_fb74336a6192c4787afee304cce81d6f_0 42 | scene0608_01 000700_fb74336a6192c4787afee304cce81d6f_0 43 | scene0608_01 000800_fb74336a6192c4787afee304cce81d6f_0 44 | scene0608_01 000900_fb74336a6192c4787afee304cce81d6f_0 45 | scene0608_01 001000_fb74336a6192c4787afee304cce81d6f_0 46 | scene0609_03 000000_cd10e95d1501ed6719fb4103277a6b93_3 47 | scene0609_03 000300_cd10e95d1501ed6719fb4103277a6b93_0 48 | scene0609_03 000600_cd10e95d1501ed6719fb4103277a6b93_2 49 | scene0609_03 000600_cd10e95d1501ed6719fb4103277a6b93_3 50 | scene0645_01 001100_fb74336a6192c4787afee304cce81d6f_0 51 | scene0645_01 001200_fb74336a6192c4787afee304cce81d6f_0 52 | scene0645_01 001300_fb74336a6192c4787afee304cce81d6f_0 53 | scene0645_01 001400_fb74336a6192c4787afee304cce81d6f_0 54 | scene0645_01 001700_fb74336a6192c4787afee304cce81d6f_0 55 | scene0559_02 000000_330d44833e1b4b168b38796afe7ee552_0 56 | scene0559_02 000100_330d44833e1b4b168b38796afe7ee552_0 57 | scene0645_00 001100_8a5a40fe10eb2b2eb022c94235bc8601_0 58 | scene0645_00 001200_8a5a40fe10eb2b2eb022c94235bc8601_0 59 | scene0645_00 001300_8a5a40fe10eb2b2eb022c94235bc8601_0 60 | scene0645_00 001400_8a5a40fe10eb2b2eb022c94235bc8601_0 61 | scene0645_00 001500_8a5a40fe10eb2b2eb022c94235bc8601_0 62 | scene0645_00 001600_8a5a40fe10eb2b2eb022c94235bc8601_0 63 | scene0231_02 000300_cbd547bfb6b7d8e54b50faf1a96496ef_1 64 | scene0231_02 000700_cbd547bfb6b7d8e54b50faf1a96496ef_1 65 | scene0231_02 002500_3f5fdc05fc572730490ad276cd2af3a4_0 66 | scene0231_02 002600_cbd547bfb6b7d8e54b50faf1a96496ef_1 67 | scene0568_02 000000_60fc7123d6360e6d620ef1b4a95dca08_0 68 | scene0568_02 000100_60fc7123d6360e6d620ef1b4a95dca08_0 69 | scene0568_02 001400_60fc7123d6360e6d620ef1b4a95dca08_0 70 | scene0423_01 000200_68a1f95fed336299f51f77a6d7299806_0 71 | scene0423_01 000300_68a1f95fed336299f51f77a6d7299806_0 72 | scene0423_01 000400_68a1f95fed336299f51f77a6d7299806_0 73 | scene0423_01 000600_68a1f95fed336299f51f77a6d7299806_0 74 | scene0690_00 000200_23833969c0011b8e98494085d68ad6a0_0 75 | scene0050_02 000400_8458d6939967ac1bbc7a6acbd8f058b_0 76 | scene0050_02 001800_8458d6939967ac1bbc7a6acbd8f058b_0 77 | scene0050_02 001900_8458d6939967ac1bbc7a6acbd8f058b_0 78 | scene0050_02 002000_8458d6939967ac1bbc7a6acbd8f058b_0 79 | scene0050_02 004300_8458d6939967ac1bbc7a6acbd8f058b_0 80 | scene0608_02 000500_608936a307740f5df7628281ecb18112_0 81 | scene0608_02 000600_608936a307740f5df7628281ecb18112_0 82 | scene0608_02 000700_608936a307740f5df7628281ecb18112_0 83 | scene0608_02 000800_608936a307740f5df7628281ecb18112_0 84 | scene0608_02 000900_608936a307740f5df7628281ecb18112_0 85 | scene0608_02 002200_608936a307740f5df7628281ecb18112_0 86 | scene0608_02 002300_608936a307740f5df7628281ecb18112_0 87 | scene0608_02 002400_608936a307740f5df7628281ecb18112_0 88 | scene0608_02 002500_608936a307740f5df7628281ecb18112_0 89 | scene0474_01 000800_fee8e1e0161f69b0db039d8689a74349_0 90 | scene0474_01 000900_fee8e1e0161f69b0db039d8689a74349_0 91 | scene0474_01 001000_fee8e1e0161f69b0db039d8689a74349_0 92 | scene0474_01 001100_fee8e1e0161f69b0db039d8689a74349_0 93 | scene0203_01 000200_117f6ac4bcd75d8b4ad65adb06bbae49_1 94 | scene0203_01 000800_117f6ac4bcd75d8b4ad65adb06bbae49_0 95 | scene0203_01 001200_117f6ac4bcd75d8b4ad65adb06bbae49_0 96 | scene0701_00 000000_60fc7123d6360e6d620ef1b4a95dca08_2 97 | scene0701_00 000100_60fc7123d6360e6d620ef1b4a95dca08_2 98 | scene0701_00 000200_f846fb7af63a5e838eec9023c5b97e00_0 99 | scene0701_00 000400_c3c5818cbe6d0903822a33e080d0e71c_1 100 | scene0701_00 000700_60fc7123d6360e6d620ef1b4a95dca08_2 101 | scene0701_00 000800_f846fb7af63a5e838eec9023c5b97e00_0 102 | scene0701_00 000900_c3c5818cbe6d0903822a33e080d0e71c_1 103 | scene0701_00 001000_f846fb7af63a5e838eec9023c5b97e00_0 104 | scene0334_01 000000_849ddda40bd6540efac8371a83e130ac_1 105 | scene0334_01 000100_849ddda40bd6540efac8371a83e130ac_0 106 | scene0334_01 000200_849ddda40bd6540efac8371a83e130ac_0 107 | scene0334_01 000900_849ddda40bd6540efac8371a83e130ac_0 108 | scene0334_01 001000_849ddda40bd6540efac8371a83e130ac_0 109 | scene0334_01 001000_849ddda40bd6540efac8371a83e130ac_1 110 | scene0353_00 000200_436a96f58ef9a6fdb039d8689a74349_0 111 | scene0353_00 000300_436a96f58ef9a6fdb039d8689a74349_0 112 | scene0435_00 000700_1230d31e3a6cbf309cd431573238602d_0 113 | scene0064_00 000000_c2d26d8c8d5917d443ba2b548bab2839_1 114 | scene0064_00 000100_c2d26d8c8d5917d443ba2b548bab2839_1 115 | scene0064_00 000300_c2d26d8c8d5917d443ba2b548bab2839_0 116 | scene0064_00 000400_c2d26d8c8d5917d443ba2b548bab2839_0 117 | scene0064_00 000500_c2d26d8c8d5917d443ba2b548bab2839_0 118 | scene0064_00 000800_c2d26d8c8d5917d443ba2b548bab2839_1 119 | scene0064_00 000900_c2d26d8c8d5917d443ba2b548bab2839_1 120 | scene0064_00 001000_c2d26d8c8d5917d443ba2b548bab2839_1 121 | scene0064_00 001200_c2d26d8c8d5917d443ba2b548bab2839_0 122 | scene0608_00 001000_556166f38429cdfe29bdd38dd4a1a461_0 123 | scene0608_00 002500_556166f38429cdfe29bdd38dd4a1a461_0 124 | scene0608_00 002600_556166f38429cdfe29bdd38dd4a1a461_0 125 | scene0231_00 000600_9ceb81a09813d5f3d2565bc39479705a_0 126 | scene0231_00 000700_9ceb81a09813d5f3d2565bc39479705a_0 127 | scene0231_00 000800_9ceb81a09813d5f3d2565bc39479705a_0 128 | scene0231_00 001500_9ceb81a09813d5f3d2565bc39479705a_0 129 | scene0231_00 003300_9ceb81a09813d5f3d2565bc39479705a_0 130 | scene0231_00 004200_9ceb81a09813d5f3d2565bc39479705a_0 131 | scene0050_00 000200_556166f38429cdfe29bdd38dd4a1a461_0 132 | scene0050_00 001700_556166f38429cdfe29bdd38dd4a1a461_0 133 | scene0050_00 002600_556166f38429cdfe29bdd38dd4a1a461_0 134 | scene0050_00 003000_556166f38429cdfe29bdd38dd4a1a461_0 135 | scene0645_02 001400_dcd9a34a9892fb11490ad276cd2af3a4_0 136 | scene0645_02 001500_dcd9a34a9892fb11490ad276cd2af3a4_0 137 | scene0645_02 001800_dcd9a34a9892fb11490ad276cd2af3a4_0 138 | scene0645_02 001900_dcd9a34a9892fb11490ad276cd2af3a4_0 139 | scene0591_02 000400_fb65fdcded332e4118039d66c0209ecb_0 140 | scene0591_02 001500_fb65fdcded332e4118039d66c0209ecb_0 141 | scene0591_02 002200_fb65fdcded332e4118039d66c0209ecb_0 142 | scene0203_02 000000_1824d5cfb7472fcf9d5cfc3a8d7af21d_0 143 | scene0203_02 000800_1824d5cfb7472fcf9d5cfc3a8d7af21d_0 144 | scene0203_02 001500_1824d5cfb7472fcf9d5cfc3a8d7af21d_0 145 | scene0025_01 000000_f20e7f4f41f323a04b3c42e318f3affc_0 146 | scene0025_01 000400_d053e745b565fa391c1b3b2ed8d13bf8_1 147 | scene0025_01 000500_d053e745b565fa391c1b3b2ed8d13bf8_1 148 | scene0025_01 000700_d053e745b565fa391c1b3b2ed8d13bf8_1 149 | scene0025_01 001200_f20e7f4f41f323a04b3c42e318f3affc_0 150 | scene0025_01 001500_f20e7f4f41f323a04b3c42e318f3affc_0 151 | scene0549_00 000100_681d226acbeaaf08a4ee0fb6a51564c3_0 152 | scene0549_00 000300_ef479941cb60405f8cbd400aa99bee96_1 153 | scene0549_00 000400_ef479941cb60405f8cbd400aa99bee96_1 154 | scene0549_00 000500_ef479941cb60405f8cbd400aa99bee96_1 155 | scene0549_00 000800_681d226acbeaaf08a4ee0fb6a51564c3_0 156 | scene0353_01 000200_b2a9553d5d81060b36c9a52137c03278_0 157 | scene0353_01 001700_b2a9553d5d81060b36c9a52137c03278_0 158 | scene0353_01 002300_b2a9553d5d81060b36c9a52137c03278_0 159 | scene0591_00 000800_289e520179ed1e397282872e507d5fff_0 160 | scene0591_00 000900_289e520179ed1e397282872e507d5fff_0 161 | scene0591_00 001000_289e520179ed1e397282872e507d5fff_0 162 | scene0591_00 001600_289e520179ed1e397282872e507d5fff_0 163 | scene0696_02 000500_7ae657b39aa2be68ccd1bcd57588acf8_0 164 | scene0549_01 000100_4e1ee66994a95492f2543b208c9ee8e2_1 165 | scene0549_01 000300_4e1ee66994a95492f2543b208c9ee8e2_1 166 | scene0549_01 000500_4e1ee66994a95492f2543b208c9ee8e2_1 167 | scene0549_01 000800_4e1ee66994a95492f2543b208c9ee8e2_0 168 | scene0549_01 000900_4e1ee66994a95492f2543b208c9ee8e2_0 169 | scene0549_01 001000_4e1ee66994a95492f2543b208c9ee8e2_0 170 | scene0474_02 000800_7cfccaf7557934911ee8243f54292d6_0 171 | scene0474_02 001200_7cfccaf7557934911ee8243f54292d6_0 172 | scene0559_00 000000_867d1e4a9f7cc110b8df7b9b18a5c81f_0 173 | scene0696_01 000300_bc6a3fa659dd7ec0c62ac18334863d36_0 174 | scene0696_01 001000_bc6a3fa659dd7ec0c62ac18334863d36_0 175 | scene0652_00 000200_cd249bd432c4bc75b82cf928f6ed5338_0 176 | scene0652_00 000300_cd249bd432c4bc75b82cf928f6ed5338_0 177 | scene0652_00 001000_cd249bd432c4bc75b82cf928f6ed5338_0 178 | scene0207_01 000000_fd4dd071f73ca07355eab99951962891_0 179 | scene0207_01 001600_fd4dd071f73ca07355eab99951962891_0 180 | scene0207_02 000300_8efa91e2f3e2eaf7bdc82a7932cd806_0 181 | scene0207_02 002300_8efa91e2f3e2eaf7bdc82a7932cd806_0 182 | scene0690_01 000100_556166f38429cdfe29bdd38dd4a1a461_0 183 | scene0690_01 000200_556166f38429cdfe29bdd38dd4a1a461_0 184 | scene0207_00 000300_330d44833e1b4b168b38796afe7ee552_0 185 | scene0207_00 000600_330d44833e1b4b168b38796afe7ee552_0 186 | scene0207_00 001700_330d44833e1b4b168b38796afe7ee552_0 187 | scene0231_01 001200_13b9cc6c187edb98afd316e82119b42_0 188 | scene0231_01 003700_13b9cc6c187edb98afd316e82119b42_0 189 | scene0025_02 000000_61f828a545649e98f1d7342136779c0_1 190 | scene0025_02 000600_7fd704652332a45b2ce025aebfea84a4_0 191 | scene0025_02 000700_61f828a545649e98f1d7342136779c0_1 192 | scene0025_02 000800_7fd704652332a45b2ce025aebfea84a4_0 193 | scene0568_01 000000_cceaeed0d8cf5bdbca68d7e2f215cb19_0 194 | scene0568_01 000100_cceaeed0d8cf5bdbca68d7e2f215cb19_0 195 | scene0187_01 000000_44854046021846f219fb4103277a6b93_0 196 | scene0187_01 000100_44854046021846f219fb4103277a6b93_0 197 | scene0187_01 001200_44854046021846f219fb4103277a6b93_1 198 | scene0187_01 001300_44854046021846f219fb4103277a6b93_1 199 | scene0187_01 001600_44854046021846f219fb4103277a6b93_0 200 | scene0334_00 000000_849ddda40bd6540efac8371a83e130ac_2 201 | scene0334_00 000100_849ddda40bd6540efac8371a83e130ac_1 202 | scene0334_00 000300_849ddda40bd6540efac8371a83e130ac_3 203 | scene0334_00 000500_849ddda40bd6540efac8371a83e130ac_0 204 | scene0334_00 001100_849ddda40bd6540efac8371a83e130ac_1 205 | scene0334_00 001100_849ddda40bd6540efac8371a83e130ac_2 206 | scene0461_00 000000_e9e5da988215f06513292732a7b1ed9a_0 207 | scene0461_00 000000_e9e5da988215f06513292732a7b1ed9a_1 208 | scene0461_00 000100_e9e5da988215f06513292732a7b1ed9a_0 209 | scene0461_00 000100_e9e5da988215f06513292732a7b1ed9a_1 210 | scene0461_00 000200_e9e5da988215f06513292732a7b1ed9a_0 211 | scene0461_00 000200_e9e5da988215f06513292732a7b1ed9a_1 212 | scene0461_00 000300_e9e5da988215f06513292732a7b1ed9a_1 213 | scene0461_00 000500_e9e5da988215f06513292732a7b1ed9a_1 214 | scene0025_00 000200_8659f0f422096e3d26f6c8b5b75f0ee9_1 215 | scene0025_00 000400_8659f0f422096e3d26f6c8b5b75f0ee9_1 216 | scene0025_00 000500_8659f0f422096e3d26f6c8b5b75f0ee9_1 217 | scene0025_00 000800_8659f0f422096e3d26f6c8b5b75f0ee9_0 218 | scene0025_00 001600_8659f0f422096e3d26f6c8b5b75f0ee9_0 219 | scene0050_01 001000_8659f0f422096e3d26f6c8b5b75f0ee9_1 220 | scene0050_01 001200_bf01483d8b58f0819767624530e7fce3_0 221 | scene0050_01 001900_8659f0f422096e3d26f6c8b5b75f0ee9_1 222 | scene0050_01 002000_8659f0f422096e3d26f6c8b5b75f0ee9_1 223 | scene0050_01 002100_8659f0f422096e3d26f6c8b5b75f0ee9_1 224 | scene0050_01 003500_bf01483d8b58f0819767624530e7fce3_0 225 | scene0701_02 000000_679010d35da8193219fb4103277a6b93_0 226 | scene0701_02 000100_679010d35da8193219fb4103277a6b93_0 227 | scene0701_02 000100_bdd7a0eb66e8884dad04591c9486ec0_2 228 | scene0701_02 000200_bdd7a0eb66e8884dad04591c9486ec0_2 229 | scene0701_02 000300_62e90a6ed511a1b2d291861d5bc3e7c8_1 230 | scene0701_02 000400_679010d35da8193219fb4103277a6b93_0 231 | scene0701_02 000500_679010d35da8193219fb4103277a6b93_0 232 | scene0701_02 000700_679010d35da8193219fb4103277a6b93_0 233 | scene0701_02 000800_bdd7a0eb66e8884dad04591c9486ec0_2 234 | scene0701_02 000900_679010d35da8193219fb4103277a6b93_0 235 | scene0701_02 000900_62e90a6ed511a1b2d291861d5bc3e7c8_1 236 | scene0701_02 001000_62e90a6ed511a1b2d291861d5bc3e7c8_1 237 | scene0701_02 001100_62e90a6ed511a1b2d291861d5bc3e7c8_1 238 | scene0701_02 001200_62e90a6ed511a1b2d291861d5bc3e7c8_1 239 | scene0701_02 001200_bdd7a0eb66e8884dad04591c9486ec0_2 240 | scene0647_01 000500_c7f31b9900a1a7644785ad2feb797e_0 241 | scene0647_01 000500_354c37c168778a0bd4830313df3656b_1 242 | scene0647_01 000600_c7f31b9900a1a7644785ad2feb797e_0 243 | scene0647_01 000600_354c37c168778a0bd4830313df3656b_1 244 | scene0474_05 002700_3d164c442e5788e25c7a30510dbe4e9f_0 245 | scene0559_01 000200_ad0e50d6f1e9a16aefc579970fcfc006_0 246 | scene0559_01 000300_ad0e50d6f1e9a16aefc579970fcfc006_0 247 | scene0559_01 000400_ad0e50d6f1e9a16aefc579970fcfc006_0 248 | scene0329_02 000000_8659f0f422096e3d26f6c8b5b75f0ee9_0 249 | scene0329_02 000100_8659f0f422096e3d26f6c8b5b75f0ee9_1 250 | scene0329_02 000500_8659f0f422096e3d26f6c8b5b75f0ee9_1 251 | scene0329_02 000600_8659f0f422096e3d26f6c8b5b75f0ee9_1 252 | scene0329_02 001300_8659f0f422096e3d26f6c8b5b75f0ee9_0 253 | scene0474_04 000200_41b02faaceadb39560fcec8f64d76ffb_0 254 | scene0334_02 000000_c856e6b37c9e12ab8a3de2846876a3c7_0 255 | scene0334_02 000100_c856e6b37c9e12ab8a3de2846876a3c7_0 256 | scene0334_02 000200_c856e6b37c9e12ab8a3de2846876a3c7_0 257 | scene0334_02 000300_c856e6b37c9e12ab8a3de2846876a3c7_1 258 | scene0334_02 000400_c856e6b37c9e12ab8a3de2846876a3c7_1 259 | scene0334_02 001000_c856e6b37c9e12ab8a3de2846876a3c7_0 260 | scene0568_00 000000_60fc7123d6360e6d620ef1b4a95dca08_0 261 | scene0568_00 001600_60fc7123d6360e6d620ef1b4a95dca08_0 262 | -------------------------------------------------------------------------------- /util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | --------------------------------------------------------------------------------