├── mar3d ├── utils │ ├── __init__.py │ ├── visualizers │ │ ├── __init__.py │ │ ├── html_util.py │ │ └── color_util.py │ ├── .DS_Store │ ├── __pycache__ │ │ ├── base.cpython-310.pyc │ │ ├── base.cpython-311.pyc │ │ ├── misc.cpython-310.pyc │ │ ├── misc.cpython-311.pyc │ │ ├── ops.cpython-310.pyc │ │ ├── ops.cpython-311.pyc │ │ ├── config.cpython-310.pyc │ │ ├── config.cpython-311.pyc │ │ ├── saving.cpython-310.pyc │ │ ├── saving.cpython-311.pyc │ │ ├── typing.cpython-310.pyc │ │ ├── typing.cpython-311.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── callbacks.cpython-310.pyc │ │ ├── callbacks.cpython-311.pyc │ │ ├── scheduler.cpython-310.pyc │ │ ├── scheduler.cpython-311.pyc │ │ ├── checkpoint.cpython-310.pyc │ │ └── checkpoint.cpython-311.pyc │ ├── typing.py │ ├── checkpoint.py │ ├── base.py │ ├── scheduler.py │ ├── config.py │ ├── misc.py │ ├── callbacks.py │ └── ops.py ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── objaverse.cpython-310.pyc │ │ ├── objaverse.cpython-311.pyc │ │ ├── objaversediff.cpython-310.pyc │ │ └── objaversediff.cpython-311.pyc │ └── objaverse.py ├── systems │ ├── __init__.py │ ├── __pycache__ │ │ ├── base.cpython-310.pyc │ │ ├── base.cpython-311.pyc │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-311.pyc │ │ ├── diffloss.cpython-311.pyc │ │ ├── mar_diffusion.cpython-310.pyc │ │ ├── mar_diffusion.cpython-311.pyc │ │ ├── shape_diffusion.cpython-310.pyc │ │ └── shape_autoencoder.cpython-310.pyc │ ├── base.py │ ├── diffloss.py │ └── mar_diffusion.py ├── models │ ├── geometry │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── base.cpython-310.pyc │ │ │ ├── base.cpython-311.pyc │ │ │ ├── utils.cpython-310.pyc │ │ │ ├── utils.cpython-311.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ └── __init__.cpython-311.pyc │ │ └── base.py │ ├── conditional_encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── base.cpython-310.pyc │ │ │ ├── base.cpython-311.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── clip_encoder.cpython-310.pyc │ │ │ └── clip_encoder.cpython-311.pyc │ │ ├── clip │ │ │ ├── __pycache__ │ │ │ │ ├── modeling_clip.cpython-310.pyc │ │ │ │ ├── modeling_clip.cpython-311.pyc │ │ │ │ ├── modeling_conditional_clip.cpython-310.pyc │ │ │ │ └── modeling_conditional_clip.cpython-311.pyc │ │ │ └── modeling_conditional_clip.py │ │ ├── base.py │ │ └── clip_encoder.py │ ├── .DS_Store │ ├── __init__.py │ ├── autoencoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── utils.cpython-310.pyc │ │ │ ├── utils.cpython-311.pyc │ │ │ ├── __init__.cpython-310.pyc │ │ │ ├── __init__.cpython-311.pyc │ │ │ ├── michelangelo_autoencoder.cpython-310.pyc │ │ │ └── michelangelo_autoencoder.cpython-311.pyc │ │ └── utils.py │ ├── __pycache__ │ │ ├── __init__.cpython-310.pyc │ │ ├── __init__.cpython-311.pyc │ │ └── diffloss.cpython-310.pyc │ ├── transformers │ │ ├── __pycache__ │ │ │ ├── utils.cpython-310.pyc │ │ │ ├── utils.cpython-311.pyc │ │ │ ├── attention.cpython-310.pyc │ │ │ ├── attention.cpython-311.pyc │ │ │ ├── perceiver_1d.cpython-310.pyc │ │ │ └── perceiver_1d.cpython-311.pyc │ │ ├── utils.py │ │ ├── perceiver_1d.py │ │ └── attention.py │ └── diffloss.py ├── .DS_Store ├── __pycache__ │ ├── __init__.cpython-310.pyc │ └── __init__.cpython-311.pyc └── __init__.py ├── configs ├── .DS_Store ├── shape-autoencoder │ └── occ.yaml └── mar-diffusion │ └── mar.yaml ├── diffusion ├── __pycache__ │ ├── __init__.cpython-310.pyc │ ├── __init__.cpython-311.pyc │ ├── respace.cpython-310.pyc │ ├── respace.cpython-311.pyc │ ├── diffusion_utils.cpython-310.pyc │ ├── diffusion_utils.cpython-311.pyc │ ├── gaussian_diffusion.cpython-310.pyc │ └── gaussian_diffusion.cpython-311.pyc ├── __init__.py ├── diffusion_utils.py └── respace.py ├── train_diffusion.sh ├── README.md └── launch.py /mar3d/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from . import base 2 | -------------------------------------------------------------------------------- /mar3d/data/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | objaverse 3 | ) -------------------------------------------------------------------------------- /mar3d/utils/visualizers/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /mar3d/systems/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | mar_diffusion 3 | ) -------------------------------------------------------------------------------- /mar3d/models/geometry/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | base 3 | ) 4 | -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/configs/.DS_Store -------------------------------------------------------------------------------- /mar3d/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/.DS_Store -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | clip_encoder, 3 | ) 4 | -------------------------------------------------------------------------------- /mar3d/models/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/.DS_Store -------------------------------------------------------------------------------- /mar3d/models/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | autoencoders, 3 | conditional_encoders, 4 | ) -------------------------------------------------------------------------------- /mar3d/models/autoencoders/__init__.py: -------------------------------------------------------------------------------- 1 | from . import ( 2 | michelangelo_autoencoder, 3 | ) 4 | -------------------------------------------------------------------------------- /mar3d/utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/.DS_Store -------------------------------------------------------------------------------- /mar3d/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/base.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/base.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/misc.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/misc.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/misc.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/misc.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/ops.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/ops.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/ops.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/ops.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/respace.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/respace.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/respace.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/respace.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/data/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/data/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/base.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/base.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/config.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/config.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/config.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/config.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/saving.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/saving.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/saving.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/saving.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/typing.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/typing.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/typing.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/typing.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/data/__pycache__/objaverse.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/objaverse.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/data/__pycache__/objaverse.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/objaverse.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/__pycache__/diffloss.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/__pycache__/diffloss.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/callbacks.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/callbacks.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/callbacks.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/callbacks.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/scheduler.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/scheduler.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/scheduler.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/scheduler.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/data/__pycache__/objaversediff.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/objaversediff.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/data/__pycache__/objaversediff.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/data/__pycache__/objaversediff.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/diffloss.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/diffloss.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/checkpoint.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/checkpoint.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/utils/__pycache__/checkpoint.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/utils/__pycache__/checkpoint.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/diffusion_utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/diffusion_utils.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/diffusion_utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/diffusion_utils.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/geometry/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/geometry/__pycache__/base.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/base.cpython-311.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/gaussian_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/diffusion/__pycache__/gaussian_diffusion.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/geometry/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/geometry/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/mar_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/mar_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/mar_diffusion.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/mar_diffusion.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/shape_diffusion.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/shape_diffusion.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/autoencoders/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/autoencoders/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/geometry/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/geometry/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/geometry/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/transformers/__pycache__/utils.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/utils.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/transformers/__pycache__/utils.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/utils.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/systems/__pycache__/shape_autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/systems/__pycache__/shape_autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/autoencoders/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/autoencoders/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/transformers/__pycache__/attention.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/attention.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/transformers/__pycache__/attention.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/attention.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/__pycache__/base.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/base.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/__pycache__/base.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/base.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/transformers/__pycache__/perceiver_1d.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/perceiver_1d.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/transformers/__pycache__/perceiver_1d.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/transformers/__pycache__/perceiver_1d.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/__init__.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/__pycache__/__init__.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/__init__.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/__pycache__/clip_encoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/clip_encoder.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/__pycache__/clip_encoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/__pycache__/clip_encoder.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/autoencoders/__pycache__/michelangelo_autoencoder.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/clip/__pycache__/modeling_clip.cpython-311.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-310.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-310.pyc -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-311.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jinnan-chen/MAR-3D/HEAD/mar3d/models/conditional_encoders/clip/__pycache__/modeling_conditional_clip.cpython-311.pyc -------------------------------------------------------------------------------- /train_diffusion.sh: -------------------------------------------------------------------------------- 1 | export CUDA_VISIBLE_DEVICES=0 #0,1,2,3 #,4,5,6,7 # 2 | python launch.py --config ./configs/mar-diffusion/mar.yaml --train --gpu 0 3 | # python launch.py --config ./configs/image-to-shape-diffusion/clip-mvrgb-modln-l256-e64-ne8-nd16-nl6.yaml --train --gpu 0 -------------------------------------------------------------------------------- /mar3d/models/transformers/utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | def init_linear(l, stddev): 4 | nn.init.normal_(l.weight, std=stddev) 5 | if l.bias is not None: 6 | nn.init.constant_(l.bias, 0.0) 7 | 8 | class MLP(nn.Module): 9 | def __init__(self, *, 10 | width: int, 11 | init_scale: float): 12 | super().__init__() 13 | self.width = width 14 | self.c_fc = nn.Linear(width, width * 4) 15 | self.c_proj = nn.Linear(width * 4, width) 16 | self.gelu = nn.GELU() 17 | init_linear(self.c_fc, init_scale) 18 | init_linear(self.c_proj, init_scale) 19 | 20 | def forward(self, x): 21 | return self.c_proj(self.gelu(self.c_fc(x))) 22 | -------------------------------------------------------------------------------- /mar3d/utils/visualizers/html_util.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import io 3 | import base64 4 | import numpy as np 5 | from PIL import Image 6 | 7 | 8 | def to_html_frame(content): 9 | 10 | html_frame = f""" 11 | 12 | 13 | {content} 14 | 15 | 16 | """ 17 | 18 | return html_frame 19 | 20 | 21 | def to_single_row_table(caption: str, content: str): 22 | 23 | table_html = f""" 24 | 25 | 26 | 27 | 28 | 29 |
{caption}
{content}
30 | """ 31 | 32 | return table_html 33 | 34 | 35 | def to_image_embed_tag(image: np.ndarray): 36 | 37 | # Convert np.ndarray to bytes 38 | img = Image.fromarray(image) 39 | raw_bytes = io.BytesIO() 40 | img.save(raw_bytes, "PNG") 41 | 42 | # Encode bytes to base64 43 | image_base64 = base64.b64encode(raw_bytes.getvalue()).decode("utf-8") 44 | 45 | image_tag = f""" 46 | Embedded Image 47 | """ 48 | 49 | return image_tag 50 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## [CVPR 2025 Highlight] 2 | ## **MAR-3D:Progressive Masked Auto-regressor for High-Resolution 3D Generation** 3 | 4 | 5 | Please prepare you preprocessed inputs points and supervision sdf/occ following https://github.com/wyysf-98/CraftsMan3D 6 | (We are unable to release the preprocessing coed and data) 7 | 8 | root_dir should contain .npz files,image file mapping is in objaverse.py 9 | 10 | The trianing pipepline is also following CraftsMan3D's code base. 11 | 12 | ## Training: 13 | 14 | bash train_diffusion.sh 15 | 16 | ### Links 17 | 📄 **Paper**: [arXiv:2503.20519](https://arxiv.org/abs/2503.20519) 18 | 🌐 **Project Page**: [https://jinnan-chen.github.io/projects/MAR-3D/](https://jinnan-chen.github.io/projects/MAR-3D/) 19 | 20 | ### Authors 21 | Jinnan Chen, Lingting Zhu, Zeyu Hu, Shengju Qian, Yugang Chen, Xin Wang, Gim Hee Lee 22 | 23 | ### Citation 24 | ```bibtex 25 | @article{chen2025mar3d, 26 | title = {MAR-3D: Progressive Masked Auto-regressor for High-Resolution 3D Generation}, 27 | author = {Chen, Jinnan and Zhu, Lingting and Hu, Zeyu and Qian, Shengju and 28 | Chen, Yugang and Wang, Xin and Lee, Gim Hee}, 29 | journal = {arXiv preprint arXiv:2503.20519}, 30 | year = {2025} 31 | } 32 | ``` 33 | -------------------------------------------------------------------------------- /mar3d/utils/typing.py: -------------------------------------------------------------------------------- 1 | """ 2 | This module contains type annotations for the project, using 3 | 1. Python type hints (https://docs.python.org/3/library/typing.html) for Python objects 4 | 2. jaxtyping (https://github.com/google/jaxtyping/blob/main/API.md) for PyTorch tensors 5 | 6 | Two types of typing checking can be used: 7 | 1. Static type checking with mypy (install with pip and enabled as the default linter in VSCode) 8 | 2. Runtime type checking with typeguard (install with pip and triggered at runtime, mainly for tensor dtype and shape checking) 9 | """ 10 | 11 | # Basic types 12 | from typing import ( 13 | Any, 14 | Callable, 15 | Dict, 16 | Iterable, 17 | List, 18 | Literal, 19 | NamedTuple, 20 | NewType, 21 | Optional, 22 | Sized, 23 | Tuple, 24 | Type, 25 | TypeVar, 26 | Union, 27 | Sequence, 28 | ) 29 | 30 | # Tensor dtype 31 | # for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md 32 | from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt 33 | 34 | # Config type 35 | from omegaconf import DictConfig 36 | 37 | # PyTorch Tensor type 38 | from torch import Tensor 39 | 40 | # Runtime type checking decorator 41 | from typeguard import typechecked as typechecker 42 | -------------------------------------------------------------------------------- /mar3d/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | __modules__ = {} 4 | 5 | 6 | def register(name): 7 | def decorator(cls): 8 | if name in __modules__: 9 | raise ValueError( 10 | f"Module {name} already exists! Names of extensions conflict!" 11 | ) 12 | else: 13 | __modules__[name] = cls 14 | return cls 15 | 16 | return decorator 17 | 18 | 19 | def find(name): 20 | if name in __modules__: 21 | return __modules__[name] 22 | else: 23 | try: 24 | module_string = ".".join(name.split(".")[:-1]) 25 | cls_name = name.split(".")[-1] 26 | module = importlib.import_module(module_string, package=None) 27 | return getattr(module, cls_name) 28 | except Exception as e: 29 | raise ValueError(f"Module {name} not found!") 30 | 31 | 32 | ### grammar sugar for logging utilities ### 33 | import logging 34 | 35 | logger = logging.getLogger("pytorch_lightning") 36 | 37 | from pytorch_lightning.utilities.rank_zero import ( 38 | rank_zero_debug, 39 | rank_zero_info, 40 | rank_zero_only, 41 | ) 42 | 43 | debug = rank_zero_debug 44 | info = rank_zero_info 45 | 46 | 47 | @rank_zero_only 48 | def warn(*args, **kwargs): 49 | logger.warn(*args, **kwargs) 50 | 51 | 52 | from . import data, models, systems 53 | -------------------------------------------------------------------------------- /mar3d/models/transformers/perceiver_1d.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from mar3d.utils.typing import * 7 | from mar3d.utils.checkpoint import checkpoint 8 | 9 | from .utils import init_linear 10 | from .attention import ResidualAttentionBlock 11 | 12 | 13 | class Perceiver(nn.Module): 14 | def __init__( 15 | self, 16 | *, 17 | n_ctx: int, 18 | width: int, 19 | layers: int, 20 | heads: int, 21 | init_scale: float = 0.25, 22 | qkv_bias: bool = True, 23 | use_flash: bool = False, 24 | use_checkpoint: bool = False 25 | ): 26 | super().__init__() 27 | self.n_ctx = n_ctx 28 | self.width = width 29 | self.layers = layers 30 | self.resblocks = nn.ModuleList( 31 | [ 32 | ResidualAttentionBlock( 33 | n_ctx=n_ctx, 34 | width=width, 35 | heads=heads, 36 | init_scale=init_scale, 37 | qkv_bias=qkv_bias, 38 | use_flash=use_flash, 39 | use_checkpoint=use_checkpoint 40 | ) 41 | for _ in range(layers) 42 | ] 43 | ) 44 | 45 | def forward(self, x: torch.Tensor): 46 | for block in self.resblocks: 47 | x = block(x) 48 | return x -------------------------------------------------------------------------------- /mar3d/utils/visualizers/color_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | # Helper functions 6 | def get_colors(inp, colormap="viridis", normalize=True, vmin=None, vmax=None): 7 | colormap = plt.cm.get_cmap(colormap) 8 | if normalize: 9 | vmin = np.min(inp) 10 | vmax = np.max(inp) 11 | 12 | norm = plt.Normalize(vmin, vmax) 13 | return colormap(norm(inp))[:, :3] 14 | 15 | 16 | def gen_checkers(n_checkers_x, n_checkers_y, width=256, height=256): 17 | # tex dims need to be power of two. 18 | array = np.ones((width, height, 3), dtype='float32') 19 | 20 | # width in texels of each checker 21 | checker_w = width / n_checkers_x 22 | checker_h = height / n_checkers_y 23 | 24 | for y in range(height): 25 | for x in range(width): 26 | color_key = int(x / checker_w) + int(y / checker_h) 27 | if color_key % 2 == 0: 28 | array[x, y, :] = [1., 0.874, 0.0] 29 | else: 30 | array[x, y, :] = [0., 0., 0.] 31 | return array 32 | 33 | 34 | def gen_circle(width=256, height=256): 35 | xx, yy = np.mgrid[:width, :height] 36 | circle = (xx - width / 2 + 0.5) ** 2 + (yy - height / 2 + 0.5) ** 2 37 | array = np.ones((width, height, 4), dtype='float32') 38 | array[:, :, 0] = (circle <= width) 39 | array[:, :, 1] = (circle <= width) 40 | array[:, :, 2] = (circle <= width) 41 | array[:, :, 3] = circle <= width 42 | return array 43 | 44 | -------------------------------------------------------------------------------- /diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | # Adopted from DiT, which is modified from OpenAI's diffusion repos 2 | # DiT: https://github.com/facebookresearch/DiT/diffusion 3 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 4 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 5 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 6 | 7 | from . import gaussian_diffusion as gd 8 | from .respace import SpacedDiffusion, space_timesteps 9 | 10 | 11 | def create_diffusion( 12 | timestep_respacing, 13 | noise_schedule="linear", 14 | use_kl=False, 15 | sigma_small=False, 16 | predict_xstart=False, 17 | learn_sigma=True, 18 | rescale_learned_sigmas=False, 19 | diffusion_steps=1000 20 | ): 21 | betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) 22 | if use_kl: 23 | loss_type = gd.LossType.RESCALED_KL 24 | elif rescale_learned_sigmas: 25 | loss_type = gd.LossType.RESCALED_MSE 26 | else: 27 | loss_type = gd.LossType.MSE 28 | if timestep_respacing is None or timestep_respacing == "": 29 | timestep_respacing = [diffusion_steps] 30 | return SpacedDiffusion( 31 | use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), 32 | betas=betas, 33 | model_mean_type=( 34 | gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X 35 | ), 36 | model_var_type=( 37 | ( 38 | gd.ModelVarType.FIXED_LARGE 39 | if not sigma_small 40 | else gd.ModelVarType.FIXED_SMALL 41 | ) 42 | if not learn_sigma 43 | else gd.ModelVarType.LEARNED_RANGE 44 | ), 45 | loss_type=loss_type 46 | # rescale_timesteps=rescale_timesteps, 47 | ) 48 | -------------------------------------------------------------------------------- /configs/shape-autoencoder/occ.yaml: -------------------------------------------------------------------------------- 1 | exp_root_dir: "outputs" 2 | name: "michelangelo-autoencoder/mul-l64-256-e64-ne8-nd16" 3 | tag: "${rmspace:n${data.n_samples}+${data.supervision_type}+rot${data.rotate}+noise${data.noise_sigma}+${system.shape_model.embed_type}+dsample${system.shape_model.use_downsample}+pfeat${system.shape_model.point_feats}+logits${system.loss.lambda_logits}+kl${system.loss.lambda_kl}+lr${system.optimizer.args.lr},_}" 4 | seed: 0 5 | # resume: model-64.ckpt 6 | data_type: "objaverse-datamodule" 7 | data: 8 | root_dir: "" 9 | data_type: "occupancy" 10 | n_samples: 20480 11 | noise_sigma: 0. 12 | rotate: False 13 | load_supervision: True 14 | supervision_type: "occupancy" 15 | n_supervision: 25600 16 | load_image: False # whether to load images 17 | load_caption: False # whether to load captions 18 | batch_size: 32 19 | num_workers: 8 20 | 21 | system_type: "shape-autoencoder-system" 22 | system: 23 | sample_posterior: true 24 | shape_model_type: "michelangelo-autoencoder" 25 | shape_model: 26 | num_latents: 256 # 1024 27 | embed_dim: 64 28 | point_feats: 3 # xyz + normal 29 | out_dim: 1 # only occupancy 30 | embed_type: "fourier" 31 | num_freqs: 8 32 | include_pi: false 33 | heads: 12 34 | width: 768 35 | num_encoder_layers: 8 36 | num_decoder_layers: 16 37 | use_ln_post: true 38 | init_scale: 0.25 39 | qkv_bias: true 40 | use_flash: true 41 | use_checkpoint: true 42 | use_downsample: false 43 | 44 | loggers: 45 | wandb: 46 | enable: false 47 | project: "mar3d" 48 | name: shape-autoencoder+${name}+${tag} 49 | 50 | loss: 51 | lambda_logits: 1. 52 | lambda_kl: 0.001 53 | 54 | optimizer: 55 | name: AdamW 56 | args: 57 | lr: 1.e-4 58 | betas: [0.9, 0.99] 59 | eps: 1.e-6 60 | 61 | scheduler: 62 | name: SequentialLR 63 | interval: step 64 | schedulers: 65 | - name: LinearLR 66 | interval: step 67 | args: 68 | start_factor: 1e-6 69 | end_factor: 1.0 70 | total_iters: 5000 71 | - name: CosineAnnealingLR 72 | interval: step 73 | args: 74 | T_max: 5000 75 | eta_min: 0. 76 | milestones: [5000] 77 | 78 | trainer: 79 | num_nodes: 1 80 | max_epochs: 100000 81 | log_every_n_steps: 10 82 | num_sanity_val_steps: 0 83 | # val_check_interval: 200 84 | check_val_every_n_epoch: 5 85 | enable_progress_bar: true 86 | precision: 16-mixed 87 | strategy: 'ddp_find_unused_parameters_true' 88 | 89 | checkpoint: 90 | save_last: true 91 | save_top_k: -1 92 | every_n_train_steps: 10000 93 | -------------------------------------------------------------------------------- /mar3d/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ 3 | Adapted from: https://github.com/openai/guided-diffusion/blob/22e0df8183507e13a7813f8d38d51b072ca1e67c/guided_diffusion/nn.py#L124 4 | """ 5 | 6 | import torch 7 | from mar3d.utils.typing import * 8 | 9 | def checkpoint( 10 | func: Callable[..., Union[torch.Tensor, Sequence[torch.Tensor]]], 11 | inputs: Sequence[torch.Tensor], 12 | params: Iterable[torch.Tensor], 13 | flag: bool, 14 | use_deepspeed: bool = False 15 | ): 16 | """ 17 | Evaluate a function without caching intermediate activations, allowing for 18 | reduced memory at the expense of extra compute in the backward pass. 19 | :param func: the function to evaluate. 20 | :param inputs: the argument sequence to pass to `func`. 21 | :param params: a sequence of parameters `func` depends on but does not 22 | explicitly take as arguments. 23 | :param flag: if False, disable gradient checkpointing. 24 | :param use_deepspeed: if True, use deepspeed 25 | """ 26 | if flag: 27 | if use_deepspeed: 28 | import deepspeed 29 | return deepspeed.checkpointing.checkpoint(func, *inputs) 30 | 31 | args = tuple(inputs) + tuple(params) 32 | return CheckpointFunction.apply(func, len(inputs), *args) 33 | else: 34 | return func(*inputs) 35 | 36 | 37 | class CheckpointFunction(torch.autograd.Function): 38 | @staticmethod 39 | @torch.cuda.amp.custom_fwd 40 | def forward(ctx, run_function, length, *args): 41 | ctx.run_function = run_function 42 | ctx.input_tensors = list(args[:length]) 43 | ctx.input_params = list(args[length:]) 44 | 45 | with torch.no_grad(): 46 | output_tensors = ctx.run_function(*ctx.input_tensors) 47 | return output_tensors 48 | 49 | @staticmethod 50 | @torch.cuda.amp.custom_bwd 51 | def backward(ctx, *output_grads): 52 | ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] 53 | with torch.enable_grad(): 54 | # Fixes a bug where the first op in run_function modifies the 55 | # Tensor storage in place, which is not allowed for detach()'d 56 | # Tensors. 57 | shallow_copies = [x.view_as(x) for x in ctx.input_tensors] 58 | output_tensors = ctx.run_function(*shallow_copies) 59 | input_grads = torch.autograd.grad( 60 | output_tensors, 61 | ctx.input_tensors + ctx.input_params, 62 | output_grads, 63 | allow_unused=True, 64 | ) 65 | del ctx.input_tensors 66 | del ctx.input_params 67 | del output_tensors 68 | return (None, None) + input_grads -------------------------------------------------------------------------------- /diffusion/diffusion_utils.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import torch as th 7 | import numpy as np 8 | 9 | 10 | def normal_kl(mean1, logvar1, mean2, logvar2): 11 | """ 12 | Compute the KL divergence between two gaussians. 13 | Shapes are automatically broadcasted, so batches can be compared to 14 | scalars, among other use cases. 15 | """ 16 | tensor = None 17 | for obj in (mean1, logvar1, mean2, logvar2): 18 | if isinstance(obj, th.Tensor): 19 | tensor = obj 20 | break 21 | assert tensor is not None, "at least one argument must be a Tensor" 22 | 23 | # Force variances to be Tensors. Broadcasting helps convert scalars to 24 | # Tensors, but it does not work for th.exp(). 25 | logvar1, logvar2 = [ 26 | x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) 27 | for x in (logvar1, logvar2) 28 | ] 29 | 30 | return 0.5 * ( 31 | -1.0 32 | + logvar2 33 | - logvar1 34 | + th.exp(logvar1 - logvar2) 35 | + ((mean1 - mean2) ** 2) * th.exp(-logvar2) 36 | ) 37 | 38 | 39 | def approx_standard_normal_cdf(x): 40 | """ 41 | A fast approximation of the cumulative distribution function of the 42 | standard normal. 43 | """ 44 | return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) 45 | 46 | 47 | def discretized_gaussian_log_likelihood(x, *, means, log_scales): 48 | """ 49 | Compute the log-likelihood of a Gaussian distribution discretizing to a 50 | given image. 51 | :param x: the target images. It is assumed that this was uint8 values, 52 | rescaled to the range [-1, 1]. 53 | :param means: the Gaussian mean Tensor. 54 | :param log_scales: the Gaussian log stddev Tensor. 55 | :return: a tensor like x of log probabilities (in nats). 56 | """ 57 | assert x.shape == means.shape == log_scales.shape 58 | centered_x = x - means 59 | inv_stdv = th.exp(-log_scales) 60 | plus_in = inv_stdv * (centered_x + 1.0 / 255.0) 61 | cdf_plus = approx_standard_normal_cdf(plus_in) 62 | min_in = inv_stdv * (centered_x - 1.0 / 255.0) 63 | cdf_min = approx_standard_normal_cdf(min_in) 64 | log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) 65 | log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) 66 | cdf_delta = cdf_plus - cdf_min 67 | log_probs = th.where( 68 | x < -0.999, 69 | log_cdf_plus, 70 | th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), 71 | ) 72 | assert log_probs.shape == x.shape 73 | return log_probs 74 | -------------------------------------------------------------------------------- /configs/mar-diffusion/mar.yaml: -------------------------------------------------------------------------------- 1 | exp_root_dir: "outputs" 2 | name: "image-to-shape-diffusion/mar" 3 | tag: "${rmspace:${system.shape_model_type}+n${data.n_samples}+noise${data.noise_sigma}+pfeat${system.shape_model.point_feats}+normemb${system.condition_model.normalize_embeds}+lr${system.optimizer.args.lr}+qkvbias${system.shape_model.qkv_bias}+nfreq${system.shape_model.num_freqs}+ln_post${system.shape_model.use_ln_post},_}" 4 | seed: 0 5 | data_type: "objaverse-datamodule" 6 | data: 7 | root_dir: [""] #sample points and sdf/occ.npz 8 | data_type: "sdf" 9 | n_samples: 20480 10 | noise_sigma: 0. 11 | load_supervision: false 12 | supervision_type: "occupancy" 13 | n_supervision: 4096 14 | load_image: True # whether to load images 15 | image_data_path: "" 16 | image_type: "rgb" # rgb, normal, mvrgb, mvnormal 17 | idx: [0,1,2,3,4,5,6,7] 18 | n_views: 1 19 | load_caption: false # whether to load captions 20 | batch_size: 12 21 | num_workers: 8 22 | 23 | system_type: "mar-diffusion-system" 24 | system: 25 | data_type: ${data_type} 26 | shape_model_type: "michelangelo-autoencoder" 27 | shape_model: 28 | num_latents: 256 #1024 #512 #256 29 | embed_dim: 64 30 | point_feats: 3 31 | out_dim: 1 32 | num_freqs: 8 33 | include_pi: false 34 | heads: 12 35 | width: 768 36 | num_encoder_layers: 16 #16 #8 37 | num_decoder_layers: 32 #32 #16 38 | use_ln_post: true 39 | init_scale: 0.25 40 | qkv_bias: true 41 | use_flash: true 42 | use_checkpoint: true 43 | use_downsample: false 44 | 45 | condition_model_type: "clip-embedder" 46 | condition_model: 47 | pretrained_model_name_or_path: "openai/clip-vit-large-patch14" 48 | encode_camera: false 49 | camera_embeds_dim: 32 # 16 * 2[sin, cos] 50 | n_views: ${data.n_views} 51 | empty_embeds_ratio: 0.1 52 | normalize_embeds: false 53 | dino: true 54 | zero_uncond_embeds: false 55 | loggers: 56 | wandb: 57 | enable: false 58 | project: "mar3d" 59 | name: image-to-shape-diffusion+${name}+${tag} 60 | loss: 61 | loss_type: "mse" 62 | lambda_diffusion: 1. 63 | optimizer: 64 | name: AdamW 65 | args: 66 | lr: 1.e-4 67 | betas: [0.9, 0.99] 68 | eps: 1.e-6 69 | scheduler: 70 | name: SequentialLR 71 | interval: step 72 | schedulers: 73 | - name: LinearLR 74 | interval: step 75 | args: 76 | start_factor: 1e-6 77 | end_factor: 1.0 78 | total_iters: 5000 79 | - name: CosineAnnealingLR 80 | interval: step 81 | args: 82 | T_max: 5000 83 | eta_min: 0. 84 | milestones: [5000] 85 | 86 | trainer: 87 | num_nodes: 1 88 | max_epochs: 100000 89 | log_every_n_steps: 1 90 | num_sanity_val_steps: 0 91 | check_val_every_n_epoch: 2 92 | enable_progress_bar: true 93 | precision: 16-mixed 94 | strategy: 'ddp_find_unused_parameters_true' 95 | 96 | checkpoint: 97 | save_last: true 98 | save_top_k: -1 99 | every_n_train_steps: 5000 -------------------------------------------------------------------------------- /mar3d/utils/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | 6 | from mar3d.utils.config import parse_structured 7 | from mar3d.utils.misc import get_device, load_module_weights 8 | from mar3d.utils.typing import * 9 | 10 | 11 | class Configurable: 12 | @dataclass 13 | class Config: 14 | pass 15 | 16 | def __init__(self, cfg: Optional[dict] = None) -> None: 17 | super().__init__() 18 | self.cfg = parse_structured(self.Config, cfg) 19 | 20 | 21 | class Updateable: 22 | def do_update_step( 23 | self, epoch: int, global_step: int, on_load_weights: bool = False 24 | ): 25 | for attr in self.__dir__(): 26 | if attr.startswith("_"): 27 | continue 28 | try: 29 | module = getattr(self, attr) 30 | except: 31 | continue # ignore attributes like property, which can't be retrived using getattr? 32 | if isinstance(module, Updateable): 33 | module.do_update_step( 34 | epoch, global_step, on_load_weights=on_load_weights 35 | ) 36 | self.update_step(epoch, global_step, on_load_weights=on_load_weights) 37 | 38 | def do_update_step_end(self, epoch: int, global_step: int): 39 | for attr in self.__dir__(): 40 | if attr.startswith("_"): 41 | continue 42 | try: 43 | module = getattr(self, attr) 44 | except: 45 | continue # ignore attributes like property, which can't be retrived using getattr? 46 | if isinstance(module, Updateable): 47 | module.do_update_step_end(epoch, global_step) 48 | self.update_step_end(epoch, global_step) 49 | 50 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 51 | # override this method to implement custom update logic 52 | # if on_load_weights is True, you should be careful doing things related to model evaluations, 53 | # as the models and tensors are not guarenteed to be on the same device 54 | pass 55 | 56 | def update_step_end(self, epoch: int, global_step: int): 57 | pass 58 | 59 | 60 | def update_if_possible(module: Any, epoch: int, global_step: int) -> None: 61 | if isinstance(module, Updateable): 62 | module.do_update_step(epoch, global_step) 63 | 64 | 65 | def update_end_if_possible(module: Any, epoch: int, global_step: int) -> None: 66 | if isinstance(module, Updateable): 67 | module.do_update_step_end(epoch, global_step) 68 | 69 | 70 | class BaseObject(Updateable): 71 | @dataclass 72 | class Config: 73 | pass 74 | 75 | cfg: Config # add this to every subclass of BaseObject to enable static type checking 76 | 77 | def __init__( 78 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 79 | ) -> None: 80 | super().__init__() 81 | self.cfg = parse_structured(self.Config, cfg) 82 | self.device = get_device() 83 | self.configure(*args, **kwargs) 84 | 85 | def configure(self, *args, **kwargs) -> None: 86 | pass 87 | 88 | 89 | class BaseModule(nn.Module, Updateable): 90 | @dataclass 91 | class Config: 92 | weights: Optional[str] = None 93 | 94 | cfg: Config # add this to every subclass of BaseModule to enable static type checking 95 | 96 | def __init__( 97 | self, cfg: Optional[Union[dict, DictConfig]] = None, *args, **kwargs 98 | ) -> None: 99 | super().__init__() 100 | self.cfg = parse_structured(self.Config, cfg) 101 | self.device = get_device() 102 | self.configure(*args, **kwargs) 103 | if self.cfg.weights is not None: 104 | # format: path/to/weights:module_name 105 | weights_path, module_name = self.cfg.weights.split(":") 106 | state_dict, epoch, global_step = load_module_weights( 107 | weights_path, module_name=module_name, map_location="cpu" 108 | ) 109 | self.load_state_dict(state_dict) 110 | self.do_update_step( 111 | epoch, global_step, on_load_weights=True 112 | ) # restore states 113 | # dummy tensor to indicate model state 114 | self._dummy: Float[Tensor, "..."] 115 | self.register_buffer("_dummy", torch.zeros(0).float(), persistent=False) 116 | 117 | def configure(self, *args, **kwargs) -> None: 118 | pass 119 | -------------------------------------------------------------------------------- /mar3d/utils/scheduler.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import warnings 3 | from bisect import bisect_right 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch.optim import lr_scheduler 8 | 9 | import mar3d 10 | 11 | 12 | def get_scheduler(name): 13 | if hasattr(lr_scheduler, name): 14 | return getattr(lr_scheduler, name) 15 | else: 16 | raise NotImplementedError 17 | 18 | 19 | def getattr_recursive(m, attr): 20 | for name in attr.split("."): 21 | m = getattr(m, name) 22 | return m 23 | 24 | 25 | def get_parameters(model, name): 26 | module = getattr_recursive(model, name) 27 | if isinstance(module, nn.Module): 28 | return module.parameters() 29 | elif isinstance(module, nn.Parameter): 30 | return module 31 | return [] 32 | from itertools import chain 33 | 34 | def parse_optimizer_control(config, model): 35 | # import ipdb 36 | # ipdb.set_trace() 37 | params = model.denoiser_model.blocks[0].cross.parameters() 38 | for num in range(len(model.denoiser_model.blocks)): 39 | if num>0: 40 | params=chain(params,model.denoiser_model.blocks[num].cross.parameters()) 41 | 42 | # if hasattr(config, "params"): 43 | # params = [ 44 | # {"params": get_parameters(model, name), "name": name, **args} 45 | # for name, args in config.params.items() 46 | # ] 47 | # mar3d.debug(f"Specify optimizer params: {config.params}") 48 | # else: 49 | # params = model.parameters() 50 | if config.name in ["FusedAdam"]: 51 | import apex 52 | optim = getattr(apex.optimizers, config.name)(params, **config.args) 53 | elif config.name in ["Adan"]: 54 | from mar3d.systems import optimizers 55 | 56 | optim = getattr(optimizers, config.name)(params, **config.args) 57 | else: 58 | optim = getattr(torch.optim, config.name)(params, **config.args) 59 | return optim 60 | 61 | def parse_optimizer(config, model): 62 | if hasattr(config, "params"): 63 | params = [ 64 | {"params": get_parameters(model, name), "name": name, **args} 65 | for name, args in config.params.items() 66 | ] 67 | mar3d.debug(f"Specify optimizer params: {config.params}") 68 | else: 69 | # params = model.parameters() 70 | params= [p for p in model.parameters() if p.requires_grad] 71 | # if config.name in ["FusedAdam"]: 72 | # import apex 73 | 74 | # optim = getattr(apex.optimizers, config.name)(params, **config.args) 75 | # elif config.name in ["Adan"]: 76 | # from mar3d.systems import optimizers 77 | 78 | # optim = getattr(optimizers, config.name)(params, **config.args) 79 | # else: 80 | optim = getattr(torch.optim, config.name)(params, **config.args) 81 | return optim 82 | 83 | 84 | def parse_scheduler_to_instance(config, optimizer): 85 | if config.name == "ChainedScheduler": 86 | schedulers = [ 87 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 88 | ] 89 | scheduler = lr_scheduler.ChainedScheduler(schedulers) 90 | elif config.name == "Sequential": 91 | schedulers = [ 92 | parse_scheduler_to_instance(conf, optimizer) for conf in config.schedulers 93 | ] 94 | scheduler = lr_scheduler.SequentialLR( 95 | optimizer, schedulers, milestones=config.milestones 96 | ) 97 | else: 98 | scheduler = getattr(lr_scheduler, config.name)(optimizer, **config.args) 99 | return scheduler 100 | 101 | 102 | def parse_scheduler(config, optimizer): 103 | interval = config.get("interval", "epoch") 104 | assert interval in ["epoch", "step"] 105 | if config.name == "SequentialLR": 106 | scheduler = { 107 | "scheduler": lr_scheduler.SequentialLR( 108 | optimizer, 109 | [ 110 | parse_scheduler(conf, optimizer)["scheduler"] 111 | for conf in config.schedulers 112 | ], 113 | milestones=config.milestones, 114 | ), 115 | "interval": interval, 116 | } 117 | elif config.name == "ChainedScheduler": 118 | scheduler = { 119 | "scheduler": lr_scheduler.ChainedScheduler( 120 | [ 121 | parse_scheduler(conf, optimizer)["scheduler"] 122 | for conf in config.schedulers 123 | ] 124 | ), 125 | "interval": interval, 126 | } 127 | else: 128 | scheduler = { 129 | "scheduler": get_scheduler(config.name)(optimizer, **config.args), 130 | "interval": interval, 131 | } 132 | return scheduler 133 | -------------------------------------------------------------------------------- /mar3d/utils/config.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | from datetime import datetime 4 | 5 | from omegaconf import OmegaConf 6 | 7 | import mar3d 8 | from mar3d.utils.typing import * 9 | 10 | # ============ Register OmegaConf Recolvers ============= # 11 | OmegaConf.register_new_resolver( 12 | "calc_exp_lr_decay_rate", lambda factor, n: factor ** (1.0 / n) 13 | ) 14 | OmegaConf.register_new_resolver("add", lambda a, b: a + b) 15 | OmegaConf.register_new_resolver("sub", lambda a, b: a - b) 16 | OmegaConf.register_new_resolver("mul", lambda a, b: a * b) 17 | OmegaConf.register_new_resolver("div", lambda a, b: a / b) 18 | OmegaConf.register_new_resolver("idiv", lambda a, b: a // b) 19 | OmegaConf.register_new_resolver("basename", lambda p: os.path.basename(p)) 20 | OmegaConf.register_new_resolver("rmspace", lambda s, sub: str(s).replace(" ", sub)) 21 | OmegaConf.register_new_resolver("tuple2", lambda s: [float(s), float(s)]) 22 | OmegaConf.register_new_resolver("gt0", lambda s: s > 0) 23 | OmegaConf.register_new_resolver("cmaxgt0", lambda s: C_max(s) > 0) 24 | OmegaConf.register_new_resolver("not", lambda s: not s) 25 | OmegaConf.register_new_resolver( 26 | "cmaxgt0orcmaxgt0", lambda a, b: C_max(a) > 0 or C_max(b) > 0 27 | ) 28 | # ======================================================= # 29 | 30 | 31 | def C_max(value: Any) -> float: 32 | if isinstance(value, int) or isinstance(value, float): 33 | pass 34 | else: 35 | value = config_to_primitive(value) 36 | if not isinstance(value, list): 37 | raise TypeError("Scalar specification only supports list, got", type(value)) 38 | if len(value) >= 6: 39 | max_value = value[2] 40 | for i in range(4, len(value), 2): 41 | max_value = max(max_value, value[i]) 42 | value = [value[0], value[1], max_value, value[3]] 43 | if len(value) == 3: 44 | value = [0] + value 45 | assert len(value) == 4 46 | start_step, start_value, end_value, end_step = value 47 | value = max(start_value, end_value) 48 | return value 49 | 50 | 51 | @dataclass 52 | class ExperimentConfig: 53 | name: str = "default" 54 | description: str = "" 55 | tag: str = "" 56 | seed: int = 0 57 | use_timestamp: bool = True 58 | timestamp: Optional[str] = None 59 | exp_root_dir: str = "outputs" 60 | 61 | ### these shouldn't be set manually 62 | exp_dir: str = "outputs/default" 63 | trial_name: str = "exp" 64 | trial_dir: str = "outputs/default/exp" 65 | n_gpus: int = 1 66 | ### 67 | 68 | resume: Optional[str] = None 69 | 70 | data_type: str = "" 71 | data: dict = field(default_factory=dict) 72 | 73 | system_type: str = "" 74 | system: dict = field(default_factory=dict) 75 | 76 | # accept pytorch-lightning trainer parameters 77 | # see https://lightning.ai/docs/pytorch/stable/common/trainer.html#trainer-class-api 78 | trainer: dict = field(default_factory=dict) 79 | 80 | # accept pytorch-lightning checkpoint callback parameters 81 | # see https://lightning.ai/docs/pytorch/stable/api/lightning.pytorch.callbacks.ModelCheckpoint.html#modelcheckpoint 82 | checkpoint: dict = field(default_factory=dict) 83 | 84 | def __post_init__(self): 85 | if not self.tag and not self.use_timestamp: 86 | raise ValueError("Either tag is specified or use_timestamp is True.") 87 | self.trial_name = self.tag 88 | # if resume from an existing config, self.timestamp should not be None 89 | if self.timestamp is None: 90 | self.timestamp = "" 91 | if self.use_timestamp: 92 | if self.n_gpus > 1: 93 | mar3d.warn( 94 | "Timestamp is disabled when using multiple GPUs, please make sure you have a unique tag." 95 | ) 96 | else: 97 | self.timestamp = datetime.now().strftime("@%Y%m%d-%H%M%S") 98 | self.trial_name += self.timestamp 99 | self.exp_dir = os.path.join(self.exp_root_dir, self.name) 100 | self.trial_dir = os.path.join(self.exp_dir, self.trial_name) 101 | os.makedirs(self.trial_dir, exist_ok=True) 102 | 103 | 104 | def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any: 105 | if from_string: 106 | yaml_confs = [OmegaConf.create(s) for s in yamls] 107 | else: 108 | yaml_confs = [OmegaConf.load(f) for f in yamls] 109 | cli_conf = OmegaConf.from_cli(cli_args) 110 | cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs) 111 | OmegaConf.resolve(cfg) 112 | assert isinstance(cfg, DictConfig) 113 | scfg = parse_structured(ExperimentConfig, cfg) 114 | return scfg 115 | 116 | 117 | def config_to_primitive(config, resolve: bool = True) -> Any: 118 | return OmegaConf.to_container(config, resolve=resolve) 119 | 120 | 121 | def dump_config(path: str, config) -> None: 122 | with open(path, "w") as fp: 123 | OmegaConf.save(config=config, f=fp) 124 | 125 | 126 | def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: 127 | scfg = OmegaConf.structured(fields(**cfg)) 128 | return scfg -------------------------------------------------------------------------------- /mar3d/utils/misc.py: -------------------------------------------------------------------------------- 1 | import gc 2 | import os 3 | import re 4 | 5 | import torch 6 | import torch.distributed as dist 7 | from packaging import version 8 | 9 | from mar3d.utils.config import config_to_primitive 10 | from mar3d.utils.typing import * 11 | 12 | 13 | 14 | def parse_version(ver: str): 15 | return version.parse(ver) 16 | 17 | 18 | def get_rank(): 19 | # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, 20 | # therefore LOCAL_RANK needs to be checked first 21 | rank_keys = ("RANK", "LOCAL_RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") 22 | for key in rank_keys: 23 | rank = os.environ.get(key) 24 | if rank is not None: 25 | return int(rank) 26 | return 0 27 | 28 | def get_world_size(): 29 | world_size_keys = ("WORLD_SIZE", "SLURM_NTASKS", "JSM_NAMESPACE_SIZE") 30 | for key in world_size_keys: 31 | world_size = os.environ.get(key) 32 | if world_size is not None: 33 | return int(world_size) 34 | return 1 35 | 36 | def get_device(): 37 | return torch.device(f"cuda:{get_rank()}") 38 | 39 | 40 | def load_module_weights( 41 | path, module_name=None, ignore_modules=None, map_location=None 42 | ) -> Tuple[dict, int, int]: 43 | if module_name is not None and ignore_modules is not None: 44 | raise ValueError("module_name and ignore_modules cannot be both set") 45 | if map_location is None: 46 | map_location = get_device() 47 | 48 | ckpt = torch.load(path, map_location=map_location) 49 | state_dict = ckpt["state_dict"] 50 | state_dict_to_load = state_dict 51 | 52 | if ignore_modules is not None: 53 | state_dict_to_load = {} 54 | for k, v in state_dict.items(): 55 | ignore = any( 56 | [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] 57 | ) 58 | if ignore: 59 | continue 60 | state_dict_to_load[k] = v 61 | 62 | if module_name is not None: 63 | state_dict_to_load = {} 64 | for k, v in state_dict.items(): 65 | m = re.match(rf"^{module_name}\.(.*)$", k) 66 | if m is None: 67 | continue 68 | state_dict_to_load[m.group(1)] = v 69 | 70 | return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] 71 | 72 | 73 | def C(value: Any, epoch: int, global_step: int) -> float: 74 | if isinstance(value, int) or isinstance(value, float): 75 | pass 76 | else: 77 | value = config_to_primitive(value) 78 | if not isinstance(value, list): 79 | raise TypeError("Scalar specification only supports list, got", type(value)) 80 | if len(value) == 3: 81 | value = [0] + value 82 | assert len(value) == 4 83 | start_step, start_value, end_value, end_step = value 84 | if isinstance(end_step, int): 85 | current_step = global_step 86 | value = start_value + (end_value - start_value) * max( 87 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 88 | ) 89 | elif isinstance(end_step, float): 90 | current_step = epoch 91 | value = start_value + (end_value - start_value) * max( 92 | min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 93 | ) 94 | return value 95 | 96 | 97 | def cleanup(): 98 | gc.collect() 99 | torch.cuda.empty_cache() 100 | tcnn.free_temporary_memory() 101 | 102 | 103 | def finish_with_cleanup(func: Callable): 104 | def wrapper(*args, **kwargs): 105 | out = func(*args, **kwargs) 106 | cleanup() 107 | return out 108 | 109 | return wrapper 110 | 111 | 112 | def _distributed_available(): 113 | return torch.distributed.is_available() and torch.distributed.is_initialized() 114 | 115 | 116 | def barrier(): 117 | if not _distributed_available(): 118 | return 119 | else: 120 | torch.distributed.barrier() 121 | 122 | 123 | def broadcast(tensor, src=0): 124 | if not _distributed_available(): 125 | return tensor 126 | else: 127 | torch.distributed.broadcast(tensor, src=src) 128 | return tensor 129 | 130 | 131 | def enable_gradient(model, enabled: bool = True) -> None: 132 | for param in model.parameters(): 133 | param.requires_grad_(enabled) 134 | 135 | 136 | def all_gather_batch(tensors): 137 | """ 138 | Performs all_gather operation on the provided tensors. 139 | """ 140 | # Queue the gathered tensors 141 | world_size = get_world_size() 142 | # There is no need for reduction in the single-proc case 143 | if world_size == 1: 144 | if isinstance(tensors, list): 145 | return tensors 146 | return tensors 147 | if not isinstance(tensors, list): 148 | is_list = False 149 | tensors = [tensors] 150 | else: 151 | is_list = True 152 | output_tensor = [] 153 | tensor_list = [] 154 | for tensor in tensors: 155 | tensor_all = [torch.ones_like(tensor) for _ in range(world_size)] 156 | dist.all_gather( 157 | tensor_all, 158 | tensor, 159 | async_op=False # performance opt 160 | ) 161 | 162 | tensor_list.append(tensor_all) 163 | 164 | for tensor_all in tensor_list: 165 | output_tensor.append(torch.cat(tensor_all, dim=0)) 166 | if not is_list: 167 | return output_tensor[0] 168 | return output_tensor -------------------------------------------------------------------------------- /mar3d/utils/callbacks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import shutil 3 | import subprocess 4 | 5 | import pytorch_lightning 6 | 7 | from mar3d.utils.config import dump_config 8 | from mar3d.utils.misc import parse_version 9 | 10 | if parse_version(pytorch_lightning.__version__) > parse_version("1.8"): 11 | from pytorch_lightning.callbacks import Callback 12 | else: 13 | from pytorch_lightning.callbacks.base import Callback 14 | 15 | from pytorch_lightning.callbacks.progress import TQDMProgressBar 16 | from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn 17 | 18 | 19 | class VersionedCallback(Callback): 20 | def __init__(self, save_root, version=None, use_version=True): 21 | self.save_root = save_root 22 | self._version = version 23 | self.use_version = use_version 24 | 25 | @property 26 | def version(self) -> int: 27 | """Get the experiment version. 28 | 29 | Returns: 30 | The experiment version if specified else the next version. 31 | """ 32 | if self._version is None: 33 | self._version = self._get_next_version() 34 | return self._version 35 | 36 | def _get_next_version(self): 37 | existing_versions = [] 38 | if os.path.isdir(self.save_root): 39 | for f in os.listdir(self.save_root): 40 | bn = os.path.basename(f) 41 | if bn.startswith("version_"): 42 | dir_ver = os.path.splitext(bn)[0].split("_")[1].replace("/", "") 43 | existing_versions.append(int(dir_ver)) 44 | if len(existing_versions) == 0: 45 | return 0 46 | return max(existing_versions) + 1 47 | 48 | @property 49 | def savedir(self): 50 | if not self.use_version: 51 | return self.save_root 52 | return os.path.join( 53 | self.save_root, 54 | self.version 55 | if isinstance(self.version, str) 56 | else f"version_{self.version}", 57 | ) 58 | 59 | 60 | class CodeSnapshotCallback(VersionedCallback): 61 | def __init__(self, save_root, version=None, use_version=True): 62 | super().__init__(save_root, version, use_version) 63 | 64 | def get_file_list(self): 65 | return [ 66 | b.decode() 67 | for b in set( 68 | subprocess.check_output( 69 | 'git ls-files -- ":!:load/*"', shell=True 70 | ).splitlines() 71 | ) 72 | | set( # hard code, TODO: use config to exclude folders or files 73 | subprocess.check_output( 74 | "git ls-files --others --exclude-standard", shell=True 75 | ).splitlines() 76 | ) 77 | ] 78 | 79 | @rank_zero_only 80 | def save_code_snapshot(self): 81 | os.makedirs(self.savedir, exist_ok=True) 82 | for f in self.get_file_list(): 83 | if not os.path.exists(f) or os.path.isdir(f): 84 | continue 85 | os.makedirs(os.path.join(self.savedir, os.path.dirname(f)), exist_ok=True) 86 | shutil.copyfile(f, os.path.join(self.savedir, f)) 87 | 88 | def on_fit_start(self, trainer, pl_module): 89 | try: 90 | self.save_code_snapshot() 91 | except: 92 | rank_zero_warn( 93 | "Code snapshot is not saved. Please make sure you have git installed and are in a git repository." 94 | ) 95 | 96 | 97 | class ConfigSnapshotCallback(VersionedCallback): 98 | def __init__(self, config_path, config, save_root, version=None, use_version=True): 99 | super().__init__(save_root, version, use_version) 100 | self.config_path = config_path 101 | self.config = config 102 | 103 | @rank_zero_only 104 | def save_config_snapshot(self): 105 | os.makedirs(self.savedir, exist_ok=True) 106 | dump_config(os.path.join(self.savedir, "parsed.yaml"), self.config) 107 | shutil.copyfile(self.config_path, os.path.join(self.savedir, "raw.yaml")) 108 | 109 | def on_fit_start(self, trainer, pl_module): 110 | self.save_config_snapshot() 111 | 112 | 113 | class CustomProgressBar(TQDMProgressBar): 114 | def get_metrics(self, *args, **kwargs): 115 | # don't show the version number 116 | items = super().get_metrics(*args, **kwargs) 117 | items.pop("v_num", None) 118 | return items 119 | 120 | 121 | class ProgressCallback(Callback): 122 | def __init__(self, save_path): 123 | super().__init__() 124 | self.save_path = save_path 125 | self._file_handle = None 126 | 127 | @property 128 | def file_handle(self): 129 | if self._file_handle is None: 130 | self._file_handle = open(self.save_path, "w") 131 | return self._file_handle 132 | 133 | @rank_zero_only 134 | def write(self, msg: str) -> None: 135 | self.file_handle.seek(0) 136 | self.file_handle.truncate() 137 | self.file_handle.write(msg) 138 | self.file_handle.flush() 139 | 140 | @rank_zero_only 141 | def on_train_batch_end(self, trainer, pl_module, *args, **kwargs): 142 | self.write( 143 | f"Generation progress: {pl_module.true_global_step / trainer.max_steps * 100:.2f}%" 144 | ) 145 | 146 | @rank_zero_only 147 | def on_validation_start(self, trainer, pl_module): 148 | self.write(f"Rendering validation image ...") 149 | 150 | @rank_zero_only 151 | def on_test_start(self, trainer, pl_module): 152 | self.write(f"Rendering video ...") 153 | 154 | @rank_zero_only 155 | def on_predict_start(self, trainer, pl_module): 156 | self.write(f"Exporting mesh assets ...") 157 | -------------------------------------------------------------------------------- /diffusion/respace.py: -------------------------------------------------------------------------------- 1 | # Modified from OpenAI's diffusion repos 2 | # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py 3 | # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion 4 | # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py 5 | 6 | import numpy as np 7 | import torch as th 8 | 9 | from .gaussian_diffusion import GaussianDiffusion 10 | 11 | 12 | def space_timesteps(num_timesteps, section_counts): 13 | """ 14 | Create a list of timesteps to use from an original diffusion process, 15 | given the number of timesteps we want to take from equally-sized portions 16 | of the original process. 17 | For example, if there's 300 timesteps and the section counts are [10,15,20] 18 | then the first 100 timesteps are strided to be 10 timesteps, the second 100 19 | are strided to be 15 timesteps, and the final 100 are strided to be 20. 20 | If the stride is a string starting with "ddim", then the fixed striding 21 | from the DDIM paper is used, and only one section is allowed. 22 | :param num_timesteps: the number of diffusion steps in the original 23 | process to divide up. 24 | :param section_counts: either a list of numbers, or a string containing 25 | comma-separated numbers, indicating the step count 26 | per section. As a special case, use "ddimN" where N 27 | is a number of steps to use the striding from the 28 | DDIM paper. 29 | :return: a set of diffusion steps from the original process to use. 30 | """ 31 | if isinstance(section_counts, str): 32 | if section_counts.startswith("ddim"): 33 | desired_count = int(section_counts[len("ddim") :]) 34 | for i in range(1, num_timesteps): 35 | if len(range(0, num_timesteps, i)) == desired_count: 36 | return set(range(0, num_timesteps, i)) 37 | raise ValueError( 38 | f"cannot create exactly {num_timesteps} steps with an integer stride" 39 | ) 40 | section_counts = [int(x) for x in section_counts.split(",")] 41 | size_per = num_timesteps // len(section_counts) 42 | extra = num_timesteps % len(section_counts) 43 | start_idx = 0 44 | all_steps = [] 45 | for i, section_count in enumerate(section_counts): 46 | size = size_per + (1 if i < extra else 0) 47 | if size < section_count: 48 | raise ValueError( 49 | f"cannot divide section of {size} steps into {section_count}" 50 | ) 51 | if section_count <= 1: 52 | frac_stride = 1 53 | else: 54 | frac_stride = (size - 1) / (section_count - 1) 55 | cur_idx = 0.0 56 | taken_steps = [] 57 | for _ in range(section_count): 58 | taken_steps.append(start_idx + round(cur_idx)) 59 | cur_idx += frac_stride 60 | all_steps += taken_steps 61 | start_idx += size 62 | return set(all_steps) 63 | 64 | 65 | class SpacedDiffusion(GaussianDiffusion): 66 | """ 67 | A diffusion process which can skip steps in a base diffusion process. 68 | :param use_timesteps: a collection (sequence or set) of timesteps from the 69 | original diffusion process to retain. 70 | :param kwargs: the kwargs to create the base diffusion process. 71 | """ 72 | 73 | def __init__(self, use_timesteps, **kwargs): 74 | self.use_timesteps = set(use_timesteps) 75 | self.timestep_map = [] 76 | self.original_num_steps = len(kwargs["betas"]) 77 | 78 | base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa 79 | last_alpha_cumprod = 1.0 80 | new_betas = [] 81 | for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): 82 | if i in self.use_timesteps: 83 | new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) 84 | last_alpha_cumprod = alpha_cumprod 85 | self.timestep_map.append(i) 86 | kwargs["betas"] = np.array(new_betas) 87 | super().__init__(**kwargs) 88 | 89 | def p_mean_variance( 90 | self, model, *args, **kwargs 91 | ): # pylint: disable=signature-differs 92 | return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) 93 | 94 | def training_losses( 95 | self, model, *args, **kwargs 96 | ): # pylint: disable=signature-differs 97 | return super().training_losses(self._wrap_model(model), *args, **kwargs) 98 | 99 | def condition_mean(self, cond_fn, *args, **kwargs): 100 | return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) 101 | 102 | def condition_score(self, cond_fn, *args, **kwargs): 103 | return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) 104 | 105 | def _wrap_model(self, model): 106 | if isinstance(model, _WrappedModel): 107 | return model 108 | return _WrappedModel( 109 | model, self.timestep_map, self.original_num_steps 110 | ) 111 | 112 | def _scale_timesteps(self, t): 113 | # Scaling is done by the wrapped model. 114 | return t 115 | 116 | 117 | class _WrappedModel: 118 | def __init__(self, model, timestep_map, original_num_steps): 119 | self.model = model 120 | self.timestep_map = timestep_map 121 | # self.rescale_timesteps = rescale_timesteps 122 | self.original_num_steps = original_num_steps 123 | 124 | def __call__(self, x, ts, **kwargs): 125 | map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) 126 | new_ts = map_tensor[ts] 127 | # if self.rescale_timesteps: 128 | # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) 129 | return self.model(x, new_ts, **kwargs) 130 | -------------------------------------------------------------------------------- /mar3d/utils/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | 9 | import mar3d 10 | from mar3d.utils.typing import * 11 | 12 | 13 | def dot(x, y): 14 | return torch.sum(x * y, -1, keepdim=True) 15 | 16 | 17 | def reflect(x, n): 18 | return 2 * dot(x, n) * n - x 19 | 20 | 21 | ValidScale = Union[Tuple[float, float], Num[Tensor, "2 D"]] 22 | 23 | 24 | def scale_tensor( 25 | dat: Num[Tensor, "... D"], inp_scale: ValidScale, tgt_scale: ValidScale 26 | ): 27 | if inp_scale is None: 28 | inp_scale = (0, 1) 29 | if tgt_scale is None: 30 | tgt_scale = (0, 1) 31 | if isinstance(tgt_scale, Tensor): 32 | assert dat.shape[-1] == tgt_scale.shape[-1] 33 | dat = (dat - inp_scale[0]) / (inp_scale[1] - inp_scale[0]) 34 | dat = dat * (tgt_scale[1] - tgt_scale[0]) + tgt_scale[0] 35 | return dat 36 | 37 | 38 | def chunk_batch(func: Callable, chunk_size: int, *args, **kwargs) -> Any: 39 | if chunk_size <= 0: 40 | return func(*args, **kwargs) 41 | B = None 42 | for arg in list(args) + list(kwargs.values()): 43 | if isinstance(arg, torch.Tensor): 44 | B = arg.shape[0] 45 | break 46 | assert ( 47 | B is not None 48 | ), "No tensor found in args or kwargs, cannot determine batch size." 49 | out = defaultdict(list) 50 | out_type = None 51 | # max(1, B) to support B == 0 52 | for i in range(0, max(1, B), chunk_size): 53 | out_chunk = func( 54 | *[ 55 | arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg 56 | for arg in args 57 | ], 58 | **{ 59 | k: arg[i : i + chunk_size] if isinstance(arg, torch.Tensor) else arg 60 | for k, arg in kwargs.items() 61 | }, 62 | ) 63 | if out_chunk is None: 64 | continue 65 | out_type = type(out_chunk) 66 | if isinstance(out_chunk, torch.Tensor): 67 | out_chunk = {0: out_chunk} 68 | elif isinstance(out_chunk, tuple) or isinstance(out_chunk, list): 69 | chunk_length = len(out_chunk) 70 | out_chunk = {i: chunk for i, chunk in enumerate(out_chunk)} 71 | elif isinstance(out_chunk, dict): 72 | pass 73 | else: 74 | print( 75 | f"Return value of func must be in type [torch.Tensor, list, tuple, dict], get {type(out_chunk)}." 76 | ) 77 | exit(1) 78 | for k, v in out_chunk.items(): 79 | v = v if torch.is_grad_enabled() else v.detach() 80 | out[k].append(v) 81 | 82 | if out_type is None: 83 | return None 84 | 85 | out_merged: Dict[Any, Optional[torch.Tensor]] = {} 86 | for k, v in out.items(): 87 | if all([vv is None for vv in v]): 88 | # allow None in return value 89 | out_merged[k] = None 90 | elif all([isinstance(vv, torch.Tensor) for vv in v]): 91 | out_merged[k] = torch.cat(v, dim=0) 92 | else: 93 | raise TypeError( 94 | f"Unsupported types in return value of func: {[type(vv) for vv in v if not isinstance(vv, torch.Tensor)]}" 95 | ) 96 | 97 | if out_type is torch.Tensor: 98 | return out_merged[0] 99 | elif out_type in [tuple, list]: 100 | return out_type([out_merged[i] for i in range(chunk_length)]) 101 | elif out_type is dict: 102 | return out_merged 103 | 104 | 105 | def randn_tensor( 106 | shape: Union[Tuple, List], 107 | generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None, 108 | device: Optional["torch.device"] = None, 109 | dtype: Optional["torch.dtype"] = None, 110 | layout: Optional["torch.layout"] = None, 111 | ): 112 | """A helper function to create random tensors on the desired `device` with the desired `dtype`. When 113 | passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor 114 | is always created on the CPU. 115 | """ 116 | # device on which tensor is created defaults to device 117 | rand_device = device 118 | batch_size = shape[0] 119 | 120 | layout = layout or torch.strided 121 | device = device or torch.device("cpu") 122 | 123 | if generator is not None: 124 | gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type 125 | if gen_device_type != device.type and gen_device_type == "cpu": 126 | rand_device = "cpu" 127 | if device != "mps": 128 | logger.info( 129 | f"The passed generator was created on 'cpu' even though a tensor on {device} was expected." 130 | f" Tensors will be created on 'cpu' and then moved to {device}. Note that one can probably" 131 | f" slighly speed up this function by passing a generator that was created on the {device} device." 132 | ) 133 | elif gen_device_type != device.type and gen_device_type == "cuda": 134 | raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.") 135 | 136 | # make sure generator list of length 1 is treated like a non-list 137 | if isinstance(generator, list) and len(generator) == 1: 138 | generator = generator[0] 139 | 140 | if isinstance(generator, list): 141 | shape = (1,) + shape[1:] 142 | latents = [ 143 | torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout) 144 | for i in range(batch_size) 145 | ] 146 | latents = torch.cat(latents, dim=0).to(device) 147 | else: 148 | latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device) 149 | 150 | return latents 151 | 152 | 153 | def generate_dense_grid_points( 154 | bbox_min: np.ndarray, 155 | bbox_max: np.ndarray, 156 | octree_depth: int, 157 | indexing: str = "ij" 158 | ): 159 | length = bbox_max - bbox_min 160 | num_cells = np.exp2(octree_depth) 161 | x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) 162 | y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) 163 | z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) 164 | [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) 165 | xyz = np.stack((xs, ys, zs), axis=-1) 166 | xyz = xyz.reshape(-1, 3) 167 | grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] 168 | 169 | return xyz, grid_size, length -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/base.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | import torch.nn as nn 4 | import numpy as np 5 | from PIL import Image 6 | from dataclasses import dataclass 7 | from torchvision.transforms import Normalize 8 | from torchvision.transforms import InterpolationMode 9 | from torchvision.transforms.transforms import _interpolation_modes_from_int 10 | 11 | from transformers import CLIPModel, CLIPTokenizer, CLIPImageProcessor 12 | from transformers.utils import ModelOutput 13 | from typing import Iterable, Optional, Union, List 14 | 15 | import mar3d 16 | from mar3d.utils.base import BaseModule 17 | from mar3d.utils.typing import * 18 | 19 | ImageType = Union[np.ndarray, torch.Tensor, Image.Image] 20 | 21 | 22 | class BaseEmbedder(BaseModule): 23 | @dataclass 24 | class Config(BaseModule.Config): 25 | pretrained_model_name_or_path: Optional[str] = None # the pretrained model name or path 26 | 27 | encode_camera: bool = False # whether to encode camera 28 | camera_embeds_type: str = "sincos" # the type of camera embeds 29 | camera_embeds_dim: Optional[int] = None # the dimension of camera embeds 30 | n_views: int = 1 # the number of views 31 | 32 | dino: bool = False 33 | nor: bool = False 34 | empty_embeds_ratio: float = 0.1 # the ratio of empty embeds 35 | zero_uncond_embeds: bool = True 36 | bbox: bool =False 37 | normalize_embeds: bool = False # whether to normalize the embeds 38 | 39 | cfg: Config 40 | 41 | def configure(self) -> None: 42 | super().configure() 43 | 44 | if self.cfg.encode_camera: 45 | self.distance = 1.0 46 | self.register_buffer( 47 | "cameras", 48 | torch.as_tensor([ 49 | [[1, 0, 0, 0], 50 | [0, 0, -1, -self.distance], 51 | [0, 1, 0, 0], 52 | [0, 0, 0, 1]], # front to back 53 | 54 | [[0, 0, 1, self.distance], 55 | [1, 0, 0, 0], 56 | [0, 1, 0, 0], 57 | [0, 0, 0, 1]], # right to left 58 | 59 | [[-1, 0, 0, 0], 60 | [0, 0, 1, self.distance], 61 | [0, 1, 0, 0], 62 | [0, 0, 0, 1]], # back to front 63 | 64 | [[0, 0, -1, -self.distance], 65 | [-1, 0, 0, 0], 66 | [0, 1, 0, 0], 67 | [0, 0, 0, 1]], # left to right 68 | ], dtype=torch.float32), 69 | ) 70 | 71 | def encode_image(self, images: Iterable[Optional[ImageType]], camera_embeds: Optional[torch.Tensor] = None, **kwargs) -> torch.FloatTensor: 72 | pass 73 | 74 | def encode_text(self, texts: List[str], **kwargs) -> torch.FloatTensor: 75 | pass 76 | 77 | def encode_camera(self, c2ws: torch.Tensor): 78 | if self.cfg.camera_embeds_type == "sincos": 79 | assert c2ws.shape[-1] == 4 and c2ws.shape[-2] == 4, f"Invalid c2ws shape: {c2ws.shape}" 80 | c2ws = c2ws.view(-1, 16) 81 | return torch.cat([torch.sin(c2ws), torch.cos(c2ws)], dim=-1) 82 | else: 83 | raise NotImplementedError(f"Unknown camera_embeds_type: {self.cfg.camera_embeds_type}") 84 | 85 | def post_process_embeds(self, text_embeds, visual_embeds): 86 | bs = text_embeds.shape[0] if text_embeds is not None else visual_embeds.shape[0] 87 | 88 | if self.cfg.normalize_embeds: 89 | # post-process the text/visual embeds 90 | if text_embeds is not None: 91 | text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) 92 | if visual_embeds is not None: 93 | visual_embeds = visual_embeds / visual_embeds.norm(dim=-1, keepdim=True) 94 | 95 | assert text_embeds is not None or visual_embeds is not None 96 | 97 | # return text_embeds, visual_embeds 98 | if text_embeds is not None and visual_embeds is not None: 99 | return [text_embeds, visual_embeds] 100 | # return torch.cat([text_embeds, visual_embeds], dim=1) 101 | elif text_embeds is not None: 102 | return text_embeds 103 | else: 104 | return visual_embeds 105 | 106 | def forward(self, batch): 107 | bs = batch["surface"].shape[0] 108 | 109 | text_embeds, visual_embeds = None, None 110 | 111 | bbox_drop=False 112 | if random.random() < self.cfg.empty_embeds_ratio: 113 | if "text_input_ids" in batch or "text_embeds" in batch: 114 | if self.empty_text_embeds is None: 115 | if not self.cfg.zero_uncond_embeds: 116 | self.empty_text_embeds = self.encode_text([""]).detach() # [1, 77, 768] 117 | text_embeds = self.empty_text_embeds.repeat(bs, 1, 1) 118 | if "image" in batch or "image_embeds" in batch: 119 | visual_embeds = self.empty_image_embeds.repeat(bs, 1, 1) 120 | if "normal" in batch: 121 | visual_embeds_normal=self.empty_image_embeds.repeat(bs, 1, 1) 122 | if self.cfg.bbox: 123 | bbox_drop=True 124 | 125 | 126 | elif "mvimages" in batch or "mvimage_embeds" in batch: 127 | visual_embeds = self.empty_image_embeds.unsqueeze(1).repeat(bs, 1, 1, 1) 128 | else: 129 | # for text inputs 130 | if "text_input_ids" in batch: 131 | # import ipdb 132 | # ipdb.set_trace() 133 | text_embeds = self.encode_text(batch["text_input_ids"]) 134 | 135 | # for visual inputs 136 | if "image" in batch: 137 | if self.cfg.encode_camera: 138 | visual_embeds = self.encode_image(batch["image"], cameras=batch["c2w"]) 139 | else: 140 | visual_embeds = self.encode_image(batch["image"]) 141 | if "normal" in batch: 142 | # if self.cfg.encode_camera: 143 | # visual_embeds = self.encode_image(batch["image"], cameras=batch["c2w"]) 144 | # else: 145 | 146 | visual_embeds_normal = self.encode_image(batch["normal"]) 147 | 148 | elif "mvimages" in batch: 149 | 150 | n_views = batch["mvimages"].shape[1] 151 | if self.cfg.encode_camera: 152 | visual_embeds = self.encode_image( 153 | batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:]), \ 154 | cameras=batch["c2ws"]).view(bs, n_views, *self.empty_image_embeds.shape[-2:]) 155 | else: 156 | visual_embeds = self.encode_image( 157 | batch["mvimages"].view(-1, *batch["mvimages"].shape[-3:])).view(bs, n_views, *self.empty_image_embeds.shape[-2:]) 158 | 159 | return self.post_process_embeds(text_embeds, visual_embeds) 160 | -------------------------------------------------------------------------------- /mar3d/models/geometry/base.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | import mar3d 9 | from .utils import ( 10 | Mesh, 11 | IsosurfaceHelper, 12 | MarchingCubeCPUHelper, 13 | MarchingTetrahedraHelper, 14 | ) 15 | 16 | from mar3d.utils.base import BaseModule 17 | from mar3d.utils.ops import chunk_batch, scale_tensor 18 | from mar3d.utils.typing import * 19 | 20 | class BaseGeometry(BaseModule): 21 | @dataclass 22 | class Config(BaseModule.Config): 23 | pass 24 | 25 | cfg: Config 26 | 27 | @staticmethod 28 | def create_from( 29 | other: "BaseGeometry", cfg: Optional[Union[dict, DictConfig]] = None, **kwargs 30 | ) -> "BaseGeometry": 31 | raise TypeError( 32 | f"Cannot create {BaseGeometry.__name__} from {other.__class__.__name__}" 33 | ) 34 | 35 | def export(self, *args, **kwargs) -> Dict[str, Any]: 36 | return {} 37 | 38 | 39 | class BaseImplicitGeometry(BaseGeometry): 40 | @dataclass 41 | class Config(BaseGeometry.Config): 42 | radius: float = 1.0 43 | isosurface: bool = True 44 | isosurface_method: str = "mt" 45 | isosurface_resolution: int = 128 46 | isosurface_threshold: Union[float, str] = 0.0 47 | isosurface_chunk: int = 0 48 | isosurface_coarse_to_fine: bool = True 49 | isosurface_deformable_grid: bool = False 50 | isosurface_remove_outliers: bool = True 51 | isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01 52 | 53 | cfg: Config 54 | 55 | def configure(self) -> None: 56 | self.bbox: Float[Tensor, "2 3"] 57 | self.register_buffer( 58 | "bbox", 59 | torch.as_tensor( 60 | [ 61 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 62 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 63 | ], 64 | dtype=torch.float32, 65 | ), 66 | ) 67 | self.isosurface_helper: Optional[IsosurfaceHelper] = None 68 | self.unbounded: bool = False 69 | 70 | def _initilize_isosurface_helper(self): 71 | if self.cfg.isosurface and self.isosurface_helper is None: 72 | if self.cfg.isosurface_method == "mc-cpu": 73 | self.isosurface_helper = MarchingCubeCPUHelper( 74 | self.cfg.isosurface_resolution 75 | ).to(self.device) 76 | elif self.cfg.isosurface_method == "mt": 77 | self.isosurface_helper = MarchingTetrahedraHelper( 78 | self.cfg.isosurface_resolution, 79 | f"load/tets/{self.cfg.isosurface_resolution}_tets.npz", 80 | ).to(self.device) 81 | else: 82 | raise AttributeError( 83 | "Unknown isosurface method {self.cfg.isosurface_method}" 84 | ) 85 | 86 | def forward( 87 | self, points: Float[Tensor, "*N Di"], output_normal: bool = False 88 | ) -> Dict[str, Float[Tensor, "..."]]: 89 | raise NotImplementedError 90 | 91 | def forward_field( 92 | self, points: Float[Tensor, "*N Di"] 93 | ) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]: 94 | # return the value of the implicit field, could be density / signed distance 95 | # also return a deformation field if the grid vertices can be optimized 96 | raise NotImplementedError 97 | 98 | def forward_level( 99 | self, field: Float[Tensor, "*N 1"], threshold: float 100 | ) -> Float[Tensor, "*N 1"]: 101 | # return the value of the implicit field, where the zero level set represents the surface 102 | raise NotImplementedError 103 | 104 | def _isosurface(self, bbox: Float[Tensor, "2 3"], fine_stage: bool = False) -> Mesh: 105 | def batch_func(x): 106 | # scale to bbox as the input vertices are in [0, 1] 107 | field, deformation = self.forward_field( 108 | scale_tensor( 109 | x.to(bbox.device), self.isosurface_helper.points_range, bbox 110 | ), 111 | ) 112 | field = field.to( 113 | x.device 114 | ) # move to the same device as the input (could be CPU) 115 | if deformation is not None: 116 | deformation = deformation.to(x.device) 117 | return field, deformation 118 | 119 | assert self.isosurface_helper is not None 120 | 121 | field, deformation = chunk_batch( 122 | batch_func, 123 | self.cfg.isosurface_chunk, 124 | self.isosurface_helper.grid_vertices, 125 | ) 126 | 127 | threshold: float 128 | 129 | if isinstance(self.cfg.isosurface_threshold, float): 130 | threshold = self.cfg.isosurface_threshold 131 | elif self.cfg.isosurface_threshold == "auto": 132 | eps = 1.0e-5 133 | threshold = field[field > eps].mean().item() 134 | mar3d.info( 135 | f"Automatically determined isosurface threshold: {threshold}" 136 | ) 137 | else: 138 | raise TypeError( 139 | f"Unknown isosurface_threshold {self.cfg.isosurface_threshold}" 140 | ) 141 | 142 | level = self.forward_level(field, threshold) 143 | mesh: Mesh = self.isosurface_helper(level, deformation=deformation) 144 | mesh.v_pos = scale_tensor( 145 | mesh.v_pos, self.isosurface_helper.points_range, bbox 146 | ) # scale to bbox as the grid vertices are in [0, 1] 147 | mesh.add_extra("bbox", bbox) 148 | 149 | if self.cfg.isosurface_remove_outliers: 150 | # remove outliers components with small number of faces 151 | # only enabled when the mesh is not differentiable 152 | mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold) 153 | 154 | return mesh 155 | 156 | def isosurface(self) -> Mesh: 157 | if not self.cfg.isosurface: 158 | raise NotImplementedError( 159 | "Isosurface is not enabled in the current configuration" 160 | ) 161 | self._initilize_isosurface_helper() 162 | if self.cfg.isosurface_coarse_to_fine: 163 | mar3d.debug("First run isosurface to get a tight bounding box ...") 164 | with torch.no_grad(): 165 | mesh_coarse = self._isosurface(self.bbox) 166 | vmin, vmax = mesh_coarse.v_pos.amin(dim=0), mesh_coarse.v_pos.amax(dim=0) 167 | vmin_ = (vmin - (vmax - vmin) * 0.1).max(self.bbox[0]) 168 | vmax_ = (vmax + (vmax - vmin) * 0.1).min(self.bbox[1]) 169 | mar3d.debug("Run isosurface again with the tight bounding box ...") 170 | mesh = self._isosurface(torch.stack([vmin_, vmax_], dim=0), fine_stage=True) 171 | else: 172 | mesh = self._isosurface(self.bbox) 173 | return mesh 174 | 175 | 176 | class BaseExplicitGeometry(BaseGeometry): 177 | @dataclass 178 | class Config(BaseGeometry.Config): 179 | radius: float = 1.0 180 | 181 | cfg: Config 182 | 183 | def configure(self) -> None: 184 | self.bbox: Float[Tensor, "2 3"] 185 | self.register_buffer( 186 | "bbox", 187 | torch.as_tensor( 188 | [ 189 | [-self.cfg.radius, -self.cfg.radius, -self.cfg.radius], 190 | [self.cfg.radius, self.cfg.radius, self.cfg.radius], 191 | ], 192 | dtype=torch.float32, 193 | ), 194 | ) -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/clip_encoder.py: -------------------------------------------------------------------------------- 1 | import random 2 | import torch 3 | from torch import nn 4 | import numpy as np 5 | from PIL import Image 6 | from einops import rearrange 7 | from dataclasses import dataclass 8 | from torchvision.transforms import Normalize 9 | from torchvision.transforms import InterpolationMode 10 | from torchvision.transforms.transforms import _interpolation_modes_from_int 11 | from torchvision import transforms 12 | 13 | from transformers import CLIPTokenizer, CLIPImageProcessor 14 | from transformers.utils import ModelOutput 15 | from typing import Iterable, Optional, Union, List 16 | 17 | import mar3d 18 | from mar3d.utils.typing import * 19 | from .clip.modeling_clip import CLIPModel 20 | from .clip.modeling_conditional_clip import ConditionalCLIPModel 21 | from .base import BaseEmbedder, ImageType 22 | @dataclass 23 | class CLIPEmbedOutput(ModelOutput): 24 | last_hidden_state: torch.FloatTensor = None 25 | pooler_output: torch.FloatTensor = None 26 | embeds: torch.FloatTensor = None 27 | 28 | @mar3d.register("clip-embedder") 29 | class CLIPEmbedder(BaseEmbedder): 30 | 31 | @dataclass 32 | class Config(BaseEmbedder.Config): 33 | freeze_modulation: bool = False 34 | config_path: str = '' 35 | 36 | cfg: Config 37 | 38 | def configure(self) -> None: 39 | super().configure() 40 | if not self.cfg.encode_camera: 41 | self.model: CLIPModel = CLIPModel.from_pretrained(self.cfg.pretrained_model_name_or_path) 42 | else: 43 | if self.cfg.pretrained_model_name_or_path == '': 44 | assert self.cfg.config_path is not None, "The config path should be provided" 45 | conditional_clip_config = ConditionalCLIPModel.config_class.from_json_file(self.cfg.config_path) 46 | conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim 47 | self.model: CLIPModel = ConditionalCLIPModel(conditional_clip_config) 48 | else: 49 | conditional_clip_config = ConditionalCLIPModel.config_class.from_pretrained( 50 | self.cfg.pretrained_model_name_or_path, 51 | ) 52 | conditional_clip_config.vision_config.modulation_dim = self.cfg.camera_embeds_dim 53 | self.model: CLIPModel = ConditionalCLIPModel.from_pretrained( 54 | self.cfg.pretrained_model_name_or_path, 55 | vision_config=conditional_clip_config.vision_config 56 | ) 57 | 58 | self.tokenizer = None 59 | self.image_preprocess = CLIPImageProcessor() 60 | self.transform = transforms.Compose( 61 | [ 62 | transforms.Resize(224, transforms.InterpolationMode.BICUBIC, antialias=True), 63 | transforms.CenterCrop(224), # crop a (224, 224) square 64 | transforms.Normalize( 65 | mean=[0.48145466, 0.4578275, 0.40821073], 66 | std=[0.26862954, 0.26130258, 0.27577711], 67 | ), 68 | ] 69 | ) 70 | 71 | self.logit_scale = self.model.logit_scale.exp() 72 | 73 | if self.cfg.zero_uncond_embeds: 74 | self.empty_text_embeds = torch.zeros((1, 77, 768)).detach() 75 | self.empty_image_embeds = torch.zeros((self.cfg.n_views, 257, 1024)).detach() 76 | 77 | else: 78 | try: 79 | self.empty_text_embeds = self.encode_text([""]).detach() # [1, 77, 768] 80 | except: 81 | self.empty_text_embeds = None 82 | 83 | 84 | if self.cfg.encode_camera: 85 | self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3), self.cameras[:self.cfg.n_views]).detach() 86 | else: 87 | self.empty_image_embeds = self.encode_image(torch.zeros(self.cfg.n_views, 224, 224, 3)).detach() 88 | 89 | # Freeze the model parameters 90 | 91 | self.model.eval() 92 | for k, p in self.model.named_parameters(): 93 | ks = k.split('.') 94 | if 'mod_norm1' in ks or 'mod_norm2' in ks and not self.cfg.freeze_modulation: 95 | p.requires_grad_(True) 96 | else: 97 | p.requires_grad_(False) 98 | # def encode_image_normal(self, images): 99 | 100 | # return normal_feats 101 | def encode_image(self, images: Iterable[Optional[ImageType]], cameras: Optional[torch.Tensor] = None, force_none_camera_embeds: bool = False, return_dict: bool = False, **kwargs) -> torch.FloatTensor: 102 | camera_embeds = None 103 | 104 | if isinstance(images, (np.ndarray, torch.Tensor)): # for training process 105 | assert images.min() >= 0.0 and images.max() <= 1.0, "The pixel values should be in the range of [0, 1]" 106 | do_rescale = False 107 | if self.cfg.encode_camera: 108 | assert cameras is not None, "The cameras should be provided" 109 | camera_embeds = self.encode_camera(cameras) 110 | pixel_values = self.transform(images.permute(0, 3, 1, 2)) 111 | else: # for inference process 112 | do_rescale = True 113 | if self.cfg.encode_camera: 114 | if cameras is None: 115 | bs = len(images) // self.cfg.n_views 116 | cameras = self.cameras[:self.cfg.n_views].repeat(bs, 1, 1).to(self.model.device) 117 | camera_embeds = self.encode_camera(cameras) 118 | pixel_values = self.image_preprocess.preprocess(images, return_tensors='pt', do_rescale=do_rescale).pixel_values 119 | 120 | if force_none_camera_embeds: 121 | camera_embeds = None 122 | 123 | 124 | packed = False 125 | 126 | 127 | if pixel_values.ndim == 4: 128 | packed = True 129 | pixel_values = pixel_values.unsqueeze(1) 130 | if camera_embeds is not None: 131 | camera_embeds = camera_embeds.unsqueeze(1) 132 | 133 | if self.cfg.encode_camera and camera_embeds is not None: 134 | vision_outputs = self.model.vision_model( 135 | pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), 136 | condition=rearrange(camera_embeds, "B N C -> (B N) C") 137 | ) 138 | else: 139 | vision_outputs = self.model.vision_model( 140 | pixel_values=rearrange(pixel_values.to(self.model.device), "B N C H W -> (B N) C H W"), 141 | ) 142 | 143 | 144 | 145 | if return_dict: 146 | pooler_output = vision_outputs[1] # pooled_output 147 | image_features = self.model.visual_projection(pooler_output) 148 | return CLIPEmbedOutput( 149 | last_hidden_state=vision_outputs.last_hidden_state, 150 | pooler_output=pooler_output, 151 | embeds=image_features 152 | ) 153 | 154 | else: 155 | return vision_outputs.last_hidden_state 156 | 157 | 158 | @torch.no_grad() 159 | def encode_text(self, text_inputs: torch.Tensor, return_dict: bool = False) -> torch.FloatTensor: 160 | if self.tokenizer is None: 161 | self.tokenizer = CLIPTokenizer.from_pretrained(self.cfg.pretrained_model_name_or_path) 162 | 163 | if isinstance(text_inputs, list): 164 | text_inputs = self.tokenizer( 165 | text_inputs, 166 | max_length=self.tokenizer.model_max_length, 167 | padding="max_length", 168 | return_tensors="pt" 169 | ).input_ids 170 | text_outputs = self.model.text_model(input_ids=text_inputs.to(self.model.device)) 171 | 172 | pooler_output = text_outputs[1] # pooled_output 173 | text_features = self.model.text_projection(pooler_output) 174 | 175 | if return_dict: 176 | return CLIPEmbedOutput( 177 | last_hidden_state=text_outputs.last_hidden_state, 178 | pooler_output=pooler_output, 179 | embeds=text_features 180 | ) 181 | else: 182 | return text_outputs.last_hidden_state -------------------------------------------------------------------------------- /mar3d/systems/base.py: -------------------------------------------------------------------------------- 1 | import os 2 | from dataclasses import dataclass, field 3 | 4 | import pytorch_lightning as pl 5 | import torch.nn.functional as F 6 | 7 | import mar3d 8 | from mar3d.utils.base import ( 9 | Updateable, 10 | update_end_if_possible, 11 | update_if_possible, 12 | ) 13 | from mar3d.utils.scheduler import parse_optimizer, parse_scheduler,parse_optimizer_control 14 | from mar3d.utils.config import parse_structured 15 | from mar3d.utils.misc import C, cleanup, get_device, load_module_weights 16 | from mar3d.utils.saving import SaverMixin 17 | from mar3d.utils.typing import * 18 | 19 | 20 | class BaseSystem(pl.LightningModule, Updateable, SaverMixin): 21 | @dataclass 22 | class Config: 23 | loggers: dict = field(default_factory=dict) 24 | loss: dict = field(default_factory=dict) 25 | optimizer: dict = field(default_factory=dict) 26 | scheduler: Optional[dict] = None 27 | weights: Optional[str] = None 28 | weights_ignore_modules: Optional[List[str]] = None 29 | cleanup_after_validation_step: bool = False 30 | cleanup_after_test_step: bool = False 31 | 32 | pretrained_model_path: Optional[str] = None 33 | strict_load: bool = True 34 | cfg: Config 35 | 36 | def __init__(self, cfg, resumed=False) -> None: 37 | super().__init__() 38 | self.cfg = parse_structured(self.Config, cfg) 39 | self._save_dir: Optional[str] = None 40 | self._resumed: bool = resumed 41 | self._resumed_eval: bool = False 42 | self._resumed_eval_status: dict = {"global_step": 0, "current_epoch": 0} 43 | if "loggers" in cfg: 44 | self.create_loggers(cfg.loggers) 45 | 46 | self.configure() 47 | if self.cfg.weights is not None: 48 | self.load_weights(self.cfg.weights, self.cfg.weights_ignore_modules) 49 | self.post_configure() 50 | 51 | 52 | 53 | 54 | def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True): 55 | """自定义加载状态字典的逻辑""" 56 | model_state_dict = self.state_dict() 57 | 58 | # 新的状态字典,用于存储形状匹配的参数 59 | new_state_dict = {} 60 | 61 | for name, param in state_dict.items(): 62 | if name in model_state_dict: 63 | if param.shape == model_state_dict[name].shape: 64 | new_state_dict[name] = param 65 | else: 66 | print(f"Skipping parameter {name} due to shape mismatch. " 67 | f"Loaded shape: {param.shape}, " 68 | f"Model shape: {model_state_dict[name].shape}") 69 | else: 70 | print(f"Parameter {name} not found in current model") 71 | 72 | print(f"Loaded {len(new_state_dict)} parameters with matching shapes.") 73 | 74 | # 调用父类的 load_state_dict 方法,使用我们过滤后的状态字典 75 | return super().load_state_dict(new_state_dict, strict=False) 76 | 77 | 78 | def load_weights(self, weights: str, ignore_modules: Optional[List[str]] = None): 79 | state_dict, epoch, global_step = load_module_weights( 80 | weights, ignore_modules=ignore_modules, map_location="cpu" 81 | ) 82 | self.load_state_dict(state_dict, strict=False) 83 | # restore step-dependent states 84 | self.do_update_step(epoch, global_step, on_load_weights=True) 85 | 86 | def set_resume_status(self, current_epoch: int, global_step: int): 87 | # restore correct epoch and global step in eval 88 | self._resumed_eval = True 89 | self._resumed_eval_status["current_epoch"] = current_epoch 90 | self._resumed_eval_status["global_step"] = global_step 91 | 92 | @property 93 | def resumed(self): 94 | # whether from resumed checkpoint 95 | return self._resumed 96 | 97 | @property 98 | def true_global_step(self): 99 | if self._resumed_eval: 100 | return self._resumed_eval_status["global_step"] 101 | else: 102 | return self.global_step 103 | 104 | @property 105 | def true_current_epoch(self): 106 | if self._resumed_eval: 107 | return self._resumed_eval_status["current_epoch"] 108 | else: 109 | return self.current_epoch 110 | 111 | def configure(self) -> None: 112 | pass 113 | 114 | def post_configure(self) -> None: 115 | """ 116 | executed after weights are loaded 117 | """ 118 | pass 119 | 120 | def C(self, value: Any) -> float: 121 | return C(value, self.true_current_epoch, self.true_global_step) 122 | 123 | def configure_optimizers(self): 124 | 125 | 126 | 127 | # if self.cfg.control: 128 | # optim = parse_optimizer_control(self.cfg.optimizer, self) 129 | # else: 130 | optim = parse_optimizer(self.cfg.optimizer, self) 131 | ret = { 132 | "optimizer": optim, 133 | } 134 | 135 | if self.cfg.scheduler is not None: 136 | ret.update( 137 | { 138 | "lr_scheduler": parse_scheduler(self.cfg.scheduler, optim), 139 | } 140 | ) 141 | return ret 142 | 143 | def training_step(self, batch, batch_idx): 144 | raise NotImplementedError 145 | 146 | def validation_step(self, batch, batch_idx): 147 | raise NotImplementedError 148 | 149 | def on_train_batch_end(self, outputs, batch, batch_idx): 150 | self.dataset = self.trainer.train_dataloader.dataset 151 | update_end_if_possible( 152 | self.dataset, self.true_current_epoch, self.true_global_step 153 | ) 154 | self.do_update_step_end(self.true_current_epoch, self.true_global_step) 155 | 156 | def on_validation_batch_end(self, outputs, batch, batch_idx): 157 | self.dataset = self.trainer.val_dataloaders.dataset 158 | update_end_if_possible( 159 | self.dataset, self.true_current_epoch, self.true_global_step 160 | ) 161 | self.do_update_step_end(self.true_current_epoch, self.true_global_step) 162 | if self.cfg.cleanup_after_validation_step: 163 | # cleanup to save vram 164 | cleanup() 165 | 166 | def on_validation_epoch_end(self): 167 | raise NotImplementedError 168 | 169 | def test_step(self, batch, batch_idx): 170 | raise NotImplementedError 171 | 172 | def on_test_batch_end(self, outputs, batch, batch_idx): 173 | self.dataset = self.trainer.test_dataloaders.dataset 174 | update_end_if_possible( 175 | self.dataset, self.true_current_epoch, self.true_global_step 176 | ) 177 | self.do_update_step_end(self.true_current_epoch, self.true_global_step) 178 | if self.cfg.cleanup_after_test_step: 179 | # cleanup to save vram 180 | cleanup() 181 | 182 | def on_test_epoch_end(self): 183 | pass 184 | 185 | def predict_step(self, batch, batch_idx): 186 | raise NotImplementedError 187 | 188 | def on_predict_batch_end(self, outputs, batch, batch_idx): 189 | self.dataset = self.trainer.predict_dataloaders.dataset 190 | update_end_if_possible( 191 | self.dataset, self.true_current_epoch, self.true_global_step 192 | ) 193 | self.do_update_step_end(self.true_current_epoch, self.true_global_step) 194 | if self.cfg.cleanup_after_test_step: 195 | # cleanup to save vram 196 | cleanup() 197 | 198 | def on_predict_epoch_end(self): 199 | pass 200 | 201 | def preprocess_data(self, batch, stage): 202 | pass 203 | 204 | """ 205 | Implementing on_after_batch_transfer of DataModule does the same. 206 | But on_after_batch_transfer does not support DP. 207 | """ 208 | 209 | def on_train_batch_start(self, batch, batch_idx, unused=0): 210 | self.preprocess_data(batch, "train") 211 | self.dataset = self.trainer.train_dataloader.dataset 212 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) 213 | self.do_update_step(self.true_current_epoch, self.true_global_step) 214 | 215 | def on_validation_batch_start(self, batch, batch_idx, dataloader_idx=0): 216 | self.preprocess_data(batch, "validation") 217 | self.dataset = self.trainer.val_dataloaders.dataset 218 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) 219 | self.do_update_step(self.true_current_epoch, self.true_global_step) 220 | 221 | def on_test_batch_start(self, batch, batch_idx, dataloader_idx=0): 222 | self.preprocess_data(batch, "test") 223 | self.dataset = self.trainer.test_dataloaders.dataset 224 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) 225 | self.do_update_step(self.true_current_epoch, self.true_global_step) 226 | 227 | def on_predict_batch_start(self, batch, batch_idx, dataloader_idx=0): 228 | self.preprocess_data(batch, "predict") 229 | self.dataset = self.trainer.predict_dataloaders.dataset 230 | update_if_possible(self.dataset, self.true_current_epoch, self.true_global_step) 231 | self.do_update_step(self.true_current_epoch, self.true_global_step) 232 | 233 | def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): 234 | pass 235 | 236 | def on_before_optimizer_step(self, optimizer): 237 | """ 238 | # some gradient-related debugging goes here, example: 239 | from lightning.pytorch.utilities import grad_norm 240 | norms = grad_norm(self.geometry, norm_type=2) 241 | print(norms) 242 | """ 243 | pass 244 | -------------------------------------------------------------------------------- /mar3d/models/diffloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | import math 5 | 6 | from diffusion import create_diffusion 7 | 8 | 9 | class DiffLoss(nn.Module): 10 | """Diffusion Loss""" 11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False): 12 | super(DiffLoss, self).__init__() 13 | self.in_channels = target_channels 14 | self.net = SimpleMLPAdaLN( 15 | in_channels=target_channels, 16 | model_channels=width, 17 | out_channels=target_channels * 2, # for vlb loss 18 | z_channels=z_channels, 19 | num_res_blocks=depth, 20 | grad_checkpointing=grad_checkpointing 21 | ) 22 | 23 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine") 24 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine") 25 | 26 | def forward(self, target, z, mask=None): 27 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device) 28 | model_kwargs = dict(c=z) 29 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs) 30 | loss = loss_dict["loss"] 31 | if mask is not None: 32 | loss = (loss * mask).sum() / mask.sum() 33 | return loss.mean() 34 | 35 | def sample(self, z, temperature=1.0, cfg=1.0): 36 | # diffusion loss sampling 37 | if not cfg == 1.0: 38 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda() 39 | noise = torch.cat([noise, noise], dim=0) 40 | model_kwargs = dict(c=z, cfg_scale=cfg) 41 | sample_fn = self.net.forward_with_cfg 42 | else: 43 | noise = torch.randn(z.shape[0], self.in_channels).cuda() 44 | model_kwargs = dict(c=z) 45 | sample_fn = self.net.forward 46 | 47 | sampled_token_latent = self.gen_diffusion.p_sample_loop( 48 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False, 49 | temperature=temperature 50 | ) 51 | 52 | return sampled_token_latent 53 | 54 | 55 | def modulate(x, shift, scale): 56 | return x * (1 + scale) + shift 57 | 58 | 59 | class TimestepEmbedder(nn.Module): 60 | """ 61 | Embeds scalar timesteps into vector representations. 62 | """ 63 | def __init__(self, hidden_size, frequency_embedding_size=256): 64 | super().__init__() 65 | self.mlp = nn.Sequential( 66 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 67 | nn.SiLU(), 68 | nn.Linear(hidden_size, hidden_size, bias=True), 69 | ) 70 | self.frequency_embedding_size = frequency_embedding_size 71 | 72 | @staticmethod 73 | def timestep_embedding(t, dim, max_period=10000): 74 | """ 75 | Create sinusoidal timestep embeddings. 76 | :param t: a 1-D Tensor of N indices, one per batch element. 77 | These may be fractional. 78 | :param dim: the dimension of the output. 79 | :param max_period: controls the minimum frequency of the embeddings. 80 | :return: an (N, D) Tensor of positional embeddings. 81 | """ 82 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 83 | half = dim // 2 84 | freqs = torch.exp( 85 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 86 | ).to(device=t.device) 87 | args = t[:, None].float() * freqs[None] 88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 89 | if dim % 2: 90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 91 | return embedding 92 | 93 | def forward(self, t): 94 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 95 | t_emb = self.mlp(t_freq) 96 | return t_emb 97 | 98 | 99 | class ResBlock(nn.Module): 100 | """ 101 | A residual block that can optionally change the number of channels. 102 | :param channels: the number of input channels. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | channels 108 | ): 109 | super().__init__() 110 | self.channels = channels 111 | 112 | self.in_ln = nn.LayerNorm(channels, eps=1e-6) 113 | self.mlp = nn.Sequential( 114 | nn.Linear(channels, channels, bias=True), 115 | nn.SiLU(), 116 | nn.Linear(channels, channels, bias=True), 117 | ) 118 | 119 | self.adaLN_modulation = nn.Sequential( 120 | nn.SiLU(), 121 | nn.Linear(channels, 3 * channels, bias=True) 122 | ) 123 | 124 | def forward(self, x, y): 125 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) 126 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) 127 | h = self.mlp(h) 128 | return x + gate_mlp * h 129 | 130 | 131 | class FinalLayer(nn.Module): 132 | """ 133 | The final layer of DiT. 134 | """ 135 | def __init__(self, model_channels, out_channels): 136 | super().__init__() 137 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) 138 | self.linear = nn.Linear(model_channels, out_channels, bias=True) 139 | self.adaLN_modulation = nn.Sequential( 140 | nn.SiLU(), 141 | nn.Linear(model_channels, 2 * model_channels, bias=True) 142 | ) 143 | 144 | def forward(self, x, c): 145 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 146 | x = modulate(self.norm_final(x), shift, scale) 147 | x = self.linear(x) 148 | return x 149 | 150 | 151 | class SimpleMLPAdaLN(nn.Module): 152 | """ 153 | The MLP for Diffusion Loss. 154 | :param in_channels: channels in the input Tensor. 155 | :param model_channels: base channel count for the model. 156 | :param out_channels: channels in the output Tensor. 157 | :param z_channels: channels in the condition. 158 | :param num_res_blocks: number of residual blocks per downsample. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | in_channels, 164 | model_channels, 165 | out_channels, 166 | z_channels, 167 | num_res_blocks, 168 | grad_checkpointing=False 169 | ): 170 | super().__init__() 171 | 172 | self.in_channels = in_channels 173 | self.model_channels = model_channels 174 | self.out_channels = out_channels 175 | self.num_res_blocks = num_res_blocks 176 | self.grad_checkpointing = grad_checkpointing 177 | 178 | self.time_embed = TimestepEmbedder(model_channels) 179 | self.cond_embed = nn.Linear(z_channels, model_channels) 180 | 181 | self.input_proj = nn.Linear(in_channels, model_channels) 182 | 183 | res_blocks = [] 184 | for i in range(num_res_blocks): 185 | res_blocks.append(ResBlock( 186 | model_channels, 187 | )) 188 | 189 | self.res_blocks = nn.ModuleList(res_blocks) 190 | self.final_layer = FinalLayer(model_channels, out_channels) 191 | 192 | self.initialize_weights() 193 | 194 | def initialize_weights(self): 195 | def _basic_init(module): 196 | if isinstance(module, nn.Linear): 197 | torch.nn.init.xavier_uniform_(module.weight) 198 | if module.bias is not None: 199 | nn.init.constant_(module.bias, 0) 200 | self.apply(_basic_init) 201 | 202 | # Initialize timestep embedding MLP 203 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) 204 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) 205 | 206 | # Zero-out adaLN modulation layers 207 | for block in self.res_blocks: 208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 210 | 211 | # Zero-out output layers 212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 214 | nn.init.constant_(self.final_layer.linear.weight, 0) 215 | nn.init.constant_(self.final_layer.linear.bias, 0) 216 | 217 | def forward(self, x, t, c): 218 | """ 219 | Apply the model to an input batch. 220 | :param x: an [N x C x ...] Tensor of inputs. 221 | :param t: a 1-D batch of timesteps. 222 | :param c: conditioning from AR transformer. 223 | :return: an [N x C x ...] Tensor of outputs. 224 | """ 225 | 226 | x = self.input_proj(x) 227 | t = self.time_embed(t) 228 | c = self.cond_embed(c) 229 | 230 | y = t + c 231 | 232 | if self.grad_checkpointing and not torch.jit.is_scripting(): 233 | for block in self.res_blocks: 234 | x = checkpoint(block, x, y) 235 | else: 236 | for block in self.res_blocks: 237 | x = block(x, y) 238 | 239 | return self.final_layer(x, y) 240 | 241 | def forward_with_cfg(self, x, t, c, cfg_scale): 242 | half = x[: len(x) // 2] 243 | combined = torch.cat([half, half], dim=0) 244 | model_out = self.forward(combined, t, c) 245 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 246 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 247 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 248 | eps = torch.cat([half_eps, half_eps], dim=0) 249 | return torch.cat([eps, rest], dim=1) 250 | -------------------------------------------------------------------------------- /mar3d/systems/diffloss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.utils.checkpoint import checkpoint 4 | import math 5 | 6 | from diffusion import create_diffusion 7 | 8 | 9 | class DiffLoss(nn.Module): 10 | """Diffusion Loss""" 11 | def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, grad_checkpointing=False): 12 | super(DiffLoss, self).__init__() 13 | self.in_channels = target_channels 14 | self.net = SimpleMLPAdaLN( 15 | in_channels=target_channels, 16 | model_channels=width, 17 | out_channels=target_channels * 2, # for vlb loss 18 | z_channels=z_channels, 19 | num_res_blocks=depth, 20 | grad_checkpointing=grad_checkpointing 21 | ) 22 | 23 | self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine") 24 | self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine") 25 | 26 | def forward(self, target, z, mask=None): 27 | t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device) 28 | model_kwargs = dict(c=z) 29 | loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs) 30 | loss = loss_dict["loss"] 31 | if mask is not None: 32 | loss = (loss * mask).sum() / mask.sum() 33 | return loss.mean() 34 | 35 | def sample(self, z, temperature=1.0, cfg=1.0): 36 | # diffusion loss sampling 37 | if not cfg == 1.0: 38 | noise = torch.randn(z.shape[0] // 2, self.in_channels).cuda() 39 | noise = torch.cat([noise, noise], dim=0) 40 | model_kwargs = dict(c=z, cfg_scale=cfg) 41 | sample_fn = self.net.forward_with_cfg 42 | else: 43 | noise = torch.randn(z.shape[0], self.in_channels).cuda() 44 | model_kwargs = dict(c=z) 45 | sample_fn = self.net.forward 46 | 47 | sampled_token_latent = self.gen_diffusion.p_sample_loop( 48 | sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False, 49 | temperature=temperature 50 | ) 51 | 52 | return sampled_token_latent 53 | 54 | 55 | def modulate(x, shift, scale): 56 | return x * (1 + scale) + shift 57 | 58 | 59 | class TimestepEmbedder(nn.Module): 60 | """ 61 | Embeds scalar timesteps into vector representations. 62 | """ 63 | def __init__(self, hidden_size, frequency_embedding_size=256): 64 | super().__init__() 65 | self.mlp = nn.Sequential( 66 | nn.Linear(frequency_embedding_size, hidden_size, bias=True), 67 | nn.SiLU(), 68 | nn.Linear(hidden_size, hidden_size, bias=True), 69 | ) 70 | self.frequency_embedding_size = frequency_embedding_size 71 | 72 | @staticmethod 73 | def timestep_embedding(t, dim, max_period=10000): 74 | """ 75 | Create sinusoidal timestep embeddings. 76 | :param t: a 1-D Tensor of N indices, one per batch element. 77 | These may be fractional. 78 | :param dim: the dimension of the output. 79 | :param max_period: controls the minimum frequency of the embeddings. 80 | :return: an (N, D) Tensor of positional embeddings. 81 | """ 82 | # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py 83 | half = dim // 2 84 | freqs = torch.exp( 85 | -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half 86 | ).to(device=t.device) 87 | args = t[:, None].float() * freqs[None] 88 | embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) 89 | if dim % 2: 90 | embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) 91 | return embedding 92 | 93 | def forward(self, t): 94 | t_freq = self.timestep_embedding(t, self.frequency_embedding_size) 95 | t_emb = self.mlp(t_freq) 96 | return t_emb 97 | 98 | 99 | class ResBlock(nn.Module): 100 | """ 101 | A residual block that can optionally change the number of channels. 102 | :param channels: the number of input channels. 103 | """ 104 | 105 | def __init__( 106 | self, 107 | channels 108 | ): 109 | super().__init__() 110 | self.channels = channels 111 | 112 | self.in_ln = nn.LayerNorm(channels, eps=1e-6) 113 | self.mlp = nn.Sequential( 114 | nn.Linear(channels, channels, bias=True), 115 | nn.SiLU(), 116 | nn.Linear(channels, channels, bias=True), 117 | ) 118 | 119 | self.adaLN_modulation = nn.Sequential( 120 | nn.SiLU(), 121 | nn.Linear(channels, 3 * channels, bias=True) 122 | ) 123 | 124 | def forward(self, x, y): 125 | shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) 126 | h = modulate(self.in_ln(x), shift_mlp, scale_mlp) 127 | h = self.mlp(h) 128 | return x + gate_mlp * h 129 | 130 | 131 | class FinalLayer(nn.Module): 132 | """ 133 | The final layer of DiT. 134 | """ 135 | def __init__(self, model_channels, out_channels): 136 | super().__init__() 137 | self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6) 138 | self.linear = nn.Linear(model_channels, out_channels, bias=True) 139 | self.adaLN_modulation = nn.Sequential( 140 | nn.SiLU(), 141 | nn.Linear(model_channels, 2 * model_channels, bias=True) 142 | ) 143 | 144 | def forward(self, x, c): 145 | shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) 146 | x = modulate(self.norm_final(x), shift, scale) 147 | x = self.linear(x) 148 | return x 149 | 150 | 151 | class SimpleMLPAdaLN(nn.Module): 152 | """ 153 | The MLP for Diffusion Loss. 154 | :param in_channels: channels in the input Tensor. 155 | :param model_channels: base channel count for the model. 156 | :param out_channels: channels in the output Tensor. 157 | :param z_channels: channels in the condition. 158 | :param num_res_blocks: number of residual blocks per downsample. 159 | """ 160 | 161 | def __init__( 162 | self, 163 | in_channels, 164 | model_channels, 165 | out_channels, 166 | z_channels, 167 | num_res_blocks, 168 | grad_checkpointing=False 169 | ): 170 | super().__init__() 171 | 172 | self.in_channels = in_channels 173 | self.model_channels = model_channels 174 | self.out_channels = out_channels 175 | self.num_res_blocks = num_res_blocks 176 | self.grad_checkpointing = grad_checkpointing 177 | 178 | self.time_embed = TimestepEmbedder(model_channels) 179 | self.cond_embed = nn.Linear(z_channels, model_channels) 180 | 181 | self.input_proj = nn.Linear(in_channels, model_channels) 182 | 183 | res_blocks = [] 184 | for i in range(num_res_blocks): 185 | res_blocks.append(ResBlock( 186 | model_channels, 187 | )) 188 | 189 | self.res_blocks = nn.ModuleList(res_blocks) 190 | self.final_layer = FinalLayer(model_channels, out_channels) 191 | 192 | self.initialize_weights() 193 | 194 | def initialize_weights(self): 195 | def _basic_init(module): 196 | if isinstance(module, nn.Linear): 197 | torch.nn.init.xavier_uniform_(module.weight) 198 | if module.bias is not None: 199 | nn.init.constant_(module.bias, 0) 200 | self.apply(_basic_init) 201 | 202 | # Initialize timestep embedding MLP 203 | nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) 204 | nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) 205 | 206 | # Zero-out adaLN modulation layers 207 | for block in self.res_blocks: 208 | nn.init.constant_(block.adaLN_modulation[-1].weight, 0) 209 | nn.init.constant_(block.adaLN_modulation[-1].bias, 0) 210 | 211 | # Zero-out output layers 212 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) 213 | nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) 214 | nn.init.constant_(self.final_layer.linear.weight, 0) 215 | nn.init.constant_(self.final_layer.linear.bias, 0) 216 | 217 | def forward(self, x, t, c): 218 | """ 219 | Apply the model to an input batch. 220 | :param x: an [N x C x ...] Tensor of inputs. 221 | :param t: a 1-D batch of timesteps. 222 | :param c: conditioning from AR transformer. 223 | :return: an [N x C x ...] Tensor of outputs. 224 | """ 225 | 226 | x = self.input_proj(x) 227 | t = self.time_embed(t) 228 | c = self.cond_embed(c) 229 | 230 | y = t + c 231 | 232 | if self.grad_checkpointing and not torch.jit.is_scripting(): 233 | for block in self.res_blocks: 234 | x = checkpoint(block, x, y) 235 | else: 236 | for block in self.res_blocks: 237 | x = block(x, y) 238 | 239 | return self.final_layer(x, y) 240 | 241 | def forward_with_cfg(self, x, t, c, cfg_scale): 242 | half = x[: len(x) // 2] 243 | combined = torch.cat([half, half], dim=0) 244 | model_out = self.forward(combined, t, c) 245 | eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:] 246 | cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) 247 | half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) 248 | eps = torch.cat([half_eps, half_eps], dim=0) 249 | return torch.cat([eps, rest], dim=1) 250 | -------------------------------------------------------------------------------- /mar3d/models/transformers/attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from mar3d.utils.typing import * 7 | from mar3d.utils.checkpoint import checkpoint 8 | 9 | from .utils import init_linear, MLP 10 | 11 | class MultiheadAttention(nn.Module): 12 | def __init__( 13 | self, 14 | *, 15 | n_ctx: int, 16 | width: int, 17 | heads: int, 18 | init_scale: float, 19 | qkv_bias: bool, 20 | use_flash: bool = False 21 | ): 22 | super().__init__() 23 | self.n_ctx = n_ctx 24 | self.width = width 25 | self.heads = heads 26 | self.c_qkv = nn.Linear(width, width * 3, bias=qkv_bias) 27 | self.c_proj = nn.Linear(width, width) 28 | self.attention = QKVMultiheadAttention(heads=heads, n_ctx=n_ctx, use_flash=use_flash) 29 | init_linear(self.c_qkv, init_scale) 30 | init_linear(self.c_proj, init_scale) 31 | 32 | def forward(self, x): 33 | x = self.c_qkv(x) 34 | x = checkpoint(self.attention, (x,), (), True) 35 | x = self.c_proj(x) 36 | return x 37 | 38 | 39 | class QKVMultiheadAttention(nn.Module): 40 | def __init__(self, *, heads: int, n_ctx: int, use_flash: bool = False): 41 | super().__init__() 42 | self.heads = heads 43 | self.n_ctx = n_ctx 44 | self.use_flash = use_flash 45 | 46 | def forward(self, qkv): 47 | bs, n_ctx, width = qkv.shape 48 | attn_ch = width // self.heads // 3 49 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 50 | qkv = qkv.view(bs, n_ctx, self.heads, -1) 51 | q, k, v = torch.split(qkv, attn_ch, dim=-1) 52 | 53 | if self.use_flash: 54 | q = q.permute(0, 2, 1, 3) 55 | k = k.permute(0, 2, 1, 3) 56 | v = v.permute(0, 2, 1, 3) 57 | out = F.scaled_dot_product_attention(q, k, v).permute(0, 2, 1, 3).reshape(bs, n_ctx, -1) 58 | else: 59 | weight = torch.einsum( 60 | "bthc,bshc->bhts", q * scale, k * scale 61 | ) # More stable with f16 than dividing afterwards 62 | wdtype = weight.dtype 63 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 64 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 65 | 66 | return out 67 | 68 | class ResidualAttentionBlock(nn.Module): 69 | def __init__( 70 | self, 71 | *, 72 | n_ctx: int, 73 | width: int, 74 | heads: int, 75 | init_scale: float = 1.0, 76 | qkv_bias: bool = True, 77 | use_flash: bool = False, 78 | use_checkpoint: bool = False 79 | ): 80 | super().__init__() 81 | 82 | self.use_checkpoint = use_checkpoint 83 | 84 | self.attn = MultiheadAttention( 85 | n_ctx=n_ctx, 86 | width=width, 87 | heads=heads, 88 | init_scale=init_scale, 89 | qkv_bias=qkv_bias, 90 | use_flash=use_flash 91 | ) 92 | self.ln_1 = nn.LayerNorm(width) 93 | self.mlp = MLP(width=width, init_scale=init_scale) 94 | self.ln_2 = nn.LayerNorm(width) 95 | 96 | def _forward(self, x: torch.Tensor): 97 | x = x + self.attn(self.ln_1(x)) 98 | x = x + self.mlp(self.ln_2(x)) 99 | return x 100 | 101 | def forward(self, x: torch.Tensor): 102 | return checkpoint(self._forward, (x,), self.parameters(), self.use_checkpoint) 103 | 104 | 105 | class MultiheadCrossAttention(nn.Module): 106 | def __init__( 107 | self, 108 | *, 109 | width: int, 110 | heads: int, 111 | init_scale: float, 112 | qkv_bias: bool = True, 113 | use_flash: bool = False, 114 | n_data: Optional[int] = None, 115 | data_width: Optional[int] = None, 116 | ): 117 | super().__init__() 118 | self.n_data = n_data 119 | self.width = width 120 | self.heads = heads 121 | self.data_width = width if data_width is None else data_width 122 | self.c_q = nn.Linear(width, width, bias=qkv_bias) 123 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) 124 | self.c_proj = nn.Linear(width, width) 125 | self.attention = QKVMultiheadCrossAttention( 126 | heads=heads, n_data=n_data, use_flash=use_flash 127 | ) 128 | init_linear(self.c_q, init_scale) 129 | init_linear(self.c_kv, init_scale) 130 | init_linear(self.c_proj, init_scale) 131 | 132 | def forward(self, x, data): 133 | x = self.c_q(x) 134 | data = self.c_kv(data) 135 | x = checkpoint(self.attention, (x, data), (), True) 136 | x = self.c_proj(x) 137 | return x 138 | 139 | 140 | class QKVMultiheadCrossAttention(nn.Module): 141 | def __init__(self, *, heads: int, use_flash: bool = False, n_data: Optional[int] = None): 142 | 143 | super().__init__() 144 | self.heads = heads 145 | self.n_data = n_data 146 | self.use_flash = use_flash 147 | 148 | def forward(self, q, kv): 149 | _, n_ctx, _ = q.shape 150 | bs, n_data, width = kv.shape 151 | attn_ch = width // self.heads // 2 152 | scale = 1 / math.sqrt(math.sqrt(attn_ch)) 153 | q = q.view(bs, n_ctx, self.heads, -1) 154 | kv = kv.view(bs, n_data, self.heads, -1) 155 | k, v = torch.split(kv, attn_ch, dim=-1) 156 | 157 | if self.use_flash: 158 | q = q.permute(0, 2, 1, 3) 159 | k = k.permute(0, 2, 1, 3) 160 | v = v.permute(0, 2, 1, 3) 161 | out = F.scaled_dot_product_attention(q, k, v).permute(0, 2, 1, 3).reshape(bs, n_ctx, -1) 162 | else: 163 | weight = torch.einsum( 164 | "bthc,bshc->bhts", q * scale, k * scale 165 | ) # More stable with f16 than dividing afterwards 166 | wdtype = weight.dtype 167 | weight = torch.softmax(weight.float(), dim=-1).type(wdtype) 168 | out = torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1) 169 | 170 | return out 171 | 172 | 173 | class ResidualCrossAttentionBlock(nn.Module): 174 | def __init__( 175 | self, 176 | *, 177 | n_data: Optional[int] = None, 178 | width: int, 179 | heads: int, 180 | data_width: Optional[int] = None, 181 | init_scale: float = 0.25, 182 | qkv_bias: bool = True, 183 | use_flash: bool = False 184 | ): 185 | super().__init__() 186 | 187 | if data_width is None: 188 | data_width = width 189 | 190 | self.attn = MultiheadCrossAttention( 191 | n_data=n_data, 192 | width=width, 193 | heads=heads, 194 | data_width=data_width, 195 | init_scale=init_scale, 196 | qkv_bias=qkv_bias, 197 | use_flash=use_flash, 198 | ) 199 | self.ln_1 = nn.LayerNorm(width) 200 | self.ln_2 = nn.LayerNorm(data_width) 201 | self.mlp = MLP(width=width, init_scale=init_scale) 202 | self.ln_3 = nn.LayerNorm(width) 203 | 204 | def forward(self, x: torch.Tensor, data: torch.Tensor): 205 | x = x + self.attn(self.ln_1(x), self.ln_2(data)) 206 | x = x + self.mlp(self.ln_3(x)) 207 | return x 208 | 209 | class MultiheadCrossAttention_(nn.Module): 210 | def __init__( 211 | self, 212 | *, 213 | widthq:int, 214 | width: int, 215 | heads: int, 216 | init_scale: float, 217 | qkv_bias: bool = True, 218 | use_flash: bool = False, 219 | n_data: Optional[int] = None, 220 | kv_width: Optional[int] = None, 221 | out_width: int, 222 | ): 223 | super().__init__() 224 | self.n_data = n_data 225 | self.width = width 226 | self.width1 = widthq 227 | self.heads = heads 228 | self.data_width = width if kv_width is None else kv_width 229 | self.c_q = nn.Linear(widthq, width, bias=qkv_bias) 230 | self.c_kv = nn.Linear(self.data_width, width * 2, bias=qkv_bias) 231 | self.c_proj = nn.Linear(width, out_width) 232 | self.attention = QKVMultiheadCrossAttention( 233 | heads=heads, n_data=n_data, use_flash=use_flash 234 | ) 235 | init_linear(self.c_q, init_scale) 236 | init_linear(self.c_kv, init_scale) 237 | init_linear(self.c_proj, init_scale) 238 | 239 | def forward(self, x, data): 240 | 241 | x = self.c_q(x) 242 | data = self.c_kv(data) 243 | x = checkpoint(self.attention, (x, data), (), True) 244 | x = self.c_proj(x) 245 | return x 246 | class CrossAttentionBlock(nn.Module): 247 | def __init__( 248 | self, 249 | *, 250 | n_data: Optional[int] = None, 251 | widq: int, 252 | width: int, 253 | heads: int, 254 | kv_width: Optional[int] = None, 255 | out_width: int, 256 | init_scale: float = 0.25, 257 | qkv_bias: bool = True, 258 | use_flash: bool = False 259 | ): 260 | super().__init__() 261 | 262 | # if data_width is None: 263 | # data_width = width 264 | 265 | self.attn = MultiheadCrossAttention_( 266 | n_data=n_data, 267 | widthq=widq, 268 | width=width, 269 | heads=heads, 270 | kv_width=kv_width, 271 | out_width=out_width, 272 | init_scale=init_scale, 273 | qkv_bias=qkv_bias, 274 | use_flash=use_flash, 275 | ) 276 | self.ln_1 = nn.LayerNorm(widq) 277 | self.ln_2 = nn.LayerNorm(kv_width) 278 | self.mlp = MLP(width=out_width, init_scale=init_scale) 279 | self.ln_3 = nn.LayerNorm( out_width) 280 | 281 | def forward(self, x: torch.Tensor, data: torch.Tensor): 282 | x = self.attn(self.ln_1(x), self.ln_2(data)) 283 | x = self.mlp(self.ln_3(x)) 284 | return x -------------------------------------------------------------------------------- /launch.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import contextlib 3 | import importlib 4 | import logging 5 | import os 6 | import sys 7 | import time 8 | import traceback 9 | import pytorch_lightning as pl 10 | import torch 11 | from pytorch_lightning import Trainer 12 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 13 | from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger 14 | from pytorch_lightning.utilities.rank_zero import rank_zero_only 15 | import mar3d 16 | from mar3d.systems.base import BaseSystem 17 | from mar3d.utils.callbacks import ( 18 | CodeSnapshotCallback, 19 | ConfigSnapshotCallback, 20 | CustomProgressBar, 21 | ProgressCallback, 22 | ) 23 | from mar3d.utils.config import ExperimentConfig, load_config 24 | from mar3d.utils.misc import get_rank 25 | from mar3d.utils.typing import Optional 26 | class ColoredFilter(logging.Filter): 27 | """ 28 | A logging filter to add color to certain log levels. 29 | """ 30 | 31 | RESET = "\033[0m" 32 | RED = "\033[31m" 33 | GREEN = "\033[32m" 34 | YELLOW = "\033[33m" 35 | BLUE = "\033[34m" 36 | MAGENTA = "\033[35m" 37 | CYAN = "\033[36m" 38 | 39 | COLORS = { 40 | "WARNING": YELLOW, 41 | "INFO": GREEN, 42 | "DEBUG": BLUE, 43 | "CRITICAL": MAGENTA, 44 | "ERROR": RED, 45 | } 46 | 47 | RESET = "\x1b[0m" 48 | 49 | def __init__(self): 50 | super().__init__() 51 | 52 | def filter(self, record): 53 | if record.levelname in self.COLORS: 54 | color_start = self.COLORS[record.levelname] 55 | record.levelname = f"{color_start}[{record.levelname}]" 56 | record.msg = f"{record.msg}{self.RESET}" 57 | return True 58 | 59 | 60 | def load_custom_module(module_path): 61 | module_name = os.path.basename(module_path) 62 | if os.path.isfile(module_path): 63 | sp = os.path.splitext(module_path) 64 | module_name = sp[0] 65 | try: 66 | if os.path.isfile(module_path): 67 | module_spec = importlib.util.spec_from_file_location( 68 | module_name, module_path 69 | ) 70 | else: 71 | module_spec = importlib.util.spec_from_file_location( 72 | module_name, os.path.join(module_path, "__init__.py") 73 | ) 74 | 75 | module = importlib.util.module_from_spec(module_spec) 76 | sys.modules[module_name] = module 77 | module_spec.loader.exec_module(module) 78 | return True 79 | except Exception as e: 80 | print(traceback.format_exc()) 81 | print(f"Cannot import {module_path} module for custom nodes:", e) 82 | return False 83 | 84 | 85 | def load_custom_modules(): 86 | node_paths = ["custom"] 87 | node_import_times = [] 88 | if not os.path.exists("node_paths"): 89 | return 90 | for custom_node_path in node_paths: 91 | possible_modules = os.listdir(custom_node_path) 92 | if "__pycache__" in possible_modules: 93 | possible_modules.remove("__pycache__") 94 | 95 | for possible_module in possible_modules: 96 | module_path = os.path.join(custom_node_path, possible_module) 97 | if ( 98 | os.path.isfile(module_path) 99 | and os.path.splitext(module_path)[1] != ".py" 100 | ): 101 | continue 102 | if module_path.endswith(".disabled"): 103 | continue 104 | time_before = time.perf_counter() 105 | success = load_custom_module(module_path) 106 | node_import_times.append( 107 | (time.perf_counter() - time_before, module_path, success) 108 | ) 109 | 110 | if len(node_import_times) > 0: 111 | print("\nImport times for custom modules:") 112 | for n in sorted(node_import_times): 113 | if n[2]: 114 | import_message = "" 115 | else: 116 | import_message = " (IMPORT FAILED)" 117 | print("{:6.1f} seconds{}:".format(n[0], import_message), n[1]) 118 | print() 119 | 120 | 121 | def main(args, extras) -> None: 122 | # set CUDA_VISIBLE_DEVICES if needed, then import pytorch-lightning 123 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 124 | env_gpus_str = os.environ.get("CUDA_VISIBLE_DEVICES", None) 125 | env_gpus = list(env_gpus_str.split(",")) if env_gpus_str else [] 126 | selected_gpus = [0] 127 | torch.set_float32_matmul_precision("high") 128 | 129 | # Always rely on CUDA_VISIBLE_DEVICES if specific GPU ID(s) are specified. 130 | # As far as Pytorch Lightning is concerned, we always use all available GPUs 131 | # (possibly filtered by CUDA_VISIBLE_DEVICES). 132 | devices = -1 133 | if len(env_gpus) > 0: 134 | n_gpus = len(env_gpus) 135 | else: 136 | selected_gpus = list(args.gpu.split(",")) 137 | n_gpus = len(selected_gpus) 138 | print(f"Using {n_gpus} GPUs: {selected_gpus}") 139 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 140 | 141 | if args.typecheck: 142 | from jaxtyping import install_import_hook 143 | 144 | install_import_hook("mar3d", "typeguard.typechecked") 145 | 146 | logger = logging.getLogger("pytorch_lightning") 147 | if args.verbose: 148 | logger.setLevel(logging.DEBUG) 149 | 150 | for handler in logger.handlers: 151 | if handler.stream == sys.stderr: # type: ignore 152 | if not args.gradio: 153 | handler.setFormatter(logging.Formatter("%(levelname)s %(message)s")) 154 | handler.addFilter(ColoredFilter()) 155 | else: 156 | handler.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) 157 | 158 | load_custom_modules() 159 | 160 | # parse YAML config to OmegaConf 161 | cfg: ExperimentConfig 162 | cfg = load_config(args.config, cli_args=extras, n_gpus=n_gpus) 163 | 164 | # set a different seed for each device 165 | pl.seed_everything(cfg.seed + get_rank(), workers=True) 166 | 167 | dm = mar3d.find(cfg.data_type)(cfg.data) 168 | system: BaseSystem = mar3d.find(cfg.system_type)( 169 | cfg.system, resumed=cfg.resume is not None 170 | ) 171 | 172 | system.set_save_dir(os.path.join(cfg.trial_dir, "save")) 173 | 174 | if args.gradio: 175 | fh = logging.FileHandler(os.path.join(cfg.trial_dir, "logs")) 176 | fh.setLevel(logging.INFO) 177 | if args.verbose: 178 | fh.setLevel(logging.DEBUG) 179 | fh.setFormatter(logging.Formatter("[%(levelname)s] %(message)s")) 180 | logger.addHandler(fh) 181 | 182 | callbacks = [] 183 | if args.train: 184 | callbacks += [ 185 | ModelCheckpoint( 186 | dirpath=os.path.join(cfg.trial_dir, "ckpts"), **cfg.checkpoint 187 | ), 188 | LearningRateMonitor(logging_interval="step"), 189 | # CodeSnapshotCallback( 190 | # os.path.join(cfg.trial_dir, "code"), use_version=False 191 | # ), 192 | ConfigSnapshotCallback( 193 | args.config, 194 | cfg, 195 | os.path.join(cfg.trial_dir, "configs"), 196 | use_version=False, 197 | ), 198 | ] 199 | if args.gradio: 200 | callbacks += [ 201 | ProgressCallback(save_path=os.path.join(cfg.trial_dir, "progress")) 202 | ] 203 | else: 204 | callbacks += [CustomProgressBar(refresh_rate=1)] 205 | 206 | def write_to_text(file, lines): 207 | with open(file, "w") as f: 208 | for line in lines: 209 | f.write(line + "\n") 210 | 211 | loggers = [] 212 | if args.train: 213 | # make tensorboard logging dir to suppress warning 214 | rank_zero_only( 215 | lambda: os.makedirs(os.path.join(cfg.trial_dir, "tb_logs"), exist_ok=True) 216 | )() 217 | loggers += [ 218 | TensorBoardLogger(cfg.trial_dir, name="tb_logs"), 219 | CSVLogger(cfg.trial_dir, name="csv_logs"), 220 | ] + system.get_loggers() 221 | rank_zero_only( 222 | lambda: write_to_text( 223 | os.path.join(cfg.trial_dir, "cmd.txt"), 224 | ["python " + " ".join(sys.argv), str(args)], 225 | ) 226 | )() 227 | 228 | trainer = Trainer( 229 | callbacks=callbacks, 230 | logger=loggers, 231 | inference_mode=False, 232 | accelerator="gpu", 233 | devices=devices, 234 | # profiler="advanced", 235 | **cfg.trainer, 236 | ) 237 | 238 | def set_system_status(system: BaseSystem, ckpt_path: Optional[str]): 239 | if ckpt_path is None: 240 | return 241 | ckpt = torch.load(ckpt_path, map_location="cpu") 242 | system.set_resume_status(ckpt["epoch"], ckpt["global_step"]) 243 | if args.train: 244 | trainer.fit(system, datamodule=dm, ckpt_path=cfg.resume) 245 | trainer.test(system, datamodule=dm) 246 | if args.gradio: 247 | # also export assets if in gradio mode 248 | trainer.predict(system, datamodule=dm) 249 | elif args.validate: 250 | # manually set epoch and global_step as they cannot be automatically resumed 251 | set_system_status(system, cfg.resume) 252 | trainer.validate(system, datamodule=dm, ckpt_path=cfg.resume) 253 | elif args.test: 254 | # manually set epoch and global_step as they cannot be automatically resumed 255 | set_system_status(system, cfg.resume) 256 | trainer.test(system, datamodule=dm, ckpt_path=cfg.resume) 257 | elif args.export: 258 | set_system_status(system, cfg.resume) 259 | trainer.predict(system, datamodule=dm, ckpt_path=cfg.resume) 260 | 261 | 262 | if __name__ == "__main__": 263 | parser = argparse.ArgumentParser() 264 | parser.add_argument("--config", required=True, help="path to config file") 265 | parser.add_argument( 266 | "--gpu", 267 | default="0", 268 | help="GPU(s) to be used. 0 means use the 1st available GPU. " 269 | "1,2 means use the 2nd and 3rd available GPU. " 270 | "If CUDA_VISIBLE_DEVICES is set before calling `launch.py`, " 271 | "this argument is ignored and all available GPUs are always used.", 272 | ) 273 | 274 | group = parser.add_mutually_exclusive_group(required=True) 275 | group.add_argument("--train", action="store_true") 276 | group.add_argument("--validate", action="store_true") 277 | group.add_argument("--test", action="store_true") 278 | group.add_argument("--export", action="store_true") 279 | 280 | parser.add_argument( 281 | "--gradio", action="store_true", help="if true, run in gradio mode" 282 | ) 283 | 284 | parser.add_argument( 285 | "--verbose", action="store_true", help="if true, set logging level to DEBUG" 286 | ) 287 | 288 | parser.add_argument( 289 | "--typecheck", 290 | action="store_true", 291 | help="whether to enable dynamic type checking", 292 | ) 293 | 294 | args, extras = parser.parse_known_args() 295 | 296 | if args.gradio: 297 | with contextlib.redirect_stdout(sys.stderr): 298 | main(args, extras) 299 | else: 300 | main(args, extras) 301 | -------------------------------------------------------------------------------- /mar3d/models/autoencoders/utils.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | 3 | import torch 4 | import torch.nn as nn 5 | from torch import distributed as tdist 6 | from torch.nn import functional as F 7 | import math 8 | import numpy as np 9 | from einops import repeat, rearrange 10 | from skimage import measure 11 | 12 | from mar3d.utils.base import BaseModule 13 | from mar3d.utils.typing import * 14 | from mar3d.utils.misc import get_world_size 15 | from mar3d.utils.ops import generate_dense_grid_points 16 | 17 | VALID_EMBED_TYPES = ["identity", "fourier", "hashgrid", "sphere_harmonic", "triplane_fourier"] 18 | 19 | class FourierEmbedder(nn.Module): 20 | def __init__(self, 21 | num_freqs: int = 6, 22 | logspace: bool = True, 23 | input_dim: int = 3, 24 | include_input: bool = True, 25 | include_pi: bool = True) -> None: 26 | super().__init__() 27 | 28 | if logspace: 29 | frequencies = 2.0 ** torch.arange( 30 | num_freqs, 31 | dtype=torch.float32 32 | ) 33 | else: 34 | frequencies = torch.linspace( 35 | 1.0, 36 | 2.0 ** (num_freqs - 1), 37 | num_freqs, 38 | dtype=torch.float32 39 | ) 40 | 41 | if include_pi: 42 | frequencies *= torch.pi 43 | 44 | self.register_buffer("frequencies", frequencies, persistent=False) 45 | self.include_input = include_input 46 | self.num_freqs = num_freqs 47 | 48 | self.out_dim = self.get_dims(input_dim) 49 | 50 | def get_dims(self, input_dim): 51 | temp = 1 if self.include_input or self.num_freqs == 0 else 0 52 | out_dim = input_dim * (self.num_freqs * 2 + temp) 53 | 54 | return out_dim 55 | 56 | def forward(self, x: torch.Tensor) -> torch.Tensor: 57 | if self.num_freqs > 0: 58 | embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1) 59 | if self.include_input: 60 | return torch.cat((x, embed.sin(), embed.cos()), dim=-1) 61 | else: 62 | return torch.cat((embed.sin(), embed.cos()), dim=-1) 63 | else: 64 | return x 65 | 66 | 67 | class LearnedFourierEmbedder(nn.Module): 68 | def __init__(self, input_dim, dim): 69 | super().__init__() 70 | assert (dim % 2) == 0 71 | half_dim = dim // 2 72 | per_channel_dim = half_dim // input_dim 73 | self.weights = nn.Parameter(torch.randn(per_channel_dim)) 74 | 75 | self.out_dim = self.get_dims(input_dim) 76 | 77 | def forward(self, x): 78 | # [b, t, c, 1] * [1, d] = [b, t, c, d] -> [b, t, c * d] 79 | freqs = (x[..., None] * self.weights[None] * 2 * np.pi).view(*x.shape[:-1], -1) 80 | fouriered = torch.cat((x, freqs.sin(), freqs.cos()), dim=-1) 81 | return fouriered 82 | 83 | def get_dims(self, input_dim): 84 | return input_dim * (self.weights.shape[0] * 2 + 1) 85 | 86 | class Sine(nn.Module): 87 | def __init__(self, w0 = 1.): 88 | super().__init__() 89 | self.w0 = w0 90 | def forward(self, x): 91 | return torch.sin(self.w0 * x) 92 | 93 | class Siren(nn.Module): 94 | def __init__( 95 | self, 96 | in_dim, 97 | out_dim, 98 | w0 = 1., 99 | c = 6., 100 | is_first = False, 101 | use_bias = True, 102 | activation = None, 103 | dropout = 0. 104 | ): 105 | super().__init__() 106 | self.in_dim = in_dim 107 | self.out_dim = out_dim 108 | self.is_first = is_first 109 | 110 | weight = torch.zeros(out_dim, in_dim) 111 | bias = torch.zeros(out_dim) if use_bias else None 112 | self.init_(weight, bias, c = c, w0 = w0) 113 | 114 | self.weight = nn.Parameter(weight) 115 | self.bias = nn.Parameter(bias) if use_bias else None 116 | self.activation = Sine(w0) if activation is None else activation 117 | self.dropout = nn.Dropout(dropout) 118 | 119 | def init_(self, weight, bias, c, w0): 120 | dim = self.in_dim 121 | 122 | w_std = (1 / dim) if self.is_first else (math.sqrt(c / dim) / w0) 123 | weight.uniform_(-w_std, w_std) 124 | 125 | if bias is not None: 126 | bias.uniform_(-w_std, w_std) 127 | 128 | def forward(self, x): 129 | out = F.linear(x, self.weight, self.bias) 130 | out = self.activation(out) 131 | out = self.dropout(out) 132 | return out 133 | 134 | def get_embedder(embed_type="fourier", num_freqs=-1, input_dim=3, include_pi=True): 135 | if embed_type == "identity" or (embed_type == "fourier" and num_freqs == -1): 136 | return nn.Identity(), input_dim 137 | 138 | elif embed_type == "fourier": 139 | embedder_obj = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi) 140 | 141 | elif embed_type == "learned_fourier": 142 | embedder_obj = LearnedFourierEmbedder(in_channels=input_dim, dim=num_freqs) 143 | 144 | elif embed_type == "siren": 145 | embedder_obj = Siren(in_dim=input_dim, out_dim=num_freqs * input_dim * 2 + input_dim) 146 | 147 | elif embed_type == "hashgrid": 148 | raise NotImplementedError 149 | 150 | elif embed_type == "sphere_harmonic": 151 | raise NotImplementedError 152 | 153 | else: 154 | raise ValueError(f"{embed_type} is not valid. Currently only supprts {VALID_EMBED_TYPES}") 155 | return embedder_obj 156 | 157 | 158 | ###################### AutoEncoder 159 | class AutoEncoder(BaseModule): 160 | @dataclass 161 | class Config(BaseModule.Config): 162 | pretrained_model_name_or_path: str = "" 163 | num_latents: int = 256 164 | embed_dim: int = 64 165 | width: int = 768 166 | upsample: bool = False 167 | cfg: Config 168 | 169 | def configure(self) -> None: 170 | super().configure() 171 | 172 | def encode(self, x: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]: 173 | raise NotImplementedError 174 | 175 | def decode(self, z: torch.FloatTensor) -> torch.FloatTensor: 176 | raise NotImplementedError 177 | 178 | def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True): 179 | posterior = None 180 | if self.cfg.embed_dim > 0: 181 | moments = self.pre_kl(latents) 182 | 183 | posterior = DiagonalGaussianDistribution(moments, feat_dim=-1) 184 | if sample_posterior: 185 | kl_embed = posterior.sample() 186 | else: 187 | kl_embed = posterior.mode() 188 | else: 189 | kl_embed = latents 190 | return kl_embed, posterior 191 | 192 | def forward(self, 193 | surface: torch.FloatTensor, 194 | surface2: torch.FloatTensor =None, 195 | surface3: torch.FloatTensor =None, 196 | queries: torch.FloatTensor =None, 197 | sample_posterior: bool = True, 198 | res: torch.FloatTensor =None, 199 | ): 200 | shape_latents, kl_embed, posterior = self.encode(surface,surface2,surface3, sample_posterior=sample_posterior) 201 | 202 | # logs=[] 203 | # de_latents=[] 204 | # if type(kl_embed) == list: 205 | 206 | # # for kl_ in kl_embed: 207 | 208 | # latents_0 = self.decode(kl_embed[0]) # [B, num_latents, width] 209 | # latents_1 = self.decode(kl_embed[1]) 210 | # # import ipdb 211 | # # ipdb.set_trace() 212 | 213 | # latents_0_up=self.cross_level(latents_1,latents_0 ) 214 | # latents_2=latents_1+latents_0_up 215 | 216 | # de_latents=[latents_0_up,latents_2] 217 | # logits = self.query(queries, latents_0_up) # [B,] 218 | # logits_2 = self.query(queries, latents_2) # [B,] 219 | # logs=[logits,logits_2] 220 | # else: 221 | de_latents = self.decode(kl_embed) # [B, num_latents, width] 222 | logs = self.query(queries, de_latents) # [B,] 223 | 224 | 225 | return kl_embed, de_latents, posterior, logs 226 | 227 | def query(self, queries: torch.FloatTensor, latents: torch.FloatTensor) -> torch.FloatTensor: 228 | raise NotImplementedError 229 | 230 | @torch.no_grad() 231 | def extract_geometry(self, 232 | latents: torch.FloatTensor, 233 | bounds: Union[Tuple[float], List[float], float] = (-1.05, -1.05, -1.05, 1.05, 1.05, 1.05), 234 | octree_depth: int = 8, 235 | num_chunks: int = 10000, 236 | ): 237 | 238 | if isinstance(bounds, float): 239 | bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] 240 | 241 | bbox_min = np.array(bounds[0:3]) 242 | bbox_max = np.array(bounds[3:6]) 243 | bbox_size = bbox_max - bbox_min 244 | 245 | xyz_samples, grid_size, length = generate_dense_grid_points( 246 | bbox_min=bbox_min, 247 | bbox_max=bbox_max, 248 | octree_depth=octree_depth, 249 | indexing="ij" 250 | ) 251 | xyz_samples = torch.FloatTensor(xyz_samples) 252 | 253 | 254 | batch_size = latents.shape[0] 255 | 256 | batch_logits = [] 257 | for start in range(0, xyz_samples.shape[0], num_chunks): 258 | queries = xyz_samples[start: start + num_chunks, :].to(latents) 259 | batch_queries = repeat(queries, "p c -> b p c", b=batch_size) 260 | 261 | logits = self.query(batch_queries, latents) 262 | batch_logits.append(logits.cpu()) 263 | 264 | grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])).float().numpy() 265 | 266 | mesh_v_f = [] 267 | has_surface = np.zeros((batch_size,), dtype=np.bool_) 268 | for i in range(batch_size): 269 | try: 270 | vertices, faces, normals, _ = measure.marching_cubes(grid_logits[i], 0, method="lewiner") 271 | vertices = vertices / grid_size * bbox_size + bbox_min 272 | faces = faces[:, [2, 1, 0]] 273 | mesh_v_f.append((vertices.astype(np.float32), np.ascontiguousarray(faces))) 274 | has_surface[i] = True 275 | except: 276 | mesh_v_f.append((None, None)) 277 | has_surface[i] = False 278 | 279 | return mesh_v_f, has_surface 280 | 281 | class DiagonalGaussianDistribution(object): 282 | def __init__(self, parameters: Union[torch.Tensor, List[torch.Tensor]], deterministic=False, feat_dim=1): 283 | self.feat_dim = feat_dim 284 | self.parameters = parameters 285 | 286 | if isinstance(parameters, list): 287 | self.mean = parameters[0] 288 | self.logvar = parameters[1] 289 | else: 290 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=feat_dim) 291 | 292 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 293 | self.deterministic = deterministic 294 | self.std = torch.exp(0.5 * self.logvar) 295 | self.var = torch.exp(self.logvar) 296 | if self.deterministic: 297 | self.var = self.std = torch.zeros_like(self.mean) 298 | 299 | def sample(self): 300 | x = self.mean + self.std * torch.randn_like(self.mean) 301 | return x 302 | 303 | def kl(self, other=None, dims=(1, 2)): 304 | if self.deterministic: 305 | return torch.Tensor([0.]) 306 | else: 307 | if other is None: 308 | return 0.5 * torch.mean(torch.pow(self.mean, 2) 309 | + self.var - 1.0 - self.logvar, 310 | dim=dims) 311 | else: 312 | return 0.5 * torch.mean( 313 | torch.pow(self.mean - other.mean, 2) / other.var 314 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 315 | dim=dims) 316 | 317 | def nll(self, sample, dims=(1, 2)): 318 | if self.deterministic: 319 | return torch.Tensor([0.]) 320 | logtwopi = np.log(2.0 * np.pi) 321 | return 0.5 * torch.sum( 322 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 323 | dim=dims) 324 | 325 | def mode(self): 326 | return self.mean 327 | -------------------------------------------------------------------------------- /mar3d/data/objaverse.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import json 4 | from dataclasses import dataclass, field 5 | 6 | import random 7 | import imageio 8 | import numpy as np 9 | import pytorch_lightning as pl 10 | import torch 11 | import torch.nn.functional as F 12 | from torch.utils.data import DataLoader, Dataset 13 | from torchvision import transforms 14 | from PIL import Image 15 | from transformers import CLIPImageProcessor, CLIPTokenizer 16 | import pickle 17 | from mar3d import register 18 | from mar3d.utils.base import Updateable 19 | from mar3d.utils.config import parse_structured 20 | from mar3d.utils.typing import * 21 | from plyfile import PlyData, PlyElement 22 | import pandas as pd 23 | def save_ply_plyfile(points, filename): 24 | # 创建结构化数组 25 | vertex = np.zeros(len(points), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 26 | vertex['x'] = points[:, 0] 27 | vertex['y'] = points[:, 1] 28 | vertex['z'] = points[:, 2] 29 | 30 | # 创建PlyElement 31 | el = PlyElement.describe(vertex, 'vertex') 32 | 33 | # 写入PLY文件 34 | PlyData([el], text=True).write(filename) 35 | def rot2eul(R): 36 | beta = -np.arcsin(R[2,0]) 37 | alpha = np.arctan2(R[2,1]/np.cos(beta),R[2,2]/np.cos(beta)) 38 | gamma = np.arctan2(R[1,0]/np.cos(beta),R[0,0]/np.cos(beta)) 39 | return np.array((alpha, beta, gamma)) 40 | 41 | def eul2rot(theta) : 42 | R = np.array([[np.cos(theta[1])*np.cos(theta[2]), np.sin(theta[0])*np.sin(theta[1])*np.cos(theta[2]) - np.sin(theta[2])*np.cos(theta[0]), np.sin(theta[1])*np.cos(theta[0])*np.cos(theta[2]) + np.sin(theta[0])*np.sin(theta[2])], 43 | [np.sin(theta[2])*np.cos(theta[1]), np.sin(theta[0])*np.sin(theta[1])*np.sin(theta[2]) + np.cos(theta[0])*np.cos(theta[2]), np.sin(theta[1])*np.sin(theta[2])*np.cos(theta[0]) - np.sin(theta[0])*np.cos(theta[2])], 44 | [-np.sin(theta[1]), np.sin(theta[0])*np.cos(theta[1]), np.cos(theta[0])*np.cos(theta[1])]]) 45 | return R 46 | 47 | @dataclass 48 | class ObjaverseDataModuleConfig: 49 | root_dir: str = None 50 | data_type: str = "occupancy" # occupancy or sdf 51 | n_samples: int = 4096 # number of points in input point cloud 52 | scale: float = 1.0 # scale of the input point cloud and target supervision 53 | noise_sigma: float = 0.0 # noise level of the input point cloud 54 | 55 | load_supervision: bool = True # whether to load supervision 56 | supervision_type: str = "occupancy" # occupancy, sdf, tsdf, tsdf_w_surface 57 | 58 | n_supervision: int = 10000 # number of points in supervision 59 | 60 | load_image: bool = False # whether to load images 61 | image_data_path: str = "" # path to the image data 62 | image_type: str = "rgb" # rgb, normal 63 | background_color: Tuple[float, float, float] = field( 64 | default_factory=lambda: (1.0, 1.0, 1.0) 65 | ) 66 | idx: Optional[List[int]] = None # index of the image to load 67 | n_views: int = 1 # number of views 68 | rotate: bool = False # whether to rotate the input point cloud and the supervision 69 | 70 | load_caption: bool = False # whether to load captions 71 | caption_type: str = "text" # text, clip_embeds 72 | tokenizer_pretrained_model_name_or_path: str = "" 73 | 74 | batch_size: int = 32 75 | num_workers: int = 0 76 | 77 | 78 | class ObjaverseDataset(Dataset): 79 | def __init__(self, cfg: Any, split: str) -> None: 80 | super().__init__() 81 | self.cfg = cfg 82 | self.split = split 83 | 84 | # make sure root_dir is list 85 | if isinstance(self.cfg.root_dir, str): 86 | self.cfg.root_dir = [self.cfg.root_dir] 87 | 88 | # cache file 89 | cache_file = '' 90 | if os.path.exists(cache_file): 91 | 92 | with open(cache_file, 'rb') as f: 93 | self.uids_og = pickle.load(f) 94 | else: 95 | self.uids_og = self._scan_files() 96 | with open(cache_file, 'wb') as f: 97 | pickle.dump(self.uids_og, f) 98 | 99 | self.background_color = torch.as_tensor(self.cfg.background_color) 100 | 101 | if self.cfg.load_image: 102 | 103 | mapping='' 104 | 105 | 106 | df = pd.read_csv(mapping, sep=',', header=None, names=['number', 'hash'], skiprows=1) 107 | 108 | 109 | self.mapping_dict = dict(zip(df['number'], df['hash'])) 110 | self.uids= [] 111 | for uid_idx, uid_tuple in enumerate(self.uids_og): 112 | # Extract the path and filename 113 | _, filename = uid_tuple 114 | # Get the part after the underscore 115 | if '_' in filename: 116 | search_str = filename.split('_', 1)[1] 117 | 118 | if search_str in self.mapping_dict.keys(): 119 | 120 | self.uids.append(uid_tuple) 121 | 122 | else: 123 | self.uids=self.uids_og 124 | print(f"Loaded {len(self.uids)} {split} usable uids") 125 | def _scan_files(self): 126 | uids = [] 127 | total_files = [] 128 | 129 | for root_dir in self.cfg.root_dir: 130 | files = os.listdir(root_dir) 131 | # 给每个文件添加对应的根目录信息 132 | files = [(root_dir, file) for file in files] 133 | total_files.extend(files) 134 | 135 | 136 | if self.split == 'train': 137 | total_files = total_files[150:] 138 | else: 139 | total_files = total_files[:150] 140 | 141 | 142 | return [ 143 | (root_dir, file) for root_dir, file in total_files 144 | if os.path.exists( 145 | f'{root_dir}/{file}/xxx.npz' 146 | ) 147 | ] 148 | 149 | def __len__(self): 150 | 151 | return len(self.uids) 152 | 153 | def _load_shape(self, index: int) -> Dict[str, Any]: 154 | 155 | if self.cfg.supervision_type == "sdf": 156 | 157 | sdfs3 = np.asarray(pointcloud['clean_surface_sdf']) 158 | ind3 = rng.choice(surface_og_n.shape[0], self.cfg.n_supervision//3, replace=False) 159 | 160 | rand_points3=surface_og[ind3] 161 | sdfs3 =sdfs3[ind3] 162 | normal3=surface_og_n[ind3] 163 | 164 | rand_points=np.concatenate((rand_points1,rand_points2,rand_points3),axis=0) 165 | sdfs=np.concatenate((sdfs1,sdfs2,sdfs3),axis=0) 166 | 167 | 168 | else: 169 | rand_points=np.concatenate((rand_points1,rand_points2),axis=0) 170 | sdfs=np.concatenate((sdfs1,sdfs2),axis=0) 171 | 172 | ret = { 173 | "uid": self.uids[index][1], 174 | "surface": surface.astype(np.float32), 175 | 176 | } 177 | 178 | 179 | ret["rand_points"] = rand_points.astype(np.float32) 180 | 181 | if self.cfg.supervision_type == "sdf": 182 | sdfs=np.nan_to_num(sdfs, nan=1.0, posinf=1.0, neginf=-1.0) 183 | # ret["sdf"] = sdfs.flatten().astype(np.float32).clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold) / self.cfg.tsdf_threshold 184 | # ret["sdf"] = sdfs.flatten().astype(np.float32).clip(-self.cfg.tsdf_threshold, self.cfg.tsdf_threshold) 185 | ret["sdf"] = sdfs.flatten().astype(np.float32) 186 | # ret["sdf"] = sdfs[ind2].flatten().astype(np.float32) 187 | ret['surface_normal']=normal3 188 | elif self.cfg.supervision_type == "occupancy": 189 | # ret["occupancies"] = np.where(sdfs[ind].flatten() < 1e-3, 0, 1).astype(np.float32) 190 | ret["occupancies"] = np.where(sdfs.flatten() < 0, 0, 1).astype(np.float32) 191 | else: 192 | raise NotImplementedError(f"Supervision type {self.cfg.supervision_type} not implemented") 193 | 194 | return ret 195 | 196 | def _load_image(self, index: int) -> Dict[str, Any]: 197 | name=self.uids[index][1].split('_')[1] 198 | file_path=self.mapping_dict[name] 199 | # image_paths=os.path.join(images_root,file_path,file_path,name) 200 | def _load_single_image(img_path): 201 | img = torch.from_numpy( 202 | np.asarray( 203 | Image.fromarray(imageio.v2.imread(img_path)) 204 | .convert("RGBA") 205 | ) 206 | / 255.0 207 | ).float() 208 | mask: Float[Tensor, "H W 1"] = img[:, :, -1:] 209 | image: Float[Tensor, "H W 3"] = img[:, :, :3] * mask + self.background_color[ 210 | None, None, : 211 | ] * (1 - mask) 212 | return image 213 | ret = {} 214 | if self.cfg.image_type == "rgb" or self.cfg.image_type == "normal": 215 | assert self.cfg.n_views == 1, "Only single view is supported for single image" 216 | sel_idx = random.choice(self.cfg.idx) 217 | ret["sel_image_idx"] = sel_idx 218 | 219 | 220 | img_path=file_path 221 | ret["image"] = _load_single_image(img_path) 222 | else: 223 | raise NotImplementedError(f"Image type {self.cfg.image_type} not implemented") 224 | 225 | return ret 226 | 227 | def _load_caption(self, index: int, drop_text_embed: bool = False) -> Dict[str, Any]: 228 | ret = {} 229 | if self.cfg.caption_type == "text": 230 | caption = eval(json.load(open(f'{self.cfg.image_data_path}/' + "/".join(self.uids[index].split('/')[-2:]) + f'/annotation.json'))) 231 | texts = [v for k, v in caption.items()] 232 | sel_idx = random.randint(0, len(texts) - 1) 233 | ret["sel_caption_idx"] = sel_idx 234 | ret['text_input_ids'] = self.tokenizer( 235 | texts[sel_idx] if not drop_text_embed else "", 236 | max_length=self.tokenizer.model_max_length, 237 | padding="max_length", 238 | truncation=True, 239 | return_tensors="pt" 240 | ).input_ids.detach() 241 | else: 242 | raise NotImplementedError(f"Caption type {self.cfg.caption_type} not implemented") 243 | 244 | return ret 245 | 246 | def get_data(self, index): 247 | # load shape 248 | ret = self._load_shape(index) 249 | 250 | 251 | if self.cfg.load_image: 252 | ret.update(self._load_image(index)) 253 | return ret 254 | 255 | def __getitem__(self, index): 256 | try: 257 | return self.get_data(index) 258 | except Exception as e: 259 | print(f"Error in {self.uids[index]}: {e}") 260 | return self.__getitem__(np.random.randint(len(self))) 261 | 262 | 263 | def collate(self, batch): 264 | batch = torch.utils.data.default_collate(batch) 265 | return batch 266 | 267 | 268 | 269 | @register("objaverse-datamodule") 270 | class ObjaverseDataModule(pl.LightningDataModule): 271 | cfg: ObjaverseDataModuleConfig 272 | 273 | def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: 274 | super().__init__() 275 | self.cfg = parse_structured(ObjaverseDataModuleConfig, cfg) 276 | 277 | def setup(self, stage=None) -> None: 278 | if stage in [None, "fit"]: 279 | self.train_dataset = ObjaverseDataset(self.cfg, "train") 280 | if stage in [None, "fit", "validate"]: 281 | self.val_dataset = ObjaverseDataset(self.cfg, "val") 282 | if stage in [None, "test", "predict"]: 283 | self.test_dataset = ObjaverseDataset(self.cfg, "test") 284 | 285 | def prepare_data(self): 286 | pass 287 | 288 | def general_loader(self, dataset, batch_size, collate_fn=None, num_workers=0) -> DataLoader: 289 | return DataLoader( 290 | dataset, batch_size=batch_size, shuffle=True,collate_fn=collate_fn, num_workers=num_workers 291 | ) 292 | def train_dataloader(self) -> DataLoader: 293 | return self.general_loader( 294 | self.train_dataset, 295 | batch_size=self.cfg.batch_size, 296 | collate_fn=self.train_dataset.collate, 297 | num_workers=self.cfg.num_workers 298 | ) 299 | 300 | def val_dataloader(self) -> DataLoader: 301 | return self.general_loader(self.val_dataset, batch_size=1) 302 | 303 | def test_dataloader(self) -> DataLoader: 304 | return self.general_loader(self.test_dataset, batch_size=1) 305 | 306 | def predict_dataloader(self) -> DataLoader: 307 | return self.general_loader(self.test_dataset, batch_size=1) -------------------------------------------------------------------------------- /mar3d/models/conditional_encoders/clip/modeling_conditional_clip.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | # Copyright 2023 Meta AI and The HuggingFace Inc. team. All rights reserved. 3 | # 4 | # Licensed under the Apache License, Version 2.0 (the "License"); 5 | # you may not use this file except in compliance with the License. 6 | # You may obtain a copy of the License at 7 | # 8 | # http://www.apache.org/licenses/LICENSE-2.0 9 | # 10 | # Unless required by applicable law or agreed to in writing, software 11 | # distributed under the License is distributed on an "AS IS" BASIS, 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 | # See the License for the specific language governing permissions and 14 | # limitations under the License. 15 | # Reference: 16 | # * transformers/models/dinov2/modeling_dinov2.py 17 | # * https://github.com/facebookresearch/DiT/blob/main/models.py#L101 18 | # * https://github.com/3DTopia/OpenLRM/tree/main/openlrm/models/encoders/dinov2 19 | """ PyTorch CLIP model.""" 20 | 21 | from typing import Dict, List, Optional, Set, Tuple, Union 22 | 23 | import torch 24 | import torch.nn as nn 25 | 26 | from .modeling_clip import ( 27 | CLIPConfig, 28 | CLIPTextConfig, 29 | CLIPVisionConfig, 30 | CLIPEncoderLayer, 31 | CLIPTextTransformer, 32 | CLIPVisionTransformer, 33 | CLIPModel, 34 | CLIPVisionEmbeddings, 35 | CLIPVisionModel, 36 | CLIPOutput, 37 | BaseModelOutput, 38 | BaseModelOutputWithPooling 39 | ) 40 | 41 | 42 | class ModLN(nn.Module): 43 | def __init__(self, inner_dim: int, mod_dim: int = 32): 44 | super().__init__() 45 | self.mlp = nn.Sequential( 46 | nn.SiLU(), 47 | nn.Linear(mod_dim, inner_dim * 2), 48 | ) 49 | 50 | for m in self.modules(): 51 | if isinstance(m, nn.Linear): 52 | nn.init.zeros_(m.weight) 53 | nn.init.zeros_(m.bias) 54 | 55 | def forward(self, x:torch.Tensor, condition:torch.Tensor): 56 | ''' 57 | x: [N, M, C_in], M: num of tokens 58 | condition: [N, C_mod] 59 | ''' 60 | shift, scale = self.mlp(condition).unsqueeze(1).chunk(2, dim=-1) 61 | return x * (1 + scale) + shift 62 | 63 | 64 | class ConditionalCLIPVisionConfig(CLIPVisionConfig): 65 | def __init__(self, modulation_dim: int = 32, *args, **kwargs): 66 | super().__init__(*args, **kwargs) 67 | self.modulation_dim = modulation_dim 68 | 69 | 70 | class ConditionalCLIPEncoderLayer(CLIPEncoderLayer): 71 | """This corresponds to the Block class in the original implementation.""" 72 | 73 | def __init__(self, config: ConditionalCLIPVisionConfig) -> None: 74 | super().__init__(config) 75 | self.mod_norm1 = ModLN(config.hidden_size, config.modulation_dim) 76 | self.mod_norm2 = ModLN(config.hidden_size, config.modulation_dim) 77 | 78 | def forward( 79 | self, 80 | hidden_states: torch.Tensor, 81 | attention_mask: torch.Tensor, 82 | causal_attention_mask: torch.Tensor, 83 | condition: Optional[torch.Tensor] = None, 84 | output_attentions: bool = False, 85 | ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: 86 | residual = hidden_states 87 | 88 | hidden_states = self.mod_norm1(self.layer_norm1(hidden_states), condition) 89 | hidden_states, attn_weights = self.self_attn( 90 | hidden_states=hidden_states, 91 | attention_mask=attention_mask, 92 | causal_attention_mask=causal_attention_mask, 93 | output_attentions=output_attentions, 94 | ) 95 | hidden_states = residual + hidden_states 96 | 97 | residual = hidden_states 98 | hidden_states = self.mod_norm2(self.layer_norm2(hidden_states), condition) 99 | hidden_states = self.mlp(hidden_states) 100 | hidden_states = residual + hidden_states 101 | 102 | outputs = (hidden_states,) 103 | 104 | if output_attentions: 105 | outputs += (attn_weights,) 106 | 107 | return outputs 108 | 109 | 110 | class ConditionalCLIPEncoder(nn.Module): 111 | def __init__(self, config: CLIPConfig) -> None: 112 | super().__init__() 113 | self.config = config 114 | self.layers = nn.ModuleList([ConditionalCLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) 115 | self.gradient_checkpointing = False 116 | 117 | def forward( 118 | self, 119 | inputs_embeds, 120 | attention_mask: Optional[torch.Tensor] = None, 121 | causal_attention_mask: Optional[torch.Tensor] = None, 122 | output_attentions: Optional[bool] = None, 123 | output_hidden_states: Optional[bool] = None, 124 | condition: Optional[torch.Tensor] = None, 125 | return_dict: Optional[bool] = None, 126 | ) -> Union[tuple, BaseModelOutput]: 127 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 128 | output_hidden_states = ( 129 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 130 | ) 131 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 132 | 133 | encoder_states = () if output_hidden_states else None 134 | all_attentions = () if output_attentions else None 135 | 136 | hidden_states = inputs_embeds 137 | for idx, encoder_layer in enumerate(self.layers): 138 | if output_hidden_states: 139 | encoder_states = encoder_states + (hidden_states,) 140 | if self.gradient_checkpointing and self.training: 141 | layer_outputs = self._gradient_checkpointing_func( 142 | encoder_layer.__call__, 143 | hidden_states, 144 | attention_mask, 145 | causal_attention_mask, 146 | condition=condition, 147 | output_attentions=output_attentions, 148 | ) 149 | else: 150 | layer_outputs = encoder_layer( 151 | hidden_states, 152 | attention_mask, 153 | causal_attention_mask, 154 | condition=condition, 155 | output_attentions=output_attentions, 156 | ) 157 | 158 | hidden_states = layer_outputs[0] 159 | 160 | if output_attentions: 161 | all_attentions = all_attentions + (layer_outputs[1],) 162 | 163 | if output_hidden_states: 164 | encoder_states = encoder_states + (hidden_states,) 165 | 166 | if not return_dict: 167 | return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) 168 | return BaseModelOutput( 169 | last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions 170 | ) 171 | 172 | 173 | class ConditionalCLIPVisionTransformer(CLIPVisionTransformer): 174 | def __init__(self, config: ConditionalCLIPVisionConfig): 175 | super().__init__(config) 176 | self.config = config 177 | embed_dim = config.hidden_size 178 | 179 | self.embeddings = CLIPVisionEmbeddings(config) 180 | self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 181 | self.encoder = ConditionalCLIPEncoder(config) 182 | self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) 183 | 184 | def forward( 185 | self, 186 | pixel_values: Optional[torch.FloatTensor] = None, 187 | condition: Optional[torch.Tensor] = None, 188 | output_attentions: Optional[bool] = None, 189 | output_hidden_states: Optional[bool] = None, 190 | return_dict: Optional[bool] = None, 191 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 192 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 193 | output_hidden_states = ( 194 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 195 | ) 196 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 197 | 198 | if pixel_values is None: 199 | raise ValueError("You have to specify pixel_values") 200 | 201 | hidden_states = self.embeddings(pixel_values) 202 | hidden_states = self.pre_layrnorm(hidden_states) 203 | 204 | encoder_outputs = self.encoder( 205 | inputs_embeds=hidden_states, 206 | output_attentions=output_attentions, 207 | output_hidden_states=output_hidden_states, 208 | condition=condition, 209 | return_dict=return_dict, 210 | ) 211 | 212 | last_hidden_state = encoder_outputs[0] 213 | pooled_output = last_hidden_state[:, 0, :] 214 | pooled_output = self.post_layernorm(pooled_output) 215 | 216 | if not return_dict: 217 | return (last_hidden_state, pooled_output) + encoder_outputs[1:] 218 | 219 | return BaseModelOutputWithPooling( 220 | last_hidden_state=last_hidden_state, 221 | pooler_output=pooled_output, 222 | hidden_states=encoder_outputs.hidden_states, 223 | attentions=encoder_outputs.attentions, 224 | ) 225 | 226 | 227 | class ConditionalCLIPVisionModel(CLIPVisionModel): 228 | config_class = ConditionalCLIPVisionConfig 229 | 230 | def __init__(self, config: ConditionalCLIPVisionConfig): 231 | super().__init__(config) 232 | self.vision_model = ConditionalCLIPVisionTransformer(config) 233 | # Initialize weights and apply final processing 234 | self.post_init() 235 | 236 | def forward( 237 | self, 238 | pixel_values: Optional[torch.FloatTensor] = None, 239 | condition: Optional[torch.Tensor] = None, 240 | output_attentions: Optional[bool] = None, 241 | output_hidden_states: Optional[bool] = None, 242 | return_dict: Optional[bool] = None, 243 | ) -> Union[Tuple, BaseModelOutputWithPooling]: 244 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 245 | 246 | return self.vision_model( 247 | pixel_values=pixel_values, 248 | condition=condition, 249 | output_attentions=output_attentions, 250 | output_hidden_states=output_hidden_states, 251 | return_dict=return_dict, 252 | ) 253 | 254 | 255 | class ConditionalCLIPModel(CLIPModel): 256 | config_class = CLIPConfig 257 | 258 | def __init__(self, config: CLIPConfig): 259 | super().__init__(config) 260 | 261 | if not isinstance(config.text_config, CLIPTextConfig): 262 | raise ValueError( 263 | "config.text_config is expected to be of type CLIPTextConfig but is of type" 264 | f" {type(config.text_config)}." 265 | ) 266 | 267 | if not isinstance(config.vision_config, CLIPVisionConfig): 268 | raise ValueError( 269 | "config.vision_config is expected to be of type CLIPVisionConfig but is of type" 270 | f" {type(config.vision_config)}." 271 | ) 272 | 273 | text_config = config.text_config 274 | vision_config = config.vision_config 275 | 276 | self.projection_dim = config.projection_dim 277 | self.text_embed_dim = text_config.hidden_size 278 | self.vision_embed_dim = vision_config.hidden_size 279 | 280 | self.text_model = CLIPTextTransformer(text_config) 281 | self.vision_model = ConditionalCLIPVisionTransformer(vision_config) 282 | 283 | self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) 284 | self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) 285 | self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value)) 286 | 287 | # Initialize weights and apply final processing 288 | self.post_init() 289 | 290 | def get_image_features( 291 | self, 292 | pixel_values: Optional[torch.FloatTensor] = None, 293 | condition: Optional[torch.Tensor] = None, 294 | output_attentions: Optional[bool] = None, 295 | output_hidden_states: Optional[bool] = None, 296 | return_dict: Optional[bool] = None, 297 | ) -> torch.FloatTensor: 298 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 299 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 300 | output_hidden_states = ( 301 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 302 | ) 303 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 304 | 305 | vision_outputs = self.vision_model( 306 | pixel_values=pixel_values, 307 | condition=condition, 308 | output_attentions=output_attentions, 309 | output_hidden_states=output_hidden_states, 310 | return_dict=return_dict, 311 | ) 312 | 313 | pooled_output = vision_outputs[1] # pooled_output 314 | image_features = self.visual_projection(pooled_output) 315 | 316 | return image_features 317 | 318 | def forward( 319 | self, 320 | input_ids: Optional[torch.LongTensor] = None, 321 | pixel_values: Optional[torch.FloatTensor] = None, 322 | condition: Optional[torch.Tensor] = None, 323 | attention_mask: Optional[torch.Tensor] = None, 324 | position_ids: Optional[torch.LongTensor] = None, 325 | return_loss: Optional[bool] = None, 326 | output_attentions: Optional[bool] = None, 327 | output_hidden_states: Optional[bool] = None, 328 | return_dict: Optional[bool] = None, 329 | ) -> Union[Tuple, CLIPOutput]: 330 | # Use CLIP model's config for some fields (if specified) instead of those of vision & text components. 331 | output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions 332 | output_hidden_states = ( 333 | output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states 334 | ) 335 | return_dict = return_dict if return_dict is not None else self.config.use_return_dict 336 | 337 | vision_outputs = self.vision_model( 338 | pixel_values=pixel_values, 339 | condition=condition, 340 | output_attentions=output_attentions, 341 | output_hidden_states=output_hidden_states, 342 | return_dict=return_dict, 343 | ) 344 | 345 | text_outputs = self.text_model( 346 | input_ids=input_ids, 347 | attention_mask=attention_mask, 348 | position_ids=position_ids, 349 | output_attentions=output_attentions, 350 | output_hidden_states=output_hidden_states, 351 | return_dict=return_dict, 352 | ) 353 | 354 | image_embeds = vision_outputs[1] 355 | image_embeds = self.visual_projection(image_embeds) 356 | 357 | text_embeds = text_outputs[1] 358 | text_embeds = self.text_projection(text_embeds) 359 | 360 | # normalized features 361 | image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True) 362 | text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True) 363 | 364 | # cosine similarity as logits 365 | logit_scale = self.logit_scale.exp() 366 | logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale 367 | logits_per_image = logits_per_text.t() 368 | 369 | loss = None 370 | if return_loss: 371 | loss = clip_loss(logits_per_text) 372 | 373 | if not return_dict: 374 | output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) 375 | return ((loss,) + output) if loss is not None else output 376 | 377 | return CLIPOutput( 378 | loss=loss, 379 | logits_per_image=logits_per_image, 380 | logits_per_text=logits_per_text, 381 | text_embeds=text_embeds, 382 | image_embeds=image_embeds, 383 | text_model_output=text_outputs, 384 | vision_model_output=vision_outputs, 385 | ) 386 | -------------------------------------------------------------------------------- /mar3d/systems/mar_diffusion.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass, field 2 | import os 3 | import numpy as np 4 | import json 5 | import copy 6 | import torch 7 | import torch.nn.functional as F 8 | from skimage import measure 9 | from einops import repeat 10 | from tqdm import tqdm 11 | from PIL import Image 12 | 13 | from transformers import CLIPImageProcessor, CLIPTokenizer 14 | from diffusers import ( 15 | DDPMScheduler, 16 | DDIMScheduler, 17 | UniPCMultistepScheduler, 18 | KarrasVeScheduler, 19 | DPMSolverMultistepScheduler 20 | ) 21 | import scipy.stats as stats 22 | 23 | from plyfile import PlyData, PlyElement 24 | import mar3d 25 | from mar3d.systems.base import BaseSystem 26 | from mar3d.utils.ops import generate_dense_grid_points 27 | from mar3d.utils.typing import * 28 | import torchvision 29 | import pandas as pd 30 | import math 31 | from timm.models.vision_transformer import Block 32 | from mar3d.systems.diffloss import DiffLoss 33 | import torch.nn as nn 34 | 35 | 36 | class MAR(nn.Module): 37 | """ Masked Autoencoder with VisionTransformer backbone 38 | """ 39 | def __init__(self, latent_size=256, patch_size=1, 40 | encoder_embed_dim=1024, encoder_depth=16, encoder_num_heads=16, 41 | decoder_embed_dim=1024, decoder_depth=16, decoder_num_heads=16, 42 | mlp_ratio=4., norm_layer=nn.LayerNorm, 43 | vae_embed_dim=64, 44 | mask_ratio_min=0.7, 45 | label_drop_prob=0.1, 46 | attn_dropout=0.1, 47 | proj_dropout=0.1, 48 | buffer_size=256, 49 | diffloss_d=6, 50 | diffloss_w=1024, 51 | num_sampling_steps='100', 52 | diffusion_batch_mul=4, 53 | grad_checkpointing=False, 54 | ): 55 | super().__init__() 56 | 57 | # -------------------------------------------------------------------------- 58 | # VAE and patchify specifics 59 | self.vae_embed_dim = vae_embed_dim 60 | # self.vae_stride = vae_stride 61 | self.patch_size = patch_size 62 | # self.seq_h = self.seq_w = img_size // vae_stride // patch_size 63 | self.seq_len =latent_size// patch_size 64 | self.token_embed_dim = vae_embed_dim * patch_size 65 | self.grad_checkpointing = grad_checkpointing 66 | 67 | 68 | # -------------------------------------------------------------------------- 69 | # Class Embedding 70 | # self.num_classes = class_num 71 | # self.class_emb = nn.Embedding(1000, encoder_embed_dim) 72 | self.label_drop_prob = label_drop_prob 73 | # Fake class embedding for CFG's unconditional generation 74 | self.fake_latent = nn.Parameter(torch.zeros(buffer_size, encoder_embed_dim)) 75 | 76 | 77 | # -------------------------------------------------------------------------- 78 | # MAR variant masking ratio, a left-half truncated Gaussian centered at 100% masking ratio with std 0.25 79 | self.mask_ratio_generator = stats.truncnorm((mask_ratio_min - 1.0) / 0.25, 0, loc=1.0, scale=0.25) 80 | 81 | # -------------------------------------------------------------------------- 82 | # MAR encoder specifics 83 | self.z_proj = nn.Linear(self.token_embed_dim, encoder_embed_dim, bias=True) 84 | self.z_proj_ln = nn.LayerNorm(encoder_embed_dim, eps=1e-6) 85 | self.buffer_size = buffer_size 86 | self.encoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, encoder_embed_dim)) 87 | 88 | self.encoder_blocks = nn.ModuleList([ 89 | Block(encoder_embed_dim, encoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, 90 | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(encoder_depth)]) 91 | self.encoder_norm = norm_layer( encoder_embed_dim) 92 | 93 | # -------------------------------------------------------------------------- 94 | # MAR decoder specifics 95 | self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True) 96 | self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim)) 97 | self.decoder_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len + self.buffer_size, decoder_embed_dim)) 98 | 99 | self.decoder_blocks = nn.ModuleList([ 100 | Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer, 101 | proj_drop=proj_dropout, attn_drop=attn_dropout) for _ in range(decoder_depth)]) 102 | 103 | self.decoder_norm = norm_layer(decoder_embed_dim) 104 | self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, self.seq_len, decoder_embed_dim)) 105 | 106 | self.initialize_weights() 107 | 108 | # -------------------------------------------------------------------------- 109 | # Diffusion Loss 110 | self.diffloss = DiffLoss( 111 | target_channels=self.token_embed_dim, 112 | z_channels=decoder_embed_dim, 113 | width=diffloss_w, 114 | depth=diffloss_d, 115 | num_sampling_steps=num_sampling_steps, 116 | grad_checkpointing=grad_checkpointing 117 | ) 118 | self.diffusion_batch_mul = diffusion_batch_mul 119 | 120 | def initialize_weights(self): 121 | # parameters 122 | # torch.nn.init.normal_(self.class_emb.weight, std=.02) 123 | torch.nn.init.normal_(self.fake_latent, std=.02) 124 | 125 | torch.nn.init.normal_(self.mask_token, std=.02) 126 | torch.nn.init.normal_(self.encoder_pos_embed_learned, std=.02) 127 | torch.nn.init.normal_(self.decoder_pos_embed_learned, std=.02) 128 | torch.nn.init.normal_(self.diffusion_pos_embed_learned, std=.02) 129 | 130 | # initialize nn.Linear and nn.LayerNorm 131 | self.apply(self._init_weights) 132 | 133 | def _init_weights(self, m): 134 | if isinstance(m, nn.Linear): 135 | # we use xavier_uniform following official JAX ViT: 136 | torch.nn.init.xavier_uniform_(m.weight) 137 | if isinstance(m, nn.Linear) and m.bias is not None: 138 | nn.init.constant_(m.bias, 0) 139 | elif isinstance(m, nn.LayerNorm): 140 | if m.bias is not None: 141 | nn.init.constant_(m.bias, 0) 142 | if m.weight is not None: 143 | nn.init.constant_(m.weight, 1.0) 144 | def sample_orders(self, bsz): 145 | # generate a batch of random generation orders 146 | orders = [] 147 | for _ in range(bsz): 148 | order = np.array(list(range(self.seq_len))) 149 | np.random.shuffle(order) 150 | orders.append(order) 151 | orders = torch.Tensor(np.array(orders)).cuda().long() 152 | return orders 153 | 154 | def random_masking(self, x, orders): 155 | # generate token mask 156 | bsz, seq_len, embed_dim = x.shape 157 | mask_rate = self.mask_ratio_generator.rvs(1)[0] 158 | num_masked_tokens = int(np.ceil(seq_len * mask_rate)) 159 | mask = torch.zeros(bsz, seq_len, device=x.device) 160 | mask = torch.scatter(mask, dim=-1, index=orders[:, :num_masked_tokens], 161 | src=torch.ones(bsz, seq_len, device=x.device)) 162 | return mask 163 | 164 | def forward_mae_encoder(self, x, mask, class_embedding): 165 | 166 | 167 | x = self.z_proj(x) 168 | bsz, seq_len, embed_dim = x.shape 169 | 170 | # concat buffer 171 | x = torch.cat([torch.zeros(bsz, self.buffer_size, embed_dim, device=x.device), x], dim=1) 172 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) 173 | 174 | # random drop class embedding during training 175 | if self.training: 176 | 177 | drop_latent_mask = torch.rand(bsz) < self.label_drop_prob 178 | drop_latent_mask =drop_latent_mask .to(x.device) 179 | drop_latent_mask = drop_latent_mask.unsqueeze(-1).to(x.dtype) 180 | 181 | class_embedding = drop_latent_mask[:,:,None] * self.fake_latent[None] + (1 - drop_latent_mask) [:,:,None]* class_embedding 182 | 183 | 184 | x[:, :self.buffer_size] = class_embedding 185 | 186 | # encoder position embedding 187 | x = x + self.encoder_pos_embed_learned 188 | x = self.z_proj_ln(x) 189 | 190 | # dropping 191 | x = x[ (1-mask_with_buffer).nonzero(as_tuple=True) ].reshape(bsz, -1, embed_dim) 192 | 193 | # apply Transformer blocks 194 | if self.grad_checkpointing and not torch.jit.is_scripting(): 195 | for block in self.encoder_blocks: 196 | x = checkpoint(block, x) 197 | else: 198 | for block in self.encoder_blocks: 199 | x = block(x) 200 | x = self.encoder_norm(x) 201 | 202 | return x 203 | 204 | def forward_mae_decoder(self, x, mask): 205 | 206 | x = self.decoder_embed(x) 207 | mask_with_buffer = torch.cat([torch.zeros(x.size(0), self.buffer_size, device=x.device), mask], dim=1) 208 | 209 | # pad mask tokens 210 | mask_tokens = self.mask_token.repeat(mask_with_buffer.shape[0], mask_with_buffer.shape[1], 1).to(x.dtype) 211 | x_after_pad = mask_tokens.clone() 212 | x_after_pad[(1 - mask_with_buffer).nonzero(as_tuple=True)] = x.reshape(x.shape[0] * x.shape[1], x.shape[2]) 213 | 214 | # decoder position embedding 215 | x = x_after_pad + self.decoder_pos_embed_learned 216 | 217 | # apply Transformer blocks 218 | if self.grad_checkpointing and not torch.jit.is_scripting(): 219 | for block in self.decoder_blocks: 220 | x = checkpoint(block, x) 221 | else: 222 | for block in self.decoder_blocks: 223 | x = block(x) 224 | x = self.decoder_norm(x) 225 | 226 | x = x[:, self.buffer_size:] 227 | x = x + self.diffusion_pos_embed_learned 228 | return x 229 | 230 | def forward_loss(self, z, target, mask): 231 | bsz, seq_len, _ = target.shape 232 | target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1) 233 | z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1) 234 | mask = mask.reshape(bsz*seq_len).repeat(self.diffusion_batch_mul) 235 | loss = self.diffloss(z=z, target=target, mask=mask) 236 | return loss 237 | 238 | def forward(self, x, labels,low_res=None): 239 | 240 | 241 | gt_latents = x.clone().detach() 242 | orders = self.sample_orders(bsz=x.size(0)) 243 | mask = self.random_masking(x, orders) 244 | 245 | # mae encoder 246 | x = self.forward_mae_encoder(x, mask,labels) 247 | 248 | # mae decoder 249 | z = self.forward_mae_decoder(x, mask) 250 | # print(z.shape) 251 | # diffloss 252 | loss = self.forward_loss(z=z, target=gt_latents, mask=mask) 253 | 254 | return loss 255 | 256 | def sample_tokens(self, bsz, num_iter=64, cfg=1.0, cfg_schedule="linear", labels=None, temperature=1.0, progress=True): 257 | 258 | # init and sample generation orders 259 | mask = torch.ones(bsz, self.seq_len).cuda() 260 | tokens = torch.zeros(bsz, self.seq_len, self.token_embed_dim).cuda() 261 | orders = self.sample_orders(bsz) 262 | 263 | indices = list(range(num_iter)) 264 | if progress: 265 | indices = tqdm(indices) 266 | # generate latents 267 | for step in indices: 268 | cur_tokens = tokens.clone() 269 | if not cfg == 1.0: 270 | # ipdb.set_trace() 271 | tokens = torch.cat([tokens, tokens], dim=0) 272 | class_embedding = torch.cat([labels, self.fake_latent[None].repeat(bsz, 1,1)], dim=0) 273 | mask = torch.cat([mask, mask], dim=0) 274 | # mae encoder 275 | x = self.forward_mae_encoder(tokens, mask, class_embedding) 276 | 277 | # mae decoder 278 | z = self.forward_mae_decoder(x, mask) 279 | # ipdb.set_trace() 280 | 281 | # mask ratio for the next round, following MaskGIT and MAGE. 282 | mask_ratio = np.cos(math.pi / 2. * (step + 1) / num_iter) 283 | mask_len = torch.Tensor([np.floor(self.seq_len * mask_ratio)]).cuda() 284 | 285 | # masks out at least one for the next iteration 286 | mask_len = torch.maximum(torch.Tensor([1]).cuda(), 287 | torch.minimum(torch.sum(mask, dim=-1, keepdims=True) - 1, mask_len)) 288 | 289 | # get masking for next iteration and locations to be predicted in this iteration 290 | mask_next = mask_by_order(mask_len[0], orders, bsz, self.seq_len) 291 | if step >= num_iter - 1: 292 | mask_to_pred = mask[:bsz].bool() 293 | else: 294 | mask_to_pred = torch.logical_xor(mask[:bsz].bool(), mask_next.bool()) 295 | mask = mask_next 296 | if not cfg == 1.0: 297 | mask_to_pred = torch.cat([mask_to_pred, mask_to_pred], dim=0) 298 | 299 | # sample token latents for this step 300 | z = z[mask_to_pred.nonzero(as_tuple=True)] 301 | # cfg schedule follow Muse 302 | if cfg_schedule == "linear": 303 | cfg_iter = 1 + (cfg - 1) * (self.seq_len - mask_len[0]) / self.seq_len 304 | elif cfg_schedule == "constant": 305 | cfg_iter = cfg 306 | else: 307 | raise NotImplementedError 308 | # ipdb.set_trace() 309 | # print(z.shape) 310 | sampled_token_latent = self.diffloss.sample(z, temperature, cfg_iter) 311 | if not cfg == 1.0: 312 | sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0) # Remove null class samples 313 | mask_to_pred, _ = mask_to_pred.chunk(2, dim=0) 314 | 315 | cur_tokens[mask_to_pred.nonzero(as_tuple=True)] = sampled_token_latent 316 | tokens = cur_tokens.clone() 317 | 318 | return tokens 319 | 320 | 321 | @mar3d.register("mar-diffusion-system") 322 | class ShapeDiffusionSystem(BaseSystem): 323 | @dataclass 324 | class Config(BaseSystem.Config): 325 | data_type: str = None 326 | val_samples_json: str = None 327 | # shape vae model 328 | shape_model_type: str = None 329 | # condition model 330 | shape_model: dict = field(default_factory=dict) 331 | condition_model_type: str = None 332 | condition_model: dict = field(default_factory=dict) 333 | 334 | cfg: Config 335 | 336 | def configure(self): 337 | super().configure() 338 | self.shape_model = mar3d.find(self.cfg.shape_model_type)(self.cfg.shape_model) 339 | self.shape_model.eval() 340 | self.sigma_min = 0.0 341 | self.condition = mar3d.find(self.cfg.condition_model_type)(self.cfg.condition_model) 342 | self.condition.requires_grad_(False) 343 | self.mar=MAR(patch_size=1,latent_size=256,buffer_size=257) 344 | 345 | def forward(self, batch: Dict[str, Any]) -> Dict[str, Any]: 346 | 347 | shape_embeds, kl_embed, posterior = self.shape_model.encode( 348 | batch["surface"][..., :3 + self.cfg.shape_model.point_feats], 349 | sample_posterior=True) 350 | 351 | latents=kl_embed 352 | cond_latents = self.condition(batch).to(latents.device) 353 | 354 | loss=self.mar(latents,cond_latents) 355 | return { 356 | "loss_diffusion": loss, 357 | "latents": latents, 358 | } 359 | 360 | def training_step(self, batch, batch_idx): 361 | out = self(batch) 362 | loss = 0. 363 | for name, value in out.items(): 364 | if name.startswith("loss_"): 365 | self.log(f"train/{name}", value) 366 | loss += value * self.C(self.cfg.loss[name.replace("loss_", "lambda_")]) 367 | 368 | for name, value in self.cfg.loss.items(): 369 | if name.startswith("lambda_"): 370 | self.log(f"train_params/{name}", self.C(value)) 371 | 372 | return {"loss": loss} 373 | 374 | @torch.no_grad() 375 | def validation_step(self, batch, batch_idx): 376 | self.eval() 377 | oct_num=9 378 | with torch.no_grad(): 379 | 380 | cond = self.condition.encode_image(batch['image']) 381 | 382 | 383 | latents = self.mar.sample_tokens(bsz=cond.shape[0], num_iter=64, cfg=3.0 , 384 | cfg_schedule="linear", labels=cond, temperature=1.0) 385 | output=self.shape_model.decode(latents ) 386 | 387 | mesh_v_f, has_surface = self.shape_model.extract_geometry(output, octree_depth=oct_num) 388 | for j in range(len(mesh_v_f)): 389 | self.save_mesh( 390 | f"it{self.true_global_step}/{batch['uid'][0]}_cfg{3.0}_oct{oct_num}.obj", 391 | mesh_v_f[j][0], mesh_v_f[j][1] 392 | ) 393 | 394 | self.save_image(f"it{self.true_global_step}/{batch['uid'][0]}_gt.jpg", (batch['image']*255).int() ) 395 | 396 | torch.cuda.empty_cache() --------------------------------------------------------------------------------