├── .gitignore ├── README.md ├── evaluate_on_testset.py ├── infer_custom.py ├── media ├── example_custom_data │ ├── 0.png │ ├── 1.png │ ├── 2.png │ ├── 3.png │ ├── 4.png │ └── 5.png └── teaser.jpg ├── requirements.txt └── src ├── .gitignore ├── Dataset ├── BaseDatabase.py ├── BaseDataset.py ├── custom.py ├── dataset_imports.py ├── gso.py ├── linemod.py ├── navi.py └── navi_util │ ├── data_util.py │ └── transformations.py ├── database_util.py ├── dataset_util.py ├── debug_util.py ├── elev_util.py ├── evaluate ├── eval_on_an_obj.py └── eval_test_set.py ├── exception_util.py ├── gen6d └── Gen6D │ ├── .gitignore │ ├── __init__.py │ ├── compute_align_poses.py │ ├── configs │ ├── detector │ │ ├── detector_pretrain.yaml │ │ └── detector_train.yaml │ ├── gen6d_pretrain.yaml │ ├── gen6d_train.yaml │ ├── refiner │ │ ├── refiner_pretrain.yaml │ │ └── refiner_train.yaml │ └── selector │ │ ├── selector_pretrain.yaml │ │ └── selector_train.yaml │ ├── dataset │ ├── __init__.py │ └── database.py │ ├── estimator.py │ ├── eval.py │ ├── network │ ├── __init__.py │ ├── attention.py │ ├── detector.py │ ├── loss.py │ ├── metrics.py │ ├── operator.py │ ├── pretrain_models.py │ ├── refiner.py │ └── selector.py │ ├── pipeline.py │ ├── scy │ ├── Config.py │ ├── DebugUtil.py │ ├── IntermediateResult.py │ ├── MyJSONEncoder.py │ ├── Zero123Detector.py │ ├── __init__.py │ └── gen6dGlobal.py │ └── utils │ ├── __init__.py │ ├── base_utils.py │ ├── bbox_utils.py │ ├── database_utils.py │ ├── dataset_utils.py │ ├── draw_utils.py │ ├── imgs_info.py │ ├── pose_utils.py │ └── read_write_model.py ├── image_util.py ├── import_util.py ├── imports.py ├── infer_pair.py ├── infer_pairs.py ├── logging_util.py ├── mask_util.py ├── misc_util.py ├── miscellaneous ├── EvalResult.py ├── MemoryCache.py ├── Zero123_BatchB_Input.py └── m.py ├── oee ├── .gitignore ├── models │ └── loftr │ │ ├── __init__.py │ │ ├── backbone │ │ ├── __init__.py │ │ └── resnet_fpn.py │ │ ├── loftr.py │ │ ├── loftr_module │ │ ├── __init__.py │ │ ├── fine_preprocess.py │ │ ├── linear_attention.py │ │ └── transformer.py │ │ └── utils │ │ ├── coarse_matching.py │ │ ├── cvpr_ds_config.py │ │ ├── fine_matching.py │ │ ├── geometry.py │ │ ├── position_encoding.py │ │ └── supervision.py └── utils │ ├── elev_est_api.py │ └── utils3d.py ├── path_configuration.py ├── pose_util.py ├── redirect_util.py ├── root_config.py ├── vis ├── InterVisualizer.py ├── cv2_util.py ├── extrinsic2pyramid │ ├── demo1.py │ ├── demo2.py │ └── util │ │ ├── camera_parameter_loader.py │ │ └── camera_pose_visualizer.py ├── plotly_scene_visualization.py ├── py_matplotlib_helper.py ├── vis_pose.py └── vis_rel_pose.py └── zero123 └── zero1 ├── __init__.py ├── configs └── sd-objaverse-finetune-c_concat-256.yaml ├── ldm ├── data │ ├── __init__.py │ ├── base.py │ ├── dummy.py │ ├── inpainting │ │ ├── __init__.py │ │ └── synthetic_mask.py │ └── simple.py ├── extras.py ├── guidance.py ├── lr_scheduler.py ├── models │ ├── autoencoder.py │ └── diffusion │ │ ├── __init__.py │ │ ├── classifier.py │ │ ├── ddim.py │ │ ├── ddpm.py │ │ ├── plms.py │ │ └── sampling_util.py ├── modules │ ├── attention.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ ├── model.py │ │ ├── openaimodel.py │ │ └── util.py │ ├── distributions │ │ ├── __init__.py │ │ └── distributions.py │ ├── ema.py │ ├── encoders │ │ ├── __init__.py │ │ └── modules.py │ ├── evaluate │ │ ├── adm_evaluator.py │ │ ├── evaluate_perceptualsim.py │ │ ├── frechet_video_distance.py │ │ ├── ssim.py │ │ └── torch_frechet_video_distance.py │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ └── x_transformer.py ├── thirdp │ └── psp │ │ ├── helpers.py │ │ ├── id_loss.py │ │ └── model_irse.py └── util.py ├── main.py ├── run4gen6d.py ├── run_.py └── util_4_e2vg ├── CameraMatrixUtil.py ├── ImagePathUtil.py ├── IntermediateResult.py ├── Util.py ├── __init__.py ├── choose_j.py ├── crop.py ├── genIntermediateResult.py └── move_obj_to_center.py /.gitignore: -------------------------------------------------------------------------------- 1 | # !*/ 2 | !/.gitignore 3 | __pycache__ 4 | .idea 5 | .vscode 6 | # !/*.py 7 | 8 | *.zip 9 | *.pth 10 | *.ckpt 11 | *.exe 12 | *.pyc 13 | *.log 14 | *.csv 15 | *.jsonc 16 | *.json5 17 | *.xlsx 18 | *.xls 19 | *.pdf 20 | *.doc 21 | *.docx 22 | *.ppt 23 | *.pth 24 | *.html 25 | *.js 26 | *.png 27 | *.jpg 28 | # *.md 29 | *.txt 30 | *.css 31 | *.json 32 | *.jsonc 33 | *.json5 34 | 35 | 36 | 37 | __pycache__ 38 | .idea 39 | pytorch3d 40 | 41 | !requirements.txt 42 | !media 43 | !media/teaser.jpg 44 | !media/example_custom_data 45 | !media/example_custom_data/* -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models 2 | 3 | ## [Paper](https://arxiv.org/abs/2402.02800) 4 | 5 | ## Introduction 6 | 7 | Estimate relative camera pose of two images containing a co-visible object 8 | ![teaser](media/teaser.jpg) 9 | 10 | ## Setup 11 | python >=3.8 12 | 1. run: 13 | ``` 14 | git clone https://github.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models.git 15 | cd Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models 16 | pip install -r requirements.txt 17 | mkdir -p install 18 | cd install 19 | git clone https://github.com/CompVis/taming-transformers.git 20 | pip install -e taming-transformers/ 21 | git clone https://github.com/openai/CLIP.git 22 | pip install -e CLIP/ 23 | cd .. 24 | mkdir -p weight 25 | cd weight 26 | mkdir weight_gen6d 27 | wget https://cv.cs.columbia.edu/zero123/assets/105000.ckpt 28 | wget https://huggingface.co/One-2-3-45/code/resolve/main/one2345_elev_est/tools/weights/indoor_ds_new.ckpt 29 | cd .. 30 | ``` 31 | if you have trouble installing Pytorch3D in the above way, follow https://github.com/facebookresearch/pytorch3d/blob/main/INSTALL.md to install Pytorch3D 32 | 33 | 2. download gen6d weight 34 | - follow https://github.com/liuyuan-pal/Gen6D#Download to download gen6d_pretrain.tar.gz 35 | - tar -xvf gen6d_pretrain.tar.gz 36 | - now you should have a folder called 'data', move sub folders 'detector_pretrain', 'selector_pretrain' and 'refiner_pretrain' to weight/weight_gen6d/ 37 | 38 | 39 | 40 | ## Usage 41 | ### Evaluate on testset 42 | Modify evaluate_on_testset.py and run it. 43 | Before evaluating on GSO, you need to: 44 | 1. download gso-renderings.zip from https://drive.google.com/file/d/1fsMGFC3FdRFzWqClOT1jgbNRNqsE5sRv/view?usp=sharing and run 'unzip gso-renderings.zip' 45 | 2. configure src/path_configuration.py: 46 | ``` 47 | # the parent folder of GSO objects folders (GSO_alarm,GSO_backpack,...) 48 | dataPath_gso='path/to/gso-renderings' 49 | ``` 50 | Before evaluating on NAVI, you need to: 51 | 1. follow https://github.com/google/navi/tree/49661e33598c4812584ef428a7b2019dbb318a3c to download navi_v1.tar.gz and extract 52 | 2. configure src/path_configuration.py: 53 | ``` 54 | # the parent folder of NAVI objects folders (3d_dollhouse_sink,bottle_vitamin_d_tablets,...) 55 | dataPath_navi='' 56 | ``` 57 | ### Estimation on custom images 58 | Modify infer_custom.py and run it. 59 | 60 | 62 | 63 | ## Todo List 64 | - [ ] Check setup 65 | - [x] Upload GSO testset to a cloud drive 66 | - [ ] Remove unused code; better document and comment 67 | - [x] Remove unused package from requirement.txt 68 | - [ ] Provide command line interface 69 | - [ ] ... 70 | 71 | ## Acknowledgements 72 | In this repository, we have used codes from the following repositories. We thank all the authors for sharing great codes. 73 | - [Gen6D](https://github.com/liuyuan-pal/Gen6D) 74 | - [zero123](https://github.com/cvlab-columbia/zero123) 75 | - [One-2-3-45](https://github.com/One-2-3-45/One-2-3-45) 76 | - [extrinsic2pyramid](https://github.com/demul/extrinsic2pyramid) 77 | - [LoFTR](https://github.com/zju3dv/LoFTR) 78 | 79 | ## Citation 80 | ``` 81 | @misc{sun2024extreme, 82 | title={Extreme Two-View Geometry From Object Poses with Diffusion Models}, 83 | author={Yujing Sun and Caiyi Sun and Yuan Liu and Yuexin Ma and Siu Ming Yiu}, 84 | year={2024}, 85 | eprint={2402.02800}, 86 | archivePrefix={arXiv}, 87 | primaryClass={cs.CV} 88 | } 89 | ``` 90 | 91 | ## Follow-up work 92 | [Generalizable Single-view Object Pose Estimation via Two-side Generation and Matching, WACV 2025 Oral](https://arxiv.org/abs/2411.15860) (improved accuracy but slower) 93 | 94 | -------------------------------------------------------------------------------- /evaluate_on_testset.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import sys,os 4 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.join(cur_dir, "src")) 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | import root_config 18 | from evaluate.eval_test_set import run 19 | 20 | def main(datasetName:str,rotate=False): 21 | #------------------configs---------------------- 22 | root_config.VIS=0 # do not visualize result to save time. when debugging, you can let it be True 23 | root_config.SKIP_EVAL_SEQ_IF_EVAL_RESULT_EXIST = 1 # skip to eval a category if its eval result exists 24 | # when GPU out of memory, decrease the following values: 25 | root_config.SAMPLE_BATCH_SIZE = 32 26 | root_config.SAMPLE_BATCH_B_SIZE = 9 27 | 28 | 29 | 30 | if rotate:#add inplane rotation to images 31 | root_config.CONSIDER_IPR=True 32 | root_config.Q0Sipr=True 33 | root_config.Q1Sipr=True 34 | run( [datasetName] ) 35 | if __name__=='__main__': 36 | main(datasetName='gso',)#gso testset 37 | main(datasetName='navi',)#navi testset 38 | main(datasetName='gso',rotate=True)#rotated gso testset 39 | main(datasetName='navi',rotate=True)#rotated navi testset -------------------------------------------------------------------------------- /infer_custom.py: -------------------------------------------------------------------------------- 1 | 2 | import sys,os 3 | # os.environ["CUDA_VISIBLE_DEVICES"] = '1' 4 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.join(cur_dir, "src")) 6 | from pathlib import Path 7 | cur_dir=Path(cur_dir) 8 | import root_config 9 | from infer_pairs import infer_pairs_wrapper 10 | 11 | 12 | 13 | #------------------configs---------------------- 14 | root_config.CONSIDER_IPR=False # IPR means inplane rotation. if the object in reference image is not oriented correctly, set CONSIDER_IPR=True to enable inplane rotation predictor; if oriented correctly, set set CONSIDER_IPR=0 to skip inplane rotation estimation to save time 15 | # If GPU out of memory, decrease the following values: 16 | root_config.SAMPLE_BATCH_SIZE = 32 17 | root_config.SAMPLE_BATCH_B_SIZE = 9 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | referenceImage_path_bbox = (cur_dir/"media/example_custom_data/0.png", (27, 92, 192, 193),) 28 | queryImages_path_bbox = [ 29 | (cur_dir/"media/example_custom_data/1.png", (84, 79, 169, 202),), 30 | (cur_dir/"media/example_custom_data/2.png", (54, 36, 179, 218),), 31 | (cur_dir/"media/example_custom_data/3.png", (31, 69, 217, 194),), 32 | (cur_dir/"media/example_custom_data/4.png", (48, 89, 178, 202),), 33 | (cur_dir/"media/example_custom_data/5.png", (73, 59, 176, 229),), 34 | ] 35 | """ 36 | :param referenceImage_path_bbox: 37 | (path of the reference image, bbox of object in the reference image) 38 | :param queryImages_path_bbox: 39 | queryImages_path_bbox=[ 40 | (path of query image 1, bbox of object in this image), 41 | (path of query image 2, bbox), 42 | (path of query image 3, bbox), 43 | ... 44 | ] 45 | bbox=(x0,y0,x1,y1), in pixel , or relative to the image size. 46 | You should provide at least one query image 47 | :param refId: 48 | refId indentify the building result of a reference image. 49 | If {refId} has been built before, then the program will reuse the building result to save time (building from a reference image takes >1min on a single 3090 GPU) 50 | :return: 51 | relativePoses=[ 52 | relative pose from reference image to query image 1, 53 | relative pose from reference image to query image 2, 54 | relative pose from reference image to query image 3, 55 | ... 56 | ] 57 | X-query_i = relativePoses[i] @ X-reference. X means point in the coordinate system of camera. 58 | The camera follows {cameraConvention} convention, cameraConvention is 'opencv' by default 59 | """ 60 | relativePoses = infer_pairs_wrapper( 61 | referenceImage_path_bbox, queryImages_path_bbox, 62 | refId='lion', 63 | ) 64 | print(relativePoses) -------------------------------------------------------------------------------- /media/example_custom_data/0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/media/example_custom_data/0.png -------------------------------------------------------------------------------- /media/example_custom_data/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/media/example_custom_data/1.png -------------------------------------------------------------------------------- /media/example_custom_data/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/media/example_custom_data/2.png -------------------------------------------------------------------------------- /media/example_custom_data/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/media/example_custom_data/3.png -------------------------------------------------------------------------------- /media/example_custom_data/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/media/example_custom_data/4.png -------------------------------------------------------------------------------- /media/example_custom_data/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/media/example_custom_data/5.png -------------------------------------------------------------------------------- /media/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/media/teaser.jpg -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | carvekit_colab==4.1.0 2 | dl_ext==1.3.4 3 | einops==0.3.0 4 | imageio==2.9.0 5 | iopath==0.1.10 6 | ipython==8.12.3 7 | kornia==0.6.0 8 | loguru==0.7.0 9 | matplotlib==3.7.2 10 | numpy==1.24.4 11 | omegaconf==2.1.1 12 | opencv_python==4.5.5.64 13 | opencv_python_headless==4.8.0.74 14 | pandas==2.0.3 15 | Pillow==9.4.0 16 | plotly==5.13.1 17 | psutil==5.9.7 18 | pudb==2019.2 19 | pytorch_lightning==1.4.2 20 | pytz==2023.3 21 | PyYAML==6.0.1 22 | scipy==1.9.1 23 | six==1.16.0 24 | tabulate==0.9.0 25 | torch==1.12.1+cu116 26 | torchvision==0.13.1+cu116 27 | tqdm==4.65.0 28 | transformers==4.22.2 29 | transforms3d==0.4.1 30 | webdataset==0.2.5 31 | yacs==0.1.8 32 | pytorch3d 33 | -------------------------------------------------------------------------------- /src/.gitignore: -------------------------------------------------------------------------------- 1 | # !*/ 2 | !/.gitignore 3 | 4 | !/*.py 5 | 6 | *.pth 7 | *.ckpt 8 | *.exe 9 | *.pyc 10 | *.log 11 | *.csv 12 | *.jsonc 13 | *.json5 14 | *.xlsx 15 | *.xls 16 | *.pdf 17 | *.doc 18 | *.docx 19 | *.ppt 20 | *.pth 21 | *.html 22 | *.js 23 | *.png 24 | *.jpg 25 | *.md 26 | *.txt 27 | *.css 28 | *.json 29 | *.jsonc 30 | *.json5 31 | 32 | 33 | 34 | __pycache__ 35 | .idea -------------------------------------------------------------------------------- /src/Dataset/BaseDataset.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os.path as osp 3 | import random, os 4 | 5 | import numpy as np 6 | import torch 7 | from PIL import Image, ImageFile 8 | # from torch.utils.data import Dataset 9 | # try: 10 | # from utils.bbox import square_bbox 11 | # # from utils.misc import get_permutations 12 | # from utils.normalize_cameras import first_camera_transform, normalize_cameras 13 | # except ModuleNotFoundError: 14 | # from ..utils.bbox import square_bbox 15 | # # from ..utils.misc import get_permutations 16 | # from ..utils.normalize_cameras import first_camera_transform, normalize_cameras 17 | 18 | 19 | 20 | 21 | class BaseDataset: 22 | class ENUM_image_full_path_TYPE: 23 | raw=0 24 | resized=1 25 | 26 | def __init__(self): 27 | self.sequence_list=[] 28 | def __len__(self): 29 | return len(self.sequence_list) 30 | 31 | 32 | def get_data_4gen6d(self, index=None, sequence_name=None, ids=(0, 1), no_images=False): 33 | """ 34 | only need these field in batch: 35 | 1. image_not_transformed_full_path 36 | 2. relative_rotation;relative_t31 37 | 3. detection_outputs if ... else bbox 38 | 4. K 39 | """ 40 | pass 41 | -------------------------------------------------------------------------------- /src/Dataset/custom.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | 4 | import matplotlib.pyplot as plt 5 | import numpy as np 6 | import torch 7 | from PIL import Image 8 | from torch.utils.data import Dataset 9 | from torchvision import transforms 10 | from tqdm.auto import tqdm 11 | try: 12 | from utils.bbox import mask_to_bbox, square_bbox 13 | except ModuleNotFoundError: 14 | from ..utils.bbox import mask_to_bbox, square_bbox 15 | 16 | class CustomDataset(Dataset): 17 | def __init__( 18 | self, 19 | # image_dir, 20 | image_paths, 21 | mask_dir=None, 22 | bboxes=None, 23 | mask_images=False, 24 | ): 25 | assert mask_images==False 26 | assert bboxes is not None 27 | """ 28 | Dataset for custom images. If mask_dir is provided, bounding boxes are extracted 29 | from the masks. Otherwise, bboxes must be provided. 30 | """ 31 | # self.image_dir = image_dir 32 | self.mask_dir = mask_dir 33 | self.mask_images = mask_images 34 | self.bboxes = [] 35 | self.images = [] 36 | 37 | """if mask_images: 38 | for image_name, mask_name in tqdm( 39 | zip(sorted(os.listdir(image_dir)), sorted(os.listdir(mask_dir))) 40 | ): 41 | image = Image.open(osp.join(image_dir, image_name)) 42 | mask = Image.open(osp.join(mask_dir, mask_name)).convert("L") 43 | white_image = Image.new("RGB", image.size, (255, 255, 255)) 44 | if mask.size != image.size: 45 | mask = mask.resize(image.size) 46 | mask = Image.fromarray(np.array(mask) > 125) 47 | image = Image.composite(image, white_image, mask) 48 | self.images.append(image) 49 | else: 50 | for image_path in sorted(os.listdir(image_dir)): 51 | self.images.append(Image.open(osp.join(image_dir, image_path)))""" 52 | for image_path in image_paths: 53 | self.images.append(Image.open(image_path)) 54 | self.n = len(self.images) 55 | if bboxes is None: 56 | for mask_path in sorted(os.listdir(mask_dir))[: self.n]: 57 | mask = plt.imread(osp.join(mask_dir, mask_path)) 58 | if len(mask.shape) == 3: 59 | mask = mask[:, :, :3] 60 | else: 61 | mask = np.dstack([mask, mask, mask]) 62 | self.bboxes.append(mask_to_bbox(mask)) 63 | else: 64 | self.bboxes = bboxes 65 | self.jitter_scale = [1.15, 1.15] 66 | self.jitter_trans = [0, 0] 67 | self.transform = transforms.Compose( 68 | [ 69 | transforms.ToTensor(), 70 | transforms.Resize(224), 71 | transforms.Normalize( 72 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 73 | ), 74 | ] 75 | ) 76 | 77 | def __len__(self): 78 | return 1 79 | 80 | 81 | def __getitem__(self, index): 82 | return self.get_data() 83 | 84 | def get_data(self, ids=(0, 1, 2, 3, 4, 5)): 85 | images = [self.images[i] for i in ids] 86 | bboxes = [self.bboxes[i] for i in ids] 87 | images_transformed = [] 88 | crop_parameters = [] 89 | for _, (bbox, image) in enumerate(zip(bboxes, images)): 90 | w, h = image.width, image.height 91 | bbox = np.array(bbox) 92 | bbox_jitter = self._jitter_bbox(bbox) 93 | image = self._crop_image(image, bbox_jitter, white_bg=self.mask_images) 94 | images_transformed.append(self.transform(image)) 95 | crop_center = (bbox_jitter[:2] + bbox_jitter[2:]) / 2 96 | cc = (2 * crop_center / min(h, w)) - 1 97 | crop_width = 2 * (bbox_jitter[2] - bbox_jitter[0]) / min(h, w) 98 | 99 | crop_parameters.append(torch.tensor([-cc[0], -cc[1], crop_width]).float()) 100 | images = images_transformed 101 | 102 | batch = {} 103 | batch["image"] = torch.stack(images) 104 | batch["n"] = len(images) 105 | batch["crop_params"] = torch.stack(crop_parameters) 106 | 107 | return batch 108 | -------------------------------------------------------------------------------- /src/Dataset/dataset_imports.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mask_util 3 | import root_config 4 | import PIL 5 | if __name__ == "__main__": 6 | import sys, os 7 | sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../../../../.."))) 8 | sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) 9 | # print("sys.path[0]:", os.path.abspath(sys.path[0])) 10 | # print("sys.path:", sys.path) 11 | from imports import * 12 | if __name__ == "__main__": 13 | from linemod import LinemodDataset 14 | from BaseDatabase import BaseDatabase 15 | else: 16 | from .linemod import LinemodDataset 17 | from .BaseDatabase import BaseDatabase 18 | from torchvision import transforms 19 | import glob 20 | from pathlib import Path 21 | import cv2 22 | import numpy as np 23 | import os 24 | import plyfile 25 | from skimage.io import imread, imsave 26 | import pickle 27 | import json 28 | import os.path as osp 29 | import random, os 30 | import torch 31 | from PIL import Image, ImageFile -------------------------------------------------------------------------------- /src/Dataset/gso.py: -------------------------------------------------------------------------------- 1 | import math 2 | import mask_util 3 | import root_config 4 | import PIL 5 | if __name__ == "__main__": 6 | import sys, os 7 | sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../../../../.."))) 8 | sys.path.insert(0, os.path.abspath(os.path.join(__file__, "../.."))) 9 | # print("sys.path[0]:", os.path.abspath(sys.path[0])) 10 | # print("sys.path:", sys.path) 11 | from imports import * 12 | if __name__ == "__main__": 13 | from linemod import LinemodDataset 14 | from BaseDatabase import BaseDatabase 15 | else: 16 | from .linemod import LinemodDataset 17 | from .BaseDatabase import BaseDatabase 18 | from torchvision import transforms 19 | import glob 20 | from pathlib import Path 21 | import cv2 22 | import numpy as np 23 | import os 24 | import plyfile 25 | from skimage.io import imread, imsave 26 | import pickle 27 | import json 28 | import os.path as osp 29 | import random, os 30 | import torch 31 | from PIL import Image, ImageFile 32 | 33 | 34 | 35 | class GsoDatabase(BaseDatabase) : 36 | """ 37 | _img_id ==imgInt. 38 | """ 39 | def __init__(self, obj:str,): 40 | assert obj.startswith("GSO_") or obj.startswith("gso_") 41 | DATASET_ROOT = os.path.abspath(root_config.dataPath_gso) 42 | self.obj = obj # bed001,bed002,.... 43 | self._dir = f'{DATASET_ROOT}/{self.obj}' 44 | assert os.path.exists(self._dir) ,f"{str(self._dir)} does not exist" 45 | self._img_ids=self._imgFullPaths_2_img_ids__A( glob.glob(f'{self._dir}/*.png')) 46 | self.poses, self.K = self.__get_poses_K() 47 | assert len(self.poses)==len(self._img_ids) 48 | assert self._img_ids==list(range(len(self._img_ids))) 49 | 50 | def __get_poses_K(self): 51 | """ 52 | pose=[R;t] 53 | xcam = R @ xw + t 54 | opencv坐标系 55 | """ 56 | def read_pickle(pkl_path): 57 | with open(pkl_path, 'rb') as f: 58 | return pickle.load(f) 59 | K, poses = read_pickle(os.path.join(self._dir,'meta.pkl')) 60 | 61 | 62 | return poses, K 63 | def get_K(self, img_id): 64 | return self.K 65 | def get_pose(self, img_id): 66 | return self.poses[img_id] 67 | 68 | def get_img_ids(self): 69 | return self._img_ids.copy() 70 | 71 | def _get_rgbaImage_full_path(self,img_id): 72 | fullpath=f'{self._dir}/{int(img_id):03}.png' 73 | return fullpath 74 | def _get_rgbaImage(self,img_id): 75 | fullpath=self._get_rgbaImage_full_path(img_id) 76 | img=imread(fullpath) 77 | return img 78 | def get_image_full_path(self, img_id): 79 | png= self._get_rgbaImage_full_path(img_id) 80 | rgb_folder=Path(f'{self._dir}/_RGB') 81 | rgb_folder.mkdir(exist_ok=1) 82 | jpg= f'{str(rgb_folder)}/{int(img_id):03}.png' 83 | if not os.path.exists(jpg): 84 | #png rgba 2 rgb(white bg) and save as jpg 85 | img=PIL.Image.open(png) 86 | assert img.mode=="RGBA" 87 | width = img.width 88 | height = img.height 89 | image = Image.new('RGB', size=(width, height), color=(255, 255, 255)) 90 | image.paste(img, (0, 0), mask=img) 91 | image.save(jpg) 92 | return jpg 93 | def get_mask_full_path(self, img_id): 94 | mask_dir= f'{self._dir}/mask' 95 | os.makedirs(mask_dir,exist_ok=1) 96 | mask_fullpath = f'{mask_dir}/{int(img_id):03}.png' 97 | if not os.path.exists(mask_fullpath): 98 | rgbaImage=self._get_rgbaImage(img_id) 99 | mask=mask_util.Mask.rgbaImage__2__hw0_255(rgbaImage) 100 | imsave(mask_fullpath,mask) 101 | return mask_fullpath 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | class GsoDataset(LinemodDataset): 111 | def __init__(self, category: str): 112 | # super().__init__(category) 113 | self.sequence_list = [""] 114 | self.database = GsoDatabase(obj=category, ) 115 | -------------------------------------------------------------------------------- /src/Dataset/navi.py: -------------------------------------------------------------------------------- 1 | import root_config 2 | from .dataset_imports import * 3 | from .navi_util.data_util import load_scene_data,camera_matrices_from_annotation 4 | 5 | class NaviDatabase(BaseDatabase) : 6 | """ 7 | """ 8 | def __init__(self, obj_with_scene :str,): 9 | """ 10 | obj_with_scene: 'obj(scene)' 11 | """ 12 | assert obj_with_scene.endswith(")") 13 | assert obj_with_scene.count('(')==1 14 | assert obj_with_scene.count(')')==1 15 | # 16 | obj,scene = obj_with_scene.split('(') 17 | scene = scene.replace(')','') 18 | # 19 | DATASET_ROOT = os.path.abspath ( root_config.dataPath_navi) 20 | self._obj_with_scene = obj_with_scene # bed001,bed002,.... 21 | self.obj = obj # bed001,bed002,.... 22 | self.scene = scene 23 | self._dir = Path(f'{DATASET_ROOT}/{self.obj}/{self.scene}') 24 | assert self._dir.exists(),f"{str(self._dir)} does not exist" 25 | # 26 | folder__images_after_exif_transpose=self._dir/'_images_after_exif_transpose' 27 | ttt355=folder__images_after_exif_transpose 28 | if os.path.exists(folder__images_after_exif_transpose): 29 | ttt355=None 30 | else: 31 | os.mkdir(folder__images_after_exif_transpose) 32 | # 33 | annotations, _, image_names = load_scene_data( 34 | obj, scene, DATASET_ROOT, max_num_images=None, 35 | folder__images_after_exif_transpose=ttt355, 36 | ) 37 | del ttt355 38 | assert len(self._imgFullPaths_2_img_ids__A( list(folder__images_after_exif_transpose.glob('*.jpg')),SUFFIX='.jpg' ))==len(image_names),'可能是创建folder__images_after_exif_transpose时还没创建完就被终止了,导致有文件夹但里面文件数目不对' 39 | self.imageFullpaths=[ str(folder__images_after_exif_transpose/i) for i in image_names] 40 | self.maskFullpaths=[ str(self._dir/'masks'/(i.replace('.jpg','.png'))) for i in image_names] 41 | self.poses=[] 42 | self.Ks=[] 43 | for i,image_name in enumerate(image_names): 44 | annotation=annotations[i] 45 | assert image_name==annotation['filename'] 46 | object_to_world, K = camera_matrices_from_annotation(annotation) 47 | object_to_world=object_to_world[:3,:] 48 | self.poses.append(object_to_world) 49 | self.Ks.append(K) 50 | del image_names 51 | self._img_ids = self._imgFullPaths_2_img_ids__A(self.imageFullpaths,check=True,SUFFIX=".jpg") 52 | 53 | def get_K(self, img_id): 54 | return self.Ks[img_id] 55 | def get_pose(self, img_id): 56 | return self.poses[img_id] 57 | def get_img_ids(self): 58 | return self._img_ids.copy() 59 | def get_image_full_path(self, img_id): 60 | return self.imageFullpaths[img_id] 61 | def get_mask_full_path(self, img_id): 62 | return self.maskFullpaths[img_id] 63 | 64 | 65 | 66 | 67 | class NaviDataset(LinemodDataset): 68 | def __init__(self, category: str): 69 | # super().__init__(category) 70 | self.sequence_list = [""] 71 | self.database = NaviDatabase(obj_with_scene=category, ) 72 | -------------------------------------------------------------------------------- /src/Dataset/navi_util/data_util.py: -------------------------------------------------------------------------------- 1 | """ 2 | from github proj navi 3 | modified 4 | """ 5 | # Copyright 2023 Google LLC 6 | # 7 | # Licensed under the Apache License, Version 2.0 (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at 10 | # 11 | # https://www.apache.org/licenses/LICENSE-2.0 12 | # 13 | # Unless required by applicable law or agreed to in writing, software 14 | # distributed under the License is distributed on an "AS IS" BASIS, 15 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 16 | # See the License for the specific language governing permissions and 17 | # limitations under the License. 18 | # 19 | # Author: kmaninis@google.com (Kevis-Kokitsi Maninis) 20 | """Useful functions for interfacing NAVI data.""" 21 | 22 | import json 23 | import torch 24 | import os 25 | # import trimesh 26 | from typing import Text, Optional 27 | import numpy as np 28 | from PIL import Image 29 | from PIL import ImageOps 30 | from . import transformations 31 | 32 | 33 | def read_image(image_path: Text) -> Image.Image: 34 | """Reads a NAVI image (and rotates it according to the metadata).""" 35 | ret= ImageOps.exif_transpose(Image.open(image_path)) 36 | return ret 37 | 38 | 39 | def decode_depth(depth_encoded: Image.Image, scale_factor: float = 10.): 40 | """Decodes depth (disparity) from an encoded image (with encode_depth). 41 | 42 | Args: 43 | depth_encoded: The encoded PIL uint16 image of the depth 44 | scale_factor: float, factor to reduce quantization error. MUST BE THE SAME 45 | as the value used to encode the depth. 46 | 47 | Returns: 48 | depth: float[h, w] image with decoded depth values. 49 | """ 50 | max_val = (2**16) - 1 51 | disparity = np.array(depth_encoded).astype('uint16') 52 | disparity = disparity.astype(np.float32) / (max_val * scale_factor) 53 | disparity[disparity == 0] = np.inf 54 | depth = 1 / disparity 55 | return depth 56 | 57 | 58 | def read_depth_from_png(depth_image_path: str) -> np.ndarray: 59 | """Reads encoded depth image from an uint16 png file.""" 60 | if not depth_image_path.endswith('.png'): 61 | raise ValueError(f'Path {depth_image_path} is not a valid png image path.') 62 | 63 | depth_image = Image.open(depth_image_path) 64 | # Don't change the scale_factor. 65 | depth = decode_depth(depth_image, scale_factor=10) 66 | return depth 67 | 68 | 69 | def convert_to_triangles(vertices: np.ndarray, faces: np.ndarray) -> np.ndarray: 70 | """Converts vertices and faces to triangle format float32[N, 3, 3].""" 71 | faces = faces.reshape([-1]) 72 | tri_flat = vertices[faces, :] 73 | return tri_flat.reshape((-1, 3, 3)).astype(np.float32) 74 | 75 | 76 | def camera_matrices_from_annotation(annotation): 77 | """Convert camera pose and intrinsics to 4x4 matrices.""" 78 | translation = transformations.translate(annotation['camera']['t']) 79 | rotation = transformations.quaternion_to_rotation_matrix( 80 | annotation['camera']['q']) 81 | object_to_world = translation @ rotation 82 | h, w = annotation['image_size'] 83 | focal_length_pixels = annotation['camera']['focal_length'] 84 | """intrinsics = transformations.gl_projection_matrix_from_intrinsics( 85 | w, h, focal_length_pixels, focal_length_pixels, w//2, h//2, zfar=1000) 86 | return object_to_world, intrinsics""" 87 | # width: torch.Tensor, height: torch.Tensor, fx: torch.Tensor, fy: torch.Tensor, cx: torch.Tensor, cy: torch.Tensor=w, h, focal_length_pixels, focal_length_pixels, w // 2, h // 2 88 | width, height,fx, fy, cx, cy = w, h, focal_length_pixels, focal_length_pixels, w // 2, h // 2 89 | K = np.array([[fx, 0, cx], 90 | [0, fy, cy], 91 | [0, 0, 1]]) 92 | return object_to_world, K 93 | 94 | 95 | def load_scene_data( 96 | # query: str, 97 | object_id: str, 98 | scene: str, 99 | 100 | navi_release_root: str, 101 | max_num_images: Optional[int] = None, 102 | folder__images_after_exif_transpose=None, 103 | ): 104 | """ 105 | # Loads the data of a certain scene from a query 106 | query_data = query.split(':') 107 | if len(query_data) == 3: 108 | object_id, scene_name, camera_model = query_data 109 | scene = f'{scene_name}_{camera_model}' 110 | elif len(query_data) == 2: 111 | object_id, scene_name = query_data 112 | scene = scene_name 113 | assert scene_name == 'wild_set' 114 | else: 115 | raise ValueError(f'Query {query} is not valid.') 116 | """ 117 | 118 | annotation_json_path = os.path.join( 119 | navi_release_root, object_id, scene, 120 | 'annotations.json') 121 | with open(annotation_json_path, 'r') as f: 122 | annotations = json.load(f) 123 | 124 | """# Load the 3D mesh. 125 | mesh_path = os.path.join( 126 | navi_release_root, object_id, '3d_scan', f'{object_id}.obj') 127 | mesh = trimesh.load(mesh_path)""" 128 | mesh=None 129 | 130 | # Load the images. 131 | images = [] 132 | for i_anno, anno in enumerate(annotations): 133 | if max_num_images is not None and i_anno >=max_num_images: 134 | break 135 | filename=anno['filename'] 136 | image_path = os.path.join( 137 | navi_release_root, object_id, scene, 'images',filename) 138 | if folder__images_after_exif_transpose: 139 | image:Image.Image=read_image(image_path) 140 | image.save(folder__images_after_exif_transpose/filename) 141 | images.append(filename) 142 | return annotations, mesh, images -------------------------------------------------------------------------------- /src/database_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | #-----------------------Database-------------------------------------------- 3 | from Dataset.gso import GsoDatabase 4 | from Dataset.navi import NaviDatabase 5 | def datasetName_cate_seq__2__database(datasetName,cate, seq): 6 | if datasetName == 'gso': 7 | return GsoDatabase(cate) 8 | elif datasetName == 'navi': 9 | return NaviDatabase(cate) 10 | else: 11 | raise NotImplementedError 12 | @functools.cache 13 | def datasetName_cate_seq__2__database__cached(*args,**kw): 14 | return datasetName_cate_seq__2__database(*args,**kw) -------------------------------------------------------------------------------- /src/debug_util.py: -------------------------------------------------------------------------------- 1 | import root_config,os,sys,shutil 2 | from skimage.io import imread, imsave 3 | import numpy as np 4 | import os,sys,math,functools,inspect 5 | from pathlib import Path 6 | import PIL 7 | def debug_imsave(path__rel_to__path_4debug,arr): 8 | # 9 | if isinstance(path__rel_to__path_4debug,Path): 10 | path__rel_to__path_4debug=str(path__rel_to__path_4debug) 11 | assert isinstance(path__rel_to__path_4debug,str) 12 | if not(path__rel_to__path_4debug.endswith('.jpg') or 13 | path__rel_to__path_4debug.endswith('.png') ): 14 | print('[warning] incorrect image format') 15 | # 16 | if isinstance(arr,PIL.Image.Image): 17 | # Convert PIL image to NumPy array 18 | arr = np.asarray(arr) 19 | assert isinstance(arr,np.ndarray) 20 | # 21 | full_path=os.path.join(root_config.path_4debug,path__rel_to__path_4debug) 22 | os.makedirs(os.path.dirname(full_path),exist_ok=1) 23 | print(f"[debug_imsave]saving...",end=" ",flush = True) 24 | imsave(full_path,arr) 25 | print(f"save to \"{full_path}\"") 26 | -------------------------------------------------------------------------------- /src/elev_util.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import glob 3 | from imports import * 4 | import json 5 | if 1: 6 | from import_util import is_in_sysPath 7 | # if(is_in_sysPath(path=os.path.abspath(os.path.join(os.path.dirname(__file__),os.path.pardir)))): 8 | # from oee.utils.elev_est_api import elev_est_api 9 | # else: 10 | # from ..oee.utils.elev_est_api import elev_est_api 11 | from oee.utils.elev_est_api import elev_est_api,ElevEstHelper 12 | import numpy as np 13 | def _sample_sphere(num_samples, begin_elevation = 0): 14 | """ sample angles from the sphere 15 | reference: https://zhuanlan.zhihu.com/p/25988652?group_id=828963677192491008 16 | """ 17 | ratio = (begin_elevation + 90) / 180 18 | num_points = int(num_samples // (1 - ratio)) 19 | phi = (np.sqrt(5) - 1.0) / 2. 20 | azimuths = [] 21 | elevations = [] 22 | for n in range(num_points - num_samples, num_points): 23 | z = 2. * n / num_points - 1. 24 | azimuths.append(2 * np.pi * n * phi % (2 * np.pi)) 25 | elevations.append(np.arcsin(z)) 26 | return np.array(azimuths), np.array(elevations) 27 | def _get_l_ele_azimuth_inRadian(num_samples, begin_elevation = 0): 28 | azimuths,elevations=_sample_sphere(num_samples, begin_elevation) 29 | l_ele_azimuth_inRadian=np.stack([elevations,azimuths],axis=1) 30 | return l_ele_azimuth_inRadian 31 | import math 32 | def eleRadian_2_baseXyz_lXyz(eleRadian:float):#xyz is in degree! 33 | eleDegree=eleRadian*180/math.pi 34 | base_xyz=(-eleDegree,0,0) 35 | l_xyz=[] 36 | l_ele_azimuth_inRadian=_get_l_ele_azimuth_inRadian( 37 | # num_samples=128 38 | num_samples=root_config.NUM_REF 39 | ) 40 | if root_config.ELEV_RANGE:#only keep elev in ELEV_RANGE 41 | assert len(root_config.ELEV_RANGE)==2 42 | # l_ele_azimuth_inRadian=np.array(l_ele_azimuth_inRadian) 43 | # to degree 44 | l_ele_azimuth_inDeg=np.rad2deg(l_ele_azimuth_inRadian) 45 | l_ele_azimuth_inDeg=l_ele_azimuth_inDeg[l_ele_azimuth_inDeg[:,0]>=root_config.ELEV_RANGE[0]] 46 | l_ele_azimuth_inDeg=l_ele_azimuth_inDeg[l_ele_azimuth_inDeg[:,0]<=root_config.ELEV_RANGE[1]] 47 | # to radian 48 | l_ele_azimuth_inRadian=np.deg2rad(l_ele_azimuth_inDeg) 49 | del l_ele_azimuth_inDeg 50 | # l_ele_azimuth_inRadian=to_list_to_primitive(l_ele_azimuth_inRadian) 51 | # 52 | for ele_azimuth_inRadian in l_ele_azimuth_inRadian: 53 | x0=base_xyz[0] 54 | y0=base_xyz[1] 55 | x1=-ele_azimuth_inRadian[0] 56 | y1=ele_azimuth_inRadian[1] 57 | x1=x1*180/math.pi 58 | y1=y1*180/math.pi 59 | l_xyz.append((x1-x0,y1-y0,0)) 60 | return base_xyz,l_xyz 61 | #------------one2345----------------------- 62 | 63 | 64 | 65 | def imgPath2elevRadian(K,input_image_path,run4gen6d_main,id_): 66 | id2 = f"4elev-{id_}-{os.path.basename(input_image_path)}" 67 | 68 | output_dir = os.path.join(root_config.dataPath_gen6d, f'{id2}/ref') 69 | def getFourNearImagePaths( ): 70 | # delta_x_2 = [-10, 10, 0, 0] 71 | # delta_y_2 = [0, 0, -10, 10] 72 | DELTA=10 73 | if 'tmp_4_ipr_ex1' not in Global.anything: 74 | assert DELTA==10 75 | delta_x_2 = [-DELTA, DELTA, 0, 0] 76 | delta_y_2 = [0, 0, -DELTA, DELTA] 77 | ElevEstHelper.DELTA=DELTA 78 | 79 | l_xyz=[(delta_x_2[i],delta_y_2[i],0) for i in range(4)] 80 | l__path_output_im=run4gen6d_main( 81 | id2, 82 | input_image_path, 83 | # output_dir=output_dir, 84 | output_dir=None, 85 | num_samples=1, 86 | l_xyz=l_xyz, 87 | base_xyz=(0,0,0), 88 | ddim_steps=75, 89 | K=K, 90 | only_gen=True, # dont crop,re center etc 91 | ) 92 | # ret=[os.path.join(output_dir,f"{i}.jpg") for i in range(4)] 93 | assert len(l_xyz)==len(l__path_output_im) 94 | return l__path_output_im 95 | if root_config.Cheat.force_elev: 96 | print(f"[warning] You enable {root_config.Cheat.force_elev=}") 97 | fourNearImagePaths='(Cheat.force_elev)' 98 | elev=root_config.Cheat.force_elev 99 | else: 100 | fourNearImagePaths=getFourNearImagePaths() 101 | elev = elev_est_api( fourNearImagePaths, 102 | # min_elev=30, max_elev=150, 103 | # min_elev=20, max_elev=160, 104 | min_elev=90-79, max_elev=90-0, 105 | # min_elev=1, max_elev=160, 106 | ) 107 | elev_deg:int=elev 108 | elev = np.deg2rad(elev) 109 | #info: rad,degree,output_dir,output_imgs,output_json 110 | os.makedirs(output_dir,exist_ok=1) 111 | output_json=os.path.join(output_dir,"elev.json") 112 | info={ 113 | "rad":elev, 114 | "degree":elev_deg, 115 | "input_image_path":input_image_path, 116 | "output_dir":output_dir, 117 | "output_imgs":fourNearImagePaths, 118 | "output_json":output_json, 119 | } 120 | print("[imgPath2elevRadian]",json.dumps(info,indent=4)) 121 | if 'tmp_4_ipr_ex1' in Global.anything: 122 | raise Exception("tmp_4_ipr_ex1",) 123 | with open(output_json,"w") as f: 124 | json.dump(info,f,indent=4) 125 | #tmp4SecondTimeDebugElev 126 | if not hasattr(Global,'tmp4SecondTimeDebugElev'): 127 | Global.tmp4SecondTimeDebugElev=[] 128 | Global.tmp4SecondTimeDebugElev.append(info) 129 | return elev 130 | def eleRadian_2_base_w2c(eleRadian): 131 | base_xyz, useless__l_xyz = eleRadian_2_baseXyz_lXyz(eleRadian=eleRadian) 132 | pose=xyz2pose(*base_xyz) 133 | assert pose.shape==(3,4) 134 | pose=np.concatenate([pose,np.array([[0,0,0,1]])],axis=0) 135 | return pose -------------------------------------------------------------------------------- /src/evaluate/eval_test_set.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import os,sys 4 | import warnings 5 | warnings.filterwarnings("ignore", category=DeprecationWarning) 6 | from imports import * 7 | from infer_pair import * 8 | import sys 9 | import gen6d.Gen6D.pipeline as pipeline 10 | from miscellaneous.EvalResult import EvalResult 11 | from evaluate.eval_on_an_obj import eval_on_an_obj 12 | 13 | 14 | 15 | 16 | 17 | #---------------------------------------------- 18 | 19 | def run(l_datasetName,model_name ="E2VG"): 20 | l__datasetName_cate_seq_Q0INDEX=get__l__datasetName_cate_seq_Q0INDEX(datasetNames=l_datasetName,datasetName_2_s=MyTestset.datasetName_2_s) 21 | # 22 | SUFFIX="" 23 | for datasetName,cate,seq, q0 in l__datasetName_cate_seq_Q0INDEX: 24 | if datasetName=='navi': 25 | root_config.tmp_batch_image__SUFFIX='.jpg' #save disk space 26 | assert seq==""# only co3d has the seq level 27 | root_config.DATASET = datasetName 28 | for Q0INDEX in [q0]: 29 | root_config.Q0INDEX = Q0INDEX 30 | root_config.refIdSuffix = f"+{Q0INDEX}" 31 | if root_config.Q0Sipr: 32 | SUFFIX='-rotated' 33 | root_config.idSuffix = f"(testset{SUFFIX}){root_config.SEED}+{Q0INDEX}" 34 | root_config.refIdSuffix += SUFFIX 35 | eval_on_an_obj( 36 | category=cate, 37 | model_name=model_name, 38 | vis_include_theOther=0, 39 | ) 40 | EvalResult.AllAcc.dump_average_acc(SUFFIX ) 41 | -------------------------------------------------------------------------------- /src/exception_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import root_config 3 | import loguru 4 | import time 5 | import traceback 6 | from misc_util import your_datetime 7 | def handle_exception(e ): 8 | print("\n-------[handle_exception]-------\n") 9 | print(f"{your_datetime():%Y.%m.%d-%H:%M:%S}") 10 | print(e) 11 | print(traceback.format_exc()) 12 | loguru.logger.exception (e) 13 | with open(os.path.join(root_config.path_root,"error.txt"), "a") as f: 14 | f.write("\n" + "\n" + "\n" + "\n") 15 | f.write(f"{your_datetime():%Y.%m.%d-%H:%M:%S}" + "\n") 16 | f.write(str(e) + "\n") 17 | f.write(traceback.format_exc() + "\n") -------------------------------------------------------------------------------- /src/gen6d/Gen6D/.gitignore: -------------------------------------------------------------------------------- 1 | # 忽略所有文件和文件夹 2 | * 3 | #!*/将所有子目录列入白名单;new bing 4 | !*/ 5 | # 不忽略.py文件 6 | !*.py 7 | !*.yaml 8 | !.gitignore 9 | !git-alias.txt 10 | .idea 11 | # .vscode 12 | 13 | 14 | !.vscode 15 | !.vscode/settings.json 16 | 17 | *.pkl 18 | *.ply 19 | *.db 20 | *.vis 21 | *.bin 22 | *.lnk 23 | 24 | #忽略scy/temp: 25 | #temp -------------------------------------------------------------------------------- /src/gen6d/Gen6D/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/src/gen6d/Gen6D/__init__.py -------------------------------------------------------------------------------- /src/gen6d/Gen6D/configs/detector/detector_pretrain.yaml: -------------------------------------------------------------------------------- 1 | name: detector_pretrain 2 | network: detector 3 | detection_scales: [-1.0,-0.5,0.0,0.5] -------------------------------------------------------------------------------- /src/gen6d/Gen6D/configs/detector/detector_train.yaml: -------------------------------------------------------------------------------- 1 | name: detector_train 2 | network: detector 3 | detection_scales: [-1.0,-0.5,0.0,0.5] 4 | 5 | ##########loss############## 6 | loss: [detection_softmax, detection_offset_scale] 7 | val_metric: [vis_bbox_scale] 8 | key_metric_name: mean_iou 9 | score_diff_thresh: 1.0 10 | use_ref_view_mask: false 11 | output_interval: 15 12 | use_offset_loss: true 13 | 14 | ###########dataset########## 15 | train_dataset_type: det_train 16 | train_dataset_cfg: 17 | use_database_sample_prob: true 18 | database_sample_prob: [ 100, 10, 30, 10, 10 ] 19 | database_names: ['co3d_train', 'gso_train_128', 'shapenet_train', 'linemod_train', 'genmop_train'] 20 | batch_size: 4 21 | 22 | ref_type: fps_32 23 | detector_scale_range: [-0.5, 1.2] 24 | detector_rotation_range: [-22.5, 22.5] 25 | 26 | resolution: 128 27 | reference_num: 32 28 | 29 | que_add_background_objects: true 30 | que_background_objects_num: 2 31 | que_background_objects_ratio: 0.3 32 | 33 | offset_type: random 34 | detector_offset_std: 3 35 | detector_real_aug_rot: true 36 | 37 | val_set_list: 38 | - 39 | name: cat_val 40 | type: det_val 41 | cfg: 42 | ref_database_name: linemod/cat 43 | test_database_name: linemod/cat 44 | ref_split_type: linemod_val 45 | test_split_type: linemod_val 46 | - 47 | name: warrior_val 48 | type: det_val 49 | cfg: 50 | ref_database_name: genmop/tformer-ref 51 | test_database_name: genmop/tformer-test 52 | ref_split_type: all 53 | test_split_type: all 54 | 55 | ##########optimizer########## 56 | optimizer_type: adam 57 | lr_type: exp_decay 58 | lr_cfg: 59 | lr_init: 1.0e-4 60 | decay_step: 100000 61 | decay_rate: 0.5 62 | total_step: 300000 63 | train_log_step: 50 64 | val_interval: 5000 65 | save_interval: 500 -------------------------------------------------------------------------------- /src/gen6d/Gen6D/configs/gen6d_pretrain.yaml: -------------------------------------------------------------------------------- 1 | name: gen6d_pretrain 2 | type: gen6d 3 | 4 | detector: configs/detector/detector_pretrain.yaml 5 | selector: configs/selector/selector_pretrain.yaml 6 | refiner: configs/refiner/refiner_pretrain.yaml 7 | 8 | # ref_resolution: 128 # reference image resolution 9 | # ref_view_num: 64 # view number used in selection 10 | ref_resolution: 128 # reference image resolution 11 | # ref_view_num: 128 # view number used in selection 12 | #ref_view_num: 128*4 13 | #ref_view_num: 512 14 | 15 | # read from root_config.NUM_REF 16 | ref_view_num: -1 17 | 18 | 19 | det_ref_view_num: 32 # view number used in detection 20 | refine_iter: 3 # refinement iteration number -------------------------------------------------------------------------------- /src/gen6d/Gen6D/configs/gen6d_train.yaml: -------------------------------------------------------------------------------- 1 | name: gen6d_train 2 | type: gen6d 3 | 4 | detector: configs/detector/detector_train.yaml 5 | selector: configs/selector/selector_train.yaml 6 | refiner: configs/refiner/refiner_train.yaml 7 | 8 | ref_resolution: 128 # reference image resolution 9 | ref_view_num: 64 # view number used in selection 10 | det_ref_view_num: 32 # view number used in detection 11 | refine_iter: 3 # refinement iteration number -------------------------------------------------------------------------------- /src/gen6d/Gen6D/configs/refiner/refiner_pretrain.yaml: -------------------------------------------------------------------------------- 1 | name: refiner_pretrain 2 | network: refiner 3 | refiner_sample_num: 32 -------------------------------------------------------------------------------- /src/gen6d/Gen6D/configs/refiner/refiner_train.yaml: -------------------------------------------------------------------------------- 1 | name: refiner_train 2 | network: refiner 3 | refiner_sample_num: 32 4 | 5 | ##########loss############## 6 | loss: [refiner_loss] 7 | val_metric: [ref_metrics] 8 | key_metric_name: pose_add 9 | output_interval: 15 10 | 11 | ###########dataset########## 12 | collate_fn: simple 13 | train_loader_batch_size: 2 14 | val_loader_batch_size: 1 15 | train_dataset_type: ref_train 16 | train_dataset_cfg: 17 | batch_size: 1 18 | use_database_sample_prob: true 19 | database_sample_prob: [ 40, 10, 10, 10 ] 20 | database_names: ['shapenet_train', 'gso_train_128', 'linemod_train', 'genmop_train' ] 21 | 22 | refine_scale_range: [-0.3, 0.3] 23 | refine_rotation_range: [-15, 15] 24 | refine_offset_std: 4 25 | refine_ref_num: 6 26 | refine_ref_resolution: 128 27 | refine_view_cfg: v3 28 | refine_ref_ids_version: fps 29 | 30 | val_set_list: 31 | - 32 | name: warrior_val 33 | type: ref_val 34 | cfg: 35 | ref_database_name: genmop/tformer-ref 36 | test_database_name: genmop/tformer-test 37 | ref_split_type: all 38 | test_split_type: all 39 | detector_name: detector_train 40 | selector_name: selector_train 41 | refine_ref_num: 6 42 | refine_ref_resolution: 128 43 | - 44 | name: cat_val 45 | type: ref_val 46 | cfg: 47 | ref_database_name: linemod/cat 48 | test_database_name: linemod/cat 49 | ref_split_type: linemod_val 50 | test_split_type: linemod_val 51 | detector_name: detector_train 52 | selector_name: selector_train 53 | refine_ref_num: 6 54 | refine_ref_resolution: 128 55 | 56 | ##########optimizer########## 57 | optimizer_type: adam 58 | lr_type: exp_decay 59 | lr_cfg: 60 | lr_init: 1.0e-4 61 | decay_step: 100000 62 | decay_rate: 0.5 63 | total_step: 300001 64 | train_log_step: 50 65 | val_interval: 5000 66 | save_interval: 500 67 | worker_num: 8 -------------------------------------------------------------------------------- /src/gen6d/Gen6D/configs/selector/selector_pretrain.yaml: -------------------------------------------------------------------------------- 1 | name: selector_pretrain 2 | network: selector 3 | selector_angle_num: 5 -------------------------------------------------------------------------------- /src/gen6d/Gen6D/configs/selector/selector_train.yaml: -------------------------------------------------------------------------------- 1 | name: selector_train 2 | network: selector 3 | selector_angle_num: 5 4 | 5 | ##########loss############## 6 | loss: [selection_loss] 7 | val_metric: [vis_sel] 8 | key_metric_name: sel_ang_acc 9 | output_interval: 15 10 | 11 | ###########dataset########## 12 | train_dataset_type: sel_train 13 | train_dataset_cfg: 14 | use_database_sample_prob: true 15 | database_sample_prob: [ 10, 30, 10, 10 ] 16 | database_names: [ 'gso_train_128', 'shapenet_train', 'linemod_train', 'genmop_train' ] 17 | 18 | batch_size: 2 19 | ref_type: fps_64 20 | 21 | use_render: false 22 | resolution: 128 23 | reference_num: 64 24 | 25 | selector_scale_range: [-0.3, 0.3] 26 | selector_angle_range: [-90, 90] 27 | selector_angles: [-90, -45, 0, 45, 90] 28 | selector_real_aug: true 29 | 30 | val_set_list: 31 | - 32 | name: cat_val 33 | type: sel_val 34 | cfg: 35 | ref_database_name: linemod/cat 36 | test_database_name: linemod/cat 37 | ref_split_type: linemod_val 38 | test_split_type: linemod_val 39 | selector_angles: [-90, -45, 0, 45, 90] 40 | selector_ref_num: 64 41 | selector_ref_res: 128 42 | - 43 | name: warrior_val 44 | type: sel_val 45 | cfg: 46 | ref_database_name: genmop/tformer-ref 47 | test_database_name: genmop/tformer-test 48 | ref_split_type: all 49 | test_split_type: all 50 | selector_angles: [-90, -45, 0, 45, 90] 51 | selector_ref_num: 64 52 | selector_ref_res: 128 53 | 54 | ##########optimizer########## 55 | optimizer_type: adam 56 | lr_type: exp_decay 57 | lr_cfg: 58 | lr_init: 1.0e-4 59 | decay_step: 100000 60 | decay_rate: 0.5 61 | total_step: 300000 62 | train_log_step: 50 63 | val_interval: 5000 64 | save_interval: 500 -------------------------------------------------------------------------------- /src/gen6d/Gen6D/dataset/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/src/gen6d/Gen6D/dataset/__init__.py -------------------------------------------------------------------------------- /src/gen6d/Gen6D/network/__init__.py: -------------------------------------------------------------------------------- 1 | from .detector import Detector 2 | from .refiner import VolumeRefiner 3 | from .selector import ViewpointSelector 4 | 5 | name2network={ 6 | 'refiner': VolumeRefiner, 7 | 'detector': Detector, 8 | 'selector': ViewpointSelector, 9 | } -------------------------------------------------------------------------------- /src/gen6d/Gen6D/network/attention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | def attention(query, key, value, key_mask=None, temperature=1.0): 5 | """ 6 | @param query: b,d,h,n 7 | @param key: b,d,h,m 8 | @param value: b,d,h,m 9 | @param key_mask: b,1,1,m 10 | @param temperature: 11 | @return: 12 | """ 13 | dim = query.shape[1] 14 | scores = torch.einsum('bdhn,bdhm->bhnm', query / temperature, key) / dim ** .5 # b,head,n0,n1 15 | if key_mask is not None: scores = scores.masked_fill(key_mask == 0, -1e7) 16 | prob = torch.nn.functional.softmax(scores, dim=-1) 17 | return torch.einsum('bhnm,bdhm->bdhn', prob, value), prob 18 | 19 | class SpecialLayerNorm(nn.Module): 20 | def __init__(self, in_dim): 21 | super().__init__() 22 | self.norm=nn.LayerNorm(in_dim) 23 | 24 | def forward(self,x): 25 | x = self.norm(x.permute(0,2,1)) 26 | return x.permute(0,2,1) 27 | 28 | class AttentionBlock(nn.Module): 29 | def __init__(self,in_dim,att_dim,out_dim,head_num=4,temperature=1.0,bias=True,skip_connect=True,norm='layer'): 30 | super().__init__() 31 | self.conv_key=nn.Conv1d(in_dim,att_dim,1,bias=bias) 32 | self.conv_query=nn.Conv1d(in_dim,att_dim,1,bias=bias) 33 | self.conv_feats=nn.Conv1d(in_dim,out_dim,1,bias=bias) 34 | self.conv_merge=nn.Conv1d(out_dim,out_dim,1,bias=bias) 35 | 36 | self.head_att_dim=att_dim//head_num 37 | self.head_out_dim=out_dim//head_num 38 | self.head_num=head_num 39 | self.temperature=temperature 40 | if norm=='layer': 41 | self.norm=SpecialLayerNorm(out_dim) 42 | elif norm=='instance': 43 | self.norm = nn.InstanceNorm1d(out_dim) 44 | else: 45 | raise NotImplementedError 46 | self.skip_connect = skip_connect 47 | if skip_connect: 48 | assert(in_dim==out_dim) 49 | 50 | def forward(self, feats_query, feats_key, key_mask=None): 51 | ''' 52 | :param feats_query: b,f,n0 53 | :param feats_key: b,f,n1 54 | :param key_mask: b,1,n1 55 | :return: b,f,n0 56 | ''' 57 | b,f,n0=feats_query.shape 58 | b,f,n1=feats_key.shape 59 | 60 | query=self.conv_query(feats_query).reshape(b, self.head_att_dim, self.head_num, n0) # b,had,hn,n0 61 | key=self.conv_key(feats_key).reshape(b, self.head_att_dim, self.head_num, n1) # b,had,hn,n1 62 | feats=self.conv_feats(feats_key).reshape(b, self.head_out_dim, self.head_num, n1) # b,hod,hn,n1 63 | if key_mask is not None: key_mask = key_mask.reshape(b, 1, 1, n1) # b,1,1,n1 64 | feats_out, weights = attention(query, key, feats, key_mask, self.temperature) 65 | feats_out = feats_out.reshape(b,self.head_out_dim*self.head_num,n0) # b,hod*hn,n0 66 | feats_out = self.conv_merge(feats_out) 67 | if self.skip_connect: feats_out=feats_out+feats_query 68 | feats_out = self.norm(feats_out) 69 | return feats_out -------------------------------------------------------------------------------- /src/gen6d/Gen6D/network/operator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | def normalize_coords(coords: torch.Tensor, h, w): 5 | """ 6 | normalzie coords to [-1,1] 7 | @param coords: 8 | @param h: 9 | @param w: 10 | @return: 11 | """ 12 | coords = torch.clone(coords) 13 | coords = coords + 0.5 14 | coords[...,0] = coords[...,0]/w 15 | coords[...,1] = coords[...,1]/h 16 | coords = (coords - 0.5)*2 17 | return coords 18 | 19 | def pose_apply_th(poses,pts): 20 | return pts @ poses[:,:,:3].permute(0,2,1) + poses[:,:,3:].permute(0,2,1) 21 | 22 | def generate_coords(h,w,device): 23 | coords=torch.stack(torch.meshgrid(torch.arange(h,device=device),torch.arange(w,device=device)),-1) 24 | return coords[...,(1,0)] -------------------------------------------------------------------------------- /src/gen6d/Gen6D/scy/Config.py: -------------------------------------------------------------------------------- 1 | # SAVE_normalized_ref_imgs=1 2 | DO_NOT_load_model=0 3 | PR_AS_INIT=0 4 | LOCAL=0#run in local machine(redmi 5 | if(LOCAL): 6 | DO_NOT_load_model=1 -------------------------------------------------------------------------------- /src/gen6d/Gen6D/scy/DebugUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import Image 3 | import cv2 4 | 5 | def show_img(img): 6 | if isinstance(img, np.ndarray): 7 | 8 | cv2.imshow('Image', img) 9 | cv2.waitKey(0) 10 | cv2.destroyAllWindows() 11 | elif isinstance(img, Image.Image): 12 | 13 | img.show() 14 | elif isinstance(img, str): 15 | 16 | image = cv2.imread(img) 17 | cv2.imshow('Image', image) 18 | cv2.waitKey(0) 19 | cv2.destroyAllWindows() 20 | else: 21 | print("Unsupported image type.") 22 | 23 | def save_img(img, path): 24 | if isinstance(img, np.ndarray): 25 | 26 | cv2.imwrite(path, img) 27 | elif isinstance(img, Image.Image): 28 | 29 | img.save(path) 30 | else: 31 | print("Unsupported image type.") -------------------------------------------------------------------------------- /src/gen6d/Gen6D/scy/IntermediateResult.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json, math 3 | 4 | 5 | class IntermediateResult: 6 | def __init__(s, ): 7 | s.data = {} 8 | 9 | def append(s, i, K, pose): 10 | s.data[i] = { 11 | "K": K, 12 | "pose": pose 13 | } 14 | 15 | def load(s, path): 16 | with open(path, "r") as f: 17 | s.data = json.load(f) 18 | 19 | for i in s.data: 20 | for key in s.data[i]: 21 | s.data[i][key] = np.array(s.data[i][key]) 22 | 23 | def dump(self, path): 24 | 25 | import json 26 | import numpy 27 | from torch import Tensor 28 | 29 | class MyJSONEncoder(json.JSONEncoder): 30 | def default(self, obj): 31 | if isinstance(obj, numpy.ndarray): 32 | return obj.tolist() 33 | if isinstance(obj, Tensor): 34 | return obj.cpu().data.numpy().tolist() 35 | elif (isinstance(obj, numpy.int32) or 36 | isinstance(obj, numpy.int64) or 37 | isinstance(obj, numpy.float32) or 38 | isinstance(obj, numpy.float64)): 39 | return obj.item() 40 | return json.JSONEncoder.default(self, obj) 41 | 42 | with open(path, "w") as f: 43 | json.dump(self.data, f, cls=MyJSONEncoder) 44 | -------------------------------------------------------------------------------- /src/gen6d/Gen6D/scy/MyJSONEncoder.py: -------------------------------------------------------------------------------- 1 | 2 | import json 3 | import numpy 4 | from torch import Tensor 5 | 6 | 7 | def to_list_to_primitive(obj): 8 | if isinstance(obj, numpy.ndarray): 9 | return obj.tolist() 10 | if isinstance(obj, Tensor): 11 | return obj.cpu().data.numpy().tolist() 12 | if isinstance(obj, list): 13 | return [to_list_to_primitive(i) for i in obj] 14 | # if isinstance(obj, DataFrame): 15 | # return obj.values.tolist() 16 | elif (isinstance(obj, numpy.int32) or 17 | isinstance(obj, numpy.int64) or 18 | isinstance(obj, numpy.float32) or 19 | isinstance(obj, numpy.float64)): 20 | return obj.item() 21 | else: 22 | assert 0 23 | 24 | 25 | class MyJSONEncoder(json.JSONEncoder): 26 | def default(self, obj): 27 | if isinstance(obj, numpy.ndarray): 28 | return obj.tolist() 29 | if isinstance(obj, Tensor): 30 | return obj.cpu().data.numpy().tolist() 31 | elif (isinstance(obj, numpy.int32) or 32 | isinstance(obj, numpy.int64) or 33 | isinstance(obj, numpy.float32) or 34 | isinstance(obj, numpy.float64)): 35 | return obj.item() 36 | return json.JSONEncoder.default(self, obj) 37 | -------------------------------------------------------------------------------- /src/gen6d/Gen6D/scy/Zero123Detector.py: -------------------------------------------------------------------------------- 1 | from .DebugUtil import * 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torchvision 7 | import numpy as np 8 | from PIL import Image 9 | import os 10 | from PIL import Image 11 | 12 | 13 | class Zero123Detector: 14 | def __init__(s): 15 | s.ref_h = None 16 | s.ref_w = None 17 | 18 | def load_ref_imgs(s, ref_imgs): 19 | """ 20 | @param ref_imgs: [an,rfn,h,w,3] in numpy 21 | @return: 22 | """ 23 | ref_imgs = torch.from_numpy(ref_imgs).permute(0, 3, 1, 2) # rfn,3,h,w 24 | rfn, _, h, w = ref_imgs.shape 25 | s.ref_h = h 26 | s.ref_w = w 27 | 28 | def detect_que_imgs(s, que_imgs): 29 | imgs = que_imgs 30 | """ 31 | io is same as raw gen6d detector: 32 | @param que_imgs: [qn,h,w,3] 33 | @return: 34 | """ 35 | positions = [] 36 | scales = [] 37 | 38 | def crop(img, bg_color, h, w, **kw): 39 | """ 40 | param: 41 | img: PIL Image 对象 42 | bg_color: 背景颜色,形如 (R, G, B) 的元组 43 | size: 裁剪后的目标尺寸,形如 (width, height) 的元组 44 | return: 45 | cropped_img: 裁剪后的物体图像,PIL Image 对象 46 | """ 47 | 48 | img_array = np.array(img) 49 | 50 | 51 | diff_pixels = np.any(np.abs(img_array - np.array(bg_color)) > 5, axis=2) 52 | 53 | 54 | rows = np.any(diff_pixels, axis=1) 55 | cols = np.any(diff_pixels, axis=0) 56 | ymin, ymax = np.where(rows)[0][[0, -1]] 57 | xmin, xmax = np.where(cols)[0][[0, -1]] 58 | 59 | obj_width = xmax - xmin 60 | obj_height = ymax - ymin 61 | # margin 62 | margin_w_px = obj_width * kw["margin_percent"] 63 | margin_h_px = obj_height * kw["margin_percent"] 64 | obj_width = obj_width + margin_w_px * 2 65 | obj_height = obj_height + margin_h_px * 2 66 | 67 | target_width = w 68 | target_height = h 69 | if (obj_width / obj_height > target_width / target_height): 70 | adjusted_width = obj_width 71 | adjusted_height = obj_width * target_height / target_width 72 | else: 73 | adjusted_width = obj_height * target_width / target_height 74 | adjusted_height = obj_height 75 | 76 | 77 | x = xmin + (obj_width - adjusted_width) / 2 78 | y = ymin - (adjusted_height - obj_height) / 2 79 | 80 | 81 | x_end = x + adjusted_width 82 | y_end = y + adjusted_height 83 | 84 | 85 | x = max(0, x) 86 | y = max(0, y) 87 | x_end = min(img_array.shape[1], x_end) 88 | y_end = min(img_array.shape[0], y_end) 89 | # to int 90 | x = int(x) 91 | y = int(y) 92 | x_end = int(x_end) 93 | y_end = int(y_end) 94 | def get_crop(img_array, x, y, x_end, y_end): 95 | cropped_img_array = img_array[y:y_end, x:x_end, :] 96 | 97 | cropped_img = Image.fromarray(cropped_img_array) 98 | 99 | cropped_img = cropped_img.resize((w, h)) 100 | return cropped_img 101 | CROP = False 102 | if (CROP): 103 | return get_crop(img_array, x, y, x_end, y_end) 104 | else: 105 | position = np.asarray([(x+x_end)/2,(y+y_end)/2]) 106 | scale = adjusted_width/target_width 107 | return position,scale 108 | show_img(get_crop(img_array, x, y, x_end, y_end)) 109 | 110 | for img in imgs: 111 | position,scale=crop(img, (255,255,255), s.ref_h, s.ref_w, margin_percent=0.1) 112 | positions.append(position) 113 | scales.append(scale) 114 | detection_outputs = { 115 | "positions": positions, 116 | "scales": scales, 117 | } 118 | return detection_outputs 119 | -------------------------------------------------------------------------------- /src/gen6d/Gen6D/scy/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/src/gen6d/Gen6D/scy/__init__.py -------------------------------------------------------------------------------- /src/gen6d/Gen6D/scy/gen6dGlobal.py: -------------------------------------------------------------------------------- 1 | import root_config 2 | 3 | 4 | class gen6dGlobal: 5 | USE_Zero123Detector=root_config.USE_white_bg_Detector -------------------------------------------------------------------------------- /src/gen6d/Gen6D/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/src/gen6d/Gen6D/utils/__init__.py -------------------------------------------------------------------------------- /src/gen6d/Gen6D/utils/bbox_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def bboxes_lthw_squared(bboxes): 5 | """ 6 | @param bboxes: b,4 in lthw 7 | @return: b,4 8 | """ 9 | bboxes_len = bboxes[:, 2:] 10 | bboxes_cen = bboxes[:, :2] + bboxes_len/2 11 | bboxes_max_len = torch.max(bboxes_len,1,keepdim=True)[0] # b,1 12 | bboxes_len = bboxes_max_len.repeat(1,2) 13 | bboxes_left_top = bboxes_cen - bboxes_len/2 14 | return torch.cat([bboxes_left_top,bboxes_len],1) 15 | 16 | def bboxes_area(bboxes): 17 | return (bboxes[...,2]-bboxes[...,0])*(bboxes[...,3]-bboxes[...,1]) 18 | 19 | def bboxes_iou(bboxes0, bboxes1,th=True): 20 | """ 21 | @param bboxes0: ...,4 22 | @param bboxes1: ...,4 23 | @return: ... 24 | """ 25 | if th: 26 | x0 = torch.max(torch.stack([bboxes0[..., 0], bboxes1[..., 0]], -1), -1)[0] 27 | y0 = torch.max(torch.stack([bboxes0[..., 1], bboxes1[..., 1]], -1), -1)[0] 28 | x1 = torch.min(torch.stack([bboxes0[..., 2], bboxes1[..., 2]], -1), -1)[0] 29 | y1 = torch.min(torch.stack([bboxes0[..., 3], bboxes1[..., 3]], -1), -1)[0] 30 | inter = torch.clip(x1 - x0, min=0) * torch.clip(y1 - y0, min=0) 31 | else: 32 | x0 = np.max(np.stack([bboxes0[..., 0], bboxes1[..., 0]], -1), -1)[0] 33 | y0 = np.max(np.stack([bboxes0[..., 1], bboxes1[..., 1]], -1), -1)[0] 34 | x1 = np.min(np.stack([bboxes0[..., 2], bboxes1[..., 2]], -1), -1)[0] 35 | y1 = np.min(np.stack([bboxes0[..., 3], bboxes1[..., 3]], -1), -1)[0] 36 | inter = np.clip(x1 - x0, a_min=0, a_max=999999) * np.clip(y1 - y0, a_min=0, a_max=999999) 37 | union = bboxes_area(bboxes0) + bboxes_area(bboxes1) - inter 38 | iou = inter / union 39 | return iou 40 | 41 | def lthw_to_ltrb(bboxes,th=True): 42 | if th: 43 | return torch.cat([bboxes[...,:2],bboxes[...,:2]+bboxes[...,2:]],-1) 44 | else: 45 | return np.concatenate([bboxes[..., :2], bboxes[..., :2] + bboxes[..., 2:]], -1) 46 | 47 | def cl_to_ltrb(bboxes_cl): 48 | bboxes_cen = bboxes_cl[...,:2] 49 | bboxes_len = bboxes_cl[...,2:] 50 | return torch.cat([bboxes_cen-bboxes_len/2,bboxes_cen+bboxes_len/2],-1) 51 | 52 | def ltrb_to_cl(bboxes_ltrb): 53 | bboxes_cen = (bboxes_ltrb[...,:2]+bboxes_ltrb[...,2:])/2 54 | bboxes_len = bboxes_ltrb[..., 2:]-bboxes_ltrb[...,:2] 55 | return torch.cat([bboxes_cen,bboxes_len],-1) 56 | 57 | def ltrb_to_lthw(bboxes,th=True): 58 | if th: 59 | raise NotImplementedError 60 | else: 61 | lt = bboxes[...,:2] 62 | hw = bboxes[...,2:] - lt 63 | return np.concatenate([lt,hw],-1) 64 | 65 | def cl_to_lthw(bboxes_cl,th=True): 66 | if th: 67 | lt = bboxes_cl[..., :2] - bboxes_cl[..., 2:] / 2 68 | return torch.cat([lt, bboxes_cl[..., 2:]], -1) 69 | else: 70 | lt = bboxes_cl[..., :2] - bboxes_cl[..., 2:] / 2 71 | return np.concatenate([lt,bboxes_cl[...,2:]],-1) 72 | 73 | def parse_bbox_from_scale_offset(que_select_id, scale_pr, select_offset, pool_ratio, ref_shape): 74 | """ 75 | 76 | @param que_select_id: [2] x,y 77 | @param scale_pr: [hq,wq] 78 | @param select_offset: [2,hq,wq] 79 | @param pool_ratio: int 80 | @param ref_shape: [2] h,w 81 | @return: 82 | """ 83 | hr, wr = ref_shape 84 | select_x, select_y = que_select_id 85 | scale_pr = scale_pr 86 | offset_pr = select_offset 87 | scale_pr = scale_pr[select_y,select_x] 88 | scale_pr = 2**scale_pr 89 | pool_ratio = pool_ratio 90 | offset_x, offset_y = offset_pr[:,select_y,select_x] 91 | center_x, center_y = select_x+offset_x, select_y+offset_y 92 | center_x = (center_x + 0.5) * pool_ratio - 0.5 93 | center_y = (center_y + 0.5) * pool_ratio - 0.5 94 | h_pr, w_pr = hr * scale_pr, wr * scale_pr 95 | bbox_pr = np.asarray([center_x - w_pr/2, center_y-h_pr/2, w_pr, h_pr]) 96 | return bbox_pr -------------------------------------------------------------------------------- /src/gen6d/Gen6D/utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import random 4 | import torch 5 | 6 | def dummy_collate_fn(data_list): 7 | return data_list[0] 8 | 9 | def simple_collate_fn(data_list): 10 | ks=data_list[0].keys() 11 | outputs={k:[] for k in ks} 12 | for k in ks: 13 | if isinstance(data_list[0][k], dict): 14 | outputs[k] = {k_: [] for k_ in data_list[0][k].keys()} 15 | for k_ in data_list[0][k].keys(): 16 | for data in data_list: 17 | outputs[k][k_].append(data[k][k_]) 18 | outputs[k][k_]=torch.stack(outputs[k][k_], 0) 19 | else: 20 | for data in data_list: 21 | outputs[k].append(data[k]) 22 | if isinstance(data_list[0][k], torch.Tensor): 23 | outputs[k]=torch.stack(outputs[k],0) 24 | return outputs 25 | 26 | def set_seed(index,is_train): 27 | if is_train: 28 | np.random.seed((index+int(time.time()))%(2**16)) 29 | random.seed((index+int(time.time()))%(2**16)+1) 30 | torch.random.manual_seed((index+int(time.time()))%(2**16)+1) 31 | else: 32 | np.random.seed(index % (2 ** 16)) 33 | random.seed(index % (2 ** 16) + 1) 34 | torch.random.manual_seed(index % (2 ** 16) + 1) -------------------------------------------------------------------------------- /src/gen6d/Gen6D/utils/imgs_info.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | from .base_utils import color_map_forward 5 | 6 | 7 | def imgs_info_to_torch(imgs_info): 8 | for k, v in imgs_info.items(): 9 | if isinstance(v,np.ndarray): 10 | imgs_info[k] = torch.from_numpy(v) 11 | return imgs_info 12 | 13 | def build_imgs_info(database, ref_ids, has_mask=True): 14 | ref_Ks = np.asarray([database.get_K(ref_id) for ref_id in ref_ids], dtype=np.float32) 15 | 16 | ref_imgs = [database.get_image(ref_id) for ref_id in ref_ids] 17 | if has_mask: ref_masks = [database.get_mask(ref_id) for ref_id in ref_ids] 18 | else: ref_masks = None 19 | 20 | ref_imgs = (np.stack(ref_imgs, 0)).transpose([0, 3, 1, 2]) 21 | ref_imgs = color_map_forward(ref_imgs) 22 | if has_mask: ref_masks = np.stack(ref_masks, 0)[:, None, :, :] 23 | ref_poses = np.asarray([database.get_pose(ref_id) for ref_id in ref_ids], dtype=np.float32) 24 | 25 | ref_imgs_info = {'imgs': ref_imgs, 'poses': ref_poses, 'Ks': ref_Ks} 26 | if has_mask: ref_imgs_info['masks'] = ref_masks 27 | return ref_imgs_info 28 | -------------------------------------------------------------------------------- /src/import_util.py: -------------------------------------------------------------------------------- 1 | 2 | import os,sys 3 | def is_in_sysPath(path): 4 | # print(os.getcwd()) 5 | # print(sys.path) 6 | tmp=[file for file in sys.path if os.path.exists(file)] 7 | return any([os.path.samefile(path,file) for file in tmp ]) 8 | def can_not_relative_import(file_path): 9 | path=os.path.abspath(os.path.join(os.path.dirname(file_path),os.path.pardir)) 10 | if(is_in_sysPath(path=path)): 11 | return 1 12 | path=os.path.abspath(os.path.join(os.path.dirname(file_path) )) 13 | if(is_in_sysPath(path=path)): 14 | return 1 15 | return 0 16 | def import_relposepp_evaluate_pairwise(): 17 | import sys,root_config 18 | sys.path.append( root_config.projPath_relposepp ) 19 | 20 | from eval.eval_rotation_util import evaluate_pairwise 21 | 22 | for i,p in enumerate(sys.path): 23 | if p==root_config.projPath_relposepp: 24 | del sys.path[i] 25 | break 26 | print("sys.path=",sys.path) 27 | return evaluate_pairwise -------------------------------------------------------------------------------- /src/imports.py: -------------------------------------------------------------------------------- 1 | import root_config,os,sys,shutil 2 | from skimage.io import imread, imsave 3 | from logging_util import * 4 | from pose_util import * 5 | from misc_util import * 6 | from exception_util import * 7 | from dataset_util import * 8 | import vis.cv2_util as cv2_util 9 | from vis.InterVisualizer import InterVisualizer 10 | from vis.vis_rel_pose import PoseVisualizer 11 | import numpy as np 12 | import os,sys,math,functools,inspect,PIL 13 | class Global: 14 | anything={} 15 | class ImagePair: 16 | def __init__(self,): 17 | self.l=[] 18 | def append(self,im): 19 | __N = 1 if root_config.one_SEQ_mul_Q0__one_Q0_mul_Q1 else 2 20 | self.l.append(im) 21 | if(len(self.l)>__N): 22 | self.l=self.l[-__N:] 23 | intermediate={ 24 | "E2VG":{ 25 | "inter_img":ImagePair(), 26 | }, 27 | } 28 | 29 | poseVisualizer1=PoseVisualizer() 30 | # interVisualizer=InterVisualizer() 31 | class RefinerInterPoses: 32 | """ 33 | """ 34 | __l=[]# list of refiner raw output pose 35 | """ 36 | @classmethod 37 | def from_dicValue(cls,l:list): 38 | assert cls.__l ==[] 39 | assert l 40 | cls.__l=l 41 | """ 42 | @classmethod 43 | def set_evalResult(cls,_evalResult, ): 44 | cls.__evalResult = _evalResult 45 | @classmethod 46 | def load_pair(cls,sequence_name, i, j): 47 | assert cls.__l ==[] 48 | l=cls.__evalResult .get_pair__in_dic(sequence_name, i, j) ['RefinerInterPoses'] 49 | assert l 50 | cls.__l=l 51 | @classmethod 52 | def to_dicValue_and_clear(cls, )->list: 53 | assert cls.__l 54 | ret=cls.__l 55 | cls.__l=[] 56 | ret=to_list_to_primitive(ret) 57 | return ret 58 | """ 59 | @classmethod 60 | def append(cls,i,raw ): 61 | assert len(cls.__l)==i 62 | cls.__l.append(raw) 63 | """ 64 | @classmethod 65 | def set(cls,l ):#l: [before refiner,after 1st refine,2nd,...] 66 | assert len(l)==root_config.REFINE_ITER+1 67 | if root_config.ABLATE_REFINE_ITER is not None: 68 | assert cls.__l==[] 69 | cls.__l=l 70 | @classmethod 71 | def get(cls, ): 72 | ret=cls.__l[root_config.ABLATE_REFINE_ITER] 73 | ret=np.array(ret) 74 | ret=Pose_R_t_Converter.pose34_2_pose44(ret) 75 | return ret 76 | from debug_util import debug_imsave 77 | -------------------------------------------------------------------------------- /src/logging_util.py: -------------------------------------------------------------------------------- 1 | # from rich import print 2 | import rich 3 | __primitive_print=print 4 | def print(*args, 5 | use_primitive=1, 6 | **kw): 7 | if use_primitive: 8 | return __primitive_print(*args,**kw) 9 | else: 10 | 11 | return rich.print(*[arg.replace('[',r'\[') if isinstance(arg,str) else arg for arg in args ],**kw, ) 12 | from redirect_util import * 13 | # ------------------------------------------------------ 14 | EXP_fp_and_lineNo="f'{os.path.abspath(__file__)}:{inspect.currentframe().f_lineno}'" 15 | EXP_print_fp_and_lineNo="eval(f'{os.path.abspath(__file__)}:{inspect.currentframe().f_lineno}')" 16 | #------------------ log_Util ---------------------------- V2022CFG230119 17 | from logging import debug as DEBUG 18 | ddd=DEBUG 19 | from logging import info as INFO 20 | from logging import warning as WARNING 21 | from logging import error as ERROR 22 | def pDEBUG(*args): 23 | DEBUG(" ".join([str(arg) for arg in args])) 24 | def pINFO(*args): 25 | INFO(" ".join([str(arg) for arg in args])) 26 | def pWARNING(*args): 27 | WARNING(" ".join([str(arg) for arg in args])) 28 | def pERROR(*args): 29 | ERROR(" ".join([str(arg) for arg in args])) 30 | import logging 31 | ## create logger with 'spam_application' 32 | # logger = logging.getLogger("My_app") 33 | logger = logging.root 34 | class _CustomFormatter(logging.Formatter): 35 | blue = "\x1b[1;34m" 36 | light_blue = "\x1b[1;36m" 37 | purple = "\x1b[1;35m" 38 | normal="\x1b[0;20m" 39 | white="\x1b[37;20m" 40 | cyan= "\x1b[36;20m"# 41 | grey = "\x1b[38;20m" 42 | yellow = "\x1b[33;20m" 43 | green = "\x1b[32;20m" 44 | red = "\x1b[31;20m" 45 | bold_red = "\x1b[31;1m" 46 | white_on_red_bg="\x1b[41;20m" 47 | reset = "\x1b[0m" 48 | 49 | # format = "%(asctime)s%(filename)s-%(lineno)d-%(funcName)s [%(levelname)s] %(message)s" 50 | format = "%(asctime)s %(filename)s:%(lineno)d %(funcName)s [%(levelname)s] %(message)s" 51 | 52 | FORMATS = { 53 | logging.DEBUG: cyan + format + reset, 54 | # logging.INFO: normal + format + reset, 55 | logging.INFO: white + format + reset, 56 | logging.WARNING: yellow + format + reset, 57 | logging.ERROR: white_on_red_bg + format + reset, 58 | logging.CRITICAL: white_on_red_bg + format + reset 59 | } 60 | 61 | def format(self, record): 62 | log_fmt = self.FORMATS.get(record.levelno) 63 | formatter = logging.Formatter(fmt=log_fmt,datefmt="%H:%M:%S") 64 | return formatter.format(record) 65 | 66 | 67 | 68 | 69 | def _configure_logging(level=logging.DEBUG): 70 | print("_configure_logging") 71 | # logging.basicConfig( 72 | 73 | 74 | 75 | 76 | 77 | # logging.basicConfig(datefmt='%M:%S') 78 | logger.setLevel(level) 79 | 80 | 81 | # create console handler with a higher log level 82 | ch = logging.StreamHandler() 83 | ch.setLevel(logging.DEBUG) 84 | 85 | ch.setFormatter(_CustomFormatter()) 86 | 87 | logger.addHandler(ch) 88 | print("_configure_logging over") 89 | 90 | _configure_logging(level=logging.INFO) 91 | 92 | if(__name__=="__main__"): 93 | INFO("INFO") 94 | ERROR("ERROR") 95 | DEBUG("DEBUG") 96 | WARNING("WARNING") -------------------------------------------------------------------------------- /src/mask_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys 3 | 4 | 5 | class Mask: 6 | @staticmethod 7 | def hw_255__2__hw_bool(arr:np.ndarray,THRES=125): 8 | """ 9 | hwAny 10 | """ 11 | assert arr.dtype==np.uint8 12 | arr=arr.astype(np.float32)/255 13 | arr[arr=THRES/255]=1 15 | return arr 16 | 17 | @staticmethod 18 | def hw0__2__hw1(arr: np.ndarray,): 19 | assert len(arr.shape)==2 20 | arr=arr[:,:,None] 21 | return arr 22 | @staticmethod 23 | def hw0__2__hw3(arr: np.ndarray,): 24 | assert len(arr.shape)==2 25 | arr=arr[:,:,None] 26 | arr=np.concatenate([arr,arr,arr],axis=-1) 27 | assert len(arr.shape)==3 and arr.shape[-1]==3 28 | return arr 29 | @staticmethod 30 | def hw3__2__hw0(arr: np.ndarray,): 31 | assert len(arr.shape)==3 and arr.shape[-1]==3 32 | arr_c0=arr[:,:,0] 33 | arr_c1=arr[:,:,1] 34 | arr_c2=arr[:,:,2] 35 | THRES=0 36 | 37 | a01 =np.abs(arr_c0-arr_c1)>THRES 38 | if np.any(a01): 39 | print("np.where(a01)",np.where(a01)) 40 | print("corresponding arr_c0", arr_c0[np.where(a01)]) 41 | print("corresponding arr_c1", arr_c1[np.where(a01)]) 42 | assert 0 43 | a02 =np.abs(arr_c0-arr_c2)>THRES 44 | if np.any(a02): 45 | print("np.where(a02)",np.where(a02)) 46 | print("corresponding arr_c0", arr_c0[np.where(a02)]) 47 | print("corresponding arr_c2", arr_c2[np.where(a02)]) 48 | assert 0 49 | return arr_c0 50 | @staticmethod 51 | def get_blackBg_maskedImage_from_hw0_255( img: np.ndarray,mask: np.ndarray,THRES=125): 52 | assert len(img.shape)==3 53 | assert len(mask.shape)==2 54 | assert img.shape[:2]==mask.shape[:2] 55 | mask=Mask.hw_255__2__hw_bool(mask,THRES=THRES) 56 | mask=Mask.hw0__2__hw1(mask) 57 | assert img.shape[:2]==mask.shape[:2] 58 | masked_image=img*mask 59 | return masked_image 60 | 61 | @staticmethod 62 | def get_blackBg_maskedImage_from_hw1_255(img: np.ndarray, mask: np.ndarray, THRES=125): 63 | assert len(img.shape) == 3 64 | assert len(mask.shape) == 2 65 | mask = Mask.hw_255__2__hw_bool(mask, THRES=THRES) 66 | masked_image = img * mask 67 | return masked_image 68 | 69 | @staticmethod 70 | def get_whiteBg_maskedImage_from_hw0_255( img: np.ndarray,mask: np.ndarray,THRES=125): 71 | assert len(img.shape)==3 72 | assert len(mask.shape)==2 73 | mask=Mask.hw_255__2__hw_bool(mask,THRES=THRES) 74 | mask=Mask.hw0__2__hw1(mask) 75 | # mask==1, use img; else use white 76 | white_image = np.ones_like(img) * 255 77 | t=img * mask 78 | tt=np.array(t,dtype=np.uint8) 79 | t2=white_image * (1 - mask) 80 | masked_image = img * mask + white_image * (1 - mask) 81 | return masked_image 82 | 83 | 84 | @staticmethod 85 | def get_whiteBg_maskedImage_from_hw1_255(img: np.ndarray, mask: np.ndarray, THRES=125): 86 | assert len(img.shape) == 3 87 | assert len(mask.shape) == 2 88 | mask = Mask.hw_255__2__hw_bool(mask, THRES=THRES) 89 | # mask==1, use img; else use white 90 | white_image = np.ones_like(img) * 255 91 | masked_image = img * mask + white_image * (1 - mask) 92 | return masked_image 93 | @staticmethod 94 | def rgbaImage__2__hw0_255(img: np.ndarray , ALPHA_THRES=0): 95 | assert len(img.shape)==3 96 | assert img.shape[-1]==4 97 | alpha=img[:,:,3] 98 | assert np.all(alpha<=255) 99 | assert np.all(alpha>=0) 100 | hw0_255=alpha 101 | hw0_255[hw0_255<=ALPHA_THRES]=0 102 | hw0_255[hw0_255>ALPHA_THRES]=255 103 | return hw0_255 104 | @staticmethod 105 | def mask_hw0_bool__2__bbox(mask): # TODO check 106 | # check all are 0 or 1; check shape 107 | unique_values = np.unique(mask) 108 | assert np.array_equal(unique_values, np.array([0, 1])) 109 | # 110 | assert len(mask.shape)==2 111 | # 112 | mask = np.where(mask == True) 113 | # x1, y1 = np.min(mask, axis=1) 114 | # x2, y2 = np.max(mask, axis=1) 115 | y1, x1 = np.min(mask, axis=1) 116 | y2, x2 = np.max(mask, axis=1) 117 | return [x1, y1, x2, y2] -------------------------------------------------------------------------------- /src/misc_util.py: -------------------------------------------------------------------------------- 1 | 2 | import os,time,root_config 3 | import numpy as np 4 | from pathlib import Path 5 | import pprint 6 | class ch_cwd_to_this_file: 7 | def __init__(self, _code_file_path): 8 | self._code_file_path = _code_file_path 9 | def __enter__(self): 10 | self._old_dir = os.getcwd() 11 | cwd=os.path.dirname(os.path.abspath(self._code_file_path)) 12 | os.chdir(cwd) 13 | def __exit__(self, exc_type, exc_val, exc_tb): 14 | os.chdir(self._old_dir) 15 | # def img_2_img_full_path(img,format='jpg',original_name_or_path=''): 16 | # """ 17 | # thread safe 18 | # """ 19 | # assert isinstance(img,np.ndarray) 20 | # assert img.shape[2]==3 or img.shape[2]==4 21 | # original_img_name_without_dir=os.path.basename(original_name_or_path) 22 | # full_path = os.path.join(root_config.path_root, f'./tmp_images/[{root_config.DATASET}][{tmp_cate_or_obj}][{sequence_name}]{img_name_without_suffix}.jpg') 23 | # if not os.path.exists(os.path.dirname(full_path)): 24 | # os.makedirs(os.path.dirname(full_path)) 25 | # print("get_data path:", full_path) 26 | # img.save(full_path) 27 | # return img_full_path 28 | 29 | import datetime 30 | import pytz 31 | def your_datetime()->datetime.datetime: 32 | """ 33 | """ 34 | 35 | local_tz = datetime.datetime.now(datetime.timezone.utc).astimezone().tzinfo 36 | 37 | 38 | now = datetime.datetime.now() 39 | 40 | local_time = now.astimezone(local_tz) 41 | 42 | return local_time 43 | def get_datetime_str( os_is_windows=False)->str: 44 | 45 | ret= f"{your_datetime():%m.%d-%H:%M:%S}" 46 | if os_is_windows: 47 | ret=ret.replace(':','-') 48 | return ret 49 | 50 | 51 | 52 | 53 | 54 | 55 | import json 56 | import numpy 57 | from torch import Tensor 58 | 59 | 60 | def to_list_to_primitive(obj): 61 | if isinstance(obj, numpy.ndarray): 62 | return obj.tolist() 63 | if isinstance(obj, Tensor): 64 | return obj.cpu().data.numpy().tolist() 65 | if isinstance(obj, list): 66 | return [to_list_to_primitive(i) for i in obj] 67 | # if isinstance(obj, DataFrame): 68 | # return obj.values.tolist() 69 | elif (isinstance(obj, numpy.int32) or 70 | isinstance(obj, numpy.int64) or 71 | isinstance(obj, numpy.float32) or 72 | isinstance(obj, numpy.float64)): 73 | return obj.item() 74 | elif (isinstance(obj, int) or 75 | isinstance(obj, float) 76 | ): 77 | return obj 78 | else: 79 | assert 0 80 | 81 | 82 | class MyJSONEncoder(json.JSONEncoder): 83 | def default(self, obj): 84 | if isinstance(obj, numpy.ndarray): 85 | return obj.tolist() 86 | if isinstance(obj, Tensor): 87 | return obj.cpu().data.numpy().tolist() 88 | elif (isinstance(obj, numpy.int32) or 89 | isinstance(obj, numpy.int64) or 90 | isinstance(obj, numpy.float32) or 91 | isinstance(obj, numpy.float64)): 92 | return obj.item() 93 | return json.JSONEncoder.default(self, obj) 94 | 95 | def truncate_str(string:str,MAX_LEN:int,suffix_if_truncate="......")->str: 96 | assert isinstance(string,str) 97 | if len(string)> MAX_LEN: 98 | string=string[:MAX_LEN]+suffix_if_truncate 99 | return string 100 | def map_string_to_int(string,MIN,MAX): 101 | """ 102 | """ 103 | assert isinstance(MIN,int) 104 | assert isinstance(MAX,int) 105 | assert MAX-MIN>=2 106 | 107 | sum = 0 108 | for char in string: 109 | sum += ord(char) 110 | # print("sum", sum) 111 | ret=2**sum 112 | ret += sum 113 | ret=ret%(MAX-MIN) 114 | ret+=MIN 115 | return ret 116 | 117 | 118 | def print_optimizer(optimizer): 119 | state_dict=optimizer.state_dict() 120 | param_groups=state_dict['param_groups'] 121 | # for i,param_group in enumerate(param_groups): 122 | pprint.pprint(param_groups) 123 | 124 | -------------------------------------------------------------------------------- /src/miscellaneous/MemoryCache.py: -------------------------------------------------------------------------------- 1 | """ 2 | for NFS, it takes several minutes to load zero123 weight from disk to mem. 3 | So I cache it in mem and sue IPC-socket to fetch it when running main program 4 | """ 5 | import socket 6 | import psutil 7 | 8 | def print_mem_occupied_by_me(): 9 | 10 | pid = psutil.Process().pid 11 | 12 | 13 | process = psutil.Process(pid) 14 | mem_info = process.memory_info() 15 | 16 | 17 | print(f"Memory occupied by current process (PID {pid}):") 18 | print(f"RSS (Resident Set Size): {mem_info.rss} bytes {mem_info.rss/1024/1024} MB {mem_info.rss/1024/1024/1024} GB") 19 | print(f"VMS (Virtual Memory Size): {mem_info.vms} bytes {mem_info.vms/1024/1024} MB {mem_info.vms/1024/1024/1024} GB") 20 | 21 | 22 | class MemoryCache: 23 | PORT = 40639 24 | bytes_ = None 25 | BYTES_OF_DATASIZE=639 26 | @classmethod 27 | def send_dataSize(cls,client_socket): 28 | 29 | size = len(cls.bytes_) 30 | print(f"dataSize = {size} B") 31 | size_bytes = size.to_bytes(cls.BYTES_OF_DATASIZE, byteorder='big') 32 | client_socket.sendall(size_bytes) 33 | @classmethod 34 | def recv_dataSize(cls, server_socket)->int: 35 | 36 | size_bytes = server_socket.recv(cls.BYTES_OF_DATASIZE) 37 | size = int.from_bytes(size_bytes, byteorder='big') 38 | print(f"dataSize = {size} B") 39 | return size 40 | @classmethod 41 | def run_as_server(cls, path_fileToKeepInMemory: str): 42 | """ 43 | From 1024 to 49151: These ports are known as the Registered ports. These ports can be used by ordinary user processes or programs executed by ordinary users. 44 | From 49152 to 65535: These ports are known as Dynamic Ports. 45 | """ 46 | 47 | 48 | server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 49 | server_socket.bind(('localhost', cls.PORT)) 50 | server_socket.listen(1) 51 | # 52 | print('reading...') 53 | with open(path_fileToKeepInMemory, 'rb') as f: 54 | cls.bytes_ = f.read() 55 | print('read over') 56 | print_mem_occupied_by_me() 57 | while True: 58 | print('等待连接...') 59 | client_socket, _ = server_socket.accept() 60 | try: 61 | 62 | print('有客户端连接.sending...') 63 | cls.send_dataSize(client_socket) 64 | 65 | client_socket.sendall(cls.bytes_) # eg. b'1010111', f.read() 66 | print('send over') 67 | except Exception as e: 68 | print(f"e=",e) 69 | finally: 70 | 71 | client_socket.close() 72 | @classmethod 73 | def receive(cls)->bytes: 74 | 75 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 76 | client_socket.connect(('localhost', cls.PORT)) 77 | 78 | 79 | SIZE=cls.recv_dataSize(client_socket) 80 | model_data = bytearray(SIZE) 81 | bytes_received = 0 82 | 83 | while bytes_received < SIZE: 84 | data = client_socket.recv(SIZE - bytes_received) 85 | if not data: 86 | break 87 | model_data[bytes_received:bytes_received + len(data)] = data 88 | bytes_received += len(data) 89 | model_data = bytes(model_data) 90 | print(f"Received {len(model_data)/(1024*1024)} MB, {len(model_data)/(1024*1024*1024)} GB") 91 | assert len(model_data)==SIZE 92 | print('len(model_data)==SIZE,恭喜!') 93 | 94 | client_socket.close() 95 | return model_data 96 | if __name__ == '__main__': 97 | import sys,os 98 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 99 | # os.chdir(cur_dir) 100 | sys.path.append(os.path.join(cur_dir, "..")) 101 | import root_config 102 | MemoryCache.run_as_server(path_fileToKeepInMemory=root_config.weightPath_zero123) -------------------------------------------------------------------------------- /src/miscellaneous/Zero123_BatchB_Input.py: -------------------------------------------------------------------------------- 1 | class Zero123_BatchB_Input: 2 | def __init__(self, 3 | id_:str, 4 | folder_outputIms:str, 5 | input_image_path:str, 6 | l_xyz:list, 7 | ): 8 | """ 9 | folder_outputIms: 10 | 1. initially, means: where to put Ig. name of folder, not path 11 | 2. after sample_model_batchB_wrapper ,its meaning turn from name to path 12 | input_image_path: 13 | 1. initially, means: full path of input image 14 | 2. after sample_model_batchB_wrapper , its meaning turn from path to tensor 15 | outputims: 16 | 1. None 17 | 2. list of output image path 18 | """ 19 | self.id_=id_ 20 | self.folder_outputIms=folder_outputIms 21 | self.input_image_path=input_image_path 22 | self.l_xyz=l_xyz 23 | self.outputims:list=None 24 | def __len__(self): 25 | return len(self.l_xyz) -------------------------------------------------------------------------------- /src/miscellaneous/m.py: -------------------------------------------------------------------------------- 1 | """ 2 | for NFS, it takes several minutes to load zero123 weight from disk to mem. 3 | So I cache it in mem and sue IPC-socket to fetch wh 4 | """ 5 | import socket 6 | import psutil 7 | 8 | def print_mem_occupied_by_me(): 9 | 10 | pid = psutil.Process().pid 11 | 12 | 13 | process = psutil.Process(pid) 14 | mem_info = process.memory_info() 15 | 16 | 17 | print(f"Memory occupied by current process (PID {pid}):") 18 | print(f"RSS (Resident Set Size): {mem_info.rss} bytes {mem_info.rss/1024/1024} MB {mem_info.rss/1024/1024/1024} GB") 19 | print(f"VMS (Virtual Memory Size): {mem_info.vms} bytes {mem_info.vms/1024/1024} MB {mem_info.vms/1024/1024/1024} GB") 20 | 21 | 22 | class MemoryCache: 23 | PORT = 40639 24 | bytes_ = None 25 | BYTES_OF_DATASIZE=639 26 | @classmethod 27 | def send_dataSize(cls,client_socket): 28 | 29 | size = len(cls.bytes_) 30 | print(f"dataSize = {size} B") 31 | size_bytes = size.to_bytes(cls.BYTES_OF_DATASIZE, byteorder='big') 32 | client_socket.sendall(size_bytes) 33 | @classmethod 34 | def recv_dataSize(cls, server_socket)->int: 35 | 36 | size_bytes = server_socket.recv(cls.BYTES_OF_DATASIZE) 37 | size = int.from_bytes(size_bytes, byteorder='big') 38 | print(f"dataSize = {size} B") 39 | return size 40 | @classmethod 41 | def run_as_server(cls, path_fileToKeepInMemory: str): 42 | """ 43 | From 1024 to 49151: These ports are known as the Registered ports. These ports can be used by ordinary user processes or programs executed by ordinary users. 44 | From 49152 to 65535: These ports are known as Dynamic Ports. 45 | """ 46 | 47 | 48 | server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 49 | server_socket.bind(('localhost', cls.PORT)) 50 | server_socket.listen(1) 51 | # 52 | print('reading...') 53 | with open(path_fileToKeepInMemory, 'rb') as f: 54 | cls.bytes_ = f.read() 55 | print('read over') 56 | print_mem_occupied_by_me() 57 | while True: 58 | print('等待连接...') 59 | client_socket, _ = server_socket.accept() 60 | try: 61 | 62 | print('有客户端连接.sending...') 63 | cls.send_dataSize(client_socket) 64 | 65 | client_socket.sendall(cls.bytes_) # eg. b'1010111', f.read() 66 | print('send over') 67 | except Exception as e: 68 | print(f"e=",e) 69 | finally: 70 | 71 | client_socket.close() 72 | @classmethod 73 | def receive(cls)->bytes: 74 | 75 | client_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 76 | client_socket.connect(('localhost', cls.PORT)) 77 | 78 | 79 | SIZE=cls.recv_dataSize(client_socket) 80 | model_data = bytearray(SIZE) 81 | bytes_received = 0 82 | 83 | while bytes_received < SIZE: 84 | data = client_socket.recv(SIZE - bytes_received) 85 | if not data: 86 | break 87 | model_data[bytes_received:bytes_received + len(data)] = data 88 | bytes_received += len(data) 89 | model_data = bytes(model_data) 90 | print(f"Received {len(model_data)/(1024*1024)} MB, {len(model_data)/(1024*1024*1024)} GB") 91 | assert len(model_data)==SIZE 92 | print('len(model_data)==SIZE,恭喜!') 93 | 94 | client_socket.close() 95 | return model_data 96 | if __name__ == '__main__': 97 | import sys,os 98 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 99 | # os.chdir(cur_dir) 100 | sys.path.append(os.path.join(cur_dir, "..")) 101 | import root_config 102 | MemoryCache.run_as_server(path_fileToKeepInMemory=root_config.weightPath_zero123) -------------------------------------------------------------------------------- /src/oee/.gitignore: -------------------------------------------------------------------------------- 1 | * 2 | !*/ 3 | !/.gitignore 4 | !/.rsync_exclude 5 | !*.py 6 | -------------------------------------------------------------------------------- /src/oee/models/loftr/__init__.py: -------------------------------------------------------------------------------- 1 | from .loftr import LoFTR 2 | from .utils.cvpr_ds_config import default_cfg 3 | -------------------------------------------------------------------------------- /src/oee/models/loftr/backbone/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_fpn import ResNetFPN_8_2, ResNetFPN_16_4 2 | 3 | 4 | def build_backbone(config): 5 | if config['backbone_type'] == 'ResNetFPN': 6 | if config['resolution'] == (8, 2): 7 | return ResNetFPN_8_2(config['resnetfpn']) 8 | elif config['resolution'] == (16, 4): 9 | return ResNetFPN_16_4(config['resnetfpn']) 10 | else: 11 | raise ValueError(f"LOFTR.BACKBONE_TYPE {config['backbone_type']} not supported.") 12 | -------------------------------------------------------------------------------- /src/oee/models/loftr/loftr.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from einops.einops import rearrange 4 | 5 | from .backbone import build_backbone 6 | from .utils.position_encoding import PositionEncodingSine 7 | from .loftr_module import LocalFeatureTransformer, FinePreprocess 8 | from .utils.coarse_matching import CoarseMatching 9 | from .utils.fine_matching import FineMatching 10 | 11 | 12 | class LoFTR(nn.Module): 13 | def __init__(self, config): 14 | super().__init__() 15 | # Misc 16 | self.config = config 17 | 18 | # Modules 19 | self.backbone = build_backbone(config) 20 | self.pos_encoding = PositionEncodingSine( 21 | config['coarse']['d_model'], 22 | temp_bug_fix=config['coarse']['temp_bug_fix']) 23 | self.loftr_coarse = LocalFeatureTransformer(config['coarse']) 24 | self.coarse_matching = CoarseMatching(config['match_coarse']) 25 | self.fine_preprocess = FinePreprocess(config) 26 | self.loftr_fine = LocalFeatureTransformer(config["fine"]) 27 | self.fine_matching = FineMatching() 28 | 29 | def forward(self, data): 30 | """ 31 | Update: 32 | data (dict): { 33 | 'image0': (torch.Tensor): (N, 1, H, W) 34 | 'image1': (torch.Tensor): (N, 1, H, W) 35 | 'mask0'(optional) : (torch.Tensor): (N, H, W) '0' indicates a padded position 36 | 'mask1'(optional) : (torch.Tensor): (N, H, W) 37 | } 38 | """ 39 | # 1. Local Feature CNN 40 | data.update({ 41 | 'bs': data['image0'].size(0), 42 | 'hw0_i': data['image0'].shape[2:], 'hw1_i': data['image1'].shape[2:] 43 | }) 44 | 45 | if data['hw0_i'] == data['hw1_i']: # faster & better BN convergence 46 | feats_c, feats_f = self.backbone(torch.cat([data['image0'], data['image1']], dim=0)) 47 | (feat_c0, feat_c1), (feat_f0, feat_f1) = feats_c.split(data['bs']), feats_f.split(data['bs']) 48 | else: # handle different input shapes 49 | (feat_c0, feat_f0), (feat_c1, feat_f1) = self.backbone(data['image0']), self.backbone(data['image1']) 50 | 51 | data.update({ 52 | 'hw0_c': feat_c0.shape[2:], 'hw1_c': feat_c1.shape[2:], 53 | 'hw0_f': feat_f0.shape[2:], 'hw1_f': feat_f1.shape[2:] 54 | }) 55 | 56 | # 2. coarse-level loftr module 57 | # add featmap with positional encoding, then flatten it to sequence [N, HW, C] 58 | feat_c0 = rearrange(self.pos_encoding(feat_c0), 'n c h w -> n (h w) c') 59 | feat_c1 = rearrange(self.pos_encoding(feat_c1), 'n c h w -> n (h w) c') 60 | 61 | mask_c0 = mask_c1 = None # mask is useful in training 62 | if 'mask0' in data: 63 | mask_c0, mask_c1 = data['mask0'].flatten(-2), data['mask1'].flatten(-2) 64 | feat_c0, feat_c1 = self.loftr_coarse(feat_c0, feat_c1, mask_c0, mask_c1) 65 | 66 | # 3. match coarse-level 67 | self.coarse_matching(feat_c0, feat_c1, data, mask_c0=mask_c0, mask_c1=mask_c1) 68 | 69 | # 4. fine-level refinement 70 | feat_f0_unfold, feat_f1_unfold = self.fine_preprocess(feat_f0, feat_f1, feat_c0, feat_c1, data) 71 | if feat_f0_unfold.size(0) != 0: # at least one coarse level predicted 72 | feat_f0_unfold, feat_f1_unfold = self.loftr_fine(feat_f0_unfold, feat_f1_unfold) 73 | 74 | # 5. match fine-level 75 | self.fine_matching(feat_f0_unfold, feat_f1_unfold, data) 76 | 77 | def load_state_dict(self, state_dict, *args, **kwargs): 78 | for k in list(state_dict.keys()): 79 | if k.startswith('matcher.'): 80 | state_dict[k.replace('matcher.', '', 1)] = state_dict.pop(k) 81 | return super().load_state_dict(state_dict, *args, **kwargs) 82 | -------------------------------------------------------------------------------- /src/oee/models/loftr/loftr_module/__init__.py: -------------------------------------------------------------------------------- 1 | from .transformer import LocalFeatureTransformer 2 | from .fine_preprocess import FinePreprocess 3 | -------------------------------------------------------------------------------- /src/oee/models/loftr/loftr_module/fine_preprocess.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from einops.einops import rearrange, repeat 5 | 6 | 7 | class FinePreprocess(nn.Module): 8 | def __init__(self, config): 9 | super().__init__() 10 | 11 | self.config = config 12 | self.cat_c_feat = config['fine_concat_coarse_feat'] 13 | self.W = self.config['fine_window_size'] 14 | 15 | d_model_c = self.config['coarse']['d_model'] 16 | d_model_f = self.config['fine']['d_model'] 17 | self.d_model_f = d_model_f 18 | if self.cat_c_feat: 19 | self.down_proj = nn.Linear(d_model_c, d_model_f, bias=True) 20 | self.merge_feat = nn.Linear(2*d_model_f, d_model_f, bias=True) 21 | 22 | self._reset_parameters() 23 | 24 | def _reset_parameters(self): 25 | for p in self.parameters(): 26 | if p.dim() > 1: 27 | nn.init.kaiming_normal_(p, mode="fan_out", nonlinearity="relu") 28 | 29 | def forward(self, feat_f0, feat_f1, feat_c0, feat_c1, data): 30 | W = self.W 31 | stride = data['hw0_f'][0] // data['hw0_c'][0] 32 | 33 | data.update({'W': W}) 34 | if data['b_ids'].shape[0] == 0: 35 | feat0 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) 36 | feat1 = torch.empty(0, self.W**2, self.d_model_f, device=feat_f0.device) 37 | return feat0, feat1 38 | 39 | # 1. unfold(crop) all local windows 40 | feat_f0_unfold = F.unfold(feat_f0, kernel_size=(W, W), stride=stride, padding=W//2) 41 | feat_f0_unfold = rearrange(feat_f0_unfold, 'n (c ww) l -> n l ww c', ww=W**2) 42 | feat_f1_unfold = F.unfold(feat_f1, kernel_size=(W, W), stride=stride, padding=W//2) 43 | feat_f1_unfold = rearrange(feat_f1_unfold, 'n (c ww) l -> n l ww c', ww=W**2) 44 | 45 | # 2. select only the predicted matches 46 | feat_f0_unfold = feat_f0_unfold[data['b_ids'], data['i_ids']] # [n, ww, cf] 47 | feat_f1_unfold = feat_f1_unfold[data['b_ids'], data['j_ids']] 48 | 49 | # option: use coarse-level loftr feature as context: concat and linear 50 | if self.cat_c_feat: 51 | feat_c_win = self.down_proj(torch.cat([feat_c0[data['b_ids'], data['i_ids']], 52 | feat_c1[data['b_ids'], data['j_ids']]], 0)) # [2n, c] 53 | feat_cf_win = self.merge_feat(torch.cat([ 54 | torch.cat([feat_f0_unfold, feat_f1_unfold], 0), # [2n, ww, cf] 55 | repeat(feat_c_win, 'n c -> n ww c', ww=W**2), # [2n, ww, cf] 56 | ], -1)) 57 | feat_f0_unfold, feat_f1_unfold = torch.chunk(feat_cf_win, 2, dim=0) 58 | 59 | return feat_f0_unfold, feat_f1_unfold 60 | -------------------------------------------------------------------------------- /src/oee/models/loftr/loftr_module/linear_attention.py: -------------------------------------------------------------------------------- 1 | """ 2 | Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" 3 | Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py 4 | """ 5 | 6 | import torch 7 | from torch.nn import Module, Dropout 8 | 9 | 10 | def elu_feature_map(x): 11 | return torch.nn.functional.elu(x) + 1 12 | 13 | 14 | class LinearAttention(Module): 15 | def __init__(self, eps=1e-6): 16 | super().__init__() 17 | self.feature_map = elu_feature_map 18 | self.eps = eps 19 | 20 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None): 21 | """ Multi-Head linear attention proposed in "Transformers are RNNs" 22 | Args: 23 | queries: [N, L, H, D] 24 | keys: [N, S, H, D] 25 | values: [N, S, H, D] 26 | q_mask: [N, L] 27 | kv_mask: [N, S] 28 | Returns: 29 | queried_values: (N, L, H, D) 30 | """ 31 | Q = self.feature_map(queries) 32 | K = self.feature_map(keys) 33 | 34 | # set padded position to zero 35 | if q_mask is not None: 36 | Q = Q * q_mask[:, :, None, None] 37 | if kv_mask is not None: 38 | K = K * kv_mask[:, :, None, None] 39 | values = values * kv_mask[:, :, None, None] 40 | 41 | v_length = values.size(1) 42 | values = values / v_length # prevent fp16 overflow 43 | KV = torch.einsum("nshd,nshv->nhdv", K, values) # (S,D)' @ S,V 44 | Z = 1 / (torch.einsum("nlhd,nhd->nlh", Q, K.sum(dim=1)) + self.eps) 45 | queried_values = torch.einsum("nlhd,nhdv,nlh->nlhv", Q, KV, Z) * v_length 46 | 47 | return queried_values.contiguous() 48 | 49 | 50 | class FullAttention(Module): 51 | def __init__(self, use_dropout=False, attention_dropout=0.1): 52 | super().__init__() 53 | self.use_dropout = use_dropout 54 | self.dropout = Dropout(attention_dropout) 55 | 56 | def forward(self, queries, keys, values, q_mask=None, kv_mask=None): 57 | """ Multi-head scaled dot-product attention, a.k.a full attention. 58 | Args: 59 | queries: [N, L, H, D] 60 | keys: [N, S, H, D] 61 | values: [N, S, H, D] 62 | q_mask: [N, L] 63 | kv_mask: [N, S] 64 | Returns: 65 | queried_values: (N, L, H, D) 66 | """ 67 | 68 | # Compute the unnormalized attention and apply the masks 69 | QK = torch.einsum("nlhd,nshd->nlsh", queries, keys) 70 | if kv_mask is not None: 71 | QK.masked_fill_(~(q_mask[:, :, None, None] * kv_mask[:, None, :, None]), float('-inf')) 72 | 73 | # Compute the attention and the weighted average 74 | softmax_temp = 1. / queries.size(3)**.5 # sqrt(D) 75 | A = torch.softmax(softmax_temp * QK, dim=2) 76 | if self.use_dropout: 77 | A = self.dropout(A) 78 | 79 | queried_values = torch.einsum("nlsh,nshd->nlhd", A, values) 80 | 81 | return queried_values.contiguous() 82 | -------------------------------------------------------------------------------- /src/oee/models/loftr/loftr_module/transformer.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import torch 3 | import torch.nn as nn 4 | from .linear_attention import LinearAttention, FullAttention 5 | 6 | 7 | class LoFTREncoderLayer(nn.Module): 8 | def __init__(self, 9 | d_model, 10 | nhead, 11 | attention='linear'): 12 | super(LoFTREncoderLayer, self).__init__() 13 | 14 | self.dim = d_model // nhead 15 | self.nhead = nhead 16 | 17 | # multi-head attention 18 | self.q_proj = nn.Linear(d_model, d_model, bias=False) 19 | self.k_proj = nn.Linear(d_model, d_model, bias=False) 20 | self.v_proj = nn.Linear(d_model, d_model, bias=False) 21 | self.attention = LinearAttention() if attention == 'linear' else FullAttention() 22 | self.merge = nn.Linear(d_model, d_model, bias=False) 23 | 24 | # feed-forward network 25 | self.mlp = nn.Sequential( 26 | nn.Linear(d_model*2, d_model*2, bias=False), 27 | nn.ReLU(True), 28 | nn.Linear(d_model*2, d_model, bias=False), 29 | ) 30 | 31 | # norm and dropout 32 | self.norm1 = nn.LayerNorm(d_model) 33 | self.norm2 = nn.LayerNorm(d_model) 34 | 35 | def forward(self, x, source, x_mask=None, source_mask=None): 36 | """ 37 | Args: 38 | x (torch.Tensor): [N, L, C] 39 | source (torch.Tensor): [N, S, C] 40 | x_mask (torch.Tensor): [N, L] (optional) 41 | source_mask (torch.Tensor): [N, S] (optional) 42 | """ 43 | bs = x.size(0) 44 | query, key, value = x, source, source 45 | 46 | # multi-head attention 47 | query = self.q_proj(query).view(bs, -1, self.nhead, self.dim) # [N, L, (H, D)] 48 | key = self.k_proj(key).view(bs, -1, self.nhead, self.dim) # [N, S, (H, D)] 49 | value = self.v_proj(value).view(bs, -1, self.nhead, self.dim) 50 | message = self.attention(query, key, value, q_mask=x_mask, kv_mask=source_mask) # [N, L, (H, D)] 51 | message = self.merge(message.view(bs, -1, self.nhead*self.dim)) # [N, L, C] 52 | message = self.norm1(message) 53 | 54 | # feed-forward network 55 | message = self.mlp(torch.cat([x, message], dim=2)) 56 | message = self.norm2(message) 57 | 58 | return x + message 59 | 60 | 61 | class LocalFeatureTransformer(nn.Module): 62 | """A Local Feature Transformer (LoFTR) module.""" 63 | 64 | def __init__(self, config): 65 | super(LocalFeatureTransformer, self).__init__() 66 | 67 | self.config = config 68 | self.d_model = config['d_model'] 69 | self.nhead = config['nhead'] 70 | self.layer_names = config['layer_names'] 71 | encoder_layer = LoFTREncoderLayer(config['d_model'], config['nhead'], config['attention']) 72 | self.layers = nn.ModuleList([copy.deepcopy(encoder_layer) for _ in range(len(self.layer_names))]) 73 | self._reset_parameters() 74 | 75 | def _reset_parameters(self): 76 | for p in self.parameters(): 77 | if p.dim() > 1: 78 | nn.init.xavier_uniform_(p) 79 | 80 | def forward(self, feat0, feat1, mask0=None, mask1=None): 81 | """ 82 | Args: 83 | feat0 (torch.Tensor): [N, L, C] 84 | feat1 (torch.Tensor): [N, S, C] 85 | mask0 (torch.Tensor): [N, L] (optional) 86 | mask1 (torch.Tensor): [N, S] (optional) 87 | """ 88 | 89 | assert self.d_model == feat0.size(2), "the feature number of src and transformer must be equal" 90 | 91 | for layer, name in zip(self.layers, self.layer_names): 92 | if name == 'self': 93 | feat0 = layer(feat0, feat0, mask0, mask0) 94 | feat1 = layer(feat1, feat1, mask1, mask1) 95 | elif name == 'cross': 96 | feat0 = layer(feat0, feat1, mask0, mask1) 97 | feat1 = layer(feat1, feat0, mask1, mask0) 98 | else: 99 | raise KeyError 100 | 101 | return feat0, feat1 102 | -------------------------------------------------------------------------------- /src/oee/models/loftr/utils/cvpr_ds_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | 4 | def lower_config(yacs_cfg): 5 | if not isinstance(yacs_cfg, CN): 6 | return yacs_cfg 7 | return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} 8 | 9 | 10 | _CN = CN() 11 | _CN.BACKBONE_TYPE = 'ResNetFPN' 12 | _CN.RESOLUTION = (8, 2) # options: [(8, 2), (16, 4)] 13 | _CN.FINE_WINDOW_SIZE = 5 # window_size in fine_level, must be odd 14 | _CN.FINE_CONCAT_COARSE_FEAT = True 15 | 16 | # 1. LoFTR-backbone (local feature CNN) config 17 | _CN.RESNETFPN = CN() 18 | _CN.RESNETFPN.INITIAL_DIM = 128 19 | _CN.RESNETFPN.BLOCK_DIMS = [128, 196, 256] # s1, s2, s3 20 | 21 | # 2. LoFTR-coarse module config 22 | _CN.COARSE = CN() 23 | _CN.COARSE.D_MODEL = 256 24 | _CN.COARSE.D_FFN = 256 25 | _CN.COARSE.NHEAD = 8 26 | _CN.COARSE.LAYER_NAMES = ['self', 'cross'] * 4 27 | _CN.COARSE.ATTENTION = 'linear' # options: ['linear', 'full'] 28 | _CN.COARSE.TEMP_BUG_FIX = False 29 | 30 | # 3. Coarse-Matching config 31 | _CN.MATCH_COARSE = CN() 32 | _CN.MATCH_COARSE.THR = 0.2 33 | _CN.MATCH_COARSE.BORDER_RM = 2 34 | _CN.MATCH_COARSE.MATCH_TYPE = 'dual_softmax' # options: ['dual_softmax, 'sinkhorn'] 35 | _CN.MATCH_COARSE.DSMAX_TEMPERATURE = 0.1 36 | _CN.MATCH_COARSE.SKH_ITERS = 3 37 | _CN.MATCH_COARSE.SKH_INIT_BIN_SCORE = 1.0 38 | _CN.MATCH_COARSE.SKH_PREFILTER = True 39 | _CN.MATCH_COARSE.TRAIN_COARSE_PERCENT = 0.4 # training tricks: save GPU memory 40 | _CN.MATCH_COARSE.TRAIN_PAD_NUM_GT_MIN = 200 # training tricks: avoid DDP deadlock 41 | 42 | # 4. LoFTR-fine module config 43 | _CN.FINE = CN() 44 | _CN.FINE.D_MODEL = 128 45 | _CN.FINE.D_FFN = 128 46 | _CN.FINE.NHEAD = 8 47 | _CN.FINE.LAYER_NAMES = ['self', 'cross'] * 1 48 | _CN.FINE.ATTENTION = 'linear' 49 | 50 | default_cfg = lower_config(_CN) 51 | -------------------------------------------------------------------------------- /src/oee/models/loftr/utils/fine_matching.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | 5 | from kornia.geometry.subpix import dsnt 6 | from kornia.utils.grid import create_meshgrid 7 | 8 | 9 | class FineMatching(nn.Module): 10 | """FineMatching with s2d paradigm""" 11 | 12 | def __init__(self): 13 | super().__init__() 14 | 15 | def forward(self, feat_f0, feat_f1, data): 16 | """ 17 | Args: 18 | feat0 (torch.Tensor): [M, WW, C] 19 | feat1 (torch.Tensor): [M, WW, C] 20 | data (dict) 21 | Update: 22 | data (dict):{ 23 | 'expec_f' (torch.Tensor): [M, 3], 24 | 'mkpts0_f' (torch.Tensor): [M, 2], 25 | 'mkpts1_f' (torch.Tensor): [M, 2]} 26 | """ 27 | M, WW, C = feat_f0.shape 28 | W = int(math.sqrt(WW)) 29 | scale = data['hw0_i'][0] / data['hw0_f'][0] 30 | self.M, self.W, self.WW, self.C, self.scale = M, W, WW, C, scale 31 | 32 | # corner case: if no coarse matches found 33 | if M == 0: 34 | assert self.training == False, "M is always >0, when training, see coarse_matching.py" 35 | # logger.warning('No matches found in coarse-level.') 36 | data.update({ 37 | 'expec_f': torch.empty(0, 3, device=feat_f0.device), 38 | 'mkpts0_f': data['mkpts0_c'], 39 | 'mkpts1_f': data['mkpts1_c'], 40 | }) 41 | return 42 | 43 | feat_f0_picked = feat_f0_picked = feat_f0[:, WW//2, :] 44 | sim_matrix = torch.einsum('mc,mrc->mr', feat_f0_picked, feat_f1) 45 | softmax_temp = 1. / C**.5 46 | heatmap = torch.softmax(softmax_temp * sim_matrix, dim=1).view(-1, W, W) 47 | 48 | # compute coordinates from heatmap 49 | coords_normalized = dsnt.spatial_expectation2d(heatmap[None], True)[0] # [M, 2] 50 | grid_normalized = create_meshgrid(W, W, True, heatmap.device).reshape(1, -1, 2) # [1, WW, 2] 51 | 52 | # compute std over 53 | var = torch.sum(grid_normalized**2 * heatmap.view(-1, WW, 1), dim=1) - coords_normalized**2 # [M, 2] 54 | std = torch.sum(torch.sqrt(torch.clamp(var, min=1e-10)), -1) # [M] clamp needed for numerical stability 55 | 56 | # for fine-level supervision 57 | data.update({'expec_f': torch.cat([coords_normalized, std.unsqueeze(1)], -1)}) 58 | 59 | # compute absolute kpt coords 60 | self.get_fine_match(coords_normalized, data) 61 | 62 | @torch.no_grad() 63 | def get_fine_match(self, coords_normed, data): 64 | W, WW, C, scale = self.W, self.WW, self.C, self.scale 65 | 66 | # mkpts0_f and mkpts1_f 67 | mkpts0_f = data['mkpts0_c'] 68 | scale1 = scale * data['scale1'][data['b_ids']] if 'scale0' in data else scale 69 | mkpts1_f = data['mkpts1_c'] + (coords_normed * (W // 2) * scale1)[:len(data['mconf'])] 70 | 71 | data.update({ 72 | "mkpts0_f": mkpts0_f, 73 | "mkpts1_f": mkpts1_f 74 | }) 75 | -------------------------------------------------------------------------------- /src/oee/models/loftr/utils/geometry.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | @torch.no_grad() 5 | def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1): 6 | """ Warp kpts0 from I0 to I1 with depth, K and Rt 7 | Also check covisibility and depth consistency. 8 | Depth is consistent if relative error < 0.2 (hard-coded). 9 | 10 | Args: 11 | kpts0 (torch.Tensor): [N, L, 2] - , 12 | depth0 (torch.Tensor): [N, H, W], 13 | depth1 (torch.Tensor): [N, H, W], 14 | T_0to1 (torch.Tensor): [N, 3, 4], 15 | K0 (torch.Tensor): [N, 3, 3], 16 | K1 (torch.Tensor): [N, 3, 3], 17 | Returns: 18 | calculable_mask (torch.Tensor): [N, L] 19 | warped_keypoints0 (torch.Tensor): [N, L, 2] 20 | """ 21 | kpts0_long = kpts0.round().long() 22 | 23 | # Sample depth, get calculable_mask on depth != 0 24 | kpts0_depth = torch.stack( 25 | [depth0[i, kpts0_long[i, :, 1], kpts0_long[i, :, 0]] for i in range(kpts0.shape[0])], dim=0 26 | ) # (N, L) 27 | nonzero_mask = kpts0_depth != 0 28 | 29 | # Unproject 30 | kpts0_h = torch.cat([kpts0, torch.ones_like(kpts0[:, :, [0]])], dim=-1) * kpts0_depth[..., None] # (N, L, 3) 31 | kpts0_cam = K0.inverse() @ kpts0_h.transpose(2, 1) # (N, 3, L) 32 | 33 | # Rigid Transform 34 | w_kpts0_cam = T_0to1[:, :3, :3] @ kpts0_cam + T_0to1[:, :3, [3]] # (N, 3, L) 35 | w_kpts0_depth_computed = w_kpts0_cam[:, 2, :] 36 | 37 | # Project 38 | w_kpts0_h = (K1 @ w_kpts0_cam).transpose(2, 1) # (N, L, 3) 39 | w_kpts0 = w_kpts0_h[:, :, :2] / (w_kpts0_h[:, :, [2]] + 1e-4) # (N, L, 2), +1e-4 to avoid zero depth 40 | 41 | # Covisible Check 42 | h, w = depth1.shape[1:3] 43 | covisible_mask = (w_kpts0[:, :, 0] > 0) * (w_kpts0[:, :, 0] < w-1) * \ 44 | (w_kpts0[:, :, 1] > 0) * (w_kpts0[:, :, 1] < h-1) 45 | w_kpts0_long = w_kpts0.long() 46 | w_kpts0_long[~covisible_mask, :] = 0 47 | 48 | w_kpts0_depth = torch.stack( 49 | [depth1[i, w_kpts0_long[i, :, 1], w_kpts0_long[i, :, 0]] for i in range(w_kpts0_long.shape[0])], dim=0 50 | ) # (N, L) 51 | consistent_mask = ((w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth).abs() < 0.2 52 | valid_mask = nonzero_mask * covisible_mask * consistent_mask 53 | 54 | return valid_mask, w_kpts0 55 | -------------------------------------------------------------------------------- /src/oee/models/loftr/utils/position_encoding.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | 5 | 6 | class PositionEncodingSine(nn.Module): 7 | """ 8 | This is a sinusoidal position encoding that generalized to 2-dimensional images 9 | """ 10 | 11 | def __init__(self, d_model, max_shape=(256, 256), temp_bug_fix=True): 12 | """ 13 | Args: 14 | max_shape (tuple): for 1/8 featmap, the max length of 256 corresponds to 2048 pixels 15 | temp_bug_fix (bool): As noted in this [issue](https://github.com/zju3dv/LoFTR/issues/41), 16 | the original implementation of LoFTR includes a bug in the pos-enc impl, which has little impact 17 | on the final performance. For now, we keep both impls for backward compatability. 18 | We will remove the buggy impl after re-training all variants of our released models. 19 | """ 20 | super().__init__() 21 | 22 | pe = torch.zeros((d_model, *max_shape)) 23 | y_position = torch.ones(max_shape).cumsum(0).float().unsqueeze(0) 24 | x_position = torch.ones(max_shape).cumsum(1).float().unsqueeze(0) 25 | if temp_bug_fix: 26 | div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / (d_model//2))) 27 | else: # a buggy implementation (for backward compatability only) 28 | div_term = torch.exp(torch.arange(0, d_model//2, 2).float() * (-math.log(10000.0) / d_model//2)) 29 | div_term = div_term[:, None, None] # [C//4, 1, 1] 30 | pe[0::4, :, :] = torch.sin(x_position * div_term) 31 | pe[1::4, :, :] = torch.cos(x_position * div_term) 32 | pe[2::4, :, :] = torch.sin(y_position * div_term) 33 | pe[3::4, :, :] = torch.cos(y_position * div_term) 34 | 35 | self.register_buffer('pe', pe.unsqueeze(0), persistent=False) # [1, C, H, W] 36 | 37 | def forward(self, x): 38 | """ 39 | Args: 40 | x: [N, C, H, W] 41 | """ 42 | return x + self.pe[:, :, :x.size(2), :x.size(3)] 43 | -------------------------------------------------------------------------------- /src/oee/utils/utils3d.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | 5 | def cart_to_hom(pts): 6 | """ 7 | :param pts: (N, 3 or 2) 8 | :return pts_hom: (N, 4 or 3) 9 | """ 10 | if isinstance(pts, np.ndarray): 11 | pts_hom = np.concatenate((pts, np.ones([*pts.shape[:-1], 1], dtype=np.float32)), -1) 12 | else: 13 | ones = torch.ones([*pts.shape[:-1], 1], dtype=torch.float32, device=pts.device) 14 | pts_hom = torch.cat((pts, ones), dim=-1) 15 | return pts_hom 16 | 17 | 18 | def hom_to_cart(pts): 19 | return pts[..., :-1] / pts[..., -1:] 20 | 21 | 22 | def canonical_to_camera(pts, pose): 23 | pts = cart_to_hom(pts) 24 | pts = pts @ pose.transpose(-1, -2) 25 | pts = hom_to_cart(pts) 26 | return pts 27 | 28 | 29 | def rect_to_img(K, pts_rect): 30 | from dl_ext.vision_ext.datasets.kitti.structures import Calibration 31 | pts_2d_hom = pts_rect @ K.t() 32 | pts_img = Calibration.hom_to_cart(pts_2d_hom) 33 | return pts_img 34 | 35 | 36 | def calc_pose(phis, thetas, size, radius=1.2): 37 | import torch 38 | def normalize(vectors): 39 | return vectors / (torch.norm(vectors, dim=-1, keepdim=True) + 1e-10) 40 | 41 | thetas = torch.FloatTensor(thetas) 42 | phis = torch.FloatTensor(phis) 43 | 44 | centers = torch.stack([ 45 | radius * torch.sin(thetas) * torch.sin(phis), 46 | -radius * torch.cos(thetas) * torch.sin(phis), 47 | radius * torch.cos(phis), 48 | ], dim=-1) # [B, 3] 49 | 50 | # lookat 51 | forward_vector = normalize(centers).squeeze(0) 52 | up_vector = torch.FloatTensor([0, 0, 1]).unsqueeze(0).repeat(size, 1) 53 | right_vector = normalize(torch.cross(up_vector, forward_vector, dim=-1)) 54 | if right_vector.pow(2).sum() < 0.01: 55 | right_vector = torch.FloatTensor([0, 1, 0]).unsqueeze(0).repeat(size, 1) 56 | up_vector = normalize(torch.cross(forward_vector, right_vector, dim=-1)) 57 | 58 | poses = torch.eye(4, dtype=torch.float).unsqueeze(0).repeat(size, 1, 1) 59 | poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) 60 | poses[:, :3, 3] = centers 61 | return poses 62 | -------------------------------------------------------------------------------- /src/path_configuration.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | # the parent folder of GSO objects folders (GSO_alarm,GSO_backpack,...) 4 | dataPath_gso='path/to/gso-renderings' 5 | # the parent folder of NAVI objects folders (3d_dollhouse_sink,bottle_vitamin_d_tablets,...) 6 | dataPath_navi='' 7 | 8 | -------------------------------------------------------------------------------- /src/pose_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys 3 | from PIL import Image 4 | from zero123.zero1.util_4_e2vg import CameraMatrixUtil 5 | from zero123.zero1.util_4_e2vg.CameraMatrixUtil import xyz2pose,get_z_4_normObj 6 | # from gen6d.Gen6D.scy.ElevationUtil import eleRadian_2_baseXyz_lXyz 7 | def R_t_2_pose(R,t): 8 | if(isinstance(R,list)): 9 | R=np.array(R) 10 | if(isinstance(t,list)): 11 | t=np.array(t) 12 | if(t.shape==(3,)): 13 | t=t.reshape((3,1)) 14 | t=t.reshape((3,)) 15 | assert(t.shape==(3,)) 16 | assert(R.shape==(3,3)) 17 | pose=np.zeros((4,4)) 18 | pose[:3,:3]=R 19 | pose[:3,3]=t 20 | pose[3,3]=1 21 | return pose 22 | class Pose_R_t_Converter: 23 | @staticmethod 24 | def pose_2_Rt(pose44_or_pose34): 25 | assert pose44_or_pose34.shape==(4,4) or pose44_or_pose34.shape==(3,4) 26 | R = pose44_or_pose34[:3, :3] 27 | t = pose44_or_pose34[:3, 3] 28 | return R,t 29 | @staticmethod 30 | # def Rt_2_pose44(R,t): 31 | def R_t3np__2__pose44(R,t): 32 | assert R.shape==(3,3) 33 | assert t.shape==(3,) 34 | pose=np.eye(4) 35 | pose[:3,:3]=R 36 | pose[:3,3]=t 37 | return pose 38 | @staticmethod 39 | def R_t3np__2__pose34(R,t): 40 | assert R.shape==(3,3) 41 | assert t.shape==(3,) 42 | pose44=Pose_R_t_Converter.R_t3np__2__pose44(R,t) 43 | pose34=pose44[:3,:] 44 | return pose34 45 | @staticmethod 46 | def pose34_2_pose44(pose34): 47 | assert pose34.shape==(3,4) 48 | pose44=np.concatenate([pose34,np.array([[0,0,0,1]])],axis=0) 49 | return pose44 50 | @staticmethod 51 | def R__2__arbitrary_t_pose44(R): 52 | assert R.shape==(3,3) 53 | pose44=Pose_R_t_Converter.R_t3np__2__pose44(R,np.zeros((3,))) 54 | return pose44 55 | 56 | def opencv_2_pytorch3d__leftMulW2cR(R):#w2opencv to w2pytorch3d 57 | assert R.shape==(3,3) 58 | Rop = np.array([ 59 | [-1, 0, 0], 60 | [0, -1, 0], 61 | [0, 0, 1], 62 | ], dtype=np.float64)#o means OpenCV, p means pytorch3d 63 | R = Rop @ R 64 | return R 65 | def opencv_2_pytorch3d__leftMulW2cpose(pose):#TODO check correctness 66 | assert pose.shape==(4,4) 67 | Poseop = np.array([ 68 | [-1, 0, 0, 0], 69 | [0, -1, 0, 0], 70 | [0, 0, 1, 0], 71 | [0, 0, 0, 1], 72 | ], dtype=np.float64)#o means OpenCV, p means pytorch3d 73 | pose = Poseop @ pose 74 | return pose 75 | def opencv_2_pytorch3d__leftMulRelR(R): 76 | assert R.shape==(3,3) 77 | Rop = np.array([ 78 | [-1, 0, 0], 79 | [0, -1, 0], 80 | [0, 0, 1], 81 | ], dtype=np.float64) 82 | R = Rop @ R @ (Rop.T) 83 | return R 84 | # def pytorch3d_2_opencv__leftMulRelR(R): 85 | # assert R.shape==(3,3) 86 | # Rop = np.array([ 87 | # [-1, 0, 0], 88 | # [0, -1, 0], 89 | # [0, 0, 1], 90 | # ], dtype=np.float64) 91 | # R = Rop @ R @ (Rop.T) 92 | # return R 93 | def opencv_2_pytorch3d__leftMulRelPose(pose): 94 | assert pose.shape==(4,4) 95 | Pop = np.array([ 96 | [-1, 0, 0, 0], 97 | [0, -1, 0, 0], 98 | [0, 0, 1, 0], 99 | [0, 0, 0, 1], 100 | ], dtype=np.float64) 101 | pose = Pop @ pose @ np.linalg.inv(Pop) 102 | return pose 103 | def pytorch3d_2_opencv__leftMulRelPose(pose): 104 | return opencv_2_pytorch3d__leftMulRelPose(pose) 105 | def pytorch3d_2_opencv__leftMulW2cpose(pose): 106 | assert pose.shape==(4,4) 107 | pose=opencv_2_pytorch3d__leftMulW2cpose(pose) 108 | return pose 109 | def opengl_2_opencv__leftMulW2cpose(pose):#TODO check correctness 110 | assert pose.shape==(4,4) 111 | Posego = np.array([#GL to OpenCV 112 | [1, 0, 0, 0], 113 | [0, -1, 0, 0], 114 | [0, 0, -1, 0], 115 | [0, 0, 0, 1], 116 | ], dtype=np.float64) 117 | pose = Posego @ pose 118 | return pose 119 | 120 | 121 | def compute_angular_error(rotation1, rotation2): 122 | # R_rel = rotation1.T @ rotation2 123 | R_rel = rotation2 @ rotation1.T 124 | tr = (np.trace(R_rel) - 1) / 2 125 | theta = np.arccos(tr.clip(-1, 1)) 126 | return theta * 180 / np.pi 127 | def compute_translation_error(t31_1, t31_2): 128 | assert t31_1.shape==(3,1) 129 | assert t31_2.shape==(3,1) 130 | # ret=np.linalg.norm(t31_1 - t31_2, axis=1) 131 | # angle between two vectors 132 | ret=np.arccos(np.dot(t31_1.T,t31_2)/(np.linalg.norm(t31_1)*np.linalg.norm(t31_2))) 133 | ret=ret.item() 134 | ret=ret * 180 / np.pi 135 | return ret 136 | 137 | 138 | def in_plane_rotate_camera(degree_clockwise, pilImage: Image.Image, w2opencv_44,fillcolor=(255, 255, 255),): 139 | """ 140 | 相机follows opencv convention 141 | 顺时针旋转相机 degree_clockwise ° <--> 逆时针旋转图片 degree_clockwise ° 142 | degree_clockwise 即草稿纸《12.26 for Q0Sipr 》上的θ 143 | """ 144 | assert -360 <= degree_clockwise <= 360 145 | if isinstance(pilImage,np.ndarray): 146 | pilImage=Image.fromarray(pilImage) 147 | assert isinstance(pilImage,Image.Image) 148 | assert w2opencv_44.shape == (4, 4) 149 | img_rot: Image.Image = pilImage.rotate( 150 | degree_clockwise, 151 | fillcolor=fillcolor, 152 | resample=Image.BICUBIC, 153 | ) 154 | rad_clockwise = np.deg2rad(degree_clockwise) 155 | P_IPR = np.asarray([ 156 | [np.cos(rad_clockwise), np.sin(rad_clockwise), 0, 0], 157 | [-np.sin(rad_clockwise), np.cos(rad_clockwise), 0, 0], 158 | [0, 0, 1, 0], 159 | [0, 0, 0, 1], 160 | ], np.float32) 161 | w2opencv_44_rot = P_IPR @ w2opencv_44 162 | return img_rot,w2opencv_44_rot 163 | def in_plane_rotate_camera_wrap(degree_clockwise, pilImage_or_ndarray, w2opencv_44=None,fillcolor=(255, 255, 255),): 164 | """ 165 | camera follows opencv convention 166 | 顺时针旋转相机 degree_clockwise ° <--> 逆时针旋转图片 degree_clockwise ° 167 | """ 168 | if w2opencv_44 is None: 169 | w2opencv_44=np.eye(4) 170 | if isinstance(pilImage_or_ndarray,np.ndarray): 171 | pilImage=Image.fromarray(pilImage_or_ndarray) 172 | img_rot,w2opencv_44_rot=in_plane_rotate_camera(degree_clockwise, pilImage, w2opencv_44,fillcolor=fillcolor) 173 | if isinstance(pilImage_or_ndarray,np.ndarray): 174 | img_rot=np.array(img_rot) 175 | return img_rot,w2opencv_44_rot -------------------------------------------------------------------------------- /src/redirect_util.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import sys, time 3 | import sys,os 4 | import root_config 5 | import threading 6 | import root_config 7 | from misc_util import your_datetime 8 | def _get_logFilePath(dir_:str,log_file_prefix: str): 9 | pid = os.getpid() 10 | th = threading.currentThread() 11 | th_name = th.getName() 12 | log_file_name = f"{log_file_prefix}{os.path.basename(sys.argv[0])[:-3]}--{your_datetime():%m.%d-%H:%M:%S}--{pid}:{th_name}" 13 | 14 | def name2fp(name): 15 | full_path = os.path.join( 16 | dir_, 17 | f"{name}.log" 18 | ) 19 | return full_path 20 | os.makedirs(dir_,exist_ok=True) 21 | while (os.path.exists(name2fp(log_file_name))): 22 | log_file_name += f"_" 23 | # print(f"redirect stdout to:\n\"{name2fp(log_file_name)}\"") 24 | # print(f"redirect stdout to:\n{name2fp(log_file_name)}") 25 | print(f"\nredirect to:\n{name2fp(log_file_name)}:0\n") 26 | # print(f"redirect stdout to:\n<{name2fp(log_file_name)}>") 27 | # print(f"redirect stdout to:\n<{name2fp(log_file_name)}>:0") 28 | return name2fp(log_file_name) 29 | def redirectA(dir_:str,log_file_prefix: str): 30 | logFilePath=_get_logFilePath(dir_,log_file_prefix) 31 | sys.stdout = open(logFilePath, 'w') 32 | 33 | 34 | 35 | 36 | 37 | class Tee:#cursor 38 | FORCE_TO_FLUSH_INTERVAL = 10 39 | def __init__(self, *files): 40 | self.files = [] 41 | for file in files: 42 | if isinstance(file, str): 43 | self.files.append(open(file, 'w')) 44 | else: 45 | self.files.append(file) 46 | self.last_flush_time = time.time() 47 | def write(self, obj): 48 | for f in self.files: 49 | f.write(obj) 50 | if time.time() - self.last_flush_time > self.FORCE_TO_FLUSH_INTERVAL: 51 | f.flush() # Force flush every FORCE_TO_FLUSH_INTERVAL seconds 52 | self.last_flush_time = time.time() 53 | def flush(self) : 54 | for f in self.files: 55 | f.flush() 56 | class RedirectorB: 57 | def __init__(self,log_file_prefix: str, 58 | dir_:str=root_config.logPath, 59 | redirect_stderr=True,also_to_screen=True): 60 | logFilePath = _get_logFilePath(dir_, log_file_prefix) 61 | if also_to_screen: 62 | f=open(logFilePath,'w') 63 | sys.stdout = Tee(sys.stdout,f) 64 | if redirect_stderr: 65 | sys.stderr = Tee(sys.stderr,f) 66 | else: 67 | tee = Tee( logFilePath) 68 | sys.stdout=tee 69 | if redirect_stderr: 70 | sys.stderr = tee 71 | 72 | if __name__=='__main__': 73 | _=RedirectorB('./tmp/','ttt435', 74 | redirect_stderr=1, 75 | also_to_screen=1) 76 | print('aaaaaaaabbbb') 77 | print('h4t93qht0') 78 | exit(0) 79 | 80 | 81 | 82 | 83 | class HiddenPrints: 84 | def __enter__(self): 85 | self._original_stdout = sys.stdout 86 | sys.stdout = open(os.devnull, 'w') 87 | 88 | def __exit__(self, exc_type, exc_val, exc_tb): 89 | sys.stdout.close() 90 | sys.stdout = self._original_stdout 91 | 92 | 93 | class HiddenSpecified_OutAndErr: 94 | class FilterOut(object):# from https://stackoverflow.com/questions/34904946/how-to-filter-stdout-in-python-logging 95 | def __init__(self, stream, l__filter_the_line_that_contains): 96 | self.stream = stream 97 | self.l__filter_the_line_that_contains = l__filter_the_line_that_contains 98 | # self.pattern = re.compile(re_pattern) if isinstance(re_pattern, str) else re_pattern 99 | def __getattr__(self, attr_name): 100 | return getattr(self.stream, attr_name) 101 | def write(self, data): 102 | for string in self.l__filter_the_line_that_contains: 103 | if string in data: 104 | return 105 | self.stream.write(data) 106 | # self.stream.flush() 107 | def flush(self): 108 | self.stream.flush() 109 | def __init__(self, l__filter_the_line_that_contains ): 110 | self.l__filter_the_line_that_contains = l__filter_the_line_that_contains 111 | def __enter__(self): 112 | self._original_stdout =sys.stdout 113 | self._original_stderr =sys.stderr 114 | sys.stdout = self.FilterOut(sys.stdout , self.l__filter_the_line_that_contains) 115 | sys.stderr = self.FilterOut(sys.stderr , self.l__filter_the_line_that_contains) 116 | 117 | def __exit__(self, exc_type, exc_val, exc_tb): 118 | sys.stdout = self._original_stdout 119 | sys.stderr = self._original_stderr 120 | 121 | -------------------------------------------------------------------------------- /src/root_config.py: -------------------------------------------------------------------------------- 1 | import torch.cuda 2 | 3 | #------- 4 | # DATASET="gso" 5 | # DATASET="navi" 6 | DATASET=None 7 | #------- 8 | SAMPLE_BATCH_SIZE=16 9 | FORCE_zero123_render_even_img_exist=0 10 | SAMPLE_BATCH_B_SIZE=4 11 | #-------gen6d 12 | NUM_REF=64 13 | #------- 14 | SKIP_EVAL_SEQ_IF_EVAL_RESULT_EXIST=1 15 | MAX_PAIRS=20 16 | CONSIDER_IPR=False # in-plane rotation of q0 17 | Q0Sipr:bool=False 18 | Q0Sipr_range:int=45 19 | Q1Sipr:bool=False 20 | Q1Sipr_range:int=45 21 | #------- 22 | SEED=0 23 | Q0INDEX:int=None 24 | 25 | idSuffix=f"{SEED}+{Q0INDEX}" 26 | refIdSuffix=f"+{Q0INDEX}" 27 | # refIdSuffix=f"{SEED}" 28 | # refIdSuffix=f"{SEED}+" 29 | #-------misc 30 | tmp_batch_image__SUFFIX='.png' 31 | SHARE_tmp_batch_images=False #False or folder name 32 | NO_CARVEKIT:bool=True #CARVEKIT is a bg remover. if input img is masked, then no need to remove bg 33 | VIS=True ## visualize result; to save time, you can let it be 0. VIS: bool/int(0 or 1) /fp(0.0-1.0,vis ratio). 34 | #------- 35 | 36 | #-------path 37 | import os 38 | path_root=os.path.dirname(os.path.abspath(__file__)) #path of src 39 | path_4debug=os.path.join(path_root,"4debug") 40 | projPath_gen6d=os.path.join(path_root,"gen6d/Gen6D") 41 | projPath_zero123=os.path.join(path_root,"zero123/zero1") 42 | dataPath_zero123=os.path.join(path_root,"zero123/zero1/output_im") 43 | dataPath_gen6d=os.path.join(path_root,projPath_gen6d,"data/zero123") 44 | from path_configuration import * 45 | weightPath_zero123=os.path.join(path_root,"../weight/105000.ckpt") 46 | weightPath_gen6d='../weight/weight_gen6d' 47 | weightPath_selector=os.path.join( weightPath_gen6d ,"selector_pretrain/model_best.pth") 48 | weightPath_refiner=os.path.join( weightPath_gen6d ,"refiner_pretrain/model_best.pth") 49 | weightPath_loftr=os.path.join(path_root ,"../weight/indoor_ds_new.ckpt") 50 | evalResultPath_co3d=os.path.join(path_root,"result/eval_result") 51 | evalVisPath=os.path.join(path_root,"result/visual") 52 | logPath=os.path.join(path_root,"log") 53 | os.makedirs(path_4debug,exist_ok=True) 54 | os.makedirs(logPath,exist_ok=True) 55 | os.makedirs(evalResultPath_co3d,exist_ok=True) 56 | os.makedirs(evalVisPath,exist_ok=True) 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | class RefIdWhenNormal: 76 | @staticmethod 77 | def get_id(cate,seq,refIdSuffix_): 78 | return f"{cate}--{seq}--{refIdSuffix_}" 79 | 80 | 81 | #-----------------------unused conf-------------------- 82 | #-------debug 83 | NO_TRY=0 84 | class ForDebug: 85 | class forceIPRtoBe: 86 | enable:bool=False 87 | IPR:int=None 88 | class Cheat: 89 | force_elev=None# 90 | #-------GPU 91 | GPU_INDEX=0 92 | DEVICE=f"cuda:{GPU_INDEX}" if torch.cuda.is_available() else "cpu" 93 | #-------zero123 l_xyz 94 | ZERO123_MULTI_INPUT_IMAGE = 0 95 | _Z = 0 96 | #-------zero123 NUM_SAMPLE 97 | USE_ALL_SAMPLE=0 98 | NUM_SAMPLE=1 99 | # 100 | USE_white_bg_Detector=1 101 | Q0_MIN_ELEV=0 102 | ELEV_RANGE=None # if not None: in degree. lower and upper bound of absolute elev. eg. (-0.1,40) 103 | USE_CONFIDENCE=0 104 | REFINE_ITER:int=3 105 | #-------check,geometry 106 | LOOK_AT_CROP_OUTSIDE_GEN6D=1 107 | # 1:call gen6d_imgPaths2relativeRt_B(where perspective trans is performed); 0:give detection_outputs to gen6d so that perspective trans in refiner 108 | IGNORE_EXCEPTION= 0 109 | ONLY_GEN_DO_NOT_MEASURE=0 110 | LOG_WHEN_SAMPLING=0 111 | one_SEQ_mul_Q0__one_Q0_mul_Q1=1 112 | class CONF_one_SEQ_mul_Q0__one_Q0_mul_Q1: 113 | ONLY_CHECK_BASENAME=False 114 | FOR_PAPER=False 115 | LOAD_BY_IPC=False 116 | MARGIN_in_LOOK_AT=0.05 117 | #-------4 ablation 118 | MASK_ABLATION=None # None/'EROSION'/'DILATION' 119 | ABLATE_REFINE_ITER:int=None# None or int (when int, it can be 0, so must use 'if root_config.ABLATE_REFINE_ITER is (not) None:' instead of 'if (not) root_config.ABLATE_REFINE_ITER:') 120 | #-------4 val 121 | VALing=0 122 | SKIP_GEN_REF_IF_REF_FOLDER_EXIST=False -------------------------------------------------------------------------------- /src/vis/InterVisualizer.py: -------------------------------------------------------------------------------- 1 | from .cv2_util import * 2 | import cv2 3 | 4 | 5 | class InterVisualizer: 6 | def __init__(self, ): 7 | self.imgs = [] 8 | self.l_text = [] 9 | 10 | def append(self, img=None, text=None, row=None, column=None, kw_putText={}): 11 | assert column==None 12 | if (img): 13 | if (text): 14 | kw_putText = { 15 | "org": (10, 30), 16 | "fontFace": cv2.FONT_HERSHEY_SIMPLEX, 17 | "fontScale": 0.6, 18 | "color": (0, 0, 0), 19 | # thickness, 20 | # lineType, 21 | **kw_putText, 22 | } 23 | putText( 24 | img, 25 | text, 26 | **kw_putText, 27 | ) 28 | if(not row): 29 | row=len(self.imgs) 30 | while(len(self.imgs)-1 0 58 | mask = (255 * mask).astype(np.uint8) 59 | mask = Image.fromarray(mask) 60 | draw = ImageDraw.Draw(mask) 61 | draw.line([start, end], fill=255, width=brush_width, joint="curve") 62 | mask = np.array(mask) / 255 63 | return mask 64 | 65 | 66 | def gen_box_mask(mask, masked): 67 | x_0, y_0, w, h = masked 68 | mask[y_0:y_0 + h, x_0:x_0 + w] = 1 69 | return mask 70 | 71 | 72 | def gen_round_mask(mask, masked, radius): 73 | x_0, y_0, w, h = masked 74 | xy = [(x_0, y_0), (x_0 + w, y_0 + w)] 75 | 76 | mask = mask > 0 77 | mask = (255 * mask).astype(np.uint8) 78 | mask = Image.fromarray(mask) 79 | draw = ImageDraw.Draw(mask) 80 | draw.rounded_rectangle(xy, radius=radius, fill=255) 81 | mask = np.array(mask) / 255 82 | return mask 83 | 84 | 85 | def gen_large_mask(prng, img_h, img_w, 86 | marg, p_irr, min_n_irr, max_n_irr, max_l_irr, max_w_irr, 87 | min_n_box, max_n_box, min_s_box, max_s_box): 88 | """ 89 | img_h: int, an image height 90 | img_w: int, an image width 91 | marg: int, a margin for a box starting coordinate 92 | p_irr: float, 0 <= p_irr <= 1, a probability of a polygonal chain mask 93 | 94 | min_n_irr: int, min number of segments 95 | max_n_irr: int, max number of segments 96 | max_l_irr: max length of a segment in polygonal chain 97 | max_w_irr: max width of a segment in polygonal chain 98 | 99 | min_n_box: int, min bound for the number of box primitives 100 | max_n_box: int, max bound for the number of box primitives 101 | min_s_box: int, min length of a box side 102 | max_s_box: int, max length of a box side 103 | """ 104 | 105 | mask = np.zeros((img_h, img_w)) 106 | uniform = prng.randint 107 | 108 | if np.random.uniform(0, 1) < p_irr: # generate polygonal chain 109 | n = uniform(min_n_irr, max_n_irr) # sample number of segments 110 | 111 | for _ in range(n): 112 | y = uniform(0, img_h) # sample a starting point 113 | x = uniform(0, img_w) 114 | 115 | a = uniform(0, 360) # sample angle 116 | l = uniform(10, max_l_irr) # sample segment length 117 | w = uniform(5, max_w_irr) # sample a segment width 118 | 119 | # draw segment starting from (x,y) to (x_,y_) using brush of width w 120 | x_ = x + l * np.sin(a) 121 | y_ = y + l * np.cos(a) 122 | 123 | mask = gen_segment_mask(mask, start=(x, y), end=(x_, y_), brush_width=w) 124 | x, y = x_, y_ 125 | else: # generate Box masks 126 | n = uniform(min_n_box, max_n_box) # sample number of rectangles 127 | 128 | for _ in range(n): 129 | h = uniform(min_s_box, max_s_box) # sample box shape 130 | w = uniform(min_s_box, max_s_box) 131 | 132 | x_0 = uniform(marg, img_w - marg - w) # sample upper-left coordinates of box 133 | y_0 = uniform(marg, img_h - marg - h) 134 | 135 | if np.random.uniform(0, 1) < 0.5: 136 | mask = gen_box_mask(mask, masked=(x_0, y_0, w, h)) 137 | else: 138 | r = uniform(0, 60) # sample radius 139 | mask = gen_round_mask(mask, masked=(x_0, y_0, w, h), radius=r) 140 | return mask 141 | 142 | 143 | make_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256train"]) 144 | make_narrow_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["256narrow"]) 145 | make_512_lama_mask = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train"]) 146 | make_512_lama_mask_large = lambda prng, h, w: gen_large_mask(prng, h, w, **settings["512train-large"]) 147 | 148 | 149 | MASK_MODES = { 150 | "256train": make_lama_mask, 151 | "256narrow": make_narrow_lama_mask, 152 | "512train": make_512_lama_mask, 153 | "512train-large": make_512_lama_mask_large 154 | } 155 | 156 | if __name__ == "__main__": 157 | import sys 158 | 159 | out = sys.argv[1] 160 | 161 | prng = np.random.RandomState(1) 162 | kwargs = settings["256train"] 163 | mask = gen_large_mask(prng, 256, 256, **kwargs) 164 | mask = (255 * mask).astype(np.uint8) 165 | mask = Image.fromarray(mask) 166 | mask.save(out) 167 | -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/extras.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from omegaconf import OmegaConf 3 | import torch 4 | from ldm.util import instantiate_from_config 5 | import logging 6 | from contextlib import contextmanager 7 | 8 | from contextlib import contextmanager 9 | import logging 10 | 11 | @contextmanager 12 | def all_logging_disabled(highest_level=logging.CRITICAL): 13 | """ 14 | A context manager that will prevent any logging messages 15 | triggered during the body from being processed. 16 | 17 | :param highest_level: the maximum logging level in use. 18 | This would only need to be changed if a custom level greater than CRITICAL 19 | is defined. 20 | 21 | https://gist.github.com/simon-weber/7853144 22 | """ 23 | # two kind-of hacks here: 24 | # * can't get the highest logging level in effect => delegate to the user 25 | # * can't get the current module-level override => use an undocumented 26 | # (but non-private!) interface 27 | 28 | previous_level = logging.root.manager.disable 29 | 30 | logging.disable(highest_level) 31 | 32 | try: 33 | yield 34 | finally: 35 | logging.disable(previous_level) 36 | 37 | def load_training_dir(train_dir, device, epoch="last"): 38 | """Load a checkpoint and config from training directory""" 39 | train_dir = Path(train_dir) 40 | ckpt = list(train_dir.rglob(f"*{epoch}.ckpt")) 41 | assert len(ckpt) == 1, f"found {len(ckpt)} matching ckpt files" 42 | config = list(train_dir.rglob(f"*-project.yaml")) 43 | assert len(ckpt) > 0, f"didn't find any config in {train_dir}" 44 | if len(config) > 1: 45 | print(f"found {len(config)} matching config files") 46 | config = sorted(config)[-1] 47 | print(f"selecting {config}") 48 | else: 49 | config = config[0] 50 | 51 | 52 | config = OmegaConf.load(config) 53 | return load_model_from_config(config, ckpt[0], device) 54 | 55 | def load_model_from_config(config, ckpt, device="cpu", verbose=False): 56 | """Loads a model from config and a ckpt 57 | if config is a path will use omegaconf to load 58 | """ 59 | if isinstance(config, (str, Path)): 60 | config = OmegaConf.load(config) 61 | 62 | with all_logging_disabled(): 63 | print(f"Loading model from {ckpt}") 64 | pl_sd = torch.load(ckpt, map_location="cpu") 65 | global_step = pl_sd["global_step"] 66 | sd = pl_sd["state_dict"] 67 | model = instantiate_from_config(config.model) 68 | m, u = model.load_state_dict(sd, strict=False) 69 | if len(m) > 0 and verbose: 70 | print("missing keys:") 71 | print(m) 72 | if len(u) > 0 and verbose: 73 | print("unexpected keys:") 74 | model.to(device) 75 | model.eval() 76 | model.cond_stage_model.device = device 77 | return model -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/guidance.py: -------------------------------------------------------------------------------- 1 | from typing import List, Tuple 2 | from scipy import interpolate 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from IPython.display import clear_output 7 | import abc 8 | 9 | 10 | class GuideModel(torch.nn.Module, abc.ABC): 11 | def __init__(self) -> None: 12 | super().__init__() 13 | 14 | @abc.abstractmethod 15 | def preprocess(self, x_img): 16 | pass 17 | 18 | @abc.abstractmethod 19 | def compute_loss(self, inp): 20 | pass 21 | 22 | 23 | class Guider(torch.nn.Module): 24 | def __init__(self, sampler, guide_model, scale=1.0, verbose=False): 25 | """Apply classifier guidance 26 | 27 | Specify a guidance scale as either a scalar 28 | Or a schedule as a list of tuples t = 0->1 and scale, e.g. 29 | [(0, 10), (0.5, 20), (1, 50)] 30 | """ 31 | super().__init__() 32 | self.sampler = sampler 33 | self.index = 0 34 | self.show = verbose 35 | self.guide_model = guide_model 36 | self.history = [] 37 | 38 | if isinstance(scale, (Tuple, List)): 39 | times = np.array([x[0] for x in scale]) 40 | values = np.array([x[1] for x in scale]) 41 | self.scale_schedule = {"times": times, "values": values} 42 | else: 43 | self.scale_schedule = float(scale) 44 | 45 | self.ddim_timesteps = sampler.ddim_timesteps 46 | self.ddpm_num_timesteps = sampler.ddpm_num_timesteps 47 | 48 | 49 | def get_scales(self): 50 | if isinstance(self.scale_schedule, float): 51 | return len(self.ddim_timesteps)*[self.scale_schedule] 52 | 53 | interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) 54 | fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps 55 | return interpolater(fractional_steps) 56 | 57 | def modify_score(self, model, e_t, x, t, c): 58 | 59 | # TODO look up index by t 60 | scale = self.get_scales()[self.index] 61 | 62 | if (scale == 0): 63 | return e_t 64 | 65 | sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) 66 | with torch.enable_grad(): 67 | x_in = x.detach().requires_grad_(True) 68 | pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) 69 | x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) 70 | 71 | inp = self.guide_model.preprocess(x_img) 72 | loss = self.guide_model.compute_loss(inp) 73 | grads = torch.autograd.grad(loss.sum(), x_in)[0] 74 | correction = grads * scale 75 | 76 | if self.show: 77 | clear_output(wait=True) 78 | print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) 79 | self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) 80 | plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) 81 | plt.axis('off') 82 | plt.show() 83 | plt.imshow(correction[0][0].detach().cpu()) 84 | plt.axis('off') 85 | plt.show() 86 | 87 | 88 | e_t_mod = e_t - sqrt_1ma*correction 89 | if self.show: 90 | fig, axs = plt.subplots(1, 3) 91 | axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) 92 | axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) 93 | axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) 94 | plt.show() 95 | self.index += 1 96 | return e_t_mod -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/src/zero123/zero1/ldm/models/diffusion/__init__.py -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/models/diffusion/sampling_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | def append_dims(x, target_dims): 6 | """Appends dimensions to the end of a tensor until it has target_dims dimensions. 7 | From https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/utils.py""" 8 | dims_to_append = target_dims - x.ndim 9 | if dims_to_append < 0: 10 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') 11 | return x[(...,) + (None,) * dims_to_append] 12 | 13 | 14 | def renorm_thresholding(x0, value): 15 | # renorm 16 | pred_max = x0.max() 17 | pred_min = x0.min() 18 | pred_x0 = (x0 - pred_min) / (pred_max - pred_min) # 0 ... 1 19 | pred_x0 = 2 * pred_x0 - 1. # -1 ... 1 20 | 21 | s = torch.quantile( 22 | rearrange(pred_x0, 'b ... -> b (...)').abs(), 23 | value, 24 | dim=-1 25 | ) 26 | s.clamp_(min=1.0) 27 | s = s.view(-1, *((1,) * (pred_x0.ndim - 1))) 28 | 29 | # clip by threshold 30 | # pred_x0 = pred_x0.clamp(-s, s) / s # needs newer pytorch # TODO bring back to pure-gpu with min/max 31 | 32 | # temporary hack: numpy on cpu 33 | pred_x0 = np.clip(pred_x0.cpu().numpy(), -s.cpu().numpy(), s.cpu().numpy()) / s.cpu().numpy() 34 | pred_x0 = torch.tensor(pred_x0).to(self.model.device) 35 | 36 | # re.renorm 37 | pred_x0 = (pred_x0 + 1.) / 2. # 0 ... 1 38 | pred_x0 = (pred_max - pred_min) * pred_x0 + pred_min # orig range 39 | return pred_x0 40 | 41 | 42 | def norm_thresholding(x0, value): 43 | s = append_dims(x0.pow(2).flatten(1).mean(1).sqrt().clamp(min=value), x0.ndim) 44 | return x0 * (value / s) 45 | 46 | 47 | def spatial_norm_thresholding(x0, value): 48 | # b c h w 49 | s = x0.pow(2).mean(1, keepdim=True).sqrt().clamp(min=value) 50 | return x0 * (value / s) -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/src/zero123/zero1/ldm/modules/diffusionmodules/__init__.py -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/src/zero123/zero1/ldm/modules/distributions/__init__.py -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/scy639/Extreme-Two-View-Geometry-From-Object-Poses-with-Diffusion-Models/49bf1508ef1c625069b18ec45931b0976f11b482/src/zero123/zero1/ldm/modules/encoders/__init__.py -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/evaluate/frechet_video_distance.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2022 The Google Research Authors. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | 16 | # Lint as: python2, python3 17 | """Minimal Reference implementation for the Frechet Video Distance (FVD). 18 | 19 | FVD is a metric for the quality of video generation models. It is inspired by 20 | the FID (Frechet Inception Distance) used for images, but uses a different 21 | embedding to be better suitable for videos. 22 | """ 23 | 24 | from __future__ import absolute_import 25 | from __future__ import division 26 | from __future__ import print_function 27 | 28 | 29 | import six 30 | import tensorflow.compat.v1 as tf 31 | import tensorflow_gan as tfgan 32 | import tensorflow_hub as hub 33 | 34 | 35 | def preprocess(videos, target_resolution): 36 | """Runs some preprocessing on the videos for I3D model. 37 | 38 | Args: 39 | videos: [batch_size, num_frames, height, width, depth] The videos to be 40 | preprocessed. We don't care about the specific dtype of the videos, it can 41 | be anything that tf.image.resize_bilinear accepts. Values are expected to 42 | be in the range 0-255. 43 | target_resolution: (width, height): target video resolution 44 | 45 | Returns: 46 | videos: [batch_size, num_frames, height, width, depth] 47 | """ 48 | videos_shape = list(videos.shape) 49 | all_frames = tf.reshape(videos, [-1] + videos_shape[-3:]) 50 | resized_videos = tf.image.resize_bilinear(all_frames, size=target_resolution) 51 | target_shape = [videos_shape[0], -1] + list(target_resolution) + [3] 52 | output_videos = tf.reshape(resized_videos, target_shape) 53 | scaled_videos = 2. * tf.cast(output_videos, tf.float32) / 255. - 1 54 | return scaled_videos 55 | 56 | 57 | def _is_in_graph(tensor_name): 58 | """Checks whether a given tensor does exists in the graph.""" 59 | try: 60 | tf.get_default_graph().get_tensor_by_name(tensor_name) 61 | except KeyError: 62 | return False 63 | return True 64 | 65 | 66 | def create_id3_embedding(videos,warmup=False,batch_size=16): 67 | """Embeds the given videos using the Inflated 3D Convolution ne twork. 68 | 69 | Downloads the graph of the I3D from tf.hub and adds it to the graph on the 70 | first call. 71 | 72 | Args: 73 | videos: [batch_size, num_frames, height=224, width=224, depth=3]. 74 | Expected range is [-1, 1]. 75 | 76 | Returns: 77 | embedding: [batch_size, embedding_size]. embedding_size depends 78 | on the model used. 79 | 80 | Raises: 81 | ValueError: when a provided embedding_layer is not supported. 82 | """ 83 | 84 | # batch_size = 16 85 | module_spec = "https://tfhub.dev/deepmind/i3d-kinetics-400/1" 86 | 87 | 88 | # Making sure that we import the graph separately for 89 | # each different input video tensor. 90 | module_name = "fvd_kinetics-400_id3_module_" + six.ensure_str( 91 | videos.name).replace(":", "_") 92 | 93 | 94 | 95 | assert_ops = [ 96 | tf.Assert( 97 | tf.reduce_max(videos) <= 1.001, 98 | ["max value in frame is > 1", videos]), 99 | tf.Assert( 100 | tf.reduce_min(videos) >= -1.001, 101 | ["min value in frame is < -1", videos]), 102 | tf.assert_equal( 103 | tf.shape(videos)[0], 104 | batch_size, ["invalid frame batch size: ", 105 | tf.shape(videos)], 106 | summarize=6), 107 | ] 108 | with tf.control_dependencies(assert_ops): 109 | videos = tf.identity(videos) 110 | 111 | module_scope = "%s_apply_default/" % module_name 112 | 113 | # To check whether the module has already been loaded into the graph, we look 114 | # for a given tensor name. If this tensor name exists, we assume the function 115 | # has been called before and the graph was imported. Otherwise we import it. 116 | # Note: in theory, the tensor could exist, but have wrong shapes. 117 | # This will happen if create_id3_embedding is called with a frames_placehoder 118 | # of wrong size/batch size, because even though that will throw a tf.Assert 119 | # on graph-execution time, it will insert the tensor (with wrong shape) into 120 | # the graph. This is why we need the following assert. 121 | if warmup: 122 | video_batch_size = int(videos.shape[0]) 123 | assert video_batch_size in [batch_size, -1, None], f"Invalid batch size {video_batch_size}" 124 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 125 | if not _is_in_graph(tensor_name): 126 | i3d_model = hub.Module(module_spec, name=module_name) 127 | i3d_model(videos) 128 | 129 | # gets the kinetics-i3d-400-logits layer 130 | tensor_name = module_scope + "RGB/inception_i3d/Mean:0" 131 | tensor = tf.get_default_graph().get_tensor_by_name(tensor_name) 132 | return tensor 133 | 134 | 135 | def calculate_fvd(real_activations, 136 | generated_activations): 137 | """Returns a list of ops that compute metrics as funcs of activations. 138 | 139 | Args: 140 | real_activations: [num_samples, embedding_size] 141 | generated_activations: [num_samples, embedding_size] 142 | 143 | Returns: 144 | A scalar that contains the requested FVD. 145 | """ 146 | return tfgan.eval.frechet_classifier_distance_from_activations( 147 | real_activations, generated_activations) 148 | -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/evaluate/ssim.py: -------------------------------------------------------------------------------- 1 | # MIT Licence 2 | 3 | # Methods to predict the SSIM, taken from 4 | # https://github.com/Po-Hsun-Su/pytorch-ssim/blob/master/pytorch_ssim/__init__.py 5 | 6 | from math import exp 7 | 8 | import torch 9 | import torch.nn.functional as F 10 | from torch.autograd import Variable 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor( 14 | [ 15 | exp(-((x - window_size // 2) ** 2) / float(2 * sigma ** 2)) 16 | for x in range(window_size) 17 | ] 18 | ) 19 | return gauss / gauss.sum() 20 | 21 | 22 | def create_window(window_size, channel): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = Variable( 26 | _2D_window.expand(channel, 1, window_size, window_size).contiguous() 27 | ) 28 | return window 29 | 30 | 31 | def _ssim( 32 | img1, img2, window, window_size, channel, mask=None, size_average=True 33 | ): 34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 36 | 37 | mu1_sq = mu1.pow(2) 38 | mu2_sq = mu2.pow(2) 39 | mu1_mu2 = mu1 * mu2 40 | 41 | sigma1_sq = ( 42 | F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) 43 | - mu1_sq 44 | ) 45 | sigma2_sq = ( 46 | F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) 47 | - mu2_sq 48 | ) 49 | sigma12 = ( 50 | F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) 51 | - mu1_mu2 52 | ) 53 | 54 | C1 = (0.01) ** 2 55 | C2 = (0.03) ** 2 56 | 57 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ( 58 | (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2) 59 | ) 60 | 61 | if not (mask is None): 62 | b = mask.size(0) 63 | ssim_map = ssim_map.mean(dim=1, keepdim=True) * mask 64 | ssim_map = ssim_map.view(b, -1).sum(dim=1) / mask.view(b, -1).sum( 65 | dim=1 66 | ).clamp(min=1) 67 | return ssim_map 68 | 69 | import pdb 70 | 71 | pdb.set_trace 72 | 73 | if size_average: 74 | return ssim_map.mean() 75 | else: 76 | return ssim_map.mean(1).mean(1).mean(1) 77 | 78 | 79 | class SSIM(torch.nn.Module): 80 | def __init__(self, window_size=11, size_average=True): 81 | super(SSIM, self).__init__() 82 | self.window_size = window_size 83 | self.size_average = size_average 84 | self.channel = 1 85 | self.window = create_window(window_size, self.channel) 86 | 87 | def forward(self, img1, img2, mask=None): 88 | (_, channel, _, _) = img1.size() 89 | 90 | if ( 91 | channel == self.channel 92 | and self.window.data.type() == img1.data.type() 93 | ): 94 | window = self.window 95 | else: 96 | window = create_window(self.window_size, channel) 97 | 98 | if img1.is_cuda: 99 | window = window.cuda(img1.get_device()) 100 | window = window.type_as(img1) 101 | 102 | self.window = window 103 | self.channel = channel 104 | 105 | return _ssim( 106 | img1, 107 | img2, 108 | window, 109 | self.window_size, 110 | channel, 111 | mask, 112 | self.size_average, 113 | ) 114 | 115 | 116 | def ssim(img1, img2, window_size=11, mask=None, size_average=True): 117 | (_, channel, _, _) = img1.size() 118 | window = create_window(window_size, channel) 119 | 120 | if img1.is_cuda: 121 | window = window.cuda(img1.get_device()) 122 | window = window.type_as(img1) 123 | 124 | return _ssim(img1, img2, window, window_size, channel, mask, size_average) 125 | -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/thirdp/psp/helpers.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from collections import namedtuple 4 | import torch 5 | from torch.nn import Conv2d, BatchNorm2d, PReLU, ReLU, Sigmoid, MaxPool2d, AdaptiveAvgPool2d, Sequential, Module 6 | 7 | """ 8 | ArcFace implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 9 | """ 10 | 11 | 12 | class Flatten(Module): 13 | def forward(self, input): 14 | return input.view(input.size(0), -1) 15 | 16 | 17 | def l2_norm(input, axis=1): 18 | norm = torch.norm(input, 2, axis, True) 19 | output = torch.div(input, norm) 20 | return output 21 | 22 | 23 | class Bottleneck(namedtuple('Block', ['in_channel', 'depth', 'stride'])): 24 | """ A named tuple describing a ResNet block. """ 25 | 26 | 27 | def get_block(in_channel, depth, num_units, stride=2): 28 | return [Bottleneck(in_channel, depth, stride)] + [Bottleneck(depth, depth, 1) for i in range(num_units - 1)] 29 | 30 | 31 | def get_blocks(num_layers): 32 | if num_layers == 50: 33 | blocks = [ 34 | get_block(in_channel=64, depth=64, num_units=3), 35 | get_block(in_channel=64, depth=128, num_units=4), 36 | get_block(in_channel=128, depth=256, num_units=14), 37 | get_block(in_channel=256, depth=512, num_units=3) 38 | ] 39 | elif num_layers == 100: 40 | blocks = [ 41 | get_block(in_channel=64, depth=64, num_units=3), 42 | get_block(in_channel=64, depth=128, num_units=13), 43 | get_block(in_channel=128, depth=256, num_units=30), 44 | get_block(in_channel=256, depth=512, num_units=3) 45 | ] 46 | elif num_layers == 152: 47 | blocks = [ 48 | get_block(in_channel=64, depth=64, num_units=3), 49 | get_block(in_channel=64, depth=128, num_units=8), 50 | get_block(in_channel=128, depth=256, num_units=36), 51 | get_block(in_channel=256, depth=512, num_units=3) 52 | ] 53 | else: 54 | raise ValueError("Invalid number of layers: {}. Must be one of [50, 100, 152]".format(num_layers)) 55 | return blocks 56 | 57 | 58 | class SEModule(Module): 59 | def __init__(self, channels, reduction): 60 | super(SEModule, self).__init__() 61 | self.avg_pool = AdaptiveAvgPool2d(1) 62 | self.fc1 = Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False) 63 | self.relu = ReLU(inplace=True) 64 | self.fc2 = Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False) 65 | self.sigmoid = Sigmoid() 66 | 67 | def forward(self, x): 68 | module_input = x 69 | x = self.avg_pool(x) 70 | x = self.fc1(x) 71 | x = self.relu(x) 72 | x = self.fc2(x) 73 | x = self.sigmoid(x) 74 | return module_input * x 75 | 76 | 77 | class bottleneck_IR(Module): 78 | def __init__(self, in_channel, depth, stride): 79 | super(bottleneck_IR, self).__init__() 80 | if in_channel == depth: 81 | self.shortcut_layer = MaxPool2d(1, stride) 82 | else: 83 | self.shortcut_layer = Sequential( 84 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 85 | BatchNorm2d(depth) 86 | ) 87 | self.res_layer = Sequential( 88 | BatchNorm2d(in_channel), 89 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), PReLU(depth), 90 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), BatchNorm2d(depth) 91 | ) 92 | 93 | def forward(self, x): 94 | shortcut = self.shortcut_layer(x) 95 | res = self.res_layer(x) 96 | return res + shortcut 97 | 98 | 99 | class bottleneck_IR_SE(Module): 100 | def __init__(self, in_channel, depth, stride): 101 | super(bottleneck_IR_SE, self).__init__() 102 | if in_channel == depth: 103 | self.shortcut_layer = MaxPool2d(1, stride) 104 | else: 105 | self.shortcut_layer = Sequential( 106 | Conv2d(in_channel, depth, (1, 1), stride, bias=False), 107 | BatchNorm2d(depth) 108 | ) 109 | self.res_layer = Sequential( 110 | BatchNorm2d(in_channel), 111 | Conv2d(in_channel, depth, (3, 3), (1, 1), 1, bias=False), 112 | PReLU(depth), 113 | Conv2d(depth, depth, (3, 3), stride, 1, bias=False), 114 | BatchNorm2d(depth), 115 | SEModule(depth, 16) 116 | ) 117 | 118 | def forward(self, x): 119 | shortcut = self.shortcut_layer(x) 120 | res = self.res_layer(x) 121 | return res + shortcut -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/thirdp/psp/id_loss.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | import torch 3 | from torch import nn 4 | from ldm.thirdp.psp.model_irse import Backbone 5 | 6 | 7 | class IDFeatures(nn.Module): 8 | def __init__(self, model_path): 9 | super(IDFeatures, self).__init__() 10 | print('Loading ResNet ArcFace') 11 | self.facenet = Backbone(input_size=112, num_layers=50, drop_ratio=0.6, mode='ir_se') 12 | self.facenet.load_state_dict(torch.load(model_path, map_location="cpu")) 13 | self.face_pool = torch.nn.AdaptiveAvgPool2d((112, 112)) 14 | self.facenet.eval() 15 | 16 | def forward(self, x, crop=False): 17 | # Not sure of the image range here 18 | if crop: 19 | x = torch.nn.functional.interpolate(x, (256, 256), mode="area") 20 | x = x[:, :, 35:223, 32:220] 21 | x = self.face_pool(x) 22 | x_feats = self.facenet(x) 23 | return x_feats 24 | -------------------------------------------------------------------------------- /src/zero123/zero1/ldm/thirdp/psp/model_irse.py: -------------------------------------------------------------------------------- 1 | # https://github.com/eladrich/pixel2style2pixel 2 | 3 | from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Dropout, Sequential, Module 4 | from ldm.thirdp.psp.helpers import get_blocks, Flatten, bottleneck_IR, bottleneck_IR_SE, l2_norm 5 | 6 | """ 7 | Modified Backbone implementation from [TreB1eN](https://github.com/TreB1eN/InsightFace_Pytorch) 8 | """ 9 | 10 | 11 | class Backbone(Module): 12 | def __init__(self, input_size, num_layers, mode='ir', drop_ratio=0.4, affine=True): 13 | super(Backbone, self).__init__() 14 | assert input_size in [112, 224], "input_size should be 112 or 224" 15 | assert num_layers in [50, 100, 152], "num_layers should be 50, 100 or 152" 16 | assert mode in ['ir', 'ir_se'], "mode should be ir or ir_se" 17 | blocks = get_blocks(num_layers) 18 | if mode == 'ir': 19 | unit_module = bottleneck_IR 20 | elif mode == 'ir_se': 21 | unit_module = bottleneck_IR_SE 22 | self.input_layer = Sequential(Conv2d(3, 64, (3, 3), 1, 1, bias=False), 23 | BatchNorm2d(64), 24 | PReLU(64)) 25 | if input_size == 112: 26 | self.output_layer = Sequential(BatchNorm2d(512), 27 | Dropout(drop_ratio), 28 | Flatten(), 29 | Linear(512 * 7 * 7, 512), 30 | BatchNorm1d(512, affine=affine)) 31 | else: 32 | self.output_layer = Sequential(BatchNorm2d(512), 33 | Dropout(drop_ratio), 34 | Flatten(), 35 | Linear(512 * 14 * 14, 512), 36 | BatchNorm1d(512, affine=affine)) 37 | 38 | modules = [] 39 | for block in blocks: 40 | for bottleneck in block: 41 | modules.append(unit_module(bottleneck.in_channel, 42 | bottleneck.depth, 43 | bottleneck.stride)) 44 | self.body = Sequential(*modules) 45 | 46 | def forward(self, x): 47 | x = self.input_layer(x) 48 | x = self.body(x) 49 | x = self.output_layer(x) 50 | return l2_norm(x) 51 | 52 | 53 | def IR_50(input_size): 54 | """Constructs a ir-50 model.""" 55 | model = Backbone(input_size, num_layers=50, mode='ir', drop_ratio=0.4, affine=False) 56 | return model 57 | 58 | 59 | def IR_101(input_size): 60 | """Constructs a ir-101 model.""" 61 | model = Backbone(input_size, num_layers=100, mode='ir', drop_ratio=0.4, affine=False) 62 | return model 63 | 64 | 65 | def IR_152(input_size): 66 | """Constructs a ir-152 model.""" 67 | model = Backbone(input_size, num_layers=152, mode='ir', drop_ratio=0.4, affine=False) 68 | return model 69 | 70 | 71 | def IR_SE_50(input_size): 72 | """Constructs a ir_se-50 model.""" 73 | model = Backbone(input_size, num_layers=50, mode='ir_se', drop_ratio=0.4, affine=False) 74 | return model 75 | 76 | 77 | def IR_SE_101(input_size): 78 | """Constructs a ir_se-101 model.""" 79 | model = Backbone(input_size, num_layers=100, mode='ir_se', drop_ratio=0.4, affine=False) 80 | return model 81 | 82 | 83 | def IR_SE_152(input_size): 84 | """Constructs a ir_se-152 model.""" 85 | model = Backbone(input_size, num_layers=152, mode='ir_se', drop_ratio=0.4, affine=False) 86 | return model -------------------------------------------------------------------------------- /src/zero123/zero1/run4gen6d.py: -------------------------------------------------------------------------------- 1 | import argparse, os, sys, datetime, glob 2 | import numpy as np 3 | import time 4 | import torch 5 | 6 | 7 | def get_parser(**parser_kwargs): 8 | 9 | parser = argparse.ArgumentParser(**parser_kwargs) 10 | 11 | parser.add_argument( 12 | "--id", 13 | type=str, 14 | ) 15 | parser.add_argument( 16 | "--input_image_path", 17 | type=str, 18 | ) 19 | parser.add_argument( 20 | "--output_dir", 21 | type=str, 22 | ) 23 | parser.add_argument( 24 | "--num_samples", 25 | type=int, 26 | default=4, 27 | ) 28 | 29 | return parser 30 | from run_ import sample_model_,sample_model_batchB_wrapper 31 | from util_4_e2vg.genIntermediateResult import genIntermediateResult 32 | from util_4_e2vg.Util import OutputIm_Name_Parser 33 | def main( 34 | id, 35 | input_image_path, 36 | output_dir, 37 | num_samples, 38 | K, 39 | ddim_steps=50, 40 | base_xyz=(0,0,0), 41 | **kw, 42 | ): 43 | parser = get_parser() 44 | args=parser.parse_args() 45 | args.id=id 46 | args.input_image_path=input_image_path 47 | args.output_dir=output_dir 48 | args.num_samples=num_samples 49 | 50 | 51 | #args: output_dir, input_image_path, num_samples 52 | folder_output_ims:str=sample_model_( input_image_path=input_image_path, num_samples=num_samples, id=id, 53 | batch_sample=True, 54 | ddim_steps=ddim_steps, 55 | **kw) 56 | if 'only_gen' in kw and kw['only_gen']: 57 | l__path_output_im:list=OutputIm_Name_Parser.parse_B(folder_output_ims) 58 | return l__path_output_im 59 | 60 | 61 | 62 | import sys,os 63 | cur_dir = os.path.dirname(os.path.abspath(__file__)) 64 | sys.path.append(os.path.join(cur_dir, "util_4_e2vg")) 65 | 66 | 67 | genIntermediateResult(path=folder_output_ims, path_save=args.output_dir, calib_xy=(0,0), base_xyz=base_xyz,called_by_run4gen6d=True,K=K, 68 | # id=args.id 69 | ) 70 | sys.path.pop() 71 | 72 | -------------------------------------------------------------------------------- /src/zero123/zero1/util_4_e2vg/CameraMatrixUtil.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json, math 3 | 4 | """ 5 | 对img操作过程中内参外参相关的变换 6 | """ 7 | 8 | 9 | def get_K(img_h, img_w): 10 | cx = img_w / 2 11 | cy = img_h / 2 12 | f = np.sqrt(img_h ** 2 + img_w ** 2) 13 | fx = f 14 | fy = f 15 | K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 16 | return K 17 | 18 | 19 | def xyz2pose(x, y, z): 20 | """ 21 | 相机相对obj极坐标为:x,y,z 22 | x:Polar angle (vertical rotation in degrees) 23 | y:Azimuth angle (horizontal rotation in degrees) 24 | z:Zoom (relative distance from center) 25 | 26 | 相机前方为z轴,下方为y轴 27 | 相机对着obj(obj在图像中心 28 | return: pose:object pose means a translation t and a rotation R that transform the object coordinate xobj to the camera coordinate xcam = R xobj+t 29 | pose=[R,t;0,1] 30 | 31 | """ 32 | 33 | p = math.radians(-x) 34 | a = math.radians(y) 35 | R_AC = np.array([ 36 | [-np.sin(a), np.cos(a), 0], 37 | [0, 0, -1], 38 | [-np.cos(a), - np.sin(a), 0], 39 | ]) 40 | R_CD = np.array([ 41 | [1, 0, 0], 42 | [0, np.cos(p), -np.sin(p)], 43 | [0, np.sin(p), np.cos(p)], 44 | ]) 45 | R = R_CD @ R_AC 46 | 47 | z_m = z 48 | t = np.array([[0, 0, z_m]]).T 49 | pose = np.concatenate([R, t], 1) 50 | # pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], 0) 51 | return pose 52 | 53 | 54 | def get_z_4_normObj(fx, fy, obj_w_pixel, obj_h_pixel,img_w,img_h): 55 | """ 56 | obj_w_pixel: obj width in pixel 57 | normObj means 物体在单位立方体中(其实是2,2,2 cube 58 | 59 | """ 60 | # if (obj_w_pixel >= obj_h_pixel): 61 | if (obj_w_pixel/obj_h_pixel >= img_w/img_h): 62 | z = fx / obj_w_pixel *2 63 | else: 64 | z = fy / obj_h_pixel *2 65 | return z 66 | def get_fx_fy_4_normObj(z,obj_w_pixel,obj_h_pixel,img_w,img_h): 67 | if (obj_w_pixel/obj_h_pixel >= img_w/img_h): 68 | fx = z * obj_w_pixel/2 69 | fy = z * obj_w_pixel/2 70 | else: 71 | fx = z * obj_h_pixel/2 72 | fy = z * obj_h_pixel/2 73 | return fx,fy 74 | 75 | def crop(K, coord0: tuple, coord1: tuple): 76 | fx = K[0][0] 77 | fy = K[1][1] 78 | cx = K[0][2] 79 | cy = K[1][2] 80 | fx_new = fx 81 | fy_new = fy 82 | cx_new = cx - coord0[0] 83 | cy_new = cy - coord0[1] 84 | K_new = np.array([[fx_new, 0, cx_new], [0, fy_new, cy_new], [0, 0, 1]]) 85 | return K_new 86 | 87 | 88 | def resize(K, h_old, w_old, h_new, w_new): 89 | """ 90 | K_new=S@K,where S=[sx,0,0;0,sy,0;0,0,1],sx=w_new/w_old. 91 | 推导过程见草稿纸(和copilot直接生成的也是一样的 92 | """ 93 | fx = K[0][0] 94 | fy = K[1][1] 95 | cx = K[0][2] 96 | cy = K[1][2] 97 | fx_new = fx * w_new / w_old 98 | fy_new = fy * h_new / h_old 99 | cx_new = cx * w_new / w_old 100 | cy_new = cy * h_new / h_old 101 | K_new = np.array([[fx_new, 0, cx_new], [0, fy_new, cy_new], [0, 0, 1]]) 102 | return K_new 103 | -------------------------------------------------------------------------------- /src/zero123/zero1/util_4_e2vg/ImagePathUtil.py: -------------------------------------------------------------------------------- 1 | import os 2 | def get_path(i,j,x,y,z): 3 | return f"{i}-{j}(x={x},y={y},z={z}).jpg" 4 | def i2glob(i,): 5 | return f"{i}-*.jpg" 6 | 7 | 8 | #------zero123 output_im------------------------ 9 | def parse_path(path:str): 10 | path=os.path.basename(path) 11 | """ 12 | path=11-2(x=0,y=30.0,z=0).png=i-j(x=0,y=30.0,z=0).png 13 | """ 14 | file=path 15 | i = int(file.split('-')[0]) 16 | j = int(file.split('-')[1].split('(')[0]) # index of sample 17 | rest = file[len(f"{i}-j"):] 18 | x=float(rest.split('(')[1].split(',')[0].split('=')[1]) 19 | y=float(rest.split('(')[1].split(',')[1].split('=')[1]) 20 | z=float(rest.split('(')[1].split(',')[2].split('=')[1].split(')')[0]) 21 | return i,j,x,y,z 22 | def sort_outputIm_A(paths:list[str])->list[str]: 23 | def path2xy(imgPath): 24 | i,j,x,y,z=parse_path(imgPath) 25 | assert j==0 26 | return x,y 27 | l_xy=list(map(path2xy,paths)) 28 | new_imgPaths=[] 29 | 30 | N=3 31 | min_x,max_x=min(l_xy,key=lambda xy:xy[0])[0],max(l_xy,key=lambda xy:xy[0])[0] 32 | l__path_xy=zip(paths,l_xy) 33 | #firstly we sort by y 34 | l__path_xy=sorted(l__path_xy,key=lambda path_xy:path_xy[1][1],reverse=False) 35 | #crate a emtpy list with len N 36 | l=[[] for i in range(N)] 37 | 38 | x_interval=(max_x-min_x)/N 39 | ranges=[[min_x+i*x_interval,min_x+i*x_interval+x_interval] for i in range(N)] 40 | ranges[-1][1]=max_x 41 | for path,xy in l__path_xy: 42 | x,y=xy 43 | level=int((x-min_x)//x_interval) 44 | if level==N: 45 | level-=1 46 | assert ranges[level][0]<=x<=ranges[level][1] 47 | l[level].append(path) 48 | l.reverse() 49 | # print(f"{l=}") 50 | #flat l 51 | new_imgPaths=[] 52 | for i in range(N): 53 | new_imgPaths.extend(l[i]) 54 | return new_imgPaths -------------------------------------------------------------------------------- /src/zero123/zero1/util_4_e2vg/IntermediateResult.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import json, math 3 | 4 | 5 | class IntermediateResult: 6 | def __init__(s, ): 7 | s.data = {} 8 | 9 | def append(s, i, K, pose): 10 | s.data[i] = { 11 | "K": K, 12 | "pose": pose 13 | } 14 | 15 | def load(s, path): 16 | with open(path, "r") as f: 17 | s.data = json.load(f) 18 | 19 | for i in s.data: 20 | for key in s.data[i]: 21 | s.data[i][key] = np.array(s.data[i][key]) 22 | 23 | def dump(self, path): 24 | 25 | import json 26 | import numpy 27 | from torch import Tensor 28 | 29 | class MyJSONEncoder(json.JSONEncoder): 30 | def default(self, obj): 31 | if isinstance(obj, numpy.ndarray): 32 | return obj.tolist() 33 | if isinstance(obj, Tensor): 34 | return obj.cpu().data.numpy().tolist() 35 | elif (isinstance(obj, numpy.int32) or 36 | isinstance(obj, numpy.int64) or 37 | isinstance(obj, numpy.float32) or 38 | isinstance(obj, numpy.float64)): 39 | return obj.item() 40 | return json.JSONEncoder.default(self, obj) 41 | 42 | with open(path, "w") as f: 43 | json.dump(self.data, f, cls=MyJSONEncoder) 44 | -------------------------------------------------------------------------------- /src/zero123/zero1/util_4_e2vg/Util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from skimage.io import imsave 4 | 5 | # from util_4_e2vg import ImagePathUtil 6 | # from util_4_e2vg import CameraMatrixUtil 7 | # from util_4_e2vg.IntermediateResult import IntermediateResult 8 | 9 | 10 | # import ImagePathUtil 11 | # import CameraMatrixUtil 12 | # from IntermediateResult import IntermediateResult 13 | 14 | 15 | def get_xyzLinearGradient(xyzStops: tuple, N: int): 16 | """ 17 | param: 18 | xyzStops:((x=0,y=0,z=0,stop=0),(x=0,y=90,z=0,stop=0.2),(x=80,y=90,z=1,stop=0.3),...(,,,1.0)) 19 | N:xyzLinearGradient长度 20 | return: 21 | xyzLinearGradient:[(x=0,y=0,z=0),(x=,y=,z=),...] 22 | """ 23 | xyzLinearGradient = [] 24 | 25 | 26 | num_stops = len(xyzStops) 27 | 28 | 29 | for i in range(N): 30 | stop = i / (N - 1) 31 | 32 | 33 | for j in range(num_stops - 1): 34 | if stop >= xyzStops[j][3] and stop <= xyzStops[j + 1][3]: 35 | 36 | t = (stop - xyzStops[j][3]) / (xyzStops[j + 1][3] - xyzStops[j][3]) 37 | x = xyzStops[j][0] + (xyzStops[j + 1][0] - xyzStops[j][0]) * t 38 | y = xyzStops[j][1] + (xyzStops[j + 1][1] - xyzStops[j][1]) * t 39 | z = xyzStops[j][2] + (xyzStops[j + 1][2] - xyzStops[j][2]) * t 40 | xyzLinearGradient.append((x, y, z)) 41 | break 42 | 43 | return xyzLinearGradient 44 | class OutputIm_Name_Parser: 45 | @staticmethod 46 | def parse_A(folder): 47 | # return i2samples 48 | files = os.listdir(folder) 49 | i2samples = {} 50 | for file in files: 51 | # if is jpg 52 | if (file.split('.')[-1] != 'jpg'): 53 | continue 54 | # file=11-2(x=0,y=30.0,z=0).png=i-j(x=0,y=30.0,z=0).png 55 | i = int(file.split('-')[0]) 56 | j = int(file.split('-')[1].split('(')[0]) # index of sample 57 | rest = file[len(f"{i}-j"):] 58 | if (i not in i2samples): 59 | i2samples[i] = [] 60 | i2samples[i].append(file) 61 | return i2samples 62 | @staticmethod 63 | def parse_B(folder,in_fullPath=True): 64 | i2samples=OutputIm_Name_Parser.parse_A(folder) 65 | ret=[] 66 | for i in range(len(i2samples)): 67 | samples=i2samples[i] 68 | assert len(samples)==1 69 | ret.append(samples[0]) 70 | if in_fullPath: 71 | ret=[os.path.join(folder,i) for i in ret] 72 | return ret -------------------------------------------------------------------------------- /src/zero123/zero1/util_4_e2vg/__init__.py: -------------------------------------------------------------------------------- 1 | import sys 2 | -------------------------------------------------------------------------------- /src/zero123/zero1/util_4_e2vg/choose_j.py: -------------------------------------------------------------------------------- 1 | import root_config,time 2 | # from imports import * 3 | from exception_util import handle_exception 4 | 5 | def main(read_path, save_path,ask): 6 | SRC = read_path 7 | new_path = save_path 8 | """ 9 | """ 10 | import os 11 | import cv2 12 | 13 | # SRC = "original-png-4samples" 14 | # SRC="test" 15 | 16 | def get_match_score(img0, img1) -> float: 17 | """ 18 | :param img0: 0-0(x=0,y=30.0,z=0).png 19 | :param img1: 0-1(x=0,y=30.0,z=0).png 20 | :return: match score 21 | 22 | """ 23 | 24 | """ 25 | # copilot version: 26 | # :logic: 1. get sift feature of img0 and img1 27 | # 2. match sift feature 28 | # 3. get match score 29 | sift = cv2.xfeatures2d.SIFT_create() 30 | kp0, des0 = sift.detectAndCompute(img0, None) 31 | kp1, des1 = sift.detectAndCompute(img1, None) 32 | bf = cv2.BFMatcher() 33 | matches = bf.knnMatch(des0, des1, k=2) 34 | good = [] 35 | for m, n in matches: 36 | if m.distance < 0.5 * n.distance: 37 | good.append([m]) 38 | return len(good) 39 | """ 40 | 41 | """ 42 | # conv version 43 | # calculate conv( matrix mul and accumulate all elements 44 | 45 | return (img0 * img1).sum() 46 | """ 47 | 48 | """ 49 | 计算每个ele之间的距离然后sum 50 | """ 51 | # return -(img0 - img1).sum() 52 | return -((img0 - img1) ** 2).sum() 53 | 54 | def choose_j(pre_img, l_cur_img: list): 55 | """ 56 | :param pre_img: i-1时的最终被选图片 57 | :param l_cur_img: i时所有备选图片 58 | :return: 最终被选图片的index(j 59 | :logic: 选择与pre_img最match的. 60 | """ 61 | 62 | score = 0 63 | j = 0 64 | for i in range(len(l_cur_img)): 65 | cur_img = l_cur_img[i] 66 | cur_score = get_match_score(pre_img, cur_img) 67 | if (cur_score > score): 68 | score = cur_score 69 | j = i 70 | return j 71 | 72 | # s1: get i2samples 73 | files = os.listdir(SRC) 74 | i2samples = {} 75 | for file in files: 76 | # if is jpg 77 | if (file.split('.')[-1] != 'jpg'): 78 | continue 79 | # file=11-2(x=0,y=30.0,z=0).png=i-j(x=0,y=30.0,z=0).png 80 | i = int(file.split('-')[0]) 81 | j = int(file.split('-')[1].split('(')[0]) # index of sample 82 | rest = file[len(f"{i}-j"):] 83 | if (i not in i2samples): 84 | i2samples[i] = [] 85 | i2samples[i].append(file) 86 | 87 | # s2: i2samples 2 l_copy 88 | def file2img(file: str): 89 | return cv2.imread(f"{SRC}/{file}") 90 | 91 | l_copy = [] 92 | pre_img = None 93 | for i in range(len(i2samples)): 94 | # for i in [44,45]: 95 | cur_fileNames = i2samples[i] 96 | if(root_config.USE_ALL_SAMPLE): 97 | for j in range(root_config.NUM_SAMPLE): 98 | l_copy.append(( 99 | cur_fileNames[j], 100 | # f"{i}.jpg" 101 | cur_fileNames[j] 102 | )) 103 | else: 104 | if (pre_img is None): 105 | pre_img = file2img(cur_fileNames[0]) 106 | # cur_fileNames 2 cur_imgs. read img to nd array 107 | cur_imgs = [file2img(cur_fileName) for cur_fileName in cur_fileNames] 108 | j = choose_j(pre_img, cur_imgs) 109 | pre_img = cur_imgs[j] 110 | l_copy.append(( 111 | cur_fileNames[j], 112 | # f"{i}.jpg" 113 | cur_fileNames[j] 114 | )) 115 | 116 | # print(l_copy) 117 | import shutil 118 | 119 | if not os.path.exists(new_path): 120 | os.makedirs(new_path) 121 | 122 | for copy in l_copy: 123 | # new_path = f"onlyI-png-1samples" 124 | try: 125 | shutil.copy(f"{SRC}/{copy[0]}", f"{new_path}/{copy[1]}") 126 | except BlockingIOError as e: 127 | time.sleep(1) 128 | pERROR(f'--------BlockingIOError------shutil.copy({f"{SRC}/{copy[0]}", f"{new_path}/{copy[1]}"}-----') 129 | handle_exception(e) 130 | -------------------------------------------------------------------------------- /src/zero123/zero1/util_4_e2vg/genIntermediateResult.py: -------------------------------------------------------------------------------- 1 | def genIntermediateResult(K,path=None, 2 | path_save=None, #eg. /baseline/relpose_plus_plus_main/relpose/../../../gen6d/Gen6D/./data/zero123/GSO_alarm----+8/ref' 3 | calib_xy=(0,0) ,base_xyz=(0,0,0),called_by_run4gen6d=False ): 4 | ASK=0 5 | CHOOSE_J=False 6 | MOVE_OBJ_TO_CENTER=False 7 | import sys 8 | import os 9 | 10 | 11 | 12 | 13 | #---------------------------------------------- 14 | path0 = path 15 | if CHOOSE_J: 16 | import choose_j 17 | path1 = os.path.join(path, 'after_choose_j') 18 | choose_j.main(read_path=path0, save_path=path1, ask=ASK) 19 | else: 20 | path1 = path0 21 | 22 | #---------------------------------------------- 23 | if MOVE_OBJ_TO_CENTER: 24 | import move_obj_to_center 25 | path2 = os.path.join(path, 'after_move_obj_to_center') 26 | move_obj_to_center.main(read_path=path1, save_path=path2) 27 | else: 28 | path2 = path1 29 | 30 | #---------------------------------------------- 31 | import crop 32 | intermediateResult = crop.crop( 33 | read_path=path2, save_path=path_save, 34 | calib_xy=calib_xy, 35 | base_xyz=base_xyz, 36 | K=K, 37 | 38 | # **kw 39 | ask=ASK, 40 | save_image=1, 41 | do_not_crop=1, 42 | norm_obj_by_z=1, 43 | 44 | margin_percent=0, 45 | # margin_percent=0.1, 46 | 47 | DRAW_cropped_img=0, 48 | ) 49 | path_intermediateResult =os.path.join(path_save, "intermediateResult.json") 50 | intermediateResult.dump(path_intermediateResult) 51 | #print("intermediateResult saved to:", os.path.abspath(path_intermediateResult)) 52 | 53 | if (__name__ == "__main__"): 54 | genIntermediateResult( called_by_run4gen6d=False) -------------------------------------------------------------------------------- /src/zero123/zero1/util_4_e2vg/move_obj_to_center.py: -------------------------------------------------------------------------------- 1 | def main(read_path, save_path): 2 | bg_color=(255,255,255) 3 | 4 | from image_util import imgArr_2_objXminYminXmaxYmax 5 | import os 6 | import cv2 7 | """ 8 | read_path下所有.jpg,识别出obj XminYminXmaxYmax后,crop出obj,然后在obj周围补bg_color使size不变,保存至save_path/jpg_name 9 | """ 10 | if(not os.path.exists(save_path)): 11 | os.makedirs(save_path) 12 | for jpg_name in os.listdir(read_path): 13 | if jpg_name.endswith(".jpg"): 14 | #print(jpg_name) 15 | img = cv2.imread(os.path.join(read_path, jpg_name)) 16 | h,w=img.shape[:2] 17 | xmin, ymin, xmax, ymax = imgArr_2_objXminYminXmaxYmax(img,bg_color) 18 | obj = img[ymin:ymax, xmin:xmax] 19 | obj = cv2.copyMakeBorder(obj, (h-(ymax-ymin))//2, (h-(ymax-ymin))//2, (w-(xmax-xmin))//2, (w-(xmax-xmin))//2, cv2.BORDER_CONSTANT, value=bg_color) 20 | obj = cv2.copyMakeBorder(obj, 0, h-obj.shape[0],0,w-obj.shape[1], cv2.BORDER_CONSTANT, value=bg_color) 21 | cv2.imwrite(os.path.join(save_path, jpg_name), obj) 22 | 23 | 24 | 25 | 26 | --------------------------------------------------------------------------------