├── LICENSE ├── README.md ├── assets ├── demo │ ├── i2mv │ │ ├── A_decorative_figurine_of_a_young_anime-style_girl.png │ │ ├── A_juvenile_emperor_penguin_chick.png │ │ └── A_striped_tabby_cat_with_white_fur_sitting_upright.png │ ├── ig2mv │ │ ├── 1ccd5c1563ea4f5fb8152eac59dabd5c.glb │ │ ├── 1ccd5c1563ea4f5fb8152eac59dabd5c.jpeg │ │ ├── 1ccd5c1563ea4f5fb8152eac59dabd5c_mv.png │ │ ├── cartoon_style_table.glb │ │ ├── cartoon_style_table.png │ │ ├── cartoon_style_table_mv.png │ │ ├── ebffbc1ef9994b41a1841bbe9492d012.glb │ │ └── ebffbc1ef9994b41a1841bbe9492d012.jpeg │ ├── scribble2mv │ │ ├── color_0000.webp │ │ ├── color_0001.webp │ │ ├── color_0002.webp │ │ ├── color_0003.webp │ │ ├── color_0004.webp │ │ └── color_0005.webp │ └── tg2mv │ │ ├── ac9d4e4f44f34775ad46878ba8fbfd86.glb │ │ ├── ac9d4e4f44f34775ad46878ba8fbfd86_mv.png │ │ └── b5f0f0f33e3644d1ba73576ceb486d42.glb └── doc │ ├── comfyui_i2mv.png │ ├── comfyui_t2mv.png │ ├── comfyui_t2mv_lora.png │ ├── demo_i2mv_1.png │ ├── demo_i2mv_2.png │ ├── demo_t2mv_1.png │ ├── demo_t2mv_2.png │ ├── demo_t2mv_anime_1.png │ ├── demo_t2mv_anime_2.png │ ├── demo_t2mv_dreamshaper_1.png │ └── teaser.jpg ├── configs ├── geometry-guidance │ ├── mvadapter_ig2mv_partialimg_sdxl.yaml │ ├── mvadapter_ig2mv_sdxl.yaml │ └── mvadapter_tg2mv_sdxl.yaml └── view-guidance │ ├── mvadapter_i2mv_sd21.yaml │ ├── mvadapter_i2mv_sdxl.yaml │ ├── mvadapter_i2mv_sdxl_aug_quantity.yaml │ ├── mvadapter_t2mv_sd21.yaml │ └── mvadapter_t2mv_sdxl.yaml ├── launch.py ├── mvadapter ├── __init__.py ├── data │ └── multiview.py ├── loaders │ ├── __init__.py │ └── custom_adapter.py ├── models │ ├── __init__.py │ └── attention_processor.py ├── pipelines │ ├── pipeline_mvadapter_i2mv_sd.py │ ├── pipeline_mvadapter_i2mv_sdxl.py │ ├── pipeline_mvadapter_t2mv_sd.py │ ├── pipeline_mvadapter_t2mv_sdxl.py │ └── pipeline_texture.py ├── schedulers │ ├── scheduler_utils.py │ └── scheduling_shift_snr.py ├── systems │ ├── __init__.py │ ├── base.py │ ├── mvadapter_image_sd.py │ ├── mvadapter_image_sdxl.py │ ├── mvadapter_text_sd.py │ ├── mvadapter_text_sdxl.py │ └── utils.py └── utils │ ├── __init__.py │ ├── base.py │ ├── callbacks.py │ ├── config.py │ ├── core.py │ ├── geometry.py │ ├── logging.py │ ├── mesh_utils │ ├── __init__.py │ ├── blend.py │ ├── camera.py │ ├── cv_ops.py │ ├── mesh.py │ ├── mesh_process.py │ ├── projection.py │ ├── render.py │ ├── seg.py │ ├── smart_paint.py │ ├── utils.py │ ├── uv.py │ └── warp.py │ ├── misc.py │ ├── ops.py │ ├── saving.py │ └── typing.py ├── requirements.txt ├── scripts ├── __init__.py ├── gradio_demo_i2mv.py ├── gradio_demo_t2mv.py ├── inference_i2mv_sd.py ├── inference_i2mv_sdxl.py ├── inference_ig2mv_partial_sdxl.py ├── inference_ig2mv_sdxl.py ├── inference_scribble2mv_sdxl.py ├── inference_t2mv_sd.py ├── inference_t2mv_sdxl.py ├── inference_tg2mv_sdxl.py ├── texture_i2tex.py └── texture_t2tex.py └── setup.py /assets/demo/i2mv/A_decorative_figurine_of_a_young_anime-style_girl.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/i2mv/A_decorative_figurine_of_a_young_anime-style_girl.png -------------------------------------------------------------------------------- /assets/demo/i2mv/A_juvenile_emperor_penguin_chick.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/i2mv/A_juvenile_emperor_penguin_chick.png -------------------------------------------------------------------------------- /assets/demo/i2mv/A_striped_tabby_cat_with_white_fur_sitting_upright.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/i2mv/A_striped_tabby_cat_with_white_fur_sitting_upright.png -------------------------------------------------------------------------------- /assets/demo/ig2mv/1ccd5c1563ea4f5fb8152eac59dabd5c.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/ig2mv/1ccd5c1563ea4f5fb8152eac59dabd5c.glb -------------------------------------------------------------------------------- /assets/demo/ig2mv/1ccd5c1563ea4f5fb8152eac59dabd5c.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/ig2mv/1ccd5c1563ea4f5fb8152eac59dabd5c.jpeg -------------------------------------------------------------------------------- /assets/demo/ig2mv/1ccd5c1563ea4f5fb8152eac59dabd5c_mv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/ig2mv/1ccd5c1563ea4f5fb8152eac59dabd5c_mv.png -------------------------------------------------------------------------------- /assets/demo/ig2mv/cartoon_style_table.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/ig2mv/cartoon_style_table.glb -------------------------------------------------------------------------------- /assets/demo/ig2mv/cartoon_style_table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/ig2mv/cartoon_style_table.png -------------------------------------------------------------------------------- /assets/demo/ig2mv/cartoon_style_table_mv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/ig2mv/cartoon_style_table_mv.png -------------------------------------------------------------------------------- /assets/demo/ig2mv/ebffbc1ef9994b41a1841bbe9492d012.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/ig2mv/ebffbc1ef9994b41a1841bbe9492d012.glb -------------------------------------------------------------------------------- /assets/demo/ig2mv/ebffbc1ef9994b41a1841bbe9492d012.jpeg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/ig2mv/ebffbc1ef9994b41a1841bbe9492d012.jpeg -------------------------------------------------------------------------------- /assets/demo/scribble2mv/color_0000.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/scribble2mv/color_0000.webp -------------------------------------------------------------------------------- /assets/demo/scribble2mv/color_0001.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/scribble2mv/color_0001.webp -------------------------------------------------------------------------------- /assets/demo/scribble2mv/color_0002.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/scribble2mv/color_0002.webp -------------------------------------------------------------------------------- /assets/demo/scribble2mv/color_0003.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/scribble2mv/color_0003.webp -------------------------------------------------------------------------------- /assets/demo/scribble2mv/color_0004.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/scribble2mv/color_0004.webp -------------------------------------------------------------------------------- /assets/demo/scribble2mv/color_0005.webp: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/scribble2mv/color_0005.webp -------------------------------------------------------------------------------- /assets/demo/tg2mv/ac9d4e4f44f34775ad46878ba8fbfd86.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/tg2mv/ac9d4e4f44f34775ad46878ba8fbfd86.glb -------------------------------------------------------------------------------- /assets/demo/tg2mv/ac9d4e4f44f34775ad46878ba8fbfd86_mv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/tg2mv/ac9d4e4f44f34775ad46878ba8fbfd86_mv.png -------------------------------------------------------------------------------- /assets/demo/tg2mv/b5f0f0f33e3644d1ba73576ceb486d42.glb: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/demo/tg2mv/b5f0f0f33e3644d1ba73576ceb486d42.glb -------------------------------------------------------------------------------- /assets/doc/comfyui_i2mv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/comfyui_i2mv.png -------------------------------------------------------------------------------- /assets/doc/comfyui_t2mv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/comfyui_t2mv.png -------------------------------------------------------------------------------- /assets/doc/comfyui_t2mv_lora.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/comfyui_t2mv_lora.png -------------------------------------------------------------------------------- /assets/doc/demo_i2mv_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/demo_i2mv_1.png -------------------------------------------------------------------------------- /assets/doc/demo_i2mv_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/demo_i2mv_2.png -------------------------------------------------------------------------------- /assets/doc/demo_t2mv_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/demo_t2mv_1.png -------------------------------------------------------------------------------- /assets/doc/demo_t2mv_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/demo_t2mv_2.png -------------------------------------------------------------------------------- /assets/doc/demo_t2mv_anime_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/demo_t2mv_anime_1.png -------------------------------------------------------------------------------- /assets/doc/demo_t2mv_anime_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/demo_t2mv_anime_2.png -------------------------------------------------------------------------------- /assets/doc/demo_t2mv_dreamshaper_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/demo_t2mv_dreamshaper_1.png -------------------------------------------------------------------------------- /assets/doc/teaser.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/assets/doc/teaser.jpg -------------------------------------------------------------------------------- /configs/geometry-guidance/mvadapter_ig2mv_partialimg_sdxl.yaml: -------------------------------------------------------------------------------- 1 | name: ig2mv 2 | tag: "r768-ortho-nv6-ig2mv-partialimg-dcrowcol-sdxl" 3 | exp_root_dir: "outputs" 4 | seed: 42 5 | 6 | data_cls: mvadapter.data.multiview.MultiviewDataModule 7 | data: 8 | root_dir: data/texture_ortho6view_easylight 9 | scene_list: data/objaverse_list_6w.json 10 | background_color: gray 11 | image_names: ["0000", "0001", "0002", "0003", "0008", "0009"] 12 | image_modality: color 13 | num_views: 6 14 | 15 | prompt_db_path: data/objaverse_short_captions.json 16 | return_prompt: true 17 | 18 | projection_type: ORTHO 19 | 20 | source_image_modality: ["position", "normal"] 21 | position_offset: 0.5 22 | position_scale: 1.0 23 | 24 | reference_root_dir: ["data/texture_rand_easylight_objaverse"] 25 | reference_scene_list: ["data/objaverse_list_6w.json"] 26 | reference_image_modality: color 27 | reference_image_names: ["0000", "0001", "0002", "0003", "0004"] 28 | reference_mask_aug: true 29 | 30 | train_indices: [0, -8] 31 | val_indices: [-8, null] 32 | test_indices: [-8, null] 33 | 34 | height: 768 35 | width: 768 36 | 37 | batch_size: 1 38 | num_workers: 16 39 | 40 | system_cls: mvadapter.systems.mvadapter_image_sdxl.MVAdapterImageSDXLSystem 41 | system: 42 | check_train_every_n_steps: 1000 43 | cleanup_after_validation_step: true 44 | cleanup_after_test_step: true 45 | 46 | # Model / Adapter 47 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" 48 | pretrained_vae_name_or_path: "madebyollin/sdxl-vae-fp16-fix" 49 | pretrained_adapter_name_or_path: null 50 | init_adapter_kwargs: 51 | # Multi-view adapter 52 | self_attn_processor: "mvadapter.models.attention_processor.DecoupledMVRowColSelfAttnProcessor2_0" 53 | # Condition encoder 54 | cond_in_channels: 6 55 | # For training 56 | copy_attn_weights: true 57 | zero_init_module_keys: ["to_out_mv", "to_out_ref"] 58 | 59 | # Training 60 | train_cond_encoder: true 61 | trainable_modules: ["_mv", "_ref"] 62 | prompt_drop_prob: 0.1 63 | image_drop_prob: 0.1 64 | cond_drop_prob: 0.1 65 | 66 | # Noise sampler 67 | shift_noise: true 68 | shift_noise_mode: interpolated 69 | shift_noise_scale: 8 70 | 71 | # Evaluation 72 | eval_seed: 42 73 | eval_num_inference_steps: 30 74 | eval_guidance_scale: 3.0 75 | eval_height: ${data.height} 76 | eval_width: ${data.width} 77 | 78 | # optimizer definition 79 | # you can set different learning rates separately for each group of parameters, but note that if you do this you should specify EVERY trainable parameters 80 | optimizer: 81 | name: AdamW 82 | args: 83 | lr: 5e-5 84 | betas: [0.9, 0.999] 85 | weight_decay: 0.01 86 | params: 87 | cond_encoder: 88 | lr: 5e-5 89 | unet: 90 | lr: 5e-5 91 | 92 | scheduler: 93 | name: SequentialLR 94 | interval: step 95 | schedulers: 96 | - name: LinearLR 97 | interval: step 98 | args: 99 | start_factor: 1e-6 100 | end_factor: 1.0 101 | total_iters: 2000 102 | - name: ConstantLR 103 | interval: step 104 | args: 105 | factor: 1.0 106 | total_iters: 9999999 107 | milestones: [2000] 108 | 109 | trainer: 110 | max_epochs: 10 111 | log_every_n_steps: 10 112 | num_sanity_val_steps: 1 113 | val_check_interval: 2000 114 | enable_progress_bar: true 115 | precision: bf16-mixed 116 | gradient_clip_val: 1.0 117 | strategy: ddp 118 | accumulate_grad_batches: 1 119 | 120 | checkpoint: 121 | save_last: true # whether to save at each validation time 122 | save_top_k: -1 123 | every_n_epochs: 9999 # do not save at all for debug purpose 124 | -------------------------------------------------------------------------------- /configs/geometry-guidance/mvadapter_ig2mv_sdxl.yaml: -------------------------------------------------------------------------------- 1 | name: ig2mv 2 | tag: "r768-ortho-nv6-ig2mv-dcrowcol-sdxl" 3 | exp_root_dir: "outputs" 4 | seed: 42 5 | 6 | data_cls: mvadapter.data.multiview.MultiviewDataModule 7 | data: 8 | root_dir: data/texture_ortho6view_easylight 9 | scene_list: data/objaverse_list_6w.json 10 | background_color: gray 11 | image_names: ["0000", "0001", "0002", "0003", "0008", "0009"] 12 | image_modality: color 13 | num_views: 6 14 | 15 | prompt_db_path: data/objaverse_short_captions.json 16 | return_prompt: true 17 | 18 | projection_type: ORTHO 19 | 20 | source_image_modality: ["position", "normal"] 21 | position_offset: 0.5 22 | position_scale: 1.0 23 | 24 | reference_root_dir: ["data/texture_rand_easylight_objaverse"] 25 | reference_scene_list: ["data/objaverse_list_6w.json"] 26 | reference_image_modality: color 27 | reference_image_names: ["0000", "0001", "0002", "0003", "0004"] 28 | 29 | train_indices: [0, -8] 30 | val_indices: [-8, null] 31 | test_indices: [-8, null] 32 | 33 | height: 768 34 | width: 768 35 | 36 | batch_size: 1 37 | num_workers: 16 38 | 39 | system_cls: mvadapter.systems.mvadapter_image_sdxl.MVAdapterImageSDXLSystem 40 | system: 41 | check_train_every_n_steps: 1000 42 | cleanup_after_validation_step: true 43 | cleanup_after_test_step: true 44 | 45 | # Model / Adapter 46 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" 47 | pretrained_vae_name_or_path: "madebyollin/sdxl-vae-fp16-fix" 48 | pretrained_adapter_name_or_path: null 49 | init_adapter_kwargs: 50 | # Multi-view adapter 51 | self_attn_processor: "mvadapter.models.attention_processor.DecoupledMVRowColSelfAttnProcessor2_0" 52 | # Condition encoder 53 | cond_in_channels: 6 54 | # For training 55 | copy_attn_weights: true 56 | zero_init_module_keys: ["to_out_mv", "to_out_ref"] 57 | 58 | # Training 59 | train_cond_encoder: true 60 | trainable_modules: ["_mv", "_ref"] 61 | prompt_drop_prob: 0.1 62 | image_drop_prob: 0.1 63 | cond_drop_prob: 0.1 64 | 65 | # Noise sampler 66 | shift_noise: true 67 | shift_noise_mode: interpolated 68 | shift_noise_scale: 8 69 | 70 | # Evaluation 71 | eval_seed: 42 72 | eval_num_inference_steps: 30 73 | eval_guidance_scale: 3.0 74 | eval_height: ${data.height} 75 | eval_width: ${data.width} 76 | 77 | # optimizer definition 78 | # you can set different learning rates separately for each group of parameters, but note that if you do this you should specify EVERY trainable parameters 79 | optimizer: 80 | name: AdamW 81 | args: 82 | lr: 5e-5 83 | betas: [0.9, 0.999] 84 | weight_decay: 0.01 85 | params: 86 | cond_encoder: 87 | lr: 5e-5 88 | unet: 89 | lr: 5e-5 90 | 91 | scheduler: 92 | name: SequentialLR 93 | interval: step 94 | schedulers: 95 | - name: LinearLR 96 | interval: step 97 | args: 98 | start_factor: 1e-6 99 | end_factor: 1.0 100 | total_iters: 2000 101 | - name: ConstantLR 102 | interval: step 103 | args: 104 | factor: 1.0 105 | total_iters: 9999999 106 | milestones: [2000] 107 | 108 | trainer: 109 | max_epochs: 10 110 | log_every_n_steps: 10 111 | num_sanity_val_steps: 1 112 | val_check_interval: 2000 113 | enable_progress_bar: true 114 | precision: bf16-mixed 115 | gradient_clip_val: 1.0 116 | strategy: ddp 117 | accumulate_grad_batches: 1 118 | 119 | checkpoint: 120 | save_last: true # whether to save at each validation time 121 | save_top_k: -1 122 | every_n_epochs: 9999 # do not save at all for debug purpose 123 | -------------------------------------------------------------------------------- /configs/geometry-guidance/mvadapter_tg2mv_sdxl.yaml: -------------------------------------------------------------------------------- 1 | name: tg2mv 2 | tag: "r768-ortho-nv6-tg2mv-dcrowcol-sdxl" 3 | exp_root_dir: "outputs" 4 | seed: 42 5 | 6 | data_cls: mvadapter.data.multiview.MultiviewDataModule 7 | data: 8 | root_dir: data/texture_ortho10view_easylight_objaverse 9 | scene_list: data/objaverse_list_6w.json 10 | background_color: gray 11 | image_names: ["0000", "0001", "0002", "0003", "0008", "0009"] 12 | image_modality: color 13 | num_views: 6 14 | 15 | prompt_db_path: data/objaverse_short_captions.json 16 | return_prompt: true 17 | 18 | projection_type: ORTHO 19 | 20 | source_image_modality: ["position", "normal"] 21 | position_offset: 0.5 22 | position_scale: 1.0 23 | 24 | train_indices: [0, -8] 25 | val_indices: [-8, null] 26 | test_indices: [-8, null] 27 | 28 | height: 768 29 | width: 768 30 | 31 | batch_size: 1 32 | num_workers: 16 33 | 34 | system_cls: mvadapter.systems.mvadapter_text_sdxl.MVAdapterTextSDXLSystem 35 | system: 36 | check_train_every_n_steps: 1000 37 | cleanup_after_validation_step: true 38 | cleanup_after_test_step: true 39 | 40 | # Model / Adapter 41 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" 42 | pretrained_vae_name_or_path: "madebyollin/sdxl-vae-fp16-fix" 43 | pretrained_adapter_name_or_path: null 44 | init_adapter_kwargs: 45 | # Multi-view adapter 46 | self_attn_processor: "mvadapter.models.attention_processor.DecoupledMVRowColSelfAttnProcessor2_0" 47 | # Condition encoder 48 | cond_in_channels: 6 49 | # For training 50 | copy_attn_weights: true 51 | zero_init_module_keys: ["to_out_mv"] 52 | 53 | # Training 54 | train_cond_encoder: true 55 | trainable_modules: ["_mv"] 56 | prompt_drop_prob: 0.1 57 | image_drop_prob: 0.1 58 | cond_drop_prob: 0.1 59 | 60 | # Noise sampler 61 | shift_noise: true 62 | shift_noise_mode: interpolated 63 | shift_noise_scale: 8 64 | 65 | # Evaluation 66 | eval_seed: 42 67 | eval_num_inference_steps: 30 68 | eval_guidance_scale: 3.0 69 | eval_height: ${data.height} 70 | eval_width: ${data.width} 71 | 72 | # optimizer definition 73 | # you can set different learning rates separately for each group of parameters, but note that if you do this you should specify EVERY trainable parameters 74 | optimizer: 75 | name: AdamW 76 | args: 77 | lr: 5e-5 78 | betas: [0.9, 0.999] 79 | weight_decay: 0.01 80 | params: 81 | cond_encoder: 82 | lr: 5e-5 83 | unet: 84 | lr: 5e-5 85 | 86 | scheduler: 87 | name: SequentialLR 88 | interval: step 89 | schedulers: 90 | - name: LinearLR 91 | interval: step 92 | args: 93 | start_factor: 1e-6 94 | end_factor: 1.0 95 | total_iters: 2000 96 | - name: ConstantLR 97 | interval: step 98 | args: 99 | factor: 1.0 100 | total_iters: 9999999 101 | milestones: [2000] 102 | 103 | trainer: 104 | max_epochs: 10 105 | log_every_n_steps: 10 106 | num_sanity_val_steps: 1 107 | val_check_interval: 2000 108 | enable_progress_bar: true 109 | precision: bf16-mixed 110 | gradient_clip_val: 1.0 111 | strategy: ddp 112 | accumulate_grad_batches: 1 113 | 114 | checkpoint: 115 | save_last: true # whether to save at each validation time 116 | save_top_k: -1 117 | every_n_epochs: 9999 # do not save at all for debug purpose 118 | -------------------------------------------------------------------------------- /configs/view-guidance/mvadapter_i2mv_sd21.yaml: -------------------------------------------------------------------------------- 1 | name: i2mv 2 | tag: "r512-ortho-nv6-ele0-sd21" 3 | exp_root_dir: "outputs" 4 | seed: 42 5 | 6 | data_cls: mvadapter.data.multiview.MultiviewDataModule 7 | data: 8 | root_dir: "data/texture_ortho10view_easylight_objaverse" 9 | scene_list: "data/objaverse_list_6w.json" 10 | background_color: gray 11 | image_names: ["0000", "0004", "0001", "0002", "0003", "0005"] 12 | image_modality: color 13 | num_views: 6 14 | 15 | prompt_db_path: data/objaverse_short_captions.json 16 | return_prompt: true 17 | 18 | projection_type: ORTHO 19 | 20 | source_image_modality: ["plucker"] 21 | plucker_offset: 1.0 22 | plucker_scale: 2.0 23 | 24 | reference_root_dir: ["data/texture_rand_easylight_objaverse"] 25 | reference_scene_list: ["data/objaverse_list_6w.json"] 26 | reference_image_modality: color 27 | reference_image_names: ["0000", "0001", "0002", "0003", "0004"] 28 | 29 | train_indices: [0, -8] 30 | val_indices: [-8, null] 31 | test_indices: [-8, null] 32 | 33 | height: 512 34 | width: 512 35 | 36 | batch_size: 1 37 | num_workers: 16 38 | 39 | system_cls: mvadapter.systems.mvadapter_image_sd.MVAdapterImageSDSystem 40 | system: 41 | check_train_every_n_steps: 1000 42 | cleanup_after_validation_step: true 43 | cleanup_after_test_step: true 44 | 45 | # Model / Adapter 46 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 47 | pretrained_vae_name_or_path: null 48 | pretrained_adapter_name_or_path: null 49 | init_adapter_kwargs: 50 | # Multi-view adapter 51 | self_attn_processor: "mvadapter.models.attention_processor.DecoupledMVRowSelfAttnProcessor2_0" 52 | # Condition encoder 53 | cond_in_channels: 6 54 | # For training 55 | copy_attn_weights: true 56 | zero_init_module_keys: ["to_out_mv", "to_out_ref"] 57 | 58 | # Training 59 | train_cond_encoder: true 60 | trainable_modules: ["_mv", "_ref"] 61 | prompt_drop_prob: 0.1 62 | image_drop_prob: 0.1 63 | cond_drop_prob: 0.1 64 | 65 | # Noise sampler 66 | shift_noise: true 67 | shift_noise_mode: interpolated 68 | shift_noise_scale: 8 69 | 70 | # Evaluation 71 | eval_seed: 42 72 | eval_num_inference_steps: 30 73 | eval_guidance_scale: 3.0 74 | eval_height: ${data.height} 75 | eval_width: ${data.width} 76 | 77 | # optimizer definition 78 | # you can set different learning rates separately for each group of parameters, but note that if you do this you should specify EVERY trainable parameters 79 | optimizer: 80 | name: AdamW 81 | args: 82 | lr: 5e-5 83 | betas: [0.9, 0.999] 84 | weight_decay: 0.01 85 | params: 86 | cond_encoder: 87 | lr: 5e-5 88 | unet: 89 | lr: 5e-5 90 | 91 | scheduler: 92 | name: SequentialLR 93 | interval: step 94 | schedulers: 95 | - name: LinearLR 96 | interval: step 97 | args: 98 | start_factor: 1e-6 99 | end_factor: 1.0 100 | total_iters: 2000 101 | - name: ConstantLR 102 | interval: step 103 | args: 104 | factor: 1.0 105 | total_iters: 9999999 106 | milestones: [2000] 107 | 108 | trainer: 109 | max_epochs: 5 110 | log_every_n_steps: 10 111 | num_sanity_val_steps: 1 112 | val_check_interval: 2000 113 | enable_progress_bar: true 114 | precision: bf16-mixed 115 | gradient_clip_val: 1.0 116 | strategy: ddp 117 | accumulate_grad_batches: 1 118 | 119 | checkpoint: 120 | save_last: true # whether to save at each validation time 121 | save_top_k: -1 122 | every_n_epochs: 9999 # do not save at all for debug purpose 123 | -------------------------------------------------------------------------------- /configs/view-guidance/mvadapter_i2mv_sdxl.yaml: -------------------------------------------------------------------------------- 1 | name: i2mv 2 | tag: "r768-ortho-nv6-ele0-sdxl" 3 | exp_root_dir: "outputs" 4 | seed: 42 5 | 6 | data_cls: mvadapter.data.multiview.MultiviewDataModule 7 | data: 8 | root_dir: "data/texture_ortho10view_easylight_objaverse" 9 | scene_list: "data/objaverse_list_6w.json" 10 | background_color: gray 11 | image_names: ["0000", "0004", "0001", "0002", "0003", "0005"] 12 | image_modality: color 13 | num_views: 6 14 | 15 | prompt_db_path: data/objaverse_short_captions.json 16 | return_prompt: true 17 | 18 | projection_type: ORTHO 19 | 20 | source_image_modality: ["plucker"] 21 | plucker_offset: 1.0 22 | plucker_scale: 2.0 23 | 24 | reference_root_dir: ["data/texture_rand_easylight_objaverse"] 25 | reference_scene_list: ["data/objaverse_list_6w.json"] 26 | reference_image_modality: color 27 | reference_image_names: ["0000", "0001", "0002", "0003", "0004"] 28 | 29 | train_indices: [0, -8] 30 | val_indices: [-8, null] 31 | test_indices: [-8, null] 32 | 33 | height: 768 34 | width: 768 35 | 36 | batch_size: 1 37 | num_workers: 16 38 | 39 | system_cls: mvadapter.systems.mvadapter_image_sdxl.MVAdapterImageSDXLSystem 40 | system: 41 | check_train_every_n_steps: 1000 42 | cleanup_after_validation_step: true 43 | cleanup_after_test_step: true 44 | 45 | # Model / Adapter 46 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" 47 | pretrained_vae_name_or_path: "madebyollin/sdxl-vae-fp16-fix" 48 | pretrained_adapter_name_or_path: null 49 | init_adapter_kwargs: 50 | # Multi-view adapter 51 | self_attn_processor: "mvadapter.models.attention_processor.DecoupledMVRowSelfAttnProcessor2_0" 52 | # Condition encoder 53 | cond_in_channels: 6 54 | # For training 55 | copy_attn_weights: true 56 | zero_init_module_keys: ["to_out_mv", "to_out_ref"] 57 | 58 | # Training 59 | train_cond_encoder: true 60 | trainable_modules: ["_mv", "_ref"] 61 | prompt_drop_prob: 0.1 62 | image_drop_prob: 0.1 63 | cond_drop_prob: 0.1 64 | 65 | # Noise sampler 66 | shift_noise: true 67 | shift_noise_mode: interpolated 68 | shift_noise_scale: 8 69 | 70 | # Evaluation 71 | eval_seed: 42 72 | eval_num_inference_steps: 30 73 | eval_guidance_scale: 3.0 74 | eval_height: ${data.height} 75 | eval_width: ${data.width} 76 | 77 | # optimizer definition 78 | # you can set different learning rates separately for each group of parameters, but note that if you do this you should specify EVERY trainable parameters 79 | optimizer: 80 | name: AdamW 81 | args: 82 | lr: 5e-5 83 | betas: [0.9, 0.999] 84 | weight_decay: 0.01 85 | params: 86 | cond_encoder: 87 | lr: 5e-5 88 | unet: 89 | lr: 5e-5 90 | 91 | scheduler: 92 | name: SequentialLR 93 | interval: step 94 | schedulers: 95 | - name: LinearLR 96 | interval: step 97 | args: 98 | start_factor: 1e-6 99 | end_factor: 1.0 100 | total_iters: 2000 101 | - name: ConstantLR 102 | interval: step 103 | args: 104 | factor: 1.0 105 | total_iters: 9999999 106 | milestones: [2000] 107 | 108 | trainer: 109 | max_epochs: 10 110 | log_every_n_steps: 10 111 | num_sanity_val_steps: 1 112 | val_check_interval: 2000 113 | enable_progress_bar: true 114 | precision: bf16-mixed 115 | gradient_clip_val: 1.0 116 | strategy: ddp 117 | accumulate_grad_batches: 1 118 | 119 | checkpoint: 120 | save_last: true # whether to save at each validation time 121 | save_top_k: -1 122 | every_n_epochs: 9999 # do not save at all for debug purpose 123 | -------------------------------------------------------------------------------- /configs/view-guidance/mvadapter_i2mv_sdxl_aug_quantity.yaml: -------------------------------------------------------------------------------- 1 | name: i2mv 2 | tag: "r768-ortho-nv6-ele0-sdxl-aug-quantity" 3 | exp_root_dir: "outputs" 4 | seed: 42 5 | 6 | data_cls: mvadapter.data.multiview.MultiviewDataModule 7 | data: 8 | root_dir: "data/texture_ortho10view_easylight_objaverse" 9 | scene_list: "data/objaverse_list_6w.json" 10 | background_color: gray 11 | image_names: ["0000", "0004", "0001", "0002", "0003", "0005"] 12 | image_modality: color 13 | num_views: 6 14 | random_view_list: [[0, 1, 2, 3, 4, 5], [0, 2, 3, 4], [0, 2, 3], [0, 3]] 15 | 16 | prompt_db_path: data/objaverse_short_captions.json 17 | return_prompt: true 18 | 19 | projection_type: ORTHO 20 | 21 | source_image_modality: ["plucker"] 22 | plucker_offset: 1.0 23 | plucker_scale: 2.0 24 | 25 | reference_root_dir: ["data/texture_rand_easylight_objaverse"] 26 | reference_scene_list: ["data/objaverse_list_6w.json"] 27 | reference_image_modality: color 28 | reference_image_names: ["0000", "0001", "0002", "0003", "0004"] 29 | 30 | train_indices: [0, -8] 31 | val_indices: [-8, null] 32 | test_indices: [-8, null] 33 | 34 | height: 768 35 | width: 768 36 | 37 | batch_size: 1 38 | num_workers: 16 39 | 40 | system_cls: mvadapter.systems.mvadapter_image_sdxl.MVAdapterImageSDXLSystem 41 | system: 42 | check_train_every_n_steps: 1000 43 | cleanup_after_validation_step: true 44 | cleanup_after_test_step: true 45 | 46 | # Model / Adapter 47 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" 48 | pretrained_vae_name_or_path: "madebyollin/sdxl-vae-fp16-fix" 49 | pretrained_adapter_name_or_path: null 50 | init_adapter_kwargs: 51 | # Multi-view adapter 52 | self_attn_processor: "mvadapter.models.attention_processor.DecoupledMVRowSelfAttnProcessor2_0" 53 | # Condition encoder 54 | cond_in_channels: 6 55 | # For training 56 | copy_attn_weights: true 57 | zero_init_module_keys: ["to_out_mv", "to_out_ref"] 58 | 59 | # Training 60 | train_cond_encoder: true 61 | trainable_modules: ["_mv", "_ref"] 62 | prompt_drop_prob: 0.1 63 | image_drop_prob: 0.1 64 | cond_drop_prob: 0.1 65 | 66 | # Noise sampler 67 | shift_noise: true 68 | shift_noise_mode: interpolated 69 | shift_noise_scale: 8 70 | 71 | # Evaluation 72 | eval_seed: 42 73 | eval_num_inference_steps: 30 74 | eval_guidance_scale: 3.0 75 | eval_height: ${data.height} 76 | eval_width: ${data.width} 77 | 78 | # optimizer definition 79 | # you can set different learning rates separately for each group of parameters, but note that if you do this you should specify EVERY trainable parameters 80 | optimizer: 81 | name: AdamW 82 | args: 83 | lr: 5e-5 84 | betas: [0.9, 0.999] 85 | weight_decay: 0.01 86 | params: 87 | cond_encoder: 88 | lr: 5e-5 89 | unet: 90 | lr: 5e-5 91 | 92 | scheduler: 93 | name: SequentialLR 94 | interval: step 95 | schedulers: 96 | - name: LinearLR 97 | interval: step 98 | args: 99 | start_factor: 1e-6 100 | end_factor: 1.0 101 | total_iters: 2000 102 | - name: ConstantLR 103 | interval: step 104 | args: 105 | factor: 1.0 106 | total_iters: 9999999 107 | milestones: [2000] 108 | 109 | trainer: 110 | max_epochs: 10 111 | log_every_n_steps: 10 112 | num_sanity_val_steps: 1 113 | val_check_interval: 2000 114 | enable_progress_bar: true 115 | precision: bf16-mixed 116 | gradient_clip_val: 1.0 117 | strategy: ddp 118 | accumulate_grad_batches: 1 119 | 120 | checkpoint: 121 | save_last: true # whether to save at each validation time 122 | save_top_k: -1 123 | every_n_epochs: 9999 # do not save at all for debug purpose 124 | -------------------------------------------------------------------------------- /configs/view-guidance/mvadapter_t2mv_sd21.yaml: -------------------------------------------------------------------------------- 1 | name: t2mv 2 | tag: "r512-ortho-nv6-ele0-sd" 3 | exp_root_dir: "outputs" 4 | seed: 42 5 | 6 | data_cls: mvadapter.data.multiview.MultiviewDataModule 7 | data: 8 | root_dir: "data/texture_ortho10view_easylight_objaverse" 9 | scene_list: "data/objaverse_list_6w.json" 10 | background_color: gray 11 | image_names: ["0000", "0004", "0001", "0002", "0003", "0005"] 12 | image_modality: color 13 | num_views: 6 14 | 15 | prompt_db_path: data/objaverse_short_captions.json 16 | return_prompt: true 17 | 18 | projection_type: ORTHO 19 | 20 | source_image_modality: ["plucker"] 21 | plucker_offset: 1.0 22 | plucker_scale: 2.0 23 | 24 | train_indices: [0, -8] 25 | val_indices: [-8, null] 26 | test_indices: [-8, null] 27 | 28 | height: 512 29 | width: 512 30 | 31 | batch_size: 1 32 | num_workers: 16 33 | 34 | system_cls: mvadapter.systems.mvadapter_text_sd.MVAdapterTextSDSystem 35 | system: 36 | check_train_every_n_steps: 1000 37 | cleanup_after_validation_step: true 38 | cleanup_after_test_step: true 39 | 40 | # Model / Adapter 41 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-2-1-base" 42 | pretrained_vae_name_or_path: null 43 | pretrained_adapter_name_or_path: null 44 | init_adapter_kwargs: 45 | # Multi-view adapter 46 | self_attn_processor: "mvadapter.models.attention_processor.DecoupledMVRowSelfAttnProcessor2_0" 47 | # Condition encoder 48 | cond_in_channels: 6 49 | # For training 50 | copy_attn_weights: true 51 | zero_init_module_keys: ["to_out_mv"] 52 | 53 | # Training 54 | train_cond_encoder: true 55 | trainable_modules: ["_mv"] 56 | prompt_drop_prob: 0.1 57 | image_drop_prob: 0.1 58 | cond_drop_prob: 0.1 59 | 60 | # Noise sampler 61 | shift_noise: true 62 | shift_noise_mode: interpolated 63 | shift_noise_scale: 8 64 | 65 | # Evaluation 66 | eval_seed: 42 67 | eval_num_inference_steps: 30 68 | eval_guidance_scale: 3.0 69 | eval_height: ${data.height} 70 | eval_width: ${data.width} 71 | 72 | # optimizer definition 73 | # you can set different learning rates separately for each group of parameters, but note that if you do this you should specify EVERY trainable parameters 74 | optimizer: 75 | name: AdamW 76 | args: 77 | lr: 5e-5 78 | betas: [0.9, 0.999] 79 | weight_decay: 0.01 80 | params: 81 | cond_encoder: 82 | lr: 5e-5 83 | unet: 84 | lr: 5e-5 85 | 86 | scheduler: 87 | name: SequentialLR 88 | interval: step 89 | schedulers: 90 | - name: LinearLR 91 | interval: step 92 | args: 93 | start_factor: 1e-6 94 | end_factor: 1.0 95 | total_iters: 2000 96 | - name: ConstantLR 97 | interval: step 98 | args: 99 | factor: 1.0 100 | total_iters: 9999999 101 | milestones: [2000] 102 | 103 | trainer: 104 | max_epochs: 10 105 | log_every_n_steps: 10 106 | num_sanity_val_steps: 1 107 | val_check_interval: 2000 108 | enable_progress_bar: true 109 | precision: bf16-mixed 110 | gradient_clip_val: 1.0 111 | strategy: ddp 112 | accumulate_grad_batches: 1 113 | 114 | checkpoint: 115 | save_last: true # whether to save at each validation time 116 | save_top_k: -1 117 | every_n_epochs: 9999 # do not save at all for debug purpose 118 | -------------------------------------------------------------------------------- /configs/view-guidance/mvadapter_t2mv_sdxl.yaml: -------------------------------------------------------------------------------- 1 | name: t2mv 2 | tag: "r768-ortho-nv6-ele0-sdxl" 3 | exp_root_dir: "outputs" 4 | seed: 42 5 | 6 | data_cls: mvadapter.data.multiview.MultiviewDataModule 7 | data: 8 | root_dir: "data/texture_ortho10view_easylight_objaverse" 9 | scene_list: "data/objaverse_list_6w.json" 10 | background_color: gray 11 | image_names: ["0000", "0004", "0001", "0002", "0003", "0005"] 12 | image_modality: color 13 | num_views: 6 14 | 15 | prompt_db_path: data/objaverse_short_captions.json 16 | return_prompt: true 17 | 18 | projection_type: ORTHO 19 | 20 | source_image_modality: ["plucker"] 21 | plucker_offset: 1.0 22 | plucker_scale: 2.0 23 | 24 | train_indices: [0, -8] 25 | val_indices: [-8, null] 26 | test_indices: [-8, null] 27 | 28 | height: 768 29 | width: 768 30 | 31 | batch_size: 1 32 | num_workers: 16 33 | 34 | system_cls: mvadapter.systems.mvadapter_text_sdxl.MVAdapterTextSDXLSystem 35 | system: 36 | check_train_every_n_steps: 1000 37 | cleanup_after_validation_step: true 38 | cleanup_after_test_step: true 39 | 40 | # Model / Adapter 41 | pretrained_model_name_or_path: "stabilityai/stable-diffusion-xl-base-1.0" 42 | pretrained_vae_name_or_path: "madebyollin/sdxl-vae-fp16-fix" 43 | pretrained_adapter_name_or_path: null 44 | init_adapter_kwargs: 45 | # Multi-view adapter 46 | self_attn_processor: "mvadapter.models.attention_processor.DecoupledMVRowSelfAttnProcessor2_0" 47 | # Condition encoder 48 | cond_in_channels: 6 49 | # For training 50 | copy_attn_weights: true 51 | zero_init_module_keys: ["to_out_mv"] 52 | 53 | # Training 54 | train_cond_encoder: true 55 | trainable_modules: ["_mv"] 56 | prompt_drop_prob: 0.1 57 | image_drop_prob: 0.1 58 | cond_drop_prob: 0.1 59 | 60 | # Noise sampler 61 | shift_noise: true 62 | shift_noise_mode: interpolated 63 | shift_noise_scale: 8 64 | 65 | # Evaluation 66 | eval_seed: 42 67 | eval_num_inference_steps: 30 68 | eval_guidance_scale: 3.0 69 | eval_height: ${data.height} 70 | eval_width: ${data.width} 71 | 72 | # optimizer definition 73 | # you can set different learning rates separately for each group of parameters, but note that if you do this you should specify EVERY trainable parameters 74 | optimizer: 75 | name: AdamW 76 | args: 77 | lr: 5e-5 78 | betas: [0.9, 0.999] 79 | weight_decay: 0.01 80 | params: 81 | cond_encoder: 82 | lr: 5e-5 83 | unet: 84 | lr: 5e-5 85 | 86 | scheduler: 87 | name: SequentialLR 88 | interval: step 89 | schedulers: 90 | - name: LinearLR 91 | interval: step 92 | args: 93 | start_factor: 1e-6 94 | end_factor: 1.0 95 | total_iters: 2000 96 | - name: ConstantLR 97 | interval: step 98 | args: 99 | factor: 1.0 100 | total_iters: 9999999 101 | milestones: [2000] 102 | 103 | trainer: 104 | max_epochs: 10 105 | log_every_n_steps: 10 106 | num_sanity_val_steps: 1 107 | val_check_interval: 2000 108 | enable_progress_bar: true 109 | precision: bf16-mixed 110 | gradient_clip_val: 1.0 111 | strategy: ddp 112 | accumulate_grad_batches: 1 113 | 114 | checkpoint: 115 | save_last: true # whether to save at each validation time 116 | save_top_k: -1 117 | every_n_epochs: 9999 # do not save at all for debug purpose 118 | -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import logging 4 | import os 5 | import sys 6 | 7 | 8 | class ColoredFilter(logging.Filter): 9 | """ 10 | A logging filter to add color to certain log levels. 11 | """ 12 | 13 | RESET = "\033[0m" 14 | RED = "\033[31m" 15 | GREEN = "\033[32m" 16 | YELLOW = "\033[33m" 17 | BLUE = "\033[34m" 18 | MAGENTA = "\033[35m" 19 | CYAN = "\033[36m" 20 | 21 | COLORS = { 22 | "WARNING": YELLOW, 23 | "INFO": GREEN, 24 | "DEBUG": BLUE, 25 | "CRITICAL": MAGENTA, 26 | "ERROR": RED, 27 | } 28 | 29 | RESET = "\x1b[0m" 30 | 31 | def __init__(self): 32 | super().__init__() 33 | 34 | def filter(self, record): 35 | if record.levelname in self.COLORS: 36 | color_start = self.COLORS[record.levelname] 37 | record.levelname = f"{color_start}[{record.levelname}]" 38 | record.msg = f"{record.msg}{self.RESET}" 39 | return True 40 | 41 | 42 | def main(args, extras) -> None: 43 | # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning 44 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 45 | env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) 46 | env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] 47 | selected_gpus = [0] 48 | 49 | # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. 50 | # As far as Pytorch Lightning is concerned, we always use all available GPUs 51 | # (possibly filtered by CUDA_VISIBLE_DEVICES). 52 | devices = -1 53 | if len(env_gpus) > 0: 54 | # CUDA_VISIBLE_DEVICES was set already, e.g. within SLURM srun or higher-level script. 55 | n_gpus = len(env_gpus) 56 | else: 57 | selected_gpus = list(args.gpu.split(",")) 58 | n_gpus = len(selected_gpus) 59 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 60 | 61 | import pytorch_lightning as pl 62 | import torch 63 | from pytorch_lightning import Trainer 64 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 65 | from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger, WandbLogger 66 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 67 | 68 | if args.typecheck: 69 | from jaxtyping import install_import_hook 70 | 71 | install_import_hook("mvadapter", "typeguard.typechecked") 72 | 73 | from mvadapter.systems.base import BaseSystem 74 | from mvadapter.utils.callbacks import ( 75 | CodeSnapshotCallback, 76 | ConfigSnapshotCallback, 77 | CustomProgressBar, 78 | ProgressCallback, 79 | ) 80 | from mvadapter.utils.config import ExperimentConfig, load_config 81 | from mvadapter.utils.core import find 82 | from mvadapter.utils.misc import get_rank, time_recorder 83 | from mvadapter.utils.typing import Optional 84 | 85 | logger = logging.getLogger("pytorch_lightning") 86 | if args.verbose: 87 | logger.setLevel(logging.DEBUG) 88 | 89 | if args.benchmark: 90 | time_recorder.enable(True) 91 | 92 | for handler in logger.handlers: 93 | if handler.stream == sys.stderr: # type: ignore 94 | handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) 95 | handler.addFilter(ColoredFilter()) 96 | 97 | # parse YAML config to OmegaConf 98 | cfg: ExperimentConfig 99 | cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) 100 | 101 | dm = find(cfg.data_cls)(cfg.data) 102 | system: BaseSystem = find(cfg.system_cls)( 103 | cfg.system, resumed=cfg.resume is not None 104 | ) 105 | system.set_save_dir(os.path.join(cfg.trial_dir, "save")) 106 | 107 | callbacks = [] 108 | if args.train: 109 | callbacks += [ 110 | ModelCheckpoint( 111 | dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint 112 | ), 113 | LearningRateMonitor(logging_interval="step"), 114 | # CodeSnapshotCallback( 115 | # os.path.join(cfg.trial_dir, "code"), use_version=False 116 | # ), 117 | ConfigSnapshotCallback( 118 | args.config, 119 | cfg, 120 | os.path.join(cfg.trial_dir, "configs"), 121 | use_version=False, 122 | ), 123 | CustomProgressBar(refresh_rate=1), 124 | ] 125 | 126 | def write_to_text(file, lines): 127 | with open(file, "w") as f: 128 | for line in lines: 129 | f.write(line + "\n") 130 | 131 | loggers = [] 132 | if args.train: 133 | # make tensorboard logging dir to suppress warning 134 | rank_zero_only( 135 | lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) 136 | )() 137 | loggers += [ 138 | TensorBoardLogger(cfg.trial_dir, name="tb_logs"), 139 | ] 140 | if args.wandb: 141 | wandb_logger = WandbLogger( 142 | project="MV-Adapter", name=f"{cfg.name}-{cfg.tag}" 143 | ) 144 | system._wandb_logger = wandb_logger 145 | loggers += [wandb_logger] 146 | rank_zero_only( 147 | lambda: write_to_text( 148 | os.path.join(cfg.trial_dir, "cmd.txt"), 149 | ["python " + " ".join(sys.argv), str(args)], 150 | ) 151 | )() 152 | 153 | trainer = Trainer( 154 | callbacks=callbacks, 155 | logger=loggers, 156 | inference_mode=False, 157 | accelerator="gpu", 158 | devices=devices, 159 | **cfg.trainer, 160 | ) 161 | 162 | # set a different seed for each device 163 | # NOTE: use trainer.global_rank instead of get_rank() to avoid getting the local rank 164 | pl.seed_everything(cfg.seed + trainer.global_rank, workers=True) 165 | 166 | def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): 167 | if ckpt_path is None: 168 | return 169 | ckpt = torch.load(ckpt_path, map_location="cpu") 170 | system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) 171 | 172 | if args.train: 173 | trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) 174 | trainer.test(system, datamodule=dm) 175 | elif args.validate: 176 | # manually set epoch and global_step as they cannot be automatically resumed 177 | set_system_status(system, cfg.resume) 178 | trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) 179 | elif args.test: 180 | # manually set epoch and global_step as they cannot be automatically resumed 181 | set_system_status(system, cfg.resume) 182 | trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) 183 | 184 | 185 | if __name__ == "__main__": 186 | parser = argparse.ArgumentParser() 187 | parser.add_argument("--config", required=True, help="path to config file") 188 | parser.add_argument( 189 | "--gpu", 190 | default="0", 191 | help="GPU(s) to be used. 0 means use the 1st available GPU. " 192 | "1,2 means use the 2nd and 3rd available GPU. " 193 | "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " 194 | "this argument is ignored and all available GPUs are always used.", 195 | ) 196 | 197 | group = parser.add_mutually_exclusive_group(required=True) 198 | group.add_argument("--train", action="store_true") 199 | group.add_argument("--validate", action="store_true") 200 | group.add_argument("--test", action="store_true") 201 | 202 | parser.add_argument("--wandb", action="store_true", help="if true, log to wandb") 203 | 204 | parser.add_argument( 205 | "--verbose", action="store_true", help="if true, set logging level to DEBUG" 206 | ) 207 | 208 | parser.add_argument( 209 | "--benchmark", 210 | action="store_true", 211 | help="if true, set to benchmark mode to record running times", 212 | ) 213 | 214 | parser.add_argument( 215 | "--typecheck", 216 | action="store_true", 217 | help="whether to enable dynamic type checking", 218 | ) 219 | 220 | args, extras = parser.parse_known_args() 221 | main(args, extras) 222 | -------------------------------------------------------------------------------- /mvadapter/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/mvadapter/__init__.py -------------------------------------------------------------------------------- /mvadapter/loaders/__init__.py: -------------------------------------------------------------------------------- 1 | from .custom_adapter import CustomAdapterMixin 2 | -------------------------------------------------------------------------------- /mvadapter/loaders/custom_adapter.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Dict, Optional, Union 3 | 4 | import safetensors 5 | import torch 6 | from diffusers.utils import _get_model_file, logging 7 | from safetensors import safe_open 8 | 9 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name 10 | 11 | 12 | class CustomAdapterMixin: 13 | def init_custom_adapter(self, *args, **kwargs): 14 | self._init_custom_adapter(*args, **kwargs) 15 | 16 | def _init_custom_adapter(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | def load_custom_adapter( 20 | self, 21 | pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], 22 | weight_name: str, 23 | subfolder: Optional[str] = None, 24 | **kwargs, 25 | ): 26 | # Load the main state dict first. 27 | cache_dir = kwargs.pop("cache_dir", None) 28 | force_download = kwargs.pop("force_download", False) 29 | proxies = kwargs.pop("proxies", None) 30 | local_files_only = kwargs.pop("local_files_only", None) 31 | token = kwargs.pop("token", None) 32 | revision = kwargs.pop("revision", None) 33 | 34 | user_agent = { 35 | "file_type": "attn_procs_weights", 36 | "framework": "pytorch", 37 | } 38 | 39 | if not isinstance(pretrained_model_name_or_path_or_dict, dict): 40 | model_file = _get_model_file( 41 | pretrained_model_name_or_path_or_dict, 42 | weights_name=weight_name, 43 | subfolder=subfolder, 44 | cache_dir=cache_dir, 45 | force_download=force_download, 46 | proxies=proxies, 47 | local_files_only=local_files_only, 48 | token=token, 49 | revision=revision, 50 | user_agent=user_agent, 51 | ) 52 | if weight_name.endswith(".safetensors"): 53 | state_dict = {} 54 | with safe_open(model_file, framework="pt", device="cpu") as f: 55 | for key in f.keys(): 56 | state_dict[key] = f.get_tensor(key) 57 | else: 58 | state_dict = torch.load(model_file, map_location="cpu") 59 | else: 60 | state_dict = pretrained_model_name_or_path_or_dict 61 | 62 | self._load_custom_adapter(state_dict) 63 | 64 | def _load_custom_adapter(self, state_dict): 65 | raise NotImplementedError 66 | 67 | def save_custom_adapter( 68 | self, 69 | save_directory: Union[str, os.PathLike], 70 | weight_name: str, 71 | safe_serialization: bool = False, 72 | **kwargs, 73 | ): 74 | if os.path.isfile(save_directory): 75 | logger.error( 76 | f"Provided path ({save_directory}) should be a directory, not a file" 77 | ) 78 | return 79 | 80 | if safe_serialization: 81 | 82 | def save_function(weights, filename): 83 | return safetensors.torch.save_file( 84 | weights, filename, metadata={"format": "pt"} 85 | ) 86 | 87 | else: 88 | save_function = torch.save 89 | 90 | # Save the model 91 | state_dict = self._save_custom_adapter(**kwargs) 92 | save_function(state_dict, os.path.join(save_directory, weight_name)) 93 | logger.info( 94 | f"Custom adapter weights saved in {os.path.join(save_directory, weight_name)}" 95 | ) 96 | 97 | def _save_custom_adapter(self): 98 | raise NotImplementedError 99 | -------------------------------------------------------------------------------- /mvadapter/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/mvadapter/models/__init__.py -------------------------------------------------------------------------------- /mvadapter/schedulers/scheduler_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def get_sigmas(noise_scheduler, timesteps, n_dim=4, dtype=torch.float32, device=None): 5 | sigmas = noise_scheduler.sigmas.to(device=device, dtype=dtype) 6 | schedule_timesteps = noise_scheduler.timesteps.to(device) 7 | timesteps = timesteps.to(device) 8 | step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] 9 | sigma = sigmas[step_indices].flatten() 10 | while len(sigma.shape) < n_dim: 11 | sigma = sigma.unsqueeze(-1) 12 | return sigma 13 | 14 | 15 | def SNR_to_betas(snr): 16 | """ 17 | Converts SNR to betas 18 | """ 19 | # alphas_cumprod = pass 20 | # snr = (alpha / ) ** 2 21 | # alpha_t^2 / (1 - alpha_t^2) = snr 22 | alpha_t = (snr / (1 + snr)) ** 0.5 23 | alphas_cumprod = alpha_t**2 24 | alphas = alphas_cumprod / torch.cat( 25 | [torch.ones(1, device=snr.device), alphas_cumprod[:-1]] 26 | ) 27 | betas = 1 - alphas 28 | return betas 29 | 30 | 31 | def compute_snr(timesteps, noise_scheduler): 32 | """ 33 | Computes SNR as per Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5 34 | """ 35 | alphas_cumprod = noise_scheduler.alphas_cumprod 36 | sqrt_alphas_cumprod = alphas_cumprod**0.5 37 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 38 | 39 | # Expand the tensors. 40 | # Adapted from Min-SNR-Diffusion-Training/guided_diffusion/gaussian_diffusion.py at 521b624bd70c67cee4bdf49225915f5 41 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ 42 | timesteps 43 | ].float() 44 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 45 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 46 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 47 | 48 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( 49 | device=timesteps.device 50 | )[timesteps].float() 51 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): 52 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] 53 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) 54 | 55 | # Compute SNR. 56 | snr = (alpha / sigma) ** 2 57 | return snr 58 | 59 | 60 | def compute_alpha(timesteps, noise_scheduler): 61 | alphas_cumprod = noise_scheduler.alphas_cumprod 62 | sqrt_alphas_cumprod = alphas_cumprod**0.5 63 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ 64 | timesteps 65 | ].float() 66 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): 67 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] 68 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) 69 | 70 | return alpha 71 | -------------------------------------------------------------------------------- /mvadapter/schedulers/scheduling_shift_snr.py: -------------------------------------------------------------------------------- 1 | from typing import Any 2 | 3 | import torch 4 | 5 | from .scheduler_utils import SNR_to_betas, compute_snr 6 | 7 | 8 | class ShiftSNRScheduler: 9 | def __init__( 10 | self, 11 | noise_scheduler: Any, 12 | timesteps: Any, 13 | shift_scale: float, 14 | scheduler_class: Any, 15 | ): 16 | self.noise_scheduler = noise_scheduler 17 | self.timesteps = timesteps 18 | self.shift_scale = shift_scale 19 | self.scheduler_class = scheduler_class 20 | 21 | def _get_shift_scheduler(self): 22 | """ 23 | Prepare scheduler for shifted betas. 24 | 25 | :return: A scheduler object configured with shifted betas 26 | """ 27 | snr = compute_snr(self.timesteps, self.noise_scheduler) 28 | shifted_betas = SNR_to_betas(snr / self.shift_scale) 29 | 30 | return self.scheduler_class.from_config( 31 | self.noise_scheduler.config, trained_betas=shifted_betas.numpy() 32 | ) 33 | 34 | def _get_interpolated_shift_scheduler(self): 35 | """ 36 | Prepare scheduler for shifted betas and interpolate with the original betas in log space. 37 | 38 | :return: A scheduler object configured with interpolated shifted betas 39 | """ 40 | snr = compute_snr(self.timesteps, self.noise_scheduler) 41 | shifted_snr = snr / self.shift_scale 42 | 43 | weighting = self.timesteps.float() / ( 44 | self.noise_scheduler.config.num_train_timesteps - 1 45 | ) 46 | interpolated_snr = torch.exp( 47 | torch.log(snr) * (1 - weighting) + torch.log(shifted_snr) * weighting 48 | ) 49 | 50 | shifted_betas = SNR_to_betas(interpolated_snr) 51 | 52 | return self.scheduler_class.from_config( 53 | self.noise_scheduler.config, trained_betas=shifted_betas.numpy() 54 | ) 55 | 56 | @classmethod 57 | def from_scheduler( 58 | cls, 59 | noise_scheduler: Any, 60 | shift_mode: str = "default", 61 | timesteps: Any = None, 62 | shift_scale: float = 1.0, 63 | scheduler_class: Any = None, 64 | ): 65 | # Check input 66 | if timesteps is None: 67 | timesteps = torch.arange(0, noise_scheduler.config.num_train_timesteps) 68 | if scheduler_class is None: 69 | scheduler_class = noise_scheduler.__class__ 70 | 71 | # Create scheduler 72 | shift_scheduler = cls( 73 | noise_scheduler=noise_scheduler, 74 | timesteps=timesteps, 75 | shift_scale=shift_scale, 76 | scheduler_class=scheduler_class, 77 | ) 78 | 79 | if shift_mode == "default": 80 | return shift_scheduler._get_shift_scheduler() 81 | elif shift_mode == "interpolated": 82 | return shift_scheduler._get_interpolated_shift_scheduler() 83 | else: 84 | raise ValueError(f"Unknown shift_mode: {shift_mode}") 85 | 86 | 87 | if __name__ == "__main__": 88 | """ 89 | Compare the alpha values for different noise schedulers. 90 | """ 91 | import matplotlib.pyplot as plt 92 | from diffusers import DDPMScheduler 93 | 94 | from .scheduler_utils import compute_alpha 95 | 96 | # Base 97 | timesteps = torch.arange(0, 1000) 98 | noise_scheduler_base = DDPMScheduler.from_pretrained( 99 | "runwayml/stable-diffusion-v1-5", subfolder="scheduler" 100 | ) 101 | alpha = compute_alpha(timesteps, noise_scheduler_base) 102 | plt.plot(timesteps.numpy(), alpha.numpy(), label="Base") 103 | 104 | # Kolors 105 | num_train_timesteps_ = 1100 106 | timesteps_ = torch.arange(0, num_train_timesteps_) 107 | noise_kwargs = {"beta_end": 0.014, "num_train_timesteps": num_train_timesteps_} 108 | noise_scheduler_kolors = DDPMScheduler.from_config( 109 | noise_scheduler_base.config, **noise_kwargs 110 | ) 111 | alpha = compute_alpha(timesteps_, noise_scheduler_kolors) 112 | plt.plot(timesteps_.numpy(), alpha.numpy(), label="Kolors") 113 | 114 | # Shift betas 115 | shift_scale = 8.0 116 | noise_scheduler_shift = ShiftSNRScheduler.from_scheduler( 117 | noise_scheduler_base, shift_mode="default", shift_scale=shift_scale 118 | ) 119 | alpha = compute_alpha(timesteps, noise_scheduler_shift) 120 | plt.plot(timesteps.numpy(), alpha.numpy(), label="Shift Noise (scale 8.0)") 121 | 122 | # Shift betas (interpolated) 123 | noise_scheduler_inter = ShiftSNRScheduler.from_scheduler( 124 | noise_scheduler_base, shift_mode="interpolated", shift_scale=shift_scale 125 | ) 126 | alpha = compute_alpha(timesteps, noise_scheduler_inter) 127 | plt.plot(timesteps.numpy(), alpha.numpy(), label="Interpolated (scale 8.0)") 128 | 129 | # ZeroSNR 130 | noise_scheduler = DDPMScheduler.from_config( 131 | noise_scheduler_base.config, rescale_betas_zero_snr=True 132 | ) 133 | alpha = compute_alpha(timesteps, noise_scheduler) 134 | plt.plot(timesteps.numpy(), alpha.numpy(), label="ZeroSNR") 135 | 136 | plt.legend() 137 | plt.grid() 138 | plt.savefig("check_alpha.png") 139 | -------------------------------------------------------------------------------- /mvadapter/systems/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/mvadapter/systems/__init__.py -------------------------------------------------------------------------------- /mvadapter/systems/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | import pytorch_lightning as pl 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from ..utils.base import Updateable, update_end_if_possible, update_if_possible 9 | from ..utils.config import parse_structured 10 | from ..utils.core import debug, find, info, warn 11 | from ..utils.misc import ( 12 | C, 13 | cleanup, 14 | get_device, 15 | get_rank, 16 | load_module_weights, 17 | show_vram_usage, 18 | ) 19 | from ..utils.saving import SaverMixin 20 | from ..utils.typing import * 21 | from .utils import parse_optimizer, parse_scheduler 22 | 23 | 24 | class BaseSystem(pl.LightningModule, Updateable, SaverMixin): 25 | @dataclass 26 | class Config: 27 | optimizer: dict = field(default_factory=dict) 28 | scheduler: Optional[dict] = None 29 | weights: Optional[str] = None 30 | weights_ignore_modules: Optional[List[str]] = None 31 | weights_mapping: Optional[List[Dict[str, str]]] = None 32 | check_train_every_n_steps: int = 0 33 | check_val_limit_rank: int = 8 34 | cleanup_after_validation_step: bool = False 35 | cleanup_after_test_step: bool = False 36 | allow_tf32: bool = True 37 | 38 | cfg: Config 39 | 40 | def __init__(self, cfg, resumed=False) -> None: 41 | super().__init__() 42 | self.cfg = parse_structured(self.Config, cfg) 43 | self._save_dir: Optional[str] = None 44 | self._resumed: bool = resumed 45 | self._resumed_eval: bool = False 46 | self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} 47 | 48 | # weird fix for extra VRAM usage on rank 0 49 | # credit: https://discuss.pytorch.org/t/extra-10gb-memory-on-gpu-0-in-ddp-tutorial/118113 50 | torch.cuda.set_device(get_rank()) 51 | torch.cuda.empty_cache() 52 | 53 | torch.backends.cuda.matmul.allow_tf32 = self.cfg.allow_tf32 54 | 55 | self.configure() 56 | if self.cfg.weights is not None: 57 | self.load_weights( 58 | self.cfg.weights, 59 | self.cfg.weights_ignore_modules, 60 | self.cfg.weights_mapping, 61 | ) 62 | self.post_configure() 63 | 64 | def load_weights( 65 | self, 66 | weights: str, 67 | ignore_modules: Optional[List[str]] = None, 68 | mapping: Optional[List[Dict[str, str]]] = None, 69 | ): 70 | state_dict, epoch, global_step = load_module_weights( 71 | weights, 72 | ignore_modules=ignore_modules, 73 | mapping=mapping, 74 | map_location="cpu", 75 | ) 76 | self.load_state_dict(state_dict, strict=False) 77 | # restore step-dependent states 78 | self.do_update_step(epoch, global_step, on_load_weights=True) 79 | 80 | def set_resume_status(self, current_epoch: int, global_step: int): 81 | # restore correct epoch and global step in eval 82 | self._resumed_eval = True 83 | self._resumed_eval_status["current_epoch"] = current_epoch 84 | self._resumed_eval_status["global_step"] = global_step 85 | 86 | @property 87 | def resumed(self): 88 | # whether from resumed checkpoint 89 | return self._resumed 90 | 91 | @property 92 | def true_global_step(self): 93 | if self._resumed_eval: 94 | return self._resumed_eval_status["global_step"] 95 | else: 96 | return self.global_step 97 | 98 | @property 99 | def true_current_epoch(self): 100 | if self._resumed_eval: 101 | return self._resumed_eval_status["current_epoch"] 102 | else: 103 | return self.current_epoch 104 | 105 | def configure(self) -> None: 106 | pass 107 | 108 | def post_configure(self) -> None: 109 | """ 110 | executed after weights are loaded 111 | """ 112 | pass 113 | 114 | def C(self, value: Any) -> float: 115 | return C(value, self.true_current_epoch, self.true_global_step) 116 | 117 | def configure_optimizers(self): 118 | optim = parse_optimizer(self.cfg.optimizer, self) 119 | ret = { 120 | "optimizer": optim, 121 | } 122 | if self.cfg.scheduler is not None: 123 | ret.update( 124 | { 125 | "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), 126 | } 127 | ) 128 | return ret 129 | 130 | def on_fit_start(self) -> None: 131 | if self._save_dir is not None: 132 | info(f"Validation results will be saved to {self._save_dir}") 133 | else: 134 | warn( 135 | f"Saving directory not set for the system, visualization results will not be saved" 136 | ) 137 | 138 | def training_step(self, batch, batch_idx): 139 | raise NotImplementedError 140 | 141 | def check_train(self, batch, **kwargs): 142 | if ( 143 | self.global_rank == 0 144 | and self.cfg.check_train_every_n_steps > 0 145 | and self.true_global_step % self.cfg.check_train_every_n_steps == 0 146 | ): 147 | self.on_check_train(batch, **kwargs) 148 | 149 | def on_check_train(self, batch, outputs, **kwargs): 150 | pass 151 | 152 | def validation_step(self, batch, batch_idx): 153 | raise NotImplementedError 154 | 155 | def on_validation_epoch_end(self): 156 | pass 157 | 158 | def test_step(self, batch, batch_idx): 159 | raise NotImplementedError 160 | 161 | def on_test_epoch_end(self): 162 | pass 163 | 164 | def on_test_end(self) -> None: 165 | if self._save_dir is not None: 166 | info(f"Test results saved to {self._save_dir}") 167 | 168 | def on_predict_start(self) -> None: 169 | pass 170 | 171 | def predict_step(self, batch, batch_idx): 172 | pass 173 | 174 | def on_predict_epoch_end(self) -> None: 175 | pass 176 | 177 | def on_predict_end(self) -> None: 178 | pass 179 | 180 | def preprocess_data(self, batch, stage): 181 | pass 182 | 183 | """ 184 | Implementing on_after_batch_transfer of DataModule does the same. 185 | But on_after_batch_transfer does not support DP. 186 | """ 187 | 188 | def on_train_batch_start(self, batch, batch_idx, unused=0): 189 | self.preprocess_data(batch, "train") 190 | self.dataset = self.trainer.train_dataloader.dataset 191 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) 192 | self.do_update_step(self.true_current_epoch, self.true_global_step) 193 | 194 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): 195 | self.preprocess_data(batch, "validation") 196 | self.dataset = self.trainer.val_dataloaders.dataset 197 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) 198 | self.do_update_step(self.true_current_epoch, self.true_global_step) 199 | 200 | def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): 201 | self.preprocess_data(batch, "test") 202 | self.dataset = self.trainer.test_dataloaders.dataset 203 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) 204 | self.do_update_step(self.true_current_epoch, self.true_global_step) 205 | 206 | def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): 207 | self.preprocess_data(batch, "predict") 208 | self.dataset = self.trainer.predict_dataloaders.dataset 209 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) 210 | self.do_update_step(self.true_current_epoch, self.true_global_step) 211 | 212 | def on_train_batch_end(self, outputs, batch, batch_idx): 213 | self.dataset = self.trainer.train_dataloader.dataset 214 | update_end_if_possible( 215 | self.dataset, self.true_current_epoch, self.true_global_step 216 | ) 217 | self.do_update_step_end(self.true_current_epoch, self.true_global_step) 218 | 219 | def on_validation_batch_end(self, outputs, batch, batch_idx): 220 | self.dataset = self.trainer.val_dataloaders.dataset 221 | update_end_if_possible( 222 | self.dataset, self.true_current_epoch, self.true_global_step 223 | ) 224 | self.do_update_step_end(self.true_current_epoch, self.true_global_step) 225 | if self.cfg.cleanup_after_validation_step: 226 | # cleanup to save vram 227 | cleanup() 228 | 229 | def on_test_batch_end(self, outputs, batch, batch_idx): 230 | self.dataset = self.trainer.test_dataloaders.dataset 231 | update_end_if_possible( 232 | self.dataset, self.true_current_epoch, self.true_global_step 233 | ) 234 | self.do_update_step_end(self.true_current_epoch, self.true_global_step) 235 | if self.cfg.cleanup_after_test_step: 236 | # cleanup to save vram 237 | cleanup() 238 | 239 | def on_predict_batch_end(self, outputs, batch, batch_idx): 240 | self.dataset = self.trainer.predict_dataloaders.dataset 241 | update_end_if_possible( 242 | self.dataset, self.true_current_epoch, self.true_global_step 243 | ) 244 | self.do_update_step_end(self.true_current_epoch, self.true_global_step) 245 | if self.cfg.cleanup_after_test_step: 246 | # cleanup to save vram 247 | cleanup() 248 | 249 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 250 | pass 251 | 252 | def on_before_optimizer_step(self, optimizer): 253 | """ 254 | # some gradient-related debugging goes here, example: 255 | from lightning.pytorch.utilities import grad_norm 256 | norms = grad_norm(self.geometry, norm_type=2) 257 | print(norms) 258 | for name, p in self.named_parameters(): 259 | if p.grad is None: 260 | info(f"{name} does not receive gradients!") 261 | """ 262 | pass 263 | -------------------------------------------------------------------------------- /mvadapter/systems/utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | from diffusers import AutoencoderKL 7 | from torch.optim import lr_scheduler 8 | 9 | from ..utils.core import debug, find, info, warn 10 | from ..utils.typing import * 11 | 12 | """Diffusers Model Utils""" 13 | 14 | 15 | def vae_encode( 16 | vae: AutoencoderKL, 17 | pixel_values: Float[Tensor, "B 3 H W"], 18 | sample: bool = True, 19 | apply_scale: bool = True, 20 | ): 21 | latent_dist = vae.encode(pixel_values).latent_dist 22 | latents = latent_dist.sample() if sample else latent_dist.mode() 23 | if apply_scale: 24 | latents = latents * vae.config.scaling_factor 25 | return latents 26 | 27 | 28 | # Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt 29 | def encode_prompt( 30 | prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True 31 | ): 32 | prompt_embeds_list = [] 33 | 34 | captions = [] 35 | for caption in prompt_batch: 36 | if random.random() < proportion_empty_prompts: 37 | captions.append("") 38 | elif isinstance(caption, str): 39 | captions.append(caption) 40 | elif isinstance(caption, (list, np.ndarray)): 41 | # take a random caption if there are multiple 42 | captions.append(random.choice(caption) if is_train else caption[0]) 43 | 44 | with torch.no_grad(): 45 | for tokenizer, text_encoder in zip(tokenizers, text_encoders): 46 | text_inputs = tokenizer( 47 | captions, 48 | padding="max_length", 49 | max_length=tokenizer.model_max_length, 50 | truncation=True, 51 | return_tensors="pt", 52 | ) 53 | text_input_ids = text_inputs.input_ids 54 | prompt_embeds = text_encoder( 55 | text_input_ids.to(text_encoder.device), 56 | output_hidden_states=True, 57 | ) 58 | 59 | # We are only ALWAYS interested in the pooled output of the final text encoder 60 | pooled_prompt_embeds = prompt_embeds[0] 61 | prompt_embeds = prompt_embeds.hidden_states[-2] 62 | bs_embed, seq_len, _ = prompt_embeds.shape 63 | prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) 64 | prompt_embeds_list.append(prompt_embeds) 65 | 66 | prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) 67 | pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) 68 | return prompt_embeds, pooled_prompt_embeds 69 | 70 | 71 | CLIP_INPUT_MEAN = torch.as_tensor( 72 | [0.48145466, 0.4578275, 0.40821073], dtype=torch.float32 73 | )[None, :, None, None] 74 | CLIP_INPUT_STD = torch.as_tensor( 75 | [0.26862954, 0.26130258, 0.27577711], dtype=torch.float32 76 | )[None, :, None, None] 77 | 78 | 79 | def normalize_image_for_clip(image: Float[Tensor, "B C H W"]): 80 | return (image - CLIP_INPUT_MEAN.to(image)) / CLIP_INPUT_STD.to(image) 81 | 82 | 83 | """Training""" 84 | 85 | 86 | def get_scheduler(name): 87 | if hasattr(lr_scheduler, name): 88 | return getattr(lr_scheduler, name) 89 | else: 90 | raise NotImplementedError 91 | 92 | 93 | def getattr_recursive(m, attr): 94 | for name in attr.split("."): 95 | m = getattr(m, name) 96 | return m 97 | 98 | 99 | def get_parameters(model, name): 100 | module = getattr_recursive(model, name) 101 | if isinstance(module, nn.Module): 102 | return module.parameters() 103 | elif isinstance(module, nn.Parameter): 104 | return module 105 | return [] 106 | 107 | 108 | def parse_optimizer(config, model): 109 | if hasattr(config, "params"): 110 | params = [ 111 | {"params": get_parameters(model, name), "name": name, **args} 112 | for name, args in config.params.items() 113 | ] 114 | debug(f"Specify optimizer params: {config.params}") 115 | else: 116 | params = model.parameters() 117 | if config.name in ["FusedAdam"]: 118 | import apex 119 | 120 | optim = getattr(apex.optimizers, config.name)(params, **config.args) 121 | elif config.name in ["Adam8bit", "AdamW8bit"]: 122 | import bitsandbytes as bnb 123 | 124 | optim = bnb.optim.Adam8bit(params, **config.args) 125 | else: 126 | optim = getattr(torch.optim, config.name)(params, **config.args) 127 | return optim 128 | 129 | 130 | def parse_scheduler_to_instance(config, optimizer): 131 | if config.name == "ChainedScheduler": 132 | schedulers = [ 133 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 134 | ] 135 | scheduler = lr_scheduler.ChainedScheduler(schedulers) 136 | elif config.name == "Sequential": 137 | schedulers = [ 138 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 139 | ] 140 | scheduler = lr_scheduler.SequentialLR( 141 | optimizer, schedulers, milestones=config.milestones 142 | ) 143 | else: 144 | scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) 145 | return scheduler 146 | 147 | 148 | def parse_scheduler(config, optimizer): 149 | interval = config.get("interval", "epoch") 150 | assert interval in ["epoch", "step"] 151 | if config.name == "SequentialLR": 152 | scheduler = { 153 | "scheduler": lr_scheduler.SequentialLR( 154 | optimizer, 155 | [ 156 | parse_scheduler(conf, optimizer)["scheduler"] 157 | for conf in config.schedulers 158 | ], 159 | milestones=config.milestones, 160 | ), 161 | "interval": interval, 162 | } 163 | elif config.name == "ChainedScheduler": 164 | scheduler = { 165 | "scheduler": lr_scheduler.ChainedScheduler( 166 | [ 167 | parse_scheduler(conf, optimizer)["scheduler"] 168 | for conf in config.schedulers 169 | ] 170 | ), 171 | "interval": interval, 172 | } 173 | else: 174 | scheduler = { 175 | "scheduler": get_scheduler(config.name)(optimizer, **config.args), 176 | "interval": interval, 177 | } 178 | return scheduler 179 | -------------------------------------------------------------------------------- /mvadapter/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .mesh_utils.camera import get_camera, get_orthogonal_camera 2 | from .geometry import get_plucker_embeds_from_cameras_ortho 3 | from .saving import make_image_grid, tensor_to_image, image_to_tensor 4 | -------------------------------------------------------------------------------- /mvadapter/utils/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from .config import parse_structured 7 | from .misc import get_device, load_module_weights 8 | from .typing import * 9 | 10 | 11 | class Configurable: 12 | @dataclass 13 | class Config: 14 | pass 15 | 16 | def __init__(self, cfg: Optional[dict] = None) -> None: 17 | super().__init__() 18 | self.cfg = parse_structured(self.Config, cfg) 19 | 20 | 21 | class Updateable: 22 | def do_update_step( 23 | self, epoch: int, global_step: int, on_load_weights: bool = False 24 | ): 25 | for attr in self.__dir__(): 26 | if attr.startswith("_"): 27 | continue 28 | try: 29 | module = getattr(self, attr) 30 | except: 31 | continue # ignore attributes like property, which can't be retrived using getattr? 32 | if isinstance(module, Updateable): 33 | module.do_update_step( 34 | epoch, global_step, on_load_weights=on_load_weights 35 | ) 36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights) 37 | 38 | def do_update_step_end(self, epoch: int, global_step: int): 39 | for attr in self.__dir__(): 40 | if attr.startswith("_"): 41 | continue 42 | try: 43 | module = getattr(self, attr) 44 | except: 45 | continue # ignore attributes like property, which can't be retrived using getattr? 46 | if isinstance(module, Updateable): 47 | module.do_update_step_end(epoch, global_step) 48 | self.update_step_end(epoch, global_step) 49 | 50 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 51 | # override this method to implement custom update logic 52 | # if on_load_weights is True, you should be careful doing things related to model evaluations, 53 | # as the models and tensors are not guarenteed to be on the same device 54 | pass 55 | 56 | def update_step_end(self, epoch: int, global_step: int): 57 | pass 58 | 59 | 60 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None: 61 | if isinstance(module, Updateable): 62 | module.do_update_step(epoch, global_step) 63 | 64 | 65 | def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: 66 | if isinstance(module, Updateable): 67 | module.do_update_step_end(epoch, global_step) 68 | 69 | 70 | class BaseObject(Updateable): 71 | @dataclass 72 | class Config: 73 | pass 74 | 75 | cfg: Config # add this to every subclass of BaseObject to enable static type checking 76 | 77 | def __init__( 78 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 79 | ) -> None: 80 | super().__init__() 81 | self.cfg = parse_structured(self.Config, cfg) 82 | self.device = get_device() 83 | self.configure(*args, **kwargs) 84 | 85 | def configure(self, *args, **kwargs) -> None: 86 | pass 87 | 88 | 89 | class BaseModule(nn.Module, Updateable): 90 | @dataclass 91 | class Config: 92 | weights: Optional[str] = None 93 | 94 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 95 | 96 | def __init__( 97 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 98 | ) -> None: 99 | super().__init__() 100 | self.cfg = parse_structured(self.Config, cfg) 101 | self.device = get_device() 102 | self._non_modules = {} 103 | self.configure(*args, **kwargs) 104 | if self.cfg.weights is not None: 105 | # format: path/to/weights:module_name 106 | weights_path, module_name = self.cfg.weights.split(":") 107 | state_dict, epoch, global_step = load_module_weights( 108 | weights_path, module_name=module_name, map_location="cpu" 109 | ) 110 | self.load_state_dict(state_dict) 111 | self.do_update_step( 112 | epoch, global_step, on_load_weights=True 113 | ) # restore states 114 | 115 | def configure(self, *args, **kwargs) -> None: 116 | pass 117 | 118 | def register_non_module(self, name: str, module: nn.Module) -> None: 119 | # non-modules won't be treated as model parameters 120 | self._non_modules[name] = module 121 | 122 | def non_module(self, name: str): 123 | return self._non_modules.get(name, None) 124 | -------------------------------------------------------------------------------- /mvadapter/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | 5 | import pytorch_lightning 6 | 7 | from .config import dump_config 8 | from .misc import parse_version 9 | 10 | if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): 11 | from pytorch_lightning.callbacks import Callback 12 | else: 13 | from pytorch_lightning.callbacks.base import Callback 14 | 15 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 16 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 17 | 18 | 19 | class VersionedCallback(Callback): 20 | def __init__(self, save_root, version=None, use_version=True): 21 | self.save_root = save_root 22 | self._version = version 23 | self.use_version = use_version 24 | 25 | @property 26 | def version(self) -> int: 27 | """Get the experiment version. 28 | 29 | Returns: 30 | The experiment version if specified else the next version. 31 | """ 32 | if self._version is None: 33 | self._version = self._get_next_version() 34 | return self._version 35 | 36 | def _get_next_version(self): 37 | existing_versions = [] 38 | if os.path.isdir(self.save_root): 39 | for f in os.listdir(self.save_root): 40 | bn = os.path.basename(f) 41 | if bn.startswith("version_"): 42 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") 43 | existing_versions.append(int(dir_ver)) 44 | if len(existing_versions) == 0: 45 | return 0 46 | return max(existing_versions) + 1 47 | 48 | @property 49 | def savedir(self): 50 | if not self.use_version: 51 | return self.save_root 52 | return os.path.join( 53 | self.save_root, 54 | ( 55 | self.version 56 | if isinstance(self.version, str) 57 | else f"version_{self.version}" 58 | ), 59 | ) 60 | 61 | 62 | class CodeSnapshotCallback(VersionedCallback): 63 | def __init__(self, save_root, version=None, use_version=True): 64 | super().__init__(save_root, version, use_version) 65 | 66 | def get_file_list(self): 67 | return [ 68 | b.decode() 69 | for b in set( 70 | subprocess.check_output( 71 | 'git ls-files -- ":!:load/*"', shell=True 72 | ).splitlines() 73 | ) 74 | | set( # hard code, TODO: use config to exclude folders or files 75 | subprocess.check_output( 76 | "git ls-files --others --exclude-standard", shell=True 77 | ).splitlines() 78 | ) 79 | ] 80 | 81 | @rank_zero_only 82 | def save_code_snapshot(self): 83 | os.makedirs(self.savedir, exist_ok=True) 84 | for f in self.get_file_list(): 85 | if not os.path.exists(f) or os.path.isdir(f): 86 | continue 87 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) 88 | shutil.copyfile(f, os.path.join(self.savedir, f)) 89 | 90 | def on_fit_start(self, trainer, pl_module): 91 | try: 92 | self.save_code_snapshot() 93 | except: 94 | rank_zero_warn( 95 | "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." 96 | ) 97 | 98 | 99 | class ConfigSnapshotCallback(VersionedCallback): 100 | def __init__(self, config_path, config, save_root, version=None, use_version=True): 101 | super().__init__(save_root, version, use_version) 102 | self.config_path = config_path 103 | self.config = config 104 | 105 | @rank_zero_only 106 | def save_config_snapshot(self): 107 | os.makedirs(self.savedir, exist_ok=True) 108 | dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) 109 | shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) 110 | 111 | def on_fit_start(self, trainer, pl_module): 112 | self.save_config_snapshot() 113 | 114 | 115 | class CustomProgressBar(TQDMProgressBar): 116 | def get_metrics(self, *args, **kwargs): 117 | # don't show the version number 118 | items = super().get_metrics(*args, **kwargs) 119 | items.pop("v_num", None) 120 | return items 121 | 122 | 123 | class ProgressCallback(Callback): 124 | def __init__(self, save_path): 125 | super().__init__() 126 | self.save_path = save_path 127 | self._file_handle = None 128 | 129 | @property 130 | def file_handle(self): 131 | if self._file_handle is None: 132 | self._file_handle = open(self.save_path, "w") 133 | return self._file_handle 134 | 135 | @rank_zero_only 136 | def write(self, msg: str) -> None: 137 | self.file_handle.seek(0) 138 | self.file_handle.truncate() 139 | self.file_handle.write(msg) 140 | self.file_handle.flush() 141 | 142 | @rank_zero_only 143 | def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): 144 | self.write( 145 | f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" 146 | ) 147 | 148 | @rank_zero_only 149 | def on_validation_start(self, trainer, pl_module): 150 | self.write(f"Rendering validation image ...") 151 | 152 | @rank_zero_only 153 | def on_test_start(self, trainer, pl_module): 154 | self.write(f"Rendering video ...") 155 | 156 | @rank_zero_only 157 | def on_predict_start(self, trainer, pl_module): 158 | self.write(f"Exporting mesh assets ...") 159 | -------------------------------------------------------------------------------- /mvadapter/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | 5 | from omegaconf import OmegaConf 6 | 7 | from .core import debug, find, info, warn 8 | from .typing import * 9 | 10 | # ============ Register OmegaConf Resolvers ============= # 11 | OmegaConf.register_new_resolver( 12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) 13 | ) 14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b) 15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b) 16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b) 17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b) 18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) 19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) 20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: s.replace(" ", sub)) 21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) 22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0) 23 | OmegaConf.register_new_resolver("not", lambda s: not s) 24 | 25 | 26 | def calc_num_train_steps(num_data, batch_size, max_epochs, num_nodes, num_cards=8): 27 | return int(num_data / (num_nodes * num_cards * batch_size)) * max_epochs 28 | 29 | 30 | OmegaConf.register_new_resolver("calc_num_train_steps", calc_num_train_steps) 31 | 32 | # ======================================================= # 33 | 34 | 35 | # ============== Automatic Name Resolvers =============== # 36 | def get_naming_convention(cfg): 37 | # TODO 38 | name = f"lrm_{cfg.system.backbone.num_layers}" 39 | return name 40 | 41 | 42 | # ======================================================= # 43 | 44 | 45 | @dataclass 46 | class ExperimentConfig: 47 | name: str = "default" 48 | description: str = "" 49 | tag: str = "" 50 | seed: int = 0 51 | use_timestamp: bool = True 52 | timestamp: Optional[str] = None 53 | exp_root_dir: str = "outputs" 54 | 55 | ### these shouldn't be set manually 56 | exp_dir: str = "outputs/default" 57 | trial_name: str = "exp" 58 | trial_dir: str = "outputs/default/exp" 59 | n_gpus: int = 1 60 | ### 61 | 62 | resume: Optional[str] = None 63 | 64 | data_cls: str = "" 65 | data: dict = field(default_factory=dict) 66 | 67 | system_cls: str = "" 68 | system: dict = field(default_factory=dict) 69 | 70 | # accept pytorch-lightning trainer parameters 71 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api 72 | trainer: dict = field(default_factory=dict) 73 | 74 | # accept pytorch-lightning checkpoint callback parameters 75 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint 76 | checkpoint: dict = field(default_factory=dict) 77 | 78 | 79 | def load_config( 80 | *yamls: str, cli_args: list = [], from_string=False, makedirs=True, **kwargs 81 | ) -> Any: 82 | if from_string: 83 | parse_func = OmegaConf.create 84 | else: 85 | parse_func = OmegaConf.load 86 | yaml_confs = [] 87 | for y in yamls: 88 | conf = parse_func(y) 89 | extends = conf.pop("extends", None) 90 | if extends: 91 | assert os.path.exists(extends), f"File {extends} does not exist." 92 | yaml_confs.append(OmegaConf.load(extends)) 93 | yaml_confs.append(conf) 94 | cli_conf = OmegaConf.from_cli(cli_args) 95 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) 96 | OmegaConf.resolve(cfg) 97 | assert isinstance(cfg, DictConfig) 98 | scfg: ExperimentConfig = parse_structured(ExperimentConfig, cfg) 99 | 100 | # post processing 101 | # auto naming 102 | if scfg.name == "auto": 103 | scfg.name = get_naming_convention(scfg) 104 | # add timestamp 105 | if not scfg.tag and not scfg.use_timestamp: 106 | raise ValueError("Either tag is specified or use_timestamp is True.") 107 | scfg.trial_name = scfg.tag 108 | # if resume from an existing config, scfg.timestamp should not be None 109 | if scfg.timestamp is None: 110 | scfg.timestamp = "" 111 | if scfg.use_timestamp: 112 | if scfg.n_gpus > 1: 113 | warn( 114 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." 115 | ) 116 | else: 117 | scfg.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") 118 | # make directories 119 | scfg.trial_name += scfg.timestamp 120 | scfg.exp_dir = os.path.join(scfg.exp_root_dir, scfg.name) 121 | scfg.trial_dir = os.path.join(scfg.exp_dir, scfg.trial_name) 122 | 123 | if makedirs: 124 | os.makedirs(scfg.trial_dir, exist_ok=True) 125 | 126 | return scfg 127 | 128 | 129 | def config_to_primitive(config, resolve: bool = True) -> Any: 130 | return OmegaConf.to_container(config, resolve=resolve) 131 | 132 | 133 | def dump_config(path: str, config) -> None: 134 | with open(path, "w") as fp: 135 | OmegaConf.save(config=config, f=fp) 136 | 137 | 138 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 139 | scfg = OmegaConf.merge(OmegaConf.structured(fields), cfg) 140 | return scfg 141 | -------------------------------------------------------------------------------- /mvadapter/utils/core.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | ### grammar sugar for logging utilities ### 4 | import logging 5 | 6 | logger = logging.getLogger("pytorch_lightning") 7 | 8 | from pytorch_lightning.utilities.rank_zero import ( 9 | rank_zero_debug, 10 | rank_zero_info, 11 | rank_zero_only, 12 | ) 13 | 14 | 15 | def find(cls_string): 16 | module_string = ".".join(cls_string.split(".")[:-1]) 17 | cls_name = cls_string.split(".")[-1] 18 | module = importlib.import_module(module_string, package=None) 19 | cls = getattr(module, cls_name) 20 | return cls 21 | 22 | 23 | debug = rank_zero_debug 24 | info = rank_zero_info 25 | 26 | 27 | @rank_zero_only 28 | def warn(*args, **kwargs): 29 | logger.warn(*args, **kwargs) 30 | -------------------------------------------------------------------------------- /mvadapter/utils/geometry.py: -------------------------------------------------------------------------------- 1 | from typing import List, Optional, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from torch.nn import functional as F 6 | 7 | 8 | def get_position_map_from_depth(depth, mask, intrinsics, extrinsics, image_wh=None): 9 | """Compute the position map from the depth map and the camera parameters for a batch of views. 10 | 11 | Args: 12 | depth (torch.Tensor): The depth maps with the shape (B, H, W, 1). 13 | mask (torch.Tensor): The masks with the shape (B, H, W, 1). 14 | intrinsics (torch.Tensor): The camera intrinsics matrices with the shape (B, 3, 3). 15 | extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4). 16 | image_wh (Tuple[int, int]): The image width and height. 17 | 18 | Returns: 19 | torch.Tensor: The position maps with the shape (B, H, W, 3). 20 | """ 21 | if image_wh is None: 22 | image_wh = depth.shape[2], depth.shape[1] 23 | 24 | B, H, W, _ = depth.shape 25 | depth = depth.squeeze(-1) 26 | 27 | u_coord, v_coord = torch.meshgrid( 28 | torch.arange(image_wh[0]), torch.arange(image_wh[1]), indexing="xy" 29 | ) 30 | u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) 31 | v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) 32 | 33 | # Compute the position map by back-projecting depth pixels to 3D space 34 | x = ( 35 | (u_coord - intrinsics[:, 0, 2].unsqueeze(-1).unsqueeze(-1)) 36 | * depth 37 | / intrinsics[:, 0, 0].unsqueeze(-1).unsqueeze(-1) 38 | ) 39 | y = ( 40 | (v_coord - intrinsics[:, 1, 2].unsqueeze(-1).unsqueeze(-1)) 41 | * depth 42 | / intrinsics[:, 1, 1].unsqueeze(-1).unsqueeze(-1) 43 | ) 44 | z = depth 45 | 46 | # Concatenate to form the 3D coordinates in the camera frame 47 | camera_coords = torch.stack([x, y, z], dim=-1) 48 | 49 | # Apply the extrinsic matrix to get coordinates in the world frame 50 | coords_homogeneous = torch.nn.functional.pad( 51 | camera_coords, (0, 1), "constant", 1.0 52 | ) # Add a homogeneous coordinate 53 | world_coords = torch.matmul( 54 | coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2) 55 | ).view(B, H, W, 4) 56 | 57 | # Apply the mask to the position map 58 | position_map = world_coords[..., :3] * mask 59 | 60 | return position_map 61 | 62 | 63 | def get_position_map_from_depth_ortho( 64 | depth, mask, extrinsics, ortho_scale, image_wh=None 65 | ): 66 | """Compute the position map from the depth map and the camera parameters for a batch of views 67 | using orthographic projection with a given ortho_scale. 68 | 69 | Args: 70 | depth (torch.Tensor): The depth maps with the shape (B, H, W, 1). 71 | mask (torch.Tensor): The masks with the shape (B, H, W, 1). 72 | extrinsics (torch.Tensor): The camera extrinsics matrices with the shape (B, 4, 4). 73 | ortho_scale (torch.Tensor): The scaling factor for the orthographic projection with the shape (B, 1, 1, 1). 74 | image_wh (Tuple[int, int]): Optional. The image width and height. 75 | 76 | Returns: 77 | torch.Tensor: The position maps with the shape (B, H, W, 3). 78 | """ 79 | if image_wh is None: 80 | image_wh = depth.shape[2], depth.shape[1] 81 | 82 | B, H, W, _ = depth.shape 83 | depth = depth.squeeze(-1) 84 | 85 | # Generating grid of coordinates in the image space 86 | u_coord, v_coord = torch.meshgrid( 87 | torch.arange(0, image_wh[0]), torch.arange(0, image_wh[1]), indexing="xy" 88 | ) 89 | u_coord = u_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) 90 | v_coord = v_coord.type_as(depth).unsqueeze(0).expand(B, -1, -1) 91 | 92 | # Compute the position map using orthographic projection with ortho_scale 93 | x = (u_coord - image_wh[0] / 2) * ortho_scale / image_wh[0] 94 | y = (v_coord - image_wh[1] / 2) * ortho_scale / image_wh[1] 95 | z = depth 96 | 97 | # Concatenate to form the 3D coordinates in the camera frame 98 | camera_coords = torch.stack([x, y, z], dim=-1) 99 | 100 | # Apply the extrinsic matrix to get coordinates in the world frame 101 | coords_homogeneous = torch.nn.functional.pad( 102 | camera_coords, (0, 1), "constant", 1.0 103 | ) # Add a homogeneous coordinate 104 | world_coords = torch.matmul( 105 | coords_homogeneous.view(B, -1, 4), extrinsics.transpose(1, 2) 106 | ).view(B, H, W, 4) 107 | 108 | # Apply the mask to the position map 109 | position_map = world_coords[..., :3] * mask 110 | 111 | return position_map 112 | 113 | 114 | def get_opencv_from_blender(matrix_world, fov=None, image_size=None): 115 | # convert matrix_world to opencv format extrinsics 116 | opencv_world_to_cam = matrix_world.inverse() 117 | opencv_world_to_cam[1, :] *= -1 118 | opencv_world_to_cam[2, :] *= -1 119 | R, T = opencv_world_to_cam[:3, :3], opencv_world_to_cam[:3, 3] 120 | 121 | if fov is None: # orthographic camera 122 | return R, T 123 | 124 | R, T = R.unsqueeze(0), T.unsqueeze(0) 125 | # convert fov to opencv format intrinsics 126 | focal = 1 / np.tan(fov / 2) 127 | intrinsics = np.diag(np.array([focal, focal, 1])).astype(np.float32) 128 | opencv_cam_matrix = ( 129 | torch.from_numpy(intrinsics).unsqueeze(0).float().to(matrix_world.device) 130 | ) 131 | opencv_cam_matrix[:, :2, -1] += torch.tensor([image_size / 2, image_size / 2]).to( 132 | matrix_world.device 133 | ) 134 | opencv_cam_matrix[:, [0, 1], [0, 1]] *= image_size / 2 135 | 136 | return R, T, opencv_cam_matrix 137 | 138 | 139 | def get_ray_directions( 140 | H: int, 141 | W: int, 142 | focal: float, 143 | principal: Optional[Tuple[float, float]] = None, 144 | use_pixel_centers: bool = True, 145 | ) -> torch.Tensor: 146 | """ 147 | Get ray directions for all pixels in camera coordinate. 148 | Args: 149 | H, W, focal, principal, use_pixel_centers: image height, width, focal length, principal point and whether use pixel centers 150 | Outputs: 151 | directions: (H, W, 3), the direction of the rays in camera coordinate 152 | """ 153 | pixel_center = 0.5 if use_pixel_centers else 0 154 | cx, cy = W / 2, H / 2 if principal is None else principal 155 | i, j = torch.meshgrid( 156 | torch.arange(W, dtype=torch.float32) + pixel_center, 157 | torch.arange(H, dtype=torch.float32) + pixel_center, 158 | indexing="xy", 159 | ) 160 | directions = torch.stack( 161 | [(i - cx) / focal, -(j - cy) / focal, -torch.ones_like(i)], -1 162 | ) 163 | return F.normalize(directions, dim=-1) 164 | 165 | 166 | def get_rays( 167 | directions: torch.Tensor, c2w: torch.Tensor 168 | ) -> Tuple[torch.Tensor, torch.Tensor]: 169 | """ 170 | Get ray origins and directions from camera coordinates to world coordinates 171 | Args: 172 | directions: (H, W, 3) ray directions in camera coordinates 173 | c2w: (4, 4) camera-to-world transformation matrix 174 | Outputs: 175 | rays_o, rays_d: (H, W, 3) ray origins and directions in world coordinates 176 | """ 177 | # Rotate ray directions from camera coordinate to the world coordinate 178 | rays_d = directions @ c2w[:3, :3].T 179 | rays_o = c2w[:3, 3].expand(rays_d.shape) 180 | return rays_o, rays_d 181 | 182 | 183 | def compute_plucker_embed( 184 | c2w: torch.Tensor, image_width: int, image_height: int, focal: float 185 | ) -> torch.Tensor: 186 | """ 187 | Computes Plucker coordinates for a camera. 188 | Args: 189 | c2w: (4, 4) camera-to-world transformation matrix 190 | image_width: Image width 191 | image_height: Image height 192 | focal: Focal length of the camera 193 | Returns: 194 | plucker: (6, H, W) Plucker embedding 195 | """ 196 | directions = get_ray_directions(image_height, image_width, focal) 197 | rays_o, rays_d = get_rays(directions, c2w) 198 | # Cross product to get Plucker coordinates 199 | cross = torch.cross(rays_o, rays_d, dim=-1) 200 | plucker = torch.cat((rays_d, cross), dim=-1) 201 | return plucker.permute(2, 0, 1) 202 | 203 | 204 | def get_plucker_embeds_from_cameras( 205 | c2w: List[torch.Tensor], fov: List[float], image_size: int 206 | ) -> torch.Tensor: 207 | """ 208 | Given lists of camera transformations and fov, returns the batched plucker embeddings. 209 | Args: 210 | c2w: list of camera-to-world transformation matrices 211 | fov: list of field of view values 212 | image_size: size of the image 213 | Returns: 214 | plucker_embeds: (B, 6, H, W) batched plucker embeddings 215 | """ 216 | plucker_embeds = [] 217 | for cam_matrix, cam_fov in zip(c2w, fov): 218 | focal = 0.5 * image_size / np.tan(0.5 * cam_fov) 219 | plucker = compute_plucker_embed(cam_matrix, image_size, image_size, focal) 220 | plucker_embeds.append(plucker) 221 | return torch.stack(plucker_embeds) 222 | 223 | 224 | def get_plucker_embeds_from_cameras_ortho( 225 | c2w: List[torch.Tensor], ortho_scale: List[float], image_size: int 226 | ): 227 | """ 228 | Given lists of camera transformations and fov, returns the batched plucker embeddings. 229 | 230 | Parameters: 231 | c2w: list of camera-to-world transformation matrices 232 | fov: list of field of view values 233 | image_size: size of the image 234 | 235 | Returns: 236 | plucker_embeds: plucker embeddings (B, 6, H, W) 237 | """ 238 | plucker_embeds = [] 239 | # compute pairwise mask and plucker embeddings 240 | for cam_matrix, scale in zip(c2w, ortho_scale): 241 | # blender to opencv to pytorch3d 242 | R, T = get_opencv_from_blender(cam_matrix) 243 | cam_pos = -R.T @ T 244 | view_dir = R.T @ torch.tensor([0, 0, 1]).float().to(cam_matrix.device) 245 | # normalize camera position 246 | cam_pos = F.normalize(cam_pos, dim=0) 247 | plucker = torch.concat([view_dir, cam_pos]) 248 | plucker = plucker.unsqueeze(-1).unsqueeze(-1).repeat(1, image_size, image_size) 249 | plucker_embeds.append(plucker) 250 | 251 | plucker_embeds = torch.stack(plucker_embeds) 252 | 253 | return plucker_embeds 254 | -------------------------------------------------------------------------------- /mvadapter/utils/mesh_utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .camera import ( 2 | Camera, 3 | get_c2w, 4 | get_camera, 5 | get_orthogonal_camera, 6 | get_orthogonal_projection_matrix, 7 | get_projection_matrix, 8 | ) 9 | from .mesh import TexturedMesh, load_mesh, replace_mesh_texture_and_save 10 | from .projection import CameraProjection, CameraProjectionOutput 11 | from .render import ( 12 | DepthControlNetNormalization, 13 | DepthNormalizationStrategy, 14 | NVDiffRastContextWrapper, 15 | RenderOutput, 16 | SimpleNormalization, 17 | Zero123PlusPlusNormalization, 18 | render, 19 | ) 20 | from .smart_paint import SmartPainter 21 | -------------------------------------------------------------------------------- /mvadapter/utils/mesh_utils/blend.py: -------------------------------------------------------------------------------- 1 | import time 2 | from abc import ABC, abstractmethod 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | import triton 7 | import triton.language as tl 8 | from torch.utils.cpp_extension import load_inline 9 | 10 | from .utils import SINGLE_IMAGE_TYPE, image_to_tensor 11 | 12 | pb_solver_cpp_source = """ 13 | #include 14 | 15 | #include 16 | 17 | // CUDA forward declarations 18 | 19 | void pb_solver_run_cuda( 20 | torch::Tensor A, 21 | torch::Tensor X, 22 | torch::Tensor B, 23 | torch::Tensor Xbuf, 24 | int num_iters 25 | ); 26 | 27 | // C++ interface 28 | 29 | #define CHECK_CUDA(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") 30 | #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") 31 | #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) 32 | 33 | void pb_solver_run( 34 | torch::Tensor A, 35 | torch::Tensor X, 36 | torch::Tensor B, 37 | torch::Tensor Xbuf, 38 | int num_iters 39 | ) { 40 | CHECK_INPUT(A); 41 | CHECK_INPUT(X); 42 | CHECK_INPUT(B); 43 | CHECK_INPUT(Xbuf); 44 | 45 | pb_solver_run_cuda(A, X, B, Xbuf, num_iters); 46 | return; 47 | } 48 | 49 | """ 50 | 51 | pb_solver_cuda_source = """ 52 | #include 53 | 54 | #include 55 | #include 56 | 57 | #include 58 | 59 | 60 | __global__ void pb_solver_run_cuda_kernel( 61 | torch::PackedTensorAccessor32 A, 62 | torch::PackedTensorAccessor32 X, 63 | torch::PackedTensorAccessor32 B, 64 | torch::PackedTensorAccessor32 Xbuf 65 | ) { 66 | const int index = blockIdx.x * blockDim.x + threadIdx.x; 67 | 68 | if (index < A.size(0)) { 69 | Xbuf[index] = (X[A[index][0]] + X[A[index][1]] + X[A[index][2]] + X[A[index][3]] + B[index]) / 4.; 70 | } 71 | } 72 | 73 | void pb_solver_run_cuda( 74 | torch::Tensor A, 75 | torch::Tensor X, 76 | torch::Tensor B, 77 | torch::Tensor Xbuf, 78 | int num_iters 79 | ) { 80 | int batch_size = A.size(0); 81 | 82 | const int threads = 1024; 83 | const dim3 blocks((batch_size + threads - 1) / threads); 84 | 85 | auto A_ptr = A.packed_accessor32(); 86 | auto X_ptr = X.packed_accessor32(); 87 | auto B_ptr = B.packed_accessor32(); 88 | auto Xbuf_ptr = Xbuf.packed_accessor32(); 89 | 90 | for (int i = 0; i < num_iters; ++i) { 91 | pb_solver_run_cuda_kernel<<>>( 92 | A_ptr, 93 | X_ptr, 94 | B_ptr, 95 | Xbuf_ptr 96 | ); 97 | cudaDeviceSynchronize(); 98 | std::swap(X_ptr, Xbuf_ptr); 99 | } 100 | // we may waste an iteration here, but it's fine 101 | return; 102 | } 103 | """ 104 | 105 | 106 | class PBBackend(ABC): 107 | def __init__(self, *args, **kwargs) -> None: 108 | pass 109 | 110 | def solve(self, num_iters, A, X, B, Xbuf) -> None: 111 | pass 112 | 113 | 114 | try: 115 | 116 | @triton.jit 117 | def pb_triton_step_kernel( 118 | A_ptr, 119 | X_ptr, 120 | B_ptr, 121 | Xbuf_ptr, 122 | A_row_stride: tl.constexpr, 123 | n_elements: tl.constexpr, 124 | BLOCK_SIZE: tl.constexpr, 125 | ): 126 | pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0. 127 | block_start = pid * BLOCK_SIZE 128 | offsets = tl.arange(0, BLOCK_SIZE) + block_start 129 | 130 | A_start_ptr = A_ptr + block_start * A_row_stride 131 | A_ptrs = A_start_ptr + ( 132 | tl.arange(0, BLOCK_SIZE)[:, None] * A_row_stride 133 | + tl.arange(0, A_row_stride)[None, :] 134 | ) 135 | B_ptrs = B_ptr + offsets 136 | Xbuf_ptrs = Xbuf_ptr + offsets 137 | 138 | mask = offsets < n_elements 139 | 140 | A = tl.load(A_ptrs, mask=mask[:, None], other=0) 141 | X = tl.load(X_ptr + A) 142 | B = tl.load(B_ptrs, mask=mask, other=0.0) 143 | 144 | Xout = (tl.sum(X, axis=1) + B) / 4 145 | tl.store(Xbuf_ptrs, Xout, mask=mask) 146 | 147 | except: 148 | pb_triton_step_kernel = None 149 | 150 | 151 | class PBTorchCUDAKernelBackend(PBBackend): 152 | def __init__(self) -> None: 153 | self.kernel = load_inline( 154 | name="pb_solver", 155 | cpp_sources=[pb_solver_cpp_source], 156 | cuda_sources=[pb_solver_cuda_source], 157 | functions=["pb_solver_run"], 158 | verbose=True, 159 | ) 160 | 161 | def solve(self, num_iters, A, X, B, Xbuf) -> None: 162 | self.kernel.pb_solver_run(A, X, B, Xbuf, num_iters) 163 | 164 | 165 | class PBTritonBackend(PBBackend): 166 | def step( 167 | self, X: torch.Tensor, A: torch.Tensor, B: torch.Tensor, Xbuf: torch.Tensor 168 | ) -> None: 169 | n_elements = X.shape[0] 170 | grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) 171 | assert pb_triton_step_kernel is not None 172 | pb_triton_step_kernel[grid](A, X, B, Xbuf, A.stride(0), n_elements, 1024) 173 | 174 | def solve(self, num_iters, A, X, B, Xbuf) -> None: 175 | for _ in range(num_iters): 176 | self.step(X, A, B, Xbuf) 177 | X, Xbuf = Xbuf, X 178 | 179 | 180 | class PBTorchNativeBackend(PBBackend): 181 | def step( 182 | self, X: torch.Tensor, A: torch.Tensor, B: torch.Tensor, Xbuf: torch.Tensor 183 | ) -> None: 184 | X.copy_(X[A].sum(-1).add_(B).div_(4)) 185 | 186 | def solve(self, num_iters, A, X, B, Xbuf) -> None: 187 | for _ in range(num_iters): 188 | self.step(X, A, B, Xbuf) 189 | 190 | 191 | class PoissonBlendingSolver: 192 | def __init__(self, backend: str, device: str): 193 | self.backend = backend 194 | if backend == "torch-native": 195 | self.pb_solver = PBTorchNativeBackend() 196 | elif backend == "torch-cuda": 197 | self.pb_solver = PBTorchCUDAKernelBackend() 198 | elif backend == "triton": 199 | self.pb_solver = PBTritonBackend() 200 | else: 201 | raise ValueError(f"Unknown backend: {backend}") 202 | 203 | self.device = device 204 | self.lap_kernel = torch.tensor( 205 | [[0, -1, 0], [-1, 4, -1], [0, -1, 0]], device=device, dtype=torch.float32 206 | ).view(1, 1, 3, 3) 207 | self.lap_kernel4 = torch.tensor( 208 | [ 209 | [[0, -1, 0], [0, 1, 0], [0, 0, 0]], 210 | [[0, 0, 0], [0, 1, 0], [0, -1, 0]], 211 | [[0, 0, 0], [-1, 1, 0], [0, 0, 0]], 212 | [[0, 0, 0], [0, 1, -1], [0, 0, 0]], 213 | ], 214 | device=device, 215 | dtype=torch.float32, 216 | ).view(4, 1, 3, 3) 217 | self.neighbor_kernel = torch.tensor( 218 | [[0, 1, 0], [1, 0, 1], [0, 1, 0]], device=device, dtype=torch.float32 219 | ).view(1, 1, 3, 3) 220 | 221 | def __call__( 222 | self, 223 | src: SINGLE_IMAGE_TYPE, 224 | mask: SINGLE_IMAGE_TYPE, 225 | tgt: SINGLE_IMAGE_TYPE, 226 | num_iters: int, 227 | inplace: bool = True, 228 | grad_mode: str = "src", 229 | ): 230 | src = image_to_tensor(src, device=self.device) 231 | mask = image_to_tensor(mask, device=self.device) 232 | tgt = image_to_tensor(tgt, device=self.device) 233 | 234 | assert src.ndim == 3 and tgt.ndim == 3 and mask.ndim in [2, 3] 235 | 236 | if mask.ndim == 3: 237 | mask = mask.mean(-1) > 0.5 238 | else: 239 | mask = mask > 0.5 240 | mask[0, :] = 0 241 | mask[-1, :] = 0 242 | mask[:, 0] = 0 243 | mask[:, -1] = 0 244 | 245 | tgt_masked = torch.where(mask[..., None], torch.zeros_like(tgt), tgt) 246 | 247 | x, y = mask.nonzero(as_tuple=True) 248 | N = x.shape[0] 249 | index_map = torch.cumsum(mask.reshape(-1).long(), dim=-1).reshape(mask.shape) 250 | index_map[~mask] = 0 251 | 252 | if grad_mode == "src": 253 | src_lap = F.conv2d( 254 | src.permute(2, 0, 1)[:, None], 255 | weight=self.lap_kernel, 256 | padding=1, 257 | )[:, 0].permute(1, 2, 0) 258 | lap = src_lap 259 | elif grad_mode == "max": 260 | src_lap = F.conv2d( 261 | src.permute(2, 0, 1)[:, None], 262 | weight=self.lap_kernel4, 263 | padding=1, 264 | ) 265 | tgt_lap = F.conv2d( 266 | tgt.permute(2, 0, 1)[:, None], 267 | weight=self.lap_kernel4, 268 | padding=1, 269 | ) 270 | lap = ( 271 | torch.where(src_lap.abs() > tgt_lap.abs(), src_lap, tgt_lap) 272 | .sum(1) 273 | .permute(1, 2, 0) 274 | ) 275 | elif grad_mode == "avg": 276 | src_lap = F.conv2d( 277 | src.permute(2, 0, 1)[:, None], 278 | weight=self.lap_kernel4, 279 | padding=1, 280 | ) 281 | tgt_lap = F.conv2d( 282 | tgt.permute(2, 0, 1)[:, None], 283 | weight=self.lap_kernel4, 284 | padding=1, 285 | ) 286 | lap = (src_lap + tgt_lap).mul(0.5).sum(1).permute(1, 2, 0) 287 | 288 | fq_star = F.conv2d( 289 | tgt_masked.permute(2, 0, 1)[:, None], 290 | weight=self.neighbor_kernel, 291 | padding=1, 292 | )[:, 0].permute(1, 2, 0) 293 | 294 | A = torch.zeros(N + 1, 4, device=self.device, dtype=torch.long) 295 | X = torch.zeros(N + 1, 3, device=self.device, dtype=torch.float32) 296 | B = torch.zeros(N + 1, 3, device=self.device, dtype=torch.float32) 297 | 298 | A[1:] = torch.stack( 299 | [ 300 | index_map[x - 1, y], 301 | index_map[x + 1, y], 302 | index_map[x, y - 1], 303 | index_map[x, y + 1], 304 | ], 305 | dim=-1, 306 | ) 307 | X[1:] = tgt[x, y] 308 | B[1:] = lap[x, y] + fq_star[x, y] 309 | 310 | X_flatten = X.flatten() 311 | B_flatten = B.flatten() 312 | A_flatten = torch.stack([3 * A, 3 * A + 1, 3 * A + 2], dim=1).reshape(-1, 4) 313 | 314 | buffer = torch.zeros_like(X_flatten) 315 | 316 | self.pb_solver.solve(num_iters, A_flatten, X_flatten, B_flatten, buffer) 317 | 318 | if inplace: 319 | tgt[x, y] = X_flatten.view(-1, 3)[1:].clamp(0.0, 1.0) 320 | else: 321 | tgt = tgt.clone() 322 | tgt[x, y] = X_flatten.view(-1, 3)[1:].clamp(0.0, 1.0) 323 | 324 | return tgt 325 | -------------------------------------------------------------------------------- /mvadapter/utils/mesh_utils/camera.py: -------------------------------------------------------------------------------- 1 | import math 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | import torch.nn.functional as F 8 | import trimesh 9 | from PIL import Image 10 | from torch import BoolTensor, FloatTensor 11 | 12 | from .utils import LIST_TYPE 13 | 14 | 15 | def list_to_pt( 16 | x: LIST_TYPE, dtype: Optional[torch.dtype] = None, device: Optional[str] = None 17 | ) -> torch.Tensor: 18 | if isinstance(x, list) or isinstance(x, np.ndarray): 19 | return torch.tensor(x, dtype=dtype, device=device) 20 | return x.to(dtype=dtype) 21 | 22 | 23 | def get_c2w( 24 | elevation_deg: LIST_TYPE, 25 | distance: LIST_TYPE, 26 | azimuth_deg: Optional[LIST_TYPE], 27 | num_views: Optional[int] = 1, 28 | device: Optional[str] = None, 29 | ) -> torch.FloatTensor: 30 | if azimuth_deg is None: 31 | assert ( 32 | num_views is not None 33 | ), "num_views must be provided if azimuth_deg is None." 34 | azimuth_deg = torch.linspace( 35 | 0, 360, num_views + 1, dtype=torch.float32, device=device 36 | )[:-1] 37 | else: 38 | num_views = len(azimuth_deg) 39 | azimuth_deg = list_to_pt(azimuth_deg, dtype=torch.float32, device=device) 40 | elevation_deg = list_to_pt(elevation_deg, dtype=torch.float32, device=device) 41 | camera_distances = list_to_pt(distance, dtype=torch.float32, device=device) 42 | elevation = elevation_deg * math.pi / 180 43 | azimuth = azimuth_deg * math.pi / 180 44 | camera_positions = torch.stack( 45 | [ 46 | camera_distances * torch.cos(elevation) * torch.cos(azimuth), 47 | camera_distances * torch.cos(elevation) * torch.sin(azimuth), 48 | camera_distances * torch.sin(elevation), 49 | ], 50 | dim=-1, 51 | ) 52 | center = torch.zeros_like(camera_positions) 53 | up = torch.tensor([0, 0, 1], dtype=torch.float32, device=device)[None, :].repeat( 54 | num_views, 1 55 | ) 56 | lookat = F.normalize(center - camera_positions, dim=-1) 57 | right = F.normalize(torch.cross(lookat, up, dim=-1), dim=-1) 58 | up = F.normalize(torch.cross(right, lookat, dim=-1), dim=-1) 59 | c2w3x4 = torch.cat( 60 | [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], 61 | dim=-1, 62 | ) 63 | c2w = torch.cat([c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1) 64 | c2w[:, 3, 3] = 1.0 65 | return c2w 66 | 67 | 68 | def get_projection_matrix( 69 | fovy_deg: LIST_TYPE, 70 | aspect_wh: float = 1.0, 71 | near: float = 0.1, 72 | far: float = 100.0, 73 | device: Optional[str] = None, 74 | ) -> torch.FloatTensor: 75 | fovy_deg = list_to_pt(fovy_deg, dtype=torch.float32, device=device) 76 | batch_size = fovy_deg.shape[0] 77 | fovy = fovy_deg * math.pi / 180 78 | tan_half_fovy = torch.tan(fovy / 2) 79 | projection_matrix = torch.zeros( 80 | batch_size, 4, 4, dtype=torch.float32, device=device 81 | ) 82 | projection_matrix[:, 0, 0] = 1 / (aspect_wh * tan_half_fovy) 83 | projection_matrix[:, 1, 1] = -1 / tan_half_fovy 84 | projection_matrix[:, 2, 2] = -(far + near) / (far - near) 85 | projection_matrix[:, 2, 3] = -2 * far * near / (far - near) 86 | projection_matrix[:, 3, 2] = -1 87 | return projection_matrix 88 | 89 | 90 | def get_orthogonal_projection_matrix( 91 | batch_size: int, 92 | left: float, 93 | right: float, 94 | bottom: float, 95 | top: float, 96 | near: float = 0.1, 97 | far: float = 100.0, 98 | device: Optional[str] = None, 99 | ) -> torch.FloatTensor: 100 | projection_matrix = torch.zeros( 101 | batch_size, 4, 4, dtype=torch.float32, device=device 102 | ) 103 | projection_matrix[:, 0, 0] = 2 / (right - left) 104 | projection_matrix[:, 1, 1] = -2 / (top - bottom) 105 | projection_matrix[:, 2, 2] = -2 / (far - near) 106 | projection_matrix[:, 0, 3] = -(right + left) / (right - left) 107 | projection_matrix[:, 1, 3] = -(top + bottom) / (top - bottom) 108 | projection_matrix[:, 2, 3] = -(far + near) / (far - near) 109 | projection_matrix[:, 3, 3] = 1 110 | return projection_matrix 111 | 112 | 113 | @dataclass 114 | class Camera: 115 | c2w: Optional[torch.FloatTensor] 116 | w2c: torch.FloatTensor 117 | proj_mtx: torch.FloatTensor 118 | mvp_mtx: torch.FloatTensor 119 | cam_pos: Optional[torch.FloatTensor] 120 | 121 | def __getitem__(self, index): 122 | if isinstance(index, int): 123 | sl = slice(index, index + 1) 124 | elif isinstance(index, slice): 125 | sl = index 126 | elif isinstance(index, list): 127 | sl = index 128 | else: 129 | raise NotImplementedError 130 | 131 | return Camera( 132 | c2w=self.c2w[sl] if self.c2w is not None else None, 133 | w2c=self.w2c[sl], 134 | proj_mtx=self.proj_mtx[sl], 135 | mvp_mtx=self.mvp_mtx[sl], 136 | cam_pos=self.cam_pos[sl] if self.cam_pos is not None else None, 137 | ) 138 | 139 | def to(self, device: Optional[str] = None): 140 | if self.c2w is not None: 141 | self.c2w = self.c2w.to(device) 142 | self.w2c = self.w2c.to(device) 143 | self.proj_mtx = self.proj_mtx.to(device) 144 | self.mvp_mtx = self.mvp_mtx.to(device) 145 | if self.cam_pos is not None: 146 | self.cam_pos = self.cam_pos.to(device) 147 | 148 | def __len__(self): 149 | return self.c2w.shape[0] 150 | 151 | 152 | def get_camera( 153 | elevation_deg: Optional[LIST_TYPE] = None, 154 | distance: Optional[LIST_TYPE] = None, 155 | fovy_deg: Optional[LIST_TYPE] = None, 156 | azimuth_deg: Optional[LIST_TYPE] = None, 157 | num_views: Optional[int] = 1, 158 | c2w: Optional[torch.FloatTensor] = None, 159 | w2c: Optional[torch.FloatTensor] = None, 160 | proj_mtx: Optional[torch.FloatTensor] = None, 161 | aspect_wh: float = 1.0, 162 | near: float = 0.1, 163 | far: float = 100.0, 164 | perturb_camera_position: Optional[float] = None, 165 | device: Optional[str] = None, 166 | ): 167 | if w2c is None: 168 | if c2w is None: 169 | c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) 170 | if perturb_camera_position is not None: 171 | perturbed_pos = ( 172 | c2w[:, :3, 3] 173 | + torch.randn_like(c2w[:, :3, 3]) * perturb_camera_position 174 | ) 175 | perturbed_pos = ( 176 | F.normalize(perturbed_pos, dim=-1) 177 | * ((c2w[:, :3, 3] ** 2).sum(-1) ** 0.5)[:, None] 178 | ) 179 | camera_positions = c2w[:, :3, 3] 180 | w2c = torch.linalg.inv(c2w) 181 | else: 182 | camera_positions = None 183 | c2w = None 184 | if proj_mtx is None: 185 | proj_mtx = get_projection_matrix( 186 | fovy_deg, aspect_wh=aspect_wh, near=near, far=far, device=device 187 | ) 188 | mvp_mtx = proj_mtx @ w2c 189 | return Camera( 190 | c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions 191 | ) 192 | 193 | 194 | def get_orthogonal_camera( 195 | elevation_deg: LIST_TYPE, 196 | distance: LIST_TYPE, 197 | left: float, 198 | right: float, 199 | bottom: float, 200 | top: float, 201 | azimuth_deg: Optional[LIST_TYPE] = None, 202 | num_views: Optional[int] = 1, 203 | near: float = 0.1, 204 | far: float = 100.0, 205 | device: Optional[str] = None, 206 | ): 207 | c2w = get_c2w(elevation_deg, distance, azimuth_deg, num_views, device) 208 | camera_positions = c2w[:, :3, 3] 209 | w2c = torch.linalg.inv(c2w) 210 | proj_mtx = get_orthogonal_projection_matrix( 211 | batch_size=c2w.shape[0], 212 | left=left, 213 | right=right, 214 | bottom=bottom, 215 | top=top, 216 | near=near, 217 | far=far, 218 | device=device, 219 | ) 220 | mvp_mtx = proj_mtx @ w2c 221 | return Camera( 222 | c2w=c2w, w2c=w2c, proj_mtx=proj_mtx, mvp_mtx=mvp_mtx, cam_pos=camera_positions 223 | ) 224 | -------------------------------------------------------------------------------- /mvadapter/utils/mesh_utils/cv_ops.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import cvcuda 4 | import torch 5 | 6 | torch_to_cvc = lambda x, layout: cvcuda.as_tensor(x, layout) 7 | 8 | cvc_to_torch = lambda x, device: torch.tensor(x.cuda(), device=device) 9 | 10 | 11 | def inpaint_cvc( 12 | image: torch.Tensor, 13 | mask: torch.Tensor, 14 | padding_size: int, 15 | return_dtype: Optional[torch.dtype] = None, 16 | ): 17 | input_dtype = image.dtype 18 | input_device = image.device 19 | 20 | image = image.detach() 21 | mask = mask.detach() 22 | 23 | if image.dtype != torch.uint8: 24 | image = (image * 255).to(torch.uint8) 25 | if mask.dtype != torch.uint8: 26 | mask = (mask * 255).to(torch.uint8) 27 | 28 | image_cvc = torch_to_cvc(image, "HWC") 29 | mask_cvc = torch_to_cvc(mask, "HW") 30 | output_cvc = cvcuda.inpaint(image_cvc, mask_cvc, padding_size) 31 | output = cvc_to_torch(output_cvc, device=input_device) 32 | 33 | if return_dtype == torch.uint8 or input_dtype == torch.uint8: 34 | return output 35 | return output.to(dtype=input_dtype) / 255.0 36 | 37 | 38 | def batch_inpaint_cvc( 39 | images: torch.Tensor, 40 | masks: torch.Tensor, 41 | padding_size: int, 42 | return_dtype: Optional[torch.dtype] = None, 43 | ): 44 | output = torch.stack( 45 | [ 46 | inpaint_cvc(image, mask, padding_size, return_dtype) 47 | for (image, mask) in zip(images, masks) 48 | ], 49 | axis=0, 50 | ) 51 | return output 52 | 53 | 54 | def batch_erode( 55 | masks: torch.Tensor, kernel_size: int, return_dtype: Optional[torch.dtype] = None 56 | ): 57 | input_dtype = masks.dtype 58 | input_device = masks.device 59 | masks = masks.detach() 60 | if masks.dtype != torch.uint8: 61 | masks = (masks.float() * 255).to(torch.uint8) 62 | masks_cvc = torch_to_cvc(masks[..., None], "NHWC") 63 | masks_erode_cvc = cvcuda.morphology( 64 | masks_cvc, 65 | cvcuda.MorphologyType.ERODE, 66 | maskSize=(kernel_size, kernel_size), 67 | anchor=(-1, -1), 68 | ) 69 | masks_erode = cvc_to_torch(masks_erode_cvc, device=input_device)[..., 0] 70 | if return_dtype == torch.uint8 or input_dtype == torch.uint8: 71 | return masks_erode 72 | return (masks_erode > 0).to(dtype=input_dtype) 73 | 74 | 75 | def batch_dilate( 76 | masks: torch.Tensor, kernel_size: int, return_dtype: Optional[torch.dtype] = None 77 | ): 78 | input_dtype = masks.dtype 79 | input_device = masks.device 80 | masks = masks.detach() 81 | if masks.dtype != torch.uint8: 82 | masks = (masks.float() * 255).to(torch.uint8) 83 | masks_cvc = torch_to_cvc(masks[..., None], "NHWC") 84 | masks_dilate_cvc = cvcuda.morphology( 85 | masks_cvc, 86 | cvcuda.MorphologyType.DILATE, 87 | maskSize=(kernel_size, kernel_size), 88 | anchor=(-1, -1), 89 | ) 90 | masks_dilate = cvc_to_torch(masks_dilate_cvc, device=input_device)[..., 0] 91 | if return_dtype == torch.uint8 or input_dtype == torch.uint8: 92 | return masks_dilate 93 | return (masks_dilate > 0).to(dtype=input_dtype) 94 | -------------------------------------------------------------------------------- /mvadapter/utils/mesh_utils/projection.py: -------------------------------------------------------------------------------- 1 | import glob 2 | from dataclasses import dataclass 3 | from typing import List, Optional, Union 4 | 5 | import numpy as np 6 | import torch 7 | from einops import rearrange 8 | from PIL import Image 9 | 10 | from .blend import PoissonBlendingSolver 11 | from .camera import Camera, get_camera, get_orthogonal_camera 12 | from .mesh import TexturedMesh, load_mesh, replace_mesh_texture_and_save 13 | from .render import NVDiffRastContextWrapper, SimpleNormalization, render 14 | from .seg import SegmentationModel 15 | from .utils import ( 16 | IMAGE_TYPE, 17 | LIST_TYPE, 18 | image_to_tensor, 19 | make_image_grid, 20 | tensor_to_image, 21 | ) 22 | from .uv import ( 23 | ExponentialBlend, 24 | SimpleUVValidityStrategy, 25 | uv_blend, 26 | uv_precompute, 27 | uv_render_attr, 28 | uv_render_geometry, 29 | ) 30 | from .warp import compute_warp_field 31 | 32 | 33 | @dataclass 34 | class CameraProjectionOutput: 35 | uv_proj: torch.FloatTensor 36 | uv_proj_mask: torch.BoolTensor 37 | uv_depth_grad: Optional[torch.FloatTensor] 38 | uv_aoi_cos: Optional[torch.FloatTensor] 39 | 40 | 41 | class CameraProjection: 42 | def __init__( 43 | self, 44 | pb_backend: str, 45 | bg_remover: Optional[SegmentationModel], 46 | device: str, 47 | context_type: str = "gl", 48 | ) -> None: 49 | self.pb_solver = PoissonBlendingSolver(pb_backend, device) 50 | self.ctx = NVDiffRastContextWrapper(device, context_type) 51 | self.bg_remover = bg_remover 52 | self.device = device 53 | 54 | def __call__( 55 | self, 56 | images: IMAGE_TYPE, 57 | mesh: TexturedMesh, 58 | cam: Optional[Camera] = None, 59 | fovy_deg: Optional[LIST_TYPE] = None, 60 | masks: Optional[IMAGE_TYPE] = None, 61 | remove_bg: bool = False, 62 | c2w: Optional[torch.FloatTensor] = None, 63 | elevation_deg: Optional[LIST_TYPE] = None, 64 | distance: Optional[LIST_TYPE] = None, 65 | azimuth_deg: Optional[LIST_TYPE] = None, 66 | num_views: Optional[int] = None, 67 | uv_size: int = 2048, 68 | warp_images: bool = False, 69 | images_background: Optional[float] = None, 70 | iou_rejection_threshold: Optional[float] = 0.8, 71 | aoi_cos_valid_threshold: float = 0.3, 72 | depth_grad_dilation: int = 5, 73 | depth_grad_threshold: float = 0.1, 74 | uv_exp_blend_alpha: float = 6, 75 | uv_exp_blend_view_weight: Optional[torch.FloatTensor] = None, 76 | poisson_blending: bool = True, 77 | pb_num_iters: int = 1000, 78 | pb_keep_original_border: bool = True, 79 | from_scratch: bool = False, 80 | uv_padding: bool = True, 81 | return_uv_projection_mask: bool = False, 82 | return_dict: bool = False, 83 | ) -> Optional[torch.FloatTensor]: 84 | images_pt = image_to_tensor(images, device=self.device) 85 | assert images_pt.ndim == 4 86 | Nv, H, W, _ = images_pt.shape 87 | 88 | if masks is not None: 89 | masks_pt = image_to_tensor(masks, device=self.device) 90 | else: 91 | if remove_bg: 92 | assert self.bg_remover is not None 93 | masks_pt = self.bg_remover(images_pt) 94 | else: 95 | masks_pt = None 96 | 97 | if masks_pt is not None and masks_pt.ndim == 4: 98 | masks_pt = masks_pt.mean(-1) 99 | 100 | if cam is None: 101 | cam = get_camera( 102 | elevation_deg, 103 | distance, 104 | fovy_deg, 105 | azimuth_deg, 106 | num_views, 107 | c2w, 108 | aspect_wh=W / H, 109 | device=self.device, 110 | ) 111 | uv_precompute_output = uv_precompute( 112 | self.ctx, mesh, height=uv_size, width=uv_size 113 | ) 114 | uv_render_geometry_output = uv_render_geometry( 115 | self.ctx, 116 | mesh, 117 | cam, 118 | view_height=H, 119 | view_width=W, 120 | uv_precompute_output=uv_precompute_output, 121 | compute_depth_grad=True, 122 | depth_grad_dilation=depth_grad_dilation, 123 | ) 124 | 125 | # IoU rejection 126 | if masks_pt is not None and iou_rejection_threshold is not None: 127 | given_masks = (masks_pt > 0.5).float() 128 | render_masks = uv_render_geometry_output.view_mask.float() 129 | intersection = given_masks * render_masks 130 | union = given_masks + render_masks - intersection 131 | iou = intersection.sum((1, 2)) / union.sum((1, 2)) 132 | iou_min = iou.min() 133 | print(f"Debug: Per view IoU: {iou.tolist()}") 134 | if iou_min < iou_rejection_threshold: 135 | print( 136 | f"Warning: Minimum view IoU {iou_min} below threshold {iou_rejection_threshold}, skipping camera projection!" 137 | ) 138 | return None 139 | 140 | if warp_images: 141 | # TODO: clean code 142 | assert images_background is not None 143 | render_attr = render( 144 | self.ctx, 145 | mesh, 146 | cam, 147 | height=H, 148 | width=W, 149 | render_attr=True, 150 | attr_background=images_background, 151 | ).attr 152 | images_pt = compute_warp_field( 153 | self.ctx.ctx, 154 | images_pt, 155 | render_attr, 156 | n_grid=10, 157 | optim_res=[64, 128], 158 | optim_step_per_res=20, 159 | lambda_reg=2.0, 160 | temp_dir="debug_warp", 161 | verbose=False, 162 | device=self.device, 163 | ) 164 | 165 | uv_render_attr_output = uv_render_attr( 166 | images=images_pt, 167 | masks=masks_pt, 168 | uv_render_geometry_output=uv_render_geometry_output, 169 | ) 170 | uv_blend_output = uv_blend( 171 | uv_precompute_output, 172 | uv_render_geometry_output, 173 | uv_render_attr_output, 174 | uv_validity_strategy=SimpleUVValidityStrategy( 175 | aoi_cos_thresh=aoi_cos_valid_threshold, 176 | depth_grad_thresh=depth_grad_threshold, 177 | ), 178 | uv_blend_weight_strategy=ExponentialBlend( 179 | alpha=uv_exp_blend_alpha, view_weight=uv_exp_blend_view_weight 180 | ), 181 | empty_value=1.0, 182 | do_uv_padding=uv_padding, 183 | pad_unseen_area=from_scratch, 184 | poisson_blending=poisson_blending, 185 | pb_solver=self.pb_solver, 186 | pb_num_iters=pb_num_iters, 187 | pb_keep_original_border=pb_keep_original_border, 188 | ) 189 | 190 | if return_dict: 191 | # recommonded new way to get return value 192 | return CameraProjectionOutput( 193 | uv_proj=uv_blend_output.uv_attr_blend, 194 | uv_proj_mask=uv_blend_output.uv_valid_mask_blend, 195 | uv_depth_grad=uv_render_geometry_output.uv_depth_grad, 196 | uv_aoi_cos=uv_render_geometry_output.uv_aoi_cos, 197 | ) 198 | else: 199 | if return_uv_projection_mask: 200 | return ( 201 | uv_blend_output.uv_attr_blend, 202 | uv_blend_output.uv_valid_mask_blend, 203 | ) 204 | return uv_blend_output.uv_attr_blend 205 | -------------------------------------------------------------------------------- /mvadapter/utils/mesh_utils/seg.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | import torch 4 | import transformers 5 | 6 | from .utils import IMAGE_TYPE, image_to_tensor 7 | 8 | 9 | class SegmentationModel(ABC): 10 | def __init__(self, *args, **kwargs): 11 | pass 12 | 13 | def __call__(self, images: IMAGE_TYPE) -> torch.FloatTensor: 14 | pass 15 | 16 | 17 | class RMBGModel(SegmentationModel): 18 | def __init__(self, pretrained_model_name_or_path: str, device: str): 19 | self.model = transformers.AutoModelForImageSegmentation.from_pretrained( 20 | pretrained_model_name_or_path, trust_remote_code=True 21 | ).to(device) 22 | self.device = device 23 | 24 | def __call__(self, images: IMAGE_TYPE) -> torch.FloatTensor: 25 | images = image_to_tensor(images, device=self.device) 26 | batched = True 27 | if images.ndim == 3: 28 | images = images.unsqueeze(0) 29 | batched = False 30 | 31 | out = ( 32 | self.model(images.permute(0, 3, 1, 2) - 0.5)[0][0] 33 | .clamp(0.0, 1.0) 34 | .permute(0, 2, 3, 1) 35 | ) 36 | if not batched: 37 | out = out.squeeze(0) 38 | return out 39 | -------------------------------------------------------------------------------- /mvadapter/utils/mesh_utils/utils.py: -------------------------------------------------------------------------------- 1 | import io 2 | import math 3 | from abc import ABC, abstractmethod 4 | from contextlib import contextmanager 5 | from dataclasses import dataclass 6 | from datetime import datetime 7 | from typing import List, Optional, Union 8 | 9 | import numpy as np 10 | import nvdiffrast.torch as dr 11 | import torch 12 | import torch.nn.functional as F 13 | import trimesh 14 | from PIL import Image 15 | from torch import BoolTensor, FloatTensor 16 | 17 | LIST_TYPE = Union[list, np.ndarray, torch.Tensor] 18 | IMAGE_TYPE = Union[Image.Image, List[Image.Image], np.ndarray, torch.Tensor] 19 | SINGLE_IMAGE_TYPE = Union[Image.Image, np.ndarray, torch.Tensor] 20 | 21 | 22 | def tensor_to_image( 23 | data: Union[Image.Image, torch.Tensor, np.ndarray], 24 | batched: bool = False, 25 | format: str = "HWC", 26 | ) -> Union[Image.Image, List[Image.Image]]: 27 | if isinstance(data, Image.Image): 28 | return data 29 | if isinstance(data, torch.Tensor): 30 | data = data.detach().cpu().numpy() 31 | if data.dtype == np.float32 or data.dtype == np.float16: 32 | data = (data * 255).astype(np.uint8) 33 | elif data.dtype == np.bool_: 34 | data = data.astype(np.uint8) * 255 35 | assert data.dtype == np.uint8 36 | if format == "CHW": 37 | if batched and data.ndim == 4: 38 | data = data.transpose((0, 2, 3, 1)) 39 | elif not batched and data.ndim == 3: 40 | data = data.transpose((1, 2, 0)) 41 | 42 | if batched: 43 | return [Image.fromarray(d) for d in data] 44 | return Image.fromarray(data) 45 | 46 | 47 | def image_to_tensor(image: IMAGE_TYPE, return_type="pt", device: Optional[str] = None): 48 | assert return_type in ["np", "pt"] 49 | batched = True 50 | if isinstance(image, Image.Image): 51 | batched = False 52 | image = [image] 53 | if isinstance(image, list): 54 | image = np.stack([np.array(img) for img in image], axis=0) 55 | image = image.astype(np.float32) / 255.0 56 | if isinstance(image, np.ndarray) and return_type == "pt": 57 | image = torch.tensor(image, device=device) 58 | if isinstance(image, torch.Tensor): 59 | image = image.to(dtype=torch.float32, device=device) 60 | 61 | if not batched: 62 | image = image[0] 63 | return image 64 | 65 | 66 | def largest_factor_near_sqrt(n: int) -> int: 67 | """ 68 | Finds the largest factor of n that is closest to the square root of n. 69 | 70 | Args: 71 | n (int): The integer for which to find the largest factor near its square root. 72 | 73 | Returns: 74 | int: The largest factor of n that is closest to the square root of n. 75 | """ 76 | sqrt_n = int(math.sqrt(n)) # Get the integer part of the square root 77 | 78 | # First, check if the square root itself is a factor 79 | if sqrt_n * sqrt_n == n: 80 | return sqrt_n 81 | 82 | # Otherwise, find the largest factor by iterating from sqrt_n downwards 83 | for i in range(sqrt_n, 0, -1): 84 | if n % i == 0: 85 | return i 86 | 87 | # If n is 1, return 1 88 | return 1 89 | 90 | 91 | def make_image_grid( 92 | images: List[Image.Image], 93 | rows: Optional[int] = None, 94 | cols: Optional[int] = None, 95 | resize: Optional[int] = None, 96 | ) -> Image.Image: 97 | """ 98 | Prepares a single grid of images. Useful for visualization purposes. 99 | """ 100 | if rows is None and cols is not None: 101 | assert len(images) % cols == 0 102 | rows = len(images) // cols 103 | elif cols is None and rows is not None: 104 | assert len(images) % rows == 0 105 | cols = len(images) // rows 106 | elif rows is None and cols is None: 107 | rows = largest_factor_near_sqrt(len(images)) 108 | cols = len(images) // rows 109 | 110 | assert len(images) == rows * cols 111 | 112 | if resize is not None: 113 | images = [img.resize((resize, resize)) for img in images] 114 | 115 | w, h = images[0].size 116 | grid = Image.new("RGB", size=(cols * w, rows * h)) 117 | 118 | for i, img in enumerate(images): 119 | grid.paste(img, box=(i % cols * w, i // cols * h)) 120 | return grid 121 | 122 | 123 | def get_current_timestamp(fmt: str = "%Y%m%d%H%M%S") -> str: 124 | return datetime.now().strftime(fmt) 125 | 126 | 127 | def get_clip_space_position(pos: torch.FloatTensor, mvp_mtx: torch.FloatTensor): 128 | pos_homo = torch.cat([pos, torch.ones([pos.shape[0], 1]).to(pos)], dim=-1) 129 | return torch.matmul(pos_homo, mvp_mtx.permute(0, 2, 1)) 130 | 131 | 132 | def transform_points_homo(pos: torch.FloatTensor, mtx: torch.FloatTensor): 133 | batch_size = pos.shape[0] 134 | pos_shape = pos.shape[1:-1] 135 | pos = pos.reshape(batch_size, -1, 3) 136 | pos_homo = torch.cat([pos, torch.ones_like(pos[..., 0:1])], dim=-1) 137 | pos = (pos_homo.unsqueeze(2) * mtx.unsqueeze(1)).sum(-1)[..., :3] 138 | pos = pos.reshape(batch_size, *pos_shape, 3) 139 | return pos 140 | -------------------------------------------------------------------------------- /mvadapter/utils/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import re 4 | import time 5 | from collections import defaultdict 6 | from contextlib import contextmanager 7 | 8 | import psutil 9 | import torch 10 | from packaging import version 11 | 12 | from .config import config_to_primitive 13 | from .core import debug, find, info, warn 14 | from .typing import * 15 | 16 | 17 | def parse_version(ver: str): 18 | return version.parse(ver) 19 | 20 | 21 | def get_rank(): 22 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 23 | # therefore LOCAL_RANK needs to be checked first 24 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 25 | for key in rank_keys: 26 | rank = os.environ.get(key) 27 | if rank is not None: 28 | return int(rank) 29 | return 0 30 | 31 | 32 | def get_device(): 33 | return torch.device(f"cuda:{get_rank()}") 34 | 35 | 36 | def load_module_weights( 37 | path, module_name=None, ignore_modules=None, mapping=None, map_location=None 38 | ) -> Tuple[dict, int, int]: 39 | if module_name is not None and ignore_modules is not None: 40 | raise ValueError("module_name and ignore_modules cannot be both set") 41 | if map_location is None: 42 | map_location = get_device() 43 | 44 | ckpt = torch.load(path, map_location=map_location) 45 | state_dict = ckpt["state_dict"] 46 | 47 | if mapping is not None: 48 | state_dict_to_load = {} 49 | for k, v in state_dict.items(): 50 | if any([k.startswith(m["to"]) for m in mapping]): 51 | pass 52 | else: 53 | state_dict_to_load[k] = v 54 | for k, v in state_dict.items(): 55 | for m in mapping: 56 | if k.startswith(m["from"]): 57 | k_dest = k.replace(m["from"], m["to"]) 58 | info(f"Mapping {k} => {k_dest}") 59 | state_dict_to_load[k_dest] = v.clone() 60 | state_dict = state_dict_to_load 61 | 62 | state_dict_to_load = state_dict 63 | 64 | if ignore_modules is not None: 65 | state_dict_to_load = {} 66 | for k, v in state_dict.items(): 67 | ignore = any( 68 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] 69 | ) 70 | if ignore: 71 | continue 72 | state_dict_to_load[k] = v 73 | 74 | if module_name is not None: 75 | state_dict_to_load = {} 76 | for k, v in state_dict.items(): 77 | m = re.match(rf"^{module_name}\.(.*)$", k) 78 | if m is None: 79 | continue 80 | state_dict_to_load[m.group(1)] = v 81 | 82 | return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] 83 | 84 | 85 | def C(value: Any, epoch: int, global_step: int) -> float: 86 | if isinstance(value, int) or isinstance(value, float): 87 | pass 88 | else: 89 | value = config_to_primitive(value) 90 | if not isinstance(value, list): 91 | raise TypeError("Scalar specification only supports list, got", type(value)) 92 | if len(value) == 3: 93 | value = [0] + value 94 | assert len(value) == 4 95 | start_step, start_value, end_value, end_step = value 96 | if isinstance(end_step, int): 97 | current_step = global_step 98 | value = start_value + (end_value - start_value) * max( 99 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 100 | ) 101 | elif isinstance(end_step, float): 102 | current_step = epoch 103 | value = start_value + (end_value - start_value) * max( 104 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 105 | ) 106 | return value 107 | 108 | 109 | def cleanup(): 110 | gc.collect() 111 | torch.cuda.empty_cache() 112 | try: 113 | import tinycudann as tcnn 114 | 115 | tcnn.free_temporary_memory() 116 | except: 117 | pass 118 | 119 | 120 | def finish_with_cleanup(func: Callable): 121 | def wrapper(*args, **kwargs): 122 | out = func(*args, **kwargs) 123 | cleanup() 124 | return out 125 | 126 | return wrapper 127 | 128 | 129 | def _distributed_available(): 130 | return torch.distributed.is_available() and torch.distributed.is_initialized() 131 | 132 | 133 | def barrier(): 134 | if not _distributed_available(): 135 | return 136 | else: 137 | torch.distributed.barrier() 138 | 139 | 140 | def broadcast(tensor, src=0): 141 | if not _distributed_available(): 142 | return tensor 143 | else: 144 | torch.distributed.broadcast(tensor, src=src) 145 | return tensor 146 | 147 | 148 | def enable_gradient(model, enabled: bool = True) -> None: 149 | for param in model.parameters(): 150 | param.requires_grad_(enabled) 151 | 152 | 153 | class TimeRecorder: 154 | _instance = None 155 | 156 | def __init__(self): 157 | self.items = {} 158 | self.accumulations = defaultdict(list) 159 | self.time_scale = 1000.0 # ms 160 | self.time_unit = "ms" 161 | self.enabled = False 162 | 163 | def __new__(cls): 164 | # singleton 165 | if cls._instance is None: 166 | cls._instance = super(TimeRecorder, cls).__new__(cls) 167 | return cls._instance 168 | 169 | def enable(self, enabled: bool) -> None: 170 | self.enabled = enabled 171 | 172 | def start(self, name: str) -> None: 173 | if not self.enabled: 174 | return 175 | torch.cuda.synchronize() 176 | self.items[name] = time.time() 177 | 178 | def end(self, name: str, accumulate: bool = False) -> float: 179 | if not self.enabled or name not in self.items: 180 | return 181 | torch.cuda.synchronize() 182 | start_time = self.items.pop(name) 183 | delta = time.time() - start_time 184 | if accumulate: 185 | self.accumulations[name].append(delta) 186 | t = delta * self.time_scale 187 | info(f"{name}: {t:.2f}{self.time_unit}") 188 | 189 | def get_accumulation(self, name: str, average: bool = False) -> float: 190 | if not self.enabled or name not in self.accumulations: 191 | return 192 | acc = self.accumulations.pop(name) 193 | total = sum(acc) 194 | if average: 195 | t = total / len(acc) * self.time_scale 196 | else: 197 | t = total * self.time_scale 198 | info(f"{name} for {len(acc)} times: {t:.2f}{self.time_unit}") 199 | 200 | 201 | ### global time recorder 202 | time_recorder = TimeRecorder() 203 | 204 | 205 | @contextmanager 206 | def time_recorder_enabled(): 207 | enabled = time_recorder.enabled 208 | time_recorder.enable(enabled=True) 209 | try: 210 | yield 211 | finally: 212 | time_recorder.enable(enabled=enabled) 213 | 214 | 215 | def show_vram_usage(name): 216 | available, total = torch.cuda.mem_get_info() 217 | used = total - available 218 | print( 219 | f"{name}: {used / 1024**2:.1f}MB, {psutil.Process(os.getpid()).memory_info().rss / 1024**2:.1f}MB" 220 | ) 221 | -------------------------------------------------------------------------------- /mvadapter/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | ) 28 | 29 | # Tensor dtype 30 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 31 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 32 | 33 | # Config type 34 | from omegaconf import DictConfig, ListConfig 35 | 36 | # PyTorch Tensor type 37 | from torch import Tensor 38 | 39 | # Runtime type checking decorator 40 | from typeguard import typechecked as typechecker 41 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | controlnet_aux 2 | diffusers 3 | transformers 4 | peft 5 | numpy 6 | huggingface_hub 7 | accelerate 8 | opencv-python 9 | safetensors 10 | pillow 11 | omegaconf 12 | trimesh 13 | einops 14 | gradio 15 | timm 16 | kornia 17 | scikit-image 18 | sentencepiece 19 | git+https://github.com/NVlabs/nvdiffrast.git 20 | 21 | # new added for training 22 | pytorch-lightning 23 | 24 | # new added for texturing 25 | spandrel==0.4.1 26 | open3d 27 | pymeshlab 28 | cvcuda_cu12 -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/huanngzh/MV-Adapter/029c8186208eb860b18ab4cfe6c2ab70bf54909a/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/gradio_demo_i2mv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import gradio as gr 5 | import numpy as np 6 | 7 | # import spaces 8 | import torch 9 | from torchvision import transforms 10 | from transformers import AutoModelForImageSegmentation 11 | 12 | from .inference_i2mv_sdxl import prepare_pipeline, remove_bg, run_pipeline 13 | 14 | # Device and dtype 15 | dtype = torch.bfloat16 16 | device = "cuda" if torch.cuda.is_available() else "cpu" 17 | 18 | # Hyperparameters 19 | NUM_VIEWS = 6 20 | HEIGHT = 768 21 | WIDTH = 768 22 | MAX_SEED = np.iinfo(np.int32).max 23 | 24 | pipe = prepare_pipeline( 25 | base_model="stabilityai/stable-diffusion-xl-base-1.0", 26 | vae_model="madebyollin/sdxl-vae-fp16-fix", 27 | unet_model=None, 28 | lora_model=None, 29 | adapter_path="huanngzh/mv-adapter", 30 | scheduler=None, 31 | num_views=NUM_VIEWS, 32 | device=device, 33 | dtype=dtype, 34 | ) 35 | 36 | # remove bg 37 | birefnet = AutoModelForImageSegmentation.from_pretrained( 38 | "ZhengPeng7/BiRefNet", trust_remote_code=True 39 | ) 40 | birefnet.to(device) 41 | transform_image = transforms.Compose( 42 | [ 43 | transforms.Resize((1024, 1024)), 44 | transforms.ToTensor(), 45 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 46 | ] 47 | ) 48 | 49 | 50 | # @spaces.GPU() 51 | def infer( 52 | prompt, 53 | image, 54 | do_rembg=True, 55 | seed=42, 56 | randomize_seed=False, 57 | guidance_scale=3.0, 58 | num_inference_steps=50, 59 | reference_conditioning_scale=1.0, 60 | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", 61 | progress=gr.Progress(track_tqdm=True), 62 | ): 63 | if do_rembg: 64 | remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, device) 65 | else: 66 | remove_bg_fn = None 67 | if randomize_seed: 68 | seed = random.randint(0, MAX_SEED) 69 | images, preprocessed_image = run_pipeline( 70 | pipe, 71 | num_views=NUM_VIEWS, 72 | text=prompt, 73 | image=image, 74 | height=HEIGHT, 75 | width=WIDTH, 76 | num_inference_steps=num_inference_steps, 77 | guidance_scale=guidance_scale, 78 | seed=seed, 79 | remove_bg_fn=remove_bg_fn, 80 | reference_conditioning_scale=reference_conditioning_scale, 81 | negative_prompt=negative_prompt, 82 | device=device, 83 | ) 84 | return images, preprocessed_image, seed 85 | 86 | 87 | examples = [ 88 | [ 89 | "A decorative figurine of a young anime-style girl", 90 | "assets/demo/i2mv/A_decorative_figurine_of_a_young_anime-style_girl.png", 91 | True, 92 | 21, 93 | ], 94 | [ 95 | "A juvenile emperor penguin chick", 96 | "assets/demo/i2mv/A_juvenile_emperor_penguin_chick.png", 97 | True, 98 | 0, 99 | ], 100 | [ 101 | "A striped tabby cat with white fur sitting upright", 102 | "assets/demo/i2mv/A_striped_tabby_cat_with_white_fur_sitting_upright.png", 103 | True, 104 | 0, 105 | ], 106 | ] 107 | 108 | 109 | with gr.Blocks() as demo: 110 | with gr.Row(): 111 | gr.Markdown( 112 | f"""# MV-Adapter [Image-to-Multi-View] 113 | Generate 768x768 multi-view images from a single image using SDXL
114 | [[page](https://huanngzh.github.io/MV-Adapter-Page/)] [[repo](https://github.com/huanngzh/MV-Adapter)] 115 | """ 116 | ) 117 | 118 | with gr.Row(): 119 | with gr.Column(): 120 | with gr.Row(): 121 | input_image = gr.Image( 122 | label="Input Image", 123 | sources=["upload", "webcam", "clipboard"], 124 | type="pil", 125 | ) 126 | preprocessed_image = gr.Image(label="Preprocessed Image", type="pil") 127 | 128 | prompt = gr.Textbox( 129 | label="Prompt", placeholder="Enter your prompt", value="high quality" 130 | ) 131 | do_rembg = gr.Checkbox(label="Remove background", value=True) 132 | run_button = gr.Button("Run") 133 | 134 | with gr.Accordion("Advanced Settings", open=False): 135 | seed = gr.Slider( 136 | label="Seed", 137 | minimum=0, 138 | maximum=MAX_SEED, 139 | step=1, 140 | value=0, 141 | ) 142 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 143 | 144 | with gr.Row(): 145 | num_inference_steps = gr.Slider( 146 | label="Number of inference steps", 147 | minimum=1, 148 | maximum=50, 149 | step=1, 150 | value=50, 151 | ) 152 | 153 | with gr.Row(): 154 | guidance_scale = gr.Slider( 155 | label="CFG scale", 156 | minimum=0.0, 157 | maximum=10.0, 158 | step=0.1, 159 | value=3.0, 160 | ) 161 | 162 | with gr.Row(): 163 | reference_conditioning_scale = gr.Slider( 164 | label="Image conditioning scale", 165 | minimum=0.0, 166 | maximum=2.0, 167 | step=0.1, 168 | value=1.0, 169 | ) 170 | 171 | with gr.Row(): 172 | negative_prompt = gr.Textbox( 173 | label="Negative prompt", 174 | placeholder="Enter your negative prompt", 175 | value="watermark, ugly, deformed, noisy, blurry, low contrast", 176 | ) 177 | 178 | with gr.Column(): 179 | result = gr.Gallery( 180 | label="Result", 181 | show_label=False, 182 | columns=[3], 183 | rows=[2], 184 | object_fit="contain", 185 | height="auto", 186 | ) 187 | 188 | with gr.Row(): 189 | gr.Examples( 190 | examples=examples, 191 | fn=infer, 192 | inputs=[prompt, input_image, do_rembg, seed], 193 | outputs=[result, preprocessed_image, seed], 194 | cache_examples=True, 195 | ) 196 | 197 | gr.on( 198 | triggers=[run_button.click, prompt.submit], 199 | fn=infer, 200 | inputs=[ 201 | prompt, 202 | input_image, 203 | do_rembg, 204 | seed, 205 | randomize_seed, 206 | guidance_scale, 207 | num_inference_steps, 208 | reference_conditioning_scale, 209 | negative_prompt, 210 | ], 211 | outputs=[result, preprocessed_image, seed], 212 | ) 213 | 214 | demo.launch() 215 | -------------------------------------------------------------------------------- /scripts/gradio_demo_t2mv.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import gradio as gr 5 | import numpy as np 6 | 7 | # import spaces 8 | import torch 9 | 10 | from .inference_t2mv_sdxl import prepare_pipeline, run_pipeline 11 | 12 | # Base model 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument( 15 | "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" 16 | ) 17 | parser.add_argument("--scheduler", type=str, default=None) 18 | args = parser.parse_args() 19 | base_model = args.base_model 20 | scheduler = args.scheduler 21 | 22 | # Device and dtype 23 | dtype = torch.bfloat16 24 | device = "cuda" if torch.cuda.is_available() else "cpu" 25 | 26 | # Hyperparameters 27 | NUM_VIEWS = 6 28 | HEIGHT = 768 29 | WIDTH = 768 30 | MAX_SEED = np.iinfo(np.int32).max 31 | 32 | pipe = prepare_pipeline( 33 | base_model=base_model, 34 | vae_model="madebyollin/sdxl-vae-fp16-fix", 35 | unet_model=None, 36 | lora_model=None, 37 | adapter_path="huanngzh/mv-adapter", 38 | scheduler=scheduler, 39 | num_views=NUM_VIEWS, 40 | device=device, 41 | dtype=dtype, 42 | ) 43 | 44 | 45 | # @spaces.GPU() 46 | def infer( 47 | prompt, 48 | seed=42, 49 | randomize_seed=False, 50 | guidance_scale=7.0, 51 | num_inference_steps=50, 52 | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", 53 | progress=gr.Progress(track_tqdm=True), 54 | ): 55 | if randomize_seed: 56 | seed = random.randint(0, MAX_SEED) 57 | images = run_pipeline( 58 | pipe, 59 | num_views=NUM_VIEWS, 60 | text=prompt, 61 | height=HEIGHT, 62 | width=WIDTH, 63 | num_inference_steps=num_inference_steps, 64 | guidance_scale=guidance_scale, 65 | seed=seed, 66 | negative_prompt=negative_prompt, 67 | device=device, 68 | ) 69 | return images, seed 70 | 71 | 72 | examples = { 73 | "stabilityai/stable-diffusion-xl-base-1.0": [ 74 | ["An astronaut riding a horse", 42], 75 | ["A DSLR photo of a frog wearing a sweater", 42], 76 | ], 77 | "cagliostrolab/animagine-xl-3.1": [ 78 | [ 79 | "1girl, izayoi sakuya, touhou, solo, maid headdress, maid, apron, short sleeves, dress, closed mouth, white apron, serious face, upper body, masterpiece, best quality, very aesthetic, absurdres", 80 | 0, 81 | ], 82 | [ 83 | "1boy, male focus, ikari shinji, neon genesis evangelion, solo, serious face,(masterpiece), (best quality), (ultra-detailed), very aesthetic, illustration, disheveled hair, moist skin, intricate details", 84 | 0, 85 | ], 86 | [ 87 | "1girl, pink hair, pink shirts, smile, shy, masterpiece, anime", 88 | 0, 89 | ], 90 | ], 91 | "Lykon/dreamshaper-xl-1-0": [ 92 | ["the warrior Aragorn from Lord of the Rings, film grain, 8k hd", 0], 93 | [ 94 | "Oil painting, masterpiece, regal, fancy. A well-dressed dog named Puproy Doggerson III wearing reading glasses types an important letter on a typewriter and enjoys a cup of coffee with the newspaper.", 95 | 42, 96 | ], 97 | ], 98 | } 99 | 100 | css = """ 101 | #col-container { 102 | margin: 0 auto; 103 | max-width: 600px; 104 | } 105 | """ 106 | 107 | with gr.Blocks(css=css) as demo: 108 | 109 | with gr.Column(elem_id="col-container"): 110 | gr.Markdown( 111 | f"""# MV-Adapter [Text-to-Multi-View] 112 | Generate 768x768 multi-view images using {base_model}
113 | [[page](https://huanngzh.github.io/MV-Adapter-Page/)] [[repo](https://github.com/huanngzh/MV-Adapter)] 114 | """ 115 | ) 116 | 117 | with gr.Row(): 118 | prompt = gr.Text( 119 | label="Prompt", 120 | show_label=False, 121 | max_lines=1, 122 | placeholder="Enter your prompt", 123 | container=False, 124 | ) 125 | 126 | run_button = gr.Button("Run", scale=0) 127 | 128 | result = gr.Gallery( 129 | label="Result", 130 | show_label=False, 131 | columns=[3], 132 | rows=[2], 133 | object_fit="contain", 134 | height="auto", 135 | ) 136 | 137 | with gr.Accordion("Advanced Settings", open=False): 138 | seed = gr.Slider( 139 | label="Seed", 140 | minimum=0, 141 | maximum=MAX_SEED, 142 | step=1, 143 | value=0, 144 | ) 145 | randomize_seed = gr.Checkbox(label="Randomize seed", value=True) 146 | 147 | with gr.Row(): 148 | num_inference_steps = gr.Slider( 149 | label="Number of inference steps", 150 | minimum=1, 151 | maximum=50, 152 | step=1, 153 | value=50, 154 | ) 155 | 156 | with gr.Row(): 157 | guidance_scale = gr.Slider( 158 | label="CFG scale", 159 | minimum=0.0, 160 | maximum=10.0, 161 | step=0.1, 162 | value=7.0, 163 | ) 164 | 165 | with gr.Row(): 166 | negative_prompt = gr.Textbox( 167 | label="Negative prompt", 168 | placeholder="Enter your negative prompt", 169 | value="watermark, ugly, deformed, noisy, blurry, low contrast", 170 | ) 171 | 172 | if base_model in examples: 173 | gr.Examples( 174 | examples=examples[base_model], 175 | fn=infer, 176 | inputs=[prompt, seed], 177 | outputs=[result, seed], 178 | cache_examples=True, 179 | ) 180 | 181 | gr.on( 182 | triggers=[run_button.click, prompt.submit], 183 | fn=infer, 184 | inputs=[ 185 | prompt, 186 | seed, 187 | randomize_seed, 188 | guidance_scale, 189 | num_inference_steps, 190 | negative_prompt, 191 | ], 192 | outputs=[result, seed], 193 | ) 194 | 195 | demo.launch() 196 | -------------------------------------------------------------------------------- /scripts/inference_i2mv_sd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel 6 | from PIL import Image 7 | from torchvision import transforms 8 | from transformers import AutoModelForImageSegmentation 9 | 10 | from mvadapter.pipelines.pipeline_mvadapter_i2mv_sd import MVAdapterI2MVSDPipeline 11 | from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler 12 | from mvadapter.utils.mesh_utils import get_orthogonal_camera 13 | from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho 14 | from mvadapter.utils import make_image_grid 15 | 16 | 17 | def prepare_pipeline( 18 | base_model, 19 | vae_model, 20 | unet_model, 21 | lora_model, 22 | adapter_path, 23 | scheduler, 24 | num_views, 25 | device, 26 | dtype, 27 | ): 28 | # Load vae and unet if provided 29 | pipe_kwargs = {} 30 | if vae_model is not None: 31 | pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) 32 | if unet_model is not None: 33 | pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) 34 | 35 | # Prepare pipeline 36 | pipe: MVAdapterI2MVSDPipeline 37 | pipe = MVAdapterI2MVSDPipeline.from_pretrained(base_model, **pipe_kwargs) 38 | 39 | # Load scheduler if provided 40 | scheduler_class = None 41 | if scheduler == "ddpm": 42 | scheduler_class = DDPMScheduler 43 | elif scheduler == "lcm": 44 | scheduler_class = LCMScheduler 45 | 46 | pipe.scheduler = ShiftSNRScheduler.from_scheduler( 47 | pipe.scheduler, 48 | shift_mode="interpolated", 49 | shift_scale=8.0, 50 | scheduler_class=scheduler_class, 51 | ) 52 | pipe.init_custom_adapter(num_views=num_views) 53 | pipe.load_custom_adapter( 54 | adapter_path, weight_name="mvadapter_i2mv_sd21.safetensors" 55 | ) 56 | 57 | pipe.to(device=device, dtype=dtype) 58 | pipe.cond_encoder.to(device=device, dtype=dtype) 59 | 60 | # load lora if provided 61 | if lora_model is not None: 62 | model_, name_ = lora_model.rsplit("/", 1) 63 | pipe.load_lora_weights(model_, weight_name=name_) 64 | 65 | # vae slicing for lower memory usage 66 | pipe.enable_vae_slicing() 67 | 68 | return pipe 69 | 70 | 71 | def remove_bg(image, net, transform, device): 72 | image_size = image.size 73 | input_images = transform(image).unsqueeze(0).to(device) 74 | with torch.no_grad(): 75 | preds = net(input_images)[-1].sigmoid().cpu() 76 | pred = preds[0].squeeze() 77 | pred_pil = transforms.ToPILImage()(pred) 78 | mask = pred_pil.resize(image_size) 79 | image.putalpha(mask) 80 | return image 81 | 82 | 83 | def preprocess_image(image: Image.Image, height, width): 84 | image = np.array(image) 85 | alpha = image[..., 3] > 0 86 | H, W = alpha.shape 87 | # get the bounding box of alpha 88 | y, x = np.where(alpha) 89 | y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) 90 | x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) 91 | image_center = image[y0:y1, x0:x1] 92 | # resize the longer side to H * 0.9 93 | H, W, _ = image_center.shape 94 | if H > W: 95 | W = int(W * (height * 0.9) / H) 96 | H = int(height * 0.9) 97 | else: 98 | H = int(H * (width * 0.9) / W) 99 | W = int(width * 0.9) 100 | image_center = np.array(Image.fromarray(image_center).resize((W, H))) 101 | # pad to H, W 102 | start_h = (height - H) // 2 103 | start_w = (width - W) // 2 104 | image = np.zeros((height, width, 4), dtype=np.uint8) 105 | image[start_h : start_h + H, start_w : start_w + W] = image_center 106 | image = image.astype(np.float32) / 255.0 107 | image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 108 | image = (image * 255).clip(0, 255).astype(np.uint8) 109 | image = Image.fromarray(image) 110 | 111 | return image 112 | 113 | 114 | def run_pipeline( 115 | pipe, 116 | num_views, 117 | text, 118 | image, 119 | height, 120 | width, 121 | num_inference_steps, 122 | guidance_scale, 123 | seed, 124 | remove_bg_fn=None, 125 | reference_conditioning_scale=1.0, 126 | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", 127 | lora_scale=1.0, 128 | device="cuda", 129 | ): 130 | # Prepare cameras 131 | cameras = get_orthogonal_camera( 132 | elevation_deg=[0, 0, 0, 0, 0, 0], 133 | distance=[1.8] * num_views, 134 | left=-0.55, 135 | right=0.55, 136 | bottom=-0.55, 137 | top=0.55, 138 | azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], 139 | device=device, 140 | ) 141 | 142 | plucker_embeds = get_plucker_embeds_from_cameras_ortho( 143 | cameras.c2w, [1.1] * num_views, width 144 | ) 145 | control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) 146 | 147 | # Prepare image 148 | reference_image = Image.open(image) if isinstance(image, str) else image 149 | if remove_bg_fn is not None: 150 | reference_image = remove_bg_fn(reference_image) 151 | reference_image = preprocess_image(reference_image, height, width) 152 | elif reference_image.mode == "RGBA": 153 | reference_image = preprocess_image(reference_image, height, width) 154 | 155 | pipe_kwargs = {} 156 | if seed != -1 and isinstance(seed, int): 157 | pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) 158 | 159 | images = pipe( 160 | text, 161 | height=height, 162 | width=width, 163 | num_inference_steps=num_inference_steps, 164 | guidance_scale=guidance_scale, 165 | num_images_per_prompt=num_views, 166 | control_image=control_images, 167 | control_conditioning_scale=1.0, 168 | reference_image=reference_image, 169 | reference_conditioning_scale=reference_conditioning_scale, 170 | negative_prompt=negative_prompt, 171 | cross_attention_kwargs={"scale": lora_scale}, 172 | **pipe_kwargs, 173 | ).images 174 | 175 | return images, reference_image 176 | 177 | 178 | if __name__ == "__main__": 179 | parser = argparse.ArgumentParser() 180 | # Models 181 | parser.add_argument( 182 | "--base_model", type=str, default="stabilityai/stable-diffusion-2-1-base" 183 | ) 184 | parser.add_argument("--vae_model", type=str, default=None) 185 | parser.add_argument("--unet_model", type=str, default=None) 186 | parser.add_argument("--scheduler", type=str, default=None) 187 | parser.add_argument("--lora_model", type=str, default=None) 188 | parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") 189 | parser.add_argument("--num_views", type=int, default=6) 190 | # Device 191 | parser.add_argument("--device", type=str, default="cuda") 192 | # Inference 193 | parser.add_argument("--image", type=str, required=True) 194 | parser.add_argument("--text", type=str, default="high quality") 195 | parser.add_argument("--num_inference_steps", type=int, default=50) 196 | parser.add_argument("--guidance_scale", type=float, default=3.0) 197 | parser.add_argument("--seed", type=int, default=-1) 198 | parser.add_argument("--lora_scale", type=float, default=1.0) 199 | parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) 200 | parser.add_argument( 201 | "--negative_prompt", 202 | type=str, 203 | default="watermark, ugly, deformed, noisy, blurry, low contrast", 204 | ) 205 | parser.add_argument("--output", type=str, default="output.png") 206 | # Extra 207 | parser.add_argument("--remove_bg", action="store_true", help="Remove background") 208 | args = parser.parse_args() 209 | 210 | pipe = prepare_pipeline( 211 | base_model=args.base_model, 212 | vae_model=args.vae_model, 213 | unet_model=args.unet_model, 214 | lora_model=args.lora_model, 215 | adapter_path=args.adapter_path, 216 | scheduler=args.scheduler, 217 | num_views=args.num_views, 218 | device=args.device, 219 | dtype=torch.float16, 220 | ) 221 | 222 | if args.remove_bg: 223 | birefnet = AutoModelForImageSegmentation.from_pretrained( 224 | "ZhengPeng7/BiRefNet", trust_remote_code=True 225 | ) 226 | birefnet.to(args.device) 227 | transform_image = transforms.Compose( 228 | [ 229 | transforms.Resize((1024, 1024)), 230 | transforms.ToTensor(), 231 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 232 | ] 233 | ) 234 | remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) 235 | else: 236 | remove_bg_fn = None 237 | 238 | images, reference_image = run_pipeline( 239 | pipe, 240 | num_views=args.num_views, 241 | text=args.text, 242 | image=args.image, 243 | height=512, 244 | width=512, 245 | num_inference_steps=args.num_inference_steps, 246 | guidance_scale=args.guidance_scale, 247 | seed=args.seed, 248 | lora_scale=args.lora_scale, 249 | reference_conditioning_scale=args.reference_conditioning_scale, 250 | negative_prompt=args.negative_prompt, 251 | device=args.device, 252 | remove_bg_fn=remove_bg_fn, 253 | ) 254 | make_image_grid(images, rows=1).save(args.output) 255 | reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png") 256 | -------------------------------------------------------------------------------- /scripts/inference_i2mv_sdxl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel 6 | from PIL import Image 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | from transformers import AutoModelForImageSegmentation 10 | 11 | from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline 12 | from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler 13 | from mvadapter.utils.mesh_utils import get_orthogonal_camera 14 | from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho 15 | from mvadapter.utils import make_image_grid 16 | 17 | 18 | def prepare_pipeline( 19 | base_model, 20 | vae_model, 21 | unet_model, 22 | lora_model, 23 | adapter_path, 24 | scheduler, 25 | num_views, 26 | device, 27 | dtype, 28 | ): 29 | # Load vae and unet if provided 30 | pipe_kwargs = {} 31 | if vae_model is not None: 32 | pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) 33 | if unet_model is not None: 34 | pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) 35 | 36 | # Prepare pipeline 37 | pipe: MVAdapterI2MVSDXLPipeline 38 | pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) 39 | 40 | # Load scheduler if provided 41 | scheduler_class = None 42 | if scheduler == "ddpm": 43 | scheduler_class = DDPMScheduler 44 | elif scheduler == "lcm": 45 | scheduler_class = LCMScheduler 46 | 47 | pipe.scheduler = ShiftSNRScheduler.from_scheduler( 48 | pipe.scheduler, 49 | shift_mode="interpolated", 50 | shift_scale=8.0, 51 | scheduler_class=scheduler_class, 52 | ) 53 | pipe.init_custom_adapter(num_views=num_views) 54 | pipe.load_custom_adapter( 55 | adapter_path, weight_name="mvadapter_i2mv_sdxl.safetensors" 56 | ) 57 | 58 | pipe.to(device=device, dtype=dtype) 59 | pipe.cond_encoder.to(device=device, dtype=dtype) 60 | 61 | # load lora if provided 62 | if lora_model is not None: 63 | model_, name_ = lora_model.rsplit("/", 1) 64 | pipe.load_lora_weights(model_, weight_name=name_) 65 | 66 | # vae slicing for lower memory usage 67 | pipe.enable_vae_slicing() 68 | 69 | return pipe 70 | 71 | 72 | def remove_bg(image, net, transform, device): 73 | image_size = image.size 74 | input_images = transform(image).unsqueeze(0).to(device) 75 | with torch.no_grad(): 76 | preds = net(input_images)[-1].sigmoid().cpu() 77 | pred = preds[0].squeeze() 78 | pred_pil = transforms.ToPILImage()(pred) 79 | mask = pred_pil.resize(image_size) 80 | image.putalpha(mask) 81 | return image 82 | 83 | 84 | def preprocess_image(image: Image.Image, height, width): 85 | image = np.array(image) 86 | alpha = image[..., 3] > 0 87 | H, W = alpha.shape 88 | # get the bounding box of alpha 89 | y, x = np.where(alpha) 90 | y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) 91 | x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) 92 | image_center = image[y0:y1, x0:x1] 93 | # resize the longer side to H * 0.9 94 | H, W, _ = image_center.shape 95 | if H > W: 96 | W = int(W * (height * 0.9) / H) 97 | H = int(height * 0.9) 98 | else: 99 | H = int(H * (width * 0.9) / W) 100 | W = int(width * 0.9) 101 | image_center = np.array(Image.fromarray(image_center).resize((W, H))) 102 | # pad to H, W 103 | start_h = (height - H) // 2 104 | start_w = (width - W) // 2 105 | image = np.zeros((height, width, 4), dtype=np.uint8) 106 | image[start_h : start_h + H, start_w : start_w + W] = image_center 107 | image = image.astype(np.float32) / 255.0 108 | image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 109 | image = (image * 255).clip(0, 255).astype(np.uint8) 110 | image = Image.fromarray(image) 111 | 112 | return image 113 | 114 | 115 | def run_pipeline( 116 | pipe, 117 | num_views, 118 | text, 119 | image, 120 | height, 121 | width, 122 | num_inference_steps, 123 | guidance_scale, 124 | seed, 125 | remove_bg_fn=None, 126 | reference_conditioning_scale=1.0, 127 | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", 128 | lora_scale=1.0, 129 | device="cuda", 130 | azimuth_deg=None, 131 | ): 132 | # Prepare cameras 133 | if azimuth_deg is None: 134 | azimuth_deg = [0, 45, 90, 180, 270, 315] 135 | cameras = get_orthogonal_camera( 136 | elevation_deg=[0] * num_views, 137 | distance=[1.8] * num_views, 138 | left=-0.55, 139 | right=0.55, 140 | bottom=-0.55, 141 | top=0.55, 142 | azimuth_deg=[x - 90 for x in azimuth_deg], 143 | device=device, 144 | ) 145 | 146 | plucker_embeds = get_plucker_embeds_from_cameras_ortho( 147 | cameras.c2w, [1.1] * num_views, width 148 | ) 149 | control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) 150 | 151 | # Prepare image 152 | reference_image = Image.open(image) if isinstance(image, str) else image 153 | if remove_bg_fn is not None: 154 | reference_image = remove_bg_fn(reference_image) 155 | reference_image = preprocess_image(reference_image, height, width) 156 | elif reference_image.mode == "RGBA": 157 | reference_image = preprocess_image(reference_image, height, width) 158 | 159 | pipe_kwargs = {} 160 | if seed != -1 and isinstance(seed, int): 161 | pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) 162 | 163 | images = pipe( 164 | text, 165 | height=height, 166 | width=width, 167 | num_inference_steps=num_inference_steps, 168 | guidance_scale=guidance_scale, 169 | num_images_per_prompt=num_views, 170 | control_image=control_images, 171 | control_conditioning_scale=1.0, 172 | reference_image=reference_image, 173 | reference_conditioning_scale=reference_conditioning_scale, 174 | negative_prompt=negative_prompt, 175 | cross_attention_kwargs={"scale": lora_scale}, 176 | **pipe_kwargs, 177 | ).images 178 | 179 | return images, reference_image 180 | 181 | 182 | if __name__ == "__main__": 183 | parser = argparse.ArgumentParser() 184 | # Models 185 | parser.add_argument( 186 | "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" 187 | ) 188 | parser.add_argument( 189 | "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" 190 | ) 191 | parser.add_argument("--unet_model", type=str, default=None) 192 | parser.add_argument("--scheduler", type=str, default=None) 193 | parser.add_argument("--lora_model", type=str, default=None) 194 | parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") 195 | # Device 196 | parser.add_argument("--device", type=str, default="cuda") 197 | # Inference 198 | parser.add_argument("--num_views", type=int, default=6) # not used 199 | parser.add_argument( 200 | "--azimuth_deg", type=int, nargs="+", default=[0, 45, 90, 180, 270, 315] 201 | ) 202 | parser.add_argument("--image", type=str, required=True) 203 | parser.add_argument("--text", type=str, default="high quality") 204 | parser.add_argument("--num_inference_steps", type=int, default=50) 205 | parser.add_argument("--guidance_scale", type=float, default=3.0) 206 | parser.add_argument("--seed", type=int, default=-1) 207 | parser.add_argument("--lora_scale", type=float, default=1.0) 208 | parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) 209 | parser.add_argument( 210 | "--negative_prompt", 211 | type=str, 212 | default="watermark, ugly, deformed, noisy, blurry, low contrast", 213 | ) 214 | parser.add_argument("--output", type=str, default="output.png") 215 | # Extra 216 | parser.add_argument("--remove_bg", action="store_true", help="Remove background") 217 | args = parser.parse_args() 218 | 219 | num_views = len(args.azimuth_deg) 220 | 221 | pipe = prepare_pipeline( 222 | base_model=args.base_model, 223 | vae_model=args.vae_model, 224 | unet_model=args.unet_model, 225 | lora_model=args.lora_model, 226 | adapter_path=args.adapter_path, 227 | scheduler=args.scheduler, 228 | num_views=num_views, 229 | device=args.device, 230 | dtype=torch.float16, 231 | ) 232 | 233 | if args.remove_bg: 234 | birefnet = AutoModelForImageSegmentation.from_pretrained( 235 | "ZhengPeng7/BiRefNet", trust_remote_code=True 236 | ) 237 | birefnet.to(args.device) 238 | transform_image = transforms.Compose( 239 | [ 240 | transforms.Resize((1024, 1024)), 241 | transforms.ToTensor(), 242 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 243 | ] 244 | ) 245 | remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) 246 | else: 247 | remove_bg_fn = None 248 | 249 | images, reference_image = run_pipeline( 250 | pipe, 251 | num_views=num_views, 252 | text=args.text, 253 | image=args.image, 254 | height=768, 255 | width=768, 256 | num_inference_steps=args.num_inference_steps, 257 | guidance_scale=args.guidance_scale, 258 | seed=args.seed, 259 | lora_scale=args.lora_scale, 260 | reference_conditioning_scale=args.reference_conditioning_scale, 261 | negative_prompt=args.negative_prompt, 262 | device=args.device, 263 | remove_bg_fn=remove_bg_fn, 264 | azimuth_deg=args.azimuth_deg, 265 | ) 266 | make_image_grid(images, rows=1).save(args.output) 267 | reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png") 268 | -------------------------------------------------------------------------------- /scripts/inference_ig2mv_sdxl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import numpy as np 4 | import torch 5 | from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel 6 | from PIL import Image 7 | from torchvision import transforms 8 | from tqdm import tqdm 9 | from transformers import AutoModelForImageSegmentation 10 | 11 | from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0 12 | from mvadapter.pipelines.pipeline_mvadapter_i2mv_sdxl import MVAdapterI2MVSDXLPipeline 13 | from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler 14 | from mvadapter.utils import make_image_grid, tensor_to_image 15 | from mvadapter.utils.mesh_utils import ( 16 | NVDiffRastContextWrapper, 17 | get_orthogonal_camera, 18 | load_mesh, 19 | render, 20 | ) 21 | 22 | 23 | def prepare_pipeline( 24 | base_model, 25 | vae_model, 26 | unet_model, 27 | lora_model, 28 | adapter_path, 29 | scheduler, 30 | num_views, 31 | device, 32 | dtype, 33 | ): 34 | # Load vae and unet if provided 35 | pipe_kwargs = {} 36 | if vae_model is not None: 37 | pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) 38 | if unet_model is not None: 39 | pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) 40 | 41 | # Prepare pipeline 42 | pipe: MVAdapterI2MVSDXLPipeline 43 | pipe = MVAdapterI2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) 44 | 45 | # Load scheduler if provided 46 | scheduler_class = None 47 | if scheduler == "ddpm": 48 | scheduler_class = DDPMScheduler 49 | elif scheduler == "lcm": 50 | scheduler_class = LCMScheduler 51 | 52 | pipe.scheduler = ShiftSNRScheduler.from_scheduler( 53 | pipe.scheduler, 54 | shift_mode="interpolated", 55 | shift_scale=8.0, 56 | scheduler_class=scheduler_class, 57 | ) 58 | pipe.init_custom_adapter( 59 | num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0 60 | ) 61 | pipe.load_custom_adapter( 62 | adapter_path, weight_name="mvadapter_ig2mv_sdxl.safetensors" 63 | ) 64 | 65 | pipe.to(device=device, dtype=dtype) 66 | pipe.cond_encoder.to(device=device, dtype=dtype) 67 | 68 | # load lora if provided 69 | if lora_model is not None: 70 | model_, name_ = lora_model.rsplit("/", 1) 71 | pipe.load_lora_weights(model_, weight_name=name_) 72 | 73 | return pipe 74 | 75 | 76 | def remove_bg(image, net, transform, device): 77 | image_size = image.size 78 | input_images = transform(image).unsqueeze(0).to(device) 79 | with torch.no_grad(): 80 | preds = net(input_images)[-1].sigmoid().cpu() 81 | pred = preds[0].squeeze() 82 | pred_pil = transforms.ToPILImage()(pred) 83 | mask = pred_pil.resize(image_size) 84 | image.putalpha(mask) 85 | return image 86 | 87 | 88 | def preprocess_image(image: Image.Image, height, width): 89 | image = np.array(image) 90 | alpha = image[..., 3] > 0 91 | H, W = alpha.shape 92 | # get the bounding box of alpha 93 | y, x = np.where(alpha) 94 | y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H) 95 | x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W) 96 | image_center = image[y0:y1, x0:x1] 97 | # resize the longer side to H * 0.9 98 | H, W, _ = image_center.shape 99 | if H > W: 100 | W = int(W * (height * 0.9) / H) 101 | H = int(height * 0.9) 102 | else: 103 | H = int(H * (width * 0.9) / W) 104 | W = int(width * 0.9) 105 | image_center = np.array(Image.fromarray(image_center).resize((W, H))) 106 | # pad to H, W 107 | start_h = (height - H) // 2 108 | start_w = (width - W) // 2 109 | image = np.zeros((height, width, 4), dtype=np.uint8) 110 | image[start_h : start_h + H, start_w : start_w + W] = image_center 111 | image = image.astype(np.float32) / 255.0 112 | image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5 113 | image = (image * 255).clip(0, 255).astype(np.uint8) 114 | image = Image.fromarray(image) 115 | 116 | return image 117 | 118 | 119 | def run_pipeline( 120 | pipe, 121 | mesh_path, 122 | num_views, 123 | text, 124 | image, 125 | height, 126 | width, 127 | num_inference_steps, 128 | guidance_scale, 129 | seed, 130 | remove_bg_fn=None, 131 | reference_conditioning_scale=1.0, 132 | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", 133 | lora_scale=1.0, 134 | device="cuda", 135 | ): 136 | # Prepare cameras 137 | cameras = get_orthogonal_camera( 138 | elevation_deg=[0, 0, 0, 0, 89.99, -89.99], 139 | distance=[1.8] * num_views, 140 | left=-0.55, 141 | right=0.55, 142 | bottom=-0.55, 143 | top=0.55, 144 | azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]], 145 | device=device, 146 | ) 147 | ctx = NVDiffRastContextWrapper(device=device) 148 | 149 | mesh = load_mesh(mesh_path, rescale=True, device=device) 150 | render_out = render( 151 | ctx, 152 | mesh, 153 | cameras, 154 | height=height, 155 | width=width, 156 | render_attr=False, 157 | normal_background=0.0, 158 | ) 159 | pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True) 160 | normal_images = tensor_to_image( 161 | (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True 162 | ) 163 | control_images = ( 164 | torch.cat( 165 | [ 166 | (render_out.pos + 0.5).clamp(0, 1), 167 | (render_out.normal / 2 + 0.5).clamp(0, 1), 168 | ], 169 | dim=-1, 170 | ) 171 | .permute(0, 3, 1, 2) 172 | .to(device) 173 | ) 174 | 175 | # Prepare image 176 | reference_image = Image.open(image) if isinstance(image, str) else image 177 | if remove_bg_fn is not None: 178 | reference_image = remove_bg_fn(reference_image) 179 | reference_image = preprocess_image(reference_image, height, width) 180 | elif reference_image.mode == "RGBA": 181 | reference_image = preprocess_image(reference_image, height, width) 182 | 183 | pipe_kwargs = {} 184 | if seed != -1 and isinstance(seed, int): 185 | pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) 186 | 187 | images = pipe( 188 | text, 189 | height=height, 190 | width=width, 191 | num_inference_steps=num_inference_steps, 192 | guidance_scale=guidance_scale, 193 | num_images_per_prompt=num_views, 194 | control_image=control_images, 195 | control_conditioning_scale=1.0, 196 | reference_image=reference_image, 197 | reference_conditioning_scale=reference_conditioning_scale, 198 | negative_prompt=negative_prompt, 199 | cross_attention_kwargs={"scale": lora_scale}, 200 | **pipe_kwargs, 201 | ).images 202 | 203 | return images, pos_images, normal_images, reference_image 204 | 205 | 206 | if __name__ == "__main__": 207 | parser = argparse.ArgumentParser() 208 | # Models 209 | parser.add_argument( 210 | "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" 211 | ) 212 | parser.add_argument( 213 | "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" 214 | ) 215 | parser.add_argument("--unet_model", type=str, default=None) 216 | parser.add_argument("--scheduler", type=str, default=None) 217 | parser.add_argument("--lora_model", type=str, default=None) 218 | parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") 219 | parser.add_argument("--num_views", type=int, default=6) 220 | # Device 221 | parser.add_argument("--device", type=str, default="cuda") 222 | # Inference 223 | parser.add_argument("--mesh", type=str, required=True) 224 | parser.add_argument("--image", type=str, required=True) 225 | parser.add_argument("--text", type=str, required=False, default="high quality") 226 | parser.add_argument("--num_inference_steps", type=int, default=50) 227 | parser.add_argument("--guidance_scale", type=float, default=3.0) 228 | parser.add_argument("--seed", type=int, default=-1) 229 | parser.add_argument("--lora_scale", type=float, default=1.0) 230 | parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) 231 | parser.add_argument( 232 | "--negative_prompt", 233 | type=str, 234 | default="watermark, ugly, deformed, noisy, blurry, low contrast", 235 | ) 236 | parser.add_argument("--output", type=str, default="output.png") 237 | # Extra 238 | parser.add_argument("--remove_bg", action="store_true", help="Remove background") 239 | args = parser.parse_args() 240 | 241 | pipe = prepare_pipeline( 242 | base_model=args.base_model, 243 | vae_model=args.vae_model, 244 | unet_model=args.unet_model, 245 | lora_model=args.lora_model, 246 | adapter_path=args.adapter_path, 247 | scheduler=args.scheduler, 248 | num_views=args.num_views, 249 | device=args.device, 250 | dtype=torch.float16, 251 | ) 252 | 253 | if args.remove_bg: 254 | birefnet = AutoModelForImageSegmentation.from_pretrained( 255 | "ZhengPeng7/BiRefNet", trust_remote_code=True 256 | ) 257 | birefnet.to(args.device) 258 | transform_image = transforms.Compose( 259 | [ 260 | transforms.Resize((1024, 1024)), 261 | transforms.ToTensor(), 262 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 263 | ] 264 | ) 265 | remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) 266 | else: 267 | remove_bg_fn = None 268 | 269 | images, pos_images, normal_images, reference_image = run_pipeline( 270 | pipe, 271 | mesh_path=args.mesh, 272 | num_views=args.num_views, 273 | text=args.text, 274 | image=args.image, 275 | height=768, 276 | width=768, 277 | num_inference_steps=args.num_inference_steps, 278 | guidance_scale=args.guidance_scale, 279 | seed=args.seed, 280 | lora_scale=args.lora_scale, 281 | reference_conditioning_scale=args.reference_conditioning_scale, 282 | negative_prompt=args.negative_prompt, 283 | device=args.device, 284 | remove_bg_fn=remove_bg_fn, 285 | ) 286 | make_image_grid(images, rows=1).save(args.output) 287 | make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png") 288 | make_image_grid(normal_images, rows=1).save( 289 | args.output.rsplit(".", 1)[0] + "_nor.png" 290 | ) 291 | reference_image.save(args.output.rsplit(".", 1)[0] + "_reference.png") 292 | -------------------------------------------------------------------------------- /scripts/inference_scribble2mv_sdxl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | from controlnet_aux import HEDdetector, PidiNetDetector 8 | from diffusers import ( 9 | AutoencoderKL, 10 | ControlNetModel, 11 | DDPMScheduler, 12 | LCMScheduler, 13 | UNet2DConditionModel, 14 | ) 15 | from PIL import Image 16 | 17 | from mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline 18 | from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler 19 | from mvadapter.utils.mesh_utils import get_orthogonal_camera 20 | from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho 21 | from mvadapter.utils import make_image_grid 22 | 23 | 24 | def prepare_pipeline( 25 | base_model, 26 | vae_model, 27 | unet_model, 28 | lora_model, 29 | adapter_path, 30 | scheduler, 31 | num_views, 32 | device, 33 | dtype, 34 | ): 35 | # Load vae and unet if provided 36 | pipe_kwargs = {} 37 | if vae_model is not None: 38 | pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) 39 | if unet_model is not None: 40 | pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) 41 | 42 | # Prepare pipeline 43 | pipe: MVAdapterT2MVSDXLPipeline 44 | pipe = MVAdapterT2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) 45 | 46 | # Load scheduler if provided 47 | scheduler_class = None 48 | if scheduler == "ddpm": 49 | scheduler_class = DDPMScheduler 50 | elif scheduler == "lcm": 51 | scheduler_class = LCMScheduler 52 | 53 | pipe.scheduler = ShiftSNRScheduler.from_scheduler( 54 | pipe.scheduler, 55 | shift_mode="interpolated", 56 | shift_scale=8.0, 57 | scheduler_class=scheduler_class, 58 | ) 59 | pipe.init_custom_adapter(num_views=num_views) 60 | pipe.load_custom_adapter( 61 | adapter_path, weight_name="mvadapter_t2mv_sdxl.safetensors" 62 | ) 63 | 64 | # ControlNet 65 | pipe.controlnet = ControlNetModel.from_pretrained( 66 | "xinsir/controlnet-scribble-sdxl-1.0" 67 | ) 68 | 69 | pipe.to(device=device, dtype=dtype) 70 | pipe.cond_encoder.to(device=device, dtype=dtype) 71 | pipe.controlnet.to(device=device, dtype=dtype) 72 | 73 | # load lora if provided 74 | if lora_model is not None: 75 | model_, name_ = lora_model.rsplit("/", 1) 76 | pipe.load_lora_weights(model_, weight_name=name_) 77 | 78 | # vae slicing for lower memory usage 79 | pipe.enable_vae_slicing() 80 | 81 | return pipe 82 | 83 | 84 | def nms(x, t, s): 85 | x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) 86 | 87 | f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) 88 | f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) 89 | f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) 90 | f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) 91 | 92 | y = np.zeros_like(x) 93 | 94 | for f in [f1, f2, f3, f4]: 95 | np.putmask(y, cv2.dilate(x, kernel=f) == x, x) 96 | 97 | z = np.zeros_like(y, dtype=np.uint8) 98 | z[y > t] = 255 99 | return z 100 | 101 | 102 | def preprocess_controlnet_image(image_path, height, width): 103 | image = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) 104 | image = cv2.resize(image, (width, height)) 105 | 106 | if image.shape[2] == 4: 107 | alpha_channel = image[:, :, 3] / 255.0 108 | rgb_channels = image[:, :, :3] / 255.0 109 | 110 | gray_background = np.ones_like(rgb_channels) * 0.5 111 | 112 | image = ( 113 | alpha_channel[..., None] * rgb_channels 114 | + (1 - alpha_channel[..., None]) * gray_background 115 | ) 116 | image = (image * 255).astype(np.uint8) 117 | 118 | processor = HEDdetector.from_pretrained("lllyasviel/Annotators") 119 | image = processor(image, scribble=False) 120 | 121 | # following is some processing to simulate human sketch draw, different threshold can generate different width of lines 122 | image = np.array(image) 123 | image = nms(image, 127, 3) 124 | image = cv2.GaussianBlur(image, (0, 0), 3) 125 | 126 | # higher threshold, thiner line 127 | random_val = int(round(random.uniform(0.01, 0.10), 2) * 255) 128 | image[image > random_val] = 255 129 | image[image < 255] = 0 130 | 131 | return Image.fromarray(image) 132 | 133 | 134 | def run_pipeline( 135 | pipe, 136 | num_views, 137 | text, 138 | height, 139 | width, 140 | num_inference_steps, 141 | guidance_scale, 142 | seed, 143 | controlnet_images, 144 | controlnet_conditioning_scale, 145 | lora_scale=1.0, 146 | device="cuda", 147 | ): 148 | # Prepare cameras 149 | cameras = get_orthogonal_camera( 150 | elevation_deg=[0, 0, 0, 0, 0, 0], 151 | distance=[1.8] * num_views, 152 | left=-0.55, 153 | right=0.55, 154 | bottom=-0.55, 155 | top=0.55, 156 | azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], 157 | device=device, 158 | ) 159 | 160 | plucker_embeds = get_plucker_embeds_from_cameras_ortho( 161 | cameras.c2w, [1.1] * num_views, width 162 | ) 163 | control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) 164 | 165 | pipe_kwargs = {} 166 | if seed != -1: 167 | pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) 168 | 169 | # Prepare controlnet images 170 | controlnet_image = [ 171 | preprocess_controlnet_image(path, height, width) for path in controlnet_images 172 | ] 173 | pipe_kwargs.update( 174 | { 175 | "controlnet_image": controlnet_image, 176 | "controlnet_conditioning_scale": controlnet_conditioning_scale, 177 | } 178 | ) 179 | 180 | images = pipe( 181 | text, 182 | height=height, 183 | width=width, 184 | num_inference_steps=num_inference_steps, 185 | guidance_scale=guidance_scale, 186 | num_images_per_prompt=num_views, 187 | control_image=control_images, 188 | control_conditioning_scale=1.0, 189 | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", 190 | cross_attention_kwargs={"scale": lora_scale}, 191 | **pipe_kwargs, 192 | ).images 193 | 194 | return images, controlnet_image 195 | 196 | 197 | if __name__ == "__main__": 198 | parser = argparse.ArgumentParser() 199 | # Models 200 | parser.add_argument( 201 | "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" 202 | ) 203 | parser.add_argument( 204 | "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" 205 | ) 206 | parser.add_argument("--unet_model", type=str, default=None) 207 | parser.add_argument("--scheduler", type=str, default=None) 208 | parser.add_argument("--lora_model", type=str, default=None) 209 | parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") 210 | parser.add_argument("--num_views", type=int, default=6) 211 | # Device 212 | parser.add_argument("--device", type=str, default="cuda") 213 | # Inference 214 | parser.add_argument("--text", type=str, required=True) 215 | parser.add_argument("--num_inference_steps", type=int, default=50) 216 | parser.add_argument("--guidance_scale", type=float, default=7.0) 217 | parser.add_argument("--seed", type=int, default=-1) 218 | parser.add_argument("--lora_scale", type=float, default=1.0) 219 | parser.add_argument("--output", type=str, default="output.png") 220 | parser.add_argument("--controlnet_images", type=str, nargs="+", required=True) 221 | parser.add_argument("--controlnet_conditioning_scale", type=float, default=1.0) 222 | args = parser.parse_args() 223 | 224 | pipe = prepare_pipeline( 225 | base_model=args.base_model, 226 | vae_model=args.vae_model, 227 | unet_model=args.unet_model, 228 | lora_model=args.lora_model, 229 | adapter_path=args.adapter_path, 230 | scheduler=args.scheduler, 231 | num_views=args.num_views, 232 | device=args.device, 233 | dtype=torch.float16, 234 | ) 235 | images, controlnet_images = run_pipeline( 236 | pipe, 237 | num_views=args.num_views, 238 | text=args.text, 239 | height=768, 240 | width=768, 241 | num_inference_steps=args.num_inference_steps, 242 | guidance_scale=args.guidance_scale, 243 | seed=args.seed, 244 | controlnet_images=args.controlnet_images, 245 | controlnet_conditioning_scale=args.controlnet_conditioning_scale, 246 | lora_scale=args.lora_scale, 247 | device=args.device, 248 | ) 249 | make_image_grid(images, rows=1).save(args.output) 250 | make_image_grid(controlnet_images, rows=1).save( 251 | args.output.rsplit(".", 1)[0] + "_controlnet.png" 252 | ) 253 | -------------------------------------------------------------------------------- /scripts/inference_t2mv_sd.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel 5 | 6 | from mvadapter.pipelines.pipeline_mvadapter_t2mv_sd import MVAdapterT2MVSDPipeline 7 | from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler 8 | from mvadapter.utils.mesh_utils import get_orthogonal_camera 9 | from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho 10 | from mvadapter.utils import make_image_grid 11 | 12 | 13 | def prepare_pipeline( 14 | base_model, 15 | vae_model, 16 | unet_model, 17 | lora_model, 18 | adapter_path, 19 | scheduler, 20 | num_views, 21 | device, 22 | dtype, 23 | ): 24 | # Load vae and unet if provided 25 | pipe_kwargs = {} 26 | if vae_model is not None: 27 | pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) 28 | if unet_model is not None: 29 | pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) 30 | 31 | # Prepare pipeline 32 | pipe: MVAdapterT2MVSDPipeline 33 | pipe = MVAdapterT2MVSDPipeline.from_pretrained(base_model, **pipe_kwargs) 34 | 35 | # Load scheduler if provided 36 | scheduler_class = None 37 | if scheduler == "ddpm": 38 | scheduler_class = DDPMScheduler 39 | elif scheduler == "lcm": 40 | scheduler_class = LCMScheduler 41 | 42 | pipe.scheduler = ShiftSNRScheduler.from_scheduler( 43 | pipe.scheduler, 44 | shift_mode="interpolated", 45 | shift_scale=8.0, 46 | scheduler_class=scheduler_class, 47 | ) 48 | pipe.init_custom_adapter(num_views=num_views) 49 | pipe.load_custom_adapter( 50 | adapter_path, weight_name="mvadapter_t2mv_sd21.safetensors" 51 | ) 52 | 53 | pipe.to(device=device, dtype=dtype) 54 | pipe.cond_encoder.to(device=device, dtype=dtype) 55 | 56 | # load lora if provided 57 | if lora_model is not None: 58 | model_, name_ = lora_model.rsplit("/", 1) 59 | pipe.load_lora_weights(model_, weight_name=name_) 60 | 61 | # vae slicing for lower memory usage 62 | pipe.enable_vae_slicing() 63 | 64 | return pipe 65 | 66 | 67 | def run_pipeline( 68 | pipe, 69 | num_views, 70 | text, 71 | height, 72 | width, 73 | num_inference_steps, 74 | guidance_scale, 75 | seed, 76 | negative_prompt, 77 | lora_scale=1.0, 78 | device="cuda", 79 | ): 80 | # Prepare cameras 81 | cameras = get_orthogonal_camera( 82 | elevation_deg=[0, 0, 0, 0, 0, 0], 83 | distance=[1.8] * num_views, 84 | left=-0.55, 85 | right=0.55, 86 | bottom=-0.55, 87 | top=0.55, 88 | azimuth_deg=[x - 90 for x in [0, 45, 90, 180, 270, 315]], 89 | device=device, 90 | ) 91 | 92 | plucker_embeds = get_plucker_embeds_from_cameras_ortho( 93 | cameras.c2w, [1.1] * num_views, width 94 | ) 95 | control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) 96 | 97 | pipe_kwargs = {} 98 | if seed != -1: 99 | pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) 100 | 101 | images = pipe( 102 | text, 103 | height=height, 104 | width=width, 105 | num_inference_steps=num_inference_steps, 106 | guidance_scale=guidance_scale, 107 | num_images_per_prompt=num_views, 108 | control_image=control_images, 109 | control_conditioning_scale=1.0, 110 | negative_prompt=negative_prompt, 111 | cross_attention_kwargs={"scale": lora_scale}, 112 | **pipe_kwargs, 113 | ).images 114 | 115 | return images 116 | 117 | 118 | if __name__ == "__main__": 119 | parser = argparse.ArgumentParser() 120 | # Models 121 | parser.add_argument( 122 | "--base_model", type=str, default="stabilityai/stable-diffusion-2-1-base" 123 | ) 124 | parser.add_argument("--vae_model", type=str, default=None) 125 | parser.add_argument("--unet_model", type=str, default=None) 126 | parser.add_argument("--scheduler", type=str, default=None) 127 | parser.add_argument("--lora_model", type=str, default=None) 128 | parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") 129 | parser.add_argument("--num_views", type=int, default=6) 130 | # Device 131 | parser.add_argument("--device", type=str, default="cuda") 132 | # Inference 133 | parser.add_argument("--text", type=str, required=True) 134 | parser.add_argument("--num_inference_steps", type=int, default=50) 135 | parser.add_argument("--guidance_scale", type=float, default=7.0) 136 | parser.add_argument("--seed", type=int, default=-1) 137 | parser.add_argument( 138 | "--negative_prompt", 139 | type=str, 140 | default="watermark, ugly, deformed, noisy, blurry, low contrast", 141 | ) 142 | parser.add_argument("--lora_scale", type=float, default=1.0) 143 | parser.add_argument("--output", type=str, default="output.png") 144 | args = parser.parse_args() 145 | 146 | pipe = prepare_pipeline( 147 | base_model=args.base_model, 148 | vae_model=args.vae_model, 149 | unet_model=args.unet_model, 150 | lora_model=args.lora_model, 151 | adapter_path=args.adapter_path, 152 | scheduler=args.scheduler, 153 | num_views=args.num_views, 154 | device=args.device, 155 | dtype=torch.float16, 156 | ) 157 | images = run_pipeline( 158 | pipe, 159 | num_views=args.num_views, 160 | text=args.text, 161 | height=512, 162 | width=512, 163 | num_inference_steps=args.num_inference_steps, 164 | guidance_scale=args.guidance_scale, 165 | seed=args.seed, 166 | negative_prompt=args.negative_prompt, 167 | lora_scale=args.lora_scale, 168 | device=args.device, 169 | ) 170 | make_image_grid(images, rows=1).save(args.output) 171 | -------------------------------------------------------------------------------- /scripts/inference_t2mv_sdxl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel 5 | 6 | from mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline 7 | from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler 8 | from mvadapter.utils.mesh_utils import get_orthogonal_camera 9 | from mvadapter.utils.geometry import get_plucker_embeds_from_cameras_ortho 10 | from mvadapter.utils import make_image_grid 11 | 12 | 13 | def prepare_pipeline( 14 | base_model, 15 | vae_model, 16 | unet_model, 17 | lora_model, 18 | adapter_path, 19 | scheduler, 20 | num_views, 21 | device, 22 | dtype, 23 | ): 24 | # Load vae and unet if provided 25 | pipe_kwargs = {} 26 | if vae_model is not None: 27 | pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) 28 | if unet_model is not None: 29 | pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) 30 | 31 | # Prepare pipeline 32 | pipe: MVAdapterT2MVSDXLPipeline 33 | pipe = MVAdapterT2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) 34 | 35 | # Load scheduler if provided 36 | scheduler_class = None 37 | if scheduler == "ddpm": 38 | scheduler_class = DDPMScheduler 39 | elif scheduler == "lcm": 40 | scheduler_class = LCMScheduler 41 | 42 | pipe.scheduler = ShiftSNRScheduler.from_scheduler( 43 | pipe.scheduler, 44 | shift_mode="interpolated", 45 | shift_scale=8.0, 46 | scheduler_class=scheduler_class, 47 | ) 48 | pipe.init_custom_adapter(num_views=num_views) 49 | pipe.load_custom_adapter( 50 | adapter_path, weight_name="mvadapter_t2mv_sdxl.safetensors" 51 | ) 52 | 53 | pipe.to(device=device, dtype=dtype) 54 | pipe.cond_encoder.to(device=device, dtype=dtype) 55 | 56 | # load lora if provided 57 | if lora_model is not None: 58 | model_, name_ = lora_model.rsplit("/", 1) 59 | pipe.load_lora_weights(model_, weight_name=name_) 60 | 61 | # vae slicing for lower memory usage 62 | pipe.enable_vae_slicing() 63 | 64 | return pipe 65 | 66 | 67 | def run_pipeline( 68 | pipe, 69 | num_views, 70 | text, 71 | height, 72 | width, 73 | num_inference_steps, 74 | guidance_scale, 75 | seed, 76 | negative_prompt, 77 | lora_scale=1.0, 78 | device="cuda", 79 | azimuth_deg=None, 80 | ): 81 | # Prepare cameras 82 | if azimuth_deg is None: 83 | azimuth_deg = [0, 45, 90, 180, 270, 315] 84 | cameras = get_orthogonal_camera( 85 | elevation_deg=[0] * num_views, 86 | distance=[1.8] * num_views, 87 | left=-0.55, 88 | right=0.55, 89 | bottom=-0.55, 90 | top=0.55, 91 | azimuth_deg=[x - 90 for x in azimuth_deg], 92 | device=device, 93 | ) 94 | 95 | plucker_embeds = get_plucker_embeds_from_cameras_ortho( 96 | cameras.c2w, [1.1] * num_views, width 97 | ) 98 | control_images = ((plucker_embeds + 1.0) / 2.0).clamp(0, 1) 99 | 100 | pipe_kwargs = {} 101 | if seed != -1: 102 | pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) 103 | 104 | images = pipe( 105 | text, 106 | height=height, 107 | width=width, 108 | num_inference_steps=num_inference_steps, 109 | guidance_scale=guidance_scale, 110 | num_images_per_prompt=num_views, 111 | control_image=control_images, 112 | control_conditioning_scale=1.0, 113 | negative_prompt=negative_prompt, 114 | cross_attention_kwargs={"scale": lora_scale}, 115 | **pipe_kwargs, 116 | ).images 117 | 118 | return images 119 | 120 | 121 | if __name__ == "__main__": 122 | parser = argparse.ArgumentParser() 123 | # Models 124 | parser.add_argument( 125 | "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" 126 | ) 127 | parser.add_argument( 128 | "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" 129 | ) 130 | parser.add_argument("--unet_model", type=str, default=None) 131 | parser.add_argument("--scheduler", type=str, default=None) 132 | parser.add_argument("--lora_model", type=str, default=None) 133 | parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") 134 | # Device 135 | parser.add_argument("--device", type=str, default="cuda") 136 | # Inference 137 | parser.add_argument("--num_views", type=int, default=6) 138 | parser.add_argument( 139 | "--azimuth_deg", type=int, nargs="+", default=[0, 45, 90, 180, 270, 315] 140 | ) 141 | parser.add_argument("--text", type=str, required=True) 142 | parser.add_argument("--num_inference_steps", type=int, default=50) 143 | parser.add_argument("--guidance_scale", type=float, default=7.0) 144 | parser.add_argument("--seed", type=int, default=-1) 145 | parser.add_argument( 146 | "--negative_prompt", 147 | type=str, 148 | default="watermark, ugly, deformed, noisy, blurry, low contrast", 149 | ) 150 | parser.add_argument("--lora_scale", type=float, default=1.0) 151 | parser.add_argument("--output", type=str, default="output.png") 152 | args = parser.parse_args() 153 | 154 | num_views = len(args.azimuth_deg) 155 | 156 | pipe = prepare_pipeline( 157 | base_model=args.base_model, 158 | vae_model=args.vae_model, 159 | unet_model=args.unet_model, 160 | lora_model=args.lora_model, 161 | adapter_path=args.adapter_path, 162 | scheduler=args.scheduler, 163 | num_views=num_views, 164 | device=args.device, 165 | dtype=torch.float16, 166 | ) 167 | images = run_pipeline( 168 | pipe, 169 | num_views=num_views, 170 | text=args.text, 171 | height=768, 172 | width=768, 173 | num_inference_steps=args.num_inference_steps, 174 | guidance_scale=args.guidance_scale, 175 | seed=args.seed, 176 | negative_prompt=args.negative_prompt, 177 | lora_scale=args.lora_scale, 178 | device=args.device, 179 | azimuth_deg=args.azimuth_deg, 180 | ) 181 | make_image_grid(images, rows=1).save(args.output) 182 | -------------------------------------------------------------------------------- /scripts/inference_tg2mv_sdxl.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import torch 4 | from diffusers import AutoencoderKL, DDPMScheduler, LCMScheduler, UNet2DConditionModel 5 | 6 | from mvadapter.models.attention_processor import DecoupledMVRowColSelfAttnProcessor2_0 7 | from mvadapter.pipelines.pipeline_mvadapter_t2mv_sdxl import MVAdapterT2MVSDXLPipeline 8 | from mvadapter.schedulers.scheduling_shift_snr import ShiftSNRScheduler 9 | from mvadapter.utils import get_orthogonal_camera, make_image_grid, tensor_to_image 10 | from mvadapter.utils.mesh_utils import NVDiffRastContextWrapper, load_mesh, render 11 | 12 | 13 | def prepare_pipeline( 14 | base_model, 15 | vae_model, 16 | unet_model, 17 | lora_model, 18 | adapter_path, 19 | scheduler, 20 | num_views, 21 | device, 22 | dtype, 23 | ): 24 | # Load vae and unet if provided 25 | pipe_kwargs = {} 26 | if vae_model is not None: 27 | pipe_kwargs["vae"] = AutoencoderKL.from_pretrained(vae_model) 28 | if unet_model is not None: 29 | pipe_kwargs["unet"] = UNet2DConditionModel.from_pretrained(unet_model) 30 | 31 | # Prepare pipeline 32 | pipe: MVAdapterT2MVSDXLPipeline 33 | pipe = MVAdapterT2MVSDXLPipeline.from_pretrained(base_model, **pipe_kwargs) 34 | 35 | # Load scheduler if provided 36 | scheduler_class = None 37 | if scheduler == "ddpm": 38 | scheduler_class = DDPMScheduler 39 | elif scheduler == "lcm": 40 | scheduler_class = LCMScheduler 41 | 42 | pipe.scheduler = ShiftSNRScheduler.from_scheduler( 43 | pipe.scheduler, 44 | shift_mode="interpolated", 45 | shift_scale=8.0, 46 | scheduler_class=scheduler_class, 47 | ) 48 | pipe.init_custom_adapter( 49 | num_views=num_views, self_attn_processor=DecoupledMVRowColSelfAttnProcessor2_0 50 | ) 51 | pipe.load_custom_adapter( 52 | adapter_path, weight_name="mvadapter_tg2mv_sdxl.safetensors" 53 | ) 54 | 55 | pipe.to(device=device, dtype=dtype) 56 | pipe.cond_encoder.to(device=device, dtype=dtype) 57 | 58 | # load lora if provided 59 | if lora_model is not None: 60 | model_, name_ = lora_model.rsplit("/", 1) 61 | pipe.load_lora_weights(model_, weight_name=name_) 62 | 63 | return pipe 64 | 65 | 66 | def run_pipeline( 67 | pipe, 68 | mesh_path, 69 | num_views, 70 | text, 71 | height, 72 | width, 73 | num_inference_steps, 74 | guidance_scale, 75 | seed, 76 | negative_prompt, 77 | lora_scale=1.0, 78 | device="cuda", 79 | ): 80 | # Prepare cameras 81 | cameras = get_orthogonal_camera( 82 | elevation_deg=[0, 0, 0, 0, 89.99, -89.99], 83 | distance=[1.8] * num_views, 84 | left=-0.55, 85 | right=0.55, 86 | bottom=-0.55, 87 | top=0.55, 88 | azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]], 89 | device=device, 90 | ) 91 | ctx = NVDiffRastContextWrapper(device=device) 92 | 93 | mesh = load_mesh(mesh_path, rescale=True, device=device) 94 | render_out = render( 95 | ctx, 96 | mesh, 97 | cameras, 98 | height=height, 99 | width=width, 100 | render_attr=False, 101 | normal_background=0.0, 102 | ) 103 | pos_images = tensor_to_image((render_out.pos + 0.5).clamp(0, 1), batched=True) 104 | normal_images = tensor_to_image( 105 | (render_out.normal / 2 + 0.5).clamp(0, 1), batched=True 106 | ) 107 | control_images = ( 108 | torch.cat( 109 | [ 110 | (render_out.pos + 0.5).clamp(0, 1), 111 | (render_out.normal / 2 + 0.5).clamp(0, 1), 112 | ], 113 | dim=-1, 114 | ) 115 | .permute(0, 3, 1, 2) 116 | .to(device) 117 | ) 118 | 119 | pipe_kwargs = {} 120 | if seed != -1: 121 | pipe_kwargs["generator"] = torch.Generator(device=device).manual_seed(seed) 122 | 123 | images = pipe( 124 | text, 125 | height=height, 126 | width=width, 127 | num_inference_steps=num_inference_steps, 128 | guidance_scale=guidance_scale, 129 | num_images_per_prompt=num_views, 130 | control_image=control_images, 131 | control_conditioning_scale=1.0, 132 | negative_prompt=negative_prompt, 133 | cross_attention_kwargs={"scale": lora_scale}, 134 | **pipe_kwargs, 135 | ).images 136 | 137 | return images, pos_images, normal_images 138 | 139 | 140 | if __name__ == "__main__": 141 | parser = argparse.ArgumentParser() 142 | # Models 143 | parser.add_argument( 144 | "--base_model", type=str, default="stabilityai/stable-diffusion-xl-base-1.0" 145 | ) 146 | parser.add_argument( 147 | "--vae_model", type=str, default="madebyollin/sdxl-vae-fp16-fix" 148 | ) 149 | parser.add_argument("--unet_model", type=str, default=None) 150 | parser.add_argument("--scheduler", type=str, default=None) 151 | parser.add_argument("--lora_model", type=str, default=None) 152 | parser.add_argument("--adapter_path", type=str, default="huanngzh/mv-adapter") 153 | parser.add_argument("--num_views", type=int, default=6) 154 | # Device 155 | parser.add_argument("--device", type=str, default="cuda") 156 | # Inference 157 | parser.add_argument("--mesh", type=str, required=True) 158 | parser.add_argument("--text", type=str, required=True) 159 | parser.add_argument("--num_inference_steps", type=int, default=50) 160 | parser.add_argument("--guidance_scale", type=float, default=7.0) 161 | parser.add_argument("--seed", type=int, default=-1) 162 | parser.add_argument( 163 | "--negative_prompt", 164 | type=str, 165 | default="watermark, ugly, deformed, noisy, blurry, low contrast", 166 | ) 167 | parser.add_argument("--lora_scale", type=float, default=1.0) 168 | parser.add_argument("--output", type=str, default="output.png") 169 | args = parser.parse_args() 170 | 171 | pipe = prepare_pipeline( 172 | base_model=args.base_model, 173 | vae_model=args.vae_model, 174 | unet_model=args.unet_model, 175 | lora_model=args.lora_model, 176 | adapter_path=args.adapter_path, 177 | scheduler=args.scheduler, 178 | num_views=args.num_views, 179 | device=args.device, 180 | dtype=torch.float16, 181 | ) 182 | images, pos_images, normal_images = run_pipeline( 183 | pipe, 184 | mesh_path=args.mesh, 185 | num_views=args.num_views, 186 | text=args.text, 187 | height=768, 188 | width=768, 189 | num_inference_steps=args.num_inference_steps, 190 | guidance_scale=args.guidance_scale, 191 | seed=args.seed, 192 | negative_prompt=args.negative_prompt, 193 | lora_scale=args.lora_scale, 194 | device=args.device, 195 | ) 196 | make_image_grid(images, rows=1).save(args.output) 197 | make_image_grid(pos_images, rows=1).save(args.output.rsplit(".", 1)[0] + "_pos.png") 198 | make_image_grid(normal_images, rows=1).save( 199 | args.output.rsplit(".", 1)[0] + "_nor.png" 200 | ) 201 | -------------------------------------------------------------------------------- /scripts/texture_i2tex.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import torch 6 | from torchvision import transforms 7 | from transformers import AutoModelForImageSegmentation 8 | 9 | from mvadapter.pipelines.pipeline_texture import ModProcessConfig, TexturePipeline 10 | from mvadapter.utils import make_image_grid 11 | 12 | from .inference_ig2mv_sdxl import prepare_pipeline, remove_bg, run_pipeline 13 | 14 | if __name__ == "__main__": 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument("--device", type=str, default="cuda") 17 | # I/O 18 | parser.add_argument("--mesh", type=str, required=True) 19 | parser.add_argument("--image", type=str, required=True) 20 | parser.add_argument("--text", type=str, default="high quality") 21 | parser.add_argument("--seed", type=int, default=-1) 22 | parser.add_argument("--save_dir", type=str, default="./output") 23 | parser.add_argument("--save_name", type=str, default="i2tex_sample") 24 | # Extra 25 | parser.add_argument("--reference_conditioning_scale", type=float, default=1.0) 26 | parser.add_argument("--preprocess_mesh", action="store_true") 27 | parser.add_argument("--remove_bg", action="store_true") 28 | args = parser.parse_args() 29 | 30 | device = args.device 31 | num_views = 6 32 | 33 | # Prepare pipelines 34 | pipe = prepare_pipeline( 35 | base_model="stabilityai/stable-diffusion-xl-base-1.0", 36 | vae_model="madebyollin/sdxl-vae-fp16-fix", 37 | unet_model=None, 38 | lora_model=None, 39 | adapter_path="huanngzh/mv-adapter", 40 | scheduler=None, 41 | num_views=num_views, 42 | device=device, 43 | dtype=torch.float16, 44 | ) 45 | if args.remove_bg: 46 | birefnet = AutoModelForImageSegmentation.from_pretrained( 47 | "ZhengPeng7/BiRefNet", trust_remote_code=True 48 | ) 49 | birefnet.to(args.device) 50 | transform_image = transforms.Compose( 51 | [ 52 | transforms.Resize((1024, 1024)), 53 | transforms.ToTensor(), 54 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 55 | ] 56 | ) 57 | remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, args.device) 58 | else: 59 | remove_bg_fn = None 60 | 61 | texture_pipe = TexturePipeline( 62 | upscaler_ckpt_path="./checkpoints/RealESRGAN_x2plus.pth", 63 | inpaint_ckpt_path="./checkpoints/big-lama.pt", 64 | device=device, 65 | ) 66 | print("Pipeline ready.") 67 | 68 | os.makedirs(args.save_dir, exist_ok=True) 69 | 70 | # 1. run MV-Adapter to generate multi-view images 71 | images, _, _, _ = run_pipeline( 72 | pipe, 73 | mesh_path=args.mesh, 74 | num_views=num_views, 75 | text=args.text, 76 | image=args.image, 77 | height=768, 78 | width=768, 79 | num_inference_steps=50, 80 | guidance_scale=3.0, 81 | seed=args.seed, 82 | reference_conditioning_scale=args.reference_conditioning_scale, 83 | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", 84 | device=device, 85 | remove_bg_fn=remove_bg_fn, 86 | ) 87 | mv_path = os.path.join(args.save_dir, f"{args.save_name}.png") 88 | make_image_grid(images, rows=1).save(mv_path) 89 | 90 | torch.cuda.empty_cache() 91 | 92 | # 2. un-project and complete texture 93 | out = texture_pipe( 94 | mesh_path=args.mesh, 95 | save_dir=args.save_dir, 96 | save_name=args.save_name, 97 | uv_unwarp=True, 98 | preprocess_mesh=args.preprocess_mesh, 99 | uv_size=4096, 100 | rgb_path=mv_path, 101 | rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"), 102 | camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]], 103 | ) 104 | print(f"Output saved to {out.shaded_model_save_path}") 105 | -------------------------------------------------------------------------------- /scripts/texture_t2tex.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import sys 4 | 5 | import torch 6 | 7 | from mvadapter.pipelines.pipeline_texture import ModProcessConfig, TexturePipeline 8 | from mvadapter.utils import make_image_grid 9 | 10 | from .inference_tg2mv_sdxl import prepare_pipeline, run_pipeline 11 | 12 | if __name__ == "__main__": 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument("--device", type=str, default="cuda") 15 | # I/O 16 | parser.add_argument("--mesh", type=str, required=True) 17 | parser.add_argument("--text", type=str, required=True) 18 | parser.add_argument("--seed", type=int, default=-1) 19 | parser.add_argument("--save_dir", type=str, default="./output") 20 | parser.add_argument("--save_name", type=str, default="t2tex_sample") 21 | # Extra 22 | parser.add_argument("--preprocess_mesh", action="store_true") 23 | args = parser.parse_args() 24 | 25 | device = args.device 26 | num_views = 6 27 | 28 | # Prepare pipelines 29 | pipe = prepare_pipeline( 30 | base_model="stabilityai/stable-diffusion-xl-base-1.0", 31 | vae_model="madebyollin/sdxl-vae-fp16-fix", 32 | unet_model=None, 33 | lora_model=None, 34 | adapter_path="huanngzh/mv-adapter", 35 | scheduler=None, 36 | num_views=num_views, 37 | device=device, 38 | dtype=torch.float16, 39 | ) 40 | texture_pipe = TexturePipeline( 41 | upscaler_ckpt_path="./checkpoints/RealESRGAN_x2plus.pth", 42 | inpaint_ckpt_path="./checkpoints/big-lama.pt", 43 | device=device, 44 | ) 45 | print("Pipeline ready.") 46 | 47 | os.makedirs(args.save_dir, exist_ok=True) 48 | 49 | # 1. run MV-Adapter to generate multi-view images 50 | images, pos_images, normal_images = run_pipeline( 51 | pipe, 52 | mesh_path=args.mesh, 53 | num_views=num_views, 54 | text=args.text, 55 | height=768, 56 | width=768, 57 | num_inference_steps=50, 58 | guidance_scale=7.0, 59 | seed=args.seed, 60 | negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast", 61 | device=device, 62 | ) 63 | mv_path = os.path.join(args.save_dir, f"{args.save_name}.png") 64 | make_image_grid(images, rows=1).save(mv_path) 65 | 66 | torch.cuda.empty_cache() 67 | 68 | # 2. un-project and complete texture 69 | out = texture_pipe( 70 | mesh_path=args.mesh, 71 | save_dir=args.save_dir, 72 | save_name=args.save_name, 73 | uv_unwarp=True, 74 | preprocess_mesh=args.preprocess_mesh, 75 | uv_size=4096, 76 | rgb_path=mv_path, 77 | rgb_process_config=ModProcessConfig(view_upscale=True, inpaint_mode="view"), 78 | camera_azimuth_deg=[x - 90 for x in [0, 90, 180, 270, 180, 180]], 79 | ) 80 | print(f"Output saved to {out.shaded_model_save_path}") 81 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="mvadapter", 5 | version="0.1.0", 6 | description="MV-Adapter: A Multi-View Adapter for 3D Generation", 7 | author="Zehuan Huang", 8 | packages=[ 9 | "mvadapter", 10 | "mvadapter.loaders", 11 | "mvadapter.models", 12 | "mvadapter.pipelines", 13 | "mvadapter.schedulers", 14 | "mvadapter.utils", 15 | "mvadapter.utils.mesh_utils", 16 | ], 17 | install_requires=[ 18 | "torch>=2.0.0", 19 | "torchvision", 20 | "controlnet_aux", 21 | "diffusers", 22 | "transformers", 23 | "peft", 24 | "numpy", 25 | "huggingface_hub", 26 | "accelerate", 27 | "opencv-python", 28 | "safetensors", 29 | "pillow", 30 | "omegaconf", 31 | "trimesh", 32 | "einops", 33 | "gradio", 34 | "timm", 35 | "kornia", 36 | "scikit-image", 37 | "sentencepiece", 38 | "spandrel", 39 | "open3d", 40 | "pymeshlab", 41 | "cvcuda_cu12", 42 | ], 43 | python_requires=">=3.8", 44 | classifiers=[ 45 | "Development Status :: 3 - Alpha", 46 | "Intended Audience :: Science/Research", 47 | "License :: OSI Approved :: MIT License", 48 | "Programming Language :: Python :: 3", 49 | "Programming Language :: Python :: 3.10", 50 | ], 51 | long_description=open("README.md").read(), 52 | long_description_content_type="text/markdown", 53 | url="https://github.com/huanngzh/MV-Adapter", 54 | project_urls={ 55 | "Project Page": "https://huanngzh.github.io/MV-Adapter-Page/", 56 | "Paper": "https://arxiv.org/abs/2412.03632", 57 | "Model Weights": "https://huggingface.co/huanngzh/mv-adapter", 58 | }, 59 | ) 60 | --------------------------------------------------------------------------------